├── AUTHORS ├── LICENSE ├── MANIFEST.in ├── README.rst ├── django_elasticsearch ├── __init__.py ├── base.py ├── compiler.py ├── creation.py ├── fields.py ├── manager.py ├── mapping.py ├── models.py ├── router.py ├── serializer.py ├── south.py └── utils.py ├── setup.py └── tests └── testproj ├── __init__.py ├── manage.py ├── mixed ├── __init__.py ├── models.py ├── tests.py └── views.py ├── myapp ├── __init__.py ├── models.py ├── tests.py └── views.py ├── settings.py └── tests.py /AUTHORS: -------------------------------------------------------------------------------- 1 | Alberto Paro 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2009, Ask Solem 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, 8 | this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above copyright 10 | notice, this list of conditions and the following disclaimer in the 11 | documentation and/or other materials provided with the distribution. 12 | 13 | Neither the name of Ask Solem nor the names of its contributors may be used 14 | to endorse or promote products derived from this software without specific 15 | prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 19 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 20 | PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS 21 | BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | POSSIBILITY OF SUCH DAMAGE. 28 | 29 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include MANIFEST.in 2 | include README.rst 3 | include AUTHORS 4 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | =========================== 2 | Django ElasticSearch Engine 3 | =========================== 4 | :Info: It's a database backend that adds elasticsearch support to django 5 | :Author: Alberto [aparo] Paro (http://github.com/aparo) 6 | 7 | Requirements 8 | ------------ 9 | 10 | - Django non rel http://github.com/aparo/django-nonrel 11 | - Djangotoolbox http://github.com/aparo/djangotoolbox 12 | - pyes http://github.com/aparo/pyes 13 | 14 | 15 | About Django 16 | ============ 17 | 18 | Django is a high-level Python Web framework that encourages rapid development and clean, pragmatic design. 19 | 20 | About ElasticSearch 21 | =================== 22 | 23 | TODO 24 | 25 | Infographics 26 | ============ 27 | :: 28 | - Django Nonrel branch 29 | - Manager 30 | - Compiler (ElasticSearch Engine one) 31 | - ElasticSearch 32 | 33 | django-elasticsearch uses the new django1.2 multi-database support and sets to the model the database using the "django_elasticsearch". 34 | 35 | Examples 36 | ======== 37 | 38 | :: 39 | 40 | class Person(models.Model): 41 | name = models.CharField(max_length=20) 42 | surname = models.CharField(max_length=20) 43 | age = models.IntegerField(null=True, blank=True) 44 | 45 | def __unicode__(self): 46 | return u"Person: %s %s" % (self.name, self.surname) 47 | 48 | >> p, created = Person.objects.get_or_create(name="John", defaults={'surname' : 'Doe'}) 49 | >> print created 50 | True 51 | >> p.age = 22 52 | >> p.save() 53 | 54 | === Querying === 55 | >> p = Person.objects.get(name__istartswith="JOH", age=22) 56 | >> p.pk 57 | u'4bd212d9ccdec2510f000000' 58 | -------------------------------------------------------------------------------- /django_elasticsearch/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import logging 5 | logger = logging.getLogger(__name__) 6 | 7 | VERSION = (0, 1, 0) 8 | 9 | __version__ = ".".join(map(str, VERSION[0:3])) + "".join(VERSION[3:]) 10 | __author__ = "Alberto Paro" 11 | __contact__ = "alberto.paro@gmail.com" 12 | __homepage__ = "http://github.com/aparo/django-elasticsearch/" 13 | __docformat__ = "restructuredtext" 14 | 15 | from django.conf import settings 16 | 17 | if not "django_elasticsearch" in settings.INSTALLED_APPS: 18 | settings.INSTALLED_APPS.insert(0, "django_elasticsearch") -------------------------------------------------------------------------------- /django_elasticsearch/base.py: -------------------------------------------------------------------------------- 1 | from django.core.exceptions import ImproperlyConfigured 2 | 3 | from .creation import DatabaseCreation 4 | from .serializer import Decoder, Encoder 5 | from pyes import ES 6 | 7 | from djangotoolbox.db.base import NonrelDatabaseFeatures, \ 8 | NonrelDatabaseWrapper, NonrelDatabaseClient, \ 9 | NonrelDatabaseValidation, NonrelDatabaseIntrospection 10 | 11 | from djangotoolbox.db.base import NonrelDatabaseOperations 12 | 13 | class DatabaseOperations(NonrelDatabaseOperations): 14 | compiler_module = __name__.rsplit('.', 1)[0] + '.compiler' 15 | 16 | def sql_flush(self, style, tables, sequence_list): 17 | for table in tables: 18 | self.connection.db_connection.delete_mapping(self.connection.db_name, table) 19 | return [] 20 | 21 | def check_aggregate_support(self, aggregate): 22 | """ 23 | This function is meant to raise exception if backend does 24 | not support aggregation. 25 | """ 26 | pass 27 | 28 | class DatabaseFeatures(NonrelDatabaseFeatures): 29 | string_based_auto_field = True 30 | 31 | class DatabaseClient(NonrelDatabaseClient): 32 | pass 33 | 34 | class DatabaseValidation(NonrelDatabaseValidation): 35 | pass 36 | 37 | class DatabaseIntrospection(NonrelDatabaseIntrospection): 38 | def table_names(self): 39 | """ 40 | Show defined models 41 | """ 42 | # TODO: get indices 43 | return [] 44 | 45 | def sequence_list(self): 46 | # TODO: check if it's necessary to implement that 47 | pass 48 | 49 | class DatabaseWrapper(NonrelDatabaseWrapper): 50 | def _cursor(self): 51 | self._ensure_is_connected() 52 | return self._connection 53 | 54 | def __init__(self, *args, **kwds): 55 | super(DatabaseWrapper, self).__init__(*args, **kwds) 56 | self.features = DatabaseFeatures(self) 57 | self.ops = DatabaseOperations(self) 58 | self.client = DatabaseClient(self) 59 | self.creation = DatabaseCreation(self) 60 | self.validation = DatabaseValidation(self) 61 | self.introspection = DatabaseIntrospection(self) 62 | self._is_connected = False 63 | 64 | @property 65 | def db_connection(self): 66 | self._ensure_is_connected() 67 | return self._db_connection 68 | 69 | def _ensure_is_connected(self): 70 | if not self._is_connected: 71 | try: 72 | port = int(self.settings_dict['PORT']) 73 | except ValueError: 74 | raise ImproperlyConfigured("PORT must be an integer") 75 | 76 | self.db_name = self.settings_dict['NAME'] 77 | 78 | self._connection = ES("%s:%s" % (self.settings_dict['HOST'], port), 79 | decoder=Decoder, 80 | encoder=Encoder, 81 | autorefresh=True, 82 | default_indices=[self.db_name]) 83 | 84 | self._db_connection = self._connection 85 | #auto index creation: check if to remove 86 | try: 87 | self._connection.create_index(self.db_name) 88 | except: 89 | pass 90 | # We're done! 91 | self._is_connected = True 92 | -------------------------------------------------------------------------------- /django_elasticsearch/compiler.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import re 3 | 4 | from datetime import datetime 5 | from functools import wraps 6 | 7 | from django.conf import settings 8 | from django.db import models 9 | from django.db.models.sql import aggregates as sqlaggregates 10 | from django.db.models.sql.compiler import SQLCompiler 11 | from django.db.models.sql import aggregates as sqlaggregates 12 | from django.db.models.sql.constants import LOOKUP_SEP, MULTI, SINGLE 13 | from django.db.models.sql.where import AND, OR 14 | from django.db.utils import DatabaseError, IntegrityError 15 | from django.db.models.sql.where import WhereNode 16 | from django.db.models.fields import NOT_PROVIDED 17 | from django.utils.tree import Node 18 | from pyes import MatchAllQuery, FilteredQuery, BoolQuery, StringQuery, \ 19 | WildcardQuery, RegexTermQuery, RangeQuery, ESRange, \ 20 | TermQuery, ConstantScoreQuery, TermFilter, TermsFilter, NotFilter, RegexTermFilter 21 | from djangotoolbox.db.basecompiler import NonrelQuery, NonrelCompiler, \ 22 | NonrelInsertCompiler, NonrelUpdateCompiler, NonrelDeleteCompiler 23 | from django.db.models.fields import AutoField 24 | import logging 25 | 26 | TYPE_MAPPING_FROM_DB = { 27 | 'unicode': lambda val: unicode(val), 28 | 'int': lambda val: int(val), 29 | 'float': lambda val: float(val), 30 | 'bool': lambda val: bool(val), 31 | } 32 | 33 | TYPE_MAPPING_TO_DB = { 34 | 'unicode': lambda val: unicode(val), 35 | 'int': lambda val: int(val), 36 | 'float': lambda val: float(val), 37 | 'bool': lambda val: bool(val), 38 | 'date': lambda val: datetime(val.year, val.month, val.day), 39 | 'time': lambda val: datetime(2000, 1, 1, val.hour, val.minute, 40 | val.second, val.microsecond), 41 | } 42 | 43 | OPERATORS_MAP = { 44 | 'exact': lambda val: val, 45 | 'iexact': lambda val: val, #tofix 46 | 'startswith': lambda val: r'^%s' % re.escape(val), 47 | 'istartswith': lambda val: r'^%s' % re.escape(val), 48 | 'endswith': lambda val: r'%s$' % re.escape(val), 49 | 'iendswith': lambda val: r'%s$' % re.escape(val), 50 | 'contains': lambda val: r'%s' % re.escape(val), 51 | 'icontains': lambda val: r'%s' % re.escape(val), 52 | 'regex': lambda val: val, 53 | 'iregex': lambda val: re.compile(val, re.IGNORECASE), 54 | 'gt': lambda val: {"_from" : val, "include_lower" : False}, 55 | 'gte': lambda val: {"_from" : val, "include_lower" : True}, 56 | 'lt': lambda val: {"_to" : val, "include_upper": False}, 57 | 'lte': lambda val: {"_to" : val, "include_upper": True}, 58 | 'range': lambda val: {"_from" : val[0], "_to" : val[1], "include_lower" : True, "include_upper": True}, 59 | 'year': lambda val: {"_from" : val[0], "_to" : val[1], "include_lower" : True, "include_upper": False}, 60 | 'isnull': lambda val: None if val else {'$ne': None}, 61 | 'in': lambda val: val, 62 | } 63 | 64 | NEGATED_OPERATORS_MAP = { 65 | 'exact': lambda val: {'$ne': val}, 66 | 'gt': lambda val: {"_to" : val, "include_upper": True}, 67 | 'gte': lambda val: {"_to" : val, "include_upper": False}, 68 | 'lt': lambda val: {"_from" : val, "include_lower" : True}, 69 | 'lte': lambda val: {"_from" : val, "include_lower" : False}, 70 | 'isnull': lambda val: {'$ne': None} if val else None, 71 | 'in': lambda val: {'$nin': val}, 72 | } 73 | 74 | def _get_mapping(db_type, value, mapping): 75 | # TODO - comments. lotsa comments 76 | 77 | if value == NOT_PROVIDED: 78 | return None 79 | 80 | if value is None: 81 | return None 82 | 83 | if db_type in mapping: 84 | _func = mapping[db_type] 85 | else: 86 | _func = lambda val: val 87 | # TODO - what if the data is represented as list on the python side? 88 | if isinstance(value, list): 89 | return map(_func, value) 90 | 91 | return _func(value) 92 | 93 | def python2db(db_type, value): 94 | return _get_mapping(db_type, value, TYPE_MAPPING_TO_DB) 95 | 96 | def db2python(db_type, value): 97 | return _get_mapping(db_type, value, TYPE_MAPPING_FROM_DB) 98 | 99 | def safe_call(func): 100 | @wraps(func) 101 | def _func(*args, **kwargs): 102 | try: 103 | return func(*args, **kwargs) 104 | except Exception, e: 105 | import traceback 106 | traceback.print_exc() 107 | raise DatabaseError, DatabaseError(str(e)), sys.exc_info()[2] 108 | return _func 109 | 110 | class DBQuery(NonrelQuery): 111 | # ---------------------------------------------- 112 | # Public API 113 | # ---------------------------------------------- 114 | def __init__(self, compiler, fields): 115 | super(DBQuery, self).__init__(compiler, fields) 116 | self._connection = self.connection.db_connection 117 | self._ordering = [] 118 | self.db_query = ConstantScoreQuery() 119 | 120 | # This is needed for debugging 121 | def __repr__(self): 122 | return '' % (self.db_query, self._ordering) 123 | 124 | @safe_call 125 | def fetch(self, low_mark, high_mark): 126 | results = self._get_results() 127 | 128 | if low_mark > 0: 129 | results = results[low_mark:] 130 | if high_mark is not None: 131 | results = results[low_mark:high_mark - low_mark] 132 | 133 | for hit in results: 134 | entity = hit.get_data() 135 | entity['id'] = hit.meta.id 136 | yield entity 137 | 138 | @safe_call 139 | def count(self, limit=None): 140 | query = self.db_query 141 | if self.db_query.is_empty(): 142 | query = MatchAllQuery() 143 | 144 | res = self._connection.count(query, doc_types=self.query.model._meta.db_table) 145 | return res["count"] 146 | 147 | @safe_call 148 | def delete(self): 149 | self._collection.remove(self.db_query) 150 | 151 | @safe_call 152 | def order_by(self, ordering): 153 | for order in ordering: 154 | if order.startswith('-'): 155 | order, direction = order[1:], {"reverse" : True} 156 | else: 157 | direction = 'desc' 158 | self._ordering.append({order: direction}) 159 | 160 | # This function is used by the default add_filters() implementation 161 | @safe_call 162 | def add_filter(self, column, lookup_type, negated, db_type, value): 163 | if column == self.query.get_meta().pk.column: 164 | column = '_id' 165 | # Emulated/converted lookups 166 | 167 | if negated and lookup_type in NEGATED_OPERATORS_MAP: 168 | op = NEGATED_OPERATORS_MAP[lookup_type] 169 | negated = False 170 | else: 171 | op = OPERATORS_MAP[lookup_type] 172 | value = op(self.convert_value_for_db(db_type, value)) 173 | 174 | queryf = self._get_query_type(column, lookup_type, db_type, value) 175 | 176 | if negated: 177 | self.db_query.add([NotFilter(queryf)]) 178 | else: 179 | self.db_query.add([queryf]) 180 | 181 | def _get_query_type(self, column, lookup_type, db_type, value): 182 | if db_type == "unicode": 183 | if (lookup_type == "exact" or lookup_type == "iexact"): 184 | q = TermQuery(column, value) 185 | return q 186 | if (lookup_type == "startswith" or lookup_type == "istartswith"): 187 | return RegexTermFilter(column, value) 188 | if (lookup_type == "endswith" or lookup_type == "iendswith"): 189 | return RegexTermFilter(column, value) 190 | if (lookup_type == "contains" or lookup_type == "icontains"): 191 | return RegexTermFilter(column, value) 192 | if (lookup_type == "regex" or lookup_type == "iregex"): 193 | return RegexTermFilter(column, value) 194 | 195 | if db_type == "datetime" or db_type == "date": 196 | if (lookup_type == "exact" or lookup_type == "iexact"): 197 | return TermFilter(column, value) 198 | 199 | #TermFilter, TermsFilter 200 | if lookup_type in ["gt", "gte", "lt", "lte", "range", "year"]: 201 | value['field'] = column 202 | return RangeQuery(ESRange(**value)) 203 | if lookup_type == "in": 204 | # terms = [TermQuery(column, val) for val in value] 205 | # if len(terms) == 1: 206 | # return terms[0] 207 | # return BoolQuery(should=terms) 208 | return TermsFilter(field=column, values=value) 209 | raise NotImplemented 210 | 211 | def _get_results(self): 212 | """ 213 | @returns: elasticsearch iterator over results 214 | defined by self.query 215 | """ 216 | query = self.db_query 217 | if self.db_query.is_empty(): 218 | query = MatchAllQuery() 219 | if self._ordering: 220 | query.sort = self._ordering 221 | #print "query", self.query.tables, query 222 | return self._connection.search(query, indices=[self.connection.db_name], doc_types=self.query.model._meta.db_table) 223 | 224 | class SQLCompiler(NonrelCompiler): 225 | """ 226 | A simple query: no joins, no distinct, etc. 227 | """ 228 | query_class = DBQuery 229 | 230 | def convert_value_from_db(self, db_type, value): 231 | # Handle list types 232 | if db_type is not None and \ 233 | isinstance(value, (list, tuple)) and len(value) and \ 234 | db_type.startswith('ListField:'): 235 | db_sub_type = db_type.split(':', 1)[1] 236 | value = [self.convert_value_from_db(db_sub_type, subvalue) 237 | for subvalue in value] 238 | else: 239 | value = db2python(db_type, value) 240 | return value 241 | 242 | # This gets called for each field type when you insert() an entity. 243 | # db_type is the string that you used in the DatabaseCreation mapping 244 | def convert_value_for_db(self, db_type, value): 245 | if db_type is not None and \ 246 | isinstance(value, (list, tuple)) and len(value) and \ 247 | db_type.startswith('ListField:'): 248 | db_sub_type = db_type.split(':', 1)[1] 249 | value = [self.convert_value_for_db(db_sub_type, subvalue) 250 | for subvalue in value] 251 | else: 252 | value = python2db(db_type, value) 253 | return value 254 | 255 | def insert_params(self): 256 | conn = self.connection 257 | 258 | params = { 259 | 'safe': conn.safe_inserts, 260 | } 261 | 262 | if conn.w: 263 | params['w'] = conn.w 264 | 265 | return params 266 | 267 | def _get_ordering(self): 268 | if not self.query.default_ordering: 269 | ordering = self.query.order_by 270 | else: 271 | ordering = self.query.order_by or self.query.get_meta().ordering 272 | result = [] 273 | for order in ordering: 274 | if LOOKUP_SEP in order: 275 | #raise DatabaseError("Ordering can't span tables on non-relational backends (%s)" % order) 276 | print "Ordering can't span tables on non-relational backends (%s):skipping" % order 277 | continue 278 | if order == '?': 279 | raise DatabaseError("Randomized ordering isn't supported by the backend") 280 | 281 | order = order.lstrip('+') 282 | 283 | descending = order.startswith('-') 284 | name = order.lstrip('-') 285 | if name == 'pk': 286 | name = self.query.get_meta().pk.name 287 | order = '-' + name if descending else name 288 | 289 | if self.query.standard_ordering: 290 | result.append(order) 291 | else: 292 | if descending: 293 | result.append(name) 294 | else: 295 | result.append('-' + name) 296 | return result 297 | 298 | 299 | class SQLInsertCompiler(NonrelInsertCompiler, SQLCompiler): 300 | @safe_call 301 | def insert(self, data, return_id=False): 302 | pk_column = self.query.get_meta().pk.column 303 | pk = None 304 | if pk_column in data: 305 | pk = data[pk_column] 306 | db_table = self.query.get_meta().db_table 307 | logging.debug("Insert data %s: %s" % (db_table, data)) 308 | #print("Insert data %s: %s" % (db_table, data)) 309 | res = self.connection.db_connection.index(data, self.connection.db_name, db_table, id=pk) 310 | #print "Insert result", res 311 | return res['_id'] 312 | 313 | # TODO: Define a common nonrel API for updates and add it to the nonrel 314 | # backend base classes and port this code to that API 315 | class SQLUpdateCompiler(SQLCompiler): 316 | def execute_sql(self, return_id=False): 317 | """ 318 | self.query - the data that should be inserted 319 | """ 320 | data = {} 321 | for (field, value), column in zip(self.query.values, self.query.columns): 322 | data[column] = python2db(field.db_type(connection=self.connection), value) 323 | # every object should have a unique pk 324 | pk_field = self.query.model._meta.pk 325 | pk_name = pk_field.attname 326 | 327 | db_table = self.query.get_meta().db_table 328 | res = self.connection.db_connection.index(data, self.connection.db_name, db_table, id=pk) 329 | 330 | return res['_id'] 331 | 332 | class SQLDeleteCompiler(NonrelDeleteCompiler, SQLCompiler): 333 | def execute_sql(self, return_id=False): 334 | """ 335 | self.query - the data that should be inserted 336 | """ 337 | db_table = self.query.get_meta().db_table 338 | if len(self.query.where.children) == 1 and isinstance(self.query.where.children[0][0].field, AutoField) and self.query.where.children[0][1] == "in": 339 | for pk in self.query.where.children[0][3]: 340 | self.connection.db_connection.delete(self.connection.db_name, db_table, pk) 341 | return 342 | -------------------------------------------------------------------------------- /django_elasticsearch/creation.py: -------------------------------------------------------------------------------- 1 | from djangotoolbox.db.base import NonrelDatabaseCreation 2 | from pyes.exceptions import NotFoundException 3 | TEST_DATABASE_PREFIX = 'test_' 4 | 5 | class DatabaseCreation(NonrelDatabaseCreation): 6 | data_types = { 7 | 'DateTimeField': 'datetime', 8 | 'DateField': 'date', 9 | 'TimeField': 'time', 10 | 'FloatField': 'float', 11 | 'EmailField': 'unicode', 12 | 'URLField': 'unicode', 13 | 'BooleanField': 'bool', 14 | 'NullBooleanField': 'bool', 15 | 'CharField': 'unicode', 16 | 'CommaSeparatedIntegerField': 'unicode', 17 | 'IPAddressField': 'unicode', 18 | 'SlugField': 'unicode', 19 | 'FileField': 'unicode', 20 | 'FilePathField': 'unicode', 21 | 'TextField': 'unicode', 22 | 'XMLField': 'unicode', 23 | 'IntegerField': 'int', 24 | 'SmallIntegerField': 'int', 25 | 'PositiveIntegerField': 'int', 26 | 'PositiveSmallIntegerField': 'int', 27 | 'BigIntegerField': 'int', 28 | 'GenericAutoField': 'unicode', 29 | 'StringForeignKey': 'unicode', 30 | 'AutoField': 'unicode', 31 | 'RelatedAutoField': 'unicode', 32 | 'OneToOneField': 'int', 33 | 'DecimalField': 'float', 34 | } 35 | 36 | def sql_indexes_for_field(self, model, f, style): 37 | """Not required. In ES all is index!!""" 38 | return [] 39 | 40 | def index_fields_group(self, model, group, style): 41 | """Not required. In ES all is index!!""" 42 | return [] 43 | 44 | def sql_indexes_for_model(self, model, style): 45 | """Not required. In ES all is index!!""" 46 | return [] 47 | 48 | def sql_create_model(self, model, style, known_models=set()): 49 | from mapping import model_to_mapping 50 | mappings = model_to_mapping(model) 51 | self.connection.db_connection.put_mapping(model._meta.db_table, {mappings.name:mappings.as_dict()}) 52 | return [], {} 53 | 54 | def set_autocommit(self): 55 | "Make sure a connection is in autocommit mode." 56 | pass 57 | 58 | def create_test_db(self, verbosity=1, autoclobber=False): 59 | # No need to create databases in mongoDB :) 60 | # but we can make sure that if the database existed is emptied 61 | from django.core.management import call_command 62 | if self.connection.settings_dict.get('TEST_NAME'): 63 | test_database_name = self.connection.settings_dict['TEST_NAME'] 64 | elif 'NAME' in self.connection.settings_dict: 65 | test_database_name = TEST_DATABASE_PREFIX + self.connection.settings_dict['NAME'] 66 | elif 'DATABASE_NAME' in self.connection.settings_dict: 67 | if self.connection.settings_dict['DATABASE_NAME'].startswith(TEST_DATABASE_PREFIX): 68 | # already been set up 69 | # must be because this is called from a setUp() instead of something formal. 70 | # suspect this Django 1.1 71 | test_database_name = self.connection.settings_dict['DATABASE_NAME'] 72 | else: 73 | test_database_name = TEST_DATABASE_PREFIX + \ 74 | self.connection.settings_dict['DATABASE_NAME'] 75 | else: 76 | raise ValueError("Name for test database not defined") 77 | 78 | self.connection.settings_dict['NAME'] = test_database_name 79 | # This is important. Here we change the settings so that all other code 80 | # things that the chosen database is now the test database. This means 81 | # that nothing needs to change in the test code for working with 82 | # connections, databases and collections. It will appear the same as 83 | # when working with non-test code. 84 | 85 | # In this phase it will only drop the database if it already existed 86 | # which could potentially happen if the test database was created but 87 | # was never dropped at the end of the tests 88 | try: 89 | self._drop_database(test_database_name) 90 | except NotFoundException: 91 | pass 92 | 93 | self.connection.db_connection.create_index(test_database_name) 94 | self.connection.db_connection.cluster_health(wait_for_status='green') 95 | 96 | call_command('syncdb', verbosity=max(verbosity - 1, 0), interactive=False, database=self.connection.alias) 97 | 98 | 99 | def destroy_test_db(self, old_database_name, verbosity=1): 100 | """ 101 | Destroy a test database, prompting the user for confirmation if the 102 | database already exists. Returns the name of the test database created. 103 | """ 104 | if verbosity >= 1: 105 | print "Destroying test database '%s'..." % self.connection.alias 106 | test_database_name = self.connection.settings_dict['NAME'] 107 | self._drop_database(test_database_name) 108 | self.connection.settings_dict['NAME'] = old_database_name 109 | 110 | def _drop_database(self, database_name): 111 | try: 112 | self.connection.db_connection.delete_index(database_name) 113 | except NotFoundException: 114 | pass 115 | self.connection.db_connection.cluster_health(wait_for_status='green') 116 | 117 | def sql_destroy_model(self, model, references_to_delete, style): 118 | print model 119 | -------------------------------------------------------------------------------- /django_elasticsearch/fields.py: -------------------------------------------------------------------------------- 1 | import django 2 | from django.conf import settings 3 | from django.db import models 4 | from django.core import exceptions, serializers 5 | from django.db.models import Field, CharField 6 | from django.db.models.fields import FieldDoesNotExist 7 | from django.utils.translation import ugettext_lazy as _ 8 | from django.db.models.fields import AutoField as DJAutoField 9 | from django.db.models import signals 10 | import uuid 11 | from .manager import Manager 12 | __all__ = ["EmbeddedModel"] 13 | __doc__ = "ES special fields" 14 | 15 | class EmbeddedModel(models.Model): 16 | _embedded_in = None 17 | class Meta: 18 | abstract = True 19 | 20 | def save(self, *args, **kwargs): 21 | if self.pk is None: 22 | self.pk = str(uuid.uuid4()) 23 | if self._embedded_in is None: 24 | raise RuntimeError("Invalid save") 25 | self._embedded_in.save() 26 | 27 | def serialize(self): 28 | if self.pk is None: 29 | self.pk = "TODO" 30 | self.id = self.pk 31 | result = {'_app':self._meta.app_label, 32 | '_model':self._meta.module_name, 33 | '_id':self.pk} 34 | for field in self._meta.fields: 35 | result[field.attname] = getattr(self, field.attname) 36 | return result 37 | 38 | class ElasticField(CharField): 39 | 40 | def __init__(self, *args, **kwargs): 41 | self.doc_type = kwargs.pop("doc_type", None) 42 | 43 | # This field stores the document id and has to be unique 44 | kwargs["unique"] = True 45 | 46 | # Let's force the field as db_index so we can get its value faster. 47 | kwargs["db_index"] = True 48 | kwargs["max_length"] = 255 49 | 50 | super(ElasticField, self).__init__(*args, **kwargs) 51 | 52 | def contribute_to_class(self, cls, name): 53 | super(ElasticField, self).contribute_to_class(cls, name) 54 | 55 | 56 | index = cls._meta.db_table 57 | doc_type = self.doc_type 58 | att_id_name = "_%s_id" % name 59 | att_cache_name = "_%s_cache" % name 60 | att_val_name = "_%s_val" % name 61 | 62 | def _get(self): 63 | """ 64 | self is the model instance not the field instance 65 | """ 66 | from django.db import connections 67 | elst = connections[self._meta.elst_connection] 68 | if not hasattr(self, att_cache_name) and not getattr(self, att_val_name, None) and getattr(self, att_id_name, None): 69 | elst = ES('http://127.0.0.1:9200/') 70 | val = elst.get(index, doc_type, id=getattr(self, att_id_name)).get("_source", None) 71 | setattr(self, att_cache_name, val) 72 | setattr(self, att_val_name, val) 73 | return getattr(self, att_val_name, None) 74 | 75 | def _set(self, val): 76 | """ 77 | self is the model instance not the field instance 78 | """ 79 | if isinstance(val, basestring) and not hasattr(self, att_id_name): 80 | setattr(self, att_id_name, val) 81 | else: 82 | setattr(self, att_val_name, val or None) 83 | 84 | setattr(cls, self.attname, property(_get, _set)) 85 | 86 | 87 | # def db_type(self, connection): 88 | # return "elst" 89 | 90 | def pre_save(self, model_instance, add): 91 | from django.db import connections 92 | elst = connections[model_instance._meta.elst_connection] 93 | 94 | id = getattr(model_instance, "_%s_id" % self.attname, None) 95 | value = getattr(model_instance, "_%s_val" % self.attname, None) 96 | index = model_instance._meta.db_table 97 | doc_type = self.doc_type 98 | 99 | if value == getattr(model_instance, "_%s_cache" % self.attname, None) and id: 100 | return id 101 | 102 | if value: 103 | # elst = ES('http://127.0.0.1:9200/') 104 | result = elst.index(doc=value, index=index, doc_type=doc_type, id=id or None) 105 | setattr(model_instance, "_%s_id" % self.attname, result["_id"]) 106 | setattr(model_instance, "_%s_cache" % self.attname, value) 107 | return getattr(model_instance, "_%s_id" % self.attname, u"") 108 | 109 | # 110 | # Fix standard models to work with elasticsearch 111 | # 112 | 113 | def autofield_to_python(value): 114 | if value is None: 115 | return value 116 | try: 117 | return str(value) 118 | except (TypeError, ValueError): 119 | raise exceptions.ValidationError(self.error_messages['invalid']) 120 | 121 | def autofield_get_prep_value(value): 122 | if value is None: 123 | return None 124 | return unicode(value) 125 | 126 | def pre_init_mongodb_signal(sender, args, **kwargs): 127 | if sender._meta.abstract: 128 | return 129 | 130 | from django.conf import settings 131 | 132 | database = settings.DATABASES[sender.objects.db] 133 | if not 'elasticsearch' in database['ENGINE']: 134 | return 135 | 136 | if not hasattr(django, 'MODIFIED') and isinstance(sender._meta.pk, DJAutoField): 137 | pk = sender._meta.pk 138 | setattr(pk, "to_python", autofield_to_python) 139 | setattr(pk, "get_prep_value", autofield_get_prep_value) 140 | 141 | class ESMeta(object): 142 | pass 143 | 144 | def add_elasticsearch_manager(sender, **kwargs): 145 | """ 146 | Fix autofield 147 | """ 148 | from django.conf import settings 149 | 150 | cls = sender 151 | database = settings.DATABASES[cls.objects.db] 152 | if 'elasticsearch' in database['ENGINE']: 153 | if cls._meta.abstract: 154 | return 155 | 156 | if getattr(cls, 'es', None) is None: 157 | # Create the default manager, if needed. 158 | try: 159 | cls._meta.get_field('es') 160 | raise ValueError("Model %s must specify a custom Manager, because it has a field named 'objects'" % cls.__name__) 161 | except FieldDoesNotExist: 162 | pass 163 | setattr(cls, 'es', Manager()) 164 | 165 | es_meta = getattr(cls, "ESMeta", ESMeta).__dict__.copy() 166 | # setattr(cls, "_meta", ESMeta()) 167 | for attr in es_meta: 168 | if attr.startswith("_"): 169 | continue 170 | setattr(cls._meta, attr, es_meta[attr]) 171 | 172 | -------------------------------------------------------------------------------- /django_elasticsearch/manager.py: -------------------------------------------------------------------------------- 1 | from django.db import connections 2 | from django.db.models.manager import Manager as DJManager 3 | 4 | import re 5 | import copy 6 | from .utils import dict_keys_to_str 7 | try: 8 | from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist 9 | except ImportError: 10 | class ObjectDoesNotExist(Exception): 11 | pass 12 | class MultipleObjectsReturned(Exception): 13 | pass 14 | 15 | DoesNotExist = ObjectDoesNotExist 16 | 17 | __all__ = ['queryset_manager', 'Q', 'InvalidQueryError', 18 | 'InvalidCollectionError'] 19 | 20 | # The maximum number of items to display in a QuerySet.__repr__ 21 | REPR_OUTPUT_SIZE = 20 22 | 23 | class InvalidQueryError(Exception): 24 | pass 25 | 26 | 27 | class OperationError(Exception): 28 | pass 29 | 30 | class InvalidCollectionError(Exception): 31 | pass 32 | 33 | DoesNotExist = ObjectDoesNotExist 34 | RE_TYPE = type(re.compile('')) 35 | 36 | 37 | class Q(object): 38 | 39 | OR = '||' 40 | AND = '&&' 41 | OPERATORS = { 42 | 'eq': 'this.%(field)s == %(value)s', 43 | 'ne': 'this.%(field)s != %(value)s', 44 | 'gt': 'this.%(field)s > %(value)s', 45 | 'gte': 'this.%(field)s >= %(value)s', 46 | 'lt': 'this.%(field)s < %(value)s', 47 | 'lte': 'this.%(field)s <= %(value)s', 48 | 'lte': 'this.%(field)s <= %(value)s', 49 | 'in': '%(value)s.indexOf(this.%(field)s) != -1', 50 | 'nin': '%(value)s.indexOf(this.%(field)s) == -1', 51 | 'mod': '%(field)s %% %(value)s', 52 | 'all': ('%(value)s.every(function(a){' 53 | 'return this.%(field)s.indexOf(a) != -1 })'), 54 | 'size': 'this.%(field)s.length == %(value)s', 55 | 'exists': 'this.%(field)s != null', 56 | 'regex_eq': '%(value)s.test(this.%(field)s)', 57 | 'regex_ne': '!%(value)s.test(this.%(field)s)', 58 | } 59 | 60 | def __init__(self, **query): 61 | self.query = [query] 62 | 63 | def _combine(self, other, op): 64 | obj = Q() 65 | obj.query = ['('] + copy.deepcopy(self.query) + [op] 66 | obj.query += copy.deepcopy(other.query) + [')'] 67 | return obj 68 | 69 | def __or__(self, other): 70 | return self._combine(other, self.OR) 71 | 72 | def __and__(self, other): 73 | return self._combine(other, self.AND) 74 | 75 | def as_js(self, document): 76 | js = [] 77 | js_scope = {} 78 | for i, item in enumerate(self.query): 79 | if isinstance(item, dict): 80 | item_query = QuerySet._transform_query(document, **item) 81 | # item_query will values will either be a value or a dict 82 | js.append(self._item_query_as_js(item_query, js_scope, i)) 83 | else: 84 | js.append(item) 85 | return pymongo.code.Code(' '.join(js), js_scope) 86 | 87 | def _item_query_as_js(self, item_query, js_scope, item_num): 88 | # item_query will be in one of the following forms 89 | # {'age': 25, 'name': 'Test'} 90 | # {'age': {'$lt': 25}, 'name': {'$in': ['Test', 'Example']} 91 | # {'age': {'$lt': 25, '$gt': 18}} 92 | js = [] 93 | for i, (key, value) in enumerate(item_query.items()): 94 | op = 'eq' 95 | # Construct a variable name for the value in the JS 96 | value_name = 'i%sf%s' % (item_num, i) 97 | if isinstance(value, dict): 98 | # Multiple operators for this field 99 | for j, (op, value) in enumerate(value.items()): 100 | # Create a custom variable name for this operator 101 | op_value_name = '%so%s' % (value_name, j) 102 | # Construct the JS that uses this op 103 | value, operation_js = self._build_op_js(op, key, value, 104 | op_value_name) 105 | # Update the js scope with the value for this op 106 | js_scope[op_value_name] = value 107 | js.append(operation_js) 108 | else: 109 | # Construct the JS for this field 110 | value, field_js = self._build_op_js(op, key, value, value_name) 111 | js_scope[value_name] = value 112 | js.append(field_js) 113 | return ' && '.join(js) 114 | 115 | def _build_op_js(self, op, key, value, value_name): 116 | """Substitute the values in to the correct chunk of Javascript. 117 | """ 118 | if isinstance(value, RE_TYPE): 119 | # Regexes are handled specially 120 | if op.strip('$') == 'ne': 121 | op_js = Q.OPERATORS['regex_ne'] 122 | else: 123 | op_js = Q.OPERATORS['regex_eq'] 124 | else: 125 | op_js = Q.OPERATORS[op.strip('$')] 126 | 127 | # Perform the substitution 128 | operation_js = op_js % { 129 | 'field': key, 130 | 'value': value_name 131 | } 132 | return value, operation_js 133 | 134 | class InternalMetadata: 135 | def __init__(self, meta): 136 | self.object_name = meta["object_name"] 137 | 138 | class InternalModel: 139 | """ 140 | An internal queryset model to be embedded in a query set for django compatibility. 141 | """ 142 | def __init__(self, document): 143 | self.document = document 144 | self._meta = InternalMetadata(document._meta) 145 | self.DoesNotExist = ObjectDoesNotExist 146 | 147 | class QuerySet(object): 148 | """A set of results returned from a query. Wraps a ES cursor, 149 | providing :class:`~mongoengine.Document` objects as the results. 150 | """ 151 | 152 | def __init__(self, document, collection): 153 | self._document = document 154 | self._collection_obj = collection 155 | self._accessed_collection = False 156 | self._query = {} 157 | self._where_clause = None 158 | self._loaded_fields = [] 159 | self._ordering = [] 160 | self.transform = TransformDjango() 161 | 162 | # If inheritance is allowed, only return instances and instances of 163 | # subclasses of the class being used 164 | #if document._meta.get('allow_inheritance'): 165 | #self._query = {'_types': self._document._class_name} 166 | self._cursor_obj = None 167 | self._limit = None 168 | self._skip = None 169 | 170 | #required for compatibility with django 171 | #self.model = InternalModel(document) 172 | 173 | def __call__(self, q_obj=None, **query): 174 | """Filter the selected documents by calling the 175 | :class:`~mongoengine.queryset.QuerySet` with a query. 176 | 177 | :param q_obj: a :class:`~mongoengine.queryset.Q` object to be used in 178 | the query; the :class:`~mongoengine.queryset.QuerySet` is filtered 179 | multiple times with different :class:`~mongoengine.queryset.Q` 180 | objects, only the last one will be used 181 | :param query: Django-style query keyword arguments 182 | """ 183 | if q_obj: 184 | self._where_clause = q_obj.as_js(self._document) 185 | query = QuerySet._transform_query(_doc_cls=self._document, **query) 186 | self._query.update(query) 187 | return self 188 | 189 | def filter(self, *q_objs, **query): 190 | """An alias of :meth:`~mongoengine.queryset.QuerySet.__call__` 191 | """ 192 | return self.__call__(*q_objs, **query) 193 | 194 | def find(self, query): 195 | self._query.update(self.transform.transform_incoming(query, self._collection)) 196 | return self 197 | 198 | def exclude(self, *q_objs, **query): 199 | """An alias of :meth:`~mongoengine.queryset.QuerySet.__call__` 200 | """ 201 | query["not"] = True 202 | return self.__call__(*q_objs, **query) 203 | 204 | def all(self): 205 | """An alias of :meth:`~mongoengine.queryset.QuerySet.__call__` 206 | """ 207 | return self.__call__() 208 | 209 | def distinct(self, *args, **kwargs): 210 | """ 211 | Distinct method 212 | """ 213 | return self._cursor.distinct(*args, **kwargs) 214 | 215 | @property 216 | def _collection(self): 217 | """Property that returns the collection object. This allows us to 218 | perform operations only if the collection is accessed. 219 | """ 220 | return self._collection_obj 221 | 222 | def values(self, *args): 223 | return (args and [dict(zip(args,[getattr(doc, key) for key in args])) for doc in self]) or [obj for obj in self._cursor.clone()] 224 | 225 | def values_list(self, *args, **kwargs): 226 | flat = kwargs.pop("flat", False) 227 | if flat and len(args) != 1: 228 | raise Exception("args len must be 1 when flat=True") 229 | 230 | return (flat and self.distinct(args[0] if not args[0] in ["id", "pk"] else "_id")) or zip(*[self.distinct(field if not field in ["id", "pk"] else "_id") for field in args]) 231 | 232 | @property 233 | def _cursor(self): 234 | if self._cursor_obj is None: 235 | cursor_args = {} 236 | if self._loaded_fields: 237 | cursor_args = {'fields': self._loaded_fields} 238 | self._cursor_obj = self._collection.find(self._query, 239 | **cursor_args) 240 | # Apply where clauses to cursor 241 | if self._where_clause: 242 | self._cursor_obj.where(self._where_clause) 243 | 244 | # apply default ordering 245 | # if self._document._meta['ordering']: 246 | # self.order_by(*self._document._meta['ordering']) 247 | 248 | return self._cursor_obj.clone() 249 | 250 | @classmethod 251 | def _lookup_field(cls, document, fields): 252 | """ 253 | Looks for "field" in "document" 254 | """ 255 | if isinstance(fields, (tuple, list)): 256 | return [document._meta.get_field_by_name((field == "pk" and "id") or field)[0] for field in fields] 257 | return document._meta.get_field_by_name((fields == "pk" and "id") or fields)[0] 258 | 259 | @classmethod 260 | def _translate_field_name(cls, doc_cls, field, sep='.'): 261 | """Translate a field attribute name to a database field name. 262 | """ 263 | parts = field.split(sep) 264 | parts = [f.attname for f in QuerySet._lookup_field(doc_cls, parts)] 265 | return '.'.join(parts) 266 | 267 | @classmethod 268 | def _transform_query(self, _doc_cls=None, **parameters): 269 | """ 270 | Converts parameters to elasticsearch queries. 271 | """ 272 | spec = {} 273 | operators = ['ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', 'all', 'size', 'exists'] 274 | match_operators = ['contains', 'icontains', 'startswith', 'istartswith', 'endswith', 'iendswith', 'exact', 'iexact'] 275 | exclude = parameters.pop("not", False) 276 | 277 | for key, value in parameters.items(): 278 | 279 | 280 | parts = key.split("__") 281 | lookup_type = (len(parts)>=2) and ( parts[-1] in operators + match_operators and parts.pop()) or "" 282 | 283 | # Let's get the right field and be sure that it exists 284 | parts[0] = QuerySet._lookup_field(_doc_cls, parts[0]).attname 285 | 286 | if not lookup_type and len(parts)==1: 287 | if exclude: 288 | value = {"$ne" : value} 289 | spec.update({parts[0] : value}) 290 | continue 291 | 292 | if parts[0] == "id": 293 | parts[0] = "_id" 294 | value = [isinstance(par, basestring) or par for par in value] 295 | 296 | if lookup_type in ['contains', 'icontains', 297 | 'startswith', 'istartswith', 298 | 'endswith', 'iendswith', 299 | 'exact', 'iexact']: 300 | flags = 0 301 | if lookup_type.startswith('i'): 302 | flags = re.IGNORECASE 303 | lookup_type = lookup_type.lstrip('i') 304 | 305 | regex = r'%s' 306 | if lookup_type == 'startswith': 307 | regex = r'^%s' 308 | elif lookup_type == 'endswith': 309 | regex = r'%s$' 310 | elif lookup_type == 'exact': 311 | regex = r'^%s$' 312 | 313 | value = re.compile(regex % value, flags) 314 | 315 | elif lookup_type in operators: 316 | value = { "$" + lookup_type : value} 317 | elif lookup_type and len(parts)==1: 318 | raise DatabaseError("Unsupported lookup type: %r" % lookup_type) 319 | 320 | key = '.'.join(parts) 321 | if exclude: 322 | value = {"$ne" : value} 323 | spec.update({key : value}) 324 | 325 | return spec 326 | 327 | def get(self, *q_objs, **query): 328 | """Retrieve the the matching object raising id django is available 329 | :class:`~django.core.exceptions.MultipleObjectsReturned` or 330 | :class:`~django.core.exceptions.ObjectDoesNotExist` exceptions if multiple or 331 | no results are found. 332 | If django is not available: 333 | :class:`~mongoengine.queryset.MultipleObjectsReturned` or 334 | `DocumentName.MultipleObjectsReturned` exception if multiple results and 335 | :class:`~mongoengine.queryset.DoesNotExist` or `DocumentName.DoesNotExist` 336 | if no results are found. 337 | 338 | .. versionadded:: 0.3 339 | """ 340 | self.__call__(*q_objs, **query) 341 | count = self.count() 342 | if count == 1: 343 | return self[0] 344 | elif count > 1: 345 | message = u'%d items returned, instead of 1' % count 346 | raise self._document.MultipleObjectsReturned(message) 347 | else: 348 | raise self._document.DoesNotExist("%s matching query does not exist." 349 | % self._document._meta.object_name) 350 | 351 | def get_or_create(self, *q_objs, **query): 352 | """Retrieve unique object or create, if it doesn't exist. Returns a tuple of 353 | ``(object, created)``, where ``object`` is the retrieved or created object 354 | and ``created`` is a boolean specifying whether a new object was created. Raises 355 | :class:`~mongoengine.queryset.MultipleObjectsReturned` or 356 | `DocumentName.MultipleObjectsReturned` if multiple results are found. 357 | A new document will be created if the document doesn't exists; a 358 | dictionary of default values for the new document may be provided as a 359 | keyword argument called :attr:`defaults`. 360 | 361 | .. versionadded:: 0.3 362 | """ 363 | defaults = query.get('defaults', {}) 364 | if 'defaults' in query: 365 | del query['defaults'] 366 | 367 | self.__call__(*q_objs, **query) 368 | count = self.count() 369 | if count == 0: 370 | query.update(defaults) 371 | doc = self._document(**query) 372 | doc.save() 373 | return doc, True 374 | elif count == 1: 375 | return self.first(), False 376 | else: 377 | message = u'%d items returned, instead of 1' % count 378 | raise self._document.MultipleObjectsReturned(message) 379 | 380 | def first(self): 381 | """Retrieve the first object matching the query. 382 | """ 383 | try: 384 | result = self[0] 385 | except IndexError: 386 | result = None 387 | return result 388 | 389 | def with_id(self, object_id): 390 | """Retrieve the object matching the id provided. 391 | 392 | :param object_id: the value for the id of the document to look up 393 | """ 394 | id_field = self._document._meta['id_field'] 395 | object_id = self._document._fields[id_field].to_mongo(object_id) 396 | 397 | result = self._collection.find_one({'_id': (not isinstance(object_id, ObjectId) and ObjectId(object_id)) or object_id}) 398 | if result is not None: 399 | result = self._document(**dict_keys_to_str(result)) 400 | return result 401 | 402 | def in_bulk(self, object_ids): 403 | """Retrieve a set of documents by their ids. 404 | 405 | :param object_ids: a list or tuple of id's 406 | :rtype: dict of ids as keys and collection-specific 407 | Document subclasses as values. 408 | 409 | .. versionadded:: 0.3 410 | """ 411 | doc_map = {} 412 | 413 | docs = self._collection.find({'_id': {'$in': [ (not isinstance(id, ObjectId) and ObjectId(id)) or id for id in object_ids]}}) 414 | for doc in docs: 415 | doc_map[str(doc['id'])] = self._document(**dict_keys_to_str(doc)) 416 | 417 | return doc_map 418 | 419 | def count(self): 420 | """Count the selected elements in the query. 421 | """ 422 | if self._limit == 0: 423 | return 0 424 | return self._cursor.count(with_limit_and_skip=False) 425 | 426 | def __len__(self): 427 | return self.count() 428 | 429 | def map_reduce(self, map_f, reduce_f, finalize_f=None, limit=None, 430 | scope=None, keep_temp=False): 431 | """Perform a map/reduce query using the current query spec 432 | and ordering. While ``map_reduce`` respects ``QuerySet`` chaining, 433 | it must be the last call made, as it does not return a maleable 434 | ``QuerySet``. 435 | 436 | See the :meth:`~mongoengine.tests.QuerySetTest.test_map_reduce` 437 | and :meth:`~mongoengine.tests.QuerySetTest.test_map_advanced` 438 | tests in ``tests.queryset.QuerySetTest`` for usage examples. 439 | 440 | :param map_f: map function, as :class:`~pymongo.code.Code` or string 441 | :param reduce_f: reduce function, as 442 | :class:`~pymongo.code.Code` or string 443 | :param finalize_f: finalize function, an optional function that 444 | performs any post-reduction processing. 445 | :param scope: values to insert into map/reduce global scope. Optional. 446 | :param limit: number of objects from current query to provide 447 | to map/reduce method 448 | :param keep_temp: keep temporary table (boolean, default ``True``) 449 | 450 | Returns an iterator yielding 451 | :class:`~mongoengine.document.MapReduceDocument`. 452 | 453 | .. note:: Map/Reduce requires server version **>= 1.1.1**. The PyMongo 454 | :meth:`~pymongo.collection.Collection.map_reduce` helper requires 455 | PyMongo version **>= 1.2**. 456 | 457 | .. versionadded:: 0.3 458 | """ 459 | #from document import MapReduceDocument 460 | 461 | if not hasattr(self._collection, "map_reduce"): 462 | raise NotImplementedError("Requires MongoDB >= 1.1.1") 463 | 464 | map_f_scope = {} 465 | if isinstance(map_f, pymongo.code.Code): 466 | map_f_scope = map_f.scope 467 | map_f = unicode(map_f) 468 | # map_f = pymongo.code.Code(self._sub_js_fields(map_f), map_f_scope) 469 | map_f = pymongo.code.Code(map_f, map_f_scope) 470 | 471 | reduce_f_scope = {} 472 | if isinstance(reduce_f, pymongo.code.Code): 473 | reduce_f_scope = reduce_f.scope 474 | reduce_f = unicode(reduce_f) 475 | # reduce_f_code = self._sub_js_fields(reduce_f) 476 | reduce_f_code = reduce_f 477 | reduce_f = pymongo.code.Code(reduce_f_code, reduce_f_scope) 478 | 479 | mr_args = {'query': self._query, 'keeptemp': keep_temp} 480 | 481 | if finalize_f: 482 | finalize_f_scope = {} 483 | if isinstance(finalize_f, pymongo.code.Code): 484 | finalize_f_scope = finalize_f.scope 485 | finalize_f = unicode(finalize_f) 486 | # finalize_f_code = self._sub_js_fields(finalize_f) 487 | finalize_f_code = finalize_f 488 | finalize_f = pymongo.code.Code(finalize_f_code, finalize_f_scope) 489 | mr_args['finalize'] = finalize_f 490 | 491 | if scope: 492 | mr_args['scope'] = scope 493 | 494 | if limit: 495 | mr_args['limit'] = limit 496 | 497 | results = self._collection.map_reduce(map_f, reduce_f, **mr_args) 498 | results = results.find() 499 | 500 | if self._ordering: 501 | results = results.sort(self._ordering) 502 | 503 | for doc in results: 504 | yield self._document.objects.with_id(doc['value']) 505 | 506 | def limit(self, n): 507 | """Limit the number of returned documents to `n`. This may also be 508 | achieved using array-slicing syntax (e.g. ``User.objects[:5]``). 509 | 510 | :param n: the maximum number of objects to return 511 | """ 512 | if n == 0: 513 | self._cursor.limit(1) 514 | else: 515 | self._cursor.limit(n) 516 | self._limit = n 517 | 518 | # Return self to allow chaining 519 | return self 520 | 521 | def skip(self, n): 522 | """Skip `n` documents before returning the results. This may also be 523 | achieved using array-slicing syntax (e.g. ``User.objects[5:]``). 524 | 525 | :param n: the number of objects to skip before returning results 526 | """ 527 | self._cursor.skip(n) 528 | self._skip = n 529 | return self 530 | 531 | def __getitem__(self, key): 532 | """Support skip and limit using getitem and slicing syntax. 533 | """ 534 | # Slice provided 535 | if isinstance(key, slice): 536 | try: 537 | self._cursor_obj = self._cursor[key] 538 | self._skip, self._limit = key.start, key.stop 539 | except IndexError, err: 540 | # PyMongo raises an error if key.start == key.stop, catch it, 541 | # bin it, kill it. 542 | start = key.start or 0 543 | if start >= 0 and key.stop >= 0 and key.step is None: 544 | if start == key.stop: 545 | self.limit(0) 546 | self._skip, self._limit = key.start, key.stop - start 547 | return self 548 | raise err 549 | # Allow further QuerySet modifications to be performed 550 | return self 551 | # Integer index provided 552 | elif isinstance(key, int): 553 | return self._document(**dict_keys_to_str(self._cursor[key])) 554 | 555 | def only(self, *fields): 556 | """Load only a subset of this document's fields. :: 557 | 558 | post = BlogPost.objects(...).only("title") 559 | 560 | :param fields: fields to include 561 | 562 | .. versionadded:: 0.3 563 | """ 564 | self._loaded_fields = [] 565 | for field in fields: 566 | if '.' in field: 567 | raise InvalidQueryError('Subfields cannot be used as ' 568 | 'arguments to QuerySet.only') 569 | # Translate field name 570 | field = QuerySet._lookup_field(self._document, field)[-1].db_field 571 | self._loaded_fields.append(field) 572 | 573 | # _cls is needed for polymorphism 574 | if self._document._meta.get('allow_inheritance'): 575 | self._loaded_fields += ['_cls'] 576 | return self 577 | 578 | def order_by(self, *args): 579 | """Order the :class:`~mongoengine.queryset.QuerySet` by the keys. The 580 | order may be specified by prepending each of the keys by a + or a -. 581 | Ascending order is assumed. 582 | 583 | :param keys: fields to order the query results by; keys may be 584 | prefixed with **+** or **-** to determine the ordering direction 585 | """ 586 | 587 | self._ordering = [] 588 | for col in args: 589 | self._ordering.append(( (col.startswith("-") and col[1:]) or col, (col.startswith("-") and -1) or 1 )) 590 | 591 | self._cursor.sort(self._ordering) 592 | return self 593 | 594 | def explain(self, format=False): 595 | """Return an explain plan record for the 596 | :class:`~mongoengine.queryset.QuerySet`\ 's cursor. 597 | 598 | :param format: format the plan before returning it 599 | """ 600 | 601 | plan = self._cursor.explain() 602 | if format: 603 | import pprint 604 | plan = pprint.pformat(plan) 605 | return plan 606 | 607 | def delete(self, safe=False): 608 | """Delete the documents matched by the query. 609 | 610 | :param safe: check if the operation succeeded before returning 611 | """ 612 | self._collection.remove(self._query, safe=safe) 613 | 614 | @classmethod 615 | def _transform_update(cls, _doc_cls=None, **update): 616 | """Transform an update spec from Django-style format to Mongo format. 617 | """ 618 | operators = ['set', 'unset', 'inc', 'dec', 'push', 'push_all', 'pull', 619 | 'pull_all'] 620 | 621 | mongo_update = {} 622 | for key, value in update.items(): 623 | parts = key.split('__') 624 | # Check for an operator and transform to mongo-style if there is 625 | op = None 626 | if parts[0] in operators: 627 | op = parts.pop(0) 628 | # Convert Pythonic names to Mongo equivalents 629 | if op in ('push_all', 'pull_all'): 630 | op = op.replace('_all', 'All') 631 | elif op == 'dec': 632 | # Support decrement by flipping a positive value's sign 633 | # and using 'inc' 634 | op = 'inc' 635 | if value > 0: 636 | value = -value 637 | 638 | if _doc_cls: 639 | # Switch field names to proper names [set in Field(name='foo')] 640 | fields = QuerySet._lookup_field(_doc_cls, parts) 641 | parts = [field.db_field for field in fields] 642 | 643 | # Convert value to proper value 644 | field = fields[-1] 645 | if op in (None, 'set', 'unset', 'push', 'pull'): 646 | value = field.prepare_query_value(op, value) 647 | elif op in ('pushAll', 'pullAll'): 648 | value = [field.prepare_query_value(op, v) for v in value] 649 | 650 | key = '.'.join(parts) 651 | 652 | if op: 653 | value = {key: value} 654 | key = '$' + op 655 | 656 | if op is None or key not in mongo_update: 657 | mongo_update[key] = value 658 | elif key in mongo_update and isinstance(mongo_update[key], dict): 659 | mongo_update[key].update(value) 660 | 661 | return mongo_update 662 | 663 | def update(self, safe_update=True, upsert=False, **update): 664 | """Perform an atomic update on the fields matched by the query. 665 | 666 | :param safe: check if the operation succeeded before returning 667 | :param update: Django-style update keyword arguments 668 | 669 | .. versionadded:: 0.2 670 | """ 671 | if pymongo.version < '1.1.1': 672 | raise OperationError('update() method requires PyMongo 1.1.1+') 673 | 674 | update = QuerySet._transform_update(self._document, **update) 675 | try: 676 | self._collection.update(self._query, update, safe=safe_update, 677 | upsert=upsert, multi=True) 678 | except pymongo.errors.OperationFailure, err: 679 | if unicode(err) == u'multi not coded yet': 680 | message = u'update() method requires MongoDB 1.1.3+' 681 | raise OperationError(message) 682 | raise OperationError(u'Update failed (%s)' % unicode(err)) 683 | 684 | def update_one(self, safe_update=True, upsert=False, **update): 685 | """Perform an atomic update on first field matched by the query. 686 | 687 | :param safe: check if the operation succeeded before returning 688 | :param update: Django-style update keyword arguments 689 | 690 | .. versionadded:: 0.2 691 | """ 692 | update = QuerySet._transform_update(self._document, **update) 693 | try: 694 | # Explicitly provide 'multi=False' to newer versions of PyMongo 695 | # as the default may change to 'True' 696 | if pymongo.version >= '1.1.1': 697 | self._collection.update(self._query, update, safe=safe_update, 698 | upsert=upsert, multi=False) 699 | else: 700 | # Older versions of PyMongo don't support 'multi' 701 | self._collection.update(self._query, update, safe=safe_update) 702 | except pymongo.errors.OperationFailure, e: 703 | raise OperationError(u'Update failed [%s]' % unicode(e)) 704 | 705 | def __iter__(self, *args, **kwargs): 706 | for obj in self._cursor: 707 | data = dict_keys_to_str(obj) 708 | if '_id' in data: 709 | data['id']=data.pop('_id') 710 | yield self._document(**data) 711 | 712 | def _sub_js_fields(self, code): 713 | """When fields are specified with [~fieldname] syntax, where 714 | *fieldname* is the Python name of a field, *fieldname* will be 715 | substituted for the MongoDB name of the field (specified using the 716 | :attr:`name` keyword argument in a field's constructor). 717 | """ 718 | def field_sub(match): 719 | # Extract just the field name, and look up the field objects 720 | field_name = match.group(1).split('.') 721 | fields = QuerySet._lookup_field(self._document, field_name) 722 | # Substitute the correct name for the field into the javascript 723 | return u'["%s"]' % fields[-1].db_field 724 | 725 | return re.sub(u'\[\s*~([A-z_][A-z_0-9.]+?)\s*\]', field_sub, code) 726 | 727 | def exec_js(self, code, *fields, **options): 728 | """ 729 | Execute a Javascript function on the server. A list of fields may be 730 | provided, which will be translated to their correct names and supplied 731 | as the arguments to the function. A few extra variables are added to 732 | the function's scope: ``collection``, which is the name of the 733 | collection in use; ``query``, which is an object representing the 734 | current query; and ``options``, which is an object containing any 735 | options specified as keyword arguments. 736 | 737 | As fields in MongoEngine may use different names in the database (set 738 | using the :attr:`db_field` keyword argument to a :class:`Field` 739 | constructor), a mechanism exists for replacing MongoEngine field names 740 | with the database field names in Javascript code. When accessing a 741 | field, use square-bracket notation, and prefix the MongoEngine field 742 | name with a tilde (~). 743 | 744 | :param code: a string of Javascript code to execute 745 | :param fields: fields that you will be using in your function, which 746 | will be passed in to your function as arguments 747 | :param options: options that you want available to the function 748 | (accessed in Javascript through the ``options`` object) 749 | """ 750 | # code = self._sub_js_fields(code) 751 | 752 | fields = [QuerySet._translate_field_name(self._document, f) for f in fields] 753 | collection = self._collection 754 | 755 | scope = { 756 | 'collection': collection.name, 757 | 'options': options or {}, 758 | } 759 | 760 | query = self._query 761 | if self._where_clause: 762 | query['$where'] = self._where_clause 763 | 764 | scope['query'] = query 765 | code = pymongo.code.Code(code, scope=scope) 766 | 767 | return collection.database.eval(code, *fields) 768 | 769 | def sum(self, field): 770 | """Sum over the values of the specified field. 771 | 772 | :param field: the field to sum over; use dot-notation to refer to 773 | embedded document fields 774 | """ 775 | sum_func = """ 776 | function(sumField) { 777 | var total = 0.0; 778 | db[collection].find(query).forEach(function(doc) { 779 | total += (doc[sumField] || 0.0); 780 | }); 781 | return total; 782 | } 783 | """ 784 | return self.exec_js(sum_func, field) 785 | 786 | def average(self, field): 787 | """Average over the values of the specified field. 788 | 789 | :param field: the field to average over; use dot-notation to refer to 790 | embedded document fields 791 | """ 792 | average_func = """ 793 | function(averageField) { 794 | var total = 0.0; 795 | var num = 0; 796 | db[collection].find(query).forEach(function(doc) { 797 | if (doc[averageField]) { 798 | total += doc[averageField]; 799 | num += 1; 800 | } 801 | }); 802 | return total / num; 803 | } 804 | """ 805 | return self.exec_js(average_func, field) 806 | 807 | def item_frequencies(self, list_field, normalize=False): 808 | """Returns a dictionary of all items present in a list field across 809 | the whole queried set of documents, and their corresponding frequency. 810 | This is useful for generating tag clouds, or searching documents. 811 | 812 | :param list_field: the list field to use 813 | :param normalize: normalize the results so they add to 1.0 814 | """ 815 | freq_func = """ 816 | function(listField) { 817 | if (options.normalize) { 818 | var total = 0.0; 819 | db[collection].find(query).forEach(function(doc) { 820 | total += doc[listField].length; 821 | }); 822 | } 823 | 824 | var frequencies = {}; 825 | var inc = 1.0; 826 | if (options.normalize) { 827 | inc /= total; 828 | } 829 | db[collection].find(query).forEach(function(doc) { 830 | doc[listField].forEach(function(item) { 831 | frequencies[item] = inc + (frequencies[item] || 0); 832 | }); 833 | }); 834 | return frequencies; 835 | } 836 | """ 837 | return self.exec_js(freq_func, list_field, normalize=normalize) 838 | 839 | def __repr__(self): 840 | limit = REPR_OUTPUT_SIZE + 1 841 | if self._limit is not None and self._limit < limit: 842 | limit = self._limit 843 | data = list(self[self._skip:limit]) 844 | if len(data) > REPR_OUTPUT_SIZE: 845 | data[-1] = "...(remaining elements truncated)..." 846 | return repr(data) 847 | 848 | def _clone(self): 849 | return self 850 | 851 | 852 | class Manager(DJManager): 853 | 854 | def __init__(self, manager_func=None): 855 | super(Manager, self).__init__() 856 | self._manager_func = manager_func 857 | self._collection = None 858 | 859 | def contribute_to_class(self, model, name): 860 | # TODO: Use weakref because of possible memory leak / circular reference. 861 | self.model = model 862 | # setattr(model, name, ManagerDescriptor(self)) 863 | if model._meta.abstract or (self._inherited and not self.model._meta.proxy): 864 | model._meta.abstract_managers.append((self.creation_counter, name, 865 | self)) 866 | else: 867 | model._meta.concrete_managers.append((self.creation_counter, name, 868 | self)) 869 | 870 | def __get__(self, instance, owner): 871 | """Descriptor for instantiating a new QuerySet object when 872 | Document.objects is accessed. 873 | """ 874 | self.model = owner #We need to set the model to get the db 875 | 876 | if instance is not None: 877 | # Document class being used rather than a document object 878 | return self 879 | 880 | if self._collection is None: 881 | self._collection = connections[self.db].db_connection[owner._meta.db_table] 882 | 883 | # owner is the document that contains the QuerySetManager 884 | queryset = QuerySet(owner, self._collection) 885 | if self._manager_func: 886 | if self._manager_func.func_code.co_argcount == 1: 887 | queryset = self._manager_func(queryset) 888 | else: 889 | queryset = self._manager_func(owner, queryset) 890 | return queryset 891 | 892 | 893 | def queryset_manager(func): 894 | """Decorator that allows you to define custom QuerySet managers on 895 | :class:`~mongoengine.Document` classes. The manager must be a function that 896 | accepts a :class:`~mongoengine.Document` class as its first argument, and a 897 | :class:`~mongoengine.queryset.QuerySet` as its second argument. The method 898 | function should return a :class:`~mongoengine.queryset.QuerySet`, probably 899 | the same one that was passed in, but modified in some way. 900 | """ 901 | if func.func_code.co_argcount == 1: 902 | import warnings 903 | msg = 'Methods decorated with queryset_manager should take 2 arguments' 904 | warnings.warn(msg, DeprecationWarning) 905 | return QuerySetManager(func) 906 | -------------------------------------------------------------------------------- /django_elasticsearch/mapping.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from pyes import mappings 5 | from django.conf import settings 6 | import time 7 | from django.db.models.manager import Manager 8 | 9 | def model_to_mapping(model, depth=1): 10 | """ 11 | Given a model return a mapping 12 | """ 13 | meta = model._meta 14 | indexoptions = getattr(model, "indexeroptions", {}) 15 | ignore = indexoptions.get("ignore", []) 16 | fields_options = indexoptions.get("fields", {}) 17 | extra_fields = indexoptions.get("extra_fields", {}) 18 | mapper = mappings.ObjectField(meta.module_name) 19 | for field in meta.fields + meta.many_to_many: 20 | name = field.name 21 | if name in ignore: 22 | continue 23 | mapdata = get_mapping_for_field(field, depth=depth, **fields_options.get(name, {})) 24 | if mapdata: 25 | mapper.add_property(mapdata) 26 | for name, options in extra_fields.items(): 27 | type = options.pop("type", "string") 28 | if type == "string": 29 | data = dict(name=name, store=True, 30 | index="analyzed", 31 | term_vector="with_positions_offsets" 32 | ) 33 | data.update(options) 34 | 35 | if data['index'] == 'not_analyzed': 36 | del data['term_vector'] 37 | 38 | mapper.add_property(mappings.StringField(**data)) 39 | continue 40 | 41 | return mapper 42 | 43 | def get_mapping_for_field(field, depth=1, **options): 44 | """Given a field returns a mapping""" 45 | ntype = type(field).__name__ 46 | if ntype in ["AutoField"]: 47 | # return mappings.MultiField(name=field.name, 48 | # fields={field.name:mappings.StringField(name=field.name, store=True), 49 | # "int":mappings.IntegerField(name="int", store=True)} 50 | # ) 51 | return mappings.StringField(name=field.name, store=True) 52 | elif ntype in ["IntegerField", 53 | "PositiveSmallIntegerField", 54 | "SmallIntegerField", 55 | "PositiveIntegerField", 56 | "PositionField", 57 | ]: 58 | return mappings.IntegerField(name=field.name, store=True) 59 | elif ntype in ["FloatField", 60 | "DecimalField", 61 | ]: 62 | return mappings.DoubleField(name=field.name, store=True) 63 | elif ntype in ["BooleanField", 64 | "NullBooleanField", 65 | ]: 66 | return mappings.BooleanField(name=field.name, store=True) 67 | elif ntype in ["DateField", 68 | "DateTimeField", 69 | "CreationDateTimeField", 70 | "ModificationDateTimeField", 71 | "AddedDateTimeField", 72 | "ModifiedDateTimeField", 73 | "brainaetic.djangoutils.db.fields.CreationDateTimeField", 74 | "brainaetic.djangoutils.db.fields.ModificationDateTimeField", 75 | ]: 76 | return mappings.DateField(name=field.name, store=True) 77 | elif ntype in ["SlugField", 78 | "EmailField", 79 | "TagField", 80 | "URLField", 81 | "CharField", 82 | "ImageField", 83 | "FileField", 84 | ]: 85 | return mappings.MultiField(name=field.name, 86 | fields={field.name:mappings.StringField(name=field.name, index="not_analyzed", store=True), 87 | "tk":mappings.StringField(name="tk", store=True, 88 | index="analyzed", 89 | term_vector="with_positions_offsets")} 90 | 91 | ) 92 | elif ntype in ["TextField", 93 | ]: 94 | data = dict(name=field.name, store=True, 95 | index="analyzed", 96 | term_vector="with_positions_offsets" 97 | ) 98 | if field.unique: 99 | data['index'] = 'not_analyzed' 100 | 101 | data.update(options) 102 | 103 | if data['index'] == 'not_analyzed': 104 | del data['term_vector'] 105 | 106 | return mappings.StringField(**data) 107 | elif ntype in ["ForeignKey", 108 | "TaggableManager", 109 | "GenericRelation", 110 | ]: 111 | if depth >= 0: 112 | mapper = model_to_mapping(field.rel.to, depth - 1) 113 | if mapper: 114 | mapper.name = field.name 115 | return mapper 116 | return None 117 | return get_mapping_for_field(field.rel.to._meta.pk, depth - 1) 118 | 119 | # "IPAddressField", 120 | # 'PickledObjectField' 121 | 122 | elif ntype in ["ManyToManyField", 123 | ]: 124 | if depth > 0: 125 | mapper = model_to_mapping(field.rel.to, depth - 1) 126 | mapper.name = field.name 127 | return mapper 128 | if depth == 0: 129 | mapper = get_mapping_for_field(field.rel.to._meta.pk, depth - 1) 130 | if mapper: 131 | mapper.name = field.name 132 | return mapper 133 | return None 134 | if depth < 0: 135 | return None 136 | print ntype 137 | return None 138 | 139 | -------------------------------------------------------------------------------- /django_elasticsearch/models.py: -------------------------------------------------------------------------------- 1 | from django.db.models import signals 2 | from .fields import add_elasticsearch_manager 3 | 4 | signals.class_prepared.connect(add_elasticsearch_manager) -------------------------------------------------------------------------------- /django_elasticsearch/router.py: -------------------------------------------------------------------------------- 1 | 2 | class ESRouter(object): 3 | """A router to control all database operations on models in 4 | the myapp application""" 5 | def __init__(self): 6 | from django.conf import settings 7 | self.managed_apps = [app.split('.')[-1] for app in getattr(settings, "ELASTICSEARCH_MANAGED_APPS", [])] 8 | self.managed_models = getattr(settings, "ELASTICSEARCH_MANAGED_MODELS", []) 9 | self.elasticsearch_database = None 10 | self.elasticsearch_databases = [] 11 | for name, databaseopt in settings.DATABASES.items(): 12 | if databaseopt["ENGINE"]=='django_elasticsearch': 13 | self.elasticsearch_database = name 14 | self.elasticsearch_databases.append(name) 15 | if self.elasticsearch_database is None: 16 | raise RuntimeError("A elasticsearch database must be set") 17 | 18 | def db_for_read(self, model, **hints): 19 | "Point all operations on elasticsearch models to a elasticsearch database" 20 | if model._meta.app_label in self.managed_apps: 21 | return self.elasticsearch_database 22 | key = "%s.%s"%(model._meta.app_label, model._meta.module_name) 23 | if key in self.managed_models: 24 | return self.elasticsearch_database 25 | return None 26 | 27 | def db_for_write(self, model, **hints): 28 | "Point all operations on elasticsearch models to a elasticsearch database" 29 | if model._meta.app_label in self.managed_apps: 30 | return self.elasticsearch_database 31 | key = "%s.%s"%(model._meta.app_label, model._meta.module_name) 32 | if key in self.managed_models: 33 | return self.elasticsearch_database 34 | return None 35 | 36 | def allow_relation(self, obj1, obj2, **hints): 37 | "Allow any relation if a model in myapp is involved" 38 | 39 | #key1 = "%s.%s"%(obj1._meta.app_label, obj1._meta.module_name) 40 | key2 = "%s.%s"%(obj2._meta.app_label, obj2._meta.module_name) 41 | 42 | # obj2 is the model instance so, mongo_serializer should take care 43 | # of the related object. We keep trac of the obj1 db so, don't worry 44 | # about the multi-database management 45 | if obj2._meta.app_label in self.managed_apps or key2 in self.managed_models: 46 | return True 47 | 48 | return None 49 | 50 | def allow_syncdb(self, db, model): 51 | "Make sure that a elasticsearch model appears on a elasticsearch database" 52 | key = "%s.%s"%(model._meta.app_label, model._meta.module_name) 53 | if db in self.elasticsearch_databases: 54 | return model._meta.app_label in self.managed_apps or key in self.managed_models 55 | elif model._meta.app_label in self.managed_apps or key in self.managed_models: 56 | if db in self.elasticsearch_databases: 57 | return True 58 | else: 59 | return False 60 | return None 61 | 62 | def valid_for_db_engine(self, driver, model): 63 | "Make sure that a model is valid for a database provider" 64 | if driver!="elasticsearch": 65 | return None 66 | if model._meta.app_label in self.managed_apps: 67 | return True 68 | key = "%s.%s"%(model._meta.app_label, model._meta.module_name) 69 | if key in self.managed_models: 70 | return True 71 | return None 72 | 73 | -------------------------------------------------------------------------------- /django_elasticsearch/serializer.py: -------------------------------------------------------------------------------- 1 | from django.utils.importlib import import_module 2 | from datetime import datetime, date, time 3 | #TODO Add content type cache 4 | from utils import ModelLazyObject 5 | from json import JSONDecoder, JSONEncoder 6 | import uuid 7 | 8 | class Decoder(JSONDecoder): 9 | """Extends the base simplejson JSONDecoder for Dejavu.""" 10 | def __init__(self, arena=None, encoding=None, object_hook=None, **kwargs): 11 | JSONDecoder.__init__(self, encoding, object_hook, **kwargs) 12 | if not self.object_hook: 13 | self.object_hook = self.json_to_python 14 | self.arena = arena 15 | 16 | def json_to_python(self, son): 17 | 18 | if isinstance(son, dict): 19 | if "_type" in son and son["_type"] in [u"django", u'emb']: 20 | son = self.decode_django(son) 21 | else: 22 | for (key, value) in son.items(): 23 | if isinstance(value, dict): 24 | if "_type" in value and value["_type"] in [u"django", u'emb']: 25 | son[key] = self.decode_django(value) 26 | else: 27 | son[key] = self.json_to_python(value) 28 | elif hasattr(value, "__iter__"): # Make sure we recurse into sub-docs 29 | son[key] = [self.json_to_python(item) for item in value] 30 | else: # Again, make sure to recurse into sub-docs 31 | son[key] = self.json_to_python(value) 32 | elif hasattr(son, "__iter__"): # Make sure we recurse into sub-docs 33 | son = [self.json_to_python(item) for item in son] 34 | return son 35 | 36 | def decode_django(self, data): 37 | from django.contrib.contenttypes.models import ContentType 38 | if data['_type']=="django": 39 | model = ContentType.objects.get(app_label=data['_app'], model=data['_model']) 40 | return ModelLazyObject(model.model_class(), data['pk']) 41 | elif data['_type']=="emb": 42 | try: 43 | model = ContentType.objects.get(app_label=data['_app'], model=data['_model']).model_class() 44 | except: 45 | module = import_module(data['_app']) 46 | model = getattr(module, data['_model']) 47 | 48 | del data['_type'] 49 | del data['_app'] 50 | del data['_model'] 51 | data.pop('_id', None) 52 | values = {} 53 | for k,v in data.items(): 54 | values[str(k)] = self.json_to_python(v) 55 | return model(**values) 56 | 57 | class Encoder(JSONEncoder): 58 | def __init__(self, *args, **kwargs): 59 | JSONEncoder.__init__(self, *args, **kwargs) 60 | 61 | 62 | def encode_django(self, model): 63 | """ 64 | Encode ricorsive embedded models and django models 65 | """ 66 | from django_elasticsearch.fields import EmbeddedModel 67 | if isinstance(model, EmbeddedModel): 68 | if model.pk is None: 69 | model.pk = str(uuid.uuid4()) 70 | res = {'_app':model._meta.app_label, 71 | '_model':model._meta.module_name, 72 | '_id':model.pk} 73 | for field in model._meta.fields: 74 | res[field.attname] = self.default(getattr(model, field.attname)) 75 | res["_type"] = "emb" 76 | from django.contrib.contenttypes.models import ContentType 77 | try: 78 | ContentType.objects.get(app_label=res['_app'], model=res['_model']) 79 | except: 80 | res['_app'] = model.__class__.__module__ 81 | res['_model'] = model._meta.object_name 82 | 83 | return res 84 | if not model.pk: 85 | model.save() 86 | return {'_app':model._meta.app_label, 87 | '_model':model._meta.module_name, 88 | 'pk':model.pk, 89 | '_type':"django"} 90 | 91 | def default(self, value): 92 | """Convert rogue and mysterious data types. 93 | Conversion notes: 94 | 95 | - ``datetime.date`` and ``datetime.datetime`` objects are 96 | converted into datetime strings. 97 | """ 98 | from django.db.models import Model 99 | from django_elasticsearch.fields import EmbeddedModel 100 | 101 | if isinstance(value, datetime): 102 | return value.strftime("%Y-%m-%dT%H:%M:%S") 103 | elif isinstance(value, date): 104 | dt = datetime(value.year, value.month, value.day, 0, 0, 0) 105 | return dt.strftime("%Y-%m-%dT%H:%M:%S") 106 | # elif isinstance(value, dict): 107 | # for (key, value) in value.items(): 108 | # if isinstance(value, (str, unicode)): 109 | # continue 110 | # if isinstance(value, (Model, EmbeddedModel)): 111 | # value[key] = self.encode_django(value, collection) 112 | # elif isinstance(value, dict): # Make sure we recurse into sub-docs 113 | # value[key] = self.transform_incoming(value) 114 | # elif hasattr(value, "__iter__"): # Make sure we recurse into sub-docs 115 | # value[key] = [self.transform_incoming(item) for item in value] 116 | elif isinstance(value, (str, unicode)): 117 | pass 118 | elif hasattr(value, "__iter__"): # Make sure we recurse into sub-docs 119 | value = [self.transform_incoming(item, collection) for item in value] 120 | elif isinstance(value, (Model, EmbeddedModel)): 121 | value = self.encode_django(value) 122 | return value 123 | -------------------------------------------------------------------------------- /django_elasticsearch/south.py: -------------------------------------------------------------------------------- 1 | class DatabaseOperations(object): 2 | """ 3 | ES implementation of database operations. 4 | """ 5 | 6 | backend_name = "django.db.backends.elasticsearch" 7 | 8 | supports_foreign_keys = False 9 | has_check_constraints = False 10 | 11 | def __init__(self, db_alias): 12 | pass 13 | 14 | def add_column(self, table_name, name, field, *args, **kwds): 15 | pass 16 | 17 | def alter_column(self, table_name, name, field, explicit_name=True): 18 | pass 19 | 20 | def delete_column(self, table_name, column_name): 21 | pass 22 | 23 | def rename_column(self, table_name, old, new): 24 | pass 25 | 26 | def create_unique(self, table_name, columns): 27 | pass 28 | 29 | def delete_unique(self, table_name, columns): 30 | pass 31 | 32 | def delete_primary_key(self, table_name): 33 | pass 34 | 35 | def delete_table(self, table_name, cascade=True): 36 | pass 37 | 38 | def connection_init(self): 39 | pass -------------------------------------------------------------------------------- /django_elasticsearch/utils.py: -------------------------------------------------------------------------------- 1 | from django.utils.functional import SimpleLazyObject 2 | 3 | def dict_keys_to_str(dictionary, recursive=False): 4 | res = dict([(str(k), (not isinstance(v, dict) and v) or (recursive and dict_keys_to_str(v)) or v) for k,v in dictionary.items()]) 5 | if '_id' in res: 6 | res["id"] = res.pop("_id") 7 | return res 8 | 9 | class ModelLazyObject(SimpleLazyObject): 10 | """ 11 | A lazy object initialised a model. 12 | """ 13 | def __init__(self, model, pk): 14 | """ 15 | Pass in a callable that returns the object to be wrapped. 16 | 17 | If copies are made of the resulting SimpleLazyObject, which can happen 18 | in various circumstances within Django, then you must ensure that the 19 | callable can be safely run more than once and will return the same 20 | value. 21 | """ 22 | # For some reason, we have to inline LazyObject.__init__ here to avoid 23 | # recursion 24 | self._wrapped = None 25 | self.__dict__['_pk'] = pk 26 | self.__dict__['_model'] = model 27 | super(ModelLazyObject, self).__init__(self._load_data) 28 | 29 | def _load_data(self): 30 | return self._model.objects.get(pk=self._pk) 31 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import os 3 | 4 | DESCRIPTION = "A ES backend standing outside django (>= 1.2)" 5 | 6 | LONG_DESCRIPTION = None 7 | try: 8 | LONG_DESCRIPTION = open('README.rst').read() 9 | except: 10 | pass 11 | 12 | def get_version(version_tuple): 13 | version = '%s.%s' % (version_tuple[0], version_tuple[1]) 14 | if version_tuple[2]: 15 | version = '%s.%s' % (version, version_tuple[2]) 16 | return version 17 | 18 | init = os.path.join(os.path.dirname(__file__), 'django_elasticsearch', '__init__.py') 19 | print init 20 | version_line = filter(lambda l: l.startswith('VERSION'), open(init))[0] 21 | VERSION = get_version(eval(version_line.split('=')[-1])) 22 | print VERSION 23 | 24 | CLASSIFIERS = [ 25 | 'Development Status :: 4 - Beta', 26 | 'Intended Audience :: Developers', 27 | 'License :: OSI Approved :: MIT License', 28 | 'Operating System :: OS Independent', 29 | 'Programming Language :: Python', 30 | 'Topic :: Database', 31 | 'Topic :: Software Development :: Libraries :: Python Modules', 32 | ] 33 | 34 | setup(name='django_elasticsearch', 35 | version=VERSION, 36 | packages=find_packages(), 37 | author='Alberto Paro', 38 | author_email='alberto@{nospam}ingparo.it', 39 | url='http://github.com/aparo/django-elasticsearch', 40 | license='MIT', 41 | include_package_data=True, 42 | description=DESCRIPTION, 43 | long_description=LONG_DESCRIPTION, 44 | platforms=['any'], 45 | classifiers=CLASSIFIERS, 46 | install_requires=['pyes'], 47 | test_suite='tests', 48 | ) 49 | -------------------------------------------------------------------------------- /tests/testproj/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aparo/django-elasticsearch/8fd25bd86b58cfc0d6490cfac08e4846ab4ddf97/tests/testproj/__init__.py -------------------------------------------------------------------------------- /tests/testproj/manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os, sys 3 | from django.core.management import execute_manager 4 | # dirty hack to get the backend working. 5 | #sys.path.insert(0, os.path.abspath('./..')) 6 | #sys.path.insert(0, os.path.abspath('./../..')) 7 | #example_dir = os.path.dirname(__file__) 8 | #sys.path.insert(0, os.path.join(example_dir, '..')) 9 | try: 10 | import settings # Assumed to be in the same directory. 11 | except ImportError: 12 | import sys 13 | sys.stderr.write("Error: Can't find the file 'settings.py' in the directory containing %r. It appears you've customized things.\nYou'll have to run django-admin.py, passing it your settings module.\n(If the file settings.py does indeed exist, it's causing an ImportError somehow.)\n" % __file__) 14 | sys.exit(1) 15 | 16 | if __name__ == "__main__": 17 | execute_manager(settings) 18 | -------------------------------------------------------------------------------- /tests/testproj/mixed/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aparo/django-elasticsearch/8fd25bd86b58cfc0d6490cfac08e4846ab4ddf97/tests/testproj/mixed/__init__.py -------------------------------------------------------------------------------- /tests/testproj/mixed/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | class Post(models.Model): 4 | title = models.CharField(max_length=200, db_index=True) 5 | 6 | def __unicode__(self): 7 | return "Post: %s" % self.title 8 | 9 | class Record(models.Model): 10 | title = models.CharField(max_length=200, db_index=True) 11 | 12 | def __unicode__(self): 13 | return "Record: %s" % self.title 14 | -------------------------------------------------------------------------------- /tests/testproj/mixed/tests.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file demonstrates two different styles of tests (one doctest and one 3 | unittest). These will both pass when you run "manage.py test". 4 | 5 | Replace these with more appropriate tests for your application. 6 | """ 7 | 8 | from django.test import TestCase 9 | 10 | class SimpleTest(TestCase): 11 | def test_basic_addition(self): 12 | """ 13 | Tests that 1 + 1 always equals 2. 14 | """ 15 | self.failUnlessEqual(1 + 1, 2) 16 | 17 | __test__ = {"doctest": """ 18 | Another way to test that 1 + 1 is equal to 2. 19 | 20 | >>> 1 + 1 == 2 21 | True 22 | """} 23 | 24 | -------------------------------------------------------------------------------- /tests/testproj/mixed/views.py: -------------------------------------------------------------------------------- 1 | # Create your views here. 2 | -------------------------------------------------------------------------------- /tests/testproj/myapp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aparo/django-elasticsearch/8fd25bd86b58cfc0d6490cfac08e4846ab4ddf97/tests/testproj/myapp/__init__.py -------------------------------------------------------------------------------- /tests/testproj/myapp/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | from django.utils.translation import ugettext_lazy as _ 3 | from django_elasticsearch.fields import EmbeddedModel 4 | from djangotoolbox.fields import ListField, DictField 5 | 6 | class Blog(models.Model): 7 | title = models.CharField(max_length=200, db_index=True) 8 | 9 | def __unicode__(self): 10 | return "Blog: %s" % self.title 11 | 12 | class Entry(models.Model): 13 | title = models.CharField(max_length=200, db_index=True, unique=True) 14 | content = models.CharField(max_length=1000) 15 | date_published = models.DateTimeField(null=True, blank=True) 16 | blog = models.ForeignKey(Blog, null=True, blank=True) 17 | 18 | class MongoMeta: 19 | descending_indexes = ['title'] 20 | 21 | def __unicode__(self): 22 | return "Entry: %s" % (self.title) 23 | 24 | class Person(models.Model): 25 | name = models.CharField(max_length=20) 26 | surname = models.CharField(max_length=20) 27 | age = models.IntegerField(null=True, blank=True) 28 | 29 | def __unicode__(self): 30 | return u"Person: %s %s" % (self.name, self.surname) 31 | 32 | class StandardAutoFieldModel(models.Model): 33 | title = models.CharField(max_length=200) 34 | 35 | def __unicode__(self): 36 | return "Standard model: %s" % (self.title) 37 | 38 | class EModel(EmbeddedModel): 39 | title = models.CharField(max_length=200) 40 | pos = models.IntegerField(default = 10) 41 | 42 | def test_func(self): 43 | return self.pos 44 | 45 | class TestFieldModel(models.Model): 46 | title = models.CharField(max_length=200) 47 | mlist = ListField() 48 | mlist_default = ListField(default=["a", "b"]) 49 | mdict = DictField() 50 | mdict_default = DictField(default={"a": "a", 'b':1}) 51 | 52 | class MongoMeta: 53 | index_together = [{ 54 | 'fields' : [ ('title', False), 'mlist'] 55 | }] 56 | def __unicode__(self): 57 | return "Test special field model: %s" % (self.title) 58 | -------------------------------------------------------------------------------- /tests/testproj/myapp/tests.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test suite for django-elasticsearch. 3 | """ 4 | 5 | from django.test import TestCase 6 | from testproj.myapp.models import Entry, Blog, StandardAutoFieldModel, Person, TestFieldModel, EModel 7 | import datetime 8 | import time 9 | 10 | class DjangoESTest(TestCase): 11 | # multi_db = True 12 | 13 | # def test_add_and_delete_blog(self): 14 | # blog1 = Blog(title="blog1") 15 | # blog1.save() 16 | # self.assertEqual(Blog.objects.count(), 1) 17 | # blog2 = Blog(title="blog2") 18 | # self.assertEqual(blog2.pk, None) 19 | # blog2.save() 20 | # self.assertNotEqual(blog2.pk, None) 21 | # self.assertEqual(Blog.objects.count(), 2) 22 | # blog2.delete() 23 | # self.assertEqual(Blog.objects.count(), 1) 24 | # blog1.delete() 25 | # self.assertEqual(Blog.objects.count(), 0) 26 | 27 | def test_simple_get(self): 28 | blog1 = Blog(title="blog1") 29 | blog1.save() 30 | blog2 = Blog(title="blog2") 31 | blog2.save() 32 | self.assertEqual(Blog.objects.count(), 2) 33 | self.assertEqual( 34 | Blog.objects.get(title="blog2"), 35 | blog2 36 | ) 37 | self.assertEqual( 38 | Blog.objects.get(title="blog1"), 39 | blog1 40 | ) 41 | 42 | def test_simple_filter(self): 43 | blog1 = Blog(title="same title") 44 | blog1.save() 45 | blog2 = Blog(title="same title") 46 | blog2.save() 47 | blog3 = Blog(title="another title") 48 | blog3.save() 49 | self.assertEqual(Blog.objects.count(), 3) 50 | blog4 = Blog.objects.get(pk=blog1.pk) 51 | self.assertEqual(blog4, blog1) 52 | self.assertEqual( 53 | Blog.objects.filter(title="same title").count(), 54 | 2 55 | ) 56 | self.assertEqual( 57 | Blog.objects.filter(title="same title", pk=blog1.pk).count(), 58 | 1 59 | ) 60 | self.assertEqual( 61 | Blog.objects.filter(title__startswith="same").count(), 62 | 2 63 | ) 64 | self.assertEqual( 65 | Blog.objects.filter(title__istartswith="SAME").count(), 66 | 2 67 | ) 68 | self.assertEqual( 69 | Blog.objects.filter(title__endswith="title").count(), 70 | 3 71 | ) 72 | self.assertEqual( 73 | Blog.objects.filter(title__iendswith="Title").count(), 74 | 3 75 | ) 76 | self.assertEqual( 77 | Blog.objects.filter(title__icontains="same").count(), 78 | 2 79 | ) 80 | self.assertEqual( 81 | Blog.objects.filter(title__contains="same").count(), 82 | 2 83 | ) 84 | self.assertEqual( 85 | Blog.objects.filter(title__iexact="same Title").count(), 86 | 2 87 | ) 88 | 89 | self.assertEqual( 90 | Blog.objects.filter(title__regex="s.me.*").count(), 91 | 2 92 | ) 93 | self.assertEqual( 94 | Blog.objects.filter(title__iregex="S.me.*").count(), 95 | 2 96 | ) 97 | 98 | def test_change_model(self): 99 | blog1 = Blog(title="blog 1") 100 | blog1.save() 101 | self.assertEqual(Blog.objects.count(), 1) 102 | blog1.title = "new title" 103 | blog1.save() 104 | self.assertEqual(Blog.objects.count(), 1) 105 | bl = Blog.objects.all()[0] 106 | self.assertEqual(blog1.title, bl.title) 107 | bl.delete() 108 | 109 | # def test_dates_ordering(self): 110 | # now = datetime.datetime.now() 111 | # before = now - datetime.timedelta(days=1) 112 | # 113 | # entry1 = Entry(title="entry 1", date_published=now) 114 | # entry1.save() 115 | # 116 | # entry2 = Entry(title="entry 2", date_published=before) 117 | # entry2.save() 118 | # 119 | # self.assertEqual( 120 | # list(Entry.objects.order_by('-date_published')), 121 | # [entry1, entry2] 122 | # ) 123 | # 124 | ## self.assertEqual( 125 | ## list(Entry.objects.order_by('date_published')), 126 | ## [entry2, entry1] 127 | ## ) 128 | # 129 | # 130 | ## def test_dates_less_and_more_than(self): 131 | ## now = datetime.datetime.now() 132 | ## before = now + datetime.timedelta(days=1) 133 | ## after = now - datetime.timedelta(days=1) 134 | ## 135 | ## entry1 = Entry(title="entry 1", date_published=now) 136 | ## entry1.save() 137 | ## 138 | ## entry2 = Entry(title="entry 2", date_published=before) 139 | ## entry2.save() 140 | ## 141 | ## entry3 = Entry(title="entry 3", date_published=after) 142 | ## entry3.save() 143 | ## 144 | ## a = list(Entry.objects.filter(date_published=now)) 145 | ## self.assertEqual( 146 | ## list(Entry.objects.filter(date_published=now)), 147 | ## [entry1] 148 | ## ) 149 | ## self.assertEqual( 150 | ## list(Entry.objects.filter(date_published__lt=now)), 151 | ## [entry3] 152 | ## ) 153 | ## self.assertEqual( 154 | ## list(Entry.objects.filter(date_published__gt=now)), 155 | ## [entry2] 156 | ## ) 157 | # 158 | # def test_complex_queries(self): 159 | # p1 = Person(name="igor", surname="duck", age=39) 160 | # p1.save() 161 | # p2 = Person(name="andrea", surname="duck", age=29) 162 | # p2.save() 163 | # self.assertEqual( 164 | # Person.objects.filter(name="igor", surname="duck").count(), 165 | # 1 166 | # ) 167 | # self.assertEqual( 168 | # Person.objects.filter(age__gte=20, surname="duck").count(), 169 | # 2 170 | # ) 171 | # 172 | # def test_fields(self): 173 | # t1 = TestFieldModel(title="p1", 174 | # mlist=["ab", "bc"], 175 | # mdict = {'a':23, "b":True }, 176 | # ) 177 | # t1.save() 178 | # 179 | # t = TestFieldModel.objects.get(id=t1.id) 180 | # self.assertEqual(t.mlist, ["ab", "bc"]) 181 | # self.assertEqual(t.mlist_default, ["a", "b"]) 182 | # self.assertEqual(t.mdict, {'a':23, "b":True }) 183 | # self.assertEqual(t.mdict_default, {"a": "a", 'b':1}) 184 | # 185 | # 186 | # def test_embedded_model(self): 187 | # em = EModel(title="1", pos = 1) 188 | # em2 = EModel(title="2", pos = 2) 189 | # t1 = TestFieldModel(title="p1", 190 | # mlist=[em, em2], 191 | # mdict = {'a':em, "b":em2 }, 192 | # ) 193 | # t1.save() 194 | # 195 | # t = TestFieldModel.objects.get(id=t1.id) 196 | # self.assertEqual(len(t.mlist), 2) 197 | # self.assertEqual(t.mlist[0].test_func(), 1) 198 | # self.assertEqual(t.mlist[1].test_func(), 2) 199 | # 200 | # def test_simple_foreign_keys(self): 201 | # now = datetime.datetime.now() 202 | # 203 | # blog1 = Blog(title="Blog") 204 | # blog1.save() 205 | # entry1 = Entry(title="entry 1", blog=blog1) 206 | # entry1.save() 207 | # entry2 = Entry(title="entry 2", blog=blog1) 208 | # entry2.save() 209 | # self.assertEqual(Entry.objects.count(), 2) 210 | # 211 | # for entry in Entry.objects.all(): 212 | # self.assertEqual( 213 | # blog1, 214 | # entry.blog 215 | # ) 216 | # 217 | # blog2 = Blog(title="Blog") 218 | # blog2.save() 219 | # entry3 = Entry(title="entry 3", blog=blog2) 220 | # entry3.save() 221 | # self.assertEqual( 222 | # # it's' necessary to explicitly state the pk here 223 | # len( list(Entry.objects.filter(blog=blog1.pk))), 224 | # len([entry1, entry2]) 225 | # ) 226 | # 227 | # 228 | ## def test_foreign_keys_bug(self): 229 | ## blog1 = Blog(title="Blog") 230 | ## blog1.save() 231 | ## entry1 = Entry(title="entry 1", blog=blog1) 232 | ## entry1.save() 233 | ## self.assertEqual( 234 | ## # this should work too 235 | ## list(Entry.objects.filter(blog=blog1)), 236 | ## [entry1] 237 | ## ) 238 | # 239 | ## def test_standard_autofield(self): 240 | ## 241 | ## sam1 = StandardAutoFieldModel(title="title 1") 242 | ## sam1.save() 243 | ## sam2 = StandardAutoFieldModel(title="title 2") 244 | ## sam2.save() 245 | ## 246 | ## self.assertEqual( 247 | ## StandardAutoFieldModel.objects.count(), 248 | ## 2 249 | ## ) 250 | ## 251 | ## sam1_query = StandardAutoFieldModel.objects.get(title="title 1") 252 | ## self.assertEqual( 253 | ## sam1_query.pk, 254 | ## sam1.pk 255 | ## ) 256 | ## 257 | ## sam1_query = StandardAutoFieldModel.objects.get(pk=sam1.pk) 258 | ## 259 | # 260 | -------------------------------------------------------------------------------- /tests/testproj/myapp/views.py: -------------------------------------------------------------------------------- 1 | # Create your views here. 2 | -------------------------------------------------------------------------------- /tests/testproj/settings.py: -------------------------------------------------------------------------------- 1 | # Django settings for testproj2 project. 2 | 3 | DEBUG = True 4 | TEMPLATE_DEBUG = DEBUG 5 | 6 | ADMINS = ( 7 | # ('Your Name', 'your_email@domain.com'), 8 | ) 9 | 10 | MANAGERS = ADMINS 11 | 12 | DATABASES = { 13 | # 'default': { 14 | # 'ENGINE': 'sqlite3', # Add 'postgresql_psycopg2', 'postgresql', 'mysql', 'sqlite3' or 'oracle'. 15 | # 'NAME': 'test.db', # Or path to database file if using sqlite3. 16 | # 'USER': '', # Not used with sqlite3. 17 | # 'PASSWORD': '', # Not used with sqlite3. 18 | # 'HOST': '', # Set to empty string for localhost. Not used with sqlite3. 19 | # 'PORT': '', # Set to empty string for default. Not used with sqlite3. 20 | # }, 21 | 'default': { 22 | 'ENGINE': 'django_elasticsearch', 23 | 'NAME': 'test', 24 | 'USER': '', 25 | 'PASSWORD': '', 26 | 'HOST': 'localhost', 27 | 'PORT': '9200', 28 | 'SUPPORTS_TRANSACTIONS': False, 29 | }, 30 | } 31 | 32 | # Local time zone for this installation. Choices can be found here: 33 | # http://en.wikipedia.org/wiki/List_of_tz_zones_by_name 34 | # although not all choices may be available on all operating systems. 35 | # On Unix systems, a value of None will cause Django to use the same 36 | # timezone as the operating system. 37 | # If running in a Windows environment this must be set to the same as your 38 | # system time zone. 39 | TIME_ZONE = 'America/Chicago' 40 | 41 | # Language code for this installation. All choices can be found here: 42 | # http://www.i18nguy.com/unicode/language-identifiers.html 43 | LANGUAGE_CODE = 'en-us' 44 | 45 | SITE_ID = 1 46 | 47 | # If you set this to False, Django will make some optimizations so as not 48 | # to load the internationalization machinery. 49 | USE_I18N = True 50 | 51 | # If you set this to False, Django will not format dates, numbers and 52 | # calendars according to the current locale 53 | USE_L10N = True 54 | 55 | # Absolute path to the directory that holds media. 56 | # Example: "/home/media/media.lawrence.com/" 57 | MEDIA_ROOT = '' 58 | 59 | # URL that handles the media served from MEDIA_ROOT. Make sure to use a 60 | # trailing slash if there is a path component (optional in other cases). 61 | # Examples: "http://media.lawrence.com", "http://example.com/media/" 62 | MEDIA_URL = '' 63 | 64 | # URL prefix for admin media -- CSS, JavaScript and images. Make sure to use a 65 | # trailing slash. 66 | # Examples: "http://foo.com/media/", "/media/". 67 | ADMIN_MEDIA_PREFIX = '/media/' 68 | 69 | # Make this unique, and don't share it with anybody. 70 | SECRET_KEY = 'ju^y4b6j4w%)346pf8oxbw=po8)-)hd3ugq=jjw4x38ugf#_0c' 71 | 72 | # List of callables that know how to import templates from various sources. 73 | TEMPLATE_LOADERS = ( 74 | 'django.template.loaders.filesystem.Loader', 75 | 'django.template.loaders.app_directories.Loader', 76 | # 'django.template.loaders.eggs.Loader', 77 | ) 78 | 79 | MIDDLEWARE_CLASSES = ( 80 | 'django.middleware.common.CommonMiddleware', 81 | 'django.contrib.sessions.middleware.SessionMiddleware', 82 | 'django.middleware.csrf.CsrfViewMiddleware', 83 | 'django.contrib.auth.middleware.AuthenticationMiddleware', 84 | 'django.contrib.messages.middleware.MessageMiddleware', 85 | ) 86 | 87 | ROOT_URLCONF = 'testproj.urls' 88 | 89 | TEMPLATE_DIRS = ( 90 | # Put strings here, like "/home/html/django_templates" or "C:/www/django/templates". 91 | # Always use forward slashes, even on Windows. 92 | # Don't forget to use absolute paths, not relative paths. 93 | ) 94 | 95 | INSTALLED_APPS = ( 96 | 'django.contrib.auth', 97 | 'django.contrib.contenttypes', 98 | 'django.contrib.sessions', 99 | 'django.contrib.sites', 100 | 'django.contrib.messages', 101 | 'testproj.myapp', 102 | 'testproj.mixed', 103 | #'south', 104 | # Uncomment the next line to enable the admin: 105 | 'django.contrib.admin', 106 | ) 107 | 108 | #DATABASE_ROUTERS = ['django_elasticsearch.router.ESRouter'] 109 | ELASTICSEARCH_MANAGED_APPS = ['testproj.myapp', ] 110 | ELASTICSEARCH_MANAGED_MODELS = ['mixed.record', ] 111 | 112 | #SOUTH_DATABASE_ADAPTERS = { "default" : "django_elasticsearch.south"} 113 | -------------------------------------------------------------------------------- /tests/testproj/tests.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["DJANGO_SETTINGS_MODULE"] = "notsqltestproj.settings" 3 | 4 | from myapp.models import Entry, Person 5 | 6 | doc, c = Person.elasticsearch.get_or_create(name="Pippo", defaults={'surname' : "Pluto", 'age' : 10}) 7 | print doc.pk 8 | print doc.surname 9 | print doc.age 10 | 11 | cursor = Person.elasticsearch.filter(age=10) 12 | print cursor[0] --------------------------------------------------------------------------------