├── .gitattributes ├── .github └── workflows │ └── pylint.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── Data ├── WeatherReal-ISD-2021.nc ├── WeatherReal-ISD-2022.nc └── WeatherReal-ISD-2023.nc ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── conda.yml ├── evaluate.py ├── evaluation ├── __init__.py ├── forecast_reformat_catalog.py ├── metric_catalog.py ├── obs_reformat_catalog.py └── utils.py ├── metric_config.yml ├── pylintrc └── quality_control ├── README.md ├── __init__.py ├── algo ├── __init__.py ├── cluster.py ├── config.yaml ├── cross_variable.py ├── distributional_gap.py ├── diurnal_cycle.py ├── fine_tuning.py ├── neighbouring_stations.py ├── persistence.py ├── record_extreme.py ├── refinement.py ├── spike.py ├── time_series.py └── utils.py ├── download_ISD.py ├── quality_control.py ├── raw_ISD_to_hourly.py └── station_merging.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.nc filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: Pylint 2 | 3 | on: 4 | pull_request: 5 | branches: [main] 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | strategy: 11 | matrix: 12 | python-version: ["3.8", "3.9", "3.10"] 13 | steps: 14 | - uses: actions/checkout@v4 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v3 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install pylint 23 | - name: Analysing the code with pylint 24 | run: | 25 | pylint $(git ls-files '*.py') 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /Data/WeatherReal-ISD-2021.nc: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b137bff773931b938f649015134d73964b12da8274066631e7f71132d44cde86 3 | size 535031735 4 | -------------------------------------------------------------------------------- /Data/WeatherReal-ISD-2022.nc: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b5a3a11ed0c2d7b790f1adc4b48df779883b99de8da38dc0fe91d371b79addc6 3 | size 530269322 4 | -------------------------------------------------------------------------------- /Data/WeatherReal-ISD-2023.nc: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2ae3267ea9cdbb6fa2551265f519afe06ab4f215b19a9b66117f32aede1eee45 3 | size 546501126 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WeatherReal: A Benchmark Based on In-Situ Observations for Evaluating Weather Models 2 | 3 | Welcome to the official GitHub repository and leaderboard page for the WeatherReal weather forecasting benchmark! 4 | This repository includes code for running the evaluation of weather forecasting models on the WeatherReal benchmark dataset, 5 | as well as links to the LFS storage for the data files. 6 | If you would like to contribute a model evaluation result to this page, please open an Issue with the tag "Submission". 7 | If you use this repository or its data, please cite the following: 8 | - [WeatherReal](https://arxiv.org/html/2409.09371): A Benchmark Based on In-Situ Observations for Evaluating Weather Models 9 | - Synoptic Data (for the Synoptic data used in the benchmark) 10 | - [NCEI ISD database](https://journals.ametsoc.org/view/journals/bams/92/6/2011bams3015_1.xml) 11 | 12 | ## Why WeatherReal? 13 | 14 | Weather forecasting is a critical application for many scenarios, including disaster preparedness, agriculture, energy, 15 | and day-to-day life. Recent advances in AI applications for weather forecasting have demonstrated enormous potential, 16 | with models like GraphCast, Pangu-Weather, and FengWu achieving state-of-the-art performance on benchmarks like 17 | WeatherBench-2, even outperforming traditional numerical weather prediction models (the ECMWF IFS). However, these models 18 | have not been fully tested on real-world data collected from observing stations; rather the evaluation has focused on 19 | reanalysis datasets such as ERA5 that are generated by NWP models and thus include those models' biases and approximations. 20 | 21 | WeatherReal is the first benchmark to provide comprehensive evaluation against observations from the 22 | [Integrated Surface Database (ISD)](https://www.ncei.noaa.gov/products/land-based-station/integrated-surface-database) 23 | and the [Synoptic Data](https://www.synopticdata.com/) API, in addition to aggregated user-reported observations collected 24 | from Weather from MSN Weather's reporting platform. We provide these datasets separately and offer separate benchmark 25 | leaderboards for each, in addition to separating leaderboards by tasks such as short-term and medium-range forecasting. 26 | We also provide evaluation code that can be run on a variety of model forecast schemas to generate scores for the benchmark. 27 | By interpolating gridded forecasts to station locations, and using nearest-neighbor interpolation to match point forecasts 28 | to station locations, we can fairly evaluate models that produce either gridded or point forecasts. 29 | 30 | ### What WeatherReal is 31 | 32 | - A benchmark dataset of quality-controlled weather observations spanning the year 2023, with ongoing updates planned 33 | - An evaluation framework to score many tasks for either grid-based or point-based forecasts 34 | - Still in its infancy - we welcome feedback on how to improve the benchmark, the evaluation code, and the leaderboards 35 | 36 | ### What WeatherReal is not 37 | 38 | - A dataset for training weather forecasting models 39 | - An inference platform for running models 40 | - An exclusive benchmark for weather forecasting 41 | 42 | ## Leaderboards 43 | 44 | ### [Provisional] Global medium-range weather forecasting 45 | 46 | Task: Forecasts are initialized twice daily at 00 and 12 UTC over the entire evaluation 47 | year 2023. Forecasts are evaluated at every 6 hours of lead time up to 168 hours (7 days). Headline metric is 48 | the RMSE for each predicted variable (except for ETS for precipitation) averaged over all forecasts and lead 49 | times. **Note**: The leaderboard is provisional due to incomplete forecast initializations for the provided mdoels, and 50 | therefore is subject to change. 51 | 52 | | **WeatherReal-ISD** | 2-m temperature
(RMSE, K) | 10-m wind speed
(RMSE, m/s) | mean sea-level pressure
(RMSE, hPa) | total cloud cover
(RMSE, okta) | 6-hour precipitation > 1 mm
(ETS) | 53 | |---------------------|--------------------------------|----------------------------------|------------------------------------------|-------------------------------------|----------------------------------------| 54 | | Microsoft-Point | **2.258** | **1.753** | - | **2.723** | - | 55 | | Aurora-9km | 2.417 | 2.186 | **2.939** | - | - | 56 | | ECMWF | 2.766 | 2.251 | 3.098 | 3.319 | 0.248 | 57 | | GFS | 3.168 | 2.455 | 3.480 | - | - | 58 | 59 | 60 | ### Other tasks 61 | 62 | We welcome feedback from the modeling community on how best to use the data for evaluating forecasts in a way that best 63 | reflects the end consumer experience with various forecasting models. We propose the following common tasks: 64 | 65 | - **Short-range forecasting.** Forecasts are initialized four times daily at 00, 06, 12, and 18 UTC. Forecasts are evaluated every 1 hour of lead time up to 72 hours. Headline metric is the RMSE (ETS for precipitation) for each predicted variable averaged over all forecasts and lead times. 66 | - **Nowcasting.** Forecasts are initialized every hour. Forecasts are evaluated every 1 hour of lead time up to 24 hours. Headline metric is the RMSE (ETS for precipitation) for each predicted variable averaged over all forecasts and lead times. 67 | - **Sub-seasonal-to-seasonal forecasts.** Following the schedule of ECMWF's long-range forecasts prior to June 2023, forecasts are initialized twice weekly at 00 UTC on Mondays and Thursdays. Forecasts are averaged either daily or weekly for lead times every 6 hours. Ideally forecasts should be probabilistic, enabling the use of proper scoring methods such as the continuous ranked probability score (CRPS). Headline metrics are week 3-4 and week 5-6 average scores. 68 | 69 | 70 | 71 | ## About the data 72 | 73 | WeatherReal includes several versions, all derived from global near-surface in-situ observations: 74 | (1) WeatherReal-ISD: An observational dataset based on Integrated Surface Database (ISD), which has been subjected to rigorous post-processing and quality control through our independently developed algorithms. 75 | (2) WeatherReal-Synoptic, An observational dataset from Synoptic Data PBC, a data service platform for 150,000+ in-situ surface weather stations, offering a much more densely distributed network. 76 | A quality control system is also provided as additional attributes delivered alongside the data from their API services. 77 | The following table lists available variables in WeatherReal-ISD and WeatherReal-Synoptic. 78 | 79 | | Variable | Short Name | Unit1 | Variable | Short Name | Unit | 80 | |---------------------------------------|------------|---------------------------------|-----------------------|------------|--------------------------| 81 | | 2m Temperature | t | °C | Total Cloud Cover | c | okta4 | 82 | | 2m Dewpoint Temperature | td | °C | 1-hour Precipitation | ra1 | mm | 83 | | Surface Pressure2 | sp | hPa | 3-hour Precipitation | ra3 | mm | 84 | | Mean Sea-level Pressure | msl | hPa | 6-hour Precipitation | ra6 | mm | 85 | | 10m Wind Speed | ws | m/s | 12-hour Precipitation | ra12 | mm | 86 | | 10m Wind Direction | wd | degree3 | 24-hour Precipitation | ra24 | mm | 87 | 88 | **1**: Refers to the units used in the WeatherReal-ISD we publish. For the units provided by the raw ISD and Synoptic, please consult their respective documentation. 89 | **2**: For in-situ weather stations, surface pressure is measured at the sensor's height, typically 2 meters above ground level at the weather station. 90 | **3**: The direction is measured clockwise from true north, ranging from 1° (north-northeast) to 360° (north), with 0° indicating calm winds. 91 | **4**: Okta is a unit of measurement used to describe the amount of cloud cover, with the data range being from 0 (clear sky) to 8 (completely overcast). 92 | 93 | ### WeatherReal-ISD 94 | 95 | The data source of WeatherReal-ISD, ISD [Smith et al., 2011], is a global near-surface observation dataset compiled by the National Centers for Environmental Information (NCEI). 96 | More than 100 original data sources, including SYNOP (surface synoptic observations) and METAR (meteorological aerodrome report) weather reports, are incorporated. 97 | 98 | There are currently more than 14,000 active reporting stations in ISD and it already includes the majority of known station observation data, making it an ideal data source for WeatherReal. 99 | However, the observational data have only undergone basic quality control, resulting in numerous erroneous data points. 100 | Therefore, to improve data fidelity, we performed extensive post-processing on it, including station selection and merging, and comprehensive quality control. 101 | For more details on the data processing, please refer to the paper. 102 | 103 | ### WeatherReal-Synoptic 104 | 105 | Data of WeatherReal-Synoptic is obtained from Synoptic Data PBC, which brings together observation data from hundreds of public and private station networks worldwide, providing a comprehensive and accessible data service platform for critical environmental information. 106 | For further details, please refer to [Synoptic Data’s official site](https://synopticdata.com/solutions/ai-ml-weather/). 107 | The WeatherReal-Synoptic dataset utilized in this paper was retrieved in real-time from their Time Series API services in 2023 to address our operational requirements, and the same data is available from them as a historical dataset. 108 | For precipitation, Synoptic also supports an advanced API that allows data retrieval through custom accumulation and interval windows. 109 | WeatherReal-Synoptic encompasses a greater volume of data, a more extensive observation network, and a larger number of stations compared to ISD. 110 | Note that Synoptic provides a quality control system as an additional attribute alongside the data from their API services, thus the quality control algorithm we developed independently has not been applied to the WeatherReal-Synoptic dataset. 111 | 112 | 113 | ## Acquiring data 114 | 115 | The WeatherReal datasets are available from the following locations: 116 | - WeatherReal-ISD: A single file in netCDF format for year 2023: [GitHub LFS](https://github.com/microsoft/WeatherReal/blob/main/Data/WeatherReal-ISD-2023.nc) 117 | - WeatherReal-Synoptic: Please reach out directly to [Synoptic Data PBC](https://synopticdata.com/) for access to the data. 118 | 119 | 120 | ## Evaluation code 121 | 122 | The evaluation code is written in Python and is available in the `evaluation` directory. The code is designed to be 123 | flexible and can be used to evaluate a wide range of forecast schemas. The launch script `evaluate.py` can be used to run 124 | the evaluation. The following example illustrates how to evaluate temperature forecasts from gridded and point-based models: 125 | 126 | ```bash 127 | python evaluate.py \ 128 | --forecast-paths /path/to/grid_forecast_1.zarr /path/to/grid_forecast_2.zarr /path/to/point_forecast_1.zarr \ 129 | --forecast-names GridForecast1 GridForecast2 PointForecast1 \ 130 | --forecast-var-names t2m t2m t \ 131 | --forecast-reformat-funcs grid_v1 grid_v1 point_standard \ 132 | --obs-path /path/to/weatherreal-isd.nc \ 133 | --obs-var-name t \ 134 | --variable-type temperature \ 135 | --convert-fcst-temperature-k-to-c \ 136 | --output-directory /path/to/output 137 | ``` 138 | 139 | ## Submitting metrics 140 | 141 | We welcome all submissions of evaluation results using WeatherReal data! 142 | To submit your model's evaluation metrics to the leaderboard, please open an Issue with the tag "Submission" and include 143 | all metrics/tasks you would like to submit. Please also include a reference paper or link to a public repository that can 144 | be used to peer-review your results. We will review your submission and add it to the leaderboard if it meets the 145 | requirements. 146 | 147 | ## Contributing 148 | 149 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 150 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 151 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 152 | 153 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 154 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 155 | provided by the bot. You will only need to do this once across all repos using our CLA. 156 | 157 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 158 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 159 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 160 | 161 | ## Trademarks 162 | 163 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 164 | trademarks or logos is subject to and must follow 165 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 166 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 167 | Any use of third-party trademarks or logos are subject to those third-party's policies. 168 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # Support 2 | 3 | ## How to file issues and get help 4 | 5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 7 | feature request as a new Issue. 8 | 9 | ## Microsoft Support Policy 10 | 11 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 12 | -------------------------------------------------------------------------------- /conda.yml: -------------------------------------------------------------------------------- 1 | name: weatherreal 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2024.3.11=h06a4308_0 8 | - ld_impl_linux-64=2.38=h1181459_1 9 | - libffi=3.4.4=h6a678d5_1 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - ncurses=6.4=h6a678d5_0 14 | - openssl=3.0.13=h7f8727e_1 15 | - pip=24.0=py39h06a4308_0 16 | - python=3.9.19=h955ad1f_1 17 | - readline=8.2=h5eee18b_0 18 | - setuptools=69.5.1=py39h06a4308_0 19 | - sqlite=3.45.3=h5eee18b_0 20 | - tk=8.6.14=h39e8969_0 21 | - tzdata=2024a=h04d1e81_0 22 | - wheel=0.43.0=py39h06a4308_0 23 | - xz=5.4.6=h5eee18b_1 24 | - zlib=1.2.13=h5eee18b_1 25 | - pip: 26 | - adal==1.2.7 27 | - annotated-types==0.6.0 28 | - antlr4-python3-runtime==4.9.3 29 | - applicationinsights==0.11.10 30 | - argcomplete==3.3.0 31 | - asciitree==0.3.3 32 | - attrs==23.2.0 33 | - azure-common==1.1.28 34 | - azure-core==1.30.1 35 | - azure-graphrbac==0.61.1 36 | - azure-identity==1.16.0 37 | - azure-mgmt-authorization==4.0.0 38 | - azure-mgmt-containerregistry==10.3.0 39 | - azure-mgmt-core==1.4.0 40 | - azure-mgmt-keyvault==10.3.0 41 | - azure-mgmt-network==25.3.0 42 | - azure-mgmt-resource==23.1.1 43 | - azure-mgmt-storage==21.1.0 44 | - azure-ml==0.0.1 45 | - azure-ml-component==0.9.18.post2 46 | - azure-storage-blob==12.19.0 47 | - azureml-contrib-services==1.56.0 48 | - azureml-core==1.56.0 49 | - azureml-dataprep==5.1.6 50 | - azureml-dataprep-native==41.0.0 51 | - azureml-dataprep-rslex==2.22.2 52 | - azureml-dataset-runtime==1.56.0 53 | - azureml-defaults==1.56.0.post1 54 | - azureml-inference-server-http==1.2.1 55 | - azureml-mlflow==1.56.0 56 | - azureml-telemetry==1.56.0 57 | - backports-tempfile==1.0 58 | - backports-weakref==1.0.post1 59 | - bcrypt==4.1.3 60 | - blinker==1.8.2 61 | - bytecode==0.15.1 62 | - cachetools==5.3.3 63 | - certifi==2024.2.2 64 | - cffi==1.16.0 65 | - cftime==1.6.3 66 | - charset-normalizer==3.3.2 67 | - click==8.1.7 68 | - cloudpickle==2.2.1 69 | - contextlib2==21.6.0 70 | - contourpy==1.2.1 71 | - cryptography==42.0.7 72 | - cycler==0.12.1 73 | - dask==2023.3.2 74 | - distributed==2023.3.2 75 | - docker==7.0.0 76 | - entrypoints==0.4 77 | - fasteners==0.19 78 | - flask==2.3.2 79 | - flask-cors==3.0.10 80 | - fonttools==4.51.0 81 | - fsspec==2024.3.1 82 | - fusepy==3.0.1 83 | - gitdb==4.0.11 84 | - gitpython==3.1.43 85 | - google-api-core==2.19.0 86 | - google-auth==2.29.0 87 | - googleapis-common-protos==1.63.0 88 | - gunicorn==22.0.0 89 | - humanfriendly==10.0 90 | - idna==3.7 91 | - importlib-metadata==7.1.0 92 | - importlib-resources==6.4.0 93 | - inference-schema==1.7.2 94 | - isodate==0.6.1 95 | - itsdangerous==2.2.0 96 | - jeepney==0.8.0 97 | - jinja2==3.1.4 98 | - jmespath==1.0.1 99 | - joblib==1.4.2 100 | - jsonpickle==3.0.4 101 | - jsonschema==4.22.0 102 | - jsonschema-specifications==2023.12.1 103 | - kiwisolver==1.4.5 104 | - knack==0.11.0 105 | - locket==1.0.0 106 | - markupsafe==2.1.5 107 | - matplotlib==3.7.1 108 | - metpy==1.3.1 109 | - mlflow-skinny==2.12.2 110 | - msal==1.28.0 111 | - msal-extensions==1.1.0 112 | - msgpack==1.0.8 113 | - msrest==0.7.1 114 | - msrestazure==0.6.4 115 | - ndg-httpsclient==0.5.1 116 | - netcdf4==1.6.3 117 | - numcodecs==0.12.1 118 | - numpy==1.23.5 119 | - oauthlib==3.2.2 120 | - omegaconf==2.3.0 121 | - opencensus==0.11.4 122 | - opencensus-context==0.1.3 123 | - opencensus-ext-azure==1.1.13 124 | - packaging==24.0 125 | - pandas==1.5.3 126 | - paramiko==3.4.0 127 | - partd==1.4.2 128 | - pathspec==0.12.1 129 | - pillow==10.3.0 130 | - pint==0.23 131 | - pkginfo==1.10.0 132 | - platformdirs==4.2.1 133 | - pooch==1.8.1 134 | - portalocker==2.8.2 135 | - proto-plus==1.23.0 136 | - protobuf==4.25.3 137 | - psutil==5.9.8 138 | - pyarrow==16.0.0 139 | - pyasn1==0.6.0 140 | - pyasn1-modules==0.4.0 141 | - pycparser==2.22 142 | - pydantic==2.7.1 143 | - pydantic-core==2.18.2 144 | - pydantic-settings==2.2.1 145 | - pydash==8.0.1 146 | - pygments==2.18.0 147 | - pyjwt==2.8.0 148 | - pynacl==1.5.0 149 | - pyopenssl==24.1.0 150 | - pyparsing==3.1.2 151 | - pyproj==3.6.1 152 | - pysocks==1.7.1 153 | - python-dateutil==2.9.0.post0 154 | - python-dotenv==1.0.1 155 | - pytz==2024.1 156 | - pyyaml==6.0.1 157 | - referencing==0.35.1 158 | - requests==2.31.0 159 | - requests-oauthlib==2.0.0 160 | - rpds-py==0.18.1 161 | - rsa==4.9 162 | - ruamel-yaml==0.17.16 163 | - ruamel-yaml-clib==0.2.8 164 | - scikit-learn==1.4.2 165 | - scipy==1.13.0 166 | - secretstorage==3.3.3 167 | - six==1.16.0 168 | - smmap==5.0.1 169 | - sortedcontainers==2.4.0 170 | - sqlparse==0.5.0 171 | - tabulate==0.9.0 172 | - tblib==3.0.0 173 | - threadpoolctl==3.5.0 174 | - toolz==0.12.1 175 | - tornado==6.4 176 | - tqdm==4.66.4 177 | - traitlets==5.14.3 178 | - typing-extensions==4.11.0 179 | - urllib3==2.2.1 180 | - werkzeug==2.3.8 181 | - wrapt==1.16.0 182 | - xarray==2023.1.0 183 | - zarr==2.14.2 184 | - zict==3.0.0 185 | - zipp==3.18.1 186 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from functools import reduce 3 | import logging 4 | import os 5 | from pathlib import Path 6 | from typing import Any, Dict, List, Optional, Tuple, Union 7 | import warnings 8 | 9 | import matplotlib.pyplot as plt 10 | import pandas as pd 11 | import xarray as xr 12 | import yaml 13 | 14 | from evaluation.forecast_reformat_catalog import reformat_forecast 15 | from evaluation.obs_reformat_catalog import get_interp_station_list, obs_to_verification, reformat_and_filter_obs 16 | from evaluation.metric_catalog import get_metric_func 17 | from evaluation.utils import configure_logging, get_metric_multiple_stations, generate_forecast_cache_path, \ 18 | cache_reformat_forecast, load_reformat_forecast, ForecastData, ForecastInfo, MetricData, get_ideal_xticks 19 | 20 | 21 | warnings.filterwarnings("ignore") 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def intersect_all_forecast(forecast_list: List[ForecastData]) -> List[ForecastData]: 27 | """ 28 | Generates new versions of the forecast data objects where the dataframes have been aligned between all 29 | forecasts in the sequence. If only one forecast is provided, the data is returned as is. 30 | """ 31 | raise NotImplementedError("Forecast alignment is not yet implemented.") 32 | 33 | 34 | def get_forecast_data(forecast_info: ForecastInfo, cache_forecast: bool) -> ForecastData: 35 | """ 36 | Open all the forecasts from the forecast_info_list and reformat them. 37 | 38 | Parameters 39 | ---------- 40 | forecast_info: ForecastInfo 41 | The forecast information. 42 | cache_forecast: bool 43 | If true, cache the reformat forecast data, next time will load from cache. 44 | 45 | Returns 46 | ------- 47 | forecast: ForecastData 48 | Forecast data and the forecast information. 49 | """ 50 | info = forecast_info 51 | cache_path = generate_forecast_cache_path(info) 52 | if not os.path.exists(cache_path) or not cache_forecast: 53 | logger.info(f"open forecast file: {info.path}") 54 | if info.file_type is None: 55 | if Path(info.path).is_dir(): 56 | forecast = xr.open_zarr(info.path) 57 | else: 58 | forecast = xr.open_dataset(info.path, chunks={}) 59 | elif info.file_type == 'zarr': 60 | forecast = xr.open_zarr(info.path) 61 | else: 62 | forecast = xr.open_dataset(info.path, chunks={}) 63 | forecast_ds = reformat_forecast(forecast, info) 64 | if cache_forecast: 65 | cache_reformat_forecast(forecast_ds, cache_path) 66 | logger.info(f"save forecast file to cache: {cache_path}") 67 | else: 68 | forecast_ds = load_reformat_forecast(cache_path) 69 | logger.info(f"load forecast: {info.forecast_name}, from cache: {cache_path}") 70 | logger.debug(f"opened forecast dataset: {forecast_ds}") 71 | return ForecastData(info=info, forecast=forecast_ds) 72 | 73 | 74 | def get_observation_data(obs_base_path: str, obs_var_name: str, station_metadata_path: str, 75 | obs_file_type: str, obs_start_month: str, obs_end_month: str, 76 | precip_threshold: Optional[float] = None) -> xr.Dataset: 77 | """ 78 | Open the observation file and reformat it. Required fields: station, valid_time, obs_var_name. 79 | 80 | Parameters 81 | ---------- 82 | obs_base_path: str 83 | Path to the observation file. 84 | obs_var_name: str 85 | Name of the observation variable. 86 | station_metadata_path: str 87 | Path to the station metadata file. 88 | obs_file_type: str 89 | Type of the observation file. 90 | obs_start_month: str 91 | Obs start month, for multi-file netCDF data. 92 | obs_end_month: str 93 | Obs end month, for multi-file netCDF data. 94 | precip_threshold: float, optional 95 | Threshold for converting precipitation amount to binary. Default is no conversion. 96 | 97 | Returns 98 | ------- 99 | obs: xr.Dataset 100 | Observation data with required fields. 101 | """ 102 | if obs_start_month is not None and obs_end_month is not None: 103 | if obs_start_month is None or obs_end_month is None: 104 | raise ValueError("Both obs_start_month and obs_end_month must be provided.") 105 | month_list = pd.date_range(obs_start_month, obs_end_month, freq='MS') 106 | suffix = obs_file_type or 'nc' 107 | obs_path = [os.path.join(obs_base_path, month.strftime(f'%Y%m.{suffix}')) for month in month_list] 108 | obs_path_filter = [] 109 | for path in obs_path: 110 | if not os.path.exists(path): 111 | logger.warning(f"expected observation path does not exist: {path}") 112 | else: 113 | obs_path_filter.append(path) 114 | obs = xr.open_mfdataset(obs_path_filter, chunks={}) 115 | else: 116 | if obs_file_type is None: 117 | if Path(obs_base_path).is_dir(): 118 | obs = xr.open_zarr(obs_base_path) 119 | else: 120 | obs = xr.open_dataset(obs_base_path, chunks={}) 121 | elif obs_file_type == 'zarr': 122 | obs = xr.open_zarr(obs_base_path) 123 | else: 124 | obs = xr.open_dataset(obs_base_path, chunks={}) 125 | 126 | obs = reformat_and_filter_obs(obs, obs_var_name, station_metadata_path, precip_threshold) 127 | logger.debug(f"opened observation dataset: {obs}") 128 | return obs 129 | 130 | 131 | def merge_forecast_obs(forecast: ForecastData, obs: xr.Dataset) -> ForecastData: 132 | """ 133 | Merge the forecast and observation data. 134 | """ 135 | new_obs = obs_to_verification( 136 | obs, 137 | steps=forecast.forecast.lead_time.values, 138 | max_lead=forecast.forecast.lead_time.values.max(), 139 | issue_times=forecast.forecast.issue_time.values 140 | ) 141 | merge_data = xr.merge([forecast.forecast, new_obs], compat='override') 142 | merge_data['delta'] = merge_data['fc'] - merge_data['obs'] 143 | result = ForecastData(info=forecast.info, forecast=forecast.forecast, merge_data=merge_data) 144 | logger.debug(f"after merge forecast and obs: {result.merge_data}") 145 | return result 146 | 147 | 148 | def filter_by_region(forecast: ForecastData, region_name: str, station_list: List[str]) \ 149 | -> ForecastData: 150 | """ 151 | Apply a selection on forecast based on the region_name and station_list. 152 | """ 153 | if region_name == 'all': 154 | return forecast 155 | merge_data = forecast.merge_data 156 | filtered_merge_data = ForecastData( 157 | merge_data=merge_data.sel(station=merge_data.station.values.isin(station_list)), 158 | info=forecast.info 159 | ) 160 | return filtered_merge_data 161 | 162 | 163 | def calculate_all_metrics(forecast_data: ForecastData, group_dim: str, metrics_params: Dict[str, Any]) \ 164 | -> MetricData: 165 | """ 166 | Calculate all the metrics together for dask graph efficiency. 167 | 168 | Parameters 169 | ---------- 170 | forecast_data: ForecastData 171 | The forecast data. 172 | group_dim: str 173 | The dimension to group the metric calculation. 174 | metrics_params: dict 175 | Dictionary containing the metrics configs. 176 | 177 | Returns 178 | ------- 179 | metric_data: MetricData 180 | Metric data and the forecast information. 181 | """ 182 | metrics = MetricData(info=forecast_data.info, metric_data=xr.Dataset()) 183 | for metric_name in metrics_params.keys(): 184 | metric_func = get_metric_func(metrics_params[metric_name]) 185 | metrics.metric_data[metric_name] = metric_func(forecast_data.merge_data, group_dim) 186 | metrics.metric_data = metrics.metric_data.compute() 187 | return metrics 188 | 189 | 190 | def get_plot_detail(forecast_data: ForecastData, group_dim: str): 191 | """ 192 | Get some added data to show on plots 193 | """ 194 | merge_data = forecast_data.merge_data 195 | counts = {key: coord.size for key, coord in merge_data.coords.items() if key not in ['lat', 'lon']} 196 | counts.update({'fc': merge_data.fc.size}) 197 | all_dims = ['valid_time', 'issue_time', 'lead_time'] 198 | all_dims.remove(group_dim) 199 | dim_info = [] 200 | for dim in all_dims: 201 | if dim == 'lead_time': 202 | vmax, vmin = merge_data[dim].max(), merge_data[dim].min() 203 | else: 204 | vmax, vmin = pd.Timestamp(merge_data[dim].values.max()).strftime("%Y-%m-%d %H:%M:%S"), \ 205 | pd.Timestamp(merge_data[dim].values.min()).strftime("%Y-%m-%d %H:%M:%S") 206 | dim_info.append(f'{dim} min: {vmin}, max: {vmax}') 207 | data_distribution = f"dim count: {str(counts)}\n{dim_info[0]}\n{dim_info[1]}" 208 | return data_distribution 209 | 210 | 211 | def plot_metric( 212 | example_data: ForecastData, 213 | metric_data_list: List[MetricData], 214 | group_dim: str, 215 | metric_name: str, 216 | base_plot_setting: Dict[str, Any], 217 | metrics_params: Dict[str, Any], 218 | output_dir: str, 219 | region_name: str, 220 | plot_save_format: Optional[str] = 'png' 221 | ) -> plt.Figure: 222 | """ 223 | A generic, basic plot for a single metric. 224 | 225 | Parameters 226 | ---------- 227 | example_data: ForecastData 228 | Example forecast data to get some extra information for the plot. 229 | metric_data_list: list of MetricData 230 | List of MetricData objects containing the metric data and the forecast information. 231 | group_dim: str 232 | The dimension to group the metric calculation. 233 | metric_name: str 234 | The name of the metric. 235 | base_plot_setting: dict 236 | Dictionary containing the base plot settings. 237 | metrics_params: dict 238 | Dictionary containing the metric method and other kwargs. 239 | output_dir: str 240 | The output directory for the plots. 241 | region_name: str 242 | The name of the region. 243 | plot_save_format: str, optional 244 | The format to save the plot in. Default is 'png'. 245 | 246 | Returns 247 | ------- 248 | fig: plt.Figure 249 | The plot figure. 250 | """ 251 | data_distribution = get_plot_detail(example_data, group_dim) 252 | 253 | fig = plt.figure(figsize=(5.5, 6.5)) 254 | font = {'weight': 'medium', 'fontsize': 11} 255 | title = base_plot_setting['title'] 256 | xlabel = base_plot_setting['xlabel'] 257 | if 'plot_setting' in metrics_params: 258 | plot_setting = metrics_params['plot_setting'] 259 | title = plot_setting.get('title', title) 260 | xlabel = plot_setting.get('xlabel', xlabel) 261 | plt.title(title) 262 | plt.suptitle(data_distribution, fontsize=7) 263 | plt.gca().set_xlabel(xlabel[group_dim], fontdict=font) 264 | plt.gca().set_ylabel(metric_name, fontdict=font) 265 | 266 | for metrics in metric_data_list: 267 | metric_data = metrics.metric_data 268 | forecast_name = metrics.info.forecast_name 269 | plt.plot(metric_data[group_dim], metric_data[metric_name], label=forecast_name, linewidth=1.5) 270 | 271 | if group_dim == 'lead_time': 272 | plt.gca().set_xticks(get_ideal_xticks(metric_data[group_dim].min(), metric_data[group_dim].max(), 8)) 273 | plt.grid(linestyle=':') 274 | plt.legend(loc='upper center', bbox_to_anchor=(0.45, -0.14), frameon=False, ncol=3, fontsize=10) 275 | plt.tight_layout() 276 | plt.subplots_adjust(top=0.85) 277 | plot_path = os.path.join(output_dir, region_name) 278 | os.makedirs(plot_path, exist_ok=True) 279 | plt.savefig(os.path.join(plot_path, f"{metric_name}.{plot_save_format}")) 280 | return fig 281 | 282 | 283 | def metrics_to_csv( 284 | metric_data_list: List[MetricData], 285 | group_dim: str, 286 | output_dir: str, 287 | region_name: str 288 | ): 289 | """ 290 | A generic, basic function to save the metric data to a CSV file. 291 | 292 | Parameters 293 | ---------- 294 | metric_data_list: list of MetricData 295 | List of MetricData objects containing the metric data and the forecast information. 296 | group_dim: str 297 | The dimension to group the metric calculation. 298 | output_dir: str 299 | The output directory for the CSV files. 300 | region_name: str 301 | The name of the region. 302 | 303 | Returns 304 | ------- 305 | merged_df: pd.DataFrame 306 | The merged DataFrame with metric data. 307 | """ 308 | df_list = [] 309 | for metrics in metric_data_list: 310 | df_list.append(metrics.metric_data.rename( 311 | {metric: f"{metrics.info.forecast_name}_{metric}" for metric in metrics.metric_data.data_vars.keys()} 312 | ).to_dataframe()) 313 | merged_df = reduce(lambda left, right: pd.merge(left, right, on=group_dim, how='inner'), df_list) 314 | 315 | output_path = os.path.join(output_dir, region_name) 316 | os.makedirs(output_path, exist_ok=True) 317 | merged_df.to_csv(os.path.join(output_path, "metrics.csv")) 318 | return merged_df 319 | 320 | 321 | def parse_args(args: argparse.Namespace) -> Tuple[List[ForecastInfo], Any, Any, Union[ 322 | Dict[str, List[Any]], Any], Any, Any, Any, Any, Any, Any, Any, str, bool, bool, Optional[float]]: 323 | forecast_info_list = [] 324 | forecast_name_list = args.forecast_names 325 | forecast_var_name_list = args.forecast_var_names 326 | forecast_reformat_func_list = args.forecast_reformat_funcs 327 | station_metadata_path = args.station_metadata_path 328 | for index, forecast_path in enumerate(args.forecast_paths): 329 | forecast_info = ForecastInfo( 330 | path=forecast_path, 331 | forecast_name=forecast_name_list[index] if index < len(forecast_name_list) else f"forecast_{index}", 332 | fc_var_name=forecast_var_name_list[index] if index < len(forecast_var_name_list) else 333 | forecast_var_name_list[0], 334 | reformat_func=forecast_reformat_func_list[index] if index < len(forecast_reformat_func_list) else 335 | forecast_reformat_func_list[0], 336 | file_type=args.forecast_file_types[index] if index < len(args.forecast_file_types) else None, 337 | station_metadata_path=station_metadata_path, 338 | interp_station_path=station_metadata_path, 339 | output_directory=args.output_directory, 340 | start_date=args.start_date, 341 | end_date=args.end_date, 342 | issue_time_freq=args.issue_time_freq, 343 | start_lead=args.start_lead, 344 | end_lead=args.end_lead, 345 | convert_temperature=args.convert_fcst_temperature_k_to_c, 346 | convert_pressure=args.convert_fcst_pressure_pa_to_hpa, 347 | convert_cloud=args.convert_fcst_cloud_to_okta, 348 | precip_proba_threshold=args.precip_proba_threshold_conversion, 349 | ) 350 | forecast_info_list.append(forecast_info) 351 | 352 | metrics_settings_path = args.config_path if args.config_path is not None else os.path.join( 353 | os.path.dirname(os.path.abspath(__file__)), 'metric_config.yml') 354 | with open(metrics_settings_path, 'r') as fs: # pylint: disable=unspecified-encoding 355 | metrics_settings = yaml.safe_load(fs) 356 | 357 | try: 358 | metrics_settings = metrics_settings[args.variable_type] 359 | except KeyError as exc: 360 | raise ValueError(f"Unknown variable type: {args.variable_type}. Check config file {metrics_settings_path}") \ 361 | from exc 362 | metrics_dict = metrics_settings['metrics'] 363 | base_plot_setting = metrics_settings['base_plot_setting'] 364 | if args.eval_region_files is not None: 365 | try: 366 | region_dict = get_metric_multiple_stations(','.join(args.eval_region_files)) 367 | except Exception as e: # pylint: disable=broad-exception-caught 368 | logger.info(f"get_metric_multiple_stations failed, use default region: all {e}") 369 | region_dict = {} 370 | else: 371 | region_dict = {} 372 | region_dict['all'] = [] 373 | 374 | return (forecast_info_list, metrics_dict, base_plot_setting, 375 | region_dict, args.group_dim, args.obs_var_name, args.obs_path, args.obs_file_type, 376 | args.obs_start_month, args.obs_end_month, args.output_directory, station_metadata_path, 377 | bool(args.cache_forecast), bool(args.align_forecasts), 378 | args.precip_proba_threshold_conversion) 379 | 380 | 381 | def main(args): 382 | logger.info("===================== parse args =====================") 383 | (forecast_info_list, metrics_dict, base_plot_setting, region_dict, group_dim, obs_var_name, 384 | obs_base_path, obs_file_type, obs_start_month, obs_end_month, output_dir, station_metadata_path, 385 | cache_forecast, align_forecasts, precip_threshold) = parse_args(args) 386 | 387 | logger.info("===================== start get_observation_data =====================") 388 | obs_ds = get_observation_data(obs_base_path, obs_var_name, station_metadata_path, obs_file_type, obs_start_month, 389 | obs_end_month, precip_threshold=precip_threshold) 390 | 391 | # Get metadata and set it on all the forecast info objects 392 | if station_metadata_path is not None: 393 | metadata = get_interp_station_list(station_metadata_path) 394 | else: 395 | try: 396 | metadata = pd.DataFrame({'lat': obs_ds.lat.values, 'lon': obs_ds.lon.values, 397 | 'station': obs_ds.station.values}) 398 | except KeyError as exc: 399 | raise ValueError("--station-metadata-path is required if lat/lon/station keys are not in the observation " 400 | "file") from exc 401 | for forecast_info in forecast_info_list: 402 | forecast_info.metadata = metadata 403 | 404 | if align_forecasts: 405 | # First load all forecasts, then compute and return metrics. 406 | logger.info("===================== start get_forecast_data =====================") 407 | forecast_list = [get_forecast_data(fi, cache_forecast) for fi in forecast_info_list] 408 | logger.info("===================== start intersect_all_forecast =====================") 409 | forecast_list = intersect_all_forecast(forecast_list) 410 | else: 411 | forecast_list = [None] * len(forecast_info_list) 412 | 413 | # For each forecast, compute all its metrics in every region. 414 | metric_data = {r: [] for r in region_dict} 415 | for forecast, forecast_info in zip(forecast_list, forecast_info_list): 416 | try: 417 | del merged_forecast # noqa: F821 418 | except NameError: 419 | pass 420 | logger.info(f"===================== compute metrics for forecast {forecast_info.forecast_name} " 421 | f"=====================") 422 | if forecast is None: 423 | logger.info("===================== get_forecast_data =====================") 424 | forecast = get_forecast_data(forecast_info, cache_forecast) 425 | 426 | logger.info("===================== start merge_forecast_obs =====================") 427 | merged_forecast = merge_forecast_obs(forecast, obs_ds) 428 | 429 | for region_name, region_data in region_dict.items(): 430 | logger.info(f"===================== filter_by_region: {region_name} =====================") 431 | filtered_forecast = filter_by_region(merged_forecast, region_name, region_data) 432 | if region_name != 'all': 433 | logger.info(f"after filter_by_region: {region_name}; " 434 | f"stations: {filtered_forecast.merge_data.station.size}") 435 | 436 | logger.info(f"start calculate_metrics, region: {region_name}") 437 | metric_data[region_name].append( 438 | calculate_all_metrics(filtered_forecast, group_dim, metrics_dict) 439 | ) 440 | forecast = None 441 | 442 | # Plot all metrics and save data 443 | for region_name in region_dict: 444 | for metric_name in metrics_dict: 445 | logger.info(f"===================== plot_metric: {metric_name}, region: {region_name} " 446 | f"=====================") 447 | plot_metric(merged_forecast, metric_data[region_name], group_dim, metric_name, 448 | base_plot_setting, metrics_dict[metric_name], output_dir, region_name) 449 | 450 | metrics_to_csv(metric_data[region_name], group_dim, output_dir, region_name) 451 | 452 | 453 | if __name__ == '__main__': 454 | parser = argparse.ArgumentParser( 455 | description="Forecast evaluation script. Given a set of forecasts and a file of reference observations, " 456 | "computes requested metrics as specified in the `metric_catalog.yml` file. Includes the ability " 457 | "to interpret either grid-based or point-based forecasts. Grid-based forecasts are interpolated " 458 | "to observation locations. Point-based forecasts are directly compared to nearest observations." 459 | ) 460 | parser.add_argument( 461 | "--forecast-paths", 462 | nargs='+', 463 | type=str, 464 | required=True, 465 | help="List of paths containing forecasts. If a directory is provided, assumes forecast is a zarr store, " 466 | "and calls xarray's `open_zarr` method. " 467 | "Required dimensions: lead_time (or step), issue_time (or time), lat (or latitude), lon (or longitude)." 468 | ) 469 | parser.add_argument( 470 | '--forecast-names', 471 | type=str, 472 | nargs='+', 473 | default=[], 474 | help="List of names to assign to the forecasts. If there are more forecast paths than names, fills in the " 475 | "remaining names with 'forecast_{index}'" 476 | ) 477 | parser.add_argument( 478 | '--forecast-var-names', 479 | type=str, 480 | nargs='+', 481 | required=True, 482 | help="List of names (one per forecast path) of the forecast variable of interest in each file. If only one " 483 | "value is provided, assumes all forecast files have the same variable name. Raises an error if the " 484 | "number of listed values is less than the number of forecast paths." 485 | ) 486 | parser.add_argument( 487 | '--forecast-reformat-funcs', 488 | type=str, 489 | nargs='+', 490 | required=True, 491 | help="For each forecast path, provide the name of the reformat function to apply. This function is based on " 492 | "the schema of the forecast file. Can be only a single value to apply to all forecasts. Options: " 493 | "\n - 'grid_standard': input is a grid forecast with dimensions lead_time, issue_time, lat, lon." 494 | "\n - 'point_standard': input is a point forecast with dimensions lead_time, issue_time, station." 495 | "\n - 'grid_v1': custom reformat function for grid forecasts with dims time, step, latitude, longitude." 496 | ) 497 | parser.add_argument( 498 | '--forecast-file-types', 499 | type=str, 500 | nargs='+', 501 | default=[], 502 | help="List of file types for each forecast path. Options: 'nc', 'zarr'. If not provided, or not enough " 503 | "entries, will assume zarr store if forecast is a directory, and otherwise will use xarray's " 504 | "`open_dataset` method." 505 | ) 506 | parser.add_argument( 507 | "--obs-path", 508 | type=str, 509 | required=True, 510 | help="Path to the verification folder or file" 511 | ) 512 | parser.add_argument( 513 | "--obs-file-type", 514 | type=str, 515 | default=None, 516 | help="Type of the observation file. Options: 'nc', 'zarr'. If not provided, will assume zarr store if this is " 517 | "a directory, and otherwise will use xarray's `open_dataset` method." 518 | ) 519 | parser.add_argument( 520 | "--obs-start-month", 521 | type=str, 522 | default=None, 523 | help="Option to read multiple netCDF files as a single dataset. These files are named 'YYYYMM.nc'. Provide " 524 | "the start month in the format 'YYYY-MM'. Not needed if obs-path is a single nc/zarr store." 525 | ) 526 | parser.add_argument( 527 | "--obs-end-month", 528 | type=str, 529 | default=None, 530 | help="Option to read multiple netCDF files as a single dataset. These files are named 'YYYYMM.nc'. Provide " 531 | "the end month in the format 'YYYY-MM'. Not needed if obs-path is a single nc/zarr store." 532 | ) 533 | parser.add_argument( 534 | "--obs-var-name", 535 | type=str, 536 | help="Name of the variable of interest in the observation data.", 537 | required=True 538 | ) 539 | parser.add_argument( 540 | "--station-metadata-path", 541 | type=str, 542 | help="Path to the station list containing metadata. Must include columns 'station', 'lat', 'lon'. " 543 | "If not provided, assumes the station lat/lon are coordinates in the observation file.", 544 | required=False 545 | ) 546 | parser.add_argument( 547 | "--config-path", 548 | type=str, 549 | help="Path to custom config yml file containing metric settings. Defaults to `metric_config.yml` in this " 550 | "script directory.", 551 | default=None 552 | ) 553 | parser.add_argument( 554 | "--variable-type", 555 | type=str, 556 | help="The type of the variable, as used in `--config-path` to select the appropriate metric settings. For " 557 | "example, 'temperature' or 'wind'.", 558 | required=True 559 | ) 560 | parser.add_argument( 561 | "--output-directory", 562 | type=str, 563 | help="Output directory for all evaluation artifacts", 564 | required=True 565 | ) 566 | parser.add_argument( 567 | "--eval-region-files", 568 | type=str, 569 | default=None, 570 | nargs='+', 571 | help="A list of files containing station lists for evaluation in certain regions" 572 | ) 573 | parser.add_argument( 574 | "--start-date", 575 | type=pd.Timestamp, 576 | default=None, 577 | help="First forecast issue time (as Timestamp) to include in evaluation" 578 | ) 579 | parser.add_argument( 580 | "--end-date", 581 | type=pd.Timestamp, 582 | default=None, 583 | help="Last forecast issue time (as Timestamp) to include in evaluation" 584 | ) 585 | parser.add_argument( 586 | "--issue-time-freq", 587 | type=str, 588 | default=None, 589 | help="Frequency of issue times (e.g., '1D') to include in evaluation. Default is None (all issue times)" 590 | ) 591 | parser.add_argument( 592 | "--start-lead", 593 | type=int, 594 | default=None, 595 | help="First lead time (in hours) to include in evaluation" 596 | ) 597 | parser.add_argument( 598 | "--end-lead", 599 | type=int, 600 | default=None, 601 | help="Last lead time (in hours) to include in evaluation" 602 | ) 603 | parser.add_argument( 604 | "--group-dim", 605 | type=str, 606 | default="lead_time", 607 | help="Group dimension for metric computation, options: lead_time, issue_time, valid_time" 608 | ) 609 | parser.add_argument( 610 | "--precip-proba-threshold-conversion", 611 | type=float, 612 | default=None, 613 | help="Convert observation and forecast fields from precipitation rate to probability of precipitation. Provide" 614 | " a threshold in mm/hr to use as positive precipitation class. Use only for evaluating precipitation!" 615 | ) 616 | parser.add_argument( 617 | "--convert-fcst-temperature-k-to-c", 618 | action='store_true', 619 | help="Convert forecast field from Kelvin to Celsius. Use only for evaluating temperature!" 620 | ) 621 | parser.add_argument( 622 | "--convert-fcst-pressure-pa-to-hpa", 623 | action='store_true', 624 | help="Convert forecast field from Pa to hPa. Use only for evaluating pressure!" 625 | ) 626 | parser.add_argument( 627 | "--convert-fcst-cloud-to-okta", 628 | action='store_true', 629 | help="Convert forecast field from cloud fraction to okta. Use only for evaluating cloud!" 630 | ) 631 | parser.add_argument( 632 | "--cache-forecast", 633 | action='store_true', 634 | help="If true, cache the intermediate interpolated forecast data in the output directory." 635 | ) 636 | parser.add_argument( 637 | "--align-forecasts", 638 | action='store_true', 639 | help="If set, load all forecasts first and then align them based on the intersection of issue/lead times. " 640 | "Note this uses substantially more memory to store all data at once." 641 | ) 642 | parser.add_argument( 643 | '--verbose', 644 | type=int, 645 | default=1, 646 | help="Verbosity level for logging. Options are 0 (WARNING), 1 (INFO), 2 (DEBUG), 3 (NOTSET). Default is 1." 647 | ) 648 | 649 | run_args = parser.parse_args() 650 | configure_logging(run_args.verbose) 651 | main(run_args) 652 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/WeatherReal-Benchmark/68b2f9293d2a0a1b1cedf396cffda34c05a21a14/evaluation/__init__.py -------------------------------------------------------------------------------- /evaluation/forecast_reformat_catalog.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import xarray as xr 6 | 7 | from .utils import convert_to_binary, ForecastInfo 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def convert_grid_to_point(grid_forecast: xr.Dataset, metadata: pd.DataFrame) -> xr.Dataset: 13 | """ 14 | Convert grid forecast to point dataframe via interpolation 15 | input data dims must be: lead_time, issue_time, lat, lon 16 | 17 | Parameters 18 | ---------- 19 | grid_forecast: xarray Dataset: grid forecast data 20 | metadata: pd.DataFrame: station metadata 21 | 22 | Returns 23 | ------- 24 | xr.Dataset: interpolated forecast data 25 | """ 26 | # grid_forecast = grid_forecast.load() 27 | grid_forecast = grid_forecast.assign_coords(lon=[lon if (lon < 180) else (lon - 360) 28 | for lon in grid_forecast['lon'].values]) 29 | # Optionally roll the longitude if the minimum longitude is not at index 0. This should make longitudes 30 | # monotonically increasing. 31 | if grid_forecast['lon'].argmin().values != 0: 32 | grid_forecast = grid_forecast.roll(lon=grid_forecast['lon'].argmin().values, roll_coords=True) 33 | 34 | # Default interpolate 35 | interp_meta = xr.Dataset.from_dataframe(metadata[['station', 'lon', 'lat']].set_index(['station'])) 36 | interp_forecast = grid_forecast.interp(lon=interp_meta['lon'], lat=interp_meta['lat'], method='linear') 37 | return interp_forecast 38 | 39 | 40 | def get_lead_time_slice(start_lead, end_lead): 41 | start = pd.Timedelta(start_lead, 'H') if start_lead is not None else None 42 | end = pd.Timedelta(end_lead, 'H') if end_lead is not None else None 43 | return slice(start, end) 44 | 45 | 46 | def update_unit_conversions(forecast: xr.Dataset, info: ForecastInfo): 47 | """ 48 | Check if the forecast data needs to be converted to the required units. Will not perform correction if 49 | not specified by user, however, if specified and data are already in converted units, then do not perform 50 | the conversion. 51 | 52 | Parameters 53 | ---------- 54 | forecast: xarray Dataset: forecast data 55 | info: ForecastInfo: forecast info 56 | """ 57 | if info.reformat_func == 'omg_v1': 58 | logger.info(f"Unit conversion not needed for forecast {info.forecast_name}") 59 | return 60 | unit = forecast['fc'].attrs.get('units', '') or forecast['fc'].attrs.get('unit', '') 61 | if info.convert_temperature: 62 | if 'C' in unit: 63 | info.convert_temperature = False 64 | logger.info(f"Temperature conversion not needed for forecast {info.forecast_name}") 65 | if info.convert_pressure: 66 | if any(u in unit.lower() for u in ['hpa', 'mb', 'millibar']): 67 | info.convert_pressure = False 68 | logger.info(f"Pressure conversion not needed for forecast {info.forecast_name}") 69 | if info.convert_cloud: 70 | if any(u in unit.lower() for u in ['okta', '0-8']): 71 | info.convert_cloud = False 72 | logger.info(f"Cloud cover conversion not needed for forecast {info.forecast_name}") 73 | if info.precip_proba_threshold is not None: 74 | if not any(u in unit.lower() for u in ['mm', 'milli']): 75 | info.precip_proba_threshold /= 1e3 76 | logger.info(f"Probability thresholding not needed for forecast {info.forecast_name}") 77 | 78 | 79 | def convert_cloud(forecast: xr.Dataset): 80 | """ 81 | Convert cloud cover inplace from percentage or fraction to okta. Assumes fraction by default unless units 82 | attribute says otherwise. 83 | 84 | Parameters 85 | ---------- 86 | forecast: xarray Dataset: forecast data 87 | 88 | Returns 89 | ------- 90 | xr.Dataset: converted forecast data 91 | """ 92 | unit = forecast['fc'].attrs.get('units', '') or forecast['fc'].attrs.get('unit', '') 93 | if any(u in unit.lower() for u in ['percent', '%', '100']): 94 | forecast['fc'] *= 8 / 100. 95 | else: 96 | forecast['fc'] *= 8 97 | return forecast 98 | 99 | 100 | def convert_precip_binary(forecast: xr.Dataset, info: ForecastInfo): 101 | """ 102 | Convert precipitation forecast to binary based on the threshold in the forecast info 103 | 104 | Parameters 105 | ---------- 106 | forecast: xarray Dataset: forecast data 107 | info: ForecastInfo: forecast info 108 | 109 | Returns 110 | ------- 111 | xr.Dataset: converted forecast data 112 | """ 113 | if 'pp' in info.fc_var_name: 114 | logger.info(f"Forecast {info.forecast_name} is already in probability (%) format. Skipping thresholding.") 115 | forecast['fc'] /= 100. 116 | return forecast 117 | forecast['fc'] = convert_to_binary(forecast['fc'], info.precip_proba_threshold) 118 | return forecast 119 | 120 | 121 | def reformat_forecast(forecast: xr.Dataset, info: ForecastInfo) -> xr.Dataset: 122 | """ 123 | Format the forecast data to the required format for evaluation, and keep only the required stations. 124 | 125 | Parameters 126 | ---------- 127 | forecast: xarray Dataset: forecast data 128 | info: ForecastInfo: forecast info 129 | 130 | Returns 131 | ------- 132 | xr.Dataset: formatted forecast data 133 | """ 134 | if info.reformat_func in ['grid_v1']: 135 | reformat_data = reformat_grid_v1(forecast, info) 136 | elif info.reformat_func in ['omg_v1', 'grid_v2']: 137 | reformat_data = reformat_grid_v2(forecast, info) 138 | elif info.reformat_func in ['grid_standard']: 139 | reformat_data = reformat_grid_standard(forecast, info) 140 | elif info.reformat_func in ['point_standard']: 141 | reformat_data = reformat_point_standard(forecast, info) 142 | else: 143 | raise ValueError(f"Unknown reformat method {info.reformat_func} for forecast {info.forecast_name}") 144 | 145 | # Update unit conversions 146 | update_unit_conversions(reformat_data, info) 147 | if info.convert_temperature: 148 | reformat_data['fc'] -= 273.15 149 | if info.convert_pressure: 150 | reformat_data['fc'] /= 100 151 | if info.convert_cloud: 152 | convert_cloud(reformat_data) 153 | if info.precip_proba_threshold is not None: 154 | convert_precip_binary(reformat_data, info) 155 | 156 | # Convert coordinates 157 | reformat_data = reformat_data.assign_coords(valid_time=reformat_data['issue_time'] + reformat_data['lead_time']) 158 | reformat_data = reformat_data.assign_coords(lead_time=reformat_data['lead_time'] / np.timedelta64(1, 'h')) 159 | return reformat_data 160 | 161 | 162 | def select_forecasts(forecast: xr.Dataset, info: ForecastInfo) -> xr.Dataset: 163 | """ 164 | Select the forecast data based on the issue time and lead time 165 | 166 | Parameters 167 | ---------- 168 | forecast: xarray Dataset: forecast data 169 | info: ForecastInfo: forecast info 170 | 171 | Returns 172 | ------- 173 | xr.Dataset: selected forecast data 174 | """ 175 | forecast = forecast.sel( 176 | lead_time=get_lead_time_slice(info.start_lead, info.end_lead), 177 | issue_time=slice(info.start_date, info.end_date) 178 | ) 179 | if info.issue_time_freq is not None: 180 | forecast = forecast.resample(issue_time=info.issue_time_freq).nearest() 181 | return forecast 182 | 183 | 184 | def reformat_grid_v1(grid_forecast: xr.Dataset, info: ForecastInfo) -> xr.Dataset: 185 | """ 186 | Standard grid forecast format following ECMWF schema 187 | input nc|zarr file dims must be: time, step, latitude, longitude 188 | 189 | Parameters 190 | ---------- 191 | grid_forecast: xarray Dataset: grid forecast data 192 | info: dict: forecast info 193 | 194 | Returns 195 | ------- 196 | xr.Dataset: formatted forecast data 197 | """ 198 | grid_forecast = grid_forecast.rename( 199 | { 200 | 'latitude': 'lat', 201 | 'longitude': 'lon', 202 | 'step': 'lead_time', 203 | 'time': 'issue_time', 204 | info.fc_var_name: 'fc' 205 | } 206 | ) 207 | grid_forecast = select_forecasts(grid_forecast, info) 208 | interp_forecast = convert_grid_to_point(grid_forecast[['fc']], info.metadata) 209 | return interp_forecast 210 | 211 | 212 | def reformat_grid_v2(grid_forecast: xr.Dataset, info: ForecastInfo) -> xr.Dataset: 213 | """ 214 | Grid format following OMG schema 215 | input nc|zarr file dims must be: lead_time, issue_time, y, x, index 216 | Must have lat, lon, index, var_name as coordinates 217 | 218 | Parameters 219 | ---------- 220 | grid_forecast: xarray Dataset: grid forecast data 221 | info: dict: forecast info 222 | 223 | Returns 224 | ------- 225 | xr.Dataset: formatted forecast data 226 | """ 227 | lat = grid_forecast['lat'].isel(index=0).lat.values 228 | lon = grid_forecast['lon'].isel(index=0).lon.values 229 | issue_time = grid_forecast['issue_time'].values 230 | grid_forecast = grid_forecast.drop_vars(['lat', 'lon', 'index', 'var_name']) 231 | grid_forecast = grid_forecast.rename({'y': 'lat', 'x': 'lon', 'index': 'issue_time'}) 232 | grid_forecast = grid_forecast.assign_coords(lat=lat, lon=lon, issue_time=issue_time) 233 | grid_forecast = grid_forecast.squeeze(dim='var_name') 234 | 235 | grid_forecast = grid_forecast.rename({info.fc_var_name: 'fc'}) 236 | grid_forecast = select_forecasts(grid_forecast, info) 237 | interp_forecast = convert_grid_to_point(grid_forecast[['fc']], info.metadata) 238 | return interp_forecast 239 | 240 | 241 | def reformat_grid_standard(grid_forecast: xr.Dataset, info: ForecastInfo) -> \ 242 | xr.Dataset: 243 | """ 244 | Grid format following standard schema 245 | input nc|zarr file dims must be: lead_time, issue_time, lat, lon 246 | 247 | Parameters 248 | ---------- 249 | grid_forecast: xarray Dataset: grid forecast data 250 | info: dict: forecast info 251 | 252 | Returns 253 | ------- 254 | xr.Dataset: formatted forecast data 255 | """ 256 | fc_var_name = info.fc_var_name 257 | grid_forecast = grid_forecast.rename({fc_var_name: 'fc'}) 258 | grid_forecast = select_forecasts(grid_forecast, info) 259 | interp_forecast = convert_grid_to_point(grid_forecast[['fc']], info.metadata) 260 | return interp_forecast 261 | 262 | 263 | def reformat_point_standard(point_forecast: xr.Dataset, info: ForecastInfo) \ 264 | -> xr.Dataset: 265 | """ 266 | Standard point forecast format 267 | input nc|zarr file dims must be: lead_time, issue_time, station 268 | 269 | Parameters 270 | ---------- 271 | point_forecast: xarray Dataset: grid forecast data 272 | info: dict: forecast info 273 | 274 | Returns 275 | ------- 276 | xr.Dataset: formatted forecast data 277 | """ 278 | fc_var_name = info.fc_var_name 279 | point_forecast = point_forecast.rename({fc_var_name: 'fc'}) 280 | point_forecast = select_forecasts(point_forecast, info) 281 | return point_forecast[['fc']] 282 | -------------------------------------------------------------------------------- /evaluation/metric_catalog.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import numpy as np 4 | import xarray as xr 5 | 6 | 7 | def _get_mean_dims(data, group_dim): 8 | """ 9 | Get the dimensions over which to compute mean. 10 | """ 11 | return [dim for dim in data.dims if dim != group_dim] 12 | 13 | 14 | def get_metric_func(settings): 15 | """ 16 | Find the appropriate metric function, specified by 'settings', and 17 | customize it according to the rest of the 'settings'. 18 | 19 | All metric functions expect two arguments: 20 | a dataset with 3 columns, 'obs', 'fc' and 'delta', 21 | and with 3 dimensions: 'station', 'issue_time' 'lead_time' 22 | a group dimension, which is one of the dimensions of the dataset. 23 | 24 | They return a dataset with metric values averaged over all dims 25 | except the group dim. 26 | """ 27 | name = settings['method'] 28 | if name == 'rmse': 29 | return rmse 30 | if name == 'count': 31 | return count 32 | if name == 'mean_error': 33 | return mean_error 34 | if name == 'mae': 35 | return mae 36 | if name == 'sde': 37 | return sde 38 | if name == 'min_record_error': 39 | return min_record_error 40 | if name == 'max_record_error': 41 | return max_record_error 42 | if name == 'step_function': 43 | th = settings['threshold'] 44 | invert = settings.get('invert', False) 45 | return partial(step_function, threshold=th, invert=invert) 46 | if name == 'accuracy': 47 | th = settings['threshold'] 48 | return partial(accuracy, threshold=th) 49 | if name == 'error': 50 | th = settings['threshold'] 51 | return partial(error, threshold=th) 52 | if name == 'f1_thresholded': 53 | th = settings.get('threshold', None) 54 | return partial(f1_thresholded, threshold=th) 55 | if name == 'f1_class_averaged': 56 | return f1_class_averaged 57 | if name == 'threat_score': 58 | th = settings.get('threshold', None) 59 | equitable = settings.get('equitable', False) 60 | return partial(ts_thresholded, threshold=th, equitable=equitable) 61 | if name == 'reliability': 62 | b = settings.get('bins', None) 63 | return partial(reliability, bins=b) 64 | if name == 'brier': 65 | return mse 66 | if name == 'pod': 67 | th = settings.get('threshold') 68 | return partial(pod, threshold=th) 69 | if name == 'far': 70 | th = settings.get('threshold') 71 | return partial(far, threshold=th) 72 | if name == 'csi': 73 | th = settings.get('threshold') 74 | return partial(csi, threshold=th) 75 | raise ValueError(f'Unknown metric: {name}') 76 | 77 | 78 | def rmse(data: xr.Dataset, group_dim: str) -> xr.DataArray: 79 | src = pow(data['delta'], 2) 80 | return np.sqrt(src.mean(_get_mean_dims(src, group_dim))).rename('metric') 81 | 82 | 83 | def mse(data: xr.Dataset, group_dim: str) -> xr.DataArray: 84 | src = pow(data['delta'], 2) 85 | return src.mean(_get_mean_dims(src, group_dim)).rename('metric') 86 | 87 | 88 | def accuracy(data: xr.Dataset, group_dim: str, threshold: float) -> xr.DataArray: 89 | src = abs(data['delta']) < threshold 90 | cnt = count(data, group_dim) 91 | # Weight the average to ignore missing data 92 | r = src.sum(_get_mean_dims(src, group_dim)) 93 | r = r.rename('metric') 94 | r /= cnt 95 | return r 96 | 97 | 98 | def error(data: xr.Dataset, group_dim: str, threshold: float) -> xr.DataArray: 99 | src = abs(data['delta']) > threshold 100 | cnt = count(data, group_dim) 101 | # Weight the average to ignore missing data 102 | r = src.sum(_get_mean_dims(src, group_dim)) 103 | r = r.rename('metric') 104 | r /= cnt 105 | return r 106 | 107 | 108 | def mean_error(data: xr.Dataset, group_dim: str) -> xr.DataArray: 109 | src = data['delta'] 110 | me = src.mean(_get_mean_dims(src, group_dim)) 111 | return me.rename('metric') 112 | 113 | 114 | def mae(data: xr.Dataset, group_dim: str) -> xr.DataArray: 115 | src = abs(data['delta']) 116 | src = src.mean(_get_mean_dims(src, group_dim)) 117 | return src.rename('metric') 118 | 119 | 120 | def sde(data: xr.Dataset, group_dim: str) -> xr.DataArray: 121 | src = data['delta'] 122 | me_all = src.mean() 123 | src = pow((src-me_all), 2) 124 | src = src.mean(_get_mean_dims(src, group_dim)) 125 | return np.sqrt(src).rename('metric') 126 | 127 | 128 | def min_record_error(data: xr.Dataset, group_dim: str) -> xr.DataArray: 129 | src = data['delta'] 130 | mre = src.min(_get_mean_dims(src, group_dim)) 131 | return mre.rename('metric') 132 | 133 | 134 | def max_record_error(data: xr.Dataset, group_dim: str) -> xr.DataArray: 135 | src = data['delta'] 136 | mre = src.max(_get_mean_dims(src, group_dim)) 137 | return mre.rename('metric') 138 | 139 | 140 | def step_function(data: xr.Dataset, group_dim: str, threshold: float, invert: bool) -> xr.DataArray: 141 | src = abs(data['delta']) < threshold 142 | cnt = count(data, group_dim) 143 | 144 | # Weight the average to ignore missing data 145 | r = src.sum(_get_mean_dims(src, group_dim)) 146 | r = r.rename('metric') 147 | r /= cnt 148 | if invert: 149 | r = 1.0 - r 150 | return r 151 | 152 | 153 | def count(data: xr.Dataset, group_dim: str) -> xr.DataArray: 154 | src = ~data['delta'].isnull() 155 | cnt = src.sum(_get_mean_dims(src, group_dim)) 156 | return cnt.rename('metric') 157 | 158 | 159 | def _true_positives_and_false_negatives(data, labels): 160 | """ 161 | Compute the true positives and false negatives for a given set of labels. Filter keeps only where truth is True and 162 | neither prediction nor truth is NaN. 163 | """ 164 | r_list = [] 165 | for label in labels: 166 | recall_filter = xr.where(np.logical_and(~np.isnan(data['fc']), data['obs'] == label), 1, np.nan) 167 | r_list.append((data['fc'] == label) * recall_filter) 168 | return xr.concat(r_list, dim='category').rename('metric') 169 | 170 | 171 | def _true_positives_and_false_positives(data, labels): 172 | """ 173 | Compute the true positives and false positives for a given set of labels. Filter keeps only where prediction is 174 | True and neither prediction nor truth is NaN. 175 | """ 176 | p_list = [] 177 | for label in labels: 178 | precision_filter = xr.where(np.logical_and(~np.isnan(data['obs']), data['fc'] == label), 1, np.nan) 179 | p_list.append((data['obs'] == label) * precision_filter) 180 | return xr.concat(p_list, dim='category').rename('metric') 181 | 182 | 183 | def _geometric_mean(a, b): 184 | return 2 * (a * b) / (a + b) 185 | 186 | 187 | def _count_binary_matches(data): 188 | data['tp'] = (data['obs'] == 1) & (data['fc'] == 1) 189 | data['fp'] = (data['obs'] == 0) & (data['fc'] == 1) 190 | data['tn'] = (data['obs'] == 0) & (data['fc'] == 0) 191 | data['fn'] = (data['obs'] == 1) & (data['fc'] == 0) 192 | return data 193 | 194 | 195 | def _threshold_digitize(data, threshold): 196 | new_data = xr.Dataset() 197 | new_data['fc'] = xr.apply_ufunc(np.digitize, data['fc'], [threshold], dask='allowed') 198 | new_data['obs'] = xr.apply_ufunc(np.digitize, data['obs'], [threshold], dask='allowed') 199 | 200 | # Re-assign NaN where appropriate, since they are converted to 1 by digitize 201 | new_data['fc'] = xr.where(np.isnan(data['fc']), np.nan, new_data['fc']) 202 | new_data['obs'] = xr.where(np.isnan(data['obs']), np.nan, new_data['obs']) 203 | 204 | return new_data 205 | 206 | 207 | def f1_thresholded(data: xr.Dataset, group_dim: str, threshold: float) -> xr.DataArray: 208 | """ 209 | Compute the F1 score for a True/False classification based on values meeting or exceeding a defined threshold. 210 | :param data: Dataset 211 | :param group_dim: str: group dimension 212 | :param threshold: float threshold value 213 | :return: DataArray of scores 214 | """ 215 | new_data = _threshold_digitize(data, threshold) 216 | 217 | # Precision/recall computation 218 | r = _true_positives_and_false_negatives(new_data, [1]) 219 | p = _true_positives_and_false_positives(new_data, [1]) 220 | 221 | # Compute and return F1 222 | f1 = _geometric_mean( 223 | p.mean(_get_mean_dims(p, group_dim)), 224 | r.mean(_get_mean_dims(r, group_dim)) 225 | ).rename('metric') 226 | 227 | return f1.mean('category') # mean over the single category 228 | 229 | 230 | def ts_thresholded(data: xr.Dataset, group_dim: str, threshold: float, equitable: bool) -> xr.DataArray: 231 | """ 232 | Compute the threat score for a True/False classification based on values meeting or exceeding a defined threshold. 233 | :param data: Dataset 234 | :param group_dim: str: group dimension 235 | :param threshold: float threshold value 236 | :param equitable: bool: use ETS formulation 237 | :return: DataArray of scores 238 | """ 239 | def _count_equitable_random_chance(ds): 240 | n = ds['tp'] + ds['fp'] + ds['fn'] + ds['tn'] 241 | correction = (ds['tp'] + ds['fp']) * (ds['tp'] + ds['fn']) / n 242 | return correction 243 | 244 | new_data = _threshold_digitize(data, threshold) 245 | new_data = _count_binary_matches(new_data) 246 | 247 | new_data = new_data.sum(_get_mean_dims(new_data, group_dim)) 248 | ar = _count_equitable_random_chance(new_data) if equitable else 0 249 | ts = (new_data['tp'] - ar) / (new_data['tp'] + new_data['fp'] + new_data['fn'] - ar) 250 | return ts.rename('metric') 251 | 252 | 253 | def f1_class_averaged(data: xr.Dataset, group_dim: str) -> xr.DataArray: 254 | """ 255 | Compute the F1 score as one-vs-rest for each unique category in the data. Averages the scores for each class 256 | equally. 257 | :param data: Dataset 258 | :param group_dim: str: group dimension 259 | :return: DataArray of scores 260 | """ 261 | # F1 score must be aggregated in at least one dimension 262 | labels = np.unique(data['obs'].values[~np.isnan(data['obs'].values)]) 263 | 264 | # Per-label metrics 265 | r = _true_positives_and_false_negatives(data, labels) 266 | p = _true_positives_and_false_positives(data, labels) 267 | 268 | # Compute and return F1 269 | f1 = _geometric_mean( 270 | p.mean(_get_mean_dims(p, group_dim)), 271 | r.mean(_get_mean_dims(r, group_dim)) 272 | ).rename('metric') 273 | 274 | return f1.mean('category') 275 | 276 | 277 | def reliability(data: xr.Dataset, group_dim: str, bins: list) -> xr.DataArray: 278 | relative_freq = [0 for i in range(len(bins))] 279 | mean_predicted_value = [0 for i in range(len(bins))] 280 | 281 | truth = data['obs'] 282 | # replace nan with 0 in truth 283 | truth = truth.where(truth == 1, 0) 284 | prediction = data['fc'] 285 | 286 | for b, bin_ in enumerate(bins): 287 | # replace predicted prob with 0 / 1 for given range in bin 288 | pred = prediction.where((prediction <= bin_[1]) & (bin_[0] < prediction), 0) 289 | # sum up to calculate mean later 290 | pred_sum = float(np.sum(pred)) 291 | pred = pred.where(pred == 0, 1) 292 | 293 | # how many days have prediction fall in the given bin 294 | n_prediction_in_bin = np.count_nonzero(pred) 295 | 296 | # how many days rain in prediction set 297 | correct_pred = np.logical_and(pred, truth) 298 | n_corr_pred = np.count_nonzero(correct_pred) 299 | 300 | if n_prediction_in_bin != 0: 301 | mean_predicted_value[b] = 100 * pred_sum / n_prediction_in_bin 302 | relative_freq[b] = 100 * n_corr_pred / n_prediction_in_bin 303 | else: 304 | mean_predicted_value[b] = 0 305 | relative_freq[b] = 0 306 | 307 | res = xr.Dataset({'metric': ([group_dim], relative_freq)}, coords={group_dim: mean_predicted_value}) 308 | res = res.isel({group_dim: ~res.get_index(group_dim).duplicated()}) 309 | 310 | return res['metric'] 311 | 312 | 313 | def csi(data: xr.Dataset, group_dim: str, threshold: float) -> xr.DataArray: 314 | return ts_thresholded(data, group_dim, threshold, equitable=False) 315 | 316 | 317 | def pod(data: xr.Dataset, group_dim: str, threshold: float) -> xr.DataArray: 318 | new_data = _threshold_digitize(data, threshold) 319 | r = _true_positives_and_false_negatives(new_data, [1]) 320 | return r.mean(_get_mean_dims(r, group_dim)).rename('metric') 321 | 322 | 323 | def far(data: xr.Dataset, group_dim: str, threshold: float) -> xr.DataArray: 324 | new_data = _threshold_digitize(data, threshold) 325 | r = _true_positives_and_false_positives(new_data, [1]) 326 | return 1 - r.mean(_get_mean_dims(r, group_dim)).rename('metric') 327 | -------------------------------------------------------------------------------- /evaluation/obs_reformat_catalog.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional, Sequence, Union 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import xarray as xr 7 | try: 8 | from metpy.calc import specific_humidity_from_dewpoint 9 | from metpy.units import units 10 | except ImportError: 11 | specific_humidity_from_dewpoint = None 12 | units = None 13 | 14 | from .utils import convert_to_binary 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def _convert_time_step(dt): # pylint: disable=invalid-name 20 | return pd.Timedelta(hours=dt) if isinstance(dt, (float, int)) else pd.Timedelta(dt) 21 | 22 | 23 | def reformat_and_filter_obs(obs: xr.Dataset, obs_var_name: str, interp_station_path: Optional[str], 24 | precip_threshold: Optional[float] = None) -> xr.Dataset: 25 | """ 26 | Reformat and filter the observation data, and return the DataFrame with required fields. 27 | Required fields: station, valid_time, obs 28 | Filter removes the stations that are not in the interp_station list 29 | 30 | Parameters 31 | ---------- 32 | obs: xarray Dataset: observation data 33 | obs_var_name: str: observation variable name 34 | interp_station_path: str: path to the interpolation station list 35 | precip_threshold: float: threshold for precipitation 36 | 37 | Returns 38 | ------- 39 | xr.Dataset: formatted observation data 40 | """ 41 | if interp_station_path is not None: 42 | interp_station = get_interp_station_list(interp_station_path) 43 | intersect_station = np.intersect1d(interp_station['station'].values, obs['station'].values) 44 | logger.debug(f"intersect_station count: {len(intersect_station)}, \ 45 | obs_station count: {len(obs['station'].values)}, \ 46 | interp_station count: {len(interp_station['station'].values)}") 47 | obs = obs.sel(station=intersect_station) 48 | if 'valid_time' not in obs.dims: 49 | obs = obs.rename({'time': 'valid_time'}) 50 | 51 | if obs_var_name in ['u10', 'v10']: 52 | obs = calculate_u_v(obs, ws_name='ws', wd_name='wd', u_name='u10', v_name='v10') 53 | elif obs_var_name == 'q': 54 | obs['q'] = xr.apply_ufunc(calculate_q, obs['td']) 55 | elif obs_var_name == 'pp': 56 | precip_var = 'ra' if 'ra' in obs.data_vars else 'ra1' 57 | logger.info(f"User requested precipitation probability from obs. Using variable '{precip_var}' with " 58 | f"threshold of 0.1 mm/hr.") 59 | obs['pp'] = convert_to_binary(obs[precip_var], 0.1) 60 | 61 | if precip_threshold is not None: 62 | obs[obs_var_name] = convert_to_binary(obs[obs_var_name], precip_threshold) 63 | 64 | return obs[[obs_var_name]].rename({obs_var_name: 'obs'}) 65 | 66 | 67 | def obs_to_verification( 68 | obs: Union[xr.Dataset, xr.DataArray], 69 | max_lead: Union[pd.Timedelta, int] = 168, 70 | steps: Optional[Sequence[pd.Timestamp]] = None, 71 | issue_times: Optional[Sequence[pd.Timestamp]] = None 72 | ) -> Union[xr.Dataset, xr.DataArray]: 73 | """ 74 | Convert a Dataset or DataArray of continuous time-series observations 75 | according to the obs data spec into the forecast data spec for direct 76 | comparison to forecasts. 77 | 78 | Parameters 79 | ---------- 80 | obs: xarray Dataset or DataArray of observation data 81 | max_lead: maximum lead time for verification dataset. If int, interpreted 82 | as hours. 83 | steps: optional sequence of lead times to retain 84 | issue_times: issue times for the forecast result. If 85 | not specified, uses all available obs times. 86 | """ 87 | issue_dim = 'issue_time' 88 | lead_dim = 'lead_time' 89 | time_dim = 'valid_time' 90 | max_lead = _convert_time_step(max_lead) 91 | if issue_times is None: 92 | issue_times = obs[time_dim].values 93 | obs_series = [] 94 | for issue in issue_times: 95 | try: 96 | obs_series.append( 97 | obs.sel(**{time_dim: slice(issue, issue + max_lead)}).rename({time_dim: lead_dim}) 98 | ) 99 | obs_series[-1] = obs_series[-1].assign_coords( 100 | **{issue_dim: [issue], lead_dim: obs_series[-1][lead_dim] - issue}) 101 | except Exception as e: # pylint: disable=broad-exception-caught 102 | print(f'Failed to sel {issue} due to {e}') 103 | continue 104 | verification_ds = xr.concat(obs_series, dim=issue_dim) 105 | if steps is not None: 106 | steps = [_convert_time_step(s) for s in steps] 107 | verification_ds = verification_ds.sel(**{lead_dim: steps}) 108 | verification_ds = verification_ds.assign_coords({lead_dim: verification_ds[lead_dim] / np.timedelta64(1, 'h')}) 109 | 110 | return verification_ds 111 | 112 | 113 | def calculate_q(td): 114 | if specific_humidity_from_dewpoint is None: 115 | raise ImportError('metpy is not installed, specific_humidity_from_dewpoint cannot be calculated') 116 | q = specific_humidity_from_dewpoint(1013.25 * units.hPa, td * units.degC).magnitude 117 | return q 118 | 119 | 120 | def calculate_u_v(data, ws_name='ws', wd_name='wd', u_name='u10', v_name='v10'): 121 | data[u_name] = data[ws_name] * np.sin(data[wd_name] / 180 * np.pi - np.pi) 122 | data[v_name] = data[ws_name] * np.cos(data[wd_name] / 180 * np.pi - np.pi) 123 | return data 124 | 125 | 126 | def get_interp_station_list(interp_station_path: str) -> pd.DataFrame: 127 | """ 128 | Read the station metadata from the interpolation station list file, and return a dataset for interpolation. 129 | 130 | Parameters 131 | ---------- 132 | interp_station_path: str: path to the interpolation station list 133 | 134 | Returns 135 | ------- 136 | pd.DataFrame: station metadata with columns station, lat, lon 137 | """ 138 | interp_station = pd.read_csv(interp_station_path) 139 | interp_station = interp_station.rename({c: c.lower() for c in interp_station.columns}, axis=1) 140 | interp_station['lon'] = interp_station['lon'].apply(lambda lon: lon if (lon < 180) else (lon - 360)) 141 | if 'station' not in interp_station.columns: 142 | interp_station['station'] = interp_station['id'] 143 | return interp_station 144 | -------------------------------------------------------------------------------- /evaluation/utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import logging 3 | import os 4 | from pathlib import Path 5 | from typing import Optional 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import xarray as xr 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def configure_logging(verbose=1): 15 | verbose_levels = { 16 | 0: logging.WARNING, 17 | 1: logging.INFO, 18 | 2: logging.DEBUG, 19 | 3: logging.NOTSET 20 | } 21 | if verbose not in verbose_levels: 22 | verbose = 1 23 | logger.setLevel(verbose_levels[verbose]) 24 | handler = logging.StreamHandler() 25 | handler.setFormatter(logging.Formatter( 26 | "[%(asctime)s] [PID=%(process)d] " 27 | "[%(levelname)s %(filename)s:%(lineno)d] %(message)s")) 28 | handler.setLevel(verbose_levels[verbose]) 29 | logger.addHandler(handler) 30 | 31 | 32 | @dataclass 33 | class ForecastInfo: 34 | path: str 35 | forecast_name: str 36 | fc_var_name: str 37 | reformat_func: str 38 | file_type: str 39 | station_metadata_path: str 40 | interp_station_path: str 41 | output_directory: str 42 | start_date: Optional[pd.Timestamp] = None 43 | end_date: Optional[pd.Timestamp] = None 44 | issue_time_freq: Optional[str] = None 45 | start_lead: Optional[int] = None 46 | end_lead: Optional[int] = None 47 | convert_temperature: Optional[bool] = False 48 | convert_pressure: Optional[bool] = False 49 | convert_cloud: Optional[bool] = False 50 | precip_proba_threshold: Optional[float] = None 51 | metadata: Optional[pd.DataFrame] = None 52 | cache_path: Optional[str] = None 53 | 54 | 55 | @dataclass 56 | class ForecastData: 57 | info: ForecastInfo 58 | forecast: Optional[xr.Dataset] = None 59 | merge_data: Optional[xr.Dataset] = None 60 | 61 | 62 | @dataclass 63 | class MetricData: 64 | info: ForecastInfo 65 | metric_data: Optional[xr.Dataset] = None 66 | 67 | 68 | def get_metric_multiple_stations(files): 69 | data = {} 70 | files = files.split(',') 71 | for f in files: 72 | if os.path.exists(f): 73 | try: 74 | key = os.path.basename(f).replace(".csv", "") 75 | data[key] = [str(id) for id in pd.read_csv(f)['Station'].tolist()] 76 | except Exception: 77 | logger.error(f"Error opening {f}!") 78 | raise 79 | else: 80 | raise Warning(f'File {f} do not exist!') 81 | return data 82 | 83 | 84 | def generate_forecast_cache_path(info: ForecastInfo): 85 | if info.cache_path is not None: 86 | return info.cache_path 87 | file_name = Path(info.path).stem 88 | forecast_name = info.forecast_name 89 | fc_var_name = info.fc_var_name 90 | reformat_func = info.reformat_func 91 | interp_station = Path(info.interp_station_path).stem if info.interp_station_path is not None else 'default' 92 | cache_directory = os.path.join(info.output_directory, 'cache') 93 | os.makedirs(cache_directory, exist_ok=True) 94 | cache_file_name = "##".join([file_name, forecast_name, fc_var_name, reformat_func, interp_station, 'cache']) 95 | info.cache_path = os.path.join(cache_directory, cache_file_name) 96 | return info.cache_path 97 | 98 | 99 | def cache_reformat_forecast(forecast_ds, cache_path): 100 | logger.info(f"saving forecast to cache at {cache_path}") 101 | forecast_ds.to_zarr(cache_path, mode='w') 102 | 103 | 104 | def load_reformat_forecast(cache_path): 105 | forecast_ds = xr.open_zarr(cache_path) 106 | return forecast_ds 107 | 108 | 109 | def get_ideal_xticks(min_lead, max_lead, tick_count=8): 110 | """ 111 | Pick the best interval for the x axis (hours) that optimizes the number of ticks 112 | """ 113 | candidate_intervals = [1, 3, 6, 12, 24] 114 | tick_counts = [] 115 | for interval in candidate_intervals: 116 | num_ticks = (max_lead - min_lead) / interval 117 | tick_counts.append(abs(num_ticks - tick_count)) 118 | best_interval = candidate_intervals[tick_counts.index(min(tick_counts))] 119 | return np.arange(min_lead, max_lead + best_interval, best_interval) 120 | 121 | 122 | def convert_to_binary(da, threshold): 123 | result = xr.where(da >= threshold, np.float32(1.0), np.float32(0.0)) 124 | result = result.where(~np.isnan(da), np.nan) 125 | return result 126 | -------------------------------------------------------------------------------- /metric_config.yml: -------------------------------------------------------------------------------- 1 | temperature: 2 | base_plot_setting: 3 | title: Temperature 4 | xlabel: 5 | lead_time: Lead Hour (H) 6 | issue_time: Issue Time (UTC) 7 | valid_time: Valid Time (UTC) 8 | 9 | metrics: 10 | RMSE: 11 | method: rmse 12 | 13 | MAE: 14 | method: mae 15 | 16 | ACCURACY_1.7: 17 | method: accuracy 18 | threshold: 1.7 19 | 20 | ERROR_5.6: 21 | method: error 22 | threshold: 5.6 23 | 24 | cloud: 25 | base_plot_setting: 26 | title: Cloud 27 | xlabel: 28 | lead_time: Lead Hour (H) 29 | issue_time: Issue Time (UTC) 30 | valid_time: Valid Time (UTC) 31 | 32 | metrics: 33 | RMSE: 34 | method: rmse 35 | 36 | MAE: 37 | method: mae 38 | 39 | ACCURACY_2: 40 | method: accuracy 41 | threshold: 2 42 | 43 | ERROR_5: 44 | method: accuracy 45 | threshold: 5 46 | 47 | ETS_1p5: 48 | method: threat_score 49 | threshold: 1.5 50 | equitable: True 51 | 52 | ETS_6p5: 53 | method: threat_score 54 | threshold: 6.5 55 | equitable: True 56 | 57 | wind: 58 | base_plot_setting: 59 | title: wind 60 | xlabel: 61 | lead_time: Lead Hour (H) 62 | issue_time: Issue Time (UTC) 63 | valid_time: Valid Time (UTC) 64 | 65 | metrics: 66 | RMSE: 67 | method: rmse 68 | 69 | MAE: 70 | method: mae 71 | 72 | ACCURACY_1: 73 | method: accuracy 74 | threshold: 1 75 | 76 | ERROR_3: 77 | method: error 78 | threshold: 3 79 | 80 | specific_humidity: 81 | base_plot_setting: 82 | title: specific_humidity 83 | xlabel: 84 | lead_time: Lead Hour (H) 85 | issue_time: Issue Time (UTC) 86 | valid_time: Valid Time (UTC) 87 | 88 | metrics: 89 | RMSE: 90 | method: rmse 91 | 92 | MAE: 93 | method: mae 94 | 95 | precipitation: 96 | base_plot_setting: 97 | title: precipitation(mm) 98 | xlabel: 99 | lead_time: Lead Hour (H) 100 | issue_time: Issue Time (UTC) 101 | valid_time: Valid Time (UTC) 102 | 103 | metrics: 104 | RMSE: 105 | method: rmse 106 | 107 | ETS_0p1: 108 | method: threat_score 109 | threshold: 0.1 110 | equitable: True 111 | 112 | ETS_1: 113 | method: threat_score 114 | threshold: 1.0 115 | equitable: True 116 | 117 | POD_1: 118 | method: pod 119 | threshold: 1.0 120 | 121 | FAR_1: 122 | method: far 123 | threshold: 1.0 124 | 125 | precip_proba: 126 | base_plot_setting: 127 | title: precipitation probability(%) 128 | xlabel: 129 | lead_time: Lead Hour (H) 130 | issue_time: Issue Time (UTC) 131 | valid_time: Valid Time (UTC) 132 | 133 | metrics: 134 | Brier_Score: 135 | method: brier 136 | 137 | Occurrence_Frequency(%): 138 | method: reliability 139 | bins: [[0, 0.05], [0.05, 0.15], [0.15, 0.25], [0.25, 0.35], [0.35, 0.45], [0.45, 0.55], [0.55, 0.65], [0.65, 0.75], [0.75, 0.85], [0.85, 0.95], [0.95, 1]] 140 | plot_setting: 141 | xlabel: 142 | lead_time: Bins of probability(%) 143 | 144 | ETS_0.1: 145 | method: threat_score 146 | threshold: 0.1 147 | equitable: True 148 | 149 | ETS_0.35: 150 | method: threat_score 151 | threshold: 0.35 152 | equitable: True 153 | 154 | ETS_0.4: 155 | method: threat_score 156 | threshold: 0.4 157 | equitable: True 158 | 159 | ETS_0.5: 160 | method: threat_score 161 | threshold: 0.5 162 | equitable: True 163 | 164 | precip_binary: 165 | base_plot_setting: 166 | title: precipitation binary 167 | xlabel: 168 | lead_time: Lead Hour (H) 169 | issue_time: Issue Time (UTC) 170 | valid_time: Valid Time (UTC) 171 | 172 | metrics: 173 | ETS: 174 | method: threat_score 175 | threshold: 0.5 176 | equitable: True 177 | 178 | POD: 179 | method: pod 180 | threshold: 0.5 181 | 182 | FAR: 183 | method: far 184 | threshold: 0.5 185 | 186 | Accuracy: 187 | method: accuracy 188 | threshold: 0.1 189 | 190 | pressure: 191 | base_plot_setting: 192 | title: Mean sea-level pressure 193 | xlabel: 194 | lead_time: Lead Hour (H) 195 | issue_time: Issue Time (UTC) 196 | valid_time: Valid Time (UTC) 197 | 198 | metrics: 199 | RMSE: 200 | method: rmse 201 | 202 | MAE: 203 | method: mae 204 | 205 | ACCURACY_2: 206 | method: accuracy 207 | threshold: 2 208 | 209 | ERROR_5: 210 | method: error 211 | threshold: 5 -------------------------------------------------------------------------------- /pylintrc: -------------------------------------------------------------------------------- 1 | [MAIN] 2 | 3 | max-line-length=120 4 | 5 | [MESSAGES CONTROL] 6 | 7 | disable= 8 | missing-module-docstring, 9 | missing-class-docstring, 10 | missing-function-docstring, 11 | missing-docstring, 12 | invalid-name, 13 | import-error, 14 | logging-fstring-interpolation, 15 | too-many-arguments, 16 | too-many-locals, 17 | too-many-branches, 18 | too-many-statements, 19 | too-many-return-statements, 20 | too-many-instance-attributes, 21 | too-many-positional-arguments, 22 | duplicate-code 23 | -------------------------------------------------------------------------------- /quality_control/README.md: -------------------------------------------------------------------------------- 1 | The quality control code in this directory is written in Python and is used by WeatherReal for downloading, post-processing, and quality control of ISD data. However, its modules can also be used for quality control of observation data from other sources. It includes four launch scripts: 2 | 3 | 1. `download_ISD.py`, used to download ISD data from the NCEI server; 4 | 2. `raw_ISD_to_hourly.py`, used to convert ISD data into hourly data; 5 | 3. `station_merging.py`, used to merge data from duplicate stations; 6 | 4. `quality_control.py`, used to perform quality control on hourly data. 7 | 8 | For the specific workflow, please refer to the WeatherReal paper. -------------------------------------------------------------------------------- /quality_control/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/WeatherReal-Benchmark/68b2f9293d2a0a1b1cedf396cffda34c05a21a14/quality_control/__init__.py -------------------------------------------------------------------------------- /quality_control/algo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/WeatherReal-Benchmark/68b2f9293d2a0a1b1cedf396cffda34c05a21a14/quality_control/algo/__init__.py -------------------------------------------------------------------------------- /quality_control/algo/cluster.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.cluster import DBSCAN 3 | from .utils import intra_station_check, quality_control_statistics, get_config 4 | 5 | 6 | CONFIG = get_config() 7 | 8 | 9 | def _cluster_check(ts, reanalysis, min_samples_ratio, eps_scale, max_std_scale=None, min_num=None): 10 | """ 11 | Perform a cluster-based check on time series data compared to reanalysis data. 12 | 13 | This function uses DBSCAN (Density-Based Spatial Clustering of Applications with Noise) 14 | to identify clusters in the difference between the time series and reanalysis data. 15 | It then flags outliers based on cluster membership. 16 | 17 | Parameters: 18 | ----------- 19 | ts : array-like 20 | The time series data to be checked. 21 | reanalysis : array-like 22 | The corresponding reanalysis data for comparison. 23 | min_samples_ratio : float 24 | The ratio of minimum samples required to form a cluster in DBSCAN. 25 | eps_scale : float 26 | The scale factor for epsilon in DBSCAN, relative to the standard deviation of reanalysis. 27 | max_std_scale : float, optional 28 | If the ratio of standard deviations between ts and reanalysis is less than max_std_scale, 29 | the check is not performed 30 | min_num : int, optional 31 | Minimum number of valid data points required to perform the check. 32 | 33 | Returns: 34 | -------- 35 | flag : np.ndarray 36 | 1D array with the same length as ts, containing flags 37 | """ 38 | flag = np.full(len(ts), CONFIG["flag_missing"], dtype=np.int8) 39 | isnan = np.isnan(ts) 40 | flag[~isnan] = CONFIG["flag_normal"] 41 | both_valid = (~isnan) & (~np.isnan(reanalysis)) 42 | 43 | if min_num is not None and both_valid.sum() < min_num: 44 | return flag 45 | 46 | if max_std_scale is not None: 47 | if np.std(ts[both_valid]) / np.std(reanalysis[both_valid]) <= max_std_scale: 48 | return flag 49 | 50 | indices = np.argwhere(both_valid).flatten() 51 | values1 = ts[indices] 52 | values2 = reanalysis[indices] 53 | 54 | cluster = DBSCAN(min_samples=int(indices.size * min_samples_ratio), eps=np.std(reanalysis) * eps_scale) 55 | labels = cluster.fit((values1 - values2).reshape(-1, 1)).labels_ 56 | 57 | # If all data points except noise are in the same cluster, just remove the noise 58 | if np.max(labels) <= 0: 59 | indices_outliers = indices[np.argwhere(labels == -1).flatten()] 60 | flag[indices_outliers] = CONFIG["flag_error"] 61 | # If there is more than one cluster, select the cluster nearest to the reanalysis 62 | else: 63 | cluster_labels = np.unique(labels) 64 | cluster_labels = list(cluster_labels >= 0) 65 | best_cluster = 0 66 | min_median = np.inf 67 | for label in cluster_labels: 68 | indices_cluster = indices[np.argwhere(labels == label).flatten()] 69 | median = np.median(ts[indices_cluster] - reanalysis[indices_cluster]) 70 | if np.abs(median) < min_median: 71 | best_cluster = label 72 | min_median = np.abs(median) 73 | indices_outliers = indices[np.argwhere(labels != best_cluster).flatten()] 74 | flag[indices_outliers] = CONFIG["flag_error"] 75 | 76 | return flag 77 | 78 | 79 | def run(da, reanalysis, varname): 80 | """ 81 | To ensure the accuracy of the parameters in the distributional gap method, a DBSCAN is used first, 82 | so that the median and MAD used for filtering are only calculated by the normal data 83 | """ 84 | flag = intra_station_check( 85 | da, 86 | reanalysis, 87 | qc_func=_cluster_check, 88 | input_core_dims=[["time"], ["time"]], 89 | kwargs=CONFIG["cluster"][varname], 90 | ) 91 | quality_control_statistics(da, flag) 92 | return flag.rename("cluster") 93 | -------------------------------------------------------------------------------- /quality_control/algo/config.yaml: -------------------------------------------------------------------------------- 1 | ############################################################ 2 | # The flag values for the data quality check 3 | ############################################################ 4 | 5 | flag_error: 2 6 | flag_suspect: 1 7 | flag_normal: 0 8 | flag_missing: -1 9 | 10 | ############################################################ 11 | # The parameters for the data quality check 12 | ############################################################ 13 | 14 | # record extreme check 15 | record: 16 | t: 17 | upper: 57.8 18 | lower: -89.2 19 | td: 20 | upper: 57.8 21 | lower: -100.0 22 | ws: 23 | upper: 113.2 24 | lower: 0 25 | wd: 26 | upper: 360 27 | lower: 0 28 | sp: 29 | upper: 1100.0 30 | lower: 300.0 31 | msl: 32 | upper: 1083.3 33 | lower: 870.0 34 | c: 35 | upper: 8 36 | lower: 0 37 | ra1: 38 | upper: 305.0 39 | lower: 0 40 | ra3: 41 | upper: 915.0 42 | lower: 0 43 | ra6: 44 | upper: 1144.0 45 | lower: 0 46 | ra12: 47 | upper: 1144.0 48 | lower: 0 49 | ra24: 50 | upper: 1825.0 51 | lower: 0 52 | 53 | # Persistence check 54 | persistence: 55 | defaults: &persistence_defaults 56 | min_num: 24 57 | max_window: 72 58 | min_var: 0.1 59 | error_length: 72 60 | t: 61 | <<: *persistence_defaults 62 | td: 63 | <<: *persistence_defaults 64 | sp: 65 | <<: *persistence_defaults 66 | msl: 67 | <<: *persistence_defaults 68 | ws: 69 | <<: *persistence_defaults 70 | exclude_value: 0 71 | wd: 72 | <<: *persistence_defaults 73 | exclude_value: 0 74 | c: 75 | <<: *persistence_defaults 76 | exclude_value: 77 | - 0 78 | - 8 79 | ra1: 80 | <<: *persistence_defaults 81 | exclude_value: 0 82 | ra3: 83 | <<: *persistence_defaults 84 | exclude_value: 0 85 | ra6: 86 | <<: *persistence_defaults 87 | exclude_value: 0 88 | ra12: 89 | <<: *persistence_defaults 90 | exclude_value: 0 91 | ra24: 92 | <<: *persistence_defaults 93 | exclude_value: 0 94 | 95 | # Spike check 96 | spike: 97 | t: 98 | max_change: 99 | - 6 100 | - 8 101 | - 10 102 | td: 103 | max_change: 104 | - 5 105 | - 7 106 | - 9 107 | sp: 108 | max_change: 109 | - 3 110 | - 5 111 | - 7 112 | msl: 113 | max_change: 114 | - 3 115 | - 5 116 | - 7 117 | 118 | # Distributional gap check with ERA5 119 | distribution: 120 | defaults: &distribution_defaults 121 | shift_step: 1 122 | gap_scale: 2 123 | default_mad: 1 124 | suspect_std_scale: 2.72 125 | min_num: 365 126 | t: 127 | <<: *distribution_defaults 128 | td: 129 | <<: *distribution_defaults 130 | sp: 131 | <<: *distribution_defaults 132 | default_mad: 0.5 133 | msl: 134 | <<: *distribution_defaults 135 | default_mad: 0.5 136 | 137 | # Cluster check 138 | cluster: 139 | defaults: &cluster_defaults 140 | min_samples_ratio: 0.1 141 | eps_scale: 2 142 | max_std_scale: 2 143 | min_num: 365 144 | t: 145 | <<: *cluster_defaults 146 | 147 | td: 148 | <<: *cluster_defaults 149 | sp: 150 | <<: *cluster_defaults 151 | msl: 152 | <<: *cluster_defaults 153 | 154 | # Neighbouring station check 155 | neighbouring: 156 | defaults: &neighbouring_defaults 157 | <<: *distribution_defaults 158 | max_dist: 300 159 | max_elev_diff: 500 160 | min_data_overlap: 0.3 161 | t: 162 | <<: *neighbouring_defaults 163 | td: 164 | <<: *neighbouring_defaults 165 | sp: 166 | <<: *neighbouring_defaults 167 | default_mad: 0.5 168 | msl: 169 | <<: *neighbouring_defaults 170 | default_mad: 0.5 171 | 172 | # Flag refinement 173 | refinement: 174 | t: 175 | check_monotonic: true 176 | check_ridge_trough: false 177 | td: 178 | check_monotonic: true 179 | check_ridge_trough: false 180 | sp: 181 | check_monotonic: true 182 | check_ridge_trough: true 183 | msl: 184 | check_monotonic: true 185 | check_ridge_trough: true 186 | 187 | diurnal: 188 | t: 189 | max_bias: 0.5 190 | -------------------------------------------------------------------------------- /quality_control/algo/cross_variable.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import xarray as xr 3 | from .utils import get_config 4 | 5 | 6 | CONFIG = get_config() 7 | 8 | 9 | def supersaturation(t, td): 10 | flag = xr.where(t < td, CONFIG["flag_error"], CONFIG["flag_normal"]) 11 | isnan = np.isnan(t) | np.isnan(td) 12 | flag = flag.where(~isnan, CONFIG["flag_missing"]) 13 | return flag.astype(np.int8) 14 | 15 | 16 | def wind_consistency(ws, wd): 17 | zero_ws = ws == 0 18 | zero_wd = wd == 0 19 | inconsistent = zero_ws != zero_wd 20 | flag = xr.where(inconsistent, CONFIG["flag_error"], CONFIG["flag_normal"]) 21 | isnan = np.isnan(ws) | np.isnan(wd) 22 | flag = flag.where(~isnan, CONFIG["flag_missing"]) 23 | return flag.astype(np.int8) 24 | 25 | 26 | def ra_consistency(ds): 27 | """ 28 | Precipitation in different period length is cross validated 29 | If there is a conflict (ra1 at 18:00 is 5mm, while ra3 at 19:00 is 0mm), 30 | both of them are flagged 31 | Parameters: 32 | ----------- 33 | ds: xarray dataset with precipitation data including ra1, ra3, ra6, ra12 and ra24 34 | Returns: 35 | -------- 36 | flag : numpy.ndarray 37 | 1D array with the same length as ts, containing flags for each value. 38 | """ 39 | flag = xr.full_like(ds, CONFIG["flag_normal"], dtype=np.int8) 40 | periods = [3, 6, 12, 24] 41 | for period in periods: 42 | da_longer = ds[f"ra{period}"] 43 | for shift in range(1, period): 44 | # Shift the precipitation in the longer period to align with the shorter period 45 | shifted = da_longer.roll(time=-shift) 46 | shifted[:, -shift:] = np.nan 47 | for target_period in [1, 3, 6, 12, 24]: 48 | # Check if the two periods are overlapping 49 | if target_period >= period or target_period + shift > period: 50 | continue 51 | # If the precipitation in the shorter period is larger than the longer period, flag both 52 | flag_indices = np.where(ds[f"ra{target_period}"].values - shifted.values > 0.101) 53 | if len(flag_indices[0]) == 0: 54 | continue 55 | flag[f"ra{target_period}"].values[flag_indices] = 1 56 | flag[f"ra{period}"].values[(flag_indices[0], flag_indices[1]+shift)] = 1 57 | flag = flag.where(ds.notnull(), CONFIG["flag_missing"]) 58 | return flag 59 | -------------------------------------------------------------------------------- /quality_control/algo/distributional_gap.py: -------------------------------------------------------------------------------- 1 | import xarray as xr 2 | from .utils import get_config, intra_station_check, quality_control_statistics 3 | from .time_series import _time_series_comparison 4 | 5 | 6 | CONFIG = get_config() 7 | 8 | 9 | def _distributional_gap( 10 | ts, 11 | reanalysis, 12 | mask, 13 | shift_step, 14 | gap_scale, 15 | default_mad, 16 | suspect_std_scale, 17 | min_mad=0.1, 18 | min_num=None, 19 | ): 20 | flag = _time_series_comparison( 21 | ts1=ts, 22 | ts2=reanalysis, 23 | shift_step=shift_step, 24 | gap_scale=gap_scale, 25 | default_mad=default_mad, 26 | suspect_std_scale=suspect_std_scale, 27 | min_mad=min_mad, 28 | min_num=min_num, 29 | mask=mask, 30 | ) 31 | return flag 32 | 33 | 34 | def run(da, reanalysis, varname, mask=None): 35 | """ 36 | Perform a distributional gap check on time series data compared to reanalysis data. 37 | """ 38 | flag = intra_station_check( 39 | da, 40 | reanalysis, 41 | mask if mask is not None else xr.full_like(da, True, dtype=bool), 42 | qc_func=_distributional_gap, 43 | input_core_dims=[["time"], ["time"], ["time"]], 44 | kwargs=CONFIG["distribution"][varname], 45 | ) 46 | quality_control_statistics(da, flag) 47 | return flag.rename("distributional_gap") 48 | -------------------------------------------------------------------------------- /quality_control/algo/diurnal_cycle.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import xarray as xr 4 | import bottleneck as bn 5 | from suntime import Sun, SunTimeException 6 | from .utils import get_config, intra_station_check 7 | 8 | 9 | CONFIG = get_config() 10 | 11 | 12 | def _diurnal_cycle_check_daily(ts, lat, lon, date, max_bias): 13 | """ 14 | Fit the daily temperature time series with a sine curve 15 | The amplitude is estimated by the daily range of the time series while the phase is estimated by the sunrise time 16 | If the time series is well fitted by the sine curve, it is considered as normal 17 | Returns: 18 | -------- 19 | flag: int, a single flag indicating the result of the daily series 20 | """ 21 | if ts.size != 24: 22 | raise ValueError("The time series should have 24 values") 23 | # Only check stations between 60S and 60N with significant diurnal cycles 24 | if abs(lat) > 60: 25 | return CONFIG["flag_missing"] 26 | # Only check samples with at least 2 valid data points in each quartile of the day 27 | if (np.isfinite(ts.reshape(-1, 6)).sum(axis=1) < 2).any(): 28 | return CONFIG["flag_missing"] 29 | 30 | maxv = bn.nanmax(ts) 31 | minv = bn.nanmin(ts) 32 | amplitude = (maxv - minv) / 2 33 | 34 | # Only check samples with significant amplitude 35 | if amplitude < 5: 36 | return CONFIG["flag_missing"] 37 | 38 | timestep = pd.Timestamp(date) 39 | try: 40 | sunrise = Sun(float(lat), float(lon)).get_sunrise_time(timestep) 41 | except SunTimeException: 42 | return CONFIG["flag_missing"] 43 | 44 | # Normalize the time series by the max and min values 45 | normed = (ts - maxv + amplitude) / amplitude 46 | 47 | # Assume the diurnal cycle is a sine curve and the valley is 1H before the sunrise time 48 | shift = (sunrise.hour + sunrise.minute / 60) / 24 * 2 * np.pi - 20 / 12 * np.pi - timestep.hour / 12 * np.pi 49 | # A tolerance of 3 hours is allowed 50 | tolerance_values = np.arange(-np.pi / 4, np.pi / 4 + 1e-5, np.pi / 12) 51 | for tolerance in tolerance_values: 52 | # Try to find a best fitted phase of the sine curve 53 | sine_curve = np.sin((2 * np.pi / 24) * np.arange(24) - shift - tolerance) 54 | if bn.nanmax(np.abs(normed - sine_curve)) < max_bias: 55 | return CONFIG["flag_normal"] 56 | 57 | return CONFIG["flag_suspect"] 58 | 59 | 60 | def _diurnal_cycle_check(ts, flagged, lat, lon, dates, max_bias): 61 | """ 62 | Check the diurnal cycle of the temperature time series only for short period flagged by `flagged` 63 | Parameters: 64 | ----------- 65 | ts: 1D np.array, the daily temperature time series 66 | flagged: 1D np.array, the flag array from other algorithms 67 | lat: float, the latitude of the station 68 | lon: float, the longitude of the station 69 | date: 1D np.array of numpy.datetime64, the date of the time series 70 | max_bias: float, the maximum bias allowed for the sine curve fitting (suggested: 0.5-1) 71 | Returns: 72 | -------- 73 | new_flag: 1D np.array, only the checked days are flagged as either normal or suspect 74 | """ 75 | new_flag = np.full_like(flagged, CONFIG["flag_missing"]) 76 | length = len(flagged) 77 | error_flags = np.argwhere((flagged == CONFIG["flag_error"]) | (flagged == CONFIG["flag_suspect"])).flatten() 78 | end_idx = 0 79 | for idx, start_idx in enumerate(error_flags): 80 | if start_idx <= end_idx: 81 | continue 82 | # Combine these short erroneous/suspect periods into a longer one 83 | end_idx = start_idx 84 | for next_idx in error_flags[idx+1:]: 85 | if ( 86 | (flagged[idx+1: next_idx] == CONFIG["flag_normal"]).any() or 87 | (flagged[idx+1: next_idx] == CONFIG["flag_suspect"]).any() 88 | ): 89 | break 90 | end_idx = next_idx 91 | period_length = end_idx - start_idx + 1 92 | if period_length > 12: 93 | continue 94 | # Select the daily series centered at the short erroneous/suspect period 95 | if length % 2 == 1: 96 | num_left, num_right = (24-period_length) // 2, (24-period_length) // 2 + 1 97 | else: 98 | num_left, num_right = (24-period_length) // 2, (24-period_length) // 2 99 | if (flagged[start_idx-num_left: start_idx] != CONFIG["flag_normal"]).all(): 100 | continue 101 | if (flagged[end_idx+1: end_idx+1+num_right] != CONFIG["flag_normal"]).all(): 102 | continue 103 | 104 | daily_start_idx = start_idx - num_left 105 | if daily_start_idx < 0: 106 | daily_start_idx = 0 107 | elif daily_start_idx > length - 24: 108 | daily_start_idx = length - 24 109 | daily_ts = ts[daily_start_idx: daily_start_idx + 24] 110 | daily_flag = _diurnal_cycle_check_daily(daily_ts, lat, lon, dates[daily_start_idx], max_bias=max_bias) 111 | new_flag[start_idx: end_idx+1] = daily_flag 112 | 113 | new_flag[np.isnan(ts)] = CONFIG["flag_missing"] 114 | return new_flag 115 | 116 | 117 | def adjust_by_diurnal(target_flag, diurnal_flag): 118 | """ 119 | Diurnal cycle check can be used for refine flags from other algorithms 120 | """ 121 | normal_diurnal = diurnal_flag == CONFIG["flag_normal"] 122 | error_target = target_flag == CONFIG["flag_error"] 123 | new_flag = target_flag.copy() 124 | new_flag = xr.where(normal_diurnal & error_target, CONFIG["flag_suspect"], new_flag) 125 | return new_flag 126 | 127 | 128 | def run(da, checked_flag): 129 | """ 130 | Perform a diurnal cycle check on Temperature time series 131 | Returns: 132 | -------- 133 | The adjusted flags 134 | """ 135 | flag = intra_station_check( 136 | da, 137 | checked_flag, 138 | da["lat"], 139 | da["lon"], 140 | da["time"], 141 | qc_func=_diurnal_cycle_check, 142 | input_core_dims=[["time"], ["time"], [], [], ["time"]], 143 | kwargs=CONFIG["diurnal"]["t"], 144 | ) 145 | new_flag = adjust_by_diurnal(checked_flag, flag) 146 | return new_flag.rename("diurnal_cycle") 147 | -------------------------------------------------------------------------------- /quality_control/algo/fine_tuning.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import bottleneck as bn 3 | import xarray as xr 4 | from .utils import get_config, intra_station_check 5 | 6 | 7 | CONFIG = get_config() 8 | 9 | 10 | def _suspect_upgrade(ts_flag): 11 | """ 12 | Upgrade suspect flag period if they are surrounded by error flags 13 | Returns: 14 | -------- 15 | new_flag: np.array, the upgraded flag array with only part of suspect flag upgraded to erroneous 16 | """ 17 | new_flag = ts_flag.copy() 18 | 19 | # Get the start and end indices of each continous suspect period 20 | is_suspect = (ts_flag == CONFIG["flag_suspect"]) | (ts_flag == CONFIG["flag_missing"]) 21 | is_suspect = np.insert(is_suspect, 0, False) 22 | is_suspect = np.append(is_suspect, False) 23 | start = np.where(is_suspect & ~np.roll(is_suspect, 1))[0] - 1 24 | end = np.where(is_suspect & ~np.roll(is_suspect, -1))[0] - 1 25 | 26 | # Filter out the suspect period in case that pure missing flags are included 27 | periods = [item for item in list(zip(start, end)) if (ts_flag[item[0]: item[1]+1] == CONFIG["flag_suspect"]).any()] 28 | 29 | length = len(ts_flag) 30 | for start_idx, end_idx in periods: 31 | if (start_idx == 0) or (end_idx == length - 1): 32 | continue 33 | is_error_left = ts_flag[start_idx-1] == CONFIG["flag_error"] 34 | is_error_right = ts_flag[end_idx+1] == CONFIG["flag_error"] 35 | if is_error_left and is_error_right: 36 | new_flag[start_idx:end_idx+1] = CONFIG["flag_error"] 37 | 38 | new_flag[ts_flag == CONFIG["flag_missing"]] = CONFIG["flag_missing"] 39 | return new_flag 40 | 41 | 42 | def _upgrade_flags_window(flags, da): 43 | """ 44 | Upgrade suspect flags based on the proportion of flagged points among valid data points in a sliding window 45 | """ 46 | window_size = 720 # One month sliding window 47 | threshold = 0.5 # More than half of the data points 48 | 49 | # Convert to numpy arrays for faster computation 50 | flags_array = flags.values 51 | data_array = da.values 52 | 53 | mask = (flags_array == CONFIG["flag_suspect"]) | (flags_array == CONFIG["flag_error"]) 54 | valid = ~np.isnan(data_array) 55 | 56 | # Calculate rolling sums 57 | rolling_sum = bn.move_sum(mask.astype(np.float64), window=window_size, min_count=1, axis=1) 58 | rolling_sum_valid = bn.move_sum(valid.astype(np.float64), window=window_size, min_count=1, axis=1) 59 | rolling_sum_valid = np.where(rolling_sum_valid == 0, np.nan, rolling_sum_valid) 60 | 61 | # Calculate proportion of flagged points among valid data points 62 | proportion_flagged = rolling_sum / rolling_sum_valid 63 | 64 | # Create a padded array for centered calculation 65 | pad_width = window_size // 2 66 | exceed_threshold = np.pad(proportion_flagged > threshold, ((0, 0), (pad_width, pad_width)), mode='edge') 67 | 68 | # Use move_max on the padded array 69 | expanded_mask = bn.move_max(exceed_threshold.astype(np.float64), window=window_size, min_count=1, axis=1) 70 | 71 | # Remove padding 72 | expanded_mask = expanded_mask[:, pad_width:-pad_width] 73 | 74 | # Upgrade suspect flags to error flags where the expanded mask is True 75 | upgraded_flags = np.where( 76 | (flags_array == CONFIG["flag_suspect"]) & (expanded_mask == 1), 77 | CONFIG["flag_error"], 78 | flags_array 79 | ) 80 | 81 | # Convert back to xarray DataArray 82 | return xr.DataArray(upgraded_flags, coords=flags.coords, dims=flags.dims) 83 | 84 | 85 | def _upgrade_flags_all(flags, da): 86 | """ 87 | If more than half of the data points at a station are flagged as erroneous, 88 | all the data points at this station are flagged as erroneous 89 | """ 90 | threshold = 0.5 # More than half of the data points 91 | 92 | mask = flags == CONFIG["flag_error"] 93 | valid = da.notnull() 94 | proportion_flagged = mask.sum(dim="time") / valid.sum(dim="time") 95 | 96 | upgraded_flags = flags.where(proportion_flagged < threshold, CONFIG["flag_error"]) 97 | upgraded_flags = flags.where(valid, CONFIG["flag_missing"]) 98 | 99 | return upgraded_flags 100 | 101 | 102 | def run(flag, da): 103 | flag = intra_station_check(flag, qc_func=_suspect_upgrade) 104 | flag = _upgrade_flags_window(flag, da) 105 | flag = _upgrade_flags_all(flag, da) 106 | return flag.rename("fine_tuning") 107 | -------------------------------------------------------------------------------- /quality_control/algo/neighbouring_stations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import xarray as xr 3 | from .time_series import _time_series_comparison 4 | from .utils import get_config, quality_control_statistics 5 | 6 | 7 | CONFIG = get_config() 8 | 9 | 10 | def _select_neighbouring_stations(similarity, stn_data, lat, lon, data, max_dist, max_elev_diff, min_data_overlap): 11 | """ 12 | Select neighbouring stations based on distance, elevations and data overlap 13 | At most 8 stations will be selected, 2 at most for each direction (northwest, northeast, southwest, southeast) 14 | Parameters: 15 | similarity: xr.Dataset generated at the previous step with only the checked station 16 | stn_data: xr.DataArray with the data of the station 17 | lat: latitude of the station 18 | lon: longitude of the station 19 | data: xr.Dataset with the data of all stations 20 | max_dist: maximum distance in km 21 | max_elev_diff: maximum elevation difference in m 22 | min_data_overlap: minimum data overlap in fraction 23 | Return: 24 | candidates: list of neighbouring station names 25 | """ 26 | similar_position = similarity["dist"] >= np.exp(-max_dist / 25) 27 | similar_elevation = similarity["elev"] >= np.exp(-max_elev_diff / 100) 28 | candidates = similarity["target"][similar_position & similar_elevation].values 29 | candidates = [item for item in candidates if item != similarity["station"].item()] 30 | 31 | # Filter out stations with insufficient data overlap 32 | is_valid = stn_data.notnull() 33 | if not is_valid.any(): 34 | return [] 35 | is_valid_neighboring = data.sel(station=candidates).notnull() 36 | overlap = (is_valid & is_valid_neighboring).sum(dim="time") / is_valid.sum() 37 | candidates = overlap["station"].values[overlap >= min_data_overlap] 38 | 39 | similarity = similarity.sel(target=candidates).sortby("dist", ascending=False) 40 | lon_diff = similarity["lon"] - lon 41 | # In case of the longitudes cross the central meridian 42 | iswest = ((lon_diff < 0) & (np.abs(lon_diff) < 180)) | ((lon_diff > 0) & (np.abs(lon_diff) > 180)) 43 | isnorth = (similarity["lat"] - lat) > 0 44 | northwest_stations = similarity["target"][iswest & isnorth].values[:2] 45 | northeast_stations = similarity["target"][~iswest & isnorth].values[:2] 46 | southwest_stations = similarity["target"][iswest & ~isnorth].values[:2] 47 | southeast_stations = similarity["target"][~iswest & ~isnorth].values[:2] 48 | return np.concatenate([northwest_stations, northeast_stations, southwest_stations, southeast_stations]) 49 | 50 | 51 | def _load_similarity(fpath, dataset): 52 | """ 53 | Load the similarity calculated at the previous step 54 | """ 55 | similarity = xr.load_dataset(fpath) 56 | similarity = similarity.sel(station1=dataset["station"].values, station2=dataset["station"].values) 57 | similarity = similarity.assign_coords(lon=("station2", dataset["lon"].values)) 58 | similarity = similarity.assign_coords(lat=("station2", dataset["lat"].values)) 59 | similarity = similarity.rename({"station1": "station", "station2": "target"}) 60 | return similarity 61 | 62 | 63 | def _neighbouring_stations_check_base( 64 | flag, 65 | similarity, 66 | data, 67 | max_dist, 68 | max_elev_diff, 69 | min_data_overlap, 70 | shift_step, 71 | gap_scale, 72 | default_mad, 73 | suspect_std_scale, 74 | min_num, 75 | ): 76 | """ 77 | Check all stations in data by comparing them with neighbouring stations 78 | For each time step, when there are at least 3 neighbouring values available, 79 | and 2 / 3 of them are in agreement (either verified or dubious), save the result 80 | Parameters: 81 | ----------- 82 | flag: a initialized DataArray with the same station dimension as the data 83 | similarity: similarity matrix between stations used for selecting neighbouring stations 84 | data: DataArray to be checked 85 | max_dist/max_elev_diff/min_data_overlap: parameters for selecting neighbouring stations 86 | shift_step/gap_scale/default_mad/suspect_std_scale: parameters for the time series comparison 87 | min_num: minimum number of valid data points to perform the check 88 | Returns: 89 | -------- 90 | flag: Updated flag DataArray 91 | """ 92 | for station in flag["station"].values: 93 | neighbours = _select_neighbouring_stations( 94 | similarity.sel(station=station), 95 | data.sel(station=station), 96 | lat=data.sel(station=station)["lat"].item(), 97 | lon=data.sel(station=station)["lon"].item(), 98 | data=data, 99 | max_dist=max_dist, 100 | max_elev_diff=max_elev_diff, 101 | min_data_overlap=min_data_overlap, 102 | ) 103 | # Only apply the check to stations with at least 3 neighbouring stations 104 | if len(neighbours) < 3: 105 | continue 106 | results = [] 107 | 108 | for target in neighbours: 109 | ith_flag = _time_series_comparison( 110 | data.sel(station=station).values, 111 | data.sel(station=target).values, 112 | shift_step=shift_step, 113 | gap_scale=gap_scale, 114 | default_mad=default_mad, 115 | suspect_std_scale=suspect_std_scale, 116 | min_num=min_num, 117 | ) 118 | results.append(ith_flag) 119 | results = np.stack(results, axis=0) 120 | # For each time step, count the number of valid neighbouring data points 121 | num_neighbours = np.sum(results != CONFIG["flag_missing"], axis=0) 122 | # For each time step, still only consider the data points with at least 3 neighbouring data points 123 | num_neighbours = np.where(num_neighbours < 3, np.nan, num_neighbours) 124 | # Mask when 2 / 3 of the neighbouring data points are in agreement 125 | # Specifically, set erroneous only when there is no normal flags 126 | min_stn = num_neighbours * 2 / 3 127 | normal = (results == CONFIG["flag_normal"]).sum(axis=0) 128 | suspect = ((results == CONFIG["flag_suspect"]) | (results == CONFIG["flag_error"])).sum(axis=0) 129 | erroneous = (results == CONFIG["flag_error"]).sum(axis=0) 130 | aggregated = np.full_like(normal, CONFIG["flag_missing"]) 131 | aggregated = np.where(normal >= min_stn, CONFIG["flag_normal"], aggregated) 132 | aggregated = np.where(suspect >= min_stn, CONFIG["flag_suspect"], aggregated) 133 | aggregated = np.where((erroneous >= min_stn) & (normal == 0), CONFIG["flag_error"], aggregated) 134 | flag.loc[{"station": station}] = aggregated 135 | return flag 136 | 137 | 138 | def run(da, f_similarity, varname): 139 | """ 140 | Check the data by comparing with neighbouring stations 141 | Time series comparison is performed for each station with at least 3 neighbouring stations 142 | `map_blocks` is used to parallelize the calculation 143 | """ 144 | similarity = _load_similarity(f_similarity, da) 145 | 146 | flag = xr.DataArray( 147 | np.full(da.shape, CONFIG["flag_missing"], dtype=np.int8), 148 | dims=["station", "time"], 149 | coords={k: da.coords[k].values for k in da.dims} 150 | ) 151 | ret = xr.map_blocks( 152 | _neighbouring_stations_check_base, 153 | flag.chunk({"station": 500}), 154 | args=(similarity.chunk({"station": 500}), ), 155 | kwargs={"data": da, **CONFIG["neighbouring"][varname]}, 156 | template=flag.chunk({"station": 500}), 157 | ).compute(scheduler='processes') 158 | ret = ret.where(da.notnull(), CONFIG["flag_missing"]) 159 | quality_control_statistics(da, ret) 160 | return ret.rename("neighbouring_stations") 161 | -------------------------------------------------------------------------------- /quality_control/algo/persistence.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import bottleneck as bn 3 | from .utils import get_config, intra_station_check, quality_control_statistics 4 | 5 | 6 | CONFIG = get_config() 7 | 8 | 9 | def _persistence_main(ts, unset_flag, min_num, max_window, min_var, error_length, exclude_value=None): 10 | """ 11 | Perform persistence check on a time series. 12 | 13 | This function identifies periods of low variability in the time series and flags them as suspect or error. 14 | 15 | Parameters: 16 | ----------- 17 | ts (array-like): The input time series. 18 | unset_flag (array-like): Pre-existing flags for the time series. 19 | min_num (int): Minimum number of valid values required in a window. 20 | max_window (int): Maximum size of the moving window. 21 | min_var (float): Minimum allowed standard deviation in a window. 22 | error_length (int): Minimum length of a period to be flagged as error. 23 | exclude_value (float or list, optional): Value(s) to exclude from the analysis. 24 | 25 | Algorithm: 26 | ---------- 27 | 1. Scan the time series with moving windows of decreasing size (max_window to min_num). 28 | 2. Flag windows with standard deviation < min_var as suspect. 29 | 3. Merge overlapping suspect windows. 30 | 4. Flag merged windows as error if their length >= error_length. 31 | 5. Flag remaining suspect windows based on their length and pre-existing flags. 32 | 33 | Returns: 34 | -------- 35 | numpy.ndarray: 1D array of flags with the same length as the input time series. 36 | """ 37 | flag = np.full(ts.size, CONFIG["flag_missing"], dtype=np.int8) 38 | suspect_windows = [] 39 | 40 | # Bottleneck for >100x faster moving window implementations but it only works well with float64 41 | ts = ts.astype(np.float64) 42 | if exclude_value is not None: 43 | if isinstance(exclude_value, list): 44 | for value in exclude_value: 45 | ts[ts == value] = np.nan 46 | else: 47 | ts[ts == exclude_value] = np.nan 48 | 49 | for window_size in range(max_window, min_num-1, -1): 50 | # min_count will ensure that the std is calculated only when there are enough valid values 51 | std = bn.move_std(ts, window=window_size, min_count=min_num)[window_size-1:] 52 | valid_indices = np.argwhere(~np.isnan(std)).flatten() 53 | for shift in range(window_size): 54 | # Values checked in at least one window are considered as valid 55 | flag[valid_indices + shift] = CONFIG["flag_normal"] 56 | error_index = np.argwhere(std < min_var).flatten() 57 | if error_index.size == 0: 58 | continue 59 | suspect_windows.extend([(i, i+window_size) for i in error_index]) 60 | 61 | isvalid = ~np.isnan(ts) 62 | if len(suspect_windows) == 0: 63 | flag[~isvalid] = CONFIG["flag_missing"] 64 | return flag 65 | 66 | # trim the NaNs at both ends of each window 67 | for idx, (start, end) in enumerate(suspect_windows): 68 | start = start + np.argmax(isvalid[start:end]) 69 | end = end - np.argmax(isvalid[start:end][::-1]) 70 | suspect_windows[idx] = (start, end) 71 | 72 | # Combine the overlapping windows 73 | suspect_windows.sort(key=lambda x: x[0]) 74 | suspect_windows_merged = [suspect_windows[0]] 75 | for current in suspect_windows[1:]: 76 | last = suspect_windows_merged[-1] 77 | if current[0] < last[1]: 78 | suspect_windows_merged[-1] = (last[0], max(last[1], current[1])) 79 | else: 80 | suspect_windows_merged.append(current) 81 | 82 | for start, end in suspect_windows_merged: 83 | if end - start >= error_length: 84 | flag[start:end] = CONFIG["flag_error"] 85 | else: 86 | # Set error flag only when more than 5% values are flagged by other methods 87 | num_suspend = (unset_flag[start:end] == CONFIG["flag_suspect"]).sum() 88 | num_error = (unset_flag[start:end] == CONFIG["flag_error"]).sum() 89 | num_valid = isvalid[start:end].sum() 90 | if num_suspend + num_error > num_valid * 0.05: 91 | flag[start:end] = CONFIG["flag_error"] 92 | else: 93 | flag[start:end] = CONFIG["flag_suspect"] 94 | 95 | flag[~isvalid] = CONFIG["flag_missing"] 96 | return flag 97 | 98 | 99 | def run(da, unset_flag, varname): 100 | flag = intra_station_check( 101 | da, 102 | unset_flag, 103 | qc_func=_persistence_main, 104 | input_core_dims=[["time"], ["time"]], 105 | kwargs=CONFIG["persistence"][varname], 106 | ) 107 | quality_control_statistics(da, flag) 108 | return flag.rename("persistence") 109 | -------------------------------------------------------------------------------- /quality_control/algo/record_extreme.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .utils import get_config, intra_station_check, quality_control_statistics 3 | 4 | 5 | CONFIG = get_config() 6 | 7 | 8 | def _record_extreme_main(ts, upper, lower): 9 | """ 10 | Flag outliers as erroneous based on the record extremes. 11 | 12 | Parameters: 13 | ----------- 14 | ts : np.ndarray 15 | 1D time series to be checked 16 | upper : float 17 | Upper bound of the record extreme 18 | lower : float 19 | Lower bound of the record extreme 20 | 21 | Returns: 22 | -------- 23 | flag : np.ndarray 24 | 1D array with the same length as ts, containing flags 25 | """ 26 | flag_upper = ts > upper 27 | flag_lower = ts < lower 28 | flag = np.full(ts.shape, CONFIG["flag_normal"], dtype=np.int8) 29 | flag[np.logical_or(flag_upper, flag_lower)] = CONFIG["flag_error"] 30 | flag[np.isnan(ts)] = CONFIG["flag_missing"] 31 | return flag 32 | 33 | 34 | def run(da, varname): 35 | flag = intra_station_check( 36 | da, 37 | qc_func=_record_extreme_main, 38 | kwargs=CONFIG["record"][varname], 39 | ) 40 | quality_control_statistics(da, flag) 41 | return flag.rename("record_extreme") 42 | -------------------------------------------------------------------------------- /quality_control/algo/refinement.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .utils import get_config, intra_station_check 3 | 4 | 5 | CONFIG = get_config() 6 | 7 | 8 | def _is_monotonic(ts, start_idx, end_idx): 9 | # Search the left and right nearest non-error values 10 | for left_idx in range(start_idx-1, max(0, start_idx-4), -1): 11 | # Note that here the found value is bound to be non-error 12 | if np.isfinite(ts[left_idx]): 13 | break 14 | else: 15 | left_idx = None 16 | for right_idx in range(end_idx, min(end_idx+3, len(ts))): 17 | if np.isfinite(ts[right_idx]): 18 | break 19 | else: 20 | right_idx = None 21 | # If these values are monotonic, downgrading the error flag to suspect 22 | if left_idx is not None and right_idx is not None: 23 | ts_period = ts[left_idx: right_idx + 1] 24 | ts_period = ts_period[~np.isnan(ts_period)] 25 | diff = np.diff(ts_period) 26 | if (diff > 0).all() or (diff < 0).all(): 27 | return True 28 | return False 29 | 30 | 31 | def _is_ridge_or_trough(ts, start_idx, end_idx): 32 | # Search the left and right nearest 3 non-error values 33 | left_values = ts[max(0, start_idx-12): max(0, start_idx)] 34 | left_values = left_values[np.isfinite(left_values)] 35 | if len(left_values) < 4: 36 | return False 37 | left_values = left_values[-4:] 38 | right_values = ts[min(len(ts), end_idx): min(len(ts), end_idx+12)] 39 | right_values = right_values[np.isfinite(right_values)] 40 | if len(right_values) < 4: 41 | return False 42 | right_values = right_values[:4] 43 | 44 | ts_period = ts[start_idx: end_idx] 45 | ts_period = np.concatenate([left_values, ts_period[np.isfinite(ts_period)], right_values]) 46 | min_idx = np.argmin(ts_period) 47 | if 3 < min_idx < len(ts_period) - 4: 48 | diff = np.diff(ts_period) 49 | # Check if it is a trough (e.g., low pressure) 50 | if (diff[:min_idx] < 0).all() and (diff[min_idx:] > 0).all(): 51 | return True 52 | max_idx = np.argmax(ts_period) 53 | if 3 < max_idx < len(ts_period) - 4: 54 | diff = np.diff(ts_period) 55 | # Check if it is a ridge (e.g., temperature peak) 56 | if (diff[:max_idx] > 0).all() and (diff[max_idx:] < 0).all(): 57 | return True 58 | return False 59 | 60 | 61 | def _refine_flag(cur_flag, ts, check_monotonic, check_ridge_trough): 62 | """ 63 | Refine the flags from other algorithms: 64 | check_monotonic: If True, downgrade error flags to suspect if they are 65 | situated in a monotonic period in-between non-error values 66 | check_ridge_trough: If True, downgrade error flags to suspect if they are 67 | situated in a ridge or trough in-between non-error values 68 | Returns: 69 | -------- 70 | new_flag: The refined flags 71 | """ 72 | new_flag = cur_flag.copy() 73 | if not check_monotonic and not check_ridge_trough: 74 | return new_flag 75 | length = len(cur_flag) 76 | error_flags = np.argwhere(cur_flag == CONFIG["flag_error"]).flatten() 77 | cur_idx = 0 78 | for idx in error_flags: 79 | if cur_idx > idx: 80 | continue 81 | cur_idx = idx + 1 82 | num_nan = 0 83 | num_error = 1 84 | # Search for the next non-error value 85 | while num_nan <= 2 and cur_idx < length: 86 | # If there are more than 3 consecutive missing values or error flags, do nothing 87 | if np.isnan(ts[cur_idx]): 88 | num_nan += 1 89 | cur_idx += 1 90 | continue 91 | if cur_flag[cur_idx] == CONFIG["flag_error"]: 92 | num_nan = 0 93 | num_error += 1 94 | cur_idx += 1 95 | continue 96 | # If a non-error value is found, check if it is monotonic 97 | if num_error > 3: 98 | break 99 | if check_monotonic and _is_monotonic(ts, idx, cur_idx): 100 | new_flag[idx:cur_idx] = CONFIG["flag_suspect"] 101 | if check_ridge_trough and _is_ridge_or_trough(ts, idx, cur_idx): 102 | new_flag[idx:cur_idx] = CONFIG["flag_suspect"] 103 | break 104 | new_flag[np.isnan(ts)] = CONFIG["flag_missing"] 105 | return new_flag 106 | 107 | 108 | def run(da, flag, varname): 109 | flag = intra_station_check( 110 | flag, 111 | da, 112 | qc_func=_refine_flag, 113 | input_core_dims=[["time"], ["time"]], 114 | output_core_dims=[["time"]], 115 | kwargs={ 116 | "check_monotonic": CONFIG["refinement"][varname]["check_monotonic"], 117 | "check_ridge_trough": CONFIG["refinement"][varname]["check_ridge_trough"] 118 | }, 119 | ) 120 | return flag.rename("refinement") 121 | -------------------------------------------------------------------------------- /quality_control/algo/spike.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import bottleneck as bn 3 | from .utils import get_config, intra_station_check, quality_control_statistics 4 | 5 | 6 | CONFIG = get_config() 7 | 8 | 9 | def _spike_check_forward(ts, unset_flag, max_change): 10 | """ 11 | Perform a forward spike check on a time series. 12 | 13 | A spike is defined as: 14 | 1. An abrupt change in value 15 | 2. Followed by an abrupt change back 16 | 3. With no abrupt changes during the spike 17 | 18 | This function flags values as suspect at both ends of abrupt changes and all values within a spike. 19 | Due to the high sensitivity of spike detection, if all values in a spike are at least suspect 20 | in unset_flag, they will be flagged as errors. 21 | 22 | Parameters: 23 | ----------- 24 | ts : array-like 25 | The input time series to check for spikes. 26 | unset_flag : array-like 27 | Reference flags from other methods. 28 | max_change : array-like 29 | 1D array with 3 elements, representing the maximum allowed change for 1, 2, and 3 steps. 30 | 31 | Returns: 32 | -------- 33 | flag : numpy.ndarray 34 | 1D array with the same length as ts, containing flags for each value. 35 | """ 36 | length = len(ts) 37 | flag = np.full(length, CONFIG["flag_missing"], dtype=np.int8) 38 | isnan = np.isnan(ts).astype(int) 39 | indices = [] 40 | gaps = [] 41 | # Get all the indices of potential spikes 42 | for step in range(1, 4, 1): 43 | diff = np.abs(ts[step:] - ts[:-step]) 44 | # Flag the values where the variation is checked 45 | flag[step:] = np.where( 46 | ~np.isnan(diff) & (flag[step:] == CONFIG["flag_missing"]), 47 | CONFIG["flag_normal"], 48 | flag[step:], 49 | ) 50 | condition = diff > max_change[step - 1] 51 | if step > 1: 52 | # For step larger than 1, only conditions that the in-between values are all NaN are considered 53 | allnan = bn.move_sum(isnan, window=step - 1)[step - 1: -1] == step - 1 54 | condition = condition & allnan 55 | indices_step = np.argwhere(condition).flatten() + step 56 | # flag both sides of the abrupt change as suspect 57 | flag[indices_step] = CONFIG["flag_suspect"] 58 | flag[indices_step - step] = CONFIG["flag_suspect"] 59 | # Save the gap between abnormal variations to be used later 60 | gaps.extend(np.full(len(indices_step), step)) 61 | indices.extend(indices_step) 62 | # Combine the potential spikes 63 | sorted_indices = np.argsort(indices) 64 | indices = np.array(indices)[sorted_indices] 65 | gaps = np.array(gaps)[sorted_indices] 66 | 67 | cur_idx = 0 68 | # Iterate all potential spikes 69 | for case_idx, start_idx in enumerate(indices): 70 | # To avoid checking on the end of the spike 71 | if start_idx <= cur_idx: 72 | continue 73 | # The value before the spike 74 | leftv = ts[start_idx - gaps[case_idx]] 75 | # The newest value in the spike 76 | lastv = ts[start_idx] 77 | # A threshold for detecting the end of a spike 78 | return_diff = np.abs((lastv - leftv) / 2) 79 | # The direction of the variation, positive(negative) for decreasing(increasing) 80 | change_sign = np.sign(leftv - lastv) 81 | cur_idx = start_idx + 1 82 | num_nan = 0 83 | # Start to search for the end of the spike. The spike should be shorter than 72 steps 84 | while cur_idx < length and cur_idx - start_idx < 72: 85 | if np.isnan(ts[cur_idx]): 86 | num_nan += 1 87 | cur_idx += 1 88 | # Consider 3 continuous NaNs as end of a spike 89 | if num_nan >= 3: 90 | # In experimental tests, if the value changes drastically and then stop recording, 91 | # this single value is considered as a spike 92 | if cur_idx - start_idx <= 4: 93 | break 94 | # Else, it is not considered as a spike 95 | cur_idx = start_idx - 1 96 | break 97 | continue 98 | num_nan = 0 99 | 100 | isabrupt = np.abs(ts[cur_idx] - lastv) > max_change[num_nan] 101 | isopposite = (lastv - ts[cur_idx]) * change_sign < 0 102 | isnear = np.abs(ts[cur_idx] - leftv) <= return_diff 103 | if not isabrupt: 104 | # if the value changes back slowly, it is considered as normal variation 105 | if isnear: 106 | cur_idx = start_idx - 1 107 | break 108 | # if there is no abrupt change, and the value is still far from the original value 109 | # continue searching for the end of the spike 110 | lastv = ts[cur_idx] 111 | cur_idx += 1 112 | continue 113 | # if there is an abrupt change, stop searching 114 | # If the value goes back with an opposite abrupt change, it is considered as the end of the spike 115 | if isopposite and isnear: 116 | break 117 | # Else, skip this case to avoid too complex situations 118 | cur_idx = start_idx - 1 119 | break 120 | else: 121 | cur_idx = start_idx - 1 122 | # Only flag the spike as erroneous when all the values are at least suspect 123 | if (unset_flag[start_idx:cur_idx] != CONFIG["flag_normal"]).all(): 124 | flag[start_idx:cur_idx] = CONFIG["flag_error"] 125 | else: 126 | flag[start_idx:cur_idx] = CONFIG["flag_suspect"] 127 | flag[isnan.astype(bool)] = CONFIG["flag_missing"] 128 | return flag 129 | 130 | 131 | def _bidirectional_spike_check(ts, unset_flag, max_change): 132 | """ 133 | Perform bidirectional spike check on the time series so that when there is an abrupt change, 134 | both directions will be checked 135 | The combined result is the higher flag of the two directions 136 | """ 137 | flag_forward = _spike_check_forward(ts, unset_flag, max_change) 138 | flag_backward = _spike_check_forward(ts[::-1], unset_flag[::-1], max_change)[::-1] 139 | flag = np.full(len(ts), CONFIG["flag_missing"], dtype=np.int8) 140 | for flag_type in ["normal", "suspect", "error"]: 141 | for direction in [flag_forward, flag_backward]: 142 | flag[direction == CONFIG[f"flag_{flag_type}"]] = CONFIG[f"flag_{flag_type}"] 143 | return flag 144 | 145 | 146 | def run(da, unset_flag, varname): 147 | flag = intra_station_check( 148 | da, 149 | unset_flag, 150 | qc_func=_bidirectional_spike_check, 151 | input_core_dims=[["time"], ["time"]], 152 | kwargs=CONFIG["spike"][varname], 153 | ) 154 | quality_control_statistics(da, flag) 155 | return flag.rename("spike") 156 | -------------------------------------------------------------------------------- /quality_control/algo/time_series.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import stats 3 | from .utils import get_config 4 | 5 | 6 | CONFIG = get_config() 7 | 8 | 9 | def _get_mean_and_mad(values): 10 | mean = np.median(values) 11 | mad = stats.median_abs_deviation(values) 12 | return mean, mad 13 | 14 | 15 | def _time_series_comparison( 16 | ts1, 17 | ts2, 18 | shift_step, 19 | gap_scale, 20 | default_mad, 21 | suspect_std_scale, 22 | min_mad=0.1, 23 | min_num=None, 24 | mask=None, 25 | ): 26 | """ 27 | Perform time series comparison between two datasets and flag potential errors. 28 | 29 | This function compares two time series (ts1 and ts2) and identifies suspect and erroneous values 30 | based on their differences. It uses a robust statistical approach to handle outliers and 31 | can accommodate temporal shifts between the series. 32 | 33 | Parameters: 34 | ----------- 35 | ts1 : array-like 36 | The primary time series to be checked. 37 | ts2 : array-like 38 | The reference time series for comparison. 39 | shift_step : int 40 | The maximum number of time steps to shift ts2 when comparing with ts1. 41 | gap_scale : float 42 | The scale factor applied to the median absolute deviation (MAD) to determine the gap threshold. 43 | default_mad : float 44 | The default MAD value to use when the calculated MAD is too small or when there are insufficient data points. 45 | suspect_std_scale : float 46 | The number of standard deviations from the mean to set the initial suspect threshold. 47 | min_mad : float, optional 48 | The minimum MAD value to calculate standard deviation, default is 0.1. 49 | min_num : int, optional 50 | The minimum number of valid data points required for robust statistics calculation. 51 | If the number of valid points is less than this, default values are used. 52 | mask : array-like, optional 53 | Boolean mask to select a subset of ts1 for calculating bounds. If None, all values are used. 54 | True for normal values, False for outliers 55 | 56 | Returns: 57 | -------- 58 | flag : numpy.ndarray 59 | 1D array with the same length as ts, containing flags for each value. 60 | """ 61 | diff = ts1 - ts2 62 | values = diff[~np.isnan(diff)] 63 | 64 | # Apply mask to diff if provided 65 | if mask is not None: 66 | masked_diff = diff[mask] 67 | values = masked_diff[~np.isnan(masked_diff)] 68 | else: 69 | values = diff[~np.isnan(diff)] 70 | 71 | if values.size == 0: 72 | return np.full(diff.size, CONFIG["flag_missing"], dtype=np.int8) 73 | 74 | if min_num is not None and values.size < min_num: 75 | fixed_mean = 0 76 | mad = default_mad 77 | else: 78 | # An estimate of the Gaussian distribution of the data which is calculated by the median and MAD 79 | # so that it is robust to outliers 80 | fixed_mean, mad = _get_mean_and_mad(values) 81 | mad = max(min_mad, mad) 82 | 83 | # Get the suspect threshold by the distance to the mean in the unit of standard deviation 84 | # Reference: y = 0.1, scale = 1.67; y = 0.05, scale = 2.04; y = 0.01, scale = 2.72 85 | # If the standard deviation estimated by MAD is too small, a default value is used 86 | fixed_std = max(default_mad, mad) * 1.4826 87 | init_upper_bound = fixed_mean + fixed_std * suspect_std_scale 88 | # For observations that the actual precision is integer, the upper and lower bounds are rounded up 89 | is_integer = np.nanmax(ts1 % 1) < 0.1 or np.nanmax(ts2 % 1) < 0.1 90 | if is_integer: 91 | init_upper_bound = np.ceil(init_upper_bound) 92 | # Set the erroneous threshold by find a gap larger than a multiple of the MAD 93 | larger_values = np.insert(np.sort(values[values > init_upper_bound]), 0, init_upper_bound) 94 | # Try to get the index of first value where the gap larger than min_gap 95 | gap = mad * gap_scale 96 | large_gap = np.diff(larger_values) > gap 97 | # If a gap is not found, no erroneous threshold is set 98 | upper_bound = larger_values[np.argmax(large_gap)] if large_gap.any() else np.max(values) 99 | if is_integer: 100 | upper_bound = np.ceil(upper_bound) 101 | 102 | init_lower_bound = fixed_mean - fixed_std * suspect_std_scale 103 | if is_integer: 104 | init_lower_bound = np.floor(init_lower_bound) 105 | smaller_values = np.insert(np.sort(values[values < init_lower_bound])[::-1], 0, init_lower_bound) 106 | small_gap = np.diff(smaller_values) < -gap 107 | lower_bound = smaller_values[np.argmax(small_gap)] if small_gap.any() else np.min(values) 108 | if is_integer: 109 | lower_bound = np.floor(lower_bound) 110 | 111 | min_diff = diff 112 | # If shift_step > 0, the values in ts1 are also compared to neighboring values in ts2 113 | # The minimum difference is kept to be inclusive of a certain degree of temporal deviation 114 | for shift in np.arange(-shift_step, shift_step+1, 1): 115 | diff_shifted = ts1 - np.roll(ts2, shift) 116 | if shift == 0: 117 | continue 118 | if shift > 0: 119 | diff_shifted[:shift] = np.nan 120 | elif shift < 0: 121 | diff_shifted[shift:] = np.nan 122 | min_diff = np.where(np.abs(diff_shifted - fixed_mean) < np.abs(min_diff - fixed_mean), diff_shifted, min_diff) 123 | 124 | flag = np.full_like(min_diff, CONFIG["flag_normal"], dtype=np.int8) 125 | flag[min_diff < init_lower_bound] = CONFIG["flag_suspect"] 126 | flag[min_diff > init_upper_bound] = CONFIG["flag_suspect"] 127 | flag[min_diff < lower_bound] = CONFIG["flag_error"] 128 | flag[min_diff > upper_bound] = CONFIG["flag_error"] 129 | flag[np.isnan(min_diff)] = CONFIG["flag_missing"] 130 | return flag 131 | -------------------------------------------------------------------------------- /quality_control/algo/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import numpy as np 4 | import xarray as xr 5 | import yaml 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def configure_logging(verbose=1): 12 | verbose_levels = { 13 | 0: logging.WARNING, 14 | 1: logging.INFO, 15 | 2: logging.DEBUG, 16 | 3: logging.NOTSET 17 | } 18 | if verbose not in verbose_levels: 19 | verbose = 1 20 | root_logger = logging.getLogger() 21 | root_logger.setLevel(verbose_levels[verbose]) 22 | handler = logging.StreamHandler() 23 | handler.setFormatter(logging.Formatter( 24 | "[%(asctime)s] [PID=%(process)d] " 25 | "[%(levelname)s %(filename)s:%(lineno)d] %(message)s")) 26 | handler.setLevel(verbose_levels[verbose]) 27 | root_logger.addHandler(handler) 28 | 29 | 30 | class Config: 31 | _instance = None 32 | 33 | def __new__(cls, config_path=None): 34 | if cls._instance is None: 35 | if config_path is None: 36 | config_path = os.path.join(os.path.dirname(__file__), "config.yaml") 37 | cls._instance = super(Config, cls).__new__(cls) 38 | cls._instance.__init__(config_path) 39 | return cls._instance 40 | 41 | def __init__(self, config_path): 42 | if not hasattr(self, 'config'): 43 | self._load_config(config_path) 44 | 45 | def _load_config(self, config_path): 46 | with open(config_path, "r", encoding="utf-8") as file: 47 | self.config = yaml.safe_load(file) 48 | 49 | def get(self, key): 50 | if key not in self.config: 51 | raise KeyError(f"Key '{key}' not found in configuration") 52 | return self.config[key] 53 | 54 | 55 | def get_config(config_path=None): 56 | return Config(config_path).config 57 | 58 | 59 | def intra_station_check( 60 | *dataarrays, 61 | qc_func=lambda da: da, 62 | input_core_dims=None, 63 | output_core_dims=None, 64 | kwargs=None, 65 | ): 66 | """ 67 | A wrapper function to apply the quality control functions to each station in a DataArray 68 | Multiprocessing is implemented by Dask 69 | Parameters: 70 | dataarrays: DataArrays to be checked, with more auxiliary DataArrays if needed 71 | qc_func: quality control function to be applied 72 | input_core_dims: core dimensions (to be remained for the function) of each DataArray 73 | output_core_dims: core dimensions (to be remained from the function) of each function result 74 | kwargs: keyword arguments for the quality control function 75 | Return: 76 | flag: DataArray with the same shape as the first input DataArray 77 | """ 78 | dataarrays_chunked = [] 79 | for item in dataarrays: 80 | if isinstance(item, xr.DataArray): 81 | dataarrays_chunked.append( 82 | item.chunk({k: 100 if k == "station" else -1 for k in item.dims}) 83 | ) 84 | else: 85 | dataarrays_chunked.append(item) 86 | if input_core_dims is None: 87 | input_core_dims = [["time"]] 88 | if output_core_dims is None: 89 | output_core_dims = [["time"]] 90 | if kwargs is None: 91 | kwargs = {} 92 | flag = xr.apply_ufunc( 93 | qc_func, 94 | *dataarrays_chunked, 95 | input_core_dims=input_core_dims, 96 | output_core_dims=output_core_dims, 97 | kwargs=kwargs, 98 | vectorize=True, 99 | dask="parallelized", 100 | output_dtypes=[np.int8], 101 | ).compute(scheduler="processes") 102 | return flag 103 | 104 | 105 | CONFIG = get_config() 106 | 107 | 108 | def merge_flags(flags, priority=None): 109 | """ 110 | Merge flags from different quality control functions in the order of priority 111 | Prior flags will be overwritten by subsequent flags 112 | """ 113 | ret = xr.full_like(flags[0], CONFIG["flag_missing"], dtype=np.int8) 114 | if priority is None: 115 | priority = ["normal", "suspect", "error"] 116 | for flag_type in priority: 117 | for item in flags: 118 | ret = xr.where(item == CONFIG[f"flag_{flag_type}"], CONFIG[f"flag_{flag_type}"], ret) 119 | return ret 120 | 121 | 122 | def quality_control_statistics(data, flag): 123 | num_valid = data.notnull().sum().item() 124 | num_normal = (flag == CONFIG["flag_normal"]).sum().item() 125 | num_suspect = (flag == CONFIG["flag_suspect"]).sum().item() 126 | num_error = (flag == CONFIG["flag_error"]).sum().item() 127 | num_checked = num_normal + num_suspect + num_error 128 | logger.debug(f"{num_valid / data.size:.5%} of the data are valid") 129 | logger.debug(f"{num_checked / num_valid:.5%} of the valid data are checked") 130 | logger.debug( 131 | "%s/%s/%s of the checked data are flagged as normal/suspect/erroneous", 132 | f"{num_normal / num_checked:.5%}", 133 | f"{num_suspect / num_checked:.5%}", 134 | f"{num_error / num_checked:.5%}" 135 | ) 136 | return num_valid, num_normal, num_suspect, num_error 137 | -------------------------------------------------------------------------------- /quality_control/download_ISD.py: -------------------------------------------------------------------------------- 1 | """ 2 | Download ISD data from NOAA 3 | All available stations on the server will be downloaded 4 | """ 5 | 6 | import argparse 7 | import logging 8 | import multiprocessing 9 | from functools import partial 10 | from pathlib import Path 11 | 12 | import pandas as pd 13 | import requests 14 | from bs4 import BeautifulSoup 15 | from tqdm import tqdm 16 | from algo.utils import configure_logging 17 | 18 | 19 | URL_DATA = "https://www.ncei.noaa.gov/data/global-hourly/access/" 20 | 21 | # Disable the logging from urllib3 22 | logging.getLogger("urllib3.connectionpool").setLevel(logging.CRITICAL) 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def download_file(station, url, output_dir, overwrite=False): 27 | response = requests.get(url + f"{station}.csv", timeout=30) 28 | output_path = Path(output_dir) / f"{station}.csv" 29 | if output_path.exists() and not overwrite: 30 | return True 31 | if response.status_code == 200: 32 | with open(output_path, "wb") as file: 33 | file.write(response.content) 34 | return True 35 | return False 36 | 37 | 38 | def download_ISD(year, station_list, output_dir, num_proc, overwrite=False): 39 | """ 40 | Download the data from the source for the stations in station_list. 41 | """ 42 | logger.info("Start downloading") 43 | output_dir = Path(output_dir) / str(year) 44 | if not output_dir.exists(): 45 | output_dir.mkdir(parents=True, exist_ok=True) 46 | 47 | url = URL_DATA + f"/{year}/" 48 | 49 | if num_proc == 1: 50 | for item in tqdm(station_list, desc="Downloading files"): 51 | download_file(item, url, output_dir, overwrite) 52 | else: 53 | func = partial(download_file, url=url, output_dir=output_dir, overwrite=overwrite) 54 | with multiprocessing.Pool(num_proc) as pool: 55 | results = list(tqdm(pool.imap(func, station_list), total=len(station_list), desc="Downloading files")) 56 | 57 | successful_downloads = sum(results) 58 | logger.info(f"Successfully downloaded {successful_downloads} out of {len(station_list)} files") 59 | 60 | 61 | def main(): 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument( 64 | "-o", "--output-dir", type=str, required=True, help="Parent directory to save the downloaded data" 65 | ) 66 | parser.add_argument("-y", "--year", type=int, default=pd.Timestamp.now().year, help="Year to download") 67 | parser.add_argument("--overwrite", action="store_true", help="Overwrite existing files") 68 | parser.add_argument("--verbose", type=int, default=1, help="Verbosity level (int >= 0)") 69 | parser.add_argument("--num-proc", type=int, default=16, help="Number of parallel processes") 70 | args = parser.parse_args() 71 | configure_logging(args.verbose) 72 | 73 | response = requests.get(URL_DATA + f"/{args.year}/", timeout=30) 74 | soup = BeautifulSoup(response.text, "html.parser") 75 | file_list = [link.get("href") for link in soup.find_all("a") if link.get("href").endswith(".csv")] 76 | station_list = [item.split(".")[0] for item in file_list] 77 | logger.info(f"Found {len(station_list)} stations for year {args.year}") 78 | download_ISD(args.year, station_list, args.output_dir, args.num_proc, args.overwrite) 79 | 80 | 81 | if __name__ == "__main__": 82 | main() 83 | -------------------------------------------------------------------------------- /quality_control/quality_control.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script performs quality control checks on weather observation data using various algorithms. 3 | It flags and filters out erroneous or suspicious data points. 4 | 5 | Usage: 6 | python quality_control.py --obs-path OBS_PATH --rnl-path RNL_PATH --similarity-path SIM_PATH \ 7 | --output-path OUTPUT_PATH [--config-path CONFIG_PATH] [--verbose VERBOSE] 8 | 9 | Input Files: 10 | - Observation data (--obs-path): NetCDF file generated at the previous step by `station_merging.py`. 11 | You can also use your own observation data. The dimension should be 'station' and 'time' and 12 | include all the hours in a same year, 13 | e.g., the dimension of WeatherReal-ISD in 2023: {'station': 13297, 'time': 8760} 14 | Supported variables: t, td, sp, msl, c, ws, wd, ra1, ra3, ra6, ra12, ra24. 15 | - Reanalysis data (--rnl-path): NetCDF file containing reanalysis data in the same format as obs. 16 | You can refer to `convert_grid_to_point` in `evaluation/forecast_reformat_catalog.py` to interpolations. 17 | - Similarity matrix (--similarity-path): File containing station similarity information. 18 | It is also generated at the previous step by `station_merging.py`. 19 | You can also generate your own similarity matrix with `calc_similarity` in `station_merging.py`. 20 | - Config file (--config-path): YAML file containing algorithm parameters. 21 | You can refer to the default settings - `config.yaml` in the `algo` directory as an example. 22 | Output: 23 | - Quality controlled observation data (--output-path): NetCDF file with erroneous data removed. 24 | 25 | Please refer to the WeatherReal paper for more details. 26 | """ 27 | 28 | 29 | import argparse 30 | import logging 31 | import os 32 | import numpy as np 33 | import pandas as pd 34 | import xarray as xr 35 | from algo import ( 36 | record_extreme, 37 | cluster, 38 | distributional_gap, 39 | neighbouring_stations, 40 | spike, 41 | persistence, 42 | cross_variable, 43 | refinement, 44 | diurnal_cycle, 45 | fine_tuning, 46 | ) 47 | from algo.utils import merge_flags, Config, get_config, configure_logging 48 | 49 | 50 | logger = logging.getLogger(__name__) 51 | CONFIG = get_config() 52 | 53 | 54 | def load_data(obs_path, rnl_path): 55 | """ 56 | Load observation and reanalysis data 57 | The reanalysis data should be in the same format as observation data (interpolated to the same stations) 58 | """ 59 | logger.info("Loading observation and reanalysis data") 60 | obs = xr.load_dataset(obs_path) 61 | year = obs["time"].dt.year.values[0] 62 | if not (obs["time"].dt.year == year).all(): 63 | raise ValueError("Data contains multiple years, which is not supported yet") 64 | 65 | full_hours = pd.date_range( 66 | start=pd.Timestamp(f"{year}-01-01"), end=pd.Timestamp(f"{year}-12-31 23:00:00"), freq='h') 67 | if obs["time"].size != full_hours.size or (obs["time"] != full_hours).any(): 68 | logger.warning("Reindexing observation data to match full hours in the year") 69 | obs = obs.reindex(time=full_hours) 70 | 71 | # Please prepare the Reanalysis data in the same format as obs 72 | rnl = xr.load_dataset(rnl_path) 73 | if obs.sizes != rnl.sizes: 74 | raise ValueError("The sizes of obs and rnl are different") 75 | return obs, rnl 76 | 77 | 78 | def cross_variable_check(obs): 79 | flag_cross = xr.Dataset() 80 | # Super-saturation check 81 | flag_cross["t"] = cross_variable.supersaturation(obs["t"], obs["td"]) 82 | flag_cross["td"] = flag_cross["t"].copy() 83 | # Wind consistency check 84 | flag_cross["ws"] = cross_variable.wind_consistency(obs["ws"], obs["wd"]) 85 | flag_cross["wd"] = flag_cross["ws"].copy() 86 | flag_ra = cross_variable.ra_consistency(obs[["ra1", "ra3", "ra6", "ra12", "ra24"]]) 87 | for varname in ["ra1", "ra3", "ra6", "ra12", "ra24"]: 88 | flag_cross[varname] = flag_ra[varname].copy() 89 | return flag_cross 90 | 91 | 92 | def quality_control(obs, rnl, f_similarity): 93 | varlist = obs.data_vars.keys() 94 | result = obs.copy() 95 | 96 | # Record extreme check 97 | flag_extreme = xr.Dataset() 98 | for varname in CONFIG["record"]: 99 | if varname not in varlist: 100 | continue 101 | logger.info(f"Record extreme check for {varname}...") 102 | flag_extreme[varname] = record_extreme.run(result[varname], varname) 103 | # For extreme value check, outliers are directly removed 104 | result[varname] = result[varname].where(flag_extreme[varname] != CONFIG["flag_error"]) 105 | 106 | # Cluster deviation check 107 | flag_cluster = xr.Dataset() 108 | for varname in CONFIG["cluster"]: 109 | if varname not in varlist: 110 | continue 111 | logger.info(f"Cluster deviation check for {varname}...") 112 | flag_cluster[varname] = cluster.run(result[varname], rnl[varname], varname) 113 | 114 | # Distributional gap check 115 | flag_distribution = xr.Dataset() 116 | for varname in CONFIG["distribution"]: 117 | if varname not in varlist: 118 | continue 119 | logger.info(f"Distributional gap check for {varname}...") 120 | # Mask from cluster deviation check is used to exclude abnormal data in the following distributional gap check 121 | mask = flag_cluster[varname] == CONFIG["flag_normal"] 122 | flag_distribution[varname] = distributional_gap.run(result[varname], rnl[varname], varname, mask) 123 | 124 | # Neighbouring station check 125 | flag_neighbour = xr.Dataset() 126 | for varname in CONFIG["neighbouring"]: 127 | if varname not in varlist: 128 | continue 129 | logger.info(f"Neighbouring station check for {varname}...") 130 | flag_neighbour[varname] = neighbouring_stations.run(result[varname], f_similarity, varname) 131 | 132 | # Spike check 133 | flag_dis_neigh = xr.Dataset() 134 | flag_spike = xr.Dataset() 135 | for varname in CONFIG["spike"]: 136 | if varname not in varlist: 137 | continue 138 | logger.info(f"Spike check for {varname}...") 139 | # Merge flags from distributional gap and neighbouring station check for spike and persistence check 140 | flag_dis_neigh[varname] = merge_flags( 141 | [flag_distribution[varname], flag_neighbour[varname]], priority=["error", "suspect", "normal"] 142 | ) 143 | flag_spike[varname] = spike.run(result[varname], flag_dis_neigh[varname], varname) 144 | 145 | # Persistence check 146 | flag_persistence = xr.Dataset() 147 | for varname in CONFIG["persistence"]: 148 | if varname not in varlist: 149 | continue 150 | logger.info(f"Persistence check for {varname}...") 151 | # Some variables are not checked by distributional gap or neighbouring station check 152 | if varname not in flag_dis_neigh: 153 | flag_dis_neigh_cur = xr.full_like(result[varname], CONFIG["flag_missing"], dtype=np.int8) 154 | else: 155 | flag_dis_neigh_cur = flag_dis_neigh[varname] 156 | flag_persistence[varname] = persistence.run(result[varname], flag_dis_neigh_cur, varname) 157 | 158 | # Cross variable check 159 | flag_cross = cross_variable_check(result) 160 | 161 | # Merge all flags 162 | flags = xr.Dataset() 163 | for varname in varlist: 164 | logger.info(f"Merging flags for {varname}...") 165 | merge_list = [ 166 | item[varname] 167 | for item in [flag_extreme, flag_dis_neigh, flag_spike, flag_persistence, flag_cross] 168 | if varname in item 169 | ] 170 | flags[varname] = merge_flags(merge_list, priority=["normal", "suspect", "error"]) 171 | 172 | # Flag refinement 173 | flags_refined = xr.Dataset() 174 | for varname in varlist: 175 | if varname not in CONFIG["refinement"].keys(): 176 | flags_refined[varname] = flags[varname].copy() 177 | else: 178 | logger.info(f"Flag refinement for {varname}...") 179 | flags_refined[varname] = refinement.run(result[varname], flags[varname], varname) 180 | if varname == "t": 181 | flags_refined[varname] = diurnal_cycle.run(result[varname], flags_refined[varname]) 182 | 183 | # Fine-tuning the flags 184 | flags_final = xr.Dataset() 185 | for varname in varlist: 186 | logger.info(f"Fine-tuning (upgrade suspect) flags for {varname}...") 187 | flags_final[varname] = fine_tuning.run(flags_refined[varname], obs[varname]) 188 | 189 | flags = { 190 | "flag_extreme": flag_extreme, 191 | "flag_cluster": flag_cluster, 192 | "flag_distribution": flag_distribution, 193 | "flag_neighbour": flag_neighbour, 194 | "flag_spike": flag_spike, 195 | "flag_persistence": flag_persistence, 196 | "flag_cross": flag_cross, 197 | "flags_refined": flags_refined, 198 | "flags_final": flags_final, 199 | } 200 | 201 | return flags 202 | 203 | 204 | def main(args): 205 | obs, rnl = load_data(args.obs_path, args.rnl_path) 206 | flags = quality_control(obs, rnl, args.similarity_path) 207 | if args.output_flags_dir: 208 | logger.info(f"Saving flags to {args.output_flags_dir}") 209 | for algo_name, flag_spec in flags.items(): 210 | flag_spec.to_netcdf(os.path.join(args.output_flags_dir, f"{algo_name}.nc")) 211 | flags_final = flags["flags_final"] 212 | for varname in obs.data_vars.keys(): 213 | obs[varname] = obs[varname].where(flags_final[varname] != CONFIG["flag_error"]) 214 | obs.to_netcdf(args.output_path) 215 | logger.info(f"Quality control finished. The results are saved to {args.output_path}") 216 | 217 | 218 | if __name__ == "__main__": 219 | parser = argparse.ArgumentParser() 220 | parser.add_argument( 221 | "--obs-path", type=str, required=True, help="Data path of observation to be quality controlled" 222 | ) 223 | parser.add_argument( 224 | "--rnl-path", type=str, required=True, help="Data path of reanalysis data to be used for quality control" 225 | ) 226 | parser.add_argument("--similarity-path", type=str, required=True, help="Data path of similarity matrix") 227 | parser.add_argument("--output-path", type=str, required=True, help="Data path of output data") 228 | parser.add_argument("--output-flags-dir", type=str, help="If specified, flags will also be saved") 229 | parser.add_argument( 230 | "--config-path", type=str, help="Path to the configuration file, default is config.yaml in the algo directory" 231 | ) 232 | parser.add_argument("--verbose", type=int, default=1, help="Verbosity level (int >= 0)") 233 | parsed_args = parser.parse_args() 234 | configure_logging(parsed_args.verbose) 235 | Config(parsed_args.config_path) 236 | main(parsed_args) 237 | -------------------------------------------------------------------------------- /quality_control/raw_ISD_to_hourly.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is used for parsing ISD files and converting them to hourly data 3 | 1. Parse corresponding columns of each variable 4 | 2. Aggregate / reorganize columns if needed 5 | 3. Simple unit conversion 6 | 4. Aggregate to hourly data. Rows that represent same hour are merged by rules 7 | 5. No quality control is applied except removing records with original erroneous flags 8 | The outputs are still saved as csv files 9 | """ 10 | 11 | import argparse 12 | import multiprocessing 13 | import os 14 | from functools import partial 15 | 16 | import pandas as pd 17 | from tqdm import tqdm 18 | 19 | 20 | # Quality code for erroneous values flagged by ISD 21 | # To collect as much data as possible and avoid some values flagged by unknown codes being excluded, 22 | # We choose the strategy "only values marked with known error tags will be rejected" 23 | # instead of "only values marked with known correct tags will be accepted" 24 | ERRONEOUS_FLAGS = ["3", "7"] 25 | 26 | 27 | def parse_temperature_col(data): 28 | """ 29 | Process temperature and dew point temperature columns 30 | TMP/DEW column format: -0100,1 31 | Steps: 32 | 1. Set values flagged as erroneous/missing to NaN 33 | 2. Convert to float in Celsius 34 | """ 35 | if "TMP" in data.columns and data["TMP"].notnull().any(): 36 | data[["t", "t_qc"]] = data["TMP"].str.split(",", expand=True) 37 | data["t"] = data["t"].where(data["t"] != "+9999", pd.NA) 38 | data["t"] = data["t"].where(~data["t_qc"].isin(ERRONEOUS_FLAGS), pd.NA) 39 | # Scaling factor: 10 40 | data["t"] = data["t"].astype("Float32") / 10 41 | if "DEW" in data.columns and data["DEW"].notnull().any(): 42 | data[["td", "td_qc"]] = data["DEW"].str.split(",", expand=True) 43 | data["td"] = data["td"].where(data["td"] != "+9999", pd.NA).astype("Float32") / 10 44 | data["td"] = data["td"].where(~data["td_qc"].isin(ERRONEOUS_FLAGS), pd.NA) 45 | data = data.drop(columns=["TMP", "DEW", "t_qc", "td_qc"], errors="ignore") 46 | return data 47 | 48 | 49 | def parse_wind_col(data): 50 | """ 51 | Process wind speed and direction column 52 | WND column format: 267,1,N,0142,1 53 | N indicates normal (other values include Beaufort, Calm, etc.). Not used currently 54 | Steps: 55 | 1. Set values flagged as erroneous/missing to NaN 56 | Note that if one of ws or wd is missing, both are set to NaN 57 | Exception: If ws is 0 and wd is missing, wd is set to 0 (calm) 58 | 2. Convert wd to integer and ws to float in m/s 59 | """ 60 | if "WND" in data.columns and data["WND"].notnull().any(): 61 | data[["wd", "wd_qc", "wt", "ws", "ws_qc"]] = data["WND"].str.split(",", expand=True) 62 | # First, set wd to 0 if ws is valid 0 and wd is missing 63 | calm = (data["ws"] == "0000") & (~data["ws_qc"].isin(ERRONEOUS_FLAGS)) & (data["wd"] == "999") 64 | data.loc[calm, "wd"] = "000" 65 | data.loc[calm, "wd_qc"] = "1" 66 | # After that, if one of ws or wd is missing/erroneous, both are set to NaN 67 | non_missing = (data["wd"] != "999") & (data["ws"] != "9999") 68 | non_error = (~data["wd_qc"].isin(ERRONEOUS_FLAGS)) & (~data["ws_qc"].isin(ERRONEOUS_FLAGS)) 69 | valid = non_missing & non_error 70 | data["wd"] = data["wd"].where(valid, pd.NA) 71 | data["ws"] = data["ws"].where(valid, pd.NA) 72 | data["wd"] = data["wd"].astype("Int16") 73 | # Scaling factor: 10 74 | data["ws"] = data["ws"].astype("Float32") / 10 75 | data = data.drop(columns=["WND", "wd_qc", "wt", "ws_qc"], errors="ignore") 76 | return data 77 | 78 | 79 | def parse_cloud_col(data): 80 | """ 81 | Process total cloud cover column 82 | All known columns including GA1-6, GD1-6, GF1 and GG1-6 are parsed and Maximum value of them is selected 83 | 1. GA1-6 column format: 07,1,+00800,1,06,1 84 | The 1st and 2nd items are c and its quality 85 | 2. GD1-6 column format: 3,99,1,+05182,9,9 86 | The 1st item is cloud cover in 0-4 and is converted to octas by multiplying 2 87 | The 2st and 3nd items are c in octas and its quality 88 | 3. GF1 column format: 07,99,1,07,1,99,9,01000,1,99,9,99,9 89 | The 1st and 3rd items are total coverage and its quality 90 | 4. GG1-6 column format: 01,1,01200,1,06,1,99,9 91 | The 1st and 2nd items are c and its quality 92 | Cloud/sky-condition related data is very complex and worth further investigation 93 | Steps: 94 | 1. Set values flagged as erroneous/missing to NaN 95 | 2. Select the maximum value of all columns 96 | """ 97 | num = 0 98 | for group in ["GA", "GG"]: 99 | for col in [f"{group}{i}" for i in range(1, 7)]: 100 | if col in data.columns and data[col].notnull().any(): 101 | data[[f"c{num}", "c_qc", "remain"]] = data[col].str.split(",", n=2, expand=True) 102 | # 99 will be removed later 103 | data[f"c{num}"] = data[f"c{num}"].where(~data["c_qc"].isin(ERRONEOUS_FLAGS), pd.NA) 104 | data[f"c{num}"] = data[f"c{num}"].astype("Int16") 105 | num += 1 106 | else: 107 | break 108 | for col in [f"GD{i}" for i in range(1, 7)]: 109 | if col in data.columns and data[col].notnull().any(): 110 | data[[f"c{num}", f"c{num+1}", "c_qc", "remain"]] = data[col].str.split(",", n=3, expand=True) 111 | c_cols = [f"c{num}", f"c{num+1}"] 112 | data[c_cols] = data[c_cols].where(~data["c_qc"].isin(ERRONEOUS_FLAGS), pd.NA) 113 | data[c_cols] = data[c_cols].astype("Int16") 114 | # The first item is 5-level cloud cover and is converted to octas by multiplying 2 115 | data[f"c{num}"] = data[f"c{num}"] * 2 116 | num += 2 117 | else: 118 | break 119 | if "GF1" in data.columns and data["GF1"].notnull().any(): 120 | data[[f"c{num}", "opa", "c_qc", "remain"]] = data["GF1"].str.split(",", n=3, expand=True) 121 | data[f"c{num}"] = data[f"c{num}"].where(~data["c_qc"].isin(ERRONEOUS_FLAGS), pd.NA) 122 | data[f"c{num}"] = data[f"c{num}"].astype("Int16") 123 | num += 1 124 | c_cols = [f"c{i}" for i in range(num)] 125 | # Mask all values larger than 8 to NaN to avoid overwriting the correct values 126 | data[c_cols] = data[c_cols].where(data[c_cols] <= 8, pd.NA) 127 | # Maximum value of all columns is selected to represent the total cloud cover 128 | data["c"] = data[c_cols].max(axis=1) 129 | data = data.drop( 130 | columns=[ 131 | "GF1", 132 | *[f"GA{i}" for i in range(1, 7)], 133 | *[f"GG{i}" for i in range(1, 7)], 134 | *[f"GD{i}" for i in range(1, 7)], 135 | *[f"c{i}" for i in range(num)], 136 | "c_5", 137 | "opa", 138 | "c_qc", 139 | "remain", 140 | ], 141 | errors="ignore", 142 | ) 143 | return data 144 | 145 | 146 | def parse_surface_pressure_col(data): 147 | """ 148 | Process surface pressure (station-level pressure) column 149 | Currently MA1 column is used. Column format: 99999,9,09713,1 150 | The 3rd and 4th items are station pressure and its quality 151 | The 1st and 2nd items are altimeter setting and its quality which are not used currently 152 | Steps: 153 | 1. Set values flagged as erroneous/missing to NaN 154 | 2. Convert to float in hPa 155 | """ 156 | if "MA1" in data.columns and data["MA1"].notnull().any(): 157 | data[["MA1_remain", "sp", "sp_qc"]] = data["MA1"].str.rsplit(",", n=2, expand=True) 158 | data["sp"] = data["sp"].where(data["sp"] != "99999", pd.NA) 159 | data["sp"] = data["sp"].where(~data["sp_qc"].isin(ERRONEOUS_FLAGS), pd.NA) 160 | # Scaling factor: 10 161 | data["sp"] = data["sp"].astype("Float32") / 10 162 | data = data.drop(columns=["MA1", "MA1_remain", "sp_qc"], errors="ignore") 163 | return data 164 | 165 | 166 | def parse_sea_level_pressure_col(data): 167 | """ 168 | Process mean sea level pressure column 169 | MSL Column format: 09725,1 170 | Steps: 171 | 1. Set values flagged as erroneous/missing to NaN 172 | 2. Convert to float in hPa 173 | """ 174 | if "SLP" in data.columns and data["SLP"].notnull().any(): 175 | data[["msl", "msl_qc"]] = data["SLP"].str.rsplit(",", expand=True) 176 | data["msl"] = data["msl"].where(data["msl"] != "99999", pd.NA) 177 | data["msl"] = data["msl"].where(~data["msl_qc"].isin(ERRONEOUS_FLAGS), pd.NA) 178 | # Scaling factor: 10 179 | data["msl"] = data["msl"].astype("Float32") / 10 180 | data = data.drop(columns=["SLP", "msl_qc"], errors="ignore") 181 | return data 182 | 183 | 184 | def parse_single_precipitation_col(data, col): 185 | """ 186 | Parse one of the precipitation columns AA1-4 187 | """ 188 | if data[col].isnull().all(): 189 | return pd.DataFrame() 190 | datacol = data[[col]].copy() 191 | # Split the column to get the period first 192 | datacol[["period", f"{col}_remain"]] = datacol[col].str.split(",", n=1, expand=True) 193 | # Remove weird periods to avoid unexpected errors 194 | datacol = datacol[datacol["period"].isin(["01", "03", "06", "12", "24"])] 195 | if len(datacol) == 0: 196 | return pd.DataFrame() 197 | # Set the period as index and unstack so that different periods are converted to different columns 198 | datacol = datacol.set_index("period", append=True)[f"{col}_remain"] 199 | datacol = datacol.unstack("period") 200 | # Rename the columns according to the period, e.g., 03 -> ra3 201 | datacol.columns = [f"ra{item.lstrip('0')}" for item in datacol.columns] 202 | # Further split the remaining sections 203 | for var in datacol.columns: 204 | datacol[[var, f"{var}_cond", f"{var}_qc"]] = datacol[var].str.split(",", expand=True) 205 | return datacol 206 | 207 | 208 | def parse_precipitation_col(data): 209 | """ 210 | Process precipitation columns 211 | Currently AA1-4 columns are used. Column format: 24,0073,3,1 212 | The items are period, depth, condition, quality. Condition is not used currently 213 | It is more complex than other variables as values during different periods are stored in same columns 214 | It is needed to separate them to different columns 215 | Steps: 216 | 1. Separate and recombine columns by period 217 | 2. Set values flagged as erroneous/missing to NaN 218 | 3. Convert to float in mm 219 | """ 220 | for col in ["AA1", "AA2", "AA3", "AA4"]: 221 | if col in data.columns: 222 | datacol = parse_single_precipitation_col(data, col) 223 | # Same variable (e.g., ra24) may be stored in different original columns 224 | # Combine_first so that same variables can be merged to the same columns 225 | data = data.combine_first(datacol) 226 | else: 227 | # Assuming that the remaining columns are also not present 228 | break 229 | # Quality status treated as valid records. 3/7 indicates erroneous value 230 | for col in [item for item in data.columns if item.startswith("ra") and item[2:].isdigit()]: 231 | data[col] = data[col].where(data[col] != "9999", pd.NA) 232 | data[col] = data[col].where(~data[f"{col}_qc"].isin(ERRONEOUS_FLAGS), pd.NA) 233 | data[col] = data[col].astype("Float32") / 10 234 | data = data.drop(columns=[f"{col}_cond", f"{col}_qc"]) 235 | data = data.drop(columns=["AA1", "AA2", "AA3", "AA4"], errors="ignore") 236 | return data 237 | 238 | 239 | def parse_single_file(fpath, fpath_last_year): 240 | """ 241 | Parse columns of each variable in a single ISD file 242 | """ 243 | # Gxn for cloud cover, MA1 for surface pressure, AAn for precipitation 244 | cols_var = [ 245 | "TMP", 246 | "DEW", 247 | "WND", 248 | "SLP", 249 | "MA1", 250 | "AA1", 251 | "AA2", 252 | "AA3", 253 | "AA4", 254 | "GF1", 255 | *[f"GA{i}" for i in range(1, 7)], 256 | *[f"GD{i}" for i in range(1, 7)], 257 | *[f"GG{i}" for i in range(1, 7)], 258 | ] 259 | cols = ["DATE"] + list(cols_var) 260 | 261 | def _load_csv(fpath): 262 | return pd.read_csv(fpath, parse_dates=["DATE"], usecols=lambda c: c in set(cols), low_memory=False) 263 | 264 | data = _load_csv(fpath) 265 | if fpath_last_year is not None and os.path.exists(fpath_last_year): 266 | data_last_year = _load_csv(fpath_last_year) 267 | # Load the last day of the last year for better hourly aggregation 268 | data_last_year = data_last_year.loc[ 269 | (data_last_year["DATE"].dt.month == 12) & (data_last_year["DATE"].dt.day == 31) 270 | ] 271 | data = pd.concat([data_last_year, data], ignore_index=True) 272 | data = data[[item for item in cols if item in data.columns]] 273 | 274 | data = parse_temperature_col(data) 275 | data = parse_wind_col(data) 276 | data = parse_cloud_col(data) 277 | data = parse_surface_pressure_col(data) 278 | data = parse_sea_level_pressure_col(data) 279 | data = parse_precipitation_col(data) 280 | 281 | data = data.rename(columns={"DATE": "time"}) 282 | value_cols = [col for col in data.columns if col != "time"] 283 | # drop all-NaN rows 284 | data = data[["time"] + value_cols].sort_values("time").dropna(how="all", subset=value_cols) 285 | return data 286 | 287 | 288 | def aggregate_to_hourly(data): 289 | """ 290 | Aggregate rows that represent same hour to one row 291 | Order the rows from same hour by difference from the top of the hour, 292 | then use ffill at each hour to get the nearest valid values for each variable 293 | Specifically, For t/td, avoid combining two records from different rows together 294 | """ 295 | data["hour"] = data["time"].dt.round("h") 296 | # Sort data by difference from the top of the hour so that bfill can be applied 297 | # to give priority to the closer records 298 | data["hour_dist"] = (data["time"] - data["hour"]).dt.total_seconds().abs() // 60 299 | data = data.sort_values(["hour", "hour_dist"]) 300 | 301 | if data["hour"].duplicated().any(): 302 | # Consruct a new column of (t, td) tuples. Values are not NaN only when both of them are valid 303 | data["t_td"] = data.apply( 304 | lambda row: (row["t"], row["td"]) if row[["t", "td"]].notnull().all() else pd.NA, axis=1 305 | ) 306 | # For same hour, fill NaNs at the first row in the order of difference from the top of the hour 307 | data = data.groupby("hour").apply(lambda df: df.bfill().iloc[0], include_groups=False) 308 | 309 | # 1st priority: for hours that has both valid t and td originally (decided by t_td), 310 | # fill values to t_new and td_new 311 | # Specifically, for corner cases that all t_td is NaN, we need to convert pd.NA to (pd.NA, pd.NA) 312 | # so that to_list() will not raise an error 313 | data["t_td"] = data["t_td"].apply(lambda item: (pd.NA, pd.NA) if pd.isna(item) else item) 314 | data[["t_new", "td_new"]] = pd.DataFrame(data["t_td"].to_list(), index=data.index) 315 | # 2nd priority: Remaining hours can only provide at most one of t and td. Try to fill t first 316 | rows_to_fill = data[["t_new", "td_new"]].isnull().all(axis=1) 317 | data.loc[rows_to_fill, "t_new"] = data.loc[rows_to_fill, "t"] 318 | # 3nd priority: Remaining hours has no t during time window. Try to fill td 319 | rows_to_fill = data[["t_new", "td_new"]].isnull().all(axis=1) 320 | data.loc[rows_to_fill, "td_new"] = data.loc[rows_to_fill, "td"] 321 | 322 | data = data.drop(columns=["t", "td", "t_td"]).rename(columns={"t_new": "t", "td_new": "td"}) 323 | 324 | data = data.reset_index(drop=True) 325 | data["time"] = data["time"].dt.round("h") 326 | return data 327 | 328 | 329 | def post_process(data, year): 330 | """ 331 | Some post-processing steps after aggregation 332 | """ 333 | data = data.set_index("time") 334 | sorted_ra_columns = sorted([col for col in data.columns if col.startswith("ra")], key=lambda x: int(x[2:])) 335 | other_columns = [item for item in ["t", "td", "ws", "wd", "sp", "msl", "c"] if item in data.columns] 336 | data = data[other_columns + sorted_ra_columns] 337 | data = data[f"{year}-01-01":f"{year}-12-31"] 338 | return data 339 | 340 | 341 | def pipeline(input_path, output_dir, year, overwrite=True): 342 | """ 343 | The pipeline function for processing a single ISD file 344 | """ 345 | output_path = os.path.join(output_dir, os.path.basename(input_path)) 346 | if not overwrite and os.path.exists(output_path): 347 | return 348 | input_dir = os.path.dirname(input_path) 349 | if input_dir.endswith(str(year)): 350 | input_path_last_year = os.path.join(input_dir[:-4] + str(year - 1), os.path.basename(input_path)) 351 | else: 352 | input_path_last_year = None 353 | data = parse_single_file(input_path, input_path_last_year) 354 | data = aggregate_to_hourly(data) 355 | data = post_process(data, year) 356 | data.astype("Float32").to_csv(output_path, float_format="%.1f") 357 | 358 | 359 | def main(args): 360 | output_dir = os.path.join(args.output_dir, str(args.year)) 361 | os.makedirs(output_dir, exist_ok=True) 362 | input_dir = os.path.join(args.input_dir, str(args.year)) 363 | input_list = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith(".csv")] 364 | if args.num_proc == 1: 365 | for fpath in tqdm(input_list): 366 | pipeline(fpath, output_dir=output_dir, year=args.year, overwrite=args.overwrite) 367 | else: 368 | func = partial(pipeline, output_dir=output_dir, year=args.year, overwrite=args.overwrite) 369 | with multiprocessing.Pool(args.num_proc) as pool: 370 | for _ in tqdm(pool.imap(func, input_list), total=len(input_list), desc="Processing files"): 371 | pass 372 | 373 | 374 | if __name__ == "__main__": 375 | parser = argparse.ArgumentParser() 376 | parser.add_argument("-i", "--input-dir", type=str, required=True, help="Directory of ISD csv files") 377 | parser.add_argument("-o", "--output-dir", type=str, required=True, help="Directory of output NetCDF files") 378 | parser.add_argument("-y", "--year", type=int, required=True, help="Target year") 379 | parser.add_argument("--num-proc", type=int, default=16, help="Number of parallel processes") 380 | parser.add_argument("--overwrite", action="store_true", help="Overwrite existing files") 381 | parsed_args = parser.parse_args() 382 | main(parsed_args) 383 | -------------------------------------------------------------------------------- /quality_control/station_merging.py: -------------------------------------------------------------------------------- 1 | """ 2 | 1. Load the metadata of ISD data online. Post-processing includes: 3 | 1. Drop rows with missing values. These stations are considered not trustworthy 4 | 2. Use USAF and WBAN to create a unique station ID 5 | 2. Calculate pairwise similarity between stations and saved as NetCDF files 6 | - Distance and elevation similarity is calculated using an exponential decay 7 | - Distance of two stations are calculated using great circle distance 8 | 3. Load ISD hourly station data with valid metadata, and merge stations based on both metadata and data similarity 9 | - Metadata similarity is calculated based on a weighted sum of distance, elevation, and name similarity 10 | - Data similarity is calculated based on the proportion of identical values among common data points 11 | 4. The merged data and the similarity will be saved in NetCDF format 12 | """ 13 | 14 | import argparse 15 | import logging 16 | import os 17 | 18 | import numpy as np 19 | import pandas as pd 20 | import xarray as xr 21 | from algo.utils import configure_logging 22 | from geopy.distance import great_circle as geodist 23 | from tqdm import tqdm 24 | 25 | # URL of the official ISD metadata file 26 | URL_ISD_HISTORY = "https://www.ncei.noaa.gov/pub/data/noaa/isd-history.txt" 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | def load_metadata(station_list): 31 | """ 32 | Load the metadata of ISD data 33 | """ 34 | meta = pd.read_fwf( 35 | URL_ISD_HISTORY, 36 | skiprows=20, 37 | usecols=["USAF", "WBAN", "STATION NAME", "CTRY", "CALL", "LAT", "LON", "ELEV(M)"], 38 | dtype={"USAF": str, "WBAN": str}, 39 | ) 40 | # Drop rows with missing values. These stations are considered not trustworthy 41 | meta = meta.dropna(how="any", subset=["LAT", "LON", "STATION NAME", "ELEV(M)", "CTRY"]) 42 | meta["STATION"] = meta["USAF"] + meta["WBAN"] 43 | meta = meta[["STATION", "CALL", "STATION NAME", "CTRY", "LAT", "LON", "ELEV(M)"]].set_index("STATION", drop=True) 44 | return meta[meta.index.isin(station_list)] 45 | 46 | 47 | def calc_distance_similarity(latlon1, latlon2, scale_dist=25): 48 | distance = geodist(latlon1, latlon2).kilometers 49 | similarity = np.exp(-distance / scale_dist) 50 | return similarity 51 | 52 | 53 | def calc_elevation_similarity(elevation1, elevation2, scale_elev=100): 54 | similarity = np.exp(-abs(elevation1 - elevation2) / scale_elev) 55 | return similarity 56 | 57 | 58 | def calc_id_similarity(ids1, ids2): 59 | """ 60 | Compare the USAF/WBAN/CALL IDs of two stations 61 | """ 62 | usaf1, wban1, call1, ctry1 = ids1 63 | usaf2, wban2, call2, ctry2 = ids2 64 | if usaf1 != "999999" and usaf1 == usaf2: 65 | return 1 66 | if wban1 != "99999" and wban1 == wban2: 67 | return 1 68 | if call1 == call2: 69 | return 1 70 | # A special case for the CALL ID, e.g., KAIO and AIO are the same stations 71 | if isinstance(call1, str) and len(call1) == 3 and ("K" + call1) == call2: 72 | return 1 73 | if isinstance(call2, str) and len(call2) == 3 and ("K" + call2) == call1: 74 | return 1 75 | # For a special case in Germany, 09xxxx and 10xxxx are the same stations 76 | # See https://gi.copernicus.org/articles/5/473/2016/gi-5-473-2016.html 77 | if usaf1.startswith("09") and usaf2.startswith("10") and usaf1[2:] == usaf2[2:] and ctry1 == ctry2 == "DE": 78 | return 1 79 | return 0 80 | 81 | 82 | def calc_name_similarity(name1, name2): 83 | """ 84 | Jaccard Index for calculating name similarity 85 | """ 86 | set1 = set(name1) 87 | set2 = set(name2) 88 | intersection = set1.intersection(set2) 89 | union = set1.union(set2) 90 | jaccard_index = len(intersection) / len(union) 91 | return jaccard_index 92 | 93 | 94 | def post_process_similarity(similarity, ids): 95 | """ 96 | Post-process the similarity matrix 97 | """ 98 | # Fill the lower triangle of the matrix 99 | similarity = similarity + similarity.T 100 | # Set the diagonal (self-similarity) to 1 101 | np.fill_diagonal(similarity, 1) 102 | similarity = xr.DataArray(similarity, dims=["station1", "station2"], coords={"station1": ids, "station2": ids}) 103 | return similarity 104 | 105 | 106 | def calc_similarity(meta, scale_dist=25, scale_elev=100): 107 | """ 108 | Calculate pairwise similarity between stations 109 | 1. Distance similarity: great circle distance between two stations using an exponential decay 110 | 2. Elevation similarity: absolute difference of elevation between two stations using an exponential decay 111 | 3. ID similarity: whether the USAF/WBAN/CALL IDs of two stations are the same 112 | """ 113 | latlon = meta[["LAT", "LON"]].apply(tuple, axis=1).values 114 | elev = meta["ELEV(M)"].values 115 | usaf = meta.index.str[:6].values 116 | wban = meta.index.str[6:].values 117 | name = meta["STATION NAME"].values 118 | ids = list(zip(usaf, wban, meta["CALL"].values, meta["CTRY"].values)) 119 | num = len(meta) 120 | dist_similarity = np.zeros((num, num)) 121 | elev_similarity = np.zeros((num, num)) 122 | id_similarity = np.zeros((num, num)) 123 | name_similarity = np.zeros((num, num)) 124 | for idx1 in tqdm(range(num - 1), desc="Calculating similarity"): 125 | for idx2 in range(idx1 + 1, num): 126 | dist_similarity[idx1, idx2] = calc_distance_similarity(latlon[idx1], latlon[idx2], scale_dist) 127 | elev_similarity[idx1, idx2] = calc_elevation_similarity(elev[idx1], elev[idx2], scale_elev) 128 | id_similarity[idx1, idx2] = calc_id_similarity(ids[idx1], ids[idx2]) 129 | name_similarity[idx1, idx2] = calc_name_similarity(name[idx1], name[idx2]) 130 | dist_similarity = post_process_similarity(dist_similarity, meta.index.values) 131 | elev_similarity = post_process_similarity(elev_similarity, meta.index.values) 132 | id_similarity = post_process_similarity(id_similarity, meta.index.values) 133 | name_similarity = post_process_similarity(name_similarity, meta.index.values) 134 | similarity = xr.merge( 135 | [ 136 | dist_similarity.rename("dist"), 137 | elev_similarity.rename("elev"), 138 | id_similarity.rename("id"), 139 | name_similarity.rename("name"), 140 | ] 141 | ) 142 | return similarity 143 | 144 | 145 | def load_raw_data(data_dir, station_list): 146 | """ 147 | Load raw hourly ISD data in csv files 148 | """ 149 | data = [] 150 | for stn in tqdm(station_list, desc="Loading data"): 151 | df = pd.read_csv(os.path.join(data_dir, f"{stn}.csv"), index_col="time", parse_dates=["time"]) 152 | df["station"] = stn 153 | df = df.set_index("station", append=True) 154 | data.append(df.to_xarray()) 155 | data = xr.concat(data, dim="station") 156 | data["time"] = pd.to_datetime(data["time"]) 157 | return data 158 | 159 | 160 | def calc_meta_similarity(similarity): 161 | """ 162 | Calculate the metadata similarity according to horizontal distance, elevation, and name similarity 163 | """ 164 | meta_simi = (similarity["dist"] * 9 + similarity["elev"] * 1 + similarity["name"] * 5) / 15 165 | # meta_simi is set to 1 if IDs are the same, 166 | # or it is set to a weighted sum of distance, elevation, and name similarity 167 | meta_simi = np.maximum(meta_simi, similarity["id"]) 168 | # set the diagonal and lower triangle to NaN to avoid duplicated pairs 169 | rows, cols = np.indices(meta_simi.shape) 170 | meta_simi.values[rows >= cols] = np.nan 171 | return meta_simi 172 | 173 | 174 | def need_merge(da, stn_source, stn_target, threshold=0.7): 175 | """ 176 | Distinguish whether two stations need to be merged based on the similarity of their data 177 | It is possible that one of them only has few data points 178 | In this case, it can be treated as removing low-quality stations 179 | """ 180 | ts1 = da.sel(station=stn_source) 181 | ts2 = da.sel(station=stn_target) 182 | diff = np.abs(ts1 - ts2) 183 | if (ts1.dropna(dim="time") % 1 < 1e-3).all() or (ts2.dropna(dim="time") % 1 < 1e-3).all(): 184 | max_diff = 0.5 185 | else: 186 | max_diff = 0.1 187 | data_simi = (diff <= max_diff).sum() / diff.notnull().sum() 188 | return data_simi.item() >= threshold 189 | 190 | 191 | def merge_pairs(ds1, ds2): 192 | """ 193 | Merge two stations. Each of the two ds should have only one station 194 | If there are only one variable, fill the missing values in ds1 with ds2 195 | If there are more than one variables, to ensure that all variables are from the same station, 196 | for each timestep, the ds with more valid variables will be selected 197 | """ 198 | if len(ds1.data_vars) == 1: 199 | return ds1.fillna(ds2) 200 | da1 = ds1.to_array() 201 | da2 = ds2.to_array() 202 | mask = da1.count(dim="variable") >= da2.count(dim="variable") 203 | return xr.where(mask, da1, da2).to_dataset(dim="variable") 204 | 205 | 206 | def merge_stations(ds, meta_simi, main_var, appendant_var=None, meta_simi_th=0.35, data_simi_th=0.7): 207 | """ 208 | For ds, merge stations based on metadata similarity and data similarity 209 | """ 210 | result = [] 211 | # Flags to avoid duplications 212 | is_merged = xr.DataArray( 213 | np.full(ds["station"].size, False), dims=["station"], coords={"station": ds["station"].values} 214 | ) 215 | for station in tqdm(ds["station"].values, desc=f"Merging {main_var}"): 216 | if is_merged.sel(station=station).item(): 217 | continue 218 | # Station list to be merged 219 | merged_stations = [station] 220 | # Candidates that pass the metadata similarity threshold 221 | candidates = meta_simi["station2"][meta_simi.sel(station1=station) >= meta_simi_th].values 222 | # Stack to store the station pairs to be checked 223 | stack = [(station, item) for item in candidates] 224 | # Search for all stations that need to be merged 225 | # If A and B should be merged, and B and C should be merged, then all of them are merged together 226 | while stack: 227 | stn_source, stn_target = stack.pop() 228 | if stn_target in merged_stations: 229 | continue 230 | if need_merge(ds[main_var], stn_source, stn_target, threshold=data_simi_th): 231 | is_merged.loc[stn_target] = True 232 | merged_stations.append(stn_target) 233 | candidates = meta_simi["station2"][meta_simi.sel(station1=stn_target) >= meta_simi_th].values 234 | stack.extend([(stn_target, item) for item in candidates]) 235 | # Merge stations according to the number of valid data points 236 | num_valid = ds[main_var].sel(station=merged_stations).notnull().sum(dim="time") 237 | sorted_stns = num_valid["station"].sortby(num_valid).values 238 | variables = [main_var] + appendant_var if appendant_var is not None else [main_var] 239 | stn_data = ds[variables].sel(station=sorted_stns[0]) 240 | for target_stn in sorted_stns[1:]: 241 | stn_data = merge_pairs(stn_data, ds[variables].sel(station=target_stn)) 242 | stn_data = stn_data.assign_coords(station=station) 243 | result.append(stn_data) 244 | result = xr.concat(result, dim="station") 245 | return result 246 | 247 | 248 | def merge_all_variables(data, meta_simi, meta_simi_th=0.35, data_simi_th=0.7): 249 | """ 250 | Merge stations for each variable 251 | The key is the main variable used to compare, and the value is the list of appendant variables 252 | """ 253 | variables = { 254 | "t": ["td"], 255 | "ws": ["wd"], 256 | "sp": [], 257 | "msl": [], 258 | "c": [], 259 | "ra1": ["ra3", "ra6", "ra12", "ra24"], 260 | } 261 | merged = [] 262 | for var, app_vars in variables.items(): 263 | ret = merge_stations(data, meta_simi, var, app_vars, meta_simi_th=meta_simi_th, data_simi_th=data_simi_th) 264 | merged.append(ret) 265 | merged = xr.merge(merged).dropna(dim="station", how="all") 266 | return merged 267 | 268 | 269 | def assign_meta_coords(ds, meta): 270 | meta = meta.loc[ds["station"].values] 271 | ds = ds.assign_coords( 272 | call=("station", meta["CALL"].values), 273 | name=("station", meta["STATION NAME"].values), 274 | lat=("station", meta["LAT"].values.astype(np.float32)), 275 | lon=("station", meta["LON"].values.astype(np.float32)), 276 | elev=("station", meta["ELEV(M)"].values.astype(np.float32)), 277 | ) 278 | return ds 279 | 280 | 281 | def main(args): 282 | station_list = [item.rsplit(".", 1)[0] for item in os.listdir(args.data_dir) if item.endswith(".csv")] 283 | meta = load_metadata(station_list) 284 | 285 | similarity = calc_similarity(meta) 286 | os.makedirs(args.output_dir, exist_ok=True) 287 | similarity_path = os.path.join(args.output_dir, f"similarity_{args.scale_dist}km_{args.scale_elev}m.nc") 288 | similarity.astype("float32").to_netcdf(similarity_path) 289 | logger.info(f"Saved similarity to {similarity_path}") 290 | 291 | data = load_raw_data(args.data_dir, meta.index.values) 292 | # some stations have no data, they have already been removed in load_raw_data 293 | meta = meta.loc[data["station"].values] 294 | 295 | similarity = similarity.sel(station1=meta.index.values, station2=meta.index.values) 296 | meta_simi = calc_meta_similarity(similarity) 297 | merged = merge_all_variables(data, meta_simi, meta_simi_th=args.meta_simi_th, data_simi_th=args.data_simi_th) 298 | # Save all metadata information in the NetCDF file 299 | merged = assign_meta_coords(merged, meta) 300 | data_path = os.path.join(args.output_dir, "data.nc") 301 | merged.astype(np.float32).to_netcdf(data_path) 302 | logger.info(f"Saved data to {data_path}") 303 | 304 | 305 | if __name__ == "__main__": 306 | parser = argparse.ArgumentParser() 307 | parser.add_argument( 308 | "-o", "--output-dir", type=str, required=True, help="Directory of output similarity and data files" 309 | ) 310 | parser.add_argument( 311 | "-d", "--data-dir", type=str, required=True, help="Directory of ISD csv files from the previous step" 312 | ) 313 | parser.add_argument("--scale-dist", type=int, default=25, help="e-fold scale of distance similarity") 314 | parser.add_argument("--scale-elev", type=int, default=100, help="e-fold scale of elevation similarity") 315 | parser.add_argument("--meta-simi-th", type=float, default=0.35, help="Threshold of metadata similarity") 316 | parser.add_argument("--data-simi-th", type=float, default=0.7, help="Threshold of data similarity") 317 | parser.add_argument("--verbose", type=int, default=1, help="Verbosity level (int >= 0)") 318 | parsed_args = parser.parse_args() 319 | configure_logging(parsed_args.verbose) 320 | main(parsed_args) 321 | --------------------------------------------------------------------------------