├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── entity_framework ├── __init__.py ├── abstract_entity_tree.py ├── entity.py ├── registry.py ├── repository.py ├── storages │ ├── __init__.py │ └── sqlalchemy │ │ ├── __init__.py │ │ ├── constructing_model │ │ ├── __init__.py │ │ ├── raw_model.py │ │ └── visitor.py │ │ ├── native_type_to_column.py │ │ ├── populating_aggregates │ │ ├── __init__.py │ │ └── visitor.py │ │ ├── populating_model │ │ ├── __init__.py │ │ └── visitor.py │ │ ├── querying │ │ ├── __init__.py │ │ └── visitor.py │ │ ├── registry.py │ │ └── types.py └── tests │ ├── __init__.py │ ├── abstract_entity_tree │ ├── __init__.py │ ├── test_build.py │ └── test_visitor.py │ ├── conftest.py │ ├── storages │ ├── __init__.py │ ├── conftest.py │ └── sqlalchemy │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── test_nested_value_objects.py │ │ └── test_one_level_nested_entities.py │ └── test_entity.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | ve36/ 2 | .mypy_cache/ 3 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.6" 4 | - "3.7-dev" 5 | services: 6 | - postgresql 7 | before_script: 8 | - psql -c 'create database travis_ci_test;' -U postgres 9 | install: 10 | - pip install -r requirements.txt 11 | script: 12 | - black --check -l 120 ./entity_framework/ 13 | - flake8 --max-line-length 120 ./entity_framework/ 14 | - pytest --sqlalchemy-postgres-url="postgresql://postgres:@localhost:5432/travis_ci_test" entity_framework/tests/ 15 | 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Sebastian Buczyński 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Python Entity Framework 2 | Persistence for attr-based Python classes. 3 | 4 | *WARNING* A lot of things here is not yet working, there is some mess. The project is stil in the prototyping phase. 5 | 6 | # Example 7 | ```python 8 | from entity_framework import Entity, Identity, Repository 9 | from entity_framework.storages.sqlalchemy import SqlAlchemyRepo 10 | 11 | 12 | # Write your business entities in almost-plain Python 13 | class Customer(Entity): 14 | id: Identity[int] 15 | full_name: str 16 | 17 | 18 | # Get abstract base class for repo 19 | CustomerRepo = Repository[Customer, int] # first argument - entity class, second - Identity field type 20 | 21 | 22 | # Setup your persistence as usual 23 | from sqlalchemy.ext.declarative import declarative_base 24 | 25 | Base = declarative_base() 26 | 27 | 28 | # Setup registry and get concrete class 29 | from entity_framework.storages.sqlalchemy import SqlAlchemyRepo 30 | from entity_framework.storages.sqlalchemy.registry import SaRegistry 31 | 32 | 33 | Registry = SaRegistry() 34 | 35 | 36 | class SaCustomerRepo(SqlAlchemyRepo, CustomerRepo): 37 | base = Base 38 | registry = Registry 39 | ``` 40 | Voilà. Python Entity Framework will generate SQLAlchemy's model for you. *They are properly detected by alembic (yay!)* Additionally, `SaCustomerRepo` will have two methods - save & get to respectively persist and fetch your entity. 41 | 42 | # WORK IN PROGRESS 43 | Everything is subjected to change, including name of the library and address of this repository. Code inside may be inconsistent and is undergoing significant refactorings all the time. 44 | 45 | ## Aim 46 | To get rid of necessity of manual writing code for persisting attr-based business entities AKA aggregates. 47 | 48 | 49 | ## Roadmap 50 | * Support SQLAlchemy with possibilities of overriding implementation partially (e.g. single column definitions) or entirely 51 | * Use SQLAlchemy's Session as UnitOfWork 52 | * Optimize code for populating models/entities 53 | * Support MongoDB with the same set of functionalities as SQLAlchemy storage 54 | -------------------------------------------------------------------------------- /entity_framework/__init__.py: -------------------------------------------------------------------------------- 1 | from entity_framework.entity import Entity, Identity, ValueObject 2 | from entity_framework.repository import Repository 3 | 4 | 5 | __all__ = ["Entity", "Identity", "ValueObject", "Repository"] 6 | -------------------------------------------------------------------------------- /entity_framework/abstract_entity_tree.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import inspect 3 | import typing 4 | from collections import deque 5 | 6 | import attr 7 | import inflection 8 | 9 | from entity_framework.entity import Entity, Identity, ValueObject, EntityOrVoType 10 | 11 | 12 | def _is_generic(field_type: typing.Type) -> bool: 13 | return hasattr(field_type, "__origin__") 14 | 15 | 16 | def _get_wrapped_type(wrapped_type: typing.Optional) -> typing.Type: 17 | return wrapped_type.__args__[0] 18 | 19 | 20 | def _is_field_optional(field_type: typing.Type) -> bool: 21 | return field_type.__origin__ == typing.Union and isinstance(None, field_type.__args__[1]) 22 | 23 | 24 | def _is_nested_entity_or_vo(field_type: typing.Type) -> bool: 25 | return issubclass(field_type, (Entity, ValueObject)) or _is_optional_entity_or_vo(field_type) 26 | 27 | 28 | def _is_optional_entity_or_vo(field_type: typing.Type) -> bool: 29 | if _is_generic(field_type) and _is_field_optional(field_type): 30 | return issubclass(_get_wrapped_type(field_type), (Entity, ValueObject)) 31 | return False 32 | 33 | 34 | def _is_identity(field_type: typing.Type) -> bool: 35 | return getattr(field_type, "__origin__", None) == Identity 36 | 37 | 38 | def _is_list_of_entities_or_vos(field_type: typing.Type) -> bool: 39 | return ( 40 | _is_generic(field_type) 41 | and field_type.__origin__ == typing.List 42 | and issubclass(_get_wrapped_type(field_type), (Entity, ValueObject)) 43 | ) 44 | 45 | 46 | class Visitor: 47 | def traverse_from(self, node: "Node") -> None: 48 | node.accept(self) 49 | for child in node.children: 50 | self.traverse_from(child) 51 | node.farewell(self) 52 | 53 | def visit_field(self, field: "FieldNode") -> None: 54 | pass 55 | 56 | def leave_field(self, field: "FieldNode") -> None: 57 | pass 58 | 59 | def visit_entity(self, entity: "EntityNode") -> None: 60 | pass 61 | 62 | def leave_entity(self, entity: "EntityNode") -> None: 63 | pass 64 | 65 | def visit_value_object(self, value_object: "ValueObjectNode") -> None: 66 | pass 67 | 68 | def leave_value_object(self, value_object: "ValueObjectNode") -> None: 69 | pass 70 | 71 | def visit_list_of_entities(self, list_of_entities: "ListOfEntitiesNode") -> None: 72 | pass 73 | 74 | def leave_list_of_entities(self, list_of_entities: "ListOfEntitiesNode") -> None: 75 | pass 76 | 77 | def visit_list_of_value_objects(self, list_of_value_objects: "ListOfValueObjectsNode") -> None: 78 | pass 79 | 80 | def leave_list_of_value_objects(self, list_of_value_objects: "ListOfValueObjectsNode") -> None: 81 | pass 82 | 83 | def noop(self, _: "Node") -> None: 84 | pass 85 | 86 | 87 | class NodeMeta(type): 88 | def __new__(mcs, name: str, bases: tuple, namespace: dict) -> typing.Type: 89 | cls = super().__new__(mcs, name, bases, namespace) 90 | if inspect.isabstract(cls): 91 | return cls 92 | return attr.s(auto_attribs=True)(cls) 93 | 94 | 95 | class Node(metaclass=NodeMeta): 96 | name: str 97 | type: typing.Type 98 | optional: bool = False 99 | children: typing.Tuple["Node", ...] = attr.Factory(list) 100 | 101 | @abc.abstractmethod 102 | def accept(self, visitor: Visitor) -> None: 103 | pass 104 | 105 | @abc.abstractmethod 106 | def farewell(self, visitor: Visitor) -> None: 107 | pass 108 | 109 | 110 | class FieldNode(Node): 111 | is_identity: bool = False 112 | 113 | def accept(self, visitor: Visitor) -> None: 114 | visitor.visit_field(self) 115 | 116 | def farewell(self, visitor: Visitor) -> None: 117 | visitor.leave_field(self) 118 | 119 | 120 | class EntityNode(Node): 121 | def accept(self, visitor: Visitor) -> None: 122 | visitor.visit_entity(self) 123 | 124 | def farewell(self, visitor: Visitor) -> None: 125 | visitor.leave_entity(self) 126 | 127 | 128 | class ValueObjectNode(Node): 129 | def accept(self, visitor: Visitor) -> None: 130 | visitor.visit_value_object(self) 131 | 132 | def farewell(self, visitor: Visitor) -> None: 133 | visitor.leave_value_object(self) 134 | 135 | 136 | class ListOfEntitiesNode(Node): 137 | def accept(self, visitor: Visitor) -> None: 138 | visitor.visit_list_of_entities(self) 139 | 140 | def farewell(self, visitor: Visitor) -> None: 141 | visitor.leave_list_of_entities(self) 142 | 143 | 144 | class ListOfValueObjectsNode(Node): 145 | def accept(self, visitor: Visitor) -> None: 146 | visitor.visit_list_of_value_objects(self) 147 | 148 | def farewell(self, visitor: Visitor) -> None: 149 | visitor.leave_list_of_value_objects(self) 150 | 151 | 152 | @attr.s(auto_attribs=True) 153 | class AbstractEntityTree: 154 | root: EntityNode 155 | 156 | def __iter__(self) -> typing.Generator[Node, None, None]: 157 | def iterate_dfs() -> typing.Generator[Node, None, None]: 158 | nodes_left: typing.Deque[Node] = deque([self.root]) 159 | 160 | while nodes_left: 161 | current = nodes_left.pop() 162 | yield current 163 | nodes_left.extend(current.children[::-1]) 164 | 165 | return iterate_dfs() 166 | 167 | 168 | def build(root: typing.Type[Entity]) -> AbstractEntityTree: 169 | # TODO: children could be tuple, not list. Then, Nodes would be hashable. 170 | def parse_node(current_root: EntityOrVoType, name: str) -> Node: 171 | node_name = name 172 | is_list = False 173 | if _is_list_of_entities_or_vos(current_root): 174 | node_optional = False 175 | node_type = _get_wrapped_type(current_root) 176 | is_list = True 177 | elif _is_optional_entity_or_vo(current_root): 178 | node_optional = True 179 | node_type = _get_wrapped_type(current_root) 180 | else: 181 | node_optional = False 182 | node_type = current_root 183 | node_children = [] 184 | 185 | for field in attr.fields(node_type): 186 | field_type = field.type 187 | field_name = field.name 188 | 189 | if _is_nested_entity_or_vo(field_type) or _is_list_of_entities_or_vos(field_type): 190 | node_children.append(parse_node(field_type, field_name)) 191 | continue 192 | 193 | field_optional = False 194 | is_identity = False 195 | 196 | if _is_generic(field.type): 197 | if _is_identity(field_type): 198 | field_type = _get_wrapped_type(field.type) 199 | is_identity = True 200 | elif _is_field_optional(field.type): 201 | field_type = _get_wrapped_type(field.type) 202 | field_optional = True 203 | else: 204 | raise Exception(f"Unhandled Generic type - {field_type}") 205 | 206 | node_children.append(FieldNode(field_name, field_type, field_optional, (), is_identity)) 207 | 208 | node_children = tuple(node_children) 209 | if issubclass(node_type, Entity): 210 | if is_list: 211 | return ListOfEntitiesNode(node_name, node_type, node_optional, node_children) 212 | return EntityNode(node_name, node_type, node_optional, node_children) 213 | 214 | if is_list: 215 | return ListOfValueObjectsNode(node_name, node_type, node_optional, node_children) 216 | return ValueObjectNode(node_name, node_type, node_optional, node_children) 217 | 218 | root_node = parse_node(root, inflection.underscore(root.__name__)) 219 | return AbstractEntityTree(root_node) 220 | -------------------------------------------------------------------------------- /entity_framework/entity.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import typing 3 | 4 | import attr 5 | 6 | 7 | class EntityWithoutIdentity(TypeError): 8 | pass 9 | 10 | 11 | class ValueObjectWithIdentity(TypeError): 12 | pass 13 | 14 | 15 | class EntityNestedInValueObject(TypeError): 16 | pass 17 | 18 | 19 | T = typing.TypeVar("T") 20 | 21 | 22 | class Identity(typing.Generic[T]): 23 | @classmethod 24 | def is_identity(cls, field: attr.Attribute) -> None: 25 | return getattr(field.type, "__origin__", None) == cls 26 | 27 | 28 | class EntityMeta(abc.ABCMeta): 29 | def __new__(mcs, name: str, bases: tuple, namespace: dict): 30 | cls = super().__new__(mcs, name, bases, namespace) 31 | if name == "Entity": 32 | return cls 33 | attr_cls = attr.s(auto_attribs=True)(cls) 34 | if not any(Identity.is_identity(field) for field in attr.fields(attr_cls)): 35 | raise EntityWithoutIdentity 36 | return attr_cls 37 | 38 | 39 | class Entity(metaclass=EntityMeta): 40 | pass 41 | 42 | 43 | class ValueObjectMeta(abc.ABCMeta): 44 | def __new__(mcs, name: str, bases: tuple, namespace: dict): 45 | cls = super().__new__(mcs, name, bases, namespace) 46 | if name == "ValueObject": 47 | return cls 48 | attr_cls = attr.s(auto_attribs=True)(cls) 49 | fields = attr.fields(attr_cls) 50 | if any(Identity.is_identity(field) for field in fields): 51 | raise ValueObjectWithIdentity 52 | if any(_is_nested_entity(field.type) for field in fields): 53 | raise EntityNestedInValueObject 54 | return attr_cls 55 | 56 | 57 | def _is_nested_entity(field_type: typing.Type) -> bool: 58 | try: 59 | return issubclass(field_type, Entity) or ( 60 | field_type.__origin__ == typing.Union 61 | and isinstance(None, field_type.__args__[1]) 62 | and issubclass(field_type.__args__[0], Entity) 63 | ) 64 | except AttributeError: 65 | return False 66 | 67 | 68 | class ValueObject(metaclass=ValueObjectMeta): 69 | pass 70 | 71 | 72 | EntityOrVo = typing.Union[Entity, ValueObject] 73 | EntityOrVoType = typing.Union[typing.Type[Entity], typing.Type[ValueObject]] 74 | -------------------------------------------------------------------------------- /entity_framework/registry.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Type 2 | 3 | import attr 4 | 5 | from entity_framework.abstract_entity_tree import AbstractEntityTree 6 | from entity_framework.entity import Entity 7 | 8 | 9 | @attr.s(auto_attribs=True) 10 | class Registry: 11 | entities_to_aets: Dict[Type[Entity], AbstractEntityTree] = attr.Factory(dict) 12 | -------------------------------------------------------------------------------- /entity_framework/repository.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import inspect 3 | import typing 4 | 5 | from entity_framework.abstract_entity_tree import build 6 | from entity_framework.registry import Registry 7 | 8 | 9 | EntityType = typing.TypeVar("EntityType") 10 | IdentityType = typing.TypeVar("IdentityType") 11 | 12 | 13 | class RepositoryMeta(typing.GenericMeta): 14 | def __new__( 15 | mcs, 16 | name: str, 17 | bases: typing.Tuple[typing.Type, ...], 18 | namespace: dict, 19 | tvars=None, 20 | args=None, 21 | origin=None, 22 | extra=None, 23 | orig_bases=None, 24 | ) -> typing.Type: 25 | cls = super().__new__( 26 | mcs, name, bases, namespace, tvars=tvars, args=args, origin=origin, extra=extra, orig_bases=orig_bases 27 | ) 28 | if not inspect.isabstract(cls): 29 | assert isinstance(getattr(cls, "registry", None), Registry) 30 | last_base_class_origin = getattr(bases[-1], "__origin__", None) 31 | assert ( 32 | last_base_class_origin is ReadOnlyRepository or last_base_class_origin is Repository 33 | ) # TODO: komunikat? 34 | args = getattr(bases[-1], "__args__", None) 35 | if args: 36 | entity_cls, _identity_cls = args 37 | if entity_cls not in cls.registry.entities_to_aets: 38 | cls.registry.entities_to_aets[entity_cls] = build(entity_cls) 39 | 40 | if hasattr(cls, "prepare"): 41 | entity_cls, _identity_cls = bases[-1].__args__ 42 | cls.prepare(entity_cls) 43 | 44 | return cls 45 | 46 | 47 | class ReadOnlyRepository(typing.Generic[EntityType, IdentityType], metaclass=RepositoryMeta): 48 | @classmethod 49 | @abc.abstractmethod 50 | def prepare(self) -> None: 51 | pass 52 | 53 | @abc.abstractmethod 54 | def get(self, identity: IdentityType) -> EntityType: 55 | pass 56 | 57 | 58 | class Repository(typing.Generic[EntityType, IdentityType], metaclass=RepositoryMeta): 59 | @classmethod 60 | @abc.abstractmethod 61 | def prepare(self) -> None: 62 | pass 63 | 64 | @abc.abstractmethod 65 | def get(self, identity: IdentityType) -> EntityType: 66 | pass 67 | 68 | @abc.abstractmethod 69 | def save(self, entity: EntityType) -> None: 70 | pass 71 | -------------------------------------------------------------------------------- /entity_framework/storages/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Enforcer/python_entity_framework/015221c068c23834aa3663c6239947c01e38f40a/entity_framework/storages/__init__.py -------------------------------------------------------------------------------- /entity_framework/storages/sqlalchemy/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Type 2 | 3 | from sqlalchemy.orm import Session, Query, exc 4 | from sqlalchemy.ext.declarative import DeclarativeMeta 5 | 6 | from entity_framework.repository import EntityType, IdentityType 7 | from entity_framework.storages.sqlalchemy.populating_aggregates.visitor import PopulatingAggregateVisitor 8 | from entity_framework.storages.sqlalchemy.constructing_model.visitor import ModelConstructingVisitor 9 | from entity_framework.storages.sqlalchemy.populating_model.visitor import ModelPopulatingVisitor 10 | from entity_framework.storages.sqlalchemy.querying.visitor import QueryBuildingVisitor 11 | from entity_framework.storages.sqlalchemy.registry import SaRegistry 12 | 13 | 14 | class SqlAlchemyRepo: 15 | base: DeclarativeMeta = None 16 | registry: SaRegistry = None 17 | 18 | _query: Optional[Query] = None 19 | 20 | def __init__(self, session: Session) -> None: 21 | self._session = session 22 | 23 | @classmethod 24 | def prepare(cls, entity_cls: Type[EntityType]) -> None: 25 | assert cls.base, "Must set cls base to an instance of DeclarativeMeta!" 26 | if not getattr(cls, "entity", None): 27 | cls.entity = entity_cls 28 | aet = cls.registry.entities_to_aets[entity_cls] 29 | ModelConstructingVisitor(cls.base, cls.registry).traverse_from(aet.root) 30 | 31 | @property 32 | def query(self) -> Query: 33 | if not getattr(self.__class__, "_query", None): 34 | aet = self.registry.entities_to_aets[self.entity] 35 | visitor = QueryBuildingVisitor(self.registry) 36 | visitor.traverse_from(aet.root) 37 | setattr(self.__class__, "_query", visitor.query) 38 | 39 | return self.__class__._query 40 | 41 | # TODO: sqlalchemy class could have an utility for creating IDS 42 | # Or it could be put into a separate utility function that would accept repo, then would get descendant classes 43 | # and got the new id. 44 | 45 | def get(self, identity: IdentityType) -> EntityType: 46 | # TODO: memoize populating func 47 | result = self.query.with_session(self._session).get(identity) 48 | if not result: 49 | # TODO: Raise more specialized exception 50 | raise exc.NoResultFound 51 | 52 | converting_visitor = PopulatingAggregateVisitor(result) 53 | aet = self.registry.entities_to_aets[self.entity] 54 | converting_visitor.traverse_from(aet.root) 55 | return converting_visitor.result 56 | 57 | def save(self, entity: EntityType) -> None: 58 | visitor = ModelPopulatingVisitor(entity, self.registry) 59 | visitor.traverse_from(self.registry.entities_to_aets[self.entity].root) 60 | self._session.merge(visitor.result) 61 | self._session.flush() 62 | -------------------------------------------------------------------------------- /entity_framework/storages/sqlalchemy/constructing_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Enforcer/python_entity_framework/015221c068c23834aa3663c6239947c01e38f40a/entity_framework/storages/sqlalchemy/constructing_model/__init__.py -------------------------------------------------------------------------------- /entity_framework/storages/sqlalchemy/constructing_model/raw_model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, Type 2 | 3 | import attr 4 | from sqlalchemy import Column 5 | from sqlalchemy.orm import relationship 6 | 7 | 8 | @attr.s(auto_attribs=True) 9 | class RawModel: 10 | name: str 11 | bases: Tuple[Type, ...] 12 | namespace: Dict 13 | 14 | def append_column(self, name: str, column: Column) -> None: 15 | self.namespace[name] = column 16 | 17 | def append_relationship(self, name: str, related_model_name: str, nullable: bool) -> None: 18 | self.namespace[name] = relationship(related_model_name, innerjoin=not nullable) 19 | 20 | def materialize(self) -> Type: 21 | return type(self.name, self.bases, self.namespace) 22 | -------------------------------------------------------------------------------- /entity_framework/storages/sqlalchemy/constructing_model/visitor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Type, Optional 2 | 3 | import inflection 4 | from sqlalchemy import Column, ForeignKey 5 | from sqlalchemy.ext.declarative import DeclarativeMeta 6 | 7 | from entity_framework.abstract_entity_tree import ( 8 | Visitor, 9 | EntityNode, 10 | ValueObjectNode, 11 | FieldNode, 12 | ListOfEntitiesNode, 13 | ListOfValueObjectsNode, 14 | ) 15 | from entity_framework.entity import Entity 16 | from entity_framework.storages.sqlalchemy import native_type_to_column 17 | from entity_framework.storages.sqlalchemy.registry import SaRegistry 18 | from entity_framework.storages.sqlalchemy.constructing_model.raw_model import RawModel 19 | 20 | 21 | class ModelConstructingVisitor(Visitor): 22 | EMPTY_PREFIX = "" 23 | 24 | def __init__(self, base: DeclarativeMeta, registry: SaRegistry) -> None: 25 | self._base = base 26 | self._registry = registry 27 | self._entities_stack: List[EntityNode] = [] 28 | self._entities_raw_models: Dict[Type[Entity], RawModel] = {} 29 | self._last_optional_vo_node: Optional[ValueObjectNode] = None 30 | self._stacked_vo: List[ValueObjectNode] = [] 31 | 32 | @property 33 | def _prefix(self) -> str: 34 | if not self._stacked_vo: 35 | return self.EMPTY_PREFIX 36 | return "_".join(vo.name for vo in self._stacked_vo) + "_" 37 | 38 | @property 39 | def current_entity(self) -> EntityNode: 40 | return self._entities_stack[-1] 41 | 42 | def visit_field(self, field: FieldNode) -> None: 43 | kwargs = {"primary_key": field.is_identity, "nullable": field.optional or self._last_optional_vo_node} 44 | raw_model: RawModel = self._entities_raw_models[self.current_entity.type] 45 | raw_model.append_column( 46 | f"{self._prefix}{field.name}", Column(native_type_to_column.convert(field.type), **kwargs) 47 | ) 48 | 49 | def visit_entity(self, entity: EntityNode) -> None: 50 | if entity.type in self._entities_raw_models: 51 | raise NotImplementedError("Probably recursive, not supported") 52 | 53 | model_name = f"{entity.type.__name__}Model" 54 | table_name = inflection.pluralize(inflection.underscore(entity.type.__name__)) 55 | 56 | if self._entities_stack: # nested, include foreign key 57 | identity_nodes: List[FieldNode] = [node for node in entity.children if getattr(node, "is_identity", None)] 58 | assert len(identity_nodes) == 1, "Multiple primary keys not supported" 59 | identity_node = identity_nodes.pop() 60 | raw_model: RawModel = self._entities_raw_models[self.current_entity.type] 61 | raw_model.append_column( 62 | f"{entity.name}_{identity_node.name}", 63 | Column( 64 | native_type_to_column.convert(identity_node.type), 65 | ForeignKey(f"{table_name}.{identity_node.name}"), 66 | nullable=entity.optional, 67 | ), 68 | ) 69 | raw_model.append_relationship(entity.name, model_name, entity.optional) 70 | 71 | self._entities_stack.append(entity) 72 | self._entities_raw_models[entity.type] = RawModel( 73 | name=model_name, bases=(self._base,), namespace={"__tablename__": table_name} 74 | ) 75 | 76 | def leave_entity(self, entity: EntityNode) -> None: 77 | entity_node = self._entities_stack.pop() 78 | raw_model: RawModel = self._entities_raw_models[entity_node.type] 79 | self._registry.entities_models[entity.type] = raw_model.materialize() 80 | 81 | def visit_value_object(self, value_object: ValueObjectNode) -> None: 82 | # value objects' fields are embedded into entity above it 83 | self._stacked_vo.append(value_object) 84 | if not self._last_optional_vo_node and value_object.optional: 85 | self._last_optional_vo_node = value_object 86 | 87 | def leave_value_object(self, value_object: ValueObjectNode) -> None: 88 | self._stacked_vo.pop() 89 | if self._last_optional_vo_node == value_object: 90 | self._last_optional_vo_node = None 91 | 92 | def visit_list_of_entities(self, list_of_entities: ListOfEntitiesNode) -> None: 93 | raise NotImplementedError 94 | 95 | def leave_list_of_entities(self, list_of_entities: ListOfEntitiesNode) -> None: 96 | raise NotImplementedError 97 | 98 | def visit_list_of_value_objects(self, list_of_entities: ListOfValueObjectsNode) -> None: 99 | raise NotImplementedError 100 | 101 | def leave_list_of_value_objects(self, list_of_entities: ListOfValueObjectsNode) -> None: 102 | raise NotImplementedError 103 | -------------------------------------------------------------------------------- /entity_framework/storages/sqlalchemy/native_type_to_column.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import typing 3 | from datetime import datetime 4 | 5 | from sqlalchemy import Integer, String, Float, DateTime 6 | from sqlalchemy.dialects import postgresql as postgresql_dialect 7 | 8 | 9 | # TODO: Support other dialects, not only PostgreSQL 10 | mapping = {int: Integer, str: String(255), uuid.UUID: postgresql_dialect.UUID, float: Float, datetime: DateTime} 11 | 12 | 13 | def convert(arg: typing.Type) -> typing.Any: 14 | try: 15 | return mapping[arg] 16 | except KeyError: 17 | raise TypeError(f"Unsupported type - {arg}") 18 | -------------------------------------------------------------------------------- /entity_framework/storages/sqlalchemy/populating_aggregates/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Enforcer/python_entity_framework/015221c068c23834aa3663c6239947c01e38f40a/entity_framework/storages/sqlalchemy/populating_aggregates/__init__.py -------------------------------------------------------------------------------- /entity_framework/storages/sqlalchemy/populating_aggregates/visitor.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any, Union 2 | 3 | from entity_framework.abstract_entity_tree import ( 4 | Visitor, 5 | FieldNode, 6 | EntityNode, 7 | ValueObjectNode, 8 | ListOfEntitiesNode, 9 | ListOfValueObjectsNode, 10 | ) 11 | 12 | 13 | class PopulatingAggregateVisitor(Visitor): 14 | EMPTY_PREFIX = "" 15 | 16 | def __init__(self, db_result: object) -> None: 17 | self._db_result = db_result 18 | self._entities_stack: List[EntityNode] = [] 19 | self._ef_objects_stack: List[Union[EntityNode, ValueObjectNode]] = [] 20 | self._ef_dicts_stack: List[dict] = [] 21 | self._result: Any = None 22 | self._stacked_vo: List[ValueObjectNode] = [] 23 | 24 | @property 25 | def _prefix(self) -> str: 26 | if not self._stacked_vo: 27 | return self.EMPTY_PREFIX 28 | return "_".join(vo.name for vo in self._stacked_vo) + "_" 29 | 30 | @property 31 | def result(self) -> Any: 32 | return self._result 33 | 34 | def visit_field(self, field: FieldNode) -> None: 35 | if isinstance(self._ef_objects_stack[-1], EntityNode): 36 | field_name = field.name 37 | else: 38 | field_name = f"{self._prefix}{field.name}" 39 | 40 | db_object = self._db_result 41 | for index, entity in enumerate(self._entities_stack[:-1]): 42 | next_entity = self._entities_stack[index + 1] 43 | db_object = getattr(db_object, next_entity.name) 44 | 45 | self._ef_dicts_stack[-1][field.name] = getattr(db_object, field_name) 46 | 47 | def visit_entity(self, entity: EntityNode) -> None: 48 | self._entities_stack.append(entity) 49 | self._stack_complex_object(entity) 50 | 51 | def leave_entity(self, entity: EntityNode) -> None: 52 | self._entities_stack.pop() 53 | self._construct_complex_object(entity) 54 | 55 | def visit_value_object(self, value_object: ValueObjectNode) -> None: 56 | self._stacked_vo.append(value_object) 57 | self._stack_complex_object(value_object) 58 | 59 | def leave_value_object(self, value_object: ValueObjectNode) -> None: 60 | self._stacked_vo.pop() 61 | self._construct_complex_object(value_object) 62 | 63 | def _stack_complex_object(self, vo_or_entity: Union[ValueObjectNode, EntityNode]) -> None: 64 | self._ef_objects_stack.append(vo_or_entity) 65 | self._ef_dicts_stack.append({}) 66 | 67 | def _construct_complex_object(self, vo_or_entity: Union[ValueObjectNode, EntityNode]) -> None: 68 | self._ef_objects_stack.pop() 69 | entity_dict = self._ef_dicts_stack.pop() 70 | if vo_or_entity.optional and entity_dict and all(v is None for v in entity_dict.values()): 71 | # One is not able to tell the difference between optional object with all its fields = None or 72 | # an absence of entire vo_or_entity 73 | instance = None 74 | else: 75 | instance = vo_or_entity.type(**entity_dict) 76 | if self._ef_dicts_stack: 77 | self._ef_dicts_stack[-1][vo_or_entity.name] = instance 78 | else: 79 | self._result = instance 80 | 81 | def visit_list_of_entities(self, list_of_entities: ListOfEntitiesNode) -> None: 82 | raise NotImplementedError 83 | 84 | def leave_list_of_entities(self, list_of_entities: ListOfEntitiesNode) -> None: 85 | raise NotImplementedError 86 | 87 | def visit_list_of_value_objects(self, list_of_value_objects: ListOfValueObjectsNode) -> None: 88 | raise NotImplementedError 89 | 90 | def leave_list_of_value_objects(self, list_of_value_objects: ListOfValueObjectsNode) -> None: 91 | raise NotImplementedError 92 | -------------------------------------------------------------------------------- /entity_framework/storages/sqlalchemy/populating_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Enforcer/python_entity_framework/015221c068c23834aa3663c6239947c01e38f40a/entity_framework/storages/sqlalchemy/populating_model/__init__.py -------------------------------------------------------------------------------- /entity_framework/storages/sqlalchemy/populating_model/visitor.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Any 2 | 3 | from entity_framework.abstract_entity_tree import ( 4 | Visitor, 5 | FieldNode, 6 | EntityNode, 7 | ValueObjectNode, 8 | ListOfEntitiesNode, 9 | ListOfValueObjectsNode, 10 | ) 11 | from entity_framework.entity import EntityOrVo 12 | from entity_framework.storages.sqlalchemy.registry import SaRegistry 13 | 14 | 15 | class ModelPopulatingVisitor(Visitor): 16 | EMPTY_PREFIX = "" 17 | 18 | def __init__(self, aggregate: EntityOrVo, registry: SaRegistry) -> None: 19 | self._aggregate = aggregate 20 | self._registry = registry 21 | self._complex_objects_stack: List[EntityOrVo] = [] 22 | self._entities_stack: List[EntityNode] = [] 23 | self._ef_objects_stack: List[Union[EntityNode, ValueObjectNode]] = [] 24 | self._models_dicts_stack: List[dict] = [] 25 | self._result: Any = None 26 | self._stacked_vo: List[ValueObjectNode] = [] 27 | 28 | @property 29 | def _prefix(self) -> str: 30 | if not self._stacked_vo: 31 | return self.EMPTY_PREFIX 32 | return "_".join(vo.name for vo in self._stacked_vo) + "_" 33 | 34 | @property 35 | def result(self) -> Any: 36 | return self._result 37 | 38 | def visit_field(self, field: FieldNode) -> None: 39 | if isinstance(self._ef_objects_stack[-1], EntityNode): 40 | field_name = field.name 41 | else: 42 | field_name = f"{self._prefix}{field.name}" 43 | 44 | if self._complex_objects_stack[-1]: # may be none if optional 45 | self._models_dicts_stack[-1][field_name] = getattr(self._complex_objects_stack[-1], field.name) 46 | 47 | def visit_entity(self, entity: EntityNode) -> None: 48 | self._entities_stack.append(entity) 49 | self._stack_complex_object(entity) 50 | self._models_dicts_stack.append({}) 51 | 52 | def leave_entity(self, entity: EntityNode) -> None: 53 | self._entities_stack.pop() 54 | self._construct_model(entity) 55 | self._complex_objects_stack.pop() 56 | 57 | def visit_value_object(self, value_object: ValueObjectNode) -> None: 58 | self._stacked_vo.append(value_object) 59 | self._stack_complex_object(value_object) 60 | 61 | def leave_value_object(self, value_object: ValueObjectNode) -> None: 62 | self._stacked_vo.pop() 63 | self._complex_objects_stack.pop() 64 | 65 | def _stack_complex_object(self, vo_or_entity: Union[ValueObjectNode, EntityNode]) -> None: 66 | self._ef_objects_stack.append(vo_or_entity) 67 | if not self._complex_objects_stack: 68 | self._complex_objects_stack.append(self._aggregate) 69 | else: 70 | current = self._complex_objects_stack[-1] 71 | if current is None: 72 | another = None 73 | else: 74 | another = getattr(current, vo_or_entity.name) 75 | 76 | self._complex_objects_stack.append(another) 77 | 78 | def _construct_model(self, entity: EntityNode) -> None: 79 | self._ef_objects_stack.pop() 80 | entity_dict = self._models_dicts_stack.pop() 81 | if entity.optional and entity_dict and all(v is None for v in entity_dict.values()): 82 | # One is not able to tell the difference between optional object with all its fields = None or 83 | # an absence of entire vo_or_entity 84 | instance = None 85 | else: 86 | model_cls = self._registry.entities_models[entity.type] 87 | instance = model_cls(**entity_dict) 88 | if self._models_dicts_stack: 89 | self._models_dicts_stack[-1][entity.name] = instance 90 | else: 91 | self._result = instance 92 | 93 | def visit_list_of_entities(self, list_of_entities: ListOfEntitiesNode) -> None: 94 | raise NotImplementedError 95 | 96 | def leave_list_of_entities(self, list_of_entities: ListOfEntitiesNode) -> None: 97 | raise NotImplementedError 98 | 99 | def visit_list_of_value_objects(self, list_of_value_objects: ListOfValueObjectsNode) -> None: 100 | raise NotImplementedError 101 | 102 | def leave_list_of_value_objects(self, list_of_value_objects: ListOfValueObjectsNode) -> None: 103 | raise NotImplementedError 104 | -------------------------------------------------------------------------------- /entity_framework/storages/sqlalchemy/querying/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Enforcer/python_entity_framework/015221c068c23834aa3663c6239947c01e38f40a/entity_framework/storages/sqlalchemy/querying/__init__.py -------------------------------------------------------------------------------- /entity_framework/storages/sqlalchemy/querying/visitor.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, DefaultDict, Set, Type 2 | from collections import defaultdict 3 | 4 | from sqlalchemy.orm import Query, joinedload 5 | 6 | from entity_framework.abstract_entity_tree import Visitor, EntityNode 7 | from entity_framework.storages.sqlalchemy.registry import SaRegistry 8 | 9 | 10 | class QueryBuildingVisitor(Visitor): 11 | def __init__(self, registry: SaRegistry) -> None: 12 | self._registry = registry 13 | self._root_model: Optional[Type] = None 14 | self._models_stack: List[Type] = [] 15 | self._models_to_join: DefaultDict[Type, List[str]] = defaultdict(list) 16 | self._all_models: Set[Type] = set() 17 | self._query: Optional[Query] = None 18 | 19 | @property 20 | def query(self) -> Query: 21 | if not self._root_model: 22 | raise Exception("No root model") 23 | 24 | # TODO: support more nesting levels than just one 25 | return Query(self._root_model).options( 26 | joinedload(getattr(self._root_model, rel_name)) for rel_name in self._models_to_join[self._root_model] 27 | ) 28 | 29 | def visit_entity(self, entity: EntityNode) -> None: 30 | # TODO: decide what to do with fields used magically, like entity.name which is really just a node name 31 | model = self._registry.entities_models[entity.type] 32 | if not self._root_model: 33 | self._root_model = model 34 | elif self._models_stack: 35 | self._models_to_join[self._models_stack[-1]].append(entity.name) 36 | 37 | self._models_stack.append(model) 38 | self._all_models.add(model) 39 | 40 | def leave_entity(self, entity: EntityNode) -> None: 41 | self._models_stack.pop() 42 | -------------------------------------------------------------------------------- /entity_framework/storages/sqlalchemy/registry.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Type 2 | 3 | import attr 4 | from sqlalchemy.ext.declarative import DeclarativeMeta 5 | 6 | from entity_framework.entity import Entity 7 | from entity_framework.registry import Registry 8 | 9 | 10 | @attr.s(auto_attribs=True) 11 | class SaRegistry(Registry): 12 | # TODO: Think of refactoring, so that this does not have semantics of a global variable 13 | entities_models: Dict[Type[Entity], Type[DeclarativeMeta]] = attr.Factory(dict) 14 | -------------------------------------------------------------------------------- /entity_framework/storages/sqlalchemy/types.py: -------------------------------------------------------------------------------- 1 | import typing 2 | import uuid 3 | from functools import singledispatch 4 | 5 | import attr 6 | 7 | from entity_framework.entity import Identity 8 | from entity_framework.abstract_entity_tree import _is_generic, _get_wrapped_type 9 | 10 | 11 | @singledispatch 12 | def to_storage(argument: typing.Any) -> typing.Any: 13 | return argument 14 | 15 | 16 | @to_storage.register(uuid.UUID) 17 | def _(argument: uuid.UUID) -> str: 18 | return str(argument) 19 | 20 | 21 | mapping = {uuid.UUID: uuid.UUID} 22 | 23 | 24 | def from_storage(argument: typing.Any, field: attr.Attribute) -> typing.Any: 25 | field_type = field.type 26 | if _is_generic(field): 27 | if Identity.is_identity(field): 28 | field_type = _get_wrapped_type(field) 29 | else: 30 | raise Exception("Unhandled branch") 31 | 32 | try: 33 | return mapping[field_type](argument) 34 | except KeyError: 35 | return argument 36 | -------------------------------------------------------------------------------- /entity_framework/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Enforcer/python_entity_framework/015221c068c23834aa3663c6239947c01e38f40a/entity_framework/tests/__init__.py -------------------------------------------------------------------------------- /entity_framework/tests/abstract_entity_tree/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Enforcer/python_entity_framework/015221c068c23834aa3663c6239947c01e38f40a/entity_framework/tests/abstract_entity_tree/__init__.py -------------------------------------------------------------------------------- /entity_framework/tests/abstract_entity_tree/test_build.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from decimal import Decimal 3 | from enum import Enum 4 | from uuid import UUID 5 | 6 | from entity_framework import abstract_entity_tree 7 | from entity_framework.abstract_entity_tree import ( 8 | AbstractEntityTree, 9 | EntityNode, 10 | FieldNode, 11 | ValueObjectNode, 12 | ListOfValueObjectsNode, 13 | ) 14 | from entity_framework import Entity, Identity, ValueObject 15 | 16 | 17 | class DummyEnum(Enum): 18 | FIRST_VALUE = "FIRST_VALUE" 19 | ANOTHER_VALUE = "ANOTHER_VALUE" 20 | 21 | 22 | class SimpleFlat(Entity): 23 | guid: Identity[UUID] 24 | name: typing.Optional[str] 25 | score: int 26 | enumerated: DummyEnum 27 | balance: typing.Optional[Decimal] 28 | 29 | 30 | def test_builds_simple_flat_entity(): 31 | result = abstract_entity_tree.build(SimpleFlat) 32 | 33 | assert result == AbstractEntityTree( 34 | root=EntityNode( 35 | name="simple_flat", 36 | type=SimpleFlat, 37 | optional=False, 38 | children=( 39 | FieldNode(name="guid", type=UUID, optional=False, children=(), is_identity=True), 40 | FieldNode(name="name", type=str, optional=True, children=(), is_identity=False), 41 | FieldNode(name="score", type=int, optional=False, children=(), is_identity=False), 42 | FieldNode(name="enumerated", type=DummyEnum, optional=False, children=(), is_identity=False), 43 | FieldNode(name="balance", type=Decimal, optional=True, children=(), is_identity=False), 44 | ), 45 | ) 46 | ) 47 | 48 | 49 | class NestedValueObject(ValueObject): 50 | amount: Decimal 51 | currency: str 52 | 53 | 54 | class NestedEntity(Entity): 55 | nested_guid: Identity[UUID] 56 | name: typing.Optional[str] 57 | 58 | 59 | class SomeAggregate(Entity): 60 | guid: Identity[UUID] 61 | nested: NestedEntity 62 | balance: NestedValueObject 63 | 64 | 65 | def test_builds_aggregate_with_embedded_entity_and_value_object(): 66 | result = abstract_entity_tree.build(SomeAggregate) 67 | 68 | assert result == AbstractEntityTree( 69 | root=EntityNode( 70 | name="some_aggregate", 71 | type=SomeAggregate, 72 | optional=False, 73 | children=( 74 | FieldNode(name="guid", type=UUID, optional=False, children=(), is_identity=True), 75 | EntityNode( 76 | name="nested", 77 | type=NestedEntity, 78 | children=( 79 | FieldNode(name="nested_guid", type=UUID, optional=False, children=(), is_identity=True), 80 | FieldNode(name="name", type=str, optional=True, children=(), is_identity=False), 81 | ), 82 | ), 83 | ValueObjectNode( 84 | name="balance", 85 | type=NestedValueObject, 86 | optional=False, 87 | children=( 88 | FieldNode(name="amount", type=Decimal, optional=False, children=(), is_identity=False), 89 | FieldNode(name="currency", type=str, optional=False, children=(), is_identity=False), 90 | ), 91 | ), 92 | ), 93 | ) 94 | ) 95 | 96 | 97 | class AggregateWithValueObjectList(Entity): 98 | id: Identity[int] 99 | wallets: typing.List[NestedValueObject] 100 | 101 | 102 | def test_builds_aggregate_with_list_of_value_objects(): 103 | result = abstract_entity_tree.build(AggregateWithValueObjectList) 104 | 105 | assert result == AbstractEntityTree( 106 | root=EntityNode( 107 | name="aggregate_with_value_object_list", 108 | type=AggregateWithValueObjectList, 109 | optional=False, 110 | children=( 111 | FieldNode(name="id", type=int, optional=False, children=(), is_identity=True), 112 | ListOfValueObjectsNode( 113 | name="wallets", 114 | type=NestedValueObject, 115 | optional=False, 116 | children=( 117 | FieldNode(name="amount", type=Decimal, optional=False, children=(), is_identity=False), 118 | FieldNode(name="currency", type=str, optional=False, children=(), is_identity=False), 119 | ), 120 | ), 121 | ), 122 | ) 123 | ) 124 | 125 | 126 | class AggregateWithOptionalNested(Entity): 127 | id: Identity[int] 128 | wallet: typing.Optional[NestedValueObject] 129 | 130 | 131 | def test_optional_nested_entity_makes_all_its_fields_optional(): 132 | result = abstract_entity_tree.build(AggregateWithOptionalNested) 133 | 134 | assert result == AbstractEntityTree( 135 | root=EntityNode( 136 | name="aggregate_with_optional_nested", 137 | type=AggregateWithOptionalNested, 138 | children=( 139 | FieldNode(name="id", type=int, optional=False, children=(), is_identity=True), 140 | ValueObjectNode( 141 | name="wallet", 142 | type=NestedValueObject, 143 | optional=True, 144 | children=( 145 | FieldNode(name="amount", type=Decimal, optional=False, children=(), is_identity=False), 146 | FieldNode(name="currency", type=str, optional=False, children=(), is_identity=False), 147 | ), 148 | ), 149 | ), 150 | ) 151 | ) 152 | -------------------------------------------------------------------------------- /entity_framework/tests/abstract_entity_tree/test_visitor.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import pytest 4 | 5 | from entity_framework.entity import Entity, Identity, ValueObject 6 | from entity_framework.abstract_entity_tree import ( 7 | AbstractEntityTree, 8 | EntityNode, 9 | ValueObjectNode, 10 | FieldNode, 11 | Visitor, 12 | Node, 13 | ) 14 | 15 | 16 | class Scribe(Visitor): 17 | def __init__(self) -> None: 18 | self.visits_log: typing.List[typing.Tuple[str, str]] = [] 19 | 20 | def visit_field(self, node: "Node") -> None: 21 | self.visits_log.append(("visit", node.name)) 22 | 23 | def leave_field(self, node: "Node") -> None: 24 | self.visits_log.append(("leave", node.name)) 25 | 26 | visit_entity = visit_value_object = visit_list_of_entities = visit_list_of_value_objects = visit_field 27 | leave_entity = leave_value_object = leave_list_of_entities = leave_list_of_value_objects = leave_field 28 | 29 | 30 | class Skill(ValueObject): 31 | skill_name: str 32 | damage: int 33 | 34 | 35 | class Dragon(Entity): 36 | name: Identity[str] 37 | skill: Skill 38 | age: int 39 | 40 | 41 | @pytest.fixture() 42 | def tree() -> AbstractEntityTree: 43 | return AbstractEntityTree( 44 | root=EntityNode( 45 | name="dragon", 46 | type=Dragon, 47 | children=[ 48 | FieldNode(name="name", type=str, is_identity=True), 49 | ValueObjectNode( 50 | name="skill", 51 | type=Skill, 52 | children=[FieldNode(name="skill_name", type=str), FieldNode(name="damage", type=int)], 53 | ), 54 | FieldNode(name="age", type=int), 55 | ], 56 | ) 57 | ) 58 | 59 | 60 | def test_iterates_tree(tree: AbstractEntityTree) -> None: 61 | visitor = Scribe() 62 | visitor.traverse_from(tree.root) 63 | 64 | assert visitor.visits_log == [ 65 | ("visit", "dragon"), 66 | ("visit", "name"), 67 | ("leave", "name"), 68 | ("visit", "skill"), 69 | ("visit", "skill_name"), 70 | ("leave", "skill_name"), 71 | ("visit", "damage"), 72 | ("leave", "damage"), 73 | ("leave", "skill"), 74 | ("visit", "age"), 75 | ("leave", "age"), 76 | ("leave", "dragon"), 77 | ] 78 | -------------------------------------------------------------------------------- /entity_framework/tests/conftest.py: -------------------------------------------------------------------------------- 1 | from _pytest.config.argparsing import Parser 2 | 3 | 4 | def pytest_addoption(parser: Parser) -> None: 5 | parser.addoption("--sqlalchemy-postgres-url", action="store", default=None) 6 | -------------------------------------------------------------------------------- /entity_framework/tests/storages/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Enforcer/python_entity_framework/015221c068c23834aa3663c6239947c01e38f40a/entity_framework/tests/storages/__init__.py -------------------------------------------------------------------------------- /entity_framework/tests/storages/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from _pytest.fixtures import SubRequest 3 | from sqlalchemy.engine import Engine, create_engine 4 | 5 | 6 | @pytest.fixture() 7 | def engine(request: SubRequest) -> Engine: 8 | connection_url = request.config.getoption("--sqlalchemy-postgres-url") 9 | assert connection_url, "You have to define --sqlalchemy-postgres-url cmd line option!" 10 | return create_engine(connection_url) 11 | -------------------------------------------------------------------------------- /entity_framework/tests/storages/sqlalchemy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Enforcer/python_entity_framework/015221c068c23834aa3663c6239947c01e38f40a/entity_framework/tests/storages/sqlalchemy/__init__.py -------------------------------------------------------------------------------- /entity_framework/tests/storages/sqlalchemy/conftest.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | 3 | import pytest 4 | from sqlalchemy.engine import Engine 5 | from sqlalchemy.orm import sessionmaker, Session 6 | from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base 7 | 8 | 9 | @pytest.fixture() 10 | def sa_base() -> DeclarativeMeta: 11 | return declarative_base() 12 | 13 | 14 | @pytest.fixture() 15 | def session(sa_base: DeclarativeMeta, engine: Engine) -> Generator[Session, None, None]: 16 | sa_base.metadata.drop_all(engine) 17 | sa_base.metadata.create_all(engine) 18 | session_factory = sessionmaker(engine) 19 | yield session_factory() 20 | session_factory.close_all() 21 | sa_base.metadata.drop_all(engine) 22 | -------------------------------------------------------------------------------- /entity_framework/tests/storages/sqlalchemy/test_nested_value_objects.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from typing import Optional, Union, Type, List, Dict 3 | 4 | import pytest 5 | from sqlalchemy.orm import Session 6 | from sqlalchemy.ext.declarative import DeclarativeMeta 7 | 8 | from entity_framework import Entity, Identity, ValueObject, Repository 9 | from entity_framework.storages.sqlalchemy import SqlAlchemyRepo 10 | from entity_framework.storages.sqlalchemy.registry import SaRegistry 11 | 12 | 13 | class Deadline(ValueObject): 14 | datetime: datetime 15 | penalty: int 16 | 17 | 18 | class Goal(ValueObject): 19 | assignee: str 20 | deadline: Optional[Deadline] = None 21 | 22 | 23 | class Board(Entity): 24 | id: Identity[int] 25 | goal: Optional[Goal] = None 26 | 27 | 28 | BoardRepo = Repository[Board, int] 29 | 30 | 31 | DATETIME = datetime.now() 32 | 33 | 34 | @pytest.fixture() 35 | def sa_repo(sa_base: DeclarativeMeta) -> Type[Union[SqlAlchemyRepo, BoardRepo]]: 36 | class SaBoardRepo(SqlAlchemyRepo, BoardRepo): 37 | base = sa_base 38 | registry = SaRegistry() 39 | 40 | return SaBoardRepo 41 | 42 | 43 | @pytest.mark.parametrize( 44 | "tables, entities, rows, expected_aggregate", 45 | [ 46 | (["boards"], [Board], [[{"id": 1}]], Board(id=1)), 47 | (["boards"], [Board], [[{"id": 1, "goal_assignee": "me"}]], Board(id=1, goal=Goal(assignee="me"))), 48 | ( 49 | ["boards"], 50 | [Board], 51 | [[{"id": 1, "goal_assignee": "me", "goal_deadline_datetime": DATETIME, "goal_deadline_penalty": 2000}]], 52 | Board(id=1, goal=Goal(assignee="me", deadline=Deadline(datetime=DATETIME, penalty=2000))), 53 | ), 54 | ], 55 | ) 56 | def test_gets_exemplary_data( 57 | sa_repo: Type[Union[SqlAlchemyRepo, BoardRepo]], 58 | session: Session, 59 | tables: List[str], 60 | entities: List[Type[Entity]], 61 | rows: List[Dict], 62 | expected_aggregate: Board, 63 | ) -> None: 64 | repo = sa_repo(session) 65 | 66 | for table_name, entity, mappings in zip(tables, entities, rows): 67 | model = sa_repo.registry.entities_models[entity] 68 | assert table_name == model.__tablename__ 69 | session.bulk_insert_mappings(model, mappings) 70 | 71 | assert repo.get(expected_aggregate.id) == expected_aggregate 72 | 73 | 74 | @pytest.mark.parametrize( 75 | "aggregate, expected_db_data", 76 | [ 77 | ( 78 | Board(id=1), 79 | {Board: [{"id": 1, "goal_assignee": None, "goal_deadline_penalty": None, "goal_deadline_datetime": None}]}, 80 | ), 81 | ( 82 | Board(id=1, goal=Goal(assignee="me")), 83 | {Board: [{"id": 1, "goal_assignee": "me", "goal_deadline_penalty": None, "goal_deadline_datetime": None}]}, 84 | ), 85 | ( 86 | Board(id=1, goal=Goal(assignee="me", deadline=Deadline(datetime=DATETIME, penalty=1500))), 87 | { 88 | Board: [ 89 | {"id": 1, "goal_assignee": "me", "goal_deadline_penalty": 1500, "goal_deadline_datetime": DATETIME} 90 | ] 91 | }, 92 | ), 93 | ], 94 | ) 95 | def test_saves_exemplary_data( 96 | sa_repo: Type[Union[SqlAlchemyRepo, BoardRepo]], 97 | session: Session, 98 | aggregate: Board, 99 | expected_db_data: Dict[Type[Entity], List[Dict]], 100 | ) -> None: 101 | repo = sa_repo(session) 102 | 103 | repo.save(aggregate) 104 | 105 | for entity, expected_rows in expected_db_data.items(): 106 | model = sa_repo.registry.entities_models[entity] 107 | assert [dict(row) for row in session.execute(model.__table__.select()).fetchall()] == expected_rows 108 | -------------------------------------------------------------------------------- /entity_framework/tests/storages/sqlalchemy/test_one_level_nested_entities.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Type, List, Dict 2 | 3 | import pytest 4 | from sqlalchemy.orm import Session 5 | from sqlalchemy.ext.declarative import DeclarativeMeta 6 | 7 | from entity_framework import Entity, Identity, ValueObject, Repository 8 | from entity_framework.storages.sqlalchemy import SqlAlchemyRepo 9 | from entity_framework.storages.sqlalchemy.registry import SaRegistry 10 | 11 | 12 | SubscriberId = int 13 | PlanId = int 14 | 15 | 16 | class Plan(Entity): 17 | id: Identity[PlanId] 18 | discount: float 19 | 20 | 21 | class Subscription(ValueObject): 22 | plan_id: PlanId 23 | start_at: int 24 | 25 | 26 | class Subscriber(Entity): 27 | id: Identity[SubscriberId] 28 | plan: Plan 29 | current_subscription: Optional[Subscription] = None 30 | lifetime_subscription: Optional[Subscription] = None 31 | 32 | def subscribe(self, subscription: Subscription) -> None: 33 | if not self.current_subscription: 34 | self.current_subscription = subscription 35 | else: 36 | raise Exception("Impossible") 37 | 38 | 39 | SubscriberRepo = Repository[Subscriber, SubscriberId] 40 | 41 | 42 | @pytest.fixture() 43 | def sa_repo(sa_base: DeclarativeMeta) -> Type[Union[SqlAlchemyRepo, SubscriberRepo]]: 44 | class SqlSubscriberRepo(SqlAlchemyRepo, SubscriberRepo): 45 | base = sa_base 46 | registry = SaRegistry() 47 | 48 | return SqlSubscriberRepo 49 | 50 | 51 | @pytest.mark.parametrize( 52 | "tables, entities, rows, expected_aggregate", 53 | [ 54 | ( 55 | ["plans", "subscribers"], 56 | [Plan, Subscriber], 57 | [[{"id": 1, "discount": 0.5}], [{"id": 1, "plan_id": 1}]], 58 | Subscriber(id=1, plan=Plan(id=1, discount=0.5)), 59 | ), 60 | ( 61 | ["plans", "subscribers"], 62 | [Plan, Subscriber], 63 | [ 64 | [{"id": 1, "discount": 0.5}], 65 | [{"id": 1, "plan_id": 1, "current_subscription_plan_id": 1, "current_subscription_start_at": 0}], 66 | ], 67 | Subscriber(id=1, plan=Plan(id=1, discount=0.5), current_subscription=Subscription(1, 0)), 68 | ), 69 | ], 70 | ) 71 | def test_gets_exemplary_data( 72 | sa_repo: Type[Union[SqlAlchemyRepo, SubscriberRepo]], 73 | session: Session, 74 | tables: List[str], 75 | entities: List[Type[Entity]], 76 | rows: List[Dict], 77 | expected_aggregate: Subscriber, 78 | ) -> None: 79 | repo = sa_repo(session) 80 | 81 | for table_name, entity, mappings in zip(tables, entities, rows): 82 | model = sa_repo.registry.entities_models[entity] 83 | assert table_name == model.__tablename__ 84 | session.bulk_insert_mappings(model, mappings) 85 | 86 | assert repo.get(expected_aggregate.id) == expected_aggregate 87 | 88 | 89 | @pytest.mark.parametrize( 90 | "aggregate, expected_db_data", 91 | [ 92 | ( 93 | Subscriber(id=1, plan=Plan(id=1, discount=0.5), current_subscription=Subscription(1, 0)), 94 | { 95 | Plan: [{"id": 1, "discount": 0.5}], 96 | Subscriber: [ 97 | { 98 | "id": 1, 99 | "current_subscription_start_at": 0, 100 | "current_subscription_plan_id": 1, 101 | "lifetime_subscription_start_at": None, 102 | "lifetime_subscription_plan_id": None, 103 | "plan_id": 1, 104 | } 105 | ], 106 | }, 107 | ) 108 | ], 109 | ) 110 | def test_saves_exemplary_data( 111 | sa_repo: Type[Union[SqlAlchemyRepo, SubscriberRepo]], 112 | session: Session, 113 | aggregate: Subscriber, 114 | expected_db_data: Dict[Type[Entity], List[Dict]], 115 | ) -> None: 116 | repo = sa_repo(session) 117 | 118 | repo.save(aggregate) 119 | 120 | for entity, expected_rows in expected_db_data.items(): 121 | model = sa_repo.registry.entities_models[entity] 122 | assert [dict(row) for row in session.execute(model.__table__.select()).fetchall()] == expected_rows 123 | -------------------------------------------------------------------------------- /entity_framework/tests/test_entity.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import pytest 4 | 5 | from entity_framework.entity import Entity, EntityWithoutIdentity, Identity, ValueObject, ValueObjectWithIdentity 6 | 7 | 8 | def test_entity_allows_one_with_identity(): 9 | class WithIdentity(Entity): 10 | id: Identity[int] 11 | 12 | assert WithIdentity(1).id == 1 13 | 14 | 15 | def test_entity_enforces_identity(): 16 | with pytest.raises(EntityWithoutIdentity): 17 | 18 | class Identless(Entity): 19 | name: str 20 | 21 | Identless("Some name") 22 | 23 | 24 | def test_entity_allows_multiple_identities(): 25 | class WithDoubleIdentity(Entity): 26 | id: Identity[int] 27 | second_id: Identity[int] 28 | 29 | entity = WithDoubleIdentity(1, 2) 30 | assert entity.id == 1 31 | assert entity.second_id == 2 32 | 33 | 34 | def test_value_object_without_identity(): 35 | class SomeValueObject(ValueObject): 36 | amount: int 37 | 38 | assert SomeValueObject(1).amount == 1 39 | 40 | 41 | def test_value_object_with_identity(): 42 | with pytest.raises(ValueObjectWithIdentity): 43 | 44 | class IllegalValueObject(ValueObject): 45 | id: Identity[int] 46 | 47 | IllegalValueObject(1) 48 | 49 | 50 | def test_vo_can_be_nested_inside_another_vo(): 51 | class A(ValueObject): 52 | name: str 53 | 54 | class B(ValueObject): 55 | age: int 56 | nested: A 57 | 58 | assert B(1, A("John")) 59 | 60 | 61 | def test_entity_can_not_be_nested_inside_vo(): 62 | class Human(Entity): 63 | id: Identity[int] 64 | name: str 65 | 66 | with pytest.raises(Exception): 67 | 68 | class A(ValueObject): 69 | score: int 70 | person: Human 71 | 72 | 73 | def test_optional_entity_can_not_be_nested_inside_vo(): 74 | class Person(Entity): 75 | id: Identity[int] 76 | name: str 77 | 78 | with pytest.raises(Exception): 79 | 80 | class B(ValueObject): 81 | score: int 82 | person: typing.Optional[Person] 83 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrs==18.2.0 2 | black==18.9b0 3 | flake8==3.7.7 4 | inflection==0.3.1 5 | mypy==0.670 6 | psycopg2-binary==2.7.7 7 | pytest==4.2.1 8 | SQLAlchemy==1.2.18 9 | 10 | --------------------------------------------------------------------------------