├── docs
├── static
│ ├── .gitkeep
│ └── f2ai_architecture.png
├── templates
│ └── .gitkeep
├── requirements.txt
├── index.rst
├── f2ai_api.rst
├── Makefile
├── make.bat
└── conf.py
├── Makefile
├── f2ai
├── cmd
│ ├── __init__.py
│ └── main.py
├── __init__.py
├── __main__.py
├── example
│ └── jaffle_shop_example.py
├── offline_stores
│ └── offline_spark_store.py
├── common
│ ├── jinja.py
│ ├── time_field.py
│ ├── read_file.py
│ ├── cmd_parser.py
│ ├── get_config.py
│ ├── oss_utils.py
│ ├── collect_fn.py
│ └── utils.py
├── persist_engine
│ ├── offline_spark_persistengine.py
│ ├── online_spark_persistengine.py
│ ├── online_local_persistengine.py
│ ├── offline_file_persistengine.py
│ └── offline_pgsql_persistengine.py
├── definitions
│ ├── constants.py
│ ├── feature_view.py
│ ├── dtypes.py
│ ├── label_view.py
│ ├── base_view.py
│ ├── entities.py
│ ├── __init__.py
│ ├── backoff_time.py
│ ├── sources.py
│ ├── online_store.py
│ ├── period.py
│ ├── features.py
│ ├── services.py
│ ├── offline_store.py
│ └── persist_engine.py
├── dataset
│ ├── __init__.py
│ ├── dataset.py
│ ├── events_sampler.py
│ ├── pytorch_dataset.py
│ └── entities_sampler.py
├── models
│ ├── sequential.py
│ ├── normalizer.py
│ ├── encoder.py
│ ├── earlystop.py
│ └── nbeats
│ │ ├── model.py
│ │ └── submodules.py
└── online_stores
│ └── online_redis_store.py
├── tests
├── fixtures
│ ├── constants.py
│ ├── git_utils.py
│ ├── credit_score_fixtures.py
│ └── guizhou_traffic_fixtures.py
├── units
│ ├── offline_stores
│ │ ├── postgres_sqls
│ │ │ ├── stats_query_unique.sql
│ │ │ ├── stats_query_avg.sql
│ │ │ ├── stats_query_max.sql
│ │ │ ├── stats_query_min.sql
│ │ │ ├── stats_query_std.sql
│ │ │ ├── stats_query_mode.sql
│ │ │ ├── stats_query_median.sql
│ │ │ ├── store_stats_query_categorical.sql
│ │ │ └── store_stats_query_numeric.sql
│ │ ├── offline_postgres_store_test.py
│ │ └── offline_file_store_test.py
│ ├── definitions
│ │ ├── entities_test.py
│ │ ├── features_test.py
│ │ ├── offline_store_test.py
│ │ └── back_off_time_test.py
│ ├── common
│ │ └── utils_test.py
│ ├── service_test.py
│ ├── dataset
│ │ ├── events_sampler_test.py
│ │ └── entities_sampler_test.py
│ └── period_test.py
├── conftest.py
└── integrations
│ └── benchmarks
│ ├── offline_pgsql_benchmark_test.py
│ └── offline_file_benchmark_test.py
├── setup.cfg
├── requirements.txt
├── dev_requirements.txt
├── DockerFile
├── .readthedocs.yaml
├── .gitignore
├── setup.py
├── .github
└── workflows
│ └── python-publish.yml
├── README.rst
└── use_cases
├── guizhou_traffic_arima.py
├── .ipynb_checkpoints
└── credit_score-checkpoint.ipynb
└── credit_score
└── credit_score.ipynb
/docs/static/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/templates/.gitkeep:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx-rtd-theme
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | redis:
2 | docker run -d --rm --name redis -p 6379:6379 redis
--------------------------------------------------------------------------------
/f2ai/cmd/__init__.py:
--------------------------------------------------------------------------------
1 | from .main import main
2 |
3 | __all__ = ["main"]
4 |
--------------------------------------------------------------------------------
/tests/fixtures/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | TEMP_DIR = os.path.expanduser("~/.f2ai")
4 |
--------------------------------------------------------------------------------
/f2ai/__init__.py:
--------------------------------------------------------------------------------
1 | from .featurestore import FeatureStore
2 |
3 | __all__ = ["FeatureStore"]
4 |
--------------------------------------------------------------------------------
/tests/units/offline_stores/postgres_sqls/stats_query_unique.sql:
--------------------------------------------------------------------------------
1 | SELECT DISTINCT zipcode FROM "zipcode_table"
--------------------------------------------------------------------------------
/f2ai/__main__.py:
--------------------------------------------------------------------------------
1 | #! python
2 |
3 | if __name__ == "__main__":
4 | from .cmd import main
5 |
6 | main()
7 |
--------------------------------------------------------------------------------
/docs/static/f2ai_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ai-excelsior/F2AI/HEAD/docs/static/f2ai_architecture.png
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore = E203,E501,E722,E731,W503,W605
3 | max-line-length = 110
4 |
5 | [tool:pytest]
6 | log_cli=true
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | pytest_plugins = [
2 | "fixtures.credit_score_fixtures",
3 | "fixtures.guizhou_traffic_fixtures",
4 | ]
5 |
--------------------------------------------------------------------------------
/tests/units/offline_stores/postgres_sqls/stats_query_avg.sql:
--------------------------------------------------------------------------------
1 | SELECT "zipcode",AVG(population) "population" FROM "zipcode_table" GROUP BY "zipcode"
--------------------------------------------------------------------------------
/tests/units/offline_stores/postgres_sqls/stats_query_max.sql:
--------------------------------------------------------------------------------
1 | SELECT "zipcode",MAX(population) "population" FROM "zipcode_table" GROUP BY "zipcode"
--------------------------------------------------------------------------------
/tests/units/offline_stores/postgres_sqls/stats_query_min.sql:
--------------------------------------------------------------------------------
1 | SELECT "zipcode",MIN(population) "population" FROM "zipcode_table" GROUP BY "zipcode"
--------------------------------------------------------------------------------
/tests/units/offline_stores/postgres_sqls/stats_query_std.sql:
--------------------------------------------------------------------------------
1 | SELECT "zipcode",STD(population) "population" FROM "zipcode_table" GROUP BY "zipcode"
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pandas==1.4.0
2 | pyyaml==6.0
3 | psycopg2-binary==2.9.5
4 | SQLAlchemy==1.4.31
5 | pypika==0.48.9
6 | pydantic==1.10.2
7 | oss2==2.16.0
8 | Jinja2
--------------------------------------------------------------------------------
/tests/units/offline_stores/postgres_sqls/stats_query_mode.sql:
--------------------------------------------------------------------------------
1 | SELECT "zipcode",MODE() WITHIN GROUP (ORDER BY population) as population FROM "zipcode_table" GROUP BY "zipcode"
--------------------------------------------------------------------------------
/dev_requirements.txt:
--------------------------------------------------------------------------------
1 | pandas==1.4.0
2 | pyyaml==6.0
3 | psycopg2-binary==2.9.5
4 | torch==1.11.0
5 | SQLAlchemy==1.4.31
6 | pypika==0.48.9
7 | pydantic==1.10.2
8 | oss2==2.16.0
9 | Jinja2
--------------------------------------------------------------------------------
/tests/units/offline_stores/postgres_sqls/stats_query_median.sql:
--------------------------------------------------------------------------------
1 | SELECT "zipcode",PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY population) as population FROM "zipcode_table" GROUP BY "zipcode"
--------------------------------------------------------------------------------
/tests/units/definitions/entities_test.py:
--------------------------------------------------------------------------------
1 | from f2ai.definitions import Entity
2 |
3 |
4 | def test_entity_auto_join_keys():
5 | entity = Entity(name="foo")
6 | assert entity.join_keys == ["foo"]
7 |
--------------------------------------------------------------------------------
/tests/units/offline_stores/postgres_sqls/store_stats_query_categorical.sql:
--------------------------------------------------------------------------------
1 | SELECT DISTINCT zipcode FROM "zipcode_table" WHERE "event_timestamp">='2017-01-01T00:00:00' AND "event_timestamp"<='2018-01-01T00:00:00'
--------------------------------------------------------------------------------
/f2ai/example/jaffle_shop_example.py:
--------------------------------------------------------------------------------
1 | from f2ai.featurestore import FeatureStore
2 |
3 |
4 | feature_store = FeatureStore(
5 | project_folder="/Users/xuyizhou/Desktop/xyz_warehouse/gitlab/f2ai-credit-scoring"
6 | )
7 |
--------------------------------------------------------------------------------
/f2ai/offline_stores/offline_spark_store.py:
--------------------------------------------------------------------------------
1 | from ..definitions import OfflineStore, OfflineStoreType
2 |
3 |
4 | class OfflineSparkStore(OfflineStore):
5 | type: OfflineStoreType = OfflineStoreType.SPARK
6 | pass
7 |
--------------------------------------------------------------------------------
/f2ai/common/jinja.py:
--------------------------------------------------------------------------------
1 | import os
2 | from jinja2 import Environment, FileSystemLoader, select_autoescape
3 |
4 | jinja_env = Environment(loader=FileSystemLoader(os.path.join(os.path.dirname(__file__), "../templates")), autoescape=select_autoescape)
--------------------------------------------------------------------------------
/tests/units/offline_stores/postgres_sqls/store_stats_query_numeric.sql:
--------------------------------------------------------------------------------
1 | SELECT "zipcode",AVG(population) "population" FROM "zipcode_table" WHERE "event_timestamp">='2017-01-01T00:00:00' AND "event_timestamp"<='2018-01-01T00:00:00' GROUP BY "zipcode"
--------------------------------------------------------------------------------
/DockerFile:
--------------------------------------------------------------------------------
1 | FROM python:3.8
2 |
3 | WORKDIR /aie-feast
4 | RUN pip install --upgrade pip
5 |
6 | COPY ./requirements.txt /aie-feast/
7 | RUN pip install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt
8 |
9 | COPY ./aie-feast /aie-feast
--------------------------------------------------------------------------------
/f2ai/persist_engine/offline_spark_persistengine.py:
--------------------------------------------------------------------------------
1 | from ..definitions import OfflinePersistEngine, OfflinePersistEngineType
2 |
3 |
4 | class OfflineSparkPersistEngine(OfflinePersistEngine):
5 | type: OfflinePersistEngine = OfflinePersistEngineType.SPARK
6 |
--------------------------------------------------------------------------------
/f2ai/persist_engine/online_spark_persistengine.py:
--------------------------------------------------------------------------------
1 | from f2ai.definitions import OnlinePersistEngine, OnlinePersistEngineType
2 |
3 |
4 | class OnlineSparkPersistEngine(OnlinePersistEngine):
5 | type: OnlinePersistEngine = OnlinePersistEngineType.DISTRIBUTE
6 |
--------------------------------------------------------------------------------
/f2ai/common/time_field.py:
--------------------------------------------------------------------------------
1 | DEFAULT_EVENT_TIMESTAMP_FIELD = "event_timestamp"
2 | ENTITY_EVENT_TIMESTAMP_FIELD = "_entity_event_timestamp_"
3 | SOURCE_EVENT_TIMESTAMP_FIELD = "_source_event_timestamp_"
4 | QUERY_COL = "query_timestamp"
5 | TIME_COL = "event_timestamp"
6 | MATERIALIZE_TIME = "materialize_time"
7 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | os: "ubuntu-20.04"
5 | tools:
6 | python: "3.8"
7 |
8 | sphinx:
9 | configuration: ./docs/conf.py
10 | fail_on_warning: true
11 |
12 | python:
13 | install:
14 | - requirements: requirements.txt
15 | - requirements: docs/requirements.txt
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. toctree::
2 | :hidden:
3 | :caption: Getting Started
4 |
5 | self
6 |
7 | .. toctree::
8 | :hidden:
9 | :caption: MISCELLANEOUS
10 |
11 | f2ai_api
12 | genindex
13 |
14 | Welcome to F2AI's documentation!
15 | ================================
16 |
17 | Hello World
--------------------------------------------------------------------------------
/f2ai/definitions/constants.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from enum import Enum
3 |
4 | LOCAL_TIMEZONE = datetime.datetime.now().astimezone().tzinfo
5 |
6 |
7 | class StatsFunctions(Enum):
8 | MIN = "min"
9 | MAX = "max"
10 | STD = "std"
11 | AVG = "avg"
12 | MODE = "mode"
13 | MEDIAN = "median"
14 | UNIQUE = "unique"
15 |
--------------------------------------------------------------------------------
/tests/units/common/utils_test.py:
--------------------------------------------------------------------------------
1 | from f2ai.common.utils import batched
2 |
3 |
4 | def test_batched():
5 | xs = range(10)
6 | batches = next(batched(xs, batch_size=3))
7 | assert batches == [0, 1, 2]
8 |
9 | last_batch = []
10 | for batch in batched(xs, batch_size=3):
11 | last_batch = batch
12 | assert last_batch == [9]
13 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.log
2 | *.pyc
3 | *.bk
4 | *.pdf
5 | *.h5
6 | *.dump
7 | .env
8 | *.pk
9 | *.test.py
10 | !aiefeast
11 | **/*.pyc
12 | **/**/*.pyc
13 | .vscode
14 | .DS_Store
15 | *.pid
16 | **/*log.*
17 | f2ai-credit-scoring-main
18 | guizhou_traffic-main
19 | guizhou_traffic-ver_pgsql
20 | docs/_*
21 | /_*
22 | 1f/
23 | 3f/
24 | scripts/
25 | egg*
26 | *.egg-info
27 | /build
--------------------------------------------------------------------------------
/tests/units/definitions/features_test.py:
--------------------------------------------------------------------------------
1 | from f2ai.definitions import SchemaAnchor
2 |
3 |
4 | def test_parse_cfg_to_feature_anchor():
5 | feature_anchor = SchemaAnchor.from_str("fv1:f1")
6 | assert feature_anchor.view_name == "fv1"
7 | assert feature_anchor.schema_name == "f1"
8 | assert feature_anchor.period is None
9 |
10 | feature_anchor = SchemaAnchor.from_str("fv1:f1:1 day")
11 | assert feature_anchor.period == "1 day"
12 |
--------------------------------------------------------------------------------
/tests/units/definitions/offline_store_test.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from f2ai.definitions import init_offline_store_from_cfg
3 |
4 | offline_file_store_yaml_string = """
5 | type: file
6 | """
7 |
8 |
9 | def test_init_file_offline_store():
10 | from f2ai.offline_stores.offline_file_store import OfflineFileStore
11 |
12 | offline_store = init_offline_store_from_cfg(yaml.safe_load(offline_file_store_yaml_string), 'test')
13 | assert isinstance(offline_store, OfflineFileStore)
14 |
--------------------------------------------------------------------------------
/tests/units/definitions/back_off_time_test.py:
--------------------------------------------------------------------------------
1 | from f2ai.definitions import BackOffTime
2 |
3 |
4 | def test_back_off_time_to_units():
5 | back_off_time = BackOffTime(start="2020-08-01 12:20", end="2020-10-30", step="1 month")
6 | units = list(back_off_time.to_units())
7 | assert len(units) == 3
8 |
9 | back_off_time = BackOffTime(start="2020-08-01 12:20", end="2020-08-02", step="2 hours")
10 | units = list(back_off_time.to_units())
11 | assert len(units) == 6
12 |
--------------------------------------------------------------------------------
/docs/f2ai_api.rst:
--------------------------------------------------------------------------------
1 | API Reference
2 | ==================
3 |
4 |
5 | featurestore
6 | ------------------------------
7 |
8 | .. automodule:: f2ai.featurestore
9 | :members:
10 | :undoc-members:
11 | :show-inheritance:
12 |
13 |
14 | dataset
15 | ------------------------------
16 |
17 | .. automodule:: f2ai.dataset
18 | :members:
19 | :undoc-members:
20 | :show-inheritance:
21 |
22 |
23 | definitions
24 | -----------------------------
25 |
26 | .. automodule:: f2ai.definitions
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
--------------------------------------------------------------------------------
/tests/units/service_test.py:
--------------------------------------------------------------------------------
1 | from f2ai.definitions import SchemaAnchor, FeatureSchema, FeatureView, Service
2 |
3 |
4 | def test_get_features_from_service():
5 | service = Service(
6 | name="foo",
7 | features=[
8 | SchemaAnchor(view_name="fv", schema_name="*"),
9 | SchemaAnchor(view_name="fv", schema_name="foo"),
10 | ],
11 | )
12 | feature_views = {"fv": FeatureView(name="fv", schema=[FeatureSchema(name="f1", dtype="string")])}
13 |
14 | features = service.get_feature_objects(feature_views)
15 |
16 | assert len(features) == 1
17 | assert list(features)[0].name == "f1"
18 |
--------------------------------------------------------------------------------
/f2ai/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import Dataset
2 | from .pytorch_dataset import TorchIterableDataset
3 | from .events_sampler import (
4 | EventsSampler,
5 | EvenEventsSampler,
6 | RandomNEventsSampler,
7 | )
8 | from .entities_sampler import (
9 | EntitiesSampler,
10 | NoEntitiesSampler,
11 | EvenEntitiesSampler,
12 | FixedNEntitiesSampler,
13 | )
14 |
15 | __all__ = [
16 | "Dataset",
17 | "TorchIterableDataset",
18 | "EventsSampler",
19 | "EvenEventsSampler",
20 | "RandomNEventsSampler",
21 | "EntitiesSampler",
22 | "NoEntitiesSampler",
23 | "EvenEntitiesSampler",
24 | "FixedNEntitiesSampler",
25 | ]
26 |
--------------------------------------------------------------------------------
/tests/units/dataset/events_sampler_test.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from f2ai.dataset import (
3 | EvenEventsSampler,
4 | RandomNEventsSampler,
5 | )
6 |
7 |
8 | def test_even_events_sampler():
9 | sampler = EvenEventsSampler(start="2022-10-02", end="2022-12-02", period="1 day")
10 |
11 | assert len(sampler()) == 62
12 | assert next(iter(sampler)) == pd.Timestamp("2022-10-02 00:00:00")
13 |
14 |
15 | def test_random_n_events_sampler():
16 | sampler = RandomNEventsSampler(start="2022-10-02", end="2022-12-02", period="1 day", n=2, random_state=666)
17 |
18 | assert len(sampler()) == 2
19 | assert next(iter(sampler)) == pd.Timestamp("2022-10-05 00:00:00")
20 |
--------------------------------------------------------------------------------
/f2ai/definitions/feature_view.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 |
4 | from .base_view import BaseView
5 | from .features import Feature
6 |
7 |
8 | class FeatureView(BaseView):
9 | def get_feature_names(self) -> List[str]:
10 | return [feature.name for feature in self.schemas]
11 |
12 | def get_feature_objects(self, is_numeric=False) -> List[Feature]:
13 | return list(
14 | dict.fromkeys(
15 | [
16 | Feature.create_feature_from_schema(schema, self.name)
17 | for schema in self.schemas
18 | if (schema.is_numeric() if is_numeric else True)
19 | ]
20 | )
21 | )
22 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/f2ai/definitions/dtypes.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class FeatureDTypes(str, Enum):
5 | """Feature data type definitions which supported by F2AI. Used to convert to a certain data type for different algorithm frameworks. Not useful now, but future."""
6 |
7 | INT = "int"
8 | INT32 = "int32"
9 | INT64 = "int64"
10 | FLOAT = "float"
11 | FLOAT32 = "float32"
12 | FLOAT64 = "float64"
13 | STRING = "string"
14 | BOOLEAN = "bool"
15 | UNKNOWN = "unknown"
16 |
17 |
18 | NUMERIC_FEATURE_D_TYPES = {
19 | FeatureDTypes.INT,
20 | FeatureDTypes.INT32,
21 | FeatureDTypes.INT64,
22 | FeatureDTypes.FLOAT,
23 | FeatureDTypes.FLOAT32,
24 | FeatureDTypes.FLOAT64,
25 | }
26 |
--------------------------------------------------------------------------------
/f2ai/definitions/label_view.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List
2 |
3 | from .base_view import BaseView
4 | from .features import Feature
5 |
6 |
7 | class LabelView(BaseView):
8 | request_source: Optional[str]
9 |
10 | def get_label_names(self):
11 | return [label.name for label in self.schemas]
12 |
13 | def get_label_objects(self, is_numeric=False) -> List[Feature]:
14 | return list(
15 | dict.fromkeys(
16 | [
17 | Feature.create_label_from_schema(schema, self.name)
18 | for schema in self.schemas
19 | if (schema.is_numeric() if is_numeric else True)
20 | ]
21 | )
22 | )
23 |
--------------------------------------------------------------------------------
/tests/fixtures/git_utils.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 |
4 | def git_clone(cwd: str, repo: str, branch: str):
5 | return subprocess.run(
6 | [
7 | "git",
8 | "clone",
9 | "--branch",
10 | branch,
11 | repo,
12 | cwd,
13 | ],
14 | check=False,
15 | )
16 |
17 |
18 | def git_reset(cwd: str, branch: str, mode: str = "hard"):
19 | return subprocess.run(["git", "reset", f"--{mode}", branch], cwd=cwd, check=False)
20 |
21 |
22 | def git_clean(cwd: str):
23 | return subprocess.run(["git", "clean", "-df"], cwd=cwd, check=False)
24 |
25 |
26 | def git_pull(cwd: str, branch: str):
27 | return subprocess.run(["git", "pull", "--rebase", "origin", branch], cwd=cwd, check=False)
28 |
--------------------------------------------------------------------------------
/f2ai/definitions/base_view.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Dict, Any
2 | from pydantic import BaseModel, Field
3 |
4 | from .features import FeatureSchema
5 | from .period import Period
6 |
7 |
8 | class BaseView(BaseModel):
9 | """Abstraction of common part of FeatureView and LabelView."""
10 |
11 | name: str
12 | description: Optional[str]
13 | entities: List[str] = []
14 | schemas: List[FeatureSchema] = Field(alias="schema", default=[])
15 | batch_source: Optional[str]
16 | ttl: Optional[Period]
17 | tags: Dict[str, str] = {}
18 |
19 | def __init__(__pydantic_self__, **data: Any) -> None:
20 | if isinstance(data.get("ttl", None), str):
21 | data["ttl"] = Period.from_str(data.get("ttl", None))
22 | super().__init__(**data)
23 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_namespace_packages
2 | from pip._internal.req import parse_requirements
3 | from pip._internal.network.session import PipSession
4 |
5 | install_requirements = parse_requirements("requirements.txt", session=PipSession())
6 |
7 | setup(
8 | name="f2ai",
9 | version="0.0.4",
10 | description="A Feature Store tool focus on making retrieve features easily in machine learning.",
11 | url="https://github.com/ai-excelsior/F2AI",
12 | author="上海半见",
13 | license="MIT",
14 | zip_safe=False,
15 | packages=find_namespace_packages(exclude=["docs", "tests", "tests.*", "use_cases", "*.egg-info"]),
16 | install_requires=[str(x.requirement) for x in install_requirements],
17 | entry_points={"console_scripts": ["f2ai = f2ai.cmd:main"]},
18 | )
19 |
--------------------------------------------------------------------------------
/f2ai/definitions/entities.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 | from typing import List, Any, Optional
3 |
4 |
5 | class Entity(BaseModel):
6 | """An entity is a key connection between different feature views. Under the hook, we use join_keys to join the feature views. If join key is empty, we will take the name as a default join key"""
7 |
8 | name: str
9 | description: Optional[str]
10 | join_keys: List[str] = []
11 |
12 | def __init__(__pydantic_self__, **data: Any) -> None:
13 |
14 | join_keys = data.pop("join_keys", [])
15 | if len(join_keys) == 0:
16 | join_keys = [data.get("name")]
17 |
18 | super().__init__(**data, join_keys=join_keys)
19 |
20 | def __hash__(self) -> int:
21 | s = ",".join(self.join_keys)
22 | return hash(f"{self.name}:{s}")
23 |
--------------------------------------------------------------------------------
/f2ai/common/read_file.py:
--------------------------------------------------------------------------------
1 | # from .oss_utils import get_bucket_from_oss_url
2 | from typing import Dict
3 | from .utils import remove_prefix
4 | import yaml
5 |
6 |
7 | def read_yml(url: str) -> Dict:
8 | """read .yml file for following execute
9 |
10 | Args:
11 | url (str): url of .yml
12 | """
13 | file = _read_file(url)
14 | cfg = yaml.load(file, Loader=yaml.FullLoader)
15 | return cfg
16 |
17 |
18 | def _read_file(url):
19 | if url.startswith("file://"):
20 | with open(remove_prefix(url, "file://"), "r") as file:
21 | return file.read()
22 | # default behavior
23 | else:
24 | with open(url, "r") as file:
25 | return file.read()
26 | # elif url.startswith("oss://"): # TODO: may not be correct
27 | # bucket, key = get_bucket_from_oss_url(url)
28 | # return bucket.get_object(key).read()
29 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
37 | livehtml:
38 | sphinx-autobuild "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/f2ai/common/cmd_parser.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 |
4 | def add_materialize_parser(subparsers):
5 | parser: ArgumentParser = subparsers.add_parser(
6 | "materialize", help="materialize a service to offline (default) or online"
7 | )
8 |
9 | parser.add_argument(
10 | "services",
11 | type=str,
12 | nargs="+",
13 | help="at least one service name, multi service name using space to separate.",
14 | )
15 | parser.add_argument(
16 | "--online", action="store_true", help="materialize service to online store if presents."
17 | )
18 | parser.add_argument("--fromnow", type=str, help="materialize start time point from now, egg: 7 days.")
19 | parser.add_argument(
20 | "--start", type=str, help="materialize start time point, egg: 2022-10-22, or 2022-11-22T10:12."
21 | )
22 | parser.add_argument(
23 | "--end", type=str, help="materialize end time point, egg: 2022-10-22, or 2022-11-22T10:12."
24 | )
25 | parser.add_argument("--step", type=str, default="1 day", help="how to split materialize task")
26 | parser.add_argument("--tz", type=str, default=None, help="timezone of start and end, default to None")
27 |
--------------------------------------------------------------------------------
/f2ai/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import TYPE_CHECKING
3 |
4 | if TYPE_CHECKING:
5 | from ..definitions.services import Service
6 | from ..featurestore import FeatureStore
7 | from .pytorch_dataset import TorchIterableDataset
8 |
9 |
10 | class Dataset:
11 | """
12 | A dataset is an abstraction, which hold a service and a sampler.
13 | A service basic tells us, where is the data. A sampler tells us which parts of the data should be retrieved.
14 |
15 | Note: We should not construct Dataset by ourself. using `store.get_dataset()` is recommended.
16 | """
17 |
18 | def __init__(
19 | self,
20 | fs: "FeatureStore",
21 | service: "Service",
22 | sampler: callable,
23 | ):
24 | self.fs = fs
25 | self.service = service
26 | self.sampler = sampler
27 |
28 | def to_pytorch(self, chunk_size: int = 64) -> "TorchIterableDataset":
29 | """convert to iterable pytorch dataset really hold data"""
30 | from .pytorch_dataset import TorchIterableDataset
31 |
32 | return TorchIterableDataset(feature_store=self.fs, service=self.service, sampler=self.sampler, chunk_size=chunk_size)
33 |
--------------------------------------------------------------------------------
/f2ai/models/sequential.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 |
5 |
6 | class SimpleClassify(nn.Module):
7 | def __init__(self, cont_nbr, cat_nbr, emd_dim, max_types, hidden_size=16, drop_out=0.1) -> None:
8 | super().__init__()
9 | # num_embeddings not less than type
10 | self.categorical_embedding = nn.Embedding(num_embeddings=max_types, embedding_dim=emd_dim)
11 | hidden_list = 2 ** np.arange(max(np.ceil(np.log2(hidden_size)), 2) + 1)[::-1]
12 | model_list = [nn.Linear(cont_nbr + cat_nbr * emd_dim, int(hidden_list[0]))]
13 | for i in range(len(hidden_list) - 1):
14 | model_list.append(nn.Dropout(drop_out))
15 | model_list.append(nn.Linear(int(hidden_list[i]), int(hidden_list[i + 1])))
16 |
17 | self.model = nn.Sequential(*model_list, nn.Sigmoid())
18 |
19 | def forward(self, x):
20 | cat_vector = []
21 | for i in range(x["categorical_features"].shape[-1]):
22 | cat = self.categorical_embedding(x["categorical_features"][..., i])
23 | cat_vector.append(cat)
24 | cat_vector = torch.cat(cat_vector, dim=-1)
25 | input_vector = torch.cat([cat_vector, x["continous_features"]], dim=-1)
26 | return self.model(input_vector)
27 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | push:
13 | tags:
14 | - release
15 |
16 | permissions:
17 | contents: read
18 |
19 | jobs:
20 | deploy:
21 |
22 | runs-on: ubuntu-latest
23 |
24 | steps:
25 | - uses: actions/checkout@v3
26 | - name: Set up Python
27 | uses: actions/setup-python@v3
28 | with:
29 | python-version: '3.x'
30 | - name: Install dependencies
31 | run: |
32 | python -m pip install --upgrade pip
33 | pip install setuptools wheel twine
34 | - name: build
35 | run: |
36 | python setup.py sdist bdist_wheel
37 | - name: Publish package
38 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
39 | with:
40 | user: __token__
41 | password: ${{ secrets.PYPI_API_TOKEN }}
42 |
--------------------------------------------------------------------------------
/tests/fixtures/credit_score_fixtures.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import os
3 |
4 | from .git_utils import git_clean, git_clone, git_reset, git_pull
5 | from .constants import TEMP_DIR
6 |
7 | CREDIT_SCORE_CFG = {
8 | "repo": "git@code.unianalysis.com:f2ai-examples/f2ai-credit-scoring.git",
9 | "infra": {
10 | "file": {
11 | "cwd": os.path.join(TEMP_DIR, "f2ai-credit-scoring_file"),
12 | "branch": "main",
13 | },
14 | "pgsql": {
15 | "cwd": os.path.join(TEMP_DIR, "f2ai-credit-scoring_pgsql"),
16 | "branch": "ver_pgsql",
17 | },
18 | },
19 | }
20 |
21 |
22 | @pytest.fixture(scope="session")
23 | def make_credit_score():
24 | if not os.path.exists(TEMP_DIR):
25 | os.makedirs(TEMP_DIR)
26 |
27 | def get_credit_score(infra="file"):
28 | repo = CREDIT_SCORE_CFG["repo"]
29 | infra = CREDIT_SCORE_CFG["infra"][infra]
30 | cwd = infra["cwd"]
31 | branch = infra["branch"]
32 |
33 | # clone repo to cwd
34 | if not os.path.isdir(cwd):
35 | git_clone(cwd, repo, branch)
36 |
37 | # reset, clean and update to latest
38 | git_reset(cwd, branch)
39 | git_clean(cwd)
40 | git_pull(cwd, branch)
41 |
42 | return cwd
43 |
44 | yield get_credit_score
45 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # For the full list of built-in configuration values, see the documentation:
4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
5 |
6 | # -- Project information -----------------------------------------------------
7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
8 | import os
9 | import sys
10 |
11 | docs_dir = os.path.dirname(os.path.realpath(__file__))
12 | project_dir = os.path.realpath(os.path.join(docs_dir, ".."))
13 | sys.path.append(project_dir)
14 | print(f'Append sys path: {project_dir}')
15 |
16 | project = "F2AI"
17 | copyright = "2022, eavae"
18 | author = "eavae"
19 |
20 | # -- General configuration ---------------------------------------------------
21 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
22 |
23 | extensions = ["sphinx.ext.napoleon", "sphinx.ext.autodoc"]
24 | templates_path = ["templates"]
25 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
26 | html_sidebars = {"**": ["globaltoc.html"]}
27 |
28 | # -- Options for HTML output -------------------------------------------------
29 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
30 |
31 | html_theme = "sphinx_rtd_theme"
32 | html_static_path = ["static"]
33 |
--------------------------------------------------------------------------------
/tests/fixtures/guizhou_traffic_fixtures.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import os
3 |
4 | from .git_utils import git_clean, git_clone, git_reset, git_pull
5 | from .constants import TEMP_DIR
6 |
7 |
8 | GUIZHOU_TRAFFIC_CFG = {
9 | "repo": "git@code.unianalysis.com:f2ai-examples/guizhou_traffic.git",
10 | "infra": {
11 | "file": {
12 | "cwd": os.path.join(TEMP_DIR, "f2ai-guizhou_traffic_file"),
13 | "branch": "main",
14 | },
15 | "pgsql": {
16 | "cwd": os.path.join(TEMP_DIR, "f2ai-guizhou_traffic_pgsql"),
17 | "branch": "ver_pgsql",
18 | },
19 | },
20 | }
21 |
22 |
23 | @pytest.fixture(scope="session")
24 | def make_guizhou_traffic():
25 | if not os.path.exists(TEMP_DIR):
26 | os.makedirs(TEMP_DIR)
27 |
28 | def get_guizhou_traffic(infra="file"):
29 | repo = GUIZHOU_TRAFFIC_CFG["repo"]
30 | infra = GUIZHOU_TRAFFIC_CFG["infra"][infra]
31 | cwd = infra["cwd"]
32 | branch = infra["branch"]
33 |
34 | # clone repo to cwd
35 | if not os.path.isdir(cwd):
36 | git_clone(cwd, repo, branch)
37 |
38 | # reset, clean and update to latest
39 | git_reset(cwd, branch)
40 | git_clean(cwd)
41 | git_pull(cwd, branch)
42 |
43 | return cwd
44 |
45 | yield get_guizhou_traffic
46 |
--------------------------------------------------------------------------------
/tests/units/period_test.py:
--------------------------------------------------------------------------------
1 | from f2ai.definitions import Period
2 | from pandas import DateOffset
3 |
4 |
5 | def test_period_to_pandas_dateoffset():
6 | offset = Period(n=1, unit="day").to_pandas_dateoffset()
7 | assert offset == DateOffset(days=1)
8 |
9 |
10 | def test_period_to_sql_interval():
11 | interval = Period(n=1, unit="day").to_pgsql_interval()
12 | assert interval == "interval '1 days'"
13 |
14 |
15 | def test_period_from_str():
16 | ten_years = Period.from_str("10 years").to_pandas_dateoffset()
17 | one_day = Period.from_str("1day").to_pandas_dateoffset()
18 |
19 | assert ten_years == DateOffset(years=10)
20 | assert one_day == DateOffset(days=1)
21 |
22 |
23 | def test_period_negative():
24 | ten_years = Period.from_str("10 years")
25 | neg_ten_years = -ten_years
26 | assert neg_ten_years.n == -10
27 |
28 | neg_ten_years = Period.from_str("-10 years")
29 | assert neg_ten_years.n == -10
30 |
31 |
32 | def test_get_pandas_datetime_components():
33 | three_minutes = Period.from_str("3 minutes")
34 | components = three_minutes.get_pandas_datetime_components()
35 | assert components == ["year", "month", "day", "hour", "minute"]
36 |
37 | one_month = Period.from_str("1 month")
38 | components = one_month.get_pandas_datetime_components()
39 | assert components == ["year", "month"]
40 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | F2AI
2 | ====
3 |
4 | .. image:: https://readthedocs.org/projects/f2ai/badge/?version=latest
5 | :target: https://f2ai.readthedocs.io/en/latest/?badge=latest
6 | :alt: Documentation Status
7 |
8 | The architecture above is a working flow demonstration when using F2AI but not technical architecture.
9 |
10 | Getting Started
11 | ---------------
12 |
13 | 1. Install F2AI
14 |
15 |
16 | Overview
17 | -------------
18 |
19 | F2AI (Feature Store to AI) is a time centric productivity data utils that re-uses existing infrastructure to get features more consistently though different stages of AI development.
20 |
21 | F2AI is focusing on:
22 |
23 | * **Consistent API to get features for training and serving**: Powered by well encapsulated OfflineStore and OnlineStore, the features are strictly managed by F2AI, and keep the same structure not only when training models, but also inference.
24 | * **Prevent feature leakage** by effective point-in-time join that reduce the cumbersome works to get features ready when experimenting a feature combinations or AI model.
25 | * **Build-in supported period feature** that allows 2 dimensional features can be retrieved easily. This is useful when facing some deep learning tasks like, time series forecasting.
26 | * **Infrastructure unawareness** by different infrastructure implementations, switching between different data storage is simply changing the configuration. The AI model will works well like before.
27 |
28 | Architecture
29 | ------------
30 |
31 | .. image:: ./docs/static/f2ai_architecture.png
32 | :alt: f2ai function architecture
33 |
34 | .. note::
35 | This is a working flow when using F2AI instead of technical architecture.
36 |
--------------------------------------------------------------------------------
/f2ai/cmd/main.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from argparse import ArgumentParser
3 | from typing import List
4 |
5 | from ..common.cmd_parser import add_materialize_parser
6 | from ..definitions import BackOffTime
7 | from ..featurestore import FeatureStore
8 |
9 |
10 | def materialize(url: str, services: List[str], back_off_time: BackOffTime, online: bool):
11 | fs = FeatureStore(url)
12 |
13 | fs.materialize(services, back_off_time, online)
14 |
15 |
16 | def main():
17 | parser = ArgumentParser(prog="F2AI")
18 | subparsers = parser.add_subparsers(title="subcommands", dest="commands")
19 | add_materialize_parser(subparsers)
20 |
21 | kwargs = vars(parser.parse_args())
22 | commands = kwargs.pop("commands", None)
23 | if commands == "materialize":
24 | if kwargs["fromnow"] is None and kwargs["start"] is None and kwargs["end"] is None:
25 | parser.error("One of fromnow or start&end is required.")
26 |
27 | if not pathlib.Path("feature_store.yml").exists():
28 | parser.error(
29 | "No feature_store.yml found in current folder, please switch to folder which feature_store.yml exists."
30 | )
31 |
32 | if commands == "materialize":
33 | from_now = kwargs.pop("fromnow", None)
34 | step = kwargs.pop("step", None)
35 | tz = kwargs.pop("tz", None)
36 |
37 | if from_now is not None:
38 | back_off_time = BackOffTime.from_now(from_now=from_now, step=step, tz=tz)
39 | else:
40 | back_off_time = BackOffTime(start=kwargs.pop("start"), end=kwargs.pop("end"), step=step, tz=tz)
41 |
42 | materialize("file://.", kwargs.pop("services"), back_off_time, kwargs.pop("online"))
43 |
--------------------------------------------------------------------------------
/f2ai/models/normalizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 |
5 | class MinMaxNormalizer:
6 | """scale continous features to [-1,1],"""
7 |
8 | def __init__(self, feature_name=None):
9 | self.feature_name = feature_name
10 |
11 | def fit_self(self, data: pd.Series):
12 | min = data.min()
13 | max = data.max()
14 | self._state = ((min + max) / 2.0, max - min + 1e-8)
15 | return self
16 |
17 | def transform_self(self, data: pd.Series) -> pd.Series:
18 | assert self._state is not None
19 | center, scale = self._state
20 |
21 | return (data - center) / scale
22 |
23 | def inverse_transform_self(self, data: pd.Series) -> pd.Series:
24 | assert self._state is not None
25 | center, scale = self._state
26 |
27 | return data * scale + center
28 |
29 |
30 | class StandardNormalizer:
31 | def __init__(self, feature_name=None, center=True):
32 | self.feature_name = feature_name
33 | self._center = center
34 | self._eps = 1e-6
35 |
36 | def fit_self(self, data: pd.Series, source: pd.DataFrame = None, **kwargs):
37 | if self._center:
38 | self._state = (data.mean(), data.std() + self._eps)
39 | else:
40 | self._state = (0.0, data.mean() + self._eps)
41 |
42 | def transform_self(self, data: pd.Series) -> pd.Series:
43 | assert self._state is not None
44 | center, scale = self._state
45 |
46 | return (data - center) / scale
47 |
48 | def inverse_transform_self(self, data: pd.Series) -> pd.Series:
49 | assert self._state is not None
50 | center, scale = self._state
51 |
52 | return data * scale + center
53 |
--------------------------------------------------------------------------------
/f2ai/models/encoder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 |
5 | def column_or_1d(y, warn):
6 | y = np.asarray(y)
7 | shape = np.shape(y)
8 | if len(shape) == 1:
9 | return np.ravel(y)
10 | elif len(shape) == 2 and shape[1] == 1:
11 | return np.ravel(y)
12 |
13 | raise ValueError("y should be a 1d array, got an array of shape {} instead.".format(shape))
14 |
15 |
16 | def map_to_integer(values, uniques):
17 | """Map values based on its position in uniques."""
18 | table = {val: i for i, val in enumerate(uniques)}
19 | return np.array([table[v] if v in table else table["UNKNOWN_CAT"] for v in values])
20 |
21 |
22 | class LabelEncoder:
23 | """encode categorical features into int number by their appearance order,
24 | will always set 0 to be UNKNOWN_CAT automatically
25 | """
26 |
27 | def __init__(self, feature_name=None):
28 | self.feature_name = feature_name
29 |
30 | def fit_self(self, y: pd.Series):
31 | y = column_or_1d(y, warn=True)
32 | self._state = ["UNKNOWN_CAT"] + sorted(set(y))
33 | return self
34 |
35 | def transform_self(self, y: pd.Series) -> pd.Series:
36 | y = column_or_1d(y, warn=True)
37 | if len(y) == 0:
38 | return np.array([])
39 | y = map_to_integer(y, self._state)
40 | return y
41 |
42 | def inverse_transform_self(self, y: pd.Series) -> pd.Series:
43 | y = column_or_1d(y, warn=True)
44 | if len(y) == 0:
45 | return np.array([])
46 | diff = np.setdiff1d(y, np.arange(len(self._state)))
47 | if len(diff):
48 | raise ValueError("y contains previously unseen labels: %s" % str(diff))
49 | y = np.asarray(y)
50 | return [self._state[i] for i in y]
51 |
--------------------------------------------------------------------------------
/f2ai/persist_engine/online_local_persistengine.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | from ..definitions import (
4 | Period,
5 | BackOffTime,
6 | OnlinePersistEngine,
7 | OnlinePersistEngineType,
8 | PersistFeatureView,
9 | )
10 |
11 | from ..common.time_field import DEFAULT_EVENT_TIMESTAMP_FIELD
12 |
13 |
14 | class OnlineLocalPersistEngine(OnlinePersistEngine):
15 | type: OnlinePersistEngineType = OnlinePersistEngineType.LOCAL
16 |
17 | def materialize(
18 | self,
19 | prefix: str,
20 | feature_view: PersistFeatureView,
21 | back_off_time: BackOffTime,
22 | view_name: str,
23 | ):
24 | date_df = pd.DataFrame(data=[back_off_time.end], columns=[DEFAULT_EVENT_TIMESTAMP_FIELD])
25 | period = -Period.from_str(str(back_off_time.end - back_off_time.start))
26 | entities_in_range = self.offline_store.get_latest_entities(
27 | source=feature_view.source,
28 | group_keys=feature_view.join_keys,
29 | entity_df=date_df,
30 | start=back_off_time.start,
31 | ).drop(columns=DEFAULT_EVENT_TIMESTAMP_FIELD)
32 |
33 | data_to_write = self.offline_store.get_period_features(
34 | entity_df=pd.merge(entities_in_range, date_df, how="cross"),
35 | features=feature_view.features,
36 | source=feature_view.source,
37 | period=period,
38 | join_keys=feature_view.join_keys,
39 | ttl=feature_view.ttl,
40 | )
41 | self.online_store.write_batch(
42 | feature_view.name,
43 | prefix,
44 | data_to_write,
45 | feature_view.ttl,
46 | join_keys=feature_view.join_keys,
47 | tz=back_off_time.tz,
48 | )
49 | return view_name
50 |
--------------------------------------------------------------------------------
/f2ai/definitions/__init__.py:
--------------------------------------------------------------------------------
1 | from .entities import Entity
2 | from .features import Feature, FeatureSchema, SchemaAnchor
3 | from .period import Period
4 | from .base_view import BaseView
5 | from .feature_view import FeatureView
6 | from .label_view import LabelView
7 | from .services import Service
8 | from .sources import Source, FileSource, SqlSource, parse_source_yaml
9 | from .offline_store import OfflineStore, OfflineStoreType, init_offline_store_from_cfg
10 | from .online_store import OnlineStore, OnlineStoreType, init_online_store_from_cfg
11 | from .constants import LOCAL_TIMEZONE, StatsFunctions
12 | from .backoff_time import BackOffTime
13 | from .persist_engine import (
14 | PersistFeatureView,
15 | PersistLabelView,
16 | PersistEngine,
17 | OfflinePersistEngine,
18 | OnlinePersistEngine,
19 | OfflinePersistEngineType,
20 | OnlinePersistEngineType,
21 | init_persist_engine_from_cfg,
22 | )
23 | from .dtypes import FeatureDTypes
24 |
25 | __all__ = [
26 | "Entity",
27 | "Feature",
28 | "FeatureSchema",
29 | "SchemaAnchor",
30 | "Period",
31 | "BaseView",
32 | "FeatureView",
33 | "LabelView",
34 | "Service",
35 | "Source",
36 | "FileSource",
37 | "SqlSource",
38 | "OfflineStoreType",
39 | "OfflineStore",
40 | "parse_source_yaml",
41 | "init_offline_store_from_cfg",
42 | "LOCAL_TIMEZONE",
43 | "StatsFunctions",
44 | "FeatureDTypes",
45 | "OnlineStoreType",
46 | "OnlineStore",
47 | "init_online_store_from_cfg",
48 | "BackOffTime",
49 | "PersistFeatureView",
50 | "PersistLabelView",
51 | "PersistEngine",
52 | "OnlinePersistEngine",
53 | "OfflinePersistEngine",
54 | "OfflinePersistEngineType",
55 | "OnlinePersistEngineType",
56 | "init_persist_engine_from_cfg",
57 | ]
58 |
--------------------------------------------------------------------------------
/f2ai/definitions/backoff_time.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from datetime import datetime
3 | from dataclasses import dataclass
4 | from typing import Iterator, Union
5 |
6 | from f2ai.definitions.period import Period
7 |
8 |
9 | @dataclass
10 | class BackOffTime:
11 | """
12 | Useful to define how to split data by time.
13 | """
14 |
15 | start: pd.Timestamp
16 | end: pd.Timestamp
17 | step: Period
18 | tz: str
19 |
20 | def __init__(
21 | self,
22 | start: Union[str, pd.Timestamp],
23 | end: Union[str, pd.Timestamp],
24 | step: Union[str, Period] = "1 day",
25 | tz: str = None,
26 | ) -> None:
27 | # if tz is needed, then pass as '2016-01-01 00+8' to indicate 8hours offset
28 | if isinstance(start, str):
29 | start = pd.Timestamp(start, tz=tz)
30 | if isinstance(end, str):
31 | end = pd.Timestamp(end, tz=tz)
32 |
33 | if isinstance(step, str):
34 | step = Period.from_str(step)
35 |
36 | self.start = start
37 | self.end = end
38 | self.step = step
39 | self.tz = tz
40 |
41 | def to_units(self) -> Iterator[Period]:
42 | pd_offset = self.step.to_pandas_dateoffset()
43 | start = self.step.normalize(self.start, "floor")
44 | end = self.step.normalize(self.end, "ceil")
45 |
46 | bins = pd.date_range(
47 | start=start,
48 | end=end,
49 | freq=pd_offset,
50 | )
51 | for (start, end) in zip(bins[:-1], bins[1:]):
52 | yield BackOffTime(start=start, end=end, step=self.step, tz=self.tz)
53 |
54 | @classmethod
55 | def from_now(cls, from_now: str, step: str = None, tz: str = None):
56 | end = pd.Timestamp(datetime.now(), tz=tz)
57 | start = end - Period.from_str(from_now).to_py_timedelta()
58 | return BackOffTime(start=start, end=end, step=step, tz=tz)
59 |
--------------------------------------------------------------------------------
/f2ai/models/earlystop.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import os
4 |
5 |
6 | class EarlyStopping:
7 | """Early stops the training if validation loss doesn't improve after a given patience."""
8 |
9 | def __init__(self, save_path=".", patience=7, verbose=False, delta=0, **kwargs):
10 | """
11 | Args:
12 | patience (int): How long to wait after last time validation loss improved.
13 | 上次验证集损失值改善后等待几个epoch
14 | Default: 7
15 | verbose (bool): If True, prints a message for each validation loss improvement.
16 | 如果是True,为每个验证集损失值改善打印一条信息
17 | Default: False
18 | delta (float): Minimum change in the monitored quantity to qualify as an improvement.
19 | 监测数量的最小变化,以符合改进的要求
20 | Default: 0
21 | """
22 | self.save_path = save_path
23 | self.patience = patience
24 | self.verbose = verbose
25 | self.counter = 0
26 | self.best_score = None
27 | self.early_stop = False
28 | self.val_loss_min = np.Inf
29 | self.delta = delta
30 | self.name = kwargs.get("cpnmae", "")
31 |
32 | def __call__(self, val_loss, model):
33 |
34 | score = -val_loss
35 |
36 | if self.best_score is None:
37 | self.best_score = score
38 | self.save_checkpoint(val_loss, model)
39 | elif score < self.best_score + self.delta:
40 | self.counter += 1
41 | print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
42 | if self.counter >= self.patience:
43 | self.early_stop = True
44 | else:
45 | self.best_score = score
46 | self.save_checkpoint(val_loss, model)
47 | self.counter = 0
48 |
49 | def save_checkpoint(self, val_loss, model):
50 | """
51 | Saves model when validation loss decrease.
52 | 验证损失减少时保存模型。
53 | """
54 | if self.verbose:
55 | print(
56 | f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ..."
57 | )
58 | path_to_save = os.path.join(self.save_path, f"best_chekpnt_{self.name}.pk")
59 | torch.save(model, path_to_save)
60 | self.val_loss_min = val_loss
61 |
--------------------------------------------------------------------------------
/f2ai/persist_engine/offline_file_persistengine.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from operator import imod
3 | from typing import List
4 | import pandas as pd
5 | import datetime
6 |
7 | from ..offline_stores.offline_file_store import OfflineFileStore
8 | from ..definitions import (
9 | FileSource,
10 | OfflinePersistEngine,
11 | OfflinePersistEngineType,
12 | BackOffTime,
13 | PersistFeatureView,
14 | PersistLabelView,
15 | )
16 | from ..common.utils import write_df_to_dataset
17 | from ..common.time_field import TIME_COL, MATERIALIZE_TIME
18 |
19 |
20 | class OfflineFilePersistEngine(OfflinePersistEngine):
21 | type: OfflinePersistEngineType = OfflinePersistEngineType.FILE
22 | offline_store: OfflineFileStore
23 |
24 | def materialize(
25 | self,
26 | feature_views: List[PersistFeatureView],
27 | label_view: PersistLabelView,
28 | destination: FileSource,
29 | back_off_time: BackOffTime,
30 | service_name: str,
31 | ):
32 | # retrieve entity_df
33 | # TODO:
34 | # 1. 这里是否需要进行更合理的抽象,而不是使用一个私有函数
35 | # 2. 在读取数据之前,框定时间可以可以提高效率
36 | entity_df = self.offline_store._read_file(
37 | source=label_view.source, features=label_view.labels, join_keys=label_view.join_keys
38 | )
39 |
40 | entity_df.drop(columns=["created_timestamp"], errors="ignore")
41 | entity_df = entity_df[
42 | (entity_df[TIME_COL] >= back_off_time.start) & (entity_df[TIME_COL] < back_off_time.end)
43 | ]
44 |
45 | # join features recursively
46 | # TODO: this should be reimplemented to directly consume multi feature_views and do a performance test.
47 | joined_frame = entity_df
48 | for feature_view in feature_views:
49 | joined_frame = self.offline_store.get_features(
50 | entity_df=joined_frame,
51 | features=feature_view.features,
52 | source=feature_view.source,
53 | join_keys=feature_view.join_keys,
54 | ttl=feature_view.ttl,
55 | include=True,
56 | how="right",
57 | )
58 | tz = joined_frame[TIME_COL][0].tz if not joined_frame.empty else None
59 | joined_frame[MATERIALIZE_TIME] = pd.Timestamp(datetime.datetime.now(), tz=tz)
60 | write_df_to_dataset(
61 | joined_frame, destination.path, time_col=destination.timestamp_field, period=back_off_time.step
62 | )
63 | return service_name
64 |
--------------------------------------------------------------------------------
/tests/units/dataset/entities_sampler_test.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from f2ai.dataset import (
3 | NoEntitiesSampler,
4 | EvenEventsSampler,
5 | FixedNEntitiesSampler,
6 | EvenEntitiesSampler,
7 | )
8 |
9 |
10 | def test_even_entities_sampler():
11 | event_time_sampler = EvenEventsSampler(start="2022-10-02", end="2022-10-04", period="1 day")
12 | group_df = pd.DataFrame(
13 | {
14 | "fruit": ["apple", "banana", "banana"],
15 | "sauce": ["tomato", "chili", "tomato"],
16 | }
17 | )
18 | sampler = EvenEntitiesSampler(event_time_sampler, group_df)
19 |
20 | sampled_df = sampler()
21 | assert len(sampled_df) == 9
22 | assert list(sampled_df.columns) == ["fruit", "sauce", "event_timestamp"]
23 |
24 | next_sampled_item = next(iter(sampler))
25 | assert all([key in {"fruit", "sauce", "event_timestamp"} for key in next_sampled_item.keys()])
26 |
27 |
28 | def test_fixed_n_entities_sampler():
29 | event_time_sampler = EvenEventsSampler(start="2022-10-02", end="2022-10-04", period="1 day")
30 | group_df = pd.DataFrame(
31 | {
32 | "fruit": ["apple", "banana", "banana"],
33 | "sauce": ["tomato", "chili", "tomato"],
34 | }
35 | )
36 | sampler = FixedNEntitiesSampler(event_time_sampler, group_df, n=2)
37 |
38 | sampled_df = sampler()
39 | assert len(sampled_df) == 6
40 | assert list(sampled_df.columns) == ["fruit", "sauce", "event_timestamp"]
41 |
42 | next_sampled_item = next(iter(sampler))
43 | assert all([key in {"fruit", "sauce", "event_timestamp"} for key in next_sampled_item.keys()])
44 |
45 |
46 | def test_fixed_n_prob_entities_sampler():
47 | event_time_sampler = EvenEventsSampler(start="2022-10-02", end="2022-10-06", period="1 day")
48 | group_df = pd.DataFrame(
49 | {"fruit": ["apple", "banana", "banana"], "sauce": ["tomato", "chili", "tomato"], "p": [0.2, 0.6, 0.2]}
50 | )
51 | sampler = FixedNEntitiesSampler(event_time_sampler, group_df, n=3)
52 |
53 | sampled_df = sampler()
54 | assert len(sampled_df) == 9
55 | assert list(sampled_df.columns) == ["fruit", "sauce", "event_timestamp"]
56 |
57 | next_sampled_item = next(iter(sampler))
58 | assert all([key in {"fruit", "sauce", "event_timestamp"} for key in next_sampled_item.keys()])
59 |
60 |
61 | def test_sampler_properties():
62 | event_time_sampler = EvenEventsSampler(start="2022-10-02", end="2022-10-06", period="1 day")
63 | sampler = NoEntitiesSampler(event_time_sampler)
64 | assert sampler.iterable
65 | assert not sampler.iterable_only
66 |
--------------------------------------------------------------------------------
/f2ai/definitions/sources.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Any
2 | from pydantic import BaseModel, Field
3 | from enum import Enum
4 |
5 | from f2ai.common.utils import read_file
6 | from .offline_store import OfflineStoreType
7 | from .features import FeatureSchema
8 |
9 |
10 | class Source(BaseModel):
11 | """An abstract class which describe the common part of a Source. A source usually defines where to access data and what the time semantic it has. In F2AI, we have 2 kinds of time semantic:
12 |
13 | 1. timestamp_field: the event timestamp which represent when the record happened, which is the main part of point-in-time join.
14 | 2. created_timestamp_field: the created timestamp which represent when the record created, which usually happened in multi cycles of feature generation scenario.
15 | """
16 |
17 | name: str
18 | description: Optional[str]
19 | timestamp_field: Optional[str]
20 | created_timestamp_field: Optional[str] = Field(alias="created_timestamp_column")
21 | tags: Dict[str, str] = {}
22 |
23 |
24 | class FileFormatEnum(str, Enum):
25 | PARQUET = "parquet"
26 | TSV = "tsv"
27 | CSV = "csv"
28 | TEXT = "text"
29 |
30 |
31 | class FileSource(Source):
32 | file_format: FileFormatEnum = FileFormatEnum.CSV
33 | path: str
34 |
35 | def read_file(self, str_cols: List[str] = [], keep_cols: List[str] = []):
36 | time_columns = []
37 | if self.timestamp_field:
38 | time_columns.append(self.timestamp_field)
39 | if self.created_timestamp_field:
40 | time_columns.append(self.created_timestamp_field)
41 |
42 | return read_file(
43 | self.path,
44 | file_format=self.file_format,
45 | parse_dates=time_columns,
46 | str_cols=str_cols,
47 | keep_cols=keep_cols,
48 | )
49 |
50 |
51 | class RequestSource(Source):
52 | schemas: List[FeatureSchema] = Field(alias="schema")
53 |
54 |
55 | class SqlSource(Source):
56 | query: str
57 |
58 | def __init__(__pydantic_self__, **data: Any) -> None:
59 |
60 | query = data.pop("query", "")
61 | if query == "":
62 | query = data.get("name")
63 |
64 | super().__init__(**data, query=query)
65 |
66 |
67 | def parse_source_yaml(o: Dict, offline_store_type: OfflineStoreType) -> Source:
68 | if o.get("type", None) == "request_source":
69 | return RequestSource(**o)
70 |
71 | if offline_store_type == OfflineStoreType.FILE:
72 | return FileSource(**o)
73 | elif offline_store_type == OfflineStoreType.PGSQL:
74 | return SqlSource(**o)
75 | elif offline_store_type == OfflineStoreType.SPARK:
76 | raise Exception("spark is not supported yet!")
77 | else:
78 | return Source(**o)
79 |
--------------------------------------------------------------------------------
/f2ai/definitions/online_store.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import abc
4 | from enum import Enum
5 | from typing import TYPE_CHECKING, Any, Dict, Optional
6 |
7 | import pandas as pd
8 | from pydantic import BaseModel
9 |
10 | if TYPE_CHECKING:
11 | from .feature_view import FeatureView
12 | from .period import Period
13 | from .sources import Source
14 |
15 |
16 | class OnlineStoreType(str, Enum):
17 | """A constant numerate choices which is used to indicate how to initialize OnlineStore from configuration. If you want to add a new type of online store, you definitely want to modify this."""
18 |
19 | REDIS = "redis"
20 |
21 |
22 | class OnlineStore(BaseModel):
23 | """An abstraction of what functionalities a OnlineStore should implements. If you want to be one of the online store contributor. This is the core."""
24 |
25 | type: OnlineStoreType
26 | name: str
27 |
28 | class Config:
29 | extra = "allow"
30 |
31 | @abc.abstractmethod
32 | def write_batch(
33 | self, featrue_view: FeatureView, project_name: str, dt: pd.DataFrame, ttl: Optional[Period]
34 | ) -> Source:
35 | """materialize data on redis
36 |
37 | Args:
38 | service (Service): an instance of Service
39 |
40 | Returns:
41 | Source
42 | """
43 | pass
44 |
45 | @abc.abstractmethod
46 | def read_batch(
47 | self,
48 | hkey: str,
49 | ttl: Optional[Period] = None,
50 | period: Optional[Period] = None,
51 | **kwargs,
52 | ) -> pd.DataFrame:
53 | """get data from current online store.
54 |
55 | Args:
56 | entity_df (pd.DataFrame): A query DataFrame which include entities and event_timestamp column.
57 | hkey: hash key.
58 | ttl (Optional[Period], optional): Time to Live, if feature's event_timestamp exceeds the ttl, it will be dropped. Defaults to None.
59 |
60 | Returns:
61 | pd.DataFrame
62 | """
63 | pass
64 |
65 |
66 | def init_online_store_from_cfg(cfg: Dict[Any], name: str) -> OnlineStore:
67 | """Initialize an implementation of OnlineStore from yaml config.
68 |
69 | Args:
70 | cfg (Dict[Any]): a parsed config object.
71 |
72 | Returns:
73 | OnlineStore: Different types of OnlineStore.
74 | """
75 | online_store_type = OnlineStoreType(cfg["type"])
76 |
77 | if online_store_type == OnlineStoreType.REDIS:
78 | from ..online_stores.online_redis_store import OnlineRedisStore
79 |
80 | redis_conf = cfg.pop("redis_conf", {})
81 | redis_conf.update({"name": name})
82 | return OnlineRedisStore(**cfg, **redis_conf)
83 |
84 | raise TypeError(f"offline store type must be one of [{','.join(e.value for e in OnlineStore)}]")
85 |
--------------------------------------------------------------------------------
/f2ai/common/get_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | from typing import List, Dict
4 |
5 | from ..definitions import (
6 | OfflineStoreType,
7 | Entity,
8 | FeatureView,
9 | LabelView,
10 | Service,
11 | Source,
12 | parse_source_yaml,
13 | )
14 | from .read_file import read_yml
15 | from .utils import remove_prefix
16 |
17 |
18 | def listdir_with_extensions(path: str, extensions: List[str] = []) -> List[str]:
19 | path = remove_prefix(path, "file://")
20 | if os.path.isdir(path):
21 | files = []
22 | for extension in extensions:
23 | files.extend(glob.glob(f"{path}/*.{extension}"))
24 | return files
25 | return []
26 |
27 |
28 | def listdir_yamls(path: str) -> List[str]:
29 | return listdir_with_extensions(path, extensions=["yml", "yaml"])
30 |
31 |
32 | def get_service_cfg(url: str) -> Dict[str, Service]:
33 | """get forecast config like length of look_back and look_forward, features and labels
34 |
35 | Args:
36 | url (str): url of .yml
37 | """
38 | service_cfg = {}
39 | for filepath in listdir_yamls(url):
40 | service = Service.from_yaml(read_yml(filepath))
41 | service_cfg[service.name] = service
42 | return service_cfg
43 |
44 |
45 | def get_entity_cfg(url: str) -> Dict[str, Entity]:
46 | """get entity config for join
47 |
48 | Args:
49 | url (str): url of .yml
50 | """
51 | entities = {}
52 | for filepath in listdir_yamls(url):
53 | entity = Entity(**read_yml(filepath))
54 | entities[entity.name] = entity
55 | return entities
56 |
57 |
58 | def get_feature_views(url: str) -> Dict[str, FeatureView]:
59 | """get Dict(FeatureViews) from /feature_views/*.yml
60 |
61 | Args:
62 | url (str): rl of .yml
63 | """
64 | feature_views = {}
65 | for filepath in listdir_yamls(url):
66 | feature_view = FeatureView(**read_yml(filepath))
67 | feature_views[feature_view.name] = feature_view
68 | return feature_views
69 |
70 |
71 | def get_label_views(url: str) -> Dict[str, LabelView]:
72 | """get Dict(LabelViews) from /label_views/*.yml
73 |
74 | Args:
75 | url (str): rl of .yml
76 | """
77 | label_views = {}
78 | for filepath in listdir_yamls(url):
79 | label_view = LabelView(**read_yml(filepath))
80 | label_views[label_view.name] = label_view
81 | return label_views
82 |
83 |
84 | def get_source_cfg(url: str, offline_store_type: OfflineStoreType) -> Dict[str, Source]:
85 | """get Dict(LabelViews) from /sources/*.yml
86 |
87 | Args:
88 | url (str): rl of .yml
89 | """
90 |
91 | source_dict = {}
92 | for filepath in listdir_yamls(url):
93 | cfg = read_yml(filepath)
94 | source = parse_source_yaml(cfg, offline_store_type)
95 | source_dict.update({source.name: source})
96 | return source_dict
97 |
--------------------------------------------------------------------------------
/use_cases/guizhou_traffic_arima.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import pmdarima as pm
3 | import matplotlib.pyplot as plt
4 | import seaborn as sns
5 | from pmdarima import pipeline
6 | from pmdarima.preprocessing import DateFeaturizer, FourierFeaturizer
7 |
8 | from f2ai.dataset import TorchIterableDataset
9 | from f2ai.featurestore import FeatureStore
10 |
11 |
12 | def fill_missing(data: pd.DataFrame, time_col="event_timestamp", query_col="query_timestamp"):
13 | min_of_time = data[time_col].min()
14 | max_of_time = data[time_col].max()
15 | query_time = data.iloc[0][query_col]
16 |
17 | data = data.set_index(time_col, drop=True)
18 | if query_time < min_of_time:
19 | date_index = pd.date_range(query_time, max_of_time, freq="2T", name=time_col)[1:]
20 | else:
21 | date_index = pd.date_range(min_of_time, query_time, freq="2T", name=time_col)
22 |
23 | df = data.reindex(date_index, method="nearest").resample("12T").sum()
24 |
25 | return df.reset_index()
26 |
27 |
28 | #! f2ai materialize travel_time_prediction_arima_v1 --start=2016-03-01 --end=2016-04-01 --step='1 day'
29 | if __name__ == "__main__":
30 | ex_vars = ["event_timestamp"]
31 | fs = FeatureStore("/Users/liyu/.f2ai/f2ai-guizhou_traffic_file")
32 |
33 | # 选择4条路的4个时间段分别进行预测
34 | entity_df = pd.DataFrame(
35 | {
36 | "link_id": [
37 | "4377906286525800514",
38 | "4377906285681600514",
39 | "3377906281774510514",
40 | "4377906280784800514",
41 | ],
42 | "event_timestamp": [
43 | pd.Timestamp("2016-03-31 07:00:00"),
44 | pd.Timestamp("2016-03-31 09:00:00"),
45 | pd.Timestamp("2016-03-31 12:00:00"),
46 | pd.Timestamp("2016-03-31 18:00:00"),
47 | ],
48 | }
49 | )
50 | dataset = TorchIterableDataset(fs, "travel_time_prediction_arima_v1", entity_df)
51 |
52 | fig = plt.figure(figsize=(16, 16))
53 | axes = fig.subplots(2, 2)
54 | for i, (look_back, look_forward) in enumerate(dataset):
55 | look_back = fill_missing(look_back)
56 | look_forward = fill_missing(look_forward)
57 |
58 | pipe = pipeline.Pipeline(
59 | [
60 | ("date", DateFeaturizer(column_name="event_timestamp", with_day_of_month=False)),
61 | ("fourier", FourierFeaturizer(m=24 * 5, k=4)),
62 | ("arima", pm.arima.ARIMA(order=(6, 0, 1))),
63 | ]
64 | )
65 | pipe.fit(look_back["travel_time"], X=look_back[ex_vars])
66 | look_forward["y_pred"] = pipe.predict(len(look_forward), X=look_forward[ex_vars])
67 | melted_df = pd.melt(look_forward, id_vars=["event_timestamp"], value_vars=["travel_time", "y_pred"])
68 | sns.lineplot(melted_df, x="event_timestamp", y="value", hue="variable", ax=axes[i // 2, i % 2])
69 |
70 | fig.savefig("f2ai_guizhou_traffic_arima", bbox_inches="tight")
71 |
--------------------------------------------------------------------------------
/tests/units/offline_stores/offline_postgres_store_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 | import datetime
4 | import pandas as pd
5 | from pypika import Query
6 | from unittest.mock import MagicMock
7 |
8 | from f2ai.offline_stores.offline_postgres_store import build_stats_query, OfflinePostgresStore
9 | from f2ai.definitions import Feature, FeatureDTypes, StatsFunctions, SqlSource
10 |
11 | feature_city = Feature(name="city", dtype=FeatureDTypes.STRING, view_name="zipcode_table")
12 | feature_population = Feature(name="population", dtype=FeatureDTypes.INT, view_name="zipcode_table")
13 |
14 | FILE_DIR = os.path.dirname(__file__)
15 |
16 |
17 | def read_sql_str(*args) -> str:
18 | with open(os.path.join(FILE_DIR, "postgres_sqls", f"{'_'.join(args)}.sql"), "r") as f:
19 | return f.read()
20 |
21 |
22 | @pytest.mark.parametrize("fn", [(fn) for fn in StatsFunctions if fn != StatsFunctions.UNIQUE])
23 | def test_build_stats_query_with_group_by_numeric(fn: StatsFunctions):
24 | q = Query.from_("zipcode_table")
25 | sql = build_stats_query(
26 | q,
27 | features=[feature_population],
28 | stats_fn=fn,
29 | group_keys=["zipcode"],
30 | )
31 | expected_sql = read_sql_str("stats_query", fn.value)
32 | assert sql.get_sql() == expected_sql
33 |
34 |
35 | @pytest.mark.parametrize("fn", [StatsFunctions.UNIQUE])
36 | def test_build_stats_query_with_group_by_categorical(fn: StatsFunctions):
37 | q = Query.from_("zipcode_table")
38 | sql = build_stats_query(
39 | q,
40 | stats_fn=fn,
41 | group_keys=["zipcode"],
42 | )
43 | expected_sql = read_sql_str("stats_query", fn.value)
44 | assert sql.get_sql() == expected_sql
45 |
46 |
47 | def test_stats_numeric():
48 | source = SqlSource(name="foo", query="zipcode_table", timestamp_field="event_timestamp")
49 | store = OfflinePostgresStore(
50 | host="localhost",
51 | user="foo",
52 | password="bar",
53 | )
54 |
55 | mock = MagicMock()
56 | store._get_dataframe = mock
57 |
58 | store.stats(
59 | source=source,
60 | features=[feature_population],
61 | fn=StatsFunctions.AVG,
62 | group_keys=["zipcode"],
63 | start=datetime.datetime(year=2017, month=1, day=1),
64 | end=datetime.datetime(year=2018, month=1, day=1),
65 | )
66 | sql, columns = mock.call_args[0] # the first call
67 | assert ",".join(columns) == "zipcode,population"
68 | assert sql.get_sql() == read_sql_str("store_stats_query", "numeric")
69 |
70 |
71 | def test_stats_unique():
72 | source = SqlSource(name="foo", query="zipcode_table", timestamp_field="event_timestamp")
73 | store = OfflinePostgresStore(host="localhost", user="foo", password="bar")
74 |
75 | mock = MagicMock(return_value=pd.DataFrame({"zipcode": ["A"]}))
76 | store._get_dataframe = mock
77 |
78 | store.stats(
79 | source=source,
80 | features=[feature_city],
81 | fn=StatsFunctions.UNIQUE,
82 | group_keys=["zipcode"],
83 | start=datetime.datetime(year=2017, month=1, day=1),
84 | end=datetime.datetime(year=2018, month=1, day=1),
85 | )
86 | sql, columns = mock.call_args[0] # the first call
87 |
88 | assert ",".join(columns) == "zipcode"
89 | assert sql.get_sql() == read_sql_str("store_stats_query", "categorical")
90 |
--------------------------------------------------------------------------------
/f2ai/common/oss_utils.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import os
3 | from typing import Tuple, Mapping, Callable, Optional, IO, Any
4 | import oss2
5 | import stat
6 | import tempfile
7 | import zipfile
8 | from ignite.handlers import DiskSaver
9 | from .utils import remove_prefix
10 |
11 |
12 | @functools.lru_cache(maxsize=64)
13 | def get_bucket(bucket, endpoint=None):
14 | key_id = os.environ.get("OSS_ACCESS_KEY_ID")
15 | key_secret = os.environ.get("OSS_ACCESS_KEY_SECRET")
16 | endpoint = endpoint or os.environ.get("OSS_ENDPOINT")
17 |
18 | return oss2.Bucket(oss2.Auth(key_id, key_secret), endpoint, bucket)
19 |
20 |
21 | def parse_oss_url(url: str) -> Tuple[str, str, str]:
22 | """
23 | url format: oss://{bucket}/{key}
24 | """
25 | url = remove_prefix(url, "oss://")
26 | components = url.split("/")
27 | return components[0], "/".join(components[1:])
28 |
29 |
30 | def get_bucket_from_oss_url(url: str):
31 | bucket_name, key = parse_oss_url(url)
32 | return get_bucket(bucket_name), key
33 |
34 |
35 | @functools.lru_cache(maxsize=1)
36 | def get_pandas_storage_options():
37 | key_id = os.environ.get("OSS_ACCESS_KEY_ID")
38 | key_secret = os.environ.get("OSS_ACCESS_KEY_SECRET")
39 | endpoint = os.environ.get("OSS_ENDPOINT")
40 |
41 | if not endpoint.startswith("https://"):
42 | endpoint = f"https://{endpoint}"
43 |
44 | return {
45 | "key": key_id,
46 | "secret": key_secret,
47 | "client_kwargs": {
48 | "endpoint_url": endpoint,
49 | },
50 | # "config_kwargs": {"s3": {"addressing_style": "virtual"}},
51 | }
52 |
53 |
54 | class DiskAndOssSaverAdd(DiskSaver):
55 | def __init__(
56 | self,
57 | dirname: str,
58 | ossaddress: str = None,
59 | create_dir: bool = True,
60 | require_empty: bool = True,
61 | **kwargs: Any,
62 | ):
63 | super().__init__(
64 | dirname=dirname, atomic=True, create_dir=create_dir, require_empty=require_empty, **kwargs
65 | )
66 | self.ossaddress = ossaddress
67 |
68 | def _save_func(self, checkpoint: Mapping, path: str, func: Callable, rank: int = 0) -> None:
69 | tmp: Optional[IO[bytes]] = None
70 | if rank == 0:
71 | tmp = tempfile.NamedTemporaryFile(delete=False, dir=self.dirname)
72 |
73 | try:
74 | func(checkpoint, tmp.file, **self.kwargs)
75 | except BaseException:
76 | if tmp is not None:
77 | tmp.close()
78 | os.remove(tmp.name)
79 | raise
80 |
81 | if tmp is not None:
82 | tmp.close()
83 | os.replace(tmp.name, path)
84 | # append group/others read mode
85 | os.chmod(path, os.stat(path).st_mode | stat.S_IRGRP | stat.S_IROTH)
86 | if self.ossaddress:
87 | bucket, key = get_bucket_from_oss_url(self.ossaddress)
88 | state_file = f"{path.rsplit('/',maxsplit=1)[0]}{os.sep}state.json"
89 | file_path = f"{path.rsplit('/',maxsplit=1)[0]}{os.sep}model.zip"
90 | with zipfile.ZipFile(file_path, "w", compression=zipfile.ZIP_BZIP2) as archive:
91 | archive.write(path, os.path.basename(path))
92 | archive.write(state_file, os.path.basename(state_file))
93 | bucket.put_object_from_file(key, file_path)
94 | os.remove(file_path)
95 |
--------------------------------------------------------------------------------
/f2ai/dataset/events_sampler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import abc
4 | from typing import Union, Any, Dict
5 |
6 | from ..definitions import Period
7 |
8 |
9 | class EventsSampler:
10 | """
11 | This sampler is used to sample event_timestamp column. Customization could be done by inheriting this class.
12 | """
13 |
14 | @abc.abstractmethod
15 | def __call__(self, *args: Any, **kwds: Any) -> pd.DatetimeIndex:
16 | """
17 | Get all sample results
18 |
19 | Returns:
20 | pd.DatetimeIndex: An datetime index
21 | """
22 | pass
23 |
24 | @abc.abstractmethod
25 | def __iter__(self, *args: Any, **kwds: Any) -> pd.Timestamp:
26 | """
27 | Iterable way to get sample result.
28 |
29 | Returns:
30 | pd.Timestamp
31 | """
32 | pass
33 |
34 |
35 | class EvenEventsSampler(EventsSampler):
36 | """
37 | A sampler which using time to generate query entity dataframe. This sampler generally useful when you don't have entity keys.
38 | """
39 |
40 | def __init__(self, start: str, end: str, period: Union[str, Period], **kwargs):
41 | """
42 | evenly sample from a range of time, with given period.
43 |
44 | Args:
45 | start (str): start datetime
46 | end (str): end datetime
47 | period (str): a period string, egg: '1 day'.
48 | **kwargs (Any): additional arguments passed to pd.date_range.
49 | """
50 | if isinstance(start, str):
51 | start = pd.to_datetime(start)
52 | if isinstance(end, str):
53 | end = pd.to_datetime(end)
54 |
55 | self._start = start
56 | self._end = end
57 | self._period = Period.from_str(period)
58 | self._kwargs = kwargs
59 |
60 | def __call__(self) -> pd.DatetimeIndex:
61 | return self._get_date_range()
62 |
63 | def __iter__(self) -> Dict:
64 | datetime_indexes = self._get_date_range()
65 | for i in datetime_indexes:
66 | yield i
67 |
68 | def _get_date_range(self) -> pd.DatetimeIndex:
69 | return pd.date_range(self._start, self._end, freq=self._period.to_pandas_freq_str(), **self._kwargs)
70 |
71 |
72 | class RandomNEventsSampler(EventsSampler):
73 | """
74 | Randomly sample a fixed number of event timestamp in a given time range.
75 | """
76 |
77 | def __init__(
78 | self,
79 | start: str,
80 | end: str,
81 | period: Union[str, Period],
82 | n: int = 1,
83 | random_state: int = None,
84 | ):
85 | """
86 | randomly sample fixed number of event timestamp.
87 |
88 | Args:
89 | start (str): start datetime
90 | end (str): end datetime
91 | period (str): a period string, egg: '1 day'.
92 | """
93 | if isinstance(start, str):
94 | start = pd.to_datetime(start)
95 | if isinstance(end, str):
96 | end = pd.to_datetime(end)
97 |
98 | self._start = start
99 | self._end = end
100 | self._period = Period.from_str(period)
101 |
102 | self._n = n
103 | self._rng = np.random.default_rng(random_state)
104 |
105 | def __call__(self) -> pd.DatetimeIndex:
106 | return self._get_date_range()
107 |
108 | def __iter__(self) -> Dict:
109 | datetime_indexes = self._get_date_range()
110 | for i in datetime_indexes:
111 | yield i
112 |
113 | def _get_date_range(self) -> pd.DatetimeIndex:
114 | datetimes = pd.date_range(self._start, self._end, freq=self._period.to_pandas_freq_str())
115 | indices = sorted(self._rng.choice(range(len(datetimes)), size=self._n, replace=False))
116 | return pd.DatetimeIndex([datetimes[i] for i in indices])
117 |
--------------------------------------------------------------------------------
/f2ai/dataset/pytorch_dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import pandas as pd
3 | from typing import TYPE_CHECKING, Optional, Union
4 | from torch.utils.data import IterableDataset
5 |
6 | from ..common.utils import batched, is_iterable
7 | from ..common.time_field import DEFAULT_EVENT_TIMESTAMP_FIELD, QUERY_COL
8 | from ..definitions import Period
9 | from .entities_sampler import EntitiesSampler
10 |
11 |
12 | if TYPE_CHECKING:
13 | from ..definitions.services import Service
14 | from ..featurestore import FeatureStore
15 |
16 |
17 | def iter_rows(df: pd.DataFrame):
18 | for _, row in df.iterrows():
19 | yield row
20 |
21 |
22 | class TorchIterableDataset(IterableDataset):
23 | """A pytorch portaled dataset."""
24 |
25 | def __init__(
26 | self,
27 | feature_store: "FeatureStore",
28 | service: Union[Service, str],
29 | sampler: Union[EntitiesSampler, pd.DataFrame],
30 | chunk_size: int = 64,
31 | ):
32 | assert is_iterable(
33 | sampler
34 | ), "TorchIterableDataset only support iterable sampler or iterable DataFrame"
35 |
36 | self._feature_store = feature_store
37 | self._service = service
38 | self._sampler = sampler
39 | self._chunk_size = chunk_size
40 |
41 | if isinstance(service, str):
42 | self._service = feature_store.services.get(service, None)
43 | assert service is not None, f"Service {service} is not found in feature store."
44 | else:
45 | self._service = service
46 |
47 | def get_feature_period(self) -> Optional[Period]:
48 | features = self._service.get_feature_objects(self._feature_store.feature_views)
49 | periods = [Period.from_str(x.period) for x in features if x.period is not None]
50 | return max(periods) if len(periods) > 0 else None
51 |
52 | def get_label_period(self) -> Optional[Period]:
53 | labels = self._service.get_label_objects(self._feature_store.label_views)
54 | periods = [Period.from_str(x.period) for x in labels if x.period is not None]
55 | return max(periods) if len(periods) > 0 else None
56 |
57 | def __iter__(self):
58 | feature_period = self.get_feature_period()
59 | label_period = self.get_label_period()
60 | join_keys = self._service.get_join_keys(
61 | self._feature_store.feature_views,
62 | self._feature_store.label_views,
63 | self._feature_store.entities,
64 | )
65 |
66 | if isinstance(self._sampler, pd.DataFrame):
67 | iterator = iter_rows(self._sampler)
68 | else:
69 | iterator = self._sampler
70 |
71 | for x in batched(iterator, batch_size=self._chunk_size):
72 | entity_df = pd.DataFrame(x)
73 | labels_df = None
74 |
75 | if feature_period is None:
76 | features_df = self._feature_store.get_features(self._service, entity_df)
77 | labels = self._service.get_label_names(self._feature_store.label_views)
78 | label_columns = join_keys + labels + [DEFAULT_EVENT_TIMESTAMP_FIELD]
79 | labels_df = features_df[label_columns]
80 | features_df = features_df.drop(columns=labels)
81 | else:
82 | if not feature_period.is_neg:
83 | feature_period = -feature_period
84 | features_df = self._feature_store.get_period_features(
85 | self._service, entity_df, feature_period
86 | )
87 |
88 | # get corresponding labels if not present in features_df
89 | if labels_df is None:
90 | labels_df = self._feature_store.get_period_labels(self._service, entity_df, label_period)
91 |
92 | if feature_period:
93 | group_columns = join_keys + [QUERY_COL]
94 | else:
95 | group_columns = join_keys + [DEFAULT_EVENT_TIMESTAMP_FIELD]
96 |
97 | labels_group = labels_df.groupby(group_columns)
98 | for name, x_features in features_df.groupby(group_columns):
99 | y_labels = labels_group.get_group(name)
100 | # TODO: test this with tabular dataset.
101 | yield (x_features, y_labels)
102 |
--------------------------------------------------------------------------------
/use_cases/.ipynb_checkpoints/credit_score-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "abc14dec",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "from f2ai.common.collecy_fn import classify_collet_fn\n",
11 | "from f2ai.models.sequential import SimpleClassify\n",
12 | "import torch\n",
13 | "from torch import nn\n",
14 | "from torch.utils.data import DataLoader\n",
15 | "from f2ai.featurestore import FeatureStore\n",
16 | "from f2ai.dataset import GroupFixednbrSampler\n",
17 | "\n",
18 | "\n",
19 | "\n",
20 | "if __name__ == \"__main__\":\n",
21 | " fs = FeatureStore(\"file:///Users/xuyizhou/Desktop/xyz_warehouse/gitlab/f2ai-credit-scoring\")\n",
22 | "\n",
23 | " ds = fs.get_dataset(\n",
24 | " service=\"credit_scoring_v1\",\n",
25 | " sampler=GroupFixednbrSampler(\n",
26 | " time_bucket=\"10 days\",\n",
27 | " stride=1,\n",
28 | " group_ids=None,\n",
29 | " group_names=None,\n",
30 | " start=\"2020-08-01\",\n",
31 | " end=\"2021-09-30\",\n",
32 | " ),\n",
33 | " )\n",
34 | " features_cat = [ # catgorical features\n",
35 | " fea\n",
36 | " for fea in fs._get_feature_to_use(fs.services[\"credit_scoring_v1\"])\n",
37 | " if fea not in fs._get_feature_to_use(fs.services[\"credit_scoring_v1\"], True)\n",
38 | " ]\n",
39 | " cat_unique = fs.stats(\n",
40 | " fs.services[\"credit_scoring_v1\"],\n",
41 | " fn=\"unique\",\n",
42 | " group_key=[],\n",
43 | " start=\"2020-08-01\",\n",
44 | " end=\"2021-09-30\",\n",
45 | " features=features_cat,\n",
46 | " ).to_dict()\n",
47 | " cat_count = {key: len(cat_unique[key]) for key in cat_unique.keys()}\n",
48 | " cont_scalar_max = fs.stats(\n",
49 | " fs.services[\"credit_scoring_v1\"], fn=\"max\", group_key=[], start=\"2020-08-01\", end=\"2021-09-30\"\n",
50 | " ).to_dict()\n",
51 | " cont_scalar_min = fs.stats(\n",
52 | " fs.services[\"credit_scoring_v1\"], fn=\"min\", group_key=[], start=\"2020-08-01\", end=\"2021-09-30\"\n",
53 | " ).to_dict()\n",
54 | " cont_scalar = {key: [cont_scalar_min[key], cont_scalar_max[key]] for key in cont_scalar_min.keys()}\n",
55 | "\n",
56 | " i_ds = ds.to_pytorch()\n",
57 | " test_data_loader = DataLoader( # `batch_siz`e and `drop_last`` do not matter now, `sampler`` set it to be None cause `test_data`` is a Iterator\n",
58 | " i_ds,\n",
59 | " collate_fn=lambda x: classify_collet_fn(\n",
60 | " x,\n",
61 | " cat_coder=cat_unique,\n",
62 | " cont_scalar=cont_scalar,\n",
63 | " label=fs._get_available_labels(fs.services[\"credit_scoring_v1\"]),\n",
64 | " ),\n",
65 | " batch_size=4,\n",
66 | " drop_last=False,\n",
67 | " sampler=None,\n",
68 | " )\n",
69 | "\n",
70 | " model = SimpleClassify(\n",
71 | " cont_nbr=len(cont_scalar_max), cat_nbr=len(cat_count), emd_dim=4, max_types=max(cat_count.values())\n",
72 | " )\n",
73 | " optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # no need to change\n",
74 | " loss_fn = nn.BCELoss() # loss function to train a classification model\n",
75 | "\n",
76 | " for epoch in range(10): # assume 10 epoch\n",
77 | " print(f\"epoch: {epoch} begin\")\n",
78 | " for x, y in test_data_loader:\n",
79 | " pred_label = model(x)\n",
80 | " true_label = y\n",
81 | " loss = loss_fn(pred_label, true_label)\n",
82 | " optimizer.zero_grad()\n",
83 | " loss.backward()\n",
84 | " optimizer.step()\n",
85 | " print(f\"epoch: {epoch} done, loss: {loss}\")\n"
86 | ]
87 | }
88 | ],
89 | "metadata": {
90 | "kernelspec": {
91 | "display_name": "autonn-3.8.7",
92 | "language": "python",
93 | "name": "python3"
94 | },
95 | "language_info": {
96 | "name": "python",
97 | "version": "3.8.7 (default, Dec 29 2021, 10:58:29) \n[Clang 13.0.0 (clang-1300.0.29.3)]"
98 | },
99 | "vscode": {
100 | "interpreter": {
101 | "hash": "5840e4ed671345474330e8fce6ab52c58896a3935f0e728b8dbef1ddfad82808"
102 | }
103 | }
104 | },
105 | "nbformat": 4,
106 | "nbformat_minor": 5
107 | }
108 |
--------------------------------------------------------------------------------
/f2ai/common/collect_fn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pandas as pd
3 |
4 | from ..models.encoder import LabelEncoder
5 | from ..models.normalizer import MinMaxNormalizer
6 |
7 |
8 | def classify_collet_fn(datas, cont_scalar={}, cat_coder={}, label=[]):
9 | """_summary_
10 |
11 | Args:
12 | datas (_type_): datas to be processed, the length equals to batch_size
13 | cont_scalar (dict, optional): normalize continous features, the key is feature name and value is scales corresponding to method using. Defaults to {}.
14 | cat_coder (dict, optional): encode categorical features, the key is feature name and value is scales corresponding to method using. Defaults to {}.
15 | label (list, optional): column name of label. Defaults to [].
16 |
17 | Returns:
18 | tuple: the first element is features and the second is label
19 | """
20 | batches = []
21 | # corresspondint to __get_item__ in Dataset
22 | for data in datas: # data[0]:features, data[1]:labels
23 | cat_features = torch.stack(
24 | [
25 | torch.tensor(
26 | LabelEncoder(cat).fit_self(pd.Series(cat_coder[cat])).transform_self(data[0][cat]),
27 | dtype=torch.int,
28 | )
29 | for cat in cat_coder.keys()
30 | ],
31 | dim=-1,
32 | )
33 | cont_features = torch.stack(
34 | [
35 | torch.tensor(
36 | MinMaxNormalizer(cont)
37 | .fit_self(pd.Series(cont_scalar[cont]))
38 | .transform_self(data[0][cont])
39 | .values,
40 | dtype=torch.float16,
41 | )
42 | for cont in cont_scalar.keys()
43 | ],
44 | dim=-1,
45 | )
46 | labels = torch.stack([torch.tensor(data[1][lab].values, dtype=torch.float) for lab in label], dim=-1)
47 | batch = (dict(categorical_features=cat_features, continous_features=cont_features), labels)
48 | batches.append(batch)
49 |
50 | # corresspondint to _collect_fn_ in Dataset
51 | categorical_features = torch.stack([batch[0]["categorical_features"] for batch in batches])
52 | continous_features = torch.stack([batch[0]["continous_features"] for batch in batches])
53 | labels = torch.stack([batch[1] for batch in batches])
54 | return (
55 | dict(
56 | categorical_features=categorical_features,
57 | continous_features=continous_features,
58 | ),
59 | labels,
60 | )
61 |
62 |
63 | def nbeats_collet_fn(
64 | datas,
65 | cont_scalar={},
66 | categoricals={},
67 | label={},
68 | ):
69 |
70 | batches = []
71 | for data in datas:
72 |
73 | all_cont = torch.stack(
74 | [
75 | torch.tensor(
76 | MinMaxNormalizer(cont)
77 | .fit_self(pd.Series(cont_scalar[cont]))
78 | .transform_self(data[0][cont]),
79 | dtype=torch.float16,
80 | )
81 | for cont in cont_scalar.keys()
82 | ],
83 | dim=-1,
84 | )
85 |
86 | all_categoricals = torch.stack(
87 | [
88 | torch.tensor(
89 | LabelEncoder(cat).fit_self(pd.Series(categoricals[cat])).transform_self(data[0][cat]),
90 | dtype=torch.int,
91 | )
92 | for cat in categoricals.keys()
93 | ],
94 | dim=-1,
95 | )
96 |
97 | targets = torch.stack(
98 | [torch.tensor(data[1][lab], dtype=torch.float) for lab in label],
99 | dim=-1,
100 | )
101 |
102 | batch = (
103 | dict(
104 | encoder_cont=all_cont,
105 | decoder_cont=all_cont,
106 | categoricals=all_categoricals,
107 | ),
108 | targets,
109 | )
110 | batches.append(batch)
111 |
112 | encoder_cont = torch.stack([batch[0]["encoder_cont"] for batch in batches])
113 | decoder_cont = torch.stack([batch[0]["decoder_cont"] for batch in batches])
114 | categoricals = torch.stack([batch[0]["categoricals"] for batch in batches])
115 | targets = torch.stack([batch[1] for batch in batches])
116 |
117 | return (
118 | dict(
119 | encoder_cont=encoder_cont + targets,
120 | decoder_cont=decoder_cont,
121 | x_categoricals=categoricals,
122 | ),
123 | targets,
124 | )
125 |
--------------------------------------------------------------------------------
/tests/integrations/benchmarks/offline_pgsql_benchmark_test.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import timeit
3 | from f2ai import FeatureStore
4 | from f2ai.definitions import StatsFunctions
5 | from f2ai.dataset import EvenEventsSampler, NoEntitiesSampler
6 | from f2ai.definitions import BackOffTime
7 |
8 |
9 | def get_guizhou_traffic_entities(store: FeatureStore):
10 | columns = ["link_id", "event_timestamp"]
11 | query_entities = store.offline_store._get_dataframe(
12 | f"select {', '.join(columns)} from gy_link_travel_time order by event_timestamp limit 1000",
13 | columns=columns,
14 | )
15 | return query_entities.astype({"link_id": "string"})
16 |
17 |
18 | def test_get_features_from_feature_view(make_guizhou_traffic):
19 | project_folder = make_guizhou_traffic("pgsql")
20 | store = FeatureStore(project_folder)
21 | entity_df = get_guizhou_traffic_entities(store)
22 |
23 | store.get_features("gy_link_travel_time_features", entity_df)
24 |
25 | measured_time = timeit.timeit(
26 | lambda: store.get_features("gy_link_travel_time_features", entity_df), number=10
27 | )
28 | print(f"get_features performance pgsql: {measured_time}s")
29 |
30 |
31 | def test_stats_from_feature_view(make_guizhou_traffic):
32 | project_folder = make_guizhou_traffic("pgsql")
33 | store = FeatureStore(project_folder)
34 |
35 | measured_time = timeit.timeit(
36 | lambda: store.stats("gy_link_travel_time_features", fn=StatsFunctions.AVG), number=10
37 | )
38 | print(f"stats performance pgsql: {measured_time}s")
39 |
40 |
41 | def test_unique_from_feature_view(make_guizhou_traffic):
42 | project_folder = make_guizhou_traffic("pgsql")
43 | store = FeatureStore(project_folder)
44 |
45 | measured_time = timeit.timeit(
46 | lambda: store.stats("gy_link_travel_time_features", group_keys=["link_id"]), number=10
47 | )
48 | print(f"stats performance pgsql: {measured_time}s")
49 |
50 |
51 | def test_get_latest_entities_from_feature_view_with_entity_df(make_guizhou_traffic):
52 | project_folder = make_guizhou_traffic("pgsql")
53 | store = FeatureStore(project_folder)
54 | measured_time = timeit.timeit(
55 | lambda: store.get_latest_entities(
56 | "gy_link_travel_time_features",
57 | pd.DataFrame({"link_id": ["3377906281518510514", "4377906284141600514"]}),
58 | ),
59 | number=10,
60 | )
61 | print(f"get_latest_entities with entity_df performance pgsql: {measured_time}s")
62 |
63 |
64 | def test_get_latest_entity_from_feature_view(make_guizhou_traffic):
65 | project_folder = make_guizhou_traffic("pgsql")
66 | store = FeatureStore(project_folder)
67 | measured_time = timeit.timeit(
68 | lambda: store.get_latest_entities("gy_link_travel_time_features"), number=10
69 | )
70 | print(f"get_latest_entities performance pgsql: {measured_time}s")
71 |
72 |
73 | def test_get_period_features_from_feature_view(make_guizhou_traffic):
74 | project_folder = make_guizhou_traffic("pgsql")
75 | store = FeatureStore(project_folder)
76 | entity_df = get_guizhou_traffic_entities(store)
77 | measured_time = timeit.timeit(
78 | lambda: store.get_period_features("gy_link_travel_time_features", entity_df, period="10 minutes"),
79 | number=10,
80 | )
81 | print(f"get_features performance pgsql: {measured_time}s")
82 |
83 |
84 | def test_dataset_to_pytorch_pgsql(make_guizhou_traffic):
85 | project_folder = make_guizhou_traffic("pgsql")
86 | store = FeatureStore(project_folder)
87 | events_sampler = EvenEventsSampler(start="2016-03-05 00:00:00", end="2016-03-06 00:00:00", period='1 hours')
88 |
89 | ds = store.get_dataset(
90 | service="traval_time_prediction_embedding_v1",
91 | sampler=NoEntitiesSampler(events_sampler),
92 | )
93 | measured_time = timeit.timeit(lambda: list(ds.to_pytorch(64)), number=1)
94 | print(f"dataset.to_pytorch pgsql performance: {measured_time}s")
95 |
96 |
97 | def test_offline_materialize(make_guizhou_traffic):
98 | project_folder = make_guizhou_traffic("pgsql")
99 | store = FeatureStore(project_folder)
100 | backoff_time = BackOffTime(
101 | start="2016-03-01 08:02:00+08", end="2016-03-01 08:06:00+08", step="4 minutes"
102 | )
103 | measured_time = timeit.timeit(
104 | lambda: store.materialize(
105 | service_or_views="traval_time_prediction_embedding_v1",
106 | back_off_time=backoff_time,
107 | ),
108 | number=1,
109 | )
110 |
111 | print(f"materialize performance pgsql: {measured_time}s")
112 |
--------------------------------------------------------------------------------
/f2ai/persist_engine/offline_pgsql_persistengine.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import List
3 | from pypika import Query, Parameter, Table, PostgreSQLQuery
4 |
5 | from ..definitions import (
6 | SqlSource,
7 | OfflinePersistEngine,
8 | OfflinePersistEngineType,
9 | BackOffTime,
10 | PersistFeatureView,
11 | PersistLabelView,
12 | )
13 | from ..offline_stores.offline_postgres_store import OfflinePostgresStore
14 | from ..common.time_field import (
15 | DEFAULT_EVENT_TIMESTAMP_FIELD,
16 | ENTITY_EVENT_TIMESTAMP_FIELD,
17 | MATERIALIZE_TIME,
18 | )
19 |
20 |
21 | class OfflinePgsqlPersistEngine(OfflinePersistEngine):
22 |
23 | type: OfflinePersistEngineType = OfflinePersistEngineType.PGSQL
24 |
25 | offline_store: OfflinePostgresStore
26 |
27 | def materialize(
28 | self,
29 | feature_views: List[PersistFeatureView],
30 | label_view: PersistLabelView,
31 | destination: SqlSource,
32 | back_off_time: BackOffTime,
33 | ):
34 | join_sql = self.store.read(
35 | source=label_view.source,
36 | features=label_view.labels,
37 | join_keys=label_view.join_keys,
38 | alias=ENTITY_EVENT_TIMESTAMP_FIELD,
39 | ).where(
40 | (Parameter(DEFAULT_EVENT_TIMESTAMP_FIELD) <= back_off_time.end)
41 | & (Parameter(DEFAULT_EVENT_TIMESTAMP_FIELD) >= back_off_time.start)
42 | )
43 |
44 | feature_names = []
45 | label_names = [label.name for label in label_view.labels]
46 | for feature_view in feature_views:
47 | source_sql = self.store.read(
48 | source=feature_view.source,
49 | features=feature_view.features,
50 | join_keys=feature_view.join_keys,
51 | )
52 | feature_names += [feature.name for feature in feature_view.features]
53 |
54 | keep_columns = label_view.join_keys + feature_names + label_names + [ENTITY_EVENT_TIMESTAMP_FIELD]
55 | join_sql = self.store._point_in_time_join(
56 | entity_df=join_sql,
57 | source_df=source_sql,
58 | timestamp_field=feature_view.source.timestamp_field,
59 | created_timestamp_field=feature_view.source.created_timestamp_field,
60 | ttl=feature_view.ttl,
61 | join_keys=feature_view.join_keys,
62 | include=True,
63 | how="right",
64 | ).select(Parameter(f"{', '.join(keep_columns)}"))
65 |
66 | data_columns = label_view.join_keys + feature_names + label_names
67 | unique_columns = label_view.join_keys + [DEFAULT_EVENT_TIMESTAMP_FIELD]
68 | join_sql = Query.from_(join_sql).select(
69 | Parameter(f"{', '.join(data_columns)}"),
70 | Parameter(f"{ENTITY_EVENT_TIMESTAMP_FIELD} as {DEFAULT_EVENT_TIMESTAMP_FIELD}"),
71 | Parameter(f"current_timestamp as {MATERIALIZE_TIME}"),
72 | )
73 |
74 | with self.store.psy_conn as con:
75 | with con.cursor() as cursor:
76 | cursor.execute(f"select to_regclass('{destination.query}')")
77 | (table_name,) = cursor.fetchone()
78 | is_table_exists = table_name in destination.query
79 |
80 | if not is_table_exists:
81 | # create table from select.
82 | cursor.execute(
83 | Query.create_table(destination.query).as_select(join_sql).get_sql(quote_char="")
84 | )
85 |
86 | # add unique constraint
87 | # TODO: vs unique index.
88 | cursor.execute(
89 | f"alter table {destination.query} add constraint unique_key_{destination.query.split('.')[-1]} unique ({Parameter(', '.join(unique_columns))})"
90 | )
91 | else:
92 | table = Table(destination.query)
93 | all_columns = data_columns + [DEFAULT_EVENT_TIMESTAMP_FIELD] + [MATERIALIZE_TIME]
94 |
95 | insert_sql = (
96 | PostgreSQLQuery.into(table)
97 | .columns(*all_columns)
98 | .from_(join_sql)
99 | .select(Parameter(f"{','.join(all_columns)}"))
100 | .on_conflict(*unique_columns)
101 | )
102 | for c in all_columns:
103 | insert_sql = insert_sql.do_update(table.field(c), Parameter(f"excluded.{c}"))
104 |
105 | cursor.execute(insert_sql.get_sql(quote_char=""))
106 |
--------------------------------------------------------------------------------
/f2ai/definitions/period.py:
--------------------------------------------------------------------------------
1 | import re
2 | import pandas as pd
3 | from pydantic import BaseModel
4 | from enum import Enum
5 | from typing import Any, List
6 | from datetime import timedelta
7 | from functools import reduce, total_ordering
8 |
9 |
10 | class AvailablePeriods(Enum):
11 | """Available Period definitions which supported by F2AI."""
12 |
13 | YEARS = "years"
14 | MONTHS = "months"
15 | WEEKS = "weeks"
16 | DAYS = "days"
17 | HOURS = "hours"
18 | MINUTES = "minutes"
19 | SECONDS = "seconds"
20 | MILLISECONDS = "milliseconds"
21 | MICROSECONDS = "microseconds"
22 | NANOSECONDS = "nanoseconds"
23 |
24 |
25 | PANDAS_TIME_COMPONENTS_MAP = {
26 | AvailablePeriods.YEARS: "year",
27 | AvailablePeriods.MONTHS: "month",
28 | AvailablePeriods.WEEKS: "day",
29 | AvailablePeriods.DAYS: "day",
30 | AvailablePeriods.HOURS: "hour",
31 | AvailablePeriods.MINUTES: "minute",
32 | AvailablePeriods.SECONDS: "second",
33 | AvailablePeriods.MILLISECONDS: "microsecond",
34 | AvailablePeriods.MICROSECONDS: "microsecond",
35 | AvailablePeriods.NANOSECONDS: "nanosecond",
36 | }
37 |
38 | PANDAS_FREQ_STR_MAP = {
39 | AvailablePeriods.YEARS: "YS",
40 | AvailablePeriods.MONTHS: "MS",
41 | AvailablePeriods.WEEKS: "W",
42 | AvailablePeriods.DAYS: "D",
43 | AvailablePeriods.HOURS: "H",
44 | AvailablePeriods.MINUTES: "T",
45 | AvailablePeriods.SECONDS: "S",
46 | AvailablePeriods.MILLISECONDS: "L",
47 | AvailablePeriods.MICROSECONDS: "U",
48 | AvailablePeriods.NANOSECONDS: "N",
49 | }
50 |
51 |
52 | @total_ordering
53 | class Period(BaseModel):
54 | """A wrapper of different representations of a time range. Useful to convert to underline utils like pandas DateOffset, Postgres interval strings."""
55 |
56 | n: int = 1
57 | unit: AvailablePeriods = AvailablePeriods.DAYS
58 |
59 | def __init__(__pydantic_self__, **data: Any) -> None:
60 | if not data.get("unit", "s").endswith("s"):
61 | data["unit"] = data.get("unit", "") + "s"
62 | super().__init__(**data)
63 |
64 | def __str__(self) -> str:
65 | return f"{self.n} {self.unit.value}"
66 |
67 | def __neg__(self) -> "Period":
68 | return Period(n=-self.n, unit=self.unit.value)
69 |
70 | def __eq__(self, other: "Period") -> bool:
71 | return self.n == other.n and self.unit == other.unit
72 |
73 | def __lt__(self, other: "Period") -> bool:
74 | assert self.unit == other.unit, f"Different unit of Period are not comparable, left: {self.unit.value}, right: {other.unit.value}"
75 | return abs(self.n) < abs(other.n)
76 |
77 | @property
78 | def is_neg(self):
79 | return self.n < 0
80 |
81 | def to_pandas_dateoffset(self, normalize=False):
82 | from pandas import DateOffset
83 |
84 | return DateOffset(**{self.unit.value: self.n}, normalize=normalize)
85 |
86 | def to_pgsql_interval(self):
87 | return f"interval '{self.n} {self.unit.value}'"
88 |
89 | def to_py_timedelta(self):
90 | if self.unit == AvailablePeriods.YEARS:
91 | return timedelta(days=365 * self.n)
92 | if self.unit == AvailablePeriods.MONTHS:
93 | return timedelta(days=30 * self.n)
94 | return timedelta(**{self.unit.value: self.n})
95 |
96 | def to_pandas_freq_str(self):
97 | return f'{self.n}{PANDAS_FREQ_STR_MAP[self.unit]}'
98 |
99 | @classmethod
100 | def from_str(cls, s: str):
101 | """Construct a period from str, egg: 10 years, 1day, -1 month.
102 |
103 | Args:
104 | s (str): string representation of a period
105 | """
106 | n, unit = re.search("(-?\d+)\s?(\w+)", s).groups()
107 | return cls(n=int(n), unit=unit)
108 |
109 | def get_pandas_datetime_components(self) -> List[str]:
110 | index_of_period = list(PANDAS_TIME_COMPONENTS_MAP.keys()).index(self.unit)
111 | components = list(PANDAS_TIME_COMPONENTS_MAP.values())[: index_of_period + 1]
112 |
113 | return reduce(lambda xs, x: xs + [x] if x not in xs else xs, components, [])
114 |
115 | def normalize(self, dt: pd.Timestamp, norm_type: str):
116 | if self.unit in {
117 | AvailablePeriods.YEARS,
118 | AvailablePeriods.MONTHS,
119 | AvailablePeriods.WEEKS,
120 | AvailablePeriods.DAYS,
121 | }:
122 | if norm_type == "ceil":
123 | return dt.normalize() + self.to_pandas_dateoffset()
124 | else:
125 | return dt.normalize()
126 |
127 | freq = self.to_pandas_freq_str()
128 | if norm_type == "floor":
129 | return dt.floor(freq)
130 | elif norm_type == "ceil":
131 | return dt.ceil(freq)
132 | else:
133 | return dt.round(freq)
134 |
--------------------------------------------------------------------------------
/f2ai/definitions/features.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from enum import Enum
3 | from typing import Optional, TYPE_CHECKING, List, Dict
4 | from pydantic import BaseModel
5 |
6 | from .dtypes import FeatureDTypes, NUMERIC_FEATURE_D_TYPES
7 |
8 | if TYPE_CHECKING:
9 | from .base_view import BaseView
10 |
11 |
12 | class SchemaType(str, Enum):
13 | """Schema used to describe a data column in a table. We only have 2 options in such context, feature or label. Label is the observation data which usually appear in supervised machine learning. In F2AI, label is treated as a special feature."""
14 |
15 | FEATURE = 0
16 | LABEL = 1
17 |
18 |
19 | class FeatureSchema(BaseModel):
20 | """A FeatureSchema is used to describe a data column but no table information included."""
21 |
22 | name: str
23 | description: Optional[str]
24 | dtype: FeatureDTypes
25 |
26 | def is_numeric(self):
27 | if self.dtype in NUMERIC_FEATURE_D_TYPES:
28 | return True
29 | return False
30 |
31 |
32 | class SchemaAnchor(BaseModel):
33 | """
34 | SchemaAnchor links a view to a group of FeatureSchemas with period information included if it has.
35 | """
36 |
37 | view_name: str
38 | schema_name: str
39 | period: Optional[str]
40 |
41 | @classmethod
42 | def from_strs(cls, cfgs: List[str]) -> "List[SchemaAnchor]":
43 | """Construct from a list of strings.
44 |
45 | Args:
46 | cfgs (List[str])
47 |
48 | Returns:
49 | List[SchemaAnchor]
50 | """
51 | return [cls.from_str(cfg) for cfg in cfgs]
52 |
53 | @classmethod
54 | def from_str(cls, cfg: str) -> "SchemaAnchor":
55 | """Construct from a string.
56 |
57 | Args:
58 | cfg (str): a string with specific format, egg: {feature_view_name}:{feature_name}:{period}
59 |
60 | Returns:
61 | SchemaAnchor
62 | """
63 | components = cfg.split(":")
64 |
65 | if len(components) < 2:
66 | raise ValueError("Please indicate features in table:feature format")
67 | elif len(components) > 3:
68 | raise ValueError("Please make sure colon not in name of table or features")
69 | elif len(components) == 2:
70 | view_name, schema_name = components
71 | return cls(view_name=view_name, schema_name=schema_name)
72 | elif len(components) == 3:
73 | view_name, schema_name, period = components
74 | return cls(view_name=view_name, schema_name=schema_name, period=period)
75 |
76 | def get_features_from_views(self, views: Dict[str, BaseView], is_numeric=False) -> List[Feature]:
77 | """With given views, construct a series of features based on this SchemaAnchor.
78 |
79 | Args:
80 | views (Dict[str, BaseView])
81 | is_numeric (bool, optional): If only return numeric features. Defaults to False.
82 |
83 | Returns:
84 | List[Feature]
85 | """
86 | from .feature_view import FeatureView
87 |
88 | view: BaseView = views[self.view_name]
89 | schema_type = SchemaType.FEATURE if isinstance(view, FeatureView) else SchemaType.LABEL
90 |
91 | if self.schema_name == "*":
92 | return [
93 | Feature.create_from_schema(feature_schema, view.name, schema_type, self.period)
94 | for feature_schema in view.schemas
95 | if (feature_schema.is_numeric() if is_numeric else True)
96 | ]
97 |
98 | feature_schema = next((schema for schema in view.schemas if schema.name == self.schema_name), None)
99 | if feature_schema and (feature_schema.is_numeric() if is_numeric else True):
100 | return [Feature.create_from_schema(feature_schema, view.name, schema_type, self.period)]
101 |
102 | return []
103 |
104 |
105 | class Feature(BaseModel):
106 | """A Feature which include all necessary information which F2AI should know."""
107 |
108 | name: str
109 | dtype: FeatureDTypes
110 | period: Optional[str]
111 | schema_type: SchemaType = SchemaType.FEATURE
112 | view_name: str
113 |
114 | @classmethod
115 | def create_feature_from_schema(
116 | cls, schema: FeatureSchema, view_name: str, period: str = None
117 | ) -> "Feature":
118 | return cls.create_from_schema(schema, view_name, SchemaType.FEATURE, period)
119 |
120 | @classmethod
121 | def create_label_from_schema(cls, schema: FeatureSchema, view_name: str, period: str = None) -> "Feature":
122 | return cls.create_from_schema(schema, view_name, SchemaType.LABEL, period)
123 |
124 | @classmethod
125 | def create_from_schema(
126 | cls, schema: FeatureSchema, view_name: str, schema_type: SchemaType, period: str
127 | ) -> "Feature":
128 | return Feature(
129 | name=schema.name, dtype=schema.dtype, schema_type=schema_type, view_name=view_name, period=period
130 | )
131 |
132 | def __hash__(self) -> int:
133 | return hash(f"{self.view_name}:{self.name}:{self.period}, {self.schema_type}")
134 |
--------------------------------------------------------------------------------
/f2ai/dataset/entities_sampler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import abc
4 | from typing import Any, Dict
5 |
6 | from ..common.utils import is_iterable
7 | from .events_sampler import EventsSampler
8 |
9 |
10 | class EntitiesSampler:
11 | @property
12 | def iterable(self):
13 | return is_iterable(self)
14 |
15 | @property
16 | def iterable_only(self):
17 | is_call_abstract = getattr(self.__call__, "__isabstractmethod__", False)
18 | if self.iterable:
19 | return is_call_abstract
20 | return False
21 |
22 | @abc.abstractmethod
23 | def __call__(self, *args: Any, **kwds: Any) -> pd.DataFrame:
24 | """
25 | Get all sample results, which usually is an entity_df
26 |
27 | Returns:
28 | pd.DataFrame: An entity DataFrame used to query features or labels
29 | """
30 | pass
31 |
32 | @abc.abstractmethod
33 | def __iter__(self, *args: Any, **kwds: Any) -> Dict:
34 | """
35 | Iterable way to get sample result.
36 |
37 | Returns:
38 | Dict: get one entity item, which contains entity keys and event_timestamp.
39 | """
40 | pass
41 |
42 |
43 | # TODO: test this
44 | class NoEntitiesSampler(EntitiesSampler):
45 | """
46 | This class will directly convert an EventsSampler to EntitiesSampler. Useful when no entity keys are exists.
47 | """
48 |
49 | def __init__(self, events_sampler: EventsSampler) -> None:
50 | super().__init__()
51 |
52 | self._events_sampler = events_sampler
53 |
54 | def __call__(self) -> pd.DataFrame:
55 | return pd.DataFrame({"event_timestamp": self._events_sampler()})
56 |
57 | def __iter__(self) -> Dict:
58 | for event_timestamp in self._events_sampler:
59 | yield {"event_timestamp": event_timestamp}
60 |
61 |
62 | class EvenEntitiesSampler(EntitiesSampler):
63 | """
64 | Sample every group with given events_sampler.
65 | """
66 |
67 | def __init__(
68 | self, events_sampler: EventsSampler, group_df: pd.DataFrame, random_state: int = None
69 | ) -> None:
70 | super().__init__()
71 |
72 | self._events_sampler = events_sampler
73 | self._rng = np.random.default_rng(random_state)
74 |
75 | self._group_df = group_df
76 |
77 | def __call__(self) -> pd.DataFrame:
78 | return pd.merge(
79 | pd.DataFrame({"event_timestamp": self._events_sampler()}), self._group_df, how="cross"
80 | )[list(self._group_df.columns) + ["event_timestamp"]]
81 |
82 | def __iter__(self) -> Dict:
83 | for event_timestamp in iter(self._events_sampler):
84 | for _, entity_row in self._group_df.iterrows():
85 | d = entity_row.to_dict()
86 | d["event_timestamp"] = event_timestamp
87 | yield d
88 |
89 |
90 | class FixedNEntitiesSampler(EntitiesSampler):
91 | """
92 | Sample N instance from each group with given probability.
93 | """
94 |
95 | def __init__(self, events_sampler: EventsSampler, group_df: pd.DataFrame, n=1, random_state=None) -> None:
96 | """
97 | Args:
98 | events_sampler (EventTimestampSampler): an EventTimestampSampler instance.
99 | group_df (pd.DataFrame): a group_df is a dataframe with contains some entity columns. Optionally, it may contains a columns named `p`, which indicates the probability of this this group.
100 | n (int, optional): how many instance per group. Defaults to 1.
101 | random_state (_type_, optional): Defaults to None.
102 | """
103 |
104 | super().__init__()
105 |
106 | self._events_sampler = events_sampler
107 | self._rng = np.random.default_rng(seed=random_state)
108 |
109 | self._event_timestamps = self._events_sampler()
110 |
111 | if "p" in group_df.columns:
112 | assert group_df["p"].sum() == 1, "sum all weights should be 1"
113 |
114 | take_n = (group_df["p"] * len(group_df) * n).round().astype(int).rename("_f2ai_take_n_")
115 | self._group_df = pd.concat([group_df, take_n], axis=1)
116 | else:
117 | self._group_df = group_df.assign(_f2ai_take_n_=n)
118 |
119 | def _sample_n(self, row: pd.DataFrame):
120 | event_timestamps = self._rng.choice(self._event_timestamps, size=row["_f2ai_take_n_"].iloc[0])
121 | return pd.merge(
122 | row.drop(columns=["p", "_f2ai_take_n_"], errors="ignore"),
123 | pd.DataFrame({"event_timestamp": event_timestamps}),
124 | how="cross",
125 | )
126 |
127 | def __call__(self) -> pd.DataFrame:
128 | group_keys = [column for column in self._group_df.columns if column not in {"p", "_f2ai_take_n_"}]
129 | return (
130 | self._group_df.groupby(group_keys, group_keys=False, sort=False)
131 | .apply(self._sample_n)
132 | .reset_index(drop=True)
133 | )
134 |
135 | def __iter__(self) -> Dict:
136 | for i, row in self._group_df.iterrows():
137 | sampled_df = self._sample_n(pd.DataFrame([row]))
138 | for j, row in sampled_df.iterrows():
139 | yield row.to_dict()
140 |
--------------------------------------------------------------------------------
/tests/integrations/benchmarks/offline_file_benchmark_test.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import timeit
3 | from os import path
4 | from f2ai import FeatureStore
5 | from f2ai.dataset import EvenEventsSampler, FixedNEntitiesSampler, NoEntitiesSampler
6 | from f2ai.definitions import BackOffTime, Period
7 |
8 | LINE_LIMIT = 1000
9 |
10 |
11 | def get_credit_score_entities(project_folder: str):
12 | query_entities = pd.read_parquet(
13 | path.join(project_folder, "row_data/loan_table.parquet"),
14 | columns=["loan_id", "dob_ssn", "zipcode", "created_timestamp", "event_timestamp"],
15 | ).iloc[:LINE_LIMIT]
16 | return query_entities.astype(
17 | {
18 | "loan_id": "string",
19 | "dob_ssn": "string",
20 | "zipcode": "string",
21 | }
22 | )
23 |
24 |
25 | def get_guizhou_traffic_entities(project_folder: str):
26 | query_entities = pd.read_csv(
27 | path.join(project_folder, "raw_data/gy_link_travel_time.csv"),
28 | usecols=["link_id", "event_timestamp"],
29 | nrows=5,
30 | )
31 | return query_entities.astype({"link_id": "string"})
32 |
33 |
34 | def test_get_features_from_feature_view(make_credit_score):
35 | project_folder = make_credit_score("file")
36 | entity_df = get_credit_score_entities(project_folder)
37 | store = FeatureStore(project_folder)
38 | store.get_features("zipcode_features", entity_df)
39 |
40 | measured_time = timeit.timeit(lambda: store.get_features("zipcode_features", entity_df), number=10)
41 | print(f"get_features performance: {measured_time}s")
42 |
43 |
44 | def test_get_labels_from_label_views(make_credit_score):
45 | project_folder = make_credit_score("file")
46 | entity_df = get_credit_score_entities(project_folder)
47 | store = FeatureStore(project_folder)
48 | store.get_labels("loan_label_view", entity_df)
49 |
50 |
51 | def test_materialize(make_credit_score):
52 | project_folder = make_credit_score("file")
53 | store = FeatureStore(project_folder)
54 | back_off_time = BackOffTime(start="2020-08-25", end="2021-08-26", step="1 month")
55 | measured_time = timeit.timeit(
56 | lambda: store.materialize(
57 | service_or_views="credit_scoring_v1", back_off_time=back_off_time, online=False
58 | ),
59 | number=1,
60 | )
61 | print(f"materialize performance: {measured_time}s")
62 |
63 |
64 | def test_get_features_from_service(make_credit_score):
65 | """this test should run after materialize"""
66 | project_folder = make_credit_score("file")
67 | store = FeatureStore(project_folder)
68 | entity_df = get_credit_score_entities(project_folder)
69 | store.get_features("credit_scoring_v1", entity_df)
70 |
71 |
72 | def test_get_period_features_from_feature_view(make_guizhou_traffic):
73 | project_folder = make_guizhou_traffic("file")
74 | entity_df = get_guizhou_traffic_entities(project_folder)
75 | store = FeatureStore(project_folder)
76 | measured_time = timeit.timeit(
77 | lambda: store.get_period_features("gy_link_travel_time_features", entity_df, period="20 minutes"),
78 | number=10,
79 | )
80 | print(f"get_period_features performance: {measured_time}s")
81 |
82 |
83 | def test_stats_from_feature_view(make_credit_score):
84 | project_folder = make_credit_score("file")
85 | store = FeatureStore(project_folder)
86 |
87 | measured_time = timeit.timeit(lambda: store.stats("loan_features", fn="avg"), number=10)
88 | print(f"stats performance: {measured_time}s")
89 |
90 |
91 | def test_get_latest_entities_from_feature_view(make_credit_score):
92 | project_folder = make_credit_score("file")
93 | store = FeatureStore(project_folder)
94 | measured_time = timeit.timeit(
95 | lambda: store.get_latest_entities("loan_features", pd.DataFrame({"dob_ssn": ["19960703_3449"]})),
96 | number=10,
97 | )
98 | print(f"get_latest_entities performance: {measured_time}s")
99 |
100 |
101 | def test_get_latest_entity_from_feature_view(make_credit_score):
102 | project_folder = make_credit_score("file")
103 | store = FeatureStore(project_folder)
104 | measured_time = timeit.timeit(lambda: store.get_latest_entities("loan_features"), number=10)
105 | print(f"get_latest_entities performance: {measured_time}s")
106 |
107 |
108 | def test_sampler_with_groups(make_credit_score):
109 | project_folder = make_credit_score("file")
110 | store = FeatureStore(project_folder)
111 | group_df = store.stats("loan_features", group_keys=["zipcode", "dob_ssn"], fn="unique")
112 | events_sampler = EvenEventsSampler(start="2020-08-20", end="2021-08-30", period="10 days")
113 | entities_sampler = FixedNEntitiesSampler(
114 | events_sampler,
115 | group_df=group_df,
116 | )
117 | measured_time = timeit.timeit(
118 | lambda: entities_sampler(),
119 | number=1,
120 | )
121 | print(f"sampler with groups performance: {measured_time}s")
122 |
123 |
124 | def test_dataset_to_pytorch(make_credit_score):
125 | project_folder = make_credit_score("file")
126 | backoff = BackOffTime(
127 | start=pd.Timestamp("2020-05-01"), end=pd.Timestamp("2020-07-01"), step=Period.from_str("1 month")
128 | )
129 | store = FeatureStore(project_folder)
130 | store.materialize(service_or_views="credit_scoring_v1", back_off_time=backoff)
131 |
132 | events_sampler = EvenEventsSampler(start="2020-08-20", end="2021-08-30", period="10 days")
133 | ds = store.get_dataset(
134 | service="credit_scoring_v1",
135 | sampler=NoEntitiesSampler(events_sampler),
136 | )
137 | measured_time = timeit.timeit(lambda: list(ds.to_pytorch()), number=1)
138 | print(f"dataset.to_pytorch performance: {measured_time}s")
139 |
140 |
141 | def test_online_materialize(make_credit_score):
142 | project_folder = make_credit_score("file")
143 | back_off_time = BackOffTime(
144 | start=pd.Timestamp("2020-05-01"), end=pd.Timestamp("2020-07-01"), step=Period.from_str("1 month")
145 | )
146 | store = FeatureStore(project_folder)
147 | store.materialize(service_or_views="credit_scoring_v1", back_off_time=back_off_time, online=True)
148 |
--------------------------------------------------------------------------------
/f2ai/common/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import json
3 | import os
4 | import oss2
5 | import pandas as pd
6 | from typing import List, Tuple, Iterable, Any
7 | from pathlib import Path
8 |
9 | from ..definitions import Period
10 |
11 |
12 | class DateEncoder(json.JSONEncoder):
13 | def default(self, o):
14 | if isinstance(o, datetime.datetime):
15 | return o.strftime("%Y-%m-%dT%H:%M:%S")
16 | else:
17 | return json.JSONEncoder.default(self, o)
18 |
19 |
20 | def remove_prefix(text: str, prefix: str):
21 | return text[text.startswith(prefix) and len(prefix) :]
22 |
23 |
24 | def get_default_value():
25 | return None
26 |
27 |
28 | def schema_to_dict(schema):
29 | return {item["name"]: item.get("dtype", "string") for item in schema}
30 |
31 |
32 | def read_file(
33 | path,
34 | parse_dates: List[str] = [],
35 | str_cols: List[str] = [],
36 | keep_cols: List[str] = [],
37 | file_format=None,
38 | ):
39 | path = Path(remove_prefix(path, "file://"))
40 | dtypes = {en: str for en in str_cols}
41 | usecols = list(dict.fromkeys(keep_cols + parse_dates + str_cols))
42 |
43 | if file_format is None:
44 | file_format = path.parts[-1].split(".")[-1]
45 |
46 | if path.is_dir():
47 | df = read_df_from_dataset(path, usecols=usecols).astype(dtypes)
48 | elif file_format.startswith("parq"):
49 | df = pd.read_parquet(path, columns=usecols).astype(dtypes)
50 | elif file_format.startswith("tsv"):
51 | df = pd.read_csv(path, sep="\t", parse_dates=parse_dates, dtype=dtypes, usecols=usecols)
52 | elif file_format.startswith("txt"):
53 | df = pd.read_csv(path, sep=" ", parse_dates=parse_dates, dtype=dtypes, usecols=usecols)
54 | else:
55 | df = pd.read_csv(path, parse_dates=parse_dates, dtype=dtypes, usecols=usecols)
56 |
57 | return df
58 |
59 |
60 | def to_file(file: pd.DataFrame, path, type, mode="w", header=True):
61 |
62 | path = remove_prefix(path, "file://")
63 | if type.startswith("parq"):
64 | file.to_parquet(path, index=False)
65 | elif type.startswith("tsv"):
66 | file.to_csv(path, sep="\t", index=False, mode=mode, header=header)
67 | elif type.startswith("txt"):
68 | file.to_csv(path, sep=" ", index=False, mode=mode, header=header)
69 | else:
70 | file.to_csv(path, index=False, mode=mode, header=header)
71 |
72 |
73 | def get_bucket(bucket, endpoint=None):
74 | key_id = os.environ.get("OSS_ACCESS_KEY_ID")
75 | key_secret = os.environ.get("OSS_ACCESS_KEY_SECRET")
76 | endpoint = endpoint or os.environ.get("OSS_ENDPOINT")
77 |
78 | return oss2.Bucket(oss2.Auth(key_id, key_secret), endpoint, bucket)
79 |
80 |
81 | def parse_oss_url(url: str) -> Tuple[str, str, str]:
82 | """
83 | url format: oss://{bucket}/{key}
84 | """
85 | url = remove_prefix(url, "oss://")
86 | components = url.split("/")
87 | return components[0], "/".join(components[1:])
88 |
89 |
90 | def get_bucket_from_oss_url(url: str):
91 | bucket_name, key = parse_oss_url(url)
92 | return get_bucket(bucket_name), key
93 |
94 |
95 | # the code below copied from https://github.com/pandas-dev/pandas/blob/91111fd99898d9dcaa6bf6bedb662db4108da6e6/pandas/io/sql.py#L1155
96 | def convert_dtype_to_sqlalchemy_type(col):
97 | from pandas._libs.lib import infer_dtype
98 | from sqlalchemy.types import (
99 | TIMESTAMP,
100 | BigInteger,
101 | Boolean,
102 | Date,
103 | DateTime,
104 | Float,
105 | Integer,
106 | SmallInteger,
107 | Text,
108 | Time,
109 | )
110 |
111 | col_type = infer_dtype(col, skipna=True)
112 |
113 | if col_type == "datetime64" or col_type == "datetime":
114 | try:
115 | if col.dt.tz is not None:
116 | return TIMESTAMP(timezone=True)
117 | except AttributeError:
118 | if getattr(col, "tz", None) is not None:
119 | return TIMESTAMP(timezone=True)
120 | return DateTime
121 |
122 | if col_type == "timedelta64":
123 | return BigInteger
124 | elif col_type == "floating":
125 | if col.dtype == "float32":
126 | return Float(precision=23)
127 | else:
128 | return Float(precision=53)
129 | elif col_type == "integer":
130 | if col.dtype.name.lower() in ("int8", "uint8", "int16"):
131 | return SmallInteger
132 | elif col.dtype.name.lower() in ("uint16", "int32"):
133 | return Integer
134 | elif col.dtype.name.lower() == "uint64":
135 | raise ValueError("Unsigned 64 bit integer datatype is not supported")
136 | else:
137 | return BigInteger
138 | elif col_type == "boolean":
139 | return Boolean
140 | elif col_type == "date":
141 | return Date
142 | elif col_type == "time":
143 | return Time
144 | elif col_type == "complex":
145 | raise ValueError("Complex datatypes not supported")
146 |
147 | return Text
148 |
149 |
150 | def write_df_to_dataset(data: pd.DataFrame, root_path: str, time_col: str, period: Period):
151 | from pyarrow import parquet, Table
152 |
153 | # create partitioning cols
154 | times = pd.to_datetime(data[time_col])
155 | components = dict()
156 | for component in period.get_pandas_datetime_components():
157 | components[f"_f2ai_{component}_"] = getattr(times.dt, component)
158 |
159 | data_with_components = pd.concat(
160 | [
161 | pd.DataFrame(components),
162 | data,
163 | ],
164 | axis=1,
165 | )
166 | table = Table.from_pandas(data_with_components)
167 | parquet.write_to_dataset(table, root_path=root_path, partition_cols=components)
168 |
169 |
170 | def read_df_from_dataset(root_path: str, usecols: List[str] = []) -> pd.DataFrame:
171 | from pyarrow.parquet import ParquetDataset
172 |
173 | ds = ParquetDataset(root_path, use_legacy_dataset=True)
174 | table = ds.read(columns=usecols)
175 | df: pd.DataFrame = table.to_pandas()
176 | drop_columns = [col_name for col_name in df.columns if col_name.startswith("_f2ai_")]
177 | return df.drop(columns=drop_columns)
178 |
179 |
180 | def batched(xs: Iterable, batch_size: int, drop_last=False):
181 | batches = []
182 | for x in xs:
183 | batches.append(x)
184 |
185 | if len(batches) == batch_size:
186 | yield batches
187 | batches = []
188 |
189 | if len(batches) > 0 and not drop_last:
190 | yield batches
191 |
192 |
193 | def is_iterable(o: Any):
194 | try:
195 | iter(o)
196 | return True
197 | except TypeError:
198 | return False
199 |
--------------------------------------------------------------------------------
/f2ai/definitions/services.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Dict
2 | from pydantic import BaseModel
3 |
4 | from .entities import Entity
5 | from .features import SchemaAnchor, Feature
6 | from .feature_view import FeatureView
7 | from .label_view import LabelView
8 |
9 |
10 | class Service(BaseModel):
11 | """A Service is a combination of a group of feature views and label views, which usually directly related to a certain AI model. Tbe best practice which F2AI suggested is, treating services are immutable. Egg: if you want to train different combinations of features for a specific AI model, you may want to create multiple Services, like: linear_reg_v1, linear_reg_v2."""
12 |
13 | name: str
14 | description: Optional[str]
15 | features: List[SchemaAnchor] = []
16 | labels: List[SchemaAnchor] = []
17 | ttl: Optional[str] = None
18 |
19 | @classmethod
20 | def from_yaml(cls, cfg: Dict) -> "Service":
21 | """Construct a Service from parsed yaml config file."""
22 |
23 | cfg["features"] = SchemaAnchor.from_strs(cfg.pop("features", []))
24 | cfg["labels"] = SchemaAnchor.from_strs(cfg.pop("labels", []))
25 |
26 | return cls(**cfg)
27 |
28 | def get_feature_names(self, feature_views: Dict[str, FeatureView], is_numeric=False) -> List[str]:
29 | return [feature.name for feature in self.get_feature_objects(feature_views, is_numeric)]
30 |
31 | def get_label_names(self, label_views: Dict[str, FeatureView], is_numeric=False) -> List[str]:
32 | return [label.name for label in self.get_label_objects(label_views, is_numeric)]
33 |
34 | def get_feature_objects(self, feature_views: Dict[str, FeatureView], is_numeric=False) -> List[Feature]:
35 | """get all the feature objects which included in this service based on features' schema anchor.
36 |
37 | Args:
38 | feature_views (Dict[str, FeatureView]): A group of FeatureViews.
39 | is_numeric (bool, optional): If only include numeric features. Defaults to False.
40 |
41 | Returns:
42 | List[Feature]
43 | """
44 | return list(
45 | dict.fromkeys(
46 | feature
47 | for schema_anchor in self.features
48 | for feature in schema_anchor.get_features_from_views(feature_views, is_numeric)
49 | )
50 | )
51 |
52 | def get_label_objects(self, label_views: Dict[str, LabelView], is_numeric=False) -> List[Feature]:
53 | """get all the label objects which included in this service based on labels' schema anchor.
54 |
55 | Args:
56 | feature_views (Dict[str, LabelView]): A group of LabelViews.
57 | is_numeric (bool, optional): If only include numeric labels. Defaults to False.
58 |
59 | Returns:
60 | List[Feature]
61 | """
62 | return list(
63 | dict.fromkeys(
64 | label
65 | for schema_anchor in self.labels
66 | for label in schema_anchor.get_features_from_views(label_views, is_numeric)
67 | )
68 | )
69 |
70 | def get_feature_view_names(self, feature_views: Dict[str, FeatureView]) -> List[str]:
71 | """
72 | Get the name of feature view names related to this service.
73 |
74 | Args:
75 | feature_views (Dict[str, FeatureView]): list of FeatureViews to filter.
76 |
77 | Returns:
78 | List[str]: names of FeatureView
79 | """
80 | feature_view_names = list(dict.fromkeys([anchor.view_name for anchor in self.features]))
81 | return [x for x in feature_view_names if x in feature_views]
82 |
83 | def get_feature_views(self, feature_views: Dict[str, FeatureView]) -> List[FeatureView]:
84 | """Get FeatureViews of this service. This will automatically filter out the feature view not given by parameters.
85 |
86 | Args:
87 | feature_views (Dict[str, FeatureView])
88 |
89 | Returns:
90 | List[FeatureView]
91 | """
92 | feature_view_names = self.get_feature_view_names(feature_views)
93 | return [feature_views[feature_view_name] for feature_view_name in feature_view_names]
94 |
95 | def get_label_views(self, label_views: Dict[str, LabelView]) -> List[LabelView]:
96 | """Get LabelViews of this service. This will automatically filter out the label view not given by parameters.
97 |
98 | Args:
99 | label_views (Dict[str, LabelView])
100 |
101 | Returns:
102 | List[LabelView]
103 | """
104 | label_view_names = list(dict.fromkeys([anchor.view_name for anchor in self.labels]))
105 | return [label_views[label_view_name] for label_view_name in label_view_names]
106 |
107 | def get_feature_entities(self, feature_views: Dict[str, FeatureView]) -> List[Entity]:
108 | """Get all entities which appeared in related feature views to this service and without duplicate entity.
109 |
110 | Args:
111 | feature_views (Dict[str, FeatureView])
112 |
113 | Returns:
114 | List[Entity]
115 | """
116 | return list(
117 | dict.fromkeys(
118 | entity
119 | for feature_view in self.get_feature_views(feature_views)
120 | for entity in feature_view.entities
121 | )
122 | )
123 |
124 | def get_label_entities(self, label_views: Dict[str, LabelView]) -> List[str]:
125 | """Get all entities which appeared in related label views to this service and without duplicate entity.
126 |
127 | Args:
128 | label_views (Dict[str, LabelView])
129 |
130 | Returns:
131 | List[str]
132 | """
133 | return list(
134 | dict.fromkeys(
135 | entity for label_view in self.get_label_views(label_views) for entity in label_view.entities
136 | )
137 | )
138 |
139 | def get_entities(
140 | self, feature_views: Dict[str, FeatureView], label_views: Dict[str, LabelView]
141 | ) -> List[str]:
142 | """Get all entities which appeared in this service and without duplicate entity.
143 |
144 | Args:
145 | feature_views (Dict[str, FeatureView])
146 | label_views (Dict[str, LabelView])
147 |
148 | Returns:
149 | List[str]
150 | """
151 | return list(
152 | dict.fromkeys(self.get_feature_entities(feature_views) + self.get_label_entities(label_views))
153 | )
154 |
155 | def get_join_keys(
156 | self,
157 | feature_views: Dict[str, FeatureView],
158 | label_views: Dict[str, FeatureView],
159 | entities: Dict[str, Entity],
160 | ) -> List[str]:
161 | return list(
162 | dict.fromkeys(
163 | [
164 | join_key
165 | for x in self.get_entities(feature_views, label_views)
166 | for join_key in entities[x].join_keys
167 | ]
168 | )
169 | )
170 |
--------------------------------------------------------------------------------
/f2ai/online_stores/online_redis_store.py:
--------------------------------------------------------------------------------
1 | import json
2 | import uuid
3 | from datetime import datetime
4 | from typing import Optional, Dict, List, Union
5 | import functools
6 | import pandas as pd
7 | from pydantic import PrivateAttr
8 | from redis import Redis, ConnectionPool
9 |
10 | from ..common.utils import DateEncoder
11 | from ..definitions import OnlineStore, OnlineStoreType, Period, FeatureView, Service, Entity
12 | from ..common.time_field import DEFAULT_EVENT_TIMESTAMP_FIELD, QUERY_COL
13 |
14 |
15 | class OnlineRedisStore(OnlineStore):
16 | type: OnlineStoreType = OnlineStoreType.REDIS
17 | host: str = "localhost"
18 | port: int = 6379
19 | db: int = 0
20 | password: str = ""
21 | name: str
22 |
23 | _cilent: Optional[Redis] = PrivateAttr(default=None)
24 |
25 | @property
26 | def client(self):
27 | if self._cilent is None:
28 | pool = ConnectionPool(
29 | host=self.host, port=self.port, db=self.db, password=self.password, decode_responses=True
30 | )
31 | self._cilent = Redis(connection_pool=pool)
32 |
33 | return self._cilent
34 |
35 | def write_batch(
36 | self,
37 | name: str,
38 | project_name: str,
39 | dt: pd.DataFrame,
40 | ttl: Optional[Period] = None,
41 | join_keys: List[str] = None,
42 | tz: str = None,
43 | ):
44 | pipe = self.client.pipeline()
45 | if not dt.empty:
46 | for group_data in dt.groupby(join_keys):
47 | all_entities = functools.reduce(
48 | lambda x, y: f"{x},{y}",
49 | list(
50 | map(
51 | lambda x, y: x + ":" + y,
52 | join_keys,
53 | [group_data[0]] if isinstance(group_data[0], str) else list(group_data[0]),
54 | )
55 | ),
56 | )
57 | if self.client.hget(f"{project_name}:{name}", all_entities) is None:
58 | zset_key = uuid.uuid4().hex
59 | pipe.hset(name=f"{project_name}:{name}", key=all_entities, value=zset_key)
60 | else:
61 | zset_key = self.client.hget(name=f"{project_name}:{name}", key=all_entities)
62 | # remove data that has expired in `zset`` according to `score`
63 | if ttl is not None:
64 | pipe.zremrangebyscore(
65 | name=zset_key,
66 | min=0,
67 | max=(pd.Timestamp(datetime.now(), tz=tz) - ttl.to_py_timedelta()).timestamp(),
68 | )
69 | zset_dict = {
70 | json.dumps(row, cls=DateEncoder): pd.to_datetime(
71 | row.get(DEFAULT_EVENT_TIMESTAMP_FIELD, pd.Timestamp(datetime.now(), tz=tz)),
72 | ).timestamp()
73 | for row in group_data[1]
74 | .drop(columns=[QUERY_COL], errors="ignore")
75 | .to_dict(orient="records")
76 | }
77 | pipe.zadd(name=zset_key, mapping=zset_dict)
78 | if ttl is not None: # add a general expire constrains on hash-key
79 | expir_time = group_data[1][DEFAULT_EVENT_TIMESTAMP_FIELD].max() + ttl.to_py_timedelta()
80 | pipe.expireat(zset_key, expir_time)
81 | pipe.execute()
82 |
83 | def read_batch(
84 | self,
85 | entity_df: pd.DataFrame,
86 | project_name: str,
87 | view: Union[Service, FeatureView],
88 | feature_views: Dict[str, FeatureView],
89 | entities: Dict[str, Entity],
90 | join_keys: List[str],
91 | **kwargs,
92 | ):
93 | if isinstance(view, FeatureView):
94 | data = self._read_batch(
95 | hkey=f"{project_name}:{view.name}",
96 | ttl=view.ttl,
97 | period=None,
98 | entity_df=entity_df[join_keys],
99 | **kwargs,
100 | )
101 | entity_df = pd.merge(entity_df, data, on=join_keys, how="inner") if not data.empty else None
102 | elif isinstance(view, Service):
103 | for featureview in view.features:
104 | fea_entities = functools.reduce(
105 | lambda x, y: x + y,
106 | [entities[entity].join_keys for entity in feature_views[featureview.view_name].entities],
107 | )
108 | fea_join_keys = [join_key for join_key in join_keys if join_key in fea_entities]
109 | feature_view_batch = self._read_batch(
110 | hkey=f"{project_name}:{featureview.view_name}",
111 | ttl=feature_views[featureview.view_name].ttl,
112 | entity_df=entity_df[fea_join_keys],
113 | period=featureview.period,
114 | **kwargs,
115 | )
116 | if not feature_view_batch.empty:
117 | entity_df = feature_view_batch.merge(entity_df, on=fea_join_keys, how="inner")
118 | else:
119 | raise TypeError("online read only allow FeatureView and Service")
120 |
121 | return entity_df
122 |
123 | def _read_batch(
124 | self,
125 | hkey: str,
126 | ttl: Optional[Period] = None,
127 | period: Optional[Period] = None,
128 | entity_df: pd.DataFrame = None,
129 | ) -> pd.DataFrame:
130 |
131 | min_timestamp = max(
132 | (pd.Timestamp(datetime.now()) - period.to_pandas_dateoffset() if period else pd.Timestamp(0)),
133 | (pd.Timestamp(datetime.now()) - ttl.to_pandas_dateoffset() if ttl else pd.Timestamp(0)),
134 | )
135 | dt_group = entity_df.groupby(list(entity_df.columns))
136 | all_entities = [
137 | functools.reduce(
138 | lambda x, y: f"{x},{y}",
139 | list(
140 | map(
141 | lambda x, y: x + ":" + y,
142 | list(entity_df.columns),
143 | [group_data[0]] if isinstance(group_data[0], str) else list(group_data[0]),
144 | )
145 | ),
146 | )
147 | for group_data in dt_group
148 | ]
149 | all_zset_key = self.client.hmget(hkey, all_entities)
150 | if all_zset_key:
151 | data = [ # newest record
152 | self.client.zrevrangebyscore(
153 | name=zset_key,
154 | min=min_timestamp.timestamp(),
155 | max=datetime.now().timestamp(),
156 | withscores=False,
157 | start=0,
158 | num=1,
159 | )
160 | for zset_key in all_zset_key
161 | if zset_key
162 | ]
163 | columns = list(json.loads(data[0][0]).keys())
164 | batch_data_list = [[json.loads(data[i][0])[key] for key in columns] for i in range(len(data))]
165 | data = pd.DataFrame(data=batch_data_list, columns=columns)
166 | else:
167 | data = pd.DataFrame()
168 | return data
169 |
--------------------------------------------------------------------------------
/f2ai/definitions/offline_store.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import abc
3 | import pandas as pd
4 | import datetime
5 | import os
6 | from enum import Enum
7 | from typing import Dict, Any, TYPE_CHECKING, Set, List, Optional, Union
8 | from pydantic import BaseModel
9 |
10 | if TYPE_CHECKING:
11 | from .services import Service
12 | from .sources import Source
13 | from .features import Feature
14 | from .period import Period
15 | from .constants import StatsFunctions
16 |
17 |
18 | class OfflineStoreType(str, Enum):
19 | """A constant numerate choices which is used to indicate how to initialize OfflineStore from configuration. If you want to add a new type of offline store, you definitely want to modify this."""
20 |
21 | FILE = "file"
22 | PGSQL = "pgsql"
23 | SPARK = "spark"
24 |
25 |
26 | class OfflineStore(BaseModel):
27 | """An abstraction of what functionalities a OfflineStore should implements. If you want to be a one of the offline store contributor. This is the core."""
28 |
29 | type: OfflineStoreType
30 | materialize_path: Optional[str]
31 |
32 | class Config:
33 | extra = "allow"
34 |
35 | @abc.abstractmethod
36 | def get_offline_source(self, service: Service) -> Source:
37 | """get offline materialized source with a specific service
38 |
39 | Args:
40 | service (Service): an instance of Service
41 |
42 | Returns:
43 | Source
44 | """
45 | pass
46 |
47 | @abc.abstractmethod
48 | def get_features(
49 | self,
50 | entity_df: pd.DataFrame,
51 | features: Set[Feature],
52 | source: Source,
53 | join_keys: List[str] = [],
54 | ttl: Optional[Period] = None,
55 | include: bool = True,
56 | **kwargs,
57 | ) -> pd.DataFrame:
58 | """get features from current offline store.
59 |
60 | Args:
61 | entity_df (pd.DataFrame): A query DataFrame which include entities and event_timestamp column.
62 | features (Set[Feature]): A set of Features you want to retrieve.
63 | source (Source): A specific implementation of Source. For example, OfflinePostgresStore will receive a SqlSource which point to table with time semantic.
64 | join_keys (List[str], optional): Which columns to join the entity_df with source. Defaults to [].
65 | ttl (Optional[Period], optional): Time to Live, if feature's event_timestamp exceeds the ttl, it will be dropped. Defaults to None.
66 | include (bool, optional): If include (<=) the event_timestamp in entity_df, else (<). Defaults to True.
67 |
68 | Returns:
69 | pd.DataFrame
70 | """
71 | pass
72 |
73 | @abc.abstractmethod
74 | def get_period_features(
75 | self,
76 | entity_df: pd.DataFrame,
77 | features: List[Feature],
78 | source: Source,
79 | period: Period,
80 | join_keys: List[str] = [],
81 | ttl: Optional[Period] = None,
82 | include: bool = True,
83 | **kwargs,
84 | ) -> pd.DataFrame:
85 | """get a period of features from offline store between [event_timestamp, event_timestamp + period] if period > 0 else [event_timestamp + period, event_timestamp].
86 |
87 | Args:
88 | entity_df (pd.DataFrame): A query DataFrame which include entities and event_timestamp column.
89 | features (List[Feature]): A list of Features you want to retrieve.
90 | source (Source): A specific implementation of Source. For example, OfflinePostgresStore will receive a SqlSource which point to table with time semantic.
91 | period (Period): A Period instance, which wrapped by F2AI.
92 | join_keys (List[str], optional): Which columns to join the entity_df with source. Defaults to [].. Defaults to [].
93 | ttl (Optional[Period], optional): Time to Live, if feature's event_timestamp exceeds the ttl, it will be dropped. Defaults to None.
94 | include (bool, optional): If include (<=) the event_timestamp in entity_df, else (<). Defaults to True.
95 |
96 | Returns:
97 | pd.DataFrame
98 | """
99 | pass
100 |
101 | @abc.abstractmethod
102 | def get_latest_entities(
103 | self,
104 | source: Source,
105 | join_keys: List[str] = None,
106 | group_keys: List[str] = None,
107 | entity_df: pd.DataFrame = None,
108 | start: datetime.datetime = None,
109 | ) -> pd.DataFrame:
110 | """get latest unique entities from a source. Which is useful when you want to know how many entities you have, or what is the latest features appear in your data source.
111 |
112 | Args:
113 | source (Source): A specific implementation of Source. For example, OfflinePostgresStore will receive a SqlSource which point to table with time semantic.
114 | join_keys (List[str], optional): Which columns to join the entity_df with source. Defaults to None.
115 | group_keys (List[str], optional): Which columns to aggregate
116 | entity_df (pd.DataFrame, optional): A query DataFrame which include entities and event_timestamp column. Defaults to None.
117 |
118 | Returns:
119 | pd.DataFrame
120 | """
121 | pass
122 |
123 | @abc.abstractmethod
124 | def stats(
125 | self,
126 | source: Source,
127 | features: Set[Feature],
128 | fn: StatsFunctions,
129 | group_keys: List[str] = [],
130 | start: datetime.datetime = None,
131 | end: datetime.datetime = None,
132 | ) -> Union[pd.DataFrame, Dict[str, list]]:
133 | """Get statistical information with given StatsFunctions.
134 |
135 | Args:
136 | source (Source): A specific implementation of Source. For example, OfflinePostgresStore
137 | features (Set[Feature]): A set of Features you want stats on.
138 | fn (StatsFunctions): A stats function, which contains min, max, std, avg, mode, median, unique.
139 | group_keys (List[str], optional): How to group by. Defaults to [].
140 | start (datetime.datetime, optional): Defaults to None.
141 | end (datetime.datetime, optional): Defaults to None.
142 |
143 | Returns:
144 | Union[pd.DataFrame, Dict[str, list]]: Return Dict when unique, otherwise pd.DataFrame
145 | """
146 | pass
147 |
148 | @abc.abstractmethod
149 | def query(self, query: str, *args, **kwargs) -> Any:
150 | """
151 | Run a query with specific offline store. egg:
152 | if you are using pgsql, this will run a query via psycopg2
153 | if you are using spark, this will run a query via sparksql
154 | """
155 | pass
156 |
157 |
158 | def init_offline_store_from_cfg(cfg: Dict[Any], project_name: str) -> OfflineStore:
159 | """Initialize an implementation of OfflineStore from yaml config.
160 |
161 | Args:
162 | cfg (Dict[Any]): a parsed config object.
163 |
164 | Returns:
165 | OfflineStore: Different types of OfflineStore.
166 | """
167 | offline_store_type = OfflineStoreType(cfg["type"])
168 |
169 | if offline_store_type == OfflineStoreType.FILE:
170 | from ..offline_stores.offline_file_store import OfflineFileStore
171 |
172 | offline_store = OfflineFileStore(**cfg)
173 | if offline_store.materialize_path is None:
174 | offline_store.materialize_path = os.path.join(os.path.expanduser("~"), '.f2ai', project_name)
175 |
176 | return offline_store
177 |
178 | if offline_store_type == OfflineStoreType.PGSQL:
179 | from ..offline_stores.offline_postgres_store import OfflinePostgresStore
180 |
181 | pgsql_conf = cfg.pop("pgsql_conf", {})
182 | offline_store = OfflinePostgresStore(**cfg, **pgsql_conf)
183 | if offline_store.materialize_path is None:
184 | offline_store.materialize_path = f'{offline_store.database}.{offline_store.db_schema}'
185 |
186 | return offline_store
187 |
188 | if offline_store_type == OfflineStoreType.SPARK:
189 | from ..offline_stores.offline_spark_store import OfflineSparkStore
190 |
191 | return OfflineSparkStore(**cfg)
192 |
193 | raise TypeError(f"offline store type must be one of [{','.join(e.value for e in OfflineStoreType)}]")
194 |
--------------------------------------------------------------------------------
/f2ai/definitions/persist_engine.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | import abc
3 | import os
4 | from typing import Dict, List, Optional
5 | from pydantic import BaseModel
6 | from enum import Enum
7 | from tqdm import tqdm
8 | from multiprocessing import Pool
9 | from f2ai.definitions import (
10 | OfflineStoreType,
11 | OfflineStore,
12 | OnlineStore,
13 | Entity,
14 | Service,
15 | FeatureView,
16 | LabelView,
17 | BackOffTime,
18 | Source,
19 | Feature,
20 | Period,
21 | )
22 |
23 |
24 | class PersistFeatureView(BaseModel):
25 | """
26 | Another feature view that usually used when finalizing the results.
27 | """
28 |
29 | name: str
30 | source: Source
31 | features: List[Feature] = []
32 | join_keys: List[str] = []
33 | ttl: Optional[Period]
34 |
35 |
36 | class PersistLabelView(BaseModel):
37 | """
38 | Another label view that usually used when finalizing the results.
39 | """
40 |
41 | source: Source
42 | labels: List[Feature] = []
43 | join_keys: List[str] = []
44 |
45 |
46 | class PersistEngineType(str, Enum):
47 | """A constant numerate choices which is used to indicate how to initialize PersistEngine from configuration."""
48 |
49 | OFFLINE = "offline"
50 | ONLINE = "online"
51 |
52 |
53 | class OfflinePersistEngineType(str, Enum):
54 | """A constant numerate choices which is used to indicate how to initialize PersistEngine from configuration."""
55 |
56 | FILE = "file"
57 | PGSQL = "pgsql"
58 | SPARK = "spark"
59 |
60 |
61 | class OnlinePersistEngineType(str, Enum):
62 | """A constant numerate choices which is used to indicate how to initialize PersistEngine from configuration."""
63 |
64 | LOCAL = "local"
65 | DISTRIBUTE = "distribute"
66 |
67 |
68 | class PersistEngine(BaseModel):
69 | type: PersistEngineType
70 |
71 | class Config:
72 | extra = "allow"
73 |
74 |
75 | class OfflinePersistEngine(PersistEngine):
76 | type: OfflinePersistEngineType
77 | offline_store: OfflineStore
78 |
79 | class Config:
80 | extra = "allow"
81 |
82 | @abc.abstractmethod
83 | def materialize(
84 | self,
85 | feature_views: List[PersistFeatureView],
86 | label_view: PersistLabelView,
87 | destination: Source,
88 | back_off_time: BackOffTime,
89 | ):
90 | pass
91 |
92 |
93 | class OnlinePersistEngine(PersistEngine):
94 | type: OnlinePersistEngineType
95 | online_store: OnlineStore
96 | offline_store: OfflineStore
97 |
98 | class Config:
99 | extra = "allow"
100 |
101 | @abc.abstractmethod
102 | def materialize(self, prefix: str, feature_view: PersistFeatureView, back_off_time: BackOffTime):
103 | pass
104 |
105 |
106 | class RealPersistEngine(BaseModel):
107 | offline_engine: OfflinePersistEngine
108 | online_engine: OnlinePersistEngine
109 |
110 | def materialize_offline(
111 | self,
112 | services: List[Service],
113 | label_views: Dict[str, LabelView],
114 | feature_views: Dict[str, FeatureView],
115 | entities: Dict[str, Entity],
116 | sources: Dict[str, Source],
117 | back_off_time: BackOffTime,
118 | ):
119 | cpu_ava = max(os.cpu_count() // 2, 1)
120 |
121 | # with Pool(processes=cpu_ava) as pool:
122 | service_to_list_of_args = []
123 | back_off_segments = list(back_off_time.to_units())
124 | for service in services:
125 | destination = self.offline_engine.offline_store.get_offline_source(service)
126 | label_view = service.get_label_views(label_views)[0]
127 | label_view = PersistLabelView(
128 | source=sources[label_view.batch_source],
129 | labels=label_view.get_label_objects(),
130 | join_keys=[
131 | join_key
132 | for entity_name in label_view.entities
133 | for join_key in entities[entity_name].join_keys
134 | ],
135 | )
136 | label_names = set([label.name for label in label_view.labels])
137 |
138 | feature_views = [
139 | PersistFeatureView(
140 | name=feature_view.name,
141 | join_keys=[
142 | join_key
143 | for entity_name in feature_view.entities
144 | for join_key in entities[entity_name].join_keys
145 | ],
146 | features=[
147 | feature
148 | for feature in feature_view.get_feature_objects()
149 | if feature.name not in label_names
150 | ],
151 | source=sources[feature_view.batch_source],
152 | ttl=feature_view.ttl,
153 | )
154 | for feature_view in service.get_feature_views(feature_views)
155 | ]
156 | feature_views = [feature_view for feature_view in feature_views if len(feature_view.features) > 0]
157 | service_to_list_of_args += [
158 | (feature_views, label_view, destination, per_backoff, service.name)
159 | for per_backoff in back_off_segments
160 | ]
161 |
162 | bars = {
163 | service.name: tqdm(
164 | total=len(back_off_segments), desc=f"materializing {service.name}", position=pos
165 | )
166 | for pos, service in enumerate(services)
167 | }
168 |
169 | pool = Pool(processes=cpu_ava)
170 |
171 | for param in service_to_list_of_args:
172 | pool.apply_async(
173 | self.offline_engine.materialize,
174 | param,
175 | callback=lambda x: bars[x].update(),
176 | error_callback=lambda x: print(x.__cause__),
177 | )
178 |
179 | pool.close()
180 | pool.join()
181 | [bars[bar].close() for bar in bars]
182 |
183 | def materialize_online(
184 | self,
185 | prefix: str,
186 | feature_views: List[FeatureView],
187 | entities: Dict[str, Entity],
188 | sources: Dict[str, Source],
189 | back_off_time: BackOffTime,
190 | ):
191 | cpu_ava = max(os.cpu_count() // 2, 1)
192 |
193 | feature_backoffs = []
194 | back_off_segments = list(back_off_time.to_units())
195 | for feature_view in feature_views:
196 | join_keys = list(
197 | {
198 | join_key
199 | for entity_name in feature_view.entities
200 | for join_key in entities[entity_name].join_keys
201 | }
202 | )
203 | feature_view = PersistFeatureView(
204 | name=feature_view.name,
205 | join_keys=join_keys,
206 | features=feature_view.get_feature_objects(),
207 | source=sources[feature_view.batch_source],
208 | ttl=feature_view.ttl,
209 | )
210 | feature_backoffs += [
211 | (prefix, feature_view, per_backoff, feature_view.name) for per_backoff in back_off_segments
212 | ]
213 |
214 | bars = {
215 | feature_view.name: tqdm(
216 | total=len(back_off_segments), desc=f"materializing {feature_view.name}", position=pos
217 | )
218 | for pos, feature_view in enumerate(feature_views)
219 | }
220 |
221 | pool = Pool(processes=cpu_ava)
222 |
223 | for param in feature_backoffs:
224 | pool.apply_async(
225 | self.online_engine.materialize,
226 | param,
227 | callback=lambda x: bars[x].update(),
228 | error_callback=lambda x: print(x.__cause__),
229 | )
230 |
231 | pool.close()
232 | pool.join()
233 | [bars[bar].close() for bar in bars]
234 |
235 |
236 | def init_persist_engine_from_cfg(offline_store: OfflineStore, online_store: OnlineStore):
237 | """Initialize an implementation of PersistEngine from yaml config.
238 |
239 | Args:
240 | cfg (Dict[Any]): a parsed config object.
241 |
242 | Returns:
243 | RealPersistEngine: Contains Offline and Online PersistEngine
244 | """
245 |
246 | if offline_store.type == OfflineStoreType.FILE:
247 | from ..persist_engine.offline_file_persistengine import OfflineFilePersistEngine
248 | from ..persist_engine.online_local_persistengine import OnlineLocalPersistEngine
249 |
250 | offline_persist_engine_cls = OfflineFilePersistEngine
251 | online_persist_engine_cls = OnlineLocalPersistEngine
252 |
253 | elif offline_store.type == OfflineStoreType.PGSQL:
254 | from ..persist_engine.offline_pgsql_persistengine import OfflinePgsqlPersistEngine
255 | from ..persist_engine.online_local_persistengine import OnlineLocalPersistEngine
256 |
257 | offline_persist_engine_cls = OfflinePgsqlPersistEngine
258 | online_persist_engine_cls = OnlineLocalPersistEngine
259 |
260 | elif offline_store.type == OfflineStoreType.SPARK:
261 | from ..persist_engine.offline_spark_persistengine import OfflineSparkPersistEngine
262 | from ..persist_engine.online_spark_persistengine import OnlineSparkPersistEngine
263 |
264 | offline_persist_engine_cls = OfflineSparkPersistEngine
265 | online_persist_engine_cls = OnlineSparkPersistEngine
266 |
267 | offline_engine = offline_persist_engine_cls(offline_store=offline_store)
268 | online_engine = online_persist_engine_cls(offline_store=offline_store, online_store=online_store)
269 | return RealPersistEngine(offline_engine=offline_engine, online_engine=online_engine)
270 |
--------------------------------------------------------------------------------
/tests/units/offline_stores/offline_file_store_test.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from f2ai.offline_stores.offline_file_store import OfflineFileStore
3 | from f2ai.definitions import Period, FileSource, Feature, FeatureDTypes, StatsFunctions
4 | from f2ai.common.time_field import ENTITY_EVENT_TIMESTAMP_FIELD, SOURCE_EVENT_TIMESTAMP_FIELD
5 |
6 | import pytest
7 | from unittest.mock import MagicMock
8 |
9 | mock_point_in_time_filter_df = pd.DataFrame(
10 | {
11 | "join_key": ["A", "A", "A", "A"],
12 | ENTITY_EVENT_TIMESTAMP_FIELD: [
13 | pd.Timestamp("2021-08-25 20:16:20"),
14 | pd.Timestamp("2021-08-25 20:16:20"),
15 | pd.Timestamp("2021-08-25 20:16:20"),
16 | pd.Timestamp("2021-08-25 20:16:20"),
17 | ],
18 | SOURCE_EVENT_TIMESTAMP_FIELD: [
19 | pd.Timestamp("2021-08-25 20:16:18"),
20 | pd.Timestamp("2021-08-25 20:16:19"),
21 | pd.Timestamp("2021-08-25 20:16:20"),
22 | pd.Timestamp("2021-08-25 20:16:21"),
23 | ],
24 | },
25 | )
26 |
27 |
28 | def test_point_in_time_filter_simple():
29 | result_df = OfflineFileStore._point_in_time_filter(mock_point_in_time_filter_df)
30 | assert len(result_df) == 3
31 | assert result_df["_source_event_timestamp_"].max() == pd.Timestamp("2021-08-25 20:16:20")
32 |
33 |
34 | def test_point_in_time_filter_not_include():
35 | result_df = OfflineFileStore._point_in_time_filter(mock_point_in_time_filter_df, include=False)
36 | assert len(result_df) == 2
37 | assert result_df["_source_event_timestamp_"].max() == pd.Timestamp("2021-08-25 20:16:19")
38 |
39 |
40 | def test_point_in_time_filter_with_ttl():
41 | result_df = OfflineFileStore._point_in_time_filter(
42 | mock_point_in_time_filter_df, ttl=Period.from_str("2 seconds")
43 | )
44 | assert len(result_df) == 2
45 | assert result_df["_source_event_timestamp_"].max() == pd.Timestamp("2021-08-25 20:16:20")
46 | assert result_df["_source_event_timestamp_"].min() == pd.Timestamp("2021-08-25 20:16:19")
47 |
48 |
49 | def test_point_on_time_filter_simple():
50 | result_df = OfflineFileStore._point_on_time_filter(
51 | mock_point_in_time_filter_df, -Period.from_str("2 seconds"), include=True
52 | )
53 | assert len(result_df) == 2
54 | assert result_df["_source_event_timestamp_"].max() == pd.Timestamp("2021-08-25 20:16:20")
55 | assert result_df["_source_event_timestamp_"].min() == pd.Timestamp("2021-08-25 20:16:19")
56 |
57 |
58 | def test_point_on_time_filter_simple_label():
59 | result_df = OfflineFileStore._point_on_time_filter(
60 | mock_point_in_time_filter_df, Period.from_str("2 seconds"), include=True
61 | )
62 | assert len(result_df) == 2
63 | assert result_df["_source_event_timestamp_"].max() == pd.Timestamp("2021-08-25 20:16:21")
64 | assert result_df["_source_event_timestamp_"].min() == pd.Timestamp("2021-08-25 20:16:20")
65 |
66 |
67 | def test_point_on_time_filter_not_include():
68 | result_df = OfflineFileStore._point_on_time_filter(
69 | mock_point_in_time_filter_df, period=-Period.from_str("2 seconds"), include=False
70 | )
71 | assert len(result_df) == 2
72 | assert result_df["_source_event_timestamp_"].max() == pd.Timestamp("2021-08-25 20:16:19")
73 | assert result_df["_source_event_timestamp_"].min() == pd.Timestamp("2021-08-25 20:16:18")
74 |
75 |
76 | def test_point_on_time_filter_not_include_label():
77 | result_df = OfflineFileStore._point_on_time_filter(
78 | mock_point_in_time_filter_df, Period.from_str("2 seconds"), include=False
79 | )
80 | assert len(result_df) == 1
81 | assert result_df["_source_event_timestamp_"].max() == pd.Timestamp("2021-08-25 20:16:21")
82 |
83 |
84 | def test_point_on_time_filter_with_ttl():
85 | result_df = OfflineFileStore._point_on_time_filter(
86 | mock_point_in_time_filter_df,
87 | period=-Period.from_str("3 seconds"),
88 | ttl=Period.from_str("2 seconds"),
89 | include=True,
90 | )
91 | assert len(result_df) == 2
92 | assert result_df["_source_event_timestamp_"].max() == pd.Timestamp("2021-08-25 20:16:20")
93 | assert result_df["_source_event_timestamp_"].min() == pd.Timestamp("2021-08-25 20:16:19")
94 |
95 |
96 | mock_point_in_time_latest_df = pd.DataFrame(
97 | {
98 | "join_key": ["A", "A", "B", "B"],
99 | ENTITY_EVENT_TIMESTAMP_FIELD: [
100 | pd.Timestamp("2021-08-25 20:16:20"),
101 | pd.Timestamp("2021-08-25 20:16:20"),
102 | pd.Timestamp("2021-08-25 20:16:20"),
103 | pd.Timestamp("2021-08-25 20:16:20"),
104 | ],
105 | SOURCE_EVENT_TIMESTAMP_FIELD: [
106 | pd.Timestamp("2021-08-25 20:16:18"),
107 | pd.Timestamp("2021-08-25 20:16:19"),
108 | pd.Timestamp("2021-08-25 20:16:11"),
109 | pd.Timestamp("2021-08-25 20:16:20"),
110 | ],
111 | },
112 | )
113 |
114 |
115 | def test_point_in_time_latest_with_group_keys():
116 | result_df = OfflineFileStore._point_in_time_latest(mock_point_in_time_latest_df, ["join_key"])
117 |
118 | df_a = result_df[result_df["join_key"] == "A"].iloc[0]
119 | df_b = result_df[result_df["join_key"] == "B"].iloc[0]
120 |
121 | assert df_a["_source_event_timestamp_"] == pd.Timestamp("2021-08-25 20:16:19")
122 | assert df_b["_source_event_timestamp_"] == pd.Timestamp("2021-08-25 20:16:20")
123 |
124 |
125 | def test_point_in_time_latest_without_group_keys():
126 | result_df = OfflineFileStore._point_in_time_latest(mock_point_in_time_latest_df)
127 | assert result_df.loc[0]["_source_event_timestamp_"] == pd.Timestamp("2021-08-25 20:16:20")
128 |
129 |
130 | mock_source_df = pd.DataFrame(
131 | {
132 | "join_key": ["A", "A", "B", "B"],
133 | "event_timestamp": [
134 | pd.Timestamp("2021-08-25 20:16:18"),
135 | pd.Timestamp("2021-08-25 20:16:19"),
136 | pd.Timestamp("2021-08-25 20:16:11"),
137 | pd.Timestamp("2021-08-25 20:16:20"),
138 | ],
139 | "feature": [1, 2, 3, 4],
140 | },
141 | )
142 | mock_entity_df = pd.DataFrame(
143 | {
144 | "join_key": ["A", "B", "A"],
145 | "event_timestamp": [
146 | pd.Timestamp("2021-08-25 20:16:18"),
147 | pd.Timestamp("2021-08-25 20:16:19"),
148 | pd.Timestamp("2021-08-25 20:16:19"),
149 | ],
150 | "request_feature": [6, 5, 4],
151 | }
152 | )
153 |
154 |
155 | def test_point_in_time_join_with_join_keys():
156 | result_df = OfflineFileStore._point_in_time_join(
157 | mock_entity_df, mock_source_df, timestamp_field="event_timestamp", join_keys=["join_key"]
158 | )
159 | assert len(result_df) == 3
160 |
161 |
162 | def test_point_in_time_join_with_ttl():
163 | result_df = OfflineFileStore._point_in_time_join(
164 | mock_entity_df,
165 | mock_source_df,
166 | timestamp_field="event_timestamp",
167 | join_keys=["join_key"],
168 | ttl=Period.from_str("2 seconds"),
169 | )
170 | assert len(result_df) == 2
171 |
172 |
173 | def test_point_in_time_join_with_extra_entities_in_source():
174 | result_df = OfflineFileStore._point_in_time_join(
175 | pd.DataFrame(
176 | {
177 | "join_key": ["A"],
178 | "event_timestamp": [pd.Timestamp("2021-08-25 20:16:18")],
179 | "request_feature": [6],
180 | }
181 | ),
182 | mock_source_df,
183 | timestamp_field="event_timestamp",
184 | join_keys=["join_key"],
185 | )
186 | assert len(result_df) == 1
187 |
188 |
189 | def test_point_in_time_join_with_created_timestamp():
190 | result_df = OfflineFileStore._point_in_time_join(
191 | mock_entity_df,
192 | pd.DataFrame(
193 | {
194 | "join_key": ["A", "A"],
195 | "event_timestamp": [
196 | pd.Timestamp("2021-08-25 20:16:18"),
197 | pd.Timestamp("2021-08-25 20:16:18"),
198 | ],
199 | "created_timestamp": [
200 | pd.Timestamp("2021-08-25 20:16:21"),
201 | pd.Timestamp("2021-08-25 20:16:20"),
202 | ],
203 | "feature": [5, 6],
204 | },
205 | ),
206 | timestamp_field="event_timestamp",
207 | created_timestamp_field="created_timestamp",
208 | join_keys=["join_key"],
209 | ttl=Period.from_str("2 seconds"),
210 | )
211 | assert all(result_df["feature"] == [5, 5])
212 |
213 |
214 | def test_point_on_time_join_with_join_keys():
215 | result_df = OfflineFileStore._point_on_time_join(
216 | mock_entity_df,
217 | mock_source_df,
218 | period=-Period.from_str("2 seconds"),
219 | timestamp_field="event_timestamp",
220 | join_keys=["join_key"],
221 | include=False,
222 | )
223 | assert len(result_df) == 1
224 | assert result_df["join_key"].values == "A"
225 |
226 |
227 | def test_point_on_time_join_with_join_keys_label():
228 | result_df = OfflineFileStore._point_on_time_join(
229 | mock_entity_df,
230 | mock_source_df,
231 | period=Period.from_str("2 seconds"),
232 | timestamp_field="event_timestamp",
233 | join_keys=["join_key"],
234 | )
235 | assert len(result_df) == 4
236 |
237 |
238 | def test_point_on_time_join_with_ttl():
239 | result_df = OfflineFileStore._point_on_time_join(
240 | mock_entity_df,
241 | mock_source_df,
242 | timestamp_field="event_timestamp",
243 | join_keys=["join_key"],
244 | ttl=Period.from_str("2 seconds"),
245 | period=-Period.from_str("10 seconds"),
246 | )
247 | assert len(result_df) == 3
248 | assert "B" not in result_df["join_key"].values
249 |
250 |
251 | def test_point_on_time_join_with_extra_entities_in_source():
252 | result_df = OfflineFileStore._point_on_time_join(
253 | pd.DataFrame(
254 | {
255 | "join_key": ["A"],
256 | "event_timestamp": [pd.Timestamp("2021-08-25 20:16:22")],
257 | "request_feature": [6],
258 | }
259 | ),
260 | mock_source_df,
261 | period=Period.from_str("3 seconds"),
262 | timestamp_field="event_timestamp",
263 | join_keys=["join_key"],
264 | )
265 | assert len(result_df) == 0
266 |
267 |
268 | def test_point_on_time_join_with_created_timestamp():
269 | result_df = OfflineFileStore._point_on_time_join(
270 | mock_entity_df,
271 | pd.DataFrame(
272 | {
273 | "join_key": ["A", "A", "A", "A"],
274 | "event_timestamp": [
275 | pd.Timestamp("2021-08-25 20:16:16"),
276 | pd.Timestamp("2021-08-25 20:16:17"),
277 | pd.Timestamp("2021-08-25 20:16:18"),
278 | pd.Timestamp("2021-08-25 20:16:18"),
279 | ],
280 | "materialize_time": [
281 | pd.Timestamp("2021-08-25 20:16:17"),
282 | pd.Timestamp("2021-08-25 20:16:19"),
283 | pd.Timestamp("2021-08-25 20:16:21"),
284 | pd.Timestamp("2021-08-25 20:16:20"),
285 | ],
286 | "feature": [3, 4, 5, 6],
287 | },
288 | ),
289 | timestamp_field="event_timestamp",
290 | created_timestamp_field="materialize_time",
291 | join_keys=["join_key"],
292 | period=-Period.from_str("2 seconds"),
293 | )
294 | assert all(result_df["feature"] == [4, 5, 5])
295 | assert result_df[result_df["event_timestamp"] == pd.Timestamp("2021-08-25 20:16:18")].shape == (2, 5)
296 | assert pd.Timestamp("2021-08-25 20:16:20") not in result_df["event_timestamp"]
297 |
298 |
299 | mocked_stats_input_df = pd.DataFrame(
300 | {
301 | "join_key": ["A", "B", "A", "B"],
302 | "event_timestamp": [
303 | pd.Timestamp("2021-08-25 20:16:16"),
304 | pd.Timestamp("2021-08-25 20:16:17"),
305 | pd.Timestamp("2021-08-25 20:16:18"),
306 | pd.Timestamp("2021-08-25 20:16:18"),
307 | ],
308 | "F1": [1, 2, 3, 4],
309 | }
310 | )
311 |
312 |
313 | @pytest.mark.parametrize("fn", [(fn) for fn in StatsFunctions])
314 | def test_stats(fn):
315 | store = OfflineFileStore()
316 | file_source = FileSource(name="mock", path="mock")
317 | features = [Feature(name="F1", dtype=FeatureDTypes.FLOAT, view_name="hello")]
318 |
319 | store._read_file = MagicMock(return_value=mocked_stats_input_df)
320 |
321 | result_df = store.stats(
322 | file_source,
323 | features,
324 | fn,
325 | group_keys=["join_key"],
326 | )
327 | if fn == StatsFunctions.UNIQUE:
328 | assert ",".join(result_df.columns) == "join_key"
329 | else:
330 | assert ",".join(result_df.index.names) == "join_key"
331 | assert isinstance(result_df, pd.DataFrame)
332 |
--------------------------------------------------------------------------------
/f2ai/models/nbeats/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torchmetrics import MeanAbsoluteError
4 | from typing import List, Dict, Tuple
5 | from torch.utils.data import DataLoader
6 |
7 | from f2ai.featurestore import FeatureStore
8 | from f2ai.dataset import EvenEventsSampler, NoEntitiesSampler
9 | from f2ai.common.collect_fn import nbeats_collet_fn
10 |
11 | from submodules import NBEATSGenericBlock, NBEATSSeasonalBlock, NBEATSTrendBlock, MultiEmbedding
12 |
13 | TIME_IDX = "_time_idx_"
14 |
15 |
16 | class NbeatsNetwork(nn.Module):
17 | def __init__(
18 | self,
19 | targets: str,
20 | model_type: str = "G", # 'I'
21 | num_stack: int = 1,
22 | num_block: int = 1,
23 | width: int = 8, # [2**9]
24 | expansion_coe: int = 5, # [2**5]
25 | num_block_layer: int = 4,
26 | prediction_length: int = 0,
27 | context_length: int = 0,
28 | dropout: float = 0.1,
29 | backcast_loss_ratio: float = 0.1,
30 | covariate_number: int = 0,
31 | encoder_cont: List[str] = [],
32 | decoder_cont: List[str] = [],
33 | embedding_sizes: Dict[str, Tuple[int, int]] = {},
34 | x_categoricals: List[str] = [],
35 | output_size=1,
36 | ):
37 |
38 | super().__init__()
39 | self._targets = targets
40 | self._encoder_cont = encoder_cont
41 | self._decoder_cont = decoder_cont
42 | self.dropout = dropout
43 | self.backcast_loss_ratio = backcast_loss_ratio
44 | self.context_length = context_length
45 | self.prediction_length = prediction_length
46 | self.target_number = output_size
47 | self.covariate_number = covariate_number
48 |
49 | self.encoder_embeddings = MultiEmbedding(
50 | embedding_sizes=embedding_sizes,
51 | embedding_paddings=[],
52 | categorical_groups={},
53 | x_categoricals=x_categoricals,
54 | )
55 |
56 | if model_type == "I":
57 | width = [2**width, 2 ** (width + 2)]
58 | self.stack_types = ["trend", "seasonality"] * num_stack
59 | self.expansion_coefficient_lengths = [item for i in range(num_stack) for item in [3, 7]]
60 | self.num_blocks = [num_block for i in range(2 * num_stack)]
61 | self.num_block_layers = [num_block_layer for i in range(2 * num_stack)]
62 | self.widths = [item for i in range(num_stack) for item in width]
63 | elif model_type == "G":
64 | self.stack_types = ["generic"] * num_stack
65 | self.expansion_coefficient_lengths = [2**expansion_coe for i in range(num_stack)]
66 | self.num_blocks = [num_block for i in range(num_stack)]
67 | self.num_block_layers = [num_block_layer for i in range(num_stack)]
68 | self.widths = [2**width for i in range(num_stack)]
69 | #
70 | # setup stacks
71 | self.net_blocks = nn.ModuleList()
72 |
73 | for stack_id, stack_type in enumerate(self.stack_types):
74 | for _ in range(self.num_blocks[stack_id]):
75 | if stack_type == "generic":
76 | net_block = NBEATSGenericBlock(
77 | units=self.widths[stack_id],
78 | thetas_dim=self.expansion_coefficient_lengths[stack_id],
79 | num_block_layers=self.num_block_layers[stack_id],
80 | backcast_length=self.context_length,
81 | forecast_length=self.prediction_length,
82 | dropout=self.dropout,
83 | tar_num=self.target_number,
84 | cov_num=self.covariate_number + self.encoder_embeddings.total_embedding_size(),
85 | tar_pos=self.target_positions,
86 | )
87 | elif stack_type == "seasonality":
88 | net_block = NBEATSSeasonalBlock(
89 | units=self.widths[stack_id],
90 | num_block_layers=self.num_block_layers[stack_id],
91 | backcast_length=self.context_length,
92 | forecast_length=self.prediction_length,
93 | min_period=self.expansion_coefficient_lengths[stack_id],
94 | dropout=self.dropout,
95 | tar_num=self.target_number,
96 | cov_num=self.covariate_number + self.encoder_embeddings.total_embedding_size(),
97 | tar_pos=self.target_positions,
98 | )
99 | elif stack_type == "trend":
100 | net_block = NBEATSTrendBlock(
101 | units=self.widths[stack_id],
102 | thetas_dim=self.expansion_coefficient_lengths[stack_id],
103 | num_block_layers=self.num_block_layers[stack_id],
104 | backcast_length=self.context_length,
105 | forecast_length=self.prediction_length,
106 | dropout=self.dropout,
107 | tar_num=self.target_number,
108 | cov_num=self.covariate_number + self.encoder_embeddings.total_embedding_size(),
109 | tar_pos=self.target_positions,
110 | )
111 | else:
112 | raise ValueError(f"Unknown stack type {stack_type}")
113 |
114 | self.net_blocks.append(net_block)
115 |
116 | @property
117 | def target_positions(self):
118 | return [self._encoder_cont.index(tar) for tar in self._targets]
119 |
120 | def forward(self, x: Dict[str, torch.Tensor], **kwargs) -> Dict[str, torch.Tensor]:
121 | """
122 | Pass forward of network.
123 |
124 | Args:
125 | x (Dict[str, torch.Tensor]): input from dataloader generated from
126 | :py:class:`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`.
127 |
128 | Returns:
129 | Dict[str, torch.Tensor]: output of model
130 | """
131 | # batch_size * look_back * features
132 | encoder_cont = torch.cat(
133 | [x["encoder_cont"], x["encoder_time_features"], x["encoder_lag_features"]], dim=-1
134 | )
135 | # `target` can only be continuous, so position inside `encoder_cat` is irrelevant
136 | encoder_cat = (
137 | torch.cat([v for _, v in self.encoder_embeddings(x["encoder_cat"]).items()], dim=-1)
138 | if self.encoder_embeddings.total_embedding_size() != 0
139 | else torch.zeros(
140 | encoder_cont.size(0),
141 | self.context_length,
142 | self.encoder_embeddings.total_embedding_size(),
143 | device=encoder_cont.device,
144 | )
145 | )
146 | # self.hparams.prediction_length=gap+real_predict
147 | timesteps = self.context_length + self.prediction_length
148 | # encoder_cont.size(2) + self.encoder_embeddings.total_embedding_size(),
149 | generic_forecast = [
150 | torch.zeros(
151 | (encoder_cont.size(0), timesteps, len(self.target_positions)),
152 | dtype=torch.float32,
153 | device=encoder_cont.device,
154 | )
155 | ]
156 |
157 | trend_forecast = [
158 | torch.zeros(
159 | (encoder_cont.size(0), timesteps, len(self.target_positions)),
160 | dtype=torch.float32,
161 | device=encoder_cont.device,
162 | )
163 | ]
164 | seasonal_forecast = [
165 | torch.zeros(
166 | (encoder_cont.size(0), timesteps, len(self.target_positions)),
167 | dtype=torch.float32,
168 | device=encoder_cont.device,
169 | )
170 | ]
171 |
172 | forecast = torch.zeros(
173 | (encoder_cont.size(0), self.prediction_length, len(self.target_positions)),
174 | dtype=torch.float32,
175 | device=encoder_cont.device,
176 | )
177 |
178 | # make sure `encoder_cont` is followed by `encoder_cat`
179 |
180 | backcast = torch.cat([encoder_cont, encoder_cat], dim=-1)
181 |
182 | for i, block in enumerate(self.net_blocks):
183 | # evaluate block
184 | backcast_block, forecast_block = block(backcast)
185 | # add for interpretation
186 | full = torch.cat([backcast_block.detach(), forecast_block.detach()], dim=1)
187 | if isinstance(block, NBEATSTrendBlock):
188 | trend_forecast.append(full)
189 | elif isinstance(block, NBEATSSeasonalBlock):
190 | seasonal_forecast.append(full)
191 | else:
192 | generic_forecast.append(full)
193 | # update backcast and forecast
194 | backcast = backcast.clone()
195 | backcast[..., self.target_positions] = backcast[..., self.target_positions] - backcast_block
196 | # do not use backcast -= backcast_block as this signifies an inline operation
197 | forecast = forecast + forecast_block
198 |
199 | # `encoder_cat` always at the end of sequence, so it will not affect `self.target_positions`
200 | # backcast, forecast is of batch_size * context_length/prediction_length * tar_num
201 | return {
202 | "prediction": forecast,
203 | "backcast": (
204 | encoder_cont[..., self.target_positions] - backcast[..., self.target_positions],
205 | self.backcast_loss_ratio,
206 | ),
207 | }
208 |
209 |
210 | if __name__ == "__main__":
211 | TIME_COL = "event_timestamp"
212 |
213 | # data = pd.read_csv("/Users/zhao123456/Desktop/gitlab/guizhou_traffic/guizhou_traffic.csv")
214 | # data = data.iloc[:20000,:]
215 | # data.to_csv("/Users/zhao123456/Desktop/gitlab/guizhou_traffic/guizhou_traffic_20000.csv")
216 |
217 | fs = FeatureStore("file:///Users/zhao123456/Desktop/gitlab/guizhou_traffic")
218 |
219 | events_sampler = EvenEventsSampler(
220 | start=fs.get_latest_entities(fs.services["traval_time_prediction_embedding_v1"])[TIME_COL].min(),
221 | end=fs.get_latest_entities(fs.services["traval_time_prediction_embedding_v1"])[TIME_COL].max(),
222 | period="10 minutes",
223 | )
224 | dataset = fs.get_dataset(
225 | service="traval_time_prediction_embedding_v1",
226 | sampler=NoEntitiesSampler(events_sampler),
227 | )
228 |
229 | features_cat = [
230 | fea
231 | for fea in fs._get_available_features(fs.services["traval_time_prediction_embedding_v1"])
232 | if fea not in fs._get_available_features(fs.services["traval_time_prediction_embedding_v1"], True)
233 | ]
234 | cat_unique = fs.stats(
235 | fs.services["traval_time_prediction_embedding_v1"],
236 | fn="unique",
237 | group_key=[],
238 | start=fs.get_latest_entities(fs.services["traval_time_prediction_embedding_v1"])[TIME_COL].min(),
239 | end=fs.get_latest_entities(fs.services["traval_time_prediction_embedding_v1"])[TIME_COL].max(),
240 | features=features_cat,
241 | ).to_dict()
242 | cat_count = {key: len(cat_unique[key]) for key in cat_unique.keys()}
243 | cont_scalar_max = fs.stats(
244 | fs.services["traval_time_prediction_embedding_v1"],
245 | fn="max",
246 | group_key=[],
247 | start=fs.get_latest_entities(fs.services["traval_time_prediction_embedding_v1"])[TIME_COL].min(),
248 | end=fs.get_latest_entities(fs.services["traval_time_prediction_embedding_v1"])[TIME_COL].max(),
249 | ).to_dict()
250 | cont_scalar_min = fs.stats(
251 | fs.services["traval_time_prediction_embedding_v1"],
252 | fn="min",
253 | group_key=[],
254 | start=fs.get_latest_entities(fs.services["traval_time_prediction_embedding_v1"])[TIME_COL].min(),
255 | end=fs.get_latest_entities(fs.services["traval_time_prediction_embedding_v1"])[TIME_COL].max(),
256 | ).to_dict()
257 | cont_scalar = {key: [cont_scalar_min[key], cont_scalar_max[key]] for key in cont_scalar_min.keys()}
258 |
259 | label = fs._get_available_labels(fs.services["traval_time_prediction_embedding_v1"])
260 | del cont_scalar[label[0]]
261 |
262 | i_ds = dataset.to_pytorch()
263 | test_data_loader = DataLoader(
264 | i_ds,
265 | collate_fn=lambda x: nbeats_collet_fn(
266 | x,
267 | cont_scalar=cont_scalar,
268 | categoricals=cat_unique,
269 | label=label,
270 | ),
271 | batch_size=8,
272 | drop_last=False,
273 | sampler=None,
274 | )
275 |
276 | model = NbeatsNetwork(
277 | targets=label,
278 | # prediction_length= fs.services["traval_time_prediction_embedding_v1"].labels[0].period,
279 | # context_length= max([feature.period for feature in fs.services["traval_time_prediction_embedding_v1"].features if feature.period is not None]),
280 | prediction_length=30,
281 | context_length=60,
282 | covariate_number=len(cont_scalar),
283 | encoder_cont=list(cont_scalar.keys()) + label,
284 | decoder_cont=list(cont_scalar.keys()),
285 | x_categoricals=features_cat,
286 | output_size=1,
287 | )
288 |
289 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
290 | loss_fn = MeanAbsoluteError()
291 |
292 | for epoch in range(10): # assume 10 epoch
293 | print(f"epoch: {epoch} begin")
294 | for x, y in test_data_loader:
295 | pred_label = model(x)
296 | true_label = y
297 | loss = loss_fn(pred_label, true_label)
298 | optimizer.zero_grad()
299 | loss.backward()
300 | optimizer.step()
301 | print(f"epoch: {epoch} done, loss: {loss}")
302 |
--------------------------------------------------------------------------------
/f2ai/models/nbeats/submodules.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Tuple, Union
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | def linear(input_size, output_size, bias=True, dropout: int = None):
10 | lin = nn.Linear(input_size, output_size, bias=bias)
11 | if dropout is not None:
12 | return nn.Sequential(nn.Dropout(dropout), lin)
13 | else:
14 | return lin
15 |
16 |
17 | def linspace(
18 | backcast_length: int, forecast_length: int, centered: bool = False
19 | ) -> Tuple[np.ndarray, np.ndarray]:
20 | if centered:
21 | norm = max(backcast_length, forecast_length)
22 | start = -backcast_length
23 | stop = forecast_length - 1
24 | else:
25 | norm = backcast_length + forecast_length
26 | start = 0
27 | stop = backcast_length + forecast_length - 1
28 | lin_space = np.linspace(start / norm, stop / norm, backcast_length + forecast_length, dtype=np.float32)
29 | b_ls = lin_space[:backcast_length]
30 | f_ls = lin_space[backcast_length:]
31 | return b_ls, f_ls
32 |
33 |
34 | class NBEATSBlock(nn.Module):
35 | def __init__(
36 | self,
37 | units,
38 | thetas_dim,
39 | num_block_layers=4,
40 | backcast_length=10,
41 | forecast_length=5,
42 | share_thetas=False,
43 | dropout=0.1,
44 | tar_num=1,
45 | cov_num=0,
46 | tar_pos=[],
47 | ):
48 |
49 | super().__init__()
50 | self.units = units
51 | self.thetas_dim = thetas_dim
52 | self.backcast_length = backcast_length
53 | self.forecast_length = forecast_length
54 | self.share_thetas = share_thetas
55 | self.tar_num = tar_num
56 | self.cov_num = cov_num
57 | self.tar_pos = tar_pos
58 |
59 | fc_stack = [
60 | [
61 | nn.Linear(backcast_length, units),
62 | nn.ReLU(),
63 | ]
64 | for i in range(self.tar_num + self.cov_num)
65 | ]
66 |
67 | for i in range(self.tar_num + self.cov_num):
68 | for _ in range(num_block_layers - 1):
69 | fc_stack[i].extend([linear(units, units, dropout=dropout), nn.ReLU()])
70 |
71 | self.fc = nn.ModuleList(nn.Sequential(*fc_stack[i]) for i in range(self.tar_num + self.cov_num))
72 |
73 | if share_thetas:
74 | self.theta_f_fc = self.theta_b_fc = nn.ModuleList(
75 | [nn.Linear(units, thetas_dim, bias=False) for i in range(self.tar_num + self.cov_num)]
76 | )
77 | else:
78 |
79 | self.theta_b_fc = nn.ModuleList(
80 | [nn.Linear(units, thetas_dim, bias=False) for i in range(self.tar_num + self.cov_num)]
81 | )
82 | self.theta_f_fc = nn.ModuleList(
83 | [nn.Linear(units, thetas_dim, bias=False) for i in range(self.tar_num + self.cov_num)]
84 | )
85 |
86 | def forward(self, x):
87 | return torch.stack([self.fc[n](x[..., n]) for n in range(self.tar_num + self.cov_num)], dim=2)
88 | # return [self.fc[n](x[...,n]) for n in range(self.tar_num+self.cov_num)]
89 |
90 |
91 | class NBEATSSeasonalBlock(NBEATSBlock):
92 | def __init__(
93 | self,
94 | units,
95 | thetas_dim=None,
96 | num_block_layers=4,
97 | backcast_length=10,
98 | forecast_length=5,
99 | nb_harmonics=None,
100 | min_period=1,
101 | dropout=0.1,
102 | tar_num=1,
103 | cov_num=0,
104 | tar_pos=[],
105 | ):
106 | if nb_harmonics:
107 | thetas_dim = nb_harmonics
108 | else:
109 | thetas_dim = forecast_length
110 | self.min_period = min_period
111 |
112 | super().__init__(
113 | units=units,
114 | thetas_dim=thetas_dim,
115 | num_block_layers=num_block_layers,
116 | backcast_length=backcast_length,
117 | forecast_length=forecast_length,
118 | share_thetas=True,
119 | dropout=dropout,
120 | tar_num=tar_num,
121 | cov_num=cov_num,
122 | tar_pos=tar_pos,
123 | )
124 |
125 | backcast_linspace, forecast_linspace = linspace(backcast_length, forecast_length, centered=False)
126 |
127 | p1, p2 = (
128 | (thetas_dim // 2, thetas_dim // 2)
129 | if thetas_dim % 2 == 0
130 | else (thetas_dim // 2, thetas_dim // 2 + 1)
131 | )
132 | # seasonal_backcast_pre_p1 seasonal_backcast_pre_p2
133 | s1_b_pre = torch.tensor(
134 | [np.cos(2 * np.pi * i * backcast_linspace) for i in self.get_frequencies(p1)],
135 | dtype=torch.float32,
136 | device=backcast_linspace.device,
137 | ) # H/2-1
138 | s2_b_pre = torch.tensor(
139 | [np.sin(2 * np.pi * i * backcast_linspace) for i in self.get_frequencies(p2)],
140 | dtype=torch.float32,
141 | device=backcast_linspace.device,
142 | )
143 | # concat seasonal_backcast_pre_p1 and seasonal_backcast_pre_p2
144 | s_b_pre = torch.stack(
145 | [torch.cat([s1_b_pre, s2_b_pre]) for n in range(self.tar_num + self.cov_num)],
146 | dim=2,
147 | ) # p1+p2 * backlength * tarnum
148 |
149 | # seasonal_forecast_pre_p1
150 | s1_f_pre = torch.tensor(
151 | [np.cos(2 * np.pi * i * forecast_linspace) for i in self.get_frequencies(p1)],
152 | dtype=torch.float32,
153 | device=forecast_linspace.device,
154 | )
155 | # seasonal_forecast_pre_p2
156 | s2_f_pre = torch.tensor(
157 | [np.sin(2 * np.pi * i * forecast_linspace) for i in self.get_frequencies(p2)],
158 | dtype=torch.float32,
159 | device=forecast_linspace.device,
160 | )
161 | # concat seasonal_forecast_pre_p1 and seasonal_forecast_pre_p2
162 | s_f_pre = torch.stack(
163 | [torch.cat([s1_f_pre, s2_f_pre]) for n in range(self.tar_num + self.cov_num)],
164 | dim=2,
165 | ) # p1+p2 * forlength * tarnum
166 |
167 | # register, then can be applied as self.S_backcast and self.S_forecast
168 | self.register_buffer("S_backcast", s_b_pre)
169 | self.register_buffer("S_forecast", s_f_pre)
170 | self.agg_layer = nn.ModuleList([nn.Linear(tar_num + cov_num, 1) for i in range(tar_num)])
171 |
172 | def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
173 | x = super().forward(x)
174 |
175 | self.S_backcast_final = [self.agg_layer[i](self.S_backcast).squeeze(-1) for i in range(self.tar_num)]
176 | self.S_forecast_final = [self.agg_layer[i](self.S_forecast).squeeze(-1) for i in range(self.tar_num)]
177 |
178 | amplitudes_backward = torch.stack(
179 | [self.theta_b_fc[n](x[..., n]) for n in (self.tar_num + self.cov_num)], dim=2
180 | )
181 |
182 | backcast = torch.stack(
183 | [amplitudes_backward[..., n].mm(self.S_backcast_final) for n in self.tar_pos], dim=2
184 | )
185 |
186 | amplitudes_forward = torch.stack(
187 | [self.theta_f_fc[n](x[..., n]) for n in (self.tar_num + self.cov_num)], dim=2
188 | )
189 | forecast = torch.stack(
190 | [amplitudes_forward[..., n].mm(self.S_forecast_final) for n in self.tar_pos], dim=2
191 | )
192 |
193 | return backcast, forecast # only target, not cov
194 |
195 | def get_frequencies(self, n):
196 | return np.linspace(0, (self.backcast_length + self.forecast_length) / self.min_period, n)
197 |
198 |
199 | class NBEATSTrendBlock(NBEATSBlock):
200 | def __init__(
201 | self,
202 | units,
203 | thetas_dim,
204 | num_block_layers=4,
205 | backcast_length=10,
206 | forecast_length=5,
207 | dropout=0.1,
208 | tar_num=1,
209 | cov_num=0,
210 | tar_pos=[],
211 | ):
212 | super().__init__(
213 | units=units,
214 | thetas_dim=thetas_dim,
215 | num_block_layers=num_block_layers,
216 | backcast_length=backcast_length,
217 | forecast_length=forecast_length,
218 | share_thetas=True,
219 | dropout=dropout,
220 | tar_num=tar_num,
221 | cov_num=cov_num,
222 | tar_pos=tar_pos,
223 | )
224 |
225 | backcast_linspace, forecast_linspace = linspace(backcast_length, forecast_length, centered=True)
226 | norm = np.sqrt(forecast_length / thetas_dim) # ensure range of predictions is comparable to input
227 |
228 | # backcast
229 | coefficients_pre_b = torch.cat(
230 | [
231 | torch.tensor(
232 | [backcast_linspace**i for i in range(thetas_dim)],
233 | dtype=torch.float32,
234 | device=backcast_linspace.device,
235 | ).unsqueeze(2)
236 | for n in range(self.tar_num + self.cov_num)
237 | ],
238 | 2,
239 | )
240 | # forecast
241 | coefficients_pre_f = torch.cat(
242 | [
243 | torch.tensor(
244 | [forecast_linspace**i for i in range(thetas_dim)],
245 | dtype=torch.float32,
246 | device=forecast_linspace.device,
247 | ).unsqueeze(2)
248 | for n in range(self.tar_num + self.cov_num)
249 | ],
250 | 2,
251 | )
252 |
253 | # register, then can be applied as self.T_backcast and self.T_forecast
254 | self.register_buffer("T_backcast", coefficients_pre_b * norm)
255 | self.register_buffer("T_forecast", coefficients_pre_f * norm)
256 |
257 | self.agg_layer = nn.ModuleList([nn.Linear(tar_num + cov_num, 1) for i in range(tar_num)])
258 |
259 | def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
260 | x = super().forward(x)
261 |
262 | self.T_backcast_final = [self.agg_layer[i](self.T_backcast).squeeze(-1) for i in range(self.tar_num)]
263 | self.T_forecast_final = [self.agg_layer[i](self.T_forecast).squeeze(-1) for i in range(self.tar_num)]
264 |
265 | backcast = torch.stack(
266 | [self.theta_b_fc[n](x[..., n]).mm(self.T_backcast_final[n]) for n in self.tar_pos], dim=2
267 | )
268 | forecast = torch.stack(
269 | [self.theta_f_fc[n](x[..., n]).mm(self.T_forecast_final[n]) for n in self.tar_pos], dim=2
270 | )
271 | return backcast, forecast
272 |
273 |
274 | class NBEATSGenericBlock(NBEATSBlock):
275 | def __init__(
276 | self,
277 | units,
278 | thetas_dim,
279 | num_block_layers=4,
280 | backcast_length=10,
281 | forecast_length=5,
282 | dropout=0.1,
283 | tar_num=1,
284 | cov_num=0,
285 | tar_pos=[],
286 | ):
287 |
288 | super().__init__(
289 | units=units,
290 | thetas_dim=thetas_dim,
291 | num_block_layers=num_block_layers,
292 | backcast_length=backcast_length,
293 | forecast_length=forecast_length,
294 | dropout=dropout,
295 | tar_num=tar_num,
296 | cov_num=cov_num,
297 | tar_pos=tar_pos,
298 | )
299 |
300 | self.backcast_fc = nn.ModuleList(
301 | [nn.Linear(thetas_dim, backcast_length) for i in range(self.tar_num)]
302 | )
303 | self.forecast_fc = nn.ModuleList(
304 | [nn.Linear(thetas_dim, forecast_length) for i in range(self.tar_num)]
305 | )
306 |
307 | self.agg_layer = nn.ModuleList([nn.Linear(tar_num + cov_num, 1) for i in range(tar_num)])
308 |
309 | def forward(self, x):
310 | x = super().forward(x)
311 |
312 | theta_bs = torch.cat(
313 | [F.relu(self.theta_b_fc[n](x[..., n])).unsqueeze(2) for n in range(self.tar_num + self.cov_num)],
314 | 2,
315 | ) # encode x thetas_dim x n_output(n_tar+n_cov)
316 | theta_fs = torch.cat(
317 | [F.relu(self.theta_f_fc[n](x[..., n])).unsqueeze(2) for n in range(self.tar_num + self.cov_num)],
318 | 2,
319 | ) # encode x thetas_dim x n_output
320 |
321 | # lengths = n_target
322 | theta_b = [self.agg_layer[i](theta_bs).squeeze(-1) for i in range(self.tar_num)]
323 | theta_f = [self.agg_layer[i](theta_fs).squeeze(-1) for i in range(self.tar_num)]
324 |
325 | return (
326 | torch.cat(
327 | [self.backcast_fc[i](theta_b[i]).unsqueeze(2) for i in range(self.tar_num)],
328 | 2,
329 | ),
330 | torch.cat(
331 | [self.forecast_fc[i](theta_f[i]).unsqueeze(2) for i in range(self.tar_num)],
332 | 2,
333 | ),
334 | )
335 |
336 |
337 | class TimeDistributedEmbeddingBag(nn.EmbeddingBag):
338 | def __init__(self, *args, batch_first: bool = False, **kwargs):
339 | super().__init__(*args, **kwargs)
340 | self.batch_first = batch_first
341 |
342 | def forward(self, x):
343 |
344 | if len(x.size()) <= 2:
345 | return super().forward(x)
346 |
347 | # Squash samples and timesteps into a single axis
348 | x_reshape = x.contiguous().view(-1, x.size(-1)) # (samples * timesteps, input_size)
349 |
350 | y = super().forward(x_reshape)
351 |
352 | # We have to reshape Y
353 | if self.batch_first:
354 | y = y.contiguous().view(x.size(0), -1, y.size(-1)) # (samples, timesteps, output_size)
355 | else:
356 | y = y.view(-1, x.size(1), y.size(-1)) # (timesteps, samples, output_size)
357 | return y
358 |
359 |
360 | class MultiEmbedding(nn.Module):
361 | def __init__(
362 | self,
363 | embedding_sizes: Dict[str, Tuple[int, int]],
364 | categorical_groups: Dict[str, List[str]],
365 | embedding_paddings: List[str],
366 | x_categoricals: List[str],
367 | max_embedding_size: int = None,
368 | ):
369 | super().__init__()
370 | self.embedding_sizes = embedding_sizes
371 | self.categorical_groups = categorical_groups
372 | self.embedding_paddings = embedding_paddings
373 | self.max_embedding_size = max_embedding_size
374 | self.x_categoricals = x_categoricals
375 | self.init_embeddings()
376 |
377 | def init_embeddings(self):
378 | self.embeddings = nn.ModuleDict()
379 | for name in self.embedding_sizes.keys():
380 | embedding_size = self.embedding_sizes[name][1]
381 | if self.max_embedding_size is not None:
382 | embedding_size = min(embedding_size, self.max_embedding_size)
383 | # convert to list to become mutable
384 | self.embedding_sizes[name] = list(self.embedding_sizes[name])
385 | self.embedding_sizes[name][1] = embedding_size
386 | if name in self.categorical_groups: # embedding bag if related embeddings
387 | self.embeddings[name] = TimeDistributedEmbeddingBag(
388 | self.embedding_sizes[name][0], embedding_size, mode="sum", batch_first=True
389 | )
390 | else:
391 | if name in self.embedding_paddings:
392 | padding_idx = 0
393 | else:
394 | padding_idx = None
395 | self.embeddings[name] = nn.Embedding(
396 | self.embedding_sizes[name][0],
397 | embedding_size,
398 | padding_idx=padding_idx,
399 | )
400 |
401 | def total_embedding_size(self) -> int:
402 | return sum([size[1] for size in self.embedding_sizes.values()])
403 |
404 | def names(self) -> List[str]:
405 | return list(self.keys())
406 |
407 | def items(self):
408 | return self.embeddings.items()
409 |
410 | def keys(self) -> List[str]:
411 | return self.embeddings.keys()
412 |
413 | def values(self):
414 | return self.embeddings.values()
415 |
416 | def __getitem__(self, name: str) -> Union[nn.Embedding, TimeDistributedEmbeddingBag]:
417 | return self.embeddings[name]
418 |
419 | def forward(self, x: torch.Tensor, flat: bool = False) -> Union[Dict[str, torch.Tensor], torch.Tensor]:
420 | out = {}
421 | for name, emb in self.embeddings.items():
422 | if name in self.categorical_groups:
423 | out[name] = emb(
424 | x[
425 | ...,
426 | [self.x_categoricals.index(cat_name) for cat_name in self.categorical_groups[name]],
427 | ]
428 | )
429 | else:
430 | out[name] = emb(x[..., self.x_categoricals.index(name)])
431 | if flat:
432 | out = torch.cat([v for v in out.values()], dim=-1)
433 | return out
434 |
--------------------------------------------------------------------------------
/use_cases/credit_score/credit_score.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "df64f867",
6 | "metadata": {},
7 | "source": [
8 | "Example: Credit Score Classification\n"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "7e99a81f",
14 | "metadata": {},
15 | "source": [
16 | "`FeatureStore` is a model-agnostic tool aiming to help data scientists and algorithm engineers get rid of tiring data storing and merging tasks.\n",
17 | "
`FeatureStore` not only work on single-dimension data such as classification and prediction, but also work on time-series data.\n",
18 | "
After collecting data, all you need to do is config several straight-forward .yml files, then you can focus on models/algorithms and leave all exhausting preparation to `FeatureStore`.\n",
19 | "
Here we present credit scoring mission as a single-dimension data demo, it takes features like wages, loan records to decide whether to grant credit or not."
20 | ]
21 | },
22 | {
23 | "cell_type": "markdown",
24 | "id": "1eebee68",
25 | "metadata": {},
26 | "source": [
27 | "Import packages"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 1,
33 | "id": "abc14dec",
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "import torch\n",
38 | "import os\n",
39 | "import numpy as np\n",
40 | "import zipfile\n",
41 | "import tempfile\n",
42 | "from torch import nn\n",
43 | "from torch.utils.data import DataLoader\n",
44 | "from f2ai.featurestore import FeatureStore\n",
45 | "from f2ai.common.sampler import GroupFixednbrSampler\n",
46 | "from f2ai.common.collect_fn import classify_collet_fn\n",
47 | "from f2ai.common.utils import get_bucket_from_oss_url\n",
48 | "from f2ai.models.earlystop import EarlyStopping \n",
49 | "from f2ai.models.sequential import SimpleClassify"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "id": "6d14ac45",
55 | "metadata": {},
56 | "source": [
57 | "Download demo project files from `OSS` "
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": 2,
63 | "id": "eda81c81",
64 | "metadata": {},
65 | "outputs": [
66 | {
67 | "name": "stdout",
68 | "output_type": "stream",
69 | "text": [
70 | "Project downloaded and saved in /tmp/558u5bgc/xyz_test_data\n"
71 | ]
72 | }
73 | ],
74 | "source": [
75 | "download_from = \"oss://aiexcelsior-shanghai-test/xyz_test_data/credit-score.zip\"\n",
76 | "save_path = '/tmp/'\n",
77 | "save_dir = tempfile.mkdtemp(prefix=save_path)\n",
78 | "bucket, key = get_bucket_from_oss_url(download_from)\n",
79 | "dest_zip_filepath = os.path.join(save_dir,key)\n",
80 | "os.makedirs(os.path.dirname(dest_zip_filepath), exist_ok=True)\n",
81 | "bucket.get_object_to_file(key, dest_zip_filepath)\n",
82 | "zipfile.ZipFile(dest_zip_filepath).extractall(dest_zip_filepath.rsplit('/',1)[0])\n",
83 | "os.remove(dest_zip_filepath)\n",
84 | "print(f\"Project downloaded and saved in {dest_zip_filepath.rsplit('/',1)[0]}\")"
85 | ]
86 | },
87 | {
88 | "cell_type": "markdown",
89 | "id": "a738d2da",
90 | "metadata": {},
91 | "source": [
92 | "Initialize `FeatureStore`"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 3,
98 | "id": "cbdaf53b",
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "TIME_COL = 'event_timestamp'\n",
103 | "fs = FeatureStore(f\"file://{save_dir}/{key.rstrip('.zip')}\")"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 4,
109 | "id": "7cbff283",
110 | "metadata": {},
111 | "outputs": [
112 | {
113 | "name": "stdout",
114 | "output_type": "stream",
115 | "text": [
116 | "All features are: ['person_home_ownership', 'location_type', 'mortgage_due', 'tax_returns_filed', 'population', 'person_emp_length', 'loan_amnt', 'person_age', 'missed_payments_2y', 'state', 'total_wages', 'hard_pulls', 'loan_int_rate', 'missed_payments_6m', 'student_loan_due', 'city', 'person_income', 'loan_intent', 'bankruptcies', 'vehicle_loan_due', 'missed_payments_1y', 'credit_card_due']\n"
117 | ]
118 | }
119 | ],
120 | "source": [
121 | "print(f\"All features are: {fs._get_feature_to_use(fs.services['credit_scoring_v1'])}\")"
122 | ]
123 | },
124 | {
125 | "cell_type": "markdown",
126 | "id": "792ba5e1",
127 | "metadata": {},
128 | "source": [
129 | "Get the time range of available data"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": 5,
135 | "id": "8eb8697f",
136 | "metadata": {},
137 | "outputs": [
138 | {
139 | "name": "stdout",
140 | "output_type": "stream",
141 | "text": [
142 | "Earliest timestamp: 2020-08-25 20:34:41.361000+00:00\n",
143 | "Latest timestamp: 2021-08-25 20:34:41.361000+00:00\n"
144 | ]
145 | }
146 | ],
147 | "source": [
148 | "print(f'Earliest timestamp: {fs.get_latest_entities(\"credit_scoring_v1\")[TIME_COL].min()}')\n",
149 | "print(f'Latest timestamp: {fs.get_latest_entities(\"credit_scoring_v1\")[TIME_COL].max()}')"
150 | ]
151 | },
152 | {
153 | "cell_type": "markdown",
154 | "id": "13419833",
155 | "metadata": {},
156 | "source": [
157 | "Split the train / valid / test data at approximately 7/2/1, use `GroupFixednbrSampler` to downsample original data and return a `torch.IterableDataset`"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 6,
163 | "id": "ec10ddd5",
164 | "metadata": {},
165 | "outputs": [],
166 | "source": [
167 | "ds_train = fs.get_dataset(\n",
168 | " service=\"credit_scoring_v1\",\n",
169 | " sampler=GroupFixednbrSampler(\n",
170 | " time_bucket=\"5 days\",\n",
171 | " stride=1,\n",
172 | " group_ids=None,\n",
173 | " group_names=None,\n",
174 | " start=\"2020-08-20\",\n",
175 | " end=\"2021-04-30\",\n",
176 | " ),\n",
177 | " )\n",
178 | "ds_valid = fs.get_dataset(\n",
179 | " service=\"credit_scoring_v1\",\n",
180 | " sampler=GroupFixednbrSampler(\n",
181 | " time_bucket=\"5 days\",\n",
182 | " stride=1,\n",
183 | " group_ids=None,\n",
184 | " group_names=None,\n",
185 | " start=\"2021-04-30\",\n",
186 | " end=\"2021-07-31\",\n",
187 | " ),\n",
188 | " )\n",
189 | "ds_test= fs.get_dataset(\n",
190 | " service=\"credit_scoring_v1\",\n",
191 | " sampler=GroupFixednbrSampler(\n",
192 | " time_bucket=\"1 days\",\n",
193 | " stride=1,\n",
194 | " group_ids=None,\n",
195 | " group_names=None,\n",
196 | " start=\"2021-07-31\",\n",
197 | " end=\"2021-08-31\",\n",
198 | " ),\n",
199 | " )"
200 | ]
201 | },
202 | {
203 | "cell_type": "markdown",
204 | "id": "e4d8d35f",
205 | "metadata": {},
206 | "source": [
207 | "Using `FeatureStore.stats` to obtain `statistical results` for data processing"
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "execution_count": 7,
213 | "id": "9b5b9ced",
214 | "metadata": {},
215 | "outputs": [
216 | {
217 | "name": "stdout",
218 | "output_type": "stream",
219 | "text": [
220 | "Number of unique values of categorical features are: {'person_home_ownership': 4, 'location_type': 1, 'state': 51, 'city': 8166, 'loan_intent': 6}\n"
221 | ]
222 | }
223 | ],
224 | "source": [
225 | "# catgorical features\n",
226 | "features_cat = [ \n",
227 | " fea\n",
228 | " for fea in fs.services[\"credit_scoring_v1\"].get_feature_names(fs.feature_views)\n",
229 | " if fea not in fs.services[\"credit_scoring_v1\"].get_feature_names(fs.feature_views,is_numeric=True)\n",
230 | "]\n",
231 | "# get unique item number to do labelencoder\n",
232 | "cat_unique = fs.stats(\n",
233 | " \"credit_scoring_v1\",\n",
234 | " fn=\"unique\",\n",
235 | " group_key=[],\n",
236 | " start=\"2020-08-01\",\n",
237 | " end=\"2021-04-30\",\n",
238 | " features=features_cat,\n",
239 | ").to_dict()\n",
240 | "cat_count = {key: len(cat_unique[key]) for key in cat_unique.keys()}\n",
241 | "print(f\"Number of unique values of categorical features are: {cat_count}\")"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": 8,
247 | "id": "8a0d6e00",
248 | "metadata": {},
249 | "outputs": [
250 | {
251 | "name": "stdout",
252 | "output_type": "stream",
253 | "text": [
254 | "Min-Max boundary of continuous features are: {'population': [272.0, 88503.0], 'missed_payments_6m': [0.0, 1.0], 'person_age': [20.0, 144.0], 'missed_payments_2y': [0.0, 7.0], 'student_loan_due': [0.0, 49997.0], 'missed_payments_1y': [0.0, 3.0], 'bankruptcies': [0.0, 2.0], 'mortgage_due': [33.0, 1999896.0], 'vehicle_loan_due': [1.0, 29998.0], 'total_wages': [0.0, 2132869892.0], 'tax_returns_filed': [250.0, 47778.0], 'credit_card_due': [0.0, 9998.0], 'person_emp_length': [0.0, 41.0], 'hard_pulls': [0.0, 10.0], 'loan_int_rate': [5.42, 23.22], 'loan_amnt': [500.0, 35000.0], 'person_income': [4000.0, 6000000.0]}\n"
255 | ]
256 | }
257 | ],
258 | "source": [
259 | "# contiouns features \n",
260 | "cont_scalar_max = fs.stats(\n",
261 | " \"credit_scoring_v1\", fn=\"max\", group_key=[], start=\"2020-08-01\", end=\"2021-04-30\"\n",
262 | ").to_dict()\n",
263 | "cont_scalar_min = fs.stats(\n",
264 | " \"credit_scoring_v1\", fn=\"min\", group_key=[], start=\"2020-08-01\", end=\"2021-04-30\"\n",
265 | ").to_dict()\n",
266 | "cont_scalar = {key: [cont_scalar_min[key], cont_scalar_max[key]] for key in cont_scalar_min.keys()}\n",
267 | "print(f\"Min-Max boundary of continuous features are: {cont_scalar}\")"
268 | ]
269 | },
270 | {
271 | "cell_type": "markdown",
272 | "id": "6a2a08f5",
273 | "metadata": {},
274 | "source": [
275 | "Construct `torch.DataLoader` from `torch.IterableDataset` for modelling\n",
276 | "
Here we compose data-preprocess in `collect_fn`, so the time range of `statistical results` used to `.fit()` should be corresponding to `train` data only so as to avoid information leakage. "
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": 9,
282 | "id": "1379abba",
283 | "metadata": {},
284 | "outputs": [],
285 | "source": [
286 | "batch_size=16\n",
287 | "\n",
288 | "train_dataloader = DataLoader( \n",
289 | " ds_train.to_pytorch(),\n",
290 | " collate_fn=lambda x: classify_collet_fn(\n",
291 | " x,\n",
292 | " cat_coder=cat_unique,\n",
293 | " cont_scalar=cont_scalar,\n",
294 | " label=fs._get_feature_to_use(fs.services[\"credit_scoring_v1\"].get_label_view(fs.label_views)),\n",
295 | " ),\n",
296 | " batch_size=batch_size,\n",
297 | " drop_last=True,\n",
298 | ")\n",
299 | "\n",
300 | "valie_dataloader = DataLoader( \n",
301 | " ds_valid.to_pytorch(),\n",
302 | " collate_fn=lambda x: classify_collet_fn(\n",
303 | " x,\n",
304 | " cat_coder=cat_unique,\n",
305 | " cont_scalar=cont_scalar,\n",
306 | " label=fs._get_feature_to_use(fs.services[\"credit_scoring_v1\"].get_label_view(fs.label_views)),\n",
307 | " ),\n",
308 | " batch_size=batch_size,\n",
309 | " drop_last=False,\n",
310 | ")\n",
311 | "\n",
312 | "test_dataloader = DataLoader( \n",
313 | " ds_valid.to_pytorch(),\n",
314 | " collate_fn=lambda x: classify_collet_fn(\n",
315 | " x,\n",
316 | " cat_coder=cat_unique,\n",
317 | " cont_scalar=cont_scalar,\n",
318 | " label=fs._get_feature_to_use(fs.services[\"credit_scoring_v1\"].get_label_view(fs.label_views)),\n",
319 | " ),\n",
320 | " drop_last=False,\n",
321 | ")"
322 | ]
323 | },
324 | {
325 | "cell_type": "markdown",
326 | "id": "1fd7cb71",
327 | "metadata": {},
328 | "source": [
329 | "Customize `model`, `optimizer` and `loss` function suitable to task"
330 | ]
331 | },
332 | {
333 | "cell_type": "code",
334 | "execution_count": 10,
335 | "id": "4a9eb8be",
336 | "metadata": {},
337 | "outputs": [],
338 | "source": [
339 | "model = SimpleClassify(\n",
340 | " cont_nbr=len(cont_scalar_max), cat_nbr=len(cat_count), emd_dim=8, max_types=max(cat_count.values()),hidden_size=4\n",
341 | ")\n",
342 | "optimizer = torch.optim.Adam(model.parameters(), lr=0.001) \n",
343 | "loss_fn = nn.BCELoss() "
344 | ]
345 | },
346 | {
347 | "cell_type": "markdown",
348 | "id": "bb956ec5",
349 | "metadata": {},
350 | "source": [
351 | "Use `train_dataloader` to train while `valie_dataloader` to guide `earlystop`"
352 | ]
353 | },
354 | {
355 | "cell_type": "code",
356 | "execution_count": 12,
357 | "id": "e3321113",
358 | "metadata": {},
359 | "outputs": [
360 | {
361 | "name": "stdout",
362 | "output_type": "stream",
363 | "text": [
364 | "epoch: 0 done, train loss: 0.9228071888287862, valid loss: 0.8917452792326609\n",
365 | "epoch: 1 done, train loss: 0.8228938817977905, valid loss: 0.806285967429479\n",
366 | "epoch: 2 done, train loss: 0.7159536878267924, valid loss: 0.715582956870397\n",
367 | "epoch: 3 done, train loss: 0.6077397664388021, valid loss: 0.6279712617397308\n",
368 | "epoch: 4 done, train loss: 0.523370603720347, valid loss: 0.5640117128690084\n",
369 | "epoch: 5 done, train loss: 0.46912830471992495, valid loss: 0.530453751484553\n",
370 | "epoch: 6 done, train loss: 0.4288410405317942, valid loss: 0.5177373588085175\n",
371 | "epoch: 7 done, train loss: 0.4263931175072988, valid loss: 0.5144786983728409\n",
372 | "epoch: 8 done, train loss: 0.4069450835386912, valid loss: 0.5147957305113474\n",
373 | "EarlyStopping counter: 1 out of 5\n",
374 | "epoch: 9 done, train loss: 0.4173213442166646, valid loss: 0.5159803281227747\n",
375 | "EarlyStopping counter: 2 out of 5\n",
376 | "epoch: 10 done, train loss: 0.43354902466138207, valid loss: 0.5168851017951965\n",
377 | "EarlyStopping counter: 3 out of 5\n",
378 | "epoch: 11 done, train loss: 0.3854812800884247, valid loss: 0.5179851104815801\n",
379 | "EarlyStopping counter: 4 out of 5\n",
380 | "epoch: 12 done, train loss: 0.3868887344996134, valid loss: 0.5189551711082458\n",
381 | "EarlyStopping counter: 5 out of 5\n",
382 | "Trigger earlystop, stop epoch at 12\n"
383 | ]
384 | }
385 | ],
386 | "source": [
387 | "# you can also use any ready-to-use training frame like ignite, pytorch-lightening...\n",
388 | "early_stop = EarlyStopping(save_path=f\"{save_dir}/{key.rstrip('.zip')}\",patience=5,delta=1e-6)\n",
389 | "for epoch in range(50):\n",
390 | " train_loss = []\n",
391 | " valid_loss = []\n",
392 | " \n",
393 | " model.train()\n",
394 | " for x, y in train_dataloader:\n",
395 | " pred_label = model(x)\n",
396 | " true_label = y\n",
397 | " loss = loss_fn(pred_label, true_label)\n",
398 | " train_loss.append(loss.item())\n",
399 | " optimizer.zero_grad()\n",
400 | " loss.backward()\n",
401 | " optimizer.step()\n",
402 | " \n",
403 | " model.eval()\n",
404 | " for x, y in valie_dataloader:\n",
405 | " pred_label = model(x)\n",
406 | " true_label = y\n",
407 | " loss = loss_fn(pred_label, true_label)\n",
408 | " valid_loss.append(loss.item())\n",
409 | "\n",
410 | " print(f\"epoch: {epoch} done, train loss: {np.mean(train_loss)}, valid loss: {np.mean(valid_loss)}\")\n",
411 | " early_stop(np.mean(valid_loss),model)\n",
412 | " if early_stop.early_stop:\n",
413 | " print(f\"Trigger earlystop, stop epoch at {epoch}\")\n",
414 | " break"
415 | ]
416 | },
417 | {
418 | "cell_type": "markdown",
419 | "id": "0f15ea7e",
420 | "metadata": {},
421 | "source": [
422 | "Get prediction result of `test_dataloader`"
423 | ]
424 | },
425 | {
426 | "cell_type": "code",
427 | "execution_count": 13,
428 | "id": "d87afe0c",
429 | "metadata": {},
430 | "outputs": [],
431 | "source": [
432 | "model = torch.load(os.path.join(f\"{save_dir}/{key.rstrip('.zip')}\",'best_chekpnt.pk'))\n",
433 | "model.eval()\n",
434 | "preds=[]\n",
435 | "trues=[]\n",
436 | "for x,y in test_dataloader:\n",
437 | " pred = model(x)\n",
438 | " pred_label = 1 if pred.cpu().detach().numpy() >0.5 else 0\n",
439 | " preds.append(pred_label)\n",
440 | " trues.append(y.cpu().detach().numpy())"
441 | ]
442 | },
443 | {
444 | "cell_type": "markdown",
445 | "id": "2d939513",
446 | "metadata": {},
447 | "source": [
448 | "Model Evaluation"
449 | ]
450 | },
451 | {
452 | "cell_type": "code",
453 | "execution_count": 14,
454 | "id": "f88810eb",
455 | "metadata": {},
456 | "outputs": [
457 | {
458 | "name": "stdout",
459 | "output_type": "stream",
460 | "text": [
461 | "Accuracy: 0.7956989247311828\n"
462 | ]
463 | }
464 | ],
465 | "source": [
466 | "# accuracy\n",
467 | "acc = [1 if preds[i]==trues[i] else 0 for i in range(len(trues))]\n",
468 | "print(f\"Accuracy: {np.sum(acc) / len(acc)}\")"
469 | ]
470 | }
471 | ],
472 | "metadata": {
473 | "kernelspec": {
474 | "display_name": "autonn-3.8.7",
475 | "language": "python",
476 | "name": "python3"
477 | },
478 | "language_info": {
479 | "codemirror_mode": {
480 | "name": "ipython",
481 | "version": 3
482 | },
483 | "file_extension": ".py",
484 | "mimetype": "text/x-python",
485 | "name": "python",
486 | "nbconvert_exporter": "python",
487 | "pygments_lexer": "ipython3",
488 | "version": "3.8.12"
489 | },
490 | "vscode": {
491 | "interpreter": {
492 | "hash": "5840e4ed671345474330e8fce6ab52c58896a3935f0e728b8dbef1ddfad82808"
493 | }
494 | }
495 | },
496 | "nbformat": 4,
497 | "nbformat_minor": 5
498 | }
499 |
--------------------------------------------------------------------------------