├── requirements.txt ├── tests ├── jsonb │ ├── __init__.py │ ├── models.py │ └── tests.py ├── arrays │ ├── __init__.py │ ├── models.py │ └── tests.py ├── benchmarks │ ├── __init__.py │ └── models.py ├── hstores │ ├── __init__.py │ ├── models.py │ └── tests.py ├── modeladmin │ ├── __init__.py │ ├── urls.py │ ├── admin.py │ ├── models.py │ └── tests.py ├── m2m_multiple_array │ ├── __init__.py │ ├── models.py │ └── tests.py ├── m2m_recursive_array │ ├── __init__.py │ ├── models.py │ └── tests.py ├── many_to_many_array │ ├── __init__.py │ └── models.py ├── nested_form_field_widget │ ├── __init__.py │ └── tests.py ├── prefetch_related_array │ ├── __init__.py │ ├── test_uuid.py │ ├── test_prefetch_related_objects.py │ └── models.py ├── urls.py └── test_postgres.py ├── requirements_docs.txt ├── setup.cfg ├── django_postgres_extensions ├── admin │ ├── __init__.py │ └── options.py ├── backends │ ├── __init__.py │ └── postgresql │ │ ├── __init__.py │ │ ├── operations.py │ │ ├── schema.py │ │ ├── base.py │ │ └── creation.py ├── models │ ├── sql │ │ ├── __init__.py │ │ ├── updates.py │ │ ├── subqueries.py │ │ ├── datastructures.py │ │ └── compiler.py │ ├── __init__.py │ ├── fields │ │ ├── related_lookups.py │ │ ├── reverse_related.py │ │ ├── __init__.py │ │ ├── related_descriptors.py │ │ └── related.py │ ├── expressions.py │ ├── lookups.py │ ├── functions.py │ └── query.py ├── forms │ ├── __init__.py │ ├── fields.py │ └── widgets.py ├── __init__.py ├── templates │ └── django_postgres_extensions │ │ └── nested_form_widget.html ├── signals.py ├── apps.py └── utils.py ├── MANIFEST.in ├── docs ├── admin_form.jpg ├── admin_list.jpg ├── array_split.jpg ├── json_field.jpg ├── array_choice.jpg ├── hstore_field.jpg ├── queryset.rst ├── intro.rst ├── settings.py ├── Makefile ├── testing.rst ├── index.rst ├── make.bat ├── arraym2m.rst ├── json.rst ├── hstores.rst ├── arrays.rst ├── conf.py └── features.rst ├── .gitignore ├── description.rst ├── setup.py ├── LICENSE └── readme.rst /requirements.txt: -------------------------------------------------------------------------------- 1 | django -------------------------------------------------------------------------------- /tests/jsonb/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/arrays/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/hstores/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/modeladmin/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/m2m_multiple_array/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/m2m_recursive_array/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/many_to_many_array/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements_docs.txt: -------------------------------------------------------------------------------- 1 | sphinxcontrib-fancybox -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal = 1 -------------------------------------------------------------------------------- /tests/nested_form_field_widget/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/prefetch_related_array/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /django_postgres_extensions/admin/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /django_postgres_extensions/backends/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /django_postgres_extensions/backends/postgresql/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include readme.rst 3 | include description.rst -------------------------------------------------------------------------------- /django_postgres_extensions/models/sql/__init__.py: -------------------------------------------------------------------------------- 1 | from .subqueries import UpdateQuery -------------------------------------------------------------------------------- /django_postgres_extensions/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .fields import * 2 | from .fields.related import * -------------------------------------------------------------------------------- /docs/admin_form.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/primal100/django_postgres_extensions/HEAD/docs/admin_form.jpg -------------------------------------------------------------------------------- /docs/admin_list.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/primal100/django_postgres_extensions/HEAD/docs/admin_list.jpg -------------------------------------------------------------------------------- /docs/array_split.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/primal100/django_postgres_extensions/HEAD/docs/array_split.jpg -------------------------------------------------------------------------------- /docs/json_field.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/primal100/django_postgres_extensions/HEAD/docs/json_field.jpg -------------------------------------------------------------------------------- /docs/array_choice.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/primal100/django_postgres_extensions/HEAD/docs/array_choice.jpg -------------------------------------------------------------------------------- /docs/hstore_field.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/primal100/django_postgres_extensions/HEAD/docs/hstore_field.jpg -------------------------------------------------------------------------------- /django_postgres_extensions/forms/__init__.py: -------------------------------------------------------------------------------- 1 | from .fields import NestedFormField 2 | from .widgets import NestedFormWidget -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.log 3 | .idea/* 4 | build/* 5 | dist/* 6 | docs/_* 7 | django_postgres_extensions.egg-info/* 8 | *.env 9 | todo.txt -------------------------------------------------------------------------------- /django_postgres_extensions/__init__.py: -------------------------------------------------------------------------------- 1 | default_app_config = 'django_postgres_extensions.apps.PSQLExtensionsConfig' 2 | __version__ = "0.9.3" 3 | -------------------------------------------------------------------------------- /tests/modeladmin/urls.py: -------------------------------------------------------------------------------- 1 | from django.conf.urls import url 2 | from django.contrib import admin 3 | 4 | urlpatterns = [ 5 | url(r'^admin/', admin.site.urls), 6 | ] 7 | -------------------------------------------------------------------------------- /tests/urls.py: -------------------------------------------------------------------------------- 1 | """This URLconf exists because Django expects ROOT_URLCONF to exist. URLs 2 | should be added within the test folders, and use TestCase.urls to set them. 3 | This helps the tests remain isolated. 4 | """ 5 | 6 | 7 | urlpatterns = [] -------------------------------------------------------------------------------- /tests/jsonb/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | from django_postgres_extensions.models.fields import JSONField 3 | 4 | class Product(models.Model): 5 | name = models.CharField(max_length=3) 6 | description = JSONField(null=True, blank=True) -------------------------------------------------------------------------------- /django_postgres_extensions/backends/postgresql/operations.py: -------------------------------------------------------------------------------- 1 | from django.db.backends.postgresql.operations import DatabaseOperations as BaseDatabaseOperations 2 | 3 | class DatabaseOperations(BaseDatabaseOperations): 4 | compiler_module = "django_postgres_extensions.models.sql.compiler" -------------------------------------------------------------------------------- /django_postgres_extensions/templates/django_postgres_extensions/nested_form_widget.html: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/queryset.rst: -------------------------------------------------------------------------------- 1 | Querysets 2 | ========= 3 | 4 | Additional Queryset Methods 5 | --------------------------- 6 | This app adds the format method to all querysets. This will defer a field and add an annotation with a different format. 7 | For example to return a hstorefield as json:: 8 | 9 | qs = Model.objects.all().format('description', HstoreToJSONBLoose) 10 | -------------------------------------------------------------------------------- /django_postgres_extensions/models/sql/updates.py: -------------------------------------------------------------------------------- 1 | class UpdateArrayByIndex(object): 2 | 3 | def __init__(self, indexes, value, field): 4 | self.indexes = indexes 5 | self.value = value 6 | self.base_field = field.base_field 7 | 8 | def alter_name(self, name, qn): 9 | for index in self.indexes: 10 | name += "[%s]" % index 11 | return name 12 | 13 | -------------------------------------------------------------------------------- /django_postgres_extensions/signals.py: -------------------------------------------------------------------------------- 1 | def delete_reverse_related(sender, signal, instance, using, **kwargs): 2 | for related in instance._meta.related_objects: 3 | field = related.field 4 | if getattr(field, 'many_to_many_array', False): 5 | accessor_name = field.get_reverse_accessor_name() 6 | accessor = getattr(instance, accessor_name) 7 | accessor.clear() 8 | -------------------------------------------------------------------------------- /tests/hstores/models.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | from django.db import models 4 | from django_postgres_extensions.models.fields import ArrayField, HStoreField 5 | 6 | class Product(models.Model): 7 | name = models.CharField(max_length=3) 8 | description = HStoreField(null=True, blank=True) 9 | details = HStoreField(null=True, blank=True) 10 | purchases = ArrayField(HStoreField(), null=True) -------------------------------------------------------------------------------- /description.rst: -------------------------------------------------------------------------------- 1 | Django Postgres Extensions adds a lot of functionality to Django.contrib.postgres, specifically in relation to ArrayField, HStoreField and JSONField, including much better form fields for dealing with these field types. The app also includes an Array Many To Many Field, so you can store the relationship in an array column instead of requiring an extra database table. 2 | 3 | Check out http://django-postgres-extensions.readthedocs.io/en/latest/ to get started. -------------------------------------------------------------------------------- /tests/modeladmin/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | from django_postgres_extensions.admin.options import PostgresAdmin 3 | from .models import Product, Buyer 4 | 5 | class ProductAdmin(PostgresAdmin): 6 | filter_horizontal = ('buyers',) 7 | fields = ('name', 'keywords', 'sports', 'shipping', 'details', 'buyers') 8 | list_display = ('name', 'keywords', 'shipping', 'details', 'country') 9 | 10 | admin.site.register(Buyer) 11 | admin.site.register(Product, ProductAdmin) 12 | 13 | -------------------------------------------------------------------------------- /docs/intro.rst: -------------------------------------------------------------------------------- 1 | Getting started 2 | =============== 3 | 4 | .. warning:: 5 | 6 | Although it generally should work on other versions, Django Postgres Extensions has been tested with Python 2.7.12, Python 3.6 and Django 1.10.5. 7 | 8 | Installation 9 | ------------- 10 | 11 | Install with ``pip install django_postgres_extensions`` 12 | 13 | Setup project 14 | ------------- 15 | 16 | In your settings.py, add 'django.contrib.postgres' and 'django_postgres_extensions' to the list of INSTALLED APPS and configure the database to use the included backend (subclassed from the default Django Postgres backend): 17 | 18 | .. literalinclude:: settings.py -------------------------------------------------------------------------------- /docs/settings.py: -------------------------------------------------------------------------------- 1 | INSTALLED_APPS = [ 2 | 'django.contrib.contenttypes', 3 | 'django.contrib.auth', 4 | 'django.contrib.sites', 5 | 'django.contrib.sessions', 6 | 'django.contrib.messages', 7 | 'django.contrib.admin.apps.SimpleAdminConfig', 8 | 'django.contrib.staticfiles', 9 | 'django.contrib.postgres', 10 | 'django_postgres_extensions' 11 | ] 12 | 13 | DATABASES = { 14 | 'default': { 15 | 'ENGINE': 'django_postgres_extensions.backends.postgresql', 16 | 'NAME': 'db', 17 | 'USER': 'postgres', 18 | 'PASSWORD': 'postgres', 19 | 'HOST': '127.0.0.1', 20 | 'PORT': 5432, 21 | } 22 | } -------------------------------------------------------------------------------- /tests/arrays/models.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | from django.db import models 4 | from django.contrib.postgres.fields import HStoreField 5 | from django_postgres_extensions.models.fields.related import ArrayField 6 | 7 | class Product(models.Model): 8 | name = models.CharField(max_length=3) 9 | tags = ArrayField(models.CharField(max_length=15), null=True, blank=True) 10 | moretags = ArrayField(models.CharField(max_length=15), null=True, blank=True) 11 | prices = ArrayField(models.IntegerField(), null=True, blank=True) 12 | description = HStoreField(null=True, blank=True) 13 | coordinates = ArrayField(ArrayField(models.IntegerField()), null=True) -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = DjangoPostgresExtensions 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/testing.rst: -------------------------------------------------------------------------------- 1 | Running Tests 2 | ============= 3 | 4 | Running 5 | ------- 6 | 7 | To run the tests 8 | 9 | ``$ git clone https://github.com/primal100/django_postgres_extensions.git dpe_repo`` 10 | 11 | ``$ cd dpe_repo/tests`` 12 | 13 | Configure the postgresql connection details in test_postgres.py. 14 | 15 | ``$ ./runtests.py --exclude-tag=benchmark`` 16 | 17 | Benchmarks 18 | ---------- 19 | 20 | Benchmark tests are included to compare performance of the Array M2M with the traditional Django table-based M2M. 21 | They can be quite slow and thus it is recommended to exclude them when running tests altogether as in the above example. 22 | 23 | They can be run with: 24 | 25 | ``$ ./runtests.py benchmarks.tests`` -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Django Postgres Extensions documentation master file, created by 2 | sphinx-quickstart on Wed Feb 8 21:57:16 2017. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Django Postgres Extensions's documentation! 7 | ====================================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | features 14 | intro 15 | arrays 16 | hstores 17 | json 18 | arraym2m 19 | queryset 20 | testing 21 | 22 | 23 | User Guide 24 | ========== 25 | :doc:`intro` describes how to get up and running with Django Postgres Extensions. 26 | :doc:`features` gives a basic overview of the features in Django Postgres Extensions. -------------------------------------------------------------------------------- /tests/benchmarks/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | from django_postgres_extensions.models.fields.related import ArrayManyToManyField 3 | 4 | class NumberTraditional(models.Model): 5 | index = models.IntegerField() 6 | 7 | def __str__(self): 8 | return self.index 9 | 10 | class Traditional(models.Model): 11 | index = models.IntegerField() 12 | numbers = models.ManyToManyField(NumberTraditional) 13 | 14 | def __str__(self): 15 | return self.index 16 | 17 | class NumberArray(models.Model): 18 | index = models.IntegerField() 19 | 20 | def __str__(self): 21 | return self.index 22 | 23 | class Array(models.Model): 24 | index = models.IntegerField() 25 | numbers = ArrayManyToManyField(NumberArray, db_index=False) 26 | 27 | def __str__(self): 28 | return self.index -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=DjangoPostgresExtensions 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /django_postgres_extensions/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | from django.utils.translation import ugettext_lazy as _ 3 | from django.db.models import query 4 | from django.db.models.sql import datastructures 5 | from .models.query import update, _update, format, prefetch_one_level 6 | from .models.sql.datastructures import as_sql 7 | from django.db.models.signals import pre_delete 8 | from .signals import delete_reverse_related 9 | from django.conf import settings 10 | 11 | 12 | class PSQLExtensionsConfig(AppConfig): 13 | name = 'django_postgres_extensions' 14 | verbose_name = _('Extra features for PostgreSQL fields') 15 | 16 | def ready(self): 17 | query.QuerySet.format = format 18 | query.QuerySet.update = update 19 | query.QuerySet._update = _update 20 | if getattr(settings, 'ENABLE_ARRAY_M2M', False): 21 | datastructures.Join.as_sql = as_sql 22 | query.prefetch_one_level = prefetch_one_level 23 | pre_delete.connect(delete_reverse_related) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | import django_postgres_extensions 4 | 5 | setup(name='django_postgres_extensions', 6 | version=django_postgres_extensions.__version__, 7 | description="Extra features for django.contrib.postgres", 8 | long_description=open('description.rst').read(), 9 | author='Paul Martin', 10 | author_email='greatestloginnameever@gmail.com', 11 | url='https://github.com/primal100/django_postgres_extensions', 12 | packages=find_packages(exclude=['tests', 'tests.*']), 13 | classifiers=[ 14 | 'Development Status :: 4 - Beta', 15 | 'Environment :: Web Environment', 16 | 'Framework :: Django', 17 | 'Intended Audience :: Developers', 18 | 'License :: OSI Approved :: BSD License', 19 | 'Operating System :: OS Independent', 20 | 'Programming Language :: Python', 21 | 'Programming Language :: Python :: 2', 22 | 'Programming Language :: Python :: 3', 23 | 'Topic :: Database', 24 | 'Topic :: Software Development :: Libraries :: Python Modules', 25 | ], 26 | ) -------------------------------------------------------------------------------- /django_postgres_extensions/backends/postgresql/schema.py: -------------------------------------------------------------------------------- 1 | from django.db.backends.postgresql import schema 2 | 3 | class DatabaseSchemaEditor(schema.DatabaseSchemaEditor): 4 | sql_create_array_index = "CREATE INDEX %(name)s ON %(table)s USING GIN (%(columns)s)%(extra)s" 5 | 6 | def _model_indexes_sql(self, model): 7 | output = super(DatabaseSchemaEditor, self)._model_indexes_sql(model) 8 | if not model._meta.managed or model._meta.proxy or model._meta.swapped: 9 | return output 10 | 11 | for field in model._meta.local_fields: 12 | array_index_statement = self._create_array_index_sql(model, field) 13 | if array_index_statement is not None: 14 | output.append(array_index_statement) 15 | return output 16 | 17 | def _create_array_index_sql(self, model, field): 18 | db_type = field.db_type(connection=self.connection) 19 | if db_type is not None and '[' in db_type and db_type.endswith(']') and (field.db_index or field.unique): 20 | return self._create_index_sql(model, [field], suffix='_gin', sql=self.sql_create_array_index) 21 | return None -------------------------------------------------------------------------------- /tests/m2m_multiple_array/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multiple many-to-many relationships between the same two tables 3 | 4 | In this example, an ``Article`` can have many "primary" ``Category`` objects 5 | and many "secondary" ``Category`` objects. 6 | 7 | Set ``related_name`` to designate what the reverse relationship is called. 8 | """ 9 | 10 | from django.db import models 11 | from django.utils.encoding import python_2_unicode_compatible 12 | from django_postgres_extensions.models.fields.related import ArrayManyToManyField 13 | 14 | @python_2_unicode_compatible 15 | class Category(models.Model): 16 | name = models.CharField(max_length=20) 17 | 18 | class Meta: 19 | ordering = ('name',) 20 | 21 | def __str__(self): 22 | return self.name 23 | 24 | 25 | @python_2_unicode_compatible 26 | class Article(models.Model): 27 | headline = models.CharField(max_length=50) 28 | pub_date = models.DateTimeField() 29 | primary_categories = ArrayManyToManyField(Category, related_name='primary_article_set') 30 | secondary_categories = ArrayManyToManyField(Category, related_name='secondary_article_set') 31 | 32 | class Meta: 33 | ordering = ('pub_date',) 34 | 35 | def __str__(self): 36 | return self.headline 37 | -------------------------------------------------------------------------------- /tests/m2m_recursive_array/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Many-to-many relationships between the same two tables 3 | 4 | In this example, a ``Person`` can have many friends, who are also ``Person`` 5 | objects. Friendship is a symmetrical relationship - if I am your friend, you 6 | are my friend. Here, ``friends`` is an example of a symmetrical 7 | ``ManyToManyField``. 8 | 9 | A ``Person`` can also have many idols - but while I may idolize you, you may 10 | not think the same of me. Here, ``idols`` is an example of a non-symmetrical 11 | ``ManyToManyField``. Only recursive ``ManyToManyField`` fields may be 12 | non-symmetrical, and they are symmetrical by default. 13 | 14 | This test validates that the many-to-many table is created using a mangled name 15 | if there is a name clash, and tests that symmetry is preserved where 16 | appropriate. 17 | """ 18 | 19 | from django.db import models 20 | from django.utils.encoding import python_2_unicode_compatible 21 | from django_postgres_extensions.models.fields.related import ArrayManyToManyField 22 | 23 | @python_2_unicode_compatible 24 | class Person(models.Model): 25 | name = models.CharField(max_length=20) 26 | friends = ArrayManyToManyField('self') 27 | idols = ArrayManyToManyField('self', symmetrical=False, related_name='stalkers') 28 | 29 | def __str__(self): 30 | return self.name 31 | -------------------------------------------------------------------------------- /tests/test_postgres.py: -------------------------------------------------------------------------------- 1 | # This is an example test settings file for use with the Django test suite. 2 | # 3 | # The 'postgresql' backend requires only the ENGINE setting (an in- 4 | # memory database will be used). All other backends will require a 5 | # NAME and potentially authentication information. See the 6 | # following section in the docs for more information: 7 | # 8 | # https://docs.djangoproject.com/en/dev/internals/contributing/writing-code/unit-tests/ 9 | # 10 | # The different databases that Django supports behave differently in certain 11 | # situations, so it is recommended to run the test suite against as many 12 | # database backends as possible. You may want to create a separate settings 13 | # file for each of the backends you test against. 14 | 15 | import os 16 | 17 | 18 | DATABASES = { 19 | 'default': { 20 | 'ENGINE': 'django_postgres_extensions.backends.postgresql', 21 | 'NAME': os.environ.get('POSTGRES_DBNAME', 'db'), 22 | 'USER': os.environ.get('POSTGRES_USER'), 23 | 'PASSWORD': os.environ.get('POSTGRES_PASSWORD'), 24 | 'PORT': 5432, 25 | } 26 | } 27 | 28 | ENABLE_ARRAY_M2M = True 29 | 30 | SECRET_KEY = "django_tests_secret_key" 31 | 32 | # Use a fast hasher to speed up tests. 33 | PASSWORD_HASHERS = [ 34 | 'django.contrib.auth.hashers.MD5PasswordHasher', 35 | ] 36 | -------------------------------------------------------------------------------- /django_postgres_extensions/backends/postgresql/base.py: -------------------------------------------------------------------------------- 1 | from django.db.backends.postgresql.base import DatabaseWrapper as BaseDatabaseWrapper 2 | from .schema import DatabaseSchemaEditor 3 | from .creation import DatabaseCreation 4 | from .operations import DatabaseOperations 5 | 6 | class DatabaseWrapper(BaseDatabaseWrapper): 7 | 8 | SchemaEditorClass = DatabaseSchemaEditor 9 | 10 | def __init__(self, *args, **kwargs): 11 | super(DatabaseWrapper, self).__init__(*args, **kwargs) 12 | self.creation = DatabaseCreation(self) 13 | self.ops = DatabaseOperations(self) 14 | 15 | self.any_operators = { 16 | 'exact': '= ANY(%s)', 17 | 'in': 'LIKE ANY(%s)', 18 | 'gt': '< ANY(%s)', 19 | 'gte': '<= ANY(%s)', 20 | 'lt': '> ANY(%s)', 21 | 'lte': '>= ANY(%s)', 22 | 'startof': 'LIKE ANY(%s)', 23 | 'endof': 'LIKE ANY(%s)', 24 | 'contains': '<@ ANY(%s)' 25 | } 26 | 27 | 28 | self.all_operators = { 29 | 'exact': '= ALL(%s)', 30 | 'in': 'LIKE ALL(%s)', 31 | 'gt': '< ALL(%s)', 32 | 'gte': '<= ALL(%s)', 33 | 'lt': '> ALL(%s)', 34 | 'lte': '>= ALL(%s)', 35 | 'startof': 'LIKE ALL(%s)', 36 | 'endof': 'LIKE ALL(%s)', 37 | 'contains': '<@ ALL(%s)' 38 | } 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) Paul Martin and all contributors. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, 8 | this list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright 11 | notice, this list of conditions and the following disclaimer in the 12 | documentation and/or other materials provided with the distribution. 13 | 14 | 3. My name may not be used to endorse or promote products 15 | derived from this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 21 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 24 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /tests/many_to_many_array/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Array Many-to-many relationships 3 | Tests taken from Django 1.9 Stable 4 | In this example, an ``Article`` can be published in multiple ``Publication`` 5 | objects, and a ``Publication`` has multiple ``Article`` objects. 6 | """ 7 | from __future__ import unicode_literals 8 | 9 | from django.db import models 10 | from django.utils.encoding import python_2_unicode_compatible 11 | from django_postgres_extensions.models.fields.related import ArrayManyToManyField 12 | 13 | 14 | @python_2_unicode_compatible 15 | class Publication(models.Model): 16 | title = models.CharField(max_length=30) 17 | 18 | def __str__(self): 19 | return self.title 20 | 21 | class Meta: 22 | ordering = ('title',) 23 | 24 | 25 | @python_2_unicode_compatible 26 | class Article(models.Model): 27 | headline = models.CharField(max_length=100) 28 | # Assign a unicode string as name to make sure the intermediary model is 29 | # correctly created. Refs #20207 30 | publications = ArrayManyToManyField(Publication, name='publications') 31 | 32 | def __str__(self): 33 | return self.headline 34 | 35 | class Meta: 36 | ordering = ('headline',) 37 | 38 | 39 | # Models to test correct related_name inheritance 40 | class AbstractArticle(models.Model): 41 | class Meta: 42 | abstract = True 43 | 44 | publications = ArrayManyToManyField(Publication, name='publications', related_name='+') 45 | 46 | 47 | class InheritedArticleA(AbstractArticle): 48 | pass 49 | 50 | 51 | class InheritedArticleB(AbstractArticle): 52 | pass -------------------------------------------------------------------------------- /tests/modeladmin/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | from django_postgres_extensions.models.fields import HStoreField, JSONField, ArrayField 3 | from django_postgres_extensions.models.fields.related import ArrayManyToManyField 4 | from django import forms 5 | from django.contrib.postgres.forms import SplitArrayField 6 | from django_postgres_extensions.forms.fields import NestedFormField 7 | 8 | details_fields = ( 9 | ('Brand', NestedFormField(keys=('Name', 'Country'))), 10 | ('Type', forms.CharField(max_length=25, required=False)), 11 | ('Colours', SplitArrayField(base_field=forms.CharField(max_length=10, required=False), size=10)), 12 | ) 13 | 14 | class Buyer(models.Model): 15 | time = models.DateTimeField(auto_now_add=True) 16 | name = models.CharField(max_length=20) 17 | 18 | def __str__(self): 19 | return self.name 20 | 21 | class Product(models.Model): 22 | name = models.CharField(max_length=15) 23 | keywords = ArrayField(models.CharField(max_length=20), default=[], form_size=10, blank=True) 24 | sports = ArrayField(models.CharField(max_length=20),default=[], blank=True, choices=( 25 | ('football', 'Football'), ('tennis', 'Tennis'), ('golf', 'Golf'), ('basketball', 'Basketball'), ('hurling', 'Hurling'), ('baseball', 'Baseball'))) 26 | shipping = HStoreField(keys=('Address', 'City', 'Region', 'Country'), blank=True, default={}) 27 | details = JSONField(fields=details_fields, blank=True, default={}) 28 | buyers = ArrayManyToManyField(Buyer) 29 | 30 | def __str__(self): 31 | return self.name 32 | 33 | @property 34 | def country(self): 35 | return self.shipping.get('Country', '') 36 | -------------------------------------------------------------------------------- /django_postgres_extensions/models/sql/subqueries.py: -------------------------------------------------------------------------------- 1 | from django.db.models.sql.subqueries import UpdateQuery as BaseUpdateQuery 2 | from django.utils import six 3 | from django.core.exceptions import FieldError 4 | 5 | class UpdateQuery(BaseUpdateQuery): 6 | def add_update_values(self, values): 7 | """ 8 | Convert a dictionary of field name to value mappings into an update 9 | query. This is the entry point for the public update() method on 10 | querysets. 11 | """ 12 | values_seq = [] 13 | for name, val in six.iteritems(values): 14 | if '__' in name: 15 | indexes = name.split('__') 16 | field_name = indexes.pop(0) 17 | field = self.get_meta().get_field(field_name) 18 | val = field.get_update_type(indexes, val) 19 | model = field.model 20 | else: 21 | field = self.get_meta().get_field(name) 22 | direct = not (field.auto_created and not field.concrete) or not field.concrete 23 | model = field.model._meta.concrete_model 24 | if not direct or (field.is_relation and field.many_to_many): 25 | raise FieldError( 26 | 'Cannot update model field %r (only non-relations and ' 27 | 'foreign keys permitted).' % field 28 | ) 29 | else: 30 | if model is not self.get_meta().model: 31 | self.add_related_update(model, field, val) 32 | continue 33 | values_seq.append((field, model, val)) 34 | return self.add_update_fields(values_seq) -------------------------------------------------------------------------------- /django_postgres_extensions/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | class OrderedSet(collections.MutableSet): 4 | """ 5 | With thanks to http://code.activestate.com/recipes/576694/ 6 | """ 7 | def __init__(self, iterable=None): 8 | self.end = end = [] 9 | end += [None, end, end] # sentinel node for doubly linked list 10 | self.map = {} # key --> [key, prev, next] 11 | if iterable is not None: 12 | self |= iterable 13 | 14 | def __len__(self): 15 | return len(self.map) 16 | 17 | def __contains__(self, key): 18 | return key in self.map 19 | 20 | def add(self, key): 21 | if key not in self.map: 22 | end = self.end 23 | curr = end[1] 24 | curr[2] = end[1] = self.map[key] = [key, curr, end] 25 | 26 | def discard(self, key): 27 | if key in self.map: 28 | key, prev, next = self.map.pop(key) 29 | prev[2] = next 30 | next[1] = prev 31 | 32 | def __iter__(self): 33 | end = self.end 34 | curr = end[2] 35 | while curr is not end: 36 | yield curr[0] 37 | curr = curr[2] 38 | 39 | def __reversed__(self): 40 | end = self.end 41 | curr = end[1] 42 | while curr is not end: 43 | yield curr[0] 44 | curr = curr[1] 45 | 46 | def pop(self, last=True): 47 | if not self: 48 | raise KeyError('set is empty') 49 | key = self.end[1][0] if last else self.end[2][0] 50 | self.discard(key) 51 | return key 52 | 53 | def __repr__(self): 54 | if not self: 55 | return '%s()' % (self.__class__.__name__,) 56 | return '%s(%r)' % (self.__class__.__name__, list(self)) 57 | 58 | def __eq__(self, other): 59 | if isinstance(other, OrderedSet): 60 | return len(self) == len(other) and list(self) == list(other) 61 | return set(self) == set(other) -------------------------------------------------------------------------------- /docs/arraym2m.rst: -------------------------------------------------------------------------------- 1 | Array Many To Many Field 2 | ======================== 3 | 4 | Basic Usage 5 | ----------- 6 | 7 | The Array Many To Many Field is designed be a drop-in replacement for the normal Django Many To Many Field 8 | except that it uses an array instead of a separate table to store relationships, but replicates many of the same features. 9 | In general, write queries are much faster than the traditional M2M however select queries are typically slower. 10 | 11 | To use this field, it is required to set ENABLE_ARRAY_M2M = True in settings.py (to enable the required monkey-patching):: 12 | 13 | ENABLE_ARRAY_M2M = True 14 | 15 | Then in models.py:: 16 | 17 | from django.db import models 18 | from django_postgres_extensions.models.fields import ArrayManyToManyField 19 | 20 | class Publication(models.Model): 21 | title = models.CharField(max_length=30) 22 | 23 | def __str__(self): 24 | return self.title 25 | 26 | class Meta: 27 | ordering = ('title',) 28 | 29 | class Article(models.Model): 30 | headline = models.CharField(max_length=100) 31 | publications = ArrayManyToManyField(Publication, name='publications') 32 | 33 | def __str__(self): 34 | return self.headline 35 | 36 | class Meta: 37 | ordering = ('headline',) 38 | 39 | The Array Many To Many field supports the following features which replicate the API of the regular Many To Many Field: 40 | 41 | - Descriptor queryset with add, remove, clear and set for both forward and reverse relationships 42 | - Prefetch related for both forward and reverse relationships 43 | - Lookups across relationships with filter for both forward and reverse relationships 44 | - Lookups across relationships with exclude for for forward relationships only 45 | 46 | You can find more information on how these features work in the Django documentation for the regular Many To Many Field: 47 | 48 | https://docs.djangoproject.com/en/1.9/topics/db/examples/many_to_many/ 49 | -------------------------------------------------------------------------------- /django_postgres_extensions/models/fields/related_lookups.py: -------------------------------------------------------------------------------- 1 | from django.db.models.fields.related_lookups import RelatedLookupMixin, MultiColSource, get_normalized_value 2 | from django.contrib.postgres.fields.array import ArrayContains, ArrayContainedBy, ArrayExact, ArrayOverlap 3 | 4 | from django_postgres_extensions.models.lookups import ( 5 | AnyExact, AnyGreaterThan, AnyLessThan, AnyGreaterThanOrEqual, AnyLessThanOrEqual, ContainsItem) 6 | 7 | class RelatedArrayMixin(RelatedLookupMixin): 8 | def get_prep_lookup(self): 9 | self.lookup_name = 'contains' 10 | if not isinstance(self.lhs, MultiColSource) and self.rhs_is_direct_value(): 11 | self.rhs = [get_normalized_value(value, self.lhs)[0] for value in self.rhs] 12 | if hasattr(self.lhs.output_field, 'get_path_info'): 13 | self.rhs = [self.lhs.output_field.get_path_info()[-1].target_fields[-1].get_prep_value(rhs) for rhs in 14 | self.rhs] 15 | self.lookup_name = 'exact' 16 | return super(RelatedLookupMixin, self).get_prep_lookup() 17 | 18 | class RelatedAnyExact(RelatedLookupMixin, AnyExact): 19 | pass 20 | 21 | class RelatedAnyGreaterThan(RelatedLookupMixin, AnyGreaterThan): 22 | pass 23 | 24 | class RelatedAnyLessThan(RelatedLookupMixin, AnyLessThan): 25 | pass 26 | 27 | class RelatedAnyGreaterThanOrEqual(RelatedLookupMixin, AnyGreaterThanOrEqual): 28 | pass 29 | 30 | class RelatedAnyLessThanOrEqual(RelatedLookupMixin, AnyLessThanOrEqual): 31 | pass 32 | 33 | class RelatedArrayExact(RelatedArrayMixin, ArrayExact): 34 | """ 35 | More like what exact should be. Checks the array contains the array of related objects and only those related objects 36 | """ 37 | pass 38 | 39 | class RelatedArrayContains(RelatedArrayMixin, ArrayContains): 40 | pass 41 | 42 | class RelatedContainsItem(RelatedArrayMixin, ContainsItem): 43 | pass 44 | 45 | class RelatedArrayContainedBy(RelatedArrayMixin, ArrayContainedBy): 46 | pass 47 | 48 | class RelatedArrayOverlap(RelatedArrayMixin, ArrayOverlap): 49 | pass -------------------------------------------------------------------------------- /django_postgres_extensions/models/sql/datastructures.py: -------------------------------------------------------------------------------- 1 | from django.db.models.sql.datastructures import Join as BaseJoin 2 | 3 | #class Join(BaseJoin): 4 | 5 | def as_sql(self, compiler, connection): 6 | """ 7 | Generates the full 8 | LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params 9 | clause for this join. 10 | """ 11 | join_conditions = [] 12 | params = [] 13 | qn = compiler.quote_name_unless_alias 14 | qn2 = connection.ops.quote_name 15 | 16 | # Add a join condition for each pair of joining columns. 17 | 18 | for index, (lhs_col, rhs_col) in enumerate(self.join_cols): 19 | if hasattr(self.join_field, 'get_join_on'): 20 | join_condition = self.join_field.get_join_on(qn(self.parent_alias), qn2(lhs_col), qn(self.table_alias), 21 | qn2(rhs_col)) 22 | join_conditions.append(join_condition) 23 | else: 24 | join_conditions.append('%s.%s = %s.%s' % ( 25 | qn(self.parent_alias), 26 | qn2(lhs_col), 27 | qn(self.table_alias), 28 | qn2(rhs_col), 29 | )) 30 | 31 | # Add a single condition inside parentheses for whatever 32 | # get_extra_restriction() returns. 33 | extra_cond = self.join_field.get_extra_restriction( 34 | compiler.query.where_class, self.table_alias, self.parent_alias) 35 | if extra_cond: 36 | extra_sql, extra_params = compiler.compile(extra_cond) 37 | join_conditions.append('(%s)' % extra_sql) 38 | params.extend(extra_params) 39 | 40 | if not join_conditions: 41 | # This might be a rel on the other end of an actual declared field. 42 | declared_field = getattr(self.join_field, 'field', self.join_field) 43 | raise ValueError( 44 | "Join generated an empty ON clause. %s did not yield either " 45 | "joining columns or extra restrictions." % declared_field.__class__ 46 | ) 47 | on_clause_sql = ' AND '.join(join_conditions) 48 | alias_str = '' if self.table_alias == self.table_name else (' %s' % self.table_alias) 49 | sql = '%s %s%s ON (%s)' % (self.join_type, qn(self.table_name), alias_str, on_clause_sql) 50 | return sql, params -------------------------------------------------------------------------------- /django_postgres_extensions/models/fields/reverse_related.py: -------------------------------------------------------------------------------- 1 | from django.db.models.fields.reverse_related import ForeignObjectRel 2 | from django.db.models.fields.related_lookups import RelatedExact, RelatedIn, RelatedGreaterThan, RelatedIsNull, \ 3 | RelatedLessThan, RelatedGreaterThanOrEqual, RelatedLessThanOrEqual 4 | from django.core import exceptions 5 | 6 | class ArrayManyToManyRel(ForeignObjectRel): 7 | 8 | """ 9 | Used by ManyToManyFields to store information about the relation. 10 | 11 | ``_meta.get_fields()`` returns this class to provide access to the field 12 | flags for the reverse relation. 13 | """ 14 | def __init__(self, field, to, field_name, related_name=None, related_query_name=None, 15 | limit_choices_to=None, symmetrical=True): 16 | super(ArrayManyToManyRel, self).__init__( 17 | field, to, 18 | related_name=related_name, 19 | related_query_name=related_query_name, 20 | limit_choices_to=limit_choices_to, 21 | ) 22 | 23 | self.model_name = to 24 | self.field_name = field_name 25 | self.symmetrical = symmetrical 26 | 27 | def get_join_on(self, parent_alias, lhs_col, table_alias, rhs_col): 28 | return '%s.%s = ANY(%s.%s)' % ( 29 | parent_alias, 30 | lhs_col, 31 | table_alias, 32 | rhs_col, 33 | ) 34 | 35 | def set_field_name(self): 36 | self.field_name = self.field_name or self.model._meta.pk.name 37 | 38 | def get_related_field(self): 39 | """ 40 | Return the Field in the 'to' object to which this relationship is tied. 41 | """ 42 | field = self.model._meta.get_field(self.field_name) 43 | if not field.concrete: 44 | raise exceptions.FieldDoesNotExist("No related field named '%s'" % 45 | self.field_name) 46 | return field 47 | 48 | def get_lookup(self, lookup_name): 49 | if lookup_name == 'in': 50 | return RelatedIn 51 | elif lookup_name == 'exact': 52 | return RelatedExact 53 | elif lookup_name == 'gt': 54 | return RelatedGreaterThan 55 | elif lookup_name == 'gte': 56 | return RelatedGreaterThanOrEqual 57 | elif lookup_name == 'lt': 58 | return RelatedLessThan 59 | elif lookup_name == 'lte': 60 | return RelatedLessThanOrEqual 61 | elif lookup_name == 'isnull': 62 | return RelatedIsNull 63 | else: 64 | raise TypeError('Related Field got invalid lookup: %s' % lookup_name) -------------------------------------------------------------------------------- /tests/m2m_multiple_array/tests.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | from datetime import datetime 4 | 5 | from django.test import TestCase 6 | 7 | from .models import Article, Category 8 | 9 | 10 | class M2MMultipleTests(TestCase): 11 | def test_multiple(self): 12 | c1, c2, c3, c4 = [ 13 | Category.objects.create(name=name) 14 | for name in ["Sports", "News", "Crime", "Life"] 15 | ] 16 | 17 | a1 = Article.objects.create( 18 | headline="Parrot steals", pub_date=datetime(2005, 11, 27) 19 | ) 20 | a1.primary_categories.add(c2, c3) 21 | a1.secondary_categories.add(c4) 22 | 23 | a2 = Article.objects.create( 24 | headline="Parrot runs", pub_date=datetime(2005, 11, 28) 25 | ) 26 | a2.primary_categories.add(c1, c2) 27 | a2.secondary_categories.add(c4) 28 | 29 | self.assertQuerysetEqual( 30 | a1.primary_categories.all(), [ 31 | "Crime", 32 | "News", 33 | ], 34 | lambda c: c.name 35 | ) 36 | self.assertQuerysetEqual( 37 | a2.primary_categories.all(), [ 38 | "News", 39 | "Sports", 40 | ], 41 | lambda c: c.name 42 | ) 43 | self.assertQuerysetEqual( 44 | a1.secondary_categories.all(), [ 45 | "Life", 46 | ], 47 | lambda c: c.name 48 | ) 49 | self.assertQuerysetEqual( 50 | c1.primary_article_set.all(), [ 51 | "Parrot runs", 52 | ], 53 | lambda a: a.headline 54 | ) 55 | self.assertQuerysetEqual( 56 | c1.secondary_article_set.all(), [] 57 | ) 58 | self.assertQuerysetEqual( 59 | c2.primary_article_set.all(), [ 60 | "Parrot steals", 61 | "Parrot runs", 62 | ], 63 | lambda a: a.headline 64 | ) 65 | self.assertQuerysetEqual( 66 | c2.secondary_article_set.all(), [] 67 | ) 68 | self.assertQuerysetEqual( 69 | c3.primary_article_set.all(), [ 70 | "Parrot steals", 71 | ], 72 | lambda a: a.headline 73 | ) 74 | self.assertQuerysetEqual( 75 | c3.secondary_article_set.all(), [] 76 | ) 77 | self.assertQuerysetEqual( 78 | c4.primary_article_set.all(), [] 79 | ) 80 | self.assertQuerysetEqual( 81 | c4.secondary_article_set.all(), [ 82 | "Parrot steals", 83 | "Parrot runs", 84 | ], 85 | lambda a: a.headline 86 | ) 87 | -------------------------------------------------------------------------------- /docs/json.rst: -------------------------------------------------------------------------------- 1 | JSONField 2 | ========= 3 | 4 | Basic Usage 5 | ----------- 6 | 7 | To use the JSON field:: 8 | 9 | from django.db import models 10 | from django_postgres_extensions.models.fields import JSONField 11 | 12 | class Product(models.Model): 13 | description = JSONField(null=True, blank=True) 14 | 15 | Individual keys 16 | --------------- 17 | 18 | - Get json values by key or key path:: 19 | 20 | from django_postgres_extensions.models.expressions import Key 21 | obj = Product.objectsannotate(Key('description', 'Details')).get() 22 | obj = Product.objects.annotate(Key('description', 'Details__Rating')).get() 23 | obj = Product.objects.annotate(Key('description', 'Tags__1')).get() 24 | 25 | - Update JSON Field by specific keys, leaving any others untouched:: 26 | 27 | Product.objects.update(description__ = {'Industry': 'Movie', 'Popularity': 'Very Popular'}) 28 | 29 | - Delete JSONField by key or key path:: 30 | 31 | Product.objects.update(description__del ='Details') 32 | Product.objects.update(description__del = 'Details__Release') 33 | Product.objects.update(description__del='Tags__1') 34 | 35 | Database functions 36 | ------------------ 37 | 38 | Various database functions are included for interacting with JSONFields: 39 | 40 | - JSONBSet: updates individual keys in the JSONField without modifying the others. 41 | 42 | - JSONBArrayLength: returns the length of a JSONField who's parent object is an array. 43 | 44 | 45 | Check the postgresql documentation for more information on these functions. 46 | These functions handle the arguments by converting them to the correct expressions automatically:: 47 | 48 | from django_postgres_extensions.models.functions import * 49 | from psycopg2.extras import Json 50 | obj = Product.objects.update(description = JSONBSet('description', ['Details', 'Genre'], Json('Heavy Metal'), True)) 51 | obj = Product.objects.update(description = JSONBSet('description', ['1', 'c'], Json('g'))) 52 | obj = Product.objects.queryset.annotate(tags_length=JSONBArrayLength('tags', 1)).get() 53 | 54 | Use With NestedFormField 55 | ------------------------ 56 | 57 | The same NestedFormField and NestedFormWidget referred in the HStore description can also be used with a JSON Field. 58 | To use it give the fields keyword argument:: 59 | 60 | details_fields = ( 61 | NestedFormField(label='Brand', keys=('Name', 'Country')), 62 | forms.CharField(label='Type', max_length=25, required=False), 63 | SplitArrayField(label='Colours', base_field=forms.CharField(max_length=10, required=False), size=10), 64 | ) 65 | 66 | class Product(models.Model): 67 | details = JSONField(fields=details_fields, blank=True, default={}) 68 | 69 | The field would look like: 70 | 71 | .. image:: json_field.jpg -------------------------------------------------------------------------------- /django_postgres_extensions/forms/fields.py: -------------------------------------------------------------------------------- 1 | from .widgets import NestedFormWidget 2 | from django.forms.fields import MultiValueField, CharField 3 | from django.core.exceptions import ValidationError 4 | 5 | 6 | class NestedFormField(MultiValueField): 7 | """ 8 | A Field that aggregates the logic of multiple Fields to create a nested form within a form. 9 | 10 | The compress method returns a dictionary of field names and values. 11 | 12 | Requires either a ``fields`` or ``keys`` argument but not both. 13 | 14 | The ``fields`` argument is a list of tuples; each tuple consisting of a field name and a field instance. 15 | If given a nested form will be created consisting of these fields. 16 | 17 | The ``keys`` argument is a list/tuple of field names. If given, a nested form will be created consisting of 18 | django.forms.CharField instances, with the given key names. This is primarily for use with 19 | django.contrib.postgres.HStoreField. By default, all fields are not required. 20 | 21 | To make all fields required set the ``require_all_fields`` argument to True. 22 | 23 | The ``max_value_length`` is ignored if the ``fields`` argument is given. If the ``keys`` argument is given, the max 24 | length for each CharField instance will be set to this value. 25 | 26 | Uses the NestedFormWidget. 27 | """ 28 | def __init__(self, fields=(), keys=(), require_all_fields=False, max_value_length=25, *args, **kwargs): 29 | if (fields and keys) or (not fields and not keys): 30 | raise ValueError("NestedFormField requires either a tuple of fields or keys but not both") 31 | 32 | if keys: 33 | fields = [] 34 | for key in keys: 35 | field = CharField(max_length=max_value_length, required=False) 36 | fields.append((key, field)) 37 | form_fields = [] 38 | widgets = [] 39 | self.labels = [] 40 | self.names = {} 41 | for field in fields: 42 | label = field[1].label or field[0] 43 | self.names[label] = field[0] 44 | self.labels.append(label) 45 | form_fields.append(field[1]) 46 | widgets.append(field[1].widget) 47 | widget = NestedFormWidget(self.labels, widgets, self.names) 48 | super(NestedFormField, self).__init__(*args, fields=form_fields, widget=widget, 49 | require_all_fields=require_all_fields, **kwargs) 50 | 51 | def compress(self, data_list): 52 | result = {} 53 | for i, label in enumerate(self.labels): 54 | name = self.names[label] 55 | result[name] = data_list[i] 56 | return result 57 | 58 | def to_python(self, value): 59 | if not value: 60 | return {} 61 | if isinstance(value, dict): 62 | return value 63 | else: 64 | raise ValidationError( 65 | self.error_messages['invalid_json'], 66 | code='invalid_json', 67 | ) 68 | -------------------------------------------------------------------------------- /django_postgres_extensions/backends/postgresql/creation.py: -------------------------------------------------------------------------------- 1 | from django.contrib.postgres.signals import register_type_handlers 2 | from django.db.backends.postgresql.creation import DatabaseCreation as BaseDatabaseCreation 3 | from django.conf import settings 4 | 5 | class DatabaseCreation(BaseDatabaseCreation): 6 | def create_test_db(self, verbosity=1, autoclobber=False, serialize=True, keepdb=False): 7 | """ 8 | Creates a test database, prompting the user for confirmation if the 9 | database already exists. Returns the name of the test database created. 10 | """ 11 | # Don't import django.core.management if it isn't needed. 12 | from django.core.management import call_command 13 | 14 | test_database_name = self._get_test_db_name() 15 | 16 | if verbosity >= 1: 17 | action = 'Creating' 18 | if keepdb: 19 | action = "Using existing" 20 | 21 | print("%s test database for alias %s..." % ( 22 | action, 23 | self._get_database_display_str(verbosity, test_database_name), 24 | )) 25 | 26 | # We could skip this call if keepdb is True, but we instead 27 | # give it the keepdb param. This is to handle the case 28 | # where the test DB doesn't exist, in which case we need to 29 | # create it, then just not destroy it. If we instead skip 30 | # this, we will get an exception. 31 | self._create_test_db(verbosity, autoclobber, keepdb) 32 | 33 | self.connection.close() 34 | settings.DATABASES[self.connection.alias]["NAME"] = test_database_name 35 | self.connection.settings_dict["NAME"] = test_database_name 36 | 37 | with self.connection.cursor() as cursor: 38 | for extension in ('hstore',): 39 | cursor.execute("CREATE EXTENSION IF NOT EXISTS %s" % extension) 40 | register_type_handlers(self.connection) 41 | 42 | # We report migrate messages at one level lower than that requested. 43 | # This ensures we don't get flooded with messages during testing 44 | # (unless you really ask to be flooded). 45 | call_command( 46 | 'migrate', 47 | verbosity=max(verbosity - 1, 0), 48 | interactive=False, 49 | database=self.connection.alias, 50 | run_syncdb=True, 51 | ) 52 | 53 | # We then serialize the current state of the database into a string 54 | # and store it on the connection. This slightly horrific process is so people 55 | # who are testing on databases without transactions or who are using 56 | # a TransactionTestCase still get a clean database on every test run. 57 | if serialize: 58 | self.connection._test_serialized_contents = self.serialize_db_to_string() 59 | 60 | call_command('createcachetable', database=self.connection.alias) 61 | 62 | # Ensure a connection for the side effect of initializing the test database. 63 | self.connection.ensure_connection() 64 | 65 | return test_database_name 66 | -------------------------------------------------------------------------------- /django_postgres_extensions/models/sql/compiler.py: -------------------------------------------------------------------------------- 1 | from django.db.models.sql.compiler import (SQLCompiler, SQLInsertCompiler, SQLUpdateCompiler as BaseUpdateCompiler, 2 | SQLAggregateCompiler, SQLDeleteCompiler) 3 | from django.core.exceptions import FieldError 4 | 5 | def no_quote_name(name): 6 | return name 7 | 8 | class SQLUpdateCompiler(BaseUpdateCompiler): 9 | def as_sql(self): 10 | """ 11 | Creates the SQL for this query. Returns the SQL string and list of 12 | parameters. 13 | """ 14 | self.pre_sql_setup() 15 | if not self.query.values: 16 | return '', () 17 | table = self.query.base_table 18 | qn = self.quote_name_unless_alias 19 | result = ['UPDATE %s' % qn(table)] 20 | result.append('SET') 21 | values, update_params = [], [] 22 | for field, model, val in self.query.values: 23 | self.name = name = field.column 24 | if hasattr(val, 'alter_name'): 25 | self.name = name = val.alter_name(name, qn) 26 | qn = no_quote_name 27 | val = val.value 28 | if hasattr(val, 'resolve_expression'): 29 | val = val.resolve_expression(self.query, allow_joins=False, for_save=True) 30 | if val.contains_aggregate: 31 | raise FieldError("Aggregate functions are not allowed in this query") 32 | elif hasattr(val, 'prepare_database_save'): 33 | if field.remote_field: 34 | val = field.get_db_prep_save( 35 | val.prepare_database_save(field), 36 | connection=self.connection, 37 | ) 38 | else: 39 | raise TypeError( 40 | "Tried to update field %s with a model instance, %r. " 41 | "Use a value compatible with %s." 42 | % (field, val, field.__class__.__name__) 43 | ) 44 | else: 45 | val = field.get_db_prep_save(val, connection=self.connection) 46 | 47 | # Getting the placeholder for the field. 48 | if hasattr(field, 'get_placeholder'): 49 | placeholder = field.get_placeholder(val, self, self.connection) 50 | else: 51 | placeholder = '%s' 52 | self.placeholder = placeholder 53 | if hasattr(val, 'as_sql'): 54 | sql, params = self.compile(val) 55 | values.append('%s = %s' % (qn(name), sql)) 56 | update_params.extend(params) 57 | elif val is not None: 58 | values.append('%s = %s' % (qn(name), placeholder)) 59 | update_params.append(val) 60 | else: 61 | values.append('%s = NULL' % qn(name)) 62 | if not values: 63 | return '', () 64 | result.append(', '.join(values)) 65 | where, params = self.compile(self.query.where) 66 | if where: 67 | result.append('WHERE %s' % where) 68 | return ' '.join(result), tuple(update_params + params) -------------------------------------------------------------------------------- /docs/hstores.rst: -------------------------------------------------------------------------------- 1 | HStoreField 2 | =========== 3 | 4 | Basic Usage 5 | ----------- 6 | To use the HStoreField:: 7 | 8 | from django.db import models 9 | from django_postgres_extensions.models.fields import HStoreField 10 | class Product(models.Model): 11 | description = HStoreField(null=True, blank=True) 12 | 13 | The customized Postgres HStoreField adds the following features: 14 | 15 | Individual keys 16 | --------------- 17 | 18 | - Get hstore values by key:: 19 | 20 | from django_postgres_extensions.models.expressions import Key, Keys 21 | obj = Product.objects.annotate(Key('description', 'Release')).get() 22 | obj = Product.objects.annotate(Keys('description', ['Industry', 'Release'])).get() 23 | 24 | - Update hstore by specific keys, leaving any others untouched:: 25 | 26 | Product.objects.update(description__ = {'Genre': 'Heavy Metal', 'Popularity': 'Very Popular'}) 27 | 28 | Database functions 29 | ------------------ 30 | 31 | Various database functions are included for interacting with HStores. 32 | 33 | - Slice: Return a dictionary with just the specified keys 34 | 35 | - Delete: Delete a key or list of keys from the hstore. Keys can also be deleted by specifying a dictionary 36 | 37 | - AKeys: Returns the hstore keys as a list 38 | 39 | - AVals: Returns the hstore values as a list 40 | 41 | - HStoreToArray: Returns the hstore as an array 42 | 43 | - HStoreToMatrix: Returns the hstore as a matrix 44 | 45 | - HstoreToJSONB: Returns the hstore as JSON, with values adapated to their correct Python data types (hstore normally only returns values as strings) 46 | 47 | - HstoreToJSONBLoose: Same as HstoreToJSONB, but attempt to distinguish numerical and Boolean values so they are unquoted in the JSON 48 | 49 | For more information on these functions, check the postgresql documentation for each one. 50 | These functions handle the arguments by converting them to the correct expressions automatically:: 51 | 52 | from django_postgres_extensions.models.functions import * 53 | obj = Product.objects.queryset.annotate(description_slice=Slice('description', ['Industry', 'Release'])).get() 54 | obj = Product.objects.update(description = Delete('description', 'Genre')) 55 | obj = Product.objects.update(description = Delete('description', ['Industry', 'Genre'])) 56 | Product.objects.update(description=Delete('description', {'Industry': 'Music', 'Release': 'Song', 'Genre': 'Rock'})) 57 | Product.objects.annotate(description_keys=AKeys('description')).get() 58 | Product.objects.annotate(description_values=AVals('description')).get() 59 | Product.objects.annotate(description_array=HStoreToArray('description')).get() 60 | Product.objects.annotate(description_matrix=HStoreToMatrix('description')).get() 61 | Product.objects.annotate(description_jsonb=HstoreToJSONB('description')).get() 62 | Product.objects.annotate(description_jsonb=HstoreToJSONBLoose('description')).get() 63 | 64 | Use With Nested Form Field 65 | -------------------------- 66 | 67 | django.contrib.postgres includes a HStoreField for forms where you have to enter a hstore value programatically. 68 | Django Postgres Extensions adds a NestedFormField and NestedFormWidget 69 | (subclassed from the Django MultiValue Field and Widget) for use with a HStore Field. 70 | To use it specify a list of fields or a list of keys as a keyword argument to the Hstore Model field, but not both:: 71 | 72 | class Product(models.Model): 73 | shipping = HStoreField(keys=('Address', 'City', 'Region', 'Country'), blank=True, default={}) 74 | 75 | The field would look like: 76 | 77 | .. image:: hstore_field.jpg -------------------------------------------------------------------------------- /django_postgres_extensions/admin/options.py: -------------------------------------------------------------------------------- 1 | from django.contrib.admin.options import ModelAdmin 2 | from django.utils.translation import string_concat, ugettext as _ 3 | from django.forms.widgets import CheckboxSelectMultiple, SelectMultiple 4 | from django_postgres_extensions.models import ArrayManyToManyField 5 | from django.contrib.admin import widgets 6 | 7 | class PostgresAdmin(ModelAdmin): 8 | 9 | def formfield_for_manytomany(self, db_field, request=None, **kwargs): 10 | """ 11 | Get a form Field for a ManyToManyField. 12 | """ 13 | # If it uses an intermediary model that isn't auto created, don't show 14 | # a field in admin. 15 | if hasattr(db_field, 'through') and not db_field.remote_field.through._meta.auto_created: 16 | return None 17 | db = kwargs.get('using') 18 | 19 | if db_field.name in self.raw_id_fields: 20 | kwargs['widget'] = widgets.ManyToManyRawIdWidget(db_field.remote_field, 21 | self.admin_site, using=db) 22 | kwargs['help_text'] = '' 23 | elif db_field.name in (list(self.filter_vertical) + list(self.filter_horizontal)): 24 | kwargs['widget'] = widgets.FilteredSelectMultiple( 25 | db_field.verbose_name, 26 | db_field.name in self.filter_vertical 27 | ) 28 | 29 | if 'queryset' not in kwargs: 30 | queryset = self.get_field_queryset(db, db_field, request) 31 | if queryset is not None: 32 | kwargs['queryset'] = queryset 33 | 34 | form_field = db_field.formfield(**kwargs) 35 | if isinstance(form_field.widget, SelectMultiple) and not isinstance(form_field.widget, CheckboxSelectMultiple): 36 | msg = _('Hold down "Control", or "Command" on a Mac, to select more than one.') 37 | help_text = form_field.help_text 38 | form_field.help_text = string_concat(help_text, ' ', msg) if help_text else msg 39 | return form_field 40 | 41 | def formfield_for_dbfield(self, db_field, request, **kwargs): 42 | 43 | # ForeignKey or ManyToManyFields 44 | if isinstance(db_field, ArrayManyToManyField): 45 | # Combine the field kwargs with any options for formfield_overrides. 46 | # Make sure the passed in **kwargs override anything in 47 | # formfield_overrides because **kwargs is more specific, and should 48 | # always win. 49 | if db_field.__class__ in self.formfield_overrides: 50 | kwargs = dict(self.formfield_overrides[db_field.__class__], **kwargs) 51 | 52 | formfield = self.formfield_for_manytomany(db_field, request, **kwargs) 53 | 54 | # For non-raw_id fields, wrap the widget with a wrapper that adds 55 | # extra HTML -- the "add other" interface -- to the end of the 56 | # rendered output. formfield can be None if it came from a 57 | # OneToOneField with parent_link=True or a M2M intermediary. 58 | if formfield and db_field.name not in self.raw_id_fields: 59 | related_modeladmin = self.admin_site._registry.get(db_field.remote_field.model) 60 | wrapper_kwargs = {} 61 | if related_modeladmin: 62 | wrapper_kwargs.update( 63 | can_add_related=related_modeladmin.has_add_permission(request), 64 | can_change_related=related_modeladmin.has_change_permission(request), 65 | can_delete_related=related_modeladmin.has_delete_permission(request), 66 | ) 67 | formfield.widget = widgets.RelatedFieldWidgetWrapper( 68 | formfield.widget, db_field.remote_field, self.admin_site, **wrapper_kwargs 69 | ) 70 | 71 | return formfield 72 | else: 73 | return super(PostgresAdmin, self).formfield_for_dbfield(db_field, request, **kwargs) -------------------------------------------------------------------------------- /django_postgres_extensions/models/expressions.py: -------------------------------------------------------------------------------- 1 | from django.db.models.expressions import F as BaseF, Value as BaseValue, Func, Expression 2 | from django.utils import six 3 | from django.contrib.postgres.fields.array import IndexTransform 4 | from django.utils.functional import cached_property 5 | from django.db.models.lookups import Transform 6 | 7 | class OperatorMixin(object): 8 | CAT = '||' 9 | REPLACE = '#=' 10 | DELETE = '#-' 11 | KEY = '->' 12 | KEYTEXT = '->>' 13 | PATH = '#>' 14 | PATHTEXT = '#>>' 15 | 16 | def cat(self, other): 17 | return self._combine(other, self.CAT, False) 18 | 19 | def replace(self, other): 20 | return self._combine(other, self.REPLACE, False) 21 | 22 | def delete(self, other): 23 | return self._combine(other, self.DELETE, False) 24 | 25 | def key(self, other): 26 | return self._combine(other, self.KEY, False) 27 | 28 | def keytext(self, other): 29 | return self._combine(other, self.KEYTEXT, False) 30 | 31 | def path(self, other): 32 | return self._combine(other, self.PATH, False) 33 | 34 | def pathtext(self, other): 35 | return self._combine(other, self.PATHTEXT, False) 36 | 37 | class F(BaseF, OperatorMixin): 38 | pass 39 | 40 | class Value(BaseValue, OperatorMixin): 41 | def as_sql(self, compiler, connection): 42 | if self._output_field_or_none and any(self._output_field_or_none.get_internal_type() == fieldname for fieldname in 43 | ['ArrayField', 'MultiReferenceArrayField']): 44 | base_field = self._output_field_or_none.base_field 45 | return '%s::%s[]' % ('%s', base_field.db_type(connection)), [self.value] 46 | return super(Value, self).as_sql(compiler, connection) 47 | 48 | class Index(IndexTransform): 49 | def __init__(self, field, index, *args, **kwargs): 50 | if not isinstance(field, Expression): 51 | field = F(field) 52 | super(Index, self).__init__(index + 1, None, field, *args, **kwargs) 53 | 54 | @cached_property 55 | def default_alias(self): 56 | return '%s__%s' % (self.lhs.name, self.index - 1) 57 | 58 | @property 59 | def name(self): 60 | return self.default_alias 61 | 62 | @property 63 | def output_field(self): 64 | return self.lhs.field.base_field 65 | 66 | class SliceArray(Transform): 67 | def __init__(self, field, *indexes, **kwargs): 68 | if isinstance(field, SliceArray): 69 | self.multidimensional = True 70 | else: 71 | field = F(field) 72 | self.multidimensional = False 73 | self.indexes = [i + 1 for i in indexes] 74 | super(SliceArray, self).__init__(field, **kwargs) 75 | 76 | def as_sql(self, compiler, connection): 77 | lhs, params = compiler.compile(self.lhs) 78 | return '%s[%s:%s]' % (lhs, self.indexes[0], self.indexes[1]), params 79 | 80 | @cached_property 81 | def default_alias(self): 82 | return '%s__%s_%s' % (self.lhs.name, self.indexes[0] - 1, self.indexes[1] - 1) 83 | 84 | @property 85 | def name(self): 86 | return self.default_alias 87 | 88 | @property 89 | def output_field(self): 90 | if self.multidimensional: 91 | return self.lhs.field 92 | return self.lhs.field.base_field 93 | 94 | 95 | def Key(field, keys_string): 96 | if isinstance(keys_string, six.string_types) and '__' in keys_string: 97 | keys = keys_string.split('__') 98 | expression = F(field).path(Value(keys)) 99 | else: 100 | expression = F(field).key(Value(keys_string)) 101 | expression.default_alias = "%s__%s" % (field, keys_string) 102 | return expression 103 | 104 | def Keys(field, keys): 105 | expression = F(field).key(Value(keys)) 106 | expression.default_alias = "%s__selected" % field 107 | return expression 108 | -------------------------------------------------------------------------------- /django_postgres_extensions/models/lookups.py: -------------------------------------------------------------------------------- 1 | from django.contrib.postgres.fields.array import ArrayField, ArrayContains 2 | from django.contrib.postgres.fields.jsonb import JSONField 3 | from django.db.models.lookups import BuiltinLookup, In, \ 4 | Contains, StartsWith, EndsWith 5 | 6 | class BaseAnyAllLookupMixin(object): 7 | def get_rhs_op(self, connection, rhs): 8 | if self.lookup_name == self.db_func: 9 | lookup_name = 'exact' 10 | else: 11 | lookup_name = self.lookup_name.split('%s_' % self.db_func)[1] 12 | operators = getattr(connection, '%s_operators' % self.db_func) 13 | return operators[lookup_name] % rhs 14 | 15 | def as_sql(self, compiler, connection): 16 | rhs_sql, rhs_params = self.process_lhs(compiler, connection) 17 | lhs_sql, params = self.process_rhs(compiler, connection) 18 | params.extend(rhs_params) 19 | rhs_sql = self.get_rhs_op(connection, rhs_sql) 20 | return '%s %s' % (lhs_sql, rhs_sql), params 21 | 22 | class AnyLookupMixin(BaseAnyAllLookupMixin, BuiltinLookup): 23 | db_func = 'any' 24 | 25 | class AllLookupMixin(BaseAnyAllLookupMixin, BuiltinLookup): 26 | db_func = 'all' 27 | 28 | @ArrayField.register_lookup 29 | class Any(AnyLookupMixin, BuiltinLookup): 30 | lookup_name = 'any' 31 | 32 | @ArrayField.register_lookup 33 | class AnyExact(AnyLookupMixin, BuiltinLookup): 34 | lookup_name = 'any_exact' 35 | 36 | @ArrayField.register_lookup 37 | class AnyGreaterThan(AnyLookupMixin, BuiltinLookup): 38 | lookup_name = 'any_gt' 39 | 40 | @ArrayField.register_lookup 41 | class AnyGreaterThanOrEqual(AnyLookupMixin, BuiltinLookup): 42 | lookup_name = 'any_gte' 43 | 44 | @ArrayField.register_lookup 45 | class AnyLessThan(AnyLookupMixin, BuiltinLookup): 46 | lookup_name = 'any_lt' 47 | 48 | @ArrayField.register_lookup 49 | class AnyLessThanOrEqual(AnyLookupMixin, BuiltinLookup): 50 | lookup_name = 'any_lte' 51 | 52 | @ArrayField.register_lookup 53 | class AnyLessThanOrEqual(AnyLookupMixin, Contains): 54 | lookup_name = 'any_in' 55 | 56 | @ArrayField.register_lookup 57 | class AnyStartOf(AnyLookupMixin, StartsWith): 58 | lookup_name = 'any_isstartof' 59 | 60 | @ArrayField.register_lookup 61 | class AnyEndOf(AnyLookupMixin, EndsWith): 62 | lookup_name = 'any_isendof' 63 | 64 | @ArrayField.register_lookup 65 | class All(AllLookupMixin): 66 | lookup_name = 'all' 67 | 68 | @ArrayField.register_lookup 69 | class AllExact(AllLookupMixin, BuiltinLookup): 70 | lookup_name = 'all_exact' 71 | 72 | @ArrayField.register_lookup 73 | class AllGreaterThan(AllLookupMixin, BuiltinLookup): 74 | lookup_name = 'all_gt' 75 | 76 | @ArrayField.register_lookup 77 | class AllGreaterThanOrEqual(AllLookupMixin, BuiltinLookup): 78 | lookup_name = 'all_gte' 79 | 80 | @ArrayField.register_lookup 81 | class AllLessThan(AllLookupMixin, BuiltinLookup): 82 | lookup_name = 'all_lt' 83 | 84 | @ArrayField.register_lookup 85 | class AllLessThanOrEqual(AllLookupMixin, BuiltinLookup): 86 | lookup_name = 'all_lte' 87 | 88 | @ArrayField.register_lookup 89 | class AllIn(AllLookupMixin, Contains): 90 | lookup_name = 'all_in' 91 | 92 | @ArrayField.register_lookup 93 | class AllStartOf(AllLookupMixin, StartsWith): 94 | lookup_name = 'all_isstartof' 95 | 96 | @ArrayField.register_lookup 97 | class AllEndOf(AnyLookupMixin, EndsWith): 98 | lookup_name = 'all_isendof' 99 | 100 | @ArrayField.register_lookup 101 | class AnyRegex(AnyLookupMixin, EndsWith): 102 | lookup_name = 'all_regex' 103 | 104 | @ArrayField.register_lookup 105 | class AnyContains(AnyLookupMixin, ArrayContains): 106 | lookup_name = 'any_contains' 107 | 108 | class ContainsItem(ArrayContains): 109 | lookup_name = 'contains' 110 | def __init__(self, lhs, rhs): 111 | if not isinstance(rhs, (list, tuple)): 112 | rhs = [rhs] 113 | super(ContainsItem, self).__init__(lhs, rhs) -------------------------------------------------------------------------------- /django_postgres_extensions/forms/widgets.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | from django.forms import widgets 4 | from django.utils.safestring import mark_safe 5 | from django.utils.html import format_html 6 | import copy 7 | 8 | class NestedFormWidget(widgets.MultiWidget): 9 | """ 10 | A widget that is composed of multiple widgets with labels in a list. 11 | 12 | Its render() method differs from the MultiWidget in that it adds support for labels 13 | and presents the widgets in a list. 14 | 15 | For initial values, the decompress method expects a dictionary of key names and values. 16 | The widget's value_from_datadict method returns an array of values. 17 | 18 | Its render() method differs from the MultiWidget in that it adds support for labels 19 | and presents the widgets in a list. It deals with widget labels which may be different 20 | from the key names for that widget. 21 | 22 | The ``labels`` argument is a list/tuple of label names which will also be used for css names and ids. 23 | The ``widgets`` arguments is a list of widgets. Labels will be matched to widget by index. 24 | The optional ``names`` argument is a dictionary of labels to names. The name refers to the key name in the 25 | python dictionary given to the decompress method. If this argument is not given, labels and key names 26 | will be assumed to be the same. 27 | 28 | You'll probably want to use this class with NestedFormField. 29 | """ 30 | template_name = "django_postgres_extensions/nested_form_widget.html" 31 | 32 | def __init__(self, labels, widgets, names=None, attrs=None, template_name=None): 33 | self.template_name = template_name or self.template_name 34 | self.labels = labels 35 | if names: 36 | self.names = names 37 | else: 38 | self.names = {label: label for label in labels} 39 | self.id_names = [label.lower().replace(" ", "") for label in labels] 40 | super(NestedFormWidget, self).__init__(widgets, attrs=attrs) 41 | 42 | def render(self, name, value, attrs=None, renderer=None): 43 | if self.is_localized: 44 | for widget in self.widgets: 45 | widget.is_localized = self.is_localized 46 | # value is a list of values, each corresponding to a widget 47 | # in self.widgets. 48 | if not isinstance(value, list): 49 | value = self.decompress(value) 50 | final_attrs = self.build_attrs(attrs or {}) 51 | base_id_ = final_attrs.get('id') 52 | for i, widget in enumerate(self.widgets): 53 | try: 54 | widget_value = value[i] 55 | except IndexError: 56 | widget_value = None 57 | widget.label = self.labels[i] 58 | id_name = self.id_names[i] 59 | if base_id_: 60 | id_ = '%s_%s' % (base_id_, id_name) 61 | final_attrs = dict(final_attrs, id=id_) 62 | if widget.id_for_label: 63 | widget.label_for = id_ 64 | else: 65 | widget.label_for = '' 66 | else: 67 | widget.label_for = '' 68 | widget.html = mark_safe(widget.render('%s_%s' % (name, id_name), widget_value, final_attrs)) 69 | context = {'widgets': self.widgets} 70 | return self._render(self.template_name, context, renderer) 71 | 72 | def value_from_datadict(self, data, files, name): 73 | return [widget.value_from_datadict(data, files, "%s_%s" % (name, self.id_names[i])) for i, widget in 74 | enumerate(self.widgets)] 75 | 76 | def decompress(self, value): 77 | if not value: 78 | return [] 79 | values = [value[self.names[label]] for label in self.labels] 80 | return values 81 | 82 | def value_omitted_from_data(self, data, files, name): 83 | return False 84 | 85 | def __deepcopy__(self, memo): 86 | obj = super(NestedFormWidget, self).__deepcopy__(memo) 87 | obj.labels = copy.deepcopy(self.labels) 88 | return obj -------------------------------------------------------------------------------- /docs/arrays.rst: -------------------------------------------------------------------------------- 1 | ArrayField 2 | ========== 3 | 4 | Basic Usage 5 | ----------- 6 | 7 | To use the ArrayField:: 8 | 9 | from django.db import models 10 | from django_postgres_extensions.models.fields import ArrayField 11 | class Product(models.Model): 12 | tags = ArrayField(models.CharField(max_length=15), null=True, blank=True) 13 | moretags = ArrayField(models.CharField(max_length=15), null=True, blank=True) 14 | 15 | Array Indexes 16 | ------------- 17 | 18 | - Get array values by index:: 19 | 20 | from django_postgres_extensions.models.expressions import Index, SliceArray 21 | obj = Product.objects.annotate(Index('tags', 1)).get() 22 | print(obj.tags__1) 23 | obj = Product.objects.annotate(tag_1=Index('tags', 1)).get() 24 | print(obj.tag_1) 25 | obj = Product.objects.annotate(SliceArray('tags', 0, 1)).get() 26 | print(obj.tags__0_1) 27 | 28 | - Update array values by index:: 29 | 30 | Product.objects.update(tags__2='Heavy Metal') 31 | 32 | Database Functions 33 | ------------------ 34 | 35 | Various database functions are included for manipulating arrays: 36 | 37 | - ArrayLength: returns the length of an array 38 | 39 | - ArrayPosition: the position of an item in an array 40 | 41 | - ArrayPositions: all positions of an item in an array 42 | 43 | - ArrayAppend: Create an array value by adding a value to the end of an array field 44 | 45 | - ArrayPrepend: Create an array value by adding a value to the start of an array field 46 | 47 | - ArrayRemove: Create an array value by removing a value from an array field 48 | 49 | - ArrayReplace: Create an array value by replacing one value with another in an array field 50 | 51 | - ArrayCat: Combine the values of two separate ArrayFields 52 | 53 | For more information on each of these functions, check the postgresql documentation. 54 | The provided arguments to each function are automatically converted to the required expressions:: 55 | 56 | from django_postgres_extensions.models.functions import * 57 | obj = Product.objects.queryset.annotate(tags_length=ArrayLength('tags', 1)).get() 58 | obj = Product.objects.annotate(position=ArrayPosition('tags', 'Rock')).get() 59 | obj = Product.objects.annotate(positions=ArrayPositions('tags', 'Rock')).get() 60 | Product.objects.update(tags = ArrayAppend('tags', 'Popular')) 61 | Product.objects.update(tags = ArrayPrepend('Popular', 'tags')) 62 | Product.objects.update(tags = ArrayRemove('tags', 'Album')) 63 | Product.objects.update(tags = ArrayReplace('tags', 'Rock', 'Heavy Metal')) 64 | Product.objects.update(tags = ArrayCat('tags', 'moretags')) 65 | Product.objects.update(tags=ArrayCat('tags', ['Popular', '8'], output_field=Product._meta.get_field('tags'))) 66 | 67 | 68 | Use in ModelForms 69 | ----------------- 70 | 71 | django.contrib.postgres includes two possible array form fields: SimpleArrayField (the default) and SplitArrayField. 72 | To use the SplitArrayField automatically when generating a ModelForm, add the form_size keyword argument to the ArrayField:: 73 | 74 | class Product(models.Model): 75 | keywords = ArrayField(models.CharField(max_length=20), default=[], form_size=10, blank=True) 76 | 77 | The field would look like: 78 | 79 | .. image:: array_split.jpg 80 | 81 | Alternatively, it is possible to use a Multiple Choice Field for an Array, by specifying a choices argument:: 82 | 83 | sports = ArrayField(models.CharField(max_length=20),default=[], blank=True, choices=( 84 | ('football', 'Football'), ('tennis', 'Tennis'), ('golf', 'Golf'), ('basketball', 'Basketball'), ('hurling', 'Hurling'), ('baseball', 'Baseball'))) 85 | 86 | 87 | The field would look like: 88 | 89 | .. image:: array_choice.jpg 90 | 91 | Array Lookups 92 | ------------- 93 | 94 | Additional lookups have been added to the ArrayField to enable queries using the ANY and ALL database functions:: 95 | 96 | qs = Product.objects.filter(tags__any = 'Popular') 97 | qs = Product.objects.filter(tags_all__isstartof = 'Popular') 98 | 99 | Any lookups check if any value in the array meets the lookup criteria. 100 | All lookups check is all values in an array meet the lookup criteria. 101 | The full list of additional lookups are: 102 | 103 | - any 104 | - any_exact 105 | - any_gt 106 | - any_gte 107 | - any_lt 108 | - any_lte 109 | - any_in 110 | - any_isstartof 111 | - any_isendof 112 | - any_contains (for 2d arrays) 113 | - all 114 | - all_exact 115 | - all_gt 116 | - all_gte 117 | - all_lt 118 | - all_lte 119 | - all_in 120 | - all_isstartof 121 | - all_isendof 122 | - all_regex -------------------------------------------------------------------------------- /tests/prefetch_related_array/test_uuid.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | from django.test import TestCase 4 | 5 | from .models import Flea, House, Person, Pet, Room 6 | 7 | 8 | class UUIDPrefetchRelated(TestCase): 9 | 10 | def test_prefetch_related_from_uuid_model(self): 11 | Pet.objects.create(name='Fifi').people.add( 12 | Person.objects.create(name='Ellen'), 13 | Person.objects.create(name='George'), 14 | ) 15 | 16 | with self.assertNumQueries(2): 17 | pet = Pet.objects.prefetch_related('people').get(name='Fifi') 18 | with self.assertNumQueries(0): 19 | self.assertEqual(2, len(pet.people.all())) 20 | 21 | def test_prefetch_related_to_uuid_model(self): 22 | Person.objects.create(name='Bella').pets.add( 23 | Pet.objects.create(name='Socks'), 24 | Pet.objects.create(name='Coffee'), 25 | ) 26 | 27 | with self.assertNumQueries(2): 28 | person = Person.objects.prefetch_related('pets').get(name='Bella') 29 | with self.assertNumQueries(0): 30 | self.assertEqual(2, len(person.pets.all())) 31 | 32 | def test_prefetch_related_from_uuid_model_to_uuid_model(self): 33 | fleas = [Flea.objects.create() for i in range(3)] 34 | Pet.objects.create(name='Fifi').fleas_hosted.add(*fleas) 35 | Pet.objects.create(name='Bobo').fleas_hosted.add(*fleas) 36 | 37 | with self.assertNumQueries(2): 38 | pet = Pet.objects.prefetch_related('fleas_hosted').get(name='Fifi') 39 | with self.assertNumQueries(0): 40 | self.assertEqual(3, len(pet.fleas_hosted.all())) 41 | 42 | with self.assertNumQueries(2): 43 | flea = Flea.objects.prefetch_related('pets_visited').get(pk=fleas[0].pk) 44 | with self.assertNumQueries(0): 45 | self.assertEqual(2, len(flea.pets_visited.all())) 46 | 47 | def test_prefetch_related_from_uuid_model_to_uuid_model_with_values_flat(self): 48 | pet = Pet.objects.create(name='Fifi') 49 | pet.people.add( 50 | Person.objects.create(name='Ellen'), 51 | Person.objects.create(name='George'), 52 | ) 53 | self.assertSequenceEqual( 54 | Pet.objects.prefetch_related('fleas_hosted').values_list('id', flat=True), 55 | [pet.id] 56 | ) 57 | 58 | 59 | class UUIDPrefetchRelatedLookups(TestCase): 60 | 61 | @classmethod 62 | def setUpTestData(cls): 63 | house = House.objects.create(name='Redwood', address='Arcata') 64 | room = Room.objects.create(name='Racoon', house=house) 65 | fleas = [Flea.objects.create(current_room=room) for i in range(3)] 66 | pet = Pet.objects.create(name='Spooky') 67 | pet.fleas_hosted.add(*fleas) 68 | person = Person.objects.create(name='Bob') 69 | person.houses.add(house) 70 | person.pets.add(pet) 71 | person.fleas_hosted.add(*fleas) 72 | 73 | def test_from_uuid_pk_lookup_uuid_pk_integer_pk(self): 74 | # From uuid-pk model, prefetch .: 75 | with self.assertNumQueries(4): 76 | spooky = Pet.objects.prefetch_related('fleas_hosted__current_room__house').get(name='Spooky') 77 | with self.assertNumQueries(0): 78 | self.assertEqual('Racoon', spooky.fleas_hosted.all()[0].current_room.name) 79 | 80 | def test_from_uuid_pk_lookup_integer_pk2_uuid_pk2(self): 81 | # From uuid-pk model, prefetch ...: 82 | with self.assertNumQueries(5): 83 | spooky = Pet.objects.prefetch_related('people__houses__rooms__fleas').get(name='Spooky') 84 | with self.assertNumQueries(0): 85 | self.assertEqual(3, len(spooky.people.all()[0].houses.all()[0].rooms.all()[0].fleas.all())) 86 | 87 | def test_from_integer_pk_lookup_uuid_pk_integer_pk(self): 88 | # From integer-pk model, prefetch .: 89 | with self.assertNumQueries(3): 90 | racoon = Room.objects.prefetch_related('fleas__people_visited').get(name='Racoon') 91 | with self.assertNumQueries(0): 92 | self.assertEqual('Bob', racoon.fleas.all()[0].people_visited.all()[0].name) 93 | 94 | def test_from_integer_pk_lookup_integer_pk_uuid_pk(self): 95 | # From integer-pk model, prefetch .: 96 | with self.assertNumQueries(3): 97 | redwood = House.objects.prefetch_related('rooms__fleas').get(name='Redwood') 98 | with self.assertNumQueries(0): 99 | self.assertEqual(3, len(redwood.rooms.all()[0].fleas.all())) 100 | 101 | def test_from_integer_pk_lookup_integer_pk_uuid_pk_uuid_pk(self): 102 | # From integer-pk model, prefetch ..: 103 | with self.assertNumQueries(4): 104 | redwood = House.objects.prefetch_related('rooms__fleas__pets_visited').get(name='Redwood') 105 | with self.assertNumQueries(0): 106 | self.assertEqual('Spooky', redwood.rooms.all()[0].fleas.all()[0].pets_visited.all()[0].name) 107 | -------------------------------------------------------------------------------- /django_postgres_extensions/models/functions.py: -------------------------------------------------------------------------------- 1 | from django.db.models.expressions import Func, Expression 2 | from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE 3 | from django.utils import six 4 | from .expressions import F, Value as V 5 | 6 | class SimpleFunc(Func): 7 | 8 | def __init__(self, field, *values, **extra): 9 | if not isinstance(field, Expression): 10 | field = F(field) 11 | if values and not isinstance(values[0], Expression): 12 | values = [V(v) for v in values] 13 | super(SimpleFunc, self).__init__(field, *values, **extra) 14 | 15 | class TooManyExpressionsError(Exception): 16 | pass 17 | 18 | def multi_func(func, expression, *args): 19 | if len(args) > GET_ITERATOR_CHUNK_SIZE: 20 | raise TooManyExpressionsError('Multi-func given %s args. The limit is %s due to Python recursion depth risk' % ( 21 | len(args), GET_ITERATOR_CHUNK_SIZE)) 22 | args = list(args) 23 | initial_arg = args.pop(0) 24 | query = func(expression, initial_arg) 25 | for arg in args: 26 | query = func(query, arg) 27 | return query 28 | 29 | def multi_array_remove(field, *args): 30 | return multi_func(ArrayRemove, field, *args) 31 | 32 | class ArrayAppend(SimpleFunc): 33 | function = 'ARRAY_APPEND' 34 | 35 | class ArrayPrepend(Func): 36 | function = 'ARRAY_PREPEND' 37 | 38 | def __init__(self, value, field, **extra): 39 | if not isinstance(value, Expression): 40 | value = V(value) 41 | field = F(field) 42 | super(ArrayPrepend, self).__init__(value, field, **extra) 43 | 44 | class ArrayRemove(SimpleFunc): 45 | function = 'ARRAY_REMOVE' 46 | 47 | class ArrayReplace(SimpleFunc): 48 | function = 'ARRAY_REPLACE' 49 | 50 | class ArrayPosition(SimpleFunc): 51 | function = 'ARRAY_POSITION' 52 | 53 | class ArrayPositions(SimpleFunc): 54 | function = 'ARRAY_POSITIONS' 55 | 56 | class ArrayCat(Func): 57 | function = 'ARRAY_CAT' 58 | 59 | def __init__(self, field, value, prepend=False, output_field=None, **extra): 60 | if not isinstance(field, Expression): 61 | field = F(field) 62 | if not isinstance(value, Expression): 63 | if isinstance(value, six.string_types): 64 | value = F(value) 65 | elif output_field: 66 | value = V(value, output_field = output_field) 67 | else: 68 | value = V(value) 69 | if prepend: 70 | super(ArrayCat, self).__init__(value, field, **extra) 71 | else: 72 | super(ArrayCat, self).__init__(field, value, **extra) 73 | 74 | class ArrayLength(SimpleFunc): 75 | function = 'ARRAY_LENGTH' 76 | 77 | class ArrayDims(SimpleFunc): 78 | function = 'ARRAY_DIMS' 79 | 80 | class ArrayUpper(SimpleFunc): 81 | function = 'ARRAY_UPPER' 82 | 83 | class ArrayLower(SimpleFunc): 84 | function = 'ARRAY_LOWER' 85 | 86 | class Cardinality(SimpleFunc): 87 | function = 'CARDINALITY' 88 | 89 | class NonFieldFunc(Func): 90 | def __init__(self, *values, **extra): 91 | values = list(values) 92 | for i, value in enumerate(values): 93 | if not isinstance(value, Expression): 94 | values[i] = V(value) 95 | super(NonFieldFunc, self).__init__(*values, **extra) 96 | 97 | class HStore(NonFieldFunc): 98 | function = 'HSTORE' 99 | 100 | class AKeys(SimpleFunc): 101 | function = 'AKEYS' 102 | 103 | class SKeys(SimpleFunc): 104 | function = 'SKEYS' 105 | 106 | class AVals(SimpleFunc): 107 | function = 'AVALS' 108 | 109 | class SVals(SimpleFunc): 110 | function = 'SVALS' 111 | 112 | class HStoreToArray(SimpleFunc): 113 | function = 'HSTORE_TO_ARRAY' 114 | 115 | class HStoreToMatrix(SimpleFunc): 116 | function = 'HSTORE_TO_MATRIX' 117 | 118 | class Slice(SimpleFunc): 119 | function = 'SLICE' 120 | 121 | class Delete(SimpleFunc): 122 | function = 'DELETE' 123 | 124 | class Each(SimpleFunc): 125 | function = 'EACH' 126 | 127 | class HstoreToJSONB(SimpleFunc): 128 | function = 'HSTORE_TO_JSONB' 129 | 130 | class HstoreToJSONBLoose(SimpleFunc): 131 | function = 'HSTORE_TO_JSONB_LOOSE' 132 | 133 | class ToJSONB(NonFieldFunc): 134 | function = 'TO_JSONB' 135 | 136 | class RowToJSON(SimpleFunc): 137 | function = 'ROW_TO_JSON' 138 | 139 | class ArrayToJSON(SimpleFunc): 140 | function = 'ARRAY_TO_JSON' 141 | 142 | class JSONBBuildArray(NonFieldFunc): 143 | function = 'JSONB_BUILD_ARRAY' 144 | 145 | class JSONBArrayElements(SimpleFunc): 146 | function = 'JSONB_ARRAY_ELEMENTS' 147 | 148 | class JSONBBuildObject(NonFieldFunc): 149 | function = 'JSONB_BUILD_OBJECT' 150 | 151 | class JSONBObject(NonFieldFunc): 152 | function = 'JOSNB_OBJECT' 153 | 154 | class JSONBSet(SimpleFunc): 155 | function = 'JSONB_SET' 156 | 157 | class JSONBArrayLength(SimpleFunc): 158 | function = 'JSONB_ARRAY_length' 159 | 160 | class JSONBPretty(SimpleFunc): 161 | function = 'JSONB_PRETTY' 162 | 163 | class JSONObjectKeys(SimpleFunc): 164 | function = 'JSON_OBJECT_KEYS' 165 | 166 | class JSONStripNulls(SimpleFunc): 167 | function = 'JSON_STRIP_NULLS' 168 | 169 | class JSONTypeOf(SimpleFunc): 170 | function = 'JSON_TYPE_OF' -------------------------------------------------------------------------------- /tests/prefetch_related_array/test_prefetch_related_objects.py: -------------------------------------------------------------------------------- 1 | from django.db.models import Prefetch 2 | from django.db.models.query import prefetch_related_objects 3 | from django.test import TestCase 4 | 5 | from .models import Author, Book, Reader 6 | 7 | 8 | class PrefetchRelatedObjectsTests(TestCase): 9 | """ 10 | Since prefetch_related_objects() is just the inner part of 11 | prefetch_related(), only do basic tests to ensure its API hasn't changed. 12 | """ 13 | @classmethod 14 | def setUpTestData(cls): 15 | cls.book1 = Book.objects.create(title='Poems') 16 | cls.book2 = Book.objects.create(title='Jane Eyre') 17 | cls.book3 = Book.objects.create(title='Wuthering Heights') 18 | cls.book4 = Book.objects.create(title='Sense and Sensibility') 19 | 20 | cls.author1 = Author.objects.create(name='Charlotte', first_book=cls.book1) 21 | cls.author2 = Author.objects.create(name='Anne', first_book=cls.book1) 22 | cls.author3 = Author.objects.create(name='Emily', first_book=cls.book1) 23 | cls.author4 = Author.objects.create(name='Jane', first_book=cls.book4) 24 | 25 | cls.book1.authors.add(cls.author1, cls.author2, cls.author3) 26 | cls.book2.authors.add(cls.author1) 27 | cls.book3.authors.add(cls.author3) 28 | cls.book4.authors.add(cls.author4) 29 | 30 | cls.reader1 = Reader.objects.create(name='Amy') 31 | cls.reader2 = Reader.objects.create(name='Belinda') 32 | 33 | cls.reader1.books_read.add(cls.book1, cls.book4) 34 | cls.reader2.books_read.add(cls.book2, cls.book4) 35 | 36 | def test_unknown(self): 37 | book1 = Book.objects.get(id=self.book1.id) 38 | with self.assertRaises(AttributeError): 39 | prefetch_related_objects([book1], 'unknown_attribute') 40 | 41 | def test_m2m_forward(self): 42 | book1 = Book.objects.get(id=self.book1.id) 43 | with self.assertNumQueries(1): 44 | prefetch_related_objects([book1], 'authors') 45 | 46 | with self.assertNumQueries(0): 47 | self.assertEqual(set(book1.authors.all()), {self.author1, self.author2, self.author3}) 48 | 49 | def test_m2m_reverse(self): 50 | author1 = Author.objects.get(id=self.author1.id) 51 | with self.assertNumQueries(1): 52 | prefetch_related_objects([author1], 'books') 53 | 54 | with self.assertNumQueries(0): 55 | self.assertEqual(set(author1.books.all()), {self.book1, self.book2}) 56 | 57 | def test_foreignkey_forward(self): 58 | authors = list(Author.objects.all()) 59 | with self.assertNumQueries(1): 60 | prefetch_related_objects(authors, 'first_book') 61 | 62 | with self.assertNumQueries(0): 63 | [author.first_book for author in authors] 64 | 65 | def test_foreignkey_reverse(self): 66 | books = list(Book.objects.all()) 67 | with self.assertNumQueries(1): 68 | prefetch_related_objects(books, 'first_time_authors') 69 | 70 | with self.assertNumQueries(0): 71 | [list(book.first_time_authors.all()) for book in books] 72 | 73 | def test_m2m_then_m2m(self): 74 | """ 75 | We can follow a m2m and another m2m. 76 | """ 77 | authors = list(Author.objects.all()) 78 | with self.assertNumQueries(2): 79 | prefetch_related_objects(authors, 'books__read_by') 80 | 81 | with self.assertNumQueries(0): 82 | self.assertEqual( 83 | [ 84 | [[str(r) for r in b.read_by.all()] for b in a.books.all()] 85 | for a in authors 86 | ], 87 | [ 88 | [['Amy'], ['Belinda']], # Charlotte - Poems, Jane Eyre 89 | [['Amy']], # Anne - Poems 90 | [['Amy'], []], # Emily - Poems, Wuthering Heights 91 | [['Amy', 'Belinda']], # Jane - Sense and Sense 92 | ] 93 | ) 94 | 95 | def test_prefetch_object(self): 96 | book1 = Book.objects.get(id=self.book1.id) 97 | with self.assertNumQueries(1): 98 | prefetch_related_objects([book1], Prefetch('authors')) 99 | 100 | with self.assertNumQueries(0): 101 | self.assertEqual(set(book1.authors.all()), {self.author1, self.author2, self.author3}) 102 | 103 | def test_prefetch_object_to_attr(self): 104 | book1 = Book.objects.get(id=self.book1.id) 105 | with self.assertNumQueries(1): 106 | prefetch_related_objects([book1], Prefetch('authors', to_attr='the_authors')) 107 | 108 | with self.assertNumQueries(0): 109 | self.assertEqual(set(book1.the_authors), {self.author1, self.author2, self.author3}) 110 | 111 | def test_prefetch_queryset(self): 112 | book1 = Book.objects.get(id=self.book1.id) 113 | with self.assertNumQueries(1): 114 | prefetch_related_objects( 115 | [book1], 116 | Prefetch('authors', queryset=Author.objects.filter(id__in=[self.author1.id, self.author2.id])) 117 | ) 118 | 119 | with self.assertNumQueries(0): 120 | self.assertEqual(set(book1.authors.all()), {self.author1, self.author2}) 121 | -------------------------------------------------------------------------------- /tests/jsonb/tests.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals, absolute_import 2 | 3 | from django.test import TestCase 4 | from .models import Product 5 | from django_postgres_extensions.models.functions import * 6 | from django_postgres_extensions.models.expressions import Key 7 | from psycopg2.extras import Json 8 | from django.db import transaction 9 | 10 | class JSONIndexTests(TestCase): 11 | 12 | def setUp(self): 13 | super(JSONIndexTests, self).setUp() 14 | self.product = Product(name='xyz', description={'Industry': 'Music', 'Details': {'Release': 'Album', 'Genre': 'Rock', 'Rating': 8}, 'Price': 9.99, 'Tags': ['Heavy', 'Guitar']}) 15 | self.product.save() 16 | self.queryset = Product.objects.filter(pk=self.product.pk) 17 | self.pk_queryset = self.queryset.only('id') 18 | 19 | def tearDown(self): 20 | Product.objects.all().delete() 21 | 22 | def test_json_value(self): 23 | product = self.queryset.get() 24 | self.assertDictEqual(product.description, 25 | {'Industry': 'Music', 'Details': {'Release': 'Album', 'Genre': 'Rock', 'Rating': 8}, 26 | 'Price': 9.99, 'Tags': ['Heavy', 'Guitar']}) 27 | 28 | def test_json_value_by_key(self): 29 | with transaction.atomic(): 30 | obj = self.pk_queryset.annotate(Key('description', 'Details')).get() 31 | self.assertDictEqual(obj.description__Details, {'Genre': 'Rock', 'Rating': 8, 'Release': 'Album'}) 32 | 33 | def test_json_value_by_key_path(self): 34 | with transaction.atomic(): 35 | obj = self.pk_queryset.annotate(Key('description', 'Details__Rating')).get() 36 | self.assertEqual(obj.description__Details__Rating, 8) 37 | with transaction.atomic(): 38 | obj = self.pk_queryset.annotate(Key('description', 'Tags__1')).get() 39 | self.assertEqual(obj.description__Tags__1, "Guitar") 40 | 41 | def test_json_update_keys_values(self): 42 | with transaction.atomic(): 43 | self.queryset.update(description__ = {'Industry': 'Movie', 'Popularity': 'Very Popular'}) 44 | product = self.queryset.get() 45 | self.assertDictEqual(product.description, 46 | {'Industry': 'Movie', 'Details': {'Release': 'Album', 'Genre': 'Rock', 'Rating': 8}, 47 | 'Price': 9.99, 'Popularity': 'Very Popular', 'Tags': ['Heavy', 'Guitar']}) 48 | 49 | def test_json_update_delete_key(self): 50 | with transaction.atomic(): 51 | self.queryset.update(description__del ='Details') 52 | product = self.queryset.get() 53 | self.assertDictEqual(product.description, 54 | {'Industry': 'Music', 'Price': 9.99, 'Tags': ['Heavy', 'Guitar']}) 55 | 56 | def test_json_update_delete_key_path(self): 57 | with transaction.atomic(): 58 | self.queryset.update(description__del = 'Details__Release') 59 | product = self.queryset.get() 60 | self.assertDictEqual(product.description, 61 | {'Industry': 'Music', 'Details': {'Genre': 'Rock', 'Rating': 8}, 62 | 'Price': 9.99, 'Tags': ['Heavy', 'Guitar']}) 63 | with transaction.atomic(): 64 | self.queryset.update(description__del='Tags__1') 65 | product = self.queryset.get() 66 | self.assertDictEqual(product.description, 67 | {'Industry': 'Music', 'Details': {'Genre': 'Rock', 'Rating': 8}, 68 | 'Price': 9.99, 'Tags': ['Heavy']}) 69 | 70 | class JSONFuncTests(TestCase): 71 | def setUp(self): 72 | super(JSONFuncTests, self).setUp() 73 | self.product = Product(name='xyz', description={'Industry': 'Music', 'Details': {'Release': 'Album', 'Genre': 'Rock', 'Rating': 8}, 'Price': 9.99, 'Tags': ['Heavy', 'Guitar']}) 74 | self.product.save() 75 | self.queryset = Product.objects.filter(pk=self.product.pk) 76 | self.product2 = Product(name='xyz', description=[{'a': 'b', 'c':'d'}, {'a': 'e', 'c': 'f'}]) 77 | self.product2.save() 78 | self.queryset2 = Product.objects.filter(pk=self.product2.pk) 79 | 80 | def tearDown(self): 81 | Product.objects.all().delete() 82 | 83 | def test_jsonb_set(self): 84 | with transaction.atomic(): 85 | self.queryset.update(description = JSONBSet('description', ['Details', 'Genre'], Json('Heavy Metal'), True)) 86 | obj = self.queryset.get() 87 | self.assertDictEqual(obj.description, 88 | {'Price': 9.99, 'Industry': 'Music', 'Details': {'Genre': 'Heavy Metal', 'Release': 'Album', 'Rating': 8}, 'Tags': ['Heavy', 'Guitar']}) 89 | 90 | def test_jsonb_array_set(self): 91 | with transaction.atomic(): 92 | self.queryset2.update(description=JSONBSet('description', ['1', 'c'], Json('g'))) 93 | obj = self.queryset2.get() 94 | self.assertListEqual(obj.description, 95 | [{'a': 'b', 'c': 'd'}, {'a': 'e', 'c': 'g'}]) 96 | 97 | def test_jsonb_array_length(self): 98 | with transaction.atomic(): 99 | qs = self.queryset2.format('description', JSONBArrayLength, output_field='desc_length') 100 | obj = qs.get() 101 | self.assertEqual(obj.desc_length, 2) -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Django Postgres Extensions documentation build configuration file, created by 5 | # sphinx-quickstart on Wed Feb 8 21:57:16 2017. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | # 20 | # import os 21 | # import sys 22 | # sys.path.insert(0, os.path.abspath('.')) 23 | 24 | 25 | # -- General configuration ------------------------------------------------ 26 | 27 | # If your documentation needs a minimal Sphinx version, state it here. 28 | # 29 | # needs_sphinx = '1.0' 30 | 31 | # Add any Sphinx extension module names here, as strings. They can be 32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 33 | # ones. 34 | extensions = ['sphinx.ext.autodoc', 35 | 'sphinx.ext.doctest', 36 | 'sphinxcontrib.fancybox'] 37 | 38 | # Add any paths that contain templates here, relative to this directory. 39 | templates_path = ['_templates'] 40 | 41 | # The suffix(es) of source filenames. 42 | # You can specify multiple suffix as a list of string: 43 | # 44 | # source_suffix = ['.rst', '.md'] 45 | source_suffix = '.rst' 46 | 47 | # The master toctree document. 48 | master_doc = 'index' 49 | 50 | # General information about the project. 51 | project = 'Django Postgres Extensions' 52 | copyright = '2017, Paul Martin' 53 | author = 'Paul Martin' 54 | 55 | # The version info for the project you're documenting, acts as replacement for 56 | # |version| and |release|, also used in various other places throughout the 57 | # built documents. 58 | # 59 | # The short X.Y version. 60 | version = '0.9.2' 61 | # The full version, including alpha/beta/rc tags. 62 | release = '0.9.2' 63 | 64 | # The language for content autogenerated by Sphinx. Refer to documentation 65 | # for a list of supported languages. 66 | # 67 | # This is also used if you do content translation via gettext catalogs. 68 | # Usually you set "language" from the command line for these cases. 69 | language = 'python' 70 | 71 | # List of patterns, relative to source directory, that match files and 72 | # directories to ignore when looking for source files. 73 | # This patterns also effect to html_static_path and html_extra_path 74 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 75 | 76 | # The name of the Pygments (syntax highlighting) style to use. 77 | pygments_style = 'sphinx' 78 | 79 | # If true, `todo` and `todoList` produce output, else they produce nothing. 80 | todo_include_todos = False 81 | 82 | 83 | # -- Options for HTML output ---------------------------------------------- 84 | 85 | # The theme to use for HTML and HTML Help pages. See the documentation for 86 | # a list of builtin themes. 87 | # 88 | html_theme = 'default' 89 | 90 | # Theme options are theme-specific and customize the look and feel of a theme 91 | # further. For a list of options available for each theme, see the 92 | # documentation. 93 | # 94 | # html_theme_options = {} 95 | 96 | # Add any paths that contain custom static files (such as style sheets) here, 97 | # relative to this directory. They are copied after the builtin static files, 98 | # so a file named "default.css" will overwrite the builtin "default.css". 99 | html_static_path = ['_static'] 100 | 101 | 102 | # -- Options for HTMLHelp output ------------------------------------------ 103 | 104 | # Output file base name for HTML help builder. 105 | htmlhelp_basename = 'DjangoPostgresExtensionsdoc' 106 | 107 | 108 | # -- Options for LaTeX output --------------------------------------------- 109 | 110 | latex_elements = { 111 | # The paper size ('letterpaper' or 'a4paper'). 112 | # 113 | # 'papersize': 'letterpaper', 114 | 115 | # The font size ('10pt', '11pt' or '12pt'). 116 | # 117 | # 'pointsize': '10pt', 118 | 119 | # Additional stuff for the LaTeX preamble. 120 | # 121 | # 'preamble': '', 122 | 123 | # Latex figure (float) alignment 124 | # 125 | # 'figure_align': 'htbp', 126 | } 127 | 128 | # Grouping the document tree into LaTeX files. List of tuples 129 | # (source start file, target name, title, 130 | # author, documentclass [howto, manual, or own class]). 131 | latex_documents = [ 132 | (master_doc, 'DjangoPostgresExtensions.tex', 'Django Postgres Extensions Documentation', 133 | 'Paul Martin', 'manual'), 134 | ] 135 | 136 | 137 | # -- Options for manual page output --------------------------------------- 138 | 139 | # One entry per manual page. List of tuples 140 | # (source start file, name, description, authors, manual section). 141 | man_pages = [ 142 | (master_doc, 'djangopostgresextensions', 'Django Postgres Extensions Documentation', 143 | [author], 1) 144 | ] 145 | 146 | 147 | # -- Options for Texinfo output ------------------------------------------- 148 | 149 | # Grouping the document tree into Texinfo files. List of tuples 150 | # (source start file, target name, title, author, 151 | # dir menu entry, description, category) 152 | texinfo_documents = [ 153 | (master_doc, 'DjangoPostgresExtensions', 'Django Postgres Extensions Documentation', 154 | author, 'DjangoPostgresExtensions', 'One line description of project.', 155 | 'Miscellaneous'), 156 | ] 157 | 158 | html_domain_indices = False -------------------------------------------------------------------------------- /tests/nested_form_field_widget/tests.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from django.test import TestCase 3 | from django.forms import CharField, Form, TextInput, FileInput 4 | from django_postgres_extensions import forms 5 | from django_postgres_extensions.models.fields import HStoreField 6 | 7 | class NestedFormWidgetTest(TestCase): 8 | 9 | def check_html(self, widget, name, value, html='', attrs=None, **kwargs): 10 | output = widget.render(name, value, attrs=attrs, **kwargs) 11 | self.assertHTMLEqual(output, html) 12 | 13 | def test_text_inputs(self): 14 | widget = forms.NestedFormWidget( 15 | ('A', 'B', 'C'), 16 | ((TextInput()), 17 | (TextInput()), 18 | (TextInput()) 19 | ) 20 | ) 21 | self.check_html(widget, 'name', ['john', 'winston', 'lennon'], html="""
    22 |
  • 23 |
  • 24 |
  • 25 |
""" 26 | ) 27 | 28 | def test_constructor_attrs(self): 29 | widget = forms.NestedFormWidget( 30 | ('A', 'B', 'C'), 31 | ((TextInput()), 32 | (TextInput()), 33 | (TextInput()) 34 | ), 35 | attrs={'id': 'bar'}, 36 | ) 37 | self.check_html(widget, 'name', ['john', 'winston', 'lennon'], attrs={'id': 'bar'}, html="""
    38 |
  • 39 |
  • 40 |
  • 41 |
42 | """ 43 | ) 44 | 45 | def test_needs_multipart_true(self): 46 | """ 47 | needs_multipart_form should be True if any widgets need it. 48 | """ 49 | widget = forms.NestedFormWidget( 50 | ('text', 'file'), 51 | (TextInput(), FileInput()) 52 | ) 53 | self.assertTrue(widget.needs_multipart_form) 54 | 55 | def test_needs_multipart_false(self): 56 | """ 57 | needs_multipart_form should be False if no widgets need it. 58 | """ 59 | widget = forms.NestedFormWidget( 60 | ('text', 'text2'), 61 | (TextInput(), TextInput()) 62 | ) 63 | self.assertFalse(widget.needs_multipart_form) 64 | 65 | def test_nested_multiwidget(self): 66 | """ 67 | NestedFormWidget can be composed of other NestedFormWidgets. 68 | """ 69 | widget = forms.NestedFormWidget( 70 | ('A', 'B'), 71 | (TextInput(), forms.NestedFormWidget( 72 | ('C', 'D'), 73 | (TextInput(), TextInput()) 74 | ) 75 | ) 76 | ) 77 | self.check_html(widget, 'name', ['Singer', ['John', 'Lennon']], html=( 78 | """ 79 |
    80 |
  • 81 |
    • 82 |
    • 83 |
    • 84 |
    85 |
  • 86 |
87 | """ 88 | )) 89 | 90 | def test_deepcopy(self): 91 | """ 92 | MultiWidget should define __deepcopy__() (#12048). 93 | """ 94 | w1 = forms.NestedFormWidget( 95 | ['A', 'B', 'C'], 96 | (TextInput(), 97 | TextInput(), 98 | TextInput() 99 | ) 100 | ) 101 | w2 = copy.deepcopy(w1) 102 | w2.labels.append('d') 103 | # w2 ought to be independent of w1, since MultiWidget ought 104 | # to make a copy of its sub-widgets when it is copied. 105 | self.assertEqual(w1.labels, ['A', 'B', 'C']) 106 | 107 | 108 | class TestNestedFormField(TestCase): 109 | 110 | def test_valid(self): 111 | field = forms.NestedFormField(keys=('a', 'b', 'c')) 112 | value = field.clean(["d", '', "f"]) 113 | self.assertEqual(value, {'a': 'd', 'b': '', 'c': 'f'}) 114 | 115 | def test_model_field_formfield_keys(self): 116 | model_field = HStoreField(keys=('a', 'b', 'c')) 117 | form_field = model_field.formfield() 118 | self.assertIsInstance(form_field, forms.NestedFormField) 119 | 120 | def test_model_field_formfield_fields(self): 121 | model_field = HStoreField(fields=( 122 | ('a', CharField(max_length=10)), 123 | ('b', CharField(max_length=10)), 124 | ('c', CharField(max_length=10)) 125 | ) 126 | ) 127 | form_field = model_field.formfield() 128 | self.assertIsInstance(form_field, forms.NestedFormField) 129 | 130 | def test_field_has_changed(self): 131 | class NestedFormTest(Form): 132 | f1 = forms.NestedFormField(keys=('a', 'b', 'c')) 133 | form_w_hstore = NestedFormTest() 134 | self.assertFalse(form_w_hstore.has_changed()) 135 | 136 | form_w_hstore = NestedFormTest({'f1_a': 'd', 'fl_b': 'e', 'f1_c': 'f'}) 137 | self.assertTrue(form_w_hstore.has_changed()) 138 | 139 | form_w_hstore = NestedFormTest({'f1_a': 'g'}, 140 | initial={'f1_a': 'd', 'fl_b': 'e', 'f1_c': 'f'}) 141 | self.assertTrue(form_w_hstore.has_changed()) 142 | -------------------------------------------------------------------------------- /tests/m2m_recursive_array/tests.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | from operator import attrgetter 4 | 5 | from django.test import TestCase 6 | 7 | from .models import Person 8 | 9 | 10 | class RecursiveM2MTests(TestCase): 11 | def setUp(self): 12 | self.a, self.b, self.c, self.d = [ 13 | Person.objects.create(name=name) 14 | for name in ["Anne", "Bill", "Chuck", "David"] 15 | ] 16 | 17 | # Anne is friends with Bill and Chuck 18 | self.a.friends.add(self.b, self.c) 19 | 20 | # David is friends with Anne and Chuck - add in reverse direction 21 | self.d.friends.add(self.a, self.c) 22 | 23 | def test_recursive_m2m_all(self): 24 | """ Test that m2m relations are reported correctly """ 25 | # Who is friends with Anne? 26 | self.assertQuerysetEqual( 27 | self.a.friends.all(), [ 28 | "Bill", 29 | "Chuck", 30 | "David" 31 | ], 32 | attrgetter("name"), 33 | ordered=False 34 | ) 35 | # Who is friends with Bill? 36 | self.assertQuerysetEqual( 37 | self.b.friends.all(), [ 38 | "Anne", 39 | ], 40 | attrgetter("name") 41 | ) 42 | # Who is friends with Chuck? 43 | self.assertQuerysetEqual( 44 | self.c.friends.all(), [ 45 | "Anne", 46 | "David" 47 | ], 48 | attrgetter("name"), 49 | ordered=False 50 | ) 51 | # Who is friends with David? 52 | self.assertQuerysetEqual( 53 | self.d.friends.all(), [ 54 | "Anne", 55 | "Chuck", 56 | ], 57 | attrgetter("name"), 58 | ordered=False 59 | ) 60 | 61 | def test_recursive_m2m_reverse_add(self): 62 | """ Test reverse m2m relation is consistent """ 63 | 64 | # Bill is already friends with Anne - add Anne again, but in the 65 | # reverse direction 66 | self.b.friends.add(self.a) 67 | 68 | # Who is friends with Anne? 69 | self.assertQuerysetEqual( 70 | self.a.friends.all(), [ 71 | "Bill", 72 | "Chuck", 73 | "David", 74 | ], 75 | attrgetter("name"), 76 | ordered=False 77 | ) 78 | # Who is friends with Bill? 79 | self.assertQuerysetEqual( 80 | self.b.friends.all(), [ 81 | "Anne", 82 | ], 83 | attrgetter("name") 84 | ) 85 | 86 | def test_recursive_m2m_remove(self): 87 | """ Test that we can remove items from an m2m relationship """ 88 | 89 | # Remove Anne from Bill's friends 90 | self.b.friends.remove(self.a) 91 | 92 | # Who is friends with Anne? 93 | self.assertQuerysetEqual( 94 | self.a.friends.all(), [ 95 | "Chuck", 96 | "David", 97 | ], 98 | attrgetter("name"), 99 | ordered=False 100 | ) 101 | # Who is friends with Bill? 102 | self.assertQuerysetEqual( 103 | self.b.friends.all(), [] 104 | ) 105 | 106 | def test_recursive_m2m_clear(self): 107 | """ Tests the clear method works as expected on m2m fields """ 108 | 109 | # Clear Anne's group of friends 110 | self.a.friends.clear() 111 | 112 | # Who is friends with Anne? 113 | self.assertQuerysetEqual( 114 | self.a.friends.all(), [] 115 | ) 116 | 117 | # Reverse relationships should also be gone 118 | # Who is friends with Chuck? 119 | self.assertQuerysetEqual( 120 | self.c.friends.all(), [ 121 | "David", 122 | ], 123 | attrgetter("name") 124 | ) 125 | 126 | # Who is friends with David? 127 | self.assertQuerysetEqual( 128 | self.d.friends.all(), [ 129 | "Chuck", 130 | ], 131 | attrgetter("name") 132 | ) 133 | 134 | def test_recursive_m2m_add_via_related_name(self): 135 | """ Tests that we can add m2m relations via the related_name attribute """ 136 | 137 | # David is idolized by Anne and Chuck - add in reverse direction 138 | self.d.stalkers.add(self.a) 139 | 140 | # Who are Anne's idols? 141 | self.assertQuerysetEqual( 142 | self.a.idols.all(), [ 143 | "David", 144 | ], 145 | attrgetter("name"), 146 | ordered=False 147 | ) 148 | # Who is stalking Anne? 149 | self.assertQuerysetEqual( 150 | self.a.stalkers.all(), [], 151 | attrgetter("name") 152 | ) 153 | 154 | def test_recursive_m2m_add_in_both_directions(self): 155 | """ Check that adding the same relation twice results in a single relation """ 156 | 157 | # Ann idolizes David 158 | self.a.idols.add(self.d) 159 | 160 | # David is idolized by Anne 161 | self.d.stalkers.add(self.a) 162 | 163 | # Who are Anne's idols? 164 | self.assertQuerysetEqual( 165 | self.a.idols.all(), [ 166 | "David", 167 | ], 168 | attrgetter("name"), 169 | ordered=False 170 | ) 171 | # As the assertQuerysetEqual uses a set for comparison, 172 | # check we've only got David listed once 173 | self.assertEqual(self.a.idols.all().count(), 1) 174 | 175 | def test_recursive_m2m_related_to_self(self): 176 | """ Check the expected behavior when an instance is related to itself """ 177 | 178 | # Ann idolizes herself 179 | self.a.idols.add(self.a) 180 | 181 | # Who are Anne's idols? 182 | self.assertQuerysetEqual( 183 | self.a.idols.all(), [ 184 | "Anne", 185 | ], 186 | attrgetter("name"), 187 | ordered=False 188 | ) 189 | # Who is stalking Anne? 190 | self.assertQuerysetEqual( 191 | self.a.stalkers.all(), [ 192 | "Anne", 193 | ], 194 | attrgetter("name") 195 | ) 196 | -------------------------------------------------------------------------------- /django_postgres_extensions/models/fields/__init__.py: -------------------------------------------------------------------------------- 1 | from django.contrib.postgres import fields 2 | from django.contrib.postgres.forms import SplitArrayField as SplitArrayFormField 3 | from django.forms.fields import TypedMultipleChoiceField 4 | from psycopg2.extras import Json 5 | from django_postgres_extensions.forms.fields import NestedFormField 6 | from django_postgres_extensions.models.expressions import F, Value as V 7 | from django_postgres_extensions.models.functions import HStore, Delete, ArrayRemove 8 | from django_postgres_extensions.models.sql.updates import UpdateArrayByIndex 9 | from django.core import exceptions 10 | 11 | 12 | class ArrayField(fields.ArrayField): 13 | 14 | def __init__(self, base_field, form_size=None, **kwargs): 15 | super(ArrayField, self).__init__(base_field, **kwargs) 16 | self.form_size = form_size 17 | 18 | def get_update_type(self, indexes, value): 19 | if indexes == 'del': 20 | return ArrayRemove(self.name, value) 21 | if '__' in indexes: 22 | indexes = indexes.split('__') 23 | try: 24 | indexes = [int(index) + 1 for index in indexes] 25 | return UpdateArrayByIndex(indexes, value, self) 26 | except ValueError: 27 | raise ValueError('Update lookup type %s not found for field %s' % (indexes, self.name)) 28 | 29 | def formfield(self, **kwargs): 30 | if self.form_size or self.choices: 31 | defaults = { 32 | 'form_class': SplitArrayFormField, 33 | 'base_field': self.base_field.formfield(), 34 | 'choices_form_class': TypedMultipleChoiceField, 35 | 'size': self.form_size, 36 | 'remove_trailing_nulls': True 37 | } 38 | if self.choices: 39 | defaults['coerce'] = self.base_field.to_python 40 | defaults.update(kwargs) 41 | return super(fields.ArrayField, self).formfield(**defaults) 42 | return super(ArrayField, self).formfield(**kwargs) 43 | 44 | def validate(self, value, model_instance): 45 | """ 46 | Validates value and throws ValidationError. Subclasses should override 47 | this to provide validation logic. 48 | """ 49 | if not self.editable: 50 | # Skip validation for non-editable fields. 51 | return 52 | 53 | if self.choices and value not in self.empty_values: 54 | if isinstance(value, (list, tuple)): 55 | option_keys = [x[0] for x in self.choices] 56 | if all(x in option_keys for x in value): 57 | return 58 | else: 59 | for option_key, option_value in self.choices: 60 | if isinstance(option_value, (list, tuple)): 61 | # This is an optgroup, so look inside the group for 62 | # options. 63 | for optgroup_key, optgroup_value in option_value: 64 | if value == optgroup_key: 65 | return 66 | elif value == option_key: 67 | return 68 | raise exceptions.ValidationError( 69 | self.error_messages['invalid_choice'], 70 | code='invalid_choice', 71 | params={'value': value}, 72 | ) 73 | 74 | if value is None and not self.null: 75 | raise exceptions.ValidationError(self.error_messages['null'], code='null') 76 | 77 | if not self.blank and value in self.empty_values: 78 | raise exceptions.ValidationError(self.error_messages['blank'], code='blank') 79 | 80 | def deconstruct(self): 81 | name, path, args, kwargs = super(ArrayField, self).deconstruct() 82 | kwargs.update({ 83 | 'form_size': self.form_size, 84 | }) 85 | return name, path, args, kwargs 86 | 87 | class HStoreField(fields.HStoreField): 88 | 89 | def __init__(self, fields=(), keys=(), max_value_length=25, require_all_fields=False, **kwargs): 90 | super(HStoreField, self).__init__(**kwargs) 91 | self.fields = fields 92 | self.keys = keys 93 | self.max_value_length = max_value_length 94 | self.require_all_fields = require_all_fields 95 | 96 | def get_update_type(self, lookups, value): 97 | lookup = lookups[0] 98 | if lookup == '' or lookup == 'raw': 99 | keys = list(value.keys()) 100 | values = list(value.values()) 101 | if lookup == '': 102 | values = [str(v) for v in value.values()] 103 | return F(self.name).cat(HStore(V(keys), V(values))) 104 | if lookup == 'del': 105 | return Delete(self.name, value) 106 | raise ValueError('Update lookup type %s not found for field %s' % (lookup, self.name)) 107 | 108 | def formfield(self, **kwargs): 109 | if self.fields or self.keys: 110 | defaults = { 111 | 'form_class': NestedFormField, 112 | 'fields': self.fields, 113 | 'keys': list(self.keys), 114 | 'require_all_fields': self.require_all_fields, 115 | 'max_value_length': self.max_value_length 116 | } 117 | defaults.update(kwargs) 118 | else: 119 | defaults = kwargs 120 | return super(HStoreField, self).formfield(**defaults) 121 | 122 | class JSONField(fields.JSONField): 123 | 124 | def __init__(self, fields=(), require_all_fields=False, **kwargs): 125 | super(JSONField, self).__init__(**kwargs) 126 | self.fields = fields 127 | self.require_all_fields = require_all_fields 128 | 129 | def get_update_type(self, lookups, value): 130 | lookup = lookups[0] 131 | if lookup == '': 132 | return F(self.name).cat(V(Json(value))) 133 | if lookup == 'del': 134 | if '__' in value: 135 | values = value.split('__') 136 | return F(self.name).delete(V(values)) 137 | return F(self.name) - V(value) 138 | raise ValueError('Update lookup type %s not found for field %s' % (lookup, self.name)) 139 | 140 | def formfield(self, **kwargs): 141 | if self.fields: 142 | defaults = { 143 | 'form_class': NestedFormField, 144 | 'fields': self.fields, 145 | 'require_all_fields': self.require_all_fields, 146 | } 147 | defaults.update(kwargs) 148 | else: 149 | defaults = kwargs 150 | return super(JSONField, self).formfield(**defaults) 151 | -------------------------------------------------------------------------------- /docs/features.rst: -------------------------------------------------------------------------------- 1 | Feature Overview 2 | ================ 3 | Custom Postgres backend 4 | ----------------------- 5 | The customized Postgres backend adds the following features: 6 | 7 | - HStore Extension is automatically activated when a test database is created so you don't need to create a separate migration which is useful especially when building re-usable apps 8 | - Uses a different update compiler which adds some functionality outlined in the ArrayField section below 9 | - If db_index is set to True for an ArrayField, a GIN index will be created which is more useful than the default database index for arrays. 10 | - Adds some extra operators to enable ANY and ALL lookups 11 | 12 | ArrayField 13 | ---------- 14 | The included ArrayField has been subclassed from django.contrib.postgres.fields.ArrayField to add extra features and is a drop-in replacement. To use this ArrayField. The customized Postgres ArrayField adds the following features: 15 | 16 | - Get array values by index: 17 | - Update array values by index: 18 | - Added database functions for interacting with Arrays. These functions handle the provided arguments by automatically converting them to the required expressions. 19 | - Add an array of values to an existing field. In this case the output_field is required to tell Django what db type to use for the array: 20 | - Additional lookups have been added to the ArrayField to enable queries using the ANY and ALL database functions. 21 | - Use either a split array field or a multiple choice field in a model form 22 | 23 | HStoreField 24 | ----------- 25 | The included HStoreField has been subclassed from django.contrib.postgres.fields.HStoreField to add extra features and is a drop-in replacement. To use this HStoreField: 26 | 27 | The customized Postgres HStoreField adds the following features: 28 | 29 | - Get hstore values by key 30 | - Update hstore by specific keys, leaving any others untouched 31 | - Added database functions for interacting with HStores, these functions handle the arguments by converting them to the correct expressions automatically. 32 | - Nested form field for a better representation of hstore in a form, either by providing a list of keys or list of form fields. 33 | 34 | JSONField 35 | --------- 36 | The included JSONField has been subclassed from django.contrib.postgres.fields.JSONField to add extra features and is a drop-in replacement. To use this JSONField: 37 | 38 | The customized Postgres JSONField adds the following features: 39 | 40 | - Get json values by key or key path 41 | - Update JSON Field by specific keys, leaving any others untouched 42 | - Delete JSONField by key or key path 43 | - Extra database functions for interacting with JSONFields. These functions handle the arguments by converting them to the correct expressions automatically. 44 | - The same NestedFormField and NestedFormWidget referred to above for HStore can also be used with a JSON Field by providing a list of fields. 45 | 46 | ModelAdmin 47 | ---------- 48 | 49 | For an example of how these fields can be configured in a modelform; take the following models.py:: 50 | 51 | from django.db import models 52 | from django_postgres_extensions.models.fields import HStoreField, JSONField, ArrayField 53 | from django_postgres_extensions.models.fields.related import ArrayManyToManyField 54 | from django import forms 55 | from django.contrib.postgres.forms import SplitArrayField 56 | from django_postgres_extensions.forms.fields import NestedFormField 57 | 58 | details_fields = ( 59 | ('Brand', NestedFormField(keys=('Name', 'Country'))), 60 | ('Type', forms.CharField(max_length=25, required=False)), 61 | ('Colours', SplitArrayField(base_field=forms.CharField(max_length=10, required=False), size=10)), 62 | ) 63 | 64 | class Buyer(models.Model): 65 | time = models.DateTimeField(auto_now_add=True) 66 | name = models.CharField(max_length=20) 67 | 68 | def __str__(self): 69 | return self.name 70 | 71 | class Product(models.Model): 72 | name = models.CharField(max_length=15) 73 | keywords = ArrayField(models.CharField(max_length=20), default=[], form_size=10, blank=True) 74 | sports = ArrayField(models.CharField(max_length=20),default=[], blank=True, choices=( 75 | ('football', 'Football'), ('tennis', 'Tennis'), ('golf', 'Golf'), ('basketball', 'Basketball'), ('hurling', 'Hurling'), ('baseball', 'Baseball'))) 76 | shipping = HStoreField(keys=('Address', 'City', 'Region', 'Country'), blank=True, default={}) 77 | details = JSONField(fields=details_fields, blank=True, default={}) 78 | buyers = ArrayManyToManyField(Buyer) 79 | 80 | def __str__(self): 81 | return self.name 82 | 83 | @property 84 | def country(self): 85 | return self.shipping.get('Country', '') 86 | 87 | And with admin.py:: 88 | 89 | from django.contrib import admin 90 | from django_postgres_extensions.admin.options import PostgresAdmin 91 | from models import Product, Buyer 92 | 93 | class ProductAdmin(PostgresAdmin): 94 | filter_horizontal = ('buyers',) 95 | fields = ('name', 'keywords', 'sports', 'shipping', 'details', 'buyers') 96 | list_display = ('name', 'keywords', 'shipping', 'details', 'country') 97 | 98 | admin.site.register(Buyer) 99 | admin.site.register(Product, ProductAdmin) 100 | 101 | The form field would look like this: 102 | 103 | .. fancybox:: admin_form.jpg 104 | :width: 100% 105 | :height: 100% 106 | 107 | The list display would look like this: 108 | 109 | .. fancybox:: admin_list.jpg 110 | :width: 100% 111 | :height: 100% 112 | 113 | Additional Queryset Methods 114 | --------------------------- 115 | The app adds the format method to all querysets. This will defer a field and add an annotation with a different format. 116 | For example to return a hstorefield as json:: 117 | 118 | qs = Model.objects.all().format('description', HstoreToJSONBLoose) 119 | 120 | Array Many To Many Field 121 | ------------------------ 122 | The Array Many To Many Field is designed be a drop-in replacement of the normal Django Many To Many Field and thus replicates many of its features. 123 | 124 | The Array Many To Many field supports the following features which replicate the API of the regular Many To Many Field: 125 | 126 | - Descriptor queryset with add, remove, clear and set for both forward and reverse relationships 127 | - Prefetch related for both forward and reverse relationships 128 | - Lookups across relationships with filter for both forward and reverse relationships 129 | - Lookups across relationships with exclude for forward relationships only 130 | 131 | Indices and tables 132 | ================== 133 | 134 | * :ref:`genindex` 135 | * :ref:`modindex` 136 | * :ref:`search` 137 | -------------------------------------------------------------------------------- /tests/hstores/tests.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | from django.test import TestCase 4 | from .models import Product 5 | from django_postgres_extensions.models.functions import * 6 | from django_postgres_extensions.models.expressions import Key, Keys 7 | from django.db.utils import ProgrammingError 8 | from django.db import transaction 9 | 10 | 11 | class HStoreIndexTests(TestCase): 12 | 13 | def setUp(self): 14 | super(HStoreIndexTests, self).setUp() 15 | self.product = Product(name='xyz', description={'Industry': 'Music', 'Release': 'Album', 'Genre': 'Rock'}) 16 | self.product.save() 17 | self.queryset = Product.objects.filter(pk=self.product.pk) 18 | 19 | def tearDown(self): 20 | Product.objects.all().delete() 21 | 22 | def test_hstore_value(self): 23 | product = self.queryset.get() 24 | self.assertDictEqual(product.description, {'Industry': 'Music', 'Release': 'Album', 'Genre': 'Rock'}) 25 | 26 | def test_hstore_value_by_key(self): 27 | with transaction.atomic(): 28 | obj = self.queryset.annotate(Key('description', 'Release')).get() 29 | self.assertEqual(obj.description__Release, 'Album') 30 | 31 | def test_hstore_values_by_keys(self): 32 | with transaction.atomic(): 33 | obj = self.queryset.annotate(Keys('description', ['Industry', 'Release'])).get() 34 | self.assertListEqual(obj.description__selected, ['Music', 'Album']) 35 | 36 | def test_array_update_keys_values(self): 37 | with transaction.atomic(): 38 | self.queryset.update(description__ = {'Genre': 'Heavy Metal', 'Popularity': 'Very Popular'}) 39 | product = self.queryset.get() 40 | self.assertDictEqual(product.description, {'Industry': 'Music', 'Release': 'Album', 'Genre': 'Heavy Metal', 41 | 'Popularity': 'Very Popular'}) 42 | 43 | def test_hstore_raw_int_raises(self): 44 | with transaction.atomic(): 45 | self.queryset.update(description__={'Popularity': 5}) 46 | self.assertRaises(ProgrammingError, self.queryset.update, 47 | description__raw={'Popularity': 5}) 48 | 49 | class HstoreFuncTests(TestCase): 50 | def setUp(self): 51 | super(HstoreFuncTests, self).setUp() 52 | self.product = Product(name='xyz', description={'Industry': 'Music', 'Release': 'Album', 'Genre': 'Rock', 'Rating': '8'}, 53 | details={'Popularity': 'Very Popular'}) 54 | self.product.save() 55 | self.queryset = Product.objects.filter(pk=self.product.pk) 56 | 57 | def tearDown(self): 58 | Product.objects.all().delete() 59 | 60 | def test_hstore_new(self): 61 | product = Product(name='xyz', description=HStore(['Industry', 'Release', 'Genre', 'Popularity'], 62 | ['Film', 'Movie', 'Horror', 'Very Good'])) 63 | product.save() 64 | queryset = Product.objects.filter(id=product.id) 65 | obj = queryset.get() 66 | self.assertDictEqual(obj.description, 67 | {'Genre': 'Horror', 'Release': 'Movie', 'Industry': 'Film', 'Popularity': 'Very Good'}) 68 | 69 | def test_hstore_slice(self): 70 | with transaction.atomic(): 71 | obj = self.queryset.annotate(description_slice=Slice('description', ['Industry', 'Release'])).get() 72 | self.assertDictEqual(obj.description_slice, {'Release': 'Album', 'Industry': 'Music'}) 73 | 74 | def test_hstore_delete_key(self): 75 | with transaction.atomic(): 76 | self.queryset.update(description = Delete('description', 'Genre')) 77 | product = self.queryset.get() 78 | self.assertDictEqual(product.description, {'Industry': 'Music', 'Release': 'Album', 'Rating': '8'}) 79 | 80 | def test_hstore_delete_keys(self): 81 | with transaction.atomic(): 82 | self.queryset.update(description = Delete('description', ['Industry', 'Genre'])) 83 | product = self.queryset.get() 84 | self.assertDictEqual(product.description, {'Release': 'Album', 'Rating': '8'}) 85 | 86 | def test_hstore_delete_by_dict(self): 87 | with transaction.atomic(): 88 | self.queryset.update(description=Delete('description', {'Industry': 'Music', 'Release': 'Song', 'Genre': 'Rock'})) 89 | product = self.queryset.get() 90 | self.assertDictEqual(product.description, {'Release': 'Album', 'Rating': '8'}) 91 | 92 | def test_hstore_keys_as_array(self): 93 | with transaction.atomic(): 94 | product = self.queryset.annotate(description_keys=AKeys('description')).get() 95 | keys = product.description_keys 96 | keys.sort() 97 | self.assertListEqual(keys, ['Genre', 'Industry', 'Rating', 'Release']) 98 | 99 | def test_hstore_values_as_array(self): 100 | with transaction.atomic(): 101 | product = self.queryset.annotate(description_values=AVals('description')).get() 102 | values = product.description_values 103 | values.sort() 104 | self.assertListEqual(values, ['8', 'Album', 'Music', 'Rock']) 105 | 106 | def test_hstore_to_array(self): 107 | with transaction.atomic(): 108 | product = self.queryset.annotate(description_array=HStoreToArray('description')).get() 109 | self.assertListEqual(product.description_array, ['Genre', 'Rock', 'Rating', '8', 'Release', 'Album', 'Industry', 'Music']) 110 | 111 | def test_hstore_to_matrix(self): 112 | with transaction.atomic(): 113 | product = self.queryset.annotate(description_matrix=HStoreToMatrix('description')).get() 114 | self.assertListEqual(product.description_matrix, [['Genre', 'Rock'], ['Rating', '8'], ['Release', 'Album'], ['Industry', 'Music']]) 115 | 116 | def test_hstore_to_jsonb(self): 117 | with transaction.atomic(): 118 | product = self.queryset.annotate(description_jsonb=HstoreToJSONB('description')).get() 119 | self.assertDictEqual(product.description_jsonb, 120 | {'Genre': 'Rock', 'Release': 'Album', 'Industry': 'Music', 'Rating': "8"}) 121 | 122 | def test_hstore_to_jsonb_loose(self): 123 | with transaction.atomic(): 124 | product = self.queryset.annotate(description_jsonb=HstoreToJSONBLoose('description')).get() 125 | self.assertDictEqual(product.description_jsonb, 126 | {'Genre': 'Rock', 'Release': 'Album', 'Industry': 'Music', 'Rating': 8}) 127 | 128 | def test_queryset_format(self): 129 | with transaction.atomic(): 130 | qs = self.queryset.format('description', HstoreToJSONBLoose) 131 | product = qs.get() 132 | self.assertDictEqual(product.description__alt, 133 | {'Genre': 'Rock', 'Release': 'Album', 'Industry': 'Music', 'Rating': 8}) -------------------------------------------------------------------------------- /readme.rst: -------------------------------------------------------------------------------- 1 | Django Postgres Extensions! 2 | =========================== 3 | 4 | Django Postgres Extensions adds a lot of functionality to Django.contrib.postgres, specifically in relation to ArrayField, HStoreField and JSONField, including much better form fields for dealing with these field types. The app also includes an Array Many To Many Field, so you can store the relationship in an array column instead of requiring an extra database table. 5 | 6 | Check out http://django-postgres-extensions.readthedocs.io/en/latest/ to get started. 7 | 8 | Latest release (0.9.3) tested with Django 2.0.2 9 | 10 | Feature Overview 11 | ================ 12 | Custom Postgres backend 13 | ----------------------- 14 | The customized Postgres backend adds the following features: 15 | 16 | - HStore Extension is automatically activated when a test database is created so you don't need to create a separate migration which is useful especially when building re-usable apps 17 | - Uses a different update compiler which adds some functionality outlined in the ArrayField section below 18 | - If db_index is set to True for an ArrayField, a GIN index will be created which is more useful than the default database index for arrays. 19 | - Adds some extra operators to enable ANY and ALL lookups 20 | 21 | ArrayField 22 | ---------- 23 | The included ArrayField has been subclassed from django.contrib.postgres.fields.ArrayField to add extra features and is a drop-in replacement. To use this ArrayField. The customized Postgres ArrayField adds the following features: 24 | 25 | - Get array values by index: 26 | - Update array values by index: 27 | - Added database functions for interacting with Arrays. These functions handle the provided arguments by automatically converting them to the required expressions. 28 | - Add an array of values to an existing field. In this case the output_field is required to tell Django what db type to use for the array: 29 | - Additional lookups have been added to the ArrayField to enable queries using the ANY and ALL database functions. 30 | - Use either a split array field or a multiple choice field in a model form 31 | 32 | HStoreField 33 | ----------- 34 | The included HStoreField has been subclassed from django.contrib.postgres.fields.HStoreField to add extra features and is a drop-in replacement. To use this HStoreField: 35 | 36 | The customized Postgres HStoreField adds the following features: 37 | 38 | - Get hstore values by key 39 | - Update hstore by specific keys, leaving any others untouched 40 | - Added database functions for interacting with HStores, these functions handle the arguments by converting them to the correct expressions automatically. 41 | - Nested form field for a better representation of hstore in a form, either by providing a list of keys or list of form fields. 42 | 43 | JSONField 44 | --------- 45 | The included JSONField has been subclassed from django.contrib.postgres.fields.JSONField to add extra features and is a drop-in replacement. To use this JSONField: 46 | 47 | The customized Postgres JSONField adds the following features: 48 | 49 | - Get json values by key or key path 50 | - Update JSON Field by specific keys, leaving any others untouched 51 | - Delete JSONField by key or key path 52 | - Extra database functions for interacting with JSONFields. These functions handle the arguments by converting them to the correct expressions automatically. 53 | - The same NestedFormField and NestedFormWidget referred to above for HStore can also be used with a JSON Field by providing a list of fields. 54 | 55 | ModelAdmin 56 | ---------- 57 | 58 | For an example of how these fields can be configured in a modelform; take the following models.py:: 59 | 60 | from django.db import models 61 | from django_postgres_extensions.models.fields import HStoreField, JSONField, ArrayField 62 | from django_postgres_extensions.models.fields.related import ArrayManyToManyField 63 | from django import forms 64 | from django.contrib.postgres.forms import SplitArrayField 65 | from django_postgres_extensions.forms.fields import NestedFormField 66 | 67 | details_fields = ( 68 | ('Brand', NestedFormField(keys=('Name', 'Country'))), 69 | ('Type', forms.CharField(max_length=25, required=False)), 70 | ('Colours', SplitArrayField(base_field=forms.CharField(max_length=10, required=False), size=10)), 71 | ) 72 | 73 | class Buyer(models.Model): 74 | time = models.DateTimeField(auto_now_add=True) 75 | name = models.CharField(max_length=20) 76 | 77 | def __str__(self): 78 | return self.name 79 | 80 | class Product(models.Model): 81 | name = models.CharField(max_length=15) 82 | keywords = ArrayField(models.CharField(max_length=20), default=[], form_size=10, blank=True) 83 | sports = ArrayField(models.CharField(max_length=20),default=[], blank=True, choices=( 84 | ('football', 'Football'), ('tennis', 'Tennis'), ('golf', 'Golf'), ('basketball', 'Basketball'), ('hurling', 'Hurling'), ('baseball', 'Baseball'))) 85 | shipping = HStoreField(keys=('Address', 'City', 'Region', 'Country'), blank=True, default={}) 86 | details = JSONField(fields=details_fields, blank=True, default={}) 87 | buyers = ArrayManyToManyField(Buyer) 88 | 89 | def __str__(self): 90 | return self.name 91 | 92 | @property 93 | def country(self): 94 | return self.shipping.get('Country', '') 95 | 96 | And with admin.py:: 97 | 98 | from django.contrib import admin 99 | from django_postgres_extensions.admin.options import PostgresAdmin 100 | from models import Product, Buyer 101 | 102 | class ProductAdmin(PostgresAdmin): 103 | filter_horizontal = ('buyers',) 104 | fields = ('name', 'keywords', 'sports', 'shipping', 'details', 'buyers') 105 | list_display = ('name', 'keywords', 'shipping', 'details', 'country') 106 | 107 | admin.site.register(Buyer) 108 | admin.site.register(Product, ProductAdmin) 109 | 110 | The form field would look like this: 111 | 112 | .. image:: docs/admin_form.jpg 113 | 114 | The list display would look like this: 115 | 116 | .. image:: docs/admin_list.jpg 117 | 118 | Additional Queryset Methods 119 | --------------------------- 120 | The app adds the format method to all querysets. This will defer a field and add an annotation with a different format. 121 | For example to return a hstorefield as json:: 122 | 123 | qs = Model.objects.all().format('description', HstoreToJSONBLoose) 124 | 125 | Array Many To Many Field 126 | ------------------------ 127 | The Array Many To Many Field is designed be a drop-in replacement of the normal Django Many To Many Field and thus replicates many of its features. 128 | 129 | The Array Many To Many field supports the following features which replicate the API of the regular Many To Many Field: 130 | 131 | - Descriptor queryset with add, remove, clear and set for both forward and reverse relationships 132 | - Prefetch related for both forward and reverse relationships 133 | - Lookups across relationships with filter for both forward and reverse relationships 134 | - Lookups across relationships with exclude for forward relationships only 135 | -------------------------------------------------------------------------------- /django_postgres_extensions/models/query.py: -------------------------------------------------------------------------------- 1 | from django.db import transaction 2 | from django.core import exceptions 3 | from django.db.models.constants import LOOKUP_SEP 4 | from django.db.models.sql.constants import CURSOR 5 | from .sql import UpdateQuery 6 | import copy 7 | 8 | 9 | 10 | def update(self, **kwargs): 11 | """ 12 | Updates all elements in the current QuerySet, setting all the given 13 | fields to the appropriate values. 14 | """ 15 | assert self.query.can_filter(), \ 16 | "Cannot update a query once a slice has been taken." 17 | self._for_write = True 18 | query = self.query.chain(UpdateQuery) 19 | query.add_update_values(kwargs) 20 | with transaction.atomic(using=self.db, savepoint=False): 21 | rows = query.get_compiler(self.db).execute_sql(CURSOR) 22 | self._result_cache = None 23 | return rows 24 | update.alters_data = True 25 | 26 | 27 | def _update(self, values): 28 | """ 29 | A version of update that accepts field objects instead of field names. 30 | Used primarily for model saving and not intended for use by general 31 | code (it requires too much poking around at model internals to be 32 | useful at that level). 33 | """ 34 | assert self.query.can_filter(), \ 35 | "Cannot update a query once a slice has been taken." 36 | query = self.query.chain(UpdateQuery) 37 | values = [value for value in values if not getattr(value[0], 'many_to_many_array', False)] 38 | query.add_update_fields(values) 39 | self._result_cache = None 40 | return query.get_compiler(self.db).execute_sql(CURSOR) 41 | _update.alters_data = True 42 | _update.queryset_only = False 43 | 44 | def format(self, field, expression, output_field=None, *args, **kwargs): 45 | if not output_field: 46 | output_field = field + '__alt' 47 | kwargs = {output_field: expression(field, *args, **kwargs)} 48 | qs = self.defer(field).annotate(**kwargs) 49 | return qs 50 | 51 | def prefetch_one_level(instances, prefetcher, lookup, level): 52 | """ 53 | Helper function for prefetch_related_objects(). 54 | 55 | Run prefetches on all instances using the prefetcher object, 56 | assigning results to relevant caches in instance. 57 | 58 | Return the prefetched objects along with any additional prefetches that 59 | must be done due to prefetch_related lookups found from default managers. 60 | """ 61 | # prefetcher must have a method get_prefetch_queryset() which takes a list 62 | # of instances, and returns a tuple: 63 | 64 | # (queryset of instances of self.model that are related to passed in instances, 65 | # callable that gets value to be matched for returned instances, 66 | # callable that gets value to be matched for passed in instances, 67 | # boolean that is True for singly related objects, 68 | # cache or field name to assign to, 69 | # boolean that is True when the previous argument is a cache name vs a field name). 70 | 71 | # The 'values to be matched' must be hashable as they will be used 72 | # in a dictionary. 73 | 74 | rel_qs, rel_obj_attr, instance_attr, single, cache_name, is_descriptor = ( 75 | prefetcher.get_prefetch_queryset(instances, lookup.get_current_queryset(level))) 76 | # We have to handle the possibility that the QuerySet we just got back 77 | # contains some prefetch_related lookups. We don't want to trigger the 78 | # prefetch_related functionality by evaluating the query. Rather, we need 79 | # to merge in the prefetch_related lookups. 80 | # Copy the lookups in case it is a Prefetch object which could be reused 81 | # later (happens in nested prefetch_related). 82 | additional_lookups = [ 83 | copy.copy(additional_lookup) for additional_lookup 84 | in getattr(rel_qs, '_prefetch_related_lookups', ()) 85 | ] 86 | if additional_lookups: 87 | # Don't need to clone because the manager should have given us a fresh 88 | # instance, so we access an internal instead of using public interface 89 | # for performance reasons. 90 | rel_qs._prefetch_related_lookups = () 91 | 92 | all_related_objects = list(rel_qs) 93 | 94 | is_multi_reference = getattr(rel_qs, 'is_multi_reference', False) 95 | 96 | rel_obj_cache = {} 97 | if not is_multi_reference: 98 | rel_obj_cache = {} 99 | for rel_obj in all_related_objects: 100 | rel_attr_val = rel_obj_attr(rel_obj) 101 | rel_obj_cache.setdefault(rel_attr_val, []).append(rel_obj) 102 | 103 | to_attr, as_attr = lookup.get_current_to_attr(level) 104 | # Make sure `to_attr` does not conflict with a field. 105 | if as_attr and instances: 106 | # We assume that objects retrieved are homogeneous (which is the premise 107 | # of prefetch_related), so what applies to first object applies to all. 108 | model = instances[0].__class__ 109 | try: 110 | model._meta.get_field(to_attr) 111 | except exceptions.FieldDoesNotExist: 112 | pass 113 | else: 114 | msg = 'to_attr={} conflicts with a field on the {} model.' 115 | raise ValueError(msg.format(to_attr, model.__name__)) 116 | 117 | # Whether or not we're prefetching the last part of the lookup. 118 | leaf = len(lookup.prefetch_through.split(LOOKUP_SEP)) - 1 == level 119 | 120 | for obj in instances: 121 | instance_attr_val = instance_attr(obj) 122 | if is_multi_reference: 123 | vals = [rel_obj for rel_obj in all_related_objects if rel_obj_attr(rel_obj, instance_attr_val)] 124 | else: 125 | vals = rel_obj_cache.get(instance_attr_val, []) 126 | 127 | if single: 128 | val = vals[0] if vals else None 129 | if as_attr: 130 | # A to_attr has been given for the prefetch. 131 | setattr(obj, to_attr, val) 132 | elif is_descriptor: 133 | # cache_name points to a field name in obj. 134 | # This field is a descriptor for a related object. 135 | setattr(obj, cache_name, val) 136 | else: 137 | # No to_attr has been given for this prefetch operation and the 138 | # cache_name does not point to a descriptor. Store the value of 139 | # the field in the object's field cache. 140 | obj._state.fields_cache[cache_name] = val 141 | else: 142 | if as_attr: 143 | setattr(obj, to_attr, vals) 144 | else: 145 | manager = getattr(obj, to_attr) 146 | if leaf and lookup.queryset is not None: 147 | qs = manager._apply_rel_filters(lookup.queryset) 148 | else: 149 | qs = manager.get_queryset() 150 | qs._result_cache = vals 151 | # We don't want the individual qs doing prefetch_related now, 152 | # since we have merged this into the current work. 153 | qs._prefetch_done = True 154 | obj._prefetched_objects_cache[cache_name] = qs 155 | return all_related_objects, additional_lookups -------------------------------------------------------------------------------- /tests/prefetch_related_array/models.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from django.contrib.contenttypes.fields import ( 4 | GenericForeignKey, GenericRelation, 5 | ) 6 | from django.contrib.contenttypes.models import ContentType 7 | from django.db import models 8 | from django.utils.encoding import python_2_unicode_compatible 9 | from django_postgres_extensions.models.fields.related import ArrayManyToManyField 10 | 11 | 12 | # Basic tests 13 | 14 | @python_2_unicode_compatible 15 | class Author(models.Model): 16 | name = models.CharField(max_length=50, unique=True) 17 | first_book = models.ForeignKey('Book', models.CASCADE, related_name='first_time_authors') 18 | favorite_authors = ArrayManyToManyField( 19 | 'self', symmetrical=False, related_name='favors_me') 20 | 21 | def __str__(self): 22 | return self.name 23 | 24 | class Meta: 25 | ordering = ['id'] 26 | 27 | 28 | class AuthorWithAge(Author): 29 | author = models.OneToOneField(Author, models.CASCADE, parent_link=True) 30 | age = models.IntegerField() 31 | 32 | 33 | @python_2_unicode_compatible 34 | class AuthorAddress(models.Model): 35 | author = models.ForeignKey(Author, models.CASCADE, to_field='name', related_name='addresses') 36 | address = models.TextField() 37 | 38 | class Meta: 39 | ordering = ['id'] 40 | 41 | def __str__(self): 42 | return self.address 43 | 44 | 45 | @python_2_unicode_compatible 46 | class Book(models.Model): 47 | title = models.CharField(max_length=255) 48 | authors = ArrayManyToManyField(Author, related_name='books') 49 | 50 | def __str__(self): 51 | return self.title 52 | 53 | class Meta: 54 | ordering = ['id'] 55 | 56 | 57 | class BookWithYear(Book): 58 | book = models.OneToOneField(Book, models.CASCADE, parent_link=True) 59 | published_year = models.IntegerField() 60 | aged_authors = ArrayManyToManyField( 61 | AuthorWithAge, related_name='books_with_year') 62 | 63 | 64 | class Bio(models.Model): 65 | author = models.OneToOneField(Author, models.CASCADE) 66 | books = ArrayManyToManyField(Book, blank=True) 67 | 68 | 69 | @python_2_unicode_compatible 70 | class Reader(models.Model): 71 | name = models.CharField(max_length=50) 72 | books_read = ArrayManyToManyField(Book, related_name='read_by') 73 | 74 | def __str__(self): 75 | return self.name 76 | 77 | class Meta: 78 | ordering = ['id'] 79 | 80 | 81 | class BookReview(models.Model): 82 | book = models.ForeignKey(BookWithYear, models.CASCADE) 83 | notes = models.TextField(null=True, blank=True) 84 | 85 | 86 | # Models for default manager tests 87 | 88 | class Qualification(models.Model): 89 | name = models.CharField(max_length=10) 90 | 91 | class Meta: 92 | ordering = ['id'] 93 | 94 | 95 | class TeacherManager(models.Manager): 96 | def get_queryset(self): 97 | return super(TeacherManager, self).get_queryset().prefetch_related('qualifications') 98 | 99 | 100 | @python_2_unicode_compatible 101 | class Teacher(models.Model): 102 | name = models.CharField(max_length=50) 103 | qualifications = ArrayManyToManyField(Qualification) 104 | 105 | objects = TeacherManager() 106 | 107 | def __str__(self): 108 | return "%s (%s)" % (self.name, ", ".join(q.name for q in self.qualifications.all())) 109 | 110 | class Meta: 111 | ordering = ['id'] 112 | 113 | 114 | class Department(models.Model): 115 | name = models.CharField(max_length=50) 116 | teachers = ArrayManyToManyField(Teacher) 117 | 118 | class Meta: 119 | ordering = ['id'] 120 | 121 | 122 | # GenericRelation/GenericForeignKey tests 123 | 124 | @python_2_unicode_compatible 125 | class TaggedItem(models.Model): 126 | tag = models.SlugField() 127 | content_type = models.ForeignKey( 128 | ContentType, 129 | models.CASCADE, 130 | related_name="taggeditem_set2", 131 | ) 132 | object_id = models.PositiveIntegerField() 133 | content_object = GenericForeignKey('content_type', 'object_id') 134 | created_by_ct = models.ForeignKey( 135 | ContentType, 136 | models.SET_NULL, 137 | null=True, 138 | related_name='taggeditem_set3', 139 | ) 140 | created_by_fkey = models.PositiveIntegerField(null=True) 141 | created_by = GenericForeignKey('created_by_ct', 'created_by_fkey',) 142 | favorite_ct = models.ForeignKey( 143 | ContentType, 144 | models.SET_NULL, 145 | null=True, 146 | related_name='taggeditem_set4', 147 | ) 148 | favorite_fkey = models.CharField(max_length=64, null=True) 149 | favorite = GenericForeignKey('favorite_ct', 'favorite_fkey') 150 | 151 | def __str__(self): 152 | return self.tag 153 | 154 | class Meta: 155 | ordering = ['id'] 156 | 157 | 158 | class Bookmark(models.Model): 159 | url = models.URLField() 160 | tags = GenericRelation(TaggedItem, related_query_name='bookmarks') 161 | favorite_tags = GenericRelation(TaggedItem, 162 | content_type_field='favorite_ct', 163 | object_id_field='favorite_fkey', 164 | related_query_name='favorite_bookmarks') 165 | 166 | class Meta: 167 | ordering = ['id'] 168 | 169 | 170 | class Comment(models.Model): 171 | comment = models.TextField() 172 | 173 | # Content-object field 174 | content_type = models.ForeignKey(ContentType, models.CASCADE) 175 | object_pk = models.TextField() 176 | content_object = GenericForeignKey(ct_field="content_type", fk_field="object_pk") 177 | 178 | class Meta: 179 | ordering = ['id'] 180 | 181 | 182 | # Models for lookup ordering tests 183 | 184 | class House(models.Model): 185 | name = models.CharField(max_length=50) 186 | address = models.CharField(max_length=255) 187 | owner = models.ForeignKey('Person', models.SET_NULL, null=True) 188 | main_room = models.OneToOneField('Room', models.SET_NULL, related_name='main_room_of', null=True) 189 | 190 | class Meta: 191 | ordering = ['id'] 192 | 193 | 194 | class Room(models.Model): 195 | name = models.CharField(max_length=50) 196 | house = models.ForeignKey(House, models.CASCADE, related_name='rooms') 197 | 198 | class Meta: 199 | ordering = ['id'] 200 | 201 | 202 | class Person(models.Model): 203 | name = models.CharField(max_length=50) 204 | houses = ArrayManyToManyField(House, related_name='occupants') 205 | 206 | @property 207 | def primary_house(self): 208 | # Assume business logic forces every person to have at least one house. 209 | return sorted(self.houses.all(), key=lambda house: -house.rooms.count())[0] 210 | 211 | @property 212 | def all_houses(self): 213 | return list(self.houses.all()) 214 | 215 | class Meta: 216 | ordering = ['id'] 217 | 218 | 219 | # Models for nullable FK tests 220 | 221 | @python_2_unicode_compatible 222 | class Employee(models.Model): 223 | name = models.CharField(max_length=50) 224 | boss = models.ForeignKey('self', models.SET_NULL, null=True, related_name='serfs') 225 | 226 | def __str__(self): 227 | return self.name 228 | 229 | class Meta: 230 | ordering = ['id'] 231 | 232 | 233 | # Ticket #19607 234 | 235 | @python_2_unicode_compatible 236 | class LessonEntry(models.Model): 237 | name1 = models.CharField(max_length=200) 238 | name2 = models.CharField(max_length=200) 239 | 240 | def __str__(self): 241 | return "%s %s" % (self.name1, self.name2) 242 | 243 | 244 | @python_2_unicode_compatible 245 | class WordEntry(models.Model): 246 | lesson_entry = models.ForeignKey(LessonEntry, models.CASCADE) 247 | name = models.CharField(max_length=200) 248 | 249 | def __str__(self): 250 | return "%s (%s)" % (self.name, self.id) 251 | 252 | 253 | # Ticket #21410: Regression when related_name="+" 254 | 255 | @python_2_unicode_compatible 256 | class Author2(models.Model): 257 | name = models.CharField(max_length=50, unique=True) 258 | first_book = models.ForeignKey('Book', models.CASCADE, related_name='first_time_authors+') 259 | favorite_books = ArrayManyToManyField('Book', related_name='+') 260 | 261 | def __str__(self): 262 | return self.name 263 | 264 | class Meta: 265 | ordering = ['id'] 266 | 267 | 268 | # Models for many-to-many with UUID pk test: 269 | 270 | class Pet(models.Model): 271 | id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) 272 | name = models.CharField(max_length=20) 273 | people = ArrayManyToManyField(Person, related_name='pets') 274 | 275 | 276 | class Flea(models.Model): 277 | id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False) 278 | current_room = models.ForeignKey(Room, models.SET_NULL, related_name='fleas', null=True) 279 | pets_visited = ArrayManyToManyField(Pet, related_name='fleas_hosted') 280 | people_visited = ArrayManyToManyField(Person, related_name='fleas_hosted') 281 | -------------------------------------------------------------------------------- /tests/modeladmin/tests.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | from django.test import override_settings 4 | from django.contrib.auth import get_user_model 5 | from django.contrib.admin.tests import AdminSeleniumTestCase 6 | from selenium.common.exceptions import NoSuchWindowException 7 | from .models import Product, Buyer 8 | import time 9 | import ast 10 | 11 | installed_apps = [ 12 | 'django.contrib.admin', 13 | 'django.contrib.contenttypes', 14 | 'django.contrib.auth', 15 | 'django.contrib.sites', 16 | 'django.contrib.sessions', 17 | 'django.contrib.messages', 18 | 'django.contrib.staticfiles', 19 | 'django.contrib.postgres', 20 | 'django_postgres_extensions', 21 | 'modeladmin' 22 | ] 23 | 24 | @override_settings(ROOT_URLCONF='modeladmin.urls', INSTALLED_APPS=installed_apps, DEBUG=True) 25 | class PostgresAdminTestCase(AdminSeleniumTestCase): 26 | 27 | close_manually = False 28 | 29 | browsers = ['chrome'] 30 | 31 | available_apps = installed_apps 32 | 33 | username = 'super' 34 | password = 'secret' 35 | email = 'super@example.com' 36 | app_name = 'modeladmin' 37 | model_name = 'product' 38 | 39 | @classmethod 40 | def setUpClass(cls): 41 | super(PostgresAdminTestCase, cls).setUpClass() 42 | cls.selenium.maximize_window() 43 | cls.url = '%s/%s/%s/%s' % (cls.live_server_url, 'admin', cls.app_name, cls.model_name) 44 | 45 | def tearDown(self): 46 | time.sleep(3) 47 | while self.close_manually: 48 | try: 49 | self.wait_page_loaded() 50 | except NoSuchWindowException: 51 | self.close_manually = False 52 | super(PostgresAdminTestCase, self).tearDown() 53 | 54 | def setUp(self): 55 | super(PostgresAdminTestCase, self).setUp() 56 | User = get_user_model() 57 | if not User.objects.filter(username=self.username).exists(): 58 | User.objects.create_superuser(username=self.username, password=self.password, email=self.email) 59 | self.admin_login(self.username, self.password) 60 | 61 | def test_list(self): 62 | prod = Product(name='Pro Trainers', keywords=['fun', 'popular'], 63 | sports=['tennis', 'basketball'], 64 | shipping={ 65 | 'Address': 'Pearse Street', 66 | 'City': 'Dublin', 67 | 'Region': 'Co. Dublin', 68 | 'Country': 'Ireland' 69 | }, 70 | details={ 71 | 'brand': { 72 | 'name': 'Adidas', 73 | 'country': 'Germany', 74 | }, 75 | 'type': 'runners', 76 | 'colours': ['black', 'white', 'blue'] 77 | } 78 | ) 79 | 80 | prod.save() 81 | self.selenium.get(self.url) 82 | self.wait_page_loaded() 83 | element = self.selenium.find_elements_by_class_name('field-name')[0] 84 | self.assertEqual(element.text, "Pro Trainers") 85 | element = self.selenium.find_elements_by_class_name('field-keywords')[0] 86 | self.assertEqual(element.text, 'fun, popular') 87 | element = self.selenium.find_elements_by_class_name('field-shipping')[0] 88 | self.assertDictEqual(ast.literal_eval(element.text), {'City': 'Dublin', 'Region': 'Co. Dublin', 'Country': 'Ireland', 'Address': 'Pearse Street'}) 89 | element = self.selenium.find_elements_by_class_name('field-details')[0] 90 | self.assertDictEqual(ast.literal_eval(element.text), {'colours': ['black', 'white', 'blue'], 'brand': {'country': 'Germany', 'name': 'Adidas'}, 'type': 'runners'}) 91 | element = self.selenium.find_elements_by_class_name('field-country')[0] 92 | self.assertEqual(element.text, "Ireland") 93 | 94 | def fill_form(self, ids_values, select_values, many_to_many_select, replace=False): 95 | for id, value in ids_values: 96 | element = self.selenium.find_element_by_id(id) 97 | if replace: 98 | element.clear() 99 | element.send_keys(value) 100 | for select in select_values: 101 | selector, values = select 102 | for value in values: 103 | selection_box = "#id_%s" % selector 104 | self.get_select_option(selection_box, value).click() 105 | for field_name, values in many_to_many_select: 106 | from_box = '#id_%s_from' % field_name 107 | choose_link = 'id_%s_add_link' % field_name 108 | choose_elem = self.selenium.find_element_by_id(choose_link) 109 | for value in values: 110 | self.get_select_option(from_box, str(value)).click() 111 | choose_elem.click() 112 | self.selenium.find_element_by_xpath('//input[@value="Save"]').click() 113 | self.wait_page_loaded() 114 | 115 | def test_add(self): 116 | buyer1 = Buyer(name='Muhammed Ali') 117 | buyer1.save() 118 | buyer2 = Buyer(name='Conor McGregor') 119 | buyer2.save() 120 | buyer3 = Buyer(name='Floyd Mayweather') 121 | buyer3.save() 122 | self.selenium.get('%s/%s' % (self.url, 'add')) 123 | self.wait_page_loaded() 124 | ids_values = (("id_name", "Pro Trainers"), ("id_keywords_0", "fun"), ("id_keywords_1", "popular"), 125 | ("id_shipping_address", "Pearse Street"), ("id_shipping_city", "Dublin"), ("id_shipping_region", "Co.Dublin"), 126 | ("id_shipping_country", "Ireland"), ("id_details_brand_name", "Adidas"), ("id_details_brand_country", "Germany"), 127 | ("id_details_type", "Runners"), ("id_details_colours_0", "Black"), ("id_details_colours_1", "White"), 128 | ("id_details_colours_2", "Blue")) 129 | select_values = (("sports", ("tennis", "basketball"),),) 130 | many_to_many_select = (("buyers", (buyer1.pk, buyer3.pk),),) 131 | self.fill_form(ids_values, select_values, many_to_many_select) 132 | obj = Product.objects.get() 133 | self.assertEqual(obj.name, 'Pro Trainers') 134 | self.assertDictEqual(obj.shipping, {'City': 'Dublin', 'Region': 'Co.Dublin', 'Country': 'Ireland', 135 | 'Address': 'Pearse Street'}) 136 | self.assertListEqual(obj.sports, ['tennis', 'basketball']) 137 | self.assertListEqual(obj.details['Colours'], ['Black', 'White', 'Blue', '', '', '', '', '', '', '']) 138 | self.assertDictEqual(obj.details['Brand'], {'Country': 'Germany', 'Name': 'Adidas'}) 139 | self.assertEqual(obj.details['Type'], 'Runners') 140 | self.assertListEqual(obj.keywords,['fun', 'popular']) 141 | buyers = obj.buyers.all().order_by('id') 142 | self.assertQuerysetEqual(buyers, ['', '']) 143 | 144 | def test_update(self): 145 | buyer1 = Buyer(name='Muhammed Ali') 146 | buyer1.save() 147 | buyer2 = Buyer(name='Conor McGregor') 148 | buyer2.save() 149 | buyer3 = Buyer(name='Floyd Mayweather') 150 | buyer3.save() 151 | prod = Product(name='Pro Trainers', keywords=['fun', 'popular'], 152 | sports=['tennis', 'basketball'], 153 | shipping={ 154 | 'Address': 'Pearse Street', 155 | 'City': 'Dublin', 156 | 'Region': 'Co. Dublin', 157 | 'Country': 'Ireland' 158 | }, 159 | details={ 160 | 'Brand': { 161 | 'Name': 'Adidas', 162 | 'Country': 'Germany', 163 | }, 164 | 'Type': 'runners', 165 | 'Colours': ['black', 'white', 'blue'] 166 | } 167 | ) 168 | 169 | prod.save() 170 | prod.buyers.add(buyer1, buyer3) 171 | self.selenium.get('%s/%s/%s' % (self.url, prod.pk, 'change')) 172 | self.wait_page_loaded() 173 | ids_values = (("id_keywords_1", "not popular"), 174 | ("id_shipping_address", "Nassau Street"), ("id_details_brand_name", "Nike"), ("id_details_brand_country", "USA"), 175 | ("id_details_colours_3", "Red")) 176 | select_values = (("sports", ("football",),),) 177 | many_to_many_select = (("buyers", (buyer2.pk, ),),) 178 | self.fill_form(ids_values, select_values, many_to_many_select, replace=True) 179 | obj = Product.objects.get() 180 | buyers = obj.buyers.all().order_by('id') 181 | self.assertQuerysetEqual(buyers, ['', '', '']) 182 | self.assertEqual(obj.name, 'Pro Trainers') 183 | self.assertDictEqual(obj.shipping, {'City': 'Dublin', 'Region': 'Co. Dublin', 'Country': 'Ireland', 184 | 'Address': 'Nassau Street'}) 185 | self.assertListEqual(obj.sports, ['football', 'tennis', 'basketball']) 186 | self.assertListEqual(obj.details['Colours'], ['black', 'white', 'blue', 'Red', '', '', '', '', '', '']) 187 | self.assertDictEqual(obj.details['Brand'], {'Country': 'USA', 'Name': 'Nike'}) 188 | self.assertEqual(obj.details['Type'], 'runners') 189 | self.assertListEqual(obj.keywords,['fun', 'not popular']) 190 | 191 | def test_delete(self): 192 | buyer1 = Buyer(name='Muhammed Ali') 193 | buyer1.save() 194 | buyer2 = Buyer(name='Conor McGregor') 195 | buyer2.save() 196 | buyer3 = Buyer(name='Floyd Mayweather') 197 | buyer3.save() 198 | prod = Product(name='Pro Trainers', keywords=['fun', 'popular'], 199 | sports=['tennis', 'basketball'], 200 | shipping={ 201 | 'Address': 'Pearse Street', 202 | 'City': 'Dublin', 203 | 'Region': 'Co. Dublin', 204 | 'Country': 'Ireland' 205 | }, 206 | details={ 207 | 'Brand': { 208 | 'Name': 'Adidas', 209 | 'Country': 'Germany', 210 | }, 211 | 'Type': 'runners', 212 | 'Colours': ['black', 'white', 'blue'] 213 | } 214 | ) 215 | 216 | prod.save() 217 | prod.buyers.add(buyer1, buyer3) 218 | self.selenium.get('%s/%s/%s' % (self.url, prod.pk, 'delete')) 219 | self.selenium.find_element_by_xpath('//input[@value="Yes, I\'m sure"]').click() 220 | self.assertEqual(Product.objects.count(), 0) -------------------------------------------------------------------------------- /django_postgres_extensions/models/fields/related_descriptors.py: -------------------------------------------------------------------------------- 1 | from django.db.models import signals 2 | from django.db import transaction, router 3 | from django_postgres_extensions.utils import OrderedSet 4 | from django_postgres_extensions.models.functions import ArrayCat, ArrayRemove, multi_array_remove 5 | from django.utils.functional import cached_property 6 | 7 | class MultiReferenceDescriptor(object): 8 | 9 | def __init__(self, rel, reverse=False, isJson=False): 10 | self.rel = rel 11 | self.isJson = isJson 12 | self.reverse = reverse 13 | self.through = rel.related_model 14 | 15 | def __get__(self, instance, cls=None): 16 | """ 17 | Get the manager for the many-to-many array field. 18 | """ 19 | if instance is None: 20 | return self 21 | 22 | return self.related_manager_cls(instance) 23 | 24 | @cached_property 25 | def related_manager_cls(self): 26 | model = self.rel.related_model if self.reverse else self.rel.model 27 | db = router.db_for_read(model) 28 | if hasattr(db, 'create_array_many_to_many_manager'): 29 | create_manager_func = db.create_array_many_to_many_manager 30 | else: 31 | create_manager_func = create_array_many_to_many_manager 32 | return create_manager_func( 33 | model._default_manager.__class__, 34 | self.rel, 35 | self.reverse, 36 | self.isJson 37 | ) 38 | 39 | def create_array_many_to_many_manager(superclass, rel, reverse, IsJson): 40 | 41 | class ArrayForwardManyToManyManager(superclass): 42 | 43 | def __init__(self, instance): 44 | if instance.pk is None: 45 | raise ValueError("%r instance needs to have a primary key value before " 46 | "a many-to-many relationship can be used." % 47 | instance.__class__.__name__) 48 | super(ArrayForwardManyToManyManager, self).__init__() 49 | self.instance = instance 50 | self.field = rel.field 51 | self.target_field = rel.target_field 52 | self.fieldname = self.field.name 53 | self.column = self.field.attname 54 | self.rel = rel 55 | self.set_attributes() 56 | self.through = self.rel.related_model 57 | 58 | def __call__(self, **kwargs): 59 | # We use **kwargs rather than a kwarg argument to enforce the 60 | # `manager='manager_name'` syntax. 61 | manager = getattr(self.model, kwargs.pop('manager')) 62 | if hasattr(self.db, 'create_array_many_to_many_manager'): 63 | create_manager_func = self.db.create_array_many_to_many_manager 64 | else: 65 | create_manager_func = create_array_many_to_many_manager 66 | manager_class = create_manager_func(manager.__class__, rel, reverse, False) 67 | return manager_class(instance=self.instance) 68 | do_not_call_in_templates = True 69 | 70 | def set_attributes(self): 71 | self.model = self.rel.model 72 | self.related_model = self.rel.related_model 73 | self.prefetch_cache_name = rel.field.name 74 | self.to_field_name = self.target_field.name 75 | self.core_filters = {'%s' % self.rel.name: self.instance} 76 | self.symmetrical = self.rel.symmetrical 77 | 78 | def _apply_rel_filters(self, queryset): 79 | """ 80 | Filter the queryset for the instance this manager is bound to. 81 | """ 82 | queryset._add_hints(instance=self.instance) 83 | if self._db: 84 | queryset = queryset.using(self._db) 85 | queryset = queryset.filter(**self.core_filters) 86 | return queryset 87 | 88 | def get_queryset(self): 89 | try: 90 | return self.instance._prefetched_objects_cache[self.prefetch_cache_name] 91 | except (AttributeError, KeyError): 92 | queryset = super(ArrayForwardManyToManyManager, self).get_queryset() 93 | return self._apply_rel_filters(queryset) 94 | 95 | def get_prefetch_filters(self, instances): 96 | pks = [] 97 | for instance in instances: 98 | pks += getattr(instance, self.column) 99 | filters = {'%s__in' % self.to_field_name:set(pks)} 100 | return filters 101 | 102 | def validate_rel_obj(self, rel_obj, fks): 103 | return getattr(rel_obj, self.to_field_name) in fks 104 | 105 | def get_instance_attr(self, instance): 106 | return getattr(instance, self.column) 107 | 108 | def get_prefetch_queryset(self, instances, queryset=None): 109 | if queryset is None: 110 | queryset = super(ArrayForwardManyToManyManager, self).get_queryset() 111 | 112 | queryset._add_hints(instance=instances[0]) 113 | queryset = queryset.using(queryset._db or self._db) 114 | 115 | query = self.get_prefetch_filters(instances) 116 | queryset = queryset.filter(**query) 117 | queryset.is_multi_reference = True 118 | 119 | return queryset, self.validate_rel_obj, self.get_instance_attr, False, self.prefetch_cache_name, True 120 | 121 | def _update_instance(self, **kwargs): 122 | qs = self.related_model.objects.filter(pk=self.instance.pk) 123 | qs.update(**kwargs) 124 | _update_instance.alters_data = True 125 | 126 | def _add_items(self, *objs): 127 | objs = list(objs) 128 | if len(objs) == 1: 129 | exclude = {self.column: objs[0]} 130 | kwargs = {self.column: ArrayCat(self.column, objs, output_field=self.field)} 131 | self.related_model.objects.filter(pk=self.instance.pk).exclude(**exclude).update(**kwargs) 132 | else: 133 | instance = self.related_model.objects.only(self.column).get(pk=self.instance.pk) 134 | objs = list(OrderedSet(objs) - set(getattr(instance, self.column))) 135 | kwargs = {self.column: ArrayCat(self.column, objs, output_field=self.field)} 136 | self._update_instance(**kwargs) 137 | # If this is a symmetrical m2m relation to self, add the mirror entry to the other objs array 138 | if self.symmetrical: 139 | kwargs = {self.column: ArrayCat(self.column, [self.instance.pk], output_field=self.field)} 140 | self.model.objects.filter(pk__in=objs).update(**kwargs) 141 | _add_items.alters_data = True 142 | 143 | def validate_item(self, obj): 144 | return self.field.validate_item(obj, model=self.model) 145 | 146 | def add(self, *objs, **kwargs): 147 | objs = [self.validate_item(obj) for obj in objs] 148 | signals.m2m_changed.send( 149 | sender=self.through, action='pre_add', 150 | instance=self.instance, reverse=self.reverse, 151 | model=self.model, pk_set=objs, using=self.db, 152 | ) 153 | with transaction.atomic(): 154 | self._add_items(*objs) 155 | signals.m2m_changed.send( 156 | sender=self.through, action='post_add', 157 | instance=self.instance, reverse=self.reverse, 158 | model=self.model, pk_set=objs, using=self.db, 159 | ) 160 | 161 | def remove(self, *objs): 162 | objs = [self.validate_item(obj) for obj in objs] 163 | signals.m2m_changed.send( 164 | sender=self.through, action="pre_remove", 165 | instance=self.instance, reverse=self.reverse, 166 | model=self.model, pk_set=objs, using=self.db, 167 | ) 168 | with transaction.atomic(): 169 | self._remove_items(*objs) 170 | signals.m2m_changed.send( 171 | sender=self.through, action="post_remove", 172 | instance=self.instance, reverse=self.reverse, 173 | model=self.model, pk_set=objs, using=self.db, 174 | ) 175 | remove.alters_data = True 176 | 177 | def _remove_items(self, *objs, **kwargs): 178 | if objs: 179 | chunks = [objs[x:x + 100] for x in range(0, len(objs), 100)] 180 | for chunk in chunks: 181 | kwargs = {self.column: multi_array_remove(self.column, *chunk)} 182 | self._update_instance(**kwargs) 183 | # If this is a symmetrical m2m relation to self, add the mirror entry to the other objs array 184 | if self.symmetrical: 185 | kwargs = {self.column: ArrayRemove(self.column, self.instance.pk, output_field=self.field)} 186 | self.model.objects.filter(pk__in=list(objs)).update(**kwargs) 187 | _remove_items.alters_data = True 188 | 189 | def create(self, **kwargs): 190 | new_obj = super(ArrayForwardManyToManyManager, self).create(**kwargs) 191 | self.add(new_obj) 192 | return new_obj 193 | create.alters_data = True 194 | 195 | def get_or_create(self, **kwargs): 196 | obj, created = super(ArrayForwardManyToManyManager, self).get_or_create(**kwargs) 197 | # We only need to add() if created because if we got an object back 198 | # from get() then the relationship already exists. 199 | if created: 200 | self.add(obj) 201 | return obj, created 202 | get_or_create.alters_data = True 203 | 204 | def update_or_create(self, **kwargs): 205 | obj, created = super(ArrayForwardManyToManyManager, self).update_or_create(**kwargs) 206 | # We only need to add() if created because if we got an object back 207 | # from get() then the relationship already exists. 208 | if created: 209 | self.add(obj) 210 | return obj, created 211 | update_or_create.alters_data = True 212 | 213 | def _clear(self): 214 | kwargs = {self.column: []} 215 | self._update_instance(**kwargs) 216 | if self.symmetrical: 217 | kwargs = {self.column: ArrayRemove(self.column, self.instance.pk, output_field=self.field)} 218 | self.model.objects.update(**kwargs) 219 | _clear.alters_data = True 220 | 221 | def clear(self, **kwargs): 222 | with transaction.atomic(): 223 | signals.m2m_changed.send( 224 | sender=self.through, action="pre_clear", 225 | instance=self.instance, reverse=reverse, 226 | model=self.model, pk_set=None, using=self.db, 227 | ) 228 | with transaction.atomic(): 229 | self._clear() 230 | signals.m2m_changed.send( 231 | sender=self.model, action="post_clear", 232 | instance=self.instance, reverse=reverse, 233 | model=self.model, pk_set=None, using=self.db, 234 | ) 235 | clear.alters_data = True 236 | 237 | def set(self, objs, **kwargs): 238 | with transaction.atomic(savepoint=False): 239 | old_ids = set(self.values_list(self.to_field_name, flat=True)) 240 | new_objs = [] 241 | for obj in objs: 242 | fk_val = (obj.pk if isinstance(obj, self.model) else obj) 243 | if fk_val in old_ids: 244 | old_ids.remove(fk_val) 245 | else: 246 | new_objs.append(obj) 247 | self.remove(*old_ids) 248 | self.add(*new_objs) 249 | set.alters_data = True 250 | 251 | class ArrayReverseManyToManyManager(ArrayForwardManyToManyManager): 252 | 253 | def set_attributes(self): 254 | self.model = rel.related_model 255 | self.related_model = rel.model 256 | self.prefetch_cache_name = rel.field.related_query_name() 257 | self.core_filters = {'%s' % self.fieldname: self.instance} 258 | self.prefetch_filters = {} 259 | self.to_field_name = 'pk' 260 | self.to_field_value = self.instance.pk 261 | self.symmetrical = False 262 | 263 | def validate_rel_obj(self, rel_obj, pk): 264 | return pk in getattr(rel_obj, self.column) 265 | 266 | def get_instance_attr(self, instance): 267 | return getattr(instance, self.to_field_name) 268 | 269 | def get_prefetch_filters(self, instances): 270 | filters = {'%s__overlap' % self.fieldname: instances} 271 | return filters 272 | 273 | def _add_items(self, *objs, **kwargs): 274 | exclude = {self.column: self.instance.pk} 275 | qs = self.model.objects.filter(pk__in = objs).exclude(**exclude) 276 | kwargs = {self.column: ArrayCat(self.column, [self.to_field_value])} 277 | qs.update(**kwargs) 278 | _add_items.alters_data = True 279 | 280 | def _remove_items(self, *objs, **kwargs): 281 | qs = self.filter(pk__in = objs) 282 | kwargs = {self.column: ArrayRemove(self.column, self.to_field_value)} 283 | with transaction.atomic(): 284 | qs.update(**kwargs) 285 | _remove_items.alters_data = True 286 | 287 | def _clear(self): 288 | with transaction.atomic(): 289 | kwargs = {self.column: ArrayRemove(self.column, self.to_field_value)} 290 | self.model.objects.update(**kwargs) 291 | _clear.alters_data = True 292 | 293 | if reverse: 294 | return ArrayReverseManyToManyManager 295 | return ArrayForwardManyToManyManager -------------------------------------------------------------------------------- /tests/arrays/tests.py: -------------------------------------------------------------------------------- 1 | from __future__ import unicode_literals 2 | 3 | from django.test import TestCase 4 | from .models import Product 5 | from django_postgres_extensions.models.functions import * 6 | from django_postgres_extensions.models.expressions import F, Value as V, Index, SliceArray 7 | from django.db.utils import ProgrammingError, DataError 8 | from django.db import transaction 9 | from unittest import skip 10 | 11 | class ArrayCharsIndexTests(TestCase): 12 | 13 | def setUp(self): 14 | super(ArrayCharsIndexTests, self).setUp() 15 | self.product = Product(name='xyz', tags=['Music', 'Album', 'Rock'], moretags=['Very Popular']) 16 | self.product.save() 17 | self.queryset = Product.objects.filter(pk=self.product.pk) 18 | obj = Product.objects.annotate(Index('tags', 1)).get() 19 | 20 | def tearDown(self): 21 | self.queryset.delete() 22 | 23 | def test_array_values(self): 24 | product = self.queryset.get() 25 | array_values = product.tags 26 | self.assertListEqual(array_values, ['Music', 'Album', 'Rock']) 27 | 28 | def test_array_index(self): 29 | with transaction.atomic(): 30 | obj = self.queryset.annotate(Index('tags', 1)).get() 31 | self.assertEqual(obj.tags__1, 'Album') 32 | 33 | def test_array_index_with_output_field(self): 34 | with transaction.atomic(): 35 | obj = self.queryset.annotate(tag_1=Index('tags', 1)).get() 36 | self.assertEqual(obj.tag_1, 'Album') 37 | 38 | def test_array_index_slice(self): 39 | with transaction.atomic(): 40 | obj = self.queryset.annotate(SliceArray('tags', 0, 1)).get() 41 | self.assertEqual(obj.tags__0_1, ['Music', 'Album']) 42 | 43 | def test_array_update_index(self): 44 | with transaction.atomic(): 45 | self.queryset.update(tags__2='Heavy Metal') 46 | product = self.queryset.get() 47 | self.assertListEqual(product.tags, ['Music', 'Album', 'Heavy Metal']) 48 | 49 | def test_array_int_converted(self): 50 | with transaction.atomic(): 51 | self.queryset.update(tags__2=1) 52 | product = self.queryset.get() 53 | self.assertListEqual(product.tags, ['Music', 'Album', '1']) 54 | 55 | class ArrayCharsFuncTests(TestCase): 56 | def setUp(self): 57 | super(ArrayCharsFuncTests, self).setUp() 58 | self.product = Product(name='xyz', tags=['Music', 'Album', 'Rock'], moretags=['Very Popular']) 59 | self.product.save() 60 | self.queryset = Product.objects.filter(id=self.product.id) 61 | 62 | def tearDown(self): 63 | self.queryset.delete() 64 | 65 | def test_array_length(self): 66 | with transaction.atomic(): 67 | obj = self.queryset.annotate(tags_length=ArrayLength('tags', 1)).get() 68 | self.assertEqual(obj.tags_length, 3) 69 | 70 | def test_array_append(self): 71 | with transaction.atomic(): 72 | obj = self.queryset.annotate(tags_appended=ArrayAppend('tags', 'Popular')).get() 73 | self.assertListEqual(obj.tags_appended, ['Music', 'Album', 'Rock', 'Popular']) 74 | with transaction.atomic(): 75 | self.queryset.update(tags = ArrayAppend('tags', 'Popular')) 76 | product = self.queryset.get() 77 | self.assertListEqual(product.tags, ['Music', 'Album', 'Rock', 'Popular']) 78 | 79 | def test_array_char_raises(self): 80 | with transaction.atomic(): 81 | self.assertRaises(ProgrammingError, self.queryset.update, tags=ArrayAppend('tags', 1)) 82 | 83 | def test_array_prepend(self): 84 | with transaction.atomic(): 85 | self.queryset.update(tags = ArrayPrepend('Popular', 'tags')) 86 | product = self.queryset.get() 87 | self.assertListEqual(product.tags, ['Popular', 'Music', 'Album', 'Rock']) 88 | 89 | def test_array_remove(self): 90 | with transaction.atomic(): 91 | self.queryset.update(tags = ArrayRemove('tags', 'Album')) 92 | product = self.queryset.get() 93 | self.assertListEqual(product.tags, ['Music', 'Rock']) 94 | 95 | def test_array_cat(self): 96 | with transaction.atomic(): 97 | self.queryset.update(tags = ArrayCat('tags', 'moretags')) 98 | product = self.queryset.get() 99 | self.assertListEqual(product.tags, ['Music', 'Album', 'Rock', 'Very Popular']) 100 | 101 | def test_array_cat_list(self): 102 | with transaction.atomic(): 103 | self.queryset.update(tags=ArrayCat('tags', ['Popular', '8'], output_field=Product._meta.get_field('tags'))) 104 | product = self.queryset.get() 105 | self.assertListEqual(product.tags, ['Music', 'Album', 'Rock', 'Popular', '8']) 106 | 107 | def test_array_replace(self): 108 | with transaction.atomic(): 109 | self.queryset.update(tags = ArrayReplace('tags', 'Rock', 'Heavy Metal')) 110 | product = self.queryset.get() 111 | self.assertListEqual(product.tags, ['Music', 'Album', 'Heavy Metal']) 112 | 113 | def test_array_position(self): 114 | with transaction.atomic(): 115 | obj = self.queryset.annotate(position=ArrayPosition('tags', 'Rock')).get() 116 | self.assertEqual(obj.position, 3) 117 | 118 | def test_array_positions(self): 119 | with transaction.atomic(): 120 | self.queryset.update(tags = ArrayPrepend('Rock', 'tags')) 121 | with transaction.atomic(): 122 | obj = self.queryset.annotate(positions=ArrayPositions('tags', 'Rock')).get() 123 | self.assertEqual(obj.positions, [1, 4]) 124 | 125 | class ArrayCharsCatTests(TestCase): 126 | def setUp(self): 127 | super(ArrayCharsCatTests, self).setUp() 128 | self.product = Product(name='xyz', tags=['Music', 'Album', 'Rock'], moretags=['Very Popular']) 129 | self.product.save() 130 | self.queryset = Product.objects.filter(id=self.product.id) 131 | self.prod1_queryset = self.queryset.filter(id=self.product.id) 132 | 133 | def tearDown(self): 134 | self.queryset.delete() 135 | 136 | def test_array_cat_append(self): 137 | with transaction.atomic(): 138 | self.queryset.update(tags=F('tags').cat(V(['Popular'], output_field = Product._meta.get_field('tags')))) 139 | product = self.queryset.get() 140 | self.assertListEqual(product.tags, ['Music', 'Album', 'Rock', 'Popular']) 141 | 142 | def test_array_cat_arrays(self): 143 | with transaction.atomic(): 144 | self.queryset.update(tags=F('tags').cat(F('moretags'))) 145 | product = self.queryset.get() 146 | self.assertListEqual(product.tags, ['Music', 'Album', 'Rock', 'Very Popular']) 147 | 148 | def test_array_char_raises(self): 149 | with transaction.atomic(): 150 | self.assertRaises((ProgrammingError), self.queryset.update, tags=F('tags').cat(V(1))) 151 | 152 | class ArrayIntTests(TestCase): 153 | def setUp(self): 154 | super(ArrayIntTests, self).setUp() 155 | self.product = Product(name='xyz', prices=[0, 1, 2]) 156 | self.product.save() 157 | self.queryset = Product.objects.filter(id=self.product.id) 158 | 159 | def test_array_int_index(self): 160 | with transaction.atomic(): 161 | obj = self.queryset.annotate(Index('prices', 1)).get() 162 | self.assertEqual(obj.prices__1, 1) 163 | 164 | def test_array_int_update_index(self): 165 | with transaction.atomic(): 166 | self.queryset.update(prices__2=3) 167 | product = self.queryset.get() 168 | self.assertListEqual(product.prices, [0, 1, 3]) 169 | 170 | def test_array_int_append(self): 171 | with transaction.atomic(): 172 | self.queryset.update(prices=ArrayAppend('prices', 3)) 173 | product = self.queryset.get() 174 | self.assertListEqual(product.prices, [0, 1, 2, 3]) 175 | 176 | def test_array_int_cat_append(self): 177 | with transaction.atomic(): 178 | self.queryset.update(prices=F('prices').cat(V(3))) 179 | product = self.queryset.get() 180 | self.assertListEqual(product.prices, [0,1,2,3]) 181 | 182 | def test_array_int_cat_append_list(self): 183 | with transaction.atomic(): 184 | self.queryset.update(prices=F('prices').cat(V([3, 4]))) 185 | product = self.queryset.get() 186 | self.assertListEqual(product.prices, [0,1,2, 3, 4]) 187 | 188 | def test_array_int_cat_prepend(self): 189 | with transaction.atomic(): 190 | self.queryset.update(prices=V(-1).cat(F('prices'))) 191 | product = self.queryset.get() 192 | self.assertListEqual(product.prices, [-1, 0, 1, 2]) 193 | 194 | def test_array_int_double_cat(self): 195 | with transaction.atomic(): 196 | self.queryset.update(prices=V(-1).cat(F('prices').cat(V(3)))) 197 | product = self.queryset.get() 198 | self.assertListEqual(product.prices, [-1, 0, 1, 2, 3]) 199 | 200 | def test_array_int_raises(self): 201 | self.assertRaises(DataError, self.queryset.update, prices=ArrayAppend('prices', 'test')) 202 | 203 | class ArrayMultiDimensionalTests(TestCase): 204 | def setUp(self): 205 | super(ArrayMultiDimensionalTests, self).setUp() 206 | self.product = Product(name='xyz', coordinates=[[0,15, 25], [15,30, 40], [45, 60, 90]]) 207 | self.product.save() 208 | self.queryset = Product.objects.filter(id=self.product.id) 209 | 210 | def tearDown(self): 211 | self.queryset.delete() 212 | 213 | def test_2d_array_values(self): 214 | product = self.queryset.get() 215 | array_values = product.coordinates 216 | self.assertListEqual(array_values, [[0,15, 25], [15, 30, 40], [45, 60, 90]]) 217 | 218 | def test_2d_array_dimensions(self): 219 | with transaction.atomic(): 220 | obj = self.queryset.annotate(coordinates_dims = ArrayDims('coordinates')).get() 221 | self.assertEqual(obj.coordinates_dims, '[1:3][1:3]') 222 | 223 | def test_2d_array_upper(self): 224 | with transaction.atomic(): 225 | obj = self.queryset.annotate(coordinates_upper_1 = ArrayUpper('coordinates', 1)).get() 226 | self.assertEqual(obj.coordinates_upper_1, 3) 227 | 228 | def test_2d_array_lower(self): 229 | with transaction.atomic(): 230 | obj = self.queryset.annotate(coordinates_lower_2 = ArrayLower('coordinates', 2)).get() 231 | self.assertEqual(obj.coordinates_lower_2, 1) 232 | 233 | def test_2d_array_length(self): 234 | with transaction.atomic(): 235 | obj = self.queryset.annotate(coordinates_length_2=ArrayLength('coordinates', 2)).get() 236 | self.assertEqual(obj.coordinates_length_2, 3) 237 | 238 | def test_2d_array_cardinality(self): 239 | with transaction.atomic(): 240 | obj = self.queryset.annotate(coordinates_cardinality=Cardinality('coordinates')).get() 241 | self.assertEqual(obj.coordinates_cardinality, 9) 242 | 243 | def test_2d_array_index(self): 244 | with transaction.atomic(): 245 | obj = self.queryset.annotate(Index(Index('coordinates', 2), 1)).get() 246 | self.assertEqual(obj.coordinates__2__1, 60) 247 | 248 | def test_2d_array_index_slice_1(self): 249 | with transaction.atomic(): 250 | obj = self.queryset.annotate(SliceArray(SliceArray('coordinates', 0, 2), 0, 0)).get() 251 | self.assertEqual(obj.coordinates__0_2__0_0, [[0], [15], [45]]) 252 | with transaction.atomic(): 253 | obj = self.queryset.annotate(SliceArray(SliceArray('coordinates', 0, 2), 1, 1)).get() 254 | self.assertEqual(obj.coordinates__0_2__1_1, [[15], [30], [60]]) 255 | 256 | def test_2d_array_index_slice_2(self): 257 | with transaction.atomic(): 258 | obj = self.queryset.annotate(SliceArray(SliceArray('coordinates', 1, 1), 1, 2)).get() 259 | self.assertEqual(obj.coordinates__1_1__1_2, [[30, 40]]) 260 | 261 | def test_2d_array_index_slice_3(self): 262 | with transaction.atomic(): 263 | obj = self.queryset.annotate(SliceArray(SliceArray('coordinates', 1, 2), 1, 2)).get() 264 | self.assertEqual(obj.coordinates__1_2__1_2, [[30, 40], [60, 90]]) 265 | 266 | @skip("Update seems to run but change is not made '{'") 267 | def test_2d_array_update_index_1(self): 268 | with transaction.atomic(): 269 | self.queryset.update(tags__1=[70, 90, 100]) 270 | product = self.queryset.get() 271 | self.assertListEqual(product.coordinates, [[0,15, 25], [70, 90, 100], [45, 60, 90]]) 272 | 273 | @skip("Update seems to run but unable to retrieve object due to DataError: array does not start with '{'") 274 | def test_2d_array_update_index_2(self): 275 | with transaction.atomic(): 276 | self.queryset.update(tags__1__2=35) 277 | product = self.queryset.get() 278 | self.assertListEqual(product.coordinates, [[0,15, 25], [15, 35, 40], [45, 60, 90]]) 279 | 280 | @skip("Fails with 'No function matches the given name and argument types. You might need to add explicit type casts.'") 281 | def test_2d_array_append_1(self): 282 | with transaction.atomic(): 283 | obj = self.queryset.annotate(coordinates_appended=ArrayAppend('coordinates', [70, 90, 100])).get() 284 | self.assertListEqual(obj.coordinates_appended, [[0,15, 25], [15, 35, 40], [45, 60, 90], [70, 90, 100]]) 285 | with transaction.atomic(): 286 | self.queryset.update(coordinates=ArrayAppend('coordinates', [70, 90, 100])) 287 | product = self.queryset.get() 288 | self.assertListEqual(product.coordinates, [[0,15, 25], [15, 35, 40], [45, 60, 90], [70, 90, 100]]) 289 | 290 | @skip( 291 | "Fails with 'No function matches the given name and argument types. " 292 | "You might need to add explicit type casts.'") 293 | def test_2d_array_append_2(self): 294 | with transaction.atomic(): 295 | obj = self.queryset.update(coordinates=ArrayAppend(Index('coordinates', 1), [100])) 296 | product = self.queryset.get() 297 | self.assertListEqual(obj.coordinates, [[0, 15, 25], [15, 35, 40, 100], [45, 60, 90]]) -------------------------------------------------------------------------------- /django_postgres_extensions/models/fields/related.py: -------------------------------------------------------------------------------- 1 | from django_postgres_extensions.models.fields import ArrayField 2 | from django.db.models.fields.related import RelatedField 3 | from django.db.models.query_utils import PathInfo 4 | from .reverse_related import ArrayManyToManyRel 5 | from .related_descriptors import MultiReferenceDescriptor 6 | from django.db import models 7 | from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT, lazy_related_operation 8 | from django.forms.models import ModelMultipleChoiceField 9 | from django.utils import six 10 | from django.utils.encoding import force_text 11 | from .related_lookups import RelatedArrayContains, RelatedArrayExact, RelatedArrayContainedBy, RelatedContainsItem, \ 12 | RelatedArrayOverlap, RelatedAnyGreaterThan, RelatedAnyLessThanOrEqual, RelatedAnyLessThan, RelatedAnyGreaterThanOrEqual 13 | 14 | class ArrayManyToManyField(ArrayField, RelatedField): 15 | # Field flags 16 | many_to_many_array = True 17 | many_to_many = False 18 | many_to_one = False 19 | one_to_many = False 20 | one_to_one = False 21 | 22 | rel_class = ArrayManyToManyRel 23 | 24 | def __init__(self, to_model, base_field=None, size=None, related_name=None, symmetrical=None, 25 | related_query_name=None, limit_choices_to=None, to_field=None, db_constraint=False, **kwargs): 26 | 27 | try: 28 | to = to_model._meta.model_name 29 | except AttributeError: 30 | assert isinstance(to_model, six.string_types), ( 31 | "%s(%r) is invalid. First parameter to ForeignKey must be " 32 | "either a model, a model name, or the string %r" % ( 33 | self.__class__.__name__, to_model, 34 | RECURSIVE_RELATIONSHIP_CONSTANT, 35 | ) 36 | ) 37 | to = str(to_model) 38 | else: 39 | # For backwards compatibility purposes, we need to *try* and set 40 | # the to_field during FK construction. It won't be guaranteed to 41 | # be correct until contribute_to_class is called. Refs #12190. 42 | to_field = to_field or (to_model._meta.pk and to_model._meta.pk.name) 43 | if not base_field: 44 | field = to_model._meta.get_field(to_field) 45 | if not field.is_relation: 46 | base_field_type = type(field) 47 | internal_type = field.get_internal_type() 48 | if internal_type == 'AutoField': 49 | pass 50 | elif internal_type == 'BigAutoField': 51 | base_field = models.BigIntegerField() 52 | elif hasattr(field, 'max_length'): 53 | base_field = base_field_type(max_length = field.max_length) 54 | else: 55 | base_field = base_field_type() 56 | 57 | if not base_field: 58 | base_field = models.IntegerField() 59 | 60 | if symmetrical is None: 61 | symmetrical = (to == RECURSIVE_RELATIONSHIP_CONSTANT) 62 | 63 | kwargs['rel'] = self.rel_class( 64 | self, to, to_field, 65 | related_name=related_name, 66 | related_query_name=related_query_name, 67 | limit_choices_to=limit_choices_to, 68 | symmetrical=symmetrical, 69 | ) 70 | self.has_null_arg = 'null' in kwargs 71 | 72 | self.db_constraint = db_constraint 73 | 74 | self.to = to 75 | 76 | if 'default' not in kwargs.keys(): 77 | kwargs['default'] = [] 78 | kwargs['blank'] = True 79 | 80 | self.from_fields = ['self'] 81 | self.to_fields = [to_field] 82 | 83 | super(ArrayManyToManyField, self).__init__(base_field, size=size, **kwargs) 84 | 85 | def deconstruct(self): 86 | name, path, args, kwargs = super(ArrayManyToManyField, self).deconstruct() 87 | args = (self.to,) 88 | kwargs.update({ 89 | 'base_field': self.base_field, 90 | 'size': self.size, 91 | 'related_name': self.remote_field.related_name, 92 | 'symmetrical': self.remote_field.symmetrical, 93 | 'related_query_name': self.remote_field.related_query_name, 94 | 'limit_choices_to': self.remote_field.limit_choices_to, 95 | 'to_field': self.remote_field.field, 96 | 'db_constraint': self.db_constraint 97 | }) 98 | return name, path, args, kwargs 99 | 100 | def get_attname(self): 101 | return '%s_ids' % self.name 102 | 103 | def get_attname_column(self): 104 | attname = self.get_attname() 105 | column = self.db_column or attname 106 | return attname, column 107 | 108 | def get_accessor_name(self): 109 | return self.remote_field.model_name + '_set' 110 | 111 | def get_reverse_accessor_name(self): 112 | return self.remote_field.get_accessor_name() 113 | 114 | def contribute_to_class(self, cls, name, **kwargs): 115 | # To support multiple relations to self, it's useful to have a non-None 116 | # related name on symmetrical relations for internal reasons. The 117 | # concept doesn't make a lot of sense externally ("you want me to 118 | # specify *what* on my non-reversible relation?!"), so we set it up 119 | # automatically. The funky name reduces the chance of an accidental 120 | # clash. 121 | if self.remote_field.symmetrical and ( 122 | self.remote_field.model == "self" or self.remote_field.model == cls._meta.object_name): 123 | self.remote_field.related_name = "%s_rel_+" % name 124 | elif self.remote_field.is_hidden(): 125 | # If the backwards relation is disabled, replace the original 126 | # related_name with one generated from the m2m field name. Django 127 | # still uses backwards relations internally and we need to avoid 128 | # clashes between multiple m2m fields with related_name == '+'. 129 | self.remote_field.related_name = "_%s_%s_+" % (cls.__name__.lower(), name) 130 | 131 | super(ArrayManyToManyField, self).contribute_to_class(cls, name) 132 | 133 | if not cls._meta.abstract: 134 | setattr(cls, self.name, MultiReferenceDescriptor(self.remote_field, reverse=False)) 135 | 136 | self.opts = cls._meta 137 | if self.remote_field.related_name: 138 | related_name = force_text(self.remote_field.related_name) % { 139 | 'class': cls.__name__.lower(), 140 | 'app_label': cls._meta.app_label.lower() 141 | } 142 | self.remote_field.related_name = related_name 143 | 144 | def resolve_related_class(model, related, field): 145 | field.remote_field.model = related 146 | field.do_related_class(related, model) 147 | 148 | lazy_related_operation(resolve_related_class, cls, self.remote_field.model, field=self) 149 | 150 | def contribute_to_related_class(self, cls, related): 151 | # Internal M2Ms (i.e., those with a related name ending with '+') 152 | # and swapped models don't get a related descriptor. 153 | if not self.remote_field.is_hidden() and not related.related_model._meta.swapped: 154 | setattr(cls, self.get_reverse_accessor_name(), MultiReferenceDescriptor(self.remote_field, reverse=True)) 155 | 156 | def formfield(self, **kwargs): 157 | db = kwargs.pop('using', None) 158 | defaults = { 159 | 'form_class': ModelMultipleChoiceField, 160 | 'queryset': self.related_model._default_manager.using(db), 161 | } 162 | defaults.update(kwargs) 163 | # If initial is passed in, it's a list of related objects, but the 164 | # MultipleChoiceField takes a list of IDs. 165 | if defaults.get('initial') is not None: 166 | initial = defaults['initial'] 167 | if callable(initial): 168 | initial = initial() 169 | defaults['initial'] = [i._get_pk_val() for i in initial] 170 | return super(RelatedField, self).formfield(**defaults) 171 | 172 | def get_join_on(self, parent_alias, lhs_col, table_alias, rhs_col): 173 | return '%s.%s = ANY(%s.%s)' % ( 174 | table_alias, 175 | rhs_col, 176 | parent_alias, 177 | lhs_col, 178 | ) 179 | 180 | def get_join_on2(self, parent_alias, lhs_col, table_alias, rhs_col): 181 | return "ARRAY_APPEND(ARRAY[]::integer[], %s.%s) <@ ANY(%s.%s)" % ( 182 | table_alias, 183 | rhs_col, 184 | parent_alias, 185 | lhs_col, 186 | ) 187 | 188 | def resolve_related_fields(self): 189 | if len(self.from_fields) < 1 or len(self.from_fields) != len(self.to_fields): 190 | raise ValueError('Foreign Object from and to fields must be the same non-zero length') 191 | if isinstance(self.remote_field.model, six.string_types): 192 | raise ValueError('Related model %r cannot be resolved' % self.remote_field.model) 193 | related_fields = [] 194 | for index in range(len(self.from_fields)): 195 | from_field_name = self.from_fields[index] 196 | to_field_name = self.to_fields[index] 197 | from_field = (self if from_field_name == 'self' 198 | else self.opts.get_field(from_field_name)) 199 | to_field = (self.remote_field.model._meta.pk if to_field_name is None 200 | else self.remote_field.model._meta.get_field(to_field_name)) 201 | related_fields.append((from_field, to_field)) 202 | return related_fields 203 | 204 | 205 | @property 206 | def related_fields(self): 207 | if not hasattr(self, '_related_fields'): 208 | self._related_fields = self.resolve_related_fields() 209 | return self._related_fields 210 | 211 | 212 | @property 213 | def reverse_related_fields(self): 214 | return [(rhs_field, lhs_field) for lhs_field, rhs_field in self.related_fields] 215 | 216 | 217 | @property 218 | def local_related_fields(self): 219 | return tuple(lhs_field for lhs_field, rhs_field in self.related_fields) 220 | 221 | 222 | @property 223 | def foreign_related_fields(self): 224 | return tuple(rhs_field for lhs_field, rhs_field in self.related_fields if rhs_field) 225 | 226 | 227 | def get_local_related_value(self, instance): 228 | return self.get_instance_value_for_fields(instance, self.local_related_fields) 229 | 230 | 231 | def get_foreign_related_value(self, instance): 232 | return self.get_instance_value_for_fields(instance, self.foreign_related_fields) 233 | 234 | 235 | @staticmethod 236 | def get_instance_value_for_fields(instance, fields): 237 | ret = [] 238 | opts = instance._meta 239 | for field in fields: 240 | # Gotcha: in some cases (like fixture loading) a model can have 241 | # different values in parent_ptr_id and parent's id. So, use 242 | # instance.pk (that is, parent_ptr_id) when asked for instance.id. 243 | if field.primary_key: 244 | possible_parent_link = opts.get_ancestor_link(field.model) 245 | if (not possible_parent_link or 246 | possible_parent_link.primary_key or 247 | possible_parent_link.model._meta.abstract): 248 | ret.append(instance.pk) 249 | continue 250 | ret.append(getattr(instance, field.attname)) 251 | return tuple(ret) 252 | 253 | def get_joining_columns(self, reverse_join=False): 254 | source = self.reverse_related_fields if reverse_join else self.related_fields 255 | columns = tuple((lhs_field.column, rhs_field.column) for lhs_field, rhs_field in source) 256 | return columns 257 | 258 | def get_reverse_joining_columns(self): 259 | return self.get_joining_columns(reverse_join=True) 260 | 261 | def get_extra_descriptor_filter(self, instance): 262 | """ 263 | Return an extra filter condition for related object fetching when 264 | user does 'instance.column', that is the extra filter is used in 265 | the descriptor of the field. 266 | 267 | The filter should be either a dict usable in .filter(**kwargs) call or 268 | a Q-object. The condition will be ANDed together with the relation's 269 | joining columns. 270 | 271 | A parallel method is get_extra_restriction() which is used in 272 | JOIN and subquery conditions. 273 | """ 274 | return {} 275 | 276 | def get_extra_restriction(self, where_class, alias, related_alias): 277 | """ 278 | Return a pair condition used for joining and subquery pushdown. The 279 | condition is something that responds to as_sql(compiler, connection) 280 | method. 281 | 282 | Note that currently referring both the 'alias' and 'related_alias' 283 | will not work in some conditions, like subquery pushdown. 284 | 285 | A parallel method is get_extra_descriptor_filter() which is used in 286 | instance.column related object fetching. 287 | """ 288 | return None 289 | 290 | def get_path_info(self, filtered_relation=None): 291 | """ 292 | Get path from this field to the related model. 293 | """ 294 | opts = self.remote_field.model._meta 295 | from_opts = self.model._meta 296 | return [PathInfo(from_opts, opts, self.foreign_related_fields, self, False, True, filtered_relation)] 297 | 298 | def validate_item(self, obj, model=None): 299 | if not model: 300 | model = self.remote_field.model 301 | if isinstance(obj, model): 302 | obj = getattr(obj, self.remote_field.target_field.name) 303 | elif isinstance(obj, models.Model): 304 | raise TypeError( 305 | "'%s' instance expected, got %r" % 306 | (model._meta.object_name, obj) 307 | ) 308 | return obj 309 | 310 | def save_form_data(self, instance, data): 311 | """ 312 | For newly created instances, the column is set and saved with all the 313 | other fields when model.save() is run. So m2m can be updated along with 314 | all the other fields with one sql query. 315 | For updateing instances with model.save(), Array M2M fields are ignored to avoid sending large 316 | amounts of unneccessary data in the update query. Instead the data is set with an 317 | extra sql command via the descriptor. 318 | """ 319 | if instance.pk: 320 | getattr(instance, self.name).set(data) 321 | else: 322 | objs = [self.validate_item(obj) for obj in data] 323 | setattr(instance, self.attname, objs) 324 | 325 | def get_reverse_path_info(self, filtered_relation): 326 | """ 327 | Get path from the related model to this field's model. 328 | """ 329 | opts = self.model._meta 330 | from_opts = self.remote_field.model._meta 331 | pathinfos = [PathInfo(from_opts, opts, (from_opts.pk,), self.remote_field, not self.unique, False, filtered_relation)] 332 | return pathinfos 333 | 334 | def get_lookup(self, lookup_name): 335 | if lookup_name == 'in': 336 | return RelatedArrayOverlap 337 | elif lookup_name == 'exact': 338 | return RelatedContainsItem 339 | elif lookup_name =='exactly': 340 | return RelatedArrayExact 341 | elif lookup_name == 'contains': 342 | return RelatedArrayContains 343 | elif lookup_name == 'contained_by': 344 | return RelatedArrayContainedBy 345 | elif lookup_name == 'overlap': 346 | return RelatedArrayOverlap 347 | elif lookup_name == 'gt': 348 | return RelatedAnyGreaterThan 349 | elif lookup_name == 'gte': 350 | return RelatedAnyGreaterThanOrEqual 351 | elif lookup_name == 'lt': 352 | return RelatedAnyLessThan 353 | elif lookup_name == 'lte': 354 | return RelatedAnyLessThanOrEqual 355 | else: 356 | raise TypeError('Related Array got invalid lookup: %s' % lookup_name) --------------------------------------------------------------------------------