├── .circleci └── config.yml ├── .gitignore ├── README.md ├── flask_common ├── __init__.py ├── app.py ├── asserts.py ├── client.py ├── commands.py ├── crypto.py ├── db.py ├── declenum.py ├── enum.py ├── formfields.py ├── mongo │ ├── __init__.py │ ├── documents.py │ ├── fields │ │ ├── __init__.py │ │ ├── basic.py │ │ ├── crypto.py │ │ ├── id.py │ │ ├── phone.py │ │ └── tz.py │ ├── query_counters.py │ ├── querysets.py │ └── utils.py ├── test_helpers.py └── utils │ ├── __init__.py │ ├── cache.py │ ├── decorators.py │ ├── id.py │ ├── legacy.py │ ├── lists.py │ └── objects.py ├── pyproject.toml ├── requirements.txt ├── requirements_lint.txt ├── setup.py ├── tests ├── __init__.py ├── test_client.py ├── test_crypto.py ├── test_declenum.py ├── test_enum.py ├── test_formfields.py ├── test_legacy.py ├── test_mongo │ ├── __init__.py │ ├── test_documents.py │ ├── test_fields │ │ ├── __init__.py │ │ ├── test_basic.py │ │ ├── test_crypto.py │ │ ├── test_phone.py │ │ └── test_tz.py │ └── test_utils.py ├── test_python_support.py └── test_test_helpers.py └── tox.ini /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | workflows: 4 | version: 2 5 | workflow: 6 | jobs: 7 | - test-3.5 8 | - test-3.6 9 | - test-3.7 10 | - lint 11 | - black 12 | 13 | defaults: &defaults 14 | working_directory: ~/code 15 | steps: 16 | - checkout 17 | - run: 18 | name: Install dependencies 19 | command: | 20 | GIT_SSH_COMMAND="ssh -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no" \ 21 | sudo -E pip install -r requirements.txt pytest==4.6.5 22 | - run: 23 | name: Test 24 | command: PYTHONPATH=. pytest tests 25 | 26 | jobs: 27 | test-3.5: 28 | <<: *defaults 29 | docker: 30 | - image: circleci/python:3.5 31 | - image: mongo:3.2.19 32 | test-3.6: 33 | <<: *defaults 34 | docker: 35 | - image: circleci/python:3.6 36 | - image: mongo:3.2.19 37 | test-3.7: 38 | <<: *defaults 39 | docker: 40 | - image: circleci/python:3.7 41 | - image: mongo:3.2.19 42 | black: 43 | docker: 44 | - image: circleci/python:3.6 45 | working_directory: ~/code 46 | steps: 47 | - checkout 48 | - run: 49 | name: Install dependencies 50 | command: sudo -E pip install black 51 | - run: 52 | name: Test 53 | command: black --check . 54 | lint: 55 | docker: 56 | - image: circleci/python:3.6 57 | working_directory: ~/code 58 | steps: 59 | - checkout 60 | - run: 61 | name: Iinstall dependencies 62 | command: sudo -E pip install -r requirements_lint.txt 63 | - run: 64 | name: Test 65 | command: | 66 | flake8 --ignore F821,E203,E402,E501,W503 --select C,E,F,W,B,B950 67 | 68 | 69 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.py[co] 2 | 3 | # Packages 4 | *.egg 5 | *.egg-info 6 | dist 7 | build 8 | eggs 9 | parts 10 | bin 11 | var 12 | sdist 13 | develop-eggs 14 | .installed.cfg 15 | 16 | # Installer logs 17 | pip-log.txt 18 | 19 | # Unit test / coverage reports 20 | .coverage 21 | .tox 22 | .pytest_cache 23 | 24 | #Translations 25 | *.mo 26 | 27 | #Mr Developer 28 | .mr.developer.cfg 29 | 30 | # Environments 31 | .env 32 | .venv 33 | env/ 34 | venv/ 35 | ENV/ 36 | env.bak/ 37 | venv.bak/ 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | flask-common 2 | ============ 3 | 4 | A collection of very random (but useful) stuff that's too small on its own to be published as a separate package. 5 | 6 | -------------------------------------------------------------------------------- /flask_common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closeio/flask-common/a3893e9f2bc1801d7e8557aef3f3e3f26811d398/flask_common/__init__.py -------------------------------------------------------------------------------- /flask_common/app.py: -------------------------------------------------------------------------------- 1 | from flask import Flask 2 | 3 | 4 | class Application(Flask): 5 | def __init__(self, name, *args, **kwargs): 6 | config = kwargs.pop('config', None) 7 | super(Application, self).__init__(name, *args, **kwargs) 8 | self.config.from_object('%s.config' % name) 9 | if config is not None: 10 | self.config.update(**config) 11 | -------------------------------------------------------------------------------- /flask_common/asserts.py: -------------------------------------------------------------------------------- 1 | def response_success(response, code=None, exception_class=None): 2 | if exception_class is None: 3 | exception_class = AssertionError 4 | 5 | if ( 6 | code is None 7 | and (response.status_code >= 300 or response.status_code < 200) 8 | ) or (code and code != response.status_code): 9 | raise exception_class( 10 | 'Received %d response: %s' % (response.status_code, response.data) 11 | ) 12 | 13 | 14 | def validation_error(response, content_type='application/json'): 15 | assert content_type in response.content_type, ( 16 | 'Invalid content-type: %s' % response.content_type 17 | ) 18 | response_error(response, code=400) 19 | 20 | 21 | def response_error(response, code=None): 22 | if code is None: 23 | assert 400 <= response.status_code < 500, 'Received %d response: %s' % ( 24 | response.status_code, 25 | response.data, 26 | ) 27 | else: 28 | assert code == response.status_code, 'Received %d response: %s' % ( 29 | response.status_code, 30 | response.data, 31 | ) 32 | 33 | 34 | def compare_req_resp(req_obj, resp_obj): 35 | for k, v in req_obj.items(): 36 | assert k in resp_obj.keys(), 'Key %r not in response (keys are %r)' % ( 37 | k, 38 | resp_obj.keys(), 39 | ) 40 | assert resp_obj[k] == v, 'Value for key %r should be %r but is %r' % ( 41 | k, 42 | v, 43 | resp_obj[k], 44 | ) 45 | -------------------------------------------------------------------------------- /flask_common/client.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | 4 | from flask import current_app 5 | from flask.testing import FlaskClient 6 | from six import PY3 7 | from werkzeug.datastructures import Headers 8 | 9 | 10 | class Client(FlaskClient): 11 | """ 12 | Test client that supports JSON and uses the application's response class. 13 | """ 14 | 15 | def __init__(self, app, response_wrapper=None, **kwargs): 16 | if not response_wrapper: 17 | response_wrapper = app.response_class 18 | super(Client, self).__init__(app, response_wrapper, **kwargs) 19 | 20 | def open(self, *args, **kwargs): 21 | if 'json' in kwargs and 'data' not in kwargs: 22 | kwargs['data'] = json.dumps(kwargs.pop('json')) 23 | kwargs['content_type'] = 'application/json' 24 | 25 | return super(Client, self).open(*args, **kwargs) 26 | 27 | 28 | class ApiClient(Client): 29 | """ 30 | API test client that supports JSON and uses the given API key. 31 | """ 32 | 33 | def __init__(self, app, api_key=None): 34 | self.api_key = api_key 35 | super(ApiClient, self).__init__(app, use_cookies=False) 36 | 37 | def get_headers(self, api_key): 38 | api_key = api_key or self.api_key 39 | 40 | # Make sure we're giving bytes to b64encode 41 | auth_header = base64.b64encode(('%s:' % api_key).encode()) 42 | 43 | # PY3 gives us bytes bck, Need to decode from ASCII back to str 44 | if PY3: 45 | auth_header = auth_header.decode() 46 | return Headers([('Authorization', 'Basic %s' % auth_header)]) 47 | 48 | def open(self, *args, **kwargs): 49 | # include api_key auth header in all api calls 50 | api_key = kwargs.pop('api_key', self.api_key) 51 | if 'headers' not in kwargs: 52 | kwargs['headers'] = self.get_headers(api_key) 53 | return super(ApiClient, self).open(*args, **kwargs) 54 | 55 | 56 | def local_request( 57 | view, 58 | method='GET', 59 | data=None, 60 | view_args=None, 61 | user=None, 62 | api_key=None, 63 | meta=None, 64 | request_id=None, 65 | ): 66 | """ 67 | Performs a request to the current application's view without the network 68 | overhead and without request pre and postprocessing. Returns a tuple 69 | (response_status_code, response_json_data). 70 | 71 | Examples: 72 | 73 | # List leads for a given organization (as seen by user A) 74 | local_request(LeadView(), data={ 'organization_id': 'orga_abc' }, user=user_A) 75 | 76 | # Post a note as user B 77 | local_request(NoteView(), method='POST', data={ 'organization_id': 'orga_abc', 'note': 'hello' }, user=user_B) 78 | 79 | # Update an opportunity as a user associated with an API key "abc" 80 | local_request(OpportunityView(), method='PUT', data={ 'status': 'won' }, 81 | view_args={ 'pk': 'oppo_abcd' }, api_key='abc') 82 | """ 83 | if api_key is not None and user is not None: 84 | raise TypeError( 85 | "local_request can only take an api_key or a user, not both." 86 | ) 87 | 88 | if not view_args: 89 | view_args = {} 90 | 91 | ctx = current_app.test_request_context() 92 | ctx.request.environ[ 93 | 'REQUEST_METHOD' 94 | ] = method # we can't directly manipulate request.method (it's immutable) 95 | ctx.user = user 96 | if api_key is not None: 97 | ctx.g.api_key = api_key 98 | if data and method == 'GET': 99 | ctx.request.args = data 100 | elif data: 101 | ctx.request.data = json.dumps(data) 102 | if meta is not None: 103 | ctx.g.meta = meta 104 | if request_id is not None: 105 | ctx.g.request_id = request_id 106 | ctx.push() 107 | 108 | try: 109 | resp = view.dispatch_request(**view_args) 110 | json_data = json.loads(resp.data) 111 | except Exception: 112 | ctx.pop() 113 | raise 114 | else: 115 | ctx.pop() 116 | 117 | return resp.status_code, json_data 118 | -------------------------------------------------------------------------------- /flask_common/commands.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module allows to run Flask-Script lazily (i.e. the app doesn't need to be 3 | created until the command is run). This allows to: 4 | 5 | - Speed up commands that don't need the app object 6 | - Allow custom app configuration in commands that need customization 7 | """ 8 | 9 | from flask import Flask 10 | import flask_script 11 | from werkzeug.local import LocalProxy 12 | 13 | 14 | __all__ = ['ContextlessCommand', 'Manager', 'Test'] 15 | 16 | 17 | class FlaskProxy(LocalProxy): 18 | pass 19 | 20 | 21 | class Manager(flask_script.Manager): 22 | """ 23 | A Flask-Script manager that supports contextless commands, i.e. commands 24 | that don't require app context. When initialized lazily, contextless 25 | commands can be executed faster since the app doesn't need to be 26 | initialized. To initialize lazily, pass either a callable that returns 27 | the app object, or a dotted string to the module and function that returns 28 | the app, e.g. 'myapp.main.setup_app'. 29 | """ 30 | 31 | def __init__(self, app=None, *args, **kwargs): 32 | """ 33 | Like flask_script.Manager, but allows you to pass a function that 34 | creates the app (in addition to passing a regular Flask app. 35 | """ 36 | if not isinstance(app, Flask): 37 | if isinstance(app, str): 38 | pkg, func_name = app.rsplit('.', 1) 39 | 40 | def create_app(*args, **kwargs): 41 | import importlib 42 | 43 | module = importlib.import_module(pkg) 44 | return getattr(module, func_name)(*args, **kwargs) 45 | 46 | app = create_app 47 | 48 | if callable(app): 49 | self._create_app = app 50 | self._cached_app = None 51 | app = FlaskProxy(self.get_or_create_app) 52 | 53 | super(Manager, self).__init__(app=app, *args, **kwargs) 54 | 55 | def get_or_create_app(self, *args, **kwargs): 56 | if self._cached_app is None: 57 | # Create app 58 | self._cached_app = self._create_app(*args, **kwargs) 59 | return self._cached_app 60 | 61 | def __call__(self, app=None, **kwargs): 62 | if app is not None: 63 | self.app = app 64 | return self.app 65 | 66 | def contextless_command(self, func): 67 | """ 68 | Function decorator for a command that doesn't require app context. 69 | """ 70 | command = ContextlessCommand(func) 71 | self.add_command(func.__name__, command) 72 | return func 73 | 74 | 75 | class ContextlessCommand(flask_script.Command): 76 | """ 77 | A Flask-Script command that doesn't require app context. 78 | """ 79 | 80 | def __call__(self, app=None, *args, **kwargs): 81 | return self.run(*args, **kwargs) 82 | 83 | 84 | class Test(flask_script.Command): 85 | """ 86 | Management command that runs tests via pytest. Any command line arguments 87 | are passed directly to pytest. 88 | 89 | When using a Manager with a lazily created app (i.e. a callable), any 90 | args/kwargs passed to the constructor will be passed to the callable. 91 | 92 | Example: 93 | 94 | def setup_app(config=None): 95 | app = Flask('app') 96 | if config: 97 | app.config.from_object(config) 98 | return app 99 | 100 | manager = Manager(setup_app) 101 | manager.add_command('test', Test(config='config.app_testing')) 102 | 103 | """ 104 | 105 | capture_all_args = True 106 | help = 'Run tests' 107 | 108 | def __init__(self, *args, **kwargs): 109 | super(Test, self).__init__() 110 | self.app_args = args 111 | self.app_kwargs = kwargs 112 | 113 | def __call__(self, app=None, *args, **kwargs): 114 | # By default, use the manager's app object, but if a callable was 115 | # passed we can forward all args. 116 | if self.app_args or self.app_kwargs: 117 | if isinstance(app, FlaskProxy): 118 | app = self.parent.get_or_create_app( 119 | *self.app_args, **self.app_kwargs 120 | ) 121 | else: 122 | raise Exception( 123 | 'Must use flask_common.commands.Manager when ' 124 | 'passing args to the Test() command.' 125 | ) 126 | 127 | return super(Test, self).__call__(app, *args, **kwargs) 128 | 129 | def create_parser(self, *args, **kwargs): 130 | # Override the default parser so we can pass all arguments to pytest. 131 | import argparse 132 | 133 | func_stack = kwargs.pop('func_stack', ()) 134 | parent = kwargs.pop('parent', None) 135 | parser = argparse.ArgumentParser(*args, add_help=False, **kwargs) 136 | parser.set_defaults(func_stack=func_stack + (self,)) 137 | self.parser = parser 138 | self.parent = parent 139 | return parser 140 | 141 | def run(self, args): 142 | # Keep imports inlined so they're not unnecessarily imported. 143 | import pytest 144 | import sys 145 | 146 | sys.exit(pytest.main(args)) 147 | -------------------------------------------------------------------------------- /flask_common/crypto.py: -------------------------------------------------------------------------------- 1 | """This file supports versioned encrypted information. 2 | 3 | * Version 0: marked with first byte `\x00` (deprecated) 4 | 5 | Implemented with `pycrypto`. 6 | 7 | In this version, data is returned from `aes_encrypt` in the format: 8 | 9 | [VERSION 0 byte][IV 32 bytes][Encrypted data][HMAC 32 bytes] 10 | 11 | This format comes with an erroneous implementation that used an IV of 12 | 32 bytes when AES expects IVs of 16 bytes, and the library used at the 13 | time (`pycrypto`) silently truncated the IV for us. 14 | 15 | * Version 1: marked with first byte `\x01` 16 | 17 | Implemented with `cryptography`. 18 | 19 | In this version, data is returned from `aes_encrypt` in the format: 20 | 21 | [VERSION 1 byte][IV 16 bytes][Encrypted data][HMAC 32 bytes] 22 | 23 | This version came into existence to fix the wrong-sized IVs from 24 | version 0. 25 | 26 | Current code decrypts from and encrypts to version 1 only. 27 | 28 | In CTR mode, IV is also often called a Nonce (in `cryptography`'s 29 | public interface, for example). 30 | """ 31 | 32 | import hashlib 33 | import hmac 34 | import os 35 | 36 | from cryptography.hazmat.backends import default_backend 37 | from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes 38 | 39 | backend = default_backend() 40 | 41 | AES_KEY_SIZE = 32 # 256 bits 42 | HMAC_KEY_SIZE = 32 # 256 bits 43 | KEY_LENGTH = AES_KEY_SIZE + HMAC_KEY_SIZE 44 | 45 | IV_SIZE = 16 # 128 bits 46 | 47 | HMAC_DIGEST = hashlib.sha256 48 | HMAC_DIGEST_SIZE = hashlib.sha256().digest_size 49 | 50 | V1_MARKER = b'\x01' 51 | 52 | 53 | class EncryptionError(Exception): 54 | pass 55 | 56 | 57 | class AuthenticationError(Exception): 58 | pass 59 | 60 | 61 | """ 62 | Helper AES encryption/decryption methods. Uses AES-CTR + HMAC for authenticated 63 | encryption. The same key/iv combination must never be reused. 64 | """ 65 | 66 | 67 | # Returns a new randomly generated AES key. 68 | def aes_generate_key(): 69 | return os.urandom(KEY_LENGTH) 70 | 71 | 72 | # Encrypt + sign using a random IV. 73 | def aes_encrypt(key, data): 74 | assert len(key) == KEY_LENGTH, 'invalid key size' 75 | iv = os.urandom(IV_SIZE) 76 | return V1_MARKER + iv + aes_encrypt_iv(key, data, iv) 77 | 78 | 79 | # Verify + decrypt data encrypted with IV. 80 | def aes_decrypt(key, data): 81 | assert len(key) == KEY_LENGTH, 'invalid key size' 82 | 83 | # In Python 3, if you extract a single byte from a bytestring, 84 | # you'll get an int. That's why we extract it using a slice. 85 | extracted_version = data[0:1] 86 | data = data[1:] 87 | 88 | if extracted_version == V1_MARKER: 89 | iv = data[:IV_SIZE] 90 | data = data[IV_SIZE:] 91 | else: 92 | raise EncryptionError( 93 | 'Found invalid version marker: {!r}'.format(extracted_version) 94 | ) 95 | 96 | return aes_decrypt_iv(key, data, iv) 97 | 98 | 99 | # Encrypt + sign using provided IV. 100 | # Note: You should normally use aes_encrypt(). 101 | def aes_encrypt_iv(key, data, iv): 102 | aes_key = key[:AES_KEY_SIZE] 103 | hmac_key = key[AES_KEY_SIZE:] 104 | encryptor = Cipher( 105 | algorithms.AES(aes_key), modes.CTR(iv), backend=backend 106 | ).encryptor() 107 | cipher = encryptor.update(data) + encryptor.finalize() 108 | sig = hmac.new(hmac_key, iv + cipher, HMAC_DIGEST).digest() 109 | return cipher + sig 110 | 111 | 112 | # Verify + decrypt using provided IV. 113 | # Note: You should normally use aes_decrypt(). 114 | def aes_decrypt_iv(key, data, iv): 115 | aes_key = key[:AES_KEY_SIZE] 116 | hmac_key = key[AES_KEY_SIZE:] 117 | cipher = data[:-HMAC_DIGEST_SIZE] 118 | sig = data[-HMAC_DIGEST_SIZE:] 119 | if hmac.new(hmac_key, iv + cipher, HMAC_DIGEST).digest() != sig: 120 | raise AuthenticationError('message authentication failed') 121 | decryptor = Cipher( 122 | algorithms.AES(aes_key), modes.CTR(iv), backend=backend 123 | ).decryptor() 124 | return decryptor.update(cipher) + decryptor.finalize() 125 | -------------------------------------------------------------------------------- /flask_common/db.py: -------------------------------------------------------------------------------- 1 | from bson import ObjectId 2 | import uuid 3 | import sqlalchemy as db 4 | from sqlalchemy.dialects.postgresql import UUID 5 | from sqlalchemy.ext.declarative import declared_attr 6 | from sqlalchemy.orm import relationship, synonym 7 | 8 | __all__ = [ 9 | 'MongoReference', 10 | 'MongoEmbedded', 11 | 'MongoEmbeddedList', 12 | 'Base', 13 | 'UserBase', 14 | ] 15 | 16 | 17 | def MongoReference(field, ref_cls, queryset=None): 18 | """ 19 | SQLA field that represents a reference to a MongoEngine document. 20 | 21 | The value is cached until an assignment is made. 22 | 23 | To use a custom QuerySet (instead of the default `ref_cls.objects`), 24 | pass it as the `queryset` kwarg. You can also pass a function that 25 | resolves to a QuerySet. 26 | """ 27 | 28 | def _resolve_queryset(): 29 | if queryset is None: 30 | return ref_cls.objects 31 | else: 32 | return queryset() 33 | 34 | def _get(obj): 35 | qs = _resolve_queryset() 36 | if not hasattr(obj, '_%s__cache' % field): 37 | ref_id = getattr(obj, field) 38 | if ref_id is None: 39 | ref = None 40 | else: 41 | ref = qs.get(pk=ref_id) 42 | setattr(obj, '_%s__cache' % field, ref) 43 | return getattr(obj, '_%s__cache' % field) 44 | 45 | def _set(obj, val): 46 | if hasattr(obj, '_%s__cache' % field): 47 | delattr(obj, '_%s__cache' % field) 48 | if isinstance(val, ref_cls): 49 | val = val.pk 50 | if isinstance(val, ObjectId): 51 | val = str(val) 52 | setattr(obj, field, val) 53 | 54 | return synonym(field, descriptor=property(_get, _set)) 55 | 56 | 57 | def MongoEmbedded(field, emb_cls): 58 | """ 59 | SQLA field that represents a MongoEngine embedded document. 60 | 61 | Converts the JSON value to/from an EmbeddedDocument. Note that a new 62 | instance is returned every time we access and we must reassign any changes 63 | back to the model. 64 | """ 65 | 66 | def _get(obj): 67 | return emb_cls._from_son(getattr(obj, field)) 68 | 69 | def _set(obj, val): 70 | setattr(obj, field, val.to_mongo()) 71 | 72 | return synonym(field, descriptor=property(_get, _set)) 73 | 74 | 75 | def MongoEmbeddedList(field, emb_cls): 76 | """SQLA field that represents a list of MongoEngine embedded documents.""" 77 | 78 | def _get(obj): 79 | return [emb_cls._from_son(item) for item in getattr(obj, field)] 80 | 81 | def _set(obj, val): 82 | setattr(obj, field, [item.to_mongo() for item in val]) 83 | 84 | return synonym(field, descriptor=property(_get, _set)) 85 | 86 | 87 | # From https://code.launchpad.net/~stefanor/ibid/sqlalchemy-0.6-trunk/+merge/66033 88 | class PGSQLModeListener(object): 89 | def connect(self, dbapi_con, con_record): 90 | c = dbapi_con.cursor() 91 | c.execute("SET TIME ZONE UTC") 92 | c.close() 93 | 94 | 95 | class Base(object): 96 | id = db.Column(UUID, default=lambda: str(uuid.uuid4()), primary_key=True) 97 | created_at = db.Column(db.DateTime(), default=db.func.now()) 98 | updated_at = db.Column( 99 | db.DateTime(), default=db.func.now(), onupdate=db.func.now() 100 | ) 101 | 102 | @property 103 | def pk(self): 104 | return self.id 105 | 106 | __mapper_args__ = {'order_by': db.desc('updated_at')} 107 | 108 | 109 | class UserBase(Base): 110 | created_by_id = declared_attr( 111 | lambda cls: db.Column( 112 | UUID, db.ForeignKey('user.id'), default=cls._get_current_user 113 | ) 114 | ) 115 | created_by = declared_attr( 116 | lambda cls: relationship( 117 | 'User', primaryjoin='%s.created_by_id == User.id' % cls.__name__ 118 | ) 119 | ) 120 | updated_by_id = declared_attr( 121 | lambda cls: db.Column( 122 | UUID, 123 | db.ForeignKey('user.id'), 124 | default=cls._get_current_user, 125 | onupdate=cls._get_current_user, 126 | ) 127 | ) 128 | updated_by = declared_attr( 129 | lambda cls: relationship( 130 | 'User', primaryjoin='%s.updated_by_id == User.id' % cls.__name__ 131 | ) 132 | ) 133 | 134 | @classmethod 135 | def _get_current_user(cls): 136 | return None 137 | -------------------------------------------------------------------------------- /flask_common/declenum.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy.types import SchemaType, TypeDecorator, Enum 2 | import re 3 | 4 | 5 | class DeclEnumType(SchemaType, TypeDecorator): 6 | """ 7 | DeclEnumType supports object instantiation in two different ways: 8 | 9 | Passing in an enum: 10 | This is to be used in the application code. It will pull enum values straight from 11 | the DeclEnum object. A helper for this is available in DeclEnum.db_type() 12 | 13 | Passing in a tuple with enum values: 14 | In migrations the enum value list needs to be fix. It should not be pulled in from 15 | the application code, otherwise later modifications of enum values could result in 16 | those values being added in an earlier migration when re-running migrations from the 17 | beginning. Therefore DeclEnum(enum_values=('one', 'two'), enum_name='MyEnum') should 18 | be used. 19 | 20 | """ 21 | 22 | def __init__(self, enum=None, enum_values=None, enum_name=None): 23 | self.enum = enum 24 | self.enum_values = enum_values 25 | self.enum_name = enum_name 26 | 27 | if enum: 28 | self.enum_values = enum.values() 29 | self.enum_name = enum.__name__ 30 | 31 | self.impl = Enum( 32 | *self.enum_values, 33 | name="ck%s" 34 | % re.sub( 35 | '([A-Z])', lambda m: "_" + m.group(1).lower(), self.enum_name 36 | ) 37 | ) 38 | 39 | def create(self, bind=None, checkfirst=False): 40 | """Issue CREATE ddl for this type, if applicable.""" 41 | super(DeclEnumType, self).create(bind, checkfirst) 42 | t = self.dialect_impl(bind.dialect) 43 | if t.impl.__class__ is not self.__class__ and isinstance(t, SchemaType): 44 | t.impl.create(bind=bind, checkfirst=checkfirst) 45 | 46 | def _set_table(self, table, column): 47 | self.impl._set_table(table, column) 48 | 49 | def copy(self): 50 | if self.enum: 51 | return DeclEnumType(self.enum) 52 | else: 53 | return DeclEnumType( 54 | enum_name=self.enum_name, enum_values=self.enum_values 55 | ) 56 | 57 | def process_bind_param(self, value, dialect): 58 | if value is None: 59 | return None 60 | return value.value 61 | 62 | def process_result_value(self, value, dialect): 63 | if value is None: 64 | return None 65 | return self.enum.from_string(value.strip()) 66 | 67 | 68 | class EnumSymbol(object): 69 | """Define a fixed symbol tied to a parent class.""" 70 | 71 | def __init__(self, cls_, name, value, description): 72 | self.cls_ = cls_ 73 | self.name = name 74 | self.value = value 75 | self.description = description 76 | 77 | def __reduce__(self): 78 | """Allow unpickling to return the symbol 79 | linked to the DeclEnum class.""" 80 | return getattr, (self.cls_, self.name) 81 | 82 | def __iter__(self): 83 | return iter([self.value, self.description]) 84 | 85 | def __repr__(self): 86 | return "<%s>" % self.name 87 | 88 | 89 | class EnumMeta(type): 90 | """Generate new DeclEnum classes.""" 91 | 92 | def __init__(cls, classname, bases, dict_): 93 | cls._reg = reg = cls._reg.copy() 94 | for k, v in dict_.items(): 95 | if isinstance(v, tuple): 96 | sym = reg[v[0]] = EnumSymbol(cls, k, *v) 97 | setattr(cls, k, sym) 98 | return type.__init__(cls, classname, bases, dict_) 99 | 100 | def __iter__(cls): 101 | return iter(cls._reg.values()) 102 | 103 | 104 | class DeclEnum(metaclass=EnumMeta): 105 | """ 106 | Declarative enumeration. 107 | --- 108 | For information on internals, see: http://techspot.zzzeek.org/2011/01/14/the-enum-recipe/ 109 | 110 | Usage: 111 | 112 | from flask_common.declenum import DeclEnum 113 | 114 | class Colors(DeclEnum): 115 | blue = 'blue', 'Blue color' 116 | red = 'red', 'Red color' 117 | 118 | 119 | color = Colors.red 120 | color == Colors.red 121 | color.value == 'red' 122 | color == Colors.from_string('red') 123 | Colors.red.description == 'Red Color' 124 | Colors.red.value = 'red' 125 | 126 | Usage in SQLAlchemy: 127 | color = sql.Column(Colors.db_type(), default=Colors.red) 128 | 129 | 130 | """ 131 | 132 | _reg = {} 133 | 134 | @classmethod 135 | def from_string(cls, value): 136 | try: 137 | return cls._reg[value] 138 | except KeyError: 139 | raise ValueError("Invalid value for %r: %r" % (cls.__name__, value)) 140 | 141 | @classmethod 142 | def values(cls): 143 | return cls._reg.keys() 144 | 145 | @classmethod 146 | def db_type(cls): 147 | return DeclEnumType(cls) 148 | -------------------------------------------------------------------------------- /flask_common/enum.py: -------------------------------------------------------------------------------- 1 | class Enum(object): 2 | """ 3 | A list of constants that can be defined in a declarative way. 4 | 5 | Example usage: 6 | 7 | class MyEnum(Enum): 8 | Choice1 = 'value1' 9 | Choice2 = 'value2' 10 | 11 | In this case, we can refer to the choices as MyEnum.Choice1 or 12 | MyEnum.Choice2, and don't have to reference the actual string value, which 13 | is prone to typos. 14 | """ 15 | 16 | # Cached values and choices to avoid introspection on every call. 17 | __values = [] 18 | __choices = [] 19 | 20 | @classmethod 21 | def values(cls): 22 | """ 23 | Returns a list of all the values, e.g.: ('choice1', 'choice2') 24 | """ 25 | if not cls.__values: 26 | cls.__values = [ 27 | getattr(cls, v) 28 | for v in dir(cls) 29 | if not callable(getattr(cls, v)) and not v.startswith('_') 30 | ] 31 | 32 | return cls.__values 33 | 34 | @classmethod 35 | def choices(cls): 36 | """ 37 | Returns a list of choice tuples, e.g.: 38 | [('value1', 'Choice1'), ('value2', 'Choice2')] 39 | """ 40 | if not cls.__choices: 41 | cls.__choices = [ 42 | (getattr(cls, v), v) 43 | for v in dir(cls) 44 | if not callable(getattr(cls, v)) and not v.startswith('_') 45 | ] 46 | 47 | return cls.__choices 48 | -------------------------------------------------------------------------------- /flask_common/formfields.py: -------------------------------------------------------------------------------- 1 | import dateutil.parser 2 | from wtforms.fields import DateTimeField 3 | 4 | 5 | class BetterDateTimeField(DateTimeField): 6 | """ Like DateTimeField, but uses dateutil.parser to parse the date """ 7 | 8 | def process_formdata(self, valuelist): 9 | if valuelist: 10 | date_str = u' '.join(valuelist) 11 | # dateutil returns the current day if passing an empty string. 12 | if date_str.strip(): 13 | try: 14 | self.data = dateutil.parser.parse(date_str) 15 | except ValueError: 16 | self.data = None 17 | raise 18 | else: 19 | self.data = None 20 | -------------------------------------------------------------------------------- /flask_common/mongo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closeio/flask-common/a3893e9f2bc1801d7e8557aef3f3e3f26811d398/flask_common/mongo/__init__.py -------------------------------------------------------------------------------- /flask_common/mongo/documents.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | 4 | from mongoengine import ( 5 | BooleanField, 6 | DateTimeField, 7 | Document, 8 | OperationError, 9 | QuerySet, 10 | StringField, 11 | ValidationError, 12 | queryset_manager, 13 | ) 14 | from zbase62 import zbase62 15 | 16 | from .querysets import NotDeletedQuerySet 17 | 18 | 19 | class StringIdField(StringField): 20 | def to_mongo(self, value): 21 | if not isinstance(value, str): 22 | raise ValidationError( 23 | errors={ 24 | self.name: ['StringIdField only accepts string values.'] 25 | } 26 | ) 27 | return super(StringIdField, self).to_mongo(value) 28 | 29 | 30 | class RandomPKDocument(Document): 31 | id = StringIdField(primary_key=True) 32 | 33 | def __repr__(self): 34 | return '<%s: %s>' % (self.__class__.__name__, self.id) 35 | 36 | @classmethod 37 | def get_pk_prefix(cls): 38 | return cls._get_collection_name()[:4] 39 | 40 | @classmethod 41 | def _generate_pk(cls): 42 | return '%s_%s' % (cls.get_pk_prefix(), zbase62.b2a(os.urandom(32))) 43 | 44 | def save(self, *args, **kwargs): 45 | old_id = self.id 46 | 47 | # Don't cascade saves by default. 48 | kwargs['cascade'] = kwargs.get('cascade', False) 49 | 50 | try: 51 | if not self.id: 52 | self.id = self._generate_pk() 53 | 54 | # Throw an exception if another object with this id already exists. 55 | kwargs['force_insert'] = True 56 | 57 | # But don't do that when cascading. 58 | kwargs['cascade_kwargs'] = {'force_insert': False} 59 | 60 | return super(RandomPKDocument, self).save(*args, **kwargs) 61 | except OperationError as err: 62 | self.id = old_id 63 | 64 | # Use "startswith" instead of "in". Otherwise, if a free form 65 | # StringField had a unique constraint someone could inject that 66 | # string into the error message. 67 | if str(err).startswith( 68 | 'Tried to save duplicate unique keys (E11000 duplicate key error index: %s.%s.$_id_ ' 69 | % (self._get_db().name, self._get_collection_name()) 70 | ): 71 | return self.save(*args, **kwargs) 72 | else: 73 | raise 74 | 75 | meta = {'abstract': True} 76 | 77 | 78 | class DocumentBase(Document): 79 | date_created = DateTimeField(required=True) 80 | date_updated = DateTimeField(required=True) 81 | 82 | meta = {'abstract': True} 83 | 84 | def _type(self): 85 | return str(self.__class__.__name__) 86 | 87 | def save(self, *args, **kwargs): 88 | update_date = kwargs.pop('update_date', True) 89 | kwargs['cascade'] = kwargs.get('cascade', False) 90 | if update_date: 91 | now = datetime.datetime.utcnow() 92 | if not self.date_created: 93 | self.date_created = now 94 | self.date_updated = now 95 | return super(DocumentBase, self).save(*args, **kwargs) 96 | 97 | def modify(self, *args, **kwargs): 98 | update_date = kwargs.pop('update_date', True) 99 | if update_date and 'set__date_updated' not in kwargs: 100 | kwargs['set__date_updated'] = datetime.datetime.utcnow() 101 | return super(DocumentBase, self).modify(*args, **kwargs) 102 | 103 | def update(self, *args, **kwargs): 104 | update_date = kwargs.pop('update_date', True) 105 | if update_date and 'set__date_updated' not in kwargs: 106 | kwargs['set__date_updated'] = datetime.datetime.utcnow() 107 | super(DocumentBase, self).update(*args, **kwargs) 108 | 109 | 110 | class SoftDeleteDocument(Document): 111 | is_deleted = BooleanField(default=False, required=True) 112 | 113 | def modify(self, **kwargs): 114 | if 'set__is_deleted' in kwargs and kwargs['set__is_deleted'] is None: 115 | raise ValidationError('is_deleted cannot be set to None') 116 | return super(SoftDeleteDocument, self).modify(**kwargs) 117 | 118 | def update(self, **kwargs): 119 | if 'set__is_deleted' in kwargs and kwargs['set__is_deleted'] is None: 120 | raise ValidationError('is_deleted cannot be set to None') 121 | super(SoftDeleteDocument, self).update(**kwargs) 122 | 123 | def delete(self, **kwargs): 124 | # delete only if already saved 125 | if self.pk: 126 | self.is_deleted = True 127 | self.modify(set__is_deleted=self.is_deleted) 128 | 129 | @queryset_manager 130 | def all_objects(doc_cls, queryset): 131 | if not hasattr(doc_cls, '_all_objs_queryset'): 132 | doc_cls._all_objs_queryset = QuerySet( 133 | doc_cls, doc_cls._get_collection() 134 | ) 135 | return doc_cls._all_objs_queryset 136 | 137 | meta = {'abstract': True, 'queryset_class': NotDeletedQuerySet} 138 | -------------------------------------------------------------------------------- /flask_common/mongo/fields/__init__.py: -------------------------------------------------------------------------------- 1 | # Basic fields that don't have any non-MongoEngine dependencies 2 | from .basic import LowerEmailField, LowerStringField, TrimmedStringField 3 | 4 | # Crypto fields 5 | try: 6 | from .crypto import EncryptedBinaryField, EncryptedStringField 7 | except ImportError: 8 | pass 9 | 10 | # Phone numbers fields 11 | try: 12 | from .phone import PhoneField 13 | except ImportError: 14 | pass 15 | 16 | # Timezone fields 17 | try: 18 | from .tz import TimezoneField 19 | except ImportError: 20 | pass 21 | 22 | # UUID fields 23 | try: 24 | from .id import IDField 25 | except ImportError: 26 | pass 27 | 28 | __all__ = [ 29 | 'LowerEmailField', 30 | 'LowerStringField', 31 | 'TrimmedStringField', 32 | 'EncryptedBinaryField', 33 | 'EncryptedStringField', 34 | 'PhoneField', 35 | 'TimezoneField', 36 | 'IDField', 37 | ] 38 | -------------------------------------------------------------------------------- /flask_common/mongo/fields/basic.py: -------------------------------------------------------------------------------- 1 | from mongoengine.fields import EmailField, StringField 2 | 3 | 4 | class TrimmedStringField(StringField): 5 | def __init__(self, *args, **kwargs): 6 | kwargs['required'] = ( 7 | kwargs.get('required', False) or kwargs.get('min_length', 0) > 0 8 | ) 9 | super(TrimmedStringField, self).__init__(*args, **kwargs) 10 | 11 | def validate(self, value): 12 | super(TrimmedStringField, self).validate(value) 13 | if self.required and not value: 14 | self.error('Value cannot be blank.') 15 | 16 | def from_python(self, value): 17 | return value and value.strip() 18 | 19 | def to_mongo(self, value): 20 | return self.from_python(value) 21 | 22 | 23 | class LowerStringField(StringField): 24 | def from_python(self, value): 25 | return value and value.lower() 26 | 27 | def to_python(self, value): 28 | return value and value.lower() 29 | 30 | def prepare_query_value(self, op, value): 31 | return super(LowerStringField, self).prepare_query_value( 32 | op, value and value.lower() 33 | ) 34 | 35 | 36 | class LowerEmailField(StringField): 37 | def from_python(self, value): 38 | return value and value.lower().strip() 39 | 40 | def to_python(self, value): 41 | return self.from_python(value) 42 | 43 | def prepare_query_value(self, op, value): 44 | return super(LowerEmailField, self).prepare_query_value( 45 | op, value and value.lower().strip() 46 | ) 47 | 48 | def validate(self, value): 49 | if not EmailField.EMAIL_REGEX.match(value): 50 | self.error('Invalid email address: %s' % value) 51 | super(LowerEmailField, self).validate(value) 52 | -------------------------------------------------------------------------------- /flask_common/mongo/fields/crypto.py: -------------------------------------------------------------------------------- 1 | from bson import Binary 2 | from mongoengine.fields import BinaryField 3 | 4 | from flask_common.crypto import ( 5 | KEY_LENGTH, 6 | AuthenticationError, 7 | aes_decrypt, 8 | aes_encrypt, 9 | ) 10 | 11 | 12 | class EncryptedBinaryField(BinaryField): 13 | """ 14 | Encrypted binary data field. Encryption is completely transparent 15 | to the caller as the field automatically decrypts when the field 16 | is accessed and encrypts when the document is saved. The 17 | underlying algorithm currently is AES-256. 18 | """ 19 | 20 | def __init__(self, key_or_list, *args, **kwargs): 21 | """ 22 | key_or_list: A 512-bit binary string containing a 256-bit AES 23 | key followed by a 256-bit HMAC-SHA256 key. 24 | Alternatively, a list of keys for decryption may be provided. 25 | The first key will always be used for encryption, the other 26 | ones will be sequentially tried for decryption. This is e.g. 27 | useful for key migration. 28 | """ 29 | if isinstance(key_or_list, (list, tuple)): 30 | self.key_list = key_or_list 31 | else: 32 | self.key_list = [key_or_list] 33 | assert len(self.key_list) > 0, "No key provided" 34 | for key in self.key_list: 35 | assert len(key) == KEY_LENGTH, 'invalid key size' 36 | super(EncryptedBinaryField, self).__init__(*args, **kwargs) 37 | 38 | def _encrypt(self, data): 39 | return Binary(aes_encrypt(self.key_list[0], data)) 40 | 41 | def _decrypt(self, data): 42 | for key in self.key_list: 43 | try: 44 | return aes_decrypt(key, data) 45 | except AuthenticationError: 46 | pass 47 | 48 | raise AuthenticationError('message authentication failed') 49 | 50 | def to_python(self, value): 51 | return self._decrypt(value) if value else None 52 | 53 | def to_mongo(self, value): 54 | return self._encrypt(value) if value else None 55 | 56 | 57 | class EncryptedStringField(EncryptedBinaryField): 58 | """ 59 | Encrypted Unicode string field. Encryption is completely transparent 60 | to the caller as the field automatically decrypts when the field 61 | is accessed and encrypts when the document is saved. The 62 | underlying algorithm currently is AES-256. 63 | """ 64 | 65 | def to_python(self, value): 66 | decrypted_value = super(EncryptedStringField, self).to_python(value) 67 | return decrypted_value.decode('utf-8') if decrypted_value else None 68 | 69 | def to_mongo(self, value): 70 | encoded_value = value.encode('utf-8') if value else None 71 | return super(EncryptedStringField, self).to_mongo(encoded_value) 72 | -------------------------------------------------------------------------------- /flask_common/mongo/fields/id.py: -------------------------------------------------------------------------------- 1 | from mongoengine import UUIDField 2 | import uuid 3 | 4 | from flask_common.utils.id import id_to_uuid, uuid_to_id 5 | 6 | 7 | class IDField(UUIDField): 8 | """ 9 | MongoEngine field type representing a zbase62-encoded ID, stored as a UUID. 10 | 11 | IDs are prefixed with the prefix (given to the constructor), followed by an 12 | underscore, followed by the zbase62-encoded ID. 13 | 14 | If autogenerate=True is passed to the constructor, a random ID is generated 15 | and assigned to the field by default. 16 | """ 17 | 18 | def __init__(self, **kwargs): 19 | self.prefix = kwargs.pop('prefix') 20 | self.autogenerate = kwargs.pop('autogenerate', False) 21 | if self.autogenerate: 22 | if 'default' in kwargs: 23 | raise RuntimeError('Can\'t use "default" with "autogenerate"') 24 | kwargs['default'] = self.generate_id 25 | super(IDField, self).__init__(**kwargs) 26 | 27 | def generate_id(self): 28 | return uuid_to_id(uuid.uuid4(), self.prefix) 29 | 30 | def to_python(self, value): 31 | if isinstance(value, uuid.UUID): 32 | return uuid_to_id(value, self.prefix) 33 | else: 34 | return value 35 | 36 | def to_mongo(self, value): 37 | if isinstance(value, str): 38 | value = id_to_uuid(value) 39 | return super(IDField, self).to_mongo(value) 40 | 41 | def prepare_query_value(self, op, value): 42 | if isinstance(value, str): 43 | try: 44 | value = id_to_uuid(value) 45 | except ValueError: 46 | value = None 47 | return super(IDField, self).prepare_query_value(op, value) 48 | 49 | def validate(self, value): 50 | try: 51 | id_to_uuid(value) 52 | except Exception as exc: 53 | self.error('Could not convert to UUID: %s' % exc) 54 | -------------------------------------------------------------------------------- /flask_common/mongo/fields/phone.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import phonenumbers 4 | 5 | from mongoengine.fields import StringField 6 | 7 | 8 | class PhoneField(StringField): 9 | """ 10 | Field that performs phone number validation. 11 | Values are stored in the format "+14151231234x123" in MongoDB and displayed 12 | in the format "+1 415-123-1234 ext. 123" in Python. 13 | """ 14 | 15 | def __init__(self, *args, **kwargs): 16 | self._strict_validation = kwargs.pop('strict', False) 17 | super(PhoneField, self).__init__(*args, **kwargs) 18 | 19 | @classmethod 20 | def _parse(cls, value, region=None): 21 | # valid numbers don't start with the same digit(s) as their country code so we strip them 22 | country_code = phonenumbers.country_code_for_region(region) 23 | if country_code and value.startswith(str(country_code)): 24 | value = value[len(str(country_code)) :] 25 | 26 | parsed = phonenumbers.parse(value, region) 27 | 28 | # strip empty extension 29 | if parsed.country_code == 1 and len(str(parsed.national_number)) > 10: 30 | regex = re.compile(r'.+\s*e?xt?\.?\s*$') 31 | if regex.match(value): 32 | value = re.sub(r'\s*e?xt?\.?\s*$', '', value) 33 | new_parsed = phonenumbers.parse(value, region) 34 | if len(str(new_parsed)) >= 10: 35 | parsed = new_parsed 36 | 37 | return parsed 38 | 39 | def validate(self, value): 40 | if not self.required and not value: 41 | return None 42 | 43 | error_msg = 'Phone number is not valid. Please use the international format like +16505551234' 44 | try: 45 | number = PhoneField._parse(value) 46 | 47 | if self._strict_validation and not phonenumbers.is_valid_number( 48 | number 49 | ): 50 | raise phonenumbers.NumberParseException( 51 | phonenumbers.NumberParseException.NOT_A_NUMBER, error_msg 52 | ) 53 | 54 | except phonenumbers.NumberParseException: 55 | self.error(error_msg) 56 | 57 | def from_python(self, value): 58 | return PhoneField.to_raw_phone(value) 59 | 60 | def _get_formatted_phone(self, value, form): 61 | if isinstance(value, str) and value != '': 62 | try: 63 | phone = PhoneField._parse(value) 64 | value = phonenumbers.format_number(phone, form) 65 | except phonenumbers.NumberParseException: 66 | pass 67 | return value 68 | 69 | def to_formatted_phone(self, value): 70 | return self._get_formatted_phone( 71 | value, phonenumbers.PhoneNumberFormat.INTERNATIONAL 72 | ) 73 | 74 | def to_local_formatted_phone(self, value): 75 | return self._get_formatted_phone( 76 | value, phonenumbers.PhoneNumberFormat.NATIONAL 77 | ) 78 | 79 | @classmethod 80 | def to_raw_phone(cls, value, region=None): 81 | if isinstance(value, str) and value != '': 82 | try: 83 | number = value 84 | phone = PhoneField._parse(number, region) 85 | number = phonenumbers.format_number( 86 | phone, phonenumbers.PhoneNumberFormat.E164 87 | ) 88 | if phone.extension: 89 | number += 'x%s' % phone.extension 90 | return number 91 | except phonenumbers.NumberParseException: 92 | pass 93 | return value 94 | 95 | def prepare_query_value(self, op, value): 96 | return super(PhoneField, self).prepare_query_value( 97 | op, PhoneField.to_raw_phone(value) 98 | ) 99 | -------------------------------------------------------------------------------- /flask_common/mongo/fields/tz.py: -------------------------------------------------------------------------------- 1 | import pytz 2 | 3 | from mongoengine.fields import StringField 4 | 5 | 6 | class TimezoneField(StringField): 7 | def __init__(self, *args, **kwargs): 8 | defaults = { 9 | 'default': 'UTC', 10 | 'choices': tuple(zip(pytz.all_timezones, pytz.all_timezones)), 11 | } 12 | defaults.update(kwargs) 13 | super(TimezoneField, self).__init__(*args, **defaults) 14 | 15 | def to_python(self, value): 16 | return pytz.timezone(value) 17 | 18 | def to_mongo(self, value): 19 | return str(value) 20 | -------------------------------------------------------------------------------- /flask_common/mongo/query_counters.py: -------------------------------------------------------------------------------- 1 | from mongoengine.context_managers import query_counter 2 | 3 | 4 | class custom_query_counter(query_counter): 5 | """ 6 | Subclass of MongoEngine's query_counter context manager that also lets 7 | you ignore some of the collections (just extend `get_ignored_collections`). 8 | 9 | Initialize with `custom_query_counter(verbose=True)` for debugging. 10 | """ 11 | 12 | def __init__(self, verbose=False): 13 | super(custom_query_counter, self).__init__() 14 | self.verbose = verbose 15 | 16 | def get_ignored_collections(self): 17 | return [ 18 | "{0}.system.indexes".format(self.db.name), 19 | "{0}.system.namespaces".format(self.db.name), 20 | "{0}.system.profile".format(self.db.name), 21 | "{0}.$cmd".format(self.db.name), 22 | ] 23 | 24 | def _get_queries(self): 25 | filter_query = { 26 | "$or": [ 27 | { 28 | "ns": {"$nin": self.get_ignored_collections()}, 29 | "op": {"$ne": "killcursors"}, 30 | }, 31 | { 32 | "ns": "{0}.$cmd".format(self.db.name), 33 | "command.findAndModify": {"$exists": True}, 34 | }, 35 | ] 36 | } 37 | return self.db.system.profile.find(filter_query) 38 | 39 | def _get_count(self): 40 | """ Get the number of queries. """ 41 | queries = self._get_queries() 42 | if self.verbose: 43 | print('-' * 80) 44 | for query in queries: 45 | # findAndModify appear in $cmd -- we'll make them more readable 46 | if query['ns'].endswith('.$cmd'): 47 | if 'findAndModify' in query['command']: 48 | ns = '.'.join( 49 | [ 50 | query['ns'].split('.')[0], 51 | query['command']['findAndModify'], 52 | ] 53 | ) 54 | op = 'findAndModify' 55 | query = query['command'].get('query') 56 | else: 57 | ns = query['ns'] 58 | op = query['op'] 59 | query = query['command'] 60 | else: 61 | ns = query['ns'] 62 | op = query['op'] 63 | query = query.get('query') 64 | print('{} [{}] {}'.format(ns, op, query)) 65 | print() 66 | print('-' * 80) 67 | count = queries.count() 68 | return count 69 | -------------------------------------------------------------------------------- /flask_common/mongo/querysets.py: -------------------------------------------------------------------------------- 1 | from flask import current_app 2 | 3 | from mongoengine import Q, QuerySet 4 | 5 | 6 | class NotDeletedQuerySet(QuerySet): 7 | """QuerySet that doesn't return soft-deleted documents by default.""" 8 | 9 | def __call__( 10 | self, 11 | q_obj=None, 12 | class_check=True, 13 | slave_okay=False, 14 | read_preference=None, 15 | **query 16 | ): 17 | # We don't use __ne=True here, because $ne isn't a selective query and 18 | # doesn't utilize an index in the most efficient manner. See 19 | # http://docs.mongodb.org/manual/faq/indexes/#using-ne-and-nin-in-a-query-is-slow-why. 20 | extra_q_obj = Q(is_deleted=False) 21 | q_obj = q_obj & extra_q_obj if q_obj else extra_q_obj 22 | return super(NotDeletedQuerySet, self).__call__( 23 | q_obj, class_check, slave_okay, read_preference, **query 24 | ) 25 | 26 | def count(self, *args, **kwargs): 27 | # we need this hack for doc.objects.count() to exclude deleted objects 28 | if not getattr(self, '_not_deleted_query_applied', False): 29 | self = self.all() 30 | return super(NotDeletedQuerySet, self).count(*args, **kwargs) 31 | 32 | 33 | class ForbiddenQueryException(Exception): 34 | """Exception raised by ForbiddenQueriesQuerySet""" 35 | 36 | 37 | class ForbiddenQueriesQuerySet(QuerySet): 38 | """ 39 | A queryset you can use to block some potentially dangerous queries 40 | just before they're sent to MongoDB. Override this queryset with a list 41 | of forbidden queries and then use the overridden class in a Document's 42 | meta['queryset_class']. 43 | 44 | `forbidden_queries` should be a list of dicts in the form of: 45 | { 46 | # shape of a query, e.g. `{"_cls": {"$in": 1}}` 47 | 'query_shape': {...}, 48 | 49 | # optional, forbids *all* orderings by default 50 | 'orderings': [{key: direction, ...}, None, etc.] 51 | 52 | # optional, defaults to 0. Even if the query matches the shape and 53 | # the ordering, we allow queries with limit < `max_allowed_limit`. 54 | 'max_allowed_limit': int or None 55 | } 56 | 57 | You can mark *any* queryset as safe with `mark_as_safe`. 58 | """ 59 | 60 | forbidden_queries = None # override this in a subclass 61 | 62 | _marked_as_safe = False 63 | 64 | def _check_for_forbidden_queries(self, idx_key=None): 65 | # idx_key can be a slice or an int from Doc.objects[idx_key] 66 | is_testing = False 67 | try: 68 | is_testing = current_app.testing 69 | except RuntimeError: 70 | pass 71 | 72 | if self._marked_as_safe or self._none or is_testing: 73 | return 74 | 75 | query_shape = self._get_query_shape(self._query) 76 | for forbidden in self.forbidden_queries: 77 | if query_shape == forbidden['query_shape'] and ( 78 | not forbidden.get('orderings') 79 | or self._ordering in forbidden['orderings'] 80 | ): 81 | 82 | # determine the real limit based on objects.limit or objects[idx_key] 83 | limit = self._limit 84 | if limit is None and idx_key is not None: 85 | if isinstance(idx_key, slice): 86 | limit = idx_key.stop 87 | else: 88 | limit = idx_key 89 | 90 | if limit is None or limit > forbidden.get( 91 | 'max_allowed_limit', 0 92 | ): 93 | raise ForbiddenQueryException( 94 | 'Forbidden query used! Query: %s, Ordering: %s, Limit: %s' 95 | % (self._query, self._ordering, limit) 96 | ) 97 | 98 | def __next__(self): 99 | self._check_for_forbidden_queries() 100 | try: 101 | return super(ForbiddenQueriesQuerySet, self).__next__() 102 | except AttributeError: 103 | return super(ForbiddenQueriesQuerySet, self).next() 104 | 105 | def __getitem__(self, key): 106 | self._check_for_forbidden_queries(key) 107 | return super(ForbiddenQueriesQuerySet, self).__getitem__(key) 108 | 109 | def mark_as_safe(self): 110 | """ 111 | If you call Doc.objects.filter(...).mark_as_safe(), you can query by 112 | whatever you want (including the forbidden queries). 113 | """ 114 | self._marked_as_safe = True 115 | return self 116 | 117 | def _get_query_shape(self, query): 118 | """ 119 | Convert a query into a query shape, e.g.: 120 | * `{"_cls": "whatever"}` into `{"_cls": 1}` 121 | * `{"date": {"$gte": '2015-01-01', "$lte": "2015-01-31"}` into 122 | `{"date": {"$gte": 1, "$lte": 1}}` 123 | * `{"_cls": {"$in": ["a", "b", "c"]}}` into `{"_cls": {"$in": []}}` 124 | """ 125 | if not query: 126 | return query 127 | 128 | query_shape = {} 129 | for key, val in query.items(): 130 | if isinstance(val, dict): 131 | query_shape[key] = self._get_query_shape(val) 132 | elif isinstance(val, (list, tuple)): 133 | query_shape[key] = [] 134 | else: 135 | query_shape[key] = 1 136 | return query_shape 137 | -------------------------------------------------------------------------------- /flask_common/mongo/utils.py: -------------------------------------------------------------------------------- 1 | from flask_common.utils import grouper 2 | from mongoengine import ListField, ReferenceField, SafeReferenceField 3 | 4 | 5 | def iter_no_cache(query_set): 6 | """Iterate over a MongoEngine QuerySet without caching it. 7 | 8 | Useful for iterating over large result sets / bulk actions. 9 | 10 | If a batch size is not set, apply a sensible default of 1000 11 | that's better than what Mongo server is doing (101 first and 12 | then as many as it can fit in 4MB) to avoid cursor timeouts. 13 | """ 14 | if query_set._batch_size is None: 15 | query_set = query_set.batch_size(1000) 16 | 17 | next = query_set.__next__ 18 | 19 | while True: 20 | try: 21 | yield next() 22 | except StopIteration: 23 | return 24 | 25 | 26 | def fetch_related( 27 | objs, 28 | field_dict, 29 | cache_map=None, 30 | extra_filters=None, 31 | batch_size=100, 32 | filter_funcs=None, 33 | ): 34 | """ 35 | Recursively fetches related objects for the given document instances. 36 | Sample usage: 37 | 38 | fetch_related(objs, { 39 | 'user': True, 40 | 'lead': { 41 | 'created_by': True, 42 | 'updated_by': True, 43 | }, 44 | 'contact': ['id'], 45 | }) 46 | 47 | In this sample, users and leads for all objs will be fetched and attached. 48 | Then, lead.created_by and lead.updated_by users are fetched in one query 49 | and attached. Finally, a contact will be pulled in, only fetching the ID 50 | from the database. 51 | 52 | Note that the function doesn't merge queries for the same document class 53 | across multiple (recursive) function calls, but it never fetches the same 54 | related object twice. 55 | 56 | Be *very* cautious when pulling in only specific fields for a related 57 | object. Accessing fields that haven't been pulled will falsely show None 58 | even if a value for that field exists in the database. 59 | 60 | Given how fragile partially pulled objects are, we don't cache them in the 61 | cache map and hence the same related object may be fetched more than once. 62 | 63 | If you need to call fetch_related multiple times, it's worth passing a 64 | cache_map (initially it can be an empty dictionary). It will be extended 65 | during each call to include all the objects fetched up until the current 66 | call. This way we ensure that the same objects aren't fetched more than 67 | once across multiple fetch_related calls. Cache map has a form of: 68 | { DocumentClass: { id_of_fetched_obj: obj, id_of_fetched_obj2: obj2 } }. 69 | 70 | The function takes an optional dict extra_filters in the form 71 | {document_class: filters} which will be passed as filters to the QuerySet. 72 | This can be useful to pass a shard key filter. For example, if the Contact 73 | model uses organization_id as a shard key, and all contacts are expected to 74 | be in the same organization, you can pass: 75 | {Contact: {'organization_id': organization.pk}} 76 | 77 | The function takes an optional dict filter_funcs in the form 78 | {document_class: filter_func} which represents the function that is used 79 | to fetch and filter documents (defaults to document_class.objects.filter). 80 | """ 81 | if not objs: 82 | return 83 | 84 | if extra_filters is None: 85 | extra_filters = {} 86 | 87 | if filter_funcs is None: 88 | filter_funcs = {} 89 | 90 | # Cache map holds a map of pks to objs for objects we fetched, over all 91 | # iterations / from previous calls, by document class (doesn't include 92 | # partially fetched objects) 93 | if cache_map is None: 94 | cache_map = {} 95 | 96 | # Cache map for partial fetches (i.e. ones where only specific fields 97 | # were requested). is only temporary since we don't want to cache partial 98 | # data through subsequent calls of this function 99 | partial_cache_map = {} 100 | 101 | # Helper mapping: field_name -> ( 102 | # field instance, 103 | # name of the field in the db, 104 | # document class, 105 | # fields to fetch (or None if the whole related obj should be fetched) 106 | # ) 107 | field_info = {} 108 | 109 | # IDs to fetch and their fetch options, by document class 110 | fetch_map = {} 111 | 112 | def id_from_value(field, val): 113 | if field.dbref: 114 | return val.id 115 | else: 116 | return val 117 | 118 | def get_instance_for_each_type(objs): 119 | instances = [] 120 | types = [] 121 | for obj in objs: 122 | if type(obj) not in types: 123 | instances.append(obj) 124 | types.append(type(obj)) 125 | return instances 126 | 127 | def setattr_unchanged(obj, key, val): 128 | """ 129 | Sets an attribute on the given document object without changing the 130 | _changed_fields set. This is because we don't actually modify the 131 | related objects. 132 | """ 133 | changed = key in obj._changed_fields 134 | setattr(obj, key, val) 135 | if not changed and key in obj._changed_fields: 136 | obj._changed_fields.remove(key) 137 | 138 | # Populate the field_info 139 | instances = get_instance_for_each_type(objs) 140 | for field_name, sub_field_dict in field_dict.items(): 141 | 142 | instance = [ 143 | instance 144 | for instance in instances 145 | if instance and field_name in instance.__class__._fields 146 | ] 147 | if not instance: 148 | continue # None of the objects contains this field 149 | 150 | instance = instance[0] 151 | field = instance.__class__._fields[field_name] 152 | db_field = instance._db_field_map.get(field_name, field_name) 153 | if isinstance(field, ReferenceField): # includes SafeReferenceListField 154 | document_class = field.document_type 155 | elif isinstance(field, ListField) and isinstance( 156 | field.field, ReferenceField 157 | ): 158 | document_class = field.field.document_type 159 | else: 160 | raise NotImplementedError( 161 | '%s class not supported for fetch_related' 162 | % field.__class__.__name__ 163 | ) 164 | fields_to_fetch = ( 165 | sub_field_dict 166 | if isinstance(sub_field_dict, (list, tuple)) 167 | else None 168 | ) 169 | field_info[field_name] = ( 170 | field, 171 | db_field, 172 | document_class, 173 | fields_to_fetch, 174 | ) 175 | 176 | # Determine what IDs we want to fetch 177 | for field_name, sub_field_dict in field_dict.items(): 178 | field, db_field, document_class, fields_to_fetch = field_info.get( 179 | field_name 180 | ) or (None, None, None, None) 181 | if not field: 182 | continue 183 | 184 | # we need to use _db_data for safe references because touching their 185 | # pks triggers a query 186 | if isinstance(field, SafeReferenceField): 187 | ids = { 188 | id_from_value(field, obj._db_data.get(db_field, None)) 189 | for obj in objs 190 | if field_name not in obj._internal_data 191 | and obj._db_data.get(db_field, None) 192 | } 193 | elif isinstance(field, ListField): 194 | ids = [ 195 | obj._db_data.get(db_field, []) 196 | for obj in objs 197 | if field_name not in obj._internal_data 198 | ] 199 | ids = { 200 | id_from_value(field.field, item) 201 | for sublist in ids 202 | for item in sublist 203 | } # flatten the list of lists 204 | elif isinstance(field, ReferenceField): 205 | ids = { 206 | getattr(obj, field_name).pk 207 | for obj in objs 208 | if getattr(obj, field_name, None) 209 | and getattr(getattr(obj, field_name), '_lazy', False) 210 | } 211 | 212 | # remove ids of objects that are already in the cache map 213 | if document_class in cache_map: 214 | ids -= set(cache_map[document_class]) 215 | 216 | # no point setting up the data structures for fields where there's nothing to fetch 217 | if not ids: 218 | continue 219 | 220 | # set up cache maps for the newly seen document class 221 | if document_class not in cache_map: 222 | cache_map[document_class] = {} 223 | if document_class not in partial_cache_map: 224 | partial_cache_map[document_class] = {} 225 | 226 | # set up a fetch map for this document class 227 | if document_class in fetch_map: 228 | fetch_map[document_class]['ids'] |= ids 229 | 230 | # make sure we don't allow partial fetching if the same document class 231 | # has conflicting fields_to_fetch (e.g. { user: ["id"], created_by: True }) 232 | # TODO this could be improved to fetch a union of all requested fields 233 | if fields_to_fetch != fetch_map[document_class]['fields_to_fetch']: 234 | raise RuntimeError( 235 | 'Cannot specify different fields_to_fetch for the same document class %s' 236 | % document_class 237 | ) 238 | else: 239 | fetch_map[document_class] = { 240 | 'ids': ids, 241 | 'fields_to_fetch': fields_to_fetch, 242 | } 243 | 244 | # Fetch objects and cache them 245 | for document_class, fetch_opts in fetch_map.items(): 246 | cls_filters = extra_filters.get(document_class, {}) 247 | 248 | # Fetch objects in batches. Also set the batch size so we don't do 249 | # multiple queries per batch. 250 | for id_group in grouper(batch_size, list(fetch_opts['ids'])): 251 | filter_func = filter_funcs.get( 252 | document_class, document_class.objects.filter 253 | ) 254 | qs = filter_func(pk__in=id_group, **cls_filters).clear_cls_query() 255 | 256 | # only fetch the requested fields 257 | if fetch_opts['fields_to_fetch']: 258 | qs = qs.only(*fetch_opts['fields_to_fetch']) 259 | 260 | # We have to apply this at the end, or only() won't work. 261 | qs = qs.batch_size(batch_size) 262 | 263 | # update the cache map - either the persistent one with full 264 | # objects, or the ephemeral partial cache 265 | update_dict = {obj.pk: obj for obj in qs} 266 | if fetch_opts['fields_to_fetch'] is None: 267 | cache_map[document_class].update(update_dict) 268 | else: 269 | partial_cache_map[document_class].update(update_dict) 270 | 271 | # Assign objects 272 | for field_name, sub_field_dict in field_dict.items(): 273 | field, db_field, document_class, fields_to_fetch = field_info.get( 274 | field_name 275 | ) or (None, None, None, None) 276 | 277 | if not field: 278 | continue 279 | 280 | # merge the permanent and temporary caches for the ease of assignment 281 | pk_to_obj = cache_map.get(document_class, {}).copy() 282 | pk_to_obj.update(partial_cache_map.get(document_class, {})) 283 | 284 | # if a dict of subfields was passed, go recursive 285 | if pk_to_obj and isinstance(sub_field_dict, dict): 286 | fetch_related( 287 | list(pk_to_obj.values()), sub_field_dict, cache_map=cache_map 288 | ) 289 | 290 | # attach all the values to all the objects 291 | for obj in objs: 292 | if isinstance(field, SafeReferenceField): 293 | if field_name not in obj._internal_data: 294 | val = obj._db_data.get(db_field, None) 295 | if val: 296 | setattr_unchanged( 297 | obj, 298 | field_name, 299 | pk_to_obj.get(id_from_value(field, val)), 300 | ) 301 | 302 | elif isinstance(field, ReferenceField): 303 | val = getattr(obj, field_name, None) 304 | if val and getattr(val, '_lazy', False): 305 | rel_obj = pk_to_obj.get(val.pk) 306 | if rel_obj: 307 | setattr_unchanged(obj, field_name, rel_obj) 308 | 309 | elif isinstance(field, ListField): 310 | if field_name not in obj._internal_data: 311 | value = list( 312 | filter( 313 | None, 314 | [ 315 | pk_to_obj.get(id_from_value(field.field, val)) 316 | for val in obj._db_data.get(db_field, []) 317 | ], 318 | ) 319 | ) 320 | setattr_unchanged(obj, field_name, value) 321 | -------------------------------------------------------------------------------- /flask_common/test_helpers.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | class SetCompare(object): 5 | """ 6 | Comparator that doesn't take ordering into account. For example, the 7 | following expression is True: 8 | 9 | SetCompare([1, 2, 3]) == [2, 3, 1] 10 | """ 11 | 12 | def __init__(self, members): 13 | self.members = members 14 | 15 | def __eq__(self, other): 16 | return set(other) == set(self.members) 17 | 18 | def __ne__(self, other): 19 | return not self == other 20 | 21 | 22 | class RegexSetCompare(object): 23 | """ 24 | Comparator that takes a regex and a set of arguments and doesn't take 25 | ordering of the arguments into account. For example, the following 26 | expression is True: 27 | 28 | RegexSetCompare('(.*) OR (.*) OR (.*)', ['1', '2', '3']) == '2 OR 3 OR 1' 29 | """ 30 | 31 | def __init__(self, regex, args): 32 | self.regex = re.compile(regex) 33 | self.args = args 34 | 35 | def __eq__(self, other): 36 | match = self.regex.match(other) 37 | if not match: 38 | return False 39 | return set(match.groups()) == set(self.args) 40 | 41 | def __ne__(self, other): 42 | return not self == other 43 | 44 | 45 | class Capture(object): 46 | """ 47 | Comparator that always returns True and returns the captured object when 48 | called. For example: 49 | 50 | capture = Capture() 51 | capture == 'Hello' # returns True 52 | capture() # returns 'Hello' 53 | """ 54 | 55 | def __call__(self): 56 | return self.obj 57 | 58 | def __eq__(self, other): 59 | self.obj = other 60 | return True 61 | 62 | 63 | class DictCompare(dict): 64 | """ 65 | Comparator that returns True if all the items in the comparator's dict are 66 | contained in the other dict and match values, ignoring keys that are in the 67 | other dict only. 68 | 69 | For example, the following is true: 70 | DictCompare({'a': 'b'}) == {'a': 'b', 'c': 'd'} 71 | 72 | But the following are false: 73 | DictCompare({'a': 'b'}) == {'a': 'c'} 74 | DictCompare({'a': 'b'}) == {'b': 'c'} 75 | """ 76 | 77 | def __eq__(self, other): 78 | for k, v in self.items(): 79 | if k not in other or other[k] != v: 80 | return False 81 | return True 82 | 83 | def __ne__(self, other): 84 | return not self == other 85 | -------------------------------------------------------------------------------- /flask_common/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # UUID related helpers 2 | from .id import id_to_uuid, uuid_to_id 3 | 4 | # List/iterator helpers 5 | from .lists import grouper 6 | 7 | # TODO: split these up 8 | try: 9 | from .legacy import ( 10 | CsvReader, 11 | CsvWriter, 12 | DetailedSMTPHandler, 13 | FileFormatException, 14 | NamedCsvReader, 15 | Reader, 16 | Normalization, 17 | NormalizationReader, 18 | ThreadedTimer, 19 | Timeout, 20 | Timer, 21 | apply_recursively, 22 | build_normalization_map, 23 | combine, 24 | finite_float, 25 | force_unicode, 26 | format_locals, 27 | json_list_generator, 28 | lazylist, 29 | localtoday, 30 | make_unaware, 31 | parse_date_tz, 32 | returns_xml, 33 | retry, 34 | slugify, 35 | smart_unicode, 36 | truncate, 37 | unicode_csv_reader, 38 | uniqify, 39 | utctime, 40 | utctoday, 41 | utf_8_encoder, 42 | ) 43 | except ImportError: 44 | pass 45 | 46 | __all__ = [ 47 | 'id_to_uuid', 48 | 'uuid_to_id', 49 | 'grouper', 50 | 'CsvReader', 51 | 'CsvWriter', 52 | 'DetailedSMTPHandler', 53 | 'FileFormatException', 54 | 'NamedCsvReader', 55 | 'Reader', 56 | 'Normalization', 57 | 'NormalizationReader', 58 | 'ThreadedTimer', 59 | 'Timeout', 60 | 'Timer', 61 | 'apply_recursively', 62 | 'build_normalization_map', 63 | 'combine', 64 | 'finite_float', 65 | 'force_unicode', 66 | 'format_locals', 67 | 'json_list_generator', 68 | 'lazylist', 69 | 'localtoday', 70 | 'make_unaware', 71 | 'parse_date_tz', 72 | 'returns_xml', 73 | 'retry', 74 | 'slugify', 75 | 'smart_unicode', 76 | 'truncate', 77 | 'unicode_csv_reader', 78 | 'uniqify', 79 | 'utctime', 80 | 'utctoday', 81 | 'utf_8_encoder', 82 | ] 83 | -------------------------------------------------------------------------------- /flask_common/utils/cache.py: -------------------------------------------------------------------------------- 1 | from .objects import freeze 2 | 3 | from weakref import WeakValueDictionary 4 | 5 | 6 | class WeakValueCache(WeakValueDictionary): 7 | """This is a local in-process cache that holds 8 | an object for as long as there's a live reference to it. 9 | 10 | Subclass and implement lookup method, then use indexing 11 | cache[key] to retrieve values. 12 | """ 13 | 14 | def __contains__(self, key): 15 | return WeakValueDictionary.__contains__(self, freeze(key)) 16 | 17 | def __setitem__(self, key, value): 18 | WeakValueDictionary.__setitem__(self, freeze(key), value) 19 | 20 | def __getitem__(self, key): 21 | frozen_key = freeze(key) 22 | if frozen_key in self: 23 | return WeakValueDictionary.__getitem__(self, frozen_key) 24 | value = self.lookup(key) 25 | self[frozen_key] = value 26 | return value 27 | 28 | def lookup(self, key): 29 | raise NotImplementedError("Implement lookup for the actual item") 30 | -------------------------------------------------------------------------------- /flask_common/utils/decorators.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | 3 | 4 | def with_context_manager(ctx_manager): 5 | """Decorator syntax for context managers.""" 6 | 7 | def decorator(f): 8 | @wraps(f) 9 | def decorated_function(*args, **kwargs): 10 | with ctx_manager: 11 | return f(*args, **kwargs) 12 | 13 | return decorated_function 14 | 15 | return decorator 16 | -------------------------------------------------------------------------------- /flask_common/utils/id.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from zbase62 import zbase62 4 | 5 | 6 | def uuid_to_id(uuid_obj, prefix): 7 | return '{}_{}'.format(prefix, zbase62.b2a(uuid_obj.bytes)) 8 | 9 | 10 | def id_to_uuid(id_str): 11 | uuid_bytes = zbase62.a2b(str(id_str[id_str.find('_') + 1 :])) 12 | return uuid.UUID(bytes=uuid_bytes) 13 | -------------------------------------------------------------------------------- /flask_common/utils/legacy.py: -------------------------------------------------------------------------------- 1 | import calendar 2 | import codecs 3 | import csv 4 | import io 5 | import datetime 6 | import itertools 7 | import math 8 | import re 9 | import signal 10 | import smtplib 11 | import threading 12 | import time 13 | 14 | from email.utils import formatdate 15 | from flask import request, Response 16 | from functools import wraps 17 | from logging.handlers import SMTPHandler 18 | 19 | try: 20 | import mongoengine 21 | except ImportError: 22 | mongoengine = None 23 | 24 | from socket import gethostname 25 | 26 | 27 | def returns_xml(f): 28 | @wraps(f) 29 | def decorated_function(*args, **kwargs): 30 | r = f(*args, **kwargs) 31 | return Response(r, content_type='text/xml; charset=utf-8') 32 | 33 | return decorated_function 34 | 35 | 36 | def json_list_generator(results): 37 | """Given a generator of individual JSON results, generate a JSON array""" 38 | yield '[' 39 | this_val = next(results) 40 | while True: 41 | next_val = next(results, None) 42 | yield this_val + ',' if next_val else this_val 43 | this_val = next_val 44 | if not this_val: 45 | break 46 | yield ']' 47 | 48 | 49 | class DetailedSMTPHandler(SMTPHandler): 50 | def __init__(self, app_name, *args, **kwargs): 51 | self.app_name = app_name 52 | super(DetailedSMTPHandler, self).__init__(*args, **kwargs) 53 | 54 | def getSubject(self, record): 55 | error = 'Error' 56 | ei = record.exc_info 57 | if ei: 58 | error = '(%s) %s' % (ei[0].__name__, ei[1]) 59 | return "[%s] %s %s on %s" % ( 60 | self.app_name, 61 | request.path, 62 | error, 63 | gethostname(), 64 | ) 65 | 66 | def emit(self, record): 67 | """ 68 | Emit a record. 69 | 70 | Format the record and send it to the specified addressees. 71 | """ 72 | try: 73 | port = self.mailport 74 | if not port: 75 | port = smtplib.SMTP_PORT 76 | smtp = smtplib.SMTP(self.mailhost, port) 77 | msg = self.format(record) 78 | msg = "From: %s\nTo: %s\nSubject: %s\nDate: %s\n\n%s\n\nRequest.url: %s\n\nRequest.headers: %s\n\nRequest.args: %s\n\nRequest.form: %s\n\nRequest.data: %s\n" % ( 79 | self.fromaddr, 80 | ",".join(self.toaddrs), 81 | self.getSubject(record), 82 | formatdate(), 83 | msg, 84 | request.url, 85 | request.headers, 86 | request.args, 87 | request.form, 88 | request.data, 89 | ) 90 | if self.username: 91 | if self.secure is not None: 92 | smtp.ehlo() 93 | smtp.starttls(*self.secure) 94 | smtp.ehlo() 95 | smtp.login(self.username, self.password) 96 | smtp.sendmail(self.fromaddr, self.toaddrs, msg) 97 | smtp.quit() 98 | except (KeyboardInterrupt, SystemExit): 99 | raise 100 | except Exception: 101 | self.handleError(record) 102 | 103 | 104 | def unicode_csv_reader(unicode_csv_data, dialect=csv.excel, **kwargs): 105 | # csv.py doesn't do Unicode; encode temporarily as UTF-8: 106 | csv_reader = csv.reader( 107 | utf_8_encoder(unicode_csv_data), dialect=dialect, **kwargs 108 | ) 109 | for row in csv_reader: 110 | yield row 111 | 112 | 113 | def utf_8_encoder(unicode_csv_data): 114 | for line in unicode_csv_data: 115 | yield line 116 | 117 | 118 | class CsvReader(object): 119 | """Wrapper around csv reader that ignores non utf-8 chars and strips the 120 | record.""" 121 | 122 | def __init__(self, file_name, delimiter=','): 123 | self.reader = csv.reader(open(file_name, 'rbU'), delimiter=delimiter) 124 | 125 | def __iter__(self): 126 | return self 127 | 128 | def __next__(self): 129 | row = next(self.reader) 130 | row = [ 131 | el.decode('utf8', errors='ignore').replace('\"', '').strip() 132 | for el in row 133 | ] 134 | return row 135 | 136 | 137 | class NamedCsvReader(CsvReader): 138 | def __init__(self, *args, **kwargs): 139 | super(NamedCsvReader, self).__init__(*args, **kwargs) 140 | self.headers = next(super(NamedCsvReader, self)) 141 | 142 | def __next__(self): 143 | row = next(super(NamedCsvReader, self)) 144 | return dict(zip(self.headers, row)) 145 | 146 | 147 | class CsvWriter(object): 148 | """ 149 | A CSV writer which will write rows to CSV file "f", 150 | which is encoded in the given encoding. 151 | From http://docs.python.org/2/library/csv.html 152 | """ 153 | 154 | def __init__(self, f, dialect=csv.excel, encoding="utf-8", **kwds): 155 | # Redirect output to a queue 156 | self.queue = io.StringIO() 157 | self.writer = csv.writer(self.queue, dialect=dialect, **kwds) 158 | self.stream = f 159 | self.encoder = codecs.getincrementalencoder(encoding)() 160 | 161 | def writerow(self, row): 162 | self.writer.writerow( 163 | [s.encode("utf-8") if isinstance(s, str) else s for s in row] 164 | ) 165 | # Fetch UTF-8 output from the queue ... 166 | data = self.queue.getvalue() 167 | data = data.decode("utf-8") 168 | # ... and reencode it into the target encoding 169 | data = self.encoder.encode(data) 170 | # write to the target stream 171 | self.stream.write(data) 172 | # empty queue 173 | self.queue.truncate(0) 174 | 175 | def writerows(self, rows): 176 | for row in rows: 177 | self.writerow(row) 178 | 179 | 180 | def smart_unicode(s, encoding='utf-8', errors='strict'): 181 | if isinstance(s, str): 182 | return s 183 | if not isinstance(s, str): 184 | if hasattr(s, '__unicode__'): 185 | s = str(s) 186 | else: 187 | s = str(str(s), encoding, errors) 188 | elif not isinstance(s, str): 189 | s = s.decode(encoding, errors) 190 | return s 191 | 192 | 193 | def finite_float(value): 194 | """Convert any value to a finite float or throw a ValueError if it can't be done.""" 195 | value = float(value) 196 | if math.isnan(value) or math.isinf(value): 197 | raise ValueError("Can't convert %s to a finite float" % value) 198 | return value 199 | 200 | 201 | def utctoday(): 202 | now = datetime.datetime.utcnow() 203 | today = datetime.date(*now.timetuple()[:3]) 204 | return today 205 | 206 | 207 | def utctime(): 208 | """ Return seconds since epoch like time.time(), but in UTC. """ 209 | return calendar.timegm(datetime.datetime.utcnow().utctimetuple()) 210 | 211 | 212 | def localtoday(tz_or_offset): 213 | """ 214 | Returns the local today date based on either a timezone object or on a UTC 215 | offset in hours. 216 | """ 217 | import pytz 218 | 219 | utc_now = datetime.datetime.utcnow() 220 | try: 221 | local_now = tz_or_offset.normalize( 222 | pytz.utc.localize(utc_now).astimezone(tz_or_offset) 223 | ) 224 | except AttributeError: # tz has no attribute normalize, assume numeric offset 225 | local_now = utc_now + datetime.timedelta(hours=tz_or_offset) 226 | local_today = datetime.date(*local_now.timetuple()[:3]) 227 | return local_today 228 | 229 | 230 | def make_unaware(d): 231 | """Converts an unaware datetime in UTC or an aware datetime to an unaware 232 | datetime in UTC.""" 233 | import pytz 234 | 235 | # "A datetime object d is aware if d.tzinfo is not None and 236 | # d.tzinfo.utcoffset(d) does not return None." 237 | # - http://docs.python.org/2/library/datetime.html 238 | if d.tzinfo is not None and d.tzinfo.utcoffset(d) is not None: 239 | return d.astimezone(pytz.utc).replace(tzinfo=None) 240 | else: 241 | return d.replace(tzinfo=None) 242 | 243 | 244 | def _gen_tz_info_dict(): 245 | """ 246 | Generates the timezone info dict to be passed to dateutil's parse method. 247 | Since TZ names are ambiguous we prefer the common ones. 248 | """ 249 | 250 | # Adapted from http://stackoverflow.com/questions/1703546/parsing-date-time-string-with-timezone-abbreviated-name-in-python 251 | 252 | tz_str = '''-12 Y 253 | -11 X NUT SST 254 | -10 W CKT HAST HST TAHT TKT 255 | -9.5 MART MIT 256 | -9 V AKST GAMT GIT HADT HNY 257 | -8 U AKDT CIST HAY HNP PST PT 258 | -7 T HAP HNR MST PDT 259 | -6 S CST EAST GALT HAR HNC MDT 260 | -5 R CDT COT EASST ECT EST ET HAC HNE PET 261 | -4.5 HLV VET 262 | -4 Q AST BOT CLT COST EDT FKT GYT HAE HNA PYT 263 | -3.5 HNT NST NT 264 | -3 P ADT ART BRT CLST FKST GFT HAA PMST PYST SRT UYT WGT 265 | -2.5 HAT NDT 266 | -2 O BRST FNT PMDT UYST WGST 267 | -1 N AZOT CVT EGT 268 | 0 Z EGST GMT UTC WET WT 269 | 1 A CET DFT WAT WEDT WEST IST MEZ 270 | 2 B CAT CEDT CEST EET SAST WAST MESZ 271 | 3 C EAT EEDT EEST IDT MSK 272 | 3.5 IRST 273 | 4 D AMT AZT GET GST KUYT MSD MUT RET SAMT SCT 274 | 4.5 AFT IRDT 275 | 5 E AMST AQTT AZST HMT MAWT MVT PKT TFT TJT TMT UZT YEKT 276 | 5.5 SLT 277 | 5.75 NPT 278 | 6 F ALMT BIOT BTT IOT KGT NOVT OMST YEKST 279 | 6.5 CCT MMT 280 | 7 G CXT DAVT HOVT ICT KRAT NOVST OMSST THA WIB 281 | 8 H ACT AWST BDT BNT CAST HKT IRKT KRAST MYT PHT SGT ULAT WITA WST 282 | 9 I AWDT IRKST JST KST PWT TLT WDT WIT YAKT 283 | 9.5 ACST 284 | 10 K AEST ChST PGT VLAT YAKST YAPT 285 | 10.5 ACDT LHST 286 | 11 L AEDT LHDT MAGT NCT PONT SBT VLAST VUT 287 | 11.5 NFT 288 | 12 M ANAST ANAT FJT GILT MAGST MHT NZST PETST PETT TVT WFT 289 | 12.75 CHAST 290 | 13 FJST NZDT PHOT TOT 291 | 13.75 CHADT 292 | 14 LINT''' 293 | 294 | tzd = {} 295 | for tz_descr in (s.split() for s in tz_str.split('\n')): 296 | tz_offset = int(float(tz_descr[0]) * 3600) 297 | for tz_code in tz_descr[1:]: 298 | assert tz_code not in tzd, "duplicate TZ alias detected" 299 | tzd[tz_code] = tz_offset 300 | return tzd 301 | 302 | 303 | _tz_info_dict = _gen_tz_info_dict() 304 | 305 | 306 | def parse_date_tz(date): 307 | """ 308 | Attempts to parse the date, taking common timezone offsets into account. An 309 | aware or unaware datetime is returned on success, otherwise None. 310 | """ 311 | import dateutil.parser 312 | 313 | try: 314 | return dateutil.parser.parse(date, tzinfos=_tz_info_dict) 315 | except (AttributeError, ValueError): 316 | return 317 | 318 | 319 | def format_locals(exc_info): 320 | tb = exc_info[2] 321 | stack = [] 322 | 323 | message = '' 324 | 325 | while tb: 326 | stack.append(tb.tb_frame) 327 | tb = tb.tb_next 328 | 329 | message += 'Locals by frame, innermost last:\n' 330 | 331 | for frame in stack: 332 | message += '\nFrame %s in %s at line %s\n' % ( 333 | frame.f_code.co_name, 334 | frame.f_code.co_filename, 335 | frame.f_lineno, 336 | ) 337 | for key, value in frame.f_locals.items(): 338 | message += "\t%16s = " % key 339 | # We have to be careful not to cause a new error in our error 340 | # printer! Calling repr() on an unknown object could cause an error 341 | # we don't want. 342 | try: 343 | message += '%s\n' % repr(value) 344 | except Exception: 345 | message += "\n" 346 | 347 | return force_unicode(message) 348 | 349 | 350 | def force_unicode(s): 351 | """ Return a unicode object, no matter what the string is. """ 352 | 353 | if isinstance(s, str): 354 | return s 355 | try: 356 | return s.decode('utf8') 357 | except UnicodeDecodeError: 358 | # most common encoding, conersion shouldn't fail 359 | return s.decode('latin1') 360 | 361 | 362 | def slugify(text, separator='_'): 363 | import unidecode 364 | 365 | if isinstance(text, str): 366 | text = unidecode.unidecode(text) 367 | text = text.lower().strip() 368 | return re.sub(r'\W+', separator, text).strip(separator) 369 | 370 | 371 | def apply_recursively(obj, f): 372 | """ 373 | Applies a function to objects by traversing lists/tuples/dicts recursively. 374 | """ 375 | if isinstance(obj, (list, tuple)): 376 | return [apply_recursively(item, f) for item in obj] 377 | elif isinstance(obj, dict): 378 | return {k: apply_recursively(v, f) for k, v in obj.items()} 379 | elif obj is None: 380 | return None 381 | else: 382 | return f(obj) 383 | 384 | 385 | class Timeout(Exception): 386 | pass 387 | 388 | 389 | class Timer(object): 390 | """ 391 | Timer class with an optional signal timer. 392 | Raises a Timeout exception when the timeout occurs. 393 | When using timeouts, you must not nest this function nor call it in 394 | any thread other than the main thread. 395 | """ 396 | 397 | def __init__(self, timeout=None, timeout_message=''): 398 | self.timeout = timeout 399 | self.timeout_message = timeout_message 400 | 401 | if timeout: 402 | signal.signal(signal.SIGALRM, self._alarm_handler) 403 | 404 | def _alarm_handler(self, signum, frame): 405 | signal.signal(signal.SIGALRM, signal.SIG_IGN) 406 | raise Timeout(self.timeout_message) 407 | 408 | def __enter__(self): 409 | if self.timeout: 410 | signal.alarm(self.timeout) 411 | self.start = datetime.datetime.utcnow() 412 | return self 413 | 414 | def __exit__(self, *args): 415 | self.end = datetime.datetime.utcnow() 416 | delta = self.end - self.start 417 | self.interval = ( 418 | delta.days * 86400 + delta.seconds + delta.microseconds / 1000000.0 419 | ) 420 | if self.timeout: 421 | signal.alarm(0) 422 | signal.signal(signal.SIGALRM, signal.SIG_IGN) 423 | 424 | 425 | class ThreadedTimer(object): 426 | """ 427 | Timer class with an optional threaded timer. By default, interrupts the 428 | main thread with a KeyboardInterrupt. 429 | """ 430 | 431 | def __init__(self, timeout=None, on_timeout=None): 432 | self.timeout = timeout 433 | self.on_timeout = on_timeout or self._timeout_handler 434 | 435 | def _timeout_handler(self): 436 | import _thread 437 | 438 | _thread.interrupt_main() 439 | 440 | def __enter__(self): 441 | if self.timeout: 442 | self._timer = threading.Timer(self.timeout, self.on_timeout) 443 | self._timer.start() 444 | self.start = datetime.datetime.utcnow() 445 | return self 446 | 447 | def __exit__(self, *args): 448 | if self.timeout: 449 | self._timer.cancel() 450 | self.end = datetime.datetime.utcnow() 451 | delta = self.end - self.start 452 | self.interval = ( 453 | delta.days * 86400 + delta.seconds + delta.microseconds / 1000000.0 454 | ) 455 | 456 | 457 | def uniqify(seq, key=lambda i: i): 458 | """ 459 | Given an iterable, return a list of its unique elements, preserving the 460 | original order. For example: 461 | 462 | >>> uniqify([1, 2, 3, 1, 'a', None, 'a', 'b']) 463 | [1, 2, 3, 'a', None, 'b'] 464 | 465 | >>> uniqify([ { 'a': 1 }, { 'a': 2 }, { 'a': 1 } ]) 466 | [ { 'a': 1 }, { 'a': 2 } ] 467 | 468 | You can optionally specify a callable as the 'key' parameter which 469 | can extract or otherwise obtain a key from the items to use as the test for uniqueness. 470 | 471 | For example: 472 | >>> uniqify([dict(foo='bar', baz='qux'), dict(foo='grill', baz='qux')], key=lambda item: item['baz']) 473 | [ { 'foo': 'bar', 'baz': 'qux' } ] 474 | 475 | Note: This function doesn't work with nested dicts. 476 | """ 477 | seen = set() 478 | result = [] 479 | for x in seq: 480 | unique_key = key(x) 481 | if mongoengine and isinstance(unique_key, mongoengine.EmbeddedDocument): 482 | unique_key = unique_key.to_dict() 483 | if isinstance(unique_key, dict): 484 | unique_key = hash(frozenset(unique_key.items())) 485 | 486 | if unique_key not in seen: 487 | seen.add(unique_key) 488 | result.append(x) 489 | return result 490 | 491 | 492 | # NORMALIZATION UTILS # 493 | 494 | 495 | class FileFormatException(Exception): 496 | pass 497 | 498 | 499 | class Reader(object): 500 | """ 501 | Able to interpret files of the form: 502 | 503 | key => value1, value2 [this is the default case where one_to_many=True] 504 | OR 505 | value1, value2 => key [one_to_many=False] 506 | 507 | 508 | This is useful for cases where we want to normalize values such as: 509 | 510 | United States, United States of America, 'Merica, USA, U.S. => US 511 | 512 | Minnesota => MN 513 | 514 | Minnesota => MN, Minne 515 | 516 | This reader also can handle quoted values such as: 517 | 518 | "this => that" => "this", that 519 | 520 | """ 521 | 522 | def __init__(self, filename): 523 | self.reader = codecs.open(filename, 'r', 'utf-8') 524 | 525 | def __exit__(self): 526 | self.reader.close() 527 | 528 | def __iter__(self): 529 | return self 530 | 531 | @classmethod 532 | def split(cls, line, one_to_many=True): 533 | """ return key, values if one_to_many else return values, key """ 534 | 535 | def _get(value): 536 | one, two = value.split('=>', 1) 537 | return one.strip(), two.strip() 538 | 539 | s = io.StringIO(line) 540 | # http://stackoverflow.com/questions/6879596/why-is-the-python-csv-reader-ignoring-double-quoted-fields 541 | seq = [ 542 | x.strip() 543 | for x in next(unicode_csv_reader(s, skipinitialspace=True)) 544 | ] 545 | if not seq: 546 | raise FileFormatException("Line does not contain any valid data.") 547 | if one_to_many: 548 | key, value = _get(seq.pop(0)) 549 | seq.insert(0, value) 550 | return key, seq 551 | else: 552 | value, key = _get(seq.pop()) 553 | seq.append(value) 554 | return seq, key 555 | 556 | def next(self, one_to_many=True): 557 | return Reader.split(next(self.reader), one_to_many=one_to_many) 558 | 559 | 560 | class Normalization(object): 561 | """ list of strings => normalized form """ 562 | 563 | def __init__(self, keys, value): 564 | self.tokens = keys 565 | self.normalized_form = value 566 | 567 | def merge(self, normalization): 568 | self.tokens = list(set(self.tokens) | set(normalization.tokens)) 569 | 570 | 571 | class NormalizationReader(Reader): 572 | """ keys => value """ 573 | 574 | def __next__(self): 575 | return Normalization( 576 | *super(NormalizationReader, self).next(one_to_many=False) 577 | ) 578 | 579 | def next(self): 580 | """Alias for PY2""" 581 | return self.__next__() 582 | 583 | 584 | def build_normalization_map(filename, case_sensitive=False): 585 | normalizations = NormalizationReader(filename) 586 | return dict( 587 | list( 588 | itertools.chain.from_iterable( 589 | [ 590 | [ 591 | ( 592 | token if case_sensitive else token.lower(), 593 | normalization.normalized_form, 594 | ) 595 | for token in normalization.tokens 596 | ] 597 | for normalization in normalizations 598 | ] 599 | ) 600 | ) 601 | ) 602 | 603 | 604 | def truncate(text, size): 605 | """ 606 | Truncates the given text to the given size. If we are in the middle of 607 | a word, we will extend until the end of the word, e.g. 608 | 609 | >>> truncate('I can haz cheeseburgers', 9) 610 | 'I can haz' 611 | >>> truncate('I can haz cheeseburgers', 10) 612 | 'I can haz cheeseburgers' 613 | """ 614 | if text and text[size:].find(' ') != -1: 615 | return text[: size + text[size:].find(' ')] 616 | else: 617 | return text 618 | 619 | 620 | def combine(*lists): 621 | """ 622 | Generate all the combinations for multiple sets of words, e.g. 623 | 624 | >>> combine(['first'], ['communication', 'communicated'], ['', 'date']) 625 | ['first_communication', 626 | 'first_communication_date', 627 | 'first_communicated', 628 | 'first_communicated_date'] 629 | """ 630 | if len(lists) == 1: 631 | return lists[0] 632 | else: 633 | return [ 634 | '_'.join([s for s in p if s]) 635 | for p in itertools.product(lists[0], combine(*lists[1:])) 636 | ] 637 | 638 | 639 | def retry(func=None, exc=Exception, tries=1, wait=0): 640 | """ 641 | A way to retry a function call up to [tries] times if it throws 642 | a [exc] exception, with [wait] seconds in between. 643 | 644 | Can be used directly, or as a decorator factory. 645 | 646 | Example Usage 1: 647 | retry(unreliable_function, exc=ValueError, tries=5, wait=1) 648 | 649 | Example Usage 2 (passing args): 650 | retry(lambda x, y: unreliable_function(x, y), exc=ValueError, tries=5, wait=1) 651 | 652 | Example Usage 3 (as a decorator generator) 653 | @retry(exc=ValueError, tries=10, wait=0.3) 654 | def unreliable_function(foo): 655 | # ... 656 | unreliable_function('boy') 657 | """ 658 | 659 | def _retry(func): 660 | tries_left = tries 661 | while True: 662 | try: 663 | return func() 664 | except exc: 665 | tries_left -= 1 666 | if tries_left <= 0: 667 | raise 668 | time.sleep(wait) 669 | 670 | if func is None: 671 | # Being used as a decorator generator 672 | def retry_decorator(func): 673 | @wraps(func) 674 | def _decorated(*args, **kwargs): 675 | return _retry(lambda: func(*args, **kwargs)) 676 | 677 | return _decorated 678 | 679 | return retry_decorator 680 | else: 681 | # Being used directly 682 | return _retry(func) 683 | 684 | 685 | class lazylist(object): 686 | """ 687 | An object that can be iterated like a list, where the data is only loaded 688 | from the given function at the first iteration. 689 | """ 690 | 691 | def __init__(self, f): 692 | self.f = f 693 | self.data = None 694 | 695 | def __getitem__(self, key): 696 | if self.data is None: 697 | self.data = list(self.f()) 698 | return self.data[key] 699 | 700 | 701 | __all__ = [ 702 | 'CsvReader', 703 | 'CsvWriter', 704 | 'DetailedSMTPHandler', 705 | 'FileFormatException', 706 | 'NamedCsvReader', 707 | 'Reader', 708 | 'Normalization', 709 | 'NormalizationReader', 710 | 'ThreadedTimer', 711 | 'Timeout', 712 | 'Timer', 713 | 'apply_recursively', 714 | 'build_normalization_map', 715 | 'combine', 716 | 'finite_float', 717 | 'force_unicode', 718 | 'format_locals', 719 | 'json_list_generator', 720 | 'lazylist', 721 | 'localtoday', 722 | 'make_unaware', 723 | 'parse_date_tz', 724 | 'returns_xml', 725 | 'retry', 726 | 'slugify', 727 | 'smart_unicode', 728 | 'truncate', 729 | 'unicode_csv_reader', 730 | 'uniqify', 731 | 'utctime', 732 | 'utctoday', 733 | 'utf_8_encoder', 734 | ] 735 | -------------------------------------------------------------------------------- /flask_common/utils/lists.py: -------------------------------------------------------------------------------- 1 | def grouper(n, iterable): 2 | # e.g. 2, [1, 2, 3, 4, 5] -> [[1, 2], [3, 4], [5]] 3 | return [iterable[i : i + n] for i in range(0, len(iterable), n)] 4 | -------------------------------------------------------------------------------- /flask_common/utils/objects.py: -------------------------------------------------------------------------------- 1 | def freeze(x): 2 | """Convert dicts and lists to frozensets of key/index, value pairs, recursively. 3 | 4 | Good for using complex python data structures in sets or dict keys. 5 | It is idempotent in a sense that freeze(freeze(x)) == freeze(x). 6 | """ 7 | if isinstance(x, dict): 8 | return frozenset((k, freeze(v)) for k, v in x.items()) 9 | if isinstance(x, list): 10 | return frozenset(enumerate(freeze(e) for e in x)) 11 | return x 12 | 13 | 14 | def dict_with_class(obj): 15 | """Just like obj.__dict__, but includes data (non-function) class attributes.""" 16 | d = {} 17 | for cls in reversed(obj.__class__.__mro__): 18 | for k, v in cls.__dict__.items(): 19 | if not (k.startswith('_') or callable(v)): 20 | d[k] = v 21 | d.update(obj.__dict__) 22 | return d 23 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | skip-string-normalization = true 3 | line-length = 80 4 | exclude = ''' 5 | /( 6 | \.git 7 | | \.venv 8 | | venv 9 | )/ 10 | ''' 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python-dateutil==2.5.0 2 | pytz==2015.7 3 | flask>=1 4 | -e git+ssh://git@github.com/closeio/mongoengine.git@7f0c4b85e375f4eb756932bcb6414daf7b993247#egg=mongoengine 5 | -e git+ssh://git@github.com/closeio/flask-mongoengine.git@0c2cc30ec98154bb2eb4499efada65239050025d#egg=flask-mongoengine 6 | Flask-SQLAlchemy==2.1 7 | phonenumbers==8.8.7 8 | cryptography==3.2.1 9 | padding==0.4 10 | Unidecode==0.4.19 11 | -e git+ssh://git@github.com/closeio/zbase62.git@e13d2c748ccdb0cafe6465961a0c6a4111ee219f#egg=zbase62 12 | pymongo==3.4.0 13 | -------------------------------------------------------------------------------- /requirements_lint.txt: -------------------------------------------------------------------------------- 1 | entrypoints==0.3 2 | flake8==3.7.7 3 | pycodestyle==2.5.0 4 | mccabe==0.6.1 5 | pyflakes==2.1.1 6 | flake8-tidy-imports==2.0.0 7 | pep8-naming==0.8.2 8 | flake8_polyfill==1.0.2 9 | 10 | isort==4.3.4 11 | 12 | mypy==0.641 13 | mypy-extensions==0.4.1 14 | typed-ast==1.1.0 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name='flask-common', 5 | version='0.1', 6 | url='http://github.com/closeio/flask-common', 7 | license='MIT', 8 | description='Close.io internal flask helpers', 9 | platforms='any', 10 | classifiers=[ 11 | 'Intended Audience :: Developers', 12 | 'Operating System :: OS Independent', 13 | 'Topic :: Software Development :: Libraries :: Python Modules', 14 | 'Programming Language :: Python', 15 | 'Programming Language :: Python :: 2', 16 | ], 17 | packages=find_packages(), 18 | test_suite='tests', 19 | tests_require=[ 20 | 'python-dateutil', 21 | 'pytz', 22 | 'flask', 23 | 'mongoengine', 24 | 'cryptography', 25 | 'padding', 26 | 'pytest', 27 | ], 28 | ) 29 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closeio/flask-common/a3893e9f2bc1801d7e8557aef3f3e3f26811d398/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_client.py: -------------------------------------------------------------------------------- 1 | from flask_common.client import ApiClient 2 | from flask import Flask 3 | from werkzeug.datastructures import Headers 4 | 5 | 6 | def test_api_client_basic_auth(): 7 | app = Flask('test') 8 | client = ApiClient(app, api_key='123456') 9 | 10 | headers = client.get_headers(client.api_key) 11 | assert headers == Headers( 12 | [ 13 | ( 14 | 'Authorization', 15 | 'Basic MTIzNDU2Og==', 16 | ) 17 | ] 18 | ) 19 | 20 | 21 | def test_json(): 22 | app = Flask('test') 23 | 24 | @app.route('/', methods=['GET']) 25 | def view(): 26 | return {'ok': 1} 27 | 28 | client = ApiClient(app) 29 | assert client.get('/').json['ok'] == 1 30 | -------------------------------------------------------------------------------- /tests/test_crypto.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from flask_common.crypto import ( 4 | AuthenticationError, 5 | EncryptionError, 6 | aes_decrypt, 7 | aes_encrypt, 8 | aes_generate_key, 9 | ) 10 | 11 | 12 | def test_with_v1_data(): 13 | data = b'test' 14 | key = aes_generate_key() 15 | encrypted_data = aes_encrypt(key, data) 16 | assert encrypted_data[0:1] == b'\x01' 17 | assert aes_decrypt(key, encrypted_data) == data 18 | 19 | 20 | def test_with_v1_corrupted_data(): 21 | data = b'test' 22 | key = aes_generate_key() 23 | encrypted_data = aes_encrypt(key, data) 24 | assert encrypted_data[0:1] == b'\x01' 25 | corrupted_encrypted_data = encrypted_data[:-3] 26 | with pytest.raises(AuthenticationError) as excinfo: 27 | aes_decrypt(key, corrupted_encrypted_data) 28 | assert str(excinfo.value) == "message authentication failed" 29 | 30 | 31 | def test_with_data_exactly_as_long_as_aes_block(): 32 | data = b'a' * 128 33 | key = aes_generate_key() 34 | assert aes_decrypt(key, aes_encrypt(key, data)) == data 35 | 36 | 37 | def test_with_data_longer_than_aes_block(): 38 | data = b'a' * 130 39 | key = aes_generate_key() 40 | assert aes_decrypt(key, aes_encrypt(key, data)) == data 41 | 42 | 43 | def test_data_encrypted_twice_is_different(): 44 | data = b'test' 45 | key = aes_generate_key() 46 | 47 | first_encryption = aes_encrypt(key, data) 48 | second_encryption = aes_encrypt(key, data) 49 | assert first_encryption != second_encryption 50 | 51 | 52 | def test_with_invalid_version(): 53 | key = b']\x1a\xa2\nW\x97\xab)\x951\xa8t\x8b\xd8\xac\x08\xebjlY\xd0S\x90d\xcc\rR\x1f\xbf\x13\xe0:\xb5\x7f\xbf\xa7\x83|\x10bQ\x03\xd3Z]\xea\x1f2\xf6tB\x13\xaeP\xcc\x8fb\xabY\xda#\xe9QE' 54 | encrypted_data = b'\x00M\xcdjP\xfd\xcc\xa1\xd7\xda\x11(Q \xbd\xe4w\n\x03C\x14!\x99N\xe8\xf0H\xbc\xf8\xf41\xa5\x10E\x0e\xbc\x04\x01\x85\x0b\xd5F\x1bq>\x12\x04\x11Y\x10\x8f\x0f\x06' 55 | with pytest.raises(EncryptionError) as excinfo: 56 | aes_decrypt(key, encrypted_data) 57 | assert str(excinfo.value) == "Found invalid version marker: {!r}".format( 58 | b'\x00' 59 | ) 60 | -------------------------------------------------------------------------------- /tests/test_declenum.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from flask_common.declenum import DeclEnum 4 | 5 | 6 | class DeclEnumTestCase(unittest.TestCase): 7 | # TODO pytest-ify 8 | 9 | def test_enum(self): 10 | class TestEnum(DeclEnum): 11 | alpha = 'alpha_value', 'Alpha Description' 12 | beta = 'beta_value', 'Beta Description' 13 | 14 | assert TestEnum.alpha != TestEnum.beta 15 | assert TestEnum.alpha.value == 'alpha_value' 16 | assert TestEnum.alpha.description == 'Alpha Description' 17 | assert TestEnum.from_string('alpha_value') == TestEnum.alpha 18 | 19 | db_type = TestEnum.db_type() 20 | self.assertEqual( 21 | set(db_type.enum.values()), set(['alpha_value', 'beta_value']) 22 | ) 23 | -------------------------------------------------------------------------------- /tests/test_enum.py: -------------------------------------------------------------------------------- 1 | from flask_common.enum import Enum 2 | 3 | 4 | class TestEnum(Enum): 5 | A = 'a' 6 | B = 'b' 7 | 8 | 9 | def test_enum(): 10 | # Fetch twice to ensure cache is correct 11 | assert TestEnum.values() == ['a', 'b'] 12 | assert TestEnum.values() == ['a', 'b'] 13 | assert TestEnum.choices() == [('a', 'A'), ('b', 'B')] 14 | assert TestEnum.choices() == [('a', 'A'), ('b', 'B')] 15 | -------------------------------------------------------------------------------- /tests/test_formfields.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import unittest 3 | 4 | from dateutil.tz import tzutc 5 | from werkzeug.datastructures import MultiDict 6 | from wtforms import Form 7 | 8 | from flask_common.formfields import BetterDateTimeField 9 | 10 | 11 | class FormFieldTestCase(unittest.TestCase): 12 | # TODO pytest-ify 13 | 14 | def test_datetime_field(self): 15 | class TestForm(Form): 16 | date = BetterDateTimeField() 17 | 18 | form = TestForm(MultiDict({'date': ''})) 19 | self.assertTrue(form.validate()) 20 | self.assertEqual(form.data['date'], None) 21 | 22 | form = TestForm(MultiDict({'date': 'invalid'})) 23 | self.assertFalse(form.validate()) 24 | 25 | form = TestForm(MultiDict({'date': '2012-09-06T01:29:14.107000+00:00'})) 26 | self.assertTrue(form.validate()) 27 | self.assertEqual( 28 | form.data['date'], 29 | datetime.datetime(2012, 9, 6, 1, 29, 14, 107000, tzinfo=tzutc()), 30 | ) 31 | 32 | form = TestForm(MultiDict({'date': '2012-09-06 01:29:14'})) 33 | self.assertTrue(form.validate()) 34 | self.assertEqual( 35 | form.data['date'], datetime.datetime(2012, 9, 6, 1, 29, 14) 36 | ) 37 | -------------------------------------------------------------------------------- /tests/test_legacy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import unittest 4 | 5 | from flask import Flask 6 | from mongoengine import Document, ReferenceField, SafeReferenceListField 7 | 8 | from flask_mongoengine import MongoEngine 9 | from flask_common.utils import apply_recursively, slugify, uniqify 10 | 11 | 12 | app = Flask(__name__) 13 | 14 | app.config.update( 15 | DEBUG=True, 16 | TESTING=True, 17 | MONGODB_HOST='localhost', 18 | MONGODB_PORT='27017', 19 | MONGODB_DB='common_example_app', 20 | ) 21 | 22 | db = MongoEngine(app) 23 | 24 | 25 | class SafeReferenceListFieldTestCase(unittest.TestCase): 26 | # TODO this is a mongoengine field and it should be tested in that package, 27 | # not here. 28 | 29 | def test_safe_reference_list_field(self): 30 | class Book(Document): 31 | pass 32 | 33 | class Author(Document): 34 | books = SafeReferenceListField(ReferenceField(Book)) 35 | 36 | Author.drop_collection() 37 | Book.drop_collection() 38 | 39 | b1 = Book.objects.create() 40 | b2 = Book.objects.create() 41 | 42 | a = Author.objects.create(books=[b1, b2]) 43 | a.reload() 44 | self.assertEqual(a.books, [b1, b2]) 45 | 46 | b1.delete() 47 | a.reload() 48 | self.assertEqual(a.books, [b2]) 49 | 50 | b3 = Book.objects.create() 51 | a.books.append(b3) 52 | a.save() 53 | a.reload() 54 | self.assertEqual(a.books, [b2, b3]) 55 | 56 | b2.delete() 57 | b3.delete() 58 | a.reload() 59 | self.assertEqual(a.books, []) 60 | 61 | 62 | class ApplyRecursivelyTestCase(unittest.TestCase): 63 | def test_none(self): 64 | self.assertEqual(apply_recursively(None, lambda n: n + 1), None) 65 | 66 | def test_list(self): 67 | self.assertEqual( 68 | apply_recursively([1, 2, 3], lambda n: n + 1), [2, 3, 4] 69 | ) 70 | 71 | def test_nested_tuple(self): 72 | self.assertEqual( 73 | apply_recursively([(1, 2), (3, 4)], lambda n: n + 1), 74 | [[2, 3], [4, 5]], 75 | ) 76 | 77 | def test_nested_dict(self): 78 | self.assertEqual( 79 | apply_recursively( 80 | [{'a': 1, 'b': [2, 3], 'c': {'d': 4, 'e': None}}, 5], 81 | lambda n: n + 1, 82 | ), 83 | [{'a': 2, 'b': [3, 4], 'c': {'d': 5, 'e': None}}, 6], 84 | ) 85 | 86 | 87 | class SlugifyTestCase(unittest.TestCase): 88 | def test_slugify(self): 89 | self.assertEqual(slugify(' Foo ???BAR\t\n\r'), 'foo_bar') 90 | self.assertEqual(slugify(u'äąé öóü', '-'), 'aae-oou') 91 | 92 | 93 | class UtilsTestCase(unittest.TestCase): 94 | def test_uniqify(self): 95 | self.assertEqual( 96 | uniqify([1, 2, 3, 1, 'a', None, 'a', 'b']), 97 | [1, 2, 3, 'a', None, 'b'], 98 | ) 99 | self.assertEqual( 100 | uniqify([{'a': 1}, {'a': 2}, {'a': 1}]), [{'a': 1}, {'a': 2}] 101 | ) 102 | self.assertEqual( 103 | uniqify( 104 | [{'a': 1, 'b': 3}, {'a': 2, 'b': 2}, {'a': 1, 'b': 1}], 105 | key=lambda i: i['a'], 106 | ), 107 | [{'a': 1, 'b': 3}, {'a': 2, 'b': 2}], 108 | ) 109 | 110 | 111 | if __name__ == '__main__': 112 | unittest.main() 113 | -------------------------------------------------------------------------------- /tests/test_mongo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closeio/flask-common/a3893e9f2bc1801d7e8557aef3f3e3f26811d398/tests/test_mongo/__init__.py -------------------------------------------------------------------------------- /tests/test_mongo/test_documents.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import time 3 | import unittest 4 | 5 | from mongoengine import Document, ReferenceField, StringField, ValidationError 6 | 7 | from flask_common.mongo.documents import ( 8 | DocumentBase, 9 | RandomPKDocument, 10 | SoftDeleteDocument, 11 | ) 12 | 13 | 14 | class DocumentBaseTestCase(unittest.TestCase): 15 | def test_cls_inheritance(self): 16 | """ 17 | Make sure _cls is not appended to queries and indexes and that 18 | allow_inheritance is disabled by default for docs inheriting from 19 | RandomPKDocument and DocumentBase. 20 | """ 21 | 22 | class Doc(DocumentBase, RandomPKDocument): 23 | text = StringField() 24 | 25 | self.assertEqual(Doc.objects.filter(text='')._query, {'text': ''}) 26 | self.assertFalse(Doc._meta['allow_inheritance']) 27 | 28 | def test_pk_validation(self): 29 | """ 30 | Make sure that you cannot save crap in a ReferenceField that 31 | references a RandomPKDocument. 32 | """ 33 | 34 | class A(RandomPKDocument): 35 | text = StringField() 36 | 37 | class B(Document): 38 | ref = ReferenceField(A) 39 | 40 | self.assertRaises(ValidationError, B.objects.create, ref={'dict': True}) 41 | 42 | def test_document_base_date_updated(self): 43 | """ 44 | Make sure a class inheriting from DocumentBase correctly handles 45 | updates to date_updated. 46 | """ 47 | 48 | class Doc(DocumentBase, RandomPKDocument): 49 | text = StringField() 50 | 51 | doc = Doc.objects.create(text='aaa') 52 | doc.reload() 53 | last_date_created = doc.date_created 54 | last_date_updated = doc.date_updated 55 | 56 | time.sleep(0.001) # make sure some time passes between the updates 57 | doc.text = 'new' 58 | doc.save() 59 | doc.reload() 60 | 61 | self.assertEqual(doc.date_created, last_date_created) 62 | self.assertTrue(doc.date_updated > last_date_updated) 63 | last_date_updated = doc.date_updated 64 | 65 | time.sleep(0.001) # make sure some time passes between the updates 66 | doc.update(set__text='newer') 67 | doc.reload() 68 | 69 | self.assertEqual(doc.date_created, last_date_created) 70 | self.assertTrue(doc.date_updated > last_date_updated) 71 | last_date_updated = doc.date_updated 72 | 73 | time.sleep(0.001) # make sure some time passes between the updates 74 | doc.update(set__date_created=datetime.datetime.utcnow()) 75 | doc.reload() 76 | 77 | self.assertTrue(doc.date_created > last_date_created) 78 | self.assertTrue(doc.date_updated > last_date_updated) 79 | last_date_created = doc.date_created 80 | last_date_updated = doc.date_updated 81 | 82 | new_date_created = datetime.datetime(2014, 6, 12) 83 | new_date_updated = datetime.datetime(2014, 10, 12) 84 | time.sleep(0.001) # make sure some time passes between the updates 85 | doc.update( 86 | set__date_created=new_date_created, 87 | set__date_updated=new_date_updated, 88 | ) 89 | doc.reload() 90 | 91 | self.assertEqual( 92 | doc.date_created.replace(tzinfo=None), new_date_created 93 | ) 94 | self.assertEqual( 95 | doc.date_updated.replace(tzinfo=None), new_date_updated 96 | ) 97 | 98 | time.sleep(0.001) # make sure some time passes between the updates 99 | doc.update(set__text='newest', update_date=False) 100 | doc.reload() 101 | 102 | self.assertEqual(doc.text, 'newest') 103 | self.assertEqual( 104 | doc.date_created.replace(tzinfo=None), new_date_created 105 | ) 106 | self.assertEqual( 107 | doc.date_updated.replace(tzinfo=None), new_date_updated 108 | ) 109 | 110 | 111 | class SoftDeleteDocumentTestCase(unittest.TestCase): 112 | class Person(DocumentBase, RandomPKDocument, SoftDeleteDocument): 113 | name = StringField() 114 | 115 | meta = {'allow_inheritance': True} 116 | 117 | class Programmer(Person): 118 | language = StringField() 119 | 120 | def setUp(self): 121 | self.Person.drop_collection() 122 | self.Programmer.drop_collection() 123 | 124 | def test_default_is_deleted(self): 125 | """Make sure is_deleted is never null.""" 126 | s = self.Person.objects.create(name='Steve') 127 | self.assertEqual(s.reload()._db_data['is_deleted'], False) 128 | 129 | def _bad_update(): 130 | s.update(set__is_deleted=None) 131 | 132 | self.assertRaises(ValidationError, _bad_update) 133 | 134 | def test_queryset_manager(self): 135 | a = self.Person.objects.create(name='Anthony') 136 | 137 | # test all the ways to filter/aggregate counts 138 | self.assertEqual(len(self.Person.objects.all()), 1) 139 | self.assertEqual(self.Person.objects.all().count(), 1) 140 | self.assertEqual(self.Person.objects.filter(name='Anthony').count(), 1) 141 | self.assertEqual(self.Person.objects.count(), 1) 142 | 143 | a.delete() 144 | self.assertEqual(len(self.Person.objects.all()), 0) 145 | self.assertEqual(self.Person.objects.all().count(), 0) 146 | self.assertEqual(self.Person.objects.filter(name='Anthony').count(), 0) 147 | self.assertEqual(self.Person.objects.count(), 0) 148 | 149 | self.assertEqual(len(self.Person.objects.filter(name='Anthony')), 0) 150 | a.is_deleted = False 151 | a.save() 152 | self.assertEqual(len(self.Person.objects.filter(name='Anthony')), 1) 153 | 154 | b = self.Programmer.objects.create(name='Thomas', language='python.net') 155 | self.assertEqual(len(self.Programmer.objects.all()), 1) 156 | b.delete() 157 | self.assertEqual(len(self.Programmer.objects.all()), 0) 158 | 159 | self.assertEqual(len(self.Programmer.objects.filter(name='Thomas')), 0) 160 | b.is_deleted = False 161 | b.save() 162 | self.assertEqual(len(self.Programmer.objects.filter(name='Thomas')), 1) 163 | 164 | def test_date_updated(self): 165 | a = self.Person.objects.create(name='Anthony') 166 | a.reload() 167 | last_date_updated = a.date_updated 168 | 169 | time.sleep(0.001) # make sure some time passes between the updates 170 | a.update(set__name='Tony') 171 | a.reload() 172 | 173 | self.assertTrue(a.date_updated > last_date_updated) 174 | last_date_updated = a.date_updated 175 | 176 | time.sleep(0.001) # make sure some time passes between the updates 177 | a.delete() 178 | a.reload() 179 | 180 | self.assertTrue(a.date_updated > last_date_updated) 181 | self.assertEqual(a.is_deleted, True) 182 | -------------------------------------------------------------------------------- /tests/test_mongo/test_fields/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/closeio/flask-common/a3893e9f2bc1801d7e8557aef3f3e3f26811d398/tests/test_mongo/test_fields/__init__.py -------------------------------------------------------------------------------- /tests/test_mongo/test_fields/test_basic.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from mongoengine import ( 4 | Document, 5 | EmbeddedDocument, 6 | EmbeddedDocumentField, 7 | ValidationError, 8 | NotUniqueError, 9 | ) 10 | 11 | from flask_common.mongo.fields import ( 12 | LowerEmailField, 13 | LowerStringField, 14 | TrimmedStringField, 15 | ) 16 | 17 | 18 | class TrimmedStringFieldTestCase(unittest.TestCase): 19 | # TODO pytest-ify and test the field instance directly without persistence. 20 | 21 | def test_trimmedstring_field(self): 22 | class Person(Document): 23 | name = TrimmedStringField(required=True) 24 | comment = TrimmedStringField() 25 | 26 | Person.drop_collection() 27 | 28 | person = Person(name='') 29 | self.assertRaises(ValidationError, person.save) 30 | 31 | person = Person(name=' ') 32 | self.assertRaises(ValidationError, person.save) 33 | 34 | person = Person(name=' 1', comment='') 35 | person.save() 36 | self.assertEqual(person.name, '1') 37 | self.assertEqual(person.comment, '') 38 | 39 | person = Person(name=' big name', comment=' this is a comment') 40 | person.save() 41 | self.assertEqual(person.name, 'big name') 42 | self.assertEqual(person.comment, 'this is a comment') 43 | 44 | 45 | class LowerStringFieldTestCase(unittest.TestCase): 46 | # TODO pytest-ify and test the field instance directly without persistence. 47 | 48 | def test_case_insensitive_query(self): 49 | class Test(Document): 50 | field = LowerStringField() 51 | 52 | Test.drop_collection() 53 | 54 | Test(field='whatever').save() 55 | 56 | obj1 = Test.objects.get(field='whatever') 57 | obj2 = Test.objects.get(field='WHATEVER') 58 | 59 | self.assertEqual(obj1, obj2) 60 | 61 | Test.drop_collection() 62 | 63 | def test_case_insensitive_uniqueness(self): 64 | class Test(Document): 65 | field = LowerStringField(unique=True) 66 | 67 | Test.drop_collection() 68 | Test.ensure_indexes() 69 | 70 | Test(field='whatever').save() 71 | self.assertRaises(NotUniqueError, Test(field='WHATEVER').save) 72 | 73 | 74 | class LowerEmailFieldTestCase(unittest.TestCase): 75 | # TODO pytest-ify and test the field instance directly without persistence. 76 | 77 | def test_email_validation(self): 78 | class Test(Document): 79 | email = LowerEmailField() 80 | 81 | Test.drop_collection() 82 | 83 | Test(email='valid@email.com').save() 84 | self.assertRaises(ValidationError, Test(email='invalid email').save) 85 | 86 | def test_case_insensitive_querying(self): 87 | class Test(Document): 88 | email = LowerEmailField() 89 | 90 | Test.drop_collection() 91 | 92 | obj = Test(email='valid@email.com') 93 | obj.save() 94 | 95 | self.assertEqual(Test.objects.get(email='valid@email.com'), obj) 96 | self.assertEqual(Test.objects.get(email='VALID@EMAIL.COM'), obj) 97 | self.assertEqual(Test.objects.get(email__in=['VALID@EMAIL.COM']), obj) 98 | self.assertEqual( 99 | Test.objects.get(email__nin=['different@email.com']), obj 100 | ) 101 | self.assertEqual( 102 | Test.objects.filter(email__ne='VALID@EMAIL.COM').count(), 0 103 | ) 104 | 105 | def test_lower_field_in_embedded_doc(self): 106 | class EmbeddedDoc(EmbeddedDocument): 107 | email = LowerEmailField() 108 | 109 | class Test(Document): 110 | embedded = EmbeddedDocumentField(EmbeddedDoc) 111 | 112 | Test.drop_collection() 113 | 114 | obj = Test(embedded=EmbeddedDoc(email='valid@email.com')) 115 | obj.save() 116 | 117 | self.assertTrue( 118 | obj 119 | in Test.objects.filter( 120 | embedded__email__in=['VALID@EMAIL.COM', 'whatever'] 121 | ) 122 | ) 123 | -------------------------------------------------------------------------------- /tests/test_mongo/test_fields/test_crypto.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | 4 | import pytest 5 | 6 | from flask_common.crypto import AuthenticationError, aes_generate_key 7 | from flask_common.mongo.fields import ( 8 | EncryptedBinaryField, 9 | EncryptedStringField, 10 | ) 11 | 12 | 13 | def test_encrypted_binary_field_can_encrypt_and_decrypt(): 14 | token = EncryptedBinaryField(aes_generate_key()) 15 | assert token.to_python(token.to_mongo(b'\x00\x01')) == b'\x00\x01' 16 | 17 | 18 | def test_encrypted_binary_field_can_rotate(): 19 | key_1 = aes_generate_key() 20 | token = EncryptedBinaryField(key_1) 21 | encrypted_data = token.to_mongo(b'\x00\x01') 22 | 23 | key_2 = aes_generate_key() 24 | token = EncryptedBinaryField([key_2, key_1]) 25 | assert token.to_python(encrypted_data) == b'\x00\x01' 26 | 27 | 28 | def test_encrypted_binary_field_will_fail_on_corrupted_data(): 29 | key_1 = aes_generate_key() 30 | token = EncryptedBinaryField(key_1) 31 | corrupted_encrypted_data = token.to_mongo(b'\x00\x01')[:3] 32 | with pytest.raises(AuthenticationError) as excinfo: 33 | token.to_python(corrupted_encrypted_data) 34 | assert str(excinfo.value) == 'message authentication failed' 35 | 36 | 37 | def test_encrypted_binary_field_with_none(): 38 | token = EncryptedBinaryField(aes_generate_key()) 39 | assert token.to_python(token.to_mongo(None)) is None 40 | 41 | 42 | def test_encrypted_string_field_works_with_unicode_data(): 43 | token = EncryptedStringField(aes_generate_key()) 44 | assert token.to_python(token.to_mongo(u'ãé')) == u'ãé' 45 | -------------------------------------------------------------------------------- /tests/test_mongo/test_fields/test_phone.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from mongoengine import Document, ValidationError 4 | 5 | from flask_common.mongo.fields import PhoneField 6 | 7 | 8 | class PhoneFieldTestCase(unittest.TestCase): 9 | # TODO pytest-ify and test the field instance directly without persistence. 10 | 11 | def test_format_number(self): 12 | class Person(Document): 13 | phone = PhoneField() 14 | 15 | Person.drop_collection() 16 | 17 | person = Person(phone='14151231234') 18 | assert person.phone == '14151231234' 19 | 20 | person.phone = 'notaphone' 21 | assert person.phone == 'notaphone' 22 | self.assertRaises(ValidationError, person.validate) 23 | self.assertRaises(ValidationError, person.save) 24 | 25 | person.phone = '+1 (650) 618 - 1234 x 768' 26 | assert person.phone == '+16506181234x768' 27 | person.validate() 28 | person.save() 29 | 30 | assert person.id == Person.objects.get(phone='+16506181234x768').id 31 | assert ( 32 | person.id == Person.objects.get(phone='+1 650-618-1234 ext 768').id 33 | ) 34 | 35 | def test_strict_format_number(self): 36 | class Person(Document): 37 | phone = PhoneField(strict=True) 38 | 39 | Person.drop_collection() 40 | 41 | person = Person(phone='12223334444') 42 | self.assertRaises(ValidationError, person.validate) 43 | self.assertRaises(ValidationError, person.save) 44 | 45 | person = Person(phone='+6594772797') 46 | assert person.phone == '+6594772797' 47 | 48 | person.save() 49 | -------------------------------------------------------------------------------- /tests/test_mongo/test_fields/test_tz.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import pytz 4 | from mongoengine import Document 5 | 6 | from flask_common.mongo.fields import TimezoneField 7 | 8 | 9 | class TimezoneFieldTestCase(unittest.TestCase): 10 | # TODO pytest-ify and test the field instance directly without persistence. 11 | 12 | def test_timezone_field(self): 13 | class Location(Document): 14 | timezone = TimezoneField() 15 | 16 | Location.drop_collection() 17 | 18 | location = Location() 19 | location.save() 20 | location = Location.objects.get(pk=location.pk) 21 | assert location.timezone == pytz.UTC 22 | location.timezone = 'America/Los_Angeles' 23 | location.save() 24 | location = Location.objects.get(pk=location.pk) 25 | assert location.timezone == pytz.timezone('America/Los_Angeles') 26 | -------------------------------------------------------------------------------- /tests/test_mongo/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import weakref 3 | 4 | from mongoengine import ( 5 | Document, 6 | DoesNotExist, 7 | IntField, 8 | ReferenceField, 9 | SafeReferenceField, 10 | SafeReferenceListField, 11 | StringField, 12 | ) 13 | 14 | from flask_common.mongo.query_counters import custom_query_counter 15 | from flask_common.mongo.utils import fetch_related, iter_no_cache 16 | 17 | 18 | class IterNoCacheTestCase(unittest.TestCase): 19 | def test_no_cache(self): 20 | def is_cached(qs): 21 | iterator = iter(qs) 22 | d = next(iterator) 23 | self.assertEqual(d.i, 0) 24 | w = weakref.ref(d) 25 | d = next(iterator) 26 | self.assertEqual(d.i, 1) 27 | # - If the weak reference is still valid at this point, then 28 | # iterator or queryset is holding onto the first object 29 | # - Hold reference to qs until very end just in case 30 | # Python gets smart enough to destroy it 31 | return w() is not None and qs is not None 32 | 33 | class D(Document): 34 | i = IntField() 35 | pass 36 | 37 | D.drop_collection() 38 | 39 | for i in range(10): 40 | D(i=i).save() 41 | 42 | self.assertTrue(is_cached(D.objects.all())) 43 | self.assertFalse(is_cached(iter_no_cache(D.objects.all()))) 44 | 45 | # check for correct exit behavior 46 | self.assertEqual( 47 | {d.i for d in iter_no_cache(D.objects.all())}, set(range(10)) 48 | ) 49 | self.assertEqual( 50 | {d.i for d in iter_no_cache(D.objects.all().batch_size(5))}, 51 | set(range(10)), 52 | ) 53 | self.assertEqual( 54 | {d.i for d in iter_no_cache(D.objects.order_by('i').limit(1))}, 55 | set(range(1)), 56 | ) 57 | 58 | 59 | class FetchRelatedTestCase(unittest.TestCase): 60 | def setUp(self): 61 | super(FetchRelatedTestCase, self).setUp() 62 | 63 | class Shard(Document): 64 | pass 65 | 66 | class A(Document): 67 | shard_a = ReferenceField(Shard) 68 | txt = StringField() 69 | 70 | class B(Document): 71 | shard_b = ReferenceField(Shard) 72 | ref = ReferenceField(A) 73 | 74 | class C(Document): 75 | shard_c = ReferenceField(Shard) 76 | ref_a = ReferenceField(A) 77 | 78 | class D(Document): 79 | shard_d = ReferenceField(Shard) 80 | ref_c = ReferenceField(C) 81 | ref_a = ReferenceField(A) 82 | 83 | class E(Document): 84 | shard_e = ReferenceField(Shard) 85 | refs_a = SafeReferenceListField(ReferenceField(A)) 86 | ref_b = SafeReferenceField(B) 87 | 88 | class F(Document): 89 | shard_f = ReferenceField(Shard) 90 | ref_a = ReferenceField(A) 91 | 92 | A.drop_collection() 93 | B.drop_collection() 94 | C.drop_collection() 95 | D.drop_collection() 96 | E.drop_collection() 97 | F.drop_collection() 98 | 99 | self.Shard = Shard 100 | self.A = A 101 | self.B = B 102 | self.C = C 103 | self.D = D 104 | self.E = E 105 | self.F = F 106 | 107 | self.shard = Shard.objects.create() 108 | self.a1 = A.objects.create(shard_a=self.shard, txt='a1') 109 | self.a2 = A.objects.create(shard_a=self.shard, txt='a2') 110 | self.a3 = A.objects.create(shard_a=self.shard, txt='a3') 111 | self.b1 = B.objects.create(shard_b=self.shard, ref=self.a1) 112 | self.b2 = B.objects.create(shard_b=self.shard, ref=self.a2) 113 | self.c1 = C.objects.create(shard_c=self.shard, ref_a=self.a3) 114 | self.d1 = D.objects.create( 115 | shard_d=self.shard, ref_c=self.c1, ref_a=self.a3 116 | ) 117 | self.e1 = E.objects.create( 118 | shard_e=self.shard, 119 | refs_a=[self.a1, self.a2, self.a3], 120 | ref_b=self.b1, 121 | ) 122 | self.f1 = F.objects.create(shard_f=self.shard, ref_a=None) # empty ref 123 | 124 | def test_fetch_related(self): 125 | with custom_query_counter() as q: 126 | objs = list(self.B.objects.all()) 127 | fetch_related(objs, {'ref': True}) 128 | 129 | # make sure A objs are fetched 130 | for obj in objs: 131 | self.assertTrue(obj.ref.txt in ('a1', 'a2')) 132 | 133 | # one query for B, one query for A 134 | self.assertEqual(q, 2) 135 | 136 | def test_fetch_related_multiple_objs(self): 137 | with custom_query_counter() as q: 138 | objs = list(self.B.objects.all()) + list(self.C.objects.all()) 139 | fetch_related(objs, {'ref': True, 'ref_a': True}) 140 | 141 | # make sure A objs are fetched 142 | for obj in objs: 143 | if isinstance(obj, self.B): 144 | self.assertTrue(obj.ref.txt in ('a1', 'a2')) 145 | else: 146 | self.assertEqual(obj.ref_a.txt, 'a3') 147 | 148 | # one query for B, one for C, one for A 149 | self.assertEqual(q, 3) 150 | 151 | def test_fetch_related_subdict(self): 152 | """ 153 | Make sure fetching related references works with subfields and that 154 | it uses caching properly. 155 | """ 156 | with custom_query_counter() as q: 157 | objs = list(self.D.objects.all()) 158 | fetch_related(objs, {'ref_a': True, 'ref_c': {'ref_a': True}}) 159 | 160 | # make sure A objs are fetched 161 | for obj in objs: 162 | self.assertEqual(obj.ref_a.txt, 'a3') 163 | self.assertEqual(obj.ref_c.ref_a.txt, 'a3') 164 | 165 | # one query for D, one query for C, one query for A 166 | self.assertEqual(q, 3) 167 | 168 | def test_fetch_related_subdict_broken_reference(self): 169 | """ 170 | Make sure that fetching sub-references of a broken reference works. 171 | """ 172 | 173 | # delete the object referenced by self.d1.ref_c 174 | self.c1.delete() 175 | 176 | objs = list(self.D.objects.all()) 177 | fetch_related(objs, {'ref_c': {'ref_a': True}}) 178 | self.assertTrue( 179 | objs[0].ref_c.pk 180 | ) # pk still exists even though the reference is broken 181 | self.assertRaises(DoesNotExist, lambda: objs[0].ref_c.ref_a) 182 | 183 | def test_partial_fetch_related(self): 184 | """ 185 | Make sure we can only fetch particular fields of a reference. 186 | """ 187 | objs = list(self.B.objects.all()) 188 | fetch_related(objs, {'ref': ["id"]}) 189 | self.assertEqual(objs[0].ref.pk, self.a1.pk) 190 | 191 | # "txt" field of the referenced object shouldn't be fetched 192 | self.assertEqual(objs[0].ref.txt, None) 193 | self.assertTrue(self.a1.txt) 194 | 195 | def test_partial_fetch_fields_conflict(self): 196 | """ 197 | Fetching certain fields via fetch_related has a limitation that 198 | different fields cannot be fetched for the same document class. 199 | Make sure that contraint is respected. 200 | """ 201 | objs = list(self.B.objects.all()) + list(self.C.objects.all()) 202 | self.assertRaises( 203 | RuntimeError, fetch_related, objs, {'ref': ["id"], 'ref_a': True} 204 | ) 205 | 206 | def test_partial_fetch_cache_map(self): 207 | """ 208 | Make sure doing a partial fetch in fetch_related doesn't cache 209 | the results (it could be dangerous for any subsequent fetch_related 210 | call). 211 | """ 212 | cache_map = {} 213 | objs = list(self.D.objects.all()) 214 | fetch_related( 215 | objs, {'ref_a': True, 'ref_c': ["id"]}, cache_map=cache_map 216 | ) 217 | self.assertEqual(objs[0].ref_c.pk, self.c1.pk) 218 | self.assertEqual(objs[0].ref_a.pk, self.a3.pk) 219 | 220 | # C reference shouldn't be cached because it was a partial fetch 221 | self.assertEqual(cache_map, {self.A: {self.a3.pk: self.a3}, self.C: {}}) 222 | 223 | def test_safe_reference_fields(self): 224 | """ 225 | Make sure SafeReferenceField and SafeReferenceListField don't fetch 226 | the entire objects if we use a partial fetch_related on them. 227 | """ 228 | objs = list(self.E.objects.all()) 229 | 230 | with custom_query_counter() as q: 231 | fetch_related(objs, {'refs_a': ["id"], 'ref_b': ["id"]}) 232 | 233 | # make sure the IDs match 234 | self.assertEqual( 235 | [a.pk for a in objs[0].refs_a], [self.a1.pk, self.a2.pk, self.a3.pk] 236 | ) 237 | self.assertEqual(objs[0].ref_b.pk, self.b1.pk) 238 | 239 | # make sure other fields are empty 240 | self.assertEqual(set([a.txt for a in objs[0].refs_a]), set([None])) 241 | self.assertEqual(objs[0].ref_b.ref, None) 242 | 243 | # make sure the queries to MongoDB only fetched the IDs 244 | queries = list( 245 | q.db.system.profile.find({'op': 'query'}, {'ns': 1, 'execStats': 1}) 246 | ) 247 | self.assertEqual({q['ns'].split('.')[1] for q in queries}, {'a', 'b'}) 248 | self.assertEqual( 249 | {q['execStats']['stage'] for q in queries}, {'PROJECTION'} 250 | ) 251 | self.assertEqual( 252 | {tuple(q['execStats']['transformBy'].keys()) for q in queries}, 253 | {('_id',)}, 254 | ) 255 | 256 | def test_fetch_field_without_refs(self): 257 | """ 258 | Make sure calling fetch_related on a field that doesn't hold any 259 | references works. 260 | """ 261 | # full fetch 262 | objs = list(self.F.objects.all()) 263 | fetch_related(objs, {'ref_a': True}) 264 | self.assertEqual(objs[0].ref_a, None) 265 | 266 | # partial fetch 267 | objs = list(self.F.objects.all()) 268 | fetch_related(objs, {'ref_a': ["id"]}) 269 | self.assertEqual(objs[0].ref_a, None) 270 | 271 | def test_fetch_same_doc_class_multiple_times_with_cache_map(self): 272 | """ 273 | Make sure that the right documents are fetched when we reuse a cache 274 | map for the same document type and the second fetch_related is a 275 | partial fetch. 276 | """ 277 | self.b1.reload() 278 | self.c1.reload() 279 | cache_map = {} 280 | objs = [self.b1, self.c1] 281 | with custom_query_counter() as q: 282 | fetch_related(objs, {'ref': True}, cache_map=cache_map) 283 | fetch_related(objs, {'ref_a': ['id']}, cache_map=cache_map) 284 | 285 | self.assertEqual(q, 2) 286 | self.assertEqual( 287 | [ 288 | op['query']['filter']['_id']['$in'][0] 289 | for op in q.db.system.profile.find({'op': 'query'}) 290 | ], 291 | [self.a1.pk, self.a3.pk], 292 | ) 293 | 294 | def test_extra_filters(self): 295 | """ 296 | Ensure we apply extra filters by collection. 297 | """ 298 | objs = list(self.E.objects.all()) 299 | 300 | with custom_query_counter() as q: 301 | fetch_related( 302 | objs, 303 | {'refs_a': True, 'ref_b': True}, 304 | extra_filters={ 305 | self.A: {'shard_a': self.shard}, 306 | self.B: {'shard_b': self.shard}, 307 | }, 308 | ) 309 | ops = list(q.db.system.profile.find({'op': 'query'})) 310 | assert len(ops) == 2 311 | filters = {op['query']['find']: op['query']['filter'] for op in ops} 312 | assert filters['a']['shard_a'] == self.shard.pk 313 | assert filters['b']['shard_b'] == self.shard.pk 314 | 315 | def test_batch_size_1(self): 316 | """ 317 | Ensure we batch requests properly, if a batch size is given. 318 | """ 319 | objs = list(self.B.objects.all()) 320 | 321 | with custom_query_counter() as q: 322 | fetch_related(objs, {'ref': True}, batch_size=2) 323 | 324 | # make sure A objs are fetched 325 | for obj in objs: 326 | self.assertTrue(obj.ref.txt in ('a1', 'a2', 'a3')) 327 | 328 | # We need two queries to fetch 3 objects. 329 | self.assertEqual(q, 2) 330 | 331 | def test_batch_size_2(self): 332 | """ 333 | Ensure we batch requests properly, if a batch size is given. 334 | """ 335 | objs = list(self.B.objects.all()) 336 | 337 | with custom_query_counter() as q: 338 | fetch_related(objs, {'ref': True}, batch_size=3) 339 | 340 | # make sure A objs are fetched 341 | for obj in objs: 342 | self.assertTrue(obj.ref.txt in ('a1', 'a2', 'a3')) 343 | 344 | # All 3 objects are fetched in one query. 345 | self.assertEqual(q, 1) 346 | -------------------------------------------------------------------------------- /tests/test_python_support.py: -------------------------------------------------------------------------------- 1 | def test_importing_mongo(): 2 | """Verify that we can at least *import* the `mongo` package. 3 | 4 | This is a given on Python 2, but we want to make sure that at least the 5 | syntax is parseable by Python 3. 6 | """ 7 | from flask_common import mongo 8 | 9 | assert mongo 10 | -------------------------------------------------------------------------------- /tests/test_test_helpers.py: -------------------------------------------------------------------------------- 1 | from flask_common.test_helpers import ( 2 | Capture, 3 | DictCompare, 4 | RegexSetCompare, 5 | SetCompare, 6 | ) 7 | 8 | 9 | # Note we're using "not" instead of "!=" for comparisons here since the latter 10 | # uses __ne__, which is not implemented. 11 | 12 | 13 | def test_set_compare(): 14 | assert SetCompare([1, 2, 3]) == [2, 3, 1] 15 | assert not (SetCompare([1, 2, 3]) == [2, 2, 2]) 16 | 17 | assert not (SetCompare([1, 2, 3]) != [2, 3, 1]) 18 | assert SetCompare([1, 2, 3]) != [2, 2, 2] 19 | 20 | 21 | def test_regex_set_compare(): 22 | regex = '(.*) OR (.*) OR (.*)' 23 | assert RegexSetCompare(regex, ['1', '2', '3']) == '2 OR 3 OR 1' 24 | assert not (RegexSetCompare(regex, ['2', '2', '2']) == '2 OR 3 OR 1') 25 | 26 | assert not (RegexSetCompare(regex, ['1', '2', '3']) != '2 OR 3 OR 1') 27 | assert RegexSetCompare(regex, ['2', '2', '2']) != '2 OR 3 OR 1' 28 | 29 | 30 | def test_capture(): 31 | capture = Capture() 32 | assert capture == 'hello' 33 | assert capture() == 'hello' 34 | 35 | 36 | def test_dict_compare(): 37 | assert DictCompare({'a': 'b'}) == {'a': 'b'} 38 | assert not (DictCompare({'a': 'b'}) == {'a': 'c'}) 39 | assert DictCompare({'a': 'b'}) == {'a': 'b', 'c': 'd'} 40 | assert not (DictCompare({'a': 'b'}) == {'a': 'c', 'c': 'd'}) 41 | assert not (DictCompare({'c': 'd'}) == {'a': 'b'}) 42 | 43 | assert not (DictCompare({'a': 'b'}) != {'a': 'b'}) 44 | assert DictCompare({'a': 'b'}) != {'a': 'c'} 45 | assert not (DictCompare({'a': 'b'}) != {'a': 'b', 'c': 'd'}) 46 | assert DictCompare({'a': 'b'}) != {'a': 'c', 'c': 'd'} 47 | assert DictCompare({'c': 'd'}) != {'a': 'b'} 48 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py27,py34,py35,py36 3 | 4 | [testenv] 5 | commands=pytest {posargs} 6 | deps=-rrequirements.txt 7 | pytest 8 | --------------------------------------------------------------------------------