├── .coveragerc ├── .gitignore ├── .travis.yml ├── CHANGELOG ├── LICENSE ├── MANIFEST.in ├── README.rst ├── dev-requirements.txt ├── docs ├── Makefile ├── backends.rst ├── basic_usage.rst ├── conf.py ├── index.rst ├── query_syntax.rst └── storedobject.rst ├── examples ├── abstract_schemas.py └── migration.py ├── main.py ├── modularodm ├── __init__.py ├── cache.py ├── exceptions.py ├── ext │ ├── __init__.py │ └── concurrency.py ├── fields │ ├── __init__.py │ ├── abstractforeignfield.py │ ├── booleanfield.py │ ├── datetimefield.py │ ├── dictionaryfield.py │ ├── field.py │ ├── floatfield.py │ ├── foreign.py │ ├── foreignfield.py │ ├── integerfield.py │ ├── listfield.py │ ├── lists.py │ ├── objectidfield.py │ └── stringfield.py ├── frozen.py ├── query │ ├── __init__.py │ ├── query.py │ ├── querydialect.py │ └── queryset.py ├── signals.py ├── storage │ ├── __init__.py │ ├── base.py │ ├── ephemeralstorage.py │ ├── mongostorage.py │ └── picklestorage.py ├── storedobject.py ├── translators │ └── __init__.py ├── utils.py ├── validators │ └── __init__.py └── writequeue.py ├── pk_sandbox.py ├── requirements.txt ├── setup.cfg ├── setup.py ├── tasks.py ├── tests ├── __init__.py ├── backrefs │ ├── __init__.py │ ├── test_abstract_backrefs.py │ ├── test_attribute_syntax.py │ ├── test_ensure_backrefs.py │ ├── test_many_to_many.py │ ├── test_one_to_many.py │ └── test_pk_change.py ├── base.py ├── ext │ ├── __init__.py │ └── test_concurrency.py ├── fixtures.py ├── laziness │ ├── __init__.py │ └── test_lazy_load.py ├── queries │ ├── __init__.py │ ├── test_comparison_operators.py │ ├── test_foreign_queries.py │ ├── test_logical_operators.py │ ├── test_simple_queries.py │ ├── test_string_operators.py │ └── test_update_queries.py ├── test_fields.py ├── test_foreign.py ├── test_migration.py ├── test_queue.py ├── test_signals.py ├── test_storage.py ├── test_storedobject.py ├── test_subclass.py ├── utils.py └── validators │ ├── __init__.py │ ├── test_iterable_validators.py │ ├── test_numeric_validators.py │ ├── test_record_validation.py │ ├── test_type_validators.py │ ├── test_url_validation.py │ └── urlValidatorTest.json └── tox.ini /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | include = 3 | modularodm/* 4 | [report] 5 | exclude_lines = 6 | raise NotImplementedError 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # OS generated files 2 | ###################### 3 | .DS_Store 4 | .DS_Store? 5 | ._* 6 | .Spotlight-V100 7 | .Trashes 8 | Icon? 9 | ehthumbs.db 10 | Thumbs.db 11 | *.swp 12 | *~ 13 | .*~ 14 | 15 | # R 16 | ####################### 17 | .Rhistory 18 | 19 | # Python 20 | ####################### 21 | *.py[cod] 22 | *.so 23 | *.egg 24 | *.egg-info 25 | *.pkl 26 | dist 27 | build 28 | eggs 29 | parts 30 | bin 31 | var 32 | sdist 33 | develop-eggs 34 | .installed.cfg 35 | lib 36 | lib64 37 | __pycache__ 38 | pip-log.txt 39 | .coverage 40 | .tox 41 | nosetests.xml 42 | *.mo 43 | .idea 44 | .hg 45 | .hgignore 46 | .env 47 | .ropeproject 48 | 49 | # Coverage 50 | cover 51 | 52 | # Readme build 53 | README.html 54 | 55 | # Docs build 56 | _build 57 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | services: mongodb 3 | 4 | python: 5 | - "2.7" 6 | - "3.3" 7 | - "3.4" 8 | - "pypy" 9 | - "pypy3" 10 | 11 | install: 12 | - "pip install --upgrade ." 13 | 14 | script: "nosetests" 15 | -------------------------------------------------------------------------------- /CHANGELOG: -------------------------------------------------------------------------------- 1 | ********* 2 | ChangeLog 3 | ********* 4 | 5 | 0.4.0 (2016-09-20) 6 | ================== 7 | - Update the URLValidator to support unicode domain names. (thanks, @caspinelli and @acshi!) 8 | - Get rid of the unused FlaskStoredObject, letting us remove Flask as a dependency. 9 | 10 | 0.3.0 (2016-03-10) 11 | ================== 12 | - An Actual Release? Sure, why not! 13 | - This probably qualifies as a 1.0 from a code point of view, upgrade from 0.2.1 *carefully*. 14 | - Feature: Add support for negative indexing of query objects. (thanks, @chrisseto!) 15 | - Other Features/Fixes: git log 0.2.1..0.3.0 # i apologize 16 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.txt *.rst 2 | recursive-exclude tests * 3 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | *********** 2 | modular-odm 3 | *********** 4 | 5 | .. image:: https://badge.fury.io/py/modular-odm.png 6 | :target: http://badge.fury.io/py/modular-odm 7 | 8 | .. image:: https://travis-ci.org/CenterForOpenScience/modular-odm.png?branch=develop 9 | :target: https://travis-ci.org/CenterForOpenScience/modular-odm 10 | 11 | A database-agnostic Document-Object Mapper for Python. 12 | 13 | 14 | Install 15 | ======= 16 | 17 | .. code-block:: bash 18 | 19 | $ pip install modular-odm 20 | 21 | 22 | Example Usage with MongoDB 23 | ========================== 24 | 25 | Defining Models 26 | --------------- 27 | 28 | .. code-block:: python 29 | 30 | from modularodm import StoredObject, fields 31 | from modularodm.validators import MinLengthValidator, MaxLengthValidator 32 | 33 | class User(StoredObject): 34 | _meta = {"optimistic": True} 35 | _id = fields.StringField(primary=True, index=True) 36 | username = fields.StringField(required=True) 37 | password = fields.StringField(required=True, validate=[MinLengthValidator(8)]) 38 | 39 | def __repr__(self): 40 | return "".format(self.username) 41 | 42 | class Comment(StoredObject): 43 | _meta = {"optimistic": True} 44 | _id = fields.StringField(primary=True, index=True) 45 | text = fields.StringField(validate=MaxLengthValidator(500)) 46 | user = fields.ForeignField("User", backref="comments") 47 | 48 | def __repr__(self): 49 | return "".format(self.text) 50 | 51 | 52 | Setting the Storage Backend 53 | --------------------------- 54 | 55 | .. code-block:: python 56 | 57 | from pymongo import MongoClient 58 | from modularodm import storage 59 | 60 | client = MongoClient() 61 | db = client['testdb'] 62 | User.set_storage(storage.MongoStorage(db, collection="user")) 63 | Comment.set_storage(storage.MongoStorage(db, collection="comment")) 64 | 65 | Creating and Querying 66 | --------------------- 67 | 68 | .. code-block:: python 69 | 70 | >>> from modularodm.query.querydialect import DefaultQueryDialect as Q 71 | >>> u = User(username="unladenswallow", password="h0lygrai1") 72 | >>> u.save() 73 | >>> comment = Comment(text="And now for something completely different.", user=u) 74 | >>> comment2 = Comment(text="It's just a flesh wound.", user=u) 75 | >>> comment.save() 76 | True 77 | >>> comment2.save() 78 | True 79 | >>> u = User.find_one(Q("username", "eq", "unladenswallow")) 80 | >>> u.comment__comments 81 | [, ] 82 | >>> c = Comment.find(Q("text", "startswith", "And now"))[0] 83 | >>> c.text 84 | 'And now for something completely different.' 85 | 86 | For more information regarding querying syntax, please visit the related readthedocs page at . 87 | 88 | Migrations 89 | ---------- 90 | 91 | TODO 92 | 93 | 94 | *Full documentation coming soon.* 95 | 96 | Development 97 | =========== 98 | 99 | Tests require `nose `_, `invoke `_, and MongoDB. 100 | 101 | Installing MongoDB 102 | ------------------ 103 | 104 | If you are on MacOSX with `homebrew `_, run 105 | 106 | .. code-block:: bash 107 | 108 | $ brew update 109 | $ brew install mongodb 110 | 111 | Running Tests 112 | ------------- 113 | 114 | To start mongodb, run 115 | 116 | .. code-block:: bash 117 | 118 | $ invoke mongo 119 | 120 | Run all tests with 121 | 122 | .. code-block:: bash 123 | 124 | $ invoke test 125 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | nose 2 | mock 3 | tox 4 | wheel 5 | invoke>=0.12.0,<0.13.0 6 | sphinx 7 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest coverage gettext 23 | 24 | help: 25 | @echo "Please use \`make ' where is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " applehelp to make an Apple Help Book" 34 | @echo " devhelp to make HTML files and a Devhelp project" 35 | @echo " epub to make an epub" 36 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 37 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 38 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 39 | @echo " text to make text files" 40 | @echo " man to make manual pages" 41 | @echo " texinfo to make Texinfo files" 42 | @echo " info to make Texinfo files and run them through makeinfo" 43 | @echo " gettext to make PO message catalogs" 44 | @echo " changes to make an overview of all changed/added/deprecated items" 45 | @echo " xml to make Docutils-native XML files" 46 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 47 | @echo " linkcheck to check all external links for integrity" 48 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 49 | @echo " coverage to run coverage check of the documentation (if enabled)" 50 | 51 | clean: 52 | rm -rf $(BUILDDIR)/* 53 | 54 | html: 55 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 56 | @echo 57 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 58 | 59 | dirhtml: 60 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 61 | @echo 62 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 63 | 64 | singlehtml: 65 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 66 | @echo 67 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 68 | 69 | pickle: 70 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 71 | @echo 72 | @echo "Build finished; now you can process the pickle files." 73 | 74 | json: 75 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 76 | @echo 77 | @echo "Build finished; now you can process the JSON files." 78 | 79 | htmlhelp: 80 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 81 | @echo 82 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 83 | ".hhp project file in $(BUILDDIR)/htmlhelp." 84 | 85 | qthelp: 86 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 87 | @echo 88 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 89 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 90 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/Modular-ODM.qhcp" 91 | @echo "To view the help file:" 92 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/Modular-ODM.qhc" 93 | 94 | applehelp: 95 | $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp 96 | @echo 97 | @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." 98 | @echo "N.B. You won't be able to view it unless you put it in" \ 99 | "~/Library/Documentation/Help or install it in your application" \ 100 | "bundle." 101 | 102 | devhelp: 103 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 104 | @echo 105 | @echo "Build finished." 106 | @echo "To view the help file:" 107 | @echo "# mkdir -p $$HOME/.local/share/devhelp/Modular-ODM" 108 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/Modular-ODM" 109 | @echo "# devhelp" 110 | 111 | epub: 112 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 113 | @echo 114 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 115 | 116 | latex: 117 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 118 | @echo 119 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 120 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 121 | "(use \`make latexpdf' here to do that automatically)." 122 | 123 | latexpdf: 124 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 125 | @echo "Running LaTeX files through pdflatex..." 126 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 127 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 128 | 129 | latexpdfja: 130 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 131 | @echo "Running LaTeX files through platex and dvipdfmx..." 132 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 133 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 134 | 135 | text: 136 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 137 | @echo 138 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 139 | 140 | man: 141 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 142 | @echo 143 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 144 | 145 | texinfo: 146 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 147 | @echo 148 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 149 | @echo "Run \`make' in that directory to run these through makeinfo" \ 150 | "(use \`make info' here to do that automatically)." 151 | 152 | info: 153 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 154 | @echo "Running Texinfo files through makeinfo..." 155 | make -C $(BUILDDIR)/texinfo info 156 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 157 | 158 | gettext: 159 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 160 | @echo 161 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 162 | 163 | changes: 164 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 165 | @echo 166 | @echo "The overview file is in $(BUILDDIR)/changes." 167 | 168 | linkcheck: 169 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 170 | @echo 171 | @echo "Link check complete; look for any errors in the above output " \ 172 | "or in $(BUILDDIR)/linkcheck/output.txt." 173 | 174 | doctest: 175 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 176 | @echo "Testing of doctests in the sources finished, look at the " \ 177 | "results in $(BUILDDIR)/doctest/output.txt." 178 | 179 | coverage: 180 | $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage 181 | @echo "Testing of coverage in the sources finished, look at the " \ 182 | "results in $(BUILDDIR)/coverage/python.txt." 183 | 184 | xml: 185 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 186 | @echo 187 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 188 | 189 | pseudoxml: 190 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 191 | @echo 192 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 193 | -------------------------------------------------------------------------------- /docs/backends.rst: -------------------------------------------------------------------------------- 1 | .. backends: 2 | 3 | Backends 4 | ======== 5 | 6 | .. module:: modularodm.storage.base 7 | 8 | .. autoclass:: Storage 9 | :members: 10 | :undoc-members: 11 | 12 | Ephemeral 13 | --------- 14 | 15 | .. autoclass:: modularodm.storage.ephemeralstorage.EphemeralStorage 16 | :members: 17 | 18 | Mongo 19 | ----- 20 | 21 | .. autoclass:: modularodm.storage.mongostorage.MongoStorage 22 | :members: 23 | 24 | Pickle 25 | ------ 26 | 27 | .. autoclass:: modularodm.storage.picklestorage.PickleStorage 28 | :members: 29 | 30 | -------------------------------------------------------------------------------- /docs/basic_usage.rst: -------------------------------------------------------------------------------- 1 | .. basic_usage: 2 | 3 | Basic Usage 4 | =========== 5 | 6 | 7 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Ensure we get the local copy of tornado instead of what's on the standard path 2 | import os 3 | import sys 4 | sys.path.insert(0, os.path.abspath("..")) 5 | import modularodm 6 | 7 | master_doc = "index" 8 | 9 | project = "Modular-ODM" 10 | copyright = "Center For Open Science" 11 | 12 | version = release = modularodm.__version__ 13 | 14 | extensions = [ 15 | "sphinx.ext.autodoc", 16 | "sphinx.ext.coverage", 17 | "sphinx.ext.extlinks", 18 | "sphinx.ext.intersphinx", 19 | "sphinx.ext.viewcode", 20 | ] 21 | 22 | primary_domain = 'py' 23 | default_role = 'py:obj' 24 | 25 | autodoc_member_order = "bysource" 26 | autoclass_content = "both" 27 | 28 | # Without this line sphinx includes a copy of object.__init__'s docstring 29 | # on any class that doesn't define __init__. 30 | # https://bitbucket.org/birkenfeld/sphinx/issue/1337/autoclass_content-both-uses-object__init__ 31 | autodoc_docstring_signature = False 32 | 33 | coverage_skip_undoc_in_source = True 34 | coverage_ignore_modules = [] 35 | # I wish this could go in a per-module file... 36 | coverage_ignore_classes = [] 37 | 38 | coverage_ignore_functions = [] 39 | 40 | # html_favicon = 'favicon.ico' 41 | 42 | latex_documents = [ 43 | ('documentation', False), 44 | ] 45 | 46 | # HACK: sphinx has limited support for substitutions with the |version| 47 | # variable, but there doesn't appear to be any way to use this in a link 48 | # target. 49 | # http://stackoverflow.com/questions/1227037/substitutions-inside-links-in-rest-sphinx 50 | # The extlink extension can be used to do link substitutions, but it requires a 51 | # portion of the url to be literally contained in the document. Therefore, 52 | # this link must be referenced as :current_tarball:`z` 53 | extlinks = {} 54 | 55 | intersphinx_mapping = { 56 | 'python': ('https://docs.python.org/3.4', None), 57 | 'tornado': ('http://www.tornadoweb.org/en/stable/', None), 58 | 'aiohttp': ('https://aiohttp.readthedocs.org/en/v0.14.1/', None), 59 | } 60 | 61 | on_rtd = os.environ.get('READTHEDOCS', None) == 'True' 62 | 63 | # On RTD we can't import sphinx_rtd_theme, but it will be applied by 64 | # default anyway. This block will use the same theme when building locally 65 | # as on RTD. 66 | if not on_rtd: 67 | import sphinx_rtd_theme 68 | html_theme = 'sphinx_rtd_theme' 69 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 70 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Modular-ODM documentation master file, created by 2 | sphinx-quickstart on Tue Apr 14 15:27:24 2015. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Modular-ODM's documentation! 7 | ======================================= 8 | 9 | Contents: 10 | 11 | .. toctree:: 12 | :maxdepth: 3 13 | 14 | basic_usage 15 | query_syntax 16 | storedobject 17 | backends 18 | 19 | 20 | 21 | Indices and tables 22 | ================== 23 | 24 | * :ref:`genindex` 25 | * :ref:`modindex` 26 | * :ref:`search` 27 | 28 | -------------------------------------------------------------------------------- /docs/query_syntax.rst: -------------------------------------------------------------------------------- 1 | .. query_syntax: 2 | 3 | Query Syntax 4 | ============ 5 | 6 | 7 | 8 | 9 | 10 | Operators 11 | --------- 12 | 13 | Equality 14 | ++++++++ 15 | 16 | =========== =============== ===================================================== 17 | Keyword Operator Description 18 | =========== =============== ===================================================== 19 | eq equal to Match a single value, including inside a list field 20 | ne not equal to 21 | =========== =============== ===================================================== 22 | 23 | Comparison 24 | ++++++++++ 25 | 26 | =========== =============== ==================================== 27 | Keyword Operator Description 28 | =========== =============== ==================================== 29 | gt greater than 30 | gte greater than or 31 | equal to 32 | lt less than 33 | lte less than or 34 | equal to 35 | =========== =============== ==================================== 36 | 37 | Membership 38 | ++++++++++ 39 | 40 | =========== =============== ==================================== 41 | Keyword Operator Description 42 | =========== =============== ==================================== 43 | in in 44 | nin not in 45 | =========== =============== ==================================== 46 | 47 | String Comparison 48 | +++++++++++++++++ 49 | 50 | =========== =============== ==================================== 51 | Keyword Operator Description 52 | =========== =============== ==================================== 53 | startswith starts with 54 | endswith ends with 55 | contains contains 56 | icontains contains (case 57 | insensitive) 58 | =========== =============== ==================================== 59 | -------------------------------------------------------------------------------- /docs/storedobject.rst: -------------------------------------------------------------------------------- 1 | .. storedobject: 2 | 3 | StoredObject 4 | ============ 5 | 6 | .. module:: modularodm.storedobject 7 | 8 | .. autoclass:: StoredObject 9 | :members: 10 | -------------------------------------------------------------------------------- /examples/abstract_schemas.py: -------------------------------------------------------------------------------- 1 | """Example of using abstract schemas as superclasses of concrete schemas. The 2 | abstract BaseSchema class cannot be instantiated, but the Schema class can be. 3 | """ 4 | 5 | from modularodm import StoredObject, fields 6 | from bson import ObjectId 7 | 8 | class BaseSchema(StoredObject): 9 | 10 | _id = fields.StringField(default=lambda: str(ObjectId())) 11 | 12 | _meta = { 13 | 'abstract': True, 14 | } 15 | 16 | class Schema(BaseSchema): 17 | 18 | data = fields.StringField() 19 | 20 | # Instantiate Schema 21 | record = Schema(data='test') 22 | print record._id # Yields an ObjectId-like primary key 23 | 24 | # Can't instantiate abstract schema 25 | try: 26 | bad_record = BaseSchema() 27 | except TypeError as error: 28 | print error 29 | -------------------------------------------------------------------------------- /examples/migration.py: -------------------------------------------------------------------------------- 1 | import os 2 | from modularodm import StoredObject, fields, storage 3 | 4 | try: 5 | os.remove('db_migrate_sandbox.pkl') 6 | except OSError: 7 | pass 8 | 9 | my_storage = storage.PickleStorage('migrate_sandbox') 10 | 11 | class Schema1(StoredObject): 12 | 13 | _id = fields.StringField(primary=True) 14 | number = fields.IntegerField() 15 | deleted = fields.FloatField() 16 | 17 | _meta = { 18 | 'optimistic': True, 19 | 'version': 1, 20 | } 21 | 22 | Schema1.set_storage(my_storage) 23 | 24 | class Schema2(StoredObject): 25 | 26 | _id = fields.StringField(primary=True) 27 | name = fields.StringField(default='name') 28 | number = fields.IntegerField() 29 | 30 | @classmethod 31 | def _migrate(self, old, new): 32 | new.number = old.number + 1 33 | return new 34 | 35 | @classmethod 36 | def _unmigrate(cls, new, old): 37 | old.number = new.number - 1 38 | return old 39 | 40 | _meta = { 41 | 'optimistic': True, 42 | 'version': 2, 43 | 'version_of': Schema1, 44 | } 45 | 46 | Schema2.set_storage(my_storage) 47 | 48 | class Schema3(StoredObject): 49 | 50 | _id = fields.StringField(primary=True) 51 | name = fields.StringField(default='eman') 52 | number = fields.IntegerField() 53 | 54 | @classmethod 55 | def _migrate(self, old, new): 56 | new.number = old.number + 1 57 | return new 58 | 59 | @classmethod 60 | def _unmigrate(cls, new, old): 61 | old.number = new.number - 1 62 | return old 63 | 64 | _meta = { 65 | 'optimistic': True, 66 | 'version': 3, 67 | 'version_of': Schema2, 68 | } 69 | 70 | Schema3.set_storage(my_storage) 71 | 72 | # Describe migration from Schema1 -> Schema2 -> Schema3 73 | Schema3.explain_migration() 74 | 75 | # Create a record of type Schema1, then migrate to Schema3 76 | record = Schema1(number=1) 77 | record.save() 78 | migrated_record = Schema3.load(record._primary_key) 79 | print 'name', migrated_record.name # Should be "eman" 80 | print 'number', migrated_record.number # Should be 3 81 | -------------------------------------------------------------------------------- /modularodm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | __version__ = "0.4.0" 4 | 5 | from .storedobject import StoredObject 6 | 7 | from .query.querydialect import DefaultQueryDialect as Q 8 | -------------------------------------------------------------------------------- /modularodm/cache.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | def set_nested(data, value, *keys): 4 | """Assign to a nested dictionary. 5 | 6 | :param dict data: Dictionary to mutate 7 | :param value: Value to set 8 | :param list *keys: List of nested keys 9 | 10 | >>> data = {} 11 | >>> set_nested(data, 'hi', 'k0', 'k1', 'k2') 12 | >>> data 13 | {'k0': {'k1': {'k2': 'hi'}}} 14 | 15 | """ 16 | if len(keys) == 1: 17 | data[keys[0]] = value 18 | else: 19 | if keys[0] not in data: 20 | data[keys[0]] = {} 21 | set_nested(data[keys[0]], value, *keys[1:]) 22 | 23 | 24 | class Cache(object): 25 | """Simple container for storing cached data. 26 | 27 | """ 28 | def __init__(self): 29 | self.data = {} 30 | 31 | @property 32 | def raw(self): 33 | return self.data 34 | 35 | def set(self, schema, key, value): 36 | set_nested(self.data, value, schema, key) 37 | 38 | def get(self, schema, key): 39 | try: 40 | return self.data[schema][key] 41 | except KeyError: 42 | return None 43 | 44 | def pop(self, schema, key): 45 | self.data[schema].pop(key, None) 46 | 47 | def clear(self): 48 | self.__init__() 49 | 50 | def clear_schema(self, schema): 51 | self.data.pop(schema, None) 52 | 53 | def __nonzero__(self): 54 | return bool(self.data) 55 | 56 | # Python 3 57 | __bool__ = __nonzero__ 58 | -------------------------------------------------------------------------------- /modularodm/exceptions.py: -------------------------------------------------------------------------------- 1 | class ModularOdmException(Exception): 2 | """ base class from which all exceptions raised by modularodm should inherit 3 | """ 4 | pass 5 | 6 | 7 | class QueryException(ModularOdmException): 8 | """ base class for exceptions raised from query parsing or execution""" 9 | pass 10 | 11 | 12 | class KeyExistsException(QueryException): 13 | """Raised when a insert in an object with a primary key that already exists. 14 | """ 15 | pass 16 | 17 | 18 | class MultipleResultsFound(QueryException): 19 | """ Raised when multiple results match the passed query, and only a single 20 | object may be returned """ 21 | pass 22 | 23 | 24 | class NoResultsFound(QueryException): 25 | """ Raised when no results match the passed query, but one or more results 26 | must be returned. """ 27 | pass 28 | 29 | 30 | class ValidationError(ModularOdmException): 31 | """ Base class for exceptions raised during validation. Should not raised 32 | directly. """ 33 | pass 34 | 35 | 36 | class ValidationTypeError(ValidationError, TypeError): 37 | """ Raised during validation if explicit type check failed """ 38 | pass 39 | 40 | 41 | class ValidationValueError(ValidationError, ValueError): 42 | """ Raised during validation if the value of the input is unacceptable, but 43 | the type is correct """ 44 | pass 45 | 46 | 47 | class ImproperConfigurationError(ModularOdmException): 48 | """Raised if configuration options are not set correctly.""" 49 | pass 50 | 51 | class DatabaseError(ModularOdmException): 52 | '''Raised when execution of a database operation fails.''' 53 | pass 54 | -------------------------------------------------------------------------------- /modularodm/ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cos-archives/modular-odm/8a34891892b8af69b21fdc46701c91763a5c1cf9/modularodm/ext/__init__.py -------------------------------------------------------------------------------- /modularodm/ext/concurrency.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import weakref 4 | import collections 5 | import six 6 | from werkzeug.local import LocalProxy 7 | 8 | from modularodm.cache import Cache 9 | from modularodm.writequeue import WriteQueue 10 | 11 | 12 | # Dictionary of proxies, keyed on base schema, class variable label, and 13 | # extension-specific key (e.g. a Flask request). The final key uses a weak 14 | # reference to avoid memory leaks. 15 | proxies = collections.defaultdict( 16 | lambda: collections.defaultdict(weakref.WeakKeyDictionary) 17 | ) 18 | 19 | # Class variables on ``StoredObject`` that should be request-local under 20 | # concurrent use 21 | proxied_members = { 22 | '_cache': Cache, 23 | '_object_cache': Cache, 24 | 'queue': WriteQueue, 25 | } 26 | 27 | 28 | def proxy_factory(BaseSchema, label, ProxiedClass, get_key): 29 | """Create a proxy to a class instance stored in ``proxies``. 30 | 31 | :param class BaseSchema: Base schema (e.g. ``StoredObject``) 32 | :param str label: Name of class variable to set 33 | :param class ProxiedClass: Class to get or create 34 | :param function get_key: Extension-specific key function; may return e.g. 35 | the current Flask request 36 | 37 | """ 38 | def local(): 39 | key = get_key() 40 | try: 41 | return proxies[BaseSchema][label][key] 42 | except KeyError: 43 | proxies[BaseSchema][label][key] = ProxiedClass() 44 | return proxies[BaseSchema][label][key] 45 | return LocalProxy(local) 46 | 47 | 48 | def with_proxies(proxy_map, get_key): 49 | """Class decorator factory; adds proxy class variables to target class. 50 | 51 | :param dict proxy_map: Mapping between class variable labels and proxied 52 | classes 53 | :param function get_key: Extension-specific key function; may return e.g. 54 | the current Flask request 55 | 56 | """ 57 | def wrapper(cls): 58 | for label, ProxiedClass in six.iteritems(proxy_map): 59 | proxy = proxy_factory(cls, label, ProxiedClass, get_key) 60 | setattr(cls, label, proxy) 61 | return cls 62 | return wrapper 63 | -------------------------------------------------------------------------------- /modularodm/fields/__init__.py: -------------------------------------------------------------------------------- 1 | from .field import Field 2 | 3 | from .booleanfield import BooleanField 4 | from .datetimefield import DateTimeField 5 | from .dictionaryfield import DictionaryField 6 | from .floatfield import FloatField 7 | from .foreignfield import ForeignField 8 | from .abstractforeignfield import AbstractForeignField 9 | from .integerfield import IntegerField 10 | from .listfield import ListField 11 | from .objectidfield import ObjectIdField 12 | from .stringfield import StringField 13 | 14 | from .lists import List, ForeignList, AbstractForeignList 15 | -------------------------------------------------------------------------------- /modularodm/fields/abstractforeignfield.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from modularodm.fields.foreign import BaseForeignField 4 | from .lists import AbstractForeignList 5 | 6 | 7 | class AbstractForeignField(BaseForeignField): 8 | 9 | _list_class = AbstractForeignList 10 | _is_foreign = True 11 | _uniform_translator = False 12 | 13 | def __init__(self, *args, **kwargs): 14 | super(AbstractForeignField, self).__init__(*args, **kwargs) 15 | self._backref_field_name = kwargs.get('backref', None) 16 | self._is_foreign = True 17 | self._is_abstract = True 18 | 19 | def get_schema_class(self, schema): 20 | return self._schema_class.get_collection(schema) 21 | 22 | def get_primary_field(self, schema): 23 | schema_class = self.get_schema_class(schema) 24 | return schema_class._fields[schema_class._primary_name] 25 | 26 | def get_foreign_object(self, value): 27 | return self.get_schema_class(value[1])\ 28 | .load(value[0]) 29 | 30 | def to_storage(self, value, translator=None): 31 | 32 | if value is None: 33 | return None 34 | if not hasattr(value, '__iter__'): 35 | value = (value._primary_key, value._name) 36 | return ( 37 | self.get_primary_field(value[1])\ 38 | .to_storage(value[0], translator), 39 | value[1] 40 | ) 41 | 42 | def from_storage(self, value, translator=None): 43 | 44 | if value is None: 45 | return None 46 | return ( 47 | self.get_primary_field(value[1])\ 48 | .from_storage(value[0], translator), 49 | value[1] 50 | ) 51 | 52 | def _to_primary_key(self, value): 53 | 54 | if value is None: 55 | return None 56 | if hasattr(value, '_primary_key'): 57 | return value._primary_key 58 | 59 | def __set__(self, instance, value, safe=False, literal=False): 60 | if hasattr(value, '_primary_key'): 61 | value = ( 62 | value._primary_key, 63 | value._name 64 | ) 65 | elif isinstance(value, tuple) or isinstance(value, list): 66 | if len(value) != 2: 67 | raise ValueError('Value must have length 2') 68 | elif value is not None: 69 | raise TypeError('Value must be StoredObject, tuple, or None') 70 | super(AbstractForeignField, self).__set__( 71 | instance, value, safe=safe, literal=literal 72 | ) 73 | 74 | def __get__(self, instance, owner, check_dirty=True): 75 | value = super(AbstractForeignField, self).__get__( 76 | instance, None, check_dirty 77 | ) 78 | if value is None: 79 | return None 80 | return self.get_foreign_object(value) 81 | -------------------------------------------------------------------------------- /modularodm/fields/booleanfield.py: -------------------------------------------------------------------------------- 1 | from . import Field 2 | from ..validators import validate_boolean 3 | 4 | class BooleanField(Field): 5 | 6 | validate = validate_boolean 7 | data_type = bool -------------------------------------------------------------------------------- /modularodm/fields/datetimefield.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import datetime 4 | 5 | from modularodm import signals 6 | from modularodm.fields import Field 7 | from modularodm.validators import validate_datetime 8 | 9 | 10 | DEFAULT_NOW = datetime.datetime.utcnow 11 | 12 | 13 | def default_or_callable(value): 14 | if value is True: 15 | return DEFAULT_NOW 16 | if callable(value): 17 | return value 18 | raise ValueError('Value must be True or callable') 19 | 20 | 21 | class DateTimeField(Field): 22 | 23 | validate = validate_datetime 24 | data_type = datetime.datetime 25 | mutable = True 26 | 27 | def __init__(self, *args, **kwargs): 28 | 29 | super(DateTimeField, self).__init__(*args, **kwargs) 30 | 31 | auto_now = kwargs.pop('auto_now', False) 32 | auto_now_add = kwargs.pop('auto_now_add', False) 33 | if auto_now and auto_now_add: 34 | raise ValueError('Cannot use auto_now and auto_now_add on the ' 35 | 'same field.') 36 | 37 | # 38 | if (auto_now or auto_now_add) and 'editable' not in kwargs: 39 | self._editable = False 40 | self.lazy_default = False 41 | 42 | # 43 | if auto_now: 44 | self._auto_now = default_or_callable(auto_now) 45 | elif auto_now_add: 46 | self._default = default_or_callable(auto_now_add) 47 | 48 | def subscribe(self, sender=None): 49 | self.update_auto_now = signals.before_save.connect( 50 | self.update_auto_now, 51 | sender=sender, 52 | ) 53 | 54 | def update_auto_now(self, cls, instance): 55 | if getattr(self, '_auto_now', None): 56 | self.data[instance] = self._auto_now() 57 | -------------------------------------------------------------------------------- /modularodm/fields/dictionaryfield.py: -------------------------------------------------------------------------------- 1 | from ..fields import Field 2 | 3 | 4 | class DictionaryField(Field): 5 | 6 | data_type = dict 7 | mutable = True 8 | 9 | def __init__(self, *args, **kwargs): 10 | super(DictionaryField, self).__init__(*args, **kwargs) 11 | self._default = kwargs.get('default', {}) 12 | -------------------------------------------------------------------------------- /modularodm/fields/field.py: -------------------------------------------------------------------------------- 1 | import weakref 2 | import warnings 3 | import copy 4 | import six 5 | 6 | from modularodm import exceptions 7 | from modularodm.query.querydialect import DefaultQueryDialect as Q 8 | from .lists import List 9 | 10 | 11 | def print_arg(arg): 12 | if isinstance(arg, six.string_types): 13 | return '"' + arg + '"' 14 | return arg 15 | 16 | 17 | class Field(object): 18 | 19 | default = None 20 | base_class = None 21 | _list_class = List 22 | mutable = False 23 | lazy_default = True 24 | _uniform_translator = True 25 | 26 | def __repr__(self): 27 | return '{cls}({kwargs})'.format( 28 | cls=self.__class__.__name__, 29 | kwargs=', '.join('{}={}'.format(key, print_arg(val)) for key, val in self._kwargs.items()) 30 | ) 31 | 32 | def subscribe(self, sender=None): 33 | pass 34 | 35 | def _to_comparable(self): 36 | return { 37 | k : v 38 | for k, v in self.__dict__.items() 39 | if k not in ['data', '_translators', '_schema_class'] 40 | } 41 | 42 | def __eq__(self, other): 43 | return self._to_comparable() == other._to_comparable() 44 | 45 | def __ne__(self, other): 46 | return not self.__eq__(other) 47 | 48 | def _prepare_validators(self, _validate): 49 | 50 | if hasattr(_validate, '__iter__'): 51 | 52 | # List of callable validators 53 | validate = [] 54 | for validator in _validate: 55 | if hasattr(validator, '__call__'): 56 | validate.append(validator) 57 | else: 58 | raise TypeError('Validator lists must be lists of callables.') 59 | 60 | elif hasattr(_validate, '__call__'): 61 | 62 | # Single callable validator 63 | validate = _validate 64 | 65 | elif type(_validate) == bool: 66 | 67 | # Boolean validator 68 | validate = _validate 69 | 70 | else: 71 | 72 | # Invalid validator type 73 | raise TypeError('Validators must be callables, lists of callables, or booleans.') 74 | 75 | return _validate, validate 76 | 77 | def __init__(self, *args, **kwargs): 78 | 79 | self._args = args 80 | self._kwargs = kwargs 81 | self._translators = {} 82 | 83 | # Pointer to containing ListField 84 | # Set in StoredObject.ObjectMeta 85 | self._list_container = None 86 | 87 | self.data = weakref.WeakKeyDictionary() 88 | 89 | self._validate, self.validate = \ 90 | self._prepare_validators(kwargs.get('validate', False)) 91 | 92 | self._default = kwargs.get('default', self.default) 93 | self._is_primary = kwargs.get('primary', False) 94 | self._list = kwargs.get('list', False) 95 | self._required = kwargs.get('required', False) 96 | self._unique = kwargs.get('unique', False) 97 | self._editable = kwargs.get('editable', True) 98 | self._index = kwargs.get('index', self._is_primary) 99 | self._is_foreign = False 100 | 101 | # Fields added by ``ObjectMeta`` 102 | self._field_name = None 103 | 104 | def do_validate(self, value, obj): 105 | 106 | # Check if required 107 | if value is None: 108 | if getattr(self, '_required', None): 109 | raise exceptions.ValidationError('Value <{0}> is required.'.format(self._field_name)) 110 | return True 111 | 112 | # Check if unique 113 | if value is not None and self._unique: 114 | unique_query = Q(self._field_name, 'eq', value) 115 | # If object has primary key, don't crash if unique value is 116 | # already associated with its key 117 | if obj._is_loaded: 118 | unique_query = unique_query & Q(obj._primary_name, 'ne', obj._primary_key) 119 | if obj.find(unique_query).limit(1).count(): 120 | raise exceptions.ValidationValueError('Value must be unique') 121 | 122 | # Field-level validation 123 | cls = self.__class__ 124 | if hasattr(cls, 'validate') and self.validate is not False: 125 | cls.validate(value) 126 | 127 | # Schema-level validation 128 | if self._validate and hasattr(self, 'validate'): 129 | if hasattr(self.validate, '__iter__'): 130 | for validator in self.validate: 131 | validator(value) 132 | elif hasattr(self.validate, '__call__'): 133 | self.validate(value) 134 | 135 | # Success 136 | return True 137 | 138 | def _gen_default(self): 139 | if callable(self._default): 140 | return self._default() 141 | return copy.deepcopy(self._default) 142 | 143 | def _get_translate_func(self, translator, direction): 144 | try: 145 | return self._translators[(translator, direction)] 146 | except KeyError: 147 | method_name = '%s_%s' % (direction, self.data_type.__name__) 148 | default_name = '%s_default' % (direction,) 149 | try: 150 | method = getattr(translator, method_name) 151 | except AttributeError: 152 | method = getattr(translator, default_name) 153 | self._translators[(translator, direction)] = method 154 | return method 155 | 156 | def to_storage(self, value, translator=None): 157 | translator = translator or self._schema_class._translator 158 | if value is None: 159 | return translator.null_value 160 | method = self._get_translate_func(translator, 'to') 161 | value = value if method is None else method(value) 162 | if self.mutable: 163 | return copy.deepcopy(value) 164 | return value 165 | 166 | def from_storage(self, value, translator=None): 167 | translator = translator or self._schema_class._translator 168 | if value == translator.null_value: 169 | return None 170 | method = self._get_translate_func(translator, 'from') 171 | value = value if method is None else method(value) 172 | if self.mutable: 173 | return copy.deepcopy(value) 174 | return value 175 | 176 | def _pre_set(self, instance, safe=False): 177 | if not self._editable and not safe: 178 | raise AttributeError('Field cannot be edited.') 179 | if instance._detached: 180 | warnings.warn('Accessing a detached record.') 181 | 182 | def __set__(self, instance, value, safe=False, literal=False): 183 | self._pre_set(instance, safe=safe) 184 | if self.mutable: 185 | value = copy.deepcopy(value) 186 | self.data[instance] = value 187 | 188 | def _touch(self, instance): 189 | 190 | # Reload if dirty 191 | if instance._dirty: 192 | instance._dirty = False 193 | instance.reload() 194 | 195 | # Impute default and return 196 | try: 197 | self.data[instance] 198 | except KeyError: 199 | self.data[instance] = self._gen_default() 200 | 201 | def __get__(self, instance, owner, check_dirty=True): 202 | 203 | # Warn if detached 204 | if instance._detached: 205 | warnings.warn('Accessing a detached record.') 206 | 207 | # Reload if dirty 208 | self._touch(instance) 209 | 210 | # Impute default and return 211 | try: 212 | return self.data[instance] 213 | except KeyError: 214 | default = self._gen_default() 215 | self.data[instance] = default 216 | return default 217 | 218 | def _get_underlying_data(self, instance): 219 | """Return data from raw data store, rather than overridden 220 | __get__ methods. Should NOT be overwritten. 221 | """ 222 | self._touch(instance) 223 | return self.data.get(instance, None) 224 | 225 | def __delete__(self, instance): 226 | self.data.pop(instance, None) 227 | -------------------------------------------------------------------------------- /modularodm/fields/floatfield.py: -------------------------------------------------------------------------------- 1 | from . import Field 2 | from ..validators import validate_float 3 | 4 | class FloatField(Field): 5 | 6 | validate = validate_float 7 | data_type = float -------------------------------------------------------------------------------- /modularodm/fields/foreign.py: -------------------------------------------------------------------------------- 1 | # -*- coding utf-8 -*- 2 | 3 | import abc 4 | import six 5 | 6 | from modularodm import signals 7 | from modularodm.fields import Field 8 | 9 | 10 | @six.add_metaclass(abc.ABCMeta) 11 | class BaseForeignField(Field): 12 | 13 | @abc.abstractmethod 14 | def get_foreign_object(self, value): 15 | pass 16 | 17 | def update_backrefs(self, instance, cached_value, current_value): 18 | 19 | if self._backref_field_name is None: 20 | return 21 | 22 | if cached_value: 23 | cached_object = self.get_foreign_object(cached_value) 24 | if cached_object: 25 | cached_object._remove_backref( 26 | self._backref_field_name, 27 | instance, 28 | self._field_name 29 | ) 30 | 31 | if current_value: 32 | current_value._set_backref( 33 | self._backref_field_name, 34 | self._field_name, 35 | instance 36 | ) 37 | 38 | def update_backrefs_callback(self, cls, instance, fields_changed, cached_data): 39 | 40 | if self._field_name not in fields_changed: 41 | return 42 | 43 | cached_value = cached_data.get(self._field_name) 44 | current_value = getattr(instance, self._field_name) 45 | 46 | self.update_backrefs(instance, cached_value, current_value) 47 | 48 | def subscribe(self, sender=None): 49 | self.update_backrefs_callback = signals.save.connect( 50 | self.update_backrefs_callback, 51 | sender=sender, 52 | ) 53 | -------------------------------------------------------------------------------- /modularodm/fields/foreignfield.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from modularodm import exceptions 4 | 5 | from modularodm.fields.foreign import BaseForeignField 6 | from .lists import ForeignList 7 | 8 | 9 | class ForeignField(BaseForeignField): 10 | 11 | _list_class = ForeignList 12 | 13 | def __init__(self, *args, **kwargs): 14 | 15 | super(ForeignField, self).__init__(*args, **kwargs) 16 | 17 | self._backref_field_name = kwargs.get('backref', None) 18 | self._base_class_reference = args[0] 19 | self._base_class = None 20 | self._is_foreign = True 21 | self._is_abstract = False 22 | 23 | def get_foreign_object(self, value): 24 | return self.base_class.load(value) 25 | 26 | def to_storage(self, value, translator=None): 27 | 28 | if value is None: 29 | return value 30 | try: 31 | value_to_store = value._primary_key 32 | except AttributeError: 33 | value_to_store = value 34 | _foreign_pn = self.base_class._primary_name 35 | return self.base_class._fields[_foreign_pn].to_storage(value_to_store, translator) 36 | 37 | def from_storage(self, value, translator=None): 38 | 39 | if value is None: 40 | return None 41 | _foreign_pn = self.base_class._primary_name 42 | _foreign_pk = self.base_class._fields[_foreign_pn].from_storage(value, translator) 43 | return _foreign_pk 44 | 45 | def _to_primary_key(self, value): 46 | """ 47 | Return primary key; if value is StoredObject, verify 48 | that it is loaded. 49 | 50 | """ 51 | if value is None: 52 | return None 53 | if isinstance(value, self.base_class): 54 | if not value._is_loaded: 55 | raise exceptions.DatabaseError('Record must be loaded.') 56 | return value._primary_key 57 | 58 | return self.base_class._to_primary_key(value) 59 | # return self.base_class._check_pk_type(value) 60 | 61 | @property 62 | def mutable(self): 63 | return self.base_class._fields[self.base_class._primary_name].mutable 64 | 65 | @property 66 | def base_class(self): 67 | if self._base_class: 68 | return self._base_class 69 | if isinstance(self._base_class_reference, type): 70 | self._base_class = self._base_class_reference 71 | else: 72 | try: 73 | self._base_class = self._schema_class.get_collection( 74 | self._base_class_reference 75 | ) 76 | except KeyError: 77 | raise exceptions.ModularOdmException( 78 | 'Unknown schema "{0}"'.format( 79 | self._base_class_reference 80 | ) 81 | ) 82 | return self._base_class 83 | 84 | def __set__(self, instance, value, safe=False, literal=False): 85 | # if instance._detached: 86 | # warnings.warn('Accessing a detached record.') 87 | value_to_set = value if literal else self._to_primary_key(value) 88 | super(ForeignField, self).__set__( 89 | instance, 90 | value_to_set, 91 | safe=safe 92 | ) 93 | 94 | def __get__(self, instance, owner): 95 | # if instance._detached: 96 | # warnings.warn('Accessing a detached record.') 97 | primary_key = super(ForeignField, self).__get__(instance, None) 98 | if primary_key is None: 99 | return 100 | return self.base_class.load(primary_key) 101 | -------------------------------------------------------------------------------- /modularodm/fields/integerfield.py: -------------------------------------------------------------------------------- 1 | from . import Field 2 | from ..validators import validate_integer 3 | 4 | class IntegerField(Field): 5 | 6 | # default = None 7 | validate = validate_integer 8 | data_type = int -------------------------------------------------------------------------------- /modularodm/fields/listfield.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import copy 4 | 5 | import six 6 | 7 | from modularodm import signals 8 | from ..fields import Field 9 | from ..validators import validate_list 10 | 11 | 12 | class ListField(Field): 13 | 14 | validate = validate_list 15 | 16 | def __init__(self, field_instance, **kwargs): 17 | 18 | super(ListField, self).__init__(**kwargs) 19 | 20 | self._list_validate, self.list_validate = self._prepare_validators(kwargs.get('list_validate', False)) 21 | 22 | # ListField is a list of the following (e.g., ForeignFields) 23 | self._field_instance = field_instance 24 | self._is_foreign = field_instance._is_foreign 25 | self._is_abstract = getattr(field_instance, '_is_abstract', False) 26 | self._uniform_translator = field_instance._uniform_translator 27 | 28 | # Descriptor data is this type of list 29 | self._list_class = self._field_instance._list_class 30 | 31 | # Descriptor data is this type of list object, instantiated as our 32 | # default 33 | if self._default: 34 | default = self._default() if six.callable(self._default) else self._default 35 | if ( 36 | not hasattr(default, '__iter__') 37 | or isinstance(default, dict) 38 | or isinstance(default, six.string_types) 39 | ): 40 | raise TypeError( 41 | 'Default value for list fields must be a list; received {0}'.format( 42 | type(self._default) 43 | ) 44 | ) 45 | else: 46 | default = None 47 | 48 | #if (self._default 49 | # and not hasattr(self._default, '__iter__') 50 | # or isinstance(self._default, dict)): 51 | # raise TypeError( 52 | # 'Default value for list fields must be a list; received {0}'.format( 53 | # type(self._default) 54 | # ) 55 | # ) 56 | 57 | # Default is a callable that returns an empty instance of the list class 58 | # Avoids the need to deepcopy default values for lists, which will break 59 | # e.g. when validators contain (un-copyable) regular expressions. 60 | self._default = lambda: self._list_class(default, base_class=self._field_instance.base_class) 61 | 62 | # Fields added by ``ObjectMeta`` 63 | self._field_name = None 64 | 65 | def subscribe(self, sender=None): 66 | self.update_backrefs_callback = signals.save.connect( 67 | self.update_backrefs_callback, 68 | sender=sender, 69 | ) 70 | 71 | def __set__(self, instance, value, safe=False, literal=False): 72 | self._pre_set(instance, safe=safe) 73 | # if isinstance(value, self._default.__class__): 74 | # self.data[instance] = value 75 | if hasattr(value, '__iter__') and not isinstance(value, six.string_types): 76 | if literal: 77 | self.data[instance] = self._list_class(value, base_class=self._field_instance.base_class, literal=True) 78 | else: 79 | self.data[instance] = self._list_class(base_class=self._field_instance.base_class) 80 | self.data[instance].extend(value) 81 | else: 82 | self.data[instance] = value 83 | 84 | def do_validate(self, value, obj): 85 | 86 | inst = self._field_instance 87 | if getattr(inst, 'validate', False) or inst._unique or inst._required: 88 | # Child-level validation 89 | for part in value: 90 | self._field_instance.do_validate(part, obj) 91 | 92 | # Field-level list validation 93 | if hasattr(self.__class__, 'validate'): 94 | self.__class__.validate(value) 95 | 96 | # Schema-level list validation 97 | if self._list_validate: 98 | if hasattr(self.list_validate, '__iter__'): 99 | for validator in self.list_validate: 100 | validator(value) 101 | elif hasattr(self.list_validate, '__call__'): 102 | self.list_validate(value) 103 | 104 | # Success 105 | return True 106 | 107 | def _get_translate_func(self, translator, direction): 108 | try: 109 | return self._translators[(translator, direction)] 110 | except KeyError: 111 | if self._is_foreign: 112 | base_class = self._field_instance.base_class 113 | primary_field = base_class._fields[base_class._primary_name] 114 | method = primary_field._get_translate_func(translator, direction) 115 | else: 116 | method = self._field_instance._get_translate_func(translator, direction) 117 | self._translators[(translator, direction)] = method 118 | return method 119 | 120 | def to_storage(self, value, translator=None): 121 | translator = translator or self._schema_class._translator 122 | if value: 123 | if hasattr(value, '_to_data'): 124 | value = value._to_data() 125 | if self._uniform_translator: 126 | method = self._get_translate_func(translator, 'to') 127 | if method is not None or translator.null_value is not None: 128 | value = [ 129 | translator.null_value if item is None 130 | else 131 | item if method is None 132 | else 133 | method(item) 134 | for item in value 135 | ] 136 | if self._field_instance.mutable: 137 | return copy.deepcopy(value) 138 | return copy.copy(value) 139 | else: 140 | return [ 141 | self._field_instance.to_storage(item) 142 | for item in value 143 | ] 144 | return [] 145 | 146 | def from_storage(self, value, translator=None): 147 | translator = translator or self._schema_class._translator 148 | if value: 149 | if self._uniform_translator: 150 | method = self._get_translate_func(translator, 'from') 151 | if method is not None or translator.null_value is not None: 152 | value = [ 153 | None if item is translator.null_value 154 | else 155 | item if method is None 156 | else 157 | method(item) 158 | for item in value 159 | ] 160 | if self._field_instance.mutable: 161 | return copy.deepcopy(value) 162 | return copy.copy(value) 163 | else: 164 | return [ 165 | self._field_instance.from_storage(item) 166 | for item in value 167 | ] 168 | return [] 169 | 170 | def update_backrefs(self, instance, cached_value, current_value): 171 | if self._field_instance._backref_field_name is None: 172 | return 173 | 174 | for item in current_value: 175 | if self._field_instance.to_storage(item) not in cached_value: 176 | self._field_instance.update_backrefs(instance, None, item) 177 | 178 | for item in cached_value: 179 | if self._field_instance.from_storage(item) not in current_value: 180 | self._field_instance.update_backrefs(instance, item, None) 181 | 182 | def update_backrefs_callback(self, cls, instance, fields_changed, cached_data): 183 | 184 | if not hasattr(self._field_instance, 'update_backrefs'): 185 | return 186 | 187 | if self._field_name not in fields_changed: 188 | return 189 | 190 | cached_value = cached_data.get(self._field_name, []) 191 | current_value = getattr(instance, self._field_name, []) 192 | 193 | self.update_backrefs(instance, cached_value, current_value) 194 | 195 | @property 196 | def base_class(self): 197 | if self._field_instance is None: 198 | return 199 | if not hasattr(self._field_instance, 'base_class'): 200 | return 201 | return self._field_instance.base_class 202 | -------------------------------------------------------------------------------- /modularodm/fields/lists.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import abc 4 | import six 5 | 6 | from modularodm.query.querydialect import DefaultQueryDialect as Q 7 | 8 | 9 | @six.add_metaclass(abc.ABCMeta) 10 | class List(list): 11 | 12 | def __init__(self, value=None, literal=False, **kwargs): 13 | 14 | value = value or [] 15 | self._base_class = kwargs.get('base_class', None) 16 | 17 | if literal: 18 | super(List, self).__init__(value) 19 | else: 20 | super(List, self).__init__() 21 | self.extend(value) 22 | 23 | 24 | class BaseForeignList(List): 25 | 26 | @abc.abstractmethod 27 | def _to_primary_keys(self): 28 | pass 29 | 30 | @abc.abstractmethod 31 | def _from_value(self, value): 32 | pass 33 | 34 | def _to_data(self): 35 | return list(super(BaseForeignList, self).__iter__()) 36 | 37 | def __iter__(self): 38 | if self: 39 | return (self[idx] for idx in range(len(self))) 40 | return iter([]) 41 | 42 | def __setitem__(self, key, value): 43 | super(BaseForeignList, self).__setitem__(key, self._from_value(value)) 44 | 45 | def insert(self, index, value): 46 | super(BaseForeignList, self).insert(index, self._from_value(value)) 47 | 48 | def append(self, value): 49 | super(BaseForeignList, self).append(self._from_value(value)) 50 | 51 | def extend(self, iterable): 52 | for item in iterable: 53 | self.append(item) 54 | 55 | def remove(self, value): 56 | super(BaseForeignList, self).remove(self._from_value(value)) 57 | 58 | 59 | class ForeignList(BaseForeignList): 60 | 61 | def _from_value(self, value): 62 | return self._base_class._to_primary_key(value) 63 | 64 | def _to_primary_keys(self): 65 | return self._to_data() 66 | 67 | def __reversed__(self): 68 | return ForeignList( 69 | super(ForeignList, self).__reversed__(), 70 | base_class=self._base_class 71 | ) 72 | 73 | def __getitem__(self, item): 74 | result = super(ForeignList, self).__getitem__(item) 75 | if isinstance(item, slice): 76 | return ForeignList(result, base_class=self._base_class) 77 | return self._base_class.load(result) 78 | 79 | def __getslice__(self, i, j): 80 | result = super(ForeignList, self).__getslice__(i, j) 81 | return ForeignList(result, base_class=self._base_class) 82 | 83 | def __contains__(self, item): 84 | keys = self._to_primary_keys() 85 | if isinstance(item, self._base_class): 86 | return item._primary_key in keys 87 | if isinstance(item, self._base_class._primary_type): 88 | return item in keys 89 | return False 90 | 91 | def index(self, value, start=None, stop=None): 92 | start = 0 if start is None else start 93 | stop = len(self) if stop is None else stop 94 | keys = self._to_primary_keys() 95 | if isinstance(value, self._base_class): 96 | return keys.index(value._primary_key, start, stop) 97 | if isinstance(value, self._base_class._primary_type): 98 | return keys.index(value, start, stop) 99 | raise ValueError('{0} is not in list'.format(value)) 100 | 101 | def find(self, query=None): 102 | combined_query = Q( 103 | self._base_class._primary_name, 104 | 'in', 105 | self._to_primary_keys() 106 | ) 107 | if query is not None: 108 | combined_query = combined_query & query 109 | return self._base_class.find(combined_query) 110 | 111 | 112 | class AbstractForeignList(BaseForeignList): 113 | 114 | def _from_value(self, value): 115 | if hasattr(value, '_primary_key'): 116 | return ( 117 | value._primary_key, 118 | value._name 119 | ) 120 | return value 121 | 122 | def _to_primary_keys(self): 123 | return [ 124 | item[0] 125 | for item in self._to_data() 126 | ] 127 | 128 | def __reversed__(self): 129 | return AbstractForeignList( 130 | super(AbstractForeignList, self).__reversed__() 131 | ) 132 | 133 | def get_foreign_object(self, value): 134 | from modularodm import StoredObject 135 | return StoredObject.get_collection(value[1])\ 136 | .load(value[0]) 137 | 138 | def __getitem__(self, item): 139 | result = super(AbstractForeignList, self).__getitem__(item) 140 | if isinstance(item, slice): 141 | return AbstractForeignList(result) 142 | return self.get_foreign_object(result) 143 | 144 | def __getslice__(self, i, j): 145 | result = super(AbstractForeignList, self).__getslice__(i, j) 146 | return AbstractForeignList(result) 147 | 148 | def __contains__(self, item): 149 | keys = self._to_primary_keys() 150 | if hasattr(item, '_primary_key'): 151 | return item._primary_key in keys 152 | elif isinstance(item, tuple): 153 | return item[0] in keys 154 | return item in keys 155 | 156 | def index(self, value, start=None, stop=None): 157 | start = 0 if start is None else start 158 | stop = len(self) if stop is None else stop 159 | keys = self._to_primary_keys() 160 | if hasattr(value, '_primary_key'): 161 | return keys.index(value._primary_key, start, stop) 162 | elif isinstance(value, tuple): 163 | return keys.index(value[0], start, stop) 164 | else: 165 | try: 166 | return keys.index(value, start, stop) 167 | except ValueError: 168 | raise ValueError('{0} not in list'.format(value)) 169 | -------------------------------------------------------------------------------- /modularodm/fields/objectidfield.py: -------------------------------------------------------------------------------- 1 | from . import Field 2 | from ..validators import validate_objectid 3 | 4 | from bson import ObjectId 5 | 6 | class ObjectIdField(Field): 7 | 8 | validate = validate_objectid 9 | data_type = ObjectId -------------------------------------------------------------------------------- /modularodm/fields/stringfield.py: -------------------------------------------------------------------------------- 1 | from six import string_types 2 | 3 | from . import Field 4 | from ..validators import validate_string 5 | 6 | class StringField(Field): 7 | 8 | data_type = string_types[0] 9 | validate = validate_string -------------------------------------------------------------------------------- /modularodm/frozen.py: -------------------------------------------------------------------------------- 1 | 2 | import collections 3 | 4 | def freeze(value): 5 | """ Cast value to its frozen counterpart. """ 6 | if isinstance(value, list): 7 | return FrozenList(*value) 8 | if isinstance(value, dict): 9 | return FrozenDict(**value) 10 | return value 11 | 12 | def thaw(value): 13 | if isinstance(value, FrozenList): 14 | return value.thaw() 15 | if isinstance(value, FrozenDict): 16 | return value.thaw() 17 | return value 18 | 19 | class FrozenDict(collections.Mapping): 20 | """ Immutable dictionary. """ 21 | def __init__(self, **kwargs): 22 | self.__data = {key : freeze(value) for key, value in kwargs.items()} 23 | 24 | def thaw(self): 25 | return {key : thaw(value) for key, value in self.__data.items()} 26 | 27 | def __eq__(self, other): 28 | if not isinstance(other, FrozenDict): 29 | return self.thaw() == other 30 | return super(FrozenDict, self).__eq__(self, other) 31 | 32 | def __getitem__(self, item): 33 | return self.__data[item] 34 | 35 | def __iter__(self): 36 | return iter(self.__data) 37 | 38 | def __len__(self): 39 | return len(self.__data) 40 | 41 | def __repr__(self): 42 | return repr(self.__data) 43 | 44 | class FrozenList(collections.Sequence): 45 | """ Immutable list. """ 46 | def __init__(self, *args): 47 | self.__data = [freeze(value) for value in args] 48 | 49 | def thaw(self): 50 | return [thaw(value) for value in self.__data] 51 | 52 | def __eq__(self, other): 53 | if not isinstance(other, FrozenList): 54 | return self.thaw() == other 55 | return super(FrozenList, self).__eq__(self, other) 56 | 57 | def __getitem__(self, item): 58 | return self.__data[item] 59 | 60 | def __iter__(self): 61 | return iter(self.__data) 62 | 63 | def __len__(self): 64 | return len(self.__data) 65 | 66 | def __repr__(self): 67 | return repr(self.__data) -------------------------------------------------------------------------------- /modularodm/query/__init__.py: -------------------------------------------------------------------------------- 1 | from .query import QueryBase, RawQuery, QueryGroup -------------------------------------------------------------------------------- /modularodm/query/query.py: -------------------------------------------------------------------------------- 1 | 2 | class QueryBase(object): 3 | 4 | def __or__(self, other): 5 | return QueryGroup('or', self, other) 6 | 7 | def __and__(self, other): 8 | return QueryGroup('and', self, other) 9 | 10 | def __invert__(self): 11 | return QueryGroup('not', self) 12 | 13 | class QueryGroup(QueryBase): 14 | 15 | def __init__(self, operator, *args): 16 | 17 | self.operator = operator 18 | 19 | self.nodes = [] 20 | for node in args: 21 | if not isinstance(node, QueryBase): 22 | raise TypeError('Nodes must be Query objects.') 23 | if isinstance(node, QueryGroup) and node.operator == operator: 24 | self.nodes += node.nodes 25 | else: 26 | self.nodes.append(node) 27 | 28 | def __repr__(self): 29 | 30 | return '{0}({1})'.format( 31 | self.operator.upper(), 32 | ', '.join(repr(node) for node in self.nodes) 33 | ) 34 | 35 | class RawQuery(QueryBase): 36 | 37 | def __init__(self, attribute, operator, argument): 38 | 39 | self.attribute = attribute 40 | self.operator = operator 41 | self.argument = argument 42 | 43 | def __repr__(self): 44 | 45 | return 'RawQuery({}, {}, {})'.format( 46 | self.attribute, 47 | self.operator, 48 | self.argument 49 | ) 50 | -------------------------------------------------------------------------------- /modularodm/query/querydialect.py: -------------------------------------------------------------------------------- 1 | 2 | from .query import RawQuery 3 | 4 | class QueryDialect(object): 5 | 6 | pass 7 | 8 | class DefaultQueryDialect(QueryDialect, RawQuery): 9 | 10 | pass 11 | 12 | class DictQueryDialect(QueryDialect): 13 | 14 | pass 15 | 16 | class DunderQueryDialect(QueryDialect): 17 | 18 | pass 19 | 20 | # ''' 21 | # __Q(foo='bar', baz__startswith='fez') 22 | # __Q(__or(foo='bar', baz__startswith='fez')) 23 | # __Q(__and(__or(foo='bar', baz__startswith='fez'), qux__ne='mom')) 24 | # ''' -------------------------------------------------------------------------------- /modularodm/query/queryset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import abc 4 | import six 5 | 6 | 7 | @six.add_metaclass(abc.ABCMeta) 8 | class BaseQuerySet(object): 9 | 10 | _NEGATIVE_INDEXING = False 11 | 12 | def __init__(self, schema, data=None): 13 | 14 | self.schema = schema 15 | self.primary = schema._primary_name 16 | self.data = data 17 | 18 | def __getitem__(self, index): 19 | if isinstance(index, slice): 20 | if index.step: 21 | raise IndexError('Slice steps not supported') 22 | if (index.start is not None and index.start < 0) or (index.stop is not None and index.stop < 0): 23 | raise IndexError('Negative indexing not supported') 24 | if index.stop is not None and index.stop < index.start: 25 | raise IndexError('Stop index must be greater than start index') 26 | elif not self.__class__._NEGATIVE_INDEXING and index < 0: 27 | raise IndexError('Negative indexing not supported') 28 | return self._do_getitem(index) 29 | 30 | @abc.abstractmethod 31 | def _do_getitem(self, index): 32 | pass 33 | 34 | @abc.abstractmethod 35 | def __iter__(self): 36 | pass 37 | 38 | @abc.abstractmethod 39 | def __len__(self): 40 | pass 41 | 42 | @abc.abstractmethod 43 | def count(self): 44 | pass 45 | 46 | @abc.abstractmethod 47 | def sort(self, *keys): 48 | pass 49 | 50 | @abc.abstractmethod 51 | def offset(self, n): 52 | pass 53 | 54 | @abc.abstractmethod 55 | def limit(self, n): 56 | pass 57 | -------------------------------------------------------------------------------- /modularodm/signals.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import blinker 4 | 5 | 6 | signals = blinker.Namespace() 7 | 8 | load = signals.signal('load') 9 | 10 | before_save = signals.signal('before_save') 11 | save = signals.signal('save') 12 | -------------------------------------------------------------------------------- /modularodm/storage/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import Storage 2 | from .mongostorage import MongoStorage 3 | from .picklestorage import PickleStorage 4 | from .ephemeralstorage import EphemeralStorage 5 | -------------------------------------------------------------------------------- /modularodm/storage/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import six 3 | import time 4 | import random 5 | import itertools 6 | from functools import wraps 7 | 8 | from ..translators import DefaultTranslator 9 | from modularodm.exceptions import KeyExistsException 10 | 11 | 12 | class Logger(object): 13 | 14 | def __init__(self): 15 | 16 | self.listening = False 17 | self.events = [] 18 | self.xtra = [] 19 | 20 | def listen(self, xtra=None): 21 | 22 | self.xtra.append(xtra) 23 | 24 | if self.listening: 25 | return False 26 | 27 | self.listening = True 28 | self.events = [] 29 | return True 30 | 31 | def record_event(self, event): 32 | 33 | if self.listening: 34 | self.events.append(event) 35 | 36 | def report(self, sort_func=None): 37 | 38 | out = {} 39 | 40 | if sort_func is None: 41 | sort_func = lambda e: e.func.__name__ 42 | 43 | heard = sorted(self.events, key=sort_func) 44 | 45 | for key, group in itertools.groupby(heard, sort_func): 46 | group = list(group) 47 | num_events = len(group) 48 | total_time = sum([event.elapsed_time for event in group]) 49 | out[key] = (num_events, total_time) 50 | 51 | return out 52 | 53 | def pop(self): 54 | 55 | self.xtra.pop() 56 | 57 | def clear(self): 58 | 59 | self.listening = False 60 | self.events = [] 61 | 62 | 63 | class LogEvent(object): 64 | 65 | def __init__(self, func, start_time, stop_time, xtra=None): 66 | 67 | self.func = func 68 | self.start_time = start_time 69 | self.stop_time = stop_time 70 | self.elapsed_time = stop_time - start_time 71 | self.xtra = xtra 72 | 73 | def __repr__(self): 74 | 75 | return 'LogEvent("{func}", {start_time}, {stop_time}, {xtra})'.format( 76 | **self.__dict__ 77 | ) 78 | 79 | 80 | def logify(func): 81 | 82 | @wraps(func) 83 | def wrapped(this, *args, **kwargs): 84 | 85 | # Note: Copy value of `this.logger.listening` here in the event that 86 | # this value is changed externally during the decorated function call. 87 | # TODO: Verify that this produces valid output for concurrent requests 88 | listening = this.logger.listening 89 | 90 | if listening: 91 | start_time = time.time() 92 | 93 | out = func(this, *args, **kwargs) 94 | 95 | if listening: 96 | stop_time = time.time() 97 | # TODO: This is a temporary fix for a suspected concurrency issue. 98 | xtra = this.logger.xtra[-1] if this.logger.xtra else None 99 | this.logger.record_event( 100 | LogEvent( 101 | func, 102 | start_time, 103 | stop_time, 104 | xtra 105 | ) 106 | ) 107 | 108 | return out 109 | 110 | return wrapped 111 | 112 | 113 | class StorageMeta(abc.ABCMeta): 114 | 115 | def __new__(mcs, name, bases, dct): 116 | 117 | # Decorate methods 118 | for key, value in dct.items(): 119 | if hasattr(value, '__call__') \ 120 | and not isinstance(value, type) \ 121 | and not key.startswith('_'): 122 | dct[key] = logify(value) 123 | 124 | # Run super-metaclass __new__ 125 | return super(StorageMeta, mcs).__new__(mcs, name, bases, dct) 126 | 127 | 128 | @six.add_metaclass(StorageMeta) 129 | class Storage(object): 130 | """Abstract base class for storage objects. Subclasses (e.g. 131 | :class:`~modularodm.storage.picklestorage.PickleStorage`, 132 | :class:`~modularodm.storage.mongostorage.MongoStorage`, etc.) 133 | must define insert, update, get, remove, flush, and find_all methods. 134 | """ 135 | translator = DefaultTranslator() 136 | logger = Logger() 137 | 138 | def _ensure_index(self, key): 139 | pass 140 | 141 | # todo allow custom id generator 142 | # todo increment n on repeated failures 143 | def _generate_random_id(self, n=5): 144 | """Generated random alphanumeric key. 145 | 146 | :param n: Number of characters in random key 147 | """ 148 | alphabet = '23456789abcdefghijkmnpqrstuvwxyz' 149 | return ''.join(random.sample(alphabet, n)) 150 | 151 | def _optimistic_insert(self, primary_name, value, n=5): 152 | """Attempt to insert with randomly generated key until insert 153 | is successful. 154 | 155 | :param str primary_name: The name of the primary key. 156 | :param dict value: The dictionary representation of the record. 157 | :param n: Number of characters in random key 158 | """ 159 | while True: 160 | try: 161 | key = self._generate_random_id(n) 162 | value[primary_name] = key 163 | self.insert(primary_name, key, value) 164 | break 165 | except KeyExistsException: 166 | pass 167 | return key 168 | 169 | @abc.abstractmethod 170 | def insert(self, primary_name, key, value): 171 | """Insert a new record. 172 | 173 | :param str primary_name: Name of primary key 174 | :param key: The value of the primary key 175 | :param dict value: The dictionary of attribute:value pairs 176 | """ 177 | pass 178 | 179 | @abc.abstractmethod 180 | def update(self, query, data): 181 | """Update multiple records with new data. 182 | 183 | :param query: A query object. 184 | :param dict data: Dictionary of key:value pairs. 185 | """ 186 | pass 187 | 188 | @abc.abstractmethod 189 | def get(self, primary_name, key): 190 | """Get a single record. 191 | 192 | :param str primary_name: The name of the primary key. 193 | :param key: The value of the primary key. 194 | """ 195 | pass 196 | 197 | @abc.abstractmethod 198 | def remove(self, query=None): 199 | """Remove records. 200 | """ 201 | pass 202 | 203 | @abc.abstractmethod 204 | def flush(self): 205 | """Flush the database.""" 206 | pass 207 | 208 | @abc.abstractmethod 209 | def find_one(self, query=None, **kwargs): 210 | """ Gets a single object from the collection. 211 | 212 | If no matching documents are found, raises `NoResultsFound`. 213 | If >1 matching documents are found, raises `MultipleResultsFound`. 214 | 215 | :params: One or more `Query` or `QuerySet` objects may be passed 216 | 217 | :returns: The selected document 218 | """ 219 | pass 220 | 221 | @abc.abstractmethod 222 | def find(self, query=None, **kwargs): 223 | """ 224 | Return a generator of query results. Takes optional `by_pk` keyword 225 | argument; if true, return keys rather than 226 | values. 227 | 228 | :param query: 229 | 230 | :return: a generator of :class:`~.storedobject.StoredObject` instances 231 | """ 232 | pass 233 | -------------------------------------------------------------------------------- /modularodm/storage/ephemeralstorage.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | from .picklestorage import PickleStorage 4 | 5 | try: 6 | import cpickle as pickle 7 | except ImportError: 8 | import pickle 9 | 10 | 11 | class EphemeralStorage(PickleStorage): 12 | def __init__(self, *args, **kwargs): 13 | self.store = {} 14 | self.fp = BytesIO() 15 | 16 | def flush(self): 17 | pickle.dump(self.store, self.fp, -1) 18 | -------------------------------------------------------------------------------- /modularodm/storage/mongostorage.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*_ 2 | 3 | import re 4 | import pymongo 5 | 6 | from .base import Storage 7 | from ..query.queryset import BaseQuerySet 8 | from ..query.query import QueryGroup 9 | from ..query.query import RawQuery 10 | from modularodm.exceptions import ( 11 | KeyExistsException, 12 | MultipleResultsFound, 13 | NoResultsFound, 14 | ) 15 | 16 | 17 | # From mongoengine.queryset.transform 18 | COMPARISON_OPERATORS = ('ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod', 19 | 'all', 'size', 'exists', 'not', 'elemMatch') 20 | # GEO_OPERATORS = ('within_distance', 'within_spherical_distance', 21 | # 'within_box', 'within_polygon', 'near', 'near_sphere', 22 | # 'max_distance', 'geo_within', 'geo_within_box', 23 | # 'geo_within_polygon', 'geo_within_center', 24 | # 'geo_within_sphere', 'geo_intersects') 25 | STRING_OPERATORS = ('contains', 'icontains', 'startswith', 26 | 'istartswith', 'endswith', 'iendswith', 27 | 'exact', 'iexact') 28 | # CUSTOM_OPERATORS = ('match',) 29 | # MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS + 30 | # STRING_OPERATORS + CUSTOM_OPERATORS) 31 | 32 | # UPDATE_OPERATORS = ('set', 'unset', 'inc', 'dec', 'pop', 'push', 33 | # 'push_all', 'pull', 'pull_all', 'add_to_set', 34 | # 'set_on_insert') 35 | 36 | 37 | # Adapted from mongoengine.fields 38 | def prepare_query_value(op, value): 39 | 40 | if op.lstrip('i') in ('startswith', 'endswith', 'contains', 'exact'): 41 | flags = 0 42 | if op.startswith('i'): 43 | flags = re.IGNORECASE 44 | op = op.lstrip('i') 45 | 46 | regex = r'%s' 47 | if op == 'startswith': 48 | regex = r'^%s' 49 | elif op == 'endswith': 50 | regex = r'%s$' 51 | elif op == 'exact': 52 | regex = r'^%s$' 53 | 54 | # escape unsafe characters which could lead to a re.error 55 | value = re.escape(value) 56 | value = re.compile(regex % value, flags) 57 | 58 | return value 59 | 60 | 61 | # TODO: Test me 62 | def translate_query(query=None, mongo_query=None): 63 | """ 64 | 65 | """ 66 | mongo_query = mongo_query or {} 67 | 68 | if isinstance(query, RawQuery): 69 | attribute, operator, argument = \ 70 | query.attribute, query.operator, query.argument 71 | 72 | if operator == 'eq': 73 | mongo_query[attribute] = argument 74 | 75 | elif operator in COMPARISON_OPERATORS: 76 | mongo_operator = '$' + operator 77 | if attribute not in mongo_query: 78 | mongo_query[attribute] = {} 79 | mongo_query[attribute][mongo_operator] = argument 80 | 81 | elif operator in STRING_OPERATORS: 82 | mongo_operator = '$regex' 83 | mongo_regex = prepare_query_value(operator, argument) 84 | if attribute not in mongo_query: 85 | mongo_query[attribute] = {} 86 | mongo_query[attribute][mongo_operator] = mongo_regex 87 | 88 | elif isinstance(query, QueryGroup): 89 | 90 | if query.operator == 'and': 91 | return {'$and': [translate_query(node) for node in query.nodes]} 92 | 93 | elif query.operator == 'or': 94 | return {'$or' : [translate_query(node) for node in query.nodes]} 95 | 96 | elif query.operator == 'not': 97 | # Hack: A nor A == not A 98 | subquery = translate_query(query.nodes[0]) 99 | return {'$nor' : [subquery, subquery]} 100 | 101 | else: 102 | raise ValueError('QueryGroup operator must be , , or .') 103 | 104 | elif query is None: 105 | return {} 106 | 107 | else: 108 | raise TypeError('Query must be a QueryGroup or Query object.') 109 | 110 | return mongo_query 111 | 112 | 113 | class MongoQuerySet(BaseQuerySet): 114 | 115 | _NEGATIVE_INDEXING = True 116 | 117 | def __init__(self, schema, cursor): 118 | super(MongoQuerySet, self).__init__(schema) 119 | self.data = cursor 120 | self._order = [('_id', 1)] # Default sorting 121 | 122 | def _do_getitem(self, index, raw=False): 123 | if isinstance(index, slice): 124 | return MongoQuerySet(self.schema, self.data.clone()[index]) 125 | if index < 0: 126 | clone = self.data.clone().sort([(o[0], o[1] * -1) for o in self._order]) 127 | result = clone[(index * -1) - 1] 128 | else: 129 | result = self.data[index] 130 | if raw: 131 | return result[self.primary] 132 | return self.schema.load(data=result) 133 | 134 | def __iter__(self, raw=False): 135 | cursor = self.data.clone() 136 | if raw: 137 | return [each[self.primary] for each in cursor] 138 | return (self.schema.load(data=each) for each in cursor) 139 | 140 | def __len__(self): 141 | return self.data.count(with_limit_and_skip=True) 142 | 143 | count = __len__ 144 | 145 | def get_key(self, index): 146 | return self.__getitem__(index, raw=True) 147 | 148 | def get_keys(self): 149 | return list(self.__iter__(raw=True)) 150 | 151 | def sort(self, *keys): 152 | sort_key = [] 153 | 154 | for key in keys: 155 | 156 | if key.startswith('-'): 157 | key = key.lstrip('-') 158 | sign = pymongo.DESCENDING 159 | else: 160 | sign = pymongo.ASCENDING 161 | 162 | sort_key.append((key, sign)) 163 | 164 | self._order = sort_key 165 | self.data = self.data.sort(sort_key) 166 | return self 167 | 168 | def offset(self, n): 169 | self.data = self.data.skip(n) 170 | return self 171 | 172 | def limit(self, n): 173 | self.data = self.data.limit(n) 174 | return self 175 | 176 | 177 | class MongoStorage(Storage): 178 | """Wrap a MongoDB collection. Note: `store` is a property instead of an 179 | attribute to handle passing `db` as a proxy. 180 | 181 | :param Database db: 182 | :param str collection: 183 | """ 184 | QuerySet = MongoQuerySet 185 | 186 | def __init__(self, db, collection): 187 | self.db = db 188 | self.collection = collection 189 | 190 | @property 191 | def store(self): 192 | return self.db[self.collection] 193 | 194 | def _ensure_index(self, key): 195 | self.store.ensure_index(key) 196 | 197 | def find(self, query=None, **kwargs): 198 | mongo_query = translate_query(query) 199 | return self.store.find(mongo_query) 200 | 201 | def find_one(self, query=None, **kwargs): 202 | mongo_query = translate_query(query) 203 | matches = self.store.find(mongo_query).limit(2) 204 | 205 | if matches.count() == 1: 206 | return matches[0] 207 | 208 | if matches.count() == 0: 209 | raise NoResultsFound() 210 | 211 | raise MultipleResultsFound( 212 | 'Query for find_one must return exactly one result; ' 213 | 'returned {0}'.format(matches.count()) 214 | ) 215 | 216 | def get(self, primary_name, key): 217 | return self.store.find_one({primary_name : key}) 218 | 219 | def insert(self, primary_name, key, value): 220 | if primary_name not in value: 221 | value = value.copy() 222 | value[primary_name] = key 223 | try: 224 | self.store.insert(value) 225 | except pymongo.errors.DuplicateKeyError: 226 | raise KeyExistsException 227 | 228 | def update(self, query, data): 229 | mongo_query = translate_query(query) 230 | 231 | # Field "_id" shouldn't appear in both search and update queries; else 232 | # MongoDB will raise a "Mod on _id not allowed" error 233 | if '_id' in mongo_query: 234 | update_data = {k: v for k, v in data.items() if k != '_id'} 235 | else: 236 | update_data = data 237 | update_query = {'$set': update_data} 238 | 239 | self.store.update( 240 | mongo_query, 241 | update_query, 242 | upsert=False, 243 | multi=True, 244 | ) 245 | 246 | def remove(self, query=None): 247 | mongo_query = translate_query(query) 248 | self.store.remove(mongo_query) 249 | 250 | def flush(self): 251 | pass 252 | -------------------------------------------------------------------------------- /modularodm/storage/picklestorage.py: -------------------------------------------------------------------------------- 1 | # -*- coding utf-8 -*- 2 | 3 | import os 4 | import copy 5 | 6 | import six 7 | 8 | from .base import Storage 9 | from ..query.queryset import BaseQuerySet 10 | from ..query.query import QueryGroup 11 | from ..query.query import RawQuery 12 | 13 | from modularodm.utils import DirtyField 14 | from modularodm.exceptions import ( 15 | KeyExistsException, 16 | MultipleResultsFound, 17 | NoResultsFound, 18 | ) 19 | 20 | try: 21 | import cpickle as pickle 22 | except ImportError: 23 | import pickle 24 | 25 | 26 | def _eq(data, test): 27 | if isinstance(data, list): 28 | return test in data 29 | return data == test 30 | 31 | operators = { 32 | 33 | 'eq': _eq, 34 | 35 | 'ne': lambda data, test: data != test, 36 | 'gt': lambda data, test: data > test, 37 | 'gte': lambda data, test: data >= test, 38 | 'lt': lambda data, test: data < test, 39 | 'lte': lambda data, test: data <= test, 40 | 'in': lambda data, test: data in test, 41 | 'nin': lambda data, test: data not in test, 42 | 43 | 'startswith': lambda data, test: data.startswith(test), 44 | 'endswith': lambda data, test: data.endswith(test), 45 | 'contains': lambda data, test: test in data, 46 | 'icontains': lambda data, test: test.lower() in data.lower(), 47 | 48 | } 49 | 50 | 51 | class PickleQuerySet(BaseQuerySet): 52 | 53 | _sort = DirtyField(None) 54 | _offset = DirtyField(None) 55 | _limit = DirtyField(None) 56 | 57 | def __init__(self, schema, data): 58 | 59 | super(PickleQuerySet, self).__init__(schema) 60 | 61 | self._data = list(data) 62 | self._dirty = True 63 | 64 | self.data = [] 65 | 66 | def _eval(self): 67 | 68 | if self._dirty: 69 | 70 | self.data = self._data[:] 71 | 72 | if self._sort is not None: 73 | 74 | for key in self._sort[::-1]: 75 | 76 | if key.startswith('-'): 77 | reverse = True 78 | key = key.lstrip('-') 79 | else: 80 | reverse = False 81 | 82 | self.data = sorted( 83 | self.data, 84 | key=lambda record: record[key], 85 | reverse=reverse 86 | ) 87 | 88 | if self._offset is not None: 89 | self.data = self.data[self._offset:] 90 | 91 | if self._limit is not None: 92 | self.data = self.data[:self._limit] 93 | 94 | self._dirty = False 95 | 96 | return self 97 | 98 | def _do_getitem(self, index, raw=False): 99 | self._eval() 100 | if isinstance(index, slice): 101 | return PickleQuerySet(self.schema, self.data[index]) 102 | key = self.data[index][self.primary] 103 | result = self.data[index] 104 | if raw: 105 | return result[self.primary] 106 | return self.schema.load(data=result) 107 | 108 | def __iter__(self, raw=False): 109 | self._eval() 110 | if raw: 111 | return [each[self.primary] for each in self.data] 112 | return (self.schema.load(data=each) for each in self.data) 113 | 114 | def __len__(self): 115 | self._eval() 116 | return len(self.data) 117 | 118 | count = __len__ 119 | 120 | def get_key(self, index): 121 | return self.__getitem__(index, raw=True) 122 | 123 | def get_keys(self): 124 | return list(self.__iter__(raw=True)) 125 | 126 | def sort(self, *keys): 127 | """ Iteratively sort data by keys in reverse order. """ 128 | self._sort = keys 129 | return self 130 | 131 | def offset(self, n): 132 | self._offset = n 133 | return self 134 | 135 | def limit(self, n): 136 | self._limit = n 137 | return self 138 | 139 | 140 | class PickleStorage(Storage): 141 | """ Storage backend using pickle. """ 142 | 143 | QuerySet = PickleQuerySet 144 | 145 | def __init__(self, collection_name, prefix='db_', ext='pkl'): 146 | """Build pickle file name and load data if exists. 147 | 148 | :param collection_name: Collection name 149 | :param prefix: File prefix. 150 | :param ext: File extension. 151 | 152 | """ 153 | # Build filename 154 | filename = collection_name + '.' + ext 155 | if prefix: 156 | self.filename = prefix + filename 157 | else: 158 | self.filename = filename 159 | 160 | # Initialize empty store 161 | self.store = {} 162 | 163 | # Load file if exists 164 | if os.path.exists(self.filename): 165 | with open(self.filename, 'rb') as fp: 166 | data = fp.read() 167 | self.store = pickle.loads(data) 168 | 169 | def _delete_file(self): 170 | try: 171 | os.remove(self.filename) 172 | except OSError: 173 | pass 174 | 175 | def insert(self, primary_name, key, value): 176 | if key not in self.store: 177 | self.store[key] = copy.deepcopy(value) 178 | self.flush() 179 | else: 180 | msg = 'Key ({key}) already exists'.format(key=key) 181 | raise KeyExistsException(msg) 182 | 183 | def update(self, query, data): 184 | data = copy.deepcopy(data) 185 | for pk in self.find(query, by_pk=True): 186 | for key, value in data.items(): 187 | self.store[pk][key] = value 188 | 189 | def get(self, primary_name, key): 190 | data = self.store.get(key) 191 | if data is not None: 192 | return copy.deepcopy(data) 193 | 194 | def _remove_by_pk(self, key, flush=True): 195 | """Retrieve value from store. 196 | 197 | :param key: Key 198 | 199 | """ 200 | try: 201 | del self.store[key] 202 | except Exception as error: 203 | pass 204 | if flush: 205 | self.flush() 206 | 207 | def remove(self, query=None): 208 | for key in self.find(query, by_pk=True): 209 | self._remove_by_pk(key, flush=False) 210 | self.flush() 211 | 212 | def flush(self): 213 | with open(self.filename, 'wb') as fp: 214 | pickle.dump(self.store, fp, -1) 215 | 216 | def find_one(self, query=None, **kwargs): 217 | results = list(self.find(query)) 218 | if len(results) == 1: 219 | return results[0] 220 | elif len(results) == 0: 221 | raise NoResultsFound() 222 | else: 223 | raise MultipleResultsFound( 224 | 'Query for find_one must return exactly one result; ' 225 | 'returned {0}'.format(len(results)) 226 | ) 227 | 228 | def _match(self, value, query): 229 | 230 | if isinstance(query, QueryGroup): 231 | 232 | matches = [self._match(value, node) for node in query.nodes] 233 | 234 | if query.operator == 'and': 235 | return all(matches) 236 | elif query.operator == 'or': 237 | return any(matches) 238 | elif query.operator == 'not': 239 | return not any(matches) 240 | else: 241 | raise ValueError('QueryGroup operator must be , , or .') 242 | 243 | elif isinstance(query, RawQuery): 244 | attribute, operator, argument = \ 245 | query.attribute, query.operator, query.argument 246 | 247 | return operators[operator](value[attribute], argument) 248 | 249 | else: 250 | raise TypeError('Query must be a QueryGroup or Query object.') 251 | 252 | def find(self, query=None, **kwargs): 253 | if query is None: 254 | for key, value in six.iteritems(self.store): 255 | yield value 256 | else: 257 | # TODO: Making this a generator breaks it, since it can change 258 | for key, value in list(six.iteritems(self.store)): 259 | if self._match(value, query): 260 | if kwargs.get('by_pk'): 261 | yield key 262 | else: 263 | yield value 264 | -------------------------------------------------------------------------------- /modularodm/translators/__init__.py: -------------------------------------------------------------------------------- 1 | from dateutil import parser as dateparser 2 | from bson import ObjectId 3 | 4 | class DefaultTranslator(object): 5 | 6 | null_value = None 7 | 8 | to_default = None 9 | from_default = None 10 | 11 | class JSONTranslator(DefaultTranslator): 12 | 13 | def to_datetime(self, value): 14 | return str(value) 15 | 16 | def from_datetime(self, value): 17 | return dateparser.parse(value) 18 | 19 | def to_ObjectId(self, value): 20 | return str(value) 21 | 22 | def from_ObjectId(self, value): 23 | return ObjectId(value) 24 | 25 | class StringTranslator(JSONTranslator): 26 | 27 | null_value = 'none' 28 | 29 | def to_default(self, value): 30 | return str(value) 31 | 32 | def from_int(self, value): 33 | return int(value) 34 | 35 | def from_float(self, value): 36 | return float(value) 37 | 38 | def from_bool(self, value): 39 | return bool(value) -------------------------------------------------------------------------------- /modularodm/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import weakref 4 | 5 | 6 | # TODO: Test me @jmcarp 7 | class CallbackField(object): 8 | 9 | def __init__(self, default, callback): 10 | self.data = weakref.WeakKeyDictionary() 11 | self.default = default 12 | self.callback = callback 13 | 14 | def __get__(self, instance, owner): 15 | try: 16 | return self.data[instance] 17 | except KeyError: 18 | return self.default 19 | 20 | def __set__(self, instance, value): 21 | current = self.__get__(instance, None) 22 | self.data[instance] = value 23 | self.callback(instance, value, current) 24 | 25 | 26 | def set_dirty_factory(field='_dirty'): 27 | def set_dirty(instance, new, current): 28 | if new != current: 29 | setattr(instance, field, True) 30 | return set_dirty 31 | 32 | 33 | class DirtyField(CallbackField): 34 | 35 | def __init__(self, default, field='_dirty'): 36 | super(DirtyField, self).__init__(default, set_dirty_factory(field)) 37 | -------------------------------------------------------------------------------- /modularodm/validators/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import six 3 | from six.moves.urllib.parse import urlsplit, urlunsplit 4 | 5 | from modularodm.exceptions import ( 6 | ValidationError, 7 | ValidationTypeError, 8 | ValidationValueError, 9 | ) 10 | 11 | from bson import ObjectId 12 | 13 | class Validator(object): 14 | 15 | def __eq__(self, other): 16 | return self.__dict__ == other.__dict__ 17 | 18 | class TypeValidator(Validator): 19 | 20 | def _as_list(self, value): 21 | 22 | if isinstance(value, (tuple, list)): 23 | return value 24 | return [value] 25 | 26 | def __init__(self, allowed_types, forbidden_types=None): 27 | self.allowed_types = self._as_list(allowed_types) 28 | self.forbidden_types = self._as_list(forbidden_types) if forbidden_types else [] 29 | 30 | def __call__(self, value): 31 | 32 | for ftype in self.forbidden_types: 33 | if isinstance(value, ftype): 34 | self._raise(value) 35 | 36 | for atype in self.allowed_types: 37 | if isinstance(value, atype): 38 | return 39 | 40 | self._raise(value) 41 | 42 | def _raise(self, value): 43 | 44 | raise ValidationTypeError( 45 | 'Received invalid value {} of type {}'.format( 46 | value, type(value) 47 | ) 48 | ) 49 | 50 | validate_string = TypeValidator(six.string_types) 51 | validate_integer = TypeValidator( 52 | allowed_types=int, 53 | forbidden_types=bool 54 | ) 55 | validate_float = TypeValidator(float) 56 | validate_boolean = TypeValidator(bool) 57 | validate_objectid = TypeValidator(ObjectId) 58 | 59 | from ..fields.lists import List 60 | validate_list = TypeValidator(List) 61 | 62 | import datetime 63 | validate_datetime = TypeValidator(datetime.datetime) 64 | 65 | # Adapted from Django RegexValidator 66 | import re 67 | class RegexValidator(Validator): 68 | 69 | def __init__(self, regex=None, flags=0): 70 | 71 | if regex is not None: 72 | self.regex = re.compile(regex, flags=flags) 73 | 74 | def __call__(self, value): 75 | 76 | if not self.regex.findall(value): 77 | raise ValidationError( 78 | u'Value must match regex {0} and flags {1}; received value <{2}>'.format( 79 | self.regex.pattern, 80 | self.regex.flags, 81 | value 82 | ) 83 | ) 84 | 85 | # Adapted from Django URLValidator v1.10 86 | # https://docs.djangoproject.com/en/1.10/_modules/django/core/validators/#URLValidator 87 | class URLValidator(RegexValidator): 88 | ul = u'\u00a1-\uffff' # unicode letters range, must be a unicode string 89 | 90 | # IP patterns 91 | ipv4_re = r'(?:25[0-5]|2[0-4]\d|[0-1]?\d?\d)(?:\.(?:25[0-5]|2[0-4]\d|[0-1]?\d?\d)){3}' 92 | ipv6_re = r'\[[0-9a-f:\.]+\]' # (simple regex, validated later) 93 | 94 | # Host patterns 95 | hostname_re = r'[a-z' + ul + r'0-9](?:[a-z' + ul + r'0-9-]{0,61}[a-z' + ul + r'0-9])?' 96 | # Max length for domain name labels is 63 characters per RFC 1034 sec. 3.1 97 | domain_re = r'(?:\.(?!-)[a-z' + ul + r'0-9-]{1,63}(? ACE 144 | except UnicodeError: # invalid domain part 145 | raise e 146 | url = urlunsplit((scheme, netloc, path, query, fragment)) 147 | super(URLValidator, self).__call__(url) 148 | else: 149 | raise 150 | else: 151 | # Now verify IPv6 in the netloc part 152 | host_match = re.search(r'^\[(.+)\](?::\d{2,5})?$', urlsplit(value).netloc) 153 | if host_match: 154 | potential_ip = host_match.groups()[0] 155 | try: 156 | validate_ipv6_address(potential_ip) 157 | except ValidationError: 158 | raise ValidationError(self.message, code=self.code) 159 | url = value 160 | 161 | # The maximum length of a full host name is 253 characters per RFC 1034 162 | # section 3.1. It's defined to be 255 bytes or less, but this includes 163 | # one byte for the length of the name and one byte for the trailing dot 164 | # that's used to indicate absolute names in DNS. 165 | if len(urlsplit(value).netloc) > 253: 166 | raise ValidationError(self.message, code=self.code) 167 | 168 | class BaseValidator(Validator): 169 | 170 | compare = lambda self, a, b: a is not b 171 | clean = lambda self, x: x 172 | message = 'Ensure this value is %(limit_value)s (it is %(show_value)s).' 173 | code = 'limit_value' 174 | 175 | def __init__(self, limit_value): 176 | self.limit_value = limit_value 177 | 178 | def __call__(self, value): 179 | cleaned = self.clean(value) 180 | params = {'limit_value': self.limit_value, 'show_value': cleaned} 181 | if self.compare(cleaned, self.limit_value): 182 | raise ValidationValueError(self.message.format(**params)) 183 | 184 | 185 | class MaxValueValidator(BaseValidator): 186 | 187 | compare = lambda self, a, b: a > b 188 | message = 'Ensure this value is less than or equal to {limit_value}.' 189 | code = 'max_value' 190 | 191 | 192 | class MinValueValidator(BaseValidator): 193 | 194 | compare = lambda self, a, b: a < b 195 | message = 'Ensure this value is greater than or equal to {limit_value}.' 196 | code = 'min_value' 197 | 198 | 199 | class MinLengthValidator(BaseValidator): 200 | 201 | compare = lambda self, a, b: a < b 202 | clean = lambda self, x: len(x) 203 | message = 'Ensure this value has length of at least {limit_value} (it has length {show_value}).' 204 | code = 'min_length' 205 | 206 | 207 | class MaxLengthValidator(BaseValidator): 208 | 209 | compare = lambda self, a, b: a > b 210 | clean = lambda self, x: len(x) 211 | message = 'Ensure this value has length of at most {limit_value} (it has length {show_value}).' 212 | code = 'max_length' 213 | -------------------------------------------------------------------------------- /modularodm/writequeue.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import logging 4 | import collections 5 | 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class WriteAction(object): 11 | 12 | def __init__(self, method, *args, **kwargs): 13 | if not callable(method): 14 | raise ValueError('Argument `method` must be callable') 15 | self.method = method 16 | # Note: `args` and `kwargs` must not be mutated after an action is 17 | # enqueued and before it is committed, else awful things can happen 18 | self.args = args 19 | self.kwargs = kwargs 20 | 21 | def execute(self): 22 | return self.method(*self.args, **self.kwargs) 23 | 24 | def __repr__(self): 25 | return '{0}(*{1}, **{2})'.format( 26 | self.method.__name__, 27 | self.args, 28 | self.kwargs 29 | ) 30 | 31 | 32 | class WriteQueue(object): 33 | 34 | def __init__(self): 35 | self.active = False 36 | self.actions = collections.deque() 37 | 38 | def start(self): 39 | if self.active: 40 | logger.warn('Already working in a write queue. Further writes ' 41 | 'will be appended to the current queue.') 42 | self.active = True 43 | 44 | def push(self, action): 45 | if not self.active: 46 | raise ValueError('Cannot push unless queue is active') 47 | if not isinstance(action, WriteAction): 48 | raise TypeError('Argument `action` must be instance ' 49 | 'of `WriteAction`') 50 | self.actions.append(action) 51 | 52 | def commit(self): 53 | if not self.active: 54 | raise ValueError('Cannot commit unless queue is active') 55 | results = [] 56 | while self.actions: 57 | action = self.actions.popleft() 58 | results.append(action.execute()) 59 | return results 60 | 61 | def clear(self): 62 | self.active = False 63 | self.actions = collections.deque() 64 | 65 | def __nonzero__(self): 66 | return bool(self.actions) 67 | 68 | # Python 3 69 | __bool__ = __nonzero__ 70 | 71 | 72 | class QueueContext(object): 73 | 74 | def __init__(self, BaseSchema): 75 | self.BaseSchema = BaseSchema 76 | 77 | def __enter__(self): 78 | self.BaseSchema.start_queue() 79 | 80 | def __exit__(self, exc_type, exc_val, exc_tb): 81 | self.BaseSchema.commit_queue() 82 | -------------------------------------------------------------------------------- /pk_sandbox.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cos-archives/modular-odm/8a34891892b8af69b21fdc46701c91763a5c1cf9/pk_sandbox.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | blinker 2 | pymongo 3 | python-dateutil 4 | Werkzeug==0.11.11 5 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [wheel] 2 | universal = 1 3 | 4 | [flake8] 5 | ignore = E127,E128,E265,E302,N803,N804,N806 6 | max-line-length = 90 7 | exclude = .git,.ropeproject,.tox,docs,.git,build,setup.py,env,venv 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import re 3 | import sys 4 | import subprocess 5 | 6 | import pip 7 | from setuptools import setup, find_packages 8 | 9 | 10 | def parse_requirements(requirements): 11 | with open(requirements) as f: 12 | return [ 13 | l.strip('\n') for l 14 | in f if l.strip('\n') and not l.startswith('#') 15 | ] 16 | 17 | 18 | def find_version(fname): 19 | '''Attempts to find the version number in the file names fname. 20 | Raises RuntimeError if not found. 21 | ''' 22 | version = '' 23 | with open(fname, 'r') as fp: 24 | reg = re.compile(r'__version__ = [\'"]([^\'"]*)[\'"]') 25 | for line in fp: 26 | m = reg.match(line) 27 | if m: 28 | version = m.group(1) 29 | break 30 | if not version: 31 | raise RuntimeError('Cannot find version information') 32 | return version 33 | 34 | 35 | __version__ = find_version("modularodm/__init__.py") 36 | 37 | PUBLISH_CMD = "python setup.py register sdist bdist_wheel upload" 38 | TEST_PUBLISH_CMD = 'python setup.py register -r test sdist bdist_wheel upload -r test' 39 | TEST_CMD = 'nosetests' 40 | 41 | if 'publish' in sys.argv: 42 | try: 43 | __import__('wheel') 44 | except ImportError: 45 | print("wheel required. Run `pip install wheel`.") 46 | sys.exit(1) 47 | status = subprocess.call(PUBLISH_CMD, shell=True) 48 | sys.exit(status) 49 | 50 | if 'publish_test' in sys.argv: 51 | try: 52 | __import__('wheel') 53 | except ImportError: 54 | print("wheel required. Run `pip install wheel`.") 55 | sys.exit(1) 56 | status = subprocess.call(TEST_PUBLISH_CMD, shell=True) 57 | sys.exit() 58 | 59 | if 'run_tests' in sys.argv: 60 | try: 61 | __import__('nose') 62 | except ImportError: 63 | print('nose required. Run `pip install nose`.') 64 | sys.exit(1) 65 | 66 | status = subprocess.call(TEST_CMD, shell=True) 67 | sys.exit(status) 68 | 69 | 70 | def read(fname): 71 | with open(fname) as fp: 72 | content = fp.read() 73 | return content 74 | 75 | setup( 76 | name='modular-odm', 77 | version=__version__, 78 | classifiers=[ 79 | "Development Status :: 3 - Alpha", 80 | "Intended Audience :: Developers", 81 | "Programming Language :: Python :: 2", 82 | "Programming Language :: Python :: 2.7", 83 | ], 84 | url="https://github.com/CenterForOpenScience/modular-odm", 85 | author='Center for Open Science', 86 | author_email='contact@centerforopenscience.org', 87 | zip_safe=False, 88 | description='A Pythonic Object Data Manager', 89 | long_description=read("README.rst"), 90 | packages=find_packages(exclude=("test*",)), 91 | install_requires=parse_requirements('requirements.txt'), 92 | tests_require=["nose"], 93 | keywords=["odm", "nosql", "mongo", "mongodb"], 94 | ) 95 | -------------------------------------------------------------------------------- /tasks.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | from invoke import task, run 4 | 5 | docs_dir = 'docs' 6 | build_dir = os.path.join(docs_dir, '_build') 7 | 8 | @task 9 | def mongo(daemon=False, port=20771): 10 | '''Run the mongod process. 11 | ''' 12 | cmd = "mongod --port {0}".format(port) 13 | if daemon: 14 | cmd += " --fork" 15 | run(cmd) 16 | 17 | @task 18 | def test(coverage=False, browse=False): 19 | command = "nosetests" 20 | if coverage: 21 | command += " --with-coverage --cover-html" 22 | run(command, pty=True) 23 | if coverage and browse: 24 | run("open cover/index.html") 25 | 26 | @task 27 | def clean(): 28 | run("rm -rf build") 29 | run("rm -rf dist") 30 | run("rm -rf marshmallow.egg-info") 31 | clean_docs() 32 | print("Cleaned up.") 33 | 34 | @task 35 | def clean_docs(): 36 | run("rm -rf %s" % build_dir) 37 | 38 | @task 39 | def browse_docs(): 40 | run("open %s" % os.path.join(build_dir, 'index.html')) 41 | 42 | @task 43 | def docs(clean=False, browse=False): 44 | if clean: 45 | clean_docs() 46 | run("sphinx-build %s %s" % (docs_dir, build_dir), pty=True) 47 | if browse: 48 | browse_docs() 49 | 50 | @task 51 | def readme(): 52 | run("rst2html.py README.rst > README.html", pty=True) 53 | run("open README.html") 54 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cos-archives/modular-odm/8a34891892b8af69b21fdc46701c91763a5c1cf9/tests/__init__.py -------------------------------------------------------------------------------- /tests/backrefs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cos-archives/modular-odm/8a34891892b8af69b21fdc46701c91763a5c1cf9/tests/backrefs/__init__.py -------------------------------------------------------------------------------- /tests/backrefs/test_abstract_backrefs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from modularodm import fields 4 | 5 | from tests.base import ModularOdmTestCase, TestObject 6 | 7 | class OneToManyFieldTestCase(ModularOdmTestCase): 8 | 9 | def define_objects(self): 10 | 11 | class Foo(TestObject): 12 | _id = fields.IntegerField() 13 | 14 | class Bar(TestObject): 15 | _id = fields.IntegerField() 16 | ref = fields.ForeignField('foo', backref='food') 17 | 18 | class Baz(TestObject): 19 | _id = fields.IntegerField() 20 | ref = fields.ForeignField('foo', backref='food') 21 | 22 | return Foo, Bar, Baz 23 | 24 | def set_up_objects(self): 25 | 26 | self.foo = self.Foo(_id=1) 27 | self.foo.save() 28 | 29 | self.bar = self.Bar(_id=2, ref=self.foo) 30 | self.baz = self.Bar(_id=3, ref=self.foo) 31 | 32 | self.bar.save() 33 | self.baz.save() 34 | 35 | def test_abstract_backrefs(self): 36 | 37 | backrefs = self.foo.food 38 | self.assertIn(self.bar, backrefs) 39 | self.assertIn(self.baz, backrefs) 40 | -------------------------------------------------------------------------------- /tests/backrefs/test_attribute_syntax.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import six 4 | 5 | from modularodm import fields 6 | from modularodm.storedobject import ContextLogger 7 | from modularodm import StoredObject 8 | from modularodm.exceptions import ModularOdmException 9 | 10 | from tests.base import ( 11 | ModularOdmTestCase 12 | ) 13 | 14 | class OneToManyFieldTestCase(ModularOdmTestCase): 15 | 16 | def define_objects(self): 17 | 18 | class Foo(StoredObject): 19 | _meta = { 20 | 'optimistic': True 21 | } 22 | _id = fields.StringField() 23 | my_bar = fields.ForeignField('Bar', backref='my_foos') 24 | my_other_bar = fields.ForeignField('Bar', backref='my_foos') 25 | 26 | class Bar(StoredObject): 27 | _meta = {'optimistic': True} 28 | _id = fields.StringField() 29 | 30 | return Foo, Bar 31 | 32 | def set_up_objects(self): 33 | 34 | self.bar = self.Bar() 35 | self.bar.save() 36 | 37 | self.foos = [] 38 | for i in range(5): 39 | foo = self.Foo() 40 | if i > 0: 41 | foo.my_bar = self.bar 42 | else: 43 | foo.my_other_bar = self.bar 44 | foo.save() 45 | self.foos.append(foo) 46 | 47 | def test_dunder_br_returns_foreignlist(self): 48 | self.assertIs( 49 | type(self.bar.foo__my_foos), 50 | fields.ForeignList 51 | ) 52 | 53 | def test_dunder_br_returns_correct(self): 54 | self.assertEqual( 55 | len(self.bar.foo__my_foos), 56 | 5 57 | ) 58 | 59 | @unittest.skip('Behavior not defined.') 60 | def test_dunder_br_unknown_field(self): 61 | with self.assertRaises(KeyError): 62 | self.bar.foo__not_a_real_key 63 | 64 | def test_dunder_br_unknown_node(self): 65 | with self.assertRaises(ModularOdmException): 66 | self.bar.not_a_real_node__foo 67 | 68 | def test_dunder_br_parent_field_correct(self): 69 | self.assertEqual( 70 | len(self.bar.foo__my_foos__my_other_bar), 71 | 1 72 | ) 73 | 74 | def test_dunder_br_laziness(self): 75 | StoredObject._clear_caches() 76 | 77 | with ContextLogger() as c: 78 | # get the Bar object 79 | bar = self.Bar.find_one() 80 | # access the ForeignList 81 | bar.foo__my_foos 82 | 83 | # Two calls so far - .find_one() and .find() 84 | self.assertNotIn( 85 | 'foo', 86 | [k[0] for k, v in six.iteritems(c.report())], 87 | ) 88 | 89 | # access a member of the ForeignList, forcing that member to load 90 | bar.foo__my_foos[0] 91 | 92 | # now there should be a call to Foo.get() 93 | self.assertEqual( 94 | c.report()[('foo', 'get')][0], 95 | 1 96 | ) 97 | 98 | class OneToManyAbstractFieldTestCase(ModularOdmTestCase): 99 | 100 | def define_objects(self): 101 | 102 | class Foo(StoredObject): 103 | _meta = { 104 | 'optimistic': True 105 | } 106 | _id = fields.StringField() 107 | my_abstract = fields.AbstractForeignField(backref='my_foos') 108 | my_other_abstract = fields.AbstractForeignField(backref='my_foos') 109 | 110 | class Bar(StoredObject): 111 | _meta = {'optimistic': True} 112 | _id = fields.StringField() 113 | 114 | class Bob(StoredObject): 115 | _meta = {'optimistic': True} 116 | _id = fields.StringField() 117 | 118 | return Foo, Bar, Bob 119 | 120 | def set_up_objects(self): 121 | 122 | self.bar = self.Bar() 123 | self.bar.save() 124 | 125 | self.bob = self.Bob() 126 | self.bob.save() 127 | 128 | self.foos = [] 129 | for i in range(5): 130 | foo = self.Foo() 131 | if i > 0: 132 | foo.my_abstract = self.bar 133 | foo.my_other_abstract = self.bob 134 | else: 135 | foo.my_abstract = self.bob 136 | foo.my_other_abstract = self.bar 137 | foo.save() 138 | self.foos.append(foo) 139 | 140 | def test_dunder_br_returns_foreignlist(self): 141 | self.assertIs( 142 | type(self.bar.foo__my_foos), 143 | fields.ForeignList 144 | ) 145 | 146 | def test_dunder_br_returns_correct(self): 147 | self.assertEqual( 148 | len(self.bar.foo__my_foos), 149 | 5 150 | ) 151 | 152 | @unittest.skip('Behavior not defined.') 153 | def test_dunder_br_unknown_field(self): 154 | with self.assertRaises(KeyError): 155 | self.bar.foo__not_a_real_key 156 | 157 | def test_dunder_br_unknown_node(self): 158 | with self.assertRaises(ModularOdmException): 159 | self.bar.not_a_real_node__foo 160 | 161 | def test_dunder_br_parent_field_correct(self): 162 | self.assertEqual( 163 | len(self.bar.foo__my_foos__my_other_abstract), 164 | 1 165 | ) 166 | 167 | def test_dunder_br_laziness(self): 168 | StoredObject._clear_caches() 169 | 170 | with ContextLogger() as c: 171 | # get the Bar object 172 | bar = self.Bar.find_one() 173 | # access the ForeignList 174 | bar.foo__my_foos 175 | 176 | # Two calls so far - .find_one() and .find() 177 | self.assertNotIn( 178 | 'foo', 179 | [k[0] for k, v in six.iteritems(c.report())], 180 | ) 181 | 182 | # access a member of the ForeignList, forcing that member to load 183 | bar.foo__my_foos[0] 184 | 185 | # now there should be a call to Foo.get() 186 | self.assertEqual( 187 | c.report()[('foo', 'get')][0], 188 | 1 189 | ) 190 | -------------------------------------------------------------------------------- /tests/backrefs/test_ensure_backrefs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from nose.tools import * # PEP8 asserts 3 | 4 | from modularodm import StoredObject, fields 5 | from modularodm.storedobject import ensure_backrefs 6 | from tests.base import ModularOdmTestCase 7 | 8 | class TestEnsureBackrefs(ModularOdmTestCase): 9 | 10 | def define_objects(self): 11 | 12 | class Foo(StoredObject): 13 | _id = fields.StringField() 14 | _meta = { 15 | 'optimistic': True, 16 | } 17 | 18 | class Bar(StoredObject): 19 | _id = fields.StringField() 20 | ref = fields.ForeignField('foo', backref='my_ref') 21 | abs_ref = fields.AbstractForeignField(backref='my_abs_ref') 22 | ref_list = fields.ForeignField('foo', backref='my_ref_list', list=True) 23 | abs_ref_list = fields.AbstractForeignField(backref='my_abs_ref_list', list=True) 24 | _meta = { 25 | 'optimistic': True, 26 | } 27 | 28 | return Foo, Bar 29 | 30 | def set_up_objects(self): 31 | 32 | self.foos = [] 33 | for _ in range(5): 34 | foo = self.Foo() 35 | foo.save() 36 | self.foos.append(foo) 37 | 38 | def test_ensure_foreign(self): 39 | 40 | bar = self.Bar(ref=self.foos[0]) 41 | bar.save() 42 | 43 | # Delete backrefs for some reason 44 | self.foos[0]._StoredObject__backrefs = {} 45 | self.foos[0].save() 46 | 47 | # Assert that backrefs are gone 48 | assert_equal( 49 | len(self.foos[0].bar__my_ref), 50 | 0 51 | ) 52 | 53 | # Restore backrefs 54 | ensure_backrefs(bar) 55 | 56 | # Assert that backrefs are correct 57 | assert_equal( 58 | len(self.foos[0].bar__my_ref), 59 | 1 60 | ) 61 | assert_equal( 62 | self.foos[0].bar__my_ref[0], 63 | bar 64 | ) 65 | 66 | def test_ensure_foreign_list(self): 67 | 68 | bar = self.Bar(ref_list=self.foos) 69 | bar.save() 70 | 71 | for foo in self.foos: 72 | 73 | # Delete backrefs for some reason 74 | foo._StoredObject__backrefs = {} 75 | foo.save() 76 | 77 | # Assert that backrefs are gone 78 | assert_equal( 79 | len(foo.bar__my_ref_list), 80 | 0 81 | ) 82 | 83 | # Restore backrefs 84 | ensure_backrefs(bar) 85 | 86 | for foo in self.foos: 87 | 88 | # Assert that backrefs are correct 89 | assert_equal( 90 | len(foo.bar__my_ref_list), 91 | 1 92 | ) 93 | assert_equal( 94 | foo.bar__my_ref_list[0], 95 | bar 96 | ) 97 | 98 | def test_missing_backref_removed_from_list(self): 99 | 100 | bar = self.Bar(ref_list=self.foos) 101 | bar.save() 102 | 103 | for foo in self.foos: 104 | 105 | # Delete backrefs for some reason 106 | foo._StoredObject__backrefs = {} 107 | foo.save() 108 | 109 | # Assert that backrefs are gone 110 | assert_equal( 111 | len(foo.bar__my_ref_list), 112 | 0 113 | ) 114 | 115 | bar.ref_list.pop() 116 | 117 | # Ensure save does not raise an uncaught KeyError from storedobject._remove_backref 118 | bar.save() 119 | -------------------------------------------------------------------------------- /tests/backrefs/test_many_to_many.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from modularodm import ( 4 | exceptions as exc, 5 | StoredObject, 6 | ) 7 | from modularodm.fields import ForeignField, IntegerField 8 | 9 | 10 | from tests.base import ModularOdmTestCase 11 | 12 | class ManyToManyFieldTestCase(ModularOdmTestCase): 13 | 14 | def define_objects(self): 15 | class Foo(StoredObject): 16 | _id = IntegerField() 17 | my_bar = ForeignField('Bar', list=True, backref='my_foo') 18 | 19 | class Bar(StoredObject): 20 | _id = IntegerField() 21 | 22 | return Foo, Bar 23 | 24 | def set_up_objects(self): 25 | 26 | # create a Foo and two Bars 27 | self.foo = self.Foo(_id=1) 28 | self.bar = self.Bar(_id=2) 29 | self.baz = self.Bar(_id=3) 30 | 31 | # save the bars so they're in the storage 32 | self.bar.save() 33 | self.baz.save() 34 | 35 | # add the bars to the foo 36 | self.foo.my_bar.append(self.bar) 37 | self.foo.my_bar.append(self.baz) 38 | 39 | # save foo to persist changes 40 | self.foo.save() 41 | 42 | def test_one_to_many_backref(self): 43 | 44 | # Should be a list of the object's ID 45 | self.assertEqual( 46 | list(self.foo.my_bar), 47 | [self.bar, self.baz] 48 | ) 49 | 50 | # The backreference on bar should be a dict with the necessary info 51 | self.assertEqual( 52 | self.bar.my_foo[0], 53 | self.foo 54 | ) 55 | 56 | # The backreference on baz should be the same 57 | self.assertEqual( 58 | self.baz.my_foo[0], 59 | self.foo 60 | ) 61 | 62 | # bar._backrefs should contain a dict with all backref information for 63 | # the object. 64 | self.assertEqual( 65 | self.bar._backrefs, 66 | {'my_foo': {'foo': {'my_bar': [self.foo._id]}}} 67 | ) 68 | 69 | # bar._backrefs should contain a dict with all backref information for 70 | # the object. 71 | self.assertEqual( 72 | self.baz._backrefs, 73 | {'my_foo': {'foo': {'my_bar': [self.foo._id]}}} 74 | ) 75 | 76 | def test_contains(self): 77 | """ Verify that the "in" operator works as expected """ 78 | 79 | self.assertIn( 80 | self.bar, 81 | self.foo.my_bar 82 | ) 83 | 84 | def test_delete_backref(self): 85 | """ Remove an element from a ForeignField, and verify that it was 86 | removed from the backref as well. 87 | """ 88 | 89 | first_bar = self.foo.my_bar[0] 90 | second_bar = self.foo.my_bar[1] 91 | 92 | del self.foo.my_bar[0] 93 | self.foo.save() 94 | 95 | # The first Bar should be gone from the ForeignField 96 | self.assertEqual( 97 | list(self.foo.my_bar), 98 | [second_bar, ], 99 | ) 100 | 101 | # The first Bar should no longer have a reference to foo 102 | self.assertEqual( 103 | first_bar.my_foo, 104 | [] 105 | ) 106 | 107 | def test_remove(self): 108 | """ Remove an object from a ForeignList field using the field's 109 | .remove() method 110 | """ 111 | self.foo.my_bar.remove(self.bar) 112 | self.foo.save() 113 | 114 | # the object should be removed from .my_bar 115 | self.assertNotIn( 116 | self.bar, 117 | self.foo.my_bar 118 | ) 119 | 120 | # the backref should be removed from the object 121 | self.assertEqual( 122 | self.bar.my_foo, 123 | [] 124 | ) 125 | 126 | def test_insert(self): 127 | """ Add a new object to the middle of a ForeignList field via .insert() 128 | """ 129 | 130 | # create a new bar 131 | new_bar = self.Bar(_id=9) 132 | new_bar.save() 133 | 134 | # insert new_bar into foo's .my_bar 135 | self.foo.my_bar.insert(1, new_bar) 136 | self.foo.save() 137 | 138 | # new_bar should now be in the list 139 | self.assertIn( 140 | new_bar, 141 | self.foo.my_bar, 142 | ) 143 | 144 | # new_bar should have a backref to foo 145 | self.assertEqual( 146 | len(new_bar.my_foo), 147 | 1 148 | ) 149 | self.assertEqual( 150 | new_bar.my_foo[0], 151 | self.foo 152 | ) 153 | 154 | def test_replace_backref(self): 155 | """ Replace an existing item in the ForeignList field with another 156 | remote object. 157 | """ 158 | 159 | # create a new bar 160 | new_bar = self.Bar(_id=9) 161 | new_bar.save() 162 | 163 | # replace the first bar in the list with it. 164 | old_bar_id = self.foo.my_bar[0]._id 165 | self.foo.my_bar[0] = new_bar 166 | self.foo.save() 167 | 168 | # the old Bar should no longer have a backref to foo 169 | old_bar = self.Bar.load(old_bar_id) 170 | self.assertEqual( 171 | old_bar._backrefs, 172 | {'my_foo': {'foo': {'my_bar': []}}} 173 | ) 174 | 175 | # the new Bar should have a backref to foo 176 | self.assertEqual( 177 | new_bar._backrefs, 178 | {'my_foo': {'foo': {'my_bar': [1]}}} 179 | ) 180 | 181 | @unittest.skip('assertion fails') 182 | def test_delete_backref_attribute_from_remote_via_pop(self): 183 | """ Delete a backref from its attribute on the remote object by calling 184 | .pop(). 185 | 186 | Backref attributes on the remote object should be read-only. 187 | """ 188 | 189 | with self.assertRaises(exc.ModularOdmException): 190 | self.bar.my_foo['foo']['my_bar'].pop() 191 | 192 | @unittest.skip('assertion fails') 193 | def test_delete_backref_attribute_from_remote_via_del(self): 194 | """ Delete a backref from its attribute from the remote object directly. 195 | 196 | Backref attributes on the remote object should be read-only. 197 | """ 198 | 199 | with self.assertRaises(exc.ModularOdmException): 200 | del self.bar.my_foo['foo']['my_bar'][0] 201 | 202 | @unittest.skip('assertion fails') 203 | def test_assign_backref_attribute_from_remote(self): 204 | """ Manually assign a backref to its attribute on the remote object. 205 | 206 | Backref attributes on the remote object should be read-only. 207 | """ 208 | 209 | with self.assertRaises(exc.ModularOdmException): 210 | self.bar.my_foo = {'foo': {'my_bar': []}} 211 | 212 | def test_del_key_from_backrefs_on_remote(self): 213 | """ Manually remove a key from _backrefs on the remote object. 214 | 215 | _backrefs on the remote object should be read-only. 216 | """ 217 | with self.assertRaises(TypeError): 218 | del self.bar._backrefs['my_foo'] 219 | 220 | def test_assign_backrefs_on_remote(self): 221 | """ Manually assign a backref on the remote object directly. 222 | 223 | _backrefs on the remote object should be read-only. 224 | """ 225 | with self.assertRaises(exc.ModularOdmException): 226 | self.bar._backrefs = {'my_foo': {'foo': {'my_bar': [self.foo._id]}}} 227 | -------------------------------------------------------------------------------- /tests/backrefs/test_one_to_many.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from nose.tools import * 4 | 5 | from modularodm import fields 6 | 7 | from tests.base import ModularOdmTestCase, TestObject 8 | 9 | 10 | class OneToManyFieldTestCase(ModularOdmTestCase): 11 | 12 | def define_objects(self): 13 | 14 | class Foo(TestObject): 15 | _id = fields.IntegerField() 16 | my_bar = fields.ForeignField('Bar', backref='my_foo') 17 | my_bar_no_backref = fields.ForeignField('Bar') 18 | 19 | class Bar(TestObject): 20 | _id = fields.IntegerField() 21 | 22 | return Foo, Bar 23 | 24 | def set_up_objects(self): 25 | 26 | self.foo = self.Foo(_id=1) 27 | self.bar = self.Bar(_id=2) 28 | 29 | self.bar.save() 30 | 31 | self.foo.my_bar = self.bar 32 | self.foo.save() 33 | 34 | def test_no_backref(self): 35 | """Regression test; ensure that backrefs are not saved when no backref 36 | key is given. 37 | """ 38 | bar = self.Bar(_id=20) 39 | bar.save() 40 | 41 | self.foo.my_bar_no_backref = bar 42 | self.foo.save() 43 | 44 | assert_equal(bar._backrefs, {}) 45 | 46 | def test_one_to_one_backref(self): 47 | 48 | # The object itself should be assigned 49 | self.assertIs( 50 | self.foo.my_bar, 51 | self.bar 52 | ) 53 | 54 | # The backreference on bar should be a dict with the necessary info 55 | self.assertEqual( 56 | self.bar.my_foo[0], 57 | self.foo 58 | ) 59 | 60 | # bar._backrefs should contain a dict with all backref information for 61 | # the object. 62 | self.assertEqual( 63 | self.bar._backrefs, 64 | {'my_foo': {'foo': {'my_bar': [self.foo._id]}}} 65 | ) 66 | 67 | def test_delete_foreign_field(self): 68 | """ Remove an element from a ForeignField, and verify that it was 69 | removed from the backref as well. 70 | """ 71 | 72 | del self.foo.my_bar 73 | self.foo.save() 74 | 75 | # The first Bar should be gone from the ForeignField 76 | self.assertEqual( 77 | self.foo.my_bar, 78 | None, 79 | ) 80 | 81 | # The first Bar should no longer have a reference to foo 82 | self.assertEqual( 83 | self.bar.my_foo, 84 | [] 85 | ) 86 | 87 | def test_assign_foreign_field_to_none(self): 88 | """ Assigning a ForeignField to None should be have the same effect as 89 | deleting it. 90 | """ 91 | 92 | # 93 | self.foo.my_bar = None 94 | self.foo.save() 95 | 96 | # The first Bar should be gone from the ForeignField 97 | self.assertEqual( 98 | self.foo.my_bar, 99 | None, 100 | ) 101 | 102 | # The first Bar should no longer have a reference to foo 103 | self.assertEqual( 104 | self.bar.my_foo, 105 | [] 106 | ) 107 | 108 | def test_assign_foreign_field_by_id(self): 109 | """ Try assigning a ForeignField to the primary key of a remote object 110 | """ 111 | 112 | # this functionality is not yet defined. 113 | self.foo.my_bar = self.bar._id 114 | self.assertIs( 115 | self.foo.my_bar, 116 | self.bar 117 | ) 118 | 119 | 120 | class OneToManyAbstractFieldTestCase(ModularOdmTestCase): 121 | 122 | def define_objects(self): 123 | 124 | class Foo(TestObject): 125 | _id = fields.IntegerField() 126 | my_abstract = fields.AbstractForeignField(backref='my_foo') 127 | 128 | class Bar(TestObject): 129 | _id = fields.IntegerField() 130 | 131 | class Bob(TestObject): 132 | _id = fields.IntegerField() 133 | 134 | return Foo, Bar, Bob 135 | 136 | def set_up_objects(self): 137 | 138 | self.bar = self.Bar(_id=3) 139 | self.bob = self.Bar(_id=4) 140 | 141 | self.bar.save() 142 | self.bob.save() 143 | 144 | self.foo1 = self.Foo(_id=1, my_abstract=self.bar) 145 | self.foo1.save() 146 | 147 | self.foo2 = self.Foo(_id=2, my_abstract=self.bob) 148 | self.foo2.save() 149 | 150 | def test_one_to_one_backref(self): 151 | 152 | # The object itself should be assigned 153 | self.assertIs( 154 | self.foo1.my_abstract, 155 | self.bar 156 | ) 157 | self.assertIs( 158 | self.foo2.my_abstract, 159 | self.bob 160 | ) 161 | 162 | # The backreference on bar should be a dict with the necessary info 163 | self.assertEqual( 164 | self.bar.my_foo[0], 165 | self.foo1 166 | ) 167 | self.assertEqual( 168 | self.bob.my_foo[0], 169 | self.foo2 170 | ) 171 | 172 | # bar._backrefs should contain a dict with all backref information for 173 | # the object. 174 | self.assertEqual( 175 | self.bar._backrefs, 176 | {'my_foo': {'foo': {'my_abstract': [self.foo1._id]}}} 177 | ) 178 | self.assertEqual( 179 | self.bob._backrefs, 180 | {'my_foo': {'foo': {'my_abstract': [self.foo2._id]}}} 181 | ) 182 | 183 | def test_delete_foreign_field(self): 184 | """ Remove an element from a ForeignField, and verify that it was 185 | removed from the backref as well. 186 | """ 187 | 188 | del self.foo1.my_abstract 189 | self.foo1.save() 190 | del self.foo2.my_abstract 191 | self.foo2.save() 192 | 193 | # The first Bar should be gone from the ForeignField 194 | self.assertEqual( 195 | self.foo1.my_abstract, 196 | None, 197 | ) 198 | self.assertEqual( 199 | self.foo2.my_abstract, 200 | None, 201 | ) 202 | 203 | # The first Bar should no longer have a reference to foo 204 | self.assertEqual( 205 | self.bar.my_foo, 206 | [] 207 | ) 208 | self.assertEqual( 209 | self.bob.my_foo, 210 | [] 211 | ) 212 | 213 | def test_assign_foreign_field_to_none(self): 214 | """ Assigning a ForeignField to None should be have the same effect as 215 | deleting it. 216 | """ 217 | 218 | # 219 | self.foo1.my_abstract = None 220 | self.foo1.save() 221 | self.foo2.my_abstract = None 222 | self.foo2.save() 223 | 224 | # The first Bar should be gone from the ForeignField 225 | self.assertEqual( 226 | self.foo1.my_abstract, 227 | None, 228 | ) 229 | self.assertEqual( 230 | self.foo2.my_abstract, 231 | None, 232 | ) 233 | 234 | # The first Bar should no longer have a reference to foo 235 | self.assertEqual( 236 | self.bar.my_foo, 237 | [] 238 | ) 239 | self.assertEqual( 240 | self.bob.my_foo, 241 | [] 242 | ) 243 | 244 | def test_assign_foreign_field_by_tuple(self): 245 | """Assign a tuple of a primary key and a schema name to an abstract 246 | foreign field. Verify that the value of the abstract foreign field is 247 | the same as the object whose primary key and schema name passed in. 248 | 249 | """ 250 | self.foo1.my_abstract = (self.bar._id, self.bar._name) 251 | 252 | self.assertIs( 253 | self.foo1.my_abstract, 254 | self.bar 255 | ) 256 | 257 | def test_assign_foreign_field_invalid_type(self): 258 | """Try to assign a value of the wrong type to an abstract foreign 259 | field; should raise a TypeError. 260 | 261 | """ 262 | with self.assertRaises(TypeError): 263 | self.foo1.my_abstract = 'some primary key' 264 | 265 | def test_assign_foreign_field_invalid_length(self): 266 | """Try to assign a tuple of the wrong length to an abstract foreign 267 | field; should raise a ValueError. 268 | 269 | """ 270 | with self.assertRaises(ValueError): 271 | self.foo1.my_abstract = ('tuple', 'of', 'wrong', 'length') 272 | -------------------------------------------------------------------------------- /tests/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | import inspect 4 | import os 5 | import pymongo 6 | import unittest 7 | import uuid 8 | import six 9 | 10 | from modularodm import StoredObject 11 | from modularodm.storage import MongoStorage, PickleStorage, EphemeralStorage 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | class TestObject(StoredObject): 16 | def __init__(self, *args, **kwargs): 17 | self.set_storage(PickleStorage('Test')) 18 | super(TestObject, self).__init__(*args, **kwargs) 19 | 20 | 21 | class EphemeralStorageMixin(object): 22 | fixture_suffix = 'Ephemeral' 23 | 24 | def make_storage(self): 25 | return EphemeralStorage() 26 | 27 | def clean_up_storage(self): 28 | pass 29 | 30 | 31 | class PickleStorageMixin(object): 32 | fixture_suffix = 'Pickle' 33 | 34 | def make_storage(self): 35 | try: 36 | self.pickle_files 37 | except AttributeError: 38 | self.pickle_files = [] 39 | 40 | filename = str(uuid.uuid4())[:8] 41 | self.pickle_files.append(filename) 42 | return PickleStorage(filename) 43 | 44 | def clean_up_storage(self): 45 | for f in self.pickle_files: 46 | try: 47 | os.remove('db_{0}.pkl'.format(f)) 48 | except OSError: 49 | pass 50 | 51 | 52 | class MongoStorageMixin(object): 53 | fixture_suffix = 'Mongo' 54 | 55 | # DB settings 56 | DB_HOST = os.environ.get('MONGO_HOST', 'localhost') 57 | DB_PORT = int(os.environ.get('MONGO_PORT', '27017')) 58 | 59 | # More efficient to set up client at the class level than to re-connect 60 | # for each test 61 | mongo_client = pymongo.MongoClient( 62 | host=DB_HOST, 63 | port=DB_PORT, 64 | ).modm_test 65 | 66 | def make_storage(self): 67 | 68 | try: 69 | self.mongo_collections 70 | except AttributeError: 71 | self.mongo_collections = [] 72 | 73 | collection = str(uuid.uuid4())[:8] 74 | self.mongo_collections.append(collection) 75 | # logger.debug(self.mongo_collections) 76 | 77 | return MongoStorage( 78 | db=self.mongo_client, 79 | collection=collection 80 | ) 81 | 82 | def clean_up_storage(self): 83 | for c in self.mongo_collections: 84 | self.mongo_client.drop_collection(c) 85 | 86 | 87 | class MultipleBackendMeta(type): 88 | def __new__(mcs, name, bases, dct): 89 | 90 | 91 | if name == 'ModularOdmTestCase': 92 | return type.__new__( 93 | mcs, 94 | name, 95 | bases, 96 | dct 97 | ) 98 | 99 | frame = inspect.currentframe().f_back 100 | 101 | for mixin in ( 102 | PickleStorageMixin, 103 | MongoStorageMixin, 104 | EphemeralStorageMixin, 105 | ): 106 | new_name = '{}{}'.format(name, mixin.fixture_suffix) 107 | frame.f_globals[new_name] = type.__new__( 108 | mcs, 109 | new_name, 110 | (mixin, ) + bases, 111 | dct 112 | ) 113 | frame.f_globals[new_name].__test__ = True 114 | 115 | 116 | @six.add_metaclass(MultipleBackendMeta) 117 | class ModularOdmTestCase(unittest.TestCase): 118 | __test__ = False 119 | 120 | # Setup 121 | 122 | def setUp(self): 123 | super(ModularOdmTestCase, self).setUp() 124 | test_objects = self.define_objects() or tuple() 125 | 126 | for obj in test_objects: 127 | # Create storage backend if not explicitly set 128 | if not getattr(obj, '_storage', None): 129 | obj.set_storage(self.make_storage()) 130 | self.__setattr__(obj.__name__, obj) 131 | 132 | StoredObject._clear_caches() 133 | 134 | self.set_up_objects() 135 | 136 | def set_up_storage(self): 137 | super(ModularOdmTestCase, self).set_up_storage() 138 | 139 | def define_objects(self): 140 | try: 141 | super(ModularOdmTestCase, self).define_objects() 142 | except AttributeError: 143 | pass 144 | 145 | def set_up_objects(self): 146 | try: 147 | super(ModularOdmTestCase, self).set_up_objects() 148 | except AttributeError: 149 | pass 150 | 151 | # Teardown 152 | 153 | def tearDown(self): 154 | # Avoids error when no models defined; variables like 155 | # pickle_files will not be defined. 156 | try: 157 | self.clean_up_storage() 158 | except AttributeError: 159 | pass 160 | super(ModularOdmTestCase, self).tearDown() 161 | -------------------------------------------------------------------------------- /tests/ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cos-archives/modular-odm/8a34891892b8af69b21fdc46701c91763a5c1cf9/tests/ext/__init__.py -------------------------------------------------------------------------------- /tests/ext/test_concurrency.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from nose.tools import * 3 | 4 | from modularodm.ext import concurrency 5 | 6 | 7 | class Key(object): 8 | pass 9 | 10 | 11 | global_key = Key() 12 | 13 | 14 | def get_key(): 15 | global global_key 16 | return global_key 17 | 18 | 19 | class ProxiedClass(object): 20 | pass 21 | 22 | 23 | @concurrency.with_proxies({'proxied': ProxiedClass}, get_key) 24 | class ParentClass(object): 25 | pass 26 | 27 | 28 | class TestConcurrency(unittest.TestCase): 29 | 30 | def setUp(self): 31 | self.key1, self.key2 = Key(), Key() 32 | 33 | def test_proxy(self): 34 | 35 | global global_key 36 | 37 | global_key = self.key1 38 | ParentClass.proxied.foo = 'bar' 39 | assert_equal(ParentClass.proxied.foo, 'bar') 40 | 41 | global_key = self.key2 42 | with assert_raises(AttributeError): 43 | ParentClass.proxied.foo 44 | ParentClass.proxied.foo = 'baz' 45 | assert_equal(ParentClass.proxied.foo, 'baz') 46 | 47 | global_key = self.key1 48 | assert_equal(ParentClass.proxied.foo, 'bar') 49 | 50 | 51 | if __name__ == '__main__': 52 | unittest.main() 53 | -------------------------------------------------------------------------------- /tests/fixtures.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cos-archives/modular-odm/8a34891892b8af69b21fdc46701c91763a5c1cf9/tests/fixtures.py -------------------------------------------------------------------------------- /tests/laziness/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cos-archives/modular-odm/8a34891892b8af69b21fdc46701c91763a5c1cf9/tests/laziness/__init__.py -------------------------------------------------------------------------------- /tests/laziness/test_lazy_load.py: -------------------------------------------------------------------------------- 1 | from modularodm import StoredObject 2 | from modularodm.storedobject import ContextLogger 3 | from modularodm.fields import ForeignField, IntegerField 4 | 5 | from tests.base import ModularOdmTestCase 6 | 7 | 8 | class LazyLoadTestCase(ModularOdmTestCase): 9 | 10 | def define_objects(self): 11 | 12 | class Foo(StoredObject): 13 | _id = IntegerField() 14 | my_bar = ForeignField('Bar', list=True, backref='my_foo') 15 | 16 | class Bar(StoredObject): 17 | _id = IntegerField() 18 | 19 | return Foo, Bar 20 | 21 | def test_create_one_object(self): 22 | 23 | with ContextLogger() as context_logger: 24 | 25 | bar = self.Bar(_id=1) 26 | bar.save() 27 | 28 | report = context_logger.report() 29 | 30 | self.assertEqual(report[('bar', 'insert')][0], 1) 31 | 32 | def test_load_object_in_cache(self): 33 | 34 | bar = self.Bar(_id=1) 35 | bar.save() 36 | 37 | with ContextLogger() as context_logger: 38 | 39 | self.Bar.load(1) 40 | report = context_logger.report() 41 | 42 | self.assertNotIn(('bar', 'load'), report) 43 | 44 | def test_load_object_not_in_cache(self): 45 | 46 | bar = self.Bar(_id=1) 47 | bar.save() 48 | 49 | self.Bar._clear_caches(1) 50 | 51 | with ContextLogger() as context_logger: 52 | 53 | self.Bar.load(1) 54 | report = context_logger.report() 55 | 56 | self.assertEqual(report[('bar', 'get')][0], 1) 57 | 58 | def test_create_several_objects(self): 59 | 60 | with ContextLogger() as context_logger: 61 | 62 | bar1 = self.Bar(_id=1) 63 | bar2 = self.Bar(_id=2) 64 | bar3 = self.Bar(_id=3) 65 | bar4 = self.Bar(_id=4) 66 | bar5 = self.Bar(_id=5) 67 | 68 | bar1.save() 69 | bar2.save() 70 | bar3.save() 71 | bar4.save() 72 | bar5.save() 73 | 74 | report = context_logger.report() 75 | 76 | self.assertEqual(report[('bar', 'insert')][0], 5) 77 | 78 | def test_create_linked_objects(self): 79 | 80 | bar1 = self.Bar(_id=1) 81 | bar2 = self.Bar(_id=2) 82 | bar3 = self.Bar(_id=3) 83 | 84 | bar1.save() 85 | bar2.save() 86 | bar3.save() 87 | 88 | with ContextLogger() as context_logger: 89 | 90 | foo1 = self.Foo(_id=4) 91 | foo1.my_bar = [bar1, bar2, bar3] 92 | foo1.save() 93 | 94 | report = context_logger.report() 95 | 96 | self.assertEqual(report[('foo', 'insert')][0], 1) 97 | self.assertEqual(report[('bar', 'update')][0], 3) 98 | 99 | def test_load_linked_objects_not_in_cache(self): 100 | 101 | bar1 = self.Bar(_id=1) 102 | bar2 = self.Bar(_id=2) 103 | bar3 = self.Bar(_id=3) 104 | 105 | bar1.save() 106 | bar2.save() 107 | bar3.save() 108 | 109 | foo1 = self.Foo(_id=4) 110 | foo1.my_bar = [bar1, bar2, bar3] 111 | foo1.save() 112 | 113 | StoredObject._clear_caches() 114 | 115 | with ContextLogger() as context_logger: 116 | 117 | self.Foo.load(4) 118 | 119 | report = context_logger.report() 120 | 121 | self.assertEqual(report[('foo', 'get')][0], 1) 122 | self.assertNotIn(('bar', 'get'), report) 123 | -------------------------------------------------------------------------------- /tests/queries/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cos-archives/modular-odm/8a34891892b8af69b21fdc46701c91763a5c1cf9/tests/queries/__init__.py -------------------------------------------------------------------------------- /tests/queries/test_comparison_operators.py: -------------------------------------------------------------------------------- 1 | import datetime as dt 2 | 3 | 4 | from modularodm import fields, StoredObject 5 | from modularodm.query.query import RawQuery as Q 6 | 7 | from tests.base import ModularOdmTestCase 8 | 9 | # TODO: The following are defined in MongoStorage, but not PickleStorage: 10 | # 'mod', 11 | # 'all', 12 | # 'size', 13 | # 'exists', 14 | # 'not' 15 | 16 | # TODO: These tests should be applied to all default field types. Perhaps use an 17 | # iterative approach? 18 | 19 | 20 | class ComparisonOperatorsTestCase(ModularOdmTestCase): 21 | 22 | def define_objects(self): 23 | class Foo(StoredObject): 24 | _id = fields.IntegerField(primary=True) 25 | integer_field = fields.IntegerField() 26 | string_field = fields.StringField() 27 | datetime_field = fields.DateTimeField() 28 | float_field = fields.FloatField() 29 | list_field = fields.IntegerField(list=True) 30 | 31 | return Foo, 32 | 33 | def set_up_objects(self): 34 | self.foos = [] 35 | 36 | for idx in range(3): 37 | foo = self.Foo( 38 | _id=idx, 39 | integer_field = idx, 40 | string_field = 'Value: {}'.format(idx), 41 | datetime_field = dt.datetime.now() + dt.timedelta(hours=idx), 42 | float_field = float(idx + 100.0), 43 | list_field = [(10 * idx) + 1, (10 * idx) + 2, (10 * idx) + 3], 44 | ) 45 | foo.save() 46 | self.foos.append(foo) 47 | 48 | def test_eq(self): 49 | """ Finds objects with the attribute equal to the parameter.""" 50 | self.assertEqual( 51 | self.Foo.find_one( 52 | Q('integer_field', 'eq', self.foos[1].integer_field) 53 | )._id, 54 | self.foos[1]._id, 55 | ) 56 | 57 | def test_ne(self): 58 | """ Finds objects with the attribute not equal to the parameter.""" 59 | self.assertEqual( 60 | len(self.Foo.find( 61 | Q('integer_field', 'ne', self.foos[1].integer_field) 62 | )), 63 | 2 64 | ) 65 | 66 | def test_gt(self): 67 | """ Finds objects with the attribute greater than the parameter.""" 68 | result = self.Foo.find( 69 | Q('integer_field', 'gt', self.foos[1].integer_field) 70 | ) 71 | self.assertEqual(len(result), 1) 72 | 73 | def test_gte(self): 74 | """ Finds objects with the attribute greater than or equal to the 75 | parameter. 76 | """ 77 | result = self.Foo.find( 78 | Q('integer_field', 'gte', self.foos[1].integer_field) 79 | ) 80 | self.assertEqual(len(result), 2) 81 | 82 | def test_lt(self): 83 | """ Finds objects with the attribute less than the parameter.""" 84 | result = self.Foo.find( 85 | Q('integer_field', 'lt', self.foos[1].integer_field) 86 | ) 87 | self.assertEqual(len(result), 1) 88 | 89 | def test_lte(self): 90 | """ Finds objects with the attribute less than or equal to the 91 | parameter.""" 92 | result = self.Foo.find( 93 | Q('integer_field', 'lte', self.foos[1].integer_field) 94 | ) 95 | self.assertEqual(len(result), 2) 96 | 97 | def test_in(self): 98 | """ Finds objects with the parameter in the attribute.""" 99 | result = self.Foo.find( 100 | Q('integer_field', 'in', [1, 11, 21, ]) 101 | ) 102 | self.assertEqual(len(result), 1) 103 | 104 | def test_nin(self): 105 | """ Finds objects with the parameter not in the attribute.""" 106 | result = self.Foo.find( 107 | Q('integer_field', 'nin', [1, 11, 21, ]) 108 | ) 109 | self.assertEqual(len(result), 2) 110 | -------------------------------------------------------------------------------- /tests/queries/test_foreign_queries.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from nose.tools import * # PEP8 asserts 3 | 4 | from modularodm import fields, StoredObject 5 | from modularodm.query.query import RawQuery as Q 6 | 7 | from tests.base import ModularOdmTestCase 8 | 9 | 10 | class TestForeignQueries(ModularOdmTestCase): 11 | 12 | def define_objects(self): 13 | 14 | class Foo(StoredObject): 15 | _id = fields.StringField(primary=True) 16 | _meta = { 17 | 'optimistic': True, 18 | } 19 | 20 | class Bar(StoredObject): 21 | _id = fields.StringField(primary=True) 22 | ref = fields.ForeignField('foo', backref='my_ref') 23 | abs_ref = fields.AbstractForeignField(backref='my_abs_ref') 24 | ref_list = fields.ForeignField('foo', backref='my_ref_list', list=True) 25 | abs_ref_list = fields.AbstractForeignField(backref='my_abs_ref_list', list=True) 26 | _meta = { 27 | 'optimistic': True, 28 | } 29 | 30 | return Foo, Bar 31 | 32 | def set_up_objects(self): 33 | 34 | self.foos = [] 35 | for _ in range(5): 36 | foo = self.Foo() 37 | foo.save() 38 | self.foos.append(foo) 39 | 40 | def test_eq_foreign(self): 41 | 42 | bar = self.Bar(ref=self.foos[0]) 43 | bar.save() 44 | 45 | result = self.Bar.find( 46 | Q('ref', 'eq', self.foos[0]) 47 | ) 48 | assert_equal(len(result), 1) 49 | 50 | result = self.Bar.find( 51 | Q('ref', 'eq', self.foos[-1]) 52 | ) 53 | assert_equal(len(result), 0) 54 | 55 | def test_eq_foreign_list(self): 56 | 57 | bar = self.Bar(ref_list=self.foos[:3]) 58 | bar.save() 59 | 60 | result = self.Bar.find( 61 | Q('ref_list', 'eq', self.foos[0]) 62 | ) 63 | assert_equal(len(result), 1) 64 | 65 | result = self.Bar.find( 66 | Q('ref_list', 'eq', self.foos[-1]) 67 | ) 68 | assert_equal(len(result), 0) 69 | 70 | def test_eq_abstract(self): 71 | 72 | bar = self.Bar(abs_ref=self.foos[0]) 73 | bar.save() 74 | 75 | result = self.Bar.find( 76 | Q('abs_ref', 'eq', self.foos[0]) 77 | ) 78 | assert_equal(len(result), 1) 79 | 80 | result = self.Bar.find( 81 | Q('abs_ref', 'eq', self.foos[-1]) 82 | ) 83 | assert_equal(len(result), 0) 84 | 85 | def test_eq_abstract_list(self): 86 | 87 | bar = self.Bar(abs_ref_list=self.foos[:3]) 88 | bar.save() 89 | 90 | result = self.Bar.find( 91 | Q('abs_ref_list', 'eq', self.foos[0]) 92 | ) 93 | assert_equal(len(result), 1) 94 | 95 | result = self.Bar.find( 96 | Q('abs_ref_list', 'eq', self.foos[-1]) 97 | ) 98 | assert_equal(len(result), 0) 99 | -------------------------------------------------------------------------------- /tests/queries/test_logical_operators.py: -------------------------------------------------------------------------------- 1 | from modularodm import fields, StoredObject 2 | from modularodm.query.querydialect import DefaultQueryDialect as Q 3 | 4 | from tests.base import ModularOdmTestCase 5 | 6 | class LogicalOperatorsBase(ModularOdmTestCase): 7 | def define_objects(self): 8 | class Foo(StoredObject): 9 | _id = fields.IntegerField(required=True, primary=True) 10 | a = fields.IntegerField() 11 | b = fields.IntegerField() 12 | 13 | return Foo, 14 | 15 | def set_up_objects(self): 16 | self.foos = [] 17 | 18 | for idx, f in enumerate([(a, b) for a in range(3) for b in range(3)]): 19 | self.foos.append( 20 | self.Foo( 21 | _id = idx, 22 | a = f[0], 23 | b = f[1], 24 | ) 25 | ) 26 | 27 | [x.save() for x in self.foos] 28 | 29 | def test_and(self): 30 | """Find the intersection of two or more queries.""" 31 | result = self.Foo.find(Q('a', 'eq', 0) & Q('b', 'eq', 1)) 32 | self.assertEqual( 33 | len(result), 34 | 1, 35 | ) 36 | self.assertEqual(result[0].a, 0) 37 | self.assertEqual(result[0].b, 1) 38 | 39 | def test_or(self): 40 | """Find the union of two or more queries.""" 41 | result = self.Foo.find(Q('a', 'eq', 0) | Q('a', 'eq', 1)) 42 | self.assertEqual( 43 | len(result), 44 | 6, 45 | ) 46 | 47 | def test_not(self): 48 | """Find the inverse of a query.""" 49 | result = self.Foo.find(~Q('a', 'eq', 0)) 50 | self.assertEqual( 51 | len(result), 52 | 6, 53 | ) 54 | 55 | def test_and_or(self): 56 | """Join multiple OR queries with an AND. 57 | 58 | """ 59 | result = self.Foo.find( 60 | (Q('a', 'eq', 0) | Q('a', 'eq', 1)) 61 | & (Q('b', 'eq', 1) | Q('b', 'eq', 2)) 62 | ) 63 | self.assertEqual( 64 | len(result), 65 | 4, 66 | ) 67 | 68 | -------------------------------------------------------------------------------- /tests/queries/test_simple_queries.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | 4 | from modularodm import exceptions, StoredObject 5 | from modularodm.fields import IntegerField 6 | from modularodm.query.query import RawQuery as Q 7 | 8 | from tests.base import ModularOdmTestCase 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class BasicQueryTestCase(ModularOdmTestCase): 14 | 15 | COUNT = 30 16 | 17 | def define_objects(self): 18 | class Foo(StoredObject): 19 | _id = IntegerField(primary=True) 20 | 21 | return Foo, 22 | 23 | def set_up_objects(self): 24 | self.foos = [] 25 | 26 | for idx in range(self.COUNT): 27 | foo = self.Foo(_id=idx) 28 | foo.save() 29 | self.foos.append(foo) 30 | 31 | def test_load_by_pk(self): 32 | """ Given a known primary key, ``.get(pk)`` should return the object. 33 | """ 34 | self.assertEqual( 35 | self.foos[0], 36 | self.Foo.load(0) 37 | ) 38 | 39 | def test_find_all(self): 40 | """ If no query object is passed, ``.find()`` should return all objects. 41 | """ 42 | self.assertEqual( 43 | len(self.Foo.find()), 44 | len(self.foos) 45 | ) 46 | 47 | def test_find_one(self): 48 | """ Given a query with exactly one result record, ``.find_one()`` should 49 | return that object. 50 | """ 51 | self.assertEqual( 52 | self.Foo.find_one(Q('_id', 'eq', 0))._id, 53 | self.foos[0]._id 54 | ) 55 | 56 | def test_find_one_return_zero(self): 57 | """ Given a query with zero result records, ``.find_one()`` should raise 58 | an appropriate error. 59 | """ 60 | with self.assertRaises(exceptions.NoResultsFound): 61 | self.Foo.find_one(Q('_id', 'eq', -1)) 62 | 63 | def test_find_one_return_many(self): 64 | """ Given a query with >1 result record, ``.find_one()`` should raise 65 | an appropriate error. 66 | """ 67 | with self.assertRaises(exceptions.MultipleResultsFound): 68 | result = self.Foo.find_one() 69 | logger.debug(result) 70 | 71 | def test_slice(self): 72 | queryset = self.Foo.find() 73 | queryslice = queryset[1:] 74 | self.assertEqual(queryset.count(), queryslice.count() + 1) 75 | self.assertEqual(queryset[1], queryslice[0]) 76 | 77 | def test_slice_negative_index(self): 78 | queryset = self.Foo.find() 79 | 80 | if queryset._NEGATIVE_INDEXING: 81 | self.assertEqual(queryset[-1], self.Foo.find().sort('-_id')[0]) 82 | else: 83 | with self.assertRaises(IndexError): 84 | queryset[-1] 85 | 86 | def test_slice_negative_slice(self): 87 | queryset = self.Foo.find() 88 | with self.assertRaises(IndexError): 89 | queryset[-5:-1] 90 | 91 | def test_slice_step(self): 92 | queryset = self.Foo.find() 93 | with self.assertRaises(IndexError): 94 | queryset[::2] 95 | 96 | def test_slice_reverse(self): 97 | queryset = self.Foo.find() 98 | with self.assertRaises(IndexError): 99 | queryset[5:0] 100 | 101 | # individual filter tests (limit, offset, sort) 102 | 103 | def test_limit(self): 104 | """ For a query that returns > n results, `.limit(n)` should return the 105 | first n. 106 | """ 107 | self.assertEqual( 108 | len(self.Foo.find().limit(10)), 109 | 10, 110 | ) 111 | 112 | self.assertEqual( 113 | len(self.Foo.find().limit(self.COUNT+10)), 114 | self.COUNT, 115 | ) 116 | # TODO: test limit = 0 117 | 118 | 119 | def test_offset(self): 120 | """For a query that returns n results, ``.offset(m)`` should return 121 | n - m results, skipping the first m that would otherwise have been 122 | returned. 123 | """ 124 | self.assertEqual( 125 | len(self.Foo.find().offset(25)), 126 | self.COUNT - 25, 127 | ) 128 | # TODO: test offset = 0, offset > self.COUNT 129 | 130 | 131 | def test_sort(self): 132 | results = self.Foo.find().sort('-_id') 133 | self.assertListEqual( 134 | [x._id for x in results], 135 | list(range(self.COUNT))[::-1], 136 | ) 137 | 138 | 139 | # paired filter tests: 140 | # limit + {limit,offset,sort} 141 | # offset + {offset,sort} 142 | # sort + sort 143 | # each test sub tests the filters in both orders. i.e. limit + offset 144 | # tests .limit().offset() AND .offset().limit() 145 | 146 | def test_limit_limit(self): 147 | self.assertEqual( len(self.Foo.find().limit(5).limit(10)), 10 ) 148 | self.assertEqual( len(self.Foo.find().limit(10).limit(5)), 5 ) 149 | 150 | 151 | def test_limit_offset(self): 152 | self.assertEqual( len(self.Foo.find().limit(2).offset(2)), 2 ) 153 | self.assertEqual( len(self.Foo.find().offset(2).limit(2)), 2 ) 154 | 155 | tmp = 5 156 | limit = tmp + 5 157 | offset = self.COUNT - tmp 158 | self.assertEqual(len(self.Foo.find().limit(limit).offset(offset)), tmp) 159 | self.assertEqual(len(self.Foo.find().offset(offset).limit(limit)), tmp) 160 | 161 | 162 | def test_limit_sort(self): 163 | limit, sort, = [10, '-_id'] 164 | expect = list(range(self.COUNT-limit, self.COUNT)[::-1]) 165 | 166 | results = self.Foo.find().limit(limit).sort(sort) 167 | self.assertListEqual([x._id for x in results], expect) 168 | 169 | results = self.Foo.find().sort(sort).limit(limit) 170 | self.assertListEqual([x._id for x in results], expect) 171 | 172 | 173 | def test_offset_offset(self): 174 | self.assertEqual( 175 | len(self.Foo.find().offset(10).offset(17)), 176 | self.COUNT-17 177 | ) 178 | self.assertEqual( 179 | len(self.Foo.find().offset(17).offset(10)), 180 | self.COUNT-10 181 | ) 182 | 183 | 184 | def test_offset_sort(self): 185 | offset, sort = [27, '-_id'] 186 | expect = list(range(self.COUNT-offset)[::-1]) 187 | 188 | results = self.Foo.find().offset(offset).sort(sort) 189 | self.assertListEqual([x._id for x in results], expect) 190 | 191 | results = self.Foo.find().sort(sort).offset(offset) 192 | self.assertListEqual([x._id for x in results], expect) 193 | 194 | 195 | def test_sort_sort(self): 196 | results = self.Foo.find().sort('-_id').sort('_id') 197 | self.assertListEqual( 198 | [x._id for x in results], 199 | list(range(self.COUNT)), 200 | ) 201 | results = self.Foo.find().sort('_id').sort('-_id') 202 | self.assertListEqual( 203 | [x._id for x in results], 204 | list(range(self.COUNT)[::-1]), 205 | ) 206 | 207 | 208 | # all three filters together 209 | 210 | def test_limit_offset_sort(self): 211 | test_sets = [ 212 | # limit offset sort expect 213 | [ 10, 7, '-_id', list(range(self.COUNT-7-10, self.COUNT-7)[::-1]), ], 214 | [ 20, 17, '_id', list(range(17, self.COUNT)), ], 215 | [ 10, 5, '_id', list(range(5, 5+10)), ], 216 | ] 217 | for test in test_sets: 218 | limit, offset, sort, expect = test 219 | all_combinations = [ 220 | self.Foo.find().limit(limit).offset(offset).sort(sort), 221 | self.Foo.find().limit(limit).sort(sort).offset(offset), 222 | self.Foo.find().offset(offset).limit(limit).sort(sort), 223 | self.Foo.find().offset(offset).sort(sort).limit(limit), 224 | self.Foo.find().sort(sort).limit(limit).offset(offset), 225 | self.Foo.find().sort(sort).offset(offset).limit(limit), 226 | ] 227 | 228 | for result in all_combinations: 229 | self.assertListEqual( [x._id for x in result], expect ) 230 | -------------------------------------------------------------------------------- /tests/queries/test_string_operators.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from modularodm import fields, StoredObject 4 | from modularodm.query.query import RawQuery as Q 5 | 6 | from tests.base import ModularOdmTestCase 7 | 8 | 9 | # TODO: The following are defined in MongoStorage, but not PickleStorage: 10 | # 'istartswith' 11 | # 'iendswith', 12 | # 'exact', 13 | # 'iexact' 14 | 15 | 16 | class StringComparisonTestCase(ModularOdmTestCase): 17 | 18 | def define_objects(self): 19 | class Foo(StoredObject): 20 | _id = fields.IntegerField(primary=True) 21 | string_field = fields.StringField() 22 | 23 | return Foo, 24 | 25 | def set_up_objects(self): 26 | self.foos = [] 27 | 28 | field_values = ( 29 | 'first value', 30 | 'second value', 31 | 'third value', 32 | ) 33 | 34 | for idx in range(len(field_values)): 35 | foo = self.Foo( 36 | _id=idx, 37 | string_field=field_values[idx], 38 | ) 39 | foo.save() 40 | self.foos.append(foo) 41 | 42 | def tear_down_objects(self): 43 | try: 44 | os.remove('db_Test.pkl') 45 | except OSError: 46 | pass 47 | 48 | def test_contains(self): 49 | """ Finds objects with the attribute containing the substring.""" 50 | result = self.Foo.find( 51 | Q('string_field', 'contains', 'second') 52 | ) 53 | self.assertEqual(len(result), 1) 54 | 55 | def test_icontains(self): 56 | """ Operates as ``contains``, but ignores case.""" 57 | result = self.Foo.find( 58 | Q('string_field', 'icontains', 'SeCoNd') 59 | ) 60 | self.assertEqual(len(result), 1) 61 | 62 | def test_startwith(self): 63 | """ Finds objects where the attribute begins with the substring """ 64 | result = self.Foo.find( 65 | Q('string_field', 'startswith', 'second') 66 | ) 67 | self.assertEqual(len(result), 1) 68 | 69 | def test_endswith(self): 70 | """ Finds objects where the attribute ends with the substring """ 71 | result = self.Foo.find( 72 | Q('string_field', 'endswith', 'value') 73 | ) 74 | self.assertEqual(len(result), 3) 75 | -------------------------------------------------------------------------------- /tests/queries/test_update_queries.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from modularodm import exceptions, fields, StoredObject 4 | from modularodm.query.querydialect import DefaultQueryDialect as Q 5 | 6 | from tests.base import ModularOdmTestCase 7 | 8 | 9 | class UpdateQueryTestCase(ModularOdmTestCase): 10 | 11 | def define_objects(self): 12 | class Foo(StoredObject): 13 | _id = fields.IntegerField(primary=True) 14 | modified = fields.BooleanField(default=False) 15 | 16 | return Foo, 17 | 18 | def set_up_objects(self): 19 | self.foos = [] 20 | 21 | for idx in range(5): 22 | foo = self.Foo( 23 | _id=idx 24 | ) 25 | foo.save() 26 | self.foos.append(foo) 27 | 28 | def tear_down_objects(self): 29 | try: 30 | os.remove('db_Test.pkl') 31 | except OSError: 32 | pass 33 | 34 | def test_update(self): 35 | """ Given a query, and an update clause, update all (and only) object 36 | returned by query. 37 | """ 38 | self.Foo.update( 39 | query=Q('_id', 'eq', 2), 40 | data={'modified': True} 41 | ) 42 | 43 | self.assertEqual( 44 | [x.modified for x in self.foos], 45 | [False, False, True, False, False], 46 | ) 47 | 48 | def test_update_one(self): 49 | """ Given a primary key, update the referenced object according to the 50 | update clause 51 | """ 52 | self.Foo.update_one( 53 | which=Q('_id', 'eq', 2), 54 | data={'modified': True} 55 | ) 56 | 57 | self.assertEqual( 58 | [x.modified for x in self.foos], 59 | [False, False, True, False, False], 60 | ) 61 | 62 | def test_remove(self): 63 | """ Given a query, remove all (and only) object returned by query. """ 64 | self.Foo.remove(Q('_id', 'eq', 2)) 65 | 66 | self.assertEqual( 67 | self.Foo.find().count(), 68 | 4 69 | ) 70 | 71 | def test_remove_one(self): 72 | """ Given a primary key, remove the referenced object. """ 73 | self.Foo.remove_one(Q('_id', 'eq', 2)) 74 | 75 | self.assertEqual( 76 | self.Foo.find().count(), 77 | 4 78 | ) 79 | 80 | def test_remove_one_returns_many(self): 81 | """ Given a primary key, remove the referenced object. """ 82 | with self.assertRaises(exceptions.ModularOdmException): 83 | self.Foo.remove_one(Q('_id', 'gt', 2)) 84 | 85 | def test_remove_one_returns_none(self): 86 | """ Given a primary key, remove the referenced object. """ 87 | with self.assertRaises(exceptions.ModularOdmException): 88 | self.Foo.remove_one(Q('_id', 'eq', 100)) 89 | -------------------------------------------------------------------------------- /tests/test_fields.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import datetime 4 | import unittest 5 | from nose.tools import * # PEP8 asserts 6 | 7 | from modularodm import StoredObject, fields, storage, exceptions 8 | 9 | 10 | def set_datetime(): 11 | return datetime.datetime(1999, 1, 2, 3, 45) 12 | 13 | 14 | class User(StoredObject): 15 | _id = fields.StringField(primary=True) 16 | name = fields.StringField(required=True) 17 | date_created = fields.DateTimeField(auto_now_add=set_datetime) 18 | date_updated = fields.DateTimeField(auto_now=set_datetime) 19 | read_only = fields.StringField(editable=False) 20 | unique = fields.StringField(unique=True) 21 | 22 | _meta = {'optimistic': True} 23 | 24 | 25 | pickle_storage = storage.PickleStorage('fields', prefix='test_') 26 | User.set_storage(storage.PickleStorage('fields', prefix='test_')) 27 | 28 | 29 | class TestField(unittest.TestCase): 30 | 31 | def setUp(self): 32 | pickle_storage._delete_file() 33 | 34 | def tearDown(self): 35 | pickle_storage._delete_file() 36 | 37 | def test_update_fields(self): 38 | u = User(name='foo') 39 | u.update_fields(name="bazzle", _id=932) 40 | assert_equal(u.name, "bazzle") 41 | assert_equal(u._id, 932) 42 | 43 | def test_validators_must_be_callable(self): 44 | assert_raises(TypeError, lambda: fields.Field(validate="invalid")) 45 | assert_raises(TypeError, lambda: fields.Field(validate=["invalid"])) 46 | 47 | def test_uneditable_field(self): 48 | u = User(name='Foo') 49 | with assert_raises(AttributeError): 50 | u.read_only = 'foo' 51 | 52 | def test_required_field(self): 53 | u = User() 54 | assert_raises(exceptions.ValidationError, lambda: u.save()) 55 | 56 | def test_unique_field(self): 57 | u0 = User(name='bob', unique='foo') 58 | u1 = User(name='bob', unique='bar') 59 | u2 = User(name='bob', unique='foo') 60 | u0.save() 61 | u1.save() 62 | # Fail on saving repeated value 63 | with self.assertRaises(ValueError): 64 | u2.save() 65 | 66 | def test_unique_ignores_self(self): 67 | u0 = User(name='bob', unique='qux') 68 | u0.save() 69 | User._clear_caches() 70 | u0.save() 71 | 72 | def test_unique_ignores_none(self): 73 | u0 = User(name='bob') 74 | u1 = User(name='bob') 75 | u0.save() 76 | u1.save() 77 | 78 | 79 | class TestListField(unittest.TestCase): 80 | 81 | def test_default_must_be_list(self): 82 | assert_raises(TypeError, 83 | lambda: fields.ListField(fields.StringField(), default=3)) 84 | assert_raises(TypeError, 85 | lambda: fields.ListField(fields.StringField(), default="123")) 86 | assert_raises(TypeError, 87 | lambda: fields.ListField(fields.StringField(), default=True)) 88 | assert_raises(TypeError, 89 | lambda: fields.ListField(fields.StringField(), default={"default": (1,2)})) 90 | 91 | 92 | class TestDateTimeField(unittest.TestCase): 93 | 94 | def setUp(self): 95 | self.user = User(name="Foo") 96 | self.user.save() 97 | 98 | def tearDown(self): 99 | os.remove('test_fields.pkl') 100 | 101 | def test_auto_now_utcnow(self): 102 | expected = set_datetime() 103 | assert_equal(self.user.date_updated, expected) 104 | 105 | def test_auto_now_add(self): 106 | expected = set_datetime() 107 | assert_equal(self.user.date_created, expected) 108 | 109 | def test_uncallable_auto_now_param_raises_type_error(self): 110 | with assert_raises(ValueError): 111 | fields.DateTimeField(auto_now='uncallable') 112 | 113 | def test_cant_use_auto_now_and_auto_now_add(self): 114 | with assert_raises(ValueError): 115 | fields.DateTimeField( 116 | auto_now=datetime.datetime.now, 117 | auto_now_add=datetime.datetime.utcnow 118 | ) 119 | 120 | 121 | class TestForeignField(unittest.TestCase): 122 | 123 | def setUp(self): 124 | class Parent(StoredObject): 125 | _id = fields.IntegerField(primary=True) 126 | self.Parent = Parent 127 | 128 | def test_string_reference(self): 129 | class Child(StoredObject): 130 | _id = fields.IntegerField(primary=True) 131 | parent = fields.ForeignField('Parent') 132 | assert_equal( 133 | Child._fields['parent'].base_class, 134 | self.Parent 135 | ) 136 | 137 | def test_string_reference_unknown(self): 138 | class Child(StoredObject): 139 | _id = fields.IntegerField(primary=True) 140 | parent = fields.ForeignField('Grandparent') 141 | with assert_raises(exceptions.ModularOdmException): 142 | Child._fields['parent'].base_class 143 | 144 | def test_class_reference(self): 145 | class Child(StoredObject): 146 | _id = fields.IntegerField(primary=True) 147 | parent = fields.ForeignField(self.Parent) 148 | assert_equal( 149 | Child._fields['parent'].base_class, 150 | self.Parent 151 | ) 152 | 153 | 154 | if __name__ == '__main__': 155 | unittest.main() 156 | -------------------------------------------------------------------------------- /tests/test_foreign.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | 5 | from nose.tools import * 6 | 7 | from tests.base import ModularOdmTestCase, TestObject 8 | 9 | from modularodm import fields 10 | 11 | 12 | class TestForeignList(ModularOdmTestCase): 13 | 14 | def define_objects(self): 15 | 16 | class Foo(TestObject): 17 | _id = fields.IntegerField() 18 | bars = fields.ForeignField('bar', list=True) 19 | 20 | class Bar(TestObject): 21 | _id = fields.IntegerField() 22 | 23 | return Foo, Bar 24 | 25 | def set_up_objects(self): 26 | 27 | self.foo = self.Foo(_id=1) 28 | 29 | self.bars = [] 30 | for idx in range(5): 31 | self.bars.append(self.Bar(_id=idx)) 32 | self.bars[idx].save() 33 | 34 | self.foo.bars = self.bars 35 | self.foo.save() 36 | 37 | def test_get_item(self): 38 | assert_equal(self.bars[2], self.foo.bars[2]) 39 | 40 | def test_get_slice(self): 41 | assert_equal(self.bars[:3], list(self.foo.bars[:3])) 42 | 43 | def test_get_slice_extended(self): 44 | assert_equal(self.bars[::-1], list(self.foo.bars[::-1])) 45 | 46 | 47 | class TestAbstractForeignList(ModularOdmTestCase): 48 | 49 | def define_objects(self): 50 | 51 | class Foo(TestObject): 52 | _id = fields.IntegerField() 53 | bars = fields.AbstractForeignField(list=True) 54 | 55 | class Bar(TestObject): 56 | _id = fields.IntegerField() 57 | 58 | return Foo, Bar 59 | 60 | def set_up_objects(self): 61 | 62 | self.foo = self.Foo(_id=1) 63 | 64 | self.bars = [] 65 | for idx in range(5): 66 | self.bars.append(self.Bar(_id=idx)) 67 | self.bars[idx].save() 68 | 69 | self.foo.bars = self.bars 70 | self.foo.save() 71 | 72 | def test_get_item(self): 73 | assert_equal(self.bars[2], self.foo.bars[2]) 74 | 75 | def test_get_slice(self): 76 | assert_equal(self.bars[:3], list(self.foo.bars[:3])) 77 | 78 | def test_get_slice_extended(self): 79 | assert_equal(self.bars[::-1], list(self.foo.bars[::-1])) 80 | 81 | -------------------------------------------------------------------------------- /tests/test_migration.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from nose.tools import * # PEP8 asserts 3 | 4 | from modularodm import StoredObject, fields, exceptions 5 | from tests.base import ModularOdmTestCase 6 | 7 | class TestMigration(ModularOdmTestCase): 8 | 9 | def define_objects(self): 10 | 11 | # Use a single storage object for both schema versions 12 | self._storage = self.make_storage() 13 | 14 | class V1(StoredObject): 15 | _id = fields.StringField(_primary_key=True, index=True) 16 | my_string = fields.StringField() 17 | my_float = fields.FloatField() 18 | my_number = fields.FloatField() 19 | my_null = fields.StringField(required=False) 20 | _meta = { 21 | 'optimistic': True, 22 | 'version': 1, 23 | 'optimistic': True 24 | } 25 | V1.set_storage(self._storage) 26 | 27 | class V2(StoredObject): 28 | _id = fields.StringField(_primary_key=True, index=True) 29 | my_string = fields.StringField() 30 | my_int = fields.IntegerField(default=5) 31 | my_number = fields.IntegerField() 32 | my_null = fields.StringField(required=False) 33 | 34 | @classmethod 35 | def _migrate(cls, old, new): 36 | if old.my_string: 37 | new.my_string = old.my_string + 'yo' 38 | if old.my_number: 39 | new.my_number = int(old.my_number) 40 | 41 | _meta = { 42 | 'optimistic': True, 43 | 'version_of': V1, 44 | 'version': 2, 45 | 'optimistic': True 46 | } 47 | V2.set_storage(self._storage) 48 | 49 | return V1, V2 50 | 51 | def set_up_objects(self): 52 | self.record = self.V1(my_string='hi', my_float=1.2, my_number=3.4) 53 | self.record.save() 54 | self.migrated_record = self.V2.load(self.record._primary_key) 55 | self.migrated_record.save() 56 | 57 | def test_version_number(self): 58 | assert_equal(self.migrated_record._version, 2) 59 | 60 | def test_new_field(self): 61 | assert_in('my_int', self.migrated_record._fields) 62 | assert_equal(self.migrated_record.my_int, 5) 63 | 64 | def test_deleted_field(self): 65 | assert_in('my_float', self.record._fields) 66 | assert_not_in('my_float', self.migrated_record._fields) 67 | 68 | def test_migrated_field(self): 69 | assert_equal(self.migrated_record.my_string, 'hiyo') 70 | 71 | def test_versions_contain_same_records(self): 72 | for i in range(5): 73 | record = self.V1(my_string="foo") 74 | record.save() 75 | assert_equal(len(self.V1.find()), len(self.V2.find())) 76 | # Primary keys are the same 77 | for old_rec, new_rec in zip(self.V1.find(), self.V2.find()): 78 | assert_equal(old_rec._primary_key, new_rec._primary_key) 79 | 80 | def test_changed_number_type_field(self): 81 | assert_true(isinstance(self.migrated_record._fields['my_number'], 82 | fields.IntegerField)) 83 | assert_true(isinstance(self.migrated_record.my_number, int)) 84 | assert_equal(self.migrated_record.my_number, int(self.record.my_number)) 85 | 86 | def test_making_field_required_without_default_raises_error(self): 87 | # TODO: This test raises a warning for setting a non-field value 88 | class V3(StoredObject): 89 | _id = fields.StringField(_primary_key=True, index=True) 90 | my_string = fields.StringField() 91 | my_int = fields.IntegerField(default=5) 92 | my_number = fields.IntegerField() 93 | my_null = fields.StringField(required=True) 94 | 95 | _meta = { 96 | 'optimistic': True, 97 | 'version_of': self.V2, 98 | 'version': 3, 99 | 'optimistic': True 100 | } 101 | V3.set_storage(self._storage) 102 | migrated = V3.load(self.migrated_record._primary_key) 103 | assert_raises(exceptions.ValidationError, lambda: migrated.save()) 104 | 105 | def test_making_field_required_with_default(self): 106 | class V3(StoredObject): 107 | _id = fields.StringField(_primary_key=True, index=True) 108 | my_string = fields.StringField() 109 | my_int = fields.IntegerField(default=5) 110 | my_number = fields.IntegerField() 111 | my_null = fields.StringField(required=True) 112 | @classmethod 113 | def _migrate(cls, old, new): 114 | if not old.my_null: 115 | new.my_null = 'default' 116 | _meta = { 117 | 'optimistic': True, 118 | 'version_of': self.V2, 119 | 'version': 3, 120 | 'optimistic': True 121 | } 122 | V3.set_storage(self._storage) 123 | old = self.V1() 124 | old.save() 125 | migrated = V3.load(old._primary_key) 126 | migrated.save() 127 | assert_equal(migrated.my_null, "default") 128 | 129 | def test_migrate_all(self): 130 | for i in range(5): 131 | rec = self.V1(my_string="foo{0}".format(i)) 132 | rec.save() 133 | self.V2.migrate_all() 134 | # TODO: This used to be self.V2.find() (without the "count") 135 | # WHY WOULD THIS WORK?! 136 | assert_greater_equal(self.V2.find().count(), 5) 137 | for record in self.V2.find(): 138 | assert_true(record.my_string.endswith("yo")) 139 | 140 | def test_save_migrated(self): 141 | try: 142 | self.migrated_record.save() 143 | except: 144 | assert False 145 | 146 | if __name__ == '__main__': 147 | unittest.main() 148 | -------------------------------------------------------------------------------- /tests/test_queue.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from nose.tools import * 3 | 4 | from modularodm import writequeue 5 | 6 | from tests.utils import make_model 7 | 8 | 9 | class TestWriteAction(unittest.TestCase): 10 | 11 | def test_init(self): 12 | key = lambda x: -1 * x 13 | action = writequeue.WriteAction( 14 | max, 15 | 1, 2, 3, 16 | key=key 17 | ) 18 | assert_equal(action.method, max) 19 | assert_equal(action.args, (1, 2, 3)) 20 | assert_equal(action.kwargs, {'key': key}) 21 | 22 | def test_init_method_not_callable(self): 23 | with assert_raises(ValueError): 24 | writequeue.WriteAction(None) 25 | 26 | def test_execute(self): 27 | action = writequeue.WriteAction( 28 | max, 29 | 1, 2, 3, 30 | key=lambda x: -1 * x 31 | ) 32 | assert_equal(action.execute(), 1) 33 | 34 | 35 | class TestWriteQueue(unittest.TestCase): 36 | 37 | def setUp(self): 38 | self.queue = writequeue.WriteQueue() 39 | self.queue.start() 40 | 41 | def test_init(self): 42 | assert_false(self.queue.actions) 43 | 44 | def test_push(self): 45 | action1 = writequeue.WriteAction(min) 46 | action2 = writequeue.WriteAction(max) 47 | self.queue.push(action1) 48 | self.queue.push(action2) 49 | assert_equal(list(self.queue.actions), [action1, action2]) 50 | 51 | def test_push_invalid(self): 52 | with assert_raises(TypeError): 53 | self.queue.push(None) 54 | 55 | def test_commit(self): 56 | action1 = writequeue.WriteAction(min, 1, 2) 57 | action2 = writequeue.WriteAction(max, 1, 2) 58 | self.queue.push(action1) 59 | self.queue.push(action2) 60 | results = self.queue.commit() 61 | assert_equal(results, [1, 2]) 62 | assert_false(self.queue.actions) 63 | 64 | def test_clear(self): 65 | action1 = writequeue.WriteAction(min) 66 | action2 = writequeue.WriteAction(max) 67 | self.queue.push(action1) 68 | self.queue.push(action2) 69 | self.queue.clear() 70 | assert_false(self.queue.actions, ) 71 | 72 | def test_bool_true(self): 73 | self.queue.push(writequeue.WriteAction(zip)) 74 | assert_true(self.queue) 75 | 76 | def test_bool_false(self): 77 | queue = writequeue.WriteQueue() 78 | assert_false(queue) 79 | 80 | 81 | class QueueTestCase(unittest.TestCase): 82 | 83 | def setUp(self): 84 | self.Model = make_model() 85 | self.Model.queue.clear() 86 | 87 | def enqueue_record(self, _id=1): 88 | record = self.Model(_id=_id) 89 | record.save() 90 | return record 91 | 92 | 93 | class TestQueueContext(QueueTestCase): 94 | 95 | def test_context(self): 96 | with writequeue.QueueContext(self.Model): 97 | self.enqueue_record() 98 | assert_false(self.Model._storage[0].insert.called) 99 | assert_true(self.Model._storage[0].insert.called) 100 | 101 | 102 | class TestModelQueue(QueueTestCase): 103 | 104 | def test_queue_initial(self): 105 | assert_false(self.Model.queue) 106 | assert_false(self.Model.queue.active) 107 | 108 | def test_start_queue(self): 109 | self.Model.start_queue() 110 | assert_true(isinstance(self.Model.queue, writequeue.WriteQueue)) 111 | assert_false(self.Model.queue.actions) 112 | 113 | def test_start_queue_again(self): 114 | self.Model.start_queue() 115 | self.enqueue_record() 116 | queue = self.Model.queue 117 | self.Model.start_queue() 118 | assert_true(self.Model.queue is queue) 119 | 120 | def test_save_no_queue(self): 121 | self.enqueue_record() 122 | assert_true(self.Model._storage[0].insert.called) 123 | 124 | def test_save_queue(self): 125 | self.Model.start_queue() 126 | self.enqueue_record() 127 | assert_false(self.Model._storage[0].insert.called) 128 | 129 | def test_clear_queue(self): 130 | self.Model.start_queue() 131 | self.Model.clear_queue() 132 | assert_false(self.Model.queue) 133 | assert_false(self.Model.queue.active) 134 | 135 | def test_cancel_queue(self): 136 | self.Model.start_queue() 137 | self.enqueue_record() 138 | self.Model.cancel_queue() 139 | assert_false(self.Model.queue) 140 | assert_false(self.Model.queue.active) 141 | assert_false(self.Model._cache) 142 | 143 | def test_cancel_queue_empty(self): 144 | self.enqueue_record() 145 | self.Model.start_queue() 146 | self.Model.cancel_queue() 147 | assert_false(self.Model.queue) 148 | assert_false(self.Model.queue.active) 149 | assert_true(self.Model._cache) 150 | 151 | def test_commit_queue(self): 152 | self.Model.start_queue() 153 | self.enqueue_record() 154 | assert_false(self.Model._storage[0].insert.called) 155 | self.Model.commit_queue() 156 | assert_true(self.Model._storage[0].insert.called) 157 | assert_false(self.Model.queue) 158 | assert_false(self.Model.queue.active) 159 | 160 | def test_commit_queue_crash(self): 161 | self.Model.start_queue() 162 | self.Model._storage[0].update.side_effect = ValueError() 163 | record = self.enqueue_record() 164 | record.value = 'error' 165 | record.save() 166 | self.Model.remove_one(record) 167 | with assert_raises(ValueError): 168 | self.Model.commit_queue() 169 | assert_true(self.Model._storage[0].insert.called) 170 | assert_true(self.Model._storage[0].update.called) 171 | assert_false(self.Model._storage[0].remove.called) 172 | assert_false(self.Model.queue) 173 | assert_false(self.Model.queue.active) 174 | 175 | def test_models_share_null_queue(self): 176 | model2 = make_model() 177 | assert_true(self.Model.queue is model2.queue) 178 | 179 | def test_models_share_queue(self): 180 | model2 = make_model() 181 | model2.start_queue() 182 | assert_true(self.Model.queue is model2.queue) 183 | -------------------------------------------------------------------------------- /tests/test_signals.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import mock 4 | import unittest 5 | from nose.tools import * 6 | 7 | from tests.utils import make_model 8 | 9 | from modularodm import StoredObject 10 | 11 | 12 | class TestSignals(unittest.TestCase): 13 | 14 | def setUp(self): 15 | self.model = make_model() 16 | 17 | @mock.patch('modularodm.storedobject.signals.save.connect_via') 18 | def test_subscribe_weak(self, mock_connect_via): 19 | self.model.subscribe('save', weak=True) 20 | mock_connect_via.assert_called_once_with( 21 | self.model, 22 | True, 23 | ) 24 | 25 | @mock.patch('modularodm.storedobject.signals.save.connect_via') 26 | def test_subscribe_strong(self, mock_connect_via): 27 | self.model.subscribe('save', weak=False) 28 | mock_connect_via.assert_called_once_with( 29 | self.model, 30 | False, 31 | ) 32 | 33 | @mock.patch('modularodm.storedobject.signals.save.connect_via') 34 | def test_subscribe_from_base_schema(self, mock_connect_via): 35 | StoredObject.subscribe('save', weak=False) 36 | mock_connect_via.assert_called_once_with( 37 | None, 38 | False 39 | ) 40 | 41 | def _subscribe_mock(self, signal_name, weak): 42 | """Create a mock callback function and subscribe it to the specific 43 | signal on ``self.model``. 44 | 45 | """ 46 | callback = mock.Mock() 47 | decorator = self.model.subscribe(signal_name, weak) 48 | return decorator(callback) 49 | 50 | def test_before_save(self): 51 | 52 | connected_callback = self._subscribe_mock('before_save', weak=False) 53 | 54 | record = self.model(_id=1) 55 | record.save() 56 | 57 | connected_callback.assert_called_once_with( 58 | record.__class__, instance=record, 59 | ) 60 | 61 | def test_save(self): 62 | 63 | connected_callback = self._subscribe_mock('save', weak=False) 64 | 65 | record = self.model(_id=1) 66 | record.save() 67 | 68 | connected_callback.assert_called_once_with( 69 | record.__class__, 70 | instance=record, 71 | fields_changed={'_id', 'value'}, 72 | cached_data={}, 73 | ) 74 | 75 | def test_load(self): 76 | 77 | connected_callback = self._subscribe_mock('load', weak=False) 78 | 79 | record = self.model(_id=1) 80 | record.save() 81 | self.model.load(1) 82 | 83 | connected_callback.assert_called_once_with( 84 | record.__class__, 85 | key=1, 86 | data=None, 87 | ) 88 | -------------------------------------------------------------------------------- /tests/test_storage.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import unittest 3 | from nose.tools import * # PEP8 asserts 4 | 5 | from modularodm import StoredObject, fields, storage 6 | 7 | 8 | class User(StoredObject): 9 | _id = fields.StringField(primary=True) 10 | _meta = {'optimistic': True} 11 | 12 | 13 | class TestStorage(unittest.TestCase): 14 | 15 | def test_bad_set_storage_argument(self): 16 | assert_raises(TypeError, lambda: User.set_storage("foo")) 17 | 18 | 19 | if __name__ == '__main__': 20 | unittest.main() 21 | -------------------------------------------------------------------------------- /tests/test_storedobject.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import mock 3 | import unittest 4 | import datetime 5 | import os 6 | from glob import glob 7 | 8 | from nose.tools import * # PEP8 asserts 9 | 10 | from modularodm import StoredObject, fields, exceptions, storage 11 | from modularodm.validators import MinLengthValidator 12 | 13 | 14 | class Tag(StoredObject): 15 | _id = fields.StringField(primary=True) 16 | date_created = fields.DateTimeField(validate=True, auto_now_add=True) 17 | date_modified = fields.DateTimeField(validate=True, auto_now=True) 18 | value = fields.StringField(default='default', validate=MinLengthValidator(5)) 19 | keywords = fields.StringField(default=['keywd1', 'keywd2'], validate=MinLengthValidator(5), list=True) 20 | _meta = {'optimistic':True} 21 | 22 | Tag.set_storage(storage.PickleStorage('tag')) 23 | 24 | 25 | class User(StoredObject): 26 | _id = fields.StringField(primary=True) 27 | _meta = {"optimistic": True} 28 | name = fields.StringField() 29 | 30 | 31 | class Comment(StoredObject): 32 | _id = fields.StringField(primary=True) 33 | text = fields.StringField() 34 | user = fields.ForeignField("user", backref="comments") 35 | _meta = {"optimistic": True} 36 | 37 | 38 | User.set_storage(storage.PickleStorage("test_user", prefix=None)) 39 | Comment.set_storage(storage.PickleStorage("test_comment", prefix=None)) 40 | 41 | 42 | class TestStoredObject(unittest.TestCase): 43 | 44 | @staticmethod 45 | def clear_pickle_files(): 46 | for fname in glob("*.pkl"): 47 | os.remove(fname) 48 | 49 | def setUp(self): 50 | self.clear_pickle_files() 51 | 52 | def tearDown(self): 53 | for fname in glob("*.pkl"): 54 | os.remove(fname) 55 | 56 | def test_string_default(self): 57 | """ Make sure the default option works for StringField fields. """ 58 | tag = Tag() 59 | self.assertEqual(tag.value, 'default') 60 | 61 | @unittest.skip('needs review') 62 | def test_stringlist_default(self): 63 | tag = Tag() 64 | self.assertEqual(tag.keywords[0], 'keywd1') 65 | self.assertEqual(tag.keywords[1], 'keywd2') 66 | 67 | def test_set_attribute(self): 68 | user = User() 69 | user.name = "Foo Bar" 70 | user.save() 71 | assert_equal(user.name, "Foo Bar") 72 | 73 | # Datetime tests 74 | 75 | def _times_approx_equal(self, first, second=None, tolerance=0.1): 76 | self.assertLess( 77 | abs((second or datetime.datetime.utcnow()) - first), 78 | datetime.timedelta(seconds=tolerance) 79 | ) 80 | 81 | def test_default_datetime(self): 82 | tag = Tag() 83 | tag.save() 84 | self._times_approx_equal(tag.date_created) 85 | 86 | @unittest.skip('needs review') 87 | def test_parse_datetime(self): 88 | tag = Tag() 89 | tag.date_created = 'october 1, 1985, 10:05 am' 90 | self.assertEqual(tag.date_created, datetime.datetime(1985, 10, 1, 10, 5)) 91 | 92 | @unittest.skip('needs review') 93 | def test_parse_bad_datetime(self): 94 | tag = Tag() 95 | def _(): 96 | tag.date_created = 'cant parse this!' 97 | self.assertRaises(ValueError, _) 98 | 99 | def test_auto_now(self): 100 | tag = Tag() 101 | tag.save() 102 | self._times_approx_equal(tag.date_modified) 103 | 104 | @mock.patch('modularodm.storedobject.StoredObject.to_storage') 105 | def test_eq_same_object_does_not_call_serialize(self, mock_serialize): 106 | tag = Tag() 107 | assert_true(tag == tag) 108 | assert_false(tag != tag) 109 | assert_false(mock_serialize.called) 110 | 111 | @mock.patch('modularodm.storedobject.StoredObject.to_storage') 112 | def test_eq_different_object_different_keys_does_not_call_serialize(self, mock_serialize): 113 | tag1, tag2 = Tag(_id='foo'), Tag(_id='bar') 114 | assert_false(tag1 == tag2) 115 | assert_true(tag1 != tag2) 116 | assert_false(mock_serialize.called) 117 | 118 | @mock.patch('modularodm.storedobject.StoredObject.to_storage') 119 | def test_eq_different_object_same_keys_calls_serialize(self, mock_serialize): 120 | tag1, tag2 = Tag(_id='foo'), Tag(_id='foo') 121 | tag1 == tag2 122 | assert_true(mock_serialize.called) 123 | 124 | # Foreign tests 125 | def test_foreign_many_to_one_set(self): 126 | pass 127 | 128 | def test_foreign_many_to_one_replace(self): 129 | pass 130 | 131 | def test_foreign_many_to_one_del(self): 132 | pass 133 | 134 | def test_foriegn_many_to_many_set(self): 135 | pass 136 | 137 | def test_foreign_many_to_many_setitem(self): 138 | pass 139 | 140 | def test_foreign_many_to_many_insert(self): 141 | pass 142 | 143 | def test_foreign_many_to_many_append(self): 144 | pass 145 | 146 | # Query tests 147 | 148 | def test_find_all(self): 149 | pass 150 | 151 | def test_find_one_no_results(self): 152 | pass 153 | 154 | def test_find_one_multiple_results(self): 155 | pass 156 | 157 | def test_find_one_one_result(self): 158 | pass 159 | 160 | def test_find_exact_match(self): 161 | pass 162 | 163 | def test_find_value_property(self): 164 | pass 165 | 166 | def test_find_operator_method(self): 167 | pass 168 | 169 | def test_infer_primary_if_none_explicit(self): 170 | """ """ 171 | class Schema(StoredObject): 172 | _id = fields.StringField() 173 | assert_true(Schema._fields['_id']._is_primary) 174 | assert_equal(Schema._primary_name, '_id') 175 | 176 | def test_dont_infer_primary_if_explicit(self): 177 | """ """ 178 | class Schema(StoredObject): 179 | _id = fields.StringField() 180 | primary = fields.StringField(primary=True) 181 | assert_true(Schema._fields['primary']._is_primary) 182 | assert_false(Schema._fields['_id']._is_primary) 183 | assert_equal(Schema._primary_name, 'primary') 184 | 185 | def test_cant_have_multiple_primary_keys(self): 186 | """Assigning an object that has not been saved as a foreign field 187 | should raise an AttributeError. 188 | 189 | """ 190 | with assert_raises(AttributeError): 191 | class BadObject(StoredObject): 192 | _id = fields.StringField(primary=True) 193 | another_id = fields.StringField(primary=True) 194 | 195 | def test_must_have_primary_key_not_abstract(self): 196 | with assert_raises(AttributeError): 197 | class NoPK(StoredObject): 198 | dummy = fields.StringField() 199 | 200 | def test_can_have_no_primary_key_if_abstract(self): 201 | try: 202 | class AbstractSchema(StoredObject): 203 | field = fields.StringField() 204 | _meta = { 205 | 'abstract': True, 206 | } 207 | except: 208 | assert False 209 | 210 | def test_cannot_instantiate_abstract_schema(self): 211 | with assert_raises(TypeError): 212 | class AbstractSchema(StoredObject): 213 | field = fields.StringField() 214 | _meta = { 215 | 'abstract': True, 216 | } 217 | abstract_instance = AbstractSchema() 218 | 219 | def test_must_be_loaded(self): 220 | user = User() 221 | assert_raises(exceptions.DatabaseError, lambda: Comment(user=user)) 222 | 223 | def test_must_be_loaded_list(self): 224 | """ Assigning an object that has not been saved to a foreign list field should throw an exception. """ 225 | pass 226 | 227 | def test_has_storage(self): 228 | """ Calling save on an object without an attached storage should throw an exception. """ 229 | class NoStorage(StoredObject): 230 | _id = fields.StringField(primary=True) 231 | obj = NoStorage() 232 | assert_raises(exceptions.ImproperConfigurationError, 233 | lambda: obj.save()) 234 | 235 | def test_storage_type(self): 236 | """ Assigning a non-Storage object in set_storage should throw an exception. """ 237 | pass 238 | 239 | def test_cannot_save_detached_object(self): 240 | user = User() 241 | user._detached = True 242 | assert_raises(exceptions.DatabaseError, lambda: user.save()) 243 | 244 | def test_eq(self): 245 | user = User(name="Foobar") 246 | user.save() 247 | same_user = User.load(user._primary_key) 248 | assert_equal(user, same_user) 249 | different = User(name="Barbaz") 250 | assert_not_equal(user, different) 251 | assert_not_equal(None, user) 252 | 253 | 254 | if __name__ == '__main__': 255 | unittest.main() 256 | -------------------------------------------------------------------------------- /tests/test_subclass.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from nose.tools import * # PEP8 asserts 3 | 4 | from modularodm import StoredObject, fields 5 | 6 | def test_subclass(): 7 | """Test that fields are inherited from superclasses. 8 | 9 | """ 10 | class ParentSchema(StoredObject): 11 | _id = fields.StringField() 12 | parent_field = fields.StringField() 13 | 14 | class ChildSchema(ParentSchema): 15 | pass 16 | 17 | assert_in('parent_field', ParentSchema._fields) 18 | assert_in('parent_field', ChildSchema._fields) 19 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | # use unittest.mock in Python3 2 | try: 3 | from unittest import mock 4 | except ImportError: 5 | import mock 6 | 7 | from modularodm import StoredObject, fields, storage 8 | from modularodm.translators import DefaultTranslator 9 | 10 | 11 | def make_mock_storage(): 12 | mock_storage = mock.create_autospec(storage.Storage)() 13 | mock_storage.translator = DefaultTranslator() 14 | return mock_storage 15 | 16 | 17 | def make_model(): 18 | class Model(StoredObject): 19 | _id = fields.IntegerField(primary=True) 20 | value = fields.StringField() 21 | Model.set_storage(make_mock_storage()) 22 | return Model 23 | 24 | -------------------------------------------------------------------------------- /tests/validators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cos-archives/modular-odm/8a34891892b8af69b21fdc46701c91763a5c1cf9/tests/validators/__init__.py -------------------------------------------------------------------------------- /tests/validators/test_iterable_validators.py: -------------------------------------------------------------------------------- 1 | from modularodm import StoredObject 2 | from modularodm.exceptions import ValidationValueError 3 | from modularodm.fields import IntegerField, StringField 4 | from modularodm.validators import MaxLengthValidator, MinLengthValidator 5 | 6 | from tests.base import ModularOdmTestCase 7 | 8 | class StringValidatorTestCase(ModularOdmTestCase): 9 | 10 | def define_objects(self): 11 | class Foo(StoredObject): 12 | _id = IntegerField() 13 | test_field_max = StringField( 14 | list=False, 15 | validate=[MaxLengthValidator(5), ] 16 | ) 17 | test_field_min = StringField( 18 | list=False, 19 | validate=[MinLengthValidator(5), ] 20 | ) 21 | self.test_object = Foo(_id=0) 22 | return Foo, 23 | 24 | def test_max_length_string_validator(self): 25 | 26 | self.test_object.test_field_max = 'abc' 27 | self.test_object.save() 28 | 29 | self.test_object.test_field_max = 'abcdefg' 30 | with self.assertRaises(ValidationValueError): 31 | self.test_object.save() 32 | 33 | def test_min_length_string_validator(self): 34 | 35 | self.test_object.test_field_min = 'abc' 36 | with self.assertRaises(ValidationValueError): 37 | self.test_object.save() 38 | 39 | self.test_object.test_field_min = 'abcdefg' 40 | self.test_object.save() 41 | 42 | 43 | class ListValidatorTestCase(ModularOdmTestCase): 44 | 45 | def define_objects(self): 46 | class Foo(StoredObject): 47 | _id = IntegerField() 48 | test_field_max = IntegerField( 49 | list=True, 50 | list_validate=[MaxLengthValidator(5), ] 51 | ) 52 | test_field_min = IntegerField( 53 | list=True, 54 | list_validate=[MinLengthValidator(3), ] 55 | ) 56 | self.test_object = Foo(_id=0) 57 | return Foo, 58 | 59 | def test_min_length_list_validator(self): 60 | # This test fails. 61 | 62 | self.test_object.test_field_min = [1, 2] 63 | with self.assertRaises(ValidationValueError): 64 | self.test_object.save() 65 | 66 | self.test_object.test_field_min = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0] 67 | self.test_object.save() 68 | 69 | def test_max_length_list_validator(self): 70 | # This test fails. 71 | 72 | self.test_object.test_field_min = [1, 2, 3] 73 | self.test_object.test_field_max = [1, 2, 3] 74 | self.test_object.save() 75 | 76 | self.test_object.test_field_max = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0] 77 | with self.assertRaises(ValidationValueError): 78 | self.test_object.save() 79 | 80 | 81 | class IterableValidatorCombinationTestCase(ModularOdmTestCase): 82 | def define_objects(self): 83 | class Foo(StoredObject): 84 | _id = IntegerField() 85 | test_field = StringField( 86 | list=True, 87 | validate=MaxLengthValidator(3), 88 | list_validate=MinLengthValidator(3) 89 | ) 90 | self.test_object = Foo(_id=0) 91 | return Foo, 92 | 93 | def test_child_pass_list_fail(self): 94 | self.test_object.test_field = ['ab', 'abc'] 95 | 96 | with self.assertRaises(ValidationValueError): 97 | self.test_object.save() 98 | 99 | def test_child_fail_list_pass(self): 100 | self.test_object.test_field = ['ab', 'abcd', 'adc'] 101 | 102 | with self.assertRaises(ValidationValueError): 103 | self.test_object.save() 104 | 105 | def test_child_fail_list_fail(self): 106 | self.test_object.test_field = ['ab', 'abdc'] 107 | 108 | with self.assertRaises(ValidationValueError): 109 | self.test_object.save() 110 | 111 | -------------------------------------------------------------------------------- /tests/validators/test_numeric_validators.py: -------------------------------------------------------------------------------- 1 | from modularodm import StoredObject 2 | from modularodm.exceptions import ValidationValueError 3 | from modularodm.fields import FloatField, IntegerField 4 | from modularodm.validators import MinValueValidator, MaxValueValidator 5 | 6 | from tests.base import ModularOdmTestCase 7 | 8 | class IntValueValidatorTestCase(ModularOdmTestCase): 9 | 10 | def test_min_value_int_validator(self): 11 | 12 | class Foo(StoredObject): 13 | _id = IntegerField() 14 | int_field = IntegerField( 15 | list=False, 16 | validate=[MinValueValidator(5), ] 17 | ) 18 | Foo.set_storage(self.make_storage()) 19 | 20 | test_object = Foo() 21 | test_object.int_field = 10 22 | test_object.save() 23 | 24 | test_object.int_field = 0 25 | with self.assertRaises(ValidationValueError): 26 | test_object.save() 27 | 28 | def test_max_value_int_validator(self): 29 | 30 | class Foo(StoredObject): 31 | _id = IntegerField() 32 | int_field = IntegerField( 33 | list=False, 34 | validate=[MaxValueValidator(5), ] 35 | ) 36 | Foo.set_storage(self.make_storage()) 37 | 38 | test_object = Foo() 39 | test_object.int_field = 0 40 | test_object.save() 41 | 42 | test_object.int_field = 10 43 | with self.assertRaises(ValidationValueError): 44 | test_object.save() 45 | 46 | def test_min_value_float_validator(self): 47 | 48 | class Foo(StoredObject): 49 | _id = IntegerField() 50 | float_field = FloatField( 51 | list=False, 52 | validate=[MinValueValidator(5.), ] 53 | ) 54 | Foo.set_storage(self.make_storage()) 55 | 56 | test_object = Foo() 57 | test_object.float_field = 10. 58 | test_object.save() 59 | 60 | test_object.float_field = 0. 61 | with self.assertRaises(ValidationValueError): 62 | test_object.save() 63 | 64 | def test_max_value_float_validator(self): 65 | 66 | class Foo(StoredObject): 67 | _id = IntegerField() 68 | float_field = FloatField( 69 | list=False, 70 | validate=[MaxValueValidator(5.), ] 71 | ) 72 | Foo.set_storage(self.make_storage()) 73 | 74 | test_object = Foo() 75 | test_object.float_field = 0. 76 | test_object.save() 77 | 78 | test_object.float_field = 10. 79 | with self.assertRaises(ValidationValueError): 80 | test_object.save() 81 | -------------------------------------------------------------------------------- /tests/validators/test_record_validation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from nose.tools import * 4 | 5 | from modularodm import StoredObject, fields 6 | from modularodm.exceptions import ValidationValueError 7 | 8 | from tests.base import ModularOdmTestCase 9 | 10 | 11 | class TestSchemaValidation(ModularOdmTestCase): 12 | 13 | def define_objects(self): 14 | 15 | def validate_schema(record): 16 | if record.value1 > record.value2: 17 | raise ValidationValueError 18 | 19 | class Schema(StoredObject): 20 | 21 | _id = fields.IntegerField(primary=True) 22 | value1 = fields.IntegerField() 23 | value2 = fields.IntegerField() 24 | 25 | _meta = { 26 | 'validators': [validate_schema], 27 | } 28 | 29 | return Schema, 30 | 31 | def test_save_valid(self): 32 | record = self.Schema(_id=1, value1=2, value2=3) 33 | try: 34 | record.save() 35 | except ValidationValueError: 36 | assert False 37 | 38 | def test_save_invalid(self): 39 | record = self.Schema(_id=1, value1=3, value2=2) 40 | with assert_raises(ValidationValueError): 41 | record.save() 42 | -------------------------------------------------------------------------------- /tests/validators/test_url_validation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os.path 4 | import json 5 | 6 | from modularodm import StoredObject 7 | from modularodm.exceptions import ValidationError 8 | from modularodm.fields import StringField, IntegerField 9 | from modularodm.validators import URLValidator 10 | 11 | from tests.base import ModularOdmTestCase 12 | 13 | class UrlValueValidatorTestCase(ModularOdmTestCase): 14 | 15 | def test_url(self): 16 | basepath = os.path.dirname(__file__) 17 | url_data_path = os.path.join(basepath, 'urlValidatorTest.json') 18 | with open(url_data_path) as url_test_data: 19 | data = json.load(url_test_data) 20 | 21 | class Foo(StoredObject): 22 | _id = IntegerField() 23 | url_field = StringField( 24 | list=False, 25 | validate=[URLValidator()] 26 | ) 27 | 28 | Foo.set_storage(self.make_storage()) 29 | test_object = Foo() 30 | 31 | for urlTrue in data['testsPositive']: 32 | test_object.url_field = urlTrue 33 | test_object.save() 34 | 35 | for urlFalse in data['testsNegative']: 36 | test_object.url_field = urlFalse 37 | try: 38 | with self.assertRaises(ValidationError): 39 | test_object.save() 40 | except AssertionError as e: 41 | e.args += (' for ', urlFalse) 42 | raise 43 | -------------------------------------------------------------------------------- /tests/validators/urlValidatorTest.json: -------------------------------------------------------------------------------- 1 | { 2 | "_comment": "These tests adapted from https://mathiasbynens.be/demo/url-regex", 3 | "testsPositive": { 4 | "definitelyawebsite.com":"should accept simple valid website", 5 | "https://Definitelyawebsite.com":"should accept valid website with protocol", 6 | "http://foo.com/blah_blah":"should accept valid website with path", 7 | "http://foo.com/blah_blah:5000/pathhere":"should accept valid website with port and path", 8 | "http://foo.com/blah_blah_(wikipedia)":"should accept valid website with parentheses in path", 9 | "https://userid@example.com/":"should accept valid user website", 10 | "http://userid@example.com:8080":"should accept valid user website with port", 11 | "http://userid:password@example.com/":"should accept valid user website with password", 12 | "http://userid:password@example.com:8080/":"should accept valid user and password website with path", 13 | "http://142.42.1.1":"should accept valid ipv4 website test 1", 14 | "http://10.1.1.0":"should accept valid ipv4 website test 2", 15 | "http://10.1.1.255":"should accept valid ipv4 website test 3", 16 | "http://224.1.1.1":"should accept valid ipv4 website test 4", 17 | "http://142.42.1.1:8080/":"should accept valid ipv4 website with port", 18 | "http://Bücher.de":"should accept valid website with unicode in domain", 19 | "http://heynow.ws/䨹":"should accept valid website with unicode in path", 20 | "http://localhost:5000/meetings":"should accept valid localhost website", 21 | "http://⌘.ws":"should accept valid website with only unicode in path", 22 | "http://⌘.ws/":"should accept valid website with only unicode in path and a / after domain", 23 | "http://foo.com/blah_(wikipedia)#cite-1":"should accept valid website with hashtag following parentheses in path", 24 | "http://foo.com/blah_(wikipedia)_blah#cite-1":"should accept valid website with hashtag in path", 25 | "http://foo.com/unicode_(✪)_in_parens":"should accept valid website with unicode in parentheses in path", 26 | "http://foo.com/(something)?after=parens":"should accept valid website with something after path", 27 | "http://staging.damowmow.com/":"should accept valid website with sub-domain", 28 | "http://☺.damowmow.com/":"should accept valid website with unicode in sub-domain", 29 | "http://code.google.com/events/#&product=browser":"should accept valid website with variables", 30 | "ftp://foo.bar/baz":"should acccept valid website with ftps", 31 | "http://foo.bar/?q=Test%20URL-encoded%20stuff":"should accept valid website with encoded stuff in path", 32 | "http://مثال.إختبار":"should accept valid unicode heavy website test 1", 33 | "http://例子.测试":"should accept valid unicode heavy website test 2", 34 | "http://उदाहरण.परीक्षा":"should accept valid unicode heavy website test 3", 35 | "http://-.~_!$&()*+,;=:%40:80%2f::::::@example.com":"should accept valid website with user but crazy username", 36 | "http://1337.net":"should accept valid website with just numbers in domain", 37 | "definitelyawebsite.com?real=yes&page=definitely":"should accept valid website with query", 38 | "http://a.b-c.de":"should accept valid website with dash", 39 | "http://website.com:3000/?q=400":"should accept valid website with port and query", 40 | "http://asd/asd@asd.com/":"should accept valid website with unusual username" 41 | }, 42 | "testsNegative": { 43 | "notevenclose": "should deny simple invalid website", 44 | "http://": "should deny invalid website with only http://", 45 | "http://.": "should deny invalid website with only http://.", 46 | "http://..": "should deny invalid website with only http://..", 47 | "http://../": "should deny invalid website with only http://../", 48 | "http://?": "should deny invalid website with only http://?", 49 | "http://??": "should deny invalid website with only http://??", 50 | "http://??/": "should deny invalid website with only http://??/", 51 | "http://#": "should deny invalid website with only http://#", 52 | "http://##": "should deny invalid website with only http://##", 53 | "http://##/": "should deny invalid website with only http://##/", 54 | "http://foo.bar?q=Spaces should be encoded": "should deny invalid website with spaces in path", 55 | "//": "should deny invalid website with only //", 56 | "//a": "should deny invalid website with only //a", 57 | "///a": "should deny invalid website with only ///a", 58 | "///": "should deny invalid website with only ///", 59 | "http:///a": "should deny invalid website with three / in protocol", 60 | "rdar://1234": "should deny invalid website with invalid protocol", 61 | "h://test": "should deny invalid website with missing letters from protocol", 62 | "http:// shouldfail.com": "should deny invalid website with space in beginning of domain", 63 | "http://should fail": "should deny invalid website with space in middle of domain", 64 | "http://-error-.invalid/": "should deny invalid website with dash at beginning and end of domain", 65 | "http://1.1.1.1.1": "should deny invalid ipv4 website with 5 numbers", 66 | "http://567.100.100.100": "should deny invalid ipv4 website with a number out of range", 67 | "http://-a.b.co": "should deny invalid website with dash at beginning of sub-domain", 68 | "http://.www.foo.bar/": "should deny invalid website with dot before sub-domain", 69 | "httsp://userid@example.com/": "should deny invalid website with username and invalid scheme" 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist =py27,py33,py34,pypy,pypy3 3 | [testenv] 4 | deps = -rdev-requirements.txt 5 | commands= 6 | nosetests 7 | --------------------------------------------------------------------------------