├── .gitignore ├── testproject ├── __init__.py ├── dreltest │ ├── __init__.py │ ├── models.py │ └── tests.py ├── setup.py ├── settings.py ├── urls.py └── manage.py ├── drel ├── __init__.py ├── compiler.py ├── porcelain.py └── ast.py ├── setup.py └── README.markdown /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /testproject/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /testproject/dreltest/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /drel/__init__.py: -------------------------------------------------------------------------------- 1 | from drel.porcelain import * 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | setup( 4 | name='drel', 5 | version='0.0.3', 6 | author='Kevin Mahoney', 7 | author_email='git@kevinmahoney.co.uk', 8 | packages=['drel'], 9 | ) 10 | -------------------------------------------------------------------------------- /testproject/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | setup( 4 | name='drel', 5 | version='0.1', 6 | author='Kevin Mahoney', 7 | author_email='git@kevinmahoney.co.uk', 8 | packages=['drel'], 9 | ) 10 | -------------------------------------------------------------------------------- /testproject/settings.py: -------------------------------------------------------------------------------- 1 | DATABASES = { 2 | 'default': { 3 | 'ENGINE': 'django.db.backends.sqlite3', 4 | 'NAME': 'test.db', 5 | 'USER': '', 6 | 'PASSWORD': '', 7 | 'HOST': '', 8 | 'PORT': '', 9 | } 10 | } 11 | 12 | INSTALLED_APPS = [ 13 | 'dreltest', 14 | ] 15 | -------------------------------------------------------------------------------- /testproject/urls.py: -------------------------------------------------------------------------------- 1 | from django.conf.urls.defaults import patterns, include, url 2 | 3 | # Uncomment the next two lines to enable the admin: 4 | # from django.contrib import admin 5 | # admin.autodiscover() 6 | 7 | urlpatterns = patterns('', 8 | # Examples: 9 | # url(r'^$', 'drel.views.home', name='home'), 10 | # url(r'^drel/', include('drel.foo.urls')), 11 | 12 | # Uncomment the admin/doc line below to enable admin documentation: 13 | # url(r'^admin/doc/', include('django.contrib.admindocs.urls')), 14 | 15 | # Uncomment the next line to enable the admin: 16 | # url(r'^admin/', include(admin.site.urls)), 17 | ) 18 | -------------------------------------------------------------------------------- /testproject/manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from django.core.management import execute_manager 3 | import imp 4 | try: 5 | imp.find_module('settings') # Assumed to be in the same directory. 6 | except ImportError: 7 | import sys 8 | sys.stderr.write("Error: Can't find the file 'settings.py' in the directory containing %r. It appears you've customized things.\nYou'll have to run django-admin.py, passing it your settings module.\n" % __file__) 9 | sys.exit(1) 10 | 11 | import settings 12 | import sys 13 | import os 14 | 15 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 16 | 17 | if __name__ == "__main__": 18 | execute_manager(settings) 19 | -------------------------------------------------------------------------------- /testproject/dreltest/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | 4 | class TestModel1(models.Model): 5 | a = models.CharField(max_length=50) 6 | b = models.IntegerField() 7 | 8 | 9 | class TestModel2(models.Model): 10 | m1 = models.ForeignKey(TestModel1) 11 | c = models.IntegerField() 12 | 13 | 14 | class TestM2M(models.Model): 15 | a = models.CharField(max_length=50) 16 | m2s = models.ManyToManyField(TestModel2) 17 | 18 | 19 | class BlogUser(models.Model): 20 | username = models.CharField(max_length=200) 21 | 22 | 23 | class BlogPost(models.Model): 24 | title = models.CharField(max_length=200) 25 | body = models.TextField() 26 | published = models.DateTimeField(auto_now_add=True) 27 | user = models.ForeignKey(BlogUser) 28 | -------------------------------------------------------------------------------- /drel/compiler.py: -------------------------------------------------------------------------------- 1 | class Compiler(object): 2 | ''' 3 | A class that maintains state during the compilation of SQL, as the 4 | AST nodes themselves are stateless and immutable. 5 | 6 | ''' 7 | def __init__(self, connection): 8 | self._aliases = {} 9 | self.values = [] 10 | self._connection = connection 11 | 12 | def q(self, s): 13 | '''Quote a name.''' 14 | return self._connection.ops.quote_name(s) 15 | 16 | def refer(self, obj): 17 | ''' 18 | Return the quoted name of an object, creating a new name if it 19 | hasn't been referred to before. 20 | 21 | ''' 22 | try: 23 | return self._aliases[obj] 24 | except KeyError: 25 | alias = self.q("t%d" % len(self._aliases)) 26 | self._aliases[obj] = alias 27 | return alias 28 | -------------------------------------------------------------------------------- /drel/porcelain.py: -------------------------------------------------------------------------------- 1 | ''' 2 | User facing top-level functions. 3 | 4 | ''' 5 | from drel.ast import ( 6 | DjangoTable, DjangoM2MTable, Const, 7 | FunctionExpression, LabelReference, RawExpression) 8 | from django.db.models.base import ModelBase 9 | from django.db.models.fields.related import ReverseManyRelatedObjectsDescriptor 10 | 11 | 12 | def table(t): 13 | ''' 14 | Create a DRel table from a Django Model or many-to-many foreign 15 | key. 16 | 17 | ''' 18 | if isinstance(t, ReverseManyRelatedObjectsDescriptor): 19 | return DjangoM2MTable(t.field) 20 | assert isinstance(t, ModelBase), "Expected Django model." 21 | return DjangoTable(t) 22 | 23 | 24 | def const(c): 25 | '''A constant SQL value. Escaped by the database engine.''' 26 | return Const(c) 27 | 28 | 29 | def label(l): 30 | '''A reference to a labelled field or expression.''' 31 | return LabelReference(l) 32 | 33 | 34 | def raw_expr(expr): 35 | '''A raw SQL expression.''' 36 | return RawExpression(expr) 37 | 38 | 39 | def fn(name, *args): 40 | '''Apply a SQL function.''' 41 | return FunctionExpression(name, *args) 42 | 43 | 44 | def sum(arg): 45 | '''Sum aggregate function.''' 46 | return fn("SUM", arg) 47 | 48 | 49 | def avg(arg): 50 | '''Average aggregate function.''' 51 | return fn("AVG", arg) 52 | 53 | 54 | def max(arg): 55 | '''Max aggregate function.''' 56 | return fn("MAX", arg) 57 | 58 | 59 | def min(arg): 60 | '''Min aggregate function.''' 61 | return fn("MIN", arg) 62 | 63 | 64 | def count(arg=raw_expr("*")): 65 | '''Count aggregate function.''' 66 | return fn("COUNT", arg) 67 | -------------------------------------------------------------------------------- /README.markdown: -------------------------------------------------------------------------------- 1 | # DRel 2 | 3 | Relational algebra for Django, a la ARel, SQLAlchemy. Because 4 | sometimes you need a left join. 5 | 6 | This (as of writing) is a young project. It probably has bugs. The 7 | interface may also change drastically between versions. 8 | 9 | 10 | ## Why not just use SQLAlchemy? 11 | 12 | I highly recommend using SQLAlchemy! However, you may find this 13 | project useful if the decision to use the Django ORM is out of your 14 | control, or if you like the benefits of sticking with the Django ORM 15 | (admin etc.) but would occasionally like to dip down into something 16 | lower level that it can't express easily. 17 | 18 | 19 | ## Basic use 20 | 21 | Given the following Models: 22 | 23 | class BlogUser(models.Model): 24 | username = models.CharField(max_length=200) 25 | 26 | class BlogPost(models.Model): 27 | title = models.CharField(max_length=200) 28 | body = models.TextField() 29 | published = models.DateTimeField(auto_now_add=True) 30 | user = models.ForeignKey(BlogUser) 31 | 32 | First, create a wrapper around them: 33 | 34 | import drel as d 35 | user = d.table(BlogUser) 36 | post = d.table(BlogPost) 37 | 38 | All DRel constructs are *immutable* and *stateless*, and so can be 39 | used at the module level. 40 | 41 | Table fields can be referenced using their Django Model name, or their 42 | database column name. 43 | 44 | The basic way to construct a query from a DRel table is using the 45 | `.project(*expressions)`, `.join(table, expression)`, 46 | `.leftjoin(table, expression)`, `.crossjoin(table)`, 47 | `.where(expression)`, `.group(*expressions)`, `.order(*expressions)` 48 | methods. Select queries can themselves be used as an expression or 49 | table using `.subquery`. 50 | 51 | All fields and expressions in the `.project` list must have a name -- 52 | expressions support `.label(name)` to give them one. Labelled 53 | expressions can be referred to in other parts of the query using 54 | `d.label(name)`. 55 | 56 | Expressions have operator overloading to support comparison, arithmetic, and 57 | (`&`), or (`|`). 58 | 59 | Insert values into your queries with `d.const(value)`. 60 | 61 | Evaluate your queries with `.all()` or `.one()`. `.all()` returns a 62 | generator yielding named tuples. 63 | 64 | 65 | ## Examples 66 | 67 | Some example queries using the Blog models. 68 | 69 | # All titles 70 | post.project(post.title).all() 71 | 72 | # All posts with usernames 73 | post.join(user, user.id == post.user_id) 74 | .project(post.title, user.username) 75 | .order(post.published.desc) 76 | .all() 77 | 78 | # Post counts 79 | user.leftjoin(post, post.user_id == user.id) 80 | .group(user.username) 81 | .project(user.username, d.count(post.id).label("postcount"))) 82 | .all() 83 | 84 | # Total posts 85 | post.project(d.count().label("total")).one() 86 | 87 | # Latest post titles for all users 88 | # Use another table for a self join 89 | post2 = d.table(BlogPost) 90 | 91 | user.leftjoin(post, post.user == user.id) 92 | .leftjoin(post2, (post2.user == user.id) & (post.published < post2.published)) 93 | .where(post2.user.is_null) 94 | .project(user.username, post.title) 95 | .all() 96 | 97 | 98 | ## TODO 99 | 100 | * Missing SQL expressiveness 101 | * Inserts 102 | * Updates 103 | * Documentation 104 | * More tests 105 | 106 | 107 | ## License 108 | 109 | Copyright (c) 2011, Kevin Mahoney 110 | 111 | All rights reserved. 112 | 113 | Redistribution and use in source and binary forms, with or without 114 | modification, are permitted provided that the following conditions are 115 | met: 116 | 117 | * Redistributions of source code must retain the above copyright notice, 118 | this list of conditions and the following disclaimer. 119 | 120 | * Redistributions in binary form must reproduce the above copyright 121 | notice, this list of conditions and the following disclaimer in the 122 | documentation and/or other materials provided with the distribution. 123 | 124 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 125 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 126 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 127 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 128 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 129 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 130 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 131 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 132 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 133 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 134 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 135 | -------------------------------------------------------------------------------- /testproject/dreltest/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | 3 | from dreltest.models import BlogUser, BlogPost 4 | from dreltest.models import TestModel1, TestModel2, TestM2M 5 | import drel as d 6 | 7 | 8 | class BlogTest(TestCase): 9 | def setUp(self): 10 | for i in range(10): 11 | u = BlogUser.objects.create(username="u%d" % i) 12 | for p in range(i): 13 | BlogPost.objects.create(user=u, title="p%d" % p, body="Test") 14 | 15 | def test_postcount(self): 16 | user = d.table(BlogUser) 17 | post = d.table(BlogPost) 18 | 19 | q = (user 20 | .leftjoin(post, post.user_id == user.id) 21 | .group(user.username) 22 | .project(user.username, d.count(post.id).label("posts"))) 23 | 24 | counts = dict(q.all()) 25 | 26 | for i in range(10): 27 | self.assertEqual(i, counts["u%d" % i]) 28 | 29 | def test_latesttitle(self): 30 | user = d.table(BlogUser) 31 | post1 = d.table(BlogPost) 32 | post2 = d.table(BlogPost) 33 | 34 | # An interesting query that's difficult in the ORM 35 | q = (user 36 | .leftjoin(post1, post1.user == user.id) 37 | .leftjoin(post2, 38 | (post2.user == user.id) & 39 | (post1.id < post2.id)) 40 | .where(post2.user.is_null) 41 | .project(user.username, post1.title)) 42 | 43 | latest = dict(q.all()) 44 | 45 | # u0 has no posts, so no latest post 46 | self.assertEqual(None, latest["u0"]) 47 | 48 | for i in range(1, 10): 49 | self.assertEqual("p%d" % (i - 1), latest["u%d" % i]) 50 | 51 | # The subquery way of doing the same thing 52 | q1 = (post1 53 | .group(post1.user) 54 | .project(post1.user, d.max(post1.id).label("post_id")) 55 | .subquery) 56 | q2 = (user 57 | .leftjoin(q1, user.id == q1.user) 58 | .leftjoin(post1, post1.id == q1.post_id) 59 | .project(user.username, post1.title)) 60 | 61 | latest2 = dict(q2.all()) 62 | self.assertEqual(latest, latest2) 63 | 64 | 65 | class DrelTest(TestCase): 66 | def setUp(self): 67 | x = TestModel1.objects.create(a="x", b=1) 68 | TestModel2.objects.create(m1=x, c=1) 69 | TestModel2.objects.create(m1=x, c=2) 70 | TestModel2.objects.create(m1=x, c=3) 71 | y = TestModel1.objects.create(a="y", b=2) 72 | TestModel2.objects.create(m1=y, c=4) 73 | t1 = TestModel2.objects.create(m1=y, c=5) 74 | t2 = TestModel2.objects.create(m1=y, c=6) 75 | 76 | m1 = TestM2M.objects.create(a="m2m1") 77 | TestM2M.objects.create(a="m2m2") 78 | 79 | m1.m2s.add(t1) 80 | m1.m2s.add(t2) 81 | 82 | def test_simple(self): 83 | t2 = d.table(TestModel2) 84 | 85 | r = t2.where(t2.c > d.const(3)).project(t2.c).all() 86 | self.assertEqual(3, len(list(r))) 87 | 88 | r = t2.where(t2.c > d.const(3) + d.const(1)).project(t2.c).all() 89 | self.assertEqual(2, len(list(r))) 90 | 91 | def test_cross(self): 92 | t1 = d.table(TestModel1) 93 | t2 = d.table(TestModel2) 94 | 95 | results = t1.crossjoin(t2).project(t1.a, t1.b, t2.c).all() 96 | 97 | def _count(x): 98 | return x.objects.all().count() 99 | 100 | self.assertEqual( 101 | _count(TestModel1) * _count(TestModel2), 102 | len(list(results))) 103 | 104 | def test_agg(self): 105 | t1 = d.table(TestModel1) 106 | t2 = d.table(TestModel2) 107 | 108 | total = t2.project(d.sum(t2.c).label("total")).one() 109 | self.assertEqual(6 + 5 + 4 + 3 + 2 + 1, total.total) 110 | 111 | grouped = list( 112 | t2 113 | .leftjoin(t1, t1.id == t2.m1) 114 | .group(t1.a) 115 | .project(t1.a, d.sum(t2.c).label("total")) 116 | .order(d.label("total").desc) 117 | .all()) 118 | 119 | self.assertEqual("y", grouped[0].a) 120 | self.assertEqual(6 + 5 + 4, grouped[0].total) 121 | self.assertEqual("x", grouped[1].a) 122 | self.assertEqual(3 + 2 + 1, grouped[1].total) 123 | 124 | def test_order(self): 125 | t2 = d.table(TestModel2) 126 | 127 | r = list(t2.project(t2.c).order(t2.c).all()) 128 | self.assertEqual(1, r[0].c) 129 | 130 | r = list(t2.project(t2.c).order(t2.c.desc).all()) 131 | self.assertEqual(6, r[0].c) 132 | 133 | def test_select_expr(self): 134 | t2 = d.table(TestModel2) 135 | 136 | a = t2.project((d.max(t2.c) - d.const(2)).label("total")).subquery 137 | r = t2.project(t2.c).where(t2.c > a).all() 138 | self.assertEqual(2, len(list(r))) 139 | 140 | def test_derived_table(self): 141 | t1 = d.table(TestModel1) 142 | t2 = d.table(TestModel2) 143 | 144 | a = (t2 145 | .project(t2.m1, d.min(t2.c).label("m")) 146 | .group(t2.m1) 147 | .subquery) 148 | 149 | b = (t1 150 | .project(t1.a, a.m) 151 | .join(a, a.m1 == t1.id) 152 | .order(a.m)) 153 | 154 | r = list(b.all()) 155 | self.assertEqual(2, len(r)) 156 | self.assertEqual(1, r[0].m) 157 | self.assertEqual(4, r[1].m) 158 | 159 | def test_count(self): 160 | t1 = d.table(TestModel1) 161 | t2 = d.table(TestModel2) 162 | 163 | a = t1.project(d.count().label("total")).one() 164 | self.assertEqual(2, a.total) 165 | 166 | b = (t1 167 | .leftjoin(t2, t2.m1_id == t1.id) 168 | .group(t1.a) 169 | .project(t1.a, d.count().label("count"))) 170 | 171 | for i in b.all(): 172 | self.assertEqual(3, i.count) 173 | 174 | def test_m2m(self): 175 | t2 = d.table(TestModel2) 176 | tm = d.table(TestM2M) 177 | tmjoin = d.table(TestM2M.m2s) 178 | 179 | a = (tm 180 | .project(tm.a, t2.c) 181 | .join(tmjoin, tmjoin.testm2m == tm.id) 182 | .join(t2, tmjoin.testmodel2 == t2.id) 183 | .order(t2.c)) 184 | 185 | l = list(a.all()) 186 | self.assertEqual(2, len(l)) 187 | self.assertEqual(5, l[0].c) 188 | self.assertEqual(6, l[1].c) 189 | 190 | def test_limit(self): 191 | t2 = d.table(TestModel2) 192 | self.assertEqual(1, len(list(t2.project(t2.c).limit(1).all()))) 193 | self.assertEqual(2, len(list(t2.project(t2.c).limit(2).all()))) 194 | 195 | def test_offset(self): 196 | t2 = d.table(TestModel2) 197 | test2 = list(t2.order(t2.c).limit(1).offset(1).project(t2.c).all()) 198 | self.assertEqual(1, len(test2)) 199 | self.assertEqual(2, test2[0].c) 200 | -------------------------------------------------------------------------------- /drel/ast.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The classes in this file are nodes of a tree that represent an SQL 3 | query. Each node should be both stateless and immutable -- that is, 4 | methods should return a new AST instead of modifying an existing 5 | one. This means expressions can be reused and composed in a similar 6 | fashion to Django querysets. 7 | 8 | The SQL statement is compiled by the various `_compile` 9 | methods. Different `_compile` methods are used for different contexts 10 | to help make sure nonsense SQL queries are not generated. We can 11 | probably give more informative error messages in this way than the 12 | database can. 13 | 14 | A `Compiler` object is passed through the `_compile` methods to keep 15 | track of any state. Currently escaped values are appended to a list 16 | which means the nodes have to be compiled in the order they appear in 17 | the SQL query (so the order of the '%s' values match up). Consider a 18 | more elegant solution to be on the TODO list. 19 | 20 | The `Compiler` object is also responsible for keeping track of table 21 | aliasing, making sure they get unique names. This allows self joins, 22 | for example. 23 | 24 | ''' 25 | from collections import namedtuple 26 | 27 | from django.db import connections 28 | 29 | from drel.compiler import Compiler 30 | 31 | 32 | class InvalidQuery(Exception): 33 | pass 34 | 35 | 36 | class AST(object): 37 | '''Base class for AST nodes.''' 38 | 39 | # Stub the various _compile interfaces with more useful error 40 | # messages. 41 | 42 | def _compile(self, compiler): 43 | raise InvalidQuery("%s is not a statement" % self) 44 | 45 | def _compile_expression(self, compiler): 46 | raise InvalidQuery("%s is not an expression" % self) 47 | 48 | def _compile_projection(self, compiler): 49 | raise InvalidQuery("%s is not labeled" % self) 50 | 51 | def _compile_join(self, compiler): 52 | raise InvalidQuery("%s is not a join" % self) 53 | 54 | def _compile_table(self, compiler): 55 | raise InvalidQuery("%s is not a table" % self) 56 | 57 | # A neater (but less informative) representation of a node. 58 | # Override where suitable. 59 | 60 | def __repr__(self): 61 | return self.__class__.__name__ 62 | 63 | 64 | class ExpressionMixin(object): 65 | '''Provides operator overriding and labelling for expressions.''' 66 | 67 | def label(self, name): 68 | return LabeledProjection(name, self) 69 | 70 | @property 71 | def desc(self): 72 | return DescendingExpression(self) 73 | 74 | @property 75 | def is_null(self): 76 | return BinaryExpression("IS", self, RawExpression("NULL")) 77 | 78 | @property 79 | def is_not_null(self): 80 | return BinaryExpression("IS NOT", self, RawExpression("NULL")) 81 | 82 | def __eq__(self, other): 83 | return BinaryExpression("=", self, other) 84 | 85 | def __ne__(self, other): 86 | return BinaryExpression("<>", self, other) 87 | 88 | def __lt__(self, other): 89 | return BinaryExpression("<", self, other) 90 | 91 | def __le__(self, other): 92 | return BinaryExpression("<=", self, other) 93 | 94 | def __gt__(self, other): 95 | return BinaryExpression(">", self, other) 96 | 97 | def __ge__(self, other): 98 | return BinaryExpression(">=", self, other) 99 | 100 | def __or__(self, other): 101 | return BinaryExpression("OR", self, other) 102 | 103 | def __and__(self, other): 104 | return BinaryExpression("AND", self, other) 105 | 106 | def __add__(self, other): 107 | return BinaryExpression("+", self, other) 108 | 109 | def __sub__(self, other): 110 | return BinaryExpression("-", self, other) 111 | 112 | def __mul__(self, other): 113 | return BinaryExpression("*", self, other) 114 | 115 | def __mod__(self, other): 116 | return BinaryExpression("%", self, other) 117 | 118 | 119 | class TableMixin(object): 120 | '''Basic table operations.''' 121 | 122 | def project(self, *fields): 123 | return Select(self, project=fields) 124 | 125 | def join(self, table, on): 126 | return Select(self, joins=[Join(table, on)]) 127 | 128 | def leftjoin(self, table, on): 129 | return Select(self, joins=[Join(table, on, "LEFT")]) 130 | 131 | def crossjoin(self, table): 132 | return Select(self, joins=[CrossJoin(table)]) 133 | 134 | def where(self, expr): 135 | return Select(self, where=expr) 136 | 137 | def group(self, *fields): 138 | return Select(self, group=fields) 139 | 140 | def order(self, *fields): 141 | return Select(self, order=fields) 142 | 143 | def limit(self, limit): 144 | return Select(self, limit=limit) 145 | 146 | def offset(self, offset): 147 | return Select(self, offset=offset) 148 | 149 | 150 | class DescendingExpression(AST): 151 | ''' 152 | A wrapper around an expression, used to make it descending in 153 | an ORDER BY expression list. 154 | 155 | Cannot be further manipulated as an expression. 156 | 157 | ''' 158 | def __init__(self, expr): 159 | self._expr = expr 160 | 161 | def _compile_expression(self, compiler): 162 | expr = self._expr._compile_expression(compiler) 163 | return "%s DESC" % expr 164 | 165 | 166 | class LabeledProjection(AST): 167 | ''' 168 | An expression that has been assigned a label. 169 | 170 | Cannot be further manipulated as an expression. 171 | 172 | ''' 173 | def __init__(self, label, expr): 174 | self.row_key = label 175 | self._expr = expr 176 | 177 | def _compile_projection(self, compiler): 178 | expr = self._expr._compile_expression(compiler) 179 | return "%s AS %s" % (expr, compiler.q(self.row_key)) 180 | 181 | 182 | class BinaryExpression(AST, ExpressionMixin): 183 | def __init__(self, op, a, b): 184 | self._op = op 185 | self._a = a 186 | self._b = b 187 | 188 | def _compile_expression(self, compiler): 189 | a = self._a._compile_expression(compiler) 190 | b = self._b._compile_expression(compiler) 191 | return "%s %s %s" % (a, self._op, b) 192 | 193 | 194 | class Const(AST, ExpressionMixin): 195 | '''A value to be escaped by the database engine.''' 196 | 197 | def __init__(self, value, alias=None): 198 | self._value = value 199 | 200 | def _compile_expression(self, compiler): 201 | # Compiler state is used here. This is why the SQL statement 202 | # has to be built in order: so the placeholders match up with 203 | # the list of values. 204 | compiler.values.append(self._value) 205 | return "%s" 206 | 207 | 208 | class RawExpression(AST, ExpressionMixin): 209 | '''Pass through a string directly to the compiled SQL.''' 210 | 211 | def __init__(self, sql): 212 | self._sql = sql 213 | 214 | def _compile_expression(self, compiler): 215 | return self._sql 216 | 217 | 218 | class LabelReference(AST, ExpressionMixin): 219 | '''A reference to a labelled field/expression.''' 220 | 221 | def __init__(self, label): 222 | self._label = label 223 | 224 | def _compile_expression(self, compiler): 225 | return compiler.q(self._label) 226 | 227 | 228 | class FunctionExpression(AST, ExpressionMixin): 229 | '''SQL function application.''' 230 | 231 | def __init__(self, fn, *args): 232 | self._fn = fn 233 | self._args = args 234 | 235 | def _compile_expression(self, compiler): 236 | args = ",".join(a._compile_expression(compiler) for a in self._args) 237 | return "%s(%s)" % (self._fn, args) 238 | 239 | 240 | class Field(AST, ExpressionMixin): 241 | def __init__(self, table, column, label=None): 242 | self._table = table 243 | self._column = column 244 | self.row_key = label or column 245 | 246 | def _compile_expression(self, compiler): 247 | alias = compiler.refer(self._table) 248 | column = compiler.q(self._column) 249 | return "%s.%s" % (alias, column) 250 | 251 | def _compile_projection(self, compiler): 252 | label = compiler.q(self.row_key) 253 | expr = self._compile_expression(compiler) 254 | return "%s AS %s" % (expr, label) 255 | 256 | 257 | class Join(AST): 258 | def __init__(self, table, on, kind="INNER"): 259 | self._table = table 260 | self._on = on 261 | self._kind = kind 262 | 263 | def _compile_join(self, compiler): 264 | table = self._table._compile_table(compiler) 265 | on_expr = self._on._compile_expression(compiler) 266 | return "%s JOIN %s ON %s" % (self._kind, table, on_expr) 267 | 268 | 269 | class CrossJoin(AST): 270 | def __init__(self, table): 271 | self._table = table 272 | 273 | def _compile_join(self, compiler): 274 | table = self._table._compile_table(compiler) 275 | return "CROSS JOIN %s" % table 276 | 277 | 278 | class Select(AST, ExpressionMixin): 279 | '''Representation of a SELECT SQL statement.''' 280 | 281 | def __init__(self, source, project=None, joins=None, 282 | where=None, group=None, order=None, 283 | limit=None, offset=None): 284 | self._source = source 285 | self._project = project or [] 286 | self._joins = joins or [] 287 | self._where = where 288 | self._group = group 289 | self._order = order 290 | self._limit = limit 291 | self._offset = offset 292 | 293 | def project(self, *fields): 294 | return self._modified(_project=fields) 295 | 296 | def join(self, table, on): 297 | return self._add_join(Join(table, on)) 298 | 299 | def leftjoin(self, table, on): 300 | return self._add_join(Join(table, on, "LEFT")) 301 | 302 | def crossjoin(self, table): 303 | return self._add_join(CrossJoin(table)) 304 | 305 | def where(self, expr): 306 | if not self._where: 307 | return self._modified(_where=expr) 308 | where = self._where & expr 309 | return self._modified(_where=where) 310 | 311 | def group(self, *fields): 312 | return self._modified(_group=fields) 313 | 314 | def order(self, *fields): 315 | return self._modified(_order=fields) 316 | 317 | def limit(self, limit): 318 | return self._modified(_limit=limit) 319 | 320 | def offset(self, offset): 321 | return self._modified(_offset=offset) 322 | 323 | @property 324 | def subquery(self): 325 | return SubQuery(self) 326 | 327 | def _clone(self): 328 | # Note: copy.copy interacts badly with __getattr__ 329 | return Select( 330 | self._source, 331 | self._project, 332 | self._joins, 333 | self._where, 334 | self._group, 335 | self._order, 336 | self._limit, 337 | self._offset) 338 | 339 | def _modified(self, **kwargs): 340 | c = self._clone() 341 | for (k, v) in kwargs.items(): 342 | setattr(c, k, v) 343 | return c 344 | 345 | def _add_join(self, join): 346 | joins = list(self._joins) 347 | joins.append(join) 348 | return self._modified(_joins=joins) 349 | 350 | def _execute(self, using='default'): 351 | con = connections[using] 352 | compiler = Compiler(con) 353 | sql = self._compile(compiler) 354 | 355 | cursor = con.cursor() 356 | cursor.execute(sql, compiler.values) 357 | return cursor 358 | 359 | def _sql(self, using='default'): 360 | con = connections[using] 361 | compiler = Compiler(con) 362 | return (self._compile(compiler), tuple(compiler.values)) 363 | 364 | def to_model(self, model, using='default'): 365 | sql, values = self._sql(using) 366 | return model.objects.raw(sql, values) 367 | 368 | def all(self, using='default'): 369 | '''Execute select and return all rows.''' 370 | cursor = self._execute(using) 371 | cons = namedtuple('Row', [f.row_key for f in self._project]) 372 | for row in cursor.fetchall(): 373 | yield cons(*row) 374 | 375 | def one(self, using='default'): 376 | '''Execute select and return a single row.''' 377 | cursor = self._execute(using) 378 | cons = namedtuple('Row', [f.row_key for f in self._project]) 379 | return cons(*cursor.fetchone()) 380 | 381 | def _compile(self, compiler): 382 | assert self._project, "No fields projected." 383 | 384 | field_sql = ",".join( 385 | f._compile_projection(compiler) for f in self._project) 386 | from_sql = self._source._compile_table(compiler) 387 | sql = ["SELECT %s FROM %s" % (field_sql, from_sql)] 388 | 389 | join_sql = [j._compile_join(compiler) for j in self._joins] 390 | sql.extend(join_sql) 391 | 392 | if self._where: 393 | sql.append("WHERE") 394 | sql.append(self._where._compile_expression(compiler)) 395 | 396 | if self._group: 397 | group_sql = ",".join( 398 | f._compile_expression(compiler) for f in self._group) 399 | sql.append("GROUP BY") 400 | sql.append(group_sql) 401 | 402 | if self._order: 403 | order_sql = ",".join( 404 | f._compile_expression(compiler) for f in self._order) 405 | sql.append("ORDER BY") 406 | sql.append(order_sql) 407 | 408 | if self._limit is not None: 409 | sql.append("LIMIT %d" % self._limit) 410 | 411 | if self._offset is not None: 412 | sql.append("OFFSET %d" % self._offset) 413 | 414 | return " ".join(sql) 415 | 416 | 417 | class SubQuery(AST, ExpressionMixin, TableMixin): 418 | def __init__(self, select): 419 | self._select = select 420 | 421 | def _compile_expression(self, compiler): 422 | return "(%s)" % self._select._compile(compiler) 423 | 424 | def _compile_table(self, compiler): 425 | alias = compiler.refer(self) 426 | return "(%s) AS %s" % (self._select._compile(compiler), alias) 427 | 428 | def __getattr__(self, key): 429 | for f in self._select._project: 430 | if key == f.row_key: 431 | return Field(self, key) 432 | 433 | raise AttributeError(key) 434 | 435 | 436 | class DjangoTable(AST, TableMixin): 437 | '''A wrapper around a Django Model for building DRel queries.''' 438 | 439 | def __init__(self, model): 440 | self._model = model 441 | 442 | def _compile_table(self, compiler): 443 | alias = compiler.refer(self) 444 | table = compiler.q(self._model._meta.db_table) 445 | return "%s AS %s" % (table, alias) 446 | 447 | def __getattr__(self, key): 448 | for f in self._model._meta.fields: 449 | if key == f.name: 450 | return Field(self, f.column, f.name) 451 | if key == f.column: 452 | return Field(self, f.column) 453 | 454 | raise AttributeError(key) 455 | 456 | 457 | class DjangoM2MTable(AST, TableMixin): 458 | ''' 459 | A wrapper around a Django many-to-many field for building DRel 460 | queries. 461 | 462 | ''' 463 | def __init__(self, m2m): 464 | self._m2m = m2m 465 | 466 | def _compile_table(self, compiler): 467 | alias = compiler.refer(self) 468 | table = compiler.q(self._m2m.m2m_db_table()) 469 | return "%s AS %s" % (table, alias) 470 | 471 | def __getattr__(self, key): 472 | if key == self._m2m.m2m_field_name(): 473 | return Field(self, self._m2m.m2m_column_name(), key) 474 | 475 | if key == self._m2m.m2m_reverse_field_name(): 476 | return Field(self, self._m2m.m2m_reverse_name(), key) 477 | 478 | if key == self._m2m.m2m_column_name(): 479 | return Field(self, key) 480 | 481 | if key == self._m2m.m2m_reverse_name(): 482 | return Field(self, key) 483 | 484 | raise AttributeError(key) 485 | --------------------------------------------------------------------------------