├── docs ├── static │ ├── .gitkeep │ └── f2ai_architecture.png ├── templates │ └── .gitkeep ├── requirements.txt ├── index.rst ├── f2ai_api.rst ├── Makefile ├── make.bat └── conf.py ├── Makefile ├── f2ai ├── cmd │ ├── __init__.py │ └── main.py ├── __init__.py ├── __main__.py ├── example │ └── jaffle_shop_example.py ├── offline_stores │ └── offline_spark_store.py ├── common │ ├── jinja.py │ ├── time_field.py │ ├── read_file.py │ ├── cmd_parser.py │ ├── get_config.py │ ├── oss_utils.py │ ├── collect_fn.py │ └── utils.py ├── persist_engine │ ├── offline_spark_persistengine.py │ ├── online_spark_persistengine.py │ ├── online_local_persistengine.py │ ├── offline_file_persistengine.py │ └── offline_pgsql_persistengine.py ├── definitions │ ├── constants.py │ ├── feature_view.py │ ├── dtypes.py │ ├── label_view.py │ ├── base_view.py │ ├── entities.py │ ├── __init__.py │ ├── backoff_time.py │ ├── sources.py │ ├── online_store.py │ ├── period.py │ ├── features.py │ ├── services.py │ ├── offline_store.py │ └── persist_engine.py ├── dataset │ ├── __init__.py │ ├── dataset.py │ ├── events_sampler.py │ ├── pytorch_dataset.py │ └── entities_sampler.py ├── models │ ├── sequential.py │ ├── normalizer.py │ ├── encoder.py │ ├── earlystop.py │ └── nbeats │ │ ├── model.py │ │ └── submodules.py └── online_stores │ └── online_redis_store.py ├── tests ├── fixtures │ ├── constants.py │ ├── git_utils.py │ ├── credit_score_fixtures.py │ └── guizhou_traffic_fixtures.py ├── units │ ├── offline_stores │ │ ├── postgres_sqls │ │ │ ├── stats_query_unique.sql │ │ │ ├── stats_query_avg.sql │ │ │ ├── stats_query_max.sql │ │ │ ├── stats_query_min.sql │ │ │ ├── stats_query_std.sql │ │ │ ├── stats_query_mode.sql │ │ │ ├── stats_query_median.sql │ │ │ ├── store_stats_query_categorical.sql │ │ │ └── store_stats_query_numeric.sql │ │ ├── offline_postgres_store_test.py │ │ └── offline_file_store_test.py │ ├── definitions │ │ ├── entities_test.py │ │ ├── features_test.py │ │ ├── offline_store_test.py │ │ └── back_off_time_test.py │ ├── common │ │ └── utils_test.py │ ├── service_test.py │ ├── dataset │ │ ├── events_sampler_test.py │ │ └── entities_sampler_test.py │ └── period_test.py ├── conftest.py └── integrations │ └── benchmarks │ ├── offline_pgsql_benchmark_test.py │ └── offline_file_benchmark_test.py ├── setup.cfg ├── requirements.txt ├── dev_requirements.txt ├── DockerFile ├── .readthedocs.yaml ├── .gitignore ├── setup.py ├── .github └── workflows │ └── python-publish.yml ├── README.rst └── use_cases ├── guizhou_traffic_arima.py ├── .ipynb_checkpoints └── credit_score-checkpoint.ipynb └── credit_score └── credit_score.ipynb /docs/static/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/templates/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx-rtd-theme -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | redis: 2 | docker run -d --rm --name redis -p 6379:6379 redis -------------------------------------------------------------------------------- /f2ai/cmd/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import main 2 | 3 | __all__ = ["main"] 4 | -------------------------------------------------------------------------------- /tests/fixtures/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | TEMP_DIR = os.path.expanduser("~/.f2ai") 4 | -------------------------------------------------------------------------------- /f2ai/__init__.py: -------------------------------------------------------------------------------- 1 | from .featurestore import FeatureStore 2 | 3 | __all__ = ["FeatureStore"] 4 | -------------------------------------------------------------------------------- /tests/units/offline_stores/postgres_sqls/stats_query_unique.sql: -------------------------------------------------------------------------------- 1 | SELECT DISTINCT zipcode FROM "zipcode_table" -------------------------------------------------------------------------------- /f2ai/__main__.py: -------------------------------------------------------------------------------- 1 | #! python 2 | 3 | if __name__ == "__main__": 4 | from .cmd import main 5 | 6 | main() 7 | -------------------------------------------------------------------------------- /docs/static/f2ai_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-excelsior/F2AI/HEAD/docs/static/f2ai_architecture.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203,E501,E722,E731,W503,W605 3 | max-line-length = 110 4 | 5 | [tool:pytest] 6 | log_cli=true -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | pytest_plugins = [ 2 | "fixtures.credit_score_fixtures", 3 | "fixtures.guizhou_traffic_fixtures", 4 | ] 5 | -------------------------------------------------------------------------------- /tests/units/offline_stores/postgres_sqls/stats_query_avg.sql: -------------------------------------------------------------------------------- 1 | SELECT "zipcode",AVG(population) "population" FROM "zipcode_table" GROUP BY "zipcode" -------------------------------------------------------------------------------- /tests/units/offline_stores/postgres_sqls/stats_query_max.sql: -------------------------------------------------------------------------------- 1 | SELECT "zipcode",MAX(population) "population" FROM "zipcode_table" GROUP BY "zipcode" -------------------------------------------------------------------------------- /tests/units/offline_stores/postgres_sqls/stats_query_min.sql: -------------------------------------------------------------------------------- 1 | SELECT "zipcode",MIN(population) "population" FROM "zipcode_table" GROUP BY "zipcode" -------------------------------------------------------------------------------- /tests/units/offline_stores/postgres_sqls/stats_query_std.sql: -------------------------------------------------------------------------------- 1 | SELECT "zipcode",STD(population) "population" FROM "zipcode_table" GROUP BY "zipcode" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==1.4.0 2 | pyyaml==6.0 3 | psycopg2-binary==2.9.5 4 | SQLAlchemy==1.4.31 5 | pypika==0.48.9 6 | pydantic==1.10.2 7 | oss2==2.16.0 8 | Jinja2 -------------------------------------------------------------------------------- /tests/units/offline_stores/postgres_sqls/stats_query_mode.sql: -------------------------------------------------------------------------------- 1 | SELECT "zipcode",MODE() WITHIN GROUP (ORDER BY population) as population FROM "zipcode_table" GROUP BY "zipcode" -------------------------------------------------------------------------------- /dev_requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==1.4.0 2 | pyyaml==6.0 3 | psycopg2-binary==2.9.5 4 | torch==1.11.0 5 | SQLAlchemy==1.4.31 6 | pypika==0.48.9 7 | pydantic==1.10.2 8 | oss2==2.16.0 9 | Jinja2 -------------------------------------------------------------------------------- /tests/units/offline_stores/postgres_sqls/stats_query_median.sql: -------------------------------------------------------------------------------- 1 | SELECT "zipcode",PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY population) as population FROM "zipcode_table" GROUP BY "zipcode" -------------------------------------------------------------------------------- /tests/units/definitions/entities_test.py: -------------------------------------------------------------------------------- 1 | from f2ai.definitions import Entity 2 | 3 | 4 | def test_entity_auto_join_keys(): 5 | entity = Entity(name="foo") 6 | assert entity.join_keys == ["foo"] 7 | -------------------------------------------------------------------------------- /tests/units/offline_stores/postgres_sqls/store_stats_query_categorical.sql: -------------------------------------------------------------------------------- 1 | SELECT DISTINCT zipcode FROM "zipcode_table" WHERE "event_timestamp">='2017-01-01T00:00:00' AND "event_timestamp"<='2018-01-01T00:00:00' -------------------------------------------------------------------------------- /f2ai/example/jaffle_shop_example.py: -------------------------------------------------------------------------------- 1 | from f2ai.featurestore import FeatureStore 2 | 3 | 4 | feature_store = FeatureStore( 5 | project_folder="/Users/xuyizhou/Desktop/xyz_warehouse/gitlab/f2ai-credit-scoring" 6 | ) 7 | -------------------------------------------------------------------------------- /f2ai/offline_stores/offline_spark_store.py: -------------------------------------------------------------------------------- 1 | from ..definitions import OfflineStore, OfflineStoreType 2 | 3 | 4 | class OfflineSparkStore(OfflineStore): 5 | type: OfflineStoreType = OfflineStoreType.SPARK 6 | pass 7 | -------------------------------------------------------------------------------- /f2ai/common/jinja.py: -------------------------------------------------------------------------------- 1 | import os 2 | from jinja2 import Environment, FileSystemLoader, select_autoescape 3 | 4 | jinja_env = Environment(loader=FileSystemLoader(os.path.join(os.path.dirname(__file__), "../templates")), autoescape=select_autoescape) -------------------------------------------------------------------------------- /tests/units/offline_stores/postgres_sqls/store_stats_query_numeric.sql: -------------------------------------------------------------------------------- 1 | SELECT "zipcode",AVG(population) "population" FROM "zipcode_table" WHERE "event_timestamp">='2017-01-01T00:00:00' AND "event_timestamp"<='2018-01-01T00:00:00' GROUP BY "zipcode" -------------------------------------------------------------------------------- /DockerFile: -------------------------------------------------------------------------------- 1 | FROM python:3.8 2 | 3 | WORKDIR /aie-feast 4 | RUN pip install --upgrade pip 5 | 6 | COPY ./requirements.txt /aie-feast/ 7 | RUN pip install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt 8 | 9 | COPY ./aie-feast /aie-feast -------------------------------------------------------------------------------- /f2ai/persist_engine/offline_spark_persistengine.py: -------------------------------------------------------------------------------- 1 | from ..definitions import OfflinePersistEngine, OfflinePersistEngineType 2 | 3 | 4 | class OfflineSparkPersistEngine(OfflinePersistEngine): 5 | type: OfflinePersistEngine = OfflinePersistEngineType.SPARK 6 | -------------------------------------------------------------------------------- /f2ai/persist_engine/online_spark_persistengine.py: -------------------------------------------------------------------------------- 1 | from f2ai.definitions import OnlinePersistEngine, OnlinePersistEngineType 2 | 3 | 4 | class OnlineSparkPersistEngine(OnlinePersistEngine): 5 | type: OnlinePersistEngine = OnlinePersistEngineType.DISTRIBUTE 6 | -------------------------------------------------------------------------------- /f2ai/common/time_field.py: -------------------------------------------------------------------------------- 1 | DEFAULT_EVENT_TIMESTAMP_FIELD = "event_timestamp" 2 | ENTITY_EVENT_TIMESTAMP_FIELD = "_entity_event_timestamp_" 3 | SOURCE_EVENT_TIMESTAMP_FIELD = "_source_event_timestamp_" 4 | QUERY_COL = "query_timestamp" 5 | TIME_COL = "event_timestamp" 6 | MATERIALIZE_TIME = "materialize_time" 7 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: "ubuntu-20.04" 5 | tools: 6 | python: "3.8" 7 | 8 | sphinx: 9 | configuration: ./docs/conf.py 10 | fail_on_warning: true 11 | 12 | python: 13 | install: 14 | - requirements: requirements.txt 15 | - requirements: docs/requirements.txt -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. toctree:: 2 | :hidden: 3 | :caption: Getting Started 4 | 5 | self 6 | 7 | .. toctree:: 8 | :hidden: 9 | :caption: MISCELLANEOUS 10 | 11 | f2ai_api 12 | genindex 13 | 14 | Welcome to F2AI's documentation! 15 | ================================ 16 | 17 | Hello World -------------------------------------------------------------------------------- /f2ai/definitions/constants.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from enum import Enum 3 | 4 | LOCAL_TIMEZONE = datetime.datetime.now().astimezone().tzinfo 5 | 6 | 7 | class StatsFunctions(Enum): 8 | MIN = "min" 9 | MAX = "max" 10 | STD = "std" 11 | AVG = "avg" 12 | MODE = "mode" 13 | MEDIAN = "median" 14 | UNIQUE = "unique" 15 | -------------------------------------------------------------------------------- /tests/units/common/utils_test.py: -------------------------------------------------------------------------------- 1 | from f2ai.common.utils import batched 2 | 3 | 4 | def test_batched(): 5 | xs = range(10) 6 | batches = next(batched(xs, batch_size=3)) 7 | assert batches == [0, 1, 2] 8 | 9 | last_batch = [] 10 | for batch in batched(xs, batch_size=3): 11 | last_batch = batch 12 | assert last_batch == [9] 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.log 2 | *.pyc 3 | *.bk 4 | *.pdf 5 | *.h5 6 | *.dump 7 | .env 8 | *.pk 9 | *.test.py 10 | !aiefeast 11 | **/*.pyc 12 | **/**/*.pyc 13 | .vscode 14 | .DS_Store 15 | *.pid 16 | **/*log.* 17 | f2ai-credit-scoring-main 18 | guizhou_traffic-main 19 | guizhou_traffic-ver_pgsql 20 | docs/_* 21 | /_* 22 | 1f/ 23 | 3f/ 24 | scripts/ 25 | egg* 26 | *.egg-info 27 | /build -------------------------------------------------------------------------------- /tests/units/definitions/features_test.py: -------------------------------------------------------------------------------- 1 | from f2ai.definitions import SchemaAnchor 2 | 3 | 4 | def test_parse_cfg_to_feature_anchor(): 5 | feature_anchor = SchemaAnchor.from_str("fv1:f1") 6 | assert feature_anchor.view_name == "fv1" 7 | assert feature_anchor.schema_name == "f1" 8 | assert feature_anchor.period is None 9 | 10 | feature_anchor = SchemaAnchor.from_str("fv1:f1:1 day") 11 | assert feature_anchor.period == "1 day" 12 | -------------------------------------------------------------------------------- /tests/units/definitions/offline_store_test.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from f2ai.definitions import init_offline_store_from_cfg 3 | 4 | offline_file_store_yaml_string = """ 5 | type: file 6 | """ 7 | 8 | 9 | def test_init_file_offline_store(): 10 | from f2ai.offline_stores.offline_file_store import OfflineFileStore 11 | 12 | offline_store = init_offline_store_from_cfg(yaml.safe_load(offline_file_store_yaml_string), 'test') 13 | assert isinstance(offline_store, OfflineFileStore) 14 | -------------------------------------------------------------------------------- /tests/units/definitions/back_off_time_test.py: -------------------------------------------------------------------------------- 1 | from f2ai.definitions import BackOffTime 2 | 3 | 4 | def test_back_off_time_to_units(): 5 | back_off_time = BackOffTime(start="2020-08-01 12:20", end="2020-10-30", step="1 month") 6 | units = list(back_off_time.to_units()) 7 | assert len(units) == 3 8 | 9 | back_off_time = BackOffTime(start="2020-08-01 12:20", end="2020-08-02", step="2 hours") 10 | units = list(back_off_time.to_units()) 11 | assert len(units) == 6 12 | -------------------------------------------------------------------------------- /docs/f2ai_api.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ================== 3 | 4 | 5 | featurestore 6 | ------------------------------ 7 | 8 | .. automodule:: f2ai.featurestore 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: 12 | 13 | 14 | dataset 15 | ------------------------------ 16 | 17 | .. automodule:: f2ai.dataset 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | 23 | definitions 24 | ----------------------------- 25 | 26 | .. automodule:: f2ai.definitions 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /tests/units/service_test.py: -------------------------------------------------------------------------------- 1 | from f2ai.definitions import SchemaAnchor, FeatureSchema, FeatureView, Service 2 | 3 | 4 | def test_get_features_from_service(): 5 | service = Service( 6 | name="foo", 7 | features=[ 8 | SchemaAnchor(view_name="fv", schema_name="*"), 9 | SchemaAnchor(view_name="fv", schema_name="foo"), 10 | ], 11 | ) 12 | feature_views = {"fv": FeatureView(name="fv", schema=[FeatureSchema(name="f1", dtype="string")])} 13 | 14 | features = service.get_feature_objects(feature_views) 15 | 16 | assert len(features) == 1 17 | assert list(features)[0].name == "f1" 18 | -------------------------------------------------------------------------------- /f2ai/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import Dataset 2 | from .pytorch_dataset import TorchIterableDataset 3 | from .events_sampler import ( 4 | EventsSampler, 5 | EvenEventsSampler, 6 | RandomNEventsSampler, 7 | ) 8 | from .entities_sampler import ( 9 | EntitiesSampler, 10 | NoEntitiesSampler, 11 | EvenEntitiesSampler, 12 | FixedNEntitiesSampler, 13 | ) 14 | 15 | __all__ = [ 16 | "Dataset", 17 | "TorchIterableDataset", 18 | "EventsSampler", 19 | "EvenEventsSampler", 20 | "RandomNEventsSampler", 21 | "EntitiesSampler", 22 | "NoEntitiesSampler", 23 | "EvenEntitiesSampler", 24 | "FixedNEntitiesSampler", 25 | ] 26 | -------------------------------------------------------------------------------- /tests/units/dataset/events_sampler_test.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from f2ai.dataset import ( 3 | EvenEventsSampler, 4 | RandomNEventsSampler, 5 | ) 6 | 7 | 8 | def test_even_events_sampler(): 9 | sampler = EvenEventsSampler(start="2022-10-02", end="2022-12-02", period="1 day") 10 | 11 | assert len(sampler()) == 62 12 | assert next(iter(sampler)) == pd.Timestamp("2022-10-02 00:00:00") 13 | 14 | 15 | def test_random_n_events_sampler(): 16 | sampler = RandomNEventsSampler(start="2022-10-02", end="2022-12-02", period="1 day", n=2, random_state=666) 17 | 18 | assert len(sampler()) == 2 19 | assert next(iter(sampler)) == pd.Timestamp("2022-10-05 00:00:00") 20 | -------------------------------------------------------------------------------- /f2ai/definitions/feature_view.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | from .base_view import BaseView 5 | from .features import Feature 6 | 7 | 8 | class FeatureView(BaseView): 9 | def get_feature_names(self) -> List[str]: 10 | return [feature.name for feature in self.schemas] 11 | 12 | def get_feature_objects(self, is_numeric=False) -> List[Feature]: 13 | return list( 14 | dict.fromkeys( 15 | [ 16 | Feature.create_feature_from_schema(schema, self.name) 17 | for schema in self.schemas 18 | if (schema.is_numeric() if is_numeric else True) 19 | ] 20 | ) 21 | ) 22 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /f2ai/definitions/dtypes.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class FeatureDTypes(str, Enum): 5 | """Feature data type definitions which supported by F2AI. Used to convert to a certain data type for different algorithm frameworks. Not useful now, but future.""" 6 | 7 | INT = "int" 8 | INT32 = "int32" 9 | INT64 = "int64" 10 | FLOAT = "float" 11 | FLOAT32 = "float32" 12 | FLOAT64 = "float64" 13 | STRING = "string" 14 | BOOLEAN = "bool" 15 | UNKNOWN = "unknown" 16 | 17 | 18 | NUMERIC_FEATURE_D_TYPES = { 19 | FeatureDTypes.INT, 20 | FeatureDTypes.INT32, 21 | FeatureDTypes.INT64, 22 | FeatureDTypes.FLOAT, 23 | FeatureDTypes.FLOAT32, 24 | FeatureDTypes.FLOAT64, 25 | } 26 | -------------------------------------------------------------------------------- /f2ai/definitions/label_view.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | from .base_view import BaseView 4 | from .features import Feature 5 | 6 | 7 | class LabelView(BaseView): 8 | request_source: Optional[str] 9 | 10 | def get_label_names(self): 11 | return [label.name for label in self.schemas] 12 | 13 | def get_label_objects(self, is_numeric=False) -> List[Feature]: 14 | return list( 15 | dict.fromkeys( 16 | [ 17 | Feature.create_label_from_schema(schema, self.name) 18 | for schema in self.schemas 19 | if (schema.is_numeric() if is_numeric else True) 20 | ] 21 | ) 22 | ) 23 | -------------------------------------------------------------------------------- /tests/fixtures/git_utils.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def git_clone(cwd: str, repo: str, branch: str): 5 | return subprocess.run( 6 | [ 7 | "git", 8 | "clone", 9 | "--branch", 10 | branch, 11 | repo, 12 | cwd, 13 | ], 14 | check=False, 15 | ) 16 | 17 | 18 | def git_reset(cwd: str, branch: str, mode: str = "hard"): 19 | return subprocess.run(["git", "reset", f"--{mode}", branch], cwd=cwd, check=False) 20 | 21 | 22 | def git_clean(cwd: str): 23 | return subprocess.run(["git", "clean", "-df"], cwd=cwd, check=False) 24 | 25 | 26 | def git_pull(cwd: str, branch: str): 27 | return subprocess.run(["git", "pull", "--rebase", "origin", branch], cwd=cwd, check=False) 28 | -------------------------------------------------------------------------------- /f2ai/definitions/base_view.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Dict, Any 2 | from pydantic import BaseModel, Field 3 | 4 | from .features import FeatureSchema 5 | from .period import Period 6 | 7 | 8 | class BaseView(BaseModel): 9 | """Abstraction of common part of FeatureView and LabelView.""" 10 | 11 | name: str 12 | description: Optional[str] 13 | entities: List[str] = [] 14 | schemas: List[FeatureSchema] = Field(alias="schema", default=[]) 15 | batch_source: Optional[str] 16 | ttl: Optional[Period] 17 | tags: Dict[str, str] = {} 18 | 19 | def __init__(__pydantic_self__, **data: Any) -> None: 20 | if isinstance(data.get("ttl", None), str): 21 | data["ttl"] = Period.from_str(data.get("ttl", None)) 22 | super().__init__(**data) 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_namespace_packages 2 | from pip._internal.req import parse_requirements 3 | from pip._internal.network.session import PipSession 4 | 5 | install_requirements = parse_requirements("requirements.txt", session=PipSession()) 6 | 7 | setup( 8 | name="f2ai", 9 | version="0.0.4", 10 | description="A Feature Store tool focus on making retrieve features easily in machine learning.", 11 | url="https://github.com/ai-excelsior/F2AI", 12 | author="上海半见", 13 | license="MIT", 14 | zip_safe=False, 15 | packages=find_namespace_packages(exclude=["docs", "tests", "tests.*", "use_cases", "*.egg-info"]), 16 | install_requires=[str(x.requirement) for x in install_requirements], 17 | entry_points={"console_scripts": ["f2ai = f2ai.cmd:main"]}, 18 | ) 19 | -------------------------------------------------------------------------------- /f2ai/definitions/entities.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import List, Any, Optional 3 | 4 | 5 | class Entity(BaseModel): 6 | """An entity is a key connection between different feature views. Under the hook, we use join_keys to join the feature views. If join key is empty, we will take the name as a default join key""" 7 | 8 | name: str 9 | description: Optional[str] 10 | join_keys: List[str] = [] 11 | 12 | def __init__(__pydantic_self__, **data: Any) -> None: 13 | 14 | join_keys = data.pop("join_keys", []) 15 | if len(join_keys) == 0: 16 | join_keys = [data.get("name")] 17 | 18 | super().__init__(**data, join_keys=join_keys) 19 | 20 | def __hash__(self) -> int: 21 | s = ",".join(self.join_keys) 22 | return hash(f"{self.name}:{s}") 23 | -------------------------------------------------------------------------------- /f2ai/common/read_file.py: -------------------------------------------------------------------------------- 1 | # from .oss_utils import get_bucket_from_oss_url 2 | from typing import Dict 3 | from .utils import remove_prefix 4 | import yaml 5 | 6 | 7 | def read_yml(url: str) -> Dict: 8 | """read .yml file for following execute 9 | 10 | Args: 11 | url (str): url of .yml 12 | """ 13 | file = _read_file(url) 14 | cfg = yaml.load(file, Loader=yaml.FullLoader) 15 | return cfg 16 | 17 | 18 | def _read_file(url): 19 | if url.startswith("file://"): 20 | with open(remove_prefix(url, "file://"), "r") as file: 21 | return file.read() 22 | # default behavior 23 | else: 24 | with open(url, "r") as file: 25 | return file.read() 26 | # elif url.startswith("oss://"): # TODO: may not be correct 27 | # bucket, key = get_bucket_from_oss_url(url) 28 | # return bucket.get_object(key).read() 29 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | 37 | livehtml: 38 | sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /f2ai/common/cmd_parser.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | def add_materialize_parser(subparsers): 5 | parser: ArgumentParser = subparsers.add_parser( 6 | "materialize", help="materialize a service to offline (default) or online" 7 | ) 8 | 9 | parser.add_argument( 10 | "services", 11 | type=str, 12 | nargs="+", 13 | help="at least one service name, multi service name using space to separate.", 14 | ) 15 | parser.add_argument( 16 | "--online", action="store_true", help="materialize service to online store if presents." 17 | ) 18 | parser.add_argument("--fromnow", type=str, help="materialize start time point from now, egg: 7 days.") 19 | parser.add_argument( 20 | "--start", type=str, help="materialize start time point, egg: 2022-10-22, or 2022-11-22T10:12." 21 | ) 22 | parser.add_argument( 23 | "--end", type=str, help="materialize end time point, egg: 2022-10-22, or 2022-11-22T10:12." 24 | ) 25 | parser.add_argument("--step", type=str, default="1 day", help="how to split materialize task") 26 | parser.add_argument("--tz", type=str, default=None, help="timezone of start and end, default to None") 27 | -------------------------------------------------------------------------------- /f2ai/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import TYPE_CHECKING 3 | 4 | if TYPE_CHECKING: 5 | from ..definitions.services import Service 6 | from ..featurestore import FeatureStore 7 | from .pytorch_dataset import TorchIterableDataset 8 | 9 | 10 | class Dataset: 11 | """ 12 | A dataset is an abstraction, which hold a service and a sampler. 13 | A service basic tells us, where is the data. A sampler tells us which parts of the data should be retrieved. 14 | 15 | Note: We should not construct Dataset by ourself. using `store.get_dataset()` is recommended. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | fs: "FeatureStore", 21 | service: "Service", 22 | sampler: callable, 23 | ): 24 | self.fs = fs 25 | self.service = service 26 | self.sampler = sampler 27 | 28 | def to_pytorch(self, chunk_size: int = 64) -> "TorchIterableDataset": 29 | """convert to iterable pytorch dataset really hold data""" 30 | from .pytorch_dataset import TorchIterableDataset 31 | 32 | return TorchIterableDataset(feature_store=self.fs, service=self.service, sampler=self.sampler, chunk_size=chunk_size) 33 | -------------------------------------------------------------------------------- /f2ai/models/sequential.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | 6 | class SimpleClassify(nn.Module): 7 | def __init__(self, cont_nbr, cat_nbr, emd_dim, max_types, hidden_size=16, drop_out=0.1) -> None: 8 | super().__init__() 9 | # num_embeddings not less than type 10 | self.categorical_embedding = nn.Embedding(num_embeddings=max_types, embedding_dim=emd_dim) 11 | hidden_list = 2 ** np.arange(max(np.ceil(np.log2(hidden_size)), 2) + 1)[::-1] 12 | model_list = [nn.Linear(cont_nbr + cat_nbr * emd_dim, int(hidden_list[0]))] 13 | for i in range(len(hidden_list) - 1): 14 | model_list.append(nn.Dropout(drop_out)) 15 | model_list.append(nn.Linear(int(hidden_list[i]), int(hidden_list[i + 1]))) 16 | 17 | self.model = nn.Sequential(*model_list, nn.Sigmoid()) 18 | 19 | def forward(self, x): 20 | cat_vector = [] 21 | for i in range(x["categorical_features"].shape[-1]): 22 | cat = self.categorical_embedding(x["categorical_features"][..., i]) 23 | cat_vector.append(cat) 24 | cat_vector = torch.cat(cat_vector, dim=-1) 25 | input_vector = torch.cat([cat_vector, x["continous_features"]], dim=-1) 26 | return self.model(input_vector) 27 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | push: 13 | tags: 14 | - release 15 | 16 | permissions: 17 | contents: read 18 | 19 | jobs: 20 | deploy: 21 | 22 | runs-on: ubuntu-latest 23 | 24 | steps: 25 | - uses: actions/checkout@v3 26 | - name: Set up Python 27 | uses: actions/setup-python@v3 28 | with: 29 | python-version: '3.x' 30 | - name: Install dependencies 31 | run: | 32 | python -m pip install --upgrade pip 33 | pip install setuptools wheel twine 34 | - name: build 35 | run: | 36 | python setup.py sdist bdist_wheel 37 | - name: Publish package 38 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 39 | with: 40 | user: __token__ 41 | password: ${{ secrets.PYPI_API_TOKEN }} 42 | -------------------------------------------------------------------------------- /tests/fixtures/credit_score_fixtures.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | 4 | from .git_utils import git_clean, git_clone, git_reset, git_pull 5 | from .constants import TEMP_DIR 6 | 7 | CREDIT_SCORE_CFG = { 8 | "repo": "git@code.unianalysis.com:f2ai-examples/f2ai-credit-scoring.git", 9 | "infra": { 10 | "file": { 11 | "cwd": os.path.join(TEMP_DIR, "f2ai-credit-scoring_file"), 12 | "branch": "main", 13 | }, 14 | "pgsql": { 15 | "cwd": os.path.join(TEMP_DIR, "f2ai-credit-scoring_pgsql"), 16 | "branch": "ver_pgsql", 17 | }, 18 | }, 19 | } 20 | 21 | 22 | @pytest.fixture(scope="session") 23 | def make_credit_score(): 24 | if not os.path.exists(TEMP_DIR): 25 | os.makedirs(TEMP_DIR) 26 | 27 | def get_credit_score(infra="file"): 28 | repo = CREDIT_SCORE_CFG["repo"] 29 | infra = CREDIT_SCORE_CFG["infra"][infra] 30 | cwd = infra["cwd"] 31 | branch = infra["branch"] 32 | 33 | # clone repo to cwd 34 | if not os.path.isdir(cwd): 35 | git_clone(cwd, repo, branch) 36 | 37 | # reset, clean and update to latest 38 | git_reset(cwd, branch) 39 | git_clean(cwd) 40 | git_pull(cwd, branch) 41 | 42 | return cwd 43 | 44 | yield get_credit_score 45 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | import os 9 | import sys 10 | 11 | docs_dir = os.path.dirname(os.path.realpath(__file__)) 12 | project_dir = os.path.realpath(os.path.join(docs_dir, "..")) 13 | sys.path.append(project_dir) 14 | print(f'Append sys path: {project_dir}') 15 | 16 | project = "F2AI" 17 | copyright = "2022, eavae" 18 | author = "eavae" 19 | 20 | # -- General configuration --------------------------------------------------- 21 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 22 | 23 | extensions = ["sphinx.ext.napoleon", "sphinx.ext.autodoc"] 24 | templates_path = ["templates"] 25 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 26 | html_sidebars = {"**": ["globaltoc.html"]} 27 | 28 | # -- Options for HTML output ------------------------------------------------- 29 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 30 | 31 | html_theme = "sphinx_rtd_theme" 32 | html_static_path = ["static"] 33 | -------------------------------------------------------------------------------- /tests/fixtures/guizhou_traffic_fixtures.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | 4 | from .git_utils import git_clean, git_clone, git_reset, git_pull 5 | from .constants import TEMP_DIR 6 | 7 | 8 | GUIZHOU_TRAFFIC_CFG = { 9 | "repo": "git@code.unianalysis.com:f2ai-examples/guizhou_traffic.git", 10 | "infra": { 11 | "file": { 12 | "cwd": os.path.join(TEMP_DIR, "f2ai-guizhou_traffic_file"), 13 | "branch": "main", 14 | }, 15 | "pgsql": { 16 | "cwd": os.path.join(TEMP_DIR, "f2ai-guizhou_traffic_pgsql"), 17 | "branch": "ver_pgsql", 18 | }, 19 | }, 20 | } 21 | 22 | 23 | @pytest.fixture(scope="session") 24 | def make_guizhou_traffic(): 25 | if not os.path.exists(TEMP_DIR): 26 | os.makedirs(TEMP_DIR) 27 | 28 | def get_guizhou_traffic(infra="file"): 29 | repo = GUIZHOU_TRAFFIC_CFG["repo"] 30 | infra = GUIZHOU_TRAFFIC_CFG["infra"][infra] 31 | cwd = infra["cwd"] 32 | branch = infra["branch"] 33 | 34 | # clone repo to cwd 35 | if not os.path.isdir(cwd): 36 | git_clone(cwd, repo, branch) 37 | 38 | # reset, clean and update to latest 39 | git_reset(cwd, branch) 40 | git_clean(cwd) 41 | git_pull(cwd, branch) 42 | 43 | return cwd 44 | 45 | yield get_guizhou_traffic 46 | -------------------------------------------------------------------------------- /tests/units/period_test.py: -------------------------------------------------------------------------------- 1 | from f2ai.definitions import Period 2 | from pandas import DateOffset 3 | 4 | 5 | def test_period_to_pandas_dateoffset(): 6 | offset = Period(n=1, unit="day").to_pandas_dateoffset() 7 | assert offset == DateOffset(days=1) 8 | 9 | 10 | def test_period_to_sql_interval(): 11 | interval = Period(n=1, unit="day").to_pgsql_interval() 12 | assert interval == "interval '1 days'" 13 | 14 | 15 | def test_period_from_str(): 16 | ten_years = Period.from_str("10 years").to_pandas_dateoffset() 17 | one_day = Period.from_str("1day").to_pandas_dateoffset() 18 | 19 | assert ten_years == DateOffset(years=10) 20 | assert one_day == DateOffset(days=1) 21 | 22 | 23 | def test_period_negative(): 24 | ten_years = Period.from_str("10 years") 25 | neg_ten_years = -ten_years 26 | assert neg_ten_years.n == -10 27 | 28 | neg_ten_years = Period.from_str("-10 years") 29 | assert neg_ten_years.n == -10 30 | 31 | 32 | def test_get_pandas_datetime_components(): 33 | three_minutes = Period.from_str("3 minutes") 34 | components = three_minutes.get_pandas_datetime_components() 35 | assert components == ["year", "month", "day", "hour", "minute"] 36 | 37 | one_month = Period.from_str("1 month") 38 | components = one_month.get_pandas_datetime_components() 39 | assert components == ["year", "month"] 40 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | F2AI 2 | ==== 3 | 4 | .. image:: https://readthedocs.org/projects/f2ai/badge/?version=latest 5 | :target: https://f2ai.readthedocs.io/en/latest/?badge=latest 6 | :alt: Documentation Status 7 | 8 | The architecture above is a working flow demonstration when using F2AI but not technical architecture. 9 | 10 | Getting Started 11 | --------------- 12 | 13 | 1. Install F2AI 14 | 15 | 16 | Overview 17 | ------------- 18 | 19 | F2AI (Feature Store to AI) is a time centric productivity data utils that re-uses existing infrastructure to get features more consistently though different stages of AI development. 20 | 21 | F2AI is focusing on: 22 | 23 | * **Consistent API to get features for training and serving**: Powered by well encapsulated OfflineStore and OnlineStore, the features are strictly managed by F2AI, and keep the same structure not only when training models, but also inference. 24 | * **Prevent feature leakage** by effective point-in-time join that reduce the cumbersome works to get features ready when experimenting a feature combinations or AI model. 25 | * **Build-in supported period feature** that allows 2 dimensional features can be retrieved easily. This is useful when facing some deep learning tasks like, time series forecasting. 26 | * **Infrastructure unawareness** by different infrastructure implementations, switching between different data storage is simply changing the configuration. The AI model will works well like before. 27 | 28 | Architecture 29 | ------------ 30 | 31 | .. image:: ./docs/static/f2ai_architecture.png 32 | :alt: f2ai function architecture 33 | 34 | .. note:: 35 | This is a working flow when using F2AI instead of technical architecture. 36 | -------------------------------------------------------------------------------- /f2ai/cmd/main.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from argparse import ArgumentParser 3 | from typing import List 4 | 5 | from ..common.cmd_parser import add_materialize_parser 6 | from ..definitions import BackOffTime 7 | from ..featurestore import FeatureStore 8 | 9 | 10 | def materialize(url: str, services: List[str], back_off_time: BackOffTime, online: bool): 11 | fs = FeatureStore(url) 12 | 13 | fs.materialize(services, back_off_time, online) 14 | 15 | 16 | def main(): 17 | parser = ArgumentParser(prog="F2AI") 18 | subparsers = parser.add_subparsers(title="subcommands", dest="commands") 19 | add_materialize_parser(subparsers) 20 | 21 | kwargs = vars(parser.parse_args()) 22 | commands = kwargs.pop("commands", None) 23 | if commands == "materialize": 24 | if kwargs["fromnow"] is None and kwargs["start"] is None and kwargs["end"] is None: 25 | parser.error("One of fromnow or start&end is required.") 26 | 27 | if not pathlib.Path("feature_store.yml").exists(): 28 | parser.error( 29 | "No feature_store.yml found in current folder, please switch to folder which feature_store.yml exists." 30 | ) 31 | 32 | if commands == "materialize": 33 | from_now = kwargs.pop("fromnow", None) 34 | step = kwargs.pop("step", None) 35 | tz = kwargs.pop("tz", None) 36 | 37 | if from_now is not None: 38 | back_off_time = BackOffTime.from_now(from_now=from_now, step=step, tz=tz) 39 | else: 40 | back_off_time = BackOffTime(start=kwargs.pop("start"), end=kwargs.pop("end"), step=step, tz=tz) 41 | 42 | materialize("file://.", kwargs.pop("services"), back_off_time, kwargs.pop("online")) 43 | -------------------------------------------------------------------------------- /f2ai/models/normalizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | class MinMaxNormalizer: 6 | """scale continous features to [-1,1],""" 7 | 8 | def __init__(self, feature_name=None): 9 | self.feature_name = feature_name 10 | 11 | def fit_self(self, data: pd.Series): 12 | min = data.min() 13 | max = data.max() 14 | self._state = ((min + max) / 2.0, max - min + 1e-8) 15 | return self 16 | 17 | def transform_self(self, data: pd.Series) -> pd.Series: 18 | assert self._state is not None 19 | center, scale = self._state 20 | 21 | return (data - center) / scale 22 | 23 | def inverse_transform_self(self, data: pd.Series) -> pd.Series: 24 | assert self._state is not None 25 | center, scale = self._state 26 | 27 | return data * scale + center 28 | 29 | 30 | class StandardNormalizer: 31 | def __init__(self, feature_name=None, center=True): 32 | self.feature_name = feature_name 33 | self._center = center 34 | self._eps = 1e-6 35 | 36 | def fit_self(self, data: pd.Series, source: pd.DataFrame = None, **kwargs): 37 | if self._center: 38 | self._state = (data.mean(), data.std() + self._eps) 39 | else: 40 | self._state = (0.0, data.mean() + self._eps) 41 | 42 | def transform_self(self, data: pd.Series) -> pd.Series: 43 | assert self._state is not None 44 | center, scale = self._state 45 | 46 | return (data - center) / scale 47 | 48 | def inverse_transform_self(self, data: pd.Series) -> pd.Series: 49 | assert self._state is not None 50 | center, scale = self._state 51 | 52 | return data * scale + center 53 | -------------------------------------------------------------------------------- /f2ai/models/encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | def column_or_1d(y, warn): 6 | y = np.asarray(y) 7 | shape = np.shape(y) 8 | if len(shape) == 1: 9 | return np.ravel(y) 10 | elif len(shape) == 2 and shape[1] == 1: 11 | return np.ravel(y) 12 | 13 | raise ValueError("y should be a 1d array, got an array of shape {} instead.".format(shape)) 14 | 15 | 16 | def map_to_integer(values, uniques): 17 | """Map values based on its position in uniques.""" 18 | table = {val: i for i, val in enumerate(uniques)} 19 | return np.array([table[v] if v in table else table["UNKNOWN_CAT"] for v in values]) 20 | 21 | 22 | class LabelEncoder: 23 | """encode categorical features into int number by their appearance order, 24 | will always set 0 to be UNKNOWN_CAT automatically 25 | """ 26 | 27 | def __init__(self, feature_name=None): 28 | self.feature_name = feature_name 29 | 30 | def fit_self(self, y: pd.Series): 31 | y = column_or_1d(y, warn=True) 32 | self._state = ["UNKNOWN_CAT"] + sorted(set(y)) 33 | return self 34 | 35 | def transform_self(self, y: pd.Series) -> pd.Series: 36 | y = column_or_1d(y, warn=True) 37 | if len(y) == 0: 38 | return np.array([]) 39 | y = map_to_integer(y, self._state) 40 | return y 41 | 42 | def inverse_transform_self(self, y: pd.Series) -> pd.Series: 43 | y = column_or_1d(y, warn=True) 44 | if len(y) == 0: 45 | return np.array([]) 46 | diff = np.setdiff1d(y, np.arange(len(self._state))) 47 | if len(diff): 48 | raise ValueError("y contains previously unseen labels: %s" % str(diff)) 49 | y = np.asarray(y) 50 | return [self._state[i] for i in y] 51 | -------------------------------------------------------------------------------- /f2ai/persist_engine/online_local_persistengine.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from ..definitions import ( 4 | Period, 5 | BackOffTime, 6 | OnlinePersistEngine, 7 | OnlinePersistEngineType, 8 | PersistFeatureView, 9 | ) 10 | 11 | from ..common.time_field import DEFAULT_EVENT_TIMESTAMP_FIELD 12 | 13 | 14 | class OnlineLocalPersistEngine(OnlinePersistEngine): 15 | type: OnlinePersistEngineType = OnlinePersistEngineType.LOCAL 16 | 17 | def materialize( 18 | self, 19 | prefix: str, 20 | feature_view: PersistFeatureView, 21 | back_off_time: BackOffTime, 22 | view_name: str, 23 | ): 24 | date_df = pd.DataFrame(data=[back_off_time.end], columns=[DEFAULT_EVENT_TIMESTAMP_FIELD]) 25 | period = -Period.from_str(str(back_off_time.end - back_off_time.start)) 26 | entities_in_range = self.offline_store.get_latest_entities( 27 | source=feature_view.source, 28 | group_keys=feature_view.join_keys, 29 | entity_df=date_df, 30 | start=back_off_time.start, 31 | ).drop(columns=DEFAULT_EVENT_TIMESTAMP_FIELD) 32 | 33 | data_to_write = self.offline_store.get_period_features( 34 | entity_df=pd.merge(entities_in_range, date_df, how="cross"), 35 | features=feature_view.features, 36 | source=feature_view.source, 37 | period=period, 38 | join_keys=feature_view.join_keys, 39 | ttl=feature_view.ttl, 40 | ) 41 | self.online_store.write_batch( 42 | feature_view.name, 43 | prefix, 44 | data_to_write, 45 | feature_view.ttl, 46 | join_keys=feature_view.join_keys, 47 | tz=back_off_time.tz, 48 | ) 49 | return view_name 50 | -------------------------------------------------------------------------------- /f2ai/definitions/__init__.py: -------------------------------------------------------------------------------- 1 | from .entities import Entity 2 | from .features import Feature, FeatureSchema, SchemaAnchor 3 | from .period import Period 4 | from .base_view import BaseView 5 | from .feature_view import FeatureView 6 | from .label_view import LabelView 7 | from .services import Service 8 | from .sources import Source, FileSource, SqlSource, parse_source_yaml 9 | from .offline_store import OfflineStore, OfflineStoreType, init_offline_store_from_cfg 10 | from .online_store import OnlineStore, OnlineStoreType, init_online_store_from_cfg 11 | from .constants import LOCAL_TIMEZONE, StatsFunctions 12 | from .backoff_time import BackOffTime 13 | from .persist_engine import ( 14 | PersistFeatureView, 15 | PersistLabelView, 16 | PersistEngine, 17 | OfflinePersistEngine, 18 | OnlinePersistEngine, 19 | OfflinePersistEngineType, 20 | OnlinePersistEngineType, 21 | init_persist_engine_from_cfg, 22 | ) 23 | from .dtypes import FeatureDTypes 24 | 25 | __all__ = [ 26 | "Entity", 27 | "Feature", 28 | "FeatureSchema", 29 | "SchemaAnchor", 30 | "Period", 31 | "BaseView", 32 | "FeatureView", 33 | "LabelView", 34 | "Service", 35 | "Source", 36 | "FileSource", 37 | "SqlSource", 38 | "OfflineStoreType", 39 | "OfflineStore", 40 | "parse_source_yaml", 41 | "init_offline_store_from_cfg", 42 | "LOCAL_TIMEZONE", 43 | "StatsFunctions", 44 | "FeatureDTypes", 45 | "OnlineStoreType", 46 | "OnlineStore", 47 | "init_online_store_from_cfg", 48 | "BackOffTime", 49 | "PersistFeatureView", 50 | "PersistLabelView", 51 | "PersistEngine", 52 | "OnlinePersistEngine", 53 | "OfflinePersistEngine", 54 | "OfflinePersistEngineType", 55 | "OnlinePersistEngineType", 56 | "init_persist_engine_from_cfg", 57 | ] 58 | -------------------------------------------------------------------------------- /f2ai/definitions/backoff_time.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from datetime import datetime 3 | from dataclasses import dataclass 4 | from typing import Iterator, Union 5 | 6 | from f2ai.definitions.period import Period 7 | 8 | 9 | @dataclass 10 | class BackOffTime: 11 | """ 12 | Useful to define how to split data by time. 13 | """ 14 | 15 | start: pd.Timestamp 16 | end: pd.Timestamp 17 | step: Period 18 | tz: str 19 | 20 | def __init__( 21 | self, 22 | start: Union[str, pd.Timestamp], 23 | end: Union[str, pd.Timestamp], 24 | step: Union[str, Period] = "1 day", 25 | tz: str = None, 26 | ) -> None: 27 | # if tz is needed, then pass as '2016-01-01 00+8' to indicate 8hours offset 28 | if isinstance(start, str): 29 | start = pd.Timestamp(start, tz=tz) 30 | if isinstance(end, str): 31 | end = pd.Timestamp(end, tz=tz) 32 | 33 | if isinstance(step, str): 34 | step = Period.from_str(step) 35 | 36 | self.start = start 37 | self.end = end 38 | self.step = step 39 | self.tz = tz 40 | 41 | def to_units(self) -> Iterator[Period]: 42 | pd_offset = self.step.to_pandas_dateoffset() 43 | start = self.step.normalize(self.start, "floor") 44 | end = self.step.normalize(self.end, "ceil") 45 | 46 | bins = pd.date_range( 47 | start=start, 48 | end=end, 49 | freq=pd_offset, 50 | ) 51 | for (start, end) in zip(bins[:-1], bins[1:]): 52 | yield BackOffTime(start=start, end=end, step=self.step, tz=self.tz) 53 | 54 | @classmethod 55 | def from_now(cls, from_now: str, step: str = None, tz: str = None): 56 | end = pd.Timestamp(datetime.now(), tz=tz) 57 | start = end - Period.from_str(from_now).to_py_timedelta() 58 | return BackOffTime(start=start, end=end, step=step, tz=tz) 59 | -------------------------------------------------------------------------------- /f2ai/models/earlystop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | 5 | 6 | class EarlyStopping: 7 | """Early stops the training if validation loss doesn't improve after a given patience.""" 8 | 9 | def __init__(self, save_path=".", patience=7, verbose=False, delta=0, **kwargs): 10 | """ 11 | Args: 12 | patience (int): How long to wait after last time validation loss improved. 13 | 上次验证集损失值改善后等待几个epoch 14 | Default: 7 15 | verbose (bool): If True, prints a message for each validation loss improvement. 16 | 如果是True,为每个验证集损失值改善打印一条信息 17 | Default: False 18 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 19 | 监测数量的最小变化,以符合改进的要求 20 | Default: 0 21 | """ 22 | self.save_path = save_path 23 | self.patience = patience 24 | self.verbose = verbose 25 | self.counter = 0 26 | self.best_score = None 27 | self.early_stop = False 28 | self.val_loss_min = np.Inf 29 | self.delta = delta 30 | self.name = kwargs.get("cpnmae", "") 31 | 32 | def __call__(self, val_loss, model): 33 | 34 | score = -val_loss 35 | 36 | if self.best_score is None: 37 | self.best_score = score 38 | self.save_checkpoint(val_loss, model) 39 | elif score < self.best_score + self.delta: 40 | self.counter += 1 41 | print(f"EarlyStopping counter: {self.counter} out of {self.patience}") 42 | if self.counter >= self.patience: 43 | self.early_stop = True 44 | else: 45 | self.best_score = score 46 | self.save_checkpoint(val_loss, model) 47 | self.counter = 0 48 | 49 | def save_checkpoint(self, val_loss, model): 50 | """ 51 | Saves model when validation loss decrease. 52 | 验证损失减少时保存模型。 53 | """ 54 | if self.verbose: 55 | print( 56 | f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ..." 57 | ) 58 | path_to_save = os.path.join(self.save_path, f"best_chekpnt_{self.name}.pk") 59 | torch.save(model, path_to_save) 60 | self.val_loss_min = val_loss 61 | -------------------------------------------------------------------------------- /f2ai/persist_engine/offline_file_persistengine.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from operator import imod 3 | from typing import List 4 | import pandas as pd 5 | import datetime 6 | 7 | from ..offline_stores.offline_file_store import OfflineFileStore 8 | from ..definitions import ( 9 | FileSource, 10 | OfflinePersistEngine, 11 | OfflinePersistEngineType, 12 | BackOffTime, 13 | PersistFeatureView, 14 | PersistLabelView, 15 | ) 16 | from ..common.utils import write_df_to_dataset 17 | from ..common.time_field import TIME_COL, MATERIALIZE_TIME 18 | 19 | 20 | class OfflineFilePersistEngine(OfflinePersistEngine): 21 | type: OfflinePersistEngineType = OfflinePersistEngineType.FILE 22 | offline_store: OfflineFileStore 23 | 24 | def materialize( 25 | self, 26 | feature_views: List[PersistFeatureView], 27 | label_view: PersistLabelView, 28 | destination: FileSource, 29 | back_off_time: BackOffTime, 30 | service_name: str, 31 | ): 32 | # retrieve entity_df 33 | # TODO: 34 | # 1. 这里是否需要进行更合理的抽象,而不是使用一个私有函数 35 | # 2. 在读取数据之前,框定时间可以可以提高效率 36 | entity_df = self.offline_store._read_file( 37 | source=label_view.source, features=label_view.labels, join_keys=label_view.join_keys 38 | ) 39 | 40 | entity_df.drop(columns=["created_timestamp"], errors="ignore") 41 | entity_df = entity_df[ 42 | (entity_df[TIME_COL] >= back_off_time.start) & (entity_df[TIME_COL] < back_off_time.end) 43 | ] 44 | 45 | # join features recursively 46 | # TODO: this should be reimplemented to directly consume multi feature_views and do a performance test. 47 | joined_frame = entity_df 48 | for feature_view in feature_views: 49 | joined_frame = self.offline_store.get_features( 50 | entity_df=joined_frame, 51 | features=feature_view.features, 52 | source=feature_view.source, 53 | join_keys=feature_view.join_keys, 54 | ttl=feature_view.ttl, 55 | include=True, 56 | how="right", 57 | ) 58 | tz = joined_frame[TIME_COL][0].tz if not joined_frame.empty else None 59 | joined_frame[MATERIALIZE_TIME] = pd.Timestamp(datetime.datetime.now(), tz=tz) 60 | write_df_to_dataset( 61 | joined_frame, destination.path, time_col=destination.timestamp_field, period=back_off_time.step 62 | ) 63 | return service_name 64 | -------------------------------------------------------------------------------- /tests/units/dataset/entities_sampler_test.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from f2ai.dataset import ( 3 | NoEntitiesSampler, 4 | EvenEventsSampler, 5 | FixedNEntitiesSampler, 6 | EvenEntitiesSampler, 7 | ) 8 | 9 | 10 | def test_even_entities_sampler(): 11 | event_time_sampler = EvenEventsSampler(start="2022-10-02", end="2022-10-04", period="1 day") 12 | group_df = pd.DataFrame( 13 | { 14 | "fruit": ["apple", "banana", "banana"], 15 | "sauce": ["tomato", "chili", "tomato"], 16 | } 17 | ) 18 | sampler = EvenEntitiesSampler(event_time_sampler, group_df) 19 | 20 | sampled_df = sampler() 21 | assert len(sampled_df) == 9 22 | assert list(sampled_df.columns) == ["fruit", "sauce", "event_timestamp"] 23 | 24 | next_sampled_item = next(iter(sampler)) 25 | assert all([key in {"fruit", "sauce", "event_timestamp"} for key in next_sampled_item.keys()]) 26 | 27 | 28 | def test_fixed_n_entities_sampler(): 29 | event_time_sampler = EvenEventsSampler(start="2022-10-02", end="2022-10-04", period="1 day") 30 | group_df = pd.DataFrame( 31 | { 32 | "fruit": ["apple", "banana", "banana"], 33 | "sauce": ["tomato", "chili", "tomato"], 34 | } 35 | ) 36 | sampler = FixedNEntitiesSampler(event_time_sampler, group_df, n=2) 37 | 38 | sampled_df = sampler() 39 | assert len(sampled_df) == 6 40 | assert list(sampled_df.columns) == ["fruit", "sauce", "event_timestamp"] 41 | 42 | next_sampled_item = next(iter(sampler)) 43 | assert all([key in {"fruit", "sauce", "event_timestamp"} for key in next_sampled_item.keys()]) 44 | 45 | 46 | def test_fixed_n_prob_entities_sampler(): 47 | event_time_sampler = EvenEventsSampler(start="2022-10-02", end="2022-10-06", period="1 day") 48 | group_df = pd.DataFrame( 49 | {"fruit": ["apple", "banana", "banana"], "sauce": ["tomato", "chili", "tomato"], "p": [0.2, 0.6, 0.2]} 50 | ) 51 | sampler = FixedNEntitiesSampler(event_time_sampler, group_df, n=3) 52 | 53 | sampled_df = sampler() 54 | assert len(sampled_df) == 9 55 | assert list(sampled_df.columns) == ["fruit", "sauce", "event_timestamp"] 56 | 57 | next_sampled_item = next(iter(sampler)) 58 | assert all([key in {"fruit", "sauce", "event_timestamp"} for key in next_sampled_item.keys()]) 59 | 60 | 61 | def test_sampler_properties(): 62 | event_time_sampler = EvenEventsSampler(start="2022-10-02", end="2022-10-06", period="1 day") 63 | sampler = NoEntitiesSampler(event_time_sampler) 64 | assert sampler.iterable 65 | assert not sampler.iterable_only 66 | -------------------------------------------------------------------------------- /f2ai/definitions/sources.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Any 2 | from pydantic import BaseModel, Field 3 | from enum import Enum 4 | 5 | from f2ai.common.utils import read_file 6 | from .offline_store import OfflineStoreType 7 | from .features import FeatureSchema 8 | 9 | 10 | class Source(BaseModel): 11 | """An abstract class which describe the common part of a Source. A source usually defines where to access data and what the time semantic it has. In F2AI, we have 2 kinds of time semantic: 12 | 13 | 1. timestamp_field: the event timestamp which represent when the record happened, which is the main part of point-in-time join. 14 | 2. created_timestamp_field: the created timestamp which represent when the record created, which usually happened in multi cycles of feature generation scenario. 15 | """ 16 | 17 | name: str 18 | description: Optional[str] 19 | timestamp_field: Optional[str] 20 | created_timestamp_field: Optional[str] = Field(alias="created_timestamp_column") 21 | tags: Dict[str, str] = {} 22 | 23 | 24 | class FileFormatEnum(str, Enum): 25 | PARQUET = "parquet" 26 | TSV = "tsv" 27 | CSV = "csv" 28 | TEXT = "text" 29 | 30 | 31 | class FileSource(Source): 32 | file_format: FileFormatEnum = FileFormatEnum.CSV 33 | path: str 34 | 35 | def read_file(self, str_cols: List[str] = [], keep_cols: List[str] = []): 36 | time_columns = [] 37 | if self.timestamp_field: 38 | time_columns.append(self.timestamp_field) 39 | if self.created_timestamp_field: 40 | time_columns.append(self.created_timestamp_field) 41 | 42 | return read_file( 43 | self.path, 44 | file_format=self.file_format, 45 | parse_dates=time_columns, 46 | str_cols=str_cols, 47 | keep_cols=keep_cols, 48 | ) 49 | 50 | 51 | class RequestSource(Source): 52 | schemas: List[FeatureSchema] = Field(alias="schema") 53 | 54 | 55 | class SqlSource(Source): 56 | query: str 57 | 58 | def __init__(__pydantic_self__, **data: Any) -> None: 59 | 60 | query = data.pop("query", "") 61 | if query == "": 62 | query = data.get("name") 63 | 64 | super().__init__(**data, query=query) 65 | 66 | 67 | def parse_source_yaml(o: Dict, offline_store_type: OfflineStoreType) -> Source: 68 | if o.get("type", None) == "request_source": 69 | return RequestSource(**o) 70 | 71 | if offline_store_type == OfflineStoreType.FILE: 72 | return FileSource(**o) 73 | elif offline_store_type == OfflineStoreType.PGSQL: 74 | return SqlSource(**o) 75 | elif offline_store_type == OfflineStoreType.SPARK: 76 | raise Exception("spark is not supported yet!") 77 | else: 78 | return Source(**o) 79 | -------------------------------------------------------------------------------- /f2ai/definitions/online_store.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import abc 4 | from enum import Enum 5 | from typing import TYPE_CHECKING, Any, Dict, Optional 6 | 7 | import pandas as pd 8 | from pydantic import BaseModel 9 | 10 | if TYPE_CHECKING: 11 | from .feature_view import FeatureView 12 | from .period import Period 13 | from .sources import Source 14 | 15 | 16 | class OnlineStoreType(str, Enum): 17 | """A constant numerate choices which is used to indicate how to initialize OnlineStore from configuration. If you want to add a new type of online store, you definitely want to modify this.""" 18 | 19 | REDIS = "redis" 20 | 21 | 22 | class OnlineStore(BaseModel): 23 | """An abstraction of what functionalities a OnlineStore should implements. If you want to be one of the online store contributor. This is the core.""" 24 | 25 | type: OnlineStoreType 26 | name: str 27 | 28 | class Config: 29 | extra = "allow" 30 | 31 | @abc.abstractmethod 32 | def write_batch( 33 | self, featrue_view: FeatureView, project_name: str, dt: pd.DataFrame, ttl: Optional[Period] 34 | ) -> Source: 35 | """materialize data on redis 36 | 37 | Args: 38 | service (Service): an instance of Service 39 | 40 | Returns: 41 | Source 42 | """ 43 | pass 44 | 45 | @abc.abstractmethod 46 | def read_batch( 47 | self, 48 | hkey: str, 49 | ttl: Optional[Period] = None, 50 | period: Optional[Period] = None, 51 | **kwargs, 52 | ) -> pd.DataFrame: 53 | """get data from current online store. 54 | 55 | Args: 56 | entity_df (pd.DataFrame): A query DataFrame which include entities and event_timestamp column. 57 | hkey: hash key. 58 | ttl (Optional[Period], optional): Time to Live, if feature's event_timestamp exceeds the ttl, it will be dropped. Defaults to None. 59 | 60 | Returns: 61 | pd.DataFrame 62 | """ 63 | pass 64 | 65 | 66 | def init_online_store_from_cfg(cfg: Dict[Any], name: str) -> OnlineStore: 67 | """Initialize an implementation of OnlineStore from yaml config. 68 | 69 | Args: 70 | cfg (Dict[Any]): a parsed config object. 71 | 72 | Returns: 73 | OnlineStore: Different types of OnlineStore. 74 | """ 75 | online_store_type = OnlineStoreType(cfg["type"]) 76 | 77 | if online_store_type == OnlineStoreType.REDIS: 78 | from ..online_stores.online_redis_store import OnlineRedisStore 79 | 80 | redis_conf = cfg.pop("redis_conf", {}) 81 | redis_conf.update({"name": name}) 82 | return OnlineRedisStore(**cfg, **redis_conf) 83 | 84 | raise TypeError(f"offline store type must be one of [{','.join(e.value for e in OnlineStore)}]") 85 | -------------------------------------------------------------------------------- /f2ai/common/get_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from typing import List, Dict 4 | 5 | from ..definitions import ( 6 | OfflineStoreType, 7 | Entity, 8 | FeatureView, 9 | LabelView, 10 | Service, 11 | Source, 12 | parse_source_yaml, 13 | ) 14 | from .read_file import read_yml 15 | from .utils import remove_prefix 16 | 17 | 18 | def listdir_with_extensions(path: str, extensions: List[str] = []) -> List[str]: 19 | path = remove_prefix(path, "file://") 20 | if os.path.isdir(path): 21 | files = [] 22 | for extension in extensions: 23 | files.extend(glob.glob(f"{path}/*.{extension}")) 24 | return files 25 | return [] 26 | 27 | 28 | def listdir_yamls(path: str) -> List[str]: 29 | return listdir_with_extensions(path, extensions=["yml", "yaml"]) 30 | 31 | 32 | def get_service_cfg(url: str) -> Dict[str, Service]: 33 | """get forecast config like length of look_back and look_forward, features and labels 34 | 35 | Args: 36 | url (str): url of .yml 37 | """ 38 | service_cfg = {} 39 | for filepath in listdir_yamls(url): 40 | service = Service.from_yaml(read_yml(filepath)) 41 | service_cfg[service.name] = service 42 | return service_cfg 43 | 44 | 45 | def get_entity_cfg(url: str) -> Dict[str, Entity]: 46 | """get entity config for join 47 | 48 | Args: 49 | url (str): url of .yml 50 | """ 51 | entities = {} 52 | for filepath in listdir_yamls(url): 53 | entity = Entity(**read_yml(filepath)) 54 | entities[entity.name] = entity 55 | return entities 56 | 57 | 58 | def get_feature_views(url: str) -> Dict[str, FeatureView]: 59 | """get Dict(FeatureViews) from /feature_views/*.yml 60 | 61 | Args: 62 | url (str): rl of .yml 63 | """ 64 | feature_views = {} 65 | for filepath in listdir_yamls(url): 66 | feature_view = FeatureView(**read_yml(filepath)) 67 | feature_views[feature_view.name] = feature_view 68 | return feature_views 69 | 70 | 71 | def get_label_views(url: str) -> Dict[str, LabelView]: 72 | """get Dict(LabelViews) from /label_views/*.yml 73 | 74 | Args: 75 | url (str): rl of .yml 76 | """ 77 | label_views = {} 78 | for filepath in listdir_yamls(url): 79 | label_view = LabelView(**read_yml(filepath)) 80 | label_views[label_view.name] = label_view 81 | return label_views 82 | 83 | 84 | def get_source_cfg(url: str, offline_store_type: OfflineStoreType) -> Dict[str, Source]: 85 | """get Dict(LabelViews) from /sources/*.yml 86 | 87 | Args: 88 | url (str): rl of .yml 89 | """ 90 | 91 | source_dict = {} 92 | for filepath in listdir_yamls(url): 93 | cfg = read_yml(filepath) 94 | source = parse_source_yaml(cfg, offline_store_type) 95 | source_dict.update({source.name: source}) 96 | return source_dict 97 | -------------------------------------------------------------------------------- /use_cases/guizhou_traffic_arima.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pmdarima as pm 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | from pmdarima import pipeline 6 | from pmdarima.preprocessing import DateFeaturizer, FourierFeaturizer 7 | 8 | from f2ai.dataset import TorchIterableDataset 9 | from f2ai.featurestore import FeatureStore 10 | 11 | 12 | def fill_missing(data: pd.DataFrame, time_col="event_timestamp", query_col="query_timestamp"): 13 | min_of_time = data[time_col].min() 14 | max_of_time = data[time_col].max() 15 | query_time = data.iloc[0][query_col] 16 | 17 | data = data.set_index(time_col, drop=True) 18 | if query_time < min_of_time: 19 | date_index = pd.date_range(query_time, max_of_time, freq="2T", name=time_col)[1:] 20 | else: 21 | date_index = pd.date_range(min_of_time, query_time, freq="2T", name=time_col) 22 | 23 | df = data.reindex(date_index, method="nearest").resample("12T").sum() 24 | 25 | return df.reset_index() 26 | 27 | 28 | #! f2ai materialize travel_time_prediction_arima_v1 --start=2016-03-01 --end=2016-04-01 --step='1 day' 29 | if __name__ == "__main__": 30 | ex_vars = ["event_timestamp"] 31 | fs = FeatureStore("/Users/liyu/.f2ai/f2ai-guizhou_traffic_file") 32 | 33 | # 选择4条路的4个时间段分别进行预测 34 | entity_df = pd.DataFrame( 35 | { 36 | "link_id": [ 37 | "4377906286525800514", 38 | "4377906285681600514", 39 | "3377906281774510514", 40 | "4377906280784800514", 41 | ], 42 | "event_timestamp": [ 43 | pd.Timestamp("2016-03-31 07:00:00"), 44 | pd.Timestamp("2016-03-31 09:00:00"), 45 | pd.Timestamp("2016-03-31 12:00:00"), 46 | pd.Timestamp("2016-03-31 18:00:00"), 47 | ], 48 | } 49 | ) 50 | dataset = TorchIterableDataset(fs, "travel_time_prediction_arima_v1", entity_df) 51 | 52 | fig = plt.figure(figsize=(16, 16)) 53 | axes = fig.subplots(2, 2) 54 | for i, (look_back, look_forward) in enumerate(dataset): 55 | look_back = fill_missing(look_back) 56 | look_forward = fill_missing(look_forward) 57 | 58 | pipe = pipeline.Pipeline( 59 | [ 60 | ("date", DateFeaturizer(column_name="event_timestamp", with_day_of_month=False)), 61 | ("fourier", FourierFeaturizer(m=24 * 5, k=4)), 62 | ("arima", pm.arima.ARIMA(order=(6, 0, 1))), 63 | ] 64 | ) 65 | pipe.fit(look_back["travel_time"], X=look_back[ex_vars]) 66 | look_forward["y_pred"] = pipe.predict(len(look_forward), X=look_forward[ex_vars]) 67 | melted_df = pd.melt(look_forward, id_vars=["event_timestamp"], value_vars=["travel_time", "y_pred"]) 68 | sns.lineplot(melted_df, x="event_timestamp", y="value", hue="variable", ax=axes[i // 2, i % 2]) 69 | 70 | fig.savefig("f2ai_guizhou_traffic_arima", bbox_inches="tight") 71 | -------------------------------------------------------------------------------- /tests/units/offline_stores/offline_postgres_store_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import datetime 4 | import pandas as pd 5 | from pypika import Query 6 | from unittest.mock import MagicMock 7 | 8 | from f2ai.offline_stores.offline_postgres_store import build_stats_query, OfflinePostgresStore 9 | from f2ai.definitions import Feature, FeatureDTypes, StatsFunctions, SqlSource 10 | 11 | feature_city = Feature(name="city", dtype=FeatureDTypes.STRING, view_name="zipcode_table") 12 | feature_population = Feature(name="population", dtype=FeatureDTypes.INT, view_name="zipcode_table") 13 | 14 | FILE_DIR = os.path.dirname(__file__) 15 | 16 | 17 | def read_sql_str(*args) -> str: 18 | with open(os.path.join(FILE_DIR, "postgres_sqls", f"{'_'.join(args)}.sql"), "r") as f: 19 | return f.read() 20 | 21 | 22 | @pytest.mark.parametrize("fn", [(fn) for fn in StatsFunctions if fn != StatsFunctions.UNIQUE]) 23 | def test_build_stats_query_with_group_by_numeric(fn: StatsFunctions): 24 | q = Query.from_("zipcode_table") 25 | sql = build_stats_query( 26 | q, 27 | features=[feature_population], 28 | stats_fn=fn, 29 | group_keys=["zipcode"], 30 | ) 31 | expected_sql = read_sql_str("stats_query", fn.value) 32 | assert sql.get_sql() == expected_sql 33 | 34 | 35 | @pytest.mark.parametrize("fn", [StatsFunctions.UNIQUE]) 36 | def test_build_stats_query_with_group_by_categorical(fn: StatsFunctions): 37 | q = Query.from_("zipcode_table") 38 | sql = build_stats_query( 39 | q, 40 | stats_fn=fn, 41 | group_keys=["zipcode"], 42 | ) 43 | expected_sql = read_sql_str("stats_query", fn.value) 44 | assert sql.get_sql() == expected_sql 45 | 46 | 47 | def test_stats_numeric(): 48 | source = SqlSource(name="foo", query="zipcode_table", timestamp_field="event_timestamp") 49 | store = OfflinePostgresStore( 50 | host="localhost", 51 | user="foo", 52 | password="bar", 53 | ) 54 | 55 | mock = MagicMock() 56 | store._get_dataframe = mock 57 | 58 | store.stats( 59 | source=source, 60 | features=[feature_population], 61 | fn=StatsFunctions.AVG, 62 | group_keys=["zipcode"], 63 | start=datetime.datetime(year=2017, month=1, day=1), 64 | end=datetime.datetime(year=2018, month=1, day=1), 65 | ) 66 | sql, columns = mock.call_args[0] # the first call 67 | assert ",".join(columns) == "zipcode,population" 68 | assert sql.get_sql() == read_sql_str("store_stats_query", "numeric") 69 | 70 | 71 | def test_stats_unique(): 72 | source = SqlSource(name="foo", query="zipcode_table", timestamp_field="event_timestamp") 73 | store = OfflinePostgresStore(host="localhost", user="foo", password="bar") 74 | 75 | mock = MagicMock(return_value=pd.DataFrame({"zipcode": ["A"]})) 76 | store._get_dataframe = mock 77 | 78 | store.stats( 79 | source=source, 80 | features=[feature_city], 81 | fn=StatsFunctions.UNIQUE, 82 | group_keys=["zipcode"], 83 | start=datetime.datetime(year=2017, month=1, day=1), 84 | end=datetime.datetime(year=2018, month=1, day=1), 85 | ) 86 | sql, columns = mock.call_args[0] # the first call 87 | 88 | assert ",".join(columns) == "zipcode" 89 | assert sql.get_sql() == read_sql_str("store_stats_query", "categorical") 90 | -------------------------------------------------------------------------------- /f2ai/common/oss_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | from typing import Tuple, Mapping, Callable, Optional, IO, Any 4 | import oss2 5 | import stat 6 | import tempfile 7 | import zipfile 8 | from ignite.handlers import DiskSaver 9 | from .utils import remove_prefix 10 | 11 | 12 | @functools.lru_cache(maxsize=64) 13 | def get_bucket(bucket, endpoint=None): 14 | key_id = os.environ.get("OSS_ACCESS_KEY_ID") 15 | key_secret = os.environ.get("OSS_ACCESS_KEY_SECRET") 16 | endpoint = endpoint or os.environ.get("OSS_ENDPOINT") 17 | 18 | return oss2.Bucket(oss2.Auth(key_id, key_secret), endpoint, bucket) 19 | 20 | 21 | def parse_oss_url(url: str) -> Tuple[str, str, str]: 22 | """ 23 | url format: oss://{bucket}/{key} 24 | """ 25 | url = remove_prefix(url, "oss://") 26 | components = url.split("/") 27 | return components[0], "/".join(components[1:]) 28 | 29 | 30 | def get_bucket_from_oss_url(url: str): 31 | bucket_name, key = parse_oss_url(url) 32 | return get_bucket(bucket_name), key 33 | 34 | 35 | @functools.lru_cache(maxsize=1) 36 | def get_pandas_storage_options(): 37 | key_id = os.environ.get("OSS_ACCESS_KEY_ID") 38 | key_secret = os.environ.get("OSS_ACCESS_KEY_SECRET") 39 | endpoint = os.environ.get("OSS_ENDPOINT") 40 | 41 | if not endpoint.startswith("https://"): 42 | endpoint = f"https://{endpoint}" 43 | 44 | return { 45 | "key": key_id, 46 | "secret": key_secret, 47 | "client_kwargs": { 48 | "endpoint_url": endpoint, 49 | }, 50 | # "config_kwargs": {"s3": {"addressing_style": "virtual"}}, 51 | } 52 | 53 | 54 | class DiskAndOssSaverAdd(DiskSaver): 55 | def __init__( 56 | self, 57 | dirname: str, 58 | ossaddress: str = None, 59 | create_dir: bool = True, 60 | require_empty: bool = True, 61 | **kwargs: Any, 62 | ): 63 | super().__init__( 64 | dirname=dirname, atomic=True, create_dir=create_dir, require_empty=require_empty, **kwargs 65 | ) 66 | self.ossaddress = ossaddress 67 | 68 | def _save_func(self, checkpoint: Mapping, path: str, func: Callable, rank: int = 0) -> None: 69 | tmp: Optional[IO[bytes]] = None 70 | if rank == 0: 71 | tmp = tempfile.NamedTemporaryFile(delete=False, dir=self.dirname) 72 | 73 | try: 74 | func(checkpoint, tmp.file, **self.kwargs) 75 | except BaseException: 76 | if tmp is not None: 77 | tmp.close() 78 | os.remove(tmp.name) 79 | raise 80 | 81 | if tmp is not None: 82 | tmp.close() 83 | os.replace(tmp.name, path) 84 | # append group/others read mode 85 | os.chmod(path, os.stat(path).st_mode | stat.S_IRGRP | stat.S_IROTH) 86 | if self.ossaddress: 87 | bucket, key = get_bucket_from_oss_url(self.ossaddress) 88 | state_file = f"{path.rsplit('/',maxsplit=1)[0]}{os.sep}state.json" 89 | file_path = f"{path.rsplit('/',maxsplit=1)[0]}{os.sep}model.zip" 90 | with zipfile.ZipFile(file_path, "w", compression=zipfile.ZIP_BZIP2) as archive: 91 | archive.write(path, os.path.basename(path)) 92 | archive.write(state_file, os.path.basename(state_file)) 93 | bucket.put_object_from_file(key, file_path) 94 | os.remove(file_path) 95 | -------------------------------------------------------------------------------- /f2ai/dataset/events_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import abc 4 | from typing import Union, Any, Dict 5 | 6 | from ..definitions import Period 7 | 8 | 9 | class EventsSampler: 10 | """ 11 | This sampler is used to sample event_timestamp column. Customization could be done by inheriting this class. 12 | """ 13 | 14 | @abc.abstractmethod 15 | def __call__(self, *args: Any, **kwds: Any) -> pd.DatetimeIndex: 16 | """ 17 | Get all sample results 18 | 19 | Returns: 20 | pd.DatetimeIndex: An datetime index 21 | """ 22 | pass 23 | 24 | @abc.abstractmethod 25 | def __iter__(self, *args: Any, **kwds: Any) -> pd.Timestamp: 26 | """ 27 | Iterable way to get sample result. 28 | 29 | Returns: 30 | pd.Timestamp 31 | """ 32 | pass 33 | 34 | 35 | class EvenEventsSampler(EventsSampler): 36 | """ 37 | A sampler which using time to generate query entity dataframe. This sampler generally useful when you don't have entity keys. 38 | """ 39 | 40 | def __init__(self, start: str, end: str, period: Union[str, Period], **kwargs): 41 | """ 42 | evenly sample from a range of time, with given period. 43 | 44 | Args: 45 | start (str): start datetime 46 | end (str): end datetime 47 | period (str): a period string, egg: '1 day'. 48 | **kwargs (Any): additional arguments passed to pd.date_range. 49 | """ 50 | if isinstance(start, str): 51 | start = pd.to_datetime(start) 52 | if isinstance(end, str): 53 | end = pd.to_datetime(end) 54 | 55 | self._start = start 56 | self._end = end 57 | self._period = Period.from_str(period) 58 | self._kwargs = kwargs 59 | 60 | def __call__(self) -> pd.DatetimeIndex: 61 | return self._get_date_range() 62 | 63 | def __iter__(self) -> Dict: 64 | datetime_indexes = self._get_date_range() 65 | for i in datetime_indexes: 66 | yield i 67 | 68 | def _get_date_range(self) -> pd.DatetimeIndex: 69 | return pd.date_range(self._start, self._end, freq=self._period.to_pandas_freq_str(), **self._kwargs) 70 | 71 | 72 | class RandomNEventsSampler(EventsSampler): 73 | """ 74 | Randomly sample a fixed number of event timestamp in a given time range. 75 | """ 76 | 77 | def __init__( 78 | self, 79 | start: str, 80 | end: str, 81 | period: Union[str, Period], 82 | n: int = 1, 83 | random_state: int = None, 84 | ): 85 | """ 86 | randomly sample fixed number of event timestamp. 87 | 88 | Args: 89 | start (str): start datetime 90 | end (str): end datetime 91 | period (str): a period string, egg: '1 day'. 92 | """ 93 | if isinstance(start, str): 94 | start = pd.to_datetime(start) 95 | if isinstance(end, str): 96 | end = pd.to_datetime(end) 97 | 98 | self._start = start 99 | self._end = end 100 | self._period = Period.from_str(period) 101 | 102 | self._n = n 103 | self._rng = np.random.default_rng(random_state) 104 | 105 | def __call__(self) -> pd.DatetimeIndex: 106 | return self._get_date_range() 107 | 108 | def __iter__(self) -> Dict: 109 | datetime_indexes = self._get_date_range() 110 | for i in datetime_indexes: 111 | yield i 112 | 113 | def _get_date_range(self) -> pd.DatetimeIndex: 114 | datetimes = pd.date_range(self._start, self._end, freq=self._period.to_pandas_freq_str()) 115 | indices = sorted(self._rng.choice(range(len(datetimes)), size=self._n, replace=False)) 116 | return pd.DatetimeIndex([datetimes[i] for i in indices]) 117 | -------------------------------------------------------------------------------- /f2ai/dataset/pytorch_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import pandas as pd 3 | from typing import TYPE_CHECKING, Optional, Union 4 | from torch.utils.data import IterableDataset 5 | 6 | from ..common.utils import batched, is_iterable 7 | from ..common.time_field import DEFAULT_EVENT_TIMESTAMP_FIELD, QUERY_COL 8 | from ..definitions import Period 9 | from .entities_sampler import EntitiesSampler 10 | 11 | 12 | if TYPE_CHECKING: 13 | from ..definitions.services import Service 14 | from ..featurestore import FeatureStore 15 | 16 | 17 | def iter_rows(df: pd.DataFrame): 18 | for _, row in df.iterrows(): 19 | yield row 20 | 21 | 22 | class TorchIterableDataset(IterableDataset): 23 | """A pytorch portaled dataset.""" 24 | 25 | def __init__( 26 | self, 27 | feature_store: "FeatureStore", 28 | service: Union[Service, str], 29 | sampler: Union[EntitiesSampler, pd.DataFrame], 30 | chunk_size: int = 64, 31 | ): 32 | assert is_iterable( 33 | sampler 34 | ), "TorchIterableDataset only support iterable sampler or iterable DataFrame" 35 | 36 | self._feature_store = feature_store 37 | self._service = service 38 | self._sampler = sampler 39 | self._chunk_size = chunk_size 40 | 41 | if isinstance(service, str): 42 | self._service = feature_store.services.get(service, None) 43 | assert service is not None, f"Service {service} is not found in feature store." 44 | else: 45 | self._service = service 46 | 47 | def get_feature_period(self) -> Optional[Period]: 48 | features = self._service.get_feature_objects(self._feature_store.feature_views) 49 | periods = [Period.from_str(x.period) for x in features if x.period is not None] 50 | return max(periods) if len(periods) > 0 else None 51 | 52 | def get_label_period(self) -> Optional[Period]: 53 | labels = self._service.get_label_objects(self._feature_store.label_views) 54 | periods = [Period.from_str(x.period) for x in labels if x.period is not None] 55 | return max(periods) if len(periods) > 0 else None 56 | 57 | def __iter__(self): 58 | feature_period = self.get_feature_period() 59 | label_period = self.get_label_period() 60 | join_keys = self._service.get_join_keys( 61 | self._feature_store.feature_views, 62 | self._feature_store.label_views, 63 | self._feature_store.entities, 64 | ) 65 | 66 | if isinstance(self._sampler, pd.DataFrame): 67 | iterator = iter_rows(self._sampler) 68 | else: 69 | iterator = self._sampler 70 | 71 | for x in batched(iterator, batch_size=self._chunk_size): 72 | entity_df = pd.DataFrame(x) 73 | labels_df = None 74 | 75 | if feature_period is None: 76 | features_df = self._feature_store.get_features(self._service, entity_df) 77 | labels = self._service.get_label_names(self._feature_store.label_views) 78 | label_columns = join_keys + labels + [DEFAULT_EVENT_TIMESTAMP_FIELD] 79 | labels_df = features_df[label_columns] 80 | features_df = features_df.drop(columns=labels) 81 | else: 82 | if not feature_period.is_neg: 83 | feature_period = -feature_period 84 | features_df = self._feature_store.get_period_features( 85 | self._service, entity_df, feature_period 86 | ) 87 | 88 | # get corresponding labels if not present in features_df 89 | if labels_df is None: 90 | labels_df = self._feature_store.get_period_labels(self._service, entity_df, label_period) 91 | 92 | if feature_period: 93 | group_columns = join_keys + [QUERY_COL] 94 | else: 95 | group_columns = join_keys + [DEFAULT_EVENT_TIMESTAMP_FIELD] 96 | 97 | labels_group = labels_df.groupby(group_columns) 98 | for name, x_features in features_df.groupby(group_columns): 99 | y_labels = labels_group.get_group(name) 100 | # TODO: test this with tabular dataset. 101 | yield (x_features, y_labels) 102 | -------------------------------------------------------------------------------- /use_cases/.ipynb_checkpoints/credit_score-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "abc14dec", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from f2ai.common.collecy_fn import classify_collet_fn\n", 11 | "from f2ai.models.sequential import SimpleClassify\n", 12 | "import torch\n", 13 | "from torch import nn\n", 14 | "from torch.utils.data import DataLoader\n", 15 | "from f2ai.featurestore import FeatureStore\n", 16 | "from f2ai.dataset import GroupFixednbrSampler\n", 17 | "\n", 18 | "\n", 19 | "\n", 20 | "if __name__ == \"__main__\":\n", 21 | " fs = FeatureStore(\"file:///Users/xuyizhou/Desktop/xyz_warehouse/gitlab/f2ai-credit-scoring\")\n", 22 | "\n", 23 | " ds = fs.get_dataset(\n", 24 | " service=\"credit_scoring_v1\",\n", 25 | " sampler=GroupFixednbrSampler(\n", 26 | " time_bucket=\"10 days\",\n", 27 | " stride=1,\n", 28 | " group_ids=None,\n", 29 | " group_names=None,\n", 30 | " start=\"2020-08-01\",\n", 31 | " end=\"2021-09-30\",\n", 32 | " ),\n", 33 | " )\n", 34 | " features_cat = [ # catgorical features\n", 35 | " fea\n", 36 | " for fea in fs._get_feature_to_use(fs.services[\"credit_scoring_v1\"])\n", 37 | " if fea not in fs._get_feature_to_use(fs.services[\"credit_scoring_v1\"], True)\n", 38 | " ]\n", 39 | " cat_unique = fs.stats(\n", 40 | " fs.services[\"credit_scoring_v1\"],\n", 41 | " fn=\"unique\",\n", 42 | " group_key=[],\n", 43 | " start=\"2020-08-01\",\n", 44 | " end=\"2021-09-30\",\n", 45 | " features=features_cat,\n", 46 | " ).to_dict()\n", 47 | " cat_count = {key: len(cat_unique[key]) for key in cat_unique.keys()}\n", 48 | " cont_scalar_max = fs.stats(\n", 49 | " fs.services[\"credit_scoring_v1\"], fn=\"max\", group_key=[], start=\"2020-08-01\", end=\"2021-09-30\"\n", 50 | " ).to_dict()\n", 51 | " cont_scalar_min = fs.stats(\n", 52 | " fs.services[\"credit_scoring_v1\"], fn=\"min\", group_key=[], start=\"2020-08-01\", end=\"2021-09-30\"\n", 53 | " ).to_dict()\n", 54 | " cont_scalar = {key: [cont_scalar_min[key], cont_scalar_max[key]] for key in cont_scalar_min.keys()}\n", 55 | "\n", 56 | " i_ds = ds.to_pytorch()\n", 57 | " test_data_loader = DataLoader( # `batch_siz`e and `drop_last`` do not matter now, `sampler`` set it to be None cause `test_data`` is a Iterator\n", 58 | " i_ds,\n", 59 | " collate_fn=lambda x: classify_collet_fn(\n", 60 | " x,\n", 61 | " cat_coder=cat_unique,\n", 62 | " cont_scalar=cont_scalar,\n", 63 | " label=fs._get_available_labels(fs.services[\"credit_scoring_v1\"]),\n", 64 | " ),\n", 65 | " batch_size=4,\n", 66 | " drop_last=False,\n", 67 | " sampler=None,\n", 68 | " )\n", 69 | "\n", 70 | " model = SimpleClassify(\n", 71 | " cont_nbr=len(cont_scalar_max), cat_nbr=len(cat_count), emd_dim=4, max_types=max(cat_count.values())\n", 72 | " )\n", 73 | " optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # no need to change\n", 74 | " loss_fn = nn.BCELoss() # loss function to train a classification model\n", 75 | "\n", 76 | " for epoch in range(10): # assume 10 epoch\n", 77 | " print(f\"epoch: {epoch} begin\")\n", 78 | " for x, y in test_data_loader:\n", 79 | " pred_label = model(x)\n", 80 | " true_label = y\n", 81 | " loss = loss_fn(pred_label, true_label)\n", 82 | " optimizer.zero_grad()\n", 83 | " loss.backward()\n", 84 | " optimizer.step()\n", 85 | " print(f\"epoch: {epoch} done, loss: {loss}\")\n" 86 | ] 87 | } 88 | ], 89 | "metadata": { 90 | "kernelspec": { 91 | "display_name": "autonn-3.8.7", 92 | "language": "python", 93 | "name": "python3" 94 | }, 95 | "language_info": { 96 | "name": "python", 97 | "version": "3.8.7 (default, Dec 29 2021, 10:58:29) \n[Clang 13.0.0 (clang-1300.0.29.3)]" 98 | }, 99 | "vscode": { 100 | "interpreter": { 101 | "hash": "5840e4ed671345474330e8fce6ab52c58896a3935f0e728b8dbef1ddfad82808" 102 | } 103 | } 104 | }, 105 | "nbformat": 4, 106 | "nbformat_minor": 5 107 | } 108 | -------------------------------------------------------------------------------- /f2ai/common/collect_fn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | 4 | from ..models.encoder import LabelEncoder 5 | from ..models.normalizer import MinMaxNormalizer 6 | 7 | 8 | def classify_collet_fn(datas, cont_scalar={}, cat_coder={}, label=[]): 9 | """_summary_ 10 | 11 | Args: 12 | datas (_type_): datas to be processed, the length equals to batch_size 13 | cont_scalar (dict, optional): normalize continous features, the key is feature name and value is scales corresponding to method using. Defaults to {}. 14 | cat_coder (dict, optional): encode categorical features, the key is feature name and value is scales corresponding to method using. Defaults to {}. 15 | label (list, optional): column name of label. Defaults to []. 16 | 17 | Returns: 18 | tuple: the first element is features and the second is label 19 | """ 20 | batches = [] 21 | # corresspondint to __get_item__ in Dataset 22 | for data in datas: # data[0]:features, data[1]:labels 23 | cat_features = torch.stack( 24 | [ 25 | torch.tensor( 26 | LabelEncoder(cat).fit_self(pd.Series(cat_coder[cat])).transform_self(data[0][cat]), 27 | dtype=torch.int, 28 | ) 29 | for cat in cat_coder.keys() 30 | ], 31 | dim=-1, 32 | ) 33 | cont_features = torch.stack( 34 | [ 35 | torch.tensor( 36 | MinMaxNormalizer(cont) 37 | .fit_self(pd.Series(cont_scalar[cont])) 38 | .transform_self(data[0][cont]) 39 | .values, 40 | dtype=torch.float16, 41 | ) 42 | for cont in cont_scalar.keys() 43 | ], 44 | dim=-1, 45 | ) 46 | labels = torch.stack([torch.tensor(data[1][lab].values, dtype=torch.float) for lab in label], dim=-1) 47 | batch = (dict(categorical_features=cat_features, continous_features=cont_features), labels) 48 | batches.append(batch) 49 | 50 | # corresspondint to _collect_fn_ in Dataset 51 | categorical_features = torch.stack([batch[0]["categorical_features"] for batch in batches]) 52 | continous_features = torch.stack([batch[0]["continous_features"] for batch in batches]) 53 | labels = torch.stack([batch[1] for batch in batches]) 54 | return ( 55 | dict( 56 | categorical_features=categorical_features, 57 | continous_features=continous_features, 58 | ), 59 | labels, 60 | ) 61 | 62 | 63 | def nbeats_collet_fn( 64 | datas, 65 | cont_scalar={}, 66 | categoricals={}, 67 | label={}, 68 | ): 69 | 70 | batches = [] 71 | for data in datas: 72 | 73 | all_cont = torch.stack( 74 | [ 75 | torch.tensor( 76 | MinMaxNormalizer(cont) 77 | .fit_self(pd.Series(cont_scalar[cont])) 78 | .transform_self(data[0][cont]), 79 | dtype=torch.float16, 80 | ) 81 | for cont in cont_scalar.keys() 82 | ], 83 | dim=-1, 84 | ) 85 | 86 | all_categoricals = torch.stack( 87 | [ 88 | torch.tensor( 89 | LabelEncoder(cat).fit_self(pd.Series(categoricals[cat])).transform_self(data[0][cat]), 90 | dtype=torch.int, 91 | ) 92 | for cat in categoricals.keys() 93 | ], 94 | dim=-1, 95 | ) 96 | 97 | targets = torch.stack( 98 | [torch.tensor(data[1][lab], dtype=torch.float) for lab in label], 99 | dim=-1, 100 | ) 101 | 102 | batch = ( 103 | dict( 104 | encoder_cont=all_cont, 105 | decoder_cont=all_cont, 106 | categoricals=all_categoricals, 107 | ), 108 | targets, 109 | ) 110 | batches.append(batch) 111 | 112 | encoder_cont = torch.stack([batch[0]["encoder_cont"] for batch in batches]) 113 | decoder_cont = torch.stack([batch[0]["decoder_cont"] for batch in batches]) 114 | categoricals = torch.stack([batch[0]["categoricals"] for batch in batches]) 115 | targets = torch.stack([batch[1] for batch in batches]) 116 | 117 | return ( 118 | dict( 119 | encoder_cont=encoder_cont + targets, 120 | decoder_cont=decoder_cont, 121 | x_categoricals=categoricals, 122 | ), 123 | targets, 124 | ) 125 | -------------------------------------------------------------------------------- /tests/integrations/benchmarks/offline_pgsql_benchmark_test.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import timeit 3 | from f2ai import FeatureStore 4 | from f2ai.definitions import StatsFunctions 5 | from f2ai.dataset import EvenEventsSampler, NoEntitiesSampler 6 | from f2ai.definitions import BackOffTime 7 | 8 | 9 | def get_guizhou_traffic_entities(store: FeatureStore): 10 | columns = ["link_id", "event_timestamp"] 11 | query_entities = store.offline_store._get_dataframe( 12 | f"select {', '.join(columns)} from gy_link_travel_time order by event_timestamp limit 1000", 13 | columns=columns, 14 | ) 15 | return query_entities.astype({"link_id": "string"}) 16 | 17 | 18 | def test_get_features_from_feature_view(make_guizhou_traffic): 19 | project_folder = make_guizhou_traffic("pgsql") 20 | store = FeatureStore(project_folder) 21 | entity_df = get_guizhou_traffic_entities(store) 22 | 23 | store.get_features("gy_link_travel_time_features", entity_df) 24 | 25 | measured_time = timeit.timeit( 26 | lambda: store.get_features("gy_link_travel_time_features", entity_df), number=10 27 | ) 28 | print(f"get_features performance pgsql: {measured_time}s") 29 | 30 | 31 | def test_stats_from_feature_view(make_guizhou_traffic): 32 | project_folder = make_guizhou_traffic("pgsql") 33 | store = FeatureStore(project_folder) 34 | 35 | measured_time = timeit.timeit( 36 | lambda: store.stats("gy_link_travel_time_features", fn=StatsFunctions.AVG), number=10 37 | ) 38 | print(f"stats performance pgsql: {measured_time}s") 39 | 40 | 41 | def test_unique_from_feature_view(make_guizhou_traffic): 42 | project_folder = make_guizhou_traffic("pgsql") 43 | store = FeatureStore(project_folder) 44 | 45 | measured_time = timeit.timeit( 46 | lambda: store.stats("gy_link_travel_time_features", group_keys=["link_id"]), number=10 47 | ) 48 | print(f"stats performance pgsql: {measured_time}s") 49 | 50 | 51 | def test_get_latest_entities_from_feature_view_with_entity_df(make_guizhou_traffic): 52 | project_folder = make_guizhou_traffic("pgsql") 53 | store = FeatureStore(project_folder) 54 | measured_time = timeit.timeit( 55 | lambda: store.get_latest_entities( 56 | "gy_link_travel_time_features", 57 | pd.DataFrame({"link_id": ["3377906281518510514", "4377906284141600514"]}), 58 | ), 59 | number=10, 60 | ) 61 | print(f"get_latest_entities with entity_df performance pgsql: {measured_time}s") 62 | 63 | 64 | def test_get_latest_entity_from_feature_view(make_guizhou_traffic): 65 | project_folder = make_guizhou_traffic("pgsql") 66 | store = FeatureStore(project_folder) 67 | measured_time = timeit.timeit( 68 | lambda: store.get_latest_entities("gy_link_travel_time_features"), number=10 69 | ) 70 | print(f"get_latest_entities performance pgsql: {measured_time}s") 71 | 72 | 73 | def test_get_period_features_from_feature_view(make_guizhou_traffic): 74 | project_folder = make_guizhou_traffic("pgsql") 75 | store = FeatureStore(project_folder) 76 | entity_df = get_guizhou_traffic_entities(store) 77 | measured_time = timeit.timeit( 78 | lambda: store.get_period_features("gy_link_travel_time_features", entity_df, period="10 minutes"), 79 | number=10, 80 | ) 81 | print(f"get_features performance pgsql: {measured_time}s") 82 | 83 | 84 | def test_dataset_to_pytorch_pgsql(make_guizhou_traffic): 85 | project_folder = make_guizhou_traffic("pgsql") 86 | store = FeatureStore(project_folder) 87 | events_sampler = EvenEventsSampler(start="2016-03-05 00:00:00", end="2016-03-06 00:00:00", period='1 hours') 88 | 89 | ds = store.get_dataset( 90 | service="traval_time_prediction_embedding_v1", 91 | sampler=NoEntitiesSampler(events_sampler), 92 | ) 93 | measured_time = timeit.timeit(lambda: list(ds.to_pytorch(64)), number=1) 94 | print(f"dataset.to_pytorch pgsql performance: {measured_time}s") 95 | 96 | 97 | def test_offline_materialize(make_guizhou_traffic): 98 | project_folder = make_guizhou_traffic("pgsql") 99 | store = FeatureStore(project_folder) 100 | backoff_time = BackOffTime( 101 | start="2016-03-01 08:02:00+08", end="2016-03-01 08:06:00+08", step="4 minutes" 102 | ) 103 | measured_time = timeit.timeit( 104 | lambda: store.materialize( 105 | service_or_views="traval_time_prediction_embedding_v1", 106 | back_off_time=backoff_time, 107 | ), 108 | number=1, 109 | ) 110 | 111 | print(f"materialize performance pgsql: {measured_time}s") 112 | -------------------------------------------------------------------------------- /f2ai/persist_engine/offline_pgsql_persistengine.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import List 3 | from pypika import Query, Parameter, Table, PostgreSQLQuery 4 | 5 | from ..definitions import ( 6 | SqlSource, 7 | OfflinePersistEngine, 8 | OfflinePersistEngineType, 9 | BackOffTime, 10 | PersistFeatureView, 11 | PersistLabelView, 12 | ) 13 | from ..offline_stores.offline_postgres_store import OfflinePostgresStore 14 | from ..common.time_field import ( 15 | DEFAULT_EVENT_TIMESTAMP_FIELD, 16 | ENTITY_EVENT_TIMESTAMP_FIELD, 17 | MATERIALIZE_TIME, 18 | ) 19 | 20 | 21 | class OfflinePgsqlPersistEngine(OfflinePersistEngine): 22 | 23 | type: OfflinePersistEngineType = OfflinePersistEngineType.PGSQL 24 | 25 | offline_store: OfflinePostgresStore 26 | 27 | def materialize( 28 | self, 29 | feature_views: List[PersistFeatureView], 30 | label_view: PersistLabelView, 31 | destination: SqlSource, 32 | back_off_time: BackOffTime, 33 | ): 34 | join_sql = self.store.read( 35 | source=label_view.source, 36 | features=label_view.labels, 37 | join_keys=label_view.join_keys, 38 | alias=ENTITY_EVENT_TIMESTAMP_FIELD, 39 | ).where( 40 | (Parameter(DEFAULT_EVENT_TIMESTAMP_FIELD) <= back_off_time.end) 41 | & (Parameter(DEFAULT_EVENT_TIMESTAMP_FIELD) >= back_off_time.start) 42 | ) 43 | 44 | feature_names = [] 45 | label_names = [label.name for label in label_view.labels] 46 | for feature_view in feature_views: 47 | source_sql = self.store.read( 48 | source=feature_view.source, 49 | features=feature_view.features, 50 | join_keys=feature_view.join_keys, 51 | ) 52 | feature_names += [feature.name for feature in feature_view.features] 53 | 54 | keep_columns = label_view.join_keys + feature_names + label_names + [ENTITY_EVENT_TIMESTAMP_FIELD] 55 | join_sql = self.store._point_in_time_join( 56 | entity_df=join_sql, 57 | source_df=source_sql, 58 | timestamp_field=feature_view.source.timestamp_field, 59 | created_timestamp_field=feature_view.source.created_timestamp_field, 60 | ttl=feature_view.ttl, 61 | join_keys=feature_view.join_keys, 62 | include=True, 63 | how="right", 64 | ).select(Parameter(f"{', '.join(keep_columns)}")) 65 | 66 | data_columns = label_view.join_keys + feature_names + label_names 67 | unique_columns = label_view.join_keys + [DEFAULT_EVENT_TIMESTAMP_FIELD] 68 | join_sql = Query.from_(join_sql).select( 69 | Parameter(f"{', '.join(data_columns)}"), 70 | Parameter(f"{ENTITY_EVENT_TIMESTAMP_FIELD} as {DEFAULT_EVENT_TIMESTAMP_FIELD}"), 71 | Parameter(f"current_timestamp as {MATERIALIZE_TIME}"), 72 | ) 73 | 74 | with self.store.psy_conn as con: 75 | with con.cursor() as cursor: 76 | cursor.execute(f"select to_regclass('{destination.query}')") 77 | (table_name,) = cursor.fetchone() 78 | is_table_exists = table_name in destination.query 79 | 80 | if not is_table_exists: 81 | # create table from select. 82 | cursor.execute( 83 | Query.create_table(destination.query).as_select(join_sql).get_sql(quote_char="") 84 | ) 85 | 86 | # add unique constraint 87 | # TODO: vs unique index. 88 | cursor.execute( 89 | f"alter table {destination.query} add constraint unique_key_{destination.query.split('.')[-1]} unique ({Parameter(', '.join(unique_columns))})" 90 | ) 91 | else: 92 | table = Table(destination.query) 93 | all_columns = data_columns + [DEFAULT_EVENT_TIMESTAMP_FIELD] + [MATERIALIZE_TIME] 94 | 95 | insert_sql = ( 96 | PostgreSQLQuery.into(table) 97 | .columns(*all_columns) 98 | .from_(join_sql) 99 | .select(Parameter(f"{','.join(all_columns)}")) 100 | .on_conflict(*unique_columns) 101 | ) 102 | for c in all_columns: 103 | insert_sql = insert_sql.do_update(table.field(c), Parameter(f"excluded.{c}")) 104 | 105 | cursor.execute(insert_sql.get_sql(quote_char="")) 106 | -------------------------------------------------------------------------------- /f2ai/definitions/period.py: -------------------------------------------------------------------------------- 1 | import re 2 | import pandas as pd 3 | from pydantic import BaseModel 4 | from enum import Enum 5 | from typing import Any, List 6 | from datetime import timedelta 7 | from functools import reduce, total_ordering 8 | 9 | 10 | class AvailablePeriods(Enum): 11 | """Available Period definitions which supported by F2AI.""" 12 | 13 | YEARS = "years" 14 | MONTHS = "months" 15 | WEEKS = "weeks" 16 | DAYS = "days" 17 | HOURS = "hours" 18 | MINUTES = "minutes" 19 | SECONDS = "seconds" 20 | MILLISECONDS = "milliseconds" 21 | MICROSECONDS = "microseconds" 22 | NANOSECONDS = "nanoseconds" 23 | 24 | 25 | PANDAS_TIME_COMPONENTS_MAP = { 26 | AvailablePeriods.YEARS: "year", 27 | AvailablePeriods.MONTHS: "month", 28 | AvailablePeriods.WEEKS: "day", 29 | AvailablePeriods.DAYS: "day", 30 | AvailablePeriods.HOURS: "hour", 31 | AvailablePeriods.MINUTES: "minute", 32 | AvailablePeriods.SECONDS: "second", 33 | AvailablePeriods.MILLISECONDS: "microsecond", 34 | AvailablePeriods.MICROSECONDS: "microsecond", 35 | AvailablePeriods.NANOSECONDS: "nanosecond", 36 | } 37 | 38 | PANDAS_FREQ_STR_MAP = { 39 | AvailablePeriods.YEARS: "YS", 40 | AvailablePeriods.MONTHS: "MS", 41 | AvailablePeriods.WEEKS: "W", 42 | AvailablePeriods.DAYS: "D", 43 | AvailablePeriods.HOURS: "H", 44 | AvailablePeriods.MINUTES: "T", 45 | AvailablePeriods.SECONDS: "S", 46 | AvailablePeriods.MILLISECONDS: "L", 47 | AvailablePeriods.MICROSECONDS: "U", 48 | AvailablePeriods.NANOSECONDS: "N", 49 | } 50 | 51 | 52 | @total_ordering 53 | class Period(BaseModel): 54 | """A wrapper of different representations of a time range. Useful to convert to underline utils like pandas DateOffset, Postgres interval strings.""" 55 | 56 | n: int = 1 57 | unit: AvailablePeriods = AvailablePeriods.DAYS 58 | 59 | def __init__(__pydantic_self__, **data: Any) -> None: 60 | if not data.get("unit", "s").endswith("s"): 61 | data["unit"] = data.get("unit", "") + "s" 62 | super().__init__(**data) 63 | 64 | def __str__(self) -> str: 65 | return f"{self.n} {self.unit.value}" 66 | 67 | def __neg__(self) -> "Period": 68 | return Period(n=-self.n, unit=self.unit.value) 69 | 70 | def __eq__(self, other: "Period") -> bool: 71 | return self.n == other.n and self.unit == other.unit 72 | 73 | def __lt__(self, other: "Period") -> bool: 74 | assert self.unit == other.unit, f"Different unit of Period are not comparable, left: {self.unit.value}, right: {other.unit.value}" 75 | return abs(self.n) < abs(other.n) 76 | 77 | @property 78 | def is_neg(self): 79 | return self.n < 0 80 | 81 | def to_pandas_dateoffset(self, normalize=False): 82 | from pandas import DateOffset 83 | 84 | return DateOffset(**{self.unit.value: self.n}, normalize=normalize) 85 | 86 | def to_pgsql_interval(self): 87 | return f"interval '{self.n} {self.unit.value}'" 88 | 89 | def to_py_timedelta(self): 90 | if self.unit == AvailablePeriods.YEARS: 91 | return timedelta(days=365 * self.n) 92 | if self.unit == AvailablePeriods.MONTHS: 93 | return timedelta(days=30 * self.n) 94 | return timedelta(**{self.unit.value: self.n}) 95 | 96 | def to_pandas_freq_str(self): 97 | return f'{self.n}{PANDAS_FREQ_STR_MAP[self.unit]}' 98 | 99 | @classmethod 100 | def from_str(cls, s: str): 101 | """Construct a period from str, egg: 10 years, 1day, -1 month. 102 | 103 | Args: 104 | s (str): string representation of a period 105 | """ 106 | n, unit = re.search("(-?\d+)\s?(\w+)", s).groups() 107 | return cls(n=int(n), unit=unit) 108 | 109 | def get_pandas_datetime_components(self) -> List[str]: 110 | index_of_period = list(PANDAS_TIME_COMPONENTS_MAP.keys()).index(self.unit) 111 | components = list(PANDAS_TIME_COMPONENTS_MAP.values())[: index_of_period + 1] 112 | 113 | return reduce(lambda xs, x: xs + [x] if x not in xs else xs, components, []) 114 | 115 | def normalize(self, dt: pd.Timestamp, norm_type: str): 116 | if self.unit in { 117 | AvailablePeriods.YEARS, 118 | AvailablePeriods.MONTHS, 119 | AvailablePeriods.WEEKS, 120 | AvailablePeriods.DAYS, 121 | }: 122 | if norm_type == "ceil": 123 | return dt.normalize() + self.to_pandas_dateoffset() 124 | else: 125 | return dt.normalize() 126 | 127 | freq = self.to_pandas_freq_str() 128 | if norm_type == "floor": 129 | return dt.floor(freq) 130 | elif norm_type == "ceil": 131 | return dt.ceil(freq) 132 | else: 133 | return dt.round(freq) 134 | -------------------------------------------------------------------------------- /f2ai/definitions/features.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from enum import Enum 3 | from typing import Optional, TYPE_CHECKING, List, Dict 4 | from pydantic import BaseModel 5 | 6 | from .dtypes import FeatureDTypes, NUMERIC_FEATURE_D_TYPES 7 | 8 | if TYPE_CHECKING: 9 | from .base_view import BaseView 10 | 11 | 12 | class SchemaType(str, Enum): 13 | """Schema used to describe a data column in a table. We only have 2 options in such context, feature or label. Label is the observation data which usually appear in supervised machine learning. In F2AI, label is treated as a special feature.""" 14 | 15 | FEATURE = 0 16 | LABEL = 1 17 | 18 | 19 | class FeatureSchema(BaseModel): 20 | """A FeatureSchema is used to describe a data column but no table information included.""" 21 | 22 | name: str 23 | description: Optional[str] 24 | dtype: FeatureDTypes 25 | 26 | def is_numeric(self): 27 | if self.dtype in NUMERIC_FEATURE_D_TYPES: 28 | return True 29 | return False 30 | 31 | 32 | class SchemaAnchor(BaseModel): 33 | """ 34 | SchemaAnchor links a view to a group of FeatureSchemas with period information included if it has. 35 | """ 36 | 37 | view_name: str 38 | schema_name: str 39 | period: Optional[str] 40 | 41 | @classmethod 42 | def from_strs(cls, cfgs: List[str]) -> "List[SchemaAnchor]": 43 | """Construct from a list of strings. 44 | 45 | Args: 46 | cfgs (List[str]) 47 | 48 | Returns: 49 | List[SchemaAnchor] 50 | """ 51 | return [cls.from_str(cfg) for cfg in cfgs] 52 | 53 | @classmethod 54 | def from_str(cls, cfg: str) -> "SchemaAnchor": 55 | """Construct from a string. 56 | 57 | Args: 58 | cfg (str): a string with specific format, egg: {feature_view_name}:{feature_name}:{period} 59 | 60 | Returns: 61 | SchemaAnchor 62 | """ 63 | components = cfg.split(":") 64 | 65 | if len(components) < 2: 66 | raise ValueError("Please indicate features in table:feature format") 67 | elif len(components) > 3: 68 | raise ValueError("Please make sure colon not in name of table or features") 69 | elif len(components) == 2: 70 | view_name, schema_name = components 71 | return cls(view_name=view_name, schema_name=schema_name) 72 | elif len(components) == 3: 73 | view_name, schema_name, period = components 74 | return cls(view_name=view_name, schema_name=schema_name, period=period) 75 | 76 | def get_features_from_views(self, views: Dict[str, BaseView], is_numeric=False) -> List[Feature]: 77 | """With given views, construct a series of features based on this SchemaAnchor. 78 | 79 | Args: 80 | views (Dict[str, BaseView]) 81 | is_numeric (bool, optional): If only return numeric features. Defaults to False. 82 | 83 | Returns: 84 | List[Feature] 85 | """ 86 | from .feature_view import FeatureView 87 | 88 | view: BaseView = views[self.view_name] 89 | schema_type = SchemaType.FEATURE if isinstance(view, FeatureView) else SchemaType.LABEL 90 | 91 | if self.schema_name == "*": 92 | return [ 93 | Feature.create_from_schema(feature_schema, view.name, schema_type, self.period) 94 | for feature_schema in view.schemas 95 | if (feature_schema.is_numeric() if is_numeric else True) 96 | ] 97 | 98 | feature_schema = next((schema for schema in view.schemas if schema.name == self.schema_name), None) 99 | if feature_schema and (feature_schema.is_numeric() if is_numeric else True): 100 | return [Feature.create_from_schema(feature_schema, view.name, schema_type, self.period)] 101 | 102 | return [] 103 | 104 | 105 | class Feature(BaseModel): 106 | """A Feature which include all necessary information which F2AI should know.""" 107 | 108 | name: str 109 | dtype: FeatureDTypes 110 | period: Optional[str] 111 | schema_type: SchemaType = SchemaType.FEATURE 112 | view_name: str 113 | 114 | @classmethod 115 | def create_feature_from_schema( 116 | cls, schema: FeatureSchema, view_name: str, period: str = None 117 | ) -> "Feature": 118 | return cls.create_from_schema(schema, view_name, SchemaType.FEATURE, period) 119 | 120 | @classmethod 121 | def create_label_from_schema(cls, schema: FeatureSchema, view_name: str, period: str = None) -> "Feature": 122 | return cls.create_from_schema(schema, view_name, SchemaType.LABEL, period) 123 | 124 | @classmethod 125 | def create_from_schema( 126 | cls, schema: FeatureSchema, view_name: str, schema_type: SchemaType, period: str 127 | ) -> "Feature": 128 | return Feature( 129 | name=schema.name, dtype=schema.dtype, schema_type=schema_type, view_name=view_name, period=period 130 | ) 131 | 132 | def __hash__(self) -> int: 133 | return hash(f"{self.view_name}:{self.name}:{self.period}, {self.schema_type}") 134 | -------------------------------------------------------------------------------- /f2ai/dataset/entities_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import abc 4 | from typing import Any, Dict 5 | 6 | from ..common.utils import is_iterable 7 | from .events_sampler import EventsSampler 8 | 9 | 10 | class EntitiesSampler: 11 | @property 12 | def iterable(self): 13 | return is_iterable(self) 14 | 15 | @property 16 | def iterable_only(self): 17 | is_call_abstract = getattr(self.__call__, "__isabstractmethod__", False) 18 | if self.iterable: 19 | return is_call_abstract 20 | return False 21 | 22 | @abc.abstractmethod 23 | def __call__(self, *args: Any, **kwds: Any) -> pd.DataFrame: 24 | """ 25 | Get all sample results, which usually is an entity_df 26 | 27 | Returns: 28 | pd.DataFrame: An entity DataFrame used to query features or labels 29 | """ 30 | pass 31 | 32 | @abc.abstractmethod 33 | def __iter__(self, *args: Any, **kwds: Any) -> Dict: 34 | """ 35 | Iterable way to get sample result. 36 | 37 | Returns: 38 | Dict: get one entity item, which contains entity keys and event_timestamp. 39 | """ 40 | pass 41 | 42 | 43 | # TODO: test this 44 | class NoEntitiesSampler(EntitiesSampler): 45 | """ 46 | This class will directly convert an EventsSampler to EntitiesSampler. Useful when no entity keys are exists. 47 | """ 48 | 49 | def __init__(self, events_sampler: EventsSampler) -> None: 50 | super().__init__() 51 | 52 | self._events_sampler = events_sampler 53 | 54 | def __call__(self) -> pd.DataFrame: 55 | return pd.DataFrame({"event_timestamp": self._events_sampler()}) 56 | 57 | def __iter__(self) -> Dict: 58 | for event_timestamp in self._events_sampler: 59 | yield {"event_timestamp": event_timestamp} 60 | 61 | 62 | class EvenEntitiesSampler(EntitiesSampler): 63 | """ 64 | Sample every group with given events_sampler. 65 | """ 66 | 67 | def __init__( 68 | self, events_sampler: EventsSampler, group_df: pd.DataFrame, random_state: int = None 69 | ) -> None: 70 | super().__init__() 71 | 72 | self._events_sampler = events_sampler 73 | self._rng = np.random.default_rng(random_state) 74 | 75 | self._group_df = group_df 76 | 77 | def __call__(self) -> pd.DataFrame: 78 | return pd.merge( 79 | pd.DataFrame({"event_timestamp": self._events_sampler()}), self._group_df, how="cross" 80 | )[list(self._group_df.columns) + ["event_timestamp"]] 81 | 82 | def __iter__(self) -> Dict: 83 | for event_timestamp in iter(self._events_sampler): 84 | for _, entity_row in self._group_df.iterrows(): 85 | d = entity_row.to_dict() 86 | d["event_timestamp"] = event_timestamp 87 | yield d 88 | 89 | 90 | class FixedNEntitiesSampler(EntitiesSampler): 91 | """ 92 | Sample N instance from each group with given probability. 93 | """ 94 | 95 | def __init__(self, events_sampler: EventsSampler, group_df: pd.DataFrame, n=1, random_state=None) -> None: 96 | """ 97 | Args: 98 | events_sampler (EventTimestampSampler): an EventTimestampSampler instance. 99 | group_df (pd.DataFrame): a group_df is a dataframe with contains some entity columns. Optionally, it may contains a columns named `p`, which indicates the probability of this this group. 100 | n (int, optional): how many instance per group. Defaults to 1. 101 | random_state (_type_, optional): Defaults to None. 102 | """ 103 | 104 | super().__init__() 105 | 106 | self._events_sampler = events_sampler 107 | self._rng = np.random.default_rng(seed=random_state) 108 | 109 | self._event_timestamps = self._events_sampler() 110 | 111 | if "p" in group_df.columns: 112 | assert group_df["p"].sum() == 1, "sum all weights should be 1" 113 | 114 | take_n = (group_df["p"] * len(group_df) * n).round().astype(int).rename("_f2ai_take_n_") 115 | self._group_df = pd.concat([group_df, take_n], axis=1) 116 | else: 117 | self._group_df = group_df.assign(_f2ai_take_n_=n) 118 | 119 | def _sample_n(self, row: pd.DataFrame): 120 | event_timestamps = self._rng.choice(self._event_timestamps, size=row["_f2ai_take_n_"].iloc[0]) 121 | return pd.merge( 122 | row.drop(columns=["p", "_f2ai_take_n_"], errors="ignore"), 123 | pd.DataFrame({"event_timestamp": event_timestamps}), 124 | how="cross", 125 | ) 126 | 127 | def __call__(self) -> pd.DataFrame: 128 | group_keys = [column for column in self._group_df.columns if column not in {"p", "_f2ai_take_n_"}] 129 | return ( 130 | self._group_df.groupby(group_keys, group_keys=False, sort=False) 131 | .apply(self._sample_n) 132 | .reset_index(drop=True) 133 | ) 134 | 135 | def __iter__(self) -> Dict: 136 | for i, row in self._group_df.iterrows(): 137 | sampled_df = self._sample_n(pd.DataFrame([row])) 138 | for j, row in sampled_df.iterrows(): 139 | yield row.to_dict() 140 | -------------------------------------------------------------------------------- /tests/integrations/benchmarks/offline_file_benchmark_test.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import timeit 3 | from os import path 4 | from f2ai import FeatureStore 5 | from f2ai.dataset import EvenEventsSampler, FixedNEntitiesSampler, NoEntitiesSampler 6 | from f2ai.definitions import BackOffTime, Period 7 | 8 | LINE_LIMIT = 1000 9 | 10 | 11 | def get_credit_score_entities(project_folder: str): 12 | query_entities = pd.read_parquet( 13 | path.join(project_folder, "row_data/loan_table.parquet"), 14 | columns=["loan_id", "dob_ssn", "zipcode", "created_timestamp", "event_timestamp"], 15 | ).iloc[:LINE_LIMIT] 16 | return query_entities.astype( 17 | { 18 | "loan_id": "string", 19 | "dob_ssn": "string", 20 | "zipcode": "string", 21 | } 22 | ) 23 | 24 | 25 | def get_guizhou_traffic_entities(project_folder: str): 26 | query_entities = pd.read_csv( 27 | path.join(project_folder, "raw_data/gy_link_travel_time.csv"), 28 | usecols=["link_id", "event_timestamp"], 29 | nrows=5, 30 | ) 31 | return query_entities.astype({"link_id": "string"}) 32 | 33 | 34 | def test_get_features_from_feature_view(make_credit_score): 35 | project_folder = make_credit_score("file") 36 | entity_df = get_credit_score_entities(project_folder) 37 | store = FeatureStore(project_folder) 38 | store.get_features("zipcode_features", entity_df) 39 | 40 | measured_time = timeit.timeit(lambda: store.get_features("zipcode_features", entity_df), number=10) 41 | print(f"get_features performance: {measured_time}s") 42 | 43 | 44 | def test_get_labels_from_label_views(make_credit_score): 45 | project_folder = make_credit_score("file") 46 | entity_df = get_credit_score_entities(project_folder) 47 | store = FeatureStore(project_folder) 48 | store.get_labels("loan_label_view", entity_df) 49 | 50 | 51 | def test_materialize(make_credit_score): 52 | project_folder = make_credit_score("file") 53 | store = FeatureStore(project_folder) 54 | back_off_time = BackOffTime(start="2020-08-25", end="2021-08-26", step="1 month") 55 | measured_time = timeit.timeit( 56 | lambda: store.materialize( 57 | service_or_views="credit_scoring_v1", back_off_time=back_off_time, online=False 58 | ), 59 | number=1, 60 | ) 61 | print(f"materialize performance: {measured_time}s") 62 | 63 | 64 | def test_get_features_from_service(make_credit_score): 65 | """this test should run after materialize""" 66 | project_folder = make_credit_score("file") 67 | store = FeatureStore(project_folder) 68 | entity_df = get_credit_score_entities(project_folder) 69 | store.get_features("credit_scoring_v1", entity_df) 70 | 71 | 72 | def test_get_period_features_from_feature_view(make_guizhou_traffic): 73 | project_folder = make_guizhou_traffic("file") 74 | entity_df = get_guizhou_traffic_entities(project_folder) 75 | store = FeatureStore(project_folder) 76 | measured_time = timeit.timeit( 77 | lambda: store.get_period_features("gy_link_travel_time_features", entity_df, period="20 minutes"), 78 | number=10, 79 | ) 80 | print(f"get_period_features performance: {measured_time}s") 81 | 82 | 83 | def test_stats_from_feature_view(make_credit_score): 84 | project_folder = make_credit_score("file") 85 | store = FeatureStore(project_folder) 86 | 87 | measured_time = timeit.timeit(lambda: store.stats("loan_features", fn="avg"), number=10) 88 | print(f"stats performance: {measured_time}s") 89 | 90 | 91 | def test_get_latest_entities_from_feature_view(make_credit_score): 92 | project_folder = make_credit_score("file") 93 | store = FeatureStore(project_folder) 94 | measured_time = timeit.timeit( 95 | lambda: store.get_latest_entities("loan_features", pd.DataFrame({"dob_ssn": ["19960703_3449"]})), 96 | number=10, 97 | ) 98 | print(f"get_latest_entities performance: {measured_time}s") 99 | 100 | 101 | def test_get_latest_entity_from_feature_view(make_credit_score): 102 | project_folder = make_credit_score("file") 103 | store = FeatureStore(project_folder) 104 | measured_time = timeit.timeit(lambda: store.get_latest_entities("loan_features"), number=10) 105 | print(f"get_latest_entities performance: {measured_time}s") 106 | 107 | 108 | def test_sampler_with_groups(make_credit_score): 109 | project_folder = make_credit_score("file") 110 | store = FeatureStore(project_folder) 111 | group_df = store.stats("loan_features", group_keys=["zipcode", "dob_ssn"], fn="unique") 112 | events_sampler = EvenEventsSampler(start="2020-08-20", end="2021-08-30", period="10 days") 113 | entities_sampler = FixedNEntitiesSampler( 114 | events_sampler, 115 | group_df=group_df, 116 | ) 117 | measured_time = timeit.timeit( 118 | lambda: entities_sampler(), 119 | number=1, 120 | ) 121 | print(f"sampler with groups performance: {measured_time}s") 122 | 123 | 124 | def test_dataset_to_pytorch(make_credit_score): 125 | project_folder = make_credit_score("file") 126 | backoff = BackOffTime( 127 | start=pd.Timestamp("2020-05-01"), end=pd.Timestamp("2020-07-01"), step=Period.from_str("1 month") 128 | ) 129 | store = FeatureStore(project_folder) 130 | store.materialize(service_or_views="credit_scoring_v1", back_off_time=backoff) 131 | 132 | events_sampler = EvenEventsSampler(start="2020-08-20", end="2021-08-30", period="10 days") 133 | ds = store.get_dataset( 134 | service="credit_scoring_v1", 135 | sampler=NoEntitiesSampler(events_sampler), 136 | ) 137 | measured_time = timeit.timeit(lambda: list(ds.to_pytorch()), number=1) 138 | print(f"dataset.to_pytorch performance: {measured_time}s") 139 | 140 | 141 | def test_online_materialize(make_credit_score): 142 | project_folder = make_credit_score("file") 143 | back_off_time = BackOffTime( 144 | start=pd.Timestamp("2020-05-01"), end=pd.Timestamp("2020-07-01"), step=Period.from_str("1 month") 145 | ) 146 | store = FeatureStore(project_folder) 147 | store.materialize(service_or_views="credit_scoring_v1", back_off_time=back_off_time, online=True) 148 | -------------------------------------------------------------------------------- /f2ai/common/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import os 4 | import oss2 5 | import pandas as pd 6 | from typing import List, Tuple, Iterable, Any 7 | from pathlib import Path 8 | 9 | from ..definitions import Period 10 | 11 | 12 | class DateEncoder(json.JSONEncoder): 13 | def default(self, o): 14 | if isinstance(o, datetime.datetime): 15 | return o.strftime("%Y-%m-%dT%H:%M:%S") 16 | else: 17 | return json.JSONEncoder.default(self, o) 18 | 19 | 20 | def remove_prefix(text: str, prefix: str): 21 | return text[text.startswith(prefix) and len(prefix) :] 22 | 23 | 24 | def get_default_value(): 25 | return None 26 | 27 | 28 | def schema_to_dict(schema): 29 | return {item["name"]: item.get("dtype", "string") for item in schema} 30 | 31 | 32 | def read_file( 33 | path, 34 | parse_dates: List[str] = [], 35 | str_cols: List[str] = [], 36 | keep_cols: List[str] = [], 37 | file_format=None, 38 | ): 39 | path = Path(remove_prefix(path, "file://")) 40 | dtypes = {en: str for en in str_cols} 41 | usecols = list(dict.fromkeys(keep_cols + parse_dates + str_cols)) 42 | 43 | if file_format is None: 44 | file_format = path.parts[-1].split(".")[-1] 45 | 46 | if path.is_dir(): 47 | df = read_df_from_dataset(path, usecols=usecols).astype(dtypes) 48 | elif file_format.startswith("parq"): 49 | df = pd.read_parquet(path, columns=usecols).astype(dtypes) 50 | elif file_format.startswith("tsv"): 51 | df = pd.read_csv(path, sep="\t", parse_dates=parse_dates, dtype=dtypes, usecols=usecols) 52 | elif file_format.startswith("txt"): 53 | df = pd.read_csv(path, sep=" ", parse_dates=parse_dates, dtype=dtypes, usecols=usecols) 54 | else: 55 | df = pd.read_csv(path, parse_dates=parse_dates, dtype=dtypes, usecols=usecols) 56 | 57 | return df 58 | 59 | 60 | def to_file(file: pd.DataFrame, path, type, mode="w", header=True): 61 | 62 | path = remove_prefix(path, "file://") 63 | if type.startswith("parq"): 64 | file.to_parquet(path, index=False) 65 | elif type.startswith("tsv"): 66 | file.to_csv(path, sep="\t", index=False, mode=mode, header=header) 67 | elif type.startswith("txt"): 68 | file.to_csv(path, sep=" ", index=False, mode=mode, header=header) 69 | else: 70 | file.to_csv(path, index=False, mode=mode, header=header) 71 | 72 | 73 | def get_bucket(bucket, endpoint=None): 74 | key_id = os.environ.get("OSS_ACCESS_KEY_ID") 75 | key_secret = os.environ.get("OSS_ACCESS_KEY_SECRET") 76 | endpoint = endpoint or os.environ.get("OSS_ENDPOINT") 77 | 78 | return oss2.Bucket(oss2.Auth(key_id, key_secret), endpoint, bucket) 79 | 80 | 81 | def parse_oss_url(url: str) -> Tuple[str, str, str]: 82 | """ 83 | url format: oss://{bucket}/{key} 84 | """ 85 | url = remove_prefix(url, "oss://") 86 | components = url.split("/") 87 | return components[0], "/".join(components[1:]) 88 | 89 | 90 | def get_bucket_from_oss_url(url: str): 91 | bucket_name, key = parse_oss_url(url) 92 | return get_bucket(bucket_name), key 93 | 94 | 95 | # the code below copied from https://github.com/pandas-dev/pandas/blob/91111fd99898d9dcaa6bf6bedb662db4108da6e6/pandas/io/sql.py#L1155 96 | def convert_dtype_to_sqlalchemy_type(col): 97 | from pandas._libs.lib import infer_dtype 98 | from sqlalchemy.types import ( 99 | TIMESTAMP, 100 | BigInteger, 101 | Boolean, 102 | Date, 103 | DateTime, 104 | Float, 105 | Integer, 106 | SmallInteger, 107 | Text, 108 | Time, 109 | ) 110 | 111 | col_type = infer_dtype(col, skipna=True) 112 | 113 | if col_type == "datetime64" or col_type == "datetime": 114 | try: 115 | if col.dt.tz is not None: 116 | return TIMESTAMP(timezone=True) 117 | except AttributeError: 118 | if getattr(col, "tz", None) is not None: 119 | return TIMESTAMP(timezone=True) 120 | return DateTime 121 | 122 | if col_type == "timedelta64": 123 | return BigInteger 124 | elif col_type == "floating": 125 | if col.dtype == "float32": 126 | return Float(precision=23) 127 | else: 128 | return Float(precision=53) 129 | elif col_type == "integer": 130 | if col.dtype.name.lower() in ("int8", "uint8", "int16"): 131 | return SmallInteger 132 | elif col.dtype.name.lower() in ("uint16", "int32"): 133 | return Integer 134 | elif col.dtype.name.lower() == "uint64": 135 | raise ValueError("Unsigned 64 bit integer datatype is not supported") 136 | else: 137 | return BigInteger 138 | elif col_type == "boolean": 139 | return Boolean 140 | elif col_type == "date": 141 | return Date 142 | elif col_type == "time": 143 | return Time 144 | elif col_type == "complex": 145 | raise ValueError("Complex datatypes not supported") 146 | 147 | return Text 148 | 149 | 150 | def write_df_to_dataset(data: pd.DataFrame, root_path: str, time_col: str, period: Period): 151 | from pyarrow import parquet, Table 152 | 153 | # create partitioning cols 154 | times = pd.to_datetime(data[time_col]) 155 | components = dict() 156 | for component in period.get_pandas_datetime_components(): 157 | components[f"_f2ai_{component}_"] = getattr(times.dt, component) 158 | 159 | data_with_components = pd.concat( 160 | [ 161 | pd.DataFrame(components), 162 | data, 163 | ], 164 | axis=1, 165 | ) 166 | table = Table.from_pandas(data_with_components) 167 | parquet.write_to_dataset(table, root_path=root_path, partition_cols=components) 168 | 169 | 170 | def read_df_from_dataset(root_path: str, usecols: List[str] = []) -> pd.DataFrame: 171 | from pyarrow.parquet import ParquetDataset 172 | 173 | ds = ParquetDataset(root_path, use_legacy_dataset=True) 174 | table = ds.read(columns=usecols) 175 | df: pd.DataFrame = table.to_pandas() 176 | drop_columns = [col_name for col_name in df.columns if col_name.startswith("_f2ai_")] 177 | return df.drop(columns=drop_columns) 178 | 179 | 180 | def batched(xs: Iterable, batch_size: int, drop_last=False): 181 | batches = [] 182 | for x in xs: 183 | batches.append(x) 184 | 185 | if len(batches) == batch_size: 186 | yield batches 187 | batches = [] 188 | 189 | if len(batches) > 0 and not drop_last: 190 | yield batches 191 | 192 | 193 | def is_iterable(o: Any): 194 | try: 195 | iter(o) 196 | return True 197 | except TypeError: 198 | return False 199 | -------------------------------------------------------------------------------- /f2ai/definitions/services.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Dict 2 | from pydantic import BaseModel 3 | 4 | from .entities import Entity 5 | from .features import SchemaAnchor, Feature 6 | from .feature_view import FeatureView 7 | from .label_view import LabelView 8 | 9 | 10 | class Service(BaseModel): 11 | """A Service is a combination of a group of feature views and label views, which usually directly related to a certain AI model. Tbe best practice which F2AI suggested is, treating services are immutable. Egg: if you want to train different combinations of features for a specific AI model, you may want to create multiple Services, like: linear_reg_v1, linear_reg_v2.""" 12 | 13 | name: str 14 | description: Optional[str] 15 | features: List[SchemaAnchor] = [] 16 | labels: List[SchemaAnchor] = [] 17 | ttl: Optional[str] = None 18 | 19 | @classmethod 20 | def from_yaml(cls, cfg: Dict) -> "Service": 21 | """Construct a Service from parsed yaml config file.""" 22 | 23 | cfg["features"] = SchemaAnchor.from_strs(cfg.pop("features", [])) 24 | cfg["labels"] = SchemaAnchor.from_strs(cfg.pop("labels", [])) 25 | 26 | return cls(**cfg) 27 | 28 | def get_feature_names(self, feature_views: Dict[str, FeatureView], is_numeric=False) -> List[str]: 29 | return [feature.name for feature in self.get_feature_objects(feature_views, is_numeric)] 30 | 31 | def get_label_names(self, label_views: Dict[str, FeatureView], is_numeric=False) -> List[str]: 32 | return [label.name for label in self.get_label_objects(label_views, is_numeric)] 33 | 34 | def get_feature_objects(self, feature_views: Dict[str, FeatureView], is_numeric=False) -> List[Feature]: 35 | """get all the feature objects which included in this service based on features' schema anchor. 36 | 37 | Args: 38 | feature_views (Dict[str, FeatureView]): A group of FeatureViews. 39 | is_numeric (bool, optional): If only include numeric features. Defaults to False. 40 | 41 | Returns: 42 | List[Feature] 43 | """ 44 | return list( 45 | dict.fromkeys( 46 | feature 47 | for schema_anchor in self.features 48 | for feature in schema_anchor.get_features_from_views(feature_views, is_numeric) 49 | ) 50 | ) 51 | 52 | def get_label_objects(self, label_views: Dict[str, LabelView], is_numeric=False) -> List[Feature]: 53 | """get all the label objects which included in this service based on labels' schema anchor. 54 | 55 | Args: 56 | feature_views (Dict[str, LabelView]): A group of LabelViews. 57 | is_numeric (bool, optional): If only include numeric labels. Defaults to False. 58 | 59 | Returns: 60 | List[Feature] 61 | """ 62 | return list( 63 | dict.fromkeys( 64 | label 65 | for schema_anchor in self.labels 66 | for label in schema_anchor.get_features_from_views(label_views, is_numeric) 67 | ) 68 | ) 69 | 70 | def get_feature_view_names(self, feature_views: Dict[str, FeatureView]) -> List[str]: 71 | """ 72 | Get the name of feature view names related to this service. 73 | 74 | Args: 75 | feature_views (Dict[str, FeatureView]): list of FeatureViews to filter. 76 | 77 | Returns: 78 | List[str]: names of FeatureView 79 | """ 80 | feature_view_names = list(dict.fromkeys([anchor.view_name for anchor in self.features])) 81 | return [x for x in feature_view_names if x in feature_views] 82 | 83 | def get_feature_views(self, feature_views: Dict[str, FeatureView]) -> List[FeatureView]: 84 | """Get FeatureViews of this service. This will automatically filter out the feature view not given by parameters. 85 | 86 | Args: 87 | feature_views (Dict[str, FeatureView]) 88 | 89 | Returns: 90 | List[FeatureView] 91 | """ 92 | feature_view_names = self.get_feature_view_names(feature_views) 93 | return [feature_views[feature_view_name] for feature_view_name in feature_view_names] 94 | 95 | def get_label_views(self, label_views: Dict[str, LabelView]) -> List[LabelView]: 96 | """Get LabelViews of this service. This will automatically filter out the label view not given by parameters. 97 | 98 | Args: 99 | label_views (Dict[str, LabelView]) 100 | 101 | Returns: 102 | List[LabelView] 103 | """ 104 | label_view_names = list(dict.fromkeys([anchor.view_name for anchor in self.labels])) 105 | return [label_views[label_view_name] for label_view_name in label_view_names] 106 | 107 | def get_feature_entities(self, feature_views: Dict[str, FeatureView]) -> List[Entity]: 108 | """Get all entities which appeared in related feature views to this service and without duplicate entity. 109 | 110 | Args: 111 | feature_views (Dict[str, FeatureView]) 112 | 113 | Returns: 114 | List[Entity] 115 | """ 116 | return list( 117 | dict.fromkeys( 118 | entity 119 | for feature_view in self.get_feature_views(feature_views) 120 | for entity in feature_view.entities 121 | ) 122 | ) 123 | 124 | def get_label_entities(self, label_views: Dict[str, LabelView]) -> List[str]: 125 | """Get all entities which appeared in related label views to this service and without duplicate entity. 126 | 127 | Args: 128 | label_views (Dict[str, LabelView]) 129 | 130 | Returns: 131 | List[str] 132 | """ 133 | return list( 134 | dict.fromkeys( 135 | entity for label_view in self.get_label_views(label_views) for entity in label_view.entities 136 | ) 137 | ) 138 | 139 | def get_entities( 140 | self, feature_views: Dict[str, FeatureView], label_views: Dict[str, LabelView] 141 | ) -> List[str]: 142 | """Get all entities which appeared in this service and without duplicate entity. 143 | 144 | Args: 145 | feature_views (Dict[str, FeatureView]) 146 | label_views (Dict[str, LabelView]) 147 | 148 | Returns: 149 | List[str] 150 | """ 151 | return list( 152 | dict.fromkeys(self.get_feature_entities(feature_views) + self.get_label_entities(label_views)) 153 | ) 154 | 155 | def get_join_keys( 156 | self, 157 | feature_views: Dict[str, FeatureView], 158 | label_views: Dict[str, FeatureView], 159 | entities: Dict[str, Entity], 160 | ) -> List[str]: 161 | return list( 162 | dict.fromkeys( 163 | [ 164 | join_key 165 | for x in self.get_entities(feature_views, label_views) 166 | for join_key in entities[x].join_keys 167 | ] 168 | ) 169 | ) 170 | -------------------------------------------------------------------------------- /f2ai/online_stores/online_redis_store.py: -------------------------------------------------------------------------------- 1 | import json 2 | import uuid 3 | from datetime import datetime 4 | from typing import Optional, Dict, List, Union 5 | import functools 6 | import pandas as pd 7 | from pydantic import PrivateAttr 8 | from redis import Redis, ConnectionPool 9 | 10 | from ..common.utils import DateEncoder 11 | from ..definitions import OnlineStore, OnlineStoreType, Period, FeatureView, Service, Entity 12 | from ..common.time_field import DEFAULT_EVENT_TIMESTAMP_FIELD, QUERY_COL 13 | 14 | 15 | class OnlineRedisStore(OnlineStore): 16 | type: OnlineStoreType = OnlineStoreType.REDIS 17 | host: str = "localhost" 18 | port: int = 6379 19 | db: int = 0 20 | password: str = "" 21 | name: str 22 | 23 | _cilent: Optional[Redis] = PrivateAttr(default=None) 24 | 25 | @property 26 | def client(self): 27 | if self._cilent is None: 28 | pool = ConnectionPool( 29 | host=self.host, port=self.port, db=self.db, password=self.password, decode_responses=True 30 | ) 31 | self._cilent = Redis(connection_pool=pool) 32 | 33 | return self._cilent 34 | 35 | def write_batch( 36 | self, 37 | name: str, 38 | project_name: str, 39 | dt: pd.DataFrame, 40 | ttl: Optional[Period] = None, 41 | join_keys: List[str] = None, 42 | tz: str = None, 43 | ): 44 | pipe = self.client.pipeline() 45 | if not dt.empty: 46 | for group_data in dt.groupby(join_keys): 47 | all_entities = functools.reduce( 48 | lambda x, y: f"{x},{y}", 49 | list( 50 | map( 51 | lambda x, y: x + ":" + y, 52 | join_keys, 53 | [group_data[0]] if isinstance(group_data[0], str) else list(group_data[0]), 54 | ) 55 | ), 56 | ) 57 | if self.client.hget(f"{project_name}:{name}", all_entities) is None: 58 | zset_key = uuid.uuid4().hex 59 | pipe.hset(name=f"{project_name}:{name}", key=all_entities, value=zset_key) 60 | else: 61 | zset_key = self.client.hget(name=f"{project_name}:{name}", key=all_entities) 62 | # remove data that has expired in `zset`` according to `score` 63 | if ttl is not None: 64 | pipe.zremrangebyscore( 65 | name=zset_key, 66 | min=0, 67 | max=(pd.Timestamp(datetime.now(), tz=tz) - ttl.to_py_timedelta()).timestamp(), 68 | ) 69 | zset_dict = { 70 | json.dumps(row, cls=DateEncoder): pd.to_datetime( 71 | row.get(DEFAULT_EVENT_TIMESTAMP_FIELD, pd.Timestamp(datetime.now(), tz=tz)), 72 | ).timestamp() 73 | for row in group_data[1] 74 | .drop(columns=[QUERY_COL], errors="ignore") 75 | .to_dict(orient="records") 76 | } 77 | pipe.zadd(name=zset_key, mapping=zset_dict) 78 | if ttl is not None: # add a general expire constrains on hash-key 79 | expir_time = group_data[1][DEFAULT_EVENT_TIMESTAMP_FIELD].max() + ttl.to_py_timedelta() 80 | pipe.expireat(zset_key, expir_time) 81 | pipe.execute() 82 | 83 | def read_batch( 84 | self, 85 | entity_df: pd.DataFrame, 86 | project_name: str, 87 | view: Union[Service, FeatureView], 88 | feature_views: Dict[str, FeatureView], 89 | entities: Dict[str, Entity], 90 | join_keys: List[str], 91 | **kwargs, 92 | ): 93 | if isinstance(view, FeatureView): 94 | data = self._read_batch( 95 | hkey=f"{project_name}:{view.name}", 96 | ttl=view.ttl, 97 | period=None, 98 | entity_df=entity_df[join_keys], 99 | **kwargs, 100 | ) 101 | entity_df = pd.merge(entity_df, data, on=join_keys, how="inner") if not data.empty else None 102 | elif isinstance(view, Service): 103 | for featureview in view.features: 104 | fea_entities = functools.reduce( 105 | lambda x, y: x + y, 106 | [entities[entity].join_keys for entity in feature_views[featureview.view_name].entities], 107 | ) 108 | fea_join_keys = [join_key for join_key in join_keys if join_key in fea_entities] 109 | feature_view_batch = self._read_batch( 110 | hkey=f"{project_name}:{featureview.view_name}", 111 | ttl=feature_views[featureview.view_name].ttl, 112 | entity_df=entity_df[fea_join_keys], 113 | period=featureview.period, 114 | **kwargs, 115 | ) 116 | if not feature_view_batch.empty: 117 | entity_df = feature_view_batch.merge(entity_df, on=fea_join_keys, how="inner") 118 | else: 119 | raise TypeError("online read only allow FeatureView and Service") 120 | 121 | return entity_df 122 | 123 | def _read_batch( 124 | self, 125 | hkey: str, 126 | ttl: Optional[Period] = None, 127 | period: Optional[Period] = None, 128 | entity_df: pd.DataFrame = None, 129 | ) -> pd.DataFrame: 130 | 131 | min_timestamp = max( 132 | (pd.Timestamp(datetime.now()) - period.to_pandas_dateoffset() if period else pd.Timestamp(0)), 133 | (pd.Timestamp(datetime.now()) - ttl.to_pandas_dateoffset() if ttl else pd.Timestamp(0)), 134 | ) 135 | dt_group = entity_df.groupby(list(entity_df.columns)) 136 | all_entities = [ 137 | functools.reduce( 138 | lambda x, y: f"{x},{y}", 139 | list( 140 | map( 141 | lambda x, y: x + ":" + y, 142 | list(entity_df.columns), 143 | [group_data[0]] if isinstance(group_data[0], str) else list(group_data[0]), 144 | ) 145 | ), 146 | ) 147 | for group_data in dt_group 148 | ] 149 | all_zset_key = self.client.hmget(hkey, all_entities) 150 | if all_zset_key: 151 | data = [ # newest record 152 | self.client.zrevrangebyscore( 153 | name=zset_key, 154 | min=min_timestamp.timestamp(), 155 | max=datetime.now().timestamp(), 156 | withscores=False, 157 | start=0, 158 | num=1, 159 | ) 160 | for zset_key in all_zset_key 161 | if zset_key 162 | ] 163 | columns = list(json.loads(data[0][0]).keys()) 164 | batch_data_list = [[json.loads(data[i][0])[key] for key in columns] for i in range(len(data))] 165 | data = pd.DataFrame(data=batch_data_list, columns=columns) 166 | else: 167 | data = pd.DataFrame() 168 | return data 169 | -------------------------------------------------------------------------------- /f2ai/definitions/offline_store.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import abc 3 | import pandas as pd 4 | import datetime 5 | import os 6 | from enum import Enum 7 | from typing import Dict, Any, TYPE_CHECKING, Set, List, Optional, Union 8 | from pydantic import BaseModel 9 | 10 | if TYPE_CHECKING: 11 | from .services import Service 12 | from .sources import Source 13 | from .features import Feature 14 | from .period import Period 15 | from .constants import StatsFunctions 16 | 17 | 18 | class OfflineStoreType(str, Enum): 19 | """A constant numerate choices which is used to indicate how to initialize OfflineStore from configuration. If you want to add a new type of offline store, you definitely want to modify this.""" 20 | 21 | FILE = "file" 22 | PGSQL = "pgsql" 23 | SPARK = "spark" 24 | 25 | 26 | class OfflineStore(BaseModel): 27 | """An abstraction of what functionalities a OfflineStore should implements. If you want to be a one of the offline store contributor. This is the core.""" 28 | 29 | type: OfflineStoreType 30 | materialize_path: Optional[str] 31 | 32 | class Config: 33 | extra = "allow" 34 | 35 | @abc.abstractmethod 36 | def get_offline_source(self, service: Service) -> Source: 37 | """get offline materialized source with a specific service 38 | 39 | Args: 40 | service (Service): an instance of Service 41 | 42 | Returns: 43 | Source 44 | """ 45 | pass 46 | 47 | @abc.abstractmethod 48 | def get_features( 49 | self, 50 | entity_df: pd.DataFrame, 51 | features: Set[Feature], 52 | source: Source, 53 | join_keys: List[str] = [], 54 | ttl: Optional[Period] = None, 55 | include: bool = True, 56 | **kwargs, 57 | ) -> pd.DataFrame: 58 | """get features from current offline store. 59 | 60 | Args: 61 | entity_df (pd.DataFrame): A query DataFrame which include entities and event_timestamp column. 62 | features (Set[Feature]): A set of Features you want to retrieve. 63 | source (Source): A specific implementation of Source. For example, OfflinePostgresStore will receive a SqlSource which point to table with time semantic. 64 | join_keys (List[str], optional): Which columns to join the entity_df with source. Defaults to []. 65 | ttl (Optional[Period], optional): Time to Live, if feature's event_timestamp exceeds the ttl, it will be dropped. Defaults to None. 66 | include (bool, optional): If include (<=) the event_timestamp in entity_df, else (<). Defaults to True. 67 | 68 | Returns: 69 | pd.DataFrame 70 | """ 71 | pass 72 | 73 | @abc.abstractmethod 74 | def get_period_features( 75 | self, 76 | entity_df: pd.DataFrame, 77 | features: List[Feature], 78 | source: Source, 79 | period: Period, 80 | join_keys: List[str] = [], 81 | ttl: Optional[Period] = None, 82 | include: bool = True, 83 | **kwargs, 84 | ) -> pd.DataFrame: 85 | """get a period of features from offline store between [event_timestamp, event_timestamp + period] if period > 0 else [event_timestamp + period, event_timestamp]. 86 | 87 | Args: 88 | entity_df (pd.DataFrame): A query DataFrame which include entities and event_timestamp column. 89 | features (List[Feature]): A list of Features you want to retrieve. 90 | source (Source): A specific implementation of Source. For example, OfflinePostgresStore will receive a SqlSource which point to table with time semantic. 91 | period (Period): A Period instance, which wrapped by F2AI. 92 | join_keys (List[str], optional): Which columns to join the entity_df with source. Defaults to [].. Defaults to []. 93 | ttl (Optional[Period], optional): Time to Live, if feature's event_timestamp exceeds the ttl, it will be dropped. Defaults to None. 94 | include (bool, optional): If include (<=) the event_timestamp in entity_df, else (<). Defaults to True. 95 | 96 | Returns: 97 | pd.DataFrame 98 | """ 99 | pass 100 | 101 | @abc.abstractmethod 102 | def get_latest_entities( 103 | self, 104 | source: Source, 105 | join_keys: List[str] = None, 106 | group_keys: List[str] = None, 107 | entity_df: pd.DataFrame = None, 108 | start: datetime.datetime = None, 109 | ) -> pd.DataFrame: 110 | """get latest unique entities from a source. Which is useful when you want to know how many entities you have, or what is the latest features appear in your data source. 111 | 112 | Args: 113 | source (Source): A specific implementation of Source. For example, OfflinePostgresStore will receive a SqlSource which point to table with time semantic. 114 | join_keys (List[str], optional): Which columns to join the entity_df with source. Defaults to None. 115 | group_keys (List[str], optional): Which columns to aggregate 116 | entity_df (pd.DataFrame, optional): A query DataFrame which include entities and event_timestamp column. Defaults to None. 117 | 118 | Returns: 119 | pd.DataFrame 120 | """ 121 | pass 122 | 123 | @abc.abstractmethod 124 | def stats( 125 | self, 126 | source: Source, 127 | features: Set[Feature], 128 | fn: StatsFunctions, 129 | group_keys: List[str] = [], 130 | start: datetime.datetime = None, 131 | end: datetime.datetime = None, 132 | ) -> Union[pd.DataFrame, Dict[str, list]]: 133 | """Get statistical information with given StatsFunctions. 134 | 135 | Args: 136 | source (Source): A specific implementation of Source. For example, OfflinePostgresStore 137 | features (Set[Feature]): A set of Features you want stats on. 138 | fn (StatsFunctions): A stats function, which contains min, max, std, avg, mode, median, unique. 139 | group_keys (List[str], optional): How to group by. Defaults to []. 140 | start (datetime.datetime, optional): Defaults to None. 141 | end (datetime.datetime, optional): Defaults to None. 142 | 143 | Returns: 144 | Union[pd.DataFrame, Dict[str, list]]: Return Dict when unique, otherwise pd.DataFrame 145 | """ 146 | pass 147 | 148 | @abc.abstractmethod 149 | def query(self, query: str, *args, **kwargs) -> Any: 150 | """ 151 | Run a query with specific offline store. egg: 152 | if you are using pgsql, this will run a query via psycopg2 153 | if you are using spark, this will run a query via sparksql 154 | """ 155 | pass 156 | 157 | 158 | def init_offline_store_from_cfg(cfg: Dict[Any], project_name: str) -> OfflineStore: 159 | """Initialize an implementation of OfflineStore from yaml config. 160 | 161 | Args: 162 | cfg (Dict[Any]): a parsed config object. 163 | 164 | Returns: 165 | OfflineStore: Different types of OfflineStore. 166 | """ 167 | offline_store_type = OfflineStoreType(cfg["type"]) 168 | 169 | if offline_store_type == OfflineStoreType.FILE: 170 | from ..offline_stores.offline_file_store import OfflineFileStore 171 | 172 | offline_store = OfflineFileStore(**cfg) 173 | if offline_store.materialize_path is None: 174 | offline_store.materialize_path = os.path.join(os.path.expanduser("~"), '.f2ai', project_name) 175 | 176 | return offline_store 177 | 178 | if offline_store_type == OfflineStoreType.PGSQL: 179 | from ..offline_stores.offline_postgres_store import OfflinePostgresStore 180 | 181 | pgsql_conf = cfg.pop("pgsql_conf", {}) 182 | offline_store = OfflinePostgresStore(**cfg, **pgsql_conf) 183 | if offline_store.materialize_path is None: 184 | offline_store.materialize_path = f'{offline_store.database}.{offline_store.db_schema}' 185 | 186 | return offline_store 187 | 188 | if offline_store_type == OfflineStoreType.SPARK: 189 | from ..offline_stores.offline_spark_store import OfflineSparkStore 190 | 191 | return OfflineSparkStore(**cfg) 192 | 193 | raise TypeError(f"offline store type must be one of [{','.join(e.value for e in OfflineStoreType)}]") 194 | -------------------------------------------------------------------------------- /f2ai/definitions/persist_engine.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import abc 3 | import os 4 | from typing import Dict, List, Optional 5 | from pydantic import BaseModel 6 | from enum import Enum 7 | from tqdm import tqdm 8 | from multiprocessing import Pool 9 | from f2ai.definitions import ( 10 | OfflineStoreType, 11 | OfflineStore, 12 | OnlineStore, 13 | Entity, 14 | Service, 15 | FeatureView, 16 | LabelView, 17 | BackOffTime, 18 | Source, 19 | Feature, 20 | Period, 21 | ) 22 | 23 | 24 | class PersistFeatureView(BaseModel): 25 | """ 26 | Another feature view that usually used when finalizing the results. 27 | """ 28 | 29 | name: str 30 | source: Source 31 | features: List[Feature] = [] 32 | join_keys: List[str] = [] 33 | ttl: Optional[Period] 34 | 35 | 36 | class PersistLabelView(BaseModel): 37 | """ 38 | Another label view that usually used when finalizing the results. 39 | """ 40 | 41 | source: Source 42 | labels: List[Feature] = [] 43 | join_keys: List[str] = [] 44 | 45 | 46 | class PersistEngineType(str, Enum): 47 | """A constant numerate choices which is used to indicate how to initialize PersistEngine from configuration.""" 48 | 49 | OFFLINE = "offline" 50 | ONLINE = "online" 51 | 52 | 53 | class OfflinePersistEngineType(str, Enum): 54 | """A constant numerate choices which is used to indicate how to initialize PersistEngine from configuration.""" 55 | 56 | FILE = "file" 57 | PGSQL = "pgsql" 58 | SPARK = "spark" 59 | 60 | 61 | class OnlinePersistEngineType(str, Enum): 62 | """A constant numerate choices which is used to indicate how to initialize PersistEngine from configuration.""" 63 | 64 | LOCAL = "local" 65 | DISTRIBUTE = "distribute" 66 | 67 | 68 | class PersistEngine(BaseModel): 69 | type: PersistEngineType 70 | 71 | class Config: 72 | extra = "allow" 73 | 74 | 75 | class OfflinePersistEngine(PersistEngine): 76 | type: OfflinePersistEngineType 77 | offline_store: OfflineStore 78 | 79 | class Config: 80 | extra = "allow" 81 | 82 | @abc.abstractmethod 83 | def materialize( 84 | self, 85 | feature_views: List[PersistFeatureView], 86 | label_view: PersistLabelView, 87 | destination: Source, 88 | back_off_time: BackOffTime, 89 | ): 90 | pass 91 | 92 | 93 | class OnlinePersistEngine(PersistEngine): 94 | type: OnlinePersistEngineType 95 | online_store: OnlineStore 96 | offline_store: OfflineStore 97 | 98 | class Config: 99 | extra = "allow" 100 | 101 | @abc.abstractmethod 102 | def materialize(self, prefix: str, feature_view: PersistFeatureView, back_off_time: BackOffTime): 103 | pass 104 | 105 | 106 | class RealPersistEngine(BaseModel): 107 | offline_engine: OfflinePersistEngine 108 | online_engine: OnlinePersistEngine 109 | 110 | def materialize_offline( 111 | self, 112 | services: List[Service], 113 | label_views: Dict[str, LabelView], 114 | feature_views: Dict[str, FeatureView], 115 | entities: Dict[str, Entity], 116 | sources: Dict[str, Source], 117 | back_off_time: BackOffTime, 118 | ): 119 | cpu_ava = max(os.cpu_count() // 2, 1) 120 | 121 | # with Pool(processes=cpu_ava) as pool: 122 | service_to_list_of_args = [] 123 | back_off_segments = list(back_off_time.to_units()) 124 | for service in services: 125 | destination = self.offline_engine.offline_store.get_offline_source(service) 126 | label_view = service.get_label_views(label_views)[0] 127 | label_view = PersistLabelView( 128 | source=sources[label_view.batch_source], 129 | labels=label_view.get_label_objects(), 130 | join_keys=[ 131 | join_key 132 | for entity_name in label_view.entities 133 | for join_key in entities[entity_name].join_keys 134 | ], 135 | ) 136 | label_names = set([label.name for label in label_view.labels]) 137 | 138 | feature_views = [ 139 | PersistFeatureView( 140 | name=feature_view.name, 141 | join_keys=[ 142 | join_key 143 | for entity_name in feature_view.entities 144 | for join_key in entities[entity_name].join_keys 145 | ], 146 | features=[ 147 | feature 148 | for feature in feature_view.get_feature_objects() 149 | if feature.name not in label_names 150 | ], 151 | source=sources[feature_view.batch_source], 152 | ttl=feature_view.ttl, 153 | ) 154 | for feature_view in service.get_feature_views(feature_views) 155 | ] 156 | feature_views = [feature_view for feature_view in feature_views if len(feature_view.features) > 0] 157 | service_to_list_of_args += [ 158 | (feature_views, label_view, destination, per_backoff, service.name) 159 | for per_backoff in back_off_segments 160 | ] 161 | 162 | bars = { 163 | service.name: tqdm( 164 | total=len(back_off_segments), desc=f"materializing {service.name}", position=pos 165 | ) 166 | for pos, service in enumerate(services) 167 | } 168 | 169 | pool = Pool(processes=cpu_ava) 170 | 171 | for param in service_to_list_of_args: 172 | pool.apply_async( 173 | self.offline_engine.materialize, 174 | param, 175 | callback=lambda x: bars[x].update(), 176 | error_callback=lambda x: print(x.__cause__), 177 | ) 178 | 179 | pool.close() 180 | pool.join() 181 | [bars[bar].close() for bar in bars] 182 | 183 | def materialize_online( 184 | self, 185 | prefix: str, 186 | feature_views: List[FeatureView], 187 | entities: Dict[str, Entity], 188 | sources: Dict[str, Source], 189 | back_off_time: BackOffTime, 190 | ): 191 | cpu_ava = max(os.cpu_count() // 2, 1) 192 | 193 | feature_backoffs = [] 194 | back_off_segments = list(back_off_time.to_units()) 195 | for feature_view in feature_views: 196 | join_keys = list( 197 | { 198 | join_key 199 | for entity_name in feature_view.entities 200 | for join_key in entities[entity_name].join_keys 201 | } 202 | ) 203 | feature_view = PersistFeatureView( 204 | name=feature_view.name, 205 | join_keys=join_keys, 206 | features=feature_view.get_feature_objects(), 207 | source=sources[feature_view.batch_source], 208 | ttl=feature_view.ttl, 209 | ) 210 | feature_backoffs += [ 211 | (prefix, feature_view, per_backoff, feature_view.name) for per_backoff in back_off_segments 212 | ] 213 | 214 | bars = { 215 | feature_view.name: tqdm( 216 | total=len(back_off_segments), desc=f"materializing {feature_view.name}", position=pos 217 | ) 218 | for pos, feature_view in enumerate(feature_views) 219 | } 220 | 221 | pool = Pool(processes=cpu_ava) 222 | 223 | for param in feature_backoffs: 224 | pool.apply_async( 225 | self.online_engine.materialize, 226 | param, 227 | callback=lambda x: bars[x].update(), 228 | error_callback=lambda x: print(x.__cause__), 229 | ) 230 | 231 | pool.close() 232 | pool.join() 233 | [bars[bar].close() for bar in bars] 234 | 235 | 236 | def init_persist_engine_from_cfg(offline_store: OfflineStore, online_store: OnlineStore): 237 | """Initialize an implementation of PersistEngine from yaml config. 238 | 239 | Args: 240 | cfg (Dict[Any]): a parsed config object. 241 | 242 | Returns: 243 | RealPersistEngine: Contains Offline and Online PersistEngine 244 | """ 245 | 246 | if offline_store.type == OfflineStoreType.FILE: 247 | from ..persist_engine.offline_file_persistengine import OfflineFilePersistEngine 248 | from ..persist_engine.online_local_persistengine import OnlineLocalPersistEngine 249 | 250 | offline_persist_engine_cls = OfflineFilePersistEngine 251 | online_persist_engine_cls = OnlineLocalPersistEngine 252 | 253 | elif offline_store.type == OfflineStoreType.PGSQL: 254 | from ..persist_engine.offline_pgsql_persistengine import OfflinePgsqlPersistEngine 255 | from ..persist_engine.online_local_persistengine import OnlineLocalPersistEngine 256 | 257 | offline_persist_engine_cls = OfflinePgsqlPersistEngine 258 | online_persist_engine_cls = OnlineLocalPersistEngine 259 | 260 | elif offline_store.type == OfflineStoreType.SPARK: 261 | from ..persist_engine.offline_spark_persistengine import OfflineSparkPersistEngine 262 | from ..persist_engine.online_spark_persistengine import OnlineSparkPersistEngine 263 | 264 | offline_persist_engine_cls = OfflineSparkPersistEngine 265 | online_persist_engine_cls = OnlineSparkPersistEngine 266 | 267 | offline_engine = offline_persist_engine_cls(offline_store=offline_store) 268 | online_engine = online_persist_engine_cls(offline_store=offline_store, online_store=online_store) 269 | return RealPersistEngine(offline_engine=offline_engine, online_engine=online_engine) 270 | -------------------------------------------------------------------------------- /tests/units/offline_stores/offline_file_store_test.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from f2ai.offline_stores.offline_file_store import OfflineFileStore 3 | from f2ai.definitions import Period, FileSource, Feature, FeatureDTypes, StatsFunctions 4 | from f2ai.common.time_field import ENTITY_EVENT_TIMESTAMP_FIELD, SOURCE_EVENT_TIMESTAMP_FIELD 5 | 6 | import pytest 7 | from unittest.mock import MagicMock 8 | 9 | mock_point_in_time_filter_df = pd.DataFrame( 10 | { 11 | "join_key": ["A", "A", "A", "A"], 12 | ENTITY_EVENT_TIMESTAMP_FIELD: [ 13 | pd.Timestamp("2021-08-25 20:16:20"), 14 | pd.Timestamp("2021-08-25 20:16:20"), 15 | pd.Timestamp("2021-08-25 20:16:20"), 16 | pd.Timestamp("2021-08-25 20:16:20"), 17 | ], 18 | SOURCE_EVENT_TIMESTAMP_FIELD: [ 19 | pd.Timestamp("2021-08-25 20:16:18"), 20 | pd.Timestamp("2021-08-25 20:16:19"), 21 | pd.Timestamp("2021-08-25 20:16:20"), 22 | pd.Timestamp("2021-08-25 20:16:21"), 23 | ], 24 | }, 25 | ) 26 | 27 | 28 | def test_point_in_time_filter_simple(): 29 | result_df = OfflineFileStore._point_in_time_filter(mock_point_in_time_filter_df) 30 | assert len(result_df) == 3 31 | assert result_df["_source_event_timestamp_"].max() == pd.Timestamp("2021-08-25 20:16:20") 32 | 33 | 34 | def test_point_in_time_filter_not_include(): 35 | result_df = OfflineFileStore._point_in_time_filter(mock_point_in_time_filter_df, include=False) 36 | assert len(result_df) == 2 37 | assert result_df["_source_event_timestamp_"].max() == pd.Timestamp("2021-08-25 20:16:19") 38 | 39 | 40 | def test_point_in_time_filter_with_ttl(): 41 | result_df = OfflineFileStore._point_in_time_filter( 42 | mock_point_in_time_filter_df, ttl=Period.from_str("2 seconds") 43 | ) 44 | assert len(result_df) == 2 45 | assert result_df["_source_event_timestamp_"].max() == pd.Timestamp("2021-08-25 20:16:20") 46 | assert result_df["_source_event_timestamp_"].min() == pd.Timestamp("2021-08-25 20:16:19") 47 | 48 | 49 | def test_point_on_time_filter_simple(): 50 | result_df = OfflineFileStore._point_on_time_filter( 51 | mock_point_in_time_filter_df, -Period.from_str("2 seconds"), include=True 52 | ) 53 | assert len(result_df) == 2 54 | assert result_df["_source_event_timestamp_"].max() == pd.Timestamp("2021-08-25 20:16:20") 55 | assert result_df["_source_event_timestamp_"].min() == pd.Timestamp("2021-08-25 20:16:19") 56 | 57 | 58 | def test_point_on_time_filter_simple_label(): 59 | result_df = OfflineFileStore._point_on_time_filter( 60 | mock_point_in_time_filter_df, Period.from_str("2 seconds"), include=True 61 | ) 62 | assert len(result_df) == 2 63 | assert result_df["_source_event_timestamp_"].max() == pd.Timestamp("2021-08-25 20:16:21") 64 | assert result_df["_source_event_timestamp_"].min() == pd.Timestamp("2021-08-25 20:16:20") 65 | 66 | 67 | def test_point_on_time_filter_not_include(): 68 | result_df = OfflineFileStore._point_on_time_filter( 69 | mock_point_in_time_filter_df, period=-Period.from_str("2 seconds"), include=False 70 | ) 71 | assert len(result_df) == 2 72 | assert result_df["_source_event_timestamp_"].max() == pd.Timestamp("2021-08-25 20:16:19") 73 | assert result_df["_source_event_timestamp_"].min() == pd.Timestamp("2021-08-25 20:16:18") 74 | 75 | 76 | def test_point_on_time_filter_not_include_label(): 77 | result_df = OfflineFileStore._point_on_time_filter( 78 | mock_point_in_time_filter_df, Period.from_str("2 seconds"), include=False 79 | ) 80 | assert len(result_df) == 1 81 | assert result_df["_source_event_timestamp_"].max() == pd.Timestamp("2021-08-25 20:16:21") 82 | 83 | 84 | def test_point_on_time_filter_with_ttl(): 85 | result_df = OfflineFileStore._point_on_time_filter( 86 | mock_point_in_time_filter_df, 87 | period=-Period.from_str("3 seconds"), 88 | ttl=Period.from_str("2 seconds"), 89 | include=True, 90 | ) 91 | assert len(result_df) == 2 92 | assert result_df["_source_event_timestamp_"].max() == pd.Timestamp("2021-08-25 20:16:20") 93 | assert result_df["_source_event_timestamp_"].min() == pd.Timestamp("2021-08-25 20:16:19") 94 | 95 | 96 | mock_point_in_time_latest_df = pd.DataFrame( 97 | { 98 | "join_key": ["A", "A", "B", "B"], 99 | ENTITY_EVENT_TIMESTAMP_FIELD: [ 100 | pd.Timestamp("2021-08-25 20:16:20"), 101 | pd.Timestamp("2021-08-25 20:16:20"), 102 | pd.Timestamp("2021-08-25 20:16:20"), 103 | pd.Timestamp("2021-08-25 20:16:20"), 104 | ], 105 | SOURCE_EVENT_TIMESTAMP_FIELD: [ 106 | pd.Timestamp("2021-08-25 20:16:18"), 107 | pd.Timestamp("2021-08-25 20:16:19"), 108 | pd.Timestamp("2021-08-25 20:16:11"), 109 | pd.Timestamp("2021-08-25 20:16:20"), 110 | ], 111 | }, 112 | ) 113 | 114 | 115 | def test_point_in_time_latest_with_group_keys(): 116 | result_df = OfflineFileStore._point_in_time_latest(mock_point_in_time_latest_df, ["join_key"]) 117 | 118 | df_a = result_df[result_df["join_key"] == "A"].iloc[0] 119 | df_b = result_df[result_df["join_key"] == "B"].iloc[0] 120 | 121 | assert df_a["_source_event_timestamp_"] == pd.Timestamp("2021-08-25 20:16:19") 122 | assert df_b["_source_event_timestamp_"] == pd.Timestamp("2021-08-25 20:16:20") 123 | 124 | 125 | def test_point_in_time_latest_without_group_keys(): 126 | result_df = OfflineFileStore._point_in_time_latest(mock_point_in_time_latest_df) 127 | assert result_df.loc[0]["_source_event_timestamp_"] == pd.Timestamp("2021-08-25 20:16:20") 128 | 129 | 130 | mock_source_df = pd.DataFrame( 131 | { 132 | "join_key": ["A", "A", "B", "B"], 133 | "event_timestamp": [ 134 | pd.Timestamp("2021-08-25 20:16:18"), 135 | pd.Timestamp("2021-08-25 20:16:19"), 136 | pd.Timestamp("2021-08-25 20:16:11"), 137 | pd.Timestamp("2021-08-25 20:16:20"), 138 | ], 139 | "feature": [1, 2, 3, 4], 140 | }, 141 | ) 142 | mock_entity_df = pd.DataFrame( 143 | { 144 | "join_key": ["A", "B", "A"], 145 | "event_timestamp": [ 146 | pd.Timestamp("2021-08-25 20:16:18"), 147 | pd.Timestamp("2021-08-25 20:16:19"), 148 | pd.Timestamp("2021-08-25 20:16:19"), 149 | ], 150 | "request_feature": [6, 5, 4], 151 | } 152 | ) 153 | 154 | 155 | def test_point_in_time_join_with_join_keys(): 156 | result_df = OfflineFileStore._point_in_time_join( 157 | mock_entity_df, mock_source_df, timestamp_field="event_timestamp", join_keys=["join_key"] 158 | ) 159 | assert len(result_df) == 3 160 | 161 | 162 | def test_point_in_time_join_with_ttl(): 163 | result_df = OfflineFileStore._point_in_time_join( 164 | mock_entity_df, 165 | mock_source_df, 166 | timestamp_field="event_timestamp", 167 | join_keys=["join_key"], 168 | ttl=Period.from_str("2 seconds"), 169 | ) 170 | assert len(result_df) == 2 171 | 172 | 173 | def test_point_in_time_join_with_extra_entities_in_source(): 174 | result_df = OfflineFileStore._point_in_time_join( 175 | pd.DataFrame( 176 | { 177 | "join_key": ["A"], 178 | "event_timestamp": [pd.Timestamp("2021-08-25 20:16:18")], 179 | "request_feature": [6], 180 | } 181 | ), 182 | mock_source_df, 183 | timestamp_field="event_timestamp", 184 | join_keys=["join_key"], 185 | ) 186 | assert len(result_df) == 1 187 | 188 | 189 | def test_point_in_time_join_with_created_timestamp(): 190 | result_df = OfflineFileStore._point_in_time_join( 191 | mock_entity_df, 192 | pd.DataFrame( 193 | { 194 | "join_key": ["A", "A"], 195 | "event_timestamp": [ 196 | pd.Timestamp("2021-08-25 20:16:18"), 197 | pd.Timestamp("2021-08-25 20:16:18"), 198 | ], 199 | "created_timestamp": [ 200 | pd.Timestamp("2021-08-25 20:16:21"), 201 | pd.Timestamp("2021-08-25 20:16:20"), 202 | ], 203 | "feature": [5, 6], 204 | }, 205 | ), 206 | timestamp_field="event_timestamp", 207 | created_timestamp_field="created_timestamp", 208 | join_keys=["join_key"], 209 | ttl=Period.from_str("2 seconds"), 210 | ) 211 | assert all(result_df["feature"] == [5, 5]) 212 | 213 | 214 | def test_point_on_time_join_with_join_keys(): 215 | result_df = OfflineFileStore._point_on_time_join( 216 | mock_entity_df, 217 | mock_source_df, 218 | period=-Period.from_str("2 seconds"), 219 | timestamp_field="event_timestamp", 220 | join_keys=["join_key"], 221 | include=False, 222 | ) 223 | assert len(result_df) == 1 224 | assert result_df["join_key"].values == "A" 225 | 226 | 227 | def test_point_on_time_join_with_join_keys_label(): 228 | result_df = OfflineFileStore._point_on_time_join( 229 | mock_entity_df, 230 | mock_source_df, 231 | period=Period.from_str("2 seconds"), 232 | timestamp_field="event_timestamp", 233 | join_keys=["join_key"], 234 | ) 235 | assert len(result_df) == 4 236 | 237 | 238 | def test_point_on_time_join_with_ttl(): 239 | result_df = OfflineFileStore._point_on_time_join( 240 | mock_entity_df, 241 | mock_source_df, 242 | timestamp_field="event_timestamp", 243 | join_keys=["join_key"], 244 | ttl=Period.from_str("2 seconds"), 245 | period=-Period.from_str("10 seconds"), 246 | ) 247 | assert len(result_df) == 3 248 | assert "B" not in result_df["join_key"].values 249 | 250 | 251 | def test_point_on_time_join_with_extra_entities_in_source(): 252 | result_df = OfflineFileStore._point_on_time_join( 253 | pd.DataFrame( 254 | { 255 | "join_key": ["A"], 256 | "event_timestamp": [pd.Timestamp("2021-08-25 20:16:22")], 257 | "request_feature": [6], 258 | } 259 | ), 260 | mock_source_df, 261 | period=Period.from_str("3 seconds"), 262 | timestamp_field="event_timestamp", 263 | join_keys=["join_key"], 264 | ) 265 | assert len(result_df) == 0 266 | 267 | 268 | def test_point_on_time_join_with_created_timestamp(): 269 | result_df = OfflineFileStore._point_on_time_join( 270 | mock_entity_df, 271 | pd.DataFrame( 272 | { 273 | "join_key": ["A", "A", "A", "A"], 274 | "event_timestamp": [ 275 | pd.Timestamp("2021-08-25 20:16:16"), 276 | pd.Timestamp("2021-08-25 20:16:17"), 277 | pd.Timestamp("2021-08-25 20:16:18"), 278 | pd.Timestamp("2021-08-25 20:16:18"), 279 | ], 280 | "materialize_time": [ 281 | pd.Timestamp("2021-08-25 20:16:17"), 282 | pd.Timestamp("2021-08-25 20:16:19"), 283 | pd.Timestamp("2021-08-25 20:16:21"), 284 | pd.Timestamp("2021-08-25 20:16:20"), 285 | ], 286 | "feature": [3, 4, 5, 6], 287 | }, 288 | ), 289 | timestamp_field="event_timestamp", 290 | created_timestamp_field="materialize_time", 291 | join_keys=["join_key"], 292 | period=-Period.from_str("2 seconds"), 293 | ) 294 | assert all(result_df["feature"] == [4, 5, 5]) 295 | assert result_df[result_df["event_timestamp"] == pd.Timestamp("2021-08-25 20:16:18")].shape == (2, 5) 296 | assert pd.Timestamp("2021-08-25 20:16:20") not in result_df["event_timestamp"] 297 | 298 | 299 | mocked_stats_input_df = pd.DataFrame( 300 | { 301 | "join_key": ["A", "B", "A", "B"], 302 | "event_timestamp": [ 303 | pd.Timestamp("2021-08-25 20:16:16"), 304 | pd.Timestamp("2021-08-25 20:16:17"), 305 | pd.Timestamp("2021-08-25 20:16:18"), 306 | pd.Timestamp("2021-08-25 20:16:18"), 307 | ], 308 | "F1": [1, 2, 3, 4], 309 | } 310 | ) 311 | 312 | 313 | @pytest.mark.parametrize("fn", [(fn) for fn in StatsFunctions]) 314 | def test_stats(fn): 315 | store = OfflineFileStore() 316 | file_source = FileSource(name="mock", path="mock") 317 | features = [Feature(name="F1", dtype=FeatureDTypes.FLOAT, view_name="hello")] 318 | 319 | store._read_file = MagicMock(return_value=mocked_stats_input_df) 320 | 321 | result_df = store.stats( 322 | file_source, 323 | features, 324 | fn, 325 | group_keys=["join_key"], 326 | ) 327 | if fn == StatsFunctions.UNIQUE: 328 | assert ",".join(result_df.columns) == "join_key" 329 | else: 330 | assert ",".join(result_df.index.names) == "join_key" 331 | assert isinstance(result_df, pd.DataFrame) 332 | -------------------------------------------------------------------------------- /f2ai/models/nbeats/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchmetrics import MeanAbsoluteError 4 | from typing import List, Dict, Tuple 5 | from torch.utils.data import DataLoader 6 | 7 | from f2ai.featurestore import FeatureStore 8 | from f2ai.dataset import EvenEventsSampler, NoEntitiesSampler 9 | from f2ai.common.collect_fn import nbeats_collet_fn 10 | 11 | from submodules import NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock, MultiEmbedding 12 | 13 | TIME_IDX = "_time_idx_" 14 | 15 | 16 | class NbeatsNetwork(nn.Module): 17 | def __init__( 18 | self, 19 | targets: str, 20 | model_type: str = "G", # 'I' 21 | num_stack: int = 1, 22 | num_block: int = 1, 23 | width: int = 8, # [2**9] 24 | expansion_coe: int = 5, # [2**5] 25 | num_block_layer: int = 4, 26 | prediction_length: int = 0, 27 | context_length: int = 0, 28 | dropout: float = 0.1, 29 | backcast_loss_ratio: float = 0.1, 30 | covariate_number: int = 0, 31 | encoder_cont: List[str] = [], 32 | decoder_cont: List[str] = [], 33 | embedding_sizes: Dict[str, Tuple[int, int]] = {}, 34 | x_categoricals: List[str] = [], 35 | output_size=1, 36 | ): 37 | 38 | super().__init__() 39 | self._targets = targets 40 | self._encoder_cont = encoder_cont 41 | self._decoder_cont = decoder_cont 42 | self.dropout = dropout 43 | self.backcast_loss_ratio = backcast_loss_ratio 44 | self.context_length = context_length 45 | self.prediction_length = prediction_length 46 | self.target_number = output_size 47 | self.covariate_number = covariate_number 48 | 49 | self.encoder_embeddings = MultiEmbedding( 50 | embedding_sizes=embedding_sizes, 51 | embedding_paddings=[], 52 | categorical_groups={}, 53 | x_categoricals=x_categoricals, 54 | ) 55 | 56 | if model_type == "I": 57 | width = [2**width, 2 ** (width + 2)] 58 | self.stack_types = ["trend", "seasonality"] * num_stack 59 | self.expansion_coefficient_lengths = [item for i in range(num_stack) for item in [3, 7]] 60 | self.num_blocks = [num_block for i in range(2 * num_stack)] 61 | self.num_block_layers = [num_block_layer for i in range(2 * num_stack)] 62 | self.widths = [item for i in range(num_stack) for item in width] 63 | elif model_type == "G": 64 | self.stack_types = ["generic"] * num_stack 65 | self.expansion_coefficient_lengths = [2**expansion_coe for i in range(num_stack)] 66 | self.num_blocks = [num_block for i in range(num_stack)] 67 | self.num_block_layers = [num_block_layer for i in range(num_stack)] 68 | self.widths = [2**width for i in range(num_stack)] 69 | # 70 | # setup stacks 71 | self.net_blocks = nn.ModuleList() 72 | 73 | for stack_id, stack_type in enumerate(self.stack_types): 74 | for _ in range(self.num_blocks[stack_id]): 75 | if stack_type == "generic": 76 | net_block = NBEATSGenericBlock( 77 | units=self.widths[stack_id], 78 | thetas_dim=self.expansion_coefficient_lengths[stack_id], 79 | num_block_layers=self.num_block_layers[stack_id], 80 | backcast_length=self.context_length, 81 | forecast_length=self.prediction_length, 82 | dropout=self.dropout, 83 | tar_num=self.target_number, 84 | cov_num=self.covariate_number + self.encoder_embeddings.total_embedding_size(), 85 | tar_pos=self.target_positions, 86 | ) 87 | elif stack_type == "seasonality": 88 | net_block = NBEATSSeasonalBlock( 89 | units=self.widths[stack_id], 90 | num_block_layers=self.num_block_layers[stack_id], 91 | backcast_length=self.context_length, 92 | forecast_length=self.prediction_length, 93 | min_period=self.expansion_coefficient_lengths[stack_id], 94 | dropout=self.dropout, 95 | tar_num=self.target_number, 96 | cov_num=self.covariate_number + self.encoder_embeddings.total_embedding_size(), 97 | tar_pos=self.target_positions, 98 | ) 99 | elif stack_type == "trend": 100 | net_block = NBEATSTrendBlock( 101 | units=self.widths[stack_id], 102 | thetas_dim=self.expansion_coefficient_lengths[stack_id], 103 | num_block_layers=self.num_block_layers[stack_id], 104 | backcast_length=self.context_length, 105 | forecast_length=self.prediction_length, 106 | dropout=self.dropout, 107 | tar_num=self.target_number, 108 | cov_num=self.covariate_number + self.encoder_embeddings.total_embedding_size(), 109 | tar_pos=self.target_positions, 110 | ) 111 | else: 112 | raise ValueError(f"Unknown stack type {stack_type}") 113 | 114 | self.net_blocks.append(net_block) 115 | 116 | @property 117 | def target_positions(self): 118 | return [self._encoder_cont.index(tar) for tar in self._targets] 119 | 120 | def forward(self, x: Dict[str, torch.Tensor], **kwargs) -> Dict[str, torch.Tensor]: 121 | """ 122 | Pass forward of network. 123 | 124 | Args: 125 | x (Dict[str, torch.Tensor]): input from dataloader generated from 126 | :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`. 127 | 128 | Returns: 129 | Dict[str, torch.Tensor]: output of model 130 | """ 131 | # batch_size * look_back * features 132 | encoder_cont = torch.cat( 133 | [x["encoder_cont"], x["encoder_time_features"], x["encoder_lag_features"]], dim=-1 134 | ) 135 | # `target` can only be continuous, so position inside `encoder_cat` is irrelevant 136 | encoder_cat = ( 137 | torch.cat([v for _, v in self.encoder_embeddings(x["encoder_cat"]).items()], dim=-1) 138 | if self.encoder_embeddings.total_embedding_size() != 0 139 | else torch.zeros( 140 | encoder_cont.size(0), 141 | self.context_length, 142 | self.encoder_embeddings.total_embedding_size(), 143 | device=encoder_cont.device, 144 | ) 145 | ) 146 | # self.hparams.prediction_length=gap+real_predict 147 | timesteps = self.context_length + self.prediction_length 148 | # encoder_cont.size(2) + self.encoder_embeddings.total_embedding_size(), 149 | generic_forecast = [ 150 | torch.zeros( 151 | (encoder_cont.size(0), timesteps, len(self.target_positions)), 152 | dtype=torch.float32, 153 | device=encoder_cont.device, 154 | ) 155 | ] 156 | 157 | trend_forecast = [ 158 | torch.zeros( 159 | (encoder_cont.size(0), timesteps, len(self.target_positions)), 160 | dtype=torch.float32, 161 | device=encoder_cont.device, 162 | ) 163 | ] 164 | seasonal_forecast = [ 165 | torch.zeros( 166 | (encoder_cont.size(0), timesteps, len(self.target_positions)), 167 | dtype=torch.float32, 168 | device=encoder_cont.device, 169 | ) 170 | ] 171 | 172 | forecast = torch.zeros( 173 | (encoder_cont.size(0), self.prediction_length, len(self.target_positions)), 174 | dtype=torch.float32, 175 | device=encoder_cont.device, 176 | ) 177 | 178 | # make sure `encoder_cont` is followed by `encoder_cat` 179 | 180 | backcast = torch.cat([encoder_cont, encoder_cat], dim=-1) 181 | 182 | for i, block in enumerate(self.net_blocks): 183 | # evaluate block 184 | backcast_block, forecast_block = block(backcast) 185 | # add for interpretation 186 | full = torch.cat([backcast_block.detach(), forecast_block.detach()], dim=1) 187 | if isinstance(block, NBEATSTrendBlock): 188 | trend_forecast.append(full) 189 | elif isinstance(block, NBEATSSeasonalBlock): 190 | seasonal_forecast.append(full) 191 | else: 192 | generic_forecast.append(full) 193 | # update backcast and forecast 194 | backcast = backcast.clone() 195 | backcast[..., self.target_positions] = backcast[..., self.target_positions] - backcast_block 196 | # do not use backcast -= backcast_block as this signifies an inline operation 197 | forecast = forecast + forecast_block 198 | 199 | # `encoder_cat` always at the end of sequence, so it will not affect `self.target_positions` 200 | # backcast, forecast is of batch_size * context_length/prediction_length * tar_num 201 | return { 202 | "prediction": forecast, 203 | "backcast": ( 204 | encoder_cont[..., self.target_positions] - backcast[..., self.target_positions], 205 | self.backcast_loss_ratio, 206 | ), 207 | } 208 | 209 | 210 | if __name__ == "__main__": 211 | TIME_COL = "event_timestamp" 212 | 213 | # data = pd.read_csv("/Users/zhao123456/Desktop/gitlab/guizhou_traffic/guizhou_traffic.csv") 214 | # data = data.iloc[:20000,:] 215 | # data.to_csv("/Users/zhao123456/Desktop/gitlab/guizhou_traffic/guizhou_traffic_20000.csv") 216 | 217 | fs = FeatureStore("file:///Users/zhao123456/Desktop/gitlab/guizhou_traffic") 218 | 219 | events_sampler = EvenEventsSampler( 220 | start=fs.get_latest_entities(fs.services["traval_time_prediction_embedding_v1"])[TIME_COL].min(), 221 | end=fs.get_latest_entities(fs.services["traval_time_prediction_embedding_v1"])[TIME_COL].max(), 222 | period="10 minutes", 223 | ) 224 | dataset = fs.get_dataset( 225 | service="traval_time_prediction_embedding_v1", 226 | sampler=NoEntitiesSampler(events_sampler), 227 | ) 228 | 229 | features_cat = [ 230 | fea 231 | for fea in fs._get_available_features(fs.services["traval_time_prediction_embedding_v1"]) 232 | if fea not in fs._get_available_features(fs.services["traval_time_prediction_embedding_v1"], True) 233 | ] 234 | cat_unique = fs.stats( 235 | fs.services["traval_time_prediction_embedding_v1"], 236 | fn="unique", 237 | group_key=[], 238 | start=fs.get_latest_entities(fs.services["traval_time_prediction_embedding_v1"])[TIME_COL].min(), 239 | end=fs.get_latest_entities(fs.services["traval_time_prediction_embedding_v1"])[TIME_COL].max(), 240 | features=features_cat, 241 | ).to_dict() 242 | cat_count = {key: len(cat_unique[key]) for key in cat_unique.keys()} 243 | cont_scalar_max = fs.stats( 244 | fs.services["traval_time_prediction_embedding_v1"], 245 | fn="max", 246 | group_key=[], 247 | start=fs.get_latest_entities(fs.services["traval_time_prediction_embedding_v1"])[TIME_COL].min(), 248 | end=fs.get_latest_entities(fs.services["traval_time_prediction_embedding_v1"])[TIME_COL].max(), 249 | ).to_dict() 250 | cont_scalar_min = fs.stats( 251 | fs.services["traval_time_prediction_embedding_v1"], 252 | fn="min", 253 | group_key=[], 254 | start=fs.get_latest_entities(fs.services["traval_time_prediction_embedding_v1"])[TIME_COL].min(), 255 | end=fs.get_latest_entities(fs.services["traval_time_prediction_embedding_v1"])[TIME_COL].max(), 256 | ).to_dict() 257 | cont_scalar = {key: [cont_scalar_min[key], cont_scalar_max[key]] for key in cont_scalar_min.keys()} 258 | 259 | label = fs._get_available_labels(fs.services["traval_time_prediction_embedding_v1"]) 260 | del cont_scalar[label[0]] 261 | 262 | i_ds = dataset.to_pytorch() 263 | test_data_loader = DataLoader( 264 | i_ds, 265 | collate_fn=lambda x: nbeats_collet_fn( 266 | x, 267 | cont_scalar=cont_scalar, 268 | categoricals=cat_unique, 269 | label=label, 270 | ), 271 | batch_size=8, 272 | drop_last=False, 273 | sampler=None, 274 | ) 275 | 276 | model = NbeatsNetwork( 277 | targets=label, 278 | # prediction_length= fs.services["traval_time_prediction_embedding_v1"].labels[0].period, 279 | # context_length= max([feature.period for feature in fs.services["traval_time_prediction_embedding_v1"].features if feature.period is not None]), 280 | prediction_length=30, 281 | context_length=60, 282 | covariate_number=len(cont_scalar), 283 | encoder_cont=list(cont_scalar.keys()) + label, 284 | decoder_cont=list(cont_scalar.keys()), 285 | x_categoricals=features_cat, 286 | output_size=1, 287 | ) 288 | 289 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 290 | loss_fn = MeanAbsoluteError() 291 | 292 | for epoch in range(10): # assume 10 epoch 293 | print(f"epoch: {epoch} begin") 294 | for x, y in test_data_loader: 295 | pred_label = model(x) 296 | true_label = y 297 | loss = loss_fn(pred_label, true_label) 298 | optimizer.zero_grad() 299 | loss.backward() 300 | optimizer.step() 301 | print(f"epoch: {epoch} done, loss: {loss}") 302 | -------------------------------------------------------------------------------- /f2ai/models/nbeats/submodules.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple, Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def linear(input_size, output_size, bias=True, dropout: int = None): 10 | lin = nn.Linear(input_size, output_size, bias=bias) 11 | if dropout is not None: 12 | return nn.Sequential(nn.Dropout(dropout), lin) 13 | else: 14 | return lin 15 | 16 | 17 | def linspace( 18 | backcast_length: int, forecast_length: int, centered: bool = False 19 | ) -> Tuple[np.ndarray, np.ndarray]: 20 | if centered: 21 | norm = max(backcast_length, forecast_length) 22 | start = -backcast_length 23 | stop = forecast_length - 1 24 | else: 25 | norm = backcast_length + forecast_length 26 | start = 0 27 | stop = backcast_length + forecast_length - 1 28 | lin_space = np.linspace(start / norm, stop / norm, backcast_length + forecast_length, dtype=np.float32) 29 | b_ls = lin_space[:backcast_length] 30 | f_ls = lin_space[backcast_length:] 31 | return b_ls, f_ls 32 | 33 | 34 | class NBEATSBlock(nn.Module): 35 | def __init__( 36 | self, 37 | units, 38 | thetas_dim, 39 | num_block_layers=4, 40 | backcast_length=10, 41 | forecast_length=5, 42 | share_thetas=False, 43 | dropout=0.1, 44 | tar_num=1, 45 | cov_num=0, 46 | tar_pos=[], 47 | ): 48 | 49 | super().__init__() 50 | self.units = units 51 | self.thetas_dim = thetas_dim 52 | self.backcast_length = backcast_length 53 | self.forecast_length = forecast_length 54 | self.share_thetas = share_thetas 55 | self.tar_num = tar_num 56 | self.cov_num = cov_num 57 | self.tar_pos = tar_pos 58 | 59 | fc_stack = [ 60 | [ 61 | nn.Linear(backcast_length, units), 62 | nn.ReLU(), 63 | ] 64 | for i in range(self.tar_num + self.cov_num) 65 | ] 66 | 67 | for i in range(self.tar_num + self.cov_num): 68 | for _ in range(num_block_layers - 1): 69 | fc_stack[i].extend([linear(units, units, dropout=dropout), nn.ReLU()]) 70 | 71 | self.fc = nn.ModuleList(nn.Sequential(*fc_stack[i]) for i in range(self.tar_num + self.cov_num)) 72 | 73 | if share_thetas: 74 | self.theta_f_fc = self.theta_b_fc = nn.ModuleList( 75 | [nn.Linear(units, thetas_dim, bias=False) for i in range(self.tar_num + self.cov_num)] 76 | ) 77 | else: 78 | 79 | self.theta_b_fc = nn.ModuleList( 80 | [nn.Linear(units, thetas_dim, bias=False) for i in range(self.tar_num + self.cov_num)] 81 | ) 82 | self.theta_f_fc = nn.ModuleList( 83 | [nn.Linear(units, thetas_dim, bias=False) for i in range(self.tar_num + self.cov_num)] 84 | ) 85 | 86 | def forward(self, x): 87 | return torch.stack([self.fc[n](x[..., n]) for n in range(self.tar_num + self.cov_num)], dim=2) 88 | # return [self.fc[n](x[...,n]) for n in range(self.tar_num+self.cov_num)] 89 | 90 | 91 | class NBEATSSeasonalBlock(NBEATSBlock): 92 | def __init__( 93 | self, 94 | units, 95 | thetas_dim=None, 96 | num_block_layers=4, 97 | backcast_length=10, 98 | forecast_length=5, 99 | nb_harmonics=None, 100 | min_period=1, 101 | dropout=0.1, 102 | tar_num=1, 103 | cov_num=0, 104 | tar_pos=[], 105 | ): 106 | if nb_harmonics: 107 | thetas_dim = nb_harmonics 108 | else: 109 | thetas_dim = forecast_length 110 | self.min_period = min_period 111 | 112 | super().__init__( 113 | units=units, 114 | thetas_dim=thetas_dim, 115 | num_block_layers=num_block_layers, 116 | backcast_length=backcast_length, 117 | forecast_length=forecast_length, 118 | share_thetas=True, 119 | dropout=dropout, 120 | tar_num=tar_num, 121 | cov_num=cov_num, 122 | tar_pos=tar_pos, 123 | ) 124 | 125 | backcast_linspace, forecast_linspace = linspace(backcast_length, forecast_length, centered=False) 126 | 127 | p1, p2 = ( 128 | (thetas_dim // 2, thetas_dim // 2) 129 | if thetas_dim % 2 == 0 130 | else (thetas_dim // 2, thetas_dim // 2 + 1) 131 | ) 132 | # seasonal_backcast_pre_p1 seasonal_backcast_pre_p2 133 | s1_b_pre = torch.tensor( 134 | [np.cos(2 * np.pi * i * backcast_linspace) for i in self.get_frequencies(p1)], 135 | dtype=torch.float32, 136 | device=backcast_linspace.device, 137 | ) # H/2-1 138 | s2_b_pre = torch.tensor( 139 | [np.sin(2 * np.pi * i * backcast_linspace) for i in self.get_frequencies(p2)], 140 | dtype=torch.float32, 141 | device=backcast_linspace.device, 142 | ) 143 | # concat seasonal_backcast_pre_p1 and seasonal_backcast_pre_p2 144 | s_b_pre = torch.stack( 145 | [torch.cat([s1_b_pre, s2_b_pre]) for n in range(self.tar_num + self.cov_num)], 146 | dim=2, 147 | ) # p1+p2 * backlength * tarnum 148 | 149 | # seasonal_forecast_pre_p1 150 | s1_f_pre = torch.tensor( 151 | [np.cos(2 * np.pi * i * forecast_linspace) for i in self.get_frequencies(p1)], 152 | dtype=torch.float32, 153 | device=forecast_linspace.device, 154 | ) 155 | # seasonal_forecast_pre_p2 156 | s2_f_pre = torch.tensor( 157 | [np.sin(2 * np.pi * i * forecast_linspace) for i in self.get_frequencies(p2)], 158 | dtype=torch.float32, 159 | device=forecast_linspace.device, 160 | ) 161 | # concat seasonal_forecast_pre_p1 and seasonal_forecast_pre_p2 162 | s_f_pre = torch.stack( 163 | [torch.cat([s1_f_pre, s2_f_pre]) for n in range(self.tar_num + self.cov_num)], 164 | dim=2, 165 | ) # p1+p2 * forlength * tarnum 166 | 167 | # register, then can be applied as self.S_backcast and self.S_forecast 168 | self.register_buffer("S_backcast", s_b_pre) 169 | self.register_buffer("S_forecast", s_f_pre) 170 | self.agg_layer = nn.ModuleList([nn.Linear(tar_num + cov_num, 1) for i in range(tar_num)]) 171 | 172 | def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: 173 | x = super().forward(x) 174 | 175 | self.S_backcast_final = [self.agg_layer[i](self.S_backcast).squeeze(-1) for i in range(self.tar_num)] 176 | self.S_forecast_final = [self.agg_layer[i](self.S_forecast).squeeze(-1) for i in range(self.tar_num)] 177 | 178 | amplitudes_backward = torch.stack( 179 | [self.theta_b_fc[n](x[..., n]) for n in (self.tar_num + self.cov_num)], dim=2 180 | ) 181 | 182 | backcast = torch.stack( 183 | [amplitudes_backward[..., n].mm(self.S_backcast_final) for n in self.tar_pos], dim=2 184 | ) 185 | 186 | amplitudes_forward = torch.stack( 187 | [self.theta_f_fc[n](x[..., n]) for n in (self.tar_num + self.cov_num)], dim=2 188 | ) 189 | forecast = torch.stack( 190 | [amplitudes_forward[..., n].mm(self.S_forecast_final) for n in self.tar_pos], dim=2 191 | ) 192 | 193 | return backcast, forecast # only target, not cov 194 | 195 | def get_frequencies(self, n): 196 | return np.linspace(0, (self.backcast_length + self.forecast_length) / self.min_period, n) 197 | 198 | 199 | class NBEATSTrendBlock(NBEATSBlock): 200 | def __init__( 201 | self, 202 | units, 203 | thetas_dim, 204 | num_block_layers=4, 205 | backcast_length=10, 206 | forecast_length=5, 207 | dropout=0.1, 208 | tar_num=1, 209 | cov_num=0, 210 | tar_pos=[], 211 | ): 212 | super().__init__( 213 | units=units, 214 | thetas_dim=thetas_dim, 215 | num_block_layers=num_block_layers, 216 | backcast_length=backcast_length, 217 | forecast_length=forecast_length, 218 | share_thetas=True, 219 | dropout=dropout, 220 | tar_num=tar_num, 221 | cov_num=cov_num, 222 | tar_pos=tar_pos, 223 | ) 224 | 225 | backcast_linspace, forecast_linspace = linspace(backcast_length, forecast_length, centered=True) 226 | norm = np.sqrt(forecast_length / thetas_dim) # ensure range of predictions is comparable to input 227 | 228 | # backcast 229 | coefficients_pre_b = torch.cat( 230 | [ 231 | torch.tensor( 232 | [backcast_linspace**i for i in range(thetas_dim)], 233 | dtype=torch.float32, 234 | device=backcast_linspace.device, 235 | ).unsqueeze(2) 236 | for n in range(self.tar_num + self.cov_num) 237 | ], 238 | 2, 239 | ) 240 | # forecast 241 | coefficients_pre_f = torch.cat( 242 | [ 243 | torch.tensor( 244 | [forecast_linspace**i for i in range(thetas_dim)], 245 | dtype=torch.float32, 246 | device=forecast_linspace.device, 247 | ).unsqueeze(2) 248 | for n in range(self.tar_num + self.cov_num) 249 | ], 250 | 2, 251 | ) 252 | 253 | # register, then can be applied as self.T_backcast and self.T_forecast 254 | self.register_buffer("T_backcast", coefficients_pre_b * norm) 255 | self.register_buffer("T_forecast", coefficients_pre_f * norm) 256 | 257 | self.agg_layer = nn.ModuleList([nn.Linear(tar_num + cov_num, 1) for i in range(tar_num)]) 258 | 259 | def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]: 260 | x = super().forward(x) 261 | 262 | self.T_backcast_final = [self.agg_layer[i](self.T_backcast).squeeze(-1) for i in range(self.tar_num)] 263 | self.T_forecast_final = [self.agg_layer[i](self.T_forecast).squeeze(-1) for i in range(self.tar_num)] 264 | 265 | backcast = torch.stack( 266 | [self.theta_b_fc[n](x[..., n]).mm(self.T_backcast_final[n]) for n in self.tar_pos], dim=2 267 | ) 268 | forecast = torch.stack( 269 | [self.theta_f_fc[n](x[..., n]).mm(self.T_forecast_final[n]) for n in self.tar_pos], dim=2 270 | ) 271 | return backcast, forecast 272 | 273 | 274 | class NBEATSGenericBlock(NBEATSBlock): 275 | def __init__( 276 | self, 277 | units, 278 | thetas_dim, 279 | num_block_layers=4, 280 | backcast_length=10, 281 | forecast_length=5, 282 | dropout=0.1, 283 | tar_num=1, 284 | cov_num=0, 285 | tar_pos=[], 286 | ): 287 | 288 | super().__init__( 289 | units=units, 290 | thetas_dim=thetas_dim, 291 | num_block_layers=num_block_layers, 292 | backcast_length=backcast_length, 293 | forecast_length=forecast_length, 294 | dropout=dropout, 295 | tar_num=tar_num, 296 | cov_num=cov_num, 297 | tar_pos=tar_pos, 298 | ) 299 | 300 | self.backcast_fc = nn.ModuleList( 301 | [nn.Linear(thetas_dim, backcast_length) for i in range(self.tar_num)] 302 | ) 303 | self.forecast_fc = nn.ModuleList( 304 | [nn.Linear(thetas_dim, forecast_length) for i in range(self.tar_num)] 305 | ) 306 | 307 | self.agg_layer = nn.ModuleList([nn.Linear(tar_num + cov_num, 1) for i in range(tar_num)]) 308 | 309 | def forward(self, x): 310 | x = super().forward(x) 311 | 312 | theta_bs = torch.cat( 313 | [F.relu(self.theta_b_fc[n](x[..., n])).unsqueeze(2) for n in range(self.tar_num + self.cov_num)], 314 | 2, 315 | ) # encode x thetas_dim x n_output(n_tar+n_cov) 316 | theta_fs = torch.cat( 317 | [F.relu(self.theta_f_fc[n](x[..., n])).unsqueeze(2) for n in range(self.tar_num + self.cov_num)], 318 | 2, 319 | ) # encode x thetas_dim x n_output 320 | 321 | # lengths = n_target 322 | theta_b = [self.agg_layer[i](theta_bs).squeeze(-1) for i in range(self.tar_num)] 323 | theta_f = [self.agg_layer[i](theta_fs).squeeze(-1) for i in range(self.tar_num)] 324 | 325 | return ( 326 | torch.cat( 327 | [self.backcast_fc[i](theta_b[i]).unsqueeze(2) for i in range(self.tar_num)], 328 | 2, 329 | ), 330 | torch.cat( 331 | [self.forecast_fc[i](theta_f[i]).unsqueeze(2) for i in range(self.tar_num)], 332 | 2, 333 | ), 334 | ) 335 | 336 | 337 | class TimeDistributedEmbeddingBag(nn.EmbeddingBag): 338 | def __init__(self, *args, batch_first: bool = False, **kwargs): 339 | super().__init__(*args, **kwargs) 340 | self.batch_first = batch_first 341 | 342 | def forward(self, x): 343 | 344 | if len(x.size()) <= 2: 345 | return super().forward(x) 346 | 347 | # Squash samples and timesteps into a single axis 348 | x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size) 349 | 350 | y = super().forward(x_reshape) 351 | 352 | # We have to reshape Y 353 | if self.batch_first: 354 | y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size) 355 | else: 356 | y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size) 357 | return y 358 | 359 | 360 | class MultiEmbedding(nn.Module): 361 | def __init__( 362 | self, 363 | embedding_sizes: Dict[str, Tuple[int, int]], 364 | categorical_groups: Dict[str, List[str]], 365 | embedding_paddings: List[str], 366 | x_categoricals: List[str], 367 | max_embedding_size: int = None, 368 | ): 369 | super().__init__() 370 | self.embedding_sizes = embedding_sizes 371 | self.categorical_groups = categorical_groups 372 | self.embedding_paddings = embedding_paddings 373 | self.max_embedding_size = max_embedding_size 374 | self.x_categoricals = x_categoricals 375 | self.init_embeddings() 376 | 377 | def init_embeddings(self): 378 | self.embeddings = nn.ModuleDict() 379 | for name in self.embedding_sizes.keys(): 380 | embedding_size = self.embedding_sizes[name][1] 381 | if self.max_embedding_size is not None: 382 | embedding_size = min(embedding_size, self.max_embedding_size) 383 | # convert to list to become mutable 384 | self.embedding_sizes[name] = list(self.embedding_sizes[name]) 385 | self.embedding_sizes[name][1] = embedding_size 386 | if name in self.categorical_groups: # embedding bag if related embeddings 387 | self.embeddings[name] = TimeDistributedEmbeddingBag( 388 | self.embedding_sizes[name][0], embedding_size, mode="sum", batch_first=True 389 | ) 390 | else: 391 | if name in self.embedding_paddings: 392 | padding_idx = 0 393 | else: 394 | padding_idx = None 395 | self.embeddings[name] = nn.Embedding( 396 | self.embedding_sizes[name][0], 397 | embedding_size, 398 | padding_idx=padding_idx, 399 | ) 400 | 401 | def total_embedding_size(self) -> int: 402 | return sum([size[1] for size in self.embedding_sizes.values()]) 403 | 404 | def names(self) -> List[str]: 405 | return list(self.keys()) 406 | 407 | def items(self): 408 | return self.embeddings.items() 409 | 410 | def keys(self) -> List[str]: 411 | return self.embeddings.keys() 412 | 413 | def values(self): 414 | return self.embeddings.values() 415 | 416 | def __getitem__(self, name: str) -> Union[nn.Embedding, TimeDistributedEmbeddingBag]: 417 | return self.embeddings[name] 418 | 419 | def forward(self, x: torch.Tensor, flat: bool = False) -> Union[Dict[str, torch.Tensor], torch.Tensor]: 420 | out = {} 421 | for name, emb in self.embeddings.items(): 422 | if name in self.categorical_groups: 423 | out[name] = emb( 424 | x[ 425 | ..., 426 | [self.x_categoricals.index(cat_name) for cat_name in self.categorical_groups[name]], 427 | ] 428 | ) 429 | else: 430 | out[name] = emb(x[..., self.x_categoricals.index(name)]) 431 | if flat: 432 | out = torch.cat([v for v in out.values()], dim=-1) 433 | return out 434 | -------------------------------------------------------------------------------- /use_cases/credit_score/credit_score.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "df64f867", 6 | "metadata": {}, 7 | "source": [ 8 | "Example: Credit Score Classification\n" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "7e99a81f", 14 | "metadata": {}, 15 | "source": [ 16 | "`FeatureStore` is a model-agnostic tool aiming to help data scientists and algorithm engineers get rid of tiring data storing and merging tasks.\n", 17 | "
`FeatureStore` not only work on single-dimension data such as classification and prediction, but also work on time-series data.\n", 18 | "
After collecting data, all you need to do is config several straight-forward .yml files, then you can focus on models/algorithms and leave all exhausting preparation to `FeatureStore`.
\n", 19 | "

