├── MANIFEST.in ├── .gitignore ├── alchemysession ├── __init__.py ├── core_mysql.py ├── core_sqlite.py ├── core_postgres.py ├── orm.py ├── core.py └── sqlalchemy.py ├── .editorconfig ├── setup.py ├── LICENSE └── README.rst /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.rst 2 | include LICENSE 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | dist/ 3 | *.egg-info/ 4 | __pycache__/ 5 | -------------------------------------------------------------------------------- /alchemysession/__init__.py: -------------------------------------------------------------------------------- 1 | from .sqlalchemy import AlchemySessionContainer 2 | 3 | __version__ = "0.2.16" 4 | __author__ = "Tulir Asokan " 5 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | indent_style = tab 5 | indent_size = 4 6 | end_of_line = lf 7 | charset = utf-8 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | 11 | [*.py] 12 | max_line_length = 99 13 | 14 | [*.{yaml,yml,py}] 15 | indent_style = space 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="telethon-session-sqlalchemy", 5 | version="0.2.16", 6 | url="https://github.com/tulir/telethon-session-sqlalchemy", 7 | 8 | author="Tulir Asokan", 9 | author_email="tulir@maunium.net", 10 | 11 | description="SQLAlchemy backend for Telethon session storage", 12 | long_description=open("README.rst").read(), 13 | 14 | packages=setuptools.find_packages(), 15 | 16 | install_requires=[ 17 | "SQLAlchemy>=1.2,<2", 18 | ], 19 | 20 | classifiers=[ 21 | "Development Status :: 4 - Beta", 22 | "License :: OSI Approved :: MIT License", 23 | "Programming Language :: Python", 24 | "Programming Language :: Python :: 3", 25 | "Programming Language :: Python :: 3.6", 26 | "Programming Language :: Python :: 3.7", 27 | "Programming Language :: Python :: 3.8", 28 | "Programming Language :: Python :: 3.9", 29 | ], 30 | python_requires="~=3.5", 31 | ) 32 | 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Tulir Asokan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /alchemysession/core_mysql.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Union 2 | 3 | from sqlalchemy.dialects.mysql import insert 4 | 5 | from telethon.sessions.memory import _SentFileType 6 | from telethon.tl.types import InputPhoto, InputDocument 7 | 8 | from .core import AlchemyCoreSession 9 | 10 | 11 | class AlchemyMySQLCoreSession(AlchemyCoreSession): 12 | def set_update_state(self, entity_id: int, row: Any) -> None: 13 | t = self.UpdateState.__table__ 14 | values = dict(pts=row.pts, qts=row.qts, date=row.date.timestamp(), 15 | seq=row.seq, unread_count=row.unread_count) 16 | with self.engine.begin() as conn: 17 | conn.execute(insert(t) 18 | .values(session_id=self.session_id, entity_id=entity_id, **values) 19 | .on_duplicate_key_update(**values)) 20 | 21 | def process_entities(self, tlo: Any) -> None: 22 | rows = self._entities_to_rows(tlo) 23 | if not rows: 24 | return 25 | 26 | t = self.Entity.__table__ 27 | with self.engine.begin() as conn: 28 | for row in rows: 29 | values = dict(hash=row[1], username=row[2], phone=row[3], name=row[4]) 30 | conn.execute(insert(t) 31 | .values(session_id=self.session_id, id=row[0], **values) 32 | .on_duplicate_key_update(**values)) 33 | 34 | def cache_file(self, md5_digest: str, file_size: int, 35 | instance: Union[InputDocument, InputPhoto]) -> None: 36 | if not isinstance(instance, (InputDocument, InputPhoto)): 37 | raise TypeError("Cannot cache {} instance".format(type(instance))) 38 | 39 | t = self.SentFile.__table__ 40 | values = dict(id=instance.id, hash=instance.access_hash) 41 | with self.engine.begin() as conn: 42 | conn.execute(insert(t) 43 | .values(session_id=self.session_id, md5_digest=md5_digest, 44 | type=_SentFileType.from_type(type(instance)).value, 45 | file_size=file_size, **values) 46 | .on_duplicate_key_update(**values)) 47 | -------------------------------------------------------------------------------- /alchemysession/core_sqlite.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Union 2 | 3 | from telethon.sessions.memory import _SentFileType 4 | from telethon.tl.types import InputPhoto, InputDocument 5 | 6 | from .core import AlchemyCoreSession 7 | 8 | 9 | class AlchemySQLiteCoreSession(AlchemyCoreSession): 10 | def set_update_state(self, entity_id: int, row: Any) -> None: 11 | with self.engine.begin() as conn: 12 | conn.execute("INSERT OR REPLACE INTO {} ".format(self.UpdateState.__tablename__) + 13 | "VALUES (:session_id, :entity_id, :pts, :qts, :date, :seq, " 14 | " :unread_count)", 15 | dict(session_id=self.session_id, entity_id=entity_id, pts=row.pts, 16 | qts=row.qts, date=row.date.timestamp(), seq=row.seq, 17 | unread_count=row.unread_count)) 18 | 19 | def process_entities(self, tlo: Any) -> None: 20 | rows = self._entities_to_rows(tlo) 21 | if not rows: 22 | return 23 | 24 | with self.engine.begin() as conn: 25 | conn.execute("INSERT OR REPLACE INTO {} ".format(self.Entity.__tablename__) + 26 | "VALUES (:session_id, :id, :hash, :username, :phone, :name)", 27 | [dict(session_id=self.session_id, id=row[0], hash=row[1], 28 | username=row[2], phone=row[3], name=row[4]) 29 | for row in rows]) 30 | 31 | def cache_file(self, md5_digest: str, file_size: int, 32 | instance: Union[InputDocument, InputPhoto]) -> None: 33 | if not isinstance(instance, (InputDocument, InputPhoto)): 34 | raise TypeError("Cannot cache {} instance".format(type(instance))) 35 | 36 | t = self.SentFile.__table__ 37 | values = dict(id=instance.id, hash=instance.access_hash) 38 | with self.engine.begin() as conn: 39 | conn.execute("INSERT OR REPLACE INTO {} ".format(self.SentFile.__tablename__) + 40 | "VALUES (:session_id, :md5_digest, :type, :file_size)", 41 | dict(session_id=self.session_id, md5_digest=md5_digest, 42 | type=_SentFileType.from_type(type(instance)).value, 43 | file_size=file_size, **values)) 44 | -------------------------------------------------------------------------------- /alchemysession/core_postgres.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Union 2 | 3 | from sqlalchemy.dialects.postgresql import insert 4 | 5 | from telethon.sessions.memory import _SentFileType 6 | from telethon.tl.types import InputPhoto, InputDocument 7 | 8 | from .core import AlchemyCoreSession 9 | 10 | 11 | class AlchemyPostgresCoreSession(AlchemyCoreSession): 12 | def set_update_state(self, entity_id: int, row: Any) -> None: 13 | t = self.UpdateState.__table__ 14 | values = dict(pts=row.pts, qts=row.qts, date=row.date.timestamp(), 15 | seq=row.seq, unread_count=row.unread_count) 16 | with self.engine.begin() as conn: 17 | conn.execute(insert(t) 18 | .values(session_id=self.session_id, entity_id=entity_id, **values) 19 | .on_conflict_do_update(constraint=t.primary_key, set_=values)) 20 | 21 | def process_entities(self, tlo: Any) -> None: 22 | rows = self._entities_to_rows(tlo) 23 | if not rows: 24 | return 25 | 26 | t = self.Entity.__table__ 27 | ins = insert(t) 28 | upsert = ins.on_conflict_do_update(constraint=t.primary_key, set_={ 29 | "hash": ins.excluded.hash, 30 | "username": ins.excluded.username, 31 | "phone": ins.excluded.phone, 32 | "name": ins.excluded.name, 33 | }) 34 | with self.engine.begin() as conn: 35 | conn.execute(upsert, [dict(session_id=self.session_id, id=row[0], hash=row[1], 36 | username=row[2], phone=row[3], name=row[4]) 37 | for row in rows]) 38 | 39 | def cache_file(self, md5_digest: str, file_size: int, 40 | instance: Union[InputDocument, InputPhoto]) -> None: 41 | if not isinstance(instance, (InputDocument, InputPhoto)): 42 | raise TypeError("Cannot cache {} instance".format(type(instance))) 43 | 44 | t = self.SentFile.__table__ 45 | values = dict(id=instance.id, hash=instance.access_hash) 46 | with self.engine.begin() as conn: 47 | conn.execute(insert(t) 48 | .values(session_id=self.session_id, md5_digest=md5_digest, 49 | type=_SentFileType.from_type(type(instance)).value, 50 | file_size=file_size, **values) 51 | .on_conflict_do_update(constraint=t.primary_key, set_=values)) 52 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | Telethon SQLAlchemy session 2 | =========================== 3 | 4 | A `Telethon`_ session storage implementation backed by `SQLAlchemy`_. 5 | 6 | .. _Telethon: https://github.com/LonamiWebs/Telethon 7 | .. _SQLAlchemy: https://www.sqlalchemy.org/ 8 | 9 | Installation 10 | ------------ 11 | `telethon-session-sqlalchemy`_ @ PyPI 12 | 13 | .. code-block:: shell 14 | 15 | pip install telethon-session-sqlalchemy 16 | 17 | .. _telethon-session-sqlalchemy: https://pypi.python.org/pypi/telethon-session-sqlalchemy 18 | 19 | Usage 20 | ----- 21 | This session implementation can store multiple Sessions in the same database, 22 | but to do this, each session instance needs to have access to the same models 23 | and database session. 24 | 25 | To get started, you need to create an ``AlchemySessionContainer`` which will 26 | contain that shared data. The simplest way to use ``AlchemySessionContainer`` 27 | is to simply pass it the database URL: 28 | 29 | .. code-block:: python 30 | 31 | from alchemysession import AlchemySessionContainer 32 | container = AlchemySessionContainer('postgres://user:pass@localhost/telethon') 33 | 34 | If you already have SQLAlchemy set up for your own project, you can also pass 35 | the engine separately: 36 | 37 | .. code-block:: python 38 | 39 | my_sqlalchemy_engine = sqlalchemy.create_engine('...') 40 | container = AlchemySessionContainer(engine=my_sqlalchemy_engine) 41 | 42 | By default, the session container will manage table creation/schema updates/etc 43 | automatically. If you want to manage everything yourself, you can pass your 44 | SQLAlchemy Session and ``declarative_base`` instances and set ``manage_tables`` 45 | to ``False``: 46 | 47 | .. code-block:: python 48 | 49 | from sqlalchemy.ext.declarative import declarative_base 50 | from sqlalchemy import orm 51 | import sqlalchemy 52 | ... 53 | session_factory = orm.sessionmaker(bind=my_sqlalchemy_engine) 54 | session = session_factory() 55 | my_base = declarative_base() 56 | ... 57 | container = AlchemySessionContainer( 58 | session=session, table_base=my_base, manage_tables=False 59 | ) 60 | 61 | You always need to provide either ``engine`` or ``session`` to the container. 62 | If you set ``manage_tables=False`` and provide a ``session``, ``engine`` is not 63 | needed. In any other case, ``engine`` is always required. 64 | 65 | After you have your ``AlchemySessionContainer`` instance created, you can 66 | create new sessions by calling ``new_session``: 67 | 68 | .. code-block:: python 69 | 70 | session = container.new_session('some session id') 71 | client = TelegramClient(session) 72 | 73 | where ``some session id`` is an unique identifier for the session. 74 | -------------------------------------------------------------------------------- /alchemysession/orm.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Any, Union, TYPE_CHECKING 2 | import datetime 3 | 4 | from sqlalchemy import orm 5 | 6 | from telethon.sessions.memory import MemorySession, _SentFileType 7 | from telethon import utils 8 | from telethon.crypto import AuthKey 9 | from telethon.tl.types import InputPhoto, InputDocument, PeerUser, PeerChat, PeerChannel, updates 10 | 11 | if TYPE_CHECKING: 12 | from .sqlalchemy import AlchemySessionContainer 13 | 14 | 15 | class AlchemySession(MemorySession): 16 | def __init__(self, container: 'AlchemySessionContainer', session_id: str) -> None: 17 | super().__init__() 18 | self.container = container 19 | self.db = container.db 20 | self.engine = container.db_engine 21 | self.Version, self.Session, self.Entity, self.SentFile, self.UpdateState = ( 22 | container.Version, container.Session, container.Entity, 23 | container.SentFile, container.UpdateState) 24 | self.session_id = session_id 25 | self._load_session() 26 | 27 | def _load_session(self) -> None: 28 | sessions = self._db_query(self.Session).all() 29 | session = sessions[0] if sessions else None 30 | if session: 31 | self._dc_id = session.dc_id 32 | self._server_address = session.server_address 33 | self._port = session.port 34 | self._auth_key = AuthKey(data=session.auth_key) 35 | 36 | def clone(self, to_instance=None) -> MemorySession: 37 | return super().clone(MemorySession()) 38 | 39 | def _get_auth_key(self) -> Optional[AuthKey]: 40 | sessions = self._db_query(self.Session).all() 41 | session = sessions[0] if sessions else None 42 | if session and session.auth_key: 43 | return AuthKey(data=session.auth_key) 44 | return None 45 | 46 | def set_dc(self, dc_id: str, server_address: str, port: int) -> None: 47 | super().set_dc(dc_id, server_address, port) 48 | self._update_session_table() 49 | self._auth_key = self._get_auth_key() 50 | 51 | def get_update_state(self, entity_id: int) -> Optional[updates.State]: 52 | row = self.UpdateState.query.get((self.session_id, entity_id)) 53 | if row: 54 | date = datetime.datetime.utcfromtimestamp(row.date) 55 | return updates.State(row.pts, row.qts, date, row.seq, row.unread_count) 56 | return None 57 | 58 | def set_update_state(self, entity_id: int, row: Any) -> None: 59 | if row: 60 | self.db.merge(self.UpdateState(session_id=self.session_id, entity_id=entity_id, 61 | pts=row.pts, qts=row.qts, date=row.date.timestamp(), 62 | seq=row.seq, 63 | unread_count=row.unread_count)) 64 | self.save() 65 | 66 | @MemorySession.auth_key.setter 67 | def auth_key(self, value: AuthKey) -> None: 68 | self._auth_key = value 69 | self._update_session_table() 70 | 71 | def _update_session_table(self) -> None: 72 | self.Session.query.filter(self.Session.session_id == self.session_id).delete() 73 | self.db.add(self.Session(session_id=self.session_id, dc_id=self._dc_id, 74 | server_address=self._server_address, port=self._port, 75 | auth_key=(self._auth_key.key if self._auth_key else b''))) 76 | 77 | def _db_query(self, dbclass: Any, *args: Any) -> orm.Query: 78 | return dbclass.query.filter( 79 | dbclass.session_id == self.session_id, *args 80 | ) 81 | 82 | def save(self) -> None: 83 | self.container.save() 84 | 85 | def close(self) -> None: 86 | # Nothing to do here, connection is managed by AlchemySessionContainer. 87 | pass 88 | 89 | def delete(self) -> None: 90 | self._db_query(self.Session).delete() 91 | self._db_query(self.Entity).delete() 92 | self._db_query(self.SentFile).delete() 93 | self._db_query(self.UpdateState).delete() 94 | 95 | def _entity_values_to_row(self, id: int, hash: int, username: str, phone: str, name: str 96 | ) -> Any: 97 | return self.Entity(session_id=self.session_id, id=id, hash=hash, 98 | username=username, phone=phone, name=name) 99 | 100 | def process_entities(self, tlo: Any) -> None: 101 | rows = self._entities_to_rows(tlo) 102 | if not rows: 103 | return 104 | 105 | for row in rows: 106 | self.db.merge(row) 107 | self.save() 108 | 109 | def get_entity_rows_by_phone(self, key: str) -> Optional[Tuple[int, int]]: 110 | row = self._db_query(self.Entity, 111 | self.Entity.phone == key).one_or_none() 112 | return (row.id, row.hash) if row else None 113 | 114 | def get_entity_rows_by_username(self, key: str) -> Optional[Tuple[int, int]]: 115 | row = self._db_query(self.Entity, 116 | self.Entity.username == key).one_or_none() 117 | return (row.id, row.hash) if row else None 118 | 119 | def get_entity_rows_by_name(self, key: str) -> Optional[Tuple[int, int]]: 120 | row = self._db_query(self.Entity, 121 | self.Entity.name == key).one_or_none() 122 | return (row.id, row.hash) if row else None 123 | 124 | def get_entity_rows_by_id(self, key: int, exact: bool = True) -> Optional[Tuple[int, int]]: 125 | if exact: 126 | query = self._db_query(self.Entity, self.Entity.id == key) 127 | else: 128 | ids = ( 129 | utils.get_peer_id(PeerUser(key)), 130 | utils.get_peer_id(PeerChat(key)), 131 | utils.get_peer_id(PeerChannel(key)) 132 | ) 133 | query = self._db_query(self.Entity, self.Entity.id.in_(ids)) 134 | 135 | row = query.one_or_none() 136 | return (row.id, row.hash) if row else None 137 | 138 | def get_file(self, md5_digest: str, file_size: int, cls: Any) -> Optional[Tuple[int, int]]: 139 | row = self._db_query(self.SentFile, 140 | self.SentFile.md5_digest == md5_digest, 141 | self.SentFile.file_size == file_size, 142 | self.SentFile.type == _SentFileType.from_type( 143 | cls).value).one_or_none() 144 | return (row.id, row.hash) if row else None 145 | 146 | def cache_file(self, md5_digest: str, file_size: int, 147 | instance: Union[InputDocument, InputPhoto]) -> None: 148 | if not isinstance(instance, (InputDocument, InputPhoto)): 149 | raise TypeError("Cannot cache {} instance".format(type(instance))) 150 | 151 | self.db.merge( 152 | self.SentFile(session_id=self.session_id, md5_digest=md5_digest, file_size=file_size, 153 | type=_SentFileType.from_type(type(instance)).value, 154 | id=instance.id, hash=instance.access_hash)) 155 | self.save() 156 | -------------------------------------------------------------------------------- /alchemysession/core.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Any, Union 2 | import datetime 3 | 4 | from sqlalchemy import and_, select 5 | 6 | from telethon.sessions.memory import _SentFileType 7 | from telethon import utils 8 | from telethon.crypto import AuthKey 9 | from telethon.tl.types import InputPhoto, InputDocument, PeerUser, PeerChat, PeerChannel, updates 10 | 11 | from .orm import AlchemySession 12 | 13 | 14 | class AlchemyCoreSession(AlchemySession): 15 | def _load_session(self) -> None: 16 | t = self.Session.__table__ 17 | rows = self.engine.execute(select([t.c.dc_id, t.c.server_address, t.c.port, t.c.auth_key]) 18 | .where(t.c.session_id == self.session_id)) 19 | try: 20 | self._dc_id, self._server_address, self._port, auth_key = next(rows) 21 | self._auth_key = AuthKey(data=auth_key) 22 | except StopIteration: 23 | pass 24 | 25 | def _get_auth_key(self) -> Optional[AuthKey]: 26 | t = self.Session.__table__ 27 | rows = self.engine.execute(select([t.c.auth_key]).where(t.c.session_id == self.session_id)) 28 | try: 29 | ak = next(rows)[0] 30 | except (StopIteration, IndexError): 31 | ak = None 32 | return AuthKey(data=ak) if ak else None 33 | 34 | def get_update_state(self, entity_id: int) -> Optional[updates.State]: 35 | t = self.UpdateState.__table__ 36 | rows = self.engine.execute(select([t]) 37 | .where(and_(t.c.session_id == self.session_id, 38 | t.c.entity_id == entity_id))) 39 | try: 40 | _, _, pts, qts, date, seq, unread_count = next(rows) 41 | date = datetime.datetime.utcfromtimestamp(date) 42 | return updates.State(pts, qts, date, seq, unread_count) 43 | except StopIteration: 44 | return None 45 | 46 | def set_update_state(self, entity_id: int, row: Any) -> None: 47 | t = self.UpdateState.__table__ 48 | self.engine.execute(t.delete().where(and_(t.c.session_id == self.session_id, 49 | t.c.entity_id == entity_id))) 50 | self.engine.execute(t.insert() 51 | .values(session_id=self.session_id, entity_id=entity_id, pts=row.pts, 52 | qts=row.qts, date=row.date.timestamp(), seq=row.seq, 53 | unread_count=row.unread_count)) 54 | 55 | def _update_session_table(self) -> None: 56 | with self.engine.begin() as conn: 57 | conn.execute( 58 | self.Session.__table__.delete().where(self.Session.session_id == self.session_id)) 59 | conn.execute(self.Session.__table__.insert(), 60 | session_id=self.session_id, dc_id=self._dc_id, 61 | server_address=self._server_address, port=self._port, 62 | auth_key=(self._auth_key.key if self._auth_key else b'')) 63 | 64 | def save(self) -> None: 65 | # engine.execute() autocommits 66 | pass 67 | 68 | def delete(self) -> None: 69 | with self.engine.begin() as conn: 70 | conn.execute(self.Session.__table__.delete().where( 71 | self.Session.__table__.c.session_id == self.session_id)) 72 | conn.execute(self.Entity.__table__.delete().where( 73 | self.Entity.__table__.c.session_id == self.session_id)) 74 | conn.execute(self.SentFile.__table__.delete().where( 75 | self.SentFile.__table__.c.session_id == self.session_id)) 76 | conn.execute(self.UpdateState.__table__.delete().where( 77 | self.UpdateState.__table__.c.session_id == self.session_id)) 78 | 79 | def _entity_values_to_row(self, id: int, hash: int, username: str, phone: str, name: str 80 | ) -> Any: 81 | return id, hash, username, phone, name 82 | 83 | def process_entities(self, tlo: Any) -> None: 84 | rows = self._entities_to_rows(tlo) 85 | if not rows: 86 | return 87 | 88 | t = self.Entity.__table__ 89 | with self.engine.begin() as conn: 90 | conn.execute(t.delete().where(and_(t.c.session_id == self.session_id, 91 | t.c.id.in_([row[0] for row in rows])))) 92 | conn.execute(t.insert(), [dict(session_id=self.session_id, id=row[0], hash=row[1], 93 | username=row[2], phone=row[3], name=row[4]) 94 | for row in rows]) 95 | 96 | def get_entity_rows_by_phone(self, key: str) -> Optional[Tuple[int, int]]: 97 | return self._get_entity_rows_by_condition(self.Entity.__table__.c.phone == key) 98 | 99 | def get_entity_rows_by_username(self, key: str) -> Optional[Tuple[int, int]]: 100 | return self._get_entity_rows_by_condition(self.Entity.__table__.c.username == key) 101 | 102 | def get_entity_rows_by_name(self, key: str) -> Optional[Tuple[int, int]]: 103 | return self._get_entity_rows_by_condition(self.Entity.__table__.c.name == key) 104 | 105 | def _get_entity_rows_by_condition(self, condition) -> Optional[Tuple[int, int]]: 106 | t = self.Entity.__table__ 107 | rows = self.engine.execute(select([t.c.id, t.c.hash]) 108 | .where(and_(t.c.session_id == self.session_id, condition))) 109 | try: 110 | return next(rows) 111 | except StopIteration: 112 | return None 113 | 114 | def get_entity_rows_by_id(self, key: int, exact: bool = True) -> Optional[Tuple[int, int]]: 115 | t = self.Entity.__table__ 116 | if exact: 117 | rows = self.engine.execute(select([t.c.id, t.c.hash]).where( 118 | and_(t.c.session_id == self.session_id, t.c.id == key))) 119 | else: 120 | ids = ( 121 | utils.get_peer_id(PeerUser(key)), 122 | utils.get_peer_id(PeerChat(key)), 123 | utils.get_peer_id(PeerChannel(key)) 124 | ) 125 | rows = self.engine.execute(select([t.c.id, t.c.hash]) 126 | .where( 127 | and_(t.c.session_id == self.session_id, t.c.id.in_(ids)))) 128 | 129 | try: 130 | return next(rows) 131 | except StopIteration: 132 | return None 133 | 134 | def get_file(self, md5_digest: str, file_size: int, cls: Any) -> Optional[Tuple[int, int]]: 135 | t = self.SentFile.__table__ 136 | rows = (self.engine.execute(select([t.c.id, t.c.hash]) 137 | .where(and_(t.c.session_id == self.session_id, 138 | t.c.md5_digest == md5_digest, 139 | t.c.file_size == file_size, 140 | t.c.type == _SentFileType.from_type(cls).value)))) 141 | try: 142 | return next(rows) 143 | except StopIteration: 144 | return None 145 | 146 | def cache_file(self, md5_digest: str, file_size: int, 147 | instance: Union[InputDocument, InputPhoto]) -> None: 148 | if not isinstance(instance, (InputDocument, InputPhoto)): 149 | raise TypeError("Cannot cache {} instance".format(type(instance))) 150 | 151 | t = self.SentFile.__table__ 152 | file_type = _SentFileType.from_type(type(instance)).value 153 | with self.engine.begin() as conn: 154 | conn.execute(t.delete().where(session_id=self.session_id, md5_digest=md5_digest, 155 | type=file_type, file_size=file_size)) 156 | conn.execute(t.insert().values(session_id=self.session_id, md5_digest=md5_digest, 157 | type=file_type, file_size=file_size, id=instance.id, 158 | hash=instance.access_hash)) 159 | -------------------------------------------------------------------------------- /alchemysession/sqlalchemy.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Any, Union 2 | 3 | from sqlalchemy.ext.declarative import declarative_base 4 | from sqlalchemy.orm.scoping import scoped_session 5 | from sqlalchemy import Column, String, Integer, BigInteger, LargeBinary, orm, func, select, and_ 6 | import sqlalchemy as sql 7 | 8 | from .orm import AlchemySession 9 | from .core import AlchemyCoreSession 10 | from .core_mysql import AlchemyMySQLCoreSession 11 | from .core_sqlite import AlchemySQLiteCoreSession 12 | from .core_postgres import AlchemyPostgresCoreSession 13 | 14 | LATEST_VERSION = 2 15 | 16 | 17 | class AlchemySessionContainer: 18 | def __init__(self, engine: Union[sql.engine.Engine, str] = None, 19 | session: Optional[Union[orm.Session, scoped_session, bool]] = None, 20 | table_prefix: str = "", table_base: Optional[declarative_base] = None, 21 | manage_tables: bool = True) -> None: 22 | if isinstance(engine, str): 23 | engine = sql.create_engine(engine) 24 | 25 | self.db_engine = engine 26 | if session is None: 27 | db_factory = orm.sessionmaker(bind=self.db_engine) 28 | self.db = orm.scoping.scoped_session(db_factory) 29 | elif not session: 30 | self.db = None 31 | else: 32 | self.db = session 33 | 34 | table_base = table_base or declarative_base() 35 | (self.Version, self.Session, self.Entity, 36 | self.SentFile, self.UpdateState) = self.create_table_classes(self.db, table_prefix, 37 | table_base) 38 | self.alchemy_session_class = AlchemySession 39 | if not self.db: 40 | # Implicit core mode if there's no ORM session. 41 | self.core_mode = True 42 | 43 | if manage_tables: 44 | if not self.db: 45 | raise ValueError("Can't manage tables without an ORM session.") 46 | table_base.metadata.bind = self.db_engine 47 | if not self.db_engine.dialect.has_table(self.db_engine, 48 | self.Version.__tablename__): 49 | table_base.metadata.create_all() 50 | self.db.add(self.Version(version=LATEST_VERSION)) 51 | self.db.commit() 52 | else: 53 | self.check_and_upgrade_database() 54 | 55 | @property 56 | def core_mode(self) -> bool: 57 | return self.alchemy_session_class != AlchemySession 58 | 59 | @core_mode.setter 60 | def core_mode(self, val: bool) -> None: 61 | if val: 62 | if self.db_engine.dialect.name == "mysql": 63 | self.alchemy_session_class = AlchemyMySQLCoreSession 64 | elif self.db_engine.dialect.name == "postgresql": 65 | self.alchemy_session_class = AlchemyPostgresCoreSession 66 | elif self.db_engine.dialect.name == "sqlite": 67 | self.alchemy_session_class = AlchemySQLiteCoreSession 68 | else: 69 | self.alchemy_session_class = AlchemyCoreSession 70 | else: 71 | if not self.db: 72 | raise ValueError("Can't use ORM mode without an ORM session.") 73 | self.alchemy_session_class = AlchemySession 74 | 75 | @staticmethod 76 | def create_table_classes(db: scoped_session, prefix: str, base: declarative_base 77 | ) -> Tuple[Any, Any, Any, Any, Any]: 78 | qp = db.query_property() if db else None 79 | 80 | class Version(base): 81 | query = qp 82 | __tablename__ = "{prefix}version".format(prefix=prefix) 83 | version = Column(Integer, primary_key=True) 84 | 85 | def __str__(self): 86 | return "Version('{}')".format(self.version) 87 | 88 | class Session(base): 89 | query = qp 90 | __tablename__ = '{prefix}sessions'.format(prefix=prefix) 91 | 92 | session_id = Column(String(255), primary_key=True) 93 | dc_id = Column(Integer, primary_key=True) 94 | server_address = Column(String(255)) 95 | port = Column(Integer) 96 | auth_key = Column(LargeBinary) 97 | 98 | def __str__(self): 99 | return "Session('{}', {}, '{}', {}, {})".format(self.session_id, self.dc_id, 100 | self.server_address, self.port, 101 | self.auth_key) 102 | 103 | class Entity(base): 104 | query = qp 105 | __tablename__ = '{prefix}entities'.format(prefix=prefix) 106 | 107 | session_id = Column(String(255), primary_key=True) 108 | id = Column(BigInteger, primary_key=True) 109 | hash = Column(BigInteger, nullable=False) 110 | username = Column(String(32)) 111 | phone = Column(BigInteger) 112 | name = Column(String(255)) 113 | 114 | def __str__(self): 115 | return "Entity('{}', {}, {}, '{}', '{}', '{}')".format(self.session_id, self.id, 116 | self.hash, self.username, 117 | self.phone, self.name) 118 | 119 | class SentFile(base): 120 | query = qp 121 | __tablename__ = '{prefix}sent_files'.format(prefix=prefix) 122 | 123 | session_id = Column(String(255), primary_key=True) 124 | md5_digest = Column(LargeBinary, primary_key=True) 125 | file_size = Column(Integer, primary_key=True) 126 | type = Column(Integer, primary_key=True) 127 | id = Column(BigInteger) 128 | hash = Column(BigInteger) 129 | 130 | def __str__(self): 131 | return "SentFile('{}', {}, {}, {}, {}, {})".format(self.session_id, 132 | self.md5_digest, self.file_size, 133 | self.type, self.id, self.hash) 134 | 135 | class UpdateState(base): 136 | query = qp 137 | __tablename__ = "{prefix}update_state".format(prefix=prefix) 138 | 139 | session_id = Column(String(255), primary_key=True) 140 | entity_id = Column(BigInteger, primary_key=True) 141 | pts = Column(BigInteger) 142 | qts = Column(BigInteger) 143 | date = Column(BigInteger) 144 | seq = Column(BigInteger) 145 | unread_count = Column(Integer) 146 | 147 | return Version, Session, Entity, SentFile, UpdateState 148 | 149 | def _add_column(self, table: Any, column: Column) -> None: 150 | column_name = column.compile(dialect=self.db_engine.dialect) 151 | column_type = column.type.compile(self.db_engine.dialect) 152 | self.db_engine.execute("ALTER TABLE {} ADD COLUMN {} {}".format( 153 | table.__tablename__, column_name, column_type)) 154 | 155 | def check_and_upgrade_database(self) -> None: 156 | row = self.Version.query.all() 157 | version = row[0].version if row else 1 158 | if version == LATEST_VERSION: 159 | return 160 | 161 | self.Version.query.delete() 162 | 163 | if version == 1: 164 | self.UpdateState.__table__.create(self.db_engine) 165 | version = 3 166 | elif version == 2: 167 | self._add_column(self.UpdateState, Column(type=Integer, name="unread_count")) 168 | 169 | self.db.add(self.Version(version=version)) 170 | self.db.commit() 171 | 172 | def new_session(self, session_id: str) -> 'AlchemySession': 173 | return self.alchemy_session_class(self, session_id) 174 | 175 | def has_session(self, session_id: str) -> bool: 176 | if self.core_mode: 177 | t = self.Session.__table__ 178 | rows = self.db_engine.execute(select([func.count(t.c.auth_key)]) 179 | .where(and_(t.c.session_id == session_id, 180 | t.c.auth_key != b''))) 181 | try: 182 | count, = next(rows) 183 | return count > 0 184 | except StopIteration: 185 | return False 186 | else: 187 | return self.Session.query.filter(self.Session.session_id == session_id).count() > 0 188 | 189 | def list_sessions(self): 190 | return 191 | 192 | def save(self) -> None: 193 | if self.db: 194 | self.db.commit() 195 | --------------------------------------------------------------------------------