├── .gitignore ├── sqlmapper ├── __init__.py ├── base_engine.py ├── connection.py ├── utils.py ├── aio │ ├── __init__.py │ └── amysql.py ├── sqlite.py ├── mysql.py ├── psql.py └── table.py ├── examples ├── tablelist.py ├── create_index.py ├── oncommit.py ├── context.py ├── aio_mysql0.py └── example0.py ├── setup.py ├── tests ├── test_threading.py └── test_main.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | info 3 | __pycache__ 4 | *.pyc 5 | MANIFEST 6 | Makefile 7 | dist/ 8 | test.py 9 | .vscode/ 10 | .pytest_cache 11 | -------------------------------------------------------------------------------- /sqlmapper/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from .connection import Connection 4 | 5 | 6 | __version__ = '0.3.5' 7 | -------------------------------------------------------------------------------- /examples/tablelist.py: -------------------------------------------------------------------------------- 1 | 2 | from sqlmapper import Connection 3 | 4 | 5 | def main(): 6 | db = Connection(host='127.0.0.1', user='root', db='example', autocreate=True) 7 | 8 | for table in db: 9 | print(table) 10 | 11 | 12 | if __name__ == '__main__': 13 | main() 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from distutils.core import setup 4 | 5 | setup( 6 | name='sqlmapper', 7 | version='0.3.5', 8 | description='SQL Mapper', 9 | author='Oleg Nechaev', 10 | author_email='lega911@gmail.com', 11 | url='https://github.com/lega911/sqlmapper', 12 | packages=['sqlmapper', 'sqlmapper.aio'], 13 | license='MIT' 14 | ) 15 | -------------------------------------------------------------------------------- /examples/create_index.py: -------------------------------------------------------------------------------- 1 | 2 | from sqlmapper import Connection 3 | 4 | 5 | def main(): 6 | db = Connection(host='127.0.0.1', user='root', db='example', autocreate=True) 7 | 8 | db.book.add_column('id', 'int', primary=True, auto_increment=True, exist_ok=True) 9 | db.book.add_column('name', 'varchar(32)', exist_ok=True) 10 | 11 | db.book.create_index('nameindex', 'name', unique=True, exist_ok=True) 12 | 13 | db.book.insert({'name': 'ubuntu'}) 14 | db.commit() 15 | 16 | 17 | if __name__ == '__main__': 18 | main() 19 | -------------------------------------------------------------------------------- /examples/oncommit.py: -------------------------------------------------------------------------------- 1 | 2 | from sqlmapper import Connection 3 | 4 | 5 | def main(): 6 | db = Connection(host='127.0.0.1', user='root', db='example', autocreate=True, read_commited=True) 7 | 8 | @db.on_commit 9 | def after_commit(): 10 | print('commited') 11 | 12 | @db.on_rollback 13 | def after_rollback(): 14 | print('rollback') 15 | 16 | db.book.insert({'name': 'ubuntu', 'value': 3}) 17 | 18 | print('start commit') 19 | db.commit() 20 | 21 | 22 | if __name__ == '__main__': 23 | main() 24 | -------------------------------------------------------------------------------- /examples/context.py: -------------------------------------------------------------------------------- 1 | 2 | from sqlmapper import Connection 3 | 4 | 5 | def main(): 6 | db = Connection(host='127.0.0.1', user='root', db='example', autocreate=True, read_commited=True) 7 | 8 | def add_row(name): 9 | with db: 10 | print('insert', name) 11 | db.book.insert({'name': name}) 12 | 13 | @db.on_commit 14 | def msg(): 15 | print('commit', name) 16 | 17 | add_row('RedHat') 18 | print() 19 | 20 | with db: 21 | add_row('Linux') 22 | add_row('Ubuntu') 23 | add_row('Debian') 24 | 25 | print('* group commit') 26 | 27 | 28 | if __name__ == '__main__': 29 | main() 30 | -------------------------------------------------------------------------------- /examples/aio_mysql0.py: -------------------------------------------------------------------------------- 1 | 2 | import asyncio 3 | from sqlmapper.aio import Connection 4 | 5 | 6 | async def main(): 7 | db = await Connection(host='127.0.0.1', user='root', db='example', autocreate=True, read_commited=True) 8 | 9 | await db.book.add_column('id', 'int', primary=True, auto_increment=True, exist_ok=True) 10 | await db.book.add_column('name', 'text', exist_ok=True) 11 | await db.book.add_column('value', 'int', exist_ok=True) 12 | 13 | await db.book.insert({'name': 'ubuntu', 'value': 16}) 14 | await db.commit() 15 | d = await db.book.find_one(1) 16 | print(d) 17 | 18 | await db.book.update(1, {'value': 18}) 19 | await db.commit() 20 | 21 | for d in await db.book.find({'name': 'ubuntu'}): 22 | print(d) 23 | 24 | await db.book.delete({'value': 18}) 25 | await db.commit() 26 | 27 | print(await db.book.count()) 28 | 29 | async for name in db: 30 | print('table', name) 31 | 32 | 33 | loop = asyncio.get_event_loop() 34 | loop.run_until_complete(main()) 35 | -------------------------------------------------------------------------------- /examples/example0.py: -------------------------------------------------------------------------------- 1 | 2 | from sqlmapper import Connection 3 | 4 | 5 | def run_mysql(): 6 | run(Connection(host='127.0.0.1', user='root', db='example', autocreate=True, read_commited=True)) 7 | 8 | 9 | def run_psql(): 10 | run(Connection(engine='postgresql', host='127.0.0.1', user='postgres', password='secret', db='example', autocreate=True)) 11 | 12 | 13 | def run_sqlite(): 14 | run(Connection(engine='sqlite')) 15 | 16 | 17 | def run(db): 18 | db.book.add_column('id', 'int', primary=True, auto_increment=True, exist_ok=True) 19 | db.book.add_column('name', 'text', exist_ok=True) 20 | db.book.add_column('value', 'int', exist_ok=True) 21 | 22 | db.book.insert({'name': 'ubuntu', 'value': 16}) 23 | db.commit() 24 | d = db.book.find_one(1) 25 | print(d) 26 | 27 | db.book.update(1, {'value': 18}) 28 | db.commit() 29 | 30 | for d in db.book.find({'name': 'ubuntu'}): 31 | print(d) 32 | 33 | db.book.delete({'value': 18}) 34 | db.commit() 35 | 36 | print(db.book.count()) 37 | 38 | 39 | if __name__ == '__main__': 40 | run_psql() 41 | run_mysql() 42 | run_sqlite() 43 | -------------------------------------------------------------------------------- /sqlmapper/base_engine.py: -------------------------------------------------------------------------------- 1 | 2 | class MultiException(Exception): 3 | def __init__(self, e): 4 | super(MultiException, self).__init__() 5 | self.exceptions = e 6 | 7 | 8 | class BaseEngine(object): 9 | def __init__(self): 10 | if not hasattr(self, 'local'): 11 | self.local = type('local', (object,), {})() 12 | self.thread_init() 13 | 14 | def thread_init(self): 15 | if hasattr(self.local, 'tables'): 16 | return 17 | self.local.tables = {} 18 | self.local.commit = [] 19 | self.local.rollback = [] 20 | 21 | def fire_event(self, success): 22 | self.thread_init() 23 | fnlist = self.local.commit if success else self.local.rollback 24 | self.local.commit = [] 25 | self.local.rollback = [] 26 | exceptions = [] 27 | for fn in fnlist: 28 | try: 29 | fn() 30 | except Exception as e: 31 | exceptions.append(e) 32 | if exceptions: 33 | if len(exceptions) == 1: 34 | raise exceptions[0] 35 | else: 36 | raise MultiException(exceptions) 37 | 38 | def on_commit(self, fn): 39 | self.thread_init() 40 | self.local.commit.append(fn) 41 | 42 | def on_rollback(self, fn): 43 | self.thread_init() 44 | self.local.rollback.append(fn) 45 | -------------------------------------------------------------------------------- /sqlmapper/connection.py: -------------------------------------------------------------------------------- 1 | 2 | class Connection(object): 3 | def __init__(self, **kw): 4 | engine = kw.pop('engine', None) or 'mysql' 5 | 6 | if engine == 'mysql': 7 | from .mysql import Engine as engine 8 | elif engine == 'sqlite': 9 | from .sqlite import Engine as engine 10 | elif engine == 'postgresql': 11 | from .psql import Engine as engine 12 | elif not callable(engine): 13 | raise NotImplementedError 14 | self._engine = engine(**kw) 15 | self._engine.local.contextlvl = 0 16 | 17 | def commit(self): 18 | self._engine.commit() 19 | 20 | def rollback(self): 21 | self._engine.rollback() 22 | 23 | def close(self): 24 | self._engine.close() 25 | 26 | def __getitem__(self, name): 27 | return self._engine.get_table(name) 28 | 29 | def __getattr__(self, name): 30 | return self[name] 31 | 32 | def __iter__(self): 33 | return self._engine.get_tables() 34 | 35 | def on_commit(self, fn): 36 | self._engine.on_commit(fn) 37 | 38 | def on_rollback(self, fn): 39 | self._engine.on_rollback(fn) 40 | 41 | def __enter__(self): 42 | self._engine.local.contextlvl += 1 43 | 44 | def __exit__(self, exc_type, exc_value, traceback): 45 | self._engine.local.contextlvl -= 1 46 | if not self._engine.local.contextlvl: 47 | if exc_type: 48 | self.rollback() 49 | else: 50 | self.commit() 51 | if exc_type: 52 | return False 53 | 54 | @property 55 | def cursor(self): 56 | return self._engine.get_cursor() 57 | -------------------------------------------------------------------------------- /sqlmapper/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import re 4 | 5 | 6 | PY3 = sys.version_info.major == 3 7 | NoValue = object() 8 | 9 | if PY3: 10 | def is_int(value): 11 | return isinstance(value, int) 12 | 13 | def is_str(value): 14 | return isinstance(value, str) 15 | 16 | def is_bytes(value): 17 | return isinstance(value, bytes) 18 | 19 | else: 20 | def is_int(value): 21 | return isinstance(value, (int, long)) 22 | 23 | def is_str(value): 24 | return isinstance(value, unicode) 25 | 26 | def is_bytes(value): 27 | return isinstance(value, str) 28 | 29 | 30 | def validate_name(name): 31 | assert name 32 | assert is_str(name) or is_bytes(name), 'Wrong type' 33 | assert re.match(r'^[\w\d_]+$', name), 'Wrong name value: `{}`'.format(name) 34 | 35 | 36 | def quote_key(name, q='`'): 37 | if '.' in name: 38 | name = name.split('.') 39 | else: 40 | name = [name] 41 | 42 | result = [] 43 | for n in name: 44 | validate_name(n) 45 | result.append(q + n + q) 46 | return '.'.join(result) 47 | 48 | 49 | def format_func(name, q='`'): 50 | if '(' not in name: 51 | return quote_key(name, q) 52 | 53 | r = re.match(r'^(\w+)\(([^\)]+)\)\s+as\s+(\w+)$', name) 54 | if not r: 55 | r = re.match(r'^(\w+)\(([^\)]+)\)$', name) 56 | if not r: 57 | raise ValueError('Error column name: "%"' % name) 58 | 59 | rx = r.groups(0) 60 | func = rx[0] 61 | name = rx[1] 62 | if len(rx) == 3: 63 | key = rx[2] 64 | else: 65 | key = func + '_' + name 66 | 67 | return '{}({}) as {}'.format(func, quote_key(name, q), quote_key(key, q).lower()) 68 | -------------------------------------------------------------------------------- /tests/test_threading.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import threading 4 | from sqlmapper import Connection 5 | 6 | 7 | def get_db(): 8 | db = Connection(host='127.0.0.1', db='unittest', user='root', autocreate=True, read_commited=True) 9 | db.book.drop() 10 | 11 | db.book.add_column('id', 'int', primary=True, auto_increment=True, exist_ok=True) 12 | db.book.add_column('value', 'int', exist_ok=True) 13 | 14 | db.book.insert({'value': 5}) 15 | db.commit() 16 | return db 17 | 18 | 19 | def test_threading0(): 20 | result = [] 21 | def add(*a): 22 | result.append('.'.join(map(str, a))) 23 | 24 | db = get_db() 25 | db.book.update(1, {'value': 20}) 26 | 27 | def run(): 28 | d = db.book.find_one(1) 29 | add('t', d['value']) 30 | 31 | d = db.book.find_one(1, for_update=True) 32 | add('t', d['value']) 33 | db.book.update(1, {'value': d['value'] + 10}) 34 | time.sleep(0.2) 35 | add('t.commit') 36 | db.commit() 37 | 38 | t = threading.Thread(target=run) 39 | t.start() 40 | time.sleep(0.5) 41 | add('m.commit') 42 | db.commit() 43 | 44 | time.sleep(0.1) 45 | d = db.book.find_one(1) 46 | add('m', d['value']) 47 | 48 | time.sleep(0.2) 49 | d = db.book.find_one(1) 50 | add('m', d['value']) 51 | 52 | t.join() 53 | assert result == ['t.5', 'm.commit', 't.20', 'm.20', 't.commit', 'm.30'] 54 | 55 | 56 | def test_threading1(): 57 | result = [] 58 | def add(*a): 59 | result.append('.'.join(map(str, a))) 60 | 61 | db = get_db() 62 | 63 | def run(n): 64 | for i in range(5): 65 | d = db.book.find_one(1, for_update=True) 66 | add(d['value']) 67 | db.book.update(1, {'value': d['value'] + 1}) 68 | db.commit() 69 | 70 | threads = [] 71 | for i in range(5): 72 | t = threading.Thread(target=run, args=(i,)) 73 | t.start() 74 | threads.append(t) 75 | for t in threads: 76 | t.join() 77 | 78 | assert db.book.find_one(1)['value'] == 30 79 | assert result == ['5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29'] 80 | -------------------------------------------------------------------------------- /sqlmapper/aio/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import asyncio 3 | 4 | async def Connection(**kw): 5 | engine = kw.pop('engine', None) or 'mysql' 6 | loop = kw.pop('engine', None) or asyncio.get_event_loop() 7 | 8 | if engine == 'mysql': 9 | from .amysql import Engine 10 | engine = Engine() 11 | await engine.init(loop=loop, **kw) 12 | return AsyncConnection(engine) 13 | else: 14 | raise NotImplementedError() 15 | 16 | 17 | class AsyncConnection: 18 | def __init__(self, engine): 19 | self._engine = engine 20 | 21 | async def commit(self): 22 | await self._engine.commit() 23 | 24 | async def rollback(self): 25 | await self._engine.rollback() 26 | 27 | def close(self): 28 | self._engine.close() 29 | 30 | def __getitem__(self, name): 31 | return self._engine.get_table(name) 32 | 33 | def __getattr__(self, name): 34 | return self[name] 35 | 36 | def __aiter__(self): 37 | return DBList(self._engine) 38 | 39 | def on_commit(self, fn): 40 | self._engine.on_commit(fn) 41 | 42 | def on_rollback(self, fn): 43 | self._engine.on_rollback(fn) 44 | 45 | @property 46 | def cursor(self): 47 | return self._engine.cursor 48 | 49 | ''' 50 | def __enter__(self): 51 | self._engine.local.contextlvl += 1 52 | 53 | def __exit__(self, exc_type, exc_value, traceback): 54 | self._engine.local.contextlvl -= 1 55 | if not self._engine.local.contextlvl: 56 | if exc_type: 57 | self.rollback() 58 | else: 59 | self.commit() 60 | if exc_type: 61 | return False 62 | ''' 63 | 64 | class DBList: 65 | def __init__(self, engine): 66 | self.engine = engine 67 | self.result = None 68 | self.index = 0 69 | 70 | def __aiter__(self): 71 | return self 72 | 73 | async def __anext__(self): 74 | if self.result is None: 75 | self.result = await self.engine.get_tables() 76 | 77 | if self.index >= len(self.result): 78 | raise StopAsyncIteration 79 | 80 | value = self.result[self.index] 81 | self.index += 1 82 | return value 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sqlmapper 2 | Easy wrapper for SQL 3 | 4 | * Supports Python 2.x, 3.x, MySQL, PostgreSQL, SQLite, asyncio + mysql 5 | * Thread-safe (you can use the same connection from different threads) 6 | * License [MIT](http://opensource.org/licenses/MIT) 7 | 8 | ### Install and update using pip 9 | 10 | ```bash 11 | pip install -U sqlmapper 12 | ``` 13 | 14 | ### Examples 15 | ```python 16 | db = Connection(db='example') 17 | 18 | db.tblname.insert({'name': 'Ubuntu', 'value': 14}) 19 | # INSERT INTO `tblname` (`name`, `value`) VALUES ('Ubuntu', 14) 20 | 21 | db.tblname.insert({'name': 'MacOS', 'value': 10}) 22 | # INSERT INTO `tblname` (`name`, `value`) VALUES ('MacOS', 10) 23 | 24 | for d in db.tblname.find({'name': 'Ubuntu'}): 25 | # SELECT tblname.* FROM `tblname` WHERE `tblname`.`name`='Ubuntu' 26 | print(d) 27 | 28 | db.tblname.update({'name': 'Ubuntu'}, {'value': 16}) 29 | # UPDATE `tblname` SET `value` = 16 WHERE `tblname`.`name`='Ubuntu' 30 | 31 | db.tblname.find_one({'Name': 'Ubuntu'}) 32 | # SELECT tblname.* FROM `tblname` WHERE `name` = 'Ubuntu' LIMIT 1 33 | 34 | db.tblname.find_one(2) 35 | # find by primary key 36 | # SELECT tblname.* FROM `tblname` WHERE `id` = 2 LIMIT 1 37 | 38 | db.tblname.delete({'name': 'MacOS'}) 39 | # DELETE FROM `tblname` WHERE `tblname`.`name`='MacOS' 40 | 41 | db.commit() 42 | ``` 43 | 44 | ### asyncio 45 | ```python 46 | from sqlmapper.aio import Connection 47 | 48 | db = await Connection(db='example') 49 | 50 | await db.book.add_column('value', 'int', exist_ok=True) 51 | await db.book.insert({'name': 'ubuntu', 'value': 16}) 52 | await db.commit() 53 | d = await db.book.find_one(1) 54 | print(d) 55 | 56 | await db.book.update(1, {'value': 18}) 57 | print(await db.book.count()) 58 | ``` 59 | 60 | ### Change schema 61 | ```python 62 | 63 | # a table is created for first column 64 | db.tblname.add_column('id', 'INT(11)', primary=True, auto_increment=True, exist_ok=True) 65 | # CREATE TABLE `tblname` (`id` INT(11) NOT NULL AUTO_INCREMENT, PRIMARY KEY (`id`)) ENGINE=InnoDB DEFAULT CHARSET utf8mb4 COLLATE utf8mb4_unicode_ci 66 | 67 | db.tblname.add_column('name', 'VARCHAR(32)', exist_ok=True) 68 | # ALTER TABLE `tblname` ADD COLUMN `name` VARCHAR(32) 69 | 70 | db.tblname.add_column('value', 'INT(11)', exist_ok=True) 71 | # ALTER TABLE `tblname` ADD COLUMN `value` INT(11) 72 | 73 | db.tblname.create_index('name_idx', ['name'], exist_ok=True) 74 | # ALTER TABLE `tblname` ADD INDEX `name_idx`(`name`) 75 | ``` 76 | 77 | ### Join 78 | ```python 79 | for d in db.parent.find({'name': 'Linux'}, join='child.id=child_id'): 80 | # SELECT parent.*, "" as __divider, child.* FROM `parent` JOIN child AS child ON child.id = child_id WHERE `parent`.`name`='Linux' 81 | print(d) 82 | # d == { 83 | # 'name': 'Linux', 84 | # 'child_id': 5, 85 | # 'child': { 86 | # 'id': 5, 87 | # 'name': 'Ubuntu' 88 | # } 89 | # } 90 | ``` 91 | 92 | ### Group by 93 | ```python 94 | for d in db.tblname.find(group_by='name', columns=['name', 'SUM(value)']): 95 | # SELECT `name`, SUM(`value`) as `sum_value` FROM `tblname` GROUP BY `name` 96 | print(d) # {'name': u'Ubuntu', 'sum_value': 32} 97 | ``` 98 | 99 | ### List of tables 100 | ```python 101 | for table_name in db: 102 | print(table_name) 103 | ``` 104 | -------------------------------------------------------------------------------- /tests/test_main.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | from sqlmapper import Connection 4 | 5 | 6 | def test_mysql(): 7 | main(Connection(host='127.0.0.1', db='unittest', user='root', autocreate=True, read_commited=True)) 8 | 9 | 10 | def test_psql(): 11 | main(Connection(engine='postgresql', host='127.0.0.1', db='unittest', user='postgres', password='secret', autocreate=True)) 12 | 13 | 14 | def test_sqlite(): 15 | main(Connection(engine='sqlite')) 16 | 17 | 18 | def main(db): 19 | db.book.drop() 20 | db.ref.drop() 21 | 22 | db.book.add_column('id', 'int', primary=True, auto_increment=True, exist_ok=True) 23 | db.book.add_column('name', 'text', exist_ok=True) 24 | db.book.add_column('value', 'int', exist_ok=True) 25 | assert db['book'].count() == 0 26 | 27 | assert len(db.book.describe()) == 3 28 | assert db['book'].get_column('value')['name'] == 'value' 29 | assert 'book' in db 30 | 31 | db.book.insert({'name': 'ubuntu', 'value': 16}) 32 | db.book.insert({'name': 'mint', 'value': 18}) 33 | db.book.insert({'name': 'debian', 'value': 9}) 34 | db.book.insert({'name': 'debian', 'value': 8}) 35 | db.book.insert({'name': 'redhat', 'value': 0}) 36 | db.book.insert({'name': 'macos', 'value': 10}) 37 | db.book.insert({'name': 'ubuntu', 'value': 18}) 38 | db.book.insert({'name': 'ubuntu', 'value': 14}) 39 | db.commit() 40 | 41 | assert db.book.count() == 8 42 | assert db.book.count(('value > %s', 10)) == 4 43 | 44 | class dd: 45 | status = 0 46 | 47 | @db.on_commit 48 | def on_commit(): 49 | dd.status = 1 50 | 51 | @db.on_rollback 52 | def on_rollback(): 53 | dd.status = 2 54 | 55 | db.book.update({'name': 'redhat'}, {'value': 5}) 56 | db.commit() 57 | 58 | assert dd.status == 1 59 | 60 | @db.on_commit 61 | def on_commit(): 62 | dd.status = 3 63 | 64 | @db.on_rollback 65 | def on_rollback(): 66 | dd.status = 4 67 | 68 | db.book.update({'name': 'redhat'}, {'value': 25}) 69 | db.rollback() 70 | assert dd.status == 4 71 | 72 | for d in db.book.find({'name': 'redhat'}): 73 | dd.status += 1 74 | assert d['value'] == 5 75 | 76 | assert dd.status == 5 77 | 78 | assert db.book.find_one(3)['value'] == 9 79 | 80 | db.book.delete({'name': 'macos'}) 81 | assert db.book.count() == 7 82 | db.commit() 83 | 84 | r = list(db.book.find(group_by='name', columns=['name', 'COUNT(value)'], order_by='-count_value')) 85 | assert len(r) == 4 86 | assert r[0]['name'] == 'ubuntu' 87 | assert r[0]['count_value'] == 3 88 | assert r[1]['name'] == 'debian' 89 | assert r[1]['count_value'] == 2 90 | 91 | db.book.add_column('ext', 'int', exist_ok=True) 92 | assert len(db.book.describe()) == 4 93 | 94 | assert db.book.has_index('ext_index') is False 95 | db.book.create_index('ext_index', 'ext', unique=True, exist_ok=True) 96 | assert db.book.has_index('ext_index') is True 97 | 98 | db.book.update(1, {'ext': 10}) 99 | db.book.update(2, {'ext': 20}) 100 | db.book.update(3, {'ext': 30}) 101 | db.commit() 102 | 103 | with pytest.raises(Exception): 104 | db.book.update(4, {'ext': 10}) 105 | db.commit() 106 | 107 | db.rollback() 108 | 109 | db.ref.add_column('id', 'int', primary=True, auto_increment=True) 110 | db.ref.add_column('book_id', 'int') 111 | db.ref.insert({'book_id': 1}) 112 | db.ref.insert({'book_id': 2}) 113 | db.ref.insert({'book_id': 3}) 114 | db.ref.insert({'book_id': 6}) 115 | db.ref.insert({'book_id': 1}) 116 | 117 | r = list(db.ref.find(join='book.id=book_id', order_by='ref.id')) 118 | assert len(r) == 4 119 | assert r[1]['book']['value'] == 18 120 | assert r[2]['book']['value'] == 9 121 | 122 | r = list(db.ref.find(left_join='book.id=book_id', order_by='ref.id')) 123 | assert len(r) == 5 124 | assert r[3]['book'] is None 125 | db.close() 126 | 127 | 128 | def test_context(): 129 | pass 130 | -------------------------------------------------------------------------------- /sqlmapper/sqlite.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | import sqlite3 4 | import copy 5 | from .table import Table 6 | from .utils import validate_name, NoValue, quote_key 7 | from .base_engine import BaseEngine 8 | 9 | 10 | class Engine(BaseEngine): 11 | def __init__(self, **kw): 12 | self.conn = sqlite3.connect(kw.get('db') or ':memory:') 13 | self.cursor = None 14 | super(Engine, self).__init__() 15 | 16 | def get_cursor(self): 17 | if not self.cursor: 18 | self.cursor = self.conn.cursor() 19 | return self.cursor 20 | 21 | def commit(self): 22 | self.conn.commit() 23 | self.fire_event(True) 24 | 25 | def rollback(self): 26 | self.conn.rollback() 27 | self.fire_event(False) 28 | 29 | def close(self): 30 | self.conn.close() 31 | self.conn = None 32 | 33 | def get_columns(self, table): 34 | result = self.local.tables.get(table) 35 | if not result: 36 | result = [] 37 | cursor = self.get_cursor() 38 | cursor.execute('PRAGMA table_info({})'.format(table)) 39 | for row in cursor: 40 | result.append({ 41 | 'name': row[1], 42 | 'type': row[2], 43 | 'notnull': row[3] == 1, 44 | 'default': row[4], 45 | 'primary': row[5] == 1 46 | }) 47 | self.local.tables[table] = result 48 | return copy.deepcopy(result) 49 | 50 | def get_table(self, name): 51 | return SqliteTable(name, self, keyword='?') 52 | 53 | def get_tables(self): 54 | cursor = self.get_cursor() 55 | cursor.execute('SELECT name FROM sqlite_master WHERE type = ?', ('table',)) 56 | for row in cursor: 57 | yield row[0] 58 | 59 | 60 | sqlite_types = { 61 | # integer 62 | 'INTEGER': 'INTEGER', 63 | 'INT': 'INTEGER', 64 | 65 | # text 66 | 'TEXT': 'TEXT', 67 | 'VARCHAR': 'TEXT', 68 | 69 | # none 70 | 'NONE': 'NONE', 71 | 'BLOB': 'NONE', 72 | 73 | # real 74 | 'REAL': 'REAL', 75 | 'DOUBLE': 'REAL', 76 | 'FLOAT': 'REAL', 77 | 78 | # numeric 79 | 'NUMERIC': 'NUMERIC', 80 | 'DECIMAL': 'NUMERIC', 81 | 'BOOLEAN': 'NUMERIC', 82 | 'DATE': 'NUMERIC', 83 | 'DATETIME': 'NUMERIC' 84 | } 85 | 86 | 87 | class SqliteTable(Table): 88 | def add_column(self, name, type, default=NoValue, exist_ok=False, primary=False, auto_increment=False, not_null=False): 89 | validate_name(name) 90 | 91 | type = sqlite_types.get(type.upper()) 92 | if not type: 93 | raise ValueError('Wrong type') 94 | 95 | values = [] 96 | scolumn = '`{}` {}'.format(name, type) 97 | 98 | if primary: 99 | scolumn += ' PRIMARY KEY' 100 | if auto_increment: 101 | scolumn += ' AUTOINCREMENT' 102 | elif not_null: 103 | scolumn += ' NOT NULL' 104 | 105 | if default != NoValue: 106 | if primary: 107 | raise ValueError('Can''t have default value') 108 | scolumn += ' DEFAULT ?' 109 | values.append(default) 110 | 111 | if self.tablename in self.engine.get_tables(): 112 | if exist_ok: 113 | if self.get_column(name): 114 | return 115 | sql = 'ALTER TABLE `{}` ADD COLUMN {}'.format(self.tablename, scolumn) 116 | else: 117 | sql = 'CREATE TABLE {} ({})'.format(self.tablename, scolumn) 118 | 119 | self.cursor.execute(sql, tuple(values)) 120 | self.engine.local.tables[self.tablename] = None 121 | 122 | def has_index(self, name): 123 | self.cursor.execute('PRAGMA index_list({})'.format(self.tablename)) 124 | for row in self.cursor: 125 | if row[1] == name: 126 | return True 127 | return False 128 | 129 | def create_index(self, name, column, unique=False, exist_ok=False): 130 | if exist_ok and self.has_index(name): 131 | return 132 | 133 | if not isinstance(column, list): 134 | column = [column] 135 | 136 | column = ', '.join(map(self.cc, column)) 137 | sql = 'CREATE {}INDEX {} on {} ({})'.format( 138 | 'UNIQUE ' if unique else '', 139 | name, 140 | self.tablename, 141 | column 142 | ) 143 | self.cursor.execute(sql) 144 | 145 | def _build_filter(self, filter): 146 | s, v = super(SqliteTable, self)._build_filter(filter) 147 | if isinstance(filter, (list, tuple)) and s: 148 | s = s.replace('%s', '?') 149 | return s, v 150 | -------------------------------------------------------------------------------- /sqlmapper/mysql.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | import MySQLdb 4 | import threading 5 | import re 6 | import copy 7 | from .table import Table 8 | from .utils import NoValue, validate_name 9 | from .base_engine import BaseEngine 10 | 11 | 12 | class Engine(BaseEngine): 13 | def __init__(self, autocreate=None, read_commited=False, **kw): 14 | self.read_commited = read_commited 15 | self.local = threading.local() 16 | if 'charset' not in kw: 17 | kw['charset'] = 'utf8mb4' 18 | super(Engine, self).__init__() 19 | 20 | self.db_config = {} 21 | for k in ['host', 'port', 'user', 'password', 'db', 'charset']: 22 | if k in kw: 23 | self.db_config[k] = kw[k] 24 | 25 | self.local.conn = self.get_connection(autocreate_db=autocreate) 26 | 27 | def get_connection(self, autocreate_db=False): 28 | try: 29 | return MySQLdb.connect(**self.db_config) 30 | except MySQLdb.OperationalError as e: 31 | if autocreate_db and e.args[0] == 1049: 32 | config = self.db_config.copy() 33 | db = config.pop('db') 34 | 35 | conn = MySQLdb.connect(**config) 36 | cursor = conn.cursor() 37 | cursor.execute('CREATE DATABASE {}'.format(db)) 38 | conn.close() 39 | return MySQLdb.connect(**self.db_config) 40 | else: 41 | raise 42 | 43 | def commit(self): 44 | self.local.conn.commit() 45 | self.fire_event(True) 46 | 47 | def rollback(self): 48 | self.local.conn.rollback() 49 | self.fire_event(False) 50 | 51 | def close(self): 52 | self.local.conn.close() 53 | self.local.cursor = None 54 | self.local.conn = None 55 | 56 | def get_cursor(self): 57 | self.thread_init() 58 | if hasattr(self.local, 'cursor'): 59 | return self.local.cursor 60 | 61 | if not hasattr(self.local, 'conn'): 62 | self.local.conn = self.get_connection() 63 | 64 | self.local.cursor = cursor = self.local.conn.cursor() 65 | if self.read_commited: 66 | cursor.execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED") 67 | return cursor 68 | 69 | def get_table(self, name): 70 | return MysqlTable(name, self) 71 | 72 | def get_tables(self): 73 | cursor = self.get_cursor() 74 | 75 | cursor.execute('SHOW TABLES') 76 | for row in cursor: 77 | yield row[0] 78 | 79 | def get_columns(self, table): 80 | self.thread_init() 81 | result = self.local.tables.get(table) 82 | if not result: 83 | result = [] 84 | cursor = self.get_cursor() 85 | cursor.execute('describe `{}`'.format(table)) 86 | for row in cursor: 87 | result.append({ 88 | 'name': row[0], 89 | 'type': row[1], 90 | 'null': row[2] == 'YES', 91 | 'default': row[4], 92 | 'primary': row[3] == 'PRI', 93 | 'auto_increment': row[5] == 'auto_increment' 94 | }) 95 | self.local.tables[table] = result 96 | return copy.deepcopy(result) 97 | 98 | 99 | class MysqlTable(Table): 100 | def add_column(self, name, column_type, not_null=False, default=NoValue, exist_ok=False, primary=False, auto_increment=False, collate=None): 101 | validate_name(name) 102 | assert re.match(r'^[\w\d\(\),]+$', column_type), 'Wrong type: {}'.format(column_type) 103 | values = [] 104 | scolumn = '`{}` {}'.format(name, column_type) 105 | 106 | if collate: 107 | charset = collate.split('_')[0] 108 | scolumn += ' CHARACTER SET {} COLLATE {}'.format(charset, collate) 109 | 110 | if primary: 111 | not_null = True 112 | 113 | if not_null: 114 | scolumn += ' NOT NULL' 115 | if auto_increment: 116 | scolumn += ' AUTO_INCREMENT' 117 | 118 | if default != NoValue: 119 | if auto_increment or primary: 120 | raise ValueError('Can''t have default value') 121 | scolumn += ' DEFAULT %s' 122 | values.append(default) 123 | 124 | if self.tablename in self.engine.get_tables(): 125 | if exist_ok: 126 | if self.get_column(name): 127 | return 128 | if primary: 129 | scolumn += ', ADD PRIMARY KEY (`{}`)'.format(name) 130 | sql = 'ALTER TABLE `{}` ADD COLUMN {}'.format(self.tablename, scolumn) 131 | else: 132 | if primary: 133 | scolumn += ', PRIMARY KEY (`{}`)'.format(name) 134 | collate = collate or 'utf8mb4_unicode_ci' 135 | charset = collate.split('_')[0] 136 | sql = 'CREATE TABLE `{}` ({}) ENGINE=InnoDB DEFAULT CHARSET {} COLLATE {}'.format(self.tablename, scolumn, charset, collate) 137 | self.cursor.execute(sql, tuple(values)) 138 | self.engine.local.tables[self.tablename] = None 139 | 140 | def has_index(self, name): 141 | self.cursor.execute('show index from ' + self.tablename) 142 | for row in self.cursor: 143 | if row[2] == name: 144 | return True 145 | return False 146 | 147 | def create_index(self, name, column, primary=False, unique=False, fulltext=False, exist_ok=False): 148 | if primary: 149 | name = 'PRIMARY' 150 | if exist_ok and self.has_index(name): 151 | return 152 | 153 | if isinstance(column, list): 154 | column = ', '.join(map(self.cc, column)) 155 | else: 156 | column = self.cc(column) 157 | 158 | index_type = 'INDEX ' 159 | if primary: 160 | index_type = 'PRIMARY KEY ' 161 | assert not fulltext 162 | name = '' 163 | else: 164 | name = self.cc(name) 165 | 166 | if unique and not primary: 167 | assert not fulltext 168 | index_type = 'UNIQUE ' 169 | elif fulltext: 170 | index_type = 'FULLTEXT ' 171 | 172 | sql = 'ALTER TABLE {} ADD {}{}({})'.format(self.cc(self.tablename), index_type, name, column) 173 | self.cursor.execute(sql) 174 | -------------------------------------------------------------------------------- /sqlmapper/psql.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | import psycopg2 4 | import threading 5 | import re 6 | import copy 7 | from .table import Table 8 | from .utils import NoValue, validate_name 9 | from .base_engine import BaseEngine 10 | 11 | 12 | class Engine(BaseEngine): 13 | def __init__(self, schema='public', autocreate=None, read_commited=False, **kw): 14 | self.read_commited = read_commited 15 | self.local = threading.local() 16 | self.schema = schema 17 | super(Engine, self).__init__() 18 | 19 | self.db_config = {} 20 | for k in ['host', 'port', 'user', 'password', 'dbname']: 21 | if k in kw: 22 | self.db_config[k] = kw[k] 23 | if not self.db_config.get('dbname'): 24 | self.db_config['dbname'] = kw.get('db') 25 | 26 | self.local.conn = self.get_connection(autocreate_db=autocreate) 27 | 28 | def get_connection(self, autocreate_db=False): 29 | try: 30 | return psycopg2.connect(**self.db_config) 31 | except psycopg2.OperationalError as e: 32 | if autocreate_db and 'does not exist' in str(e): 33 | config = self.db_config.copy() 34 | db = config.pop('dbname') 35 | 36 | conn = psycopg2.connect(**config) 37 | conn.autocommit = True 38 | cursor = conn.cursor() 39 | cursor.execute('CREATE DATABASE {}'.format(db)) 40 | conn.close() 41 | return psycopg2.connect(**self.db_config) 42 | else: 43 | raise 44 | 45 | def commit(self): 46 | self.local.conn.commit() 47 | self.fire_event(True) 48 | 49 | def rollback(self): 50 | self.local.conn.rollback() 51 | self.fire_event(False) 52 | 53 | def close(self): 54 | self.local.conn.close() 55 | self.local.cursor = None 56 | self.local.conn = None 57 | 58 | def get_cursor(self): 59 | self.thread_init() 60 | if hasattr(self.local, 'cursor'): 61 | return self.local.cursor 62 | 63 | if not hasattr(self.local, 'conn'): 64 | self.local.conn = self.get_connection() 65 | 66 | self.local.cursor = cursor = self.local.conn.cursor() 67 | if self.read_commited: 68 | cursor.execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED") 69 | cursor.execute('SET search_path TO ' + self.schema) 70 | return cursor 71 | 72 | def get_table(self, name): 73 | return PsqlTable(name, self) 74 | 75 | def get_tables(self): 76 | cursor = self.get_cursor() 77 | cursor.execute('SELECT tablename FROM pg_catalog.pg_tables where schemaname=%s', (self.schema,)) 78 | for row in cursor: 79 | yield row[0] 80 | 81 | def get_columns(self, table): 82 | self.thread_init() 83 | result = self.local.tables.get(table) 84 | if not result: 85 | result = [] 86 | cursor = self.get_cursor() 87 | # get primary key 88 | 89 | primary = set() 90 | cursor.execute( 91 | 'SELECT c.column_name, c.data_type FROM ' 92 | 'information_schema.table_constraints tc ' 93 | 'JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) ' 94 | 'JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name ' 95 | 'where constraint_type = %s and tc.table_name = %s', ('PRIMARY KEY', table)) 96 | for row in cursor: 97 | primary.add(row[0]) 98 | 99 | cursor.execute('select column_name, is_nullable, data_type, column_default, numeric_precision, numeric_precision_radix, * from INFORMATION_SCHEMA.COLUMNS where table_catalog=%s and table_schema=%s and table_name=%s', (self.db_config['dbname'], self.schema, table)) 100 | for row in cursor: 101 | result.append({ 102 | 'name': row[0], 103 | 'null': row[1] == 'YES', 104 | 'type': row[2], 105 | 'default': row[3], 106 | 'primary': row[0] in primary 107 | #'auto_increment': False 108 | }) 109 | self.local.tables[table] = result 110 | return copy.deepcopy(result) 111 | 112 | 113 | class PsqlTable(Table): 114 | def __init__(self, *a, **kw): 115 | super(PsqlTable, self).__init__(*a, quote='"', **kw) 116 | 117 | def add_column(self, name, column_type, not_null=False, default=NoValue, exist_ok=False, primary=False, auto_increment=False, collate=None): 118 | validate_name(name) 119 | assert re.match(r'^[\w\d\(\)]+$', column_type), 'Wrong type: {}'.format(column_type) 120 | values = [] 121 | 122 | if primary: 123 | if auto_increment: 124 | if column_type == 'bigint': 125 | column_type = 'bigserial' 126 | else: 127 | column_type = 'serial' 128 | 129 | scolumn = '{} {}'.format(name, column_type) 130 | 131 | if primary: 132 | scolumn += ' PRIMARY KEY' 133 | 134 | if not_null: 135 | scolumn += ' NOT NULL' 136 | 137 | if default != NoValue: 138 | if not_null or primary: 139 | raise ValueError('Can''t have default value') 140 | scolumn += ' DEFAULT %s' 141 | values.append(default) 142 | 143 | if self.tablename in self.engine.get_tables(): 144 | if exist_ok: 145 | if self.get_column(name): 146 | return 147 | if primary: 148 | raise NotImplementedError() 149 | sql = 'ALTER TABLE {} ADD COLUMN {}'.format(self.tablename, scolumn) 150 | else: 151 | sql = 'CREATE TABLE {} ({})'.format(self.tablename, scolumn) 152 | 153 | self.cursor.execute(sql, tuple(values)) 154 | self.engine.commit() 155 | self.engine.local.tables[self.tablename] = None 156 | 157 | def has_index(self, name): 158 | self.cursor.execute('select i.relname ' 159 | 'from pg_class t, pg_class i, pg_index ix, pg_attribute a ' 160 | 'where t.oid = ix.indrelid and i.oid = ix.indexrelid ' 161 | 'and a.attrelid = t.oid and a.attnum = ANY(ix.indkey) ' 162 | 'and t.relkind = %s and t.relname = %s', ('r', self.tablename)) 163 | for row in self.cursor: 164 | if row[0] == name: 165 | return True 166 | return False 167 | 168 | def create_index(self, name, column, unique=False, exist_ok=False): 169 | if exist_ok and self.has_index(name): 170 | return 171 | 172 | if isinstance(column, list): 173 | column = ', '.join(map(self.cc, column)) 174 | else: 175 | column = self.cc(column) 176 | 177 | if unique: 178 | index_type = 'UNIQUE INDEX' 179 | else: 180 | index_type = 'INDEX' 181 | 182 | sql = 'CREATE {} {} ON {} ({})'.format(index_type, self.cc(name), self.cc(self.tablename), column) 183 | self.cursor.execute(sql) 184 | -------------------------------------------------------------------------------- /sqlmapper/table.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | import re 4 | from .utils import NoValue, validate_name, quote_key, format_func, is_bytes, is_int, is_str 5 | 6 | 7 | class Table(object): 8 | def __init__(self, name, engine, keyword='%s', quote='`'): 9 | self.tablename = name 10 | self.engine = engine 11 | self.keyword = keyword 12 | self.quote = quote 13 | 14 | def cc(self, name): 15 | return quote_key(name, self.quote) 16 | 17 | @property 18 | def cursor(self): 19 | return self.engine.get_cursor() 20 | 21 | def describe(self): 22 | return self.engine.get_columns(self.tablename) 23 | 24 | def get_column(self, name): 25 | for column in self.describe(): 26 | if column['name'] == name: 27 | return column 28 | 29 | def insert(self, data): 30 | keys = [] 31 | values = [] 32 | items = [] 33 | for key, value in data.items(): 34 | keys.append(self.cc(key)) 35 | values.append(value) 36 | items.append(self.keyword) 37 | 38 | sql = 'INSERT INTO {} ({}) VALUES ({})'.format(self.cc(self.tablename), ', '.join(keys), ', '.join(items)) 39 | self.cursor.execute(sql, tuple(values)) 40 | assert self.cursor.rowcount == 1 41 | return self.cursor.lastrowid 42 | 43 | def _build_filter(self, filter): 44 | if filter is None: 45 | return None, [] 46 | elif isinstance(filter, dict): 47 | keys = [] 48 | values = [] 49 | for k, v in filter.items(): 50 | if '.' in k: 51 | k = self.cc(k) 52 | else: 53 | k = self.cc(self.tablename + '.' + k) 54 | if v is None: 55 | keys.append(k + ' is NULL') 56 | else: 57 | keys.append(k + '=' + self.keyword) 58 | values.append(v) 59 | sql = ' AND '.join(keys) 60 | return sql, values 61 | elif isinstance(filter, (list, tuple)): 62 | return filter[0], filter[1:] 63 | elif is_int(filter) or is_str(filter) or is_bytes(filter): 64 | # find by primary key 65 | key = None 66 | for column in self.describe(): 67 | if column['primary']: 68 | key = column['name'] 69 | break 70 | else: 71 | raise ValueError('No primary key') 72 | return '{} = {}'.format(self.cc(key), self.keyword), [filter] 73 | else: 74 | raise NotImplementedError 75 | 76 | def find_one(self, filter=None, join=None, left_join=None, for_update=False, columns=None, order_by=None): 77 | result = list(self.find(filter, limit=1, join=join, for_update=for_update, columns=columns, order_by=order_by)) 78 | if result: 79 | return result[0] 80 | 81 | def find(self, filter=None, limit=None, join=None, left_join=None, for_update=False, columns=None, group_by=None, order_by=None, distinct=False): 82 | """ 83 | join='subtable.id=column' 84 | join='subtable as tbl.id=column' 85 | """ 86 | 87 | if columns: 88 | assert not join 89 | if not isinstance(columns, (list, tuple)): 90 | columns = [columns] 91 | columns = ', '.join(map(lambda n: format_func(n, self.quote), columns)) 92 | else: 93 | columns = '{}.*'.format(self.tablename) 94 | 95 | joins = [] 96 | if join or left_join: 97 | assert bool(join) ^ bool(left_join) 98 | if left_join: 99 | join = left_join 100 | prefix = 'LEFT ' 101 | else: 102 | prefix = '' 103 | r = re.match(r'(\w+)\.(\w+)=(\w+)', join) 104 | if r: 105 | table2, column2, column1 = r.groups() 106 | alias = table2 107 | else: 108 | r = re.match(r'(\w+)\s+as\s+(\w+)\.(\w+)=(\w+)', join) 109 | assert r 110 | table2, alias, column2, column1 = r.groups() 111 | 112 | columns += ', \'\' as __divider, {}.*'.format(alias) 113 | join = ' {}JOIN {} AS {} ON {}.{} = {}'.format(prefix, table2, alias, alias, column2, column1) 114 | 115 | key = None 116 | if left_join: 117 | for c in self.engine.get_columns(self.tablename): 118 | if c['primary']: 119 | key = c['name'] 120 | break 121 | 122 | joins.append({ 123 | 'alias': alias, 124 | 'key': key 125 | }) 126 | 127 | sql = 'SELECT ' 128 | if distinct: 129 | sql += 'DISTINCT ' 130 | sql += '{} FROM {}'.format(columns, self.cc(self.tablename)) 131 | where, values = self._build_filter(filter) 132 | if join: 133 | sql += join 134 | if where: 135 | sql += ' WHERE ' + where 136 | if group_by: 137 | if not isinstance(group_by, list): 138 | group_by = [group_by] 139 | sql += ' GROUP BY ' + ', '.join(map(self.cc, group_by)) 140 | if order_by: 141 | if not isinstance(order_by, list): 142 | order_by = [order_by] 143 | oc = [] 144 | for name in order_by: 145 | if name.startswith('-'): 146 | oc.append(self.cc(name[1:]) + ' DESC') 147 | else: 148 | oc.append(self.cc(name)) 149 | sql += ' ORDER BY ' + ', '.join(oc) 150 | if limit: 151 | assert is_int(limit) 152 | sql += ' LIMIT {}'.format(limit) 153 | 154 | if for_update: 155 | sql += ' FOR UPDATE' 156 | 157 | self.cursor.execute(sql, tuple(values)) 158 | 159 | columns = self.cursor.description 160 | if self.cursor.rowcount: 161 | for row in self.cursor: 162 | join_index = -1 163 | join_alias = None 164 | join_key = None 165 | d = {} 166 | for i, value in enumerate(row): 167 | col = columns[i] 168 | column_name = col[0] 169 | if column_name == '__divider': 170 | join_index += 1 171 | join_alias = joins[join_index]['alias'] 172 | join_key = joins[join_index]['key'] 173 | d[join_alias] = {} 174 | continue 175 | if join_alias: 176 | if column_name == join_key: 177 | if value is None: 178 | d[join_alias] = None 179 | if d[join_alias] is not None: 180 | d[join_alias][column_name] = value 181 | else: 182 | d[column_name] = value 183 | yield d 184 | 185 | def update(self, filter=None, update=None, limit=None): 186 | up = [] 187 | values = [] 188 | for key, value in update.items(): 189 | up.append('{} = {}'.format(self.cc(key), self.keyword)) 190 | values.append(value) 191 | 192 | sql = 'UPDATE {} SET {}'.format(self.cc(self.tablename), ', '.join(up)) 193 | 194 | where, wvalues = self._build_filter(filter) 195 | if where: 196 | sql += ' WHERE ' + where 197 | values += wvalues 198 | 199 | if limit: 200 | assert is_int(limit) 201 | sql += ' LIMIT {}'.format(limit) 202 | 203 | self.cursor.execute(sql, tuple(values)) 204 | 205 | def update_one(self, filter=None, update=None): 206 | self.update(filter, update, limit=1) 207 | 208 | def delete(self, filter=None): 209 | where, values = self._build_filter(filter) 210 | 211 | sql = 'DELETE FROM {}'.format(self.cc(self.tablename)) 212 | if where: 213 | sql += ' WHERE {}'.format(where) 214 | self.cursor.execute(sql, tuple(values)) 215 | 216 | def count(self, filter=None): 217 | where, values = self._build_filter(filter) 218 | 219 | sql = 'SELECT COUNT(*) FROM {}'.format(self.cc(self.tablename)) 220 | if where: 221 | sql += ' WHERE {}'.format(where) 222 | self.cursor.execute(sql, tuple(values)) 223 | return self.cursor.fetchone()[0] 224 | 225 | def drop(self, exist_ok=True): 226 | sql = 'DROP TABLE ' 227 | if exist_ok: 228 | sql += 'IF EXISTS ' 229 | sql += self.tablename 230 | self.cursor.execute(sql) 231 | -------------------------------------------------------------------------------- /sqlmapper/aio/amysql.py: -------------------------------------------------------------------------------- 1 | 2 | import copy 3 | import re 4 | import aiomysql 5 | from pymysql.err import InternalError, OperationalError 6 | from ..utils import validate_name, NoValue, quote_key, format_func 7 | 8 | 9 | class Engine: 10 | def __init__(self): 11 | self.cursors = [] 12 | self.local = type('local', (object,), {'tables': {}})() 13 | 14 | async def init(self, *, loop, read_commited=False, autocreate=False, **kw): 15 | self.loop = loop 16 | self.read_commited = read_commited 17 | 18 | self._option = option = {} 19 | for k in ['db', 'host', 'port', 'user', 'password', 'charset']: 20 | if k in kw: 21 | option[k] = kw[k] 22 | 23 | if 'charset' not in option: 24 | option['charset'] = 'utf8mb4' 25 | 26 | try: 27 | self.connection = await aiomysql.connect(loop=loop, **option) 28 | except OperationalError as e: 29 | if autocreate and e.args[0] == 2003 and isinstance(e.__cause__, InternalError) and e.__cause__.args[0] == 1049: 30 | # Unknown database 31 | connect_opt = option.copy() 32 | db = connect_opt.pop('db') 33 | connection = await aiomysql.connect(loop=loop, **connect_opt) 34 | cursor = await connection.cursor() 35 | try: 36 | await cursor.execute("CREATE DATABASE `{}`".format(db)) 37 | finally: 38 | await cursor.close() 39 | connection.close() 40 | 41 | self.connection = await aiomysql.connect(loop=loop, **option) 42 | else: 43 | raise 44 | 45 | async def commit(self): 46 | await self.connection.commit() 47 | 48 | async def rollback(self): 49 | await self.connection.rollback() 50 | 51 | async def acquare_cursor(self): 52 | if self.cursors: 53 | return self.cursors.pop() 54 | 55 | cursor = await self.connection.cursor() 56 | if self.read_commited: 57 | await cursor.execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED") 58 | return cursor 59 | 60 | def release_cursor(self, cursor): 61 | if not cursor.closed: 62 | self.cursors.append(cursor) 63 | 64 | async def reconnect(self): 65 | for cursor in self.cursors: 66 | await cursor.close() 67 | self.cursors = [] 68 | 69 | if self.connection: 70 | self.connection.close() 71 | 72 | self.connection = await aiomysql.connect(loop=self.loop, **self._option) 73 | 74 | @property 75 | def cursor(self): 76 | return CursorContext(self) 77 | 78 | def try_execute(self, query, argv=None): 79 | return TryExecuteContext(self, query, argv) 80 | 81 | def get_table(self, name): 82 | return Table(name, self) 83 | 84 | async def get_tables(self): 85 | result = [] 86 | async with self.try_execute('SHOW TABLES') as cursor: 87 | for row in await cursor.fetchall(): 88 | result.append(row[0]) 89 | return result 90 | 91 | async def get_columns(self, table): 92 | result = self.local.tables.get(table) 93 | if not result: 94 | result = [] 95 | async with self.try_execute('describe `{}`'.format(table)) as cursor: 96 | for row in await cursor.fetchall(): 97 | result.append({ 98 | 'name': row[0], 99 | 'type': row[1], 100 | 'null': row[2] == 'YES', 101 | 'default': row[4], 102 | 'primary': row[3] == 'PRI', 103 | 'auto_increment': row[5] == 'auto_increment' 104 | }) 105 | 106 | self.local.tables[table] = result 107 | return copy.deepcopy(result) 108 | 109 | 110 | class CursorContext: 111 | def __init__(self, engine): 112 | self.engine = engine 113 | 114 | async def __aenter__(self): 115 | self.cursor = await self.engine.acquare_cursor() 116 | return self.cursor 117 | 118 | async def __aexit__(self, exc_type, exc, tb): 119 | self.engine.release_cursor(self.cursor) 120 | 121 | 122 | class TryExecuteContext: 123 | def __init__(self, engine, query, argv): 124 | self.engine = engine 125 | self.query = query 126 | self.argv = argv 127 | 128 | async def run(self): 129 | cursor = await self.engine.acquare_cursor() 130 | try: 131 | await cursor.execute(self.query, self.argv) 132 | except OperationalError as e: 133 | self.engine.release_cursor(cursor) 134 | if e.args[0] == 2013: 135 | await self.engine.reconnect() 136 | cursor = await self.engine.acquare_cursor() 137 | try: 138 | await cursor.execute(self.query, self.argv) 139 | except Exception: 140 | self.engine.release_cursor(cursor) 141 | raise 142 | else: 143 | raise 144 | 145 | return cursor 146 | 147 | async def __aenter__(self): 148 | self.cursor = await self.run() 149 | return self.cursor 150 | 151 | async def __aexit__(self, exc_type, exc, tb): 152 | self.engine.release_cursor(self.cursor) 153 | 154 | async def __call__(self): 155 | self.engine.release_cursor(await self.run()) 156 | 157 | 158 | class Table: 159 | def __init__(self, name, engine): 160 | self.tablename = name 161 | self.engine = engine 162 | self.keyword = '%s' 163 | 164 | @property 165 | def cursor(self): 166 | return self.engine.cursor 167 | 168 | def try_execute(self, query, argv=None): 169 | return self.engine.try_execute(query, argv) 170 | 171 | async def describe(self): 172 | return await self.engine.get_columns(self.tablename) 173 | 174 | async def get_column(self, name): 175 | for column in await self.describe(): 176 | if column['name'] == name: 177 | return column 178 | 179 | async def add_column(self, name, column_type, not_null=False, default=NoValue, exist_ok=False, primary=False, auto_increment=False, collate=None): 180 | validate_name(name) 181 | assert re.match(r'^[\w\d\(\)]+$', column_type), 'Wrong type: {}'.format(column_type) 182 | values = [] 183 | scolumn = '`{}` {}'.format(name, column_type) 184 | 185 | if collate: 186 | charset = collate.split('_')[0] 187 | scolumn += ' CHARACTER SET {} COLLATE {}'.format(charset, collate) 188 | 189 | if primary: 190 | not_null = True 191 | 192 | if not_null: 193 | scolumn += ' NOT NULL' 194 | if auto_increment: 195 | scolumn += ' AUTO_INCREMENT' 196 | 197 | if default != NoValue: 198 | if not_null or primary: 199 | raise ValueError('Can''t have default value') 200 | scolumn += ' DEFAULT %s' 201 | values.append(default) 202 | 203 | if self.tablename in await self.engine.get_tables(): 204 | if exist_ok: 205 | if await self.get_column(name): 206 | return 207 | if primary: 208 | scolumn += ', ADD PRIMARY KEY (`{}`)'.format(name) 209 | sql = 'ALTER TABLE `{}` ADD COLUMN {}'.format(self.tablename, scolumn) 210 | else: 211 | if primary: 212 | scolumn += ', PRIMARY KEY (`{}`)'.format(name) 213 | collate = collate or 'utf8mb4_unicode_ci' 214 | charset = collate.split('_')[0] 215 | sql = 'CREATE TABLE `{}` ({}) ENGINE=InnoDB DEFAULT CHARSET {} COLLATE {}'.format(self.tablename, scolumn, charset, collate) 216 | 217 | await self.try_execute(sql, tuple(values))() 218 | self.engine.local.tables[self.tablename] = None 219 | 220 | async def insert(self, data): 221 | keys = [] 222 | values = [] 223 | items = [] 224 | for key, value in data.items(): 225 | keys.append(quote_key(key)) 226 | values.append(value) 227 | items.append(self.keyword) 228 | 229 | sql = 'INSERT INTO `{}` ({}) VALUES ({})'.format(self.tablename, ', '.join(keys), ', '.join(items)) 230 | async with self.try_execute(sql, tuple(values)) as cursor: 231 | assert cursor.rowcount == 1 232 | return cursor.lastrowid 233 | 234 | async def _build_filter(self, filter): 235 | if filter is None: 236 | return None, [] 237 | elif isinstance(filter, dict): 238 | keys = [] 239 | values = [] 240 | for k, v in filter.items(): 241 | if '.' not in k: 242 | k = self.tablename + '.' + k 243 | k = quote_key(k) 244 | if v is None: 245 | keys.append(k + ' is NULL') 246 | else: 247 | keys.append(k + '=' + self.keyword) 248 | values.append(v) 249 | sql = ' AND '.join(keys) 250 | return sql, values 251 | elif isinstance(filter, (list, tuple)): 252 | return filter[0], filter[1:] 253 | elif isinstance(filter, (int, str, bytes)): 254 | # find by primary key 255 | key = None 256 | for column in await self.describe(): 257 | if column['primary']: 258 | key = column['name'] 259 | break 260 | else: 261 | raise ValueError('No primary key') 262 | return '`{}` = {}'.format(key, self.keyword), [filter] 263 | else: 264 | raise NotImplementedError 265 | 266 | async def find_one(self, filter=None, join=None, left_join=None, for_update=False, columns=None, order_by=None): 267 | result = list(await self.find(filter, limit=1, join=join, for_update=for_update, columns=columns, order_by=order_by)) 268 | if result: 269 | return result[0] 270 | 271 | async def find(self, filter=None, limit=None, join=None, left_join=None, for_update=False, columns=None, group_by=None, order_by=None): 272 | """ 273 | join='subtable.id=column' 274 | join='subtable as tbl.id=column' 275 | """ 276 | 277 | if columns: 278 | assert not join 279 | if not isinstance(columns, (list, tuple)): 280 | columns = [columns] 281 | columns = ', '.join(map(format_func, columns)) 282 | else: 283 | columns = '{}.*'.format(self.tablename) 284 | 285 | joins = [] 286 | if join or left_join: 287 | assert bool(join) ^ bool(left_join) 288 | if left_join: 289 | join = left_join 290 | prefix = 'LEFT ' 291 | else: 292 | prefix = '' 293 | r = re.match(r'(\w+)\.(\w+)=(\w+)', join) 294 | if r: 295 | table2, column2, column1 = r.groups() 296 | alias = table2 297 | else: 298 | r = re.match(r'(\w+)\s+as\s+(\w+)\.(\w+)=(\w+)', join) 299 | assert r 300 | table2, alias, column2, column1 = r.groups() 301 | 302 | columns += ', "" as __divider, {}.*'.format(alias) 303 | join = ' {}JOIN {} AS {} ON {}.{} = {}'.format(prefix, table2, alias, alias, column2, column1) 304 | 305 | key = None 306 | if left_join: 307 | for c in self.engine.get_columns(self.tablename): 308 | if c['primary']: 309 | key = c['name'] 310 | break 311 | 312 | joins.append({ 313 | 'alias': alias, 314 | 'key': key 315 | }) 316 | 317 | sql = 'SELECT {} FROM `{}`'.format(columns, self.tablename) 318 | where, values = await self._build_filter(filter) 319 | if join: 320 | sql += join 321 | if where: 322 | sql += ' WHERE ' + where 323 | if group_by: 324 | sql += ' GROUP BY ' + quote_key(group_by) 325 | if order_by: 326 | if not isinstance(order_by, list): 327 | order_by = [order_by] 328 | oc = [] 329 | for name in order_by: 330 | if name.startswith('-'): 331 | oc.append(quote_key(name[1:]) + ' DESC') 332 | else: 333 | oc.append(quote_key(name)) 334 | sql += ' ORDER BY ' + ', '.join(oc) 335 | if limit: 336 | assert isinstance(limit, int) 337 | sql += ' LIMIT {}'.format(limit) 338 | 339 | if for_update: 340 | sql += ' FOR UPDATE' 341 | 342 | result = [] 343 | async with self.try_execute(sql, tuple(values)) as cursor: 344 | if cursor.rowcount: 345 | columns = cursor.description 346 | for row in await cursor.fetchall(): 347 | join_index = -1 348 | join_alias = None 349 | join_key = None 350 | d = {} 351 | for i, value in enumerate(row): 352 | col = columns[i] 353 | column_name = col[0] 354 | if column_name == '__divider': 355 | join_index += 1 356 | join_alias = joins[join_index]['alias'] 357 | join_key = joins[join_index]['key'] 358 | d[join_alias] = {} 359 | continue 360 | if join_alias: 361 | if column_name == join_key: 362 | if value is None: 363 | d[join_alias] = None 364 | if d[join_alias] is not None: 365 | d[join_alias][column_name] = value 366 | else: 367 | d[column_name] = value 368 | result.append(d) 369 | 370 | return result 371 | 372 | async def update(self, filter=None, update=None, limit=None): 373 | up = [] 374 | values = [] 375 | for key, value in update.items(): 376 | up.append('`{}` = {}'.format(key, self.keyword)) 377 | values.append(value) 378 | 379 | sql = 'UPDATE `{}` SET {}'.format(self.tablename, ', '.join(up)) 380 | 381 | where, wvalues = await self._build_filter(filter) 382 | if where: 383 | sql += ' WHERE ' + where 384 | values += wvalues 385 | 386 | if limit: 387 | assert isinstance(limit, int) 388 | sql += ' LIMIT {}'.format(limit) 389 | 390 | await self.try_execute(sql, tuple(values))() 391 | 392 | async def update_one(self, filter=None, update=None): 393 | await self.update(filter, update, limit=1) 394 | 395 | async def delete(self, filter=None): 396 | where, values = await self._build_filter(filter) 397 | 398 | sql = 'DELETE FROM `{}`'.format(self.tablename) 399 | if where: 400 | sql += ' WHERE {}'.format(where) 401 | await self.try_execute(sql, tuple(values))() 402 | 403 | async def create_index(self, name, column, primary=False, unique=False, fulltext=False, exist_ok=False): 404 | if primary: 405 | name = 'PRIMARY' 406 | if exist_ok and await self.has_index(name): 407 | return 408 | 409 | if isinstance(column, list): 410 | column = ', '.join(map(quote_key, column)) 411 | else: 412 | column = quote_key(column) 413 | 414 | index_type = 'INDEX ' 415 | if primary: 416 | index_type = 'PRIMARY KEY ' 417 | assert not fulltext 418 | name = '' 419 | else: 420 | name = quote_key(name) 421 | 422 | if unique and not primary: 423 | assert not fulltext 424 | index_type = 'UNIQUE ' 425 | elif fulltext: 426 | index_type = 'FULLTEXT ' 427 | 428 | sql = 'ALTER TABLE `{}` ADD {}{}({})'.format(self.tablename, index_type, name, column) 429 | await self.try_execute(sql)() 430 | 431 | async def has_index(self, name): 432 | async with self.try_execute('show index from ' + self.tablename) as cursor: 433 | for row in await cursor.fetchall(): 434 | if row[2] == name: 435 | return True 436 | return False 437 | 438 | async def count(self, filter=None): 439 | where, values = await self._build_filter(filter) 440 | 441 | sql = 'SELECT COUNT(*) FROM `{}`'.format(self.tablename) 442 | if where: 443 | sql += ' WHERE {}'.format(where) 444 | async with self.try_execute(sql, tuple(values)) as cursor: 445 | return (await cursor.fetchone())[0] 446 | 447 | async def drop(self, exist_ok=True): 448 | sql = 'DROP TABLE ' 449 | if exist_ok: 450 | sql += 'IF EXISTS ' 451 | sql += self.tablename 452 | await self.try_execute(sql)() 453 | --------------------------------------------------------------------------------