├── src └── numerai_era_data │ ├── __init__.py │ ├── data_sources │ ├── __init__.py │ ├── base_data_source.py │ ├── ds_calendar.py │ ├── ds_wei.py │ ├── ds_markets.py │ └── ds_bls.py │ ├── date_utils.py │ └── era_data_api.py ├── .gitignore ├── .github └── workflows │ ├── publish.yml │ └── CI.yml ├── pyproject.toml ├── tests ├── data_sources │ ├── ds_calendar_test.py │ ├── ds_markets_test.py │ ├── ds_bls_test.py │ └── ds_wei_test.py ├── date_utils_test.py └── era_data_api_test.py ├── LICENSE └── README.md /src/numerai_era_data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/numerai_era_data/data_sources/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | __pycache__ 3 | venv/ 4 | .coverage 5 | exception.log 6 | .vscode/ 7 | src/numerai_era_data/cache/data.parquet 8 | src/numerai_era_data.egg-info/ 9 | src/numerai_era_data/cache/daily.parquet 10 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI.org 2 | on: 3 | workflow_dispatch: 4 | release: 5 | types: [published] 6 | jobs: 7 | pypi: 8 | runs-on: ubuntu-latest 9 | environment: "PyPI deployment" 10 | permissions: 11 | id-token: write 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@v3 15 | with: 16 | fetch-depth: 0 17 | - run: python3 -m pip install --upgrade build && python3 -m build 18 | - name: Publish package 19 | uses: pypa/gh-action-pypi-publish@release/v1 20 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "numerai_era_data" 3 | description = "era-level data for Numerai" 4 | readme = "README.md" 5 | version = "1.0.0" 6 | authors = [ 7 | { name = "Gregory Morse", email = "gregorymorse07@gmail.com" } 8 | ] 9 | dependencies = [ 10 | "pandas", 11 | "pyarrow", 12 | "yfinance", 13 | ] 14 | 15 | [project.optional-dependencies] 16 | dev = [ 17 | "black", 18 | "isort", 19 | "mock", 20 | "pytest", 21 | "ruff", 22 | ] 23 | 24 | [build-system] 25 | requires = ["setuptools>=61.2", "wheel"] 26 | 27 | [tool.black] 28 | line-length = 119 29 | 30 | [tool.ruff] 31 | line-length = 119 32 | 33 | [tool.coverage.run] 34 | branch = true 35 | 36 | [tool.coverage.report] 37 | show_missing = true 38 | skip_covered = true 39 | -------------------------------------------------------------------------------- /src/numerai_era_data/date_utils.py: -------------------------------------------------------------------------------- 1 | from datetime import date, datetime, timedelta, timezone 2 | 3 | ERA_ONE_START = date(2003, 1, 11) 4 | 5 | 6 | def get_current_era() -> int: 7 | return get_era_for_date(get_current_date()) 8 | 9 | 10 | def get_current_date() -> date: 11 | # noon UTC is the cutoff for the current date, will return yesterday's date if before noon 12 | return ( 13 | datetime.now(timezone.utc).replace(hour=12, minute=0, second=0, microsecond=0) 14 | - timedelta(days=datetime.now().hour < 12) 15 | ).date() 16 | 17 | 18 | def get_era_for_date(date: date) -> int: 19 | return (date - ERA_ONE_START).days // 7 + 1 # era is 1-indexed 20 | 21 | 22 | def get_date_for_era(era: int) -> date: 23 | return ERA_ONE_START + timedelta(days=(era - 1) * 7) 24 | -------------------------------------------------------------------------------- /.github/workflows/CI.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | 8 | jobs: 9 | build-and-test: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Checkout repository 14 | uses: actions/checkout@v3 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v2 18 | with: 19 | python-version: 3.11 20 | 21 | - name: Install dependencies 22 | run: pip install -U pip setuptools wheel build 23 | 24 | - name: Install dev dependencies 25 | run: pip install .[dev] 26 | 27 | - name: Build package 28 | run: python -m build 29 | 30 | - name: Install package 31 | run: pip install . 32 | 33 | - name: Run tests 34 | run: pytest 35 | 36 | -------------------------------------------------------------------------------- /src/numerai_era_data/data_sources/base_data_source.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from datetime import date 3 | 4 | import pandas as pd 5 | 6 | 7 | class BaseDataSource(ABC): 8 | _BASE_PREFIX = "era_feature_" 9 | _BASE_PREFIX_RAW = "era_feature_raw_" 10 | DATE_COL = "date" 11 | ERA_COL = "era" 12 | 13 | @abstractmethod 14 | def get_data(self, start_date: date, end_date: date) -> pd.DataFrame: # pragma: no cover 15 | """Returns a dataframe with the following columns: 16 | - date: datetime 17 | - one or more features 18 | feature data must correspond to data available by noon UTC on the given date 19 | start_date and end_date are inclusive""" 20 | pass 21 | 22 | @abstractmethod 23 | def get_columns(self) -> list: # pragma: no cover 24 | pass 25 | -------------------------------------------------------------------------------- /tests/data_sources/ds_calendar_test.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | 3 | from numerai_era_data.data_sources.base_data_source import BaseDataSource 4 | from numerai_era_data.data_sources.ds_calendar import DataSourceCalendar 5 | 6 | 7 | def test_get_data(): 8 | ds_calendar = DataSourceCalendar() 9 | ds_data = ds_calendar.get_data(date(2012, 1, 1), 10 | date(2022, 1, 1)) 11 | 12 | assert ds_data.shape[0] == 3654 13 | assert ds_data.iloc[0][BaseDataSource.DATE_COL] == date(2012, 1, 1) 14 | assert ds_data.iloc[-1][BaseDataSource.DATE_COL] == date(2022, 1, 1) 15 | 16 | 17 | def test_get_columns(): 18 | ds_calendar = DataSourceCalendar() 19 | ds_columns = ds_calendar.get_columns() 20 | ds_data = ds_calendar.get_data(date(2012, 1, 1), 21 | date(2012, 1, 8)) 22 | 23 | data_columns = [column for column in ds_data.columns if column != BaseDataSource.DATE_COL] 24 | assert ds_columns == data_columns 25 | -------------------------------------------------------------------------------- /src/numerai_era_data/data_sources/ds_calendar.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | 3 | import pandas as pd 4 | 5 | from numerai_era_data.data_sources.base_data_source import BaseDataSource 6 | 7 | 8 | class DataSourceCalendar(BaseDataSource): 9 | _PREFIX = BaseDataSource._BASE_PREFIX + "calendar_" 10 | 11 | # columns 12 | COLUMN_MONTH = _PREFIX + "month" 13 | COLUMN_QUARTER = _PREFIX + "quarter" 14 | COLUMN_YEAR = _PREFIX + "year" 15 | COLUMNS = [COLUMN_MONTH, COLUMN_QUARTER, COLUMN_YEAR] 16 | 17 | def get_data(self, start_date: date, end_date: date) -> pd.DataFrame: 18 | data = pd.DataFrame() 19 | data[self.DATE_COL] = pd.date_range(start_date, end_date) 20 | data[self.COLUMN_MONTH] = data[self.DATE_COL].dt.month 21 | data[self.COLUMN_QUARTER] = data[self.DATE_COL].dt.quarter 22 | data[self.COLUMN_YEAR] = data[self.DATE_COL].dt.year 23 | data[self.DATE_COL] = data[self.DATE_COL].dt.date 24 | 25 | return data 26 | 27 | def get_columns(self) -> list: 28 | return self.COLUMNS 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Gregory Morse 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/data_sources/ds_markets_test.py: -------------------------------------------------------------------------------- 1 | from datetime import date, datetime 2 | 3 | import pytz 4 | 5 | from numerai_era_data.data_sources.base_data_source import BaseDataSource 6 | from numerai_era_data.data_sources.ds_markets import DataSourceMarkets 7 | 8 | 9 | def test_get_data(): 10 | ds_markets = DataSourceMarkets() 11 | ds_data = ds_markets.get_data(date(2012, 1, 1), date(2022, 1, 1)) 12 | 13 | assert ds_data.iloc[0][BaseDataSource.DATE_COL] == date(2012, 1, 1) 14 | assert ds_data.iloc[-1][BaseDataSource.DATE_COL] == date(2022, 1, 1) 15 | assert round(ds_data.loc[ds_data[BaseDataSource.DATE_COL] == date(2012, 10, 24)] \ 16 | [DataSourceMarkets.COLUMN_SPX_CLOSE].values[0], 2) == 1413.11 17 | 18 | 19 | def test_get_data_today(): 20 | ds_markets = DataSourceMarkets() 21 | ds_data = ds_markets.get_data(date(2012, 1, 1), datetime.now(pytz.timezone('US/Eastern')).date()) 22 | 23 | assert ds_data.iloc[-1][BaseDataSource.DATE_COL] == datetime.now(pytz.timezone('US/Eastern')).date() 24 | 25 | 26 | def test_get_columns(): 27 | ds_markets = DataSourceMarkets() 28 | ds_columns = ds_markets.get_columns() 29 | ds_data = ds_markets.get_data(date(2012, 1, 1), date(2012, 1, 8)) 30 | 31 | data_columns = [column for column in ds_data.columns if column != BaseDataSource.DATE_COL] 32 | assert ds_columns == data_columns 33 | -------------------------------------------------------------------------------- /tests/date_utils_test.py: -------------------------------------------------------------------------------- 1 | from datetime import date, datetime, timezone 2 | 3 | from mock import patch 4 | 5 | from numerai_era_data.date_utils import (get_current_date, get_current_era, 6 | get_date_for_era, get_era_for_date) 7 | 8 | 9 | def test_get_current_era(): 10 | assert get_current_era() > 0 11 | 12 | def test_get_current_date_before_noon(): 13 | with patch("numerai_era_data.date_utils.datetime") as mock_datetime: 14 | mock_datetime.now.return_value = datetime(2023, 5, 27, 11, 59, 0, 0, timezone.utc) 15 | assert get_current_date() == date(2023, 5, 26) 16 | 17 | def test_get_current_date_after_noon(): 18 | with patch("numerai_era_data.date_utils.datetime") as mock_datetime: 19 | mock_datetime.now.return_value = datetime(2023, 5, 27, 12, 0, 0, 0, timezone.utc) 20 | assert get_current_date() == date(2023, 5, 27) 21 | 22 | def test_get_era_for_date_identity(): 23 | assert get_era_for_date(get_date_for_era(1)) == 1 24 | 25 | def test_get_era_for_date_day_before(): 26 | assert get_era_for_date(date(2023, 5, 26)) == 1063 27 | 28 | def test_get_era_for_date_day_after(): 29 | assert get_era_for_date(date(2023, 5, 27)) == 1064 30 | 31 | def test_get_date_for_era_1(): 32 | assert get_date_for_era(1) == date(2003, 1, 11) 33 | 34 | def test_get_date_for_era_1063(): 35 | assert get_date_for_era(1063) == date(2023, 5, 20) 36 | -------------------------------------------------------------------------------- /tests/data_sources/ds_bls_test.py: -------------------------------------------------------------------------------- 1 | from datetime import date, datetime 2 | 3 | import pytz 4 | 5 | from numerai_era_data.data_sources.base_data_source import BaseDataSource 6 | from numerai_era_data.data_sources.ds_bls import DataSourceBLS 7 | 8 | 9 | def test_get_data(): 10 | ds_bls = DataSourceBLS() 11 | ds_data = ds_bls.get_data(date(2012, 1, 1), date(2022, 1, 1)) 12 | 13 | assert ds_data.iloc[0][BaseDataSource.DATE_COL] == date(2012, 1, 1) 14 | assert ds_data.iloc[-1][BaseDataSource.DATE_COL] == date(2022, 1, 1) 15 | assert round(ds_data.loc[ds_data[BaseDataSource.DATE_COL] == date(2012, 2, 17)] \ 16 | [DataSourceBLS.COLUMN_CPI_U].values[0], 3) == 226.740 17 | assert round(ds_data.loc[ds_data[BaseDataSource.DATE_COL] == date(2012, 2, 18)] \ 18 | [DataSourceBLS.COLUMN_CPI_U].values[0], 3) == 227.237 19 | 20 | 21 | def test_get_data_today(): 22 | ds_bls = DataSourceBLS() 23 | ds_data = ds_bls.get_data(date(2012, 1, 1), datetime.now(pytz.timezone('US/Eastern')).date()) 24 | 25 | assert ds_data.iloc[-1][BaseDataSource.DATE_COL] == datetime.now(pytz.timezone('US/Eastern')).date() 26 | 27 | 28 | def test_get_columns(): 29 | ds_bls = DataSourceBLS() 30 | ds_columns = ds_bls.get_columns() 31 | ds_data = ds_bls.get_data(date(2012, 1, 1), date(2012, 1, 8)) 32 | 33 | data_columns = [column for column in ds_data.columns if column != BaseDataSource.DATE_COL] 34 | assert ds_columns == data_columns 35 | -------------------------------------------------------------------------------- /tests/data_sources/ds_wei_test.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | 3 | from numerai_era_data.data_sources.base_data_source import BaseDataSource 4 | from numerai_era_data.data_sources.ds_wei import DataSourceWEI 5 | 6 | 7 | def test_get_data(): 8 | ds_wei = DataSourceWEI() 9 | ds_data = ds_wei.get_data(date(2012, 1, 1), 10 | date(2022, 1, 1)) 11 | 12 | assert ds_data.iloc[0][BaseDataSource.DATE_COL] == date(2012, 1, 1) 13 | assert ds_data.iloc[-1][BaseDataSource.DATE_COL] == date(2022, 1, 1) 14 | 15 | 16 | def test_get_data_single_date(): 17 | ds_wei = DataSourceWEI() 18 | ds_data = ds_wei.get_data(date(2012, 1, 1), 19 | date(2012, 1, 1)) 20 | 21 | assert ds_data.shape[0] == 1 22 | assert ds_data.iloc[0][BaseDataSource.DATE_COL] == date(2012, 1, 1) 23 | assert ds_data[DataSourceWEI.COLUMN_WEI].iloc[0] > 0 24 | 25 | 26 | def test_get_data_cutoff(): 27 | ds_wei = DataSourceWEI() 28 | ds_data = ds_wei.get_data(date(2023, 6, 1), 29 | date(2023, 6, 2)) 30 | 31 | assert ds_data[DataSourceWEI.COLUMN_WEI].iloc[0] != ds_data[DataSourceWEI.COLUMN_WEI].iloc[1] 32 | 33 | 34 | def test_get_columns(): 35 | ds_wei = DataSourceWEI() 36 | ds_columns = ds_wei.get_columns() 37 | ds_data = ds_wei.get_data(date(2012, 1, 1), 38 | date(2012, 1, 8)) 39 | 40 | data_columns = [column for column in ds_data.columns if column != BaseDataSource.DATE_COL] 41 | assert ds_columns == data_columns 42 | 43 | test_get_data() -------------------------------------------------------------------------------- /src/numerai_era_data/data_sources/ds_wei.py: -------------------------------------------------------------------------------- 1 | from datetime import date, timedelta 2 | 3 | import pandas as pd 4 | import requests 5 | 6 | from numerai_era_data.data_sources.base_data_source import BaseDataSource 7 | 8 | 9 | class DataSourceWEI(BaseDataSource): 10 | _PREFIX = BaseDataSource._BASE_PREFIX + "wei_" 11 | 12 | # columns 13 | COLUMN_WEI = _PREFIX + "wei" 14 | COLUMNS = [COLUMN_WEI] 15 | 16 | def get_data(self, start_date: date, end_date: date) -> pd.DataFrame: 17 | # add 13 days of padding 18 | padded_start_date = start_date - timedelta(days=13) 19 | 20 | date_df = pd.DataFrame() 21 | date_df[self.DATE_COL] = pd.date_range(padded_start_date, end_date) 22 | date_df[self.DATE_COL] = date_df[self.DATE_COL].dt.date 23 | 24 | # URL of the weekly economic index data 25 | url = "https://fred.stlouisfed.org/graph/fredgraph.csv?id=WEI" 26 | 27 | # Make the HTTP request to fetch the data 28 | response = requests.get(url) 29 | response.raise_for_status() 30 | 31 | # Create a DataFrame from the CSV data 32 | wei_df = pd.read_csv(url) 33 | 34 | # rename columns 35 | wei_df.rename(columns={"DATE": self.DATE_COL, "WEI": self.COLUMN_WEI}, inplace=True) 36 | 37 | # convert date column to datetime 38 | wei_df[self.DATE_COL] = pd.to_datetime(wei_df[self.DATE_COL]) 39 | 40 | # data is not ready until after noon UTC on Thursday, dates are for previous Saturday 41 | wei_df[self.DATE_COL] = wei_df[self.DATE_COL].dt.date + timedelta(days=6) 42 | 43 | # merge market data with date data to fill in missing dates 44 | data = pd.merge(date_df, wei_df, on=self.DATE_COL, how="left").ffill() 45 | 46 | data = data[(data[self.DATE_COL] >= start_date)] 47 | 48 | return data 49 | 50 | def get_columns(self) -> list: 51 | return self.COLUMNS 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Numerai Era Data 2 | 3 | Numerai Era Data is a Python project dedicated to enriching the Numerai tournament experience by providing supplemental era-level data. This data may offer valuable insights to enhance the modeling capabilities of participants in the Numerai tournament. 4 | 5 | ## Table of Contents 6 | 7 | - [Introduction](#introduction) 8 | - [Installation](#installation) 9 | - [Usage](#usage) 10 | - [Data Types](#data-types) 11 | - [Extending Data Sources](#extending-data-sources) 12 | - [Contributing](#contributing) 13 | - [License](#license) 14 | 15 | ## Introduction 16 | 17 | The Numerai tournament provides an innovative platform for data scientists to build predictive models for financial data. Numerai Era Data takes this a step further by offering supplemental data at the era level. These data enhancements can be seamlessly integrated into participants' modeling pipelines, enabling them to explore new approaches and potentially improve their model performance. 18 | 19 | Era-level data could be incorporated into a model pipeline in many ways. The simplest approach would be to add the supplemental columns as new feature columns directly to the Numerai data. Initial tests on this approach have not shown any benefit. Another avenue would be to use the era-level data to help predict and respond to changes in regime, which has seemed to plague Numerai participants periodically, including during the heavy drawdowns of Q2 2023. One approach along this avenue would be to cluster eras based on era-level feature similarity and then train a separate model or models on each cluster. These models could then be used in an ensemble or mixture-of-experts system. 20 | 21 | ## Installation 22 | 23 | To start utilizing Numerai Era Data, install it from PyPI using the following command: 24 | 25 | ``` 26 | pip install numerai-era-data 27 | ``` 28 | 29 | ## Usage 30 | 31 | Numerai Era Data can be incorporated into your Numerai modeling process to enhance your models' predictive power. Here's how you can use it: 32 | 33 | 34 | ``` 35 | from numerai_era_data.era_data_api import EraDataAPI 36 | 37 | # Get data for all eras and latest daily data for live era 38 | era_data_api = EraDataAPI() 39 | era_data = era_data_api.get_all_eras() 40 | daily_data = era_data_api.get_current_daily() 41 | 42 | # Exclude raw columns 43 | era_feature_columns = [f for f in era_data.columns if f != "era" and not f.startswith("era_feature_raw_")] 44 | 45 | # Merge era data with Numerai data 46 | all_data = all_data.merge(era_data[["era"] + era_feature_columns], on="era", how="left") 47 | live_data = live_data.merge(daily_data[["era"] + era_feature_columns], on="era", how="outer") 48 | ``` 49 | 50 | ## Data Types 51 | 52 | Numerai Era Data provides two types of columns: normal and raw. Raw features, indicated by the prefix "era_feature_raw_", require additional processing to be useful in modeling. These features encompass data like the S&P500 closing price. Incorporating these columns can potentially contribute to more accurate and sophisticated models. 53 | Extending Data Sources 54 | 55 | ## Contributing 56 | 57 | Numerai Era Data welcomes contributors to expand its capabilities by implementing new data sources. To add a new data source, follow these steps: 58 | 59 | 1. Create a new class that extends numerai_era_data.data_sources.base_data_source.BaseDataSource. 60 | 1. Implement the get_data() function in the new class, returning a Pandas DataFrame. The DataFrame should have a "date" column and one or more columns starting with either "_BASE_PREFIX" or "_BASE_PREFIX_RAW". These columns should contain the values available at noon UTC for each date in the DataFrame's range. 61 | 1. Implement the get_columns() function to return the list of data columns provided by the new data source. 62 | 63 | ## License 64 | 65 | Numerai Era Data is released under the MIT License. You are free to use, modify, and distribute the code according to the terms of the license. 66 | -------------------------------------------------------------------------------- /src/numerai_era_data/data_sources/ds_markets.py: -------------------------------------------------------------------------------- 1 | from datetime import date, datetime, timedelta 2 | 3 | import pandas as pd 4 | import pytz 5 | import yfinance as yf 6 | 7 | from numerai_era_data.data_sources.base_data_source import BaseDataSource 8 | 9 | 10 | class DataSourceMarkets(BaseDataSource): 11 | _PREFIX = BaseDataSource._BASE_PREFIX + "markets_" 12 | _PREFIX_RAW = BaseDataSource._BASE_PREFIX_RAW + "markets_" 13 | _PREFIX_SPX_SMA = _PREFIX_RAW + "spx_sma_" 14 | _PREFIX_SPX_EMA = _PREFIX_RAW + "spx_ema_" 15 | _PREFIX_SPX_RETURN = _PREFIX + "spx_return_" 16 | _TIME_WINDOWS = [10, 20, 50, 100, 200] 17 | 18 | # columns 19 | COLUMN_SPX_CLOSE = _PREFIX_RAW + "spx_close" 20 | 21 | def __init__(self): 22 | self.COLUMNS = [self.COLUMN_SPX_CLOSE] 23 | 24 | for i in self._TIME_WINDOWS: 25 | setattr(self, f"COLUMN_SPX_SMA{i}", self._PREFIX_SPX_SMA + str(i)) 26 | self.COLUMNS.append(getattr(self, f"COLUMN_SPX_SMA{i}")) 27 | for i in self._TIME_WINDOWS: 28 | setattr(self, f"COLUMN_SPX_EMA{i}", self._PREFIX_SPX_EMA + str(i)) 29 | self.COLUMNS.append(getattr(self, f"COLUMN_SPX_EMA{i}")) 30 | for i in self._TIME_WINDOWS: 31 | setattr(self, f"COLUMN_SPX_RETURN{i}", self._PREFIX_SPX_RETURN + str(i)) 32 | self.COLUMNS.append(getattr(self, f"COLUMN_SPX_RETURN{i}")) 33 | 34 | def get_data(self, start_date: date, end_date: date) -> pd.DataFrame: 35 | # adjusted close is more accurate than close 36 | CLOSE_COL = "Close" 37 | 38 | # get 300 calendar days of padding for the 200 day moving average calculation 39 | padded_start_date = start_date - timedelta(days=300) 40 | 41 | # dataframe with all dates including weekends and holidays 42 | date_df = pd.DataFrame() 43 | date_df[self.DATE_COL] = pd.date_range(padded_start_date, end_date) 44 | date_df[self.DATE_COL] = date_df[self.DATE_COL].dt.date 45 | 46 | # dataframe with only trading days 47 | data = yf.download("^SPX", start=padded_start_date, end=end_date) 48 | 49 | if isinstance(data.columns, pd.MultiIndex): 50 | # Extract just the first element of each tuple for column names 51 | data.columns = [col[0] for col in data.columns] 52 | 53 | data = data.reset_index() 54 | 55 | # calculate moving averages 56 | for i in self._TIME_WINDOWS: 57 | data[getattr(self, f"COLUMN_SPX_SMA{i}")] = data[CLOSE_COL].rolling(window=i).mean() 58 | 59 | # calculate exponential moving averages 60 | for i in self._TIME_WINDOWS: 61 | data[getattr(self, f"COLUMN_SPX_EMA{i}")] = data[CLOSE_COL].ewm(span=i, adjust=False).mean() 62 | 63 | # calculate returns 64 | for i in self._TIME_WINDOWS: 65 | data[getattr(self, f"COLUMN_SPX_RETURN{i}")] = data[CLOSE_COL].pct_change(periods=i) 66 | 67 | data.rename(columns={"Date": self.DATE_COL, CLOSE_COL: self.COLUMN_SPX_CLOSE}, inplace=True) 68 | 69 | # data is not finalized until around midnight Eastern time 70 | # add one day to date column to align with data availability 71 | data[self.DATE_COL] = data[self.DATE_COL].dt.date + timedelta(days=1) 72 | 73 | # merge market data with date data to fill in missing dates 74 | data = pd.merge(date_df, data, on=self.DATE_COL, how="left").ffill() 75 | 76 | # remove any data corresponding to future date (in eastern tz) as it may not be complete 77 | data = data[data[self.DATE_COL] <= datetime.now(pytz.timezone("US/Eastern")).date()] 78 | 79 | # filter out data outside of the requested date range 80 | data = data[(data[self.DATE_COL] >= start_date) & (data[self.DATE_COL] <= end_date)] 81 | 82 | # filter columns 83 | final_columns = [self.DATE_COL] + self.get_columns() 84 | data = data[final_columns] 85 | 86 | return data 87 | 88 | def get_columns(self) -> list: 89 | return self.COLUMNS 90 | -------------------------------------------------------------------------------- /src/numerai_era_data/era_data_api.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import inspect 3 | import logging 4 | import os 5 | import pkgutil 6 | 7 | import pandas as pd 8 | 9 | import numerai_era_data.date_utils as date_utils 10 | from numerai_era_data.data_sources.base_data_source import BaseDataSource 11 | 12 | 13 | class EraDataAPI: 14 | CACHE_DIRECTORY = os.path.join(os.path.dirname(__file__), 'cache') 15 | DATA_CACHE_FILE = os.path.join(CACHE_DIRECTORY, 'data.parquet') 16 | DAILY_CACHE_FILE = os.path.join(CACHE_DIRECTORY, 'daily.parquet') 17 | 18 | def __init__(self): 19 | dir_name = os.path.dirname(self.DATA_CACHE_FILE) 20 | if not os.path.exists(dir_name): 21 | os.makedirs(dir_name) 22 | 23 | if os.path.exists(self.DATA_CACHE_FILE): 24 | self.data_cache = pd.read_parquet(self.DATA_CACHE_FILE) 25 | else: 26 | self.data_cache = pd.DataFrame() 27 | 28 | if os.path.exists(self.DAILY_CACHE_FILE): 29 | self.daily_cache = pd.read_parquet(self.DAILY_CACHE_FILE) 30 | else: 31 | self.daily_cache = pd.DataFrame() 32 | 33 | self.class_cache = [] 34 | 35 | # logger config 36 | logging.basicConfig(filename="exception.log", level=logging.ERROR) 37 | 38 | def get_all_eras(self, update_if_stale=True) -> pd.DataFrame: 39 | update = False 40 | 41 | if update_if_stale: 42 | # if most current era is not in the data, update the data 43 | if self.data_cache.empty or self.data_cache[BaseDataSource.ERA_COL].astype(int).max() < date_utils.get_current_era(): 44 | update = True 45 | 46 | # if any columns have been added since the last update, update the data 47 | for data_source_class in self._get_data_sources(): 48 | data_source = data_source_class() 49 | if not set(data_source.get_columns()).issubset(set(self.data_cache.columns)): 50 | update = True 51 | break 52 | 53 | if update: 54 | self.update_data() 55 | 56 | return self.data_cache 57 | 58 | def get_current_daily(self, update_if_stale=True) -> pd.DataFrame: 59 | update = False 60 | 61 | if update_if_stale: 62 | # if most current era is not in the data, update the data 63 | if self.daily_cache.empty or self.daily_cache[BaseDataSource.DATE_COL][0] != date_utils.get_current_date(): 64 | update = True 65 | 66 | # if any columns have been added since the last update, update the data 67 | for data_source_class in self._get_data_sources(): 68 | data_source = data_source_class() 69 | if not set(data_source.get_columns()).issubset(set(self.daily_cache.columns)): 70 | update = True 71 | break 72 | 73 | if update: 74 | self.update_daily_data() 75 | 76 | return self.daily_cache 77 | 78 | def update_data(self): 79 | # update the cache 80 | new_data = pd.DataFrame() 81 | start_date = date_utils.get_date_for_era(1) 82 | end_date = date_utils.get_date_for_era(date_utils.get_current_era()) 83 | 84 | for data_source_class in self._get_data_sources(): 85 | data_source = data_source_class() 86 | 87 | try: 88 | data = data_source.get_data(start_date, end_date) 89 | except Exception as e: 90 | logging.exception( 91 | f"Error getting data from {data_source_class.__name__}: {e} on {start_date} to {end_date}" 92 | ) 93 | data = pd.DataFrame() 94 | data[BaseDataSource.DATE_COL] = pd.date_range(start_date, end_date) 95 | data[BaseDataSource.DATE_COL] = data[BaseDataSource.DATE_COL].dt.date 96 | data[data_source.get_columns()] = None 97 | 98 | new_data = data if new_data.empty else pd.merge(new_data, data, how="outer", on=BaseDataSource.DATE_COL) 99 | 100 | new_data[BaseDataSource.ERA_COL] = new_data[BaseDataSource.DATE_COL].apply(date_utils.get_era_for_date).astype(str).str.zfill(4) 101 | new_data = new_data.fillna(method="ffill") 102 | new_data = new_data.drop_duplicates(subset=[BaseDataSource.ERA_COL], keep="last") 103 | new_data = new_data.reindex(columns=[BaseDataSource.ERA_COL] 104 | + new_data.columns.difference([BaseDataSource.ERA_COL]).tolist()) 105 | new_data = new_data.drop(columns=[BaseDataSource.DATE_COL]) 106 | self.data_cache = new_data.reset_index(drop=True) 107 | 108 | # write cache to disk 109 | self.data_cache.to_parquet(self.DATA_CACHE_FILE) 110 | 111 | def update_daily_data(self): 112 | new_data = pd.DataFrame() 113 | start_date = date_utils.get_current_date() 114 | end_date = date_utils.get_current_date() 115 | 116 | for data_source_class in self._get_data_sources(): 117 | data_source = data_source_class() 118 | try: 119 | data = data_source.get_data(start_date, end_date) 120 | except Exception as e: 121 | logging.exception( 122 | f"Error getting data from {data_source_class.__name__}: {e} on {start_date} to {end_date}" 123 | ) 124 | # fill with the last era value 125 | data = pd.DataFrame() 126 | data[BaseDataSource.DATE_COL] = pd.date_range(start_date, end_date) 127 | data[BaseDataSource.DATE_COL] = data[BaseDataSource.DATE_COL].dt.date 128 | data[data_source.get_columns()] = self.data_cache[data_source.get_columns()].tail(1).values 129 | 130 | new_data = data if new_data.empty else pd.merge(new_data, data, how="outer", on=BaseDataSource.DATE_COL) 131 | 132 | # add era column with X value so it can be merged with the live data 133 | new_data[BaseDataSource.ERA_COL] = "X" 134 | self.daily_cache = new_data 135 | self.daily_cache.to_parquet(self.DAILY_CACHE_FILE) 136 | 137 | def _get_data_sources(self) -> list: 138 | if len(self.class_cache) > 0: 139 | return self.class_cache 140 | 141 | full_subpackage_name = "numerai_era_data.data_sources" 142 | module = importlib.import_module(full_subpackage_name) 143 | classes = [] 144 | 145 | for _, name, _ in pkgutil.iter_modules(module.__path__): 146 | sub_module = importlib.import_module(f"{full_subpackage_name}.{name}") 147 | for _, obj in inspect.getmembers(sub_module): 148 | if ( 149 | inspect.isclass(obj) 150 | and inspect.getmodule(obj) == sub_module 151 | and obj != BaseDataSource 152 | and issubclass(obj, BaseDataSource) 153 | ): 154 | classes.append(obj) 155 | 156 | self.class_cache = classes 157 | return classes 158 | -------------------------------------------------------------------------------- /src/numerai_era_data/data_sources/ds_bls.py: -------------------------------------------------------------------------------- 1 | from datetime import date, timedelta 2 | import math 3 | 4 | import pandas as pd 5 | import requests 6 | 7 | from numerai_era_data.data_sources.base_data_source import BaseDataSource 8 | 9 | 10 | class DataSourceBLS(BaseDataSource): 11 | _PREFIX = BaseDataSource._BASE_PREFIX + "bls_" 12 | _PREFIX_RAW = BaseDataSource._BASE_PREFIX_RAW + "bls_" 13 | 14 | # columns 15 | COLUMN_CPI_U = _PREFIX_RAW + "cpi_u" 16 | COLUMN_CPI_U_ALL = _PREFIX_RAW + "cpi_u_all" 17 | COLUMN_PPI_FINISHED = _PREFIX_RAW + "ppi_finished" 18 | COLUMN_UE = _PREFIX + "unemployment" 19 | COLUMN_WEEKLY_HOURS = _PREFIX + "weekly_hours" 20 | COLUMN_HOURLY_EARNINGS = _PREFIX_RAW + "hourly_earnings" 21 | COLUMN_OUTPUT = _PREFIX + "output" 22 | COLUMN_IMPORT_INDEX = _PREFIX_RAW + "import_index" 23 | COLUMN_EXPORT_INDEX = _PREFIX_RAW + "export_index" 24 | 25 | # columns from month over month and year over year changes 26 | COLUMN_CPI_U_MOM = _PREFIX + "cpi_u_mom" 27 | COLUMN_CPI_U_YOY = _PREFIX + "cpi_u_yoy" 28 | COLUMN_CPI_U_ALL_MOM = _PREFIX + "cpi_u_all_mom" 29 | COLUMN_CPI_U_ALL_YOY = _PREFIX + "cpi_u_all_yoy" 30 | COLUMN_PPI_FINISHED_MOM = _PREFIX + "ppi_finished_mom" 31 | COLUMN_PPI_FINISHED_YOY = _PREFIX + "ppi_finished_yoy" 32 | COLUMN_UE_MOM = _PREFIX + "unemployment_mom" 33 | COLUMN_UE_YOY = _PREFIX + "unemployment_yoy" 34 | COLUMN_WEEKLY_HOURS_MOM = _PREFIX + "weekly_hours_mom" 35 | COLUMN_WEEKLY_HOURS_YOY = _PREFIX + "weekly_hours_yoy" 36 | COLUMN_HOURLY_EARNINGS_MOM = _PREFIX + "hourly_earnings_mom" 37 | COLUMN_HOURLY_EARNINGS_YOY = _PREFIX + "hourly_earnings_yoy" 38 | COLUMN_OUTPUT_MOM = _PREFIX + "output_mom" 39 | COLUMN_OUTPUT_YOY = _PREFIX + "output_yoy" 40 | COLUMN_IMPORT_INDEX_MOM = _PREFIX + "import_index_mom" 41 | COLUMN_IMPORT_INDEX_YOY = _PREFIX + "import_index_yoy" 42 | COLUMN_EXPORT_INDEX_MOM = _PREFIX + "export_index_mom" 43 | COLUMN_EXPORT_INDEX_YOY = _PREFIX + "export_index_yoy" 44 | 45 | COLUMNS = [ 46 | COLUMN_CPI_U, 47 | COLUMN_CPI_U_ALL, 48 | COLUMN_PPI_FINISHED, 49 | COLUMN_UE, 50 | COLUMN_WEEKLY_HOURS, 51 | COLUMN_HOURLY_EARNINGS, 52 | COLUMN_OUTPUT, 53 | COLUMN_IMPORT_INDEX, 54 | COLUMN_EXPORT_INDEX, 55 | COLUMN_CPI_U_MOM, 56 | COLUMN_CPI_U_YOY, 57 | COLUMN_CPI_U_ALL_MOM, 58 | COLUMN_CPI_U_ALL_YOY, 59 | COLUMN_PPI_FINISHED_MOM, 60 | COLUMN_PPI_FINISHED_YOY, 61 | COLUMN_UE_MOM, 62 | COLUMN_UE_YOY, 63 | COLUMN_WEEKLY_HOURS_MOM, 64 | COLUMN_WEEKLY_HOURS_YOY, 65 | COLUMN_HOURLY_EARNINGS_MOM, 66 | COLUMN_HOURLY_EARNINGS_YOY, 67 | COLUMN_OUTPUT_MOM, 68 | COLUMN_OUTPUT_YOY, 69 | COLUMN_IMPORT_INDEX_MOM, 70 | COLUMN_IMPORT_INDEX_YOY, 71 | COLUMN_EXPORT_INDEX_MOM, 72 | COLUMN_EXPORT_INDEX_YOY, 73 | ] 74 | 75 | # BLS series IDs 76 | SERIES_ID_CPI_U = "CUUR0000SA0L1E" 77 | SERIES_ID_CPI_U_ALL_ITEMS = "CUUR0000SA0" 78 | SERIES_ID_PPI_FINISHED_GOODS = "WPUFD49207" 79 | SERIES_ID_UNEMPLOYMENT = "LNS14000000" 80 | SERIES_ID_WEEKLY_HOURS = "CES0500000002" 81 | SERIES_ID_HOURLY_EARNINGS = "CES0500000003" 82 | SERIES_ID_OUTPUT = "PRS85006092" 83 | SERIES_ID_IMPORT_INDEX = "EIUIR" 84 | SERIES_ID_EXPORT_INDEX = "EIUIQ" 85 | 86 | def get_data(self, start_date: date, end_date: date) -> pd.DataFrame: 87 | # add 18 months of padding to the start date 88 | # accounts for delays in reporting and need to calculate 12 month changes 89 | padded_start_date = start_date - timedelta(days=549) 90 | 91 | date_df = pd.DataFrame() 92 | date_df[self.DATE_COL] = pd.date_range(padded_start_date, end_date) 93 | date_df[self.DATE_COL] = date_df[self.DATE_COL].dt.date 94 | 95 | # Define the URL for the BLS API 96 | api_url = "https://api.bls.gov/publicAPI/v2/timeseries/data/" 97 | 98 | # Define the BLS API headers 99 | headers = {"Content-type": "application/json"} 100 | 101 | # Define the BLS API request data 102 | series_ids = [ 103 | self.SERIES_ID_CPI_U, 104 | self.SERIES_ID_CPI_U_ALL_ITEMS, 105 | self.SERIES_ID_PPI_FINISHED_GOODS, 106 | self.SERIES_ID_UNEMPLOYMENT, 107 | self.SERIES_ID_WEEKLY_HOURS, 108 | self.SERIES_ID_HOURLY_EARNINGS, 109 | self.SERIES_ID_OUTPUT, 110 | self.SERIES_ID_IMPORT_INDEX, 111 | self.SERIES_ID_EXPORT_INDEX, 112 | ] 113 | 114 | combined_df = pd.DataFrame() 115 | 116 | total_years = end_date.year - padded_start_date.year + 1 117 | num_requests = math.ceil(total_years / 10.0) 118 | 119 | for i in range(num_requests): 120 | # Calculate the start and end years for the current request 121 | start_year = padded_start_date.year + i * 10 122 | end_year = min(start_year + 9, end_date.year) 123 | 124 | 125 | request_data = { 126 | "seriesid": series_ids, 127 | "startyear": str(start_year), 128 | "endyear": str(end_year), 129 | } 130 | 131 | # Send request to the BLS API 132 | response = requests.post(api_url, headers=headers, json=request_data) 133 | 134 | # Check if the request was successful 135 | if response.status_code == 200: 136 | request_df = pd.DataFrame() 137 | 138 | json_response = response.json() 139 | 140 | for series in json_response["Results"]["series"]: 141 | series_id = series["seriesID"] 142 | series_data = series["data"] 143 | 144 | # Extract values and dates 145 | values = [float(data_point["value"]) for data_point in series_data] 146 | dates = [] 147 | for data_point in series_data: 148 | if data_point["period"][0] == "Q": 149 | quarter_num = int(data_point["period"][1:]) 150 | year = int(data_point["year"]) 151 | month = (quarter_num - 1) * 3 + 1 152 | date = pd.to_datetime(f"{year}-{month}-01") 153 | else: 154 | date = pd.to_datetime(data_point["year"] + "-" + data_point["period"][1:]) 155 | dates.append(date) 156 | 157 | # Create DataFrame for the series data 158 | series_df = pd.DataFrame({self.DATE_COL: dates, series_id: values}) 159 | series_df = series_df.set_index(self.DATE_COL) 160 | 161 | # Merge the series data into the combined DataFrame 162 | if request_df.empty: 163 | request_df = series_df 164 | else: 165 | request_df = pd.merge(request_df, series_df, left_index=True, right_index=True, how='outer') 166 | 167 | # Merge the request data into the combined DataFrame 168 | if combined_df.empty: 169 | combined_df = request_df 170 | else: 171 | combined_df = pd.concat([combined_df, request_df]) 172 | 173 | else: 174 | print(f"Error occurred while fetching data for series ID: {series_id}") 175 | raise Exception(response.text) 176 | 177 | # rename columns 178 | combined_df.rename( 179 | columns={ 180 | self.SERIES_ID_CPI_U: self.COLUMN_CPI_U, 181 | self.SERIES_ID_CPI_U_ALL_ITEMS: self.COLUMN_CPI_U_ALL, 182 | self.SERIES_ID_PPI_FINISHED_GOODS: self.COLUMN_PPI_FINISHED, 183 | self.SERIES_ID_UNEMPLOYMENT: self.COLUMN_UE, 184 | self.SERIES_ID_WEEKLY_HOURS: self.COLUMN_WEEKLY_HOURS, 185 | self.SERIES_ID_HOURLY_EARNINGS: self.COLUMN_HOURLY_EARNINGS, 186 | self.SERIES_ID_OUTPUT: self.COLUMN_OUTPUT, 187 | self.SERIES_ID_IMPORT_INDEX: self.COLUMN_IMPORT_INDEX, 188 | self.SERIES_ID_EXPORT_INDEX: self.COLUMN_EXPORT_INDEX, 189 | }, 190 | inplace=True, 191 | ) 192 | 193 | # calculate month-over-month and year-over-year changes 194 | combined_df[self.COLUMN_CPI_U_MOM] = combined_df[self.COLUMN_CPI_U].pct_change(1) 195 | combined_df[self.COLUMN_CPI_U_YOY] = combined_df[self.COLUMN_CPI_U].pct_change(12) 196 | combined_df[self.COLUMN_CPI_U_ALL_MOM] = combined_df[self.COLUMN_CPI_U_ALL].pct_change(1) 197 | combined_df[self.COLUMN_CPI_U_ALL_YOY] = combined_df[self.COLUMN_CPI_U_ALL].pct_change(12) 198 | combined_df[self.COLUMN_PPI_FINISHED_MOM] = combined_df[self.COLUMN_PPI_FINISHED].pct_change(1) 199 | combined_df[self.COLUMN_PPI_FINISHED_YOY] = combined_df[self.COLUMN_PPI_FINISHED].pct_change(12) 200 | combined_df[self.COLUMN_UE_MOM] = combined_df[self.COLUMN_UE].pct_change(1) 201 | combined_df[self.COLUMN_UE_YOY] = combined_df[self.COLUMN_UE].pct_change(12) 202 | combined_df[self.COLUMN_WEEKLY_HOURS_MOM] = combined_df[self.COLUMN_WEEKLY_HOURS].pct_change(1) 203 | combined_df[self.COLUMN_WEEKLY_HOURS_YOY] = combined_df[self.COLUMN_WEEKLY_HOURS].pct_change(12) 204 | combined_df[self.COLUMN_HOURLY_EARNINGS_MOM] = combined_df[self.COLUMN_HOURLY_EARNINGS].pct_change(1) 205 | combined_df[self.COLUMN_HOURLY_EARNINGS_YOY] = combined_df[self.COLUMN_HOURLY_EARNINGS].pct_change(12) 206 | combined_df[self.COLUMN_OUTPUT_MOM] = combined_df[self.COLUMN_OUTPUT].pct_change(1) 207 | combined_df[self.COLUMN_OUTPUT_YOY] = combined_df[self.COLUMN_OUTPUT].pct_change(12) 208 | combined_df[self.COLUMN_IMPORT_INDEX_MOM] = combined_df[self.COLUMN_IMPORT_INDEX].pct_change(1) 209 | combined_df[self.COLUMN_IMPORT_INDEX_YOY] = combined_df[self.COLUMN_IMPORT_INDEX].pct_change(12) 210 | combined_df[self.COLUMN_EXPORT_INDEX_MOM] = combined_df[self.COLUMN_EXPORT_INDEX].pct_change(1) 211 | combined_df[self.COLUMN_EXPORT_INDEX_YOY] = combined_df[self.COLUMN_EXPORT_INDEX].pct_change(12) 212 | 213 | # convert date column to datetime 214 | combined_df = combined_df.reset_index() 215 | combined_df[self.DATE_COL] = pd.to_datetime(combined_df[self.DATE_COL]) 216 | combined_df[self.DATE_COL] = combined_df[self.DATE_COL].dt.date 217 | 218 | # merge market data with date data to fill in missing dates 219 | data = pd.merge(date_df, combined_df, on=self.DATE_COL, how="left").ffill() 220 | 221 | # shift each column based on its release schedule 222 | data[self.COLUMN_CPI_U] = data[self.COLUMN_CPI_U].shift(48) 223 | data[self.COLUMN_CPI_U_ALL] = data[self.COLUMN_CPI_U_ALL].shift(48) 224 | data[self.COLUMN_PPI_FINISHED] = data[self.COLUMN_PPI_FINISHED].shift(45) 225 | data[self.COLUMN_UE] = data[self.COLUMN_UE].shift(38) 226 | data[self.COLUMN_WEEKLY_HOURS] = data[self.COLUMN_WEEKLY_HOURS].shift(38) 227 | data[self.COLUMN_HOURLY_EARNINGS] = data[self.COLUMN_HOURLY_EARNINGS].shift(38) 228 | data[self.COLUMN_OUTPUT] = data[self.COLUMN_OUTPUT].shift(134) 229 | data[self.COLUMN_IMPORT_INDEX] = data[self.COLUMN_IMPORT_INDEX].shift(45) 230 | data[self.COLUMN_EXPORT_INDEX] = data[self.COLUMN_EXPORT_INDEX].shift(45) 231 | 232 | data = data[(data[self.DATE_COL] >= start_date)][[self.DATE_COL] + self.COLUMNS] 233 | 234 | return data 235 | 236 | def get_columns(self) -> list: 237 | return self.COLUMNS 238 | -------------------------------------------------------------------------------- /tests/era_data_api_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import date, timedelta 3 | 4 | import pandas as pd 5 | import pytest 6 | from mock import MagicMock, patch 7 | 8 | from numerai_era_data import era_data_api 9 | from numerai_era_data.data_sources.base_data_source import BaseDataSource 10 | from numerai_era_data.date_utils import ERA_ONE_START, get_date_for_era 11 | 12 | 13 | class MockDataSource: 14 | def __init__(self): 15 | pass 16 | 17 | def get_columns(self): 18 | return ["column1"] 19 | 20 | def get_data(self, start_date, end_date): 21 | df = pd.DataFrame() 22 | if start_date == get_date_for_era(1): 23 | df = pd.concat([df, pd.DataFrame({BaseDataSource.DATE_COL: [ERA_ONE_START], "column1": [1]})]) 24 | if end_date >= get_date_for_era(2): 25 | df = pd.concat([df, pd.DataFrame( 26 | {BaseDataSource.DATE_COL: [ERA_ONE_START + timedelta(days=7)], "column2": [2]})]) 27 | if end_date >= get_date_for_era(3): 28 | df = pd.concat([df, pd.DataFrame( 29 | {BaseDataSource.DATE_COL: [ERA_ONE_START + timedelta(days=14)], "column3": [3]})]) 30 | if end_date >= get_date_for_era(4): 31 | df = pd.concat([df, pd.DataFrame( 32 | {BaseDataSource.DATE_COL: [ERA_ONE_START + timedelta(days=21)], "column4": [4]})]) 33 | if start_date == date(2001, 4, 20): 34 | df = pd.DataFrame({BaseDataSource.DATE_COL: [date(2001, 4, 20)], "column5": [5]}) 35 | 36 | return df 37 | 38 | 39 | class MockDataSourceWithException: 40 | def __init__(self): 41 | pass 42 | 43 | def get_columns(self): 44 | return ["column2", "column3"] 45 | 46 | def get_data(self, start_date, end_date): 47 | raise Exception("Test exception") 48 | 49 | 50 | @pytest.fixture 51 | def manage_cache(): 52 | instance = era_data_api.EraDataAPI() 53 | instance.DATA_CACHE_FILE = "test_data_cache.parquet" 54 | instance.DAILY_CACHE_FILE = "test_daily_cache.parquet" 55 | 56 | yield instance 57 | 58 | if os.path.exists(instance.DATA_CACHE_FILE): 59 | os.remove(instance.DATA_CACHE_FILE) 60 | if os.path.exists(instance.DAILY_CACHE_FILE): 61 | os.remove(instance.DAILY_CACHE_FILE) 62 | 63 | 64 | def test_get_data_sources(): 65 | data_api = era_data_api.EraDataAPI() 66 | data_sources = data_api._get_data_sources() 67 | 68 | assert len(data_sources) > 0 69 | for data_source_class in data_sources: 70 | assert issubclass(data_source_class, era_data_api.BaseDataSource) 71 | assert data_source_class != BaseDataSource 72 | 73 | 74 | def test_get_data_sources_returns_cached_list(): 75 | instance = era_data_api.EraDataAPI() 76 | instance.class_cache = ['cached_data'] 77 | 78 | result = instance._get_data_sources() 79 | 80 | assert result == ['cached_data'] 81 | 82 | 83 | def test_update_data_with_empty_cache(manage_cache): 84 | instance = manage_cache 85 | instance._get_data_sources = MagicMock(return_value=[MockDataSource]) 86 | instance.data_cache = pd.DataFrame() 87 | 88 | with patch("numerai_era_data.date_utils.get_current_era", return_value=3): 89 | instance.update_data() 90 | 91 | assert not instance.data_cache.empty 92 | assert instance.data_cache[BaseDataSource.ERA_COL].tolist() == ["0001", "0002", "0003"] 93 | assert instance.data_cache.columns.tolist() == [BaseDataSource.ERA_COL, "column1", "column2", "column3"] 94 | 95 | 96 | def test_update_data_with_existing_cache(manage_cache): 97 | instance = manage_cache 98 | instance._get_data_sources = MagicMock(return_value=[MockDataSource]) 99 | instance.data_cache = pd.DataFrame({BaseDataSource.ERA_COL: ["0001"], "column1": [1]}) 100 | 101 | with patch("numerai_era_data.date_utils.get_current_era", return_value=4): 102 | instance.update_data() 103 | 104 | assert not instance.data_cache.empty 105 | assert instance.data_cache[BaseDataSource.ERA_COL].tolist() == ["0001", "0002", "0003", "0004"] 106 | assert instance.data_cache.columns.tolist() == [BaseDataSource.ERA_COL, "column1", "column2", "column3", "column4"] 107 | 108 | 109 | def test_update_data_with_exception(manage_cache): 110 | instance = manage_cache 111 | instance._get_data_sources = MagicMock(return_value=[MockDataSourceWithException]) 112 | instance.data_cache = pd.DataFrame({BaseDataSource.ERA_COL: ["0001"], "column1": [1]}) 113 | 114 | with patch("numerai_era_data.date_utils.get_current_era", return_value=3): 115 | instance.update_data() 116 | 117 | assert not instance.data_cache.empty 118 | assert instance.data_cache[BaseDataSource.ERA_COL].tolist() == ["0001", "0002", "0003"] 119 | assert instance.data_cache.columns.tolist() == [BaseDataSource.ERA_COL, "column2", "column3"] 120 | 121 | def test_update_daily_data_with_empty_cache(manage_cache): 122 | instance = manage_cache 123 | instance._get_data_sources = MagicMock(return_value=[MockDataSource]) 124 | instance.daily_cache = pd.DataFrame() 125 | data_date = date(2001, 4, 20) 126 | 127 | with patch("numerai_era_data.date_utils.get_current_date", return_value=data_date): 128 | instance.update_daily_data() 129 | 130 | assert not instance.daily_cache.empty 131 | assert instance.daily_cache[BaseDataSource.DATE_COL].tolist() == [data_date] 132 | assert instance.daily_cache.columns.tolist() == [BaseDataSource.DATE_COL, "column5", BaseDataSource.ERA_COL] 133 | 134 | 135 | def test_update_daily_data_with_existing_cache(manage_cache): 136 | instance = manage_cache 137 | instance._get_data_sources = MagicMock(return_value=[MockDataSource]) 138 | instance.daily_cache = pd.DataFrame({BaseDataSource.DATE_COL: ["1234"], "column1": [1]}) 139 | data_date = date(2001, 4, 20) 140 | 141 | with patch("numerai_era_data.date_utils.get_current_date", return_value=data_date): 142 | instance.update_daily_data() 143 | 144 | assert not instance.daily_cache.empty 145 | assert instance.daily_cache[BaseDataSource.DATE_COL].tolist() == [data_date] 146 | assert instance.daily_cache.columns.tolist() == [BaseDataSource.DATE_COL, "column5", BaseDataSource.ERA_COL] 147 | 148 | 149 | def test_update_daily_data_with_exception(manage_cache): 150 | instance = manage_cache 151 | instance._get_data_sources = MagicMock(return_value=[MockDataSourceWithException]) 152 | instance.data_cache = pd.DataFrame({BaseDataSource.ERA_COL: ["0001"], "column2": [2], "column3": [3]}) 153 | instance.daily_cache = pd.DataFrame({BaseDataSource.DATE_COL: ["1234"], "column1": [1]}) 154 | data_date = date(2001, 4, 20) 155 | 156 | with patch("numerai_era_data.date_utils.get_current_date", return_value=data_date): 157 | instance.update_daily_data() 158 | 159 | assert not instance.daily_cache.empty 160 | assert instance.daily_cache[BaseDataSource.DATE_COL].tolist() == [data_date] 161 | assert instance.daily_cache.columns.tolist() == [BaseDataSource.DATE_COL, "column2", "column3", BaseDataSource.ERA_COL] 162 | 163 | 164 | def test_get_all_eras_no_update(manage_cache): 165 | instance = manage_cache 166 | instance._get_data_sources = MagicMock(return_value=[MockDataSource]) 167 | instance.data_cache = pd.DataFrame({BaseDataSource.ERA_COL: ["0001"], "column1": [1]}) 168 | 169 | with patch("numerai_era_data.date_utils.get_current_era", return_value=1): 170 | df = instance.get_all_eras() 171 | 172 | assert df[BaseDataSource.ERA_COL].tolist() == ["0001"] 173 | 174 | 175 | def test_get_all_eras_with_update(manage_cache): 176 | instance = manage_cache 177 | instance._get_data_sources = MagicMock(return_value=[MockDataSource]) 178 | instance.data_cache = pd.DataFrame({BaseDataSource.ERA_COL: ["0001"], "column1": [1]}) 179 | 180 | with patch("numerai_era_data.date_utils.get_current_era", return_value=2): 181 | df = instance.get_all_eras() 182 | 183 | assert df[BaseDataSource.ERA_COL].tolist() == ["0001", "0002"] 184 | 185 | 186 | def test_get_all_eras_with_update_and_no_cache(manage_cache): 187 | instance = manage_cache 188 | instance._get_data_sources = MagicMock(return_value=[MockDataSource]) 189 | instance.data_cache = pd.DataFrame() 190 | 191 | with patch("numerai_era_data.date_utils.get_current_era", return_value=2): 192 | df = instance.get_all_eras() 193 | 194 | assert df[BaseDataSource.ERA_COL].tolist() == ["0001", "0002"] 195 | 196 | 197 | def test_get_all_eras_columns_changed(manage_cache): 198 | instance = manage_cache 199 | instance._get_data_sources = MagicMock(return_value=[MockDataSource]) 200 | instance.data_cache = pd.DataFrame({BaseDataSource.ERA_COL: ["0001"], "column0": [1]}) 201 | 202 | with patch("numerai_era_data.date_utils.get_current_era", return_value=1): 203 | df = instance.get_all_eras() 204 | 205 | assert df.columns.tolist() == [BaseDataSource.ERA_COL, "column1"] 206 | 207 | 208 | def test_get_all_eras_no_update_stale(manage_cache): 209 | instance = manage_cache 210 | instance._get_data_sources = MagicMock(return_value=[MockDataSource]) 211 | instance.data_cache = pd.DataFrame({BaseDataSource.ERA_COL: ["0001"], "column1": [1]}) 212 | 213 | with patch("numerai_era_data.date_utils.get_current_era", return_value=3): 214 | df = instance.get_all_eras(False) 215 | 216 | assert df[BaseDataSource.ERA_COL].tolist() == ["0001"] 217 | 218 | 219 | def test_get_current_daily_no_update(manage_cache): 220 | instance = manage_cache 221 | instance._get_data_sources = MagicMock(return_value=[MockDataSource]) 222 | data_date = date(2022, 1, 1) 223 | instance.daily_cache = pd.DataFrame({BaseDataSource.DATE_COL: [data_date], "column1": [1]}) 224 | 225 | with patch("numerai_era_data.date_utils.get_current_date", return_value=data_date): 226 | df = instance.get_current_daily() 227 | 228 | assert df[BaseDataSource.DATE_COL].tolist() == [data_date] 229 | assert df["column1"].tolist() == [1] 230 | 231 | 232 | def test_get_current_daily_with_update(manage_cache): 233 | instance = manage_cache 234 | instance._get_data_sources = MagicMock(return_value=[MockDataSource]) 235 | cache_data_date = date(2022, 1, 1) 236 | data_date = date(2001, 4, 20) 237 | instance.daily_cache = pd.DataFrame({BaseDataSource.DATE_COL: [cache_data_date], "column1": [1]}) 238 | 239 | with patch("numerai_era_data.date_utils.get_current_date", return_value=data_date): 240 | df = instance.get_current_daily() 241 | 242 | assert df[BaseDataSource.DATE_COL].tolist() == [data_date] 243 | assert df["column5"].tolist() == [5] 244 | 245 | 246 | def test_get_current_daily_with_update_and_no_cache(manage_cache): 247 | instance = manage_cache 248 | instance._get_data_sources = MagicMock(return_value=[MockDataSource]) 249 | data_date = date(2001, 4, 20) 250 | instance.daily_cache = pd.DataFrame() 251 | 252 | with patch("numerai_era_data.date_utils.get_current_date", return_value=data_date): 253 | df = instance.get_current_daily() 254 | 255 | assert df[BaseDataSource.DATE_COL].tolist() == [data_date] 256 | assert df["column5"].tolist() == [5] 257 | 258 | 259 | def test_get_current_daily_no_update_stale(manage_cache): 260 | instance = manage_cache 261 | instance._get_data_sources = MagicMock(return_value=[MockDataSource]) 262 | data_date = date(2022, 1, 1) 263 | request_date = date(2022, 5, 1) 264 | instance.daily_cache = pd.DataFrame({BaseDataSource.DATE_COL: [data_date], "column1": [1]}) 265 | 266 | with patch("numerai_era_data.date_utils.get_current_date", return_value=request_date): 267 | df = instance.get_current_daily(False) 268 | 269 | assert df[BaseDataSource.DATE_COL].tolist() == [data_date] 270 | assert df["column1"].tolist() == [1] 271 | --------------------------------------------------------------------------------