├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.rst ├── batch_select ├── __init__.py ├── models.py ├── replay.py └── tests.py ├── setup.py └── tests ├── run_tests.sh └── test_settings.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .coverage 3 | dist 4 | django_batch_select.egg-info 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2009, John Montgomery. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, 8 | this list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of Django nor the names of its contributors may be used 15 | to endorse or promote products derived from this software without 16 | specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 22 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 25 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.rst 2 | include LICENSE 3 | recursive-include tests * -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | =================== 2 | Django Batch Select 3 | =================== 4 | 5 | The idea of Django Batch Select is to provide an equivalent to 6 | Django's select_related_ functionality. As of such it's another handy 7 | tool for avoiding the "n+1 query problem". 8 | 9 | select_related_ is handy for minimizing the number of queries that need 10 | to be made in certain situations. However it is only usual for 11 | pre-selecting ForeignKey_ relations. 12 | 13 | batch_select is handy for pre-selecting ManyToManyField_ relations and 14 | reverse ForeignKey_ relations. 15 | 16 | It works by performing a single extra SQL query after a QuerySet_ has 17 | been evaluated to stitch in the the extra fields asked for. This requires 18 | the addition of a custom Manager_, which in turn returns a custom QuerySet_ 19 | with extra methods attached. 20 | 21 | Example Usage 22 | ============= 23 | 24 | Assuming we have models defined as the following:: 25 | 26 | from batch_select.models import BatchManager 27 | 28 | class Tag(models.Model): 29 | name = models.CharField(max_length=32) 30 | 31 | class Section(models.Model): 32 | name = models.CharField(max_length=32) 33 | 34 | objects = BatchManager() 35 | 36 | class Entry(models.Model): 37 | title = models.CharField(max_length=255) 38 | section = models.ForeignKey(Section, blank=True, null=True) 39 | tags = models.ManyToManyField(Tag) 40 | 41 | objects = BatchManager() 42 | 43 | I'll also define a helper function to show the SQL queries generated:: 44 | 45 | from django import db 46 | 47 | def show_queries(): 48 | for query in db.connection.queries: 49 | print query["sql"] 50 | db.reset_queries() 51 | 52 | Here are a few example (with generated sql queries):: 53 | 54 | >>> Entry.objects.batch_select('tags').all() 55 | [] 56 | >>> show_queries() # no results, so no 2nd query 57 | SELECT "batch_select_entry"."id", "batch_select_entry"."title", "batch_select_entry"."section_id" FROM "batch_select_entry" 58 | >>> Entry.objects.create() 59 | >>> Entry.objects.create() 60 | >>> tag1 = Tag.objects.create(name='tag1') 61 | >>> tag2 = Tag.objects.create(name='tag2') 62 | >>> db.reset_queries() 63 | >>> entries = Entry.objects.batch_select('tags').all() 64 | >>> entry = entries[0] 65 | >>> print entry.tags_all 66 | [] 67 | >>> show_queries() 68 | SELECT "batch_select_entry"."id", "batch_select_entry"."title", "batch_select_entry"."section_id" FROM "batch_select_entry" LIMIT 1 69 | SELECT (`batch_select_entry_tags`.`entry_id`) AS "entry_id", "batch_select_tag"."id", "batch_select_tag"."name" FROM "batch_select_tag" INNER JOIN "batch_select_entry_tags" ON ("batch_select_tag"."id" = "batch_select_entry_tags"."tag_id") WHERE "batch_select_entry_tags".entry_id IN (1) 70 | >>> entry.tags.add(tag1) 71 | >>> db.reset_queries() 72 | >>> entries = Entry.objects.batch_select('tags').all() 73 | >>> entry = entries[0] 74 | >>> print entry.tags_all 75 | [] 76 | >>> show_queries() 77 | SELECT "batch_select_entry"."id", "batch_select_entry"."title", "batch_select_entry"."section_id" FROM "batch_select_entry" LIMIT 1 78 | SELECT (`batch_select_entry_tags`.`entry_id`) AS "entry_id", "batch_select_tag"."id", "batch_select_tag"."name" FROM "batch_select_tag" INNER JOIN "batch_select_entry_tags" ON ("batch_select_tag"."id" = "batch_select_entry_tags"."tag_id") WHERE "batch_select_entry_tags".entry_id IN (1) 79 | >>> entries = Entry.objects.batch_select('tags').all() 80 | >>> for entry in entries: 81 | .... print entry.tags_all 82 | .... 83 | [] 84 | [] 85 | >>> show_queries() 86 | SELECT "batch_select_entry"."id", "batch_select_entry"."title", "batch_select_entry"."section_id" FROM "batch_select_entry" 87 | SELECT (`batch_select_entry_tags`.`entry_id`) AS "entry_id", "batch_select_tag"."id", "batch_select_tag"."name" FROM "batch_select_tag" INNER JOIN "batch_select_entry_tags" ON ("batch_select_tag"."id" = "batch_select_entry_tags"."tag_id") WHERE "batch_select_entry_tags".entry_id IN (1, 2) 88 | 89 | Re-running that same last for loop without using batch_select 90 | generate three queries instead of two (n+1 queries):: 91 | 92 | >>> entries = Entry.objects.all() 93 | >>> for entry in entries: 94 | .... print entry.tags.all() 95 | .... 96 | [] 97 | [] 98 | 99 | >>> show_queries() 100 | SELECT "batch_select_entry"."id", "batch_select_entry"."title", "batch_select_entry"."section_id" FROM "batch_select_entry" 101 | SELECT "batch_select_tag"."id", "batch_select_tag"."name" FROM "batch_select_tag" INNER JOIN "batch_select_entry_tags" ON ("batch_select_tag"."id" = "batch_select_entry_tags"."tag_id") WHERE "batch_select_entry_tags"."entry_id" = 1 102 | SELECT "batch_select_tag"."id", "batch_select_tag"."name" FROM "batch_select_tag" INNER JOIN "batch_select_entry_tags" ON ("batch_select_tag"."id" = "batch_select_entry_tags"."tag_id") WHERE "batch_select_entry_tags"."entry_id" = 2 103 | 104 | This also works with reverse foreign keys. So for example we can get 105 | this entries that belong to each section:: 106 | 107 | >>> section1 = Section.objects.create(name='section1') 108 | >>> section2 = Section.objects.create(name='section2') 109 | >>> Entry.objects.create(section=section1) 110 | >>> Entry.objects.create(section=section1) 111 | >>> Entry.objects.create(section=section2) 112 | >>> db.reset_queries() 113 | >>> Section.objects.batch_select('entry_set') 114 | [, ] 115 | >>> show_queries() 116 | SELECT "batch_select_section"."id", "batch_select_section"."name" FROM "batch_select_section" LIMIT 21 117 | SELECT ("batch_select_entry"."section_id") AS "__section_id", "batch_select_entry"."id", "batch_select_entry"."title", "batch_select_entry"."section_id", "batch_select_entry"."location_id" FROM "batch_select_entry" WHERE "batch_select_entry"."section_id" IN (1, 2) 118 | 119 | Each section object in that query will have an entry_set_all field 120 | containing the relevant entries. 121 | 122 | You need to pass batch_select the "related name" of the foreign key, 123 | in this case "entry_set". NB by default the related name for a foreign 124 | key does not actually include the _set suffix, so you can use just "entry" 125 | in this case. I have made sure that the _set suffix version also works to 126 | try and keep the API simpler. 127 | 128 | 129 | More Advanced Usage 130 | ========================= 131 | 132 | By default the batch fields are inserted into fields named ``_all``, 133 | on each object. So:: 134 | 135 | Entry.objects.batch_select('tags').all() 136 | 137 | results in the Entry instances having fields called ``'tags_all'`` 138 | containing the Tag objects associated with that Entry. 139 | 140 | If you want to give the field a different name just use a keyword 141 | argument - in the same way as using the Aggregation_ API:: 142 | 143 | Entry.objects.batch_select(selected_tags='tags').all() 144 | 145 | Would means the Tag objects would be assigned to fields called 146 | ``'selected_tags'``. 147 | 148 | If you want to perform filtering of the related objects you will need to 149 | use a Batch object. By doing this you can pass extra keyword arguments 150 | in the same way as when using the filter method of a QuerySet:: 151 | 152 | from batch_select.models import Batch 153 | 154 | Entry.objects.batch_select(tags_containing_blue=Batch('tags', name__contains='blue')) 155 | 156 | Would return Entry objects with fields called 'tags_containing_name' with 157 | only those Tags whose name contains 'blue'. 158 | 159 | In addition to filtering using keyword arguments, you can also call the 160 | following methods on a Batch object, with their effects being passed on 161 | to the underlying QuerySet_ object: 162 | 163 | * filter_ 164 | * exclude_ 165 | * annotate_ 166 | * order_by_ 167 | * reverse_ 168 | * select_related_ 169 | * extra_ 170 | * defer_ 171 | * only_ 172 | * batch_select 173 | 174 | (Note that distinct(), values() etc are not included as they would have 175 | side-effects on how the extra query is associated with the original query) 176 | So for example to achieve the same effect as the filter above you could 177 | do the following:: 178 | 179 | from batch_select.models import Batch 180 | 181 | Entry.objects.batch_select(tags_containing_blue=Batch('tags').filter(name__contains='blue')) 182 | 183 | Whereas the following would exclude tags containing "blue" and order by name:: 184 | 185 | from batch_select.models import Batch 186 | 187 | batch = Batch('tags').exclude(name__contains='blue').order_by('name') 188 | Entry.objects.batch_select(tags_not_containing_blue=batch) 189 | 190 | 191 | Compatibility 192 | ============= 193 | 194 | Django batch select should work with Django 1.1-1.3 at least. 195 | 196 | 197 | TODOs and BUGS 198 | ============== 199 | 200 | See: http://github.com/lilspikey/django-batch-select/issues 201 | 202 | .. _select_related: http://docs.djangoproject.com/en/dev/ref/models/querysets/#id4 203 | .. _ForeignKey: http://docs.djangoproject.com/en/dev/ref/models/fields/#foreignkey 204 | .. _ManyToManyField: http://docs.djangoproject.com/en/dev/ref/models/fields/#manytomanyfield 205 | .. _QuerySet: http://docs.djangoproject.com/en/dev/ref/models/querysets/ 206 | .. _Manager: http://docs.djangoproject.com/en/dev/topics/db/managers/ 207 | .. _Aggregation: http://docs.djangoproject.com/en/dev/topics/db/aggregation/ 208 | .. _filter: http://docs.djangoproject.com/en/dev/ref/models/querysets/#filter-kwargs 209 | .. _exclude: http://docs.djangoproject.com/en/dev/ref/models/querysets/#exclude-kwargs 210 | .. _annotate: http://docs.djangoproject.com/en/dev/ref/models/querysets/#annotate-args-kwargs 211 | .. _order_by: http://docs.djangoproject.com/en/dev/ref/models/querysets/#order-by-fields 212 | .. _reverse: http://docs.djangoproject.com/en/dev/ref/models/querysets/#reverse 213 | .. _extra: http://docs.djangoproject.com/en/dev/ref/models/querysets/#extra-select-none-where-none-params-none-tables-none-order-by-none-select-params-none 214 | .. _defer: http://docs.djangoproject.com/en/dev/ref/models/querysets/#defer-fields 215 | .. _only: http://docs.djangoproject.com/en/dev/ref/models/querysets/#only-fields 216 | -------------------------------------------------------------------------------- /batch_select/__init__.py: -------------------------------------------------------------------------------- 1 | VERSION = (0, 2, 4) 2 | __version__ = '.'.join(map(str, VERSION)) -------------------------------------------------------------------------------- /batch_select/models.py: -------------------------------------------------------------------------------- 1 | from django.db.models.query import QuerySet 2 | from django.db import models, connection 3 | from django.db.models.fields import FieldDoesNotExist 4 | 5 | from django.conf import settings 6 | 7 | from replay import Replay 8 | 9 | def _not_exists(fieldname): 10 | raise FieldDoesNotExist('"%s" is not a ManyToManyField or a reverse ForeignKey relationship' % fieldname) 11 | 12 | def _check_field_exists(model, fieldname): 13 | try: 14 | field_object, model, direct, m2m = model._meta.get_field_by_name(fieldname) 15 | except FieldDoesNotExist: 16 | # might be after reverse foreign key 17 | # which by default don't have the name we expect 18 | if fieldname.endswith('_set'): 19 | return _check_field_exists(model, fieldname[:-len('_set')]) 20 | else: 21 | raise 22 | if not m2m: 23 | if direct: # reverse foreign key relationship 24 | _not_exists(fieldname) 25 | return fieldname 26 | 27 | def _id_attr(id_column): 28 | # mangle the id column name, so we can make sure 29 | # the postgres doesn't complain about not quoting 30 | # field names (this helps make sure we don't clash 31 | # with the regular id column) 32 | return '__%s' % id_column.lower() 33 | 34 | def _select_related_instances(related_model, related_name, ids, db_table, id_column): 35 | id__in_filter={ ('%s__pk__in' % related_name): ids } 36 | qn = connection.ops.quote_name 37 | select = { _id_attr(id_column): '%s.%s' % (qn(db_table), qn(id_column)) } 38 | related_instances = related_model._default_manager \ 39 | .filter(**id__in_filter) \ 40 | .extra(select=select) 41 | return related_instances 42 | 43 | def batch_select(model, instances, target_field_name, fieldname, filter=None): 44 | ''' 45 | basically do an extra-query to select the many-to-many 46 | field values into the instances given. e.g. so we can get all 47 | Entries and their Tags in two queries rather than n+1 48 | 49 | returns a list of the instances with the newly attached fields 50 | 51 | batch_select(Entry, Entry.objects.all(), 'tags_all', 'tags') 52 | 53 | would return a list of Entry objects with 'tags_all' fields 54 | containing the tags for that Entry 55 | 56 | filter is a function that can be used alter the extra-query - it 57 | takes a queryset and returns a filtered version of the queryset 58 | 59 | NB: this is a semi-private API at the moment, but may be useful if you 60 | dont want to change your model/manager. 61 | ''' 62 | 63 | fieldname = _check_field_exists(model, fieldname) 64 | 65 | instances = list(instances) 66 | ids = [instance.pk for instance in instances] 67 | 68 | field_object, model, direct, m2m = model._meta.get_field_by_name(fieldname) 69 | if m2m: 70 | if not direct: 71 | m2m_field = field_object.field 72 | related_model = field_object.model 73 | related_name = m2m_field.name 74 | id_column = m2m_field.m2m_reverse_name() 75 | db_table = m2m_field.m2m_db_table() 76 | else: 77 | m2m_field = field_object 78 | related_model = m2m_field.rel.to # model on other end of relationship 79 | related_name = m2m_field.related_query_name() 80 | id_column = m2m_field.m2m_column_name() 81 | db_table = m2m_field.m2m_db_table() 82 | elif not direct: 83 | # handle reverse foreign key relationships 84 | fk_field = field_object.field 85 | related_model = field_object.model 86 | related_name = fk_field.name 87 | id_column = fk_field.column 88 | db_table = related_model._meta.db_table 89 | 90 | related_instances = _select_related_instances(related_model, related_name, 91 | ids, db_table, id_column) 92 | 93 | if filter: 94 | related_instances = filter(related_instances) 95 | 96 | grouped = {} 97 | id_attr = _id_attr(id_column) 98 | for related_instance in related_instances: 99 | instance_id = getattr(related_instance, id_attr) 100 | group = grouped.get(instance_id, []) 101 | group.append(related_instance) 102 | grouped[instance_id] = group 103 | 104 | for instance in instances: 105 | setattr(instance, target_field_name, grouped.get(instance.pk, [])) 106 | 107 | return instances 108 | 109 | class Batch(Replay): 110 | # functions on QuerySet that we can invoke via this batch object 111 | __replayable__ = ('filter', 'exclude', 'annotate', 112 | 'order_by', 'reverse', 'select_related', 113 | 'extra', 'defer', 'only', 'batch_select') 114 | 115 | def __init__(self, m2m_fieldname, **filter): 116 | super(Batch,self).__init__() 117 | self.m2m_fieldname = m2m_fieldname 118 | self.target_field_name = '%s_all' % m2m_fieldname 119 | if filter: # add a filter replay method 120 | self._add_replay('filter', *(), **filter) 121 | 122 | def clone(self): 123 | cloned = super(Batch, self).clone(self.m2m_fieldname) 124 | cloned.target_field_name = self.target_field_name 125 | return cloned 126 | 127 | class BatchQuerySet(QuerySet): 128 | 129 | def _clone(self, *args, **kwargs): 130 | query = super(BatchQuerySet, self)._clone(*args, **kwargs) 131 | batches = getattr(self, '_batches', None) 132 | if batches: 133 | query._batches = set(batches) 134 | return query 135 | 136 | def _create_batch(self, batch_or_str, target_field_name=None): 137 | batch = batch_or_str 138 | if isinstance(batch_or_str, basestring): 139 | batch = Batch(batch_or_str) 140 | if target_field_name: 141 | batch.target_field_name = target_field_name 142 | 143 | _check_field_exists(self.model, batch.m2m_fieldname) 144 | return batch 145 | 146 | def batch_select(self, *batches, **named_batches): 147 | batches = getattr(self, '_batches', set()) | \ 148 | set(self._create_batch(batch) for batch in batches) | \ 149 | set(self._create_batch(batch, target_field_name) \ 150 | for target_field_name, batch in named_batches.items()) 151 | 152 | query = self._clone() 153 | query._batches = batches 154 | return query 155 | 156 | def iterator(self): 157 | result_iter = super(BatchQuerySet, self).iterator() 158 | batches = getattr(self, '_batches', None) 159 | if batches: 160 | results = list(result_iter) 161 | for batch in batches: 162 | results = batch_select(self.model, results, 163 | batch.target_field_name, 164 | batch.m2m_fieldname, 165 | batch.replay) 166 | return iter(results) 167 | return result_iter 168 | 169 | class BatchManager(models.Manager): 170 | use_for_related_fields = True 171 | 172 | def get_query_set(self): 173 | return BatchQuerySet(self.model) 174 | 175 | def batch_select(self, *batches, **named_batches): 176 | return self.all().batch_select(*batches, **named_batches) 177 | 178 | if getattr(settings, 'TESTING_BATCH_SELECT', False): 179 | class Tag(models.Model): 180 | name = models.CharField(max_length=32) 181 | 182 | objects = BatchManager() 183 | 184 | class Section(models.Model): 185 | name = models.CharField(max_length=32) 186 | 187 | objects = BatchManager() 188 | 189 | class Location(models.Model): 190 | name = models.CharField(max_length=32) 191 | 192 | class Entry(models.Model): 193 | title = models.CharField(max_length=255) 194 | section = models.ForeignKey(Section, blank=True, null=True) 195 | location = models.ForeignKey(Location, blank=True, null=True) 196 | tags = models.ManyToManyField(Tag) 197 | 198 | objects = BatchManager() 199 | 200 | class Country(models.Model): 201 | # non id pk 202 | name = models.CharField(primary_key=True, max_length=100) 203 | locations = models.ManyToManyField(Location) 204 | 205 | objects = BatchManager() 206 | 207 | -------------------------------------------------------------------------------- /batch_select/replay.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Class that we can use to chain together several methods calls (in 3 | the same fashion as we would with a QuerySet) and then "replay" them 4 | later on a different object. 5 | ''' 6 | 7 | def create_replay_method(name): 8 | def _replay_method(self, *args, **kwargs): 9 | cloned = self.clone() 10 | cloned._add_replay(name, *args, **kwargs) 11 | return cloned 12 | _replay_method.__name__ = name 13 | _replay_method.__doc__ = 'replay %s method on target object' % name 14 | return _replay_method 15 | 16 | class ReplayMetaClass(type): 17 | def __new__(meta, classname, bases, class_dict): 18 | replay_methods = class_dict.get('__replayable__', []) 19 | for name in replay_methods: 20 | class_dict[name] = create_replay_method(name) 21 | return type.__new__(meta, classname, bases, class_dict) 22 | 23 | class Replay(object): 24 | __metaclass__ = ReplayMetaClass 25 | 26 | def __init__(self): 27 | self._replays=[] 28 | 29 | def _add_replay(self, method_name, *args, **kwargs): 30 | self._replays.append((method_name, args, kwargs)) 31 | 32 | def clone(self, *args, **kwargs): 33 | klass = self.__class__ 34 | cloned = klass(*args, **kwargs) 35 | cloned._replays=self._replays[:] 36 | return cloned 37 | 38 | def replay(self, target): 39 | result = target 40 | for method_name, args, kwargs in self._replays: 41 | method = getattr(result, method_name) 42 | result = method(*args, **kwargs) 43 | return result 44 | -------------------------------------------------------------------------------- /batch_select/tests.py: -------------------------------------------------------------------------------- 1 | from django.conf import settings 2 | 3 | if getattr(settings, 'TESTING_BATCH_SELECT', False): 4 | from django.test import TransactionTestCase 5 | from django.db.models.fields import FieldDoesNotExist 6 | from batch_select.models import Tag, Entry, Section, Batch, Location,\ 7 | _select_related_instances, Country,\ 8 | _check_field_exists 9 | from batch_select.replay import Replay 10 | from django import db 11 | from django.db.models import Count 12 | import unittest 13 | 14 | def with_debug_queries(fn): 15 | def _decorated(*arg, **kw): 16 | db.reset_queries() 17 | old_debug, settings.DEBUG = settings.DEBUG, True 18 | result = fn(*arg, **kw) 19 | settings.DEBUG = old_debug 20 | return result 21 | return _decorated 22 | 23 | def _create_tags(*names): 24 | return [Tag.objects.create(name=name) for name in names] 25 | 26 | def _create_entries(count): 27 | return [Entry.objects.create() for _ in xrange(count)] 28 | 29 | class TestBatchSelect(TransactionTestCase): 30 | 31 | def test_batch_select_empty(self): 32 | entries = Entry.objects.batch_select('tags') 33 | self.failUnlessEqual([], list(entries)) 34 | 35 | def test_batch_select_no_tags(self): 36 | entry = Entry.objects.create() 37 | entries = Entry.objects.batch_select('tags') 38 | self.failUnlessEqual([entry], list(entries)) 39 | 40 | def test_batch_select_default_name(self): 41 | entry = _create_entries(1)[0] 42 | tag1, tag2 = _create_tags('tag1', 'tag2') 43 | 44 | entry.tags.add(tag1, tag2) 45 | 46 | entry = Entry.objects.batch_select('tags')[0] 47 | 48 | self.failIf( getattr(entry, 'tags_all', None) is None ) 49 | self.failUnlessEqual( set([tag1, tag2]), set(entry.tags_all) ) 50 | 51 | def test_batch_select_non_default_name(self): 52 | entry = _create_entries(1)[0] 53 | tag1, tag2 = _create_tags('tag1', 'tag2') 54 | 55 | entry.tags.add(tag1, tag2) 56 | 57 | entry = Entry.objects.batch_select(batch_tags='tags')[0] 58 | 59 | self.failIf( getattr(entry, 'batch_tags', None) is None ) 60 | self.failUnlessEqual( set([tag1, tag2]), set(entry.batch_tags) ) 61 | 62 | def test_batch_select_with_tags(self): 63 | entry1, entry2, entry3, entry4 = _create_entries(4) 64 | tag1, tag2, tag3 = _create_tags('tag1', 'tag2', 'tag3') 65 | 66 | entry1.tags.add(tag1, tag2, tag3) 67 | 68 | entry2.tags.add(tag2) 69 | 70 | entry3.tags.add(tag2, tag3) 71 | 72 | entries = Entry.objects.batch_select('tags').order_by('id') 73 | entries = list(entries) 74 | 75 | self.failUnlessEqual([entry1, entry2, entry3, entry4], entries) 76 | 77 | entry1, entry2, entry3, entry4 = entries 78 | 79 | self.failUnlessEqual(set([tag1, tag2, tag3]), set(entry1.tags_all)) 80 | self.failUnlessEqual(set([tag2]), set(entry2.tags_all)) 81 | self.failUnlessEqual(set([tag2, tag3]), set(entry3.tags_all)) 82 | self.failUnlessEqual(set([]), set(entry4.tags_all)) 83 | 84 | def test_batch_select_get(self): 85 | entry = Entry.objects.create() 86 | tag1, tag2, tag3 = _create_tags('tag1', 'tag2', 'tag3') 87 | 88 | entry.tags.add(tag1, tag2, tag3) 89 | 90 | entry = Entry.objects.batch_select('tags').get() 91 | 92 | self.failIf( getattr(entry, 'tags_all', None) is None ) 93 | self.failUnlessEqual( set([tag1, tag2, tag3]), set(entry.tags_all) ) 94 | 95 | def test_batch_select_caching_works(self): 96 | # make sure that query set caching still 97 | # works and doesn't alter the added fields 98 | entry1, entry2, entry3, entry4 = _create_entries(4) 99 | tag1, tag2, tag3 = _create_tags('tag1', 'tag2', 'tag3') 100 | 101 | entry1.tags.add(tag1, tag2, tag3) 102 | 103 | entry2.tags.add(tag2) 104 | 105 | entry3.tags.add(tag2, tag3) 106 | 107 | qs = Entry.objects.batch_select(Batch('tags')).order_by('id') 108 | 109 | self.failUnlessEqual([entry1, entry2, entry3, entry4], list(qs)) 110 | 111 | entry1, entry2, entry3, entry4 = list(qs) 112 | 113 | self.failUnlessEqual(set([tag1, tag2, tag3]), set(entry1.tags_all)) 114 | self.failUnlessEqual(set([tag2]), set(entry2.tags_all)) 115 | self.failUnlessEqual(set([tag2, tag3]), set(entry3.tags_all)) 116 | self.failUnlessEqual(set([]), set(entry4.tags_all)) 117 | 118 | def test_no_batch_select(self): 119 | # make sure things still work when we don't do a batch select 120 | entry1, entry2, entry3, entry4 = _create_entries(4) 121 | 122 | qs = Entry.objects.all().order_by('id') 123 | 124 | self.failUnlessEqual([entry1, entry2, entry3, entry4], list(qs)) 125 | 126 | def test_batch_select_after_new_query(self): 127 | entry1, entry2, entry3, entry4 = _create_entries(4) 128 | tag1, tag2, tag3 = _create_tags('tag1', 'tag2', 'tag3') 129 | 130 | entry1.tags.add(tag1, tag2, tag3) 131 | 132 | entry2.tags.add(tag2) 133 | 134 | entry3.tags.add(tag2, tag3) 135 | 136 | qs = Entry.objects.batch_select(Batch('tags')).order_by('id') 137 | 138 | self.failUnlessEqual([entry1, entry2, entry3, entry4], list(qs)) 139 | 140 | entry1, entry2, entry3, entry4 = list(qs) 141 | 142 | self.failUnlessEqual(set([tag1, tag2, tag3]), set(entry1.tags_all)) 143 | self.failUnlessEqual(set([tag2]), set(entry2.tags_all)) 144 | self.failUnlessEqual(set([tag2, tag3]), set(entry3.tags_all)) 145 | self.failUnlessEqual(set([]), set(entry4.tags_all)) 146 | 147 | new_qs = qs.filter(id=entry1.id) 148 | 149 | self.failUnlessEqual([entry1], list(new_qs)) 150 | 151 | entry1 = list(new_qs)[0] 152 | self.failUnlessEqual(set([tag1, tag2, tag3]), set(entry1.tags_all)) 153 | 154 | @with_debug_queries 155 | def test_batch_select_minimal_queries(self): 156 | # make sure we are only doing the number of sql queries we intend to 157 | entry1, entry2, entry3, entry4 = _create_entries(4) 158 | tag1, tag2, tag3 = _create_tags('tag1', 'tag2', 'tag3') 159 | 160 | entry1.tags.add(tag1, tag2, tag3) 161 | entry2.tags.add(tag2) 162 | entry3.tags.add(tag2, tag3) 163 | 164 | db.reset_queries() 165 | 166 | qs = Entry.objects.batch_select(Batch('tags')).order_by('id') 167 | 168 | self.failUnlessEqual([entry1, entry2, entry3, entry4], list(qs)) 169 | 170 | # this should have resulted in only two queries 171 | self.failUnlessEqual(2, len(db.connection.queries)) 172 | 173 | # double-check result is cached, and doesn't trigger more queries 174 | self.failUnlessEqual([entry1, entry2, entry3, entry4], list(qs)) 175 | self.failUnlessEqual(2, len(db.connection.queries)) 176 | 177 | @with_debug_queries 178 | def test_no_batch_select_minimal_queries(self): 179 | # check we haven't altered the original querying behaviour 180 | entry1, entry2, entry3 = _create_entries(3) 181 | 182 | db.reset_queries() 183 | 184 | qs = Entry.objects.order_by('id') 185 | 186 | self.failUnlessEqual([entry1, entry2, entry3], list(qs)) 187 | 188 | # this should have resulted in only two queries 189 | self.failUnlessEqual(1, len(db.connection.queries)) 190 | 191 | # check caching still works 192 | self.failUnlessEqual([entry1, entry2, entry3], list(qs)) 193 | self.failUnlessEqual(1, len(db.connection.queries)) 194 | 195 | def test_batch_select_non_existant_field(self): 196 | try: 197 | qs = Entry.objects.batch_select(Batch('qwerty')).order_by('id') 198 | self.fail('selected field that does not exist') 199 | except FieldDoesNotExist: 200 | pass 201 | 202 | def test_batch_select_non_m2m_field(self): 203 | try: 204 | qs = Entry.objects.batch_select(Batch('title')).order_by('id') 205 | self.fail('selected field that is not m2m field') 206 | except FieldDoesNotExist: 207 | pass 208 | 209 | def test_batch_select_empty_one_to_many(self): 210 | sections = Section.objects.batch_select('entry') 211 | self.failUnlessEqual([], list(sections)) 212 | 213 | def test_batch_select_one_to_many_no_children(self): 214 | section1 = Section.objects.create(name='s1') 215 | section2 = Section.objects.create(name='s2') 216 | 217 | sections = Section.objects.batch_select('entry').order_by('id') 218 | self.failUnlessEqual([section1, section2], list(sections)) 219 | 220 | def test_batch_select_one_to_many_with_children(self): 221 | section1 = Section.objects.create(name='s1') 222 | section2 = Section.objects.create(name='s2') 223 | section3 = Section.objects.create(name='s3') 224 | 225 | entry1 = Entry.objects.create(section=section1) 226 | entry2 = Entry.objects.create(section=section1) 227 | entry3 = Entry.objects.create(section=section3) 228 | 229 | sections = Section.objects.batch_select('entry').order_by('id') 230 | self.failUnlessEqual([section1, section2, section3], list(sections)) 231 | 232 | section1, section2, section3 = list(sections) 233 | 234 | self.failUnlessEqual(set([entry1, entry2]), set(section1.entry_all)) 235 | self.failUnlessEqual(set([]), set(section2.entry_all)) 236 | self.failUnlessEqual(set([entry3]), set(section3.entry_all)) 237 | 238 | def test___check_field_exists_full_field_name(self): 239 | # make sure we can retrieve the "real" fieldname 240 | # given the full fieldname for a reverse foreign key 241 | # relationship 242 | # e.g. give "entry_set" we should get "entry" 243 | self.failUnlessEqual("entry", _check_field_exists(Section, "entry_set")) 244 | 245 | def test___check_field_exists_full_field_name_non_existant_field(self): 246 | try: 247 | _check_field_exists(Section, "qwerty_set") 248 | self.fail('selected field that does not exist') 249 | except FieldDoesNotExist: 250 | pass 251 | 252 | def test_batch_select_one_to_many_with_children_full_field_name(self): 253 | section1 = Section.objects.create(name='s1') 254 | section2 = Section.objects.create(name='s2') 255 | section3 = Section.objects.create(name='s3') 256 | 257 | entry1 = Entry.objects.create(section=section1) 258 | entry2 = Entry.objects.create(section=section1) 259 | entry3 = Entry.objects.create(section=section3) 260 | 261 | sections = Section.objects.batch_select('entry_set').order_by('id') 262 | self.failUnlessEqual([section1, section2, section3], list(sections)) 263 | 264 | section1, section2, section3 = list(sections) 265 | 266 | self.failUnlessEqual(set([entry1, entry2]), set(section1.entry_set_all)) 267 | self.failUnlessEqual(set([]), set(section2.entry_set_all)) 268 | self.failUnlessEqual(set([entry3]), set(section3.entry_set_all)) 269 | 270 | @with_debug_queries 271 | def test_batch_select_one_to_many_with_children_minimal_queries(self): 272 | section1 = Section.objects.create(name='s1') 273 | section2 = Section.objects.create(name='s2') 274 | section3 = Section.objects.create(name='s3') 275 | 276 | entry1 = Entry.objects.create(section=section1) 277 | entry2 = Entry.objects.create(section=section2) 278 | entry3 = Entry.objects.create(section=section3) 279 | 280 | db.reset_queries() 281 | 282 | sections = Section.objects.batch_select('entry').order_by('id') 283 | self.failUnlessEqual([section1, section2, section3], list(sections)) 284 | 285 | # this should have resulted in only two queries 286 | self.failUnlessEqual(2, len(db.connection.queries)) 287 | 288 | section1, section2, section3 = list(sections) 289 | 290 | self.failUnlessEqual(set([entry1]), set(section1.entry_all)) 291 | self.failUnlessEqual(set([entry2]), set(section2.entry_all)) 292 | self.failUnlessEqual(set([entry3]), set(section3.entry_all)) 293 | 294 | class TestBatchSelectQuerySetMethods(TransactionTestCase): 295 | 296 | def setUp(self): 297 | super(TransactionTestCase, self).setUp() 298 | self.entry1, self.entry2, self.entry3, self.entry4 = _create_entries(4) 299 | # put tags names in different order to id 300 | self.tag2, self.tag1, self.tag3 = _create_tags('tag2', 'tag1', 'tag3') 301 | 302 | self.entry1.tags.add(self.tag1, self.tag2, self.tag3) 303 | self.entry2.tags.add(self.tag2) 304 | self.entry3.tags.add(self.tag2, self.tag3) 305 | 306 | def test_batch_select_filtering_name_params(self): 307 | entries = Entry.objects.batch_select(Batch('tags', name='tag1')).order_by('id') 308 | entries = list(entries) 309 | 310 | self.failUnlessEqual([self.entry1, self.entry2, self.entry3, self.entry4], 311 | entries) 312 | 313 | entry1, entry2, entry3, entry4 = entries 314 | 315 | self.failUnlessEqual(set([self.tag1]), set(entry1.tags_all)) 316 | self.failUnlessEqual(set([]), set(entry2.tags_all)) 317 | self.failUnlessEqual(set([]), set(entry3.tags_all)) 318 | self.failUnlessEqual(set([]), set(entry4.tags_all)) 319 | 320 | def test_batch_select_filter(self): 321 | entries = Entry.objects.batch_select(Batch('tags').filter(name='tag2')).order_by('id') 322 | entries = list(entries) 323 | 324 | self.failUnlessEqual([self.entry1, self.entry2, self.entry3, self.entry4], 325 | entries) 326 | 327 | entry1, entry2, entry3, entry4 = entries 328 | 329 | self.failUnlessEqual(set([self.tag2]), set(entry1.tags_all)) 330 | self.failUnlessEqual(set([self.tag2]), set(entry2.tags_all)) 331 | self.failUnlessEqual(set([self.tag2]), set(entry3.tags_all)) 332 | self.failUnlessEqual(set([]), set(entry4.tags_all)) 333 | 334 | def test_batch_select_exclude(self): 335 | entries = Entry.objects.batch_select(Batch('tags').exclude(name='tag2')).order_by('id') 336 | entries = list(entries) 337 | 338 | self.failUnlessEqual([self.entry1, self.entry2, self.entry3, self.entry4], 339 | entries) 340 | 341 | entry1, entry2, entry3, entry4 = entries 342 | 343 | self.failUnlessEqual(set([self.tag1, self.tag3]), set(entry1.tags_all)) 344 | self.failUnlessEqual(set([]), set(entry2.tags_all)) 345 | self.failUnlessEqual(set([self.tag3]), set(entry3.tags_all)) 346 | self.failUnlessEqual(set([]), set(entry4.tags_all)) 347 | 348 | def test_batch_order_by_name(self): 349 | entries = Entry.objects.batch_select(Batch('tags').order_by('name')).order_by('id') 350 | entries = list(entries) 351 | 352 | self.failUnlessEqual([self.entry1, self.entry2, self.entry3, self.entry4], 353 | entries) 354 | 355 | entry1, entry2, entry3, entry4 = entries 356 | 357 | self.failUnlessEqual([self.tag1, self.tag2, self.tag3], entry1.tags_all) 358 | self.failUnlessEqual([self.tag2], entry2.tags_all) 359 | self.failUnlessEqual([self.tag2, self.tag3], entry3.tags_all) 360 | self.failUnlessEqual([], entry4.tags_all) 361 | 362 | def test_batch_order_by_id(self): 363 | entries = Entry.objects.batch_select(Batch('tags').order_by('id')).order_by('id') 364 | entries = list(entries) 365 | 366 | self.failUnlessEqual([self.entry1, self.entry2, self.entry3, self.entry4], 367 | entries) 368 | 369 | entry1, entry2, entry3, entry4 = entries 370 | 371 | self.failUnlessEqual([self.tag2, self.tag1, self.tag3], entry1.tags_all) 372 | self.failUnlessEqual([self.tag2], entry2.tags_all) 373 | self.failUnlessEqual([self.tag2, self.tag3], entry3.tags_all) 374 | self.failUnlessEqual([], entry4.tags_all) 375 | 376 | def test_batch_reverse(self): 377 | entries = Entry.objects.batch_select(Batch('tags').order_by('name').reverse()).order_by('id') 378 | entries = list(entries) 379 | 380 | self.failUnlessEqual([self.entry1, self.entry2, self.entry3, self.entry4], 381 | entries) 382 | 383 | entry1, entry2, entry3, entry4 = entries 384 | 385 | self.failUnlessEqual([self.tag3, self.tag2, self.tag1], entry1.tags_all) 386 | self.failUnlessEqual([self.tag2], entry2.tags_all) 387 | self.failUnlessEqual([self.tag3, self.tag2], entry3.tags_all) 388 | self.failUnlessEqual([], entry4.tags_all) 389 | 390 | def test_batch_annotate(self): 391 | section1 = Section.objects.create(name='s1') 392 | section2 = Section.objects.create(name='s2') 393 | section3 = Section.objects.create(name='s3') 394 | 395 | entry1 = Entry.objects.create(section=section1) 396 | entry2 = Entry.objects.create(section=section1) 397 | entry3 = Entry.objects.create(section=section3) 398 | 399 | entry1.tags.add(self.tag2, self.tag3, self.tag1) 400 | entry3.tags.add(self.tag2, self.tag3) 401 | 402 | batch = Batch('entry').order_by('id').annotate(Count('tags')) 403 | sections = Section.objects.batch_select(batch).order_by('id') 404 | sections = list(sections) 405 | self.failUnlessEqual([section1, section2, section3], sections) 406 | 407 | section1, section2, section3 = sections 408 | 409 | self.failUnlessEqual([entry1, entry2], section1.entry_all) 410 | self.failUnlessEqual([], section2.entry_all) 411 | self.failUnlessEqual([entry3], section3.entry_all) 412 | 413 | self.failUnlessEqual(3, section1.entry_all[0].tags__count) 414 | self.failUnlessEqual(0, section1.entry_all[1].tags__count) 415 | self.failUnlessEqual(2, section3.entry_all[0].tags__count) 416 | 417 | @with_debug_queries 418 | def test_batch_select_related(self): 419 | # verify using select related doesn't tigger more queries 420 | section1 = Section.objects.create(name='s1') 421 | section2 = Section.objects.create(name='s2') 422 | section3 = Section.objects.create(name='s3') 423 | 424 | location = Location.objects.create(name='home') 425 | 426 | entry1 = Entry.objects.create(section=section1, location=location) 427 | entry2 = Entry.objects.create(section=section1) 428 | entry3 = Entry.objects.create(section=section3) 429 | 430 | entry1.tags.add(self.tag2, self.tag3, self.tag1) 431 | entry3.tags.add(self.tag2, self.tag3) 432 | 433 | db.reset_queries() 434 | 435 | batch = Batch('entry').order_by('id').select_related('location') 436 | sections = Section.objects.batch_select(batch).order_by('id') 437 | sections = list(sections) 438 | self.failUnlessEqual([section1, section2, section3], sections) 439 | 440 | section1, section2, section3 = sections 441 | 442 | self.failUnlessEqual([entry1, entry2], section1.entry_all) 443 | self.failUnlessEqual([], section2.entry_all) 444 | self.failUnlessEqual([entry3], section3.entry_all) 445 | 446 | self.failUnlessEqual(2, len(db.connection.queries)) 447 | db.reset_queries() 448 | 449 | entry1, entry2 = section1.entry_all 450 | 451 | self.failUnlessEqual(0, len(db.connection.queries)) 452 | self.failUnlessEqual(location, entry1.location) 453 | self.failUnlessEqual(0, len(db.connection.queries)) 454 | self.failUnless( entry2.location is None ) 455 | self.failUnlessEqual(0, len(db.connection.queries)) 456 | 457 | def _check_name_deferred(self, batch): 458 | entries = Entry.objects.batch_select(batch).order_by('id') 459 | entries = list(entries) 460 | 461 | self.failUnlessEqual([self.entry1, self.entry2, self.entry3, self.entry4], 462 | entries) 463 | 464 | self.failUnlessEqual(2, len(db.connection.queries)) 465 | db.reset_queries() 466 | 467 | entry1, entry2, entry3, entry4 = entries 468 | 469 | self.failUnlessEqual(3, len(entry1.tags_all)) 470 | self.failUnlessEqual(1, len(entry2.tags_all)) 471 | self.failUnlessEqual(2, len(entry3.tags_all)) 472 | self.failUnlessEqual(0, len(entry4.tags_all)) 473 | 474 | self.failUnlessEqual(0, len(db.connection.queries)) 475 | 476 | # as name has been defered it should trigger a query when we 477 | # try to access it 478 | self.failUnlessEqual( self.tag2.name, entry1.tags_all[0].name ) 479 | self.failUnlessEqual(1, len(db.connection.queries)) 480 | self.failUnlessEqual( self.tag1.name, entry1.tags_all[1].name ) 481 | self.failUnlessEqual(2, len(db.connection.queries)) 482 | self.failUnlessEqual( self.tag3.name, entry1.tags_all[2].name ) 483 | self.failUnlessEqual(3, len(db.connection.queries)) 484 | 485 | @with_debug_queries 486 | def test_batch_defer(self): 487 | batch = Batch('tags').order_by('id').defer('name') 488 | self._check_name_deferred(batch) 489 | 490 | @with_debug_queries 491 | def test_batch_only(self): 492 | batch = Batch('tags').order_by('id').only('id') 493 | self._check_name_deferred(batch) 494 | 495 | def test_batch_select_reverse_m2m(self): 496 | entry1, entry2, entry3, entry4 = _create_entries(4) 497 | tag1, tag2, tag3 = _create_tags('tag1', 'tag2', 'tag3') 498 | 499 | entry1.tags.add(tag1, tag2, tag3) 500 | 501 | entry2.tags.add(tag2) 502 | 503 | entry3.tags.add(tag2, tag3) 504 | 505 | tags = Tag.objects.batch_select('entry')\ 506 | .filter(id__in=[tag1.id, tag2.id, tag3.id])\ 507 | .order_by('id') 508 | tags = list(tags) 509 | 510 | self.failUnlessEqual([tag1, tag2, tag3], tags) 511 | 512 | tag1, tag2, tag3 = tags 513 | 514 | self.failUnlessEqual(set([entry1]), set(tag1.entry_all)) 515 | self.failUnlessEqual(set([entry1, entry2, entry3]), 516 | set(tag2.entry_all)) 517 | self.failUnlessEqual(set([entry1, entry3]), set(tag3.entry_all)) 518 | 519 | def test_non_id_primary_key(self): 520 | uk = Country.objects.create(name='United Kingdom') 521 | brighton = Location.objects.create(name='Brighton') 522 | hove = Location.objects.create(name='Hove') 523 | 524 | uk.locations.add(brighton, hove) 525 | 526 | countries = Country.objects.batch_select('locations') 527 | countries = list(countries) 528 | 529 | self.failUnlessEqual([uk], countries) 530 | 531 | uk = countries[0] 532 | self.failUnlessEqual(set([brighton, hove]), set(uk.locations_all)) 533 | 534 | @with_debug_queries 535 | def test_batch_nested(self): 536 | section1 = Section.objects.create(name='s1') 537 | 538 | entry1 = Entry.objects.create(section=section1) 539 | entry2 = Entry.objects.create(section=section1) 540 | 541 | tag1, tag2, tag3 = _create_tags('tag1', 'tag2', 'tag3') 542 | 543 | entry1.tags.add(tag1, tag3) 544 | entry2.tags.add(tag2) 545 | 546 | db.reset_queries() 547 | 548 | entry_batch = Batch('entry_set').batch_select('tags') 549 | sections = Section.objects.batch_select(entries=entry_batch) 550 | sections = list(sections) 551 | section1 = sections[0] 552 | section1_tags = [tag for entry in section1.entries 553 | for tag in entry.tags_all] 554 | 555 | self.failUnlessEqual(set([tag1, tag2, tag3]), set(section1_tags)) 556 | self.failUnlessEqual(3, len(db.connection.queries)) 557 | 558 | 559 | class ReplayTestCase(unittest.TestCase): 560 | 561 | def setUp(self): 562 | class ReplayTest(Replay): 563 | __replayable__ = ('lower', 'upper', 'replace') 564 | self.klass = ReplayTest 565 | self.instance = ReplayTest() 566 | 567 | def test_replayable_methods_present_on_class(self): 568 | self.failIf( getattr(self.klass, 'lower', None) is None ) 569 | self.failIf( getattr(self.klass, 'upper', None) is None ) 570 | self.failIf( getattr(self.klass, 'replace', None) is None ) 571 | 572 | def test_replayable_methods_present_on_instance(self): 573 | self.failIf( getattr(self.instance, 'lower', None) is None ) 574 | self.failIf( getattr(self.instance, 'upper', None) is None ) 575 | self.failIf( getattr(self.instance, 'replace', None) is None ) 576 | 577 | def test_replay_methods_recorded(self): 578 | r = self.instance 579 | self.failUnlessEqual([], r._replays) 580 | 581 | self.failIf(r == r.upper()) 582 | 583 | self.failUnlessEqual([('upper', (), {})], r.upper()._replays) 584 | self.failUnlessEqual([('lower', (), {})], r.lower()._replays) 585 | self.failUnlessEqual([('replace', (), {})], r.replace()._replays) 586 | 587 | self.failUnlessEqual([('upper', (1,), {})], r.upper(1)._replays) 588 | self.failUnlessEqual([('upper', (1,), {'param': 's'})], r.upper(1, param='s')._replays) 589 | 590 | self.failUnlessEqual([('upper', (), {'name__contains': 'test'}), 591 | ('replace', ('id',), {})], 592 | r.upper(name__contains='test').replace('id')._replays) 593 | 594 | def test_replay_no_replay(self): 595 | r = self.instance 596 | s = 'gfjhGF&' 597 | self.failUnlessEqual(s, r.replay(s)) 598 | 599 | def test_replay_single_call(self): 600 | r = self.instance.upper() 601 | self.failUnlessEqual('MYWORD', r.replay('MyWord')) 602 | 603 | r = self.instance.lower() 604 | self.failUnlessEqual('myword', r.replay('MyWord')) 605 | 606 | r = self.instance.replace('a', 'b') 607 | self.failUnlessEqual('bbb', r.replay('aaa')) 608 | 609 | r = self.instance.replace('a', 'b', 1) 610 | self.failUnlessEqual('baa', r.replay('aaa')) 611 | 612 | class QuotingTestCase(TransactionTestCase): 613 | """Ensure correct quoting of table and field names in queries""" 614 | 615 | @with_debug_queries 616 | def test_uses_backend_specific_quoting(self): 617 | """Backend-specific quotes should be used 618 | 619 | Table and field names should be quoted with the quote_name 620 | function provided by the database backend. The test here 621 | is a bit trivial since a real-life test case with 622 | PostgreSQL schema tricks or other table/field name munging 623 | would be difficult. 624 | """ 625 | qn = db.connection.ops.quote_name 626 | qs = _select_related_instances(Entry, 'section', [1], 627 | 'batch_select_entry', 'section_id') 628 | db.reset_queries() 629 | list(qs) 630 | sql = db.connection.queries[-1]['sql'] 631 | self.failUnless(sql.startswith('SELECT (%s.%s) AS ' %( 632 | qn('batch_select_entry'), qn('section_id')))) 633 | 634 | @with_debug_queries 635 | def test_batch_select_related_quoted_section_id(self): 636 | """Field names should be quoted in the WHERE clause 637 | 638 | PostgreSQL is particularly picky about quoting when table 639 | or field names contain mixed case 640 | """ 641 | section = Section.objects.create(name='s1') 642 | entry = Entry.objects.create(section=section) 643 | 644 | db.reset_queries() 645 | sections = Section.objects.batch_select('entry').all() 646 | sections[0] 647 | sql = db.connection.queries[-1]['sql'] 648 | correct_where = ' WHERE "batch_select_entry"."section_id" IN (1)' 649 | self.failUnless(sql.endswith(correct_where), 650 | '"section_id" is not correctly quoted in the WHERE ' 651 | 'clause of %r' % sql) 652 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='django-batch-select', 5 | version=__import__('batch_select').__version__, 6 | description='batch select many-to-many and one-to-many fields (to help avoid n+1 query problem)', 7 | long_description=open('README.rst').read(), 8 | author='John Montgomery', 9 | author_email='john@littlespikeyland.com', 10 | url='http://github.com/lilspikey/django-batch-select/', 11 | download_url='http://github.com/lilspikey/django-batch-select/downloads', 12 | license='BSD', 13 | packages=find_packages(exclude=['ez_setup']), 14 | include_package_data=True, 15 | zip_safe=True, 16 | classifiers=[ 17 | 'Development Status :: 4 - Beta', 18 | 'Environment :: Web Environment', 19 | 'Framework :: Django', 20 | 'Intended Audience :: Developers', 21 | 'License :: OSI Approved :: BSD License', 22 | 'Operating System :: OS Independent', 23 | 'Programming Language :: Python', 24 | 'Topic :: Software Development :: Libraries :: Python Modules', 25 | ], 26 | ) -------------------------------------------------------------------------------- /tests/run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # run from parent directory (e.g. tests/run_tests.sh) 3 | django-admin.py test batch_select --pythonpath=. --pythonpath=tests --settings=test_settings -------------------------------------------------------------------------------- /tests/test_settings.py: -------------------------------------------------------------------------------- 1 | 2 | DEBUG = True 3 | TEMPLATE_DEBUG = DEBUG 4 | 5 | # Django 1.2 and less 6 | DATABASE_ENGINE = 'sqlite3' 7 | DATABASE_NAME = ':memory:' 8 | # Django 1.3 and above 9 | DATABASES = { 10 | 'default': { 11 | 'NAME': DATABASE_NAME, 12 | 'ENGINE': 'django.db.backends.sqlite3', 13 | }, 14 | } 15 | 16 | INSTALLED_APPS = ( 'batch_select', ) 17 | 18 | 19 | TESTING_BATCH_SELECT=True 20 | 21 | # enable this for coverage (using django test coverage 22 | # http://pypi.python.org/pypi/django-test-coverage ) 23 | #TEST_RUNNER = 'django-test-coverage.runner.run_tests' 24 | #COVERAGE_MODULES = ('batch_select.models', 'batch_select.replay') --------------------------------------------------------------------------------