├── requirements.txt
├── tests
├── jsonb
│ ├── __init__.py
│ ├── models.py
│ └── tests.py
├── arrays
│ ├── __init__.py
│ ├── models.py
│ └── tests.py
├── benchmarks
│ ├── __init__.py
│ └── models.py
├── hstores
│ ├── __init__.py
│ ├── models.py
│ └── tests.py
├── modeladmin
│ ├── __init__.py
│ ├── urls.py
│ ├── admin.py
│ ├── models.py
│ └── tests.py
├── m2m_multiple_array
│ ├── __init__.py
│ ├── models.py
│ └── tests.py
├── m2m_recursive_array
│ ├── __init__.py
│ ├── models.py
│ └── tests.py
├── many_to_many_array
│ ├── __init__.py
│ └── models.py
├── nested_form_field_widget
│ ├── __init__.py
│ └── tests.py
├── prefetch_related_array
│ ├── __init__.py
│ ├── test_uuid.py
│ ├── test_prefetch_related_objects.py
│ └── models.py
├── urls.py
└── test_postgres.py
├── requirements_docs.txt
├── setup.cfg
├── django_postgres_extensions
├── admin
│ ├── __init__.py
│ └── options.py
├── backends
│ ├── __init__.py
│ └── postgresql
│ │ ├── __init__.py
│ │ ├── operations.py
│ │ ├── schema.py
│ │ ├── base.py
│ │ └── creation.py
├── models
│ ├── sql
│ │ ├── __init__.py
│ │ ├── updates.py
│ │ ├── subqueries.py
│ │ ├── datastructures.py
│ │ └── compiler.py
│ ├── __init__.py
│ ├── fields
│ │ ├── related_lookups.py
│ │ ├── reverse_related.py
│ │ ├── __init__.py
│ │ ├── related_descriptors.py
│ │ └── related.py
│ ├── expressions.py
│ ├── lookups.py
│ ├── functions.py
│ └── query.py
├── forms
│ ├── __init__.py
│ ├── fields.py
│ └── widgets.py
├── __init__.py
├── templates
│ └── django_postgres_extensions
│ │ └── nested_form_widget.html
├── signals.py
├── apps.py
└── utils.py
├── MANIFEST.in
├── docs
├── admin_form.jpg
├── admin_list.jpg
├── array_split.jpg
├── json_field.jpg
├── array_choice.jpg
├── hstore_field.jpg
├── queryset.rst
├── intro.rst
├── settings.py
├── Makefile
├── testing.rst
├── index.rst
├── make.bat
├── arraym2m.rst
├── json.rst
├── hstores.rst
├── arrays.rst
├── conf.py
└── features.rst
├── .gitignore
├── description.rst
├── setup.py
├── LICENSE
└── readme.rst
/requirements.txt:
--------------------------------------------------------------------------------
1 | django
--------------------------------------------------------------------------------
/tests/jsonb/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/arrays/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/benchmarks/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/hstores/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/modeladmin/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/m2m_multiple_array/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/m2m_recursive_array/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/many_to_many_array/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements_docs.txt:
--------------------------------------------------------------------------------
1 | sphinxcontrib-fancybox
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [bdist_wheel]
2 | universal = 1
--------------------------------------------------------------------------------
/tests/nested_form_field_widget/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/prefetch_related_array/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/django_postgres_extensions/admin/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/django_postgres_extensions/backends/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/django_postgres_extensions/backends/postgresql/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE
2 | include readme.rst
3 | include description.rst
--------------------------------------------------------------------------------
/django_postgres_extensions/models/sql/__init__.py:
--------------------------------------------------------------------------------
1 | from .subqueries import UpdateQuery
--------------------------------------------------------------------------------
/django_postgres_extensions/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .fields import *
2 | from .fields.related import *
--------------------------------------------------------------------------------
/docs/admin_form.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/primal100/django_postgres_extensions/HEAD/docs/admin_form.jpg
--------------------------------------------------------------------------------
/docs/admin_list.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/primal100/django_postgres_extensions/HEAD/docs/admin_list.jpg
--------------------------------------------------------------------------------
/docs/array_split.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/primal100/django_postgres_extensions/HEAD/docs/array_split.jpg
--------------------------------------------------------------------------------
/docs/json_field.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/primal100/django_postgres_extensions/HEAD/docs/json_field.jpg
--------------------------------------------------------------------------------
/docs/array_choice.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/primal100/django_postgres_extensions/HEAD/docs/array_choice.jpg
--------------------------------------------------------------------------------
/docs/hstore_field.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/primal100/django_postgres_extensions/HEAD/docs/hstore_field.jpg
--------------------------------------------------------------------------------
/django_postgres_extensions/forms/__init__.py:
--------------------------------------------------------------------------------
1 | from .fields import NestedFormField
2 | from .widgets import NestedFormWidget
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.log
3 | .idea/*
4 | build/*
5 | dist/*
6 | docs/_*
7 | django_postgres_extensions.egg-info/*
8 | *.env
9 | todo.txt
--------------------------------------------------------------------------------
/django_postgres_extensions/__init__.py:
--------------------------------------------------------------------------------
1 | default_app_config = 'django_postgres_extensions.apps.PSQLExtensionsConfig'
2 | __version__ = "0.9.3"
3 |
--------------------------------------------------------------------------------
/tests/modeladmin/urls.py:
--------------------------------------------------------------------------------
1 | from django.conf.urls import url
2 | from django.contrib import admin
3 |
4 | urlpatterns = [
5 | url(r'^admin/', admin.site.urls),
6 | ]
7 |
--------------------------------------------------------------------------------
/tests/urls.py:
--------------------------------------------------------------------------------
1 | """This URLconf exists because Django expects ROOT_URLCONF to exist. URLs
2 | should be added within the test folders, and use TestCase.urls to set them.
3 | This helps the tests remain isolated.
4 | """
5 |
6 |
7 | urlpatterns = []
--------------------------------------------------------------------------------
/tests/jsonb/models.py:
--------------------------------------------------------------------------------
1 | from django.db import models
2 | from django_postgres_extensions.models.fields import JSONField
3 |
4 | class Product(models.Model):
5 | name = models.CharField(max_length=3)
6 | description = JSONField(null=True, blank=True)
--------------------------------------------------------------------------------
/django_postgres_extensions/backends/postgresql/operations.py:
--------------------------------------------------------------------------------
1 | from django.db.backends.postgresql.operations import DatabaseOperations as BaseDatabaseOperations
2 |
3 | class DatabaseOperations(BaseDatabaseOperations):
4 | compiler_module = "django_postgres_extensions.models.sql.compiler"
--------------------------------------------------------------------------------
/django_postgres_extensions/templates/django_postgres_extensions/nested_form_widget.html:
--------------------------------------------------------------------------------
1 |
{% for widget in widgets %}
2 | -
3 |
4 | {{ widget.html }}
5 |
6 | {% endfor %}
--------------------------------------------------------------------------------
/docs/queryset.rst:
--------------------------------------------------------------------------------
1 | Querysets
2 | =========
3 |
4 | Additional Queryset Methods
5 | ---------------------------
6 | This app adds the format method to all querysets. This will defer a field and add an annotation with a different format.
7 | For example to return a hstorefield as json::
8 |
9 | qs = Model.objects.all().format('description', HstoreToJSONBLoose)
10 |
--------------------------------------------------------------------------------
/django_postgres_extensions/models/sql/updates.py:
--------------------------------------------------------------------------------
1 | class UpdateArrayByIndex(object):
2 |
3 | def __init__(self, indexes, value, field):
4 | self.indexes = indexes
5 | self.value = value
6 | self.base_field = field.base_field
7 |
8 | def alter_name(self, name, qn):
9 | for index in self.indexes:
10 | name += "[%s]" % index
11 | return name
12 |
13 |
--------------------------------------------------------------------------------
/django_postgres_extensions/signals.py:
--------------------------------------------------------------------------------
1 | def delete_reverse_related(sender, signal, instance, using, **kwargs):
2 | for related in instance._meta.related_objects:
3 | field = related.field
4 | if getattr(field, 'many_to_many_array', False):
5 | accessor_name = field.get_reverse_accessor_name()
6 | accessor = getattr(instance, accessor_name)
7 | accessor.clear()
8 |
--------------------------------------------------------------------------------
/tests/hstores/models.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals
2 |
3 | from django.db import models
4 | from django_postgres_extensions.models.fields import ArrayField, HStoreField
5 |
6 | class Product(models.Model):
7 | name = models.CharField(max_length=3)
8 | description = HStoreField(null=True, blank=True)
9 | details = HStoreField(null=True, blank=True)
10 | purchases = ArrayField(HStoreField(), null=True)
--------------------------------------------------------------------------------
/description.rst:
--------------------------------------------------------------------------------
1 | Django Postgres Extensions adds a lot of functionality to Django.contrib.postgres, specifically in relation to ArrayField, HStoreField and JSONField, including much better form fields for dealing with these field types. The app also includes an Array Many To Many Field, so you can store the relationship in an array column instead of requiring an extra database table.
2 |
3 | Check out http://django-postgres-extensions.readthedocs.io/en/latest/ to get started.
--------------------------------------------------------------------------------
/tests/modeladmin/admin.py:
--------------------------------------------------------------------------------
1 | from django.contrib import admin
2 | from django_postgres_extensions.admin.options import PostgresAdmin
3 | from .models import Product, Buyer
4 |
5 | class ProductAdmin(PostgresAdmin):
6 | filter_horizontal = ('buyers',)
7 | fields = ('name', 'keywords', 'sports', 'shipping', 'details', 'buyers')
8 | list_display = ('name', 'keywords', 'shipping', 'details', 'country')
9 |
10 | admin.site.register(Buyer)
11 | admin.site.register(Product, ProductAdmin)
12 |
13 |
--------------------------------------------------------------------------------
/docs/intro.rst:
--------------------------------------------------------------------------------
1 | Getting started
2 | ===============
3 |
4 | .. warning::
5 |
6 | Although it generally should work on other versions, Django Postgres Extensions has been tested with Python 2.7.12, Python 3.6 and Django 1.10.5.
7 |
8 | Installation
9 | -------------
10 |
11 | Install with ``pip install django_postgres_extensions``
12 |
13 | Setup project
14 | -------------
15 |
16 | In your settings.py, add 'django.contrib.postgres' and 'django_postgres_extensions' to the list of INSTALLED APPS and configure the database to use the included backend (subclassed from the default Django Postgres backend):
17 |
18 | .. literalinclude:: settings.py
--------------------------------------------------------------------------------
/docs/settings.py:
--------------------------------------------------------------------------------
1 | INSTALLED_APPS = [
2 | 'django.contrib.contenttypes',
3 | 'django.contrib.auth',
4 | 'django.contrib.sites',
5 | 'django.contrib.sessions',
6 | 'django.contrib.messages',
7 | 'django.contrib.admin.apps.SimpleAdminConfig',
8 | 'django.contrib.staticfiles',
9 | 'django.contrib.postgres',
10 | 'django_postgres_extensions'
11 | ]
12 |
13 | DATABASES = {
14 | 'default': {
15 | 'ENGINE': 'django_postgres_extensions.backends.postgresql',
16 | 'NAME': 'db',
17 | 'USER': 'postgres',
18 | 'PASSWORD': 'postgres',
19 | 'HOST': '127.0.0.1',
20 | 'PORT': 5432,
21 | }
22 | }
--------------------------------------------------------------------------------
/tests/arrays/models.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals
2 |
3 | from django.db import models
4 | from django.contrib.postgres.fields import HStoreField
5 | from django_postgres_extensions.models.fields.related import ArrayField
6 |
7 | class Product(models.Model):
8 | name = models.CharField(max_length=3)
9 | tags = ArrayField(models.CharField(max_length=15), null=True, blank=True)
10 | moretags = ArrayField(models.CharField(max_length=15), null=True, blank=True)
11 | prices = ArrayField(models.IntegerField(), null=True, blank=True)
12 | description = HStoreField(null=True, blank=True)
13 | coordinates = ArrayField(ArrayField(models.IntegerField()), null=True)
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SPHINXPROJ = DjangoPostgresExtensions
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/docs/testing.rst:
--------------------------------------------------------------------------------
1 | Running Tests
2 | =============
3 |
4 | Running
5 | -------
6 |
7 | To run the tests
8 |
9 | ``$ git clone https://github.com/primal100/django_postgres_extensions.git dpe_repo``
10 |
11 | ``$ cd dpe_repo/tests``
12 |
13 | Configure the postgresql connection details in test_postgres.py.
14 |
15 | ``$ ./runtests.py --exclude-tag=benchmark``
16 |
17 | Benchmarks
18 | ----------
19 |
20 | Benchmark tests are included to compare performance of the Array M2M with the traditional Django table-based M2M.
21 | They can be quite slow and thus it is recommended to exclude them when running tests altogether as in the above example.
22 |
23 | They can be run with:
24 |
25 | ``$ ./runtests.py benchmarks.tests``
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. Django Postgres Extensions documentation master file, created by
2 | sphinx-quickstart on Wed Feb 8 21:57:16 2017.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to Django Postgres Extensions's documentation!
7 | ======================================================
8 |
9 | .. toctree::
10 | :maxdepth: 2
11 | :caption: Contents:
12 |
13 | features
14 | intro
15 | arrays
16 | hstores
17 | json
18 | arraym2m
19 | queryset
20 | testing
21 |
22 |
23 | User Guide
24 | ==========
25 | :doc:`intro` describes how to get up and running with Django Postgres Extensions.
26 | :doc:`features` gives a basic overview of the features in Django Postgres Extensions.
--------------------------------------------------------------------------------
/tests/benchmarks/models.py:
--------------------------------------------------------------------------------
1 | from django.db import models
2 | from django_postgres_extensions.models.fields.related import ArrayManyToManyField
3 |
4 | class NumberTraditional(models.Model):
5 | index = models.IntegerField()
6 |
7 | def __str__(self):
8 | return self.index
9 |
10 | class Traditional(models.Model):
11 | index = models.IntegerField()
12 | numbers = models.ManyToManyField(NumberTraditional)
13 |
14 | def __str__(self):
15 | return self.index
16 |
17 | class NumberArray(models.Model):
18 | index = models.IntegerField()
19 |
20 | def __str__(self):
21 | return self.index
22 |
23 | class Array(models.Model):
24 | index = models.IntegerField()
25 | numbers = ArrayManyToManyField(NumberArray, db_index=False)
26 |
27 | def __str__(self):
28 | return self.index
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 | set SPHINXPROJ=DjangoPostgresExtensions
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
20 | echo.installed, then set the SPHINXBUILD environment variable to point
21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
22 | echo.may add the Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/django_postgres_extensions/apps.py:
--------------------------------------------------------------------------------
1 | from django.apps import AppConfig
2 | from django.utils.translation import ugettext_lazy as _
3 | from django.db.models import query
4 | from django.db.models.sql import datastructures
5 | from .models.query import update, _update, format, prefetch_one_level
6 | from .models.sql.datastructures import as_sql
7 | from django.db.models.signals import pre_delete
8 | from .signals import delete_reverse_related
9 | from django.conf import settings
10 |
11 |
12 | class PSQLExtensionsConfig(AppConfig):
13 | name = 'django_postgres_extensions'
14 | verbose_name = _('Extra features for PostgreSQL fields')
15 |
16 | def ready(self):
17 | query.QuerySet.format = format
18 | query.QuerySet.update = update
19 | query.QuerySet._update = _update
20 | if getattr(settings, 'ENABLE_ARRAY_M2M', False):
21 | datastructures.Join.as_sql = as_sql
22 | query.prefetch_one_level = prefetch_one_level
23 | pre_delete.connect(delete_reverse_related)
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | import django_postgres_extensions
4 |
5 | setup(name='django_postgres_extensions',
6 | version=django_postgres_extensions.__version__,
7 | description="Extra features for django.contrib.postgres",
8 | long_description=open('description.rst').read(),
9 | author='Paul Martin',
10 | author_email='greatestloginnameever@gmail.com',
11 | url='https://github.com/primal100/django_postgres_extensions',
12 | packages=find_packages(exclude=['tests', 'tests.*']),
13 | classifiers=[
14 | 'Development Status :: 4 - Beta',
15 | 'Environment :: Web Environment',
16 | 'Framework :: Django',
17 | 'Intended Audience :: Developers',
18 | 'License :: OSI Approved :: BSD License',
19 | 'Operating System :: OS Independent',
20 | 'Programming Language :: Python',
21 | 'Programming Language :: Python :: 2',
22 | 'Programming Language :: Python :: 3',
23 | 'Topic :: Database',
24 | 'Topic :: Software Development :: Libraries :: Python Modules',
25 | ],
26 | )
--------------------------------------------------------------------------------
/django_postgres_extensions/backends/postgresql/schema.py:
--------------------------------------------------------------------------------
1 | from django.db.backends.postgresql import schema
2 |
3 | class DatabaseSchemaEditor(schema.DatabaseSchemaEditor):
4 | sql_create_array_index = "CREATE INDEX %(name)s ON %(table)s USING GIN (%(columns)s)%(extra)s"
5 |
6 | def _model_indexes_sql(self, model):
7 | output = super(DatabaseSchemaEditor, self)._model_indexes_sql(model)
8 | if not model._meta.managed or model._meta.proxy or model._meta.swapped:
9 | return output
10 |
11 | for field in model._meta.local_fields:
12 | array_index_statement = self._create_array_index_sql(model, field)
13 | if array_index_statement is not None:
14 | output.append(array_index_statement)
15 | return output
16 |
17 | def _create_array_index_sql(self, model, field):
18 | db_type = field.db_type(connection=self.connection)
19 | if db_type is not None and '[' in db_type and db_type.endswith(']') and (field.db_index or field.unique):
20 | return self._create_index_sql(model, [field], suffix='_gin', sql=self.sql_create_array_index)
21 | return None
--------------------------------------------------------------------------------
/tests/m2m_multiple_array/models.py:
--------------------------------------------------------------------------------
1 | """
2 | Multiple many-to-many relationships between the same two tables
3 |
4 | In this example, an ``Article`` can have many "primary" ``Category`` objects
5 | and many "secondary" ``Category`` objects.
6 |
7 | Set ``related_name`` to designate what the reverse relationship is called.
8 | """
9 |
10 | from django.db import models
11 | from django.utils.encoding import python_2_unicode_compatible
12 | from django_postgres_extensions.models.fields.related import ArrayManyToManyField
13 |
14 | @python_2_unicode_compatible
15 | class Category(models.Model):
16 | name = models.CharField(max_length=20)
17 |
18 | class Meta:
19 | ordering = ('name',)
20 |
21 | def __str__(self):
22 | return self.name
23 |
24 |
25 | @python_2_unicode_compatible
26 | class Article(models.Model):
27 | headline = models.CharField(max_length=50)
28 | pub_date = models.DateTimeField()
29 | primary_categories = ArrayManyToManyField(Category, related_name='primary_article_set')
30 | secondary_categories = ArrayManyToManyField(Category, related_name='secondary_article_set')
31 |
32 | class Meta:
33 | ordering = ('pub_date',)
34 |
35 | def __str__(self):
36 | return self.headline
37 |
--------------------------------------------------------------------------------
/tests/m2m_recursive_array/models.py:
--------------------------------------------------------------------------------
1 | """
2 | Many-to-many relationships between the same two tables
3 |
4 | In this example, a ``Person`` can have many friends, who are also ``Person``
5 | objects. Friendship is a symmetrical relationship - if I am your friend, you
6 | are my friend. Here, ``friends`` is an example of a symmetrical
7 | ``ManyToManyField``.
8 |
9 | A ``Person`` can also have many idols - but while I may idolize you, you may
10 | not think the same of me. Here, ``idols`` is an example of a non-symmetrical
11 | ``ManyToManyField``. Only recursive ``ManyToManyField`` fields may be
12 | non-symmetrical, and they are symmetrical by default.
13 |
14 | This test validates that the many-to-many table is created using a mangled name
15 | if there is a name clash, and tests that symmetry is preserved where
16 | appropriate.
17 | """
18 |
19 | from django.db import models
20 | from django.utils.encoding import python_2_unicode_compatible
21 | from django_postgres_extensions.models.fields.related import ArrayManyToManyField
22 |
23 | @python_2_unicode_compatible
24 | class Person(models.Model):
25 | name = models.CharField(max_length=20)
26 | friends = ArrayManyToManyField('self')
27 | idols = ArrayManyToManyField('self', symmetrical=False, related_name='stalkers')
28 |
29 | def __str__(self):
30 | return self.name
31 |
--------------------------------------------------------------------------------
/tests/test_postgres.py:
--------------------------------------------------------------------------------
1 | # This is an example test settings file for use with the Django test suite.
2 | #
3 | # The 'postgresql' backend requires only the ENGINE setting (an in-
4 | # memory database will be used). All other backends will require a
5 | # NAME and potentially authentication information. See the
6 | # following section in the docs for more information:
7 | #
8 | # https://docs.djangoproject.com/en/dev/internals/contributing/writing-code/unit-tests/
9 | #
10 | # The different databases that Django supports behave differently in certain
11 | # situations, so it is recommended to run the test suite against as many
12 | # database backends as possible. You may want to create a separate settings
13 | # file for each of the backends you test against.
14 |
15 | import os
16 |
17 |
18 | DATABASES = {
19 | 'default': {
20 | 'ENGINE': 'django_postgres_extensions.backends.postgresql',
21 | 'NAME': os.environ.get('POSTGRES_DBNAME', 'db'),
22 | 'USER': os.environ.get('POSTGRES_USER'),
23 | 'PASSWORD': os.environ.get('POSTGRES_PASSWORD'),
24 | 'PORT': 5432,
25 | }
26 | }
27 |
28 | ENABLE_ARRAY_M2M = True
29 |
30 | SECRET_KEY = "django_tests_secret_key"
31 |
32 | # Use a fast hasher to speed up tests.
33 | PASSWORD_HASHERS = [
34 | 'django.contrib.auth.hashers.MD5PasswordHasher',
35 | ]
36 |
--------------------------------------------------------------------------------
/django_postgres_extensions/backends/postgresql/base.py:
--------------------------------------------------------------------------------
1 | from django.db.backends.postgresql.base import DatabaseWrapper as BaseDatabaseWrapper
2 | from .schema import DatabaseSchemaEditor
3 | from .creation import DatabaseCreation
4 | from .operations import DatabaseOperations
5 |
6 | class DatabaseWrapper(BaseDatabaseWrapper):
7 |
8 | SchemaEditorClass = DatabaseSchemaEditor
9 |
10 | def __init__(self, *args, **kwargs):
11 | super(DatabaseWrapper, self).__init__(*args, **kwargs)
12 | self.creation = DatabaseCreation(self)
13 | self.ops = DatabaseOperations(self)
14 |
15 | self.any_operators = {
16 | 'exact': '= ANY(%s)',
17 | 'in': 'LIKE ANY(%s)',
18 | 'gt': '< ANY(%s)',
19 | 'gte': '<= ANY(%s)',
20 | 'lt': '> ANY(%s)',
21 | 'lte': '>= ANY(%s)',
22 | 'startof': 'LIKE ANY(%s)',
23 | 'endof': 'LIKE ANY(%s)',
24 | 'contains': '<@ ANY(%s)'
25 | }
26 |
27 |
28 | self.all_operators = {
29 | 'exact': '= ALL(%s)',
30 | 'in': 'LIKE ALL(%s)',
31 | 'gt': '< ALL(%s)',
32 | 'gte': '<= ALL(%s)',
33 | 'lt': '> ALL(%s)',
34 | 'lte': '>= ALL(%s)',
35 | 'startof': 'LIKE ALL(%s)',
36 | 'endof': 'LIKE ALL(%s)',
37 | 'contains': '<@ ALL(%s)'
38 | }
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) Paul Martin and all contributors.
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without modification,
5 | are permitted provided that the following conditions are met:
6 |
7 | 1. Redistributions of source code must retain the above copyright notice,
8 | this list of conditions and the following disclaimer.
9 |
10 | 2. Redistributions in binary form must reproduce the above copyright
11 | notice, this list of conditions and the following disclaimer in the
12 | documentation and/or other materials provided with the distribution.
13 |
14 | 3. My name may not be used to endorse or promote products
15 | derived from this software without specific prior written permission.
16 |
17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
21 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
23 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
24 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/tests/many_to_many_array/models.py:
--------------------------------------------------------------------------------
1 | """
2 | Array Many-to-many relationships
3 | Tests taken from Django 1.9 Stable
4 | In this example, an ``Article`` can be published in multiple ``Publication``
5 | objects, and a ``Publication`` has multiple ``Article`` objects.
6 | """
7 | from __future__ import unicode_literals
8 |
9 | from django.db import models
10 | from django.utils.encoding import python_2_unicode_compatible
11 | from django_postgres_extensions.models.fields.related import ArrayManyToManyField
12 |
13 |
14 | @python_2_unicode_compatible
15 | class Publication(models.Model):
16 | title = models.CharField(max_length=30)
17 |
18 | def __str__(self):
19 | return self.title
20 |
21 | class Meta:
22 | ordering = ('title',)
23 |
24 |
25 | @python_2_unicode_compatible
26 | class Article(models.Model):
27 | headline = models.CharField(max_length=100)
28 | # Assign a unicode string as name to make sure the intermediary model is
29 | # correctly created. Refs #20207
30 | publications = ArrayManyToManyField(Publication, name='publications')
31 |
32 | def __str__(self):
33 | return self.headline
34 |
35 | class Meta:
36 | ordering = ('headline',)
37 |
38 |
39 | # Models to test correct related_name inheritance
40 | class AbstractArticle(models.Model):
41 | class Meta:
42 | abstract = True
43 |
44 | publications = ArrayManyToManyField(Publication, name='publications', related_name='+')
45 |
46 |
47 | class InheritedArticleA(AbstractArticle):
48 | pass
49 |
50 |
51 | class InheritedArticleB(AbstractArticle):
52 | pass
--------------------------------------------------------------------------------
/tests/modeladmin/models.py:
--------------------------------------------------------------------------------
1 | from django.db import models
2 | from django_postgres_extensions.models.fields import HStoreField, JSONField, ArrayField
3 | from django_postgres_extensions.models.fields.related import ArrayManyToManyField
4 | from django import forms
5 | from django.contrib.postgres.forms import SplitArrayField
6 | from django_postgres_extensions.forms.fields import NestedFormField
7 |
8 | details_fields = (
9 | ('Brand', NestedFormField(keys=('Name', 'Country'))),
10 | ('Type', forms.CharField(max_length=25, required=False)),
11 | ('Colours', SplitArrayField(base_field=forms.CharField(max_length=10, required=False), size=10)),
12 | )
13 |
14 | class Buyer(models.Model):
15 | time = models.DateTimeField(auto_now_add=True)
16 | name = models.CharField(max_length=20)
17 |
18 | def __str__(self):
19 | return self.name
20 |
21 | class Product(models.Model):
22 | name = models.CharField(max_length=15)
23 | keywords = ArrayField(models.CharField(max_length=20), default=[], form_size=10, blank=True)
24 | sports = ArrayField(models.CharField(max_length=20),default=[], blank=True, choices=(
25 | ('football', 'Football'), ('tennis', 'Tennis'), ('golf', 'Golf'), ('basketball', 'Basketball'), ('hurling', 'Hurling'), ('baseball', 'Baseball')))
26 | shipping = HStoreField(keys=('Address', 'City', 'Region', 'Country'), blank=True, default={})
27 | details = JSONField(fields=details_fields, blank=True, default={})
28 | buyers = ArrayManyToManyField(Buyer)
29 |
30 | def __str__(self):
31 | return self.name
32 |
33 | @property
34 | def country(self):
35 | return self.shipping.get('Country', '')
36 |
--------------------------------------------------------------------------------
/django_postgres_extensions/models/sql/subqueries.py:
--------------------------------------------------------------------------------
1 | from django.db.models.sql.subqueries import UpdateQuery as BaseUpdateQuery
2 | from django.utils import six
3 | from django.core.exceptions import FieldError
4 |
5 | class UpdateQuery(BaseUpdateQuery):
6 | def add_update_values(self, values):
7 | """
8 | Convert a dictionary of field name to value mappings into an update
9 | query. This is the entry point for the public update() method on
10 | querysets.
11 | """
12 | values_seq = []
13 | for name, val in six.iteritems(values):
14 | if '__' in name:
15 | indexes = name.split('__')
16 | field_name = indexes.pop(0)
17 | field = self.get_meta().get_field(field_name)
18 | val = field.get_update_type(indexes, val)
19 | model = field.model
20 | else:
21 | field = self.get_meta().get_field(name)
22 | direct = not (field.auto_created and not field.concrete) or not field.concrete
23 | model = field.model._meta.concrete_model
24 | if not direct or (field.is_relation and field.many_to_many):
25 | raise FieldError(
26 | 'Cannot update model field %r (only non-relations and '
27 | 'foreign keys permitted).' % field
28 | )
29 | else:
30 | if model is not self.get_meta().model:
31 | self.add_related_update(model, field, val)
32 | continue
33 | values_seq.append((field, model, val))
34 | return self.add_update_fields(values_seq)
--------------------------------------------------------------------------------
/django_postgres_extensions/utils.py:
--------------------------------------------------------------------------------
1 | import collections
2 |
3 | class OrderedSet(collections.MutableSet):
4 | """
5 | With thanks to http://code.activestate.com/recipes/576694/
6 | """
7 | def __init__(self, iterable=None):
8 | self.end = end = []
9 | end += [None, end, end] # sentinel node for doubly linked list
10 | self.map = {} # key --> [key, prev, next]
11 | if iterable is not None:
12 | self |= iterable
13 |
14 | def __len__(self):
15 | return len(self.map)
16 |
17 | def __contains__(self, key):
18 | return key in self.map
19 |
20 | def add(self, key):
21 | if key not in self.map:
22 | end = self.end
23 | curr = end[1]
24 | curr[2] = end[1] = self.map[key] = [key, curr, end]
25 |
26 | def discard(self, key):
27 | if key in self.map:
28 | key, prev, next = self.map.pop(key)
29 | prev[2] = next
30 | next[1] = prev
31 |
32 | def __iter__(self):
33 | end = self.end
34 | curr = end[2]
35 | while curr is not end:
36 | yield curr[0]
37 | curr = curr[2]
38 |
39 | def __reversed__(self):
40 | end = self.end
41 | curr = end[1]
42 | while curr is not end:
43 | yield curr[0]
44 | curr = curr[1]
45 |
46 | def pop(self, last=True):
47 | if not self:
48 | raise KeyError('set is empty')
49 | key = self.end[1][0] if last else self.end[2][0]
50 | self.discard(key)
51 | return key
52 |
53 | def __repr__(self):
54 | if not self:
55 | return '%s()' % (self.__class__.__name__,)
56 | return '%s(%r)' % (self.__class__.__name__, list(self))
57 |
58 | def __eq__(self, other):
59 | if isinstance(other, OrderedSet):
60 | return len(self) == len(other) and list(self) == list(other)
61 | return set(self) == set(other)
--------------------------------------------------------------------------------
/docs/arraym2m.rst:
--------------------------------------------------------------------------------
1 | Array Many To Many Field
2 | ========================
3 |
4 | Basic Usage
5 | -----------
6 |
7 | The Array Many To Many Field is designed be a drop-in replacement for the normal Django Many To Many Field
8 | except that it uses an array instead of a separate table to store relationships, but replicates many of the same features.
9 | In general, write queries are much faster than the traditional M2M however select queries are typically slower.
10 |
11 | To use this field, it is required to set ENABLE_ARRAY_M2M = True in settings.py (to enable the required monkey-patching)::
12 |
13 | ENABLE_ARRAY_M2M = True
14 |
15 | Then in models.py::
16 |
17 | from django.db import models
18 | from django_postgres_extensions.models.fields import ArrayManyToManyField
19 |
20 | class Publication(models.Model):
21 | title = models.CharField(max_length=30)
22 |
23 | def __str__(self):
24 | return self.title
25 |
26 | class Meta:
27 | ordering = ('title',)
28 |
29 | class Article(models.Model):
30 | headline = models.CharField(max_length=100)
31 | publications = ArrayManyToManyField(Publication, name='publications')
32 |
33 | def __str__(self):
34 | return self.headline
35 |
36 | class Meta:
37 | ordering = ('headline',)
38 |
39 | The Array Many To Many field supports the following features which replicate the API of the regular Many To Many Field:
40 |
41 | - Descriptor queryset with add, remove, clear and set for both forward and reverse relationships
42 | - Prefetch related for both forward and reverse relationships
43 | - Lookups across relationships with filter for both forward and reverse relationships
44 | - Lookups across relationships with exclude for for forward relationships only
45 |
46 | You can find more information on how these features work in the Django documentation for the regular Many To Many Field:
47 |
48 | https://docs.djangoproject.com/en/1.9/topics/db/examples/many_to_many/
49 |
--------------------------------------------------------------------------------
/django_postgres_extensions/models/fields/related_lookups.py:
--------------------------------------------------------------------------------
1 | from django.db.models.fields.related_lookups import RelatedLookupMixin, MultiColSource, get_normalized_value
2 | from django.contrib.postgres.fields.array import ArrayContains, ArrayContainedBy, ArrayExact, ArrayOverlap
3 |
4 | from django_postgres_extensions.models.lookups import (
5 | AnyExact, AnyGreaterThan, AnyLessThan, AnyGreaterThanOrEqual, AnyLessThanOrEqual, ContainsItem)
6 |
7 | class RelatedArrayMixin(RelatedLookupMixin):
8 | def get_prep_lookup(self):
9 | self.lookup_name = 'contains'
10 | if not isinstance(self.lhs, MultiColSource) and self.rhs_is_direct_value():
11 | self.rhs = [get_normalized_value(value, self.lhs)[0] for value in self.rhs]
12 | if hasattr(self.lhs.output_field, 'get_path_info'):
13 | self.rhs = [self.lhs.output_field.get_path_info()[-1].target_fields[-1].get_prep_value(rhs) for rhs in
14 | self.rhs]
15 | self.lookup_name = 'exact'
16 | return super(RelatedLookupMixin, self).get_prep_lookup()
17 |
18 | class RelatedAnyExact(RelatedLookupMixin, AnyExact):
19 | pass
20 |
21 | class RelatedAnyGreaterThan(RelatedLookupMixin, AnyGreaterThan):
22 | pass
23 |
24 | class RelatedAnyLessThan(RelatedLookupMixin, AnyLessThan):
25 | pass
26 |
27 | class RelatedAnyGreaterThanOrEqual(RelatedLookupMixin, AnyGreaterThanOrEqual):
28 | pass
29 |
30 | class RelatedAnyLessThanOrEqual(RelatedLookupMixin, AnyLessThanOrEqual):
31 | pass
32 |
33 | class RelatedArrayExact(RelatedArrayMixin, ArrayExact):
34 | """
35 | More like what exact should be. Checks the array contains the array of related objects and only those related objects
36 | """
37 | pass
38 |
39 | class RelatedArrayContains(RelatedArrayMixin, ArrayContains):
40 | pass
41 |
42 | class RelatedContainsItem(RelatedArrayMixin, ContainsItem):
43 | pass
44 |
45 | class RelatedArrayContainedBy(RelatedArrayMixin, ArrayContainedBy):
46 | pass
47 |
48 | class RelatedArrayOverlap(RelatedArrayMixin, ArrayOverlap):
49 | pass
--------------------------------------------------------------------------------
/django_postgres_extensions/models/sql/datastructures.py:
--------------------------------------------------------------------------------
1 | from django.db.models.sql.datastructures import Join as BaseJoin
2 |
3 | #class Join(BaseJoin):
4 |
5 | def as_sql(self, compiler, connection):
6 | """
7 | Generates the full
8 | LEFT OUTER JOIN sometable ON sometable.somecol = othertable.othercol, params
9 | clause for this join.
10 | """
11 | join_conditions = []
12 | params = []
13 | qn = compiler.quote_name_unless_alias
14 | qn2 = connection.ops.quote_name
15 |
16 | # Add a join condition for each pair of joining columns.
17 |
18 | for index, (lhs_col, rhs_col) in enumerate(self.join_cols):
19 | if hasattr(self.join_field, 'get_join_on'):
20 | join_condition = self.join_field.get_join_on(qn(self.parent_alias), qn2(lhs_col), qn(self.table_alias),
21 | qn2(rhs_col))
22 | join_conditions.append(join_condition)
23 | else:
24 | join_conditions.append('%s.%s = %s.%s' % (
25 | qn(self.parent_alias),
26 | qn2(lhs_col),
27 | qn(self.table_alias),
28 | qn2(rhs_col),
29 | ))
30 |
31 | # Add a single condition inside parentheses for whatever
32 | # get_extra_restriction() returns.
33 | extra_cond = self.join_field.get_extra_restriction(
34 | compiler.query.where_class, self.table_alias, self.parent_alias)
35 | if extra_cond:
36 | extra_sql, extra_params = compiler.compile(extra_cond)
37 | join_conditions.append('(%s)' % extra_sql)
38 | params.extend(extra_params)
39 |
40 | if not join_conditions:
41 | # This might be a rel on the other end of an actual declared field.
42 | declared_field = getattr(self.join_field, 'field', self.join_field)
43 | raise ValueError(
44 | "Join generated an empty ON clause. %s did not yield either "
45 | "joining columns or extra restrictions." % declared_field.__class__
46 | )
47 | on_clause_sql = ' AND '.join(join_conditions)
48 | alias_str = '' if self.table_alias == self.table_name else (' %s' % self.table_alias)
49 | sql = '%s %s%s ON (%s)' % (self.join_type, qn(self.table_name), alias_str, on_clause_sql)
50 | return sql, params
--------------------------------------------------------------------------------
/django_postgres_extensions/models/fields/reverse_related.py:
--------------------------------------------------------------------------------
1 | from django.db.models.fields.reverse_related import ForeignObjectRel
2 | from django.db.models.fields.related_lookups import RelatedExact, RelatedIn, RelatedGreaterThan, RelatedIsNull, \
3 | RelatedLessThan, RelatedGreaterThanOrEqual, RelatedLessThanOrEqual
4 | from django.core import exceptions
5 |
6 | class ArrayManyToManyRel(ForeignObjectRel):
7 |
8 | """
9 | Used by ManyToManyFields to store information about the relation.
10 |
11 | ``_meta.get_fields()`` returns this class to provide access to the field
12 | flags for the reverse relation.
13 | """
14 | def __init__(self, field, to, field_name, related_name=None, related_query_name=None,
15 | limit_choices_to=None, symmetrical=True):
16 | super(ArrayManyToManyRel, self).__init__(
17 | field, to,
18 | related_name=related_name,
19 | related_query_name=related_query_name,
20 | limit_choices_to=limit_choices_to,
21 | )
22 |
23 | self.model_name = to
24 | self.field_name = field_name
25 | self.symmetrical = symmetrical
26 |
27 | def get_join_on(self, parent_alias, lhs_col, table_alias, rhs_col):
28 | return '%s.%s = ANY(%s.%s)' % (
29 | parent_alias,
30 | lhs_col,
31 | table_alias,
32 | rhs_col,
33 | )
34 |
35 | def set_field_name(self):
36 | self.field_name = self.field_name or self.model._meta.pk.name
37 |
38 | def get_related_field(self):
39 | """
40 | Return the Field in the 'to' object to which this relationship is tied.
41 | """
42 | field = self.model._meta.get_field(self.field_name)
43 | if not field.concrete:
44 | raise exceptions.FieldDoesNotExist("No related field named '%s'" %
45 | self.field_name)
46 | return field
47 |
48 | def get_lookup(self, lookup_name):
49 | if lookup_name == 'in':
50 | return RelatedIn
51 | elif lookup_name == 'exact':
52 | return RelatedExact
53 | elif lookup_name == 'gt':
54 | return RelatedGreaterThan
55 | elif lookup_name == 'gte':
56 | return RelatedGreaterThanOrEqual
57 | elif lookup_name == 'lt':
58 | return RelatedLessThan
59 | elif lookup_name == 'lte':
60 | return RelatedLessThanOrEqual
61 | elif lookup_name == 'isnull':
62 | return RelatedIsNull
63 | else:
64 | raise TypeError('Related Field got invalid lookup: %s' % lookup_name)
--------------------------------------------------------------------------------
/tests/m2m_multiple_array/tests.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals
2 |
3 | from datetime import datetime
4 |
5 | from django.test import TestCase
6 |
7 | from .models import Article, Category
8 |
9 |
10 | class M2MMultipleTests(TestCase):
11 | def test_multiple(self):
12 | c1, c2, c3, c4 = [
13 | Category.objects.create(name=name)
14 | for name in ["Sports", "News", "Crime", "Life"]
15 | ]
16 |
17 | a1 = Article.objects.create(
18 | headline="Parrot steals", pub_date=datetime(2005, 11, 27)
19 | )
20 | a1.primary_categories.add(c2, c3)
21 | a1.secondary_categories.add(c4)
22 |
23 | a2 = Article.objects.create(
24 | headline="Parrot runs", pub_date=datetime(2005, 11, 28)
25 | )
26 | a2.primary_categories.add(c1, c2)
27 | a2.secondary_categories.add(c4)
28 |
29 | self.assertQuerysetEqual(
30 | a1.primary_categories.all(), [
31 | "Crime",
32 | "News",
33 | ],
34 | lambda c: c.name
35 | )
36 | self.assertQuerysetEqual(
37 | a2.primary_categories.all(), [
38 | "News",
39 | "Sports",
40 | ],
41 | lambda c: c.name
42 | )
43 | self.assertQuerysetEqual(
44 | a1.secondary_categories.all(), [
45 | "Life",
46 | ],
47 | lambda c: c.name
48 | )
49 | self.assertQuerysetEqual(
50 | c1.primary_article_set.all(), [
51 | "Parrot runs",
52 | ],
53 | lambda a: a.headline
54 | )
55 | self.assertQuerysetEqual(
56 | c1.secondary_article_set.all(), []
57 | )
58 | self.assertQuerysetEqual(
59 | c2.primary_article_set.all(), [
60 | "Parrot steals",
61 | "Parrot runs",
62 | ],
63 | lambda a: a.headline
64 | )
65 | self.assertQuerysetEqual(
66 | c2.secondary_article_set.all(), []
67 | )
68 | self.assertQuerysetEqual(
69 | c3.primary_article_set.all(), [
70 | "Parrot steals",
71 | ],
72 | lambda a: a.headline
73 | )
74 | self.assertQuerysetEqual(
75 | c3.secondary_article_set.all(), []
76 | )
77 | self.assertQuerysetEqual(
78 | c4.primary_article_set.all(), []
79 | )
80 | self.assertQuerysetEqual(
81 | c4.secondary_article_set.all(), [
82 | "Parrot steals",
83 | "Parrot runs",
84 | ],
85 | lambda a: a.headline
86 | )
87 |
--------------------------------------------------------------------------------
/docs/json.rst:
--------------------------------------------------------------------------------
1 | JSONField
2 | =========
3 |
4 | Basic Usage
5 | -----------
6 |
7 | To use the JSON field::
8 |
9 | from django.db import models
10 | from django_postgres_extensions.models.fields import JSONField
11 |
12 | class Product(models.Model):
13 | description = JSONField(null=True, blank=True)
14 |
15 | Individual keys
16 | ---------------
17 |
18 | - Get json values by key or key path::
19 |
20 | from django_postgres_extensions.models.expressions import Key
21 | obj = Product.objectsannotate(Key('description', 'Details')).get()
22 | obj = Product.objects.annotate(Key('description', 'Details__Rating')).get()
23 | obj = Product.objects.annotate(Key('description', 'Tags__1')).get()
24 |
25 | - Update JSON Field by specific keys, leaving any others untouched::
26 |
27 | Product.objects.update(description__ = {'Industry': 'Movie', 'Popularity': 'Very Popular'})
28 |
29 | - Delete JSONField by key or key path::
30 |
31 | Product.objects.update(description__del ='Details')
32 | Product.objects.update(description__del = 'Details__Release')
33 | Product.objects.update(description__del='Tags__1')
34 |
35 | Database functions
36 | ------------------
37 |
38 | Various database functions are included for interacting with JSONFields:
39 |
40 | - JSONBSet: updates individual keys in the JSONField without modifying the others.
41 |
42 | - JSONBArrayLength: returns the length of a JSONField who's parent object is an array.
43 |
44 |
45 | Check the postgresql documentation for more information on these functions.
46 | These functions handle the arguments by converting them to the correct expressions automatically::
47 |
48 | from django_postgres_extensions.models.functions import *
49 | from psycopg2.extras import Json
50 | obj = Product.objects.update(description = JSONBSet('description', ['Details', 'Genre'], Json('Heavy Metal'), True))
51 | obj = Product.objects.update(description = JSONBSet('description', ['1', 'c'], Json('g')))
52 | obj = Product.objects.queryset.annotate(tags_length=JSONBArrayLength('tags', 1)).get()
53 |
54 | Use With NestedFormField
55 | ------------------------
56 |
57 | The same NestedFormField and NestedFormWidget referred in the HStore description can also be used with a JSON Field.
58 | To use it give the fields keyword argument::
59 |
60 | details_fields = (
61 | NestedFormField(label='Brand', keys=('Name', 'Country')),
62 | forms.CharField(label='Type', max_length=25, required=False),
63 | SplitArrayField(label='Colours', base_field=forms.CharField(max_length=10, required=False), size=10),
64 | )
65 |
66 | class Product(models.Model):
67 | details = JSONField(fields=details_fields, blank=True, default={})
68 |
69 | The field would look like:
70 |
71 | .. image:: json_field.jpg
--------------------------------------------------------------------------------
/django_postgres_extensions/forms/fields.py:
--------------------------------------------------------------------------------
1 | from .widgets import NestedFormWidget
2 | from django.forms.fields import MultiValueField, CharField
3 | from django.core.exceptions import ValidationError
4 |
5 |
6 | class NestedFormField(MultiValueField):
7 | """
8 | A Field that aggregates the logic of multiple Fields to create a nested form within a form.
9 |
10 | The compress method returns a dictionary of field names and values.
11 |
12 | Requires either a ``fields`` or ``keys`` argument but not both.
13 |
14 | The ``fields`` argument is a list of tuples; each tuple consisting of a field name and a field instance.
15 | If given a nested form will be created consisting of these fields.
16 |
17 | The ``keys`` argument is a list/tuple of field names. If given, a nested form will be created consisting of
18 | django.forms.CharField instances, with the given key names. This is primarily for use with
19 | django.contrib.postgres.HStoreField. By default, all fields are not required.
20 |
21 | To make all fields required set the ``require_all_fields`` argument to True.
22 |
23 | The ``max_value_length`` is ignored if the ``fields`` argument is given. If the ``keys`` argument is given, the max
24 | length for each CharField instance will be set to this value.
25 |
26 | Uses the NestedFormWidget.
27 | """
28 | def __init__(self, fields=(), keys=(), require_all_fields=False, max_value_length=25, *args, **kwargs):
29 | if (fields and keys) or (not fields and not keys):
30 | raise ValueError("NestedFormField requires either a tuple of fields or keys but not both")
31 |
32 | if keys:
33 | fields = []
34 | for key in keys:
35 | field = CharField(max_length=max_value_length, required=False)
36 | fields.append((key, field))
37 | form_fields = []
38 | widgets = []
39 | self.labels = []
40 | self.names = {}
41 | for field in fields:
42 | label = field[1].label or field[0]
43 | self.names[label] = field[0]
44 | self.labels.append(label)
45 | form_fields.append(field[1])
46 | widgets.append(field[1].widget)
47 | widget = NestedFormWidget(self.labels, widgets, self.names)
48 | super(NestedFormField, self).__init__(*args, fields=form_fields, widget=widget,
49 | require_all_fields=require_all_fields, **kwargs)
50 |
51 | def compress(self, data_list):
52 | result = {}
53 | for i, label in enumerate(self.labels):
54 | name = self.names[label]
55 | result[name] = data_list[i]
56 | return result
57 |
58 | def to_python(self, value):
59 | if not value:
60 | return {}
61 | if isinstance(value, dict):
62 | return value
63 | else:
64 | raise ValidationError(
65 | self.error_messages['invalid_json'],
66 | code='invalid_json',
67 | )
68 |
--------------------------------------------------------------------------------
/django_postgres_extensions/backends/postgresql/creation.py:
--------------------------------------------------------------------------------
1 | from django.contrib.postgres.signals import register_type_handlers
2 | from django.db.backends.postgresql.creation import DatabaseCreation as BaseDatabaseCreation
3 | from django.conf import settings
4 |
5 | class DatabaseCreation(BaseDatabaseCreation):
6 | def create_test_db(self, verbosity=1, autoclobber=False, serialize=True, keepdb=False):
7 | """
8 | Creates a test database, prompting the user for confirmation if the
9 | database already exists. Returns the name of the test database created.
10 | """
11 | # Don't import django.core.management if it isn't needed.
12 | from django.core.management import call_command
13 |
14 | test_database_name = self._get_test_db_name()
15 |
16 | if verbosity >= 1:
17 | action = 'Creating'
18 | if keepdb:
19 | action = "Using existing"
20 |
21 | print("%s test database for alias %s..." % (
22 | action,
23 | self._get_database_display_str(verbosity, test_database_name),
24 | ))
25 |
26 | # We could skip this call if keepdb is True, but we instead
27 | # give it the keepdb param. This is to handle the case
28 | # where the test DB doesn't exist, in which case we need to
29 | # create it, then just not destroy it. If we instead skip
30 | # this, we will get an exception.
31 | self._create_test_db(verbosity, autoclobber, keepdb)
32 |
33 | self.connection.close()
34 | settings.DATABASES[self.connection.alias]["NAME"] = test_database_name
35 | self.connection.settings_dict["NAME"] = test_database_name
36 |
37 | with self.connection.cursor() as cursor:
38 | for extension in ('hstore',):
39 | cursor.execute("CREATE EXTENSION IF NOT EXISTS %s" % extension)
40 | register_type_handlers(self.connection)
41 |
42 | # We report migrate messages at one level lower than that requested.
43 | # This ensures we don't get flooded with messages during testing
44 | # (unless you really ask to be flooded).
45 | call_command(
46 | 'migrate',
47 | verbosity=max(verbosity - 1, 0),
48 | interactive=False,
49 | database=self.connection.alias,
50 | run_syncdb=True,
51 | )
52 |
53 | # We then serialize the current state of the database into a string
54 | # and store it on the connection. This slightly horrific process is so people
55 | # who are testing on databases without transactions or who are using
56 | # a TransactionTestCase still get a clean database on every test run.
57 | if serialize:
58 | self.connection._test_serialized_contents = self.serialize_db_to_string()
59 |
60 | call_command('createcachetable', database=self.connection.alias)
61 |
62 | # Ensure a connection for the side effect of initializing the test database.
63 | self.connection.ensure_connection()
64 |
65 | return test_database_name
66 |
--------------------------------------------------------------------------------
/django_postgres_extensions/models/sql/compiler.py:
--------------------------------------------------------------------------------
1 | from django.db.models.sql.compiler import (SQLCompiler, SQLInsertCompiler, SQLUpdateCompiler as BaseUpdateCompiler,
2 | SQLAggregateCompiler, SQLDeleteCompiler)
3 | from django.core.exceptions import FieldError
4 |
5 | def no_quote_name(name):
6 | return name
7 |
8 | class SQLUpdateCompiler(BaseUpdateCompiler):
9 | def as_sql(self):
10 | """
11 | Creates the SQL for this query. Returns the SQL string and list of
12 | parameters.
13 | """
14 | self.pre_sql_setup()
15 | if not self.query.values:
16 | return '', ()
17 | table = self.query.base_table
18 | qn = self.quote_name_unless_alias
19 | result = ['UPDATE %s' % qn(table)]
20 | result.append('SET')
21 | values, update_params = [], []
22 | for field, model, val in self.query.values:
23 | self.name = name = field.column
24 | if hasattr(val, 'alter_name'):
25 | self.name = name = val.alter_name(name, qn)
26 | qn = no_quote_name
27 | val = val.value
28 | if hasattr(val, 'resolve_expression'):
29 | val = val.resolve_expression(self.query, allow_joins=False, for_save=True)
30 | if val.contains_aggregate:
31 | raise FieldError("Aggregate functions are not allowed in this query")
32 | elif hasattr(val, 'prepare_database_save'):
33 | if field.remote_field:
34 | val = field.get_db_prep_save(
35 | val.prepare_database_save(field),
36 | connection=self.connection,
37 | )
38 | else:
39 | raise TypeError(
40 | "Tried to update field %s with a model instance, %r. "
41 | "Use a value compatible with %s."
42 | % (field, val, field.__class__.__name__)
43 | )
44 | else:
45 | val = field.get_db_prep_save(val, connection=self.connection)
46 |
47 | # Getting the placeholder for the field.
48 | if hasattr(field, 'get_placeholder'):
49 | placeholder = field.get_placeholder(val, self, self.connection)
50 | else:
51 | placeholder = '%s'
52 | self.placeholder = placeholder
53 | if hasattr(val, 'as_sql'):
54 | sql, params = self.compile(val)
55 | values.append('%s = %s' % (qn(name), sql))
56 | update_params.extend(params)
57 | elif val is not None:
58 | values.append('%s = %s' % (qn(name), placeholder))
59 | update_params.append(val)
60 | else:
61 | values.append('%s = NULL' % qn(name))
62 | if not values:
63 | return '', ()
64 | result.append(', '.join(values))
65 | where, params = self.compile(self.query.where)
66 | if where:
67 | result.append('WHERE %s' % where)
68 | return ' '.join(result), tuple(update_params + params)
--------------------------------------------------------------------------------
/docs/hstores.rst:
--------------------------------------------------------------------------------
1 | HStoreField
2 | ===========
3 |
4 | Basic Usage
5 | -----------
6 | To use the HStoreField::
7 |
8 | from django.db import models
9 | from django_postgres_extensions.models.fields import HStoreField
10 | class Product(models.Model):
11 | description = HStoreField(null=True, blank=True)
12 |
13 | The customized Postgres HStoreField adds the following features:
14 |
15 | Individual keys
16 | ---------------
17 |
18 | - Get hstore values by key::
19 |
20 | from django_postgres_extensions.models.expressions import Key, Keys
21 | obj = Product.objects.annotate(Key('description', 'Release')).get()
22 | obj = Product.objects.annotate(Keys('description', ['Industry', 'Release'])).get()
23 |
24 | - Update hstore by specific keys, leaving any others untouched::
25 |
26 | Product.objects.update(description__ = {'Genre': 'Heavy Metal', 'Popularity': 'Very Popular'})
27 |
28 | Database functions
29 | ------------------
30 |
31 | Various database functions are included for interacting with HStores.
32 |
33 | - Slice: Return a dictionary with just the specified keys
34 |
35 | - Delete: Delete a key or list of keys from the hstore. Keys can also be deleted by specifying a dictionary
36 |
37 | - AKeys: Returns the hstore keys as a list
38 |
39 | - AVals: Returns the hstore values as a list
40 |
41 | - HStoreToArray: Returns the hstore as an array
42 |
43 | - HStoreToMatrix: Returns the hstore as a matrix
44 |
45 | - HstoreToJSONB: Returns the hstore as JSON, with values adapated to their correct Python data types (hstore normally only returns values as strings)
46 |
47 | - HstoreToJSONBLoose: Same as HstoreToJSONB, but attempt to distinguish numerical and Boolean values so they are unquoted in the JSON
48 |
49 | For more information on these functions, check the postgresql documentation for each one.
50 | These functions handle the arguments by converting them to the correct expressions automatically::
51 |
52 | from django_postgres_extensions.models.functions import *
53 | obj = Product.objects.queryset.annotate(description_slice=Slice('description', ['Industry', 'Release'])).get()
54 | obj = Product.objects.update(description = Delete('description', 'Genre'))
55 | obj = Product.objects.update(description = Delete('description', ['Industry', 'Genre']))
56 | Product.objects.update(description=Delete('description', {'Industry': 'Music', 'Release': 'Song', 'Genre': 'Rock'}))
57 | Product.objects.annotate(description_keys=AKeys('description')).get()
58 | Product.objects.annotate(description_values=AVals('description')).get()
59 | Product.objects.annotate(description_array=HStoreToArray('description')).get()
60 | Product.objects.annotate(description_matrix=HStoreToMatrix('description')).get()
61 | Product.objects.annotate(description_jsonb=HstoreToJSONB('description')).get()
62 | Product.objects.annotate(description_jsonb=HstoreToJSONBLoose('description')).get()
63 |
64 | Use With Nested Form Field
65 | --------------------------
66 |
67 | django.contrib.postgres includes a HStoreField for forms where you have to enter a hstore value programatically.
68 | Django Postgres Extensions adds a NestedFormField and NestedFormWidget
69 | (subclassed from the Django MultiValue Field and Widget) for use with a HStore Field.
70 | To use it specify a list of fields or a list of keys as a keyword argument to the Hstore Model field, but not both::
71 |
72 | class Product(models.Model):
73 | shipping = HStoreField(keys=('Address', 'City', 'Region', 'Country'), blank=True, default={})
74 |
75 | The field would look like:
76 |
77 | .. image:: hstore_field.jpg
--------------------------------------------------------------------------------
/django_postgres_extensions/admin/options.py:
--------------------------------------------------------------------------------
1 | from django.contrib.admin.options import ModelAdmin
2 | from django.utils.translation import string_concat, ugettext as _
3 | from django.forms.widgets import CheckboxSelectMultiple, SelectMultiple
4 | from django_postgres_extensions.models import ArrayManyToManyField
5 | from django.contrib.admin import widgets
6 |
7 | class PostgresAdmin(ModelAdmin):
8 |
9 | def formfield_for_manytomany(self, db_field, request=None, **kwargs):
10 | """
11 | Get a form Field for a ManyToManyField.
12 | """
13 | # If it uses an intermediary model that isn't auto created, don't show
14 | # a field in admin.
15 | if hasattr(db_field, 'through') and not db_field.remote_field.through._meta.auto_created:
16 | return None
17 | db = kwargs.get('using')
18 |
19 | if db_field.name in self.raw_id_fields:
20 | kwargs['widget'] = widgets.ManyToManyRawIdWidget(db_field.remote_field,
21 | self.admin_site, using=db)
22 | kwargs['help_text'] = ''
23 | elif db_field.name in (list(self.filter_vertical) + list(self.filter_horizontal)):
24 | kwargs['widget'] = widgets.FilteredSelectMultiple(
25 | db_field.verbose_name,
26 | db_field.name in self.filter_vertical
27 | )
28 |
29 | if 'queryset' not in kwargs:
30 | queryset = self.get_field_queryset(db, db_field, request)
31 | if queryset is not None:
32 | kwargs['queryset'] = queryset
33 |
34 | form_field = db_field.formfield(**kwargs)
35 | if isinstance(form_field.widget, SelectMultiple) and not isinstance(form_field.widget, CheckboxSelectMultiple):
36 | msg = _('Hold down "Control", or "Command" on a Mac, to select more than one.')
37 | help_text = form_field.help_text
38 | form_field.help_text = string_concat(help_text, ' ', msg) if help_text else msg
39 | return form_field
40 |
41 | def formfield_for_dbfield(self, db_field, request, **kwargs):
42 |
43 | # ForeignKey or ManyToManyFields
44 | if isinstance(db_field, ArrayManyToManyField):
45 | # Combine the field kwargs with any options for formfield_overrides.
46 | # Make sure the passed in **kwargs override anything in
47 | # formfield_overrides because **kwargs is more specific, and should
48 | # always win.
49 | if db_field.__class__ in self.formfield_overrides:
50 | kwargs = dict(self.formfield_overrides[db_field.__class__], **kwargs)
51 |
52 | formfield = self.formfield_for_manytomany(db_field, request, **kwargs)
53 |
54 | # For non-raw_id fields, wrap the widget with a wrapper that adds
55 | # extra HTML -- the "add other" interface -- to the end of the
56 | # rendered output. formfield can be None if it came from a
57 | # OneToOneField with parent_link=True or a M2M intermediary.
58 | if formfield and db_field.name not in self.raw_id_fields:
59 | related_modeladmin = self.admin_site._registry.get(db_field.remote_field.model)
60 | wrapper_kwargs = {}
61 | if related_modeladmin:
62 | wrapper_kwargs.update(
63 | can_add_related=related_modeladmin.has_add_permission(request),
64 | can_change_related=related_modeladmin.has_change_permission(request),
65 | can_delete_related=related_modeladmin.has_delete_permission(request),
66 | )
67 | formfield.widget = widgets.RelatedFieldWidgetWrapper(
68 | formfield.widget, db_field.remote_field, self.admin_site, **wrapper_kwargs
69 | )
70 |
71 | return formfield
72 | else:
73 | return super(PostgresAdmin, self).formfield_for_dbfield(db_field, request, **kwargs)
--------------------------------------------------------------------------------
/django_postgres_extensions/models/expressions.py:
--------------------------------------------------------------------------------
1 | from django.db.models.expressions import F as BaseF, Value as BaseValue, Func, Expression
2 | from django.utils import six
3 | from django.contrib.postgres.fields.array import IndexTransform
4 | from django.utils.functional import cached_property
5 | from django.db.models.lookups import Transform
6 |
7 | class OperatorMixin(object):
8 | CAT = '||'
9 | REPLACE = '#='
10 | DELETE = '#-'
11 | KEY = '->'
12 | KEYTEXT = '->>'
13 | PATH = '#>'
14 | PATHTEXT = '#>>'
15 |
16 | def cat(self, other):
17 | return self._combine(other, self.CAT, False)
18 |
19 | def replace(self, other):
20 | return self._combine(other, self.REPLACE, False)
21 |
22 | def delete(self, other):
23 | return self._combine(other, self.DELETE, False)
24 |
25 | def key(self, other):
26 | return self._combine(other, self.KEY, False)
27 |
28 | def keytext(self, other):
29 | return self._combine(other, self.KEYTEXT, False)
30 |
31 | def path(self, other):
32 | return self._combine(other, self.PATH, False)
33 |
34 | def pathtext(self, other):
35 | return self._combine(other, self.PATHTEXT, False)
36 |
37 | class F(BaseF, OperatorMixin):
38 | pass
39 |
40 | class Value(BaseValue, OperatorMixin):
41 | def as_sql(self, compiler, connection):
42 | if self._output_field_or_none and any(self._output_field_or_none.get_internal_type() == fieldname for fieldname in
43 | ['ArrayField', 'MultiReferenceArrayField']):
44 | base_field = self._output_field_or_none.base_field
45 | return '%s::%s[]' % ('%s', base_field.db_type(connection)), [self.value]
46 | return super(Value, self).as_sql(compiler, connection)
47 |
48 | class Index(IndexTransform):
49 | def __init__(self, field, index, *args, **kwargs):
50 | if not isinstance(field, Expression):
51 | field = F(field)
52 | super(Index, self).__init__(index + 1, None, field, *args, **kwargs)
53 |
54 | @cached_property
55 | def default_alias(self):
56 | return '%s__%s' % (self.lhs.name, self.index - 1)
57 |
58 | @property
59 | def name(self):
60 | return self.default_alias
61 |
62 | @property
63 | def output_field(self):
64 | return self.lhs.field.base_field
65 |
66 | class SliceArray(Transform):
67 | def __init__(self, field, *indexes, **kwargs):
68 | if isinstance(field, SliceArray):
69 | self.multidimensional = True
70 | else:
71 | field = F(field)
72 | self.multidimensional = False
73 | self.indexes = [i + 1 for i in indexes]
74 | super(SliceArray, self).__init__(field, **kwargs)
75 |
76 | def as_sql(self, compiler, connection):
77 | lhs, params = compiler.compile(self.lhs)
78 | return '%s[%s:%s]' % (lhs, self.indexes[0], self.indexes[1]), params
79 |
80 | @cached_property
81 | def default_alias(self):
82 | return '%s__%s_%s' % (self.lhs.name, self.indexes[0] - 1, self.indexes[1] - 1)
83 |
84 | @property
85 | def name(self):
86 | return self.default_alias
87 |
88 | @property
89 | def output_field(self):
90 | if self.multidimensional:
91 | return self.lhs.field
92 | return self.lhs.field.base_field
93 |
94 |
95 | def Key(field, keys_string):
96 | if isinstance(keys_string, six.string_types) and '__' in keys_string:
97 | keys = keys_string.split('__')
98 | expression = F(field).path(Value(keys))
99 | else:
100 | expression = F(field).key(Value(keys_string))
101 | expression.default_alias = "%s__%s" % (field, keys_string)
102 | return expression
103 |
104 | def Keys(field, keys):
105 | expression = F(field).key(Value(keys))
106 | expression.default_alias = "%s__selected" % field
107 | return expression
108 |
--------------------------------------------------------------------------------
/django_postgres_extensions/models/lookups.py:
--------------------------------------------------------------------------------
1 | from django.contrib.postgres.fields.array import ArrayField, ArrayContains
2 | from django.contrib.postgres.fields.jsonb import JSONField
3 | from django.db.models.lookups import BuiltinLookup, In, \
4 | Contains, StartsWith, EndsWith
5 |
6 | class BaseAnyAllLookupMixin(object):
7 | def get_rhs_op(self, connection, rhs):
8 | if self.lookup_name == self.db_func:
9 | lookup_name = 'exact'
10 | else:
11 | lookup_name = self.lookup_name.split('%s_' % self.db_func)[1]
12 | operators = getattr(connection, '%s_operators' % self.db_func)
13 | return operators[lookup_name] % rhs
14 |
15 | def as_sql(self, compiler, connection):
16 | rhs_sql, rhs_params = self.process_lhs(compiler, connection)
17 | lhs_sql, params = self.process_rhs(compiler, connection)
18 | params.extend(rhs_params)
19 | rhs_sql = self.get_rhs_op(connection, rhs_sql)
20 | return '%s %s' % (lhs_sql, rhs_sql), params
21 |
22 | class AnyLookupMixin(BaseAnyAllLookupMixin, BuiltinLookup):
23 | db_func = 'any'
24 |
25 | class AllLookupMixin(BaseAnyAllLookupMixin, BuiltinLookup):
26 | db_func = 'all'
27 |
28 | @ArrayField.register_lookup
29 | class Any(AnyLookupMixin, BuiltinLookup):
30 | lookup_name = 'any'
31 |
32 | @ArrayField.register_lookup
33 | class AnyExact(AnyLookupMixin, BuiltinLookup):
34 | lookup_name = 'any_exact'
35 |
36 | @ArrayField.register_lookup
37 | class AnyGreaterThan(AnyLookupMixin, BuiltinLookup):
38 | lookup_name = 'any_gt'
39 |
40 | @ArrayField.register_lookup
41 | class AnyGreaterThanOrEqual(AnyLookupMixin, BuiltinLookup):
42 | lookup_name = 'any_gte'
43 |
44 | @ArrayField.register_lookup
45 | class AnyLessThan(AnyLookupMixin, BuiltinLookup):
46 | lookup_name = 'any_lt'
47 |
48 | @ArrayField.register_lookup
49 | class AnyLessThanOrEqual(AnyLookupMixin, BuiltinLookup):
50 | lookup_name = 'any_lte'
51 |
52 | @ArrayField.register_lookup
53 | class AnyLessThanOrEqual(AnyLookupMixin, Contains):
54 | lookup_name = 'any_in'
55 |
56 | @ArrayField.register_lookup
57 | class AnyStartOf(AnyLookupMixin, StartsWith):
58 | lookup_name = 'any_isstartof'
59 |
60 | @ArrayField.register_lookup
61 | class AnyEndOf(AnyLookupMixin, EndsWith):
62 | lookup_name = 'any_isendof'
63 |
64 | @ArrayField.register_lookup
65 | class All(AllLookupMixin):
66 | lookup_name = 'all'
67 |
68 | @ArrayField.register_lookup
69 | class AllExact(AllLookupMixin, BuiltinLookup):
70 | lookup_name = 'all_exact'
71 |
72 | @ArrayField.register_lookup
73 | class AllGreaterThan(AllLookupMixin, BuiltinLookup):
74 | lookup_name = 'all_gt'
75 |
76 | @ArrayField.register_lookup
77 | class AllGreaterThanOrEqual(AllLookupMixin, BuiltinLookup):
78 | lookup_name = 'all_gte'
79 |
80 | @ArrayField.register_lookup
81 | class AllLessThan(AllLookupMixin, BuiltinLookup):
82 | lookup_name = 'all_lt'
83 |
84 | @ArrayField.register_lookup
85 | class AllLessThanOrEqual(AllLookupMixin, BuiltinLookup):
86 | lookup_name = 'all_lte'
87 |
88 | @ArrayField.register_lookup
89 | class AllIn(AllLookupMixin, Contains):
90 | lookup_name = 'all_in'
91 |
92 | @ArrayField.register_lookup
93 | class AllStartOf(AllLookupMixin, StartsWith):
94 | lookup_name = 'all_isstartof'
95 |
96 | @ArrayField.register_lookup
97 | class AllEndOf(AnyLookupMixin, EndsWith):
98 | lookup_name = 'all_isendof'
99 |
100 | @ArrayField.register_lookup
101 | class AnyRegex(AnyLookupMixin, EndsWith):
102 | lookup_name = 'all_regex'
103 |
104 | @ArrayField.register_lookup
105 | class AnyContains(AnyLookupMixin, ArrayContains):
106 | lookup_name = 'any_contains'
107 |
108 | class ContainsItem(ArrayContains):
109 | lookup_name = 'contains'
110 | def __init__(self, lhs, rhs):
111 | if not isinstance(rhs, (list, tuple)):
112 | rhs = [rhs]
113 | super(ContainsItem, self).__init__(lhs, rhs)
--------------------------------------------------------------------------------
/django_postgres_extensions/forms/widgets.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals
2 |
3 | from django.forms import widgets
4 | from django.utils.safestring import mark_safe
5 | from django.utils.html import format_html
6 | import copy
7 |
8 | class NestedFormWidget(widgets.MultiWidget):
9 | """
10 | A widget that is composed of multiple widgets with labels in a list.
11 |
12 | Its render() method differs from the MultiWidget in that it adds support for labels
13 | and presents the widgets in a list.
14 |
15 | For initial values, the decompress method expects a dictionary of key names and values.
16 | The widget's value_from_datadict method returns an array of values.
17 |
18 | Its render() method differs from the MultiWidget in that it adds support for labels
19 | and presents the widgets in a list. It deals with widget labels which may be different
20 | from the key names for that widget.
21 |
22 | The ``labels`` argument is a list/tuple of label names which will also be used for css names and ids.
23 | The ``widgets`` arguments is a list of widgets. Labels will be matched to widget by index.
24 | The optional ``names`` argument is a dictionary of labels to names. The name refers to the key name in the
25 | python dictionary given to the decompress method. If this argument is not given, labels and key names
26 | will be assumed to be the same.
27 |
28 | You'll probably want to use this class with NestedFormField.
29 | """
30 | template_name = "django_postgres_extensions/nested_form_widget.html"
31 |
32 | def __init__(self, labels, widgets, names=None, attrs=None, template_name=None):
33 | self.template_name = template_name or self.template_name
34 | self.labels = labels
35 | if names:
36 | self.names = names
37 | else:
38 | self.names = {label: label for label in labels}
39 | self.id_names = [label.lower().replace(" ", "") for label in labels]
40 | super(NestedFormWidget, self).__init__(widgets, attrs=attrs)
41 |
42 | def render(self, name, value, attrs=None, renderer=None):
43 | if self.is_localized:
44 | for widget in self.widgets:
45 | widget.is_localized = self.is_localized
46 | # value is a list of values, each corresponding to a widget
47 | # in self.widgets.
48 | if not isinstance(value, list):
49 | value = self.decompress(value)
50 | final_attrs = self.build_attrs(attrs or {})
51 | base_id_ = final_attrs.get('id')
52 | for i, widget in enumerate(self.widgets):
53 | try:
54 | widget_value = value[i]
55 | except IndexError:
56 | widget_value = None
57 | widget.label = self.labels[i]
58 | id_name = self.id_names[i]
59 | if base_id_:
60 | id_ = '%s_%s' % (base_id_, id_name)
61 | final_attrs = dict(final_attrs, id=id_)
62 | if widget.id_for_label:
63 | widget.label_for = id_
64 | else:
65 | widget.label_for = ''
66 | else:
67 | widget.label_for = ''
68 | widget.html = mark_safe(widget.render('%s_%s' % (name, id_name), widget_value, final_attrs))
69 | context = {'widgets': self.widgets}
70 | return self._render(self.template_name, context, renderer)
71 |
72 | def value_from_datadict(self, data, files, name):
73 | return [widget.value_from_datadict(data, files, "%s_%s" % (name, self.id_names[i])) for i, widget in
74 | enumerate(self.widgets)]
75 |
76 | def decompress(self, value):
77 | if not value:
78 | return []
79 | values = [value[self.names[label]] for label in self.labels]
80 | return values
81 |
82 | def value_omitted_from_data(self, data, files, name):
83 | return False
84 |
85 | def __deepcopy__(self, memo):
86 | obj = super(NestedFormWidget, self).__deepcopy__(memo)
87 | obj.labels = copy.deepcopy(self.labels)
88 | return obj
--------------------------------------------------------------------------------
/docs/arrays.rst:
--------------------------------------------------------------------------------
1 | ArrayField
2 | ==========
3 |
4 | Basic Usage
5 | -----------
6 |
7 | To use the ArrayField::
8 |
9 | from django.db import models
10 | from django_postgres_extensions.models.fields import ArrayField
11 | class Product(models.Model):
12 | tags = ArrayField(models.CharField(max_length=15), null=True, blank=True)
13 | moretags = ArrayField(models.CharField(max_length=15), null=True, blank=True)
14 |
15 | Array Indexes
16 | -------------
17 |
18 | - Get array values by index::
19 |
20 | from django_postgres_extensions.models.expressions import Index, SliceArray
21 | obj = Product.objects.annotate(Index('tags', 1)).get()
22 | print(obj.tags__1)
23 | obj = Product.objects.annotate(tag_1=Index('tags', 1)).get()
24 | print(obj.tag_1)
25 | obj = Product.objects.annotate(SliceArray('tags', 0, 1)).get()
26 | print(obj.tags__0_1)
27 |
28 | - Update array values by index::
29 |
30 | Product.objects.update(tags__2='Heavy Metal')
31 |
32 | Database Functions
33 | ------------------
34 |
35 | Various database functions are included for manipulating arrays:
36 |
37 | - ArrayLength: returns the length of an array
38 |
39 | - ArrayPosition: the position of an item in an array
40 |
41 | - ArrayPositions: all positions of an item in an array
42 |
43 | - ArrayAppend: Create an array value by adding a value to the end of an array field
44 |
45 | - ArrayPrepend: Create an array value by adding a value to the start of an array field
46 |
47 | - ArrayRemove: Create an array value by removing a value from an array field
48 |
49 | - ArrayReplace: Create an array value by replacing one value with another in an array field
50 |
51 | - ArrayCat: Combine the values of two separate ArrayFields
52 |
53 | For more information on each of these functions, check the postgresql documentation.
54 | The provided arguments to each function are automatically converted to the required expressions::
55 |
56 | from django_postgres_extensions.models.functions import *
57 | obj = Product.objects.queryset.annotate(tags_length=ArrayLength('tags', 1)).get()
58 | obj = Product.objects.annotate(position=ArrayPosition('tags', 'Rock')).get()
59 | obj = Product.objects.annotate(positions=ArrayPositions('tags', 'Rock')).get()
60 | Product.objects.update(tags = ArrayAppend('tags', 'Popular'))
61 | Product.objects.update(tags = ArrayPrepend('Popular', 'tags'))
62 | Product.objects.update(tags = ArrayRemove('tags', 'Album'))
63 | Product.objects.update(tags = ArrayReplace('tags', 'Rock', 'Heavy Metal'))
64 | Product.objects.update(tags = ArrayCat('tags', 'moretags'))
65 | Product.objects.update(tags=ArrayCat('tags', ['Popular', '8'], output_field=Product._meta.get_field('tags')))
66 |
67 |
68 | Use in ModelForms
69 | -----------------
70 |
71 | django.contrib.postgres includes two possible array form fields: SimpleArrayField (the default) and SplitArrayField.
72 | To use the SplitArrayField automatically when generating a ModelForm, add the form_size keyword argument to the ArrayField::
73 |
74 | class Product(models.Model):
75 | keywords = ArrayField(models.CharField(max_length=20), default=[], form_size=10, blank=True)
76 |
77 | The field would look like:
78 |
79 | .. image:: array_split.jpg
80 |
81 | Alternatively, it is possible to use a Multiple Choice Field for an Array, by specifying a choices argument::
82 |
83 | sports = ArrayField(models.CharField(max_length=20),default=[], blank=True, choices=(
84 | ('football', 'Football'), ('tennis', 'Tennis'), ('golf', 'Golf'), ('basketball', 'Basketball'), ('hurling', 'Hurling'), ('baseball', 'Baseball')))
85 |
86 |
87 | The field would look like:
88 |
89 | .. image:: array_choice.jpg
90 |
91 | Array Lookups
92 | -------------
93 |
94 | Additional lookups have been added to the ArrayField to enable queries using the ANY and ALL database functions::
95 |
96 | qs = Product.objects.filter(tags__any = 'Popular')
97 | qs = Product.objects.filter(tags_all__isstartof = 'Popular')
98 |
99 | Any lookups check if any value in the array meets the lookup criteria.
100 | All lookups check is all values in an array meet the lookup criteria.
101 | The full list of additional lookups are:
102 |
103 | - any
104 | - any_exact
105 | - any_gt
106 | - any_gte
107 | - any_lt
108 | - any_lte
109 | - any_in
110 | - any_isstartof
111 | - any_isendof
112 | - any_contains (for 2d arrays)
113 | - all
114 | - all_exact
115 | - all_gt
116 | - all_gte
117 | - all_lt
118 | - all_lte
119 | - all_in
120 | - all_isstartof
121 | - all_isendof
122 | - all_regex
--------------------------------------------------------------------------------
/tests/prefetch_related_array/test_uuid.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals
2 |
3 | from django.test import TestCase
4 |
5 | from .models import Flea, House, Person, Pet, Room
6 |
7 |
8 | class UUIDPrefetchRelated(TestCase):
9 |
10 | def test_prefetch_related_from_uuid_model(self):
11 | Pet.objects.create(name='Fifi').people.add(
12 | Person.objects.create(name='Ellen'),
13 | Person.objects.create(name='George'),
14 | )
15 |
16 | with self.assertNumQueries(2):
17 | pet = Pet.objects.prefetch_related('people').get(name='Fifi')
18 | with self.assertNumQueries(0):
19 | self.assertEqual(2, len(pet.people.all()))
20 |
21 | def test_prefetch_related_to_uuid_model(self):
22 | Person.objects.create(name='Bella').pets.add(
23 | Pet.objects.create(name='Socks'),
24 | Pet.objects.create(name='Coffee'),
25 | )
26 |
27 | with self.assertNumQueries(2):
28 | person = Person.objects.prefetch_related('pets').get(name='Bella')
29 | with self.assertNumQueries(0):
30 | self.assertEqual(2, len(person.pets.all()))
31 |
32 | def test_prefetch_related_from_uuid_model_to_uuid_model(self):
33 | fleas = [Flea.objects.create() for i in range(3)]
34 | Pet.objects.create(name='Fifi').fleas_hosted.add(*fleas)
35 | Pet.objects.create(name='Bobo').fleas_hosted.add(*fleas)
36 |
37 | with self.assertNumQueries(2):
38 | pet = Pet.objects.prefetch_related('fleas_hosted').get(name='Fifi')
39 | with self.assertNumQueries(0):
40 | self.assertEqual(3, len(pet.fleas_hosted.all()))
41 |
42 | with self.assertNumQueries(2):
43 | flea = Flea.objects.prefetch_related('pets_visited').get(pk=fleas[0].pk)
44 | with self.assertNumQueries(0):
45 | self.assertEqual(2, len(flea.pets_visited.all()))
46 |
47 | def test_prefetch_related_from_uuid_model_to_uuid_model_with_values_flat(self):
48 | pet = Pet.objects.create(name='Fifi')
49 | pet.people.add(
50 | Person.objects.create(name='Ellen'),
51 | Person.objects.create(name='George'),
52 | )
53 | self.assertSequenceEqual(
54 | Pet.objects.prefetch_related('fleas_hosted').values_list('id', flat=True),
55 | [pet.id]
56 | )
57 |
58 |
59 | class UUIDPrefetchRelatedLookups(TestCase):
60 |
61 | @classmethod
62 | def setUpTestData(cls):
63 | house = House.objects.create(name='Redwood', address='Arcata')
64 | room = Room.objects.create(name='Racoon', house=house)
65 | fleas = [Flea.objects.create(current_room=room) for i in range(3)]
66 | pet = Pet.objects.create(name='Spooky')
67 | pet.fleas_hosted.add(*fleas)
68 | person = Person.objects.create(name='Bob')
69 | person.houses.add(house)
70 | person.pets.add(pet)
71 | person.fleas_hosted.add(*fleas)
72 |
73 | def test_from_uuid_pk_lookup_uuid_pk_integer_pk(self):
74 | # From uuid-pk model, prefetch .:
75 | with self.assertNumQueries(4):
76 | spooky = Pet.objects.prefetch_related('fleas_hosted__current_room__house').get(name='Spooky')
77 | with self.assertNumQueries(0):
78 | self.assertEqual('Racoon', spooky.fleas_hosted.all()[0].current_room.name)
79 |
80 | def test_from_uuid_pk_lookup_integer_pk2_uuid_pk2(self):
81 | # From uuid-pk model, prefetch ...:
82 | with self.assertNumQueries(5):
83 | spooky = Pet.objects.prefetch_related('people__houses__rooms__fleas').get(name='Spooky')
84 | with self.assertNumQueries(0):
85 | self.assertEqual(3, len(spooky.people.all()[0].houses.all()[0].rooms.all()[0].fleas.all()))
86 |
87 | def test_from_integer_pk_lookup_uuid_pk_integer_pk(self):
88 | # From integer-pk model, prefetch .:
89 | with self.assertNumQueries(3):
90 | racoon = Room.objects.prefetch_related('fleas__people_visited').get(name='Racoon')
91 | with self.assertNumQueries(0):
92 | self.assertEqual('Bob', racoon.fleas.all()[0].people_visited.all()[0].name)
93 |
94 | def test_from_integer_pk_lookup_integer_pk_uuid_pk(self):
95 | # From integer-pk model, prefetch .:
96 | with self.assertNumQueries(3):
97 | redwood = House.objects.prefetch_related('rooms__fleas').get(name='Redwood')
98 | with self.assertNumQueries(0):
99 | self.assertEqual(3, len(redwood.rooms.all()[0].fleas.all()))
100 |
101 | def test_from_integer_pk_lookup_integer_pk_uuid_pk_uuid_pk(self):
102 | # From integer-pk model, prefetch ..:
103 | with self.assertNumQueries(4):
104 | redwood = House.objects.prefetch_related('rooms__fleas__pets_visited').get(name='Redwood')
105 | with self.assertNumQueries(0):
106 | self.assertEqual('Spooky', redwood.rooms.all()[0].fleas.all()[0].pets_visited.all()[0].name)
107 |
--------------------------------------------------------------------------------
/django_postgres_extensions/models/functions.py:
--------------------------------------------------------------------------------
1 | from django.db.models.expressions import Func, Expression
2 | from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE
3 | from django.utils import six
4 | from .expressions import F, Value as V
5 |
6 | class SimpleFunc(Func):
7 |
8 | def __init__(self, field, *values, **extra):
9 | if not isinstance(field, Expression):
10 | field = F(field)
11 | if values and not isinstance(values[0], Expression):
12 | values = [V(v) for v in values]
13 | super(SimpleFunc, self).__init__(field, *values, **extra)
14 |
15 | class TooManyExpressionsError(Exception):
16 | pass
17 |
18 | def multi_func(func, expression, *args):
19 | if len(args) > GET_ITERATOR_CHUNK_SIZE:
20 | raise TooManyExpressionsError('Multi-func given %s args. The limit is %s due to Python recursion depth risk' % (
21 | len(args), GET_ITERATOR_CHUNK_SIZE))
22 | args = list(args)
23 | initial_arg = args.pop(0)
24 | query = func(expression, initial_arg)
25 | for arg in args:
26 | query = func(query, arg)
27 | return query
28 |
29 | def multi_array_remove(field, *args):
30 | return multi_func(ArrayRemove, field, *args)
31 |
32 | class ArrayAppend(SimpleFunc):
33 | function = 'ARRAY_APPEND'
34 |
35 | class ArrayPrepend(Func):
36 | function = 'ARRAY_PREPEND'
37 |
38 | def __init__(self, value, field, **extra):
39 | if not isinstance(value, Expression):
40 | value = V(value)
41 | field = F(field)
42 | super(ArrayPrepend, self).__init__(value, field, **extra)
43 |
44 | class ArrayRemove(SimpleFunc):
45 | function = 'ARRAY_REMOVE'
46 |
47 | class ArrayReplace(SimpleFunc):
48 | function = 'ARRAY_REPLACE'
49 |
50 | class ArrayPosition(SimpleFunc):
51 | function = 'ARRAY_POSITION'
52 |
53 | class ArrayPositions(SimpleFunc):
54 | function = 'ARRAY_POSITIONS'
55 |
56 | class ArrayCat(Func):
57 | function = 'ARRAY_CAT'
58 |
59 | def __init__(self, field, value, prepend=False, output_field=None, **extra):
60 | if not isinstance(field, Expression):
61 | field = F(field)
62 | if not isinstance(value, Expression):
63 | if isinstance(value, six.string_types):
64 | value = F(value)
65 | elif output_field:
66 | value = V(value, output_field = output_field)
67 | else:
68 | value = V(value)
69 | if prepend:
70 | super(ArrayCat, self).__init__(value, field, **extra)
71 | else:
72 | super(ArrayCat, self).__init__(field, value, **extra)
73 |
74 | class ArrayLength(SimpleFunc):
75 | function = 'ARRAY_LENGTH'
76 |
77 | class ArrayDims(SimpleFunc):
78 | function = 'ARRAY_DIMS'
79 |
80 | class ArrayUpper(SimpleFunc):
81 | function = 'ARRAY_UPPER'
82 |
83 | class ArrayLower(SimpleFunc):
84 | function = 'ARRAY_LOWER'
85 |
86 | class Cardinality(SimpleFunc):
87 | function = 'CARDINALITY'
88 |
89 | class NonFieldFunc(Func):
90 | def __init__(self, *values, **extra):
91 | values = list(values)
92 | for i, value in enumerate(values):
93 | if not isinstance(value, Expression):
94 | values[i] = V(value)
95 | super(NonFieldFunc, self).__init__(*values, **extra)
96 |
97 | class HStore(NonFieldFunc):
98 | function = 'HSTORE'
99 |
100 | class AKeys(SimpleFunc):
101 | function = 'AKEYS'
102 |
103 | class SKeys(SimpleFunc):
104 | function = 'SKEYS'
105 |
106 | class AVals(SimpleFunc):
107 | function = 'AVALS'
108 |
109 | class SVals(SimpleFunc):
110 | function = 'SVALS'
111 |
112 | class HStoreToArray(SimpleFunc):
113 | function = 'HSTORE_TO_ARRAY'
114 |
115 | class HStoreToMatrix(SimpleFunc):
116 | function = 'HSTORE_TO_MATRIX'
117 |
118 | class Slice(SimpleFunc):
119 | function = 'SLICE'
120 |
121 | class Delete(SimpleFunc):
122 | function = 'DELETE'
123 |
124 | class Each(SimpleFunc):
125 | function = 'EACH'
126 |
127 | class HstoreToJSONB(SimpleFunc):
128 | function = 'HSTORE_TO_JSONB'
129 |
130 | class HstoreToJSONBLoose(SimpleFunc):
131 | function = 'HSTORE_TO_JSONB_LOOSE'
132 |
133 | class ToJSONB(NonFieldFunc):
134 | function = 'TO_JSONB'
135 |
136 | class RowToJSON(SimpleFunc):
137 | function = 'ROW_TO_JSON'
138 |
139 | class ArrayToJSON(SimpleFunc):
140 | function = 'ARRAY_TO_JSON'
141 |
142 | class JSONBBuildArray(NonFieldFunc):
143 | function = 'JSONB_BUILD_ARRAY'
144 |
145 | class JSONBArrayElements(SimpleFunc):
146 | function = 'JSONB_ARRAY_ELEMENTS'
147 |
148 | class JSONBBuildObject(NonFieldFunc):
149 | function = 'JSONB_BUILD_OBJECT'
150 |
151 | class JSONBObject(NonFieldFunc):
152 | function = 'JOSNB_OBJECT'
153 |
154 | class JSONBSet(SimpleFunc):
155 | function = 'JSONB_SET'
156 |
157 | class JSONBArrayLength(SimpleFunc):
158 | function = 'JSONB_ARRAY_length'
159 |
160 | class JSONBPretty(SimpleFunc):
161 | function = 'JSONB_PRETTY'
162 |
163 | class JSONObjectKeys(SimpleFunc):
164 | function = 'JSON_OBJECT_KEYS'
165 |
166 | class JSONStripNulls(SimpleFunc):
167 | function = 'JSON_STRIP_NULLS'
168 |
169 | class JSONTypeOf(SimpleFunc):
170 | function = 'JSON_TYPE_OF'
--------------------------------------------------------------------------------
/tests/prefetch_related_array/test_prefetch_related_objects.py:
--------------------------------------------------------------------------------
1 | from django.db.models import Prefetch
2 | from django.db.models.query import prefetch_related_objects
3 | from django.test import TestCase
4 |
5 | from .models import Author, Book, Reader
6 |
7 |
8 | class PrefetchRelatedObjectsTests(TestCase):
9 | """
10 | Since prefetch_related_objects() is just the inner part of
11 | prefetch_related(), only do basic tests to ensure its API hasn't changed.
12 | """
13 | @classmethod
14 | def setUpTestData(cls):
15 | cls.book1 = Book.objects.create(title='Poems')
16 | cls.book2 = Book.objects.create(title='Jane Eyre')
17 | cls.book3 = Book.objects.create(title='Wuthering Heights')
18 | cls.book4 = Book.objects.create(title='Sense and Sensibility')
19 |
20 | cls.author1 = Author.objects.create(name='Charlotte', first_book=cls.book1)
21 | cls.author2 = Author.objects.create(name='Anne', first_book=cls.book1)
22 | cls.author3 = Author.objects.create(name='Emily', first_book=cls.book1)
23 | cls.author4 = Author.objects.create(name='Jane', first_book=cls.book4)
24 |
25 | cls.book1.authors.add(cls.author1, cls.author2, cls.author3)
26 | cls.book2.authors.add(cls.author1)
27 | cls.book3.authors.add(cls.author3)
28 | cls.book4.authors.add(cls.author4)
29 |
30 | cls.reader1 = Reader.objects.create(name='Amy')
31 | cls.reader2 = Reader.objects.create(name='Belinda')
32 |
33 | cls.reader1.books_read.add(cls.book1, cls.book4)
34 | cls.reader2.books_read.add(cls.book2, cls.book4)
35 |
36 | def test_unknown(self):
37 | book1 = Book.objects.get(id=self.book1.id)
38 | with self.assertRaises(AttributeError):
39 | prefetch_related_objects([book1], 'unknown_attribute')
40 |
41 | def test_m2m_forward(self):
42 | book1 = Book.objects.get(id=self.book1.id)
43 | with self.assertNumQueries(1):
44 | prefetch_related_objects([book1], 'authors')
45 |
46 | with self.assertNumQueries(0):
47 | self.assertEqual(set(book1.authors.all()), {self.author1, self.author2, self.author3})
48 |
49 | def test_m2m_reverse(self):
50 | author1 = Author.objects.get(id=self.author1.id)
51 | with self.assertNumQueries(1):
52 | prefetch_related_objects([author1], 'books')
53 |
54 | with self.assertNumQueries(0):
55 | self.assertEqual(set(author1.books.all()), {self.book1, self.book2})
56 |
57 | def test_foreignkey_forward(self):
58 | authors = list(Author.objects.all())
59 | with self.assertNumQueries(1):
60 | prefetch_related_objects(authors, 'first_book')
61 |
62 | with self.assertNumQueries(0):
63 | [author.first_book for author in authors]
64 |
65 | def test_foreignkey_reverse(self):
66 | books = list(Book.objects.all())
67 | with self.assertNumQueries(1):
68 | prefetch_related_objects(books, 'first_time_authors')
69 |
70 | with self.assertNumQueries(0):
71 | [list(book.first_time_authors.all()) for book in books]
72 |
73 | def test_m2m_then_m2m(self):
74 | """
75 | We can follow a m2m and another m2m.
76 | """
77 | authors = list(Author.objects.all())
78 | with self.assertNumQueries(2):
79 | prefetch_related_objects(authors, 'books__read_by')
80 |
81 | with self.assertNumQueries(0):
82 | self.assertEqual(
83 | [
84 | [[str(r) for r in b.read_by.all()] for b in a.books.all()]
85 | for a in authors
86 | ],
87 | [
88 | [['Amy'], ['Belinda']], # Charlotte - Poems, Jane Eyre
89 | [['Amy']], # Anne - Poems
90 | [['Amy'], []], # Emily - Poems, Wuthering Heights
91 | [['Amy', 'Belinda']], # Jane - Sense and Sense
92 | ]
93 | )
94 |
95 | def test_prefetch_object(self):
96 | book1 = Book.objects.get(id=self.book1.id)
97 | with self.assertNumQueries(1):
98 | prefetch_related_objects([book1], Prefetch('authors'))
99 |
100 | with self.assertNumQueries(0):
101 | self.assertEqual(set(book1.authors.all()), {self.author1, self.author2, self.author3})
102 |
103 | def test_prefetch_object_to_attr(self):
104 | book1 = Book.objects.get(id=self.book1.id)
105 | with self.assertNumQueries(1):
106 | prefetch_related_objects([book1], Prefetch('authors', to_attr='the_authors'))
107 |
108 | with self.assertNumQueries(0):
109 | self.assertEqual(set(book1.the_authors), {self.author1, self.author2, self.author3})
110 |
111 | def test_prefetch_queryset(self):
112 | book1 = Book.objects.get(id=self.book1.id)
113 | with self.assertNumQueries(1):
114 | prefetch_related_objects(
115 | [book1],
116 | Prefetch('authors', queryset=Author.objects.filter(id__in=[self.author1.id, self.author2.id]))
117 | )
118 |
119 | with self.assertNumQueries(0):
120 | self.assertEqual(set(book1.authors.all()), {self.author1, self.author2})
121 |
--------------------------------------------------------------------------------
/tests/jsonb/tests.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals, absolute_import
2 |
3 | from django.test import TestCase
4 | from .models import Product
5 | from django_postgres_extensions.models.functions import *
6 | from django_postgres_extensions.models.expressions import Key
7 | from psycopg2.extras import Json
8 | from django.db import transaction
9 |
10 | class JSONIndexTests(TestCase):
11 |
12 | def setUp(self):
13 | super(JSONIndexTests, self).setUp()
14 | self.product = Product(name='xyz', description={'Industry': 'Music', 'Details': {'Release': 'Album', 'Genre': 'Rock', 'Rating': 8}, 'Price': 9.99, 'Tags': ['Heavy', 'Guitar']})
15 | self.product.save()
16 | self.queryset = Product.objects.filter(pk=self.product.pk)
17 | self.pk_queryset = self.queryset.only('id')
18 |
19 | def tearDown(self):
20 | Product.objects.all().delete()
21 |
22 | def test_json_value(self):
23 | product = self.queryset.get()
24 | self.assertDictEqual(product.description,
25 | {'Industry': 'Music', 'Details': {'Release': 'Album', 'Genre': 'Rock', 'Rating': 8},
26 | 'Price': 9.99, 'Tags': ['Heavy', 'Guitar']})
27 |
28 | def test_json_value_by_key(self):
29 | with transaction.atomic():
30 | obj = self.pk_queryset.annotate(Key('description', 'Details')).get()
31 | self.assertDictEqual(obj.description__Details, {'Genre': 'Rock', 'Rating': 8, 'Release': 'Album'})
32 |
33 | def test_json_value_by_key_path(self):
34 | with transaction.atomic():
35 | obj = self.pk_queryset.annotate(Key('description', 'Details__Rating')).get()
36 | self.assertEqual(obj.description__Details__Rating, 8)
37 | with transaction.atomic():
38 | obj = self.pk_queryset.annotate(Key('description', 'Tags__1')).get()
39 | self.assertEqual(obj.description__Tags__1, "Guitar")
40 |
41 | def test_json_update_keys_values(self):
42 | with transaction.atomic():
43 | self.queryset.update(description__ = {'Industry': 'Movie', 'Popularity': 'Very Popular'})
44 | product = self.queryset.get()
45 | self.assertDictEqual(product.description,
46 | {'Industry': 'Movie', 'Details': {'Release': 'Album', 'Genre': 'Rock', 'Rating': 8},
47 | 'Price': 9.99, 'Popularity': 'Very Popular', 'Tags': ['Heavy', 'Guitar']})
48 |
49 | def test_json_update_delete_key(self):
50 | with transaction.atomic():
51 | self.queryset.update(description__del ='Details')
52 | product = self.queryset.get()
53 | self.assertDictEqual(product.description,
54 | {'Industry': 'Music', 'Price': 9.99, 'Tags': ['Heavy', 'Guitar']})
55 |
56 | def test_json_update_delete_key_path(self):
57 | with transaction.atomic():
58 | self.queryset.update(description__del = 'Details__Release')
59 | product = self.queryset.get()
60 | self.assertDictEqual(product.description,
61 | {'Industry': 'Music', 'Details': {'Genre': 'Rock', 'Rating': 8},
62 | 'Price': 9.99, 'Tags': ['Heavy', 'Guitar']})
63 | with transaction.atomic():
64 | self.queryset.update(description__del='Tags__1')
65 | product = self.queryset.get()
66 | self.assertDictEqual(product.description,
67 | {'Industry': 'Music', 'Details': {'Genre': 'Rock', 'Rating': 8},
68 | 'Price': 9.99, 'Tags': ['Heavy']})
69 |
70 | class JSONFuncTests(TestCase):
71 | def setUp(self):
72 | super(JSONFuncTests, self).setUp()
73 | self.product = Product(name='xyz', description={'Industry': 'Music', 'Details': {'Release': 'Album', 'Genre': 'Rock', 'Rating': 8}, 'Price': 9.99, 'Tags': ['Heavy', 'Guitar']})
74 | self.product.save()
75 | self.queryset = Product.objects.filter(pk=self.product.pk)
76 | self.product2 = Product(name='xyz', description=[{'a': 'b', 'c':'d'}, {'a': 'e', 'c': 'f'}])
77 | self.product2.save()
78 | self.queryset2 = Product.objects.filter(pk=self.product2.pk)
79 |
80 | def tearDown(self):
81 | Product.objects.all().delete()
82 |
83 | def test_jsonb_set(self):
84 | with transaction.atomic():
85 | self.queryset.update(description = JSONBSet('description', ['Details', 'Genre'], Json('Heavy Metal'), True))
86 | obj = self.queryset.get()
87 | self.assertDictEqual(obj.description,
88 | {'Price': 9.99, 'Industry': 'Music', 'Details': {'Genre': 'Heavy Metal', 'Release': 'Album', 'Rating': 8}, 'Tags': ['Heavy', 'Guitar']})
89 |
90 | def test_jsonb_array_set(self):
91 | with transaction.atomic():
92 | self.queryset2.update(description=JSONBSet('description', ['1', 'c'], Json('g')))
93 | obj = self.queryset2.get()
94 | self.assertListEqual(obj.description,
95 | [{'a': 'b', 'c': 'd'}, {'a': 'e', 'c': 'g'}])
96 |
97 | def test_jsonb_array_length(self):
98 | with transaction.atomic():
99 | qs = self.queryset2.format('description', JSONBArrayLength, output_field='desc_length')
100 | obj = qs.get()
101 | self.assertEqual(obj.desc_length, 2)
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Django Postgres Extensions documentation build configuration file, created by
5 | # sphinx-quickstart on Wed Feb 8 21:57:16 2017.
6 | #
7 | # This file is execfile()d with the current directory set to its
8 | # containing dir.
9 | #
10 | # Note that not all possible configuration values are present in this
11 | # autogenerated file.
12 | #
13 | # All configuration values have a default; values that are commented out
14 | # serve to show the default.
15 |
16 | # If extensions (or modules to document with autodoc) are in another directory,
17 | # add these directories to sys.path here. If the directory is relative to the
18 | # documentation root, use os.path.abspath to make it absolute, like shown here.
19 | #
20 | # import os
21 | # import sys
22 | # sys.path.insert(0, os.path.abspath('.'))
23 |
24 |
25 | # -- General configuration ------------------------------------------------
26 |
27 | # If your documentation needs a minimal Sphinx version, state it here.
28 | #
29 | # needs_sphinx = '1.0'
30 |
31 | # Add any Sphinx extension module names here, as strings. They can be
32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
33 | # ones.
34 | extensions = ['sphinx.ext.autodoc',
35 | 'sphinx.ext.doctest',
36 | 'sphinxcontrib.fancybox']
37 |
38 | # Add any paths that contain templates here, relative to this directory.
39 | templates_path = ['_templates']
40 |
41 | # The suffix(es) of source filenames.
42 | # You can specify multiple suffix as a list of string:
43 | #
44 | # source_suffix = ['.rst', '.md']
45 | source_suffix = '.rst'
46 |
47 | # The master toctree document.
48 | master_doc = 'index'
49 |
50 | # General information about the project.
51 | project = 'Django Postgres Extensions'
52 | copyright = '2017, Paul Martin'
53 | author = 'Paul Martin'
54 |
55 | # The version info for the project you're documenting, acts as replacement for
56 | # |version| and |release|, also used in various other places throughout the
57 | # built documents.
58 | #
59 | # The short X.Y version.
60 | version = '0.9.2'
61 | # The full version, including alpha/beta/rc tags.
62 | release = '0.9.2'
63 |
64 | # The language for content autogenerated by Sphinx. Refer to documentation
65 | # for a list of supported languages.
66 | #
67 | # This is also used if you do content translation via gettext catalogs.
68 | # Usually you set "language" from the command line for these cases.
69 | language = 'python'
70 |
71 | # List of patterns, relative to source directory, that match files and
72 | # directories to ignore when looking for source files.
73 | # This patterns also effect to html_static_path and html_extra_path
74 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
75 |
76 | # The name of the Pygments (syntax highlighting) style to use.
77 | pygments_style = 'sphinx'
78 |
79 | # If true, `todo` and `todoList` produce output, else they produce nothing.
80 | todo_include_todos = False
81 |
82 |
83 | # -- Options for HTML output ----------------------------------------------
84 |
85 | # The theme to use for HTML and HTML Help pages. See the documentation for
86 | # a list of builtin themes.
87 | #
88 | html_theme = 'default'
89 |
90 | # Theme options are theme-specific and customize the look and feel of a theme
91 | # further. For a list of options available for each theme, see the
92 | # documentation.
93 | #
94 | # html_theme_options = {}
95 |
96 | # Add any paths that contain custom static files (such as style sheets) here,
97 | # relative to this directory. They are copied after the builtin static files,
98 | # so a file named "default.css" will overwrite the builtin "default.css".
99 | html_static_path = ['_static']
100 |
101 |
102 | # -- Options for HTMLHelp output ------------------------------------------
103 |
104 | # Output file base name for HTML help builder.
105 | htmlhelp_basename = 'DjangoPostgresExtensionsdoc'
106 |
107 |
108 | # -- Options for LaTeX output ---------------------------------------------
109 |
110 | latex_elements = {
111 | # The paper size ('letterpaper' or 'a4paper').
112 | #
113 | # 'papersize': 'letterpaper',
114 |
115 | # The font size ('10pt', '11pt' or '12pt').
116 | #
117 | # 'pointsize': '10pt',
118 |
119 | # Additional stuff for the LaTeX preamble.
120 | #
121 | # 'preamble': '',
122 |
123 | # Latex figure (float) alignment
124 | #
125 | # 'figure_align': 'htbp',
126 | }
127 |
128 | # Grouping the document tree into LaTeX files. List of tuples
129 | # (source start file, target name, title,
130 | # author, documentclass [howto, manual, or own class]).
131 | latex_documents = [
132 | (master_doc, 'DjangoPostgresExtensions.tex', 'Django Postgres Extensions Documentation',
133 | 'Paul Martin', 'manual'),
134 | ]
135 |
136 |
137 | # -- Options for manual page output ---------------------------------------
138 |
139 | # One entry per manual page. List of tuples
140 | # (source start file, name, description, authors, manual section).
141 | man_pages = [
142 | (master_doc, 'djangopostgresextensions', 'Django Postgres Extensions Documentation',
143 | [author], 1)
144 | ]
145 |
146 |
147 | # -- Options for Texinfo output -------------------------------------------
148 |
149 | # Grouping the document tree into Texinfo files. List of tuples
150 | # (source start file, target name, title, author,
151 | # dir menu entry, description, category)
152 | texinfo_documents = [
153 | (master_doc, 'DjangoPostgresExtensions', 'Django Postgres Extensions Documentation',
154 | author, 'DjangoPostgresExtensions', 'One line description of project.',
155 | 'Miscellaneous'),
156 | ]
157 |
158 | html_domain_indices = False
--------------------------------------------------------------------------------
/tests/nested_form_field_widget/tests.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from django.test import TestCase
3 | from django.forms import CharField, Form, TextInput, FileInput
4 | from django_postgres_extensions import forms
5 | from django_postgres_extensions.models.fields import HStoreField
6 |
7 | class NestedFormWidgetTest(TestCase):
8 |
9 | def check_html(self, widget, name, value, html='', attrs=None, **kwargs):
10 | output = widget.render(name, value, attrs=attrs, **kwargs)
11 | self.assertHTMLEqual(output, html)
12 |
13 | def test_text_inputs(self):
14 | widget = forms.NestedFormWidget(
15 | ('A', 'B', 'C'),
16 | ((TextInput()),
17 | (TextInput()),
18 | (TextInput())
19 | )
20 | )
21 | self.check_html(widget, 'name', ['john', 'winston', 'lennon'], html=""""""
26 | )
27 |
28 | def test_constructor_attrs(self):
29 | widget = forms.NestedFormWidget(
30 | ('A', 'B', 'C'),
31 | ((TextInput()),
32 | (TextInput()),
33 | (TextInput())
34 | ),
35 | attrs={'id': 'bar'},
36 | )
37 | self.check_html(widget, 'name', ['john', 'winston', 'lennon'], attrs={'id': 'bar'}, html="""
42 | """
43 | )
44 |
45 | def test_needs_multipart_true(self):
46 | """
47 | needs_multipart_form should be True if any widgets need it.
48 | """
49 | widget = forms.NestedFormWidget(
50 | ('text', 'file'),
51 | (TextInput(), FileInput())
52 | )
53 | self.assertTrue(widget.needs_multipart_form)
54 |
55 | def test_needs_multipart_false(self):
56 | """
57 | needs_multipart_form should be False if no widgets need it.
58 | """
59 | widget = forms.NestedFormWidget(
60 | ('text', 'text2'),
61 | (TextInput(), TextInput())
62 | )
63 | self.assertFalse(widget.needs_multipart_form)
64 |
65 | def test_nested_multiwidget(self):
66 | """
67 | NestedFormWidget can be composed of other NestedFormWidgets.
68 | """
69 | widget = forms.NestedFormWidget(
70 | ('A', 'B'),
71 | (TextInput(), forms.NestedFormWidget(
72 | ('C', 'D'),
73 | (TextInput(), TextInput())
74 | )
75 | )
76 | )
77 | self.check_html(widget, 'name', ['Singer', ['John', 'Lennon']], html=(
78 | """
79 |
87 | """
88 | ))
89 |
90 | def test_deepcopy(self):
91 | """
92 | MultiWidget should define __deepcopy__() (#12048).
93 | """
94 | w1 = forms.NestedFormWidget(
95 | ['A', 'B', 'C'],
96 | (TextInput(),
97 | TextInput(),
98 | TextInput()
99 | )
100 | )
101 | w2 = copy.deepcopy(w1)
102 | w2.labels.append('d')
103 | # w2 ought to be independent of w1, since MultiWidget ought
104 | # to make a copy of its sub-widgets when it is copied.
105 | self.assertEqual(w1.labels, ['A', 'B', 'C'])
106 |
107 |
108 | class TestNestedFormField(TestCase):
109 |
110 | def test_valid(self):
111 | field = forms.NestedFormField(keys=('a', 'b', 'c'))
112 | value = field.clean(["d", '', "f"])
113 | self.assertEqual(value, {'a': 'd', 'b': '', 'c': 'f'})
114 |
115 | def test_model_field_formfield_keys(self):
116 | model_field = HStoreField(keys=('a', 'b', 'c'))
117 | form_field = model_field.formfield()
118 | self.assertIsInstance(form_field, forms.NestedFormField)
119 |
120 | def test_model_field_formfield_fields(self):
121 | model_field = HStoreField(fields=(
122 | ('a', CharField(max_length=10)),
123 | ('b', CharField(max_length=10)),
124 | ('c', CharField(max_length=10))
125 | )
126 | )
127 | form_field = model_field.formfield()
128 | self.assertIsInstance(form_field, forms.NestedFormField)
129 |
130 | def test_field_has_changed(self):
131 | class NestedFormTest(Form):
132 | f1 = forms.NestedFormField(keys=('a', 'b', 'c'))
133 | form_w_hstore = NestedFormTest()
134 | self.assertFalse(form_w_hstore.has_changed())
135 |
136 | form_w_hstore = NestedFormTest({'f1_a': 'd', 'fl_b': 'e', 'f1_c': 'f'})
137 | self.assertTrue(form_w_hstore.has_changed())
138 |
139 | form_w_hstore = NestedFormTest({'f1_a': 'g'},
140 | initial={'f1_a': 'd', 'fl_b': 'e', 'f1_c': 'f'})
141 | self.assertTrue(form_w_hstore.has_changed())
142 |
--------------------------------------------------------------------------------
/tests/m2m_recursive_array/tests.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals
2 |
3 | from operator import attrgetter
4 |
5 | from django.test import TestCase
6 |
7 | from .models import Person
8 |
9 |
10 | class RecursiveM2MTests(TestCase):
11 | def setUp(self):
12 | self.a, self.b, self.c, self.d = [
13 | Person.objects.create(name=name)
14 | for name in ["Anne", "Bill", "Chuck", "David"]
15 | ]
16 |
17 | # Anne is friends with Bill and Chuck
18 | self.a.friends.add(self.b, self.c)
19 |
20 | # David is friends with Anne and Chuck - add in reverse direction
21 | self.d.friends.add(self.a, self.c)
22 |
23 | def test_recursive_m2m_all(self):
24 | """ Test that m2m relations are reported correctly """
25 | # Who is friends with Anne?
26 | self.assertQuerysetEqual(
27 | self.a.friends.all(), [
28 | "Bill",
29 | "Chuck",
30 | "David"
31 | ],
32 | attrgetter("name"),
33 | ordered=False
34 | )
35 | # Who is friends with Bill?
36 | self.assertQuerysetEqual(
37 | self.b.friends.all(), [
38 | "Anne",
39 | ],
40 | attrgetter("name")
41 | )
42 | # Who is friends with Chuck?
43 | self.assertQuerysetEqual(
44 | self.c.friends.all(), [
45 | "Anne",
46 | "David"
47 | ],
48 | attrgetter("name"),
49 | ordered=False
50 | )
51 | # Who is friends with David?
52 | self.assertQuerysetEqual(
53 | self.d.friends.all(), [
54 | "Anne",
55 | "Chuck",
56 | ],
57 | attrgetter("name"),
58 | ordered=False
59 | )
60 |
61 | def test_recursive_m2m_reverse_add(self):
62 | """ Test reverse m2m relation is consistent """
63 |
64 | # Bill is already friends with Anne - add Anne again, but in the
65 | # reverse direction
66 | self.b.friends.add(self.a)
67 |
68 | # Who is friends with Anne?
69 | self.assertQuerysetEqual(
70 | self.a.friends.all(), [
71 | "Bill",
72 | "Chuck",
73 | "David",
74 | ],
75 | attrgetter("name"),
76 | ordered=False
77 | )
78 | # Who is friends with Bill?
79 | self.assertQuerysetEqual(
80 | self.b.friends.all(), [
81 | "Anne",
82 | ],
83 | attrgetter("name")
84 | )
85 |
86 | def test_recursive_m2m_remove(self):
87 | """ Test that we can remove items from an m2m relationship """
88 |
89 | # Remove Anne from Bill's friends
90 | self.b.friends.remove(self.a)
91 |
92 | # Who is friends with Anne?
93 | self.assertQuerysetEqual(
94 | self.a.friends.all(), [
95 | "Chuck",
96 | "David",
97 | ],
98 | attrgetter("name"),
99 | ordered=False
100 | )
101 | # Who is friends with Bill?
102 | self.assertQuerysetEqual(
103 | self.b.friends.all(), []
104 | )
105 |
106 | def test_recursive_m2m_clear(self):
107 | """ Tests the clear method works as expected on m2m fields """
108 |
109 | # Clear Anne's group of friends
110 | self.a.friends.clear()
111 |
112 | # Who is friends with Anne?
113 | self.assertQuerysetEqual(
114 | self.a.friends.all(), []
115 | )
116 |
117 | # Reverse relationships should also be gone
118 | # Who is friends with Chuck?
119 | self.assertQuerysetEqual(
120 | self.c.friends.all(), [
121 | "David",
122 | ],
123 | attrgetter("name")
124 | )
125 |
126 | # Who is friends with David?
127 | self.assertQuerysetEqual(
128 | self.d.friends.all(), [
129 | "Chuck",
130 | ],
131 | attrgetter("name")
132 | )
133 |
134 | def test_recursive_m2m_add_via_related_name(self):
135 | """ Tests that we can add m2m relations via the related_name attribute """
136 |
137 | # David is idolized by Anne and Chuck - add in reverse direction
138 | self.d.stalkers.add(self.a)
139 |
140 | # Who are Anne's idols?
141 | self.assertQuerysetEqual(
142 | self.a.idols.all(), [
143 | "David",
144 | ],
145 | attrgetter("name"),
146 | ordered=False
147 | )
148 | # Who is stalking Anne?
149 | self.assertQuerysetEqual(
150 | self.a.stalkers.all(), [],
151 | attrgetter("name")
152 | )
153 |
154 | def test_recursive_m2m_add_in_both_directions(self):
155 | """ Check that adding the same relation twice results in a single relation """
156 |
157 | # Ann idolizes David
158 | self.a.idols.add(self.d)
159 |
160 | # David is idolized by Anne
161 | self.d.stalkers.add(self.a)
162 |
163 | # Who are Anne's idols?
164 | self.assertQuerysetEqual(
165 | self.a.idols.all(), [
166 | "David",
167 | ],
168 | attrgetter("name"),
169 | ordered=False
170 | )
171 | # As the assertQuerysetEqual uses a set for comparison,
172 | # check we've only got David listed once
173 | self.assertEqual(self.a.idols.all().count(), 1)
174 |
175 | def test_recursive_m2m_related_to_self(self):
176 | """ Check the expected behavior when an instance is related to itself """
177 |
178 | # Ann idolizes herself
179 | self.a.idols.add(self.a)
180 |
181 | # Who are Anne's idols?
182 | self.assertQuerysetEqual(
183 | self.a.idols.all(), [
184 | "Anne",
185 | ],
186 | attrgetter("name"),
187 | ordered=False
188 | )
189 | # Who is stalking Anne?
190 | self.assertQuerysetEqual(
191 | self.a.stalkers.all(), [
192 | "Anne",
193 | ],
194 | attrgetter("name")
195 | )
196 |
--------------------------------------------------------------------------------
/django_postgres_extensions/models/fields/__init__.py:
--------------------------------------------------------------------------------
1 | from django.contrib.postgres import fields
2 | from django.contrib.postgres.forms import SplitArrayField as SplitArrayFormField
3 | from django.forms.fields import TypedMultipleChoiceField
4 | from psycopg2.extras import Json
5 | from django_postgres_extensions.forms.fields import NestedFormField
6 | from django_postgres_extensions.models.expressions import F, Value as V
7 | from django_postgres_extensions.models.functions import HStore, Delete, ArrayRemove
8 | from django_postgres_extensions.models.sql.updates import UpdateArrayByIndex
9 | from django.core import exceptions
10 |
11 |
12 | class ArrayField(fields.ArrayField):
13 |
14 | def __init__(self, base_field, form_size=None, **kwargs):
15 | super(ArrayField, self).__init__(base_field, **kwargs)
16 | self.form_size = form_size
17 |
18 | def get_update_type(self, indexes, value):
19 | if indexes == 'del':
20 | return ArrayRemove(self.name, value)
21 | if '__' in indexes:
22 | indexes = indexes.split('__')
23 | try:
24 | indexes = [int(index) + 1 for index in indexes]
25 | return UpdateArrayByIndex(indexes, value, self)
26 | except ValueError:
27 | raise ValueError('Update lookup type %s not found for field %s' % (indexes, self.name))
28 |
29 | def formfield(self, **kwargs):
30 | if self.form_size or self.choices:
31 | defaults = {
32 | 'form_class': SplitArrayFormField,
33 | 'base_field': self.base_field.formfield(),
34 | 'choices_form_class': TypedMultipleChoiceField,
35 | 'size': self.form_size,
36 | 'remove_trailing_nulls': True
37 | }
38 | if self.choices:
39 | defaults['coerce'] = self.base_field.to_python
40 | defaults.update(kwargs)
41 | return super(fields.ArrayField, self).formfield(**defaults)
42 | return super(ArrayField, self).formfield(**kwargs)
43 |
44 | def validate(self, value, model_instance):
45 | """
46 | Validates value and throws ValidationError. Subclasses should override
47 | this to provide validation logic.
48 | """
49 | if not self.editable:
50 | # Skip validation for non-editable fields.
51 | return
52 |
53 | if self.choices and value not in self.empty_values:
54 | if isinstance(value, (list, tuple)):
55 | option_keys = [x[0] for x in self.choices]
56 | if all(x in option_keys for x in value):
57 | return
58 | else:
59 | for option_key, option_value in self.choices:
60 | if isinstance(option_value, (list, tuple)):
61 | # This is an optgroup, so look inside the group for
62 | # options.
63 | for optgroup_key, optgroup_value in option_value:
64 | if value == optgroup_key:
65 | return
66 | elif value == option_key:
67 | return
68 | raise exceptions.ValidationError(
69 | self.error_messages['invalid_choice'],
70 | code='invalid_choice',
71 | params={'value': value},
72 | )
73 |
74 | if value is None and not self.null:
75 | raise exceptions.ValidationError(self.error_messages['null'], code='null')
76 |
77 | if not self.blank and value in self.empty_values:
78 | raise exceptions.ValidationError(self.error_messages['blank'], code='blank')
79 |
80 | def deconstruct(self):
81 | name, path, args, kwargs = super(ArrayField, self).deconstruct()
82 | kwargs.update({
83 | 'form_size': self.form_size,
84 | })
85 | return name, path, args, kwargs
86 |
87 | class HStoreField(fields.HStoreField):
88 |
89 | def __init__(self, fields=(), keys=(), max_value_length=25, require_all_fields=False, **kwargs):
90 | super(HStoreField, self).__init__(**kwargs)
91 | self.fields = fields
92 | self.keys = keys
93 | self.max_value_length = max_value_length
94 | self.require_all_fields = require_all_fields
95 |
96 | def get_update_type(self, lookups, value):
97 | lookup = lookups[0]
98 | if lookup == '' or lookup == 'raw':
99 | keys = list(value.keys())
100 | values = list(value.values())
101 | if lookup == '':
102 | values = [str(v) for v in value.values()]
103 | return F(self.name).cat(HStore(V(keys), V(values)))
104 | if lookup == 'del':
105 | return Delete(self.name, value)
106 | raise ValueError('Update lookup type %s not found for field %s' % (lookup, self.name))
107 |
108 | def formfield(self, **kwargs):
109 | if self.fields or self.keys:
110 | defaults = {
111 | 'form_class': NestedFormField,
112 | 'fields': self.fields,
113 | 'keys': list(self.keys),
114 | 'require_all_fields': self.require_all_fields,
115 | 'max_value_length': self.max_value_length
116 | }
117 | defaults.update(kwargs)
118 | else:
119 | defaults = kwargs
120 | return super(HStoreField, self).formfield(**defaults)
121 |
122 | class JSONField(fields.JSONField):
123 |
124 | def __init__(self, fields=(), require_all_fields=False, **kwargs):
125 | super(JSONField, self).__init__(**kwargs)
126 | self.fields = fields
127 | self.require_all_fields = require_all_fields
128 |
129 | def get_update_type(self, lookups, value):
130 | lookup = lookups[0]
131 | if lookup == '':
132 | return F(self.name).cat(V(Json(value)))
133 | if lookup == 'del':
134 | if '__' in value:
135 | values = value.split('__')
136 | return F(self.name).delete(V(values))
137 | return F(self.name) - V(value)
138 | raise ValueError('Update lookup type %s not found for field %s' % (lookup, self.name))
139 |
140 | def formfield(self, **kwargs):
141 | if self.fields:
142 | defaults = {
143 | 'form_class': NestedFormField,
144 | 'fields': self.fields,
145 | 'require_all_fields': self.require_all_fields,
146 | }
147 | defaults.update(kwargs)
148 | else:
149 | defaults = kwargs
150 | return super(JSONField, self).formfield(**defaults)
151 |
--------------------------------------------------------------------------------
/docs/features.rst:
--------------------------------------------------------------------------------
1 | Feature Overview
2 | ================
3 | Custom Postgres backend
4 | -----------------------
5 | The customized Postgres backend adds the following features:
6 |
7 | - HStore Extension is automatically activated when a test database is created so you don't need to create a separate migration which is useful especially when building re-usable apps
8 | - Uses a different update compiler which adds some functionality outlined in the ArrayField section below
9 | - If db_index is set to True for an ArrayField, a GIN index will be created which is more useful than the default database index for arrays.
10 | - Adds some extra operators to enable ANY and ALL lookups
11 |
12 | ArrayField
13 | ----------
14 | The included ArrayField has been subclassed from django.contrib.postgres.fields.ArrayField to add extra features and is a drop-in replacement. To use this ArrayField. The customized Postgres ArrayField adds the following features:
15 |
16 | - Get array values by index:
17 | - Update array values by index:
18 | - Added database functions for interacting with Arrays. These functions handle the provided arguments by automatically converting them to the required expressions.
19 | - Add an array of values to an existing field. In this case the output_field is required to tell Django what db type to use for the array:
20 | - Additional lookups have been added to the ArrayField to enable queries using the ANY and ALL database functions.
21 | - Use either a split array field or a multiple choice field in a model form
22 |
23 | HStoreField
24 | -----------
25 | The included HStoreField has been subclassed from django.contrib.postgres.fields.HStoreField to add extra features and is a drop-in replacement. To use this HStoreField:
26 |
27 | The customized Postgres HStoreField adds the following features:
28 |
29 | - Get hstore values by key
30 | - Update hstore by specific keys, leaving any others untouched
31 | - Added database functions for interacting with HStores, these functions handle the arguments by converting them to the correct expressions automatically.
32 | - Nested form field for a better representation of hstore in a form, either by providing a list of keys or list of form fields.
33 |
34 | JSONField
35 | ---------
36 | The included JSONField has been subclassed from django.contrib.postgres.fields.JSONField to add extra features and is a drop-in replacement. To use this JSONField:
37 |
38 | The customized Postgres JSONField adds the following features:
39 |
40 | - Get json values by key or key path
41 | - Update JSON Field by specific keys, leaving any others untouched
42 | - Delete JSONField by key or key path
43 | - Extra database functions for interacting with JSONFields. These functions handle the arguments by converting them to the correct expressions automatically.
44 | - The same NestedFormField and NestedFormWidget referred to above for HStore can also be used with a JSON Field by providing a list of fields.
45 |
46 | ModelAdmin
47 | ----------
48 |
49 | For an example of how these fields can be configured in a modelform; take the following models.py::
50 |
51 | from django.db import models
52 | from django_postgres_extensions.models.fields import HStoreField, JSONField, ArrayField
53 | from django_postgres_extensions.models.fields.related import ArrayManyToManyField
54 | from django import forms
55 | from django.contrib.postgres.forms import SplitArrayField
56 | from django_postgres_extensions.forms.fields import NestedFormField
57 |
58 | details_fields = (
59 | ('Brand', NestedFormField(keys=('Name', 'Country'))),
60 | ('Type', forms.CharField(max_length=25, required=False)),
61 | ('Colours', SplitArrayField(base_field=forms.CharField(max_length=10, required=False), size=10)),
62 | )
63 |
64 | class Buyer(models.Model):
65 | time = models.DateTimeField(auto_now_add=True)
66 | name = models.CharField(max_length=20)
67 |
68 | def __str__(self):
69 | return self.name
70 |
71 | class Product(models.Model):
72 | name = models.CharField(max_length=15)
73 | keywords = ArrayField(models.CharField(max_length=20), default=[], form_size=10, blank=True)
74 | sports = ArrayField(models.CharField(max_length=20),default=[], blank=True, choices=(
75 | ('football', 'Football'), ('tennis', 'Tennis'), ('golf', 'Golf'), ('basketball', 'Basketball'), ('hurling', 'Hurling'), ('baseball', 'Baseball')))
76 | shipping = HStoreField(keys=('Address', 'City', 'Region', 'Country'), blank=True, default={})
77 | details = JSONField(fields=details_fields, blank=True, default={})
78 | buyers = ArrayManyToManyField(Buyer)
79 |
80 | def __str__(self):
81 | return self.name
82 |
83 | @property
84 | def country(self):
85 | return self.shipping.get('Country', '')
86 |
87 | And with admin.py::
88 |
89 | from django.contrib import admin
90 | from django_postgres_extensions.admin.options import PostgresAdmin
91 | from models import Product, Buyer
92 |
93 | class ProductAdmin(PostgresAdmin):
94 | filter_horizontal = ('buyers',)
95 | fields = ('name', 'keywords', 'sports', 'shipping', 'details', 'buyers')
96 | list_display = ('name', 'keywords', 'shipping', 'details', 'country')
97 |
98 | admin.site.register(Buyer)
99 | admin.site.register(Product, ProductAdmin)
100 |
101 | The form field would look like this:
102 |
103 | .. fancybox:: admin_form.jpg
104 | :width: 100%
105 | :height: 100%
106 |
107 | The list display would look like this:
108 |
109 | .. fancybox:: admin_list.jpg
110 | :width: 100%
111 | :height: 100%
112 |
113 | Additional Queryset Methods
114 | ---------------------------
115 | The app adds the format method to all querysets. This will defer a field and add an annotation with a different format.
116 | For example to return a hstorefield as json::
117 |
118 | qs = Model.objects.all().format('description', HstoreToJSONBLoose)
119 |
120 | Array Many To Many Field
121 | ------------------------
122 | The Array Many To Many Field is designed be a drop-in replacement of the normal Django Many To Many Field and thus replicates many of its features.
123 |
124 | The Array Many To Many field supports the following features which replicate the API of the regular Many To Many Field:
125 |
126 | - Descriptor queryset with add, remove, clear and set for both forward and reverse relationships
127 | - Prefetch related for both forward and reverse relationships
128 | - Lookups across relationships with filter for both forward and reverse relationships
129 | - Lookups across relationships with exclude for forward relationships only
130 |
131 | Indices and tables
132 | ==================
133 |
134 | * :ref:`genindex`
135 | * :ref:`modindex`
136 | * :ref:`search`
137 |
--------------------------------------------------------------------------------
/tests/hstores/tests.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals
2 |
3 | from django.test import TestCase
4 | from .models import Product
5 | from django_postgres_extensions.models.functions import *
6 | from django_postgres_extensions.models.expressions import Key, Keys
7 | from django.db.utils import ProgrammingError
8 | from django.db import transaction
9 |
10 |
11 | class HStoreIndexTests(TestCase):
12 |
13 | def setUp(self):
14 | super(HStoreIndexTests, self).setUp()
15 | self.product = Product(name='xyz', description={'Industry': 'Music', 'Release': 'Album', 'Genre': 'Rock'})
16 | self.product.save()
17 | self.queryset = Product.objects.filter(pk=self.product.pk)
18 |
19 | def tearDown(self):
20 | Product.objects.all().delete()
21 |
22 | def test_hstore_value(self):
23 | product = self.queryset.get()
24 | self.assertDictEqual(product.description, {'Industry': 'Music', 'Release': 'Album', 'Genre': 'Rock'})
25 |
26 | def test_hstore_value_by_key(self):
27 | with transaction.atomic():
28 | obj = self.queryset.annotate(Key('description', 'Release')).get()
29 | self.assertEqual(obj.description__Release, 'Album')
30 |
31 | def test_hstore_values_by_keys(self):
32 | with transaction.atomic():
33 | obj = self.queryset.annotate(Keys('description', ['Industry', 'Release'])).get()
34 | self.assertListEqual(obj.description__selected, ['Music', 'Album'])
35 |
36 | def test_array_update_keys_values(self):
37 | with transaction.atomic():
38 | self.queryset.update(description__ = {'Genre': 'Heavy Metal', 'Popularity': 'Very Popular'})
39 | product = self.queryset.get()
40 | self.assertDictEqual(product.description, {'Industry': 'Music', 'Release': 'Album', 'Genre': 'Heavy Metal',
41 | 'Popularity': 'Very Popular'})
42 |
43 | def test_hstore_raw_int_raises(self):
44 | with transaction.atomic():
45 | self.queryset.update(description__={'Popularity': 5})
46 | self.assertRaises(ProgrammingError, self.queryset.update,
47 | description__raw={'Popularity': 5})
48 |
49 | class HstoreFuncTests(TestCase):
50 | def setUp(self):
51 | super(HstoreFuncTests, self).setUp()
52 | self.product = Product(name='xyz', description={'Industry': 'Music', 'Release': 'Album', 'Genre': 'Rock', 'Rating': '8'},
53 | details={'Popularity': 'Very Popular'})
54 | self.product.save()
55 | self.queryset = Product.objects.filter(pk=self.product.pk)
56 |
57 | def tearDown(self):
58 | Product.objects.all().delete()
59 |
60 | def test_hstore_new(self):
61 | product = Product(name='xyz', description=HStore(['Industry', 'Release', 'Genre', 'Popularity'],
62 | ['Film', 'Movie', 'Horror', 'Very Good']))
63 | product.save()
64 | queryset = Product.objects.filter(id=product.id)
65 | obj = queryset.get()
66 | self.assertDictEqual(obj.description,
67 | {'Genre': 'Horror', 'Release': 'Movie', 'Industry': 'Film', 'Popularity': 'Very Good'})
68 |
69 | def test_hstore_slice(self):
70 | with transaction.atomic():
71 | obj = self.queryset.annotate(description_slice=Slice('description', ['Industry', 'Release'])).get()
72 | self.assertDictEqual(obj.description_slice, {'Release': 'Album', 'Industry': 'Music'})
73 |
74 | def test_hstore_delete_key(self):
75 | with transaction.atomic():
76 | self.queryset.update(description = Delete('description', 'Genre'))
77 | product = self.queryset.get()
78 | self.assertDictEqual(product.description, {'Industry': 'Music', 'Release': 'Album', 'Rating': '8'})
79 |
80 | def test_hstore_delete_keys(self):
81 | with transaction.atomic():
82 | self.queryset.update(description = Delete('description', ['Industry', 'Genre']))
83 | product = self.queryset.get()
84 | self.assertDictEqual(product.description, {'Release': 'Album', 'Rating': '8'})
85 |
86 | def test_hstore_delete_by_dict(self):
87 | with transaction.atomic():
88 | self.queryset.update(description=Delete('description', {'Industry': 'Music', 'Release': 'Song', 'Genre': 'Rock'}))
89 | product = self.queryset.get()
90 | self.assertDictEqual(product.description, {'Release': 'Album', 'Rating': '8'})
91 |
92 | def test_hstore_keys_as_array(self):
93 | with transaction.atomic():
94 | product = self.queryset.annotate(description_keys=AKeys('description')).get()
95 | keys = product.description_keys
96 | keys.sort()
97 | self.assertListEqual(keys, ['Genre', 'Industry', 'Rating', 'Release'])
98 |
99 | def test_hstore_values_as_array(self):
100 | with transaction.atomic():
101 | product = self.queryset.annotate(description_values=AVals('description')).get()
102 | values = product.description_values
103 | values.sort()
104 | self.assertListEqual(values, ['8', 'Album', 'Music', 'Rock'])
105 |
106 | def test_hstore_to_array(self):
107 | with transaction.atomic():
108 | product = self.queryset.annotate(description_array=HStoreToArray('description')).get()
109 | self.assertListEqual(product.description_array, ['Genre', 'Rock', 'Rating', '8', 'Release', 'Album', 'Industry', 'Music'])
110 |
111 | def test_hstore_to_matrix(self):
112 | with transaction.atomic():
113 | product = self.queryset.annotate(description_matrix=HStoreToMatrix('description')).get()
114 | self.assertListEqual(product.description_matrix, [['Genre', 'Rock'], ['Rating', '8'], ['Release', 'Album'], ['Industry', 'Music']])
115 |
116 | def test_hstore_to_jsonb(self):
117 | with transaction.atomic():
118 | product = self.queryset.annotate(description_jsonb=HstoreToJSONB('description')).get()
119 | self.assertDictEqual(product.description_jsonb,
120 | {'Genre': 'Rock', 'Release': 'Album', 'Industry': 'Music', 'Rating': "8"})
121 |
122 | def test_hstore_to_jsonb_loose(self):
123 | with transaction.atomic():
124 | product = self.queryset.annotate(description_jsonb=HstoreToJSONBLoose('description')).get()
125 | self.assertDictEqual(product.description_jsonb,
126 | {'Genre': 'Rock', 'Release': 'Album', 'Industry': 'Music', 'Rating': 8})
127 |
128 | def test_queryset_format(self):
129 | with transaction.atomic():
130 | qs = self.queryset.format('description', HstoreToJSONBLoose)
131 | product = qs.get()
132 | self.assertDictEqual(product.description__alt,
133 | {'Genre': 'Rock', 'Release': 'Album', 'Industry': 'Music', 'Rating': 8})
--------------------------------------------------------------------------------
/readme.rst:
--------------------------------------------------------------------------------
1 | Django Postgres Extensions!
2 | ===========================
3 |
4 | Django Postgres Extensions adds a lot of functionality to Django.contrib.postgres, specifically in relation to ArrayField, HStoreField and JSONField, including much better form fields for dealing with these field types. The app also includes an Array Many To Many Field, so you can store the relationship in an array column instead of requiring an extra database table.
5 |
6 | Check out http://django-postgres-extensions.readthedocs.io/en/latest/ to get started.
7 |
8 | Latest release (0.9.3) tested with Django 2.0.2
9 |
10 | Feature Overview
11 | ================
12 | Custom Postgres backend
13 | -----------------------
14 | The customized Postgres backend adds the following features:
15 |
16 | - HStore Extension is automatically activated when a test database is created so you don't need to create a separate migration which is useful especially when building re-usable apps
17 | - Uses a different update compiler which adds some functionality outlined in the ArrayField section below
18 | - If db_index is set to True for an ArrayField, a GIN index will be created which is more useful than the default database index for arrays.
19 | - Adds some extra operators to enable ANY and ALL lookups
20 |
21 | ArrayField
22 | ----------
23 | The included ArrayField has been subclassed from django.contrib.postgres.fields.ArrayField to add extra features and is a drop-in replacement. To use this ArrayField. The customized Postgres ArrayField adds the following features:
24 |
25 | - Get array values by index:
26 | - Update array values by index:
27 | - Added database functions for interacting with Arrays. These functions handle the provided arguments by automatically converting them to the required expressions.
28 | - Add an array of values to an existing field. In this case the output_field is required to tell Django what db type to use for the array:
29 | - Additional lookups have been added to the ArrayField to enable queries using the ANY and ALL database functions.
30 | - Use either a split array field or a multiple choice field in a model form
31 |
32 | HStoreField
33 | -----------
34 | The included HStoreField has been subclassed from django.contrib.postgres.fields.HStoreField to add extra features and is a drop-in replacement. To use this HStoreField:
35 |
36 | The customized Postgres HStoreField adds the following features:
37 |
38 | - Get hstore values by key
39 | - Update hstore by specific keys, leaving any others untouched
40 | - Added database functions for interacting with HStores, these functions handle the arguments by converting them to the correct expressions automatically.
41 | - Nested form field for a better representation of hstore in a form, either by providing a list of keys or list of form fields.
42 |
43 | JSONField
44 | ---------
45 | The included JSONField has been subclassed from django.contrib.postgres.fields.JSONField to add extra features and is a drop-in replacement. To use this JSONField:
46 |
47 | The customized Postgres JSONField adds the following features:
48 |
49 | - Get json values by key or key path
50 | - Update JSON Field by specific keys, leaving any others untouched
51 | - Delete JSONField by key or key path
52 | - Extra database functions for interacting with JSONFields. These functions handle the arguments by converting them to the correct expressions automatically.
53 | - The same NestedFormField and NestedFormWidget referred to above for HStore can also be used with a JSON Field by providing a list of fields.
54 |
55 | ModelAdmin
56 | ----------
57 |
58 | For an example of how these fields can be configured in a modelform; take the following models.py::
59 |
60 | from django.db import models
61 | from django_postgres_extensions.models.fields import HStoreField, JSONField, ArrayField
62 | from django_postgres_extensions.models.fields.related import ArrayManyToManyField
63 | from django import forms
64 | from django.contrib.postgres.forms import SplitArrayField
65 | from django_postgres_extensions.forms.fields import NestedFormField
66 |
67 | details_fields = (
68 | ('Brand', NestedFormField(keys=('Name', 'Country'))),
69 | ('Type', forms.CharField(max_length=25, required=False)),
70 | ('Colours', SplitArrayField(base_field=forms.CharField(max_length=10, required=False), size=10)),
71 | )
72 |
73 | class Buyer(models.Model):
74 | time = models.DateTimeField(auto_now_add=True)
75 | name = models.CharField(max_length=20)
76 |
77 | def __str__(self):
78 | return self.name
79 |
80 | class Product(models.Model):
81 | name = models.CharField(max_length=15)
82 | keywords = ArrayField(models.CharField(max_length=20), default=[], form_size=10, blank=True)
83 | sports = ArrayField(models.CharField(max_length=20),default=[], blank=True, choices=(
84 | ('football', 'Football'), ('tennis', 'Tennis'), ('golf', 'Golf'), ('basketball', 'Basketball'), ('hurling', 'Hurling'), ('baseball', 'Baseball')))
85 | shipping = HStoreField(keys=('Address', 'City', 'Region', 'Country'), blank=True, default={})
86 | details = JSONField(fields=details_fields, blank=True, default={})
87 | buyers = ArrayManyToManyField(Buyer)
88 |
89 | def __str__(self):
90 | return self.name
91 |
92 | @property
93 | def country(self):
94 | return self.shipping.get('Country', '')
95 |
96 | And with admin.py::
97 |
98 | from django.contrib import admin
99 | from django_postgres_extensions.admin.options import PostgresAdmin
100 | from models import Product, Buyer
101 |
102 | class ProductAdmin(PostgresAdmin):
103 | filter_horizontal = ('buyers',)
104 | fields = ('name', 'keywords', 'sports', 'shipping', 'details', 'buyers')
105 | list_display = ('name', 'keywords', 'shipping', 'details', 'country')
106 |
107 | admin.site.register(Buyer)
108 | admin.site.register(Product, ProductAdmin)
109 |
110 | The form field would look like this:
111 |
112 | .. image:: docs/admin_form.jpg
113 |
114 | The list display would look like this:
115 |
116 | .. image:: docs/admin_list.jpg
117 |
118 | Additional Queryset Methods
119 | ---------------------------
120 | The app adds the format method to all querysets. This will defer a field and add an annotation with a different format.
121 | For example to return a hstorefield as json::
122 |
123 | qs = Model.objects.all().format('description', HstoreToJSONBLoose)
124 |
125 | Array Many To Many Field
126 | ------------------------
127 | The Array Many To Many Field is designed be a drop-in replacement of the normal Django Many To Many Field and thus replicates many of its features.
128 |
129 | The Array Many To Many field supports the following features which replicate the API of the regular Many To Many Field:
130 |
131 | - Descriptor queryset with add, remove, clear and set for both forward and reverse relationships
132 | - Prefetch related for both forward and reverse relationships
133 | - Lookups across relationships with filter for both forward and reverse relationships
134 | - Lookups across relationships with exclude for forward relationships only
135 |
--------------------------------------------------------------------------------
/django_postgres_extensions/models/query.py:
--------------------------------------------------------------------------------
1 | from django.db import transaction
2 | from django.core import exceptions
3 | from django.db.models.constants import LOOKUP_SEP
4 | from django.db.models.sql.constants import CURSOR
5 | from .sql import UpdateQuery
6 | import copy
7 |
8 |
9 |
10 | def update(self, **kwargs):
11 | """
12 | Updates all elements in the current QuerySet, setting all the given
13 | fields to the appropriate values.
14 | """
15 | assert self.query.can_filter(), \
16 | "Cannot update a query once a slice has been taken."
17 | self._for_write = True
18 | query = self.query.chain(UpdateQuery)
19 | query.add_update_values(kwargs)
20 | with transaction.atomic(using=self.db, savepoint=False):
21 | rows = query.get_compiler(self.db).execute_sql(CURSOR)
22 | self._result_cache = None
23 | return rows
24 | update.alters_data = True
25 |
26 |
27 | def _update(self, values):
28 | """
29 | A version of update that accepts field objects instead of field names.
30 | Used primarily for model saving and not intended for use by general
31 | code (it requires too much poking around at model internals to be
32 | useful at that level).
33 | """
34 | assert self.query.can_filter(), \
35 | "Cannot update a query once a slice has been taken."
36 | query = self.query.chain(UpdateQuery)
37 | values = [value for value in values if not getattr(value[0], 'many_to_many_array', False)]
38 | query.add_update_fields(values)
39 | self._result_cache = None
40 | return query.get_compiler(self.db).execute_sql(CURSOR)
41 | _update.alters_data = True
42 | _update.queryset_only = False
43 |
44 | def format(self, field, expression, output_field=None, *args, **kwargs):
45 | if not output_field:
46 | output_field = field + '__alt'
47 | kwargs = {output_field: expression(field, *args, **kwargs)}
48 | qs = self.defer(field).annotate(**kwargs)
49 | return qs
50 |
51 | def prefetch_one_level(instances, prefetcher, lookup, level):
52 | """
53 | Helper function for prefetch_related_objects().
54 |
55 | Run prefetches on all instances using the prefetcher object,
56 | assigning results to relevant caches in instance.
57 |
58 | Return the prefetched objects along with any additional prefetches that
59 | must be done due to prefetch_related lookups found from default managers.
60 | """
61 | # prefetcher must have a method get_prefetch_queryset() which takes a list
62 | # of instances, and returns a tuple:
63 |
64 | # (queryset of instances of self.model that are related to passed in instances,
65 | # callable that gets value to be matched for returned instances,
66 | # callable that gets value to be matched for passed in instances,
67 | # boolean that is True for singly related objects,
68 | # cache or field name to assign to,
69 | # boolean that is True when the previous argument is a cache name vs a field name).
70 |
71 | # The 'values to be matched' must be hashable as they will be used
72 | # in a dictionary.
73 |
74 | rel_qs, rel_obj_attr, instance_attr, single, cache_name, is_descriptor = (
75 | prefetcher.get_prefetch_queryset(instances, lookup.get_current_queryset(level)))
76 | # We have to handle the possibility that the QuerySet we just got back
77 | # contains some prefetch_related lookups. We don't want to trigger the
78 | # prefetch_related functionality by evaluating the query. Rather, we need
79 | # to merge in the prefetch_related lookups.
80 | # Copy the lookups in case it is a Prefetch object which could be reused
81 | # later (happens in nested prefetch_related).
82 | additional_lookups = [
83 | copy.copy(additional_lookup) for additional_lookup
84 | in getattr(rel_qs, '_prefetch_related_lookups', ())
85 | ]
86 | if additional_lookups:
87 | # Don't need to clone because the manager should have given us a fresh
88 | # instance, so we access an internal instead of using public interface
89 | # for performance reasons.
90 | rel_qs._prefetch_related_lookups = ()
91 |
92 | all_related_objects = list(rel_qs)
93 |
94 | is_multi_reference = getattr(rel_qs, 'is_multi_reference', False)
95 |
96 | rel_obj_cache = {}
97 | if not is_multi_reference:
98 | rel_obj_cache = {}
99 | for rel_obj in all_related_objects:
100 | rel_attr_val = rel_obj_attr(rel_obj)
101 | rel_obj_cache.setdefault(rel_attr_val, []).append(rel_obj)
102 |
103 | to_attr, as_attr = lookup.get_current_to_attr(level)
104 | # Make sure `to_attr` does not conflict with a field.
105 | if as_attr and instances:
106 | # We assume that objects retrieved are homogeneous (which is the premise
107 | # of prefetch_related), so what applies to first object applies to all.
108 | model = instances[0].__class__
109 | try:
110 | model._meta.get_field(to_attr)
111 | except exceptions.FieldDoesNotExist:
112 | pass
113 | else:
114 | msg = 'to_attr={} conflicts with a field on the {} model.'
115 | raise ValueError(msg.format(to_attr, model.__name__))
116 |
117 | # Whether or not we're prefetching the last part of the lookup.
118 | leaf = len(lookup.prefetch_through.split(LOOKUP_SEP)) - 1 == level
119 |
120 | for obj in instances:
121 | instance_attr_val = instance_attr(obj)
122 | if is_multi_reference:
123 | vals = [rel_obj for rel_obj in all_related_objects if rel_obj_attr(rel_obj, instance_attr_val)]
124 | else:
125 | vals = rel_obj_cache.get(instance_attr_val, [])
126 |
127 | if single:
128 | val = vals[0] if vals else None
129 | if as_attr:
130 | # A to_attr has been given for the prefetch.
131 | setattr(obj, to_attr, val)
132 | elif is_descriptor:
133 | # cache_name points to a field name in obj.
134 | # This field is a descriptor for a related object.
135 | setattr(obj, cache_name, val)
136 | else:
137 | # No to_attr has been given for this prefetch operation and the
138 | # cache_name does not point to a descriptor. Store the value of
139 | # the field in the object's field cache.
140 | obj._state.fields_cache[cache_name] = val
141 | else:
142 | if as_attr:
143 | setattr(obj, to_attr, vals)
144 | else:
145 | manager = getattr(obj, to_attr)
146 | if leaf and lookup.queryset is not None:
147 | qs = manager._apply_rel_filters(lookup.queryset)
148 | else:
149 | qs = manager.get_queryset()
150 | qs._result_cache = vals
151 | # We don't want the individual qs doing prefetch_related now,
152 | # since we have merged this into the current work.
153 | qs._prefetch_done = True
154 | obj._prefetched_objects_cache[cache_name] = qs
155 | return all_related_objects, additional_lookups
--------------------------------------------------------------------------------
/tests/prefetch_related_array/models.py:
--------------------------------------------------------------------------------
1 | import uuid
2 |
3 | from django.contrib.contenttypes.fields import (
4 | GenericForeignKey, GenericRelation,
5 | )
6 | from django.contrib.contenttypes.models import ContentType
7 | from django.db import models
8 | from django.utils.encoding import python_2_unicode_compatible
9 | from django_postgres_extensions.models.fields.related import ArrayManyToManyField
10 |
11 |
12 | # Basic tests
13 |
14 | @python_2_unicode_compatible
15 | class Author(models.Model):
16 | name = models.CharField(max_length=50, unique=True)
17 | first_book = models.ForeignKey('Book', models.CASCADE, related_name='first_time_authors')
18 | favorite_authors = ArrayManyToManyField(
19 | 'self', symmetrical=False, related_name='favors_me')
20 |
21 | def __str__(self):
22 | return self.name
23 |
24 | class Meta:
25 | ordering = ['id']
26 |
27 |
28 | class AuthorWithAge(Author):
29 | author = models.OneToOneField(Author, models.CASCADE, parent_link=True)
30 | age = models.IntegerField()
31 |
32 |
33 | @python_2_unicode_compatible
34 | class AuthorAddress(models.Model):
35 | author = models.ForeignKey(Author, models.CASCADE, to_field='name', related_name='addresses')
36 | address = models.TextField()
37 |
38 | class Meta:
39 | ordering = ['id']
40 |
41 | def __str__(self):
42 | return self.address
43 |
44 |
45 | @python_2_unicode_compatible
46 | class Book(models.Model):
47 | title = models.CharField(max_length=255)
48 | authors = ArrayManyToManyField(Author, related_name='books')
49 |
50 | def __str__(self):
51 | return self.title
52 |
53 | class Meta:
54 | ordering = ['id']
55 |
56 |
57 | class BookWithYear(Book):
58 | book = models.OneToOneField(Book, models.CASCADE, parent_link=True)
59 | published_year = models.IntegerField()
60 | aged_authors = ArrayManyToManyField(
61 | AuthorWithAge, related_name='books_with_year')
62 |
63 |
64 | class Bio(models.Model):
65 | author = models.OneToOneField(Author, models.CASCADE)
66 | books = ArrayManyToManyField(Book, blank=True)
67 |
68 |
69 | @python_2_unicode_compatible
70 | class Reader(models.Model):
71 | name = models.CharField(max_length=50)
72 | books_read = ArrayManyToManyField(Book, related_name='read_by')
73 |
74 | def __str__(self):
75 | return self.name
76 |
77 | class Meta:
78 | ordering = ['id']
79 |
80 |
81 | class BookReview(models.Model):
82 | book = models.ForeignKey(BookWithYear, models.CASCADE)
83 | notes = models.TextField(null=True, blank=True)
84 |
85 |
86 | # Models for default manager tests
87 |
88 | class Qualification(models.Model):
89 | name = models.CharField(max_length=10)
90 |
91 | class Meta:
92 | ordering = ['id']
93 |
94 |
95 | class TeacherManager(models.Manager):
96 | def get_queryset(self):
97 | return super(TeacherManager, self).get_queryset().prefetch_related('qualifications')
98 |
99 |
100 | @python_2_unicode_compatible
101 | class Teacher(models.Model):
102 | name = models.CharField(max_length=50)
103 | qualifications = ArrayManyToManyField(Qualification)
104 |
105 | objects = TeacherManager()
106 |
107 | def __str__(self):
108 | return "%s (%s)" % (self.name, ", ".join(q.name for q in self.qualifications.all()))
109 |
110 | class Meta:
111 | ordering = ['id']
112 |
113 |
114 | class Department(models.Model):
115 | name = models.CharField(max_length=50)
116 | teachers = ArrayManyToManyField(Teacher)
117 |
118 | class Meta:
119 | ordering = ['id']
120 |
121 |
122 | # GenericRelation/GenericForeignKey tests
123 |
124 | @python_2_unicode_compatible
125 | class TaggedItem(models.Model):
126 | tag = models.SlugField()
127 | content_type = models.ForeignKey(
128 | ContentType,
129 | models.CASCADE,
130 | related_name="taggeditem_set2",
131 | )
132 | object_id = models.PositiveIntegerField()
133 | content_object = GenericForeignKey('content_type', 'object_id')
134 | created_by_ct = models.ForeignKey(
135 | ContentType,
136 | models.SET_NULL,
137 | null=True,
138 | related_name='taggeditem_set3',
139 | )
140 | created_by_fkey = models.PositiveIntegerField(null=True)
141 | created_by = GenericForeignKey('created_by_ct', 'created_by_fkey',)
142 | favorite_ct = models.ForeignKey(
143 | ContentType,
144 | models.SET_NULL,
145 | null=True,
146 | related_name='taggeditem_set4',
147 | )
148 | favorite_fkey = models.CharField(max_length=64, null=True)
149 | favorite = GenericForeignKey('favorite_ct', 'favorite_fkey')
150 |
151 | def __str__(self):
152 | return self.tag
153 |
154 | class Meta:
155 | ordering = ['id']
156 |
157 |
158 | class Bookmark(models.Model):
159 | url = models.URLField()
160 | tags = GenericRelation(TaggedItem, related_query_name='bookmarks')
161 | favorite_tags = GenericRelation(TaggedItem,
162 | content_type_field='favorite_ct',
163 | object_id_field='favorite_fkey',
164 | related_query_name='favorite_bookmarks')
165 |
166 | class Meta:
167 | ordering = ['id']
168 |
169 |
170 | class Comment(models.Model):
171 | comment = models.TextField()
172 |
173 | # Content-object field
174 | content_type = models.ForeignKey(ContentType, models.CASCADE)
175 | object_pk = models.TextField()
176 | content_object = GenericForeignKey(ct_field="content_type", fk_field="object_pk")
177 |
178 | class Meta:
179 | ordering = ['id']
180 |
181 |
182 | # Models for lookup ordering tests
183 |
184 | class House(models.Model):
185 | name = models.CharField(max_length=50)
186 | address = models.CharField(max_length=255)
187 | owner = models.ForeignKey('Person', models.SET_NULL, null=True)
188 | main_room = models.OneToOneField('Room', models.SET_NULL, related_name='main_room_of', null=True)
189 |
190 | class Meta:
191 | ordering = ['id']
192 |
193 |
194 | class Room(models.Model):
195 | name = models.CharField(max_length=50)
196 | house = models.ForeignKey(House, models.CASCADE, related_name='rooms')
197 |
198 | class Meta:
199 | ordering = ['id']
200 |
201 |
202 | class Person(models.Model):
203 | name = models.CharField(max_length=50)
204 | houses = ArrayManyToManyField(House, related_name='occupants')
205 |
206 | @property
207 | def primary_house(self):
208 | # Assume business logic forces every person to have at least one house.
209 | return sorted(self.houses.all(), key=lambda house: -house.rooms.count())[0]
210 |
211 | @property
212 | def all_houses(self):
213 | return list(self.houses.all())
214 |
215 | class Meta:
216 | ordering = ['id']
217 |
218 |
219 | # Models for nullable FK tests
220 |
221 | @python_2_unicode_compatible
222 | class Employee(models.Model):
223 | name = models.CharField(max_length=50)
224 | boss = models.ForeignKey('self', models.SET_NULL, null=True, related_name='serfs')
225 |
226 | def __str__(self):
227 | return self.name
228 |
229 | class Meta:
230 | ordering = ['id']
231 |
232 |
233 | # Ticket #19607
234 |
235 | @python_2_unicode_compatible
236 | class LessonEntry(models.Model):
237 | name1 = models.CharField(max_length=200)
238 | name2 = models.CharField(max_length=200)
239 |
240 | def __str__(self):
241 | return "%s %s" % (self.name1, self.name2)
242 |
243 |
244 | @python_2_unicode_compatible
245 | class WordEntry(models.Model):
246 | lesson_entry = models.ForeignKey(LessonEntry, models.CASCADE)
247 | name = models.CharField(max_length=200)
248 |
249 | def __str__(self):
250 | return "%s (%s)" % (self.name, self.id)
251 |
252 |
253 | # Ticket #21410: Regression when related_name="+"
254 |
255 | @python_2_unicode_compatible
256 | class Author2(models.Model):
257 | name = models.CharField(max_length=50, unique=True)
258 | first_book = models.ForeignKey('Book', models.CASCADE, related_name='first_time_authors+')
259 | favorite_books = ArrayManyToManyField('Book', related_name='+')
260 |
261 | def __str__(self):
262 | return self.name
263 |
264 | class Meta:
265 | ordering = ['id']
266 |
267 |
268 | # Models for many-to-many with UUID pk test:
269 |
270 | class Pet(models.Model):
271 | id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
272 | name = models.CharField(max_length=20)
273 | people = ArrayManyToManyField(Person, related_name='pets')
274 |
275 |
276 | class Flea(models.Model):
277 | id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
278 | current_room = models.ForeignKey(Room, models.SET_NULL, related_name='fleas', null=True)
279 | pets_visited = ArrayManyToManyField(Pet, related_name='fleas_hosted')
280 | people_visited = ArrayManyToManyField(Person, related_name='fleas_hosted')
281 |
--------------------------------------------------------------------------------
/tests/modeladmin/tests.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals
2 |
3 | from django.test import override_settings
4 | from django.contrib.auth import get_user_model
5 | from django.contrib.admin.tests import AdminSeleniumTestCase
6 | from selenium.common.exceptions import NoSuchWindowException
7 | from .models import Product, Buyer
8 | import time
9 | import ast
10 |
11 | installed_apps = [
12 | 'django.contrib.admin',
13 | 'django.contrib.contenttypes',
14 | 'django.contrib.auth',
15 | 'django.contrib.sites',
16 | 'django.contrib.sessions',
17 | 'django.contrib.messages',
18 | 'django.contrib.staticfiles',
19 | 'django.contrib.postgres',
20 | 'django_postgres_extensions',
21 | 'modeladmin'
22 | ]
23 |
24 | @override_settings(ROOT_URLCONF='modeladmin.urls', INSTALLED_APPS=installed_apps, DEBUG=True)
25 | class PostgresAdminTestCase(AdminSeleniumTestCase):
26 |
27 | close_manually = False
28 |
29 | browsers = ['chrome']
30 |
31 | available_apps = installed_apps
32 |
33 | username = 'super'
34 | password = 'secret'
35 | email = 'super@example.com'
36 | app_name = 'modeladmin'
37 | model_name = 'product'
38 |
39 | @classmethod
40 | def setUpClass(cls):
41 | super(PostgresAdminTestCase, cls).setUpClass()
42 | cls.selenium.maximize_window()
43 | cls.url = '%s/%s/%s/%s' % (cls.live_server_url, 'admin', cls.app_name, cls.model_name)
44 |
45 | def tearDown(self):
46 | time.sleep(3)
47 | while self.close_manually:
48 | try:
49 | self.wait_page_loaded()
50 | except NoSuchWindowException:
51 | self.close_manually = False
52 | super(PostgresAdminTestCase, self).tearDown()
53 |
54 | def setUp(self):
55 | super(PostgresAdminTestCase, self).setUp()
56 | User = get_user_model()
57 | if not User.objects.filter(username=self.username).exists():
58 | User.objects.create_superuser(username=self.username, password=self.password, email=self.email)
59 | self.admin_login(self.username, self.password)
60 |
61 | def test_list(self):
62 | prod = Product(name='Pro Trainers', keywords=['fun', 'popular'],
63 | sports=['tennis', 'basketball'],
64 | shipping={
65 | 'Address': 'Pearse Street',
66 | 'City': 'Dublin',
67 | 'Region': 'Co. Dublin',
68 | 'Country': 'Ireland'
69 | },
70 | details={
71 | 'brand': {
72 | 'name': 'Adidas',
73 | 'country': 'Germany',
74 | },
75 | 'type': 'runners',
76 | 'colours': ['black', 'white', 'blue']
77 | }
78 | )
79 |
80 | prod.save()
81 | self.selenium.get(self.url)
82 | self.wait_page_loaded()
83 | element = self.selenium.find_elements_by_class_name('field-name')[0]
84 | self.assertEqual(element.text, "Pro Trainers")
85 | element = self.selenium.find_elements_by_class_name('field-keywords')[0]
86 | self.assertEqual(element.text, 'fun, popular')
87 | element = self.selenium.find_elements_by_class_name('field-shipping')[0]
88 | self.assertDictEqual(ast.literal_eval(element.text), {'City': 'Dublin', 'Region': 'Co. Dublin', 'Country': 'Ireland', 'Address': 'Pearse Street'})
89 | element = self.selenium.find_elements_by_class_name('field-details')[0]
90 | self.assertDictEqual(ast.literal_eval(element.text), {'colours': ['black', 'white', 'blue'], 'brand': {'country': 'Germany', 'name': 'Adidas'}, 'type': 'runners'})
91 | element = self.selenium.find_elements_by_class_name('field-country')[0]
92 | self.assertEqual(element.text, "Ireland")
93 |
94 | def fill_form(self, ids_values, select_values, many_to_many_select, replace=False):
95 | for id, value in ids_values:
96 | element = self.selenium.find_element_by_id(id)
97 | if replace:
98 | element.clear()
99 | element.send_keys(value)
100 | for select in select_values:
101 | selector, values = select
102 | for value in values:
103 | selection_box = "#id_%s" % selector
104 | self.get_select_option(selection_box, value).click()
105 | for field_name, values in many_to_many_select:
106 | from_box = '#id_%s_from' % field_name
107 | choose_link = 'id_%s_add_link' % field_name
108 | choose_elem = self.selenium.find_element_by_id(choose_link)
109 | for value in values:
110 | self.get_select_option(from_box, str(value)).click()
111 | choose_elem.click()
112 | self.selenium.find_element_by_xpath('//input[@value="Save"]').click()
113 | self.wait_page_loaded()
114 |
115 | def test_add(self):
116 | buyer1 = Buyer(name='Muhammed Ali')
117 | buyer1.save()
118 | buyer2 = Buyer(name='Conor McGregor')
119 | buyer2.save()
120 | buyer3 = Buyer(name='Floyd Mayweather')
121 | buyer3.save()
122 | self.selenium.get('%s/%s' % (self.url, 'add'))
123 | self.wait_page_loaded()
124 | ids_values = (("id_name", "Pro Trainers"), ("id_keywords_0", "fun"), ("id_keywords_1", "popular"),
125 | ("id_shipping_address", "Pearse Street"), ("id_shipping_city", "Dublin"), ("id_shipping_region", "Co.Dublin"),
126 | ("id_shipping_country", "Ireland"), ("id_details_brand_name", "Adidas"), ("id_details_brand_country", "Germany"),
127 | ("id_details_type", "Runners"), ("id_details_colours_0", "Black"), ("id_details_colours_1", "White"),
128 | ("id_details_colours_2", "Blue"))
129 | select_values = (("sports", ("tennis", "basketball"),),)
130 | many_to_many_select = (("buyers", (buyer1.pk, buyer3.pk),),)
131 | self.fill_form(ids_values, select_values, many_to_many_select)
132 | obj = Product.objects.get()
133 | self.assertEqual(obj.name, 'Pro Trainers')
134 | self.assertDictEqual(obj.shipping, {'City': 'Dublin', 'Region': 'Co.Dublin', 'Country': 'Ireland',
135 | 'Address': 'Pearse Street'})
136 | self.assertListEqual(obj.sports, ['tennis', 'basketball'])
137 | self.assertListEqual(obj.details['Colours'], ['Black', 'White', 'Blue', '', '', '', '', '', '', ''])
138 | self.assertDictEqual(obj.details['Brand'], {'Country': 'Germany', 'Name': 'Adidas'})
139 | self.assertEqual(obj.details['Type'], 'Runners')
140 | self.assertListEqual(obj.keywords,['fun', 'popular'])
141 | buyers = obj.buyers.all().order_by('id')
142 | self.assertQuerysetEqual(buyers, ['', ''])
143 |
144 | def test_update(self):
145 | buyer1 = Buyer(name='Muhammed Ali')
146 | buyer1.save()
147 | buyer2 = Buyer(name='Conor McGregor')
148 | buyer2.save()
149 | buyer3 = Buyer(name='Floyd Mayweather')
150 | buyer3.save()
151 | prod = Product(name='Pro Trainers', keywords=['fun', 'popular'],
152 | sports=['tennis', 'basketball'],
153 | shipping={
154 | 'Address': 'Pearse Street',
155 | 'City': 'Dublin',
156 | 'Region': 'Co. Dublin',
157 | 'Country': 'Ireland'
158 | },
159 | details={
160 | 'Brand': {
161 | 'Name': 'Adidas',
162 | 'Country': 'Germany',
163 | },
164 | 'Type': 'runners',
165 | 'Colours': ['black', 'white', 'blue']
166 | }
167 | )
168 |
169 | prod.save()
170 | prod.buyers.add(buyer1, buyer3)
171 | self.selenium.get('%s/%s/%s' % (self.url, prod.pk, 'change'))
172 | self.wait_page_loaded()
173 | ids_values = (("id_keywords_1", "not popular"),
174 | ("id_shipping_address", "Nassau Street"), ("id_details_brand_name", "Nike"), ("id_details_brand_country", "USA"),
175 | ("id_details_colours_3", "Red"))
176 | select_values = (("sports", ("football",),),)
177 | many_to_many_select = (("buyers", (buyer2.pk, ),),)
178 | self.fill_form(ids_values, select_values, many_to_many_select, replace=True)
179 | obj = Product.objects.get()
180 | buyers = obj.buyers.all().order_by('id')
181 | self.assertQuerysetEqual(buyers, ['', '', ''])
182 | self.assertEqual(obj.name, 'Pro Trainers')
183 | self.assertDictEqual(obj.shipping, {'City': 'Dublin', 'Region': 'Co. Dublin', 'Country': 'Ireland',
184 | 'Address': 'Nassau Street'})
185 | self.assertListEqual(obj.sports, ['football', 'tennis', 'basketball'])
186 | self.assertListEqual(obj.details['Colours'], ['black', 'white', 'blue', 'Red', '', '', '', '', '', ''])
187 | self.assertDictEqual(obj.details['Brand'], {'Country': 'USA', 'Name': 'Nike'})
188 | self.assertEqual(obj.details['Type'], 'runners')
189 | self.assertListEqual(obj.keywords,['fun', 'not popular'])
190 |
191 | def test_delete(self):
192 | buyer1 = Buyer(name='Muhammed Ali')
193 | buyer1.save()
194 | buyer2 = Buyer(name='Conor McGregor')
195 | buyer2.save()
196 | buyer3 = Buyer(name='Floyd Mayweather')
197 | buyer3.save()
198 | prod = Product(name='Pro Trainers', keywords=['fun', 'popular'],
199 | sports=['tennis', 'basketball'],
200 | shipping={
201 | 'Address': 'Pearse Street',
202 | 'City': 'Dublin',
203 | 'Region': 'Co. Dublin',
204 | 'Country': 'Ireland'
205 | },
206 | details={
207 | 'Brand': {
208 | 'Name': 'Adidas',
209 | 'Country': 'Germany',
210 | },
211 | 'Type': 'runners',
212 | 'Colours': ['black', 'white', 'blue']
213 | }
214 | )
215 |
216 | prod.save()
217 | prod.buyers.add(buyer1, buyer3)
218 | self.selenium.get('%s/%s/%s' % (self.url, prod.pk, 'delete'))
219 | self.selenium.find_element_by_xpath('//input[@value="Yes, I\'m sure"]').click()
220 | self.assertEqual(Product.objects.count(), 0)
--------------------------------------------------------------------------------
/django_postgres_extensions/models/fields/related_descriptors.py:
--------------------------------------------------------------------------------
1 | from django.db.models import signals
2 | from django.db import transaction, router
3 | from django_postgres_extensions.utils import OrderedSet
4 | from django_postgres_extensions.models.functions import ArrayCat, ArrayRemove, multi_array_remove
5 | from django.utils.functional import cached_property
6 |
7 | class MultiReferenceDescriptor(object):
8 |
9 | def __init__(self, rel, reverse=False, isJson=False):
10 | self.rel = rel
11 | self.isJson = isJson
12 | self.reverse = reverse
13 | self.through = rel.related_model
14 |
15 | def __get__(self, instance, cls=None):
16 | """
17 | Get the manager for the many-to-many array field.
18 | """
19 | if instance is None:
20 | return self
21 |
22 | return self.related_manager_cls(instance)
23 |
24 | @cached_property
25 | def related_manager_cls(self):
26 | model = self.rel.related_model if self.reverse else self.rel.model
27 | db = router.db_for_read(model)
28 | if hasattr(db, 'create_array_many_to_many_manager'):
29 | create_manager_func = db.create_array_many_to_many_manager
30 | else:
31 | create_manager_func = create_array_many_to_many_manager
32 | return create_manager_func(
33 | model._default_manager.__class__,
34 | self.rel,
35 | self.reverse,
36 | self.isJson
37 | )
38 |
39 | def create_array_many_to_many_manager(superclass, rel, reverse, IsJson):
40 |
41 | class ArrayForwardManyToManyManager(superclass):
42 |
43 | def __init__(self, instance):
44 | if instance.pk is None:
45 | raise ValueError("%r instance needs to have a primary key value before "
46 | "a many-to-many relationship can be used." %
47 | instance.__class__.__name__)
48 | super(ArrayForwardManyToManyManager, self).__init__()
49 | self.instance = instance
50 | self.field = rel.field
51 | self.target_field = rel.target_field
52 | self.fieldname = self.field.name
53 | self.column = self.field.attname
54 | self.rel = rel
55 | self.set_attributes()
56 | self.through = self.rel.related_model
57 |
58 | def __call__(self, **kwargs):
59 | # We use **kwargs rather than a kwarg argument to enforce the
60 | # `manager='manager_name'` syntax.
61 | manager = getattr(self.model, kwargs.pop('manager'))
62 | if hasattr(self.db, 'create_array_many_to_many_manager'):
63 | create_manager_func = self.db.create_array_many_to_many_manager
64 | else:
65 | create_manager_func = create_array_many_to_many_manager
66 | manager_class = create_manager_func(manager.__class__, rel, reverse, False)
67 | return manager_class(instance=self.instance)
68 | do_not_call_in_templates = True
69 |
70 | def set_attributes(self):
71 | self.model = self.rel.model
72 | self.related_model = self.rel.related_model
73 | self.prefetch_cache_name = rel.field.name
74 | self.to_field_name = self.target_field.name
75 | self.core_filters = {'%s' % self.rel.name: self.instance}
76 | self.symmetrical = self.rel.symmetrical
77 |
78 | def _apply_rel_filters(self, queryset):
79 | """
80 | Filter the queryset for the instance this manager is bound to.
81 | """
82 | queryset._add_hints(instance=self.instance)
83 | if self._db:
84 | queryset = queryset.using(self._db)
85 | queryset = queryset.filter(**self.core_filters)
86 | return queryset
87 |
88 | def get_queryset(self):
89 | try:
90 | return self.instance._prefetched_objects_cache[self.prefetch_cache_name]
91 | except (AttributeError, KeyError):
92 | queryset = super(ArrayForwardManyToManyManager, self).get_queryset()
93 | return self._apply_rel_filters(queryset)
94 |
95 | def get_prefetch_filters(self, instances):
96 | pks = []
97 | for instance in instances:
98 | pks += getattr(instance, self.column)
99 | filters = {'%s__in' % self.to_field_name:set(pks)}
100 | return filters
101 |
102 | def validate_rel_obj(self, rel_obj, fks):
103 | return getattr(rel_obj, self.to_field_name) in fks
104 |
105 | def get_instance_attr(self, instance):
106 | return getattr(instance, self.column)
107 |
108 | def get_prefetch_queryset(self, instances, queryset=None):
109 | if queryset is None:
110 | queryset = super(ArrayForwardManyToManyManager, self).get_queryset()
111 |
112 | queryset._add_hints(instance=instances[0])
113 | queryset = queryset.using(queryset._db or self._db)
114 |
115 | query = self.get_prefetch_filters(instances)
116 | queryset = queryset.filter(**query)
117 | queryset.is_multi_reference = True
118 |
119 | return queryset, self.validate_rel_obj, self.get_instance_attr, False, self.prefetch_cache_name, True
120 |
121 | def _update_instance(self, **kwargs):
122 | qs = self.related_model.objects.filter(pk=self.instance.pk)
123 | qs.update(**kwargs)
124 | _update_instance.alters_data = True
125 |
126 | def _add_items(self, *objs):
127 | objs = list(objs)
128 | if len(objs) == 1:
129 | exclude = {self.column: objs[0]}
130 | kwargs = {self.column: ArrayCat(self.column, objs, output_field=self.field)}
131 | self.related_model.objects.filter(pk=self.instance.pk).exclude(**exclude).update(**kwargs)
132 | else:
133 | instance = self.related_model.objects.only(self.column).get(pk=self.instance.pk)
134 | objs = list(OrderedSet(objs) - set(getattr(instance, self.column)))
135 | kwargs = {self.column: ArrayCat(self.column, objs, output_field=self.field)}
136 | self._update_instance(**kwargs)
137 | # If this is a symmetrical m2m relation to self, add the mirror entry to the other objs array
138 | if self.symmetrical:
139 | kwargs = {self.column: ArrayCat(self.column, [self.instance.pk], output_field=self.field)}
140 | self.model.objects.filter(pk__in=objs).update(**kwargs)
141 | _add_items.alters_data = True
142 |
143 | def validate_item(self, obj):
144 | return self.field.validate_item(obj, model=self.model)
145 |
146 | def add(self, *objs, **kwargs):
147 | objs = [self.validate_item(obj) for obj in objs]
148 | signals.m2m_changed.send(
149 | sender=self.through, action='pre_add',
150 | instance=self.instance, reverse=self.reverse,
151 | model=self.model, pk_set=objs, using=self.db,
152 | )
153 | with transaction.atomic():
154 | self._add_items(*objs)
155 | signals.m2m_changed.send(
156 | sender=self.through, action='post_add',
157 | instance=self.instance, reverse=self.reverse,
158 | model=self.model, pk_set=objs, using=self.db,
159 | )
160 |
161 | def remove(self, *objs):
162 | objs = [self.validate_item(obj) for obj in objs]
163 | signals.m2m_changed.send(
164 | sender=self.through, action="pre_remove",
165 | instance=self.instance, reverse=self.reverse,
166 | model=self.model, pk_set=objs, using=self.db,
167 | )
168 | with transaction.atomic():
169 | self._remove_items(*objs)
170 | signals.m2m_changed.send(
171 | sender=self.through, action="post_remove",
172 | instance=self.instance, reverse=self.reverse,
173 | model=self.model, pk_set=objs, using=self.db,
174 | )
175 | remove.alters_data = True
176 |
177 | def _remove_items(self, *objs, **kwargs):
178 | if objs:
179 | chunks = [objs[x:x + 100] for x in range(0, len(objs), 100)]
180 | for chunk in chunks:
181 | kwargs = {self.column: multi_array_remove(self.column, *chunk)}
182 | self._update_instance(**kwargs)
183 | # If this is a symmetrical m2m relation to self, add the mirror entry to the other objs array
184 | if self.symmetrical:
185 | kwargs = {self.column: ArrayRemove(self.column, self.instance.pk, output_field=self.field)}
186 | self.model.objects.filter(pk__in=list(objs)).update(**kwargs)
187 | _remove_items.alters_data = True
188 |
189 | def create(self, **kwargs):
190 | new_obj = super(ArrayForwardManyToManyManager, self).create(**kwargs)
191 | self.add(new_obj)
192 | return new_obj
193 | create.alters_data = True
194 |
195 | def get_or_create(self, **kwargs):
196 | obj, created = super(ArrayForwardManyToManyManager, self).get_or_create(**kwargs)
197 | # We only need to add() if created because if we got an object back
198 | # from get() then the relationship already exists.
199 | if created:
200 | self.add(obj)
201 | return obj, created
202 | get_or_create.alters_data = True
203 |
204 | def update_or_create(self, **kwargs):
205 | obj, created = super(ArrayForwardManyToManyManager, self).update_or_create(**kwargs)
206 | # We only need to add() if created because if we got an object back
207 | # from get() then the relationship already exists.
208 | if created:
209 | self.add(obj)
210 | return obj, created
211 | update_or_create.alters_data = True
212 |
213 | def _clear(self):
214 | kwargs = {self.column: []}
215 | self._update_instance(**kwargs)
216 | if self.symmetrical:
217 | kwargs = {self.column: ArrayRemove(self.column, self.instance.pk, output_field=self.field)}
218 | self.model.objects.update(**kwargs)
219 | _clear.alters_data = True
220 |
221 | def clear(self, **kwargs):
222 | with transaction.atomic():
223 | signals.m2m_changed.send(
224 | sender=self.through, action="pre_clear",
225 | instance=self.instance, reverse=reverse,
226 | model=self.model, pk_set=None, using=self.db,
227 | )
228 | with transaction.atomic():
229 | self._clear()
230 | signals.m2m_changed.send(
231 | sender=self.model, action="post_clear",
232 | instance=self.instance, reverse=reverse,
233 | model=self.model, pk_set=None, using=self.db,
234 | )
235 | clear.alters_data = True
236 |
237 | def set(self, objs, **kwargs):
238 | with transaction.atomic(savepoint=False):
239 | old_ids = set(self.values_list(self.to_field_name, flat=True))
240 | new_objs = []
241 | for obj in objs:
242 | fk_val = (obj.pk if isinstance(obj, self.model) else obj)
243 | if fk_val in old_ids:
244 | old_ids.remove(fk_val)
245 | else:
246 | new_objs.append(obj)
247 | self.remove(*old_ids)
248 | self.add(*new_objs)
249 | set.alters_data = True
250 |
251 | class ArrayReverseManyToManyManager(ArrayForwardManyToManyManager):
252 |
253 | def set_attributes(self):
254 | self.model = rel.related_model
255 | self.related_model = rel.model
256 | self.prefetch_cache_name = rel.field.related_query_name()
257 | self.core_filters = {'%s' % self.fieldname: self.instance}
258 | self.prefetch_filters = {}
259 | self.to_field_name = 'pk'
260 | self.to_field_value = self.instance.pk
261 | self.symmetrical = False
262 |
263 | def validate_rel_obj(self, rel_obj, pk):
264 | return pk in getattr(rel_obj, self.column)
265 |
266 | def get_instance_attr(self, instance):
267 | return getattr(instance, self.to_field_name)
268 |
269 | def get_prefetch_filters(self, instances):
270 | filters = {'%s__overlap' % self.fieldname: instances}
271 | return filters
272 |
273 | def _add_items(self, *objs, **kwargs):
274 | exclude = {self.column: self.instance.pk}
275 | qs = self.model.objects.filter(pk__in = objs).exclude(**exclude)
276 | kwargs = {self.column: ArrayCat(self.column, [self.to_field_value])}
277 | qs.update(**kwargs)
278 | _add_items.alters_data = True
279 |
280 | def _remove_items(self, *objs, **kwargs):
281 | qs = self.filter(pk__in = objs)
282 | kwargs = {self.column: ArrayRemove(self.column, self.to_field_value)}
283 | with transaction.atomic():
284 | qs.update(**kwargs)
285 | _remove_items.alters_data = True
286 |
287 | def _clear(self):
288 | with transaction.atomic():
289 | kwargs = {self.column: ArrayRemove(self.column, self.to_field_value)}
290 | self.model.objects.update(**kwargs)
291 | _clear.alters_data = True
292 |
293 | if reverse:
294 | return ArrayReverseManyToManyManager
295 | return ArrayForwardManyToManyManager
--------------------------------------------------------------------------------
/tests/arrays/tests.py:
--------------------------------------------------------------------------------
1 | from __future__ import unicode_literals
2 |
3 | from django.test import TestCase
4 | from .models import Product
5 | from django_postgres_extensions.models.functions import *
6 | from django_postgres_extensions.models.expressions import F, Value as V, Index, SliceArray
7 | from django.db.utils import ProgrammingError, DataError
8 | from django.db import transaction
9 | from unittest import skip
10 |
11 | class ArrayCharsIndexTests(TestCase):
12 |
13 | def setUp(self):
14 | super(ArrayCharsIndexTests, self).setUp()
15 | self.product = Product(name='xyz', tags=['Music', 'Album', 'Rock'], moretags=['Very Popular'])
16 | self.product.save()
17 | self.queryset = Product.objects.filter(pk=self.product.pk)
18 | obj = Product.objects.annotate(Index('tags', 1)).get()
19 |
20 | def tearDown(self):
21 | self.queryset.delete()
22 |
23 | def test_array_values(self):
24 | product = self.queryset.get()
25 | array_values = product.tags
26 | self.assertListEqual(array_values, ['Music', 'Album', 'Rock'])
27 |
28 | def test_array_index(self):
29 | with transaction.atomic():
30 | obj = self.queryset.annotate(Index('tags', 1)).get()
31 | self.assertEqual(obj.tags__1, 'Album')
32 |
33 | def test_array_index_with_output_field(self):
34 | with transaction.atomic():
35 | obj = self.queryset.annotate(tag_1=Index('tags', 1)).get()
36 | self.assertEqual(obj.tag_1, 'Album')
37 |
38 | def test_array_index_slice(self):
39 | with transaction.atomic():
40 | obj = self.queryset.annotate(SliceArray('tags', 0, 1)).get()
41 | self.assertEqual(obj.tags__0_1, ['Music', 'Album'])
42 |
43 | def test_array_update_index(self):
44 | with transaction.atomic():
45 | self.queryset.update(tags__2='Heavy Metal')
46 | product = self.queryset.get()
47 | self.assertListEqual(product.tags, ['Music', 'Album', 'Heavy Metal'])
48 |
49 | def test_array_int_converted(self):
50 | with transaction.atomic():
51 | self.queryset.update(tags__2=1)
52 | product = self.queryset.get()
53 | self.assertListEqual(product.tags, ['Music', 'Album', '1'])
54 |
55 | class ArrayCharsFuncTests(TestCase):
56 | def setUp(self):
57 | super(ArrayCharsFuncTests, self).setUp()
58 | self.product = Product(name='xyz', tags=['Music', 'Album', 'Rock'], moretags=['Very Popular'])
59 | self.product.save()
60 | self.queryset = Product.objects.filter(id=self.product.id)
61 |
62 | def tearDown(self):
63 | self.queryset.delete()
64 |
65 | def test_array_length(self):
66 | with transaction.atomic():
67 | obj = self.queryset.annotate(tags_length=ArrayLength('tags', 1)).get()
68 | self.assertEqual(obj.tags_length, 3)
69 |
70 | def test_array_append(self):
71 | with transaction.atomic():
72 | obj = self.queryset.annotate(tags_appended=ArrayAppend('tags', 'Popular')).get()
73 | self.assertListEqual(obj.tags_appended, ['Music', 'Album', 'Rock', 'Popular'])
74 | with transaction.atomic():
75 | self.queryset.update(tags = ArrayAppend('tags', 'Popular'))
76 | product = self.queryset.get()
77 | self.assertListEqual(product.tags, ['Music', 'Album', 'Rock', 'Popular'])
78 |
79 | def test_array_char_raises(self):
80 | with transaction.atomic():
81 | self.assertRaises(ProgrammingError, self.queryset.update, tags=ArrayAppend('tags', 1))
82 |
83 | def test_array_prepend(self):
84 | with transaction.atomic():
85 | self.queryset.update(tags = ArrayPrepend('Popular', 'tags'))
86 | product = self.queryset.get()
87 | self.assertListEqual(product.tags, ['Popular', 'Music', 'Album', 'Rock'])
88 |
89 | def test_array_remove(self):
90 | with transaction.atomic():
91 | self.queryset.update(tags = ArrayRemove('tags', 'Album'))
92 | product = self.queryset.get()
93 | self.assertListEqual(product.tags, ['Music', 'Rock'])
94 |
95 | def test_array_cat(self):
96 | with transaction.atomic():
97 | self.queryset.update(tags = ArrayCat('tags', 'moretags'))
98 | product = self.queryset.get()
99 | self.assertListEqual(product.tags, ['Music', 'Album', 'Rock', 'Very Popular'])
100 |
101 | def test_array_cat_list(self):
102 | with transaction.atomic():
103 | self.queryset.update(tags=ArrayCat('tags', ['Popular', '8'], output_field=Product._meta.get_field('tags')))
104 | product = self.queryset.get()
105 | self.assertListEqual(product.tags, ['Music', 'Album', 'Rock', 'Popular', '8'])
106 |
107 | def test_array_replace(self):
108 | with transaction.atomic():
109 | self.queryset.update(tags = ArrayReplace('tags', 'Rock', 'Heavy Metal'))
110 | product = self.queryset.get()
111 | self.assertListEqual(product.tags, ['Music', 'Album', 'Heavy Metal'])
112 |
113 | def test_array_position(self):
114 | with transaction.atomic():
115 | obj = self.queryset.annotate(position=ArrayPosition('tags', 'Rock')).get()
116 | self.assertEqual(obj.position, 3)
117 |
118 | def test_array_positions(self):
119 | with transaction.atomic():
120 | self.queryset.update(tags = ArrayPrepend('Rock', 'tags'))
121 | with transaction.atomic():
122 | obj = self.queryset.annotate(positions=ArrayPositions('tags', 'Rock')).get()
123 | self.assertEqual(obj.positions, [1, 4])
124 |
125 | class ArrayCharsCatTests(TestCase):
126 | def setUp(self):
127 | super(ArrayCharsCatTests, self).setUp()
128 | self.product = Product(name='xyz', tags=['Music', 'Album', 'Rock'], moretags=['Very Popular'])
129 | self.product.save()
130 | self.queryset = Product.objects.filter(id=self.product.id)
131 | self.prod1_queryset = self.queryset.filter(id=self.product.id)
132 |
133 | def tearDown(self):
134 | self.queryset.delete()
135 |
136 | def test_array_cat_append(self):
137 | with transaction.atomic():
138 | self.queryset.update(tags=F('tags').cat(V(['Popular'], output_field = Product._meta.get_field('tags'))))
139 | product = self.queryset.get()
140 | self.assertListEqual(product.tags, ['Music', 'Album', 'Rock', 'Popular'])
141 |
142 | def test_array_cat_arrays(self):
143 | with transaction.atomic():
144 | self.queryset.update(tags=F('tags').cat(F('moretags')))
145 | product = self.queryset.get()
146 | self.assertListEqual(product.tags, ['Music', 'Album', 'Rock', 'Very Popular'])
147 |
148 | def test_array_char_raises(self):
149 | with transaction.atomic():
150 | self.assertRaises((ProgrammingError), self.queryset.update, tags=F('tags').cat(V(1)))
151 |
152 | class ArrayIntTests(TestCase):
153 | def setUp(self):
154 | super(ArrayIntTests, self).setUp()
155 | self.product = Product(name='xyz', prices=[0, 1, 2])
156 | self.product.save()
157 | self.queryset = Product.objects.filter(id=self.product.id)
158 |
159 | def test_array_int_index(self):
160 | with transaction.atomic():
161 | obj = self.queryset.annotate(Index('prices', 1)).get()
162 | self.assertEqual(obj.prices__1, 1)
163 |
164 | def test_array_int_update_index(self):
165 | with transaction.atomic():
166 | self.queryset.update(prices__2=3)
167 | product = self.queryset.get()
168 | self.assertListEqual(product.prices, [0, 1, 3])
169 |
170 | def test_array_int_append(self):
171 | with transaction.atomic():
172 | self.queryset.update(prices=ArrayAppend('prices', 3))
173 | product = self.queryset.get()
174 | self.assertListEqual(product.prices, [0, 1, 2, 3])
175 |
176 | def test_array_int_cat_append(self):
177 | with transaction.atomic():
178 | self.queryset.update(prices=F('prices').cat(V(3)))
179 | product = self.queryset.get()
180 | self.assertListEqual(product.prices, [0,1,2,3])
181 |
182 | def test_array_int_cat_append_list(self):
183 | with transaction.atomic():
184 | self.queryset.update(prices=F('prices').cat(V([3, 4])))
185 | product = self.queryset.get()
186 | self.assertListEqual(product.prices, [0,1,2, 3, 4])
187 |
188 | def test_array_int_cat_prepend(self):
189 | with transaction.atomic():
190 | self.queryset.update(prices=V(-1).cat(F('prices')))
191 | product = self.queryset.get()
192 | self.assertListEqual(product.prices, [-1, 0, 1, 2])
193 |
194 | def test_array_int_double_cat(self):
195 | with transaction.atomic():
196 | self.queryset.update(prices=V(-1).cat(F('prices').cat(V(3))))
197 | product = self.queryset.get()
198 | self.assertListEqual(product.prices, [-1, 0, 1, 2, 3])
199 |
200 | def test_array_int_raises(self):
201 | self.assertRaises(DataError, self.queryset.update, prices=ArrayAppend('prices', 'test'))
202 |
203 | class ArrayMultiDimensionalTests(TestCase):
204 | def setUp(self):
205 | super(ArrayMultiDimensionalTests, self).setUp()
206 | self.product = Product(name='xyz', coordinates=[[0,15, 25], [15,30, 40], [45, 60, 90]])
207 | self.product.save()
208 | self.queryset = Product.objects.filter(id=self.product.id)
209 |
210 | def tearDown(self):
211 | self.queryset.delete()
212 |
213 | def test_2d_array_values(self):
214 | product = self.queryset.get()
215 | array_values = product.coordinates
216 | self.assertListEqual(array_values, [[0,15, 25], [15, 30, 40], [45, 60, 90]])
217 |
218 | def test_2d_array_dimensions(self):
219 | with transaction.atomic():
220 | obj = self.queryset.annotate(coordinates_dims = ArrayDims('coordinates')).get()
221 | self.assertEqual(obj.coordinates_dims, '[1:3][1:3]')
222 |
223 | def test_2d_array_upper(self):
224 | with transaction.atomic():
225 | obj = self.queryset.annotate(coordinates_upper_1 = ArrayUpper('coordinates', 1)).get()
226 | self.assertEqual(obj.coordinates_upper_1, 3)
227 |
228 | def test_2d_array_lower(self):
229 | with transaction.atomic():
230 | obj = self.queryset.annotate(coordinates_lower_2 = ArrayLower('coordinates', 2)).get()
231 | self.assertEqual(obj.coordinates_lower_2, 1)
232 |
233 | def test_2d_array_length(self):
234 | with transaction.atomic():
235 | obj = self.queryset.annotate(coordinates_length_2=ArrayLength('coordinates', 2)).get()
236 | self.assertEqual(obj.coordinates_length_2, 3)
237 |
238 | def test_2d_array_cardinality(self):
239 | with transaction.atomic():
240 | obj = self.queryset.annotate(coordinates_cardinality=Cardinality('coordinates')).get()
241 | self.assertEqual(obj.coordinates_cardinality, 9)
242 |
243 | def test_2d_array_index(self):
244 | with transaction.atomic():
245 | obj = self.queryset.annotate(Index(Index('coordinates', 2), 1)).get()
246 | self.assertEqual(obj.coordinates__2__1, 60)
247 |
248 | def test_2d_array_index_slice_1(self):
249 | with transaction.atomic():
250 | obj = self.queryset.annotate(SliceArray(SliceArray('coordinates', 0, 2), 0, 0)).get()
251 | self.assertEqual(obj.coordinates__0_2__0_0, [[0], [15], [45]])
252 | with transaction.atomic():
253 | obj = self.queryset.annotate(SliceArray(SliceArray('coordinates', 0, 2), 1, 1)).get()
254 | self.assertEqual(obj.coordinates__0_2__1_1, [[15], [30], [60]])
255 |
256 | def test_2d_array_index_slice_2(self):
257 | with transaction.atomic():
258 | obj = self.queryset.annotate(SliceArray(SliceArray('coordinates', 1, 1), 1, 2)).get()
259 | self.assertEqual(obj.coordinates__1_1__1_2, [[30, 40]])
260 |
261 | def test_2d_array_index_slice_3(self):
262 | with transaction.atomic():
263 | obj = self.queryset.annotate(SliceArray(SliceArray('coordinates', 1, 2), 1, 2)).get()
264 | self.assertEqual(obj.coordinates__1_2__1_2, [[30, 40], [60, 90]])
265 |
266 | @skip("Update seems to run but change is not made '{'")
267 | def test_2d_array_update_index_1(self):
268 | with transaction.atomic():
269 | self.queryset.update(tags__1=[70, 90, 100])
270 | product = self.queryset.get()
271 | self.assertListEqual(product.coordinates, [[0,15, 25], [70, 90, 100], [45, 60, 90]])
272 |
273 | @skip("Update seems to run but unable to retrieve object due to DataError: array does not start with '{'")
274 | def test_2d_array_update_index_2(self):
275 | with transaction.atomic():
276 | self.queryset.update(tags__1__2=35)
277 | product = self.queryset.get()
278 | self.assertListEqual(product.coordinates, [[0,15, 25], [15, 35, 40], [45, 60, 90]])
279 |
280 | @skip("Fails with 'No function matches the given name and argument types. You might need to add explicit type casts.'")
281 | def test_2d_array_append_1(self):
282 | with transaction.atomic():
283 | obj = self.queryset.annotate(coordinates_appended=ArrayAppend('coordinates', [70, 90, 100])).get()
284 | self.assertListEqual(obj.coordinates_appended, [[0,15, 25], [15, 35, 40], [45, 60, 90], [70, 90, 100]])
285 | with transaction.atomic():
286 | self.queryset.update(coordinates=ArrayAppend('coordinates', [70, 90, 100]))
287 | product = self.queryset.get()
288 | self.assertListEqual(product.coordinates, [[0,15, 25], [15, 35, 40], [45, 60, 90], [70, 90, 100]])
289 |
290 | @skip(
291 | "Fails with 'No function matches the given name and argument types. "
292 | "You might need to add explicit type casts.'")
293 | def test_2d_array_append_2(self):
294 | with transaction.atomic():
295 | obj = self.queryset.update(coordinates=ArrayAppend(Index('coordinates', 1), [100]))
296 | product = self.queryset.get()
297 | self.assertListEqual(obj.coordinates, [[0, 15, 25], [15, 35, 40, 100], [45, 60, 90]])
--------------------------------------------------------------------------------
/django_postgres_extensions/models/fields/related.py:
--------------------------------------------------------------------------------
1 | from django_postgres_extensions.models.fields import ArrayField
2 | from django.db.models.fields.related import RelatedField
3 | from django.db.models.query_utils import PathInfo
4 | from .reverse_related import ArrayManyToManyRel
5 | from .related_descriptors import MultiReferenceDescriptor
6 | from django.db import models
7 | from django.db.models.fields.related import RECURSIVE_RELATIONSHIP_CONSTANT, lazy_related_operation
8 | from django.forms.models import ModelMultipleChoiceField
9 | from django.utils import six
10 | from django.utils.encoding import force_text
11 | from .related_lookups import RelatedArrayContains, RelatedArrayExact, RelatedArrayContainedBy, RelatedContainsItem, \
12 | RelatedArrayOverlap, RelatedAnyGreaterThan, RelatedAnyLessThanOrEqual, RelatedAnyLessThan, RelatedAnyGreaterThanOrEqual
13 |
14 | class ArrayManyToManyField(ArrayField, RelatedField):
15 | # Field flags
16 | many_to_many_array = True
17 | many_to_many = False
18 | many_to_one = False
19 | one_to_many = False
20 | one_to_one = False
21 |
22 | rel_class = ArrayManyToManyRel
23 |
24 | def __init__(self, to_model, base_field=None, size=None, related_name=None, symmetrical=None,
25 | related_query_name=None, limit_choices_to=None, to_field=None, db_constraint=False, **kwargs):
26 |
27 | try:
28 | to = to_model._meta.model_name
29 | except AttributeError:
30 | assert isinstance(to_model, six.string_types), (
31 | "%s(%r) is invalid. First parameter to ForeignKey must be "
32 | "either a model, a model name, or the string %r" % (
33 | self.__class__.__name__, to_model,
34 | RECURSIVE_RELATIONSHIP_CONSTANT,
35 | )
36 | )
37 | to = str(to_model)
38 | else:
39 | # For backwards compatibility purposes, we need to *try* and set
40 | # the to_field during FK construction. It won't be guaranteed to
41 | # be correct until contribute_to_class is called. Refs #12190.
42 | to_field = to_field or (to_model._meta.pk and to_model._meta.pk.name)
43 | if not base_field:
44 | field = to_model._meta.get_field(to_field)
45 | if not field.is_relation:
46 | base_field_type = type(field)
47 | internal_type = field.get_internal_type()
48 | if internal_type == 'AutoField':
49 | pass
50 | elif internal_type == 'BigAutoField':
51 | base_field = models.BigIntegerField()
52 | elif hasattr(field, 'max_length'):
53 | base_field = base_field_type(max_length = field.max_length)
54 | else:
55 | base_field = base_field_type()
56 |
57 | if not base_field:
58 | base_field = models.IntegerField()
59 |
60 | if symmetrical is None:
61 | symmetrical = (to == RECURSIVE_RELATIONSHIP_CONSTANT)
62 |
63 | kwargs['rel'] = self.rel_class(
64 | self, to, to_field,
65 | related_name=related_name,
66 | related_query_name=related_query_name,
67 | limit_choices_to=limit_choices_to,
68 | symmetrical=symmetrical,
69 | )
70 | self.has_null_arg = 'null' in kwargs
71 |
72 | self.db_constraint = db_constraint
73 |
74 | self.to = to
75 |
76 | if 'default' not in kwargs.keys():
77 | kwargs['default'] = []
78 | kwargs['blank'] = True
79 |
80 | self.from_fields = ['self']
81 | self.to_fields = [to_field]
82 |
83 | super(ArrayManyToManyField, self).__init__(base_field, size=size, **kwargs)
84 |
85 | def deconstruct(self):
86 | name, path, args, kwargs = super(ArrayManyToManyField, self).deconstruct()
87 | args = (self.to,)
88 | kwargs.update({
89 | 'base_field': self.base_field,
90 | 'size': self.size,
91 | 'related_name': self.remote_field.related_name,
92 | 'symmetrical': self.remote_field.symmetrical,
93 | 'related_query_name': self.remote_field.related_query_name,
94 | 'limit_choices_to': self.remote_field.limit_choices_to,
95 | 'to_field': self.remote_field.field,
96 | 'db_constraint': self.db_constraint
97 | })
98 | return name, path, args, kwargs
99 |
100 | def get_attname(self):
101 | return '%s_ids' % self.name
102 |
103 | def get_attname_column(self):
104 | attname = self.get_attname()
105 | column = self.db_column or attname
106 | return attname, column
107 |
108 | def get_accessor_name(self):
109 | return self.remote_field.model_name + '_set'
110 |
111 | def get_reverse_accessor_name(self):
112 | return self.remote_field.get_accessor_name()
113 |
114 | def contribute_to_class(self, cls, name, **kwargs):
115 | # To support multiple relations to self, it's useful to have a non-None
116 | # related name on symmetrical relations for internal reasons. The
117 | # concept doesn't make a lot of sense externally ("you want me to
118 | # specify *what* on my non-reversible relation?!"), so we set it up
119 | # automatically. The funky name reduces the chance of an accidental
120 | # clash.
121 | if self.remote_field.symmetrical and (
122 | self.remote_field.model == "self" or self.remote_field.model == cls._meta.object_name):
123 | self.remote_field.related_name = "%s_rel_+" % name
124 | elif self.remote_field.is_hidden():
125 | # If the backwards relation is disabled, replace the original
126 | # related_name with one generated from the m2m field name. Django
127 | # still uses backwards relations internally and we need to avoid
128 | # clashes between multiple m2m fields with related_name == '+'.
129 | self.remote_field.related_name = "_%s_%s_+" % (cls.__name__.lower(), name)
130 |
131 | super(ArrayManyToManyField, self).contribute_to_class(cls, name)
132 |
133 | if not cls._meta.abstract:
134 | setattr(cls, self.name, MultiReferenceDescriptor(self.remote_field, reverse=False))
135 |
136 | self.opts = cls._meta
137 | if self.remote_field.related_name:
138 | related_name = force_text(self.remote_field.related_name) % {
139 | 'class': cls.__name__.lower(),
140 | 'app_label': cls._meta.app_label.lower()
141 | }
142 | self.remote_field.related_name = related_name
143 |
144 | def resolve_related_class(model, related, field):
145 | field.remote_field.model = related
146 | field.do_related_class(related, model)
147 |
148 | lazy_related_operation(resolve_related_class, cls, self.remote_field.model, field=self)
149 |
150 | def contribute_to_related_class(self, cls, related):
151 | # Internal M2Ms (i.e., those with a related name ending with '+')
152 | # and swapped models don't get a related descriptor.
153 | if not self.remote_field.is_hidden() and not related.related_model._meta.swapped:
154 | setattr(cls, self.get_reverse_accessor_name(), MultiReferenceDescriptor(self.remote_field, reverse=True))
155 |
156 | def formfield(self, **kwargs):
157 | db = kwargs.pop('using', None)
158 | defaults = {
159 | 'form_class': ModelMultipleChoiceField,
160 | 'queryset': self.related_model._default_manager.using(db),
161 | }
162 | defaults.update(kwargs)
163 | # If initial is passed in, it's a list of related objects, but the
164 | # MultipleChoiceField takes a list of IDs.
165 | if defaults.get('initial') is not None:
166 | initial = defaults['initial']
167 | if callable(initial):
168 | initial = initial()
169 | defaults['initial'] = [i._get_pk_val() for i in initial]
170 | return super(RelatedField, self).formfield(**defaults)
171 |
172 | def get_join_on(self, parent_alias, lhs_col, table_alias, rhs_col):
173 | return '%s.%s = ANY(%s.%s)' % (
174 | table_alias,
175 | rhs_col,
176 | parent_alias,
177 | lhs_col,
178 | )
179 |
180 | def get_join_on2(self, parent_alias, lhs_col, table_alias, rhs_col):
181 | return "ARRAY_APPEND(ARRAY[]::integer[], %s.%s) <@ ANY(%s.%s)" % (
182 | table_alias,
183 | rhs_col,
184 | parent_alias,
185 | lhs_col,
186 | )
187 |
188 | def resolve_related_fields(self):
189 | if len(self.from_fields) < 1 or len(self.from_fields) != len(self.to_fields):
190 | raise ValueError('Foreign Object from and to fields must be the same non-zero length')
191 | if isinstance(self.remote_field.model, six.string_types):
192 | raise ValueError('Related model %r cannot be resolved' % self.remote_field.model)
193 | related_fields = []
194 | for index in range(len(self.from_fields)):
195 | from_field_name = self.from_fields[index]
196 | to_field_name = self.to_fields[index]
197 | from_field = (self if from_field_name == 'self'
198 | else self.opts.get_field(from_field_name))
199 | to_field = (self.remote_field.model._meta.pk if to_field_name is None
200 | else self.remote_field.model._meta.get_field(to_field_name))
201 | related_fields.append((from_field, to_field))
202 | return related_fields
203 |
204 |
205 | @property
206 | def related_fields(self):
207 | if not hasattr(self, '_related_fields'):
208 | self._related_fields = self.resolve_related_fields()
209 | return self._related_fields
210 |
211 |
212 | @property
213 | def reverse_related_fields(self):
214 | return [(rhs_field, lhs_field) for lhs_field, rhs_field in self.related_fields]
215 |
216 |
217 | @property
218 | def local_related_fields(self):
219 | return tuple(lhs_field for lhs_field, rhs_field in self.related_fields)
220 |
221 |
222 | @property
223 | def foreign_related_fields(self):
224 | return tuple(rhs_field for lhs_field, rhs_field in self.related_fields if rhs_field)
225 |
226 |
227 | def get_local_related_value(self, instance):
228 | return self.get_instance_value_for_fields(instance, self.local_related_fields)
229 |
230 |
231 | def get_foreign_related_value(self, instance):
232 | return self.get_instance_value_for_fields(instance, self.foreign_related_fields)
233 |
234 |
235 | @staticmethod
236 | def get_instance_value_for_fields(instance, fields):
237 | ret = []
238 | opts = instance._meta
239 | for field in fields:
240 | # Gotcha: in some cases (like fixture loading) a model can have
241 | # different values in parent_ptr_id and parent's id. So, use
242 | # instance.pk (that is, parent_ptr_id) when asked for instance.id.
243 | if field.primary_key:
244 | possible_parent_link = opts.get_ancestor_link(field.model)
245 | if (not possible_parent_link or
246 | possible_parent_link.primary_key or
247 | possible_parent_link.model._meta.abstract):
248 | ret.append(instance.pk)
249 | continue
250 | ret.append(getattr(instance, field.attname))
251 | return tuple(ret)
252 |
253 | def get_joining_columns(self, reverse_join=False):
254 | source = self.reverse_related_fields if reverse_join else self.related_fields
255 | columns = tuple((lhs_field.column, rhs_field.column) for lhs_field, rhs_field in source)
256 | return columns
257 |
258 | def get_reverse_joining_columns(self):
259 | return self.get_joining_columns(reverse_join=True)
260 |
261 | def get_extra_descriptor_filter(self, instance):
262 | """
263 | Return an extra filter condition for related object fetching when
264 | user does 'instance.column', that is the extra filter is used in
265 | the descriptor of the field.
266 |
267 | The filter should be either a dict usable in .filter(**kwargs) call or
268 | a Q-object. The condition will be ANDed together with the relation's
269 | joining columns.
270 |
271 | A parallel method is get_extra_restriction() which is used in
272 | JOIN and subquery conditions.
273 | """
274 | return {}
275 |
276 | def get_extra_restriction(self, where_class, alias, related_alias):
277 | """
278 | Return a pair condition used for joining and subquery pushdown. The
279 | condition is something that responds to as_sql(compiler, connection)
280 | method.
281 |
282 | Note that currently referring both the 'alias' and 'related_alias'
283 | will not work in some conditions, like subquery pushdown.
284 |
285 | A parallel method is get_extra_descriptor_filter() which is used in
286 | instance.column related object fetching.
287 | """
288 | return None
289 |
290 | def get_path_info(self, filtered_relation=None):
291 | """
292 | Get path from this field to the related model.
293 | """
294 | opts = self.remote_field.model._meta
295 | from_opts = self.model._meta
296 | return [PathInfo(from_opts, opts, self.foreign_related_fields, self, False, True, filtered_relation)]
297 |
298 | def validate_item(self, obj, model=None):
299 | if not model:
300 | model = self.remote_field.model
301 | if isinstance(obj, model):
302 | obj = getattr(obj, self.remote_field.target_field.name)
303 | elif isinstance(obj, models.Model):
304 | raise TypeError(
305 | "'%s' instance expected, got %r" %
306 | (model._meta.object_name, obj)
307 | )
308 | return obj
309 |
310 | def save_form_data(self, instance, data):
311 | """
312 | For newly created instances, the column is set and saved with all the
313 | other fields when model.save() is run. So m2m can be updated along with
314 | all the other fields with one sql query.
315 | For updateing instances with model.save(), Array M2M fields are ignored to avoid sending large
316 | amounts of unneccessary data in the update query. Instead the data is set with an
317 | extra sql command via the descriptor.
318 | """
319 | if instance.pk:
320 | getattr(instance, self.name).set(data)
321 | else:
322 | objs = [self.validate_item(obj) for obj in data]
323 | setattr(instance, self.attname, objs)
324 |
325 | def get_reverse_path_info(self, filtered_relation):
326 | """
327 | Get path from the related model to this field's model.
328 | """
329 | opts = self.model._meta
330 | from_opts = self.remote_field.model._meta
331 | pathinfos = [PathInfo(from_opts, opts, (from_opts.pk,), self.remote_field, not self.unique, False, filtered_relation)]
332 | return pathinfos
333 |
334 | def get_lookup(self, lookup_name):
335 | if lookup_name == 'in':
336 | return RelatedArrayOverlap
337 | elif lookup_name == 'exact':
338 | return RelatedContainsItem
339 | elif lookup_name =='exactly':
340 | return RelatedArrayExact
341 | elif lookup_name == 'contains':
342 | return RelatedArrayContains
343 | elif lookup_name == 'contained_by':
344 | return RelatedArrayContainedBy
345 | elif lookup_name == 'overlap':
346 | return RelatedArrayOverlap
347 | elif lookup_name == 'gt':
348 | return RelatedAnyGreaterThan
349 | elif lookup_name == 'gte':
350 | return RelatedAnyGreaterThanOrEqual
351 | elif lookup_name == 'lt':
352 | return RelatedAnyLessThan
353 | elif lookup_name == 'lte':
354 | return RelatedAnyLessThanOrEqual
355 | else:
356 | raise TypeError('Related Array got invalid lookup: %s' % lookup_name)
--------------------------------------------------------------------------------