Here we present credit scoring mission as a single-dimension data demo, it takes features like wages, loan records to decide whether to grant credit or not.
" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "id": "1eebee68", 25 | "metadata": {}, 26 | "source": [ 27 | "Import packages" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 1, 33 | "id": "abc14dec", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "import torch\n", 38 | "import os\n", 39 | "import numpy as np\n", 40 | "import zipfile\n", 41 | "import tempfile\n", 42 | "from torch import nn\n", 43 | "from torch.utils.data import DataLoader\n", 44 | "from f2ai.featurestore import FeatureStore\n", 45 | "from f2ai.common.sampler import GroupFixednbrSampler\n", 46 | "from f2ai.common.collect_fn import classify_collet_fn\n", 47 | "from f2ai.common.utils import get_bucket_from_oss_url\n", 48 | "from f2ai.models.earlystop import EarlyStopping \n", 49 | "from f2ai.models.sequential import SimpleClassify" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "id": "6d14ac45", 55 | "metadata": {}, 56 | "source": [ 57 | "Download demo project files from `OSS` " 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 2, 63 | "id": "eda81c81", 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "name": "stdout", 68 | "output_type": "stream", 69 | "text": [ 70 | "Project downloaded and saved in /tmp/558u5bgc/xyz_test_data\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "download_from = \"oss://aiexcelsior-shanghai-test/xyz_test_data/credit-score.zip\"\n", 76 | "save_path = '/tmp/'\n", 77 | "save_dir = tempfile.mkdtemp(prefix=save_path)\n", 78 | "bucket, key = get_bucket_from_oss_url(download_from)\n", 79 | "dest_zip_filepath = os.path.join(save_dir,key)\n", 80 | "os.makedirs(os.path.dirname(dest_zip_filepath), exist_ok=True)\n", 81 | "bucket.get_object_to_file(key, dest_zip_filepath)\n", 82 | "zipfile.ZipFile(dest_zip_filepath).extractall(dest_zip_filepath.rsplit('/',1)[0])\n", 83 | "os.remove(dest_zip_filepath)\n", 84 | "print(f\"Project downloaded and saved in {dest_zip_filepath.rsplit('/',1)[0]}\")" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "id": "a738d2da", 90 | "metadata": {}, 91 | "source": [ 92 | "Initialize `FeatureStore`" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 3, 98 | "id": "cbdaf53b", 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "TIME_COL = 'event_timestamp'\n", 103 | "fs = FeatureStore(f\"file://{save_dir}/{key.rstrip('.zip')}\")" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 4, 109 | "id": "7cbff283", 110 | "metadata": {}, 111 | "outputs": [ 112 | { 113 | "name": "stdout", 114 | "output_type": "stream", 115 | "text": [ 116 | "All features are: ['person_home_ownership', 'location_type', 'mortgage_due', 'tax_returns_filed', 'population', 'person_emp_length', 'loan_amnt', 'person_age', 'missed_payments_2y', 'state', 'total_wages', 'hard_pulls', 'loan_int_rate', 'missed_payments_6m', 'student_loan_due', 'city', 'person_income', 'loan_intent', 'bankruptcies', 'vehicle_loan_due', 'missed_payments_1y', 'credit_card_due']\n" 117 | ] 118 | } 119 | ], 120 | "source": [ 121 | "print(f\"All features are: {fs._get_feature_to_use(fs.services['credit_scoring_v1'])}\")" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "id": "792ba5e1", 127 | "metadata": {}, 128 | "source": [ 129 | "Get the time range of available data" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 5, 135 | "id": "8eb8697f", 136 | "metadata": {}, 137 | "outputs": [ 138 | { 139 | "name": "stdout", 140 | "output_type": "stream", 141 | "text": [ 142 | "Earliest timestamp: 2020-08-25 20:34:41.361000+00:00\n", 143 | "Latest timestamp: 2021-08-25 20:34:41.361000+00:00\n" 144 | ] 145 | } 146 | ], 147 | "source": [ 148 | "print(f'Earliest timestamp: {fs.get_latest_entities(\"credit_scoring_v1\")[TIME_COL].min()}')\n", 149 | "print(f'Latest timestamp: {fs.get_latest_entities(\"credit_scoring_v1\")[TIME_COL].max()}')" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "id": "13419833", 155 | "metadata": {}, 156 | "source": [ 157 | "Split the train / valid / test data at approximately 7/2/1, use `GroupFixednbrSampler` to downsample original data and return a `torch.IterableDataset`" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 6, 163 | "id": "ec10ddd5", 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "ds_train = fs.get_dataset(\n", 168 | " service=\"credit_scoring_v1\",\n", 169 | " sampler=GroupFixednbrSampler(\n", 170 | " time_bucket=\"5 days\",\n", 171 | " stride=1,\n", 172 | " group_ids=None,\n", 173 | " group_names=None,\n", 174 | " start=\"2020-08-20\",\n", 175 | " end=\"2021-04-30\",\n", 176 | " ),\n", 177 | " )\n", 178 | "ds_valid = fs.get_dataset(\n", 179 | " service=\"credit_scoring_v1\",\n", 180 | " sampler=GroupFixednbrSampler(\n", 181 | " time_bucket=\"5 days\",\n", 182 | " stride=1,\n", 183 | " group_ids=None,\n", 184 | " group_names=None,\n", 185 | " start=\"2021-04-30\",\n", 186 | " end=\"2021-07-31\",\n", 187 | " ),\n", 188 | " )\n", 189 | "ds_test= fs.get_dataset(\n", 190 | " service=\"credit_scoring_v1\",\n", 191 | " sampler=GroupFixednbrSampler(\n", 192 | " time_bucket=\"1 days\",\n", 193 | " stride=1,\n", 194 | " group_ids=None,\n", 195 | " group_names=None,\n", 196 | " start=\"2021-07-31\",\n", 197 | " end=\"2021-08-31\",\n", 198 | " ),\n", 199 | " )" 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "id": "e4d8d35f", 205 | "metadata": {}, 206 | "source": [ 207 | "Using `FeatureStore.stats` to obtain `statistical results` for data processing" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 7, 213 | "id": "9b5b9ced", 214 | "metadata": {}, 215 | "outputs": [ 216 | { 217 | "name": "stdout", 218 | "output_type": "stream", 219 | "text": [ 220 | "Number of unique values of categorical features are: {'person_home_ownership': 4, 'location_type': 1, 'state': 51, 'city': 8166, 'loan_intent': 6}\n" 221 | ] 222 | } 223 | ], 224 | "source": [ 225 | "# catgorical features\n", 226 | "features_cat = [ \n", 227 | " fea\n", 228 | " for fea in fs.services[\"credit_scoring_v1\"].get_feature_names(fs.feature_views)\n", 229 | " if fea not in fs.services[\"credit_scoring_v1\"].get_feature_names(fs.feature_views,is_numeric=True)\n", 230 | "]\n", 231 | "# get unique item number to do labelencoder\n", 232 | "cat_unique = fs.stats(\n", 233 | " \"credit_scoring_v1\",\n", 234 | " fn=\"unique\",\n", 235 | " group_key=[],\n", 236 | " start=\"2020-08-01\",\n", 237 | " end=\"2021-04-30\",\n", 238 | " features=features_cat,\n", 239 | ").to_dict()\n", 240 | "cat_count = {key: len(cat_unique[key]) for key in cat_unique.keys()}\n", 241 | "print(f\"Number of unique values of categorical features are: {cat_count}\")" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 8, 247 | "id": "8a0d6e00", 248 | "metadata": {}, 249 | "outputs": [ 250 | { 251 | "name": "stdout", 252 | "output_type": "stream", 253 | "text": [ 254 | "Min-Max boundary of continuous features are: {'population': [272.0, 88503.0], 'missed_payments_6m': [0.0, 1.0], 'person_age': [20.0, 144.0], 'missed_payments_2y': [0.0, 7.0], 'student_loan_due': [0.0, 49997.0], 'missed_payments_1y': [0.0, 3.0], 'bankruptcies': [0.0, 2.0], 'mortgage_due': [33.0, 1999896.0], 'vehicle_loan_due': [1.0, 29998.0], 'total_wages': [0.0, 2132869892.0], 'tax_returns_filed': [250.0, 47778.0], 'credit_card_due': [0.0, 9998.0], 'person_emp_length': [0.0, 41.0], 'hard_pulls': [0.0, 10.0], 'loan_int_rate': [5.42, 23.22], 'loan_amnt': [500.0, 35000.0], 'person_income': [4000.0, 6000000.0]}\n" 255 | ] 256 | } 257 | ], 258 | "source": [ 259 | "# contiouns features \n", 260 | "cont_scalar_max = fs.stats(\n", 261 | " \"credit_scoring_v1\", fn=\"max\", group_key=[], start=\"2020-08-01\", end=\"2021-04-30\"\n", 262 | ").to_dict()\n", 263 | "cont_scalar_min = fs.stats(\n", 264 | " \"credit_scoring_v1\", fn=\"min\", group_key=[], start=\"2020-08-01\", end=\"2021-04-30\"\n", 265 | ").to_dict()\n", 266 | "cont_scalar = {key: [cont_scalar_min[key], cont_scalar_max[key]] for key in cont_scalar_min.keys()}\n", 267 | "print(f\"Min-Max boundary of continuous features are: {cont_scalar}\")" 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "id": "6a2a08f5", 273 | "metadata": {}, 274 | "source": [ 275 | "Construct `torch.DataLoader` from `torch.IterableDataset` for modelling\n", 276 | "
Here we compose data-preprocess in `collect_fn`, so the time range of `statistical results` used to `.fit()` should be corresponding to `train` data only so as to avoid information leakage.
" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 9, 282 | "id": "1379abba", 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [ 286 | "batch_size=16\n", 287 | "\n", 288 | "train_dataloader = DataLoader( \n", 289 | " ds_train.to_pytorch(),\n", 290 | " collate_fn=lambda x: classify_collet_fn(\n", 291 | " x,\n", 292 | " cat_coder=cat_unique,\n", 293 | " cont_scalar=cont_scalar,\n", 294 | " label=fs._get_feature_to_use(fs.services[\"credit_scoring_v1\"].get_label_view(fs.label_views)),\n", 295 | " ),\n", 296 | " batch_size=batch_size,\n", 297 | " drop_last=True,\n", 298 | ")\n", 299 | "\n", 300 | "valie_dataloader = DataLoader( \n", 301 | " ds_valid.to_pytorch(),\n", 302 | " collate_fn=lambda x: classify_collet_fn(\n", 303 | " x,\n", 304 | " cat_coder=cat_unique,\n", 305 | " cont_scalar=cont_scalar,\n", 306 | " label=fs._get_feature_to_use(fs.services[\"credit_scoring_v1\"].get_label_view(fs.label_views)),\n", 307 | " ),\n", 308 | " batch_size=batch_size,\n", 309 | " drop_last=False,\n", 310 | ")\n", 311 | "\n", 312 | "test_dataloader = DataLoader( \n", 313 | " ds_valid.to_pytorch(),\n", 314 | " collate_fn=lambda x: classify_collet_fn(\n", 315 | " x,\n", 316 | " cat_coder=cat_unique,\n", 317 | " cont_scalar=cont_scalar,\n", 318 | " label=fs._get_feature_to_use(fs.services[\"credit_scoring_v1\"].get_label_view(fs.label_views)),\n", 319 | " ),\n", 320 | " drop_last=False,\n", 321 | ")" 322 | ] 323 | }, 324 | { 325 | "cell_type": "markdown", 326 | "id": "1fd7cb71", 327 | "metadata": {}, 328 | "source": [ 329 | "Customize `model`, `optimizer` and `loss` function suitable to task" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 10, 335 | "id": "4a9eb8be", 336 | "metadata": {}, 337 | "outputs": [], 338 | "source": [ 339 | "model = SimpleClassify(\n", 340 | " cont_nbr=len(cont_scalar_max), cat_nbr=len(cat_count), emd_dim=8, max_types=max(cat_count.values()),hidden_size=4\n", 341 | ")\n", 342 | "optimizer = torch.optim.Adam(model.parameters(), lr=0.001) \n", 343 | "loss_fn = nn.BCELoss() " 344 | ] 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "id": "bb956ec5", 349 | "metadata": {}, 350 | "source": [ 351 | "Use `train_dataloader` to train while `valie_dataloader` to guide `earlystop`" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 12, 357 | "id": "e3321113", 358 | "metadata": {}, 359 | "outputs": [ 360 | { 361 | "name": "stdout", 362 | "output_type": "stream", 363 | "text": [ 364 | "epoch: 0 done, train loss: 0.9228071888287862, valid loss: 0.8917452792326609\n", 365 | "epoch: 1 done, train loss: 0.8228938817977905, valid loss: 0.806285967429479\n", 366 | "epoch: 2 done, train loss: 0.7159536878267924, valid loss: 0.715582956870397\n", 367 | "epoch: 3 done, train loss: 0.6077397664388021, valid loss: 0.6279712617397308\n", 368 | "epoch: 4 done, train loss: 0.523370603720347, valid loss: 0.5640117128690084\n", 369 | "epoch: 5 done, train loss: 0.46912830471992495, valid loss: 0.530453751484553\n", 370 | "epoch: 6 done, train loss: 0.4288410405317942, valid loss: 0.5177373588085175\n", 371 | "epoch: 7 done, train loss: 0.4263931175072988, valid loss: 0.5144786983728409\n", 372 | "epoch: 8 done, train loss: 0.4069450835386912, valid loss: 0.5147957305113474\n", 373 | "EarlyStopping counter: 1 out of 5\n", 374 | "epoch: 9 done, train loss: 0.4173213442166646, valid loss: 0.5159803281227747\n", 375 | "EarlyStopping counter: 2 out of 5\n", 376 | "epoch: 10 done, train loss: 0.43354902466138207, valid loss: 0.5168851017951965\n", 377 | "EarlyStopping counter: 3 out of 5\n", 378 | "epoch: 11 done, train loss: 0.3854812800884247, valid loss: 0.5179851104815801\n", 379 | "EarlyStopping counter: 4 out of 5\n", 380 | "epoch: 12 done, train loss: 0.3868887344996134, valid loss: 0.5189551711082458\n", 381 | "EarlyStopping counter: 5 out of 5\n", 382 | "Trigger earlystop, stop epoch at 12\n" 383 | ] 384 | } 385 | ], 386 | "source": [ 387 | "# you can also use any ready-to-use training frame like ignite, pytorch-lightening...\n", 388 | "early_stop = EarlyStopping(save_path=f\"{save_dir}/{key.rstrip('.zip')}\",patience=5,delta=1e-6)\n", 389 | "for epoch in range(50):\n", 390 | " train_loss = []\n", 391 | " valid_loss = []\n", 392 | " \n", 393 | " model.train()\n", 394 | " for x, y in train_dataloader:\n", 395 | " pred_label = model(x)\n", 396 | " true_label = y\n", 397 | " loss = loss_fn(pred_label, true_label)\n", 398 | " train_loss.append(loss.item())\n", 399 | " optimizer.zero_grad()\n", 400 | " loss.backward()\n", 401 | " optimizer.step()\n", 402 | " \n", 403 | " model.eval()\n", 404 | " for x, y in valie_dataloader:\n", 405 | " pred_label = model(x)\n", 406 | " true_label = y\n", 407 | " loss = loss_fn(pred_label, true_label)\n", 408 | " valid_loss.append(loss.item())\n", 409 | "\n", 410 | " print(f\"epoch: {epoch} done, train loss: {np.mean(train_loss)}, valid loss: {np.mean(valid_loss)}\")\n", 411 | " early_stop(np.mean(valid_loss),model)\n", 412 | " if early_stop.early_stop:\n", 413 | " print(f\"Trigger earlystop, stop epoch at {epoch}\")\n", 414 | " break" 415 | ] 416 | }, 417 | { 418 | "cell_type": "markdown", 419 | "id": "0f15ea7e", 420 | "metadata": {}, 421 | "source": [ 422 | "Get prediction result of `test_dataloader`" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": 13, 428 | "id": "d87afe0c", 429 | "metadata": {}, 430 | "outputs": [], 431 | "source": [ 432 | "model = torch.load(os.path.join(f\"{save_dir}/{key.rstrip('.zip')}\",'best_chekpnt.pk'))\n", 433 | "model.eval()\n", 434 | "preds=[]\n", 435 | "trues=[]\n", 436 | "for x,y in test_dataloader:\n", 437 | " pred = model(x)\n", 438 | " pred_label = 1 if pred.cpu().detach().numpy() >0.5 else 0\n", 439 | " preds.append(pred_label)\n", 440 | " trues.append(y.cpu().detach().numpy())" 441 | ] 442 | }, 443 | { 444 | "cell_type": "markdown", 445 | "id": "2d939513", 446 | "metadata": {}, 447 | "source": [ 448 | "Model Evaluation" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": 14, 454 | "id": "f88810eb", 455 | "metadata": {}, 456 | "outputs": [ 457 | { 458 | "name": "stdout", 459 | "output_type": "stream", 460 | "text": [ 461 | "Accuracy: 0.7956989247311828\n" 462 | ] 463 | } 464 | ], 465 | "source": [ 466 | "# accuracy\n", 467 | "acc = [1 if preds[i]==trues[i] else 0 for i in range(len(trues))]\n", 468 | "print(f\"Accuracy: {np.sum(acc) / len(acc)}\")" 469 | ] 470 | } 471 | ], 472 | "metadata": { 473 | "kernelspec": { 474 | "display_name": "autonn-3.8.7", 475 | "language": "python", 476 | "name": "python3" 477 | }, 478 | "language_info": { 479 | "codemirror_mode": { 480 | "name": "ipython", 481 | "version": 3 482 | }, 483 | "file_extension": ".py", 484 | "mimetype": "text/x-python", 485 | "name": "python", 486 | "nbconvert_exporter": "python", 487 | "pygments_lexer": "ipython3", 488 | "version": "3.8.12" 489 | }, 490 | "vscode": { 491 | "interpreter": { 492 | "hash": "5840e4ed671345474330e8fce6ab52c58896a3935f0e728b8dbef1ddfad82808" 493 | } 494 | } 495 | }, 496 | "nbformat": 4, 497 | "nbformat_minor": 5 498 | } 499 | --------------------------------------------------------------------------------