├── .gitignore ├── README.md ├── q ├── db.py └── __init__.py ├── test0.2.py └── test0.1.py /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea 2 | /__pycache__ 3 | test.db 4 | *.pyc 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Q : 从零开始,造一个 Python ORM 2 | 3 | 4 | #### 原因 5 | + 满足实际需要 6 | + 做中学 7 | 8 | #### 目标 9 | + https://zhuanlan.zhihu.com/p/28059817 10 | + 独立框架 11 | + 友好API 12 | + 支持SQL 13 | + 支持工具一键生成 ORM 14 | + 足够简洁 15 | + 学习和参考资料 16 | + 快速操作数据库的工具 17 | 18 | #### 版本 19 | 20 | + 0.2 接入SQLite 21 | 22 | ``` 23 | python3 test0.2.py 24 | ``` 25 | 26 | + 0.1 基本SQL实现 27 | 28 | ``` 29 | python3 test0.1.py 30 | ``` 31 | -------------------------------------------------------------------------------- /q/db.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | 3 | 4 | class QDataBase(object): 5 | def _connent(self, *args): 6 | return NotImplemented 7 | 8 | def _execute_sql(self, sql): 9 | return NotImplemented 10 | 11 | 12 | class QSQLite(QDataBase): 13 | def __init__(self, name): 14 | self.name = name 15 | 16 | def _connent(self): 17 | conn = sqlite3.connect(self.name) 18 | return conn 19 | 20 | def _execute_sql(self, sql): 21 | print('QSQLite execute sql: {}'.format(sql)) 22 | conn = self._connent() 23 | conn.row_factory = sqlite3.Row 24 | cursor = conn.cursor() 25 | cursor.execute(sql) 26 | data = [{k: m[k] for k in m.keys()} for m in cursor] 27 | conn.commit() 28 | conn.close() 29 | return data 30 | 31 | def execute_sql(self, sql): 32 | return self._execute_sql(sql) 33 | -------------------------------------------------------------------------------- /test0.2.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | import random 3 | import string 4 | from q import QSQLite 5 | from q import Q 6 | from q import T 7 | from q import CharField 8 | from q import IntegerField 9 | 10 | 11 | def log(models): 12 | for m in models: 13 | print(m) 14 | 15 | 16 | if __name__ == '__main__': 17 | sql = ''' 18 | CREATE TABLE user 19 | ( 20 | id INTEGER PRIMARY KEY AUTOINCREMENT, 21 | name CHAR(50) NOT NULL, 22 | age INT NOT NULL, 23 | gender INT, 24 | phone CHAR(50) 25 | ); 26 | ''' 27 | try: 28 | QSQLite('test.db').execute_sql(sql) 29 | except sqlite3.OperationalError: 30 | print('user表已经创建') 31 | 32 | 33 | class User(Q): 34 | __tablename__ = 'user' 35 | id = IntegerField() 36 | name = CharField() 37 | age = IntegerField() 38 | gender = IntegerField() 39 | phone = CharField() 40 | 41 | 42 | for n in range(10): 43 | name = ''.join(random.sample(string.ascii_letters, 4)) 44 | form = dict( 45 | name=name, 46 | age=random.randint(10, 30), 47 | gender=random.randint(0, 2), 48 | ) 49 | User.new(form).execute() 50 | 51 | users = User.select(User.age).where(User.name == 'sen').execute() 52 | log(users) 53 | 54 | users = User.select().where(User.id == 1).execute() 55 | log(users) 56 | 57 | User.update({'name': 'sen'}).where(User.id == 2).execute() 58 | users = User.select(User.name).where(User.id == 2).execute() 59 | log(users) 60 | 61 | users = User.select(T.max(User.age)).execute() 62 | log(users) 63 | 64 | users = User.select(T.distinct(User.age)).execute() 65 | log(users) 66 | 67 | users = User.select().where(User.name == 'sen').range(1).execute() 68 | user = users[0] 69 | print(user.id, user.name, user.age, user.phone) 70 | 71 | users = User.select().range(0, 10).execute() 72 | user = users[0] 73 | print(user.id, user.name, user.age, user.phone) 74 | 75 | users = User.select().order_by(User.age).execute() 76 | for user in users: 77 | print(user.id, user.name, user.age, user.phone) 78 | 79 | users = User.select(T.alias(T.count(User.age), 'count'), User.age).group_by(User.age).execute() 80 | for user in users: 81 | print(user.age, user.count) 82 | -------------------------------------------------------------------------------- /test0.1.py: -------------------------------------------------------------------------------- 1 | from q import Q 2 | from q import IntegerField 3 | from q import CharField 4 | from q import T 5 | from q import _join_string 6 | 7 | 8 | if __name__ == '__main__': 9 | class User(Q): 10 | __tablename__ = 'user' 11 | name = CharField() 12 | age = IntegerField() 13 | gender = CharField() 14 | phone = CharField() 15 | 16 | 17 | form = dict( 18 | name='sen', 19 | age=18, 20 | ) 21 | 22 | sqls = [ 23 | { 24 | 'comment': 'User.select().generate_sql()', 25 | 'sql': User.select().generate_sql(), 26 | }, 27 | { 28 | 'comment': 'User.select(User.name).generate_sql()', 29 | 'sql': User.select(User.name).generate_sql(), 30 | }, 31 | { 32 | 'comment': 'User.select(User.name, User.age).generate_sql()', 33 | 'sql': User.select(User.name, User.age).generate_sql(), 34 | }, 35 | { 36 | 'comment': "User.select(User.name, User.age).where(User.phone == '136********').generate_sql()", 37 | 'sql': User.select(User.name, User.age).where(User.phone == '136********').generate_sql(), 38 | }, 39 | { 40 | 'comment': "User.select(User.name, User.age).where((User.gender == '男') & (User.age >= 18)).generate_sql()", 41 | 'sql': User.select(User.name, User.age).where((User.gender == '男') & (User.age >= 18)).generate_sql(), 42 | }, 43 | { 44 | 'comment': 'User.select(T.count(User.name)).generate_sql()', 45 | 'sql': User.select(T.count(User.name)).generate_sql(), 46 | }, 47 | { 48 | 'comment': 'User.select(T.distinct(User.name)).generate_sql()', 49 | 'sql': User.select(T.distinct(User.name)).generate_sql(), 50 | }, 51 | { 52 | 'comment': 'User.select(T.max(User.age), T.distinct(User.name)).generate_sql()', 53 | 'sql': User.select(T.max(User.age), T.distinct(User.name)).generate_sql(), 54 | }, 55 | { 56 | 'comment': 'User.new(form).generate_sql()', 57 | 'sql': User.new(form).generate_sql(), 58 | }, 59 | { 60 | 'comment': 'User.select(T.max(User.age), T.distinct(User.name)).generate_sql()', 61 | 'sql': User.select(T.max(User.age), T.distinct(User.name)).generate_sql(), 62 | }, 63 | { 64 | 'comment': "User.update(form).where(User.name != '森').generate_sql()", 65 | 'sql': User.update(form).where(User.name != '森').generate_sql(), 66 | }, 67 | { 68 | 'comment': "User.delete().where(User.name != '森').generate_sql()", 69 | 'sql': User.delete().where(User.name != '森').generate_sql(), 70 | }, 71 | { 72 | 'comment': "User.select(T.alias(T.max(User.age), 'max_age'), T.distinct(User.name)).generate_sql()", 73 | 'sql': User.select(T.alias(T.max(User.age), 'max_age'), T.distinct(User.name)).generate_sql(), 74 | }, 75 | { 76 | 'comment': "User.select(T.alias(User.name, 'user_name')).generate_sql()", 77 | 'sql': User.select(T.alias(User.name, 'user_name')).generate_sql(), 78 | }, 79 | ] 80 | 81 | template_print = ''' 82 | {comment} 83 | >>> 84 | {sql} 85 | ''' 86 | sqls = map(lambda v: template_print.format(**v), sqls) 87 | s = _join_string(sqls, glue='\n') 88 | 89 | template = ''' 90 | class User(Queryable): 91 | __tablename__ = 'user' 92 | name = CharField() 93 | age = IntergerField() 94 | gender = CharField() 95 | phone = CharField() 96 | 97 | form = dict( 98 | name='sen', 99 | age=18, 100 | ) 101 | 102 | {} 103 | '''.format(s) 104 | print(template) 105 | -------------------------------------------------------------------------------- /q/__init__.py: -------------------------------------------------------------------------------- 1 | from q.db import QSQLite 2 | 3 | 4 | class Q(object): 5 | 6 | expression_list = [] 7 | db = QSQLite('test.db') 8 | 9 | def __init__(self, form): 10 | for k, v in form.items(): 11 | setattr(self, k, v) 12 | 13 | @classmethod 14 | def table_name(cls): 15 | return cls.__tablename__ 16 | 17 | @classmethod 18 | def field(cls): 19 | m = { 20 | 'tablename': cls.__tablename__ 21 | } 22 | for k in dir(cls): 23 | f = getattr(cls, k) 24 | if isinstance(f, Field): 25 | m[f.uuid()] = k 26 | return m 27 | 28 | @classmethod 29 | def select(cls, *keys): 30 | cls.expression_list.append(Node('select', *keys)) 31 | return cls 32 | 33 | @classmethod 34 | def update(cls, *keys): 35 | cls.expression_list.append(Node('update', *keys)) 36 | return cls 37 | 38 | @classmethod 39 | def insert(cls): 40 | pass 41 | 42 | @classmethod 43 | def delete(cls, *keys): 44 | cls.expression_list.append(Node('delete', *keys)) 45 | return cls 46 | 47 | @classmethod 48 | def where(cls, *keys): 49 | cls.expression_list.append(Node('where', *keys)) 50 | return cls 51 | 52 | @classmethod 53 | def order_by(cls, *keys): 54 | cls.expression_list.append(Node('order_by', *keys)) 55 | return cls 56 | 57 | @classmethod 58 | def group_by(cls, *keys): 59 | cls.expression_list.append(Node('group_by', *keys)) 60 | return cls 61 | 62 | @classmethod 63 | def range(cls, *keys): 64 | cls.expression_list.append(Node('range', *keys)) 65 | return cls 66 | 67 | @classmethod 68 | def execute(cls): 69 | sql = SQL(cls.expression_list, cls.field()).generate() 70 | data = cls.db.execute_sql(sql) 71 | cls.expression_list = [] 72 | return [cls(v) for v in data] 73 | 74 | @classmethod 75 | def generate_sql(cls): 76 | sql = SQL(cls.expression_list, cls.field()).generate() 77 | cls.expression_list = [] 78 | return sql 79 | 80 | @classmethod 81 | def new(cls, form): 82 | for k, v in form.items(): 83 | attr = getattr(cls, k) 84 | if attr is None: 85 | raise Exception() 86 | if not isinstance(attr, Field): 87 | raise 88 | if attr.validate(v) is False: 89 | raise Exception('key: {} 类型错误。value: {}'.format(k, v)) 90 | n = Node('insert', form) 91 | cls.expression_list.append(n) 92 | return cls 93 | 94 | 95 | class T(object): 96 | @classmethod 97 | def count(cls, node): 98 | func = 'count({})' 99 | return Function(node, func) 100 | 101 | @classmethod 102 | def distinct(cls, node): 103 | func = 'distinct {}' 104 | return Function(node, func) 105 | 106 | @classmethod 107 | def max(cls, node): 108 | func = 'max({})' 109 | return Function(node, func) 110 | 111 | @classmethod 112 | def alias(cls, node, name): 113 | if isinstance(node, Function): 114 | func = '{} as {}'.format(node.func, _escape(name)) 115 | node.func = func 116 | return node 117 | else: 118 | func = '{} as ' + _escape(name) 119 | return Function(node, func) 120 | 121 | 122 | class Node(object): 123 | def __init__(self, op, *args): 124 | self.op = op 125 | self.args = args 126 | 127 | 128 | class Function(object): 129 | def __init__(self, node, func): 130 | self.n = node 131 | self.func = func 132 | 133 | 134 | class SQL(object): 135 | def __init__(self, expressions, mapper): 136 | self.expressions = expressions 137 | self.mapper = mapper 138 | 139 | 140 | def generate(self): 141 | sql = [] 142 | for node in self.expressions: 143 | if node.op == 'select': 144 | sub_sql = self._generate_select(node.args) 145 | sql.append(sub_sql) 146 | 147 | if node.op == 'where': 148 | sub_sql = self._generate_where(node.args) 149 | sql.append(sub_sql) 150 | 151 | if node.op == 'insert': 152 | sub_sql = self._generate_insert(node.args) 153 | sql.append(sub_sql) 154 | 155 | if node.op == 'update': 156 | sub_sql = self._generate_update(node.args) 157 | sql.append(sub_sql) 158 | 159 | if node.op == 'delete': 160 | sub_sql = self._generate_delete(node.args) 161 | sql.append(sub_sql) 162 | 163 | if node.op == 'order_by': 164 | sub_sql = self._generate_order_by(node.args) 165 | sql.append(sub_sql) 166 | 167 | if node.op == 'group_by': 168 | sub_sql = self._generate_group_by(node.args) 169 | sql.append(sub_sql) 170 | 171 | if node.op == 'range': 172 | sub_sql = self._generate_limit(node.args) 173 | sql.append(sub_sql) 174 | 175 | return _join_string(sql, ' ') 176 | 177 | def _generate_select(self, nodes): 178 | tablename = self.mapper.get('tablename') 179 | if len(nodes) == 0: 180 | return SQLPattern.select_all.format(tablename) 181 | else: 182 | fileds = [] 183 | for node in nodes: 184 | if isinstance(node, Function): 185 | fileds.append(node.func.format(self.mapper.get((node.n.uuid())))) 186 | else: 187 | fileds.append(self.mapper.get((node.uuid()))) 188 | fileds = _join_string(fileds) 189 | return SQLPattern.select_multi.format(fileds, tablename) 190 | 191 | def _generate_where(self, node): 192 | if isinstance(node, tuple): 193 | if len(node) == 0: 194 | return SQLPattern.where_no 195 | else: 196 | return SQLPattern.where_multi.format(self._generate_where(node[0])) 197 | else: 198 | if node.op is None: 199 | return self._generate_where(node.f1) 200 | if isinstance(node, Field): 201 | return '{} {} {}'.format(self.mapper.get(node.uuid()), node.op, _escape(node.o)) 202 | else: 203 | l = self._generate_where(node.f1) 204 | r = self._generate_where(node.f2) 205 | return '{} {} {}'.format(l, node.op, r) 206 | 207 | def _generate_insert(self, node): 208 | form = node[0] 209 | tablename = self.mapper.get('tablename') 210 | 211 | fileds = _join_string(form.keys()) 212 | values = _join_string(map(_escape, form.values())) 213 | return SQLPattern.insert.format(tablename, fileds, values) 214 | 215 | def _generate_update(self, node): 216 | form = node[0] 217 | tablename = self.mapper.get('tablename') 218 | es = [SQLPattern.set_element.format(k, _escape(v)) for k, v in form.items()] 219 | sub_sql = [SQLPattern.update.format(tablename), SQLPattern.set.format(_join_string(es))] 220 | return _join_string(sub_sql, ' ') 221 | 222 | def _generate_delete(self, node): 223 | tablename = self.mapper.get('tablename') 224 | return SQLPattern.delete.format(tablename) 225 | 226 | def _generate_order_by(self, node): 227 | node = node[0] 228 | return SQLPattern.order_by.format(self.mapper.get((node.uuid()))) 229 | 230 | def _generate_group_by(self, node): 231 | node = node[0] 232 | return SQLPattern.group_by.format(self.mapper.get((node.uuid()))) 233 | 234 | def _generate_limit(self, node): 235 | if len(node) == 0: 236 | return SQLPattern.empty 237 | elif len(node) == 1: 238 | return SQLPattern.limit.format(0, node[0]) 239 | else: 240 | return SQLPattern.limit.format(node[0], node[1]) 241 | 242 | 243 | class Expression(object): 244 | def __init__(self, f1): 245 | self.f1 = f1 246 | self.f2 = None 247 | self.op = None 248 | 249 | def __and__(self, o): 250 | self.f2 = o 251 | self.op = 'and' 252 | return Expression(self) 253 | 254 | 255 | def __or__(self, o): 256 | self.f2 = o 257 | self.op = 'or' 258 | return Expression(self) 259 | 260 | 261 | class Field(object): 262 | def __init__(self, k=None): 263 | self.k = k 264 | self.o = None 265 | self.op = None 266 | 267 | def __eq__(self, o): 268 | self.op = '=' 269 | self.o = o 270 | return Expression(self) 271 | 272 | def __ne__(self, o): 273 | self.op = '!=' 274 | self.o = o 275 | return Expression(self) 276 | 277 | def __le__(self, o): 278 | self.op = '<' 279 | self.o = o 280 | return Expression(self) 281 | 282 | def __gt__(self, o): 283 | self.op = '>' 284 | self.o = o 285 | return Expression(self) 286 | 287 | def __ge__(self, o): 288 | self.op = '>=' 289 | self.o = o 290 | return Expression(self) 291 | 292 | def __lt__(self, o): 293 | self.op = '<=' 294 | self.o = o 295 | return Expression(self) 296 | 297 | def uuid(self): 298 | return id(self) 299 | 300 | def validate(self, v): 301 | return NotImplemented 302 | 303 | 304 | class CharField(Field): 305 | def validate(self, v): 306 | return isinstance(v, str) 307 | 308 | 309 | class IntegerField(Field): 310 | def validate(self, v): 311 | return isinstance(v, int) 312 | 313 | 314 | def _join_string(_list, glue=', '): 315 | _list = map(str, _list) 316 | return glue.join(_list) 317 | 318 | 319 | def _escape(s, template="'{}'"): 320 | return template.format(s) 321 | 322 | 323 | class SQLPattern(object): 324 | select_all = 'select * from {}' 325 | select_multi = 'select {} from {}' 326 | where_multi = 'where {}' 327 | where_no = '' 328 | insert = 'insert into {} ({}) values ({})' 329 | update = 'update {}' 330 | set = 'set {}' 331 | set_element = '{} = {}' 332 | delete = 'delete from {}' 333 | empty = '' 334 | limit = 'limit {},{}' 335 | order_by = 'order by {}' 336 | group_by = 'group by {}' 337 | --------------------------------------------------------------------------------