├── .gitignore ├── .requires ├── .test-requires ├── LICENSE ├── MANIFEST.in ├── README.rst ├── bin ├── compactor_daemon ├── dump_limits ├── remote_daemon ├── setup_limits └── turnstile_command ├── setup.py ├── tests ├── __init__.py └── unit │ ├── __init__.py │ ├── test_compactor.py │ ├── test_config.py │ ├── test_control.py │ ├── test_database.py │ ├── test_limits.py │ ├── test_middleware.py │ ├── test_remote.py │ ├── test_tools.py │ ├── test_utils.py │ └── utils.py ├── tox.ini └── turnstile ├── __init__.py ├── compactor.py ├── config.py ├── control.py ├── database.py ├── limits.py ├── middleware.py ├── remote.py ├── tools.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /build 2 | /.coverage 3 | /cov_html 4 | /dist 5 | /.tox 6 | /turnstile.egg-info 7 | *.log 8 | *.pyc 9 | -------------------------------------------------------------------------------- /.requires: -------------------------------------------------------------------------------- 1 | argparse 2 | eventlet 3 | lxml>=2.3 4 | metatools 5 | msgpack-python 6 | redis 7 | routes 8 | setuptools 9 | -------------------------------------------------------------------------------- /.test-requires: -------------------------------------------------------------------------------- 1 | mock>=1.0b1 2 | nose 3 | unittest2>=0.5.1 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE README.rst .requires .test-requires 2 | include tests/*.py 3 | graft bin 4 | -------------------------------------------------------------------------------- /bin/compactor_daemon: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2012 Rackspace 4 | # All Rights Reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 7 | # not use this file except in compliance with the License. You may obtain 8 | # a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 14 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 15 | # License for the specific language governing permissions and limitations 16 | # under the License. 17 | # 18 | # 19 | # Note: This executable is provided as a convenience for running this 20 | # tool out of the source tree. It is not, nor should it be, directly 21 | # installed by setup.py. Turnstile uses console_scripts entry points 22 | # to advertise executable commands. You can find the implementation 23 | # of this script in the turnstile/tools.py file. 24 | 25 | import os 26 | import sys 27 | 28 | 29 | # We need the tools module from turnstile 30 | poss_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]), 31 | os.pardir, 32 | os.pardir)) 33 | if os.path.exists(os.path.join(poss_topdir, 'turnstile', '__init__.py')): 34 | sys.path.insert(0, poss_topdir) 35 | 36 | 37 | from turnstile import tools 38 | 39 | 40 | if __name__ == '__main__': 41 | tools.compactor_daemon.console() 42 | -------------------------------------------------------------------------------- /bin/dump_limits: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2012 Rackspace 4 | # All Rights Reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 7 | # not use this file except in compliance with the License. You may obtain 8 | # a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 14 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 15 | # License for the specific language governing permissions and limitations 16 | # under the License. 17 | # 18 | # 19 | # Note: This executable is provided as a convenience for running this 20 | # tool out of the source tree. It is not, nor should it be, directly 21 | # installed by setup.py. Turnstile uses console_scripts entry points 22 | # to advertise executable commands. You can find the implementation 23 | # of this script in the turnstile/tools.py file. 24 | 25 | import os 26 | import sys 27 | 28 | 29 | # We need the tools module from turnstile 30 | poss_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]), 31 | os.pardir, 32 | os.pardir)) 33 | if os.path.exists(os.path.join(poss_topdir, 'turnstile', '__init__.py')): 34 | sys.path.insert(0, poss_topdir) 35 | 36 | 37 | from turnstile import tools 38 | 39 | 40 | if __name__ == '__main__': 41 | tools.dump_limits.console() 42 | -------------------------------------------------------------------------------- /bin/remote_daemon: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2012 Rackspace 4 | # All Rights Reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 7 | # not use this file except in compliance with the License. You may obtain 8 | # a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 14 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 15 | # License for the specific language governing permissions and limitations 16 | # under the License. 17 | # 18 | # 19 | # Note: This executable is provided as a convenience for running this 20 | # tool out of the source tree. It is not, nor should it be, directly 21 | # installed by setup.py. Turnstile uses console_scripts entry points 22 | # to advertise executable commands. You can find the implementation 23 | # of this script in the turnstile/tools.py file. 24 | 25 | import os 26 | import sys 27 | 28 | 29 | # We need the tools module from turnstile 30 | poss_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]), 31 | os.pardir, 32 | os.pardir)) 33 | if os.path.exists(os.path.join(poss_topdir, 'turnstile', '__init__.py')): 34 | sys.path.insert(0, poss_topdir) 35 | 36 | 37 | from turnstile import tools 38 | 39 | 40 | if __name__ == '__main__': 41 | tools.remote_daemon.console() 42 | -------------------------------------------------------------------------------- /bin/setup_limits: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2012 Rackspace 4 | # All Rights Reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 7 | # not use this file except in compliance with the License. You may obtain 8 | # a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 14 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 15 | # License for the specific language governing permissions and limitations 16 | # under the License. 17 | # 18 | # 19 | # Note: This executable is provided as a convenience for running this 20 | # tool out of the source tree. It is not, nor should it be, directly 21 | # installed by setup.py. Turnstile uses console_scripts entry points 22 | # to advertise executable commands. You can find the implementation 23 | # of this script in the turnstile/tools.py file. 24 | 25 | import os 26 | import sys 27 | 28 | 29 | # We need the tools module from turnstile 30 | poss_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]), 31 | os.pardir, 32 | os.pardir)) 33 | if os.path.exists(os.path.join(poss_topdir, 'turnstile', '__init__.py')): 34 | sys.path.insert(0, poss_topdir) 35 | 36 | 37 | from turnstile import tools 38 | 39 | 40 | if __name__ == '__main__': 41 | tools.setup_limits.console() 42 | -------------------------------------------------------------------------------- /bin/turnstile_command: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2012 Rackspace 4 | # All Rights Reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 7 | # not use this file except in compliance with the License. You may obtain 8 | # a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 14 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 15 | # License for the specific language governing permissions and limitations 16 | # under the License. 17 | # 18 | # 19 | # Note: This executable is provided as a convenience for running this 20 | # tool out of the source tree. It is not, nor should it be, directly 21 | # installed by setup.py. Turnstile uses console_scripts entry points 22 | # to advertise executable commands. You can find the implementation 23 | # of this script in the turnstile/tools.py file. 24 | 25 | import os 26 | import sys 27 | 28 | 29 | # We need the tools module from turnstile 30 | poss_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]), 31 | os.pardir, 32 | os.pardir)) 33 | if os.path.exists(os.path.join(poss_topdir, 'turnstile', '__init__.py')): 34 | sys.path.insert(0, poss_topdir) 35 | 36 | 37 | from turnstile import tools 38 | 39 | 40 | if __name__ == '__main__': 41 | tools.turnstile_command.console() 42 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | 5 | from setuptools import setup 6 | 7 | 8 | def readreq(filename): 9 | result = [] 10 | with open(filename) as f: 11 | for req in f: 12 | req = req.partition('#')[0].strip() 13 | if not req: 14 | continue 15 | result.append(req) 16 | return result 17 | 18 | 19 | def readfile(filename): 20 | with open(filename) as f: 21 | return f.read() 22 | 23 | 24 | setup( 25 | name='turnstile', 26 | version='0.7.0b2', 27 | author='Kevin L. Mitchell', 28 | author_email='kevin.mitchell@rackspace.com', 29 | url='https://github.com/klmitch/turnstile', 30 | description="Distributed rate-limiting middleware", 31 | long_description=readfile('README.rst'), 32 | license='Apache License (2.0)', 33 | classifiers=[ 34 | 'Development Status :: 4 - Beta', 35 | 'Environment :: Web Environment', 36 | 'Framework :: Paste', 37 | 'Intended Audience :: System Administrators', 38 | 'License :: OSI Approved :: Apache Software License', 39 | 'Programming Language :: Python', 40 | 'Topic :: Internet :: WWW/HTTP :: WSGI :: Middleware', 41 | ], 42 | packages=['turnstile'], 43 | install_requires=readreq('.requires'), 44 | tests_require=readreq('.test-requires'), 45 | entry_points={ 46 | 'paste.filter_factory': [ 47 | 'turnstile = turnstile.middleware:turnstile_filter', 48 | ], 49 | 'console_scripts': [ 50 | 'setup_limits = turnstile.tools:setup_limits.console', 51 | 'dump_limits = turnstile.tools:dump_limits.console', 52 | 'remote_daemon = turnstile.tools:remote_daemon.console', 53 | 'turnstile_command = turnstile.tools:turnstile_command.console', 54 | 'compactor_daemon = turnstile.tools:compactor.console', 55 | ], 56 | 'turnstile.redis_client': [ 57 | 'redis = redis:StrictRedis', 58 | ], 59 | 'turnstile.connection_class': [ 60 | 'redis = redis:Connection', 61 | 'unix_domain = redis:UnixDomainSocketConnection', 62 | ], 63 | 'turnstile.connection_pool': [ 64 | 'redis = redis:ConnectionPool', 65 | ], 66 | 'turnstile.limit': [ 67 | 'limit = turnstile.limits:Limit', 68 | ], 69 | 'turnstile.middleware': [ 70 | 'turnstile = turnstile.middleware:TurnstileMiddleware', 71 | ], 72 | }, 73 | ) 74 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2013 Rackspace 2 | # All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | # not use this file except in compliance with the License. You may obtain 6 | # a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13 | # License for the specific language governing permissions and limitations 14 | # under the License. 15 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2013 Rackspace 2 | # All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | # not use this file except in compliance with the License. You may obtain 6 | # a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13 | # License for the specific language governing permissions and limitations 14 | # under the License. 15 | -------------------------------------------------------------------------------- /tests/unit/test_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2013 Rackspace 2 | # All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | # not use this file except in compliance with the License. You may obtain 6 | # a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13 | # License for the specific language governing permissions and limitations 14 | # under the License. 15 | 16 | import mock 17 | import unittest2 18 | 19 | from turnstile import config 20 | from turnstile import database 21 | 22 | 23 | class TestConfig(unittest2.TestCase): 24 | @mock.patch('ConfigParser.SafeConfigParser') 25 | def test_init_empty(self, mock_SafeConfigParser): 26 | cfg = config.Config() 27 | 28 | self.assertEqual(cfg._config, { 29 | None: { 30 | 'status': '413 Request Entity Too Large', 31 | }, 32 | }) 33 | self.assertFalse(mock_SafeConfigParser.called) 34 | 35 | @mock.patch('ConfigParser.SafeConfigParser') 36 | def test_init_dict(self, mock_SafeConfigParser): 37 | local_conf = { 38 | 'preprocess': 'foo:bar', 39 | 'redis.host': '10.0.0.1', 40 | 'control.channel': 'control_channel', 41 | 'control.connection_pool.connection': 'FoobarConnection', 42 | } 43 | 44 | cfg = config.Config(conf_dict=local_conf) 45 | 46 | self.assertEqual(cfg._config, { 47 | None: { 48 | 'status': '413 Request Entity Too Large', 49 | 'preprocess': 'foo:bar', 50 | }, 51 | 'redis': { 52 | 'host': '10.0.0.1', 53 | }, 54 | 'control': { 55 | 'channel': 'control_channel', 56 | 'connection_pool.connection': 'FoobarConnection', 57 | }, 58 | }) 59 | self.assertFalse(mock_SafeConfigParser.called) 60 | 61 | @mock.patch('ConfigParser.SafeConfigParser', return_value=mock.Mock(**{ 62 | 'sections.return_value': [], 63 | })) 64 | def test_init_files(self, mock_SafeConfigParser): 65 | local_conf = { 66 | 'config': 'file_from_dict', 67 | } 68 | 69 | cfg = config.Config(conf_dict=local_conf, conf_file='file_from_args') 70 | 71 | self.assertEqual(cfg._config, { 72 | None: { 73 | 'status': '413 Request Entity Too Large', 74 | 'config': 'file_from_dict', 75 | }, 76 | }) 77 | mock_SafeConfigParser.assert_called_once_with() 78 | mock_SafeConfigParser.return_value.read.assert_called_once_with( 79 | ['file_from_dict', 'file_from_args']) 80 | 81 | @mock.patch('ConfigParser.SafeConfigParser', return_value=mock.Mock()) 82 | def test_init_from_files(self, mock_SafeConfigParser): 83 | items = { 84 | 'turnstile': [ 85 | ('preprocess', 'foo:bar'), 86 | ], 87 | 'redis': [ 88 | ('password', 'spampass'), 89 | ], 90 | 'control': [ 91 | ('channel', 'control_channel'), 92 | ('connection_pool', 'FoobarConnectionPool'), 93 | ], 94 | } 95 | mock_SafeConfigParser.return_value.sections.return_value = \ 96 | ['turnstile', 'redis', 'control'] 97 | mock_SafeConfigParser.return_value.items.side_effect = \ 98 | lambda x: items[x] 99 | local_conf = { 100 | 'config': 'file_from_dict', 101 | 'status': '500 Internal Error', 102 | 'redis.host': '10.0.0.1', 103 | } 104 | 105 | cfg = config.Config(conf_dict=local_conf) 106 | 107 | self.assertEqual(cfg._config, { 108 | None: { 109 | 'status': '500 Internal Error', 110 | 'config': 'file_from_dict', 111 | 'preprocess': 'foo:bar', 112 | }, 113 | 'redis': { 114 | 'host': '10.0.0.1', 115 | 'password': 'spampass', 116 | }, 117 | 'control': { 118 | 'channel': 'control_channel', 119 | 'connection_pool': 'FoobarConnectionPool', 120 | }, 121 | }) 122 | mock_SafeConfigParser.assert_called_once_with() 123 | mock_SafeConfigParser.return_value.assert_has_calls([ 124 | mock.call.read(['file_from_dict']), 125 | mock.call.sections(), 126 | mock.call.items('turnstile'), 127 | mock.call.items('redis'), 128 | mock.call.items('control'), 129 | ]) 130 | 131 | @mock.patch('ConfigParser.SafeConfigParser') 132 | def test_getitem(self, mock_SafeConfigParser): 133 | local_conf = { 134 | 'preprocess': 'foo:bar', 135 | 'redis.host': '10.0.0.1', 136 | 'control.channel': 'control_channel', 137 | 'control.connection_pool.connection': 'FoobarConnection', 138 | } 139 | cfg = config.Config(conf_dict=local_conf) 140 | 141 | self.assertEqual(cfg['redis'], dict(host='10.0.0.1')) 142 | self.assertEqual(cfg['nosuch'], {}) 143 | 144 | @mock.patch('ConfigParser.SafeConfigParser') 145 | def test_contains(self, mock_SafeConfigParser): 146 | local_conf = { 147 | 'preprocess': 'foo:bar', 148 | 'redis.host': '10.0.0.1', 149 | 'control.channel': 'control_channel', 150 | 'control.connection_pool.connection': 'FoobarConnection', 151 | } 152 | cfg = config.Config(conf_dict=local_conf) 153 | 154 | self.assertTrue('redis' in cfg) 155 | self.assertFalse('nosuch' in cfg) 156 | 157 | @mock.patch('ConfigParser.SafeConfigParser') 158 | def test_getattr(self, mock_SafeConfigParser): 159 | local_conf = { 160 | 'preprocess': 'foo:bar', 161 | 'redis.host': '10.0.0.1', 162 | 'control.channel': 'control_channel', 163 | 'control.connection_pool.connection': 'FoobarConnection', 164 | } 165 | cfg = config.Config(conf_dict=local_conf) 166 | 167 | self.assertEqual(cfg.preprocess, 'foo:bar') 168 | with self.assertRaises(AttributeError): 169 | dummy = cfg.nosuch 170 | 171 | @mock.patch('ConfigParser.SafeConfigParser') 172 | def test_get(self, mock_SafeConfigParser): 173 | local_conf = { 174 | 'preprocess': 'foo:bar', 175 | 'redis.host': '10.0.0.1', 176 | 'control.channel': 'control_channel', 177 | 'control.connection_pool.connection': 'FoobarConnection', 178 | } 179 | cfg = config.Config(conf_dict=local_conf) 180 | 181 | self.assertEqual(cfg.get('preprocess'), 'foo:bar') 182 | self.assertEqual(cfg.get('nosuch'), None) 183 | self.assertEqual(cfg.get('nosuch', 'other'), 'other') 184 | 185 | @mock.patch('ConfigParser.SafeConfigParser') 186 | @mock.patch.object(database, 'initialize', return_value='db_handle') 187 | def test_get_database_basic(self, mock_initialize, mock_SafeConfigParser): 188 | local_conf = { 189 | 'redis.host': '10.0.0.1', 190 | 'redis.password': 'spampass', 191 | 'redis.db': '3', 192 | 'control.host': '10.0.0.2', 193 | 'control.redis.host': '10.0.0.11', 194 | 'control.redis.password': 'passspam', 195 | 'control.redis.port': '1234', 196 | } 197 | cfg = config.Config(conf_dict=local_conf) 198 | 199 | result = cfg.get_database() 200 | 201 | self.assertEqual(result, 'db_handle') 202 | mock_initialize.assert_called_once_with({ 203 | 'host': '10.0.0.1', 204 | 'password': 'spampass', 205 | 'db': '3', 206 | }) 207 | self.assertEqual(cfg._config, { 208 | None: { 209 | 'status': '413 Request Entity Too Large', 210 | }, 211 | 'redis': { 212 | 'host': '10.0.0.1', 213 | 'password': 'spampass', 214 | 'db': '3', 215 | }, 216 | 'control': { 217 | 'host': '10.0.0.2', 218 | 'redis.host': '10.0.0.11', 219 | 'redis.password': 'passspam', 220 | 'redis.port': '1234', 221 | }, 222 | }) 223 | 224 | @mock.patch('ConfigParser.SafeConfigParser') 225 | @mock.patch.object(database, 'initialize', return_value='db_handle') 226 | def test_get_database_override(self, mock_initialize, 227 | mock_SafeConfigParser): 228 | local_conf = { 229 | 'redis.host': '10.0.0.1', 230 | 'redis.password': 'spampass', 231 | 'redis.db': '3', 232 | 'control.host': '10.0.0.2', 233 | 'control.redis.host': '10.0.0.11', 234 | 'control.redis.port': '1234', 235 | 'control.redis.password': 'passspam', 236 | 'control.redis.db': '', 237 | } 238 | cfg = config.Config(conf_dict=local_conf) 239 | 240 | result = cfg.get_database(override='control') 241 | 242 | self.assertEqual(result, 'db_handle') 243 | mock_initialize.assert_called_once_with({ 244 | 'host': '10.0.0.11', 245 | 'port': '1234', 246 | 'password': 'passspam', 247 | }) 248 | self.assertEqual(cfg._config, { 249 | None: { 250 | 'status': '413 Request Entity Too Large', 251 | }, 252 | 'redis': { 253 | 'host': '10.0.0.1', 254 | 'password': 'spampass', 255 | 'db': '3', 256 | }, 257 | 'control': { 258 | 'host': '10.0.0.2', 259 | 'redis.host': '10.0.0.11', 260 | 'redis.password': 'passspam', 261 | 'redis.port': '1234', 262 | 'redis.db': '', 263 | }, 264 | }) 265 | 266 | def test_to_bool_integers(self): 267 | self.assertEqual(config.Config.to_bool('0'), False) 268 | self.assertEqual(config.Config.to_bool('1'), True) 269 | self.assertEqual(config.Config.to_bool('123412341234'), True) 270 | 271 | def test_to_bool_true(self): 272 | self.assertEqual(config.Config.to_bool('t'), True) 273 | self.assertEqual(config.Config.to_bool('true'), True) 274 | self.assertEqual(config.Config.to_bool('on'), True) 275 | self.assertEqual(config.Config.to_bool('y'), True) 276 | self.assertEqual(config.Config.to_bool('yes'), True) 277 | 278 | def test_to_bool_false(self): 279 | self.assertEqual(config.Config.to_bool('f'), False) 280 | self.assertEqual(config.Config.to_bool('false'), False) 281 | self.assertEqual(config.Config.to_bool('off'), False) 282 | self.assertEqual(config.Config.to_bool('n'), False) 283 | self.assertEqual(config.Config.to_bool('no'), False) 284 | 285 | def test_to_bool_invalid(self): 286 | self.assertRaises(ValueError, config.Config.to_bool, 'invalid') 287 | 288 | def test_to_bool_invalid_noraise(self): 289 | self.assertEqual(config.Config.to_bool('invalid', False), False) 290 | -------------------------------------------------------------------------------- /tests/unit/test_control.py: -------------------------------------------------------------------------------- 1 | # Copyright 2013 Rackspace 2 | # All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | # not use this file except in compliance with the License. You may obtain 6 | # a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13 | # License for the specific language governing permissions and limitations 14 | # under the License. 15 | 16 | import eventlet.semaphore 17 | import mock 18 | import unittest2 19 | 20 | from turnstile import config 21 | from turnstile import control 22 | from turnstile import utils 23 | 24 | from tests.unit import utils as test_utils 25 | 26 | 27 | class TestLimitData(unittest2.TestCase): 28 | # MD5 sum of '' 29 | EMPTY_CHECKSUM = 'd41d8cd98f00b204e9800998ecf8427e' 30 | 31 | # Test data 32 | TEST_DATA = ['Nobody', 'inspects', 'the', 'spammish', 'repetition'] 33 | 34 | # MD5 sum of the test data 35 | TEST_DATA_CHECKSUM = '2c79f652a24d3f4438b6b4034ae120cb' 36 | 37 | def test_init(self): 38 | ld = control.LimitData() 39 | 40 | self.assertEqual(ld.limit_data, []) 41 | self.assertEqual(ld.limit_sum, self.EMPTY_CHECKSUM) 42 | self.assertIsInstance(ld.limit_lock, eventlet.semaphore.Semaphore) 43 | 44 | @mock.patch.object(eventlet.semaphore, 'Semaphore', 45 | return_value=mock.MagicMock()) 46 | @mock.patch('msgpack.loads', side_effect=lambda x: x) 47 | def test_set_limits_nochange(self, mock_loads, mock_Semaphore): 48 | ld = control.LimitData() 49 | ld.limit_sum = self.TEST_DATA_CHECKSUM 50 | 51 | ld.set_limits(self.TEST_DATA) 52 | 53 | mock_Semaphore.return_value.assert_has_calls([ 54 | mock.call.__enter__(), 55 | mock.call.__exit__(None, None, None), 56 | ]) 57 | self.assertFalse(mock_loads.called) 58 | self.assertEqual(ld.limit_data, []) 59 | self.assertEqual(ld.limit_sum, self.TEST_DATA_CHECKSUM) 60 | 61 | @mock.patch.object(eventlet.semaphore, 'Semaphore', 62 | return_value=mock.MagicMock()) 63 | @mock.patch('msgpack.loads', side_effect=lambda x: x) 64 | def test_set_limits(self, mock_loads, mock_Semaphore): 65 | ld = control.LimitData() 66 | 67 | ld.set_limits(self.TEST_DATA) 68 | 69 | mock_Semaphore.return_value.assert_has_calls([ 70 | mock.call.__enter__(), 71 | mock.call.__exit__(None, None, None), 72 | ]) 73 | mock_loads.assert_has_calls([mock.call(x) for x in self.TEST_DATA]) 74 | self.assertEqual(ld.limit_data, self.TEST_DATA) 75 | self.assertEqual(ld.limit_sum, self.TEST_DATA_CHECKSUM) 76 | 77 | @mock.patch.object(eventlet.semaphore, 'Semaphore', 78 | return_value=mock.MagicMock()) 79 | def test_get_limits_nosum(self, mock_Semaphore): 80 | ld = control.LimitData() 81 | ld.limit_data = self.TEST_DATA 82 | ld.limit_sum = self.TEST_DATA_CHECKSUM 83 | 84 | result = ld.get_limits() 85 | 86 | mock_Semaphore.return_value.assert_has_calls([ 87 | mock.call.__enter__(), 88 | mock.call.__exit__(None, None, None), 89 | ]) 90 | self.assertEqual(result, (self.TEST_DATA_CHECKSUM, self.TEST_DATA)) 91 | 92 | @mock.patch.object(eventlet.semaphore, 'Semaphore', 93 | return_value=mock.MagicMock()) 94 | def test_get_limits_nochange(self, mock_Semaphore): 95 | ld = control.LimitData() 96 | ld.limit_data = self.TEST_DATA 97 | ld.limit_sum = self.TEST_DATA_CHECKSUM 98 | 99 | self.assertRaises(control.NoChangeException, ld.get_limits, 100 | self.TEST_DATA_CHECKSUM) 101 | mock_Semaphore.return_value.assert_has_calls([ 102 | mock.call.__enter__(), 103 | mock.call.__exit__(control.NoChangeException, mock.ANY, mock.ANY), 104 | ]) 105 | 106 | @mock.patch.object(eventlet.semaphore, 'Semaphore', 107 | return_value=mock.MagicMock()) 108 | def test_get_limits(self, mock_Semaphore): 109 | ld = control.LimitData() 110 | ld.limit_data = self.TEST_DATA 111 | ld.limit_sum = self.TEST_DATA_CHECKSUM 112 | 113 | result = ld.get_limits(self.EMPTY_CHECKSUM) 114 | 115 | mock_Semaphore.return_value.assert_has_calls([ 116 | mock.call.__enter__(), 117 | mock.call.__exit__(None, None, None), 118 | ]) 119 | self.assertEqual(result, (self.TEST_DATA_CHECKSUM, self.TEST_DATA)) 120 | 121 | 122 | class TestControlDaemon(unittest2.TestCase): 123 | @mock.patch.dict(control.ControlDaemon._commands) 124 | def test_register(self): 125 | self.assertEqual(control.ControlDaemon._commands, { 126 | 'ping': control.ping, 127 | 'reload': control.reload, 128 | }) 129 | 130 | control.ControlDaemon._register('spam', 'ni') 131 | 132 | self.assertEqual(control.ControlDaemon._commands, { 133 | 'ping': control.ping, 134 | 'reload': control.reload, 135 | 'spam': 'ni', 136 | }) 137 | 138 | def test_init(self): 139 | cd = control.ControlDaemon('middleware', 'config') 140 | 141 | self.assertEqual(cd._db, None) 142 | self.assertEqual(cd.middleware, 'middleware') 143 | self.assertEqual(cd.config, 'config') 144 | self.assertIsInstance(cd.limits, control.LimitData) 145 | self.assertIsInstance(cd.pending, eventlet.semaphore.Semaphore) 146 | self.assertEqual(cd.listen_thread, None) 147 | 148 | @mock.patch.object(eventlet, 'spawn_n', return_value='listen_thread') 149 | @mock.patch.object(control.ControlDaemon, 'reload') 150 | def test_start(self, mock_reload, mock_spawn_n): 151 | cd = control.ControlDaemon('middleware', 'config') 152 | 153 | cd.start() 154 | 155 | mock_spawn_n.assert_called_once_with(cd.listen) 156 | self.assertEqual(cd.listen_thread, 'listen_thread') 157 | mock_reload.assert_called_once_with() 158 | 159 | @mock.patch.dict(control.ControlDaemon._commands, clear=True, 160 | ping=mock.Mock(), _ping=mock.Mock(), 161 | fail=mock.Mock(side_effect=test_utils.TestException)) 162 | @mock.patch.object(utils, 'find_entrypoint') 163 | @mock.patch.object(config.Config, 'get_database') 164 | @mock.patch.object(control.LOG, 'error') 165 | @mock.patch.object(control.LOG, 'exception') 166 | def test_listen(self, mock_exception, mock_error, mock_get_database, 167 | mock_find_entrypoint): 168 | entrypoints = dict(discovered=mock.Mock()) 169 | mock_find_entrypoint.side_effect = (lambda x, y, compat: 170 | entrypoints.get(y)) 171 | pubsub = mock.Mock(**{'listen.return_value': [ 172 | { 173 | 'type': 'other', 174 | 'channel': 'control', 175 | 'data': 'ping', 176 | }, 177 | { 178 | 'type': 'pmessage', 179 | 'channel': 'other', 180 | 'data': 'ping', 181 | }, 182 | { 183 | 'type': 'message', 184 | 'channel': 'other', 185 | 'data': 'ping', 186 | }, 187 | { 188 | 'type': 'pmessage', 189 | 'channel': 'control', 190 | 'data': '', 191 | }, 192 | { 193 | 'type': 'message', 194 | 'channel': 'control', 195 | 'data': '', 196 | }, 197 | { 198 | 'type': 'pmessage', 199 | 'channel': 'control', 200 | 'data': '_ping', 201 | }, 202 | { 203 | 'type': 'message', 204 | 'channel': 'control', 205 | 'data': '_ping', 206 | }, 207 | { 208 | 'type': 'pmessage', 209 | 'channel': 'control', 210 | 'data': 'nosuch', 211 | }, 212 | { 213 | 'type': 'message', 214 | 'channel': 'control', 215 | 'data': 'nosuch', 216 | }, 217 | { 218 | 'type': 'pmessage', 219 | 'channel': 'control', 220 | 'data': 'fail', 221 | }, 222 | { 223 | 'type': 'message', 224 | 'channel': 'control', 225 | 'data': 'fail', 226 | }, 227 | { 228 | 'type': 'pmessage', 229 | 'channel': 'control', 230 | 'data': 'fail:arg1:arg2', 231 | }, 232 | { 233 | 'type': 'message', 234 | 'channel': 'control', 235 | 'data': 'fail:arg1:arg2', 236 | }, 237 | { 238 | 'type': 'pmessage', 239 | 'channel': 'control', 240 | 'data': 'ping', 241 | }, 242 | { 243 | 'type': 'message', 244 | 'channel': 'control', 245 | 'data': 'ping', 246 | }, 247 | { 248 | 'type': 'pmessage', 249 | 'channel': 'control', 250 | 'data': 'ping:arg1:arg2', 251 | }, 252 | { 253 | 'type': 'message', 254 | 'channel': 'control', 255 | 'data': 'ping:arg1:arg2', 256 | }, 257 | { 258 | 'type': 'pmessage', 259 | 'channel': 'control', 260 | 'data': 'discovered', 261 | }, 262 | { 263 | 'type': 'message', 264 | 'channel': 'control', 265 | 'data': 'discovered', 266 | }, 267 | { 268 | 'type': 'pmessage', 269 | 'channel': 'control', 270 | 'data': 'discovered:arg1:arg2', 271 | }, 272 | { 273 | 'type': 'message', 274 | 'channel': 'control', 275 | 'data': 'discovered:arg1:arg2', 276 | }, 277 | ]}) 278 | db = mock.Mock(**{'pubsub.return_value': pubsub}) 279 | mock_get_database.return_value = db 280 | cd = control.ControlDaemon('middleware', config.Config()) 281 | 282 | cd.listen() 283 | 284 | mock_get_database.assert_called_once_with('control') 285 | db.pubsub.assert_called_once_with() 286 | pubsub.assert_has_calls([ 287 | mock.call.subscribe('control'), 288 | mock.call.listen(), 289 | ]) 290 | mock_error.assert_has_calls([ 291 | mock.call("Cannot call internal command '_ping'"), 292 | mock.call("Cannot call internal command '_ping'"), 293 | mock.call("No such command 'nosuch'"), 294 | mock.call("No such command 'nosuch'"), 295 | ]) 296 | mock_exception.assert_has_calls([ 297 | mock.call("Failed to execute command 'fail' arguments []"), 298 | mock.call("Failed to execute command 'fail' arguments []"), 299 | mock.call("Failed to execute command 'fail' arguments " 300 | "['arg1', 'arg2']"), 301 | mock.call("Failed to execute command 'fail' arguments " 302 | "['arg1', 'arg2']"), 303 | ]) 304 | control.ControlDaemon._commands['ping'].assert_has_calls([ 305 | mock.call(cd), 306 | mock.call(cd), 307 | mock.call(cd, 'arg1', 'arg2'), 308 | mock.call(cd, 'arg1', 'arg2'), 309 | ]) 310 | self.assertFalse(control.ControlDaemon._commands['_ping'].called) 311 | control.ControlDaemon._commands['fail'].assert_has_calls([ 312 | mock.call(cd), 313 | mock.call(cd), 314 | mock.call(cd, 'arg1', 'arg2'), 315 | mock.call(cd, 'arg1', 'arg2'), 316 | ]) 317 | entrypoints['discovered'].assert_has_calls([ 318 | mock.call(cd), 319 | mock.call(cd), 320 | mock.call(cd, 'arg1', 'arg2'), 321 | mock.call(cd, 'arg1', 'arg2'), 322 | ]) 323 | mock_find_entrypoint.assert_has_calls([ 324 | mock.call('turnstile.command', 'nosuch', compat=False), 325 | mock.call('turnstile.command', 'discovered', compat=False), 326 | ]) 327 | self.assertEqual(len(mock_find_entrypoint.mock_calls), 2) 328 | self.assertEqual(control.ControlDaemon._commands['discovered'], 329 | entrypoints['discovered']) 330 | self.assertEqual(control.ControlDaemon._commands['nosuch'], None) 331 | 332 | @mock.patch.dict(control.ControlDaemon._commands, clear=True, 333 | ping=mock.Mock(), _ping=mock.Mock(), 334 | fail=mock.Mock(side_effect=test_utils.TestException)) 335 | @mock.patch.object(utils, 'find_entrypoint', return_value=None) 336 | @mock.patch.object(config.Config, 'get_database') 337 | @mock.patch.object(control.LOG, 'error') 338 | @mock.patch.object(control.LOG, 'exception') 339 | def test_listen_altchan(self, mock_exception, mock_error, 340 | mock_get_database, mock_find_entrypoint): 341 | pubsub = mock.Mock(**{'listen.return_value': [ 342 | { 343 | 'type': 'other', 344 | 'channel': 'control', 345 | 'data': 'ping', 346 | }, 347 | { 348 | 'type': 'pmessage', 349 | 'channel': 'other', 350 | 'data': 'ping', 351 | }, 352 | { 353 | 'type': 'message', 354 | 'channel': 'other', 355 | 'data': 'ping', 356 | }, 357 | { 358 | 'type': 'pmessage', 359 | 'channel': 'control', 360 | 'data': '', 361 | }, 362 | { 363 | 'type': 'message', 364 | 'channel': 'control', 365 | 'data': '', 366 | }, 367 | { 368 | 'type': 'pmessage', 369 | 'channel': 'control', 370 | 'data': '_ping', 371 | }, 372 | { 373 | 'type': 'message', 374 | 'channel': 'control', 375 | 'data': '_ping', 376 | }, 377 | { 378 | 'type': 'pmessage', 379 | 'channel': 'control', 380 | 'data': 'nosuch', 381 | }, 382 | { 383 | 'type': 'message', 384 | 'channel': 'control', 385 | 'data': 'nosuch', 386 | }, 387 | { 388 | 'type': 'pmessage', 389 | 'channel': 'control', 390 | 'data': 'fail', 391 | }, 392 | { 393 | 'type': 'message', 394 | 'channel': 'control', 395 | 'data': 'fail', 396 | }, 397 | { 398 | 'type': 'pmessage', 399 | 'channel': 'control', 400 | 'data': 'fail:arg1:arg2', 401 | }, 402 | { 403 | 'type': 'message', 404 | 'channel': 'control', 405 | 'data': 'fail:arg1:arg2', 406 | }, 407 | { 408 | 'type': 'pmessage', 409 | 'channel': 'control', 410 | 'data': 'ping', 411 | }, 412 | { 413 | 'type': 'message', 414 | 'channel': 'control', 415 | 'data': 'ping', 416 | }, 417 | { 418 | 'type': 'pmessage', 419 | 'channel': 'control', 420 | 'data': 'ping:arg1:arg2', 421 | }, 422 | { 423 | 'type': 'message', 424 | 'channel': 'control', 425 | 'data': 'ping:arg1:arg2', 426 | }, 427 | ]}) 428 | db = mock.Mock(**{'pubsub.return_value': pubsub}) 429 | mock_get_database.return_value = db 430 | cd = control.ControlDaemon('middleware', config.Config(conf_dict={ 431 | 'control.channel': 'other', 432 | })) 433 | 434 | cd.listen() 435 | 436 | mock_get_database.assert_called_once_with('control') 437 | db.pubsub.assert_called_once_with() 438 | pubsub.assert_has_calls([ 439 | mock.call.subscribe('other'), 440 | mock.call.listen(), 441 | ]) 442 | self.assertFalse(mock_error.called) 443 | self.assertFalse(mock_exception.called) 444 | control.ControlDaemon._commands['ping'].assert_has_calls([ 445 | mock.call(cd), 446 | mock.call(cd), 447 | ]) 448 | self.assertFalse(control.ControlDaemon._commands['_ping'].called) 449 | self.assertFalse(control.ControlDaemon._commands['fail'].called) 450 | self.assertFalse(mock_find_entrypoint.called) 451 | 452 | @mock.patch.object(config.Config, 'get_database') 453 | def test_listen_shardhint(self, mock_get_database): 454 | pubsub = mock.Mock(**{'listen.return_value': []}) 455 | db = mock.Mock(**{'pubsub.return_value': pubsub}) 456 | mock_get_database.return_value = db 457 | cd = control.ControlDaemon('middleware', config.Config(conf_dict={ 458 | 'control.shard_hint': 'shard', 459 | })) 460 | 461 | cd.listen() 462 | 463 | mock_get_database.assert_called_once_with('control') 464 | db.pubsub.assert_called_once_with(shard_hint='shard') 465 | pubsub.assert_has_calls([ 466 | mock.call.subscribe('control'), 467 | mock.call.listen(), 468 | ]) 469 | 470 | def test_get_limits(self): 471 | cd = control.ControlDaemon('middleware', 'config') 472 | cd.limits = 'limits' 473 | 474 | self.assertEqual(cd.get_limits(), 'limits') 475 | 476 | @mock.patch.object(control.LOG, 'exception') 477 | @mock.patch('traceback.format_exc', return_value='') 478 | def test_reload_noacquire(self, mock_format_exc, mock_exception): 479 | cd = control.ControlDaemon('middleware', config.Config()) 480 | cd.pending = mock.Mock(**{'acquire.return_value': False}) 481 | cd.limits = mock.Mock() 482 | cd._db = mock.Mock() 483 | 484 | cd.reload() 485 | 486 | cd.pending.assert_has_calls([ 487 | mock.call.acquire(False), 488 | ]) 489 | self.assertEqual(len(cd.pending.method_calls), 1) 490 | self.assertEqual(len(cd.limits.method_calls), 0) 491 | self.assertEqual(len(cd._db.method_calls), 0) 492 | self.assertFalse(mock_exception.called) 493 | self.assertFalse(mock_format_exc.called) 494 | 495 | @mock.patch.object(control.LOG, 'exception') 496 | @mock.patch('traceback.format_exc', return_value='') 497 | def test_reload(self, mock_format_exc, mock_exception): 498 | cd = control.ControlDaemon('middleware', config.Config()) 499 | cd.pending = mock.Mock(**{'acquire.return_value': True}) 500 | cd.limits = mock.Mock() 501 | cd._db = mock.Mock(**{'zrange.return_value': ['limit1', 'limit2']}) 502 | 503 | cd.reload() 504 | 505 | cd.pending.assert_has_calls([ 506 | mock.call.acquire(False), 507 | mock.call.release(), 508 | ]) 509 | self.assertEqual(len(cd.pending.method_calls), 2) 510 | cd.limits.set_limits.assert_called_once_with(['limit1', 'limit2']) 511 | cd._db.assert_has_calls([ 512 | mock.call.zrange('limits', 0, -1), 513 | ]) 514 | self.assertEqual(len(cd._db.method_calls), 1) 515 | self.assertFalse(mock_exception.called) 516 | self.assertFalse(mock_format_exc.called) 517 | 518 | @mock.patch.object(control.LOG, 'exception') 519 | @mock.patch('traceback.format_exc', return_value='') 520 | def test_reload_altlimits(self, mock_format_exc, mock_exception): 521 | cd = control.ControlDaemon('middleware', config.Config(conf_dict={ 522 | 'control.limits_key': 'other', 523 | })) 524 | cd.pending = mock.Mock(**{'acquire.return_value': True}) 525 | cd.limits = mock.Mock() 526 | cd._db = mock.Mock(**{'zrange.return_value': ['limit1', 'limit2']}) 527 | 528 | cd.reload() 529 | 530 | cd.pending.assert_has_calls([ 531 | mock.call.acquire(False), 532 | mock.call.release(), 533 | ]) 534 | self.assertEqual(len(cd.pending.method_calls), 2) 535 | cd.limits.set_limits.assert_called_once_with(['limit1', 'limit2']) 536 | cd._db.assert_has_calls([ 537 | mock.call.zrange('other', 0, -1), 538 | ]) 539 | self.assertEqual(len(cd._db.method_calls), 1) 540 | self.assertFalse(mock_exception.called) 541 | self.assertFalse(mock_format_exc.called) 542 | 543 | @mock.patch.object(control.LOG, 'exception') 544 | @mock.patch('traceback.format_exc', return_value='') 545 | def test_reload_exception(self, mock_format_exc, mock_exception): 546 | cd = control.ControlDaemon('middleware', config.Config()) 547 | cd.pending = mock.Mock(**{'acquire.return_value': True}) 548 | cd.limits = mock.Mock(**{ 549 | 'set_limits.side_effect': test_utils.TestException, 550 | }) 551 | cd._db = mock.Mock(**{'zrange.return_value': ['limit1', 'limit2']}) 552 | 553 | cd.reload() 554 | 555 | cd.pending.assert_has_calls([ 556 | mock.call.acquire(False), 557 | mock.call.release(), 558 | ]) 559 | self.assertEqual(len(cd.pending.method_calls), 2) 560 | cd.limits.set_limits.assert_called_once_with(['limit1', 'limit2']) 561 | cd._db.assert_has_calls([ 562 | mock.call.zrange('limits', 0, -1), 563 | mock.call.sadd('errors', 'Failed to load limits: '), 564 | mock.call.publish('errors', 'Failed to load limits: '), 565 | ]) 566 | self.assertEqual(len(cd._db.method_calls), 3) 567 | mock_exception.assert_called_once_with('Could not load limits') 568 | mock_format_exc.assert_called_once_with() 569 | 570 | @mock.patch.object(control.LOG, 'exception') 571 | @mock.patch('traceback.format_exc', return_value='') 572 | def test_reload_exception_altkeys(self, mock_format_exc, mock_exception): 573 | cd = control.ControlDaemon('middleware', config.Config(conf_dict={ 574 | 'control.errors_key': 'alt_err', 575 | 'control.errors_channel': 'alt_chan', 576 | })) 577 | cd.pending = mock.Mock(**{'acquire.return_value': True}) 578 | cd.limits = mock.Mock(**{ 579 | 'set_limits.side_effect': test_utils.TestException, 580 | }) 581 | cd._db = mock.Mock(**{'zrange.return_value': ['limit1', 'limit2']}) 582 | 583 | cd.reload() 584 | 585 | cd.pending.assert_has_calls([ 586 | mock.call.acquire(False), 587 | mock.call.release(), 588 | ]) 589 | self.assertEqual(len(cd.pending.method_calls), 2) 590 | cd.limits.set_limits.assert_called_once_with(['limit1', 'limit2']) 591 | cd._db.assert_has_calls([ 592 | mock.call.zrange('limits', 0, -1), 593 | mock.call.sadd('alt_err', 'Failed to load limits: '), 594 | mock.call.publish('alt_chan', 595 | 'Failed to load limits: '), 596 | ]) 597 | self.assertEqual(len(cd._db.method_calls), 3) 598 | mock_exception.assert_called_once_with('Could not load limits') 599 | mock_format_exc.assert_called_once_with() 600 | 601 | def test_db_present(self): 602 | middleware = mock.Mock(db='midware_db') 603 | cd = control.ControlDaemon(middleware, config.Config()) 604 | cd._db = 'cached_db' 605 | 606 | self.assertEqual(cd.db, 'cached_db') 607 | 608 | def test_db_middleware(self): 609 | middleware = mock.Mock(db='midware_db') 610 | cd = control.ControlDaemon(middleware, config.Config()) 611 | 612 | self.assertEqual(cd.db, 'midware_db') 613 | 614 | 615 | class TestRegister(unittest2.TestCase): 616 | @mock.patch.object(control.ControlDaemon, '_register') 617 | def test_as_function(self, mock_register): 618 | control.register('spam', 'func') 619 | 620 | mock_register.assert_called_once_with('spam', 'func') 621 | 622 | @mock.patch.object(control.ControlDaemon, '_register') 623 | def test_as_decorator(self, mock_register): 624 | @control.register('spam') 625 | def func(): 626 | pass 627 | 628 | mock_register.assert_called_once_with('spam', func) 629 | 630 | 631 | class TestPing(unittest2.TestCase): 632 | def test_ping_no_channel(self): 633 | conf = config.Config() 634 | db = mock.Mock() 635 | daemon = mock.Mock(config=conf, db=db) 636 | 637 | control.ping(daemon, '') 638 | 639 | self.assertFalse(db.publish.called) 640 | 641 | def test_ping_no_data_no_nodename(self): 642 | conf = config.Config() 643 | db = mock.Mock() 644 | daemon = mock.Mock(config=conf, db=db) 645 | 646 | control.ping(daemon, 'reply') 647 | 648 | db.publish.assert_called_once_with('reply', 'pong') 649 | 650 | def test_ping_with_data_no_nodename(self): 651 | conf = config.Config() 652 | db = mock.Mock() 653 | daemon = mock.Mock(config=conf, db=db) 654 | 655 | control.ping(daemon, 'reply', 'data') 656 | 657 | db.publish.assert_called_once_with('reply', 'pong::data') 658 | 659 | def test_ping_no_data_with_nodename(self): 660 | conf = config.Config(conf_dict={ 661 | 'control.node_name': 'node', 662 | }) 663 | db = mock.Mock() 664 | daemon = mock.Mock(config=conf, db=db) 665 | 666 | control.ping(daemon, 'reply') 667 | 668 | db.publish.assert_called_once_with('reply', 'pong:node') 669 | 670 | def test_ping_with_data_with_nodename(self): 671 | conf = config.Config(conf_dict={ 672 | 'control.node_name': 'node', 673 | }) 674 | db = mock.Mock() 675 | daemon = mock.Mock(config=conf, db=db) 676 | 677 | control.ping(daemon, 'reply', 'data') 678 | 679 | db.publish.assert_called_once_with('reply', 'pong:node:data') 680 | 681 | 682 | class TestReload(unittest2.TestCase): 683 | @mock.patch.object(eventlet, 'spawn_after') 684 | @mock.patch.object(eventlet, 'spawn_n') 685 | @mock.patch('random.random', return_value=0.5) 686 | def test_basic(self, mock_random, mock_spawn_n, mock_spawn_after): 687 | daemon = mock.Mock(reload='reload', config=config.Config()) 688 | 689 | control.reload(daemon) 690 | 691 | self.assertFalse(mock_random.called) 692 | self.assertFalse(mock_spawn_after.called) 693 | mock_spawn_n.assert_called_once_with('reload') 694 | 695 | @mock.patch.object(eventlet, 'spawn_after') 696 | @mock.patch.object(eventlet, 'spawn_n') 697 | @mock.patch('random.random', return_value=0.5) 698 | def test_configured_spread(self, mock_random, mock_spawn_n, 699 | mock_spawn_after): 700 | daemon = mock.Mock(reload='reload', config=config.Config(conf_dict={ 701 | 'control.reload_spread': '20.4', 702 | })) 703 | 704 | control.reload(daemon) 705 | 706 | mock_random.assert_called_once_with() 707 | mock_spawn_after.assert_called_once_with(10.2, 'reload') 708 | self.assertFalse(mock_spawn_n.called) 709 | 710 | @mock.patch.object(eventlet, 'spawn_after') 711 | @mock.patch.object(eventlet, 'spawn_n') 712 | @mock.patch('random.random', return_value=0.5) 713 | def test_configured_spread_bad(self, mock_random, mock_spawn_n, 714 | mock_spawn_after): 715 | daemon = mock.Mock(reload='reload', config=config.Config(conf_dict={ 716 | 'control.reload_spread': '20.4.3', 717 | })) 718 | 719 | control.reload(daemon) 720 | 721 | self.assertFalse(mock_random.called) 722 | self.assertFalse(mock_spawn_after.called) 723 | mock_spawn_n.assert_called_once_with('reload') 724 | 725 | @mock.patch.object(eventlet, 'spawn_after') 726 | @mock.patch.object(eventlet, 'spawn_n') 727 | @mock.patch('random.random', return_value=0.5) 728 | def test_configured_spread_override(self, mock_random, mock_spawn_n, 729 | mock_spawn_after): 730 | daemon = mock.Mock(reload='reload', config=config.Config(conf_dict={ 731 | 'control.reload_spread': '20.4', 732 | })) 733 | 734 | control.reload(daemon, 'immediate') 735 | 736 | self.assertFalse(mock_random.called) 737 | self.assertFalse(mock_spawn_after.called) 738 | mock_spawn_n.assert_called_once_with('reload') 739 | 740 | @mock.patch.object(eventlet, 'spawn_after') 741 | @mock.patch.object(eventlet, 'spawn_n') 742 | @mock.patch('random.random', return_value=0.5) 743 | def test_forced_spread(self, mock_random, mock_spawn_n, mock_spawn_after): 744 | daemon = mock.Mock(reload='reload', config=config.Config()) 745 | 746 | control.reload(daemon, 'spread', '20.4') 747 | 748 | mock_random.assert_called_once_with() 749 | mock_spawn_after.assert_called_once_with(10.2, 'reload') 750 | self.assertFalse(mock_spawn_n.called) 751 | 752 | @mock.patch.object(eventlet, 'spawn_after') 753 | @mock.patch.object(eventlet, 'spawn_n') 754 | @mock.patch('random.random', return_value=0.5) 755 | def test_bad_spread_fallback(self, mock_random, mock_spawn_n, 756 | mock_spawn_after): 757 | daemon = mock.Mock(reload='reload', config=config.Config(conf_dict={ 758 | 'control.reload_spread': '40.8', 759 | })) 760 | 761 | control.reload(daemon, 'spread', '20.4.3') 762 | 763 | mock_random.assert_called_once_with() 764 | mock_spawn_after.assert_called_once_with(20.4, 'reload') 765 | self.assertFalse(mock_spawn_n.called) 766 | -------------------------------------------------------------------------------- /tests/unit/test_database.py: -------------------------------------------------------------------------------- 1 | # Copyright 2013 Rackspace 2 | # All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | # not use this file except in compliance with the License. You may obtain 6 | # a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13 | # License for the specific language governing permissions and limitations 14 | # under the License. 15 | 16 | import mock 17 | import redis 18 | import unittest2 19 | 20 | from turnstile import database 21 | from turnstile import limits 22 | from turnstile import utils 23 | 24 | 25 | class TestInitialize(unittest2.TestCase): 26 | def make_entrypoints(self, mock_find_entrypoint, **entrypoints): 27 | def fake_find_entrypoint(group, name, compat=True, required=False): 28 | try: 29 | return entrypoints[name] 30 | except KeyError: 31 | raise ImportError(name) 32 | 33 | mock_find_entrypoint.side_effect = fake_find_entrypoint 34 | 35 | return entrypoints 36 | 37 | @mock.patch.object(redis, 'StrictRedis', return_value='db_handle') 38 | @mock.patch.object(redis, 'ConnectionPool', return_value='conn_pool') 39 | @mock.patch.object(utils, 'find_entrypoint') 40 | def test_empty_config(self, mock_find_entrypoint, mock_ConnectionPool, 41 | mock_StrictRedis): 42 | self.assertRaises(redis.ConnectionError, database.initialize, {}) 43 | 44 | self.assertFalse(mock_ConnectionPool.called) 45 | self.assertFalse(mock_StrictRedis.called) 46 | 47 | @mock.patch.object(redis, 'StrictRedis', return_value='db_handle') 48 | @mock.patch.object(redis, 'ConnectionPool', return_value='conn_pool') 49 | @mock.patch.object(utils, 'find_entrypoint') 50 | def test_host_only(self, mock_find_entrypoint, mock_ConnectionPool, 51 | mock_StrictRedis): 52 | result = database.initialize(dict(host='10.0.0.1')) 53 | 54 | self.assertEqual(result, 'db_handle') 55 | self.assertFalse(mock_ConnectionPool.called) 56 | mock_StrictRedis.assert_called_once_with(host='10.0.0.1') 57 | 58 | @mock.patch.object(redis, 'StrictRedis', return_value='db_handle') 59 | @mock.patch.object(redis, 'ConnectionPool', return_value='conn_pool') 60 | @mock.patch.object(utils, 'find_entrypoint') 61 | def test_unixpath_only(self, mock_find_entrypoint, mock_ConnectionPool, 62 | mock_StrictRedis): 63 | result = database.initialize(dict(unix_socket_path='/tmp/socket')) 64 | 65 | self.assertEqual(result, 'db_handle') 66 | self.assertFalse(mock_ConnectionPool.called) 67 | mock_StrictRedis.assert_called_once_with( 68 | unix_socket_path='/tmp/socket') 69 | 70 | @mock.patch.object(redis, 'StrictRedis', return_value='db_handle') 71 | @mock.patch.object(redis, 'ConnectionPool', return_value='conn_pool') 72 | @mock.patch.object(utils, 'find_entrypoint') 73 | def test_alt_client(self, mock_find_entrypoint, mock_ConnectionPool, 74 | mock_StrictRedis): 75 | entrypoints = self.make_entrypoints( 76 | mock_find_entrypoint, 77 | client=mock.Mock(return_value='alt_handle'), 78 | ) 79 | 80 | result = database.initialize(dict(host='10.0.0.1', 81 | redis_client='client')) 82 | 83 | self.assertEqual(result, 'alt_handle') 84 | self.assertFalse(mock_ConnectionPool.called) 85 | self.assertFalse(mock_StrictRedis.called) 86 | entrypoints['client'].assert_called_once_with(host='10.0.0.1') 87 | 88 | @mock.patch.object(redis, 'StrictRedis', return_value='db_handle') 89 | @mock.patch.object(redis, 'ConnectionPool', return_value='conn_pool') 90 | @mock.patch.object(utils, 'find_entrypoint') 91 | def test_all_options(self, mock_find_entrypoint, mock_ConnectionPool, 92 | mock_StrictRedis): 93 | result = database.initialize({ 94 | 'host': '10.0.0.1', 95 | 'port': '1234', 96 | 'db': '5', 97 | 'password': 'spampass', 98 | 'socket_timeout': '600', 99 | 'unix_socket_path': '/tmp/redis', 100 | 'extra_config': 'extra_value', 101 | }) 102 | 103 | self.assertEqual(result, 'db_handle') 104 | self.assertFalse(mock_ConnectionPool.called) 105 | mock_StrictRedis.assert_called_once_with( 106 | host='10.0.0.1', 107 | port=1234, 108 | db=5, 109 | password='spampass', 110 | socket_timeout=600, 111 | unix_socket_path='/tmp/redis', 112 | extra_config='extra_value', 113 | ) 114 | 115 | @mock.patch.object(redis, 'StrictRedis', return_value='db_handle') 116 | @mock.patch.object(redis, 'ConnectionPool', return_value='conn_pool') 117 | @mock.patch.object(utils, 'find_entrypoint') 118 | def test_cpool_options(self, mock_find_entrypoint, mock_ConnectionPool, 119 | mock_StrictRedis): 120 | entrypoints = self.make_entrypoints( 121 | mock_find_entrypoint, 122 | connection='connection_fake', 123 | parser='parser_fake', 124 | ) 125 | 126 | result = database.initialize({ 127 | 'host': '10.0.0.1', 128 | 'port': '1234', 129 | 'db': '5', 130 | 'password': 'spampass', 131 | 'socket_timeout': '600', 132 | 'unix_socket_path': '/tmp/redis', 133 | 'connection_pool.connection_class': 'connection', 134 | 'connection_pool.max_connections': '50', 135 | 'connection_pool.parser_class': 'parser', 136 | 'connection_pool.other': 'value', 137 | }) 138 | 139 | self.assertEqual(result, 'db_handle') 140 | mock_StrictRedis.assert_called_once_with(connection_pool='conn_pool') 141 | mock_ConnectionPool.assert_called_once_with( 142 | host='10.0.0.1', 143 | port=1234, 144 | db=5, 145 | password='spampass', 146 | socket_timeout=600, 147 | unix_socket_path='/tmp/redis', 148 | connection_class='connection_fake', 149 | max_connections=50, 150 | parser_class='parser_fake', 151 | other='value', 152 | ) 153 | 154 | @mock.patch.object(redis, 'StrictRedis', return_value='db_handle') 155 | @mock.patch.object(redis, 'ConnectionPool', return_value='conn_pool') 156 | @mock.patch.object(utils, 'find_entrypoint') 157 | def test_cpool_options_altpool(self, mock_find_entrypoint, 158 | mock_ConnectionPool, mock_StrictRedis): 159 | entrypoints = self.make_entrypoints( 160 | mock_find_entrypoint, 161 | connection='connection_fake', 162 | parser='parser_fake', 163 | pool=mock.Mock(return_value='pool_fake'), 164 | ) 165 | 166 | result = database.initialize({ 167 | 'host': '10.0.0.1', 168 | 'port': '1234', 169 | 'db': '5', 170 | 'password': 'spampass', 171 | 'socket_timeout': '600', 172 | 'unix_socket_path': '/tmp/redis', 173 | 'connection_pool': 'pool', 174 | 'connection_pool.connection_class': 'connection', 175 | 'connection_pool.max_connections': '50', 176 | 'connection_pool.parser_class': 'parser', 177 | 'connection_pool.other': 'value', 178 | }) 179 | 180 | self.assertEqual(result, 'db_handle') 181 | mock_StrictRedis.assert_called_once_with(connection_pool='pool_fake') 182 | self.assertFalse(mock_ConnectionPool.called) 183 | entrypoints['pool'].assert_called_once_with( 184 | host='10.0.0.1', 185 | port=1234, 186 | db=5, 187 | password='spampass', 188 | socket_timeout=600, 189 | unix_socket_path='/tmp/redis', 190 | connection_class='connection_fake', 191 | max_connections=50, 192 | parser_class='parser_fake', 193 | other='value', 194 | ) 195 | 196 | @mock.patch.object(redis, 'StrictRedis', return_value='db_handle') 197 | @mock.patch.object(redis, 'ConnectionPool', return_value='conn_pool') 198 | @mock.patch.object(utils, 'find_entrypoint') 199 | def test_cpool_unixsock(self, mock_find_entrypoint, mock_ConnectionPool, 200 | mock_StrictRedis): 201 | result = database.initialize({ 202 | 'host': '10.0.0.1', 203 | 'port': '1234', 204 | 'unix_socket_path': '/tmp/redis', 205 | 'connection_pool.other': 'value', 206 | }) 207 | 208 | self.assertEqual(result, 'db_handle') 209 | mock_StrictRedis.assert_called_once_with(connection_pool='conn_pool') 210 | mock_ConnectionPool.assert_called_once_with( 211 | path='/tmp/redis', 212 | other='value', 213 | connection_class=redis.UnixDomainSocketConnection, 214 | ) 215 | 216 | @mock.patch.object(redis, 'StrictRedis', return_value='db_handle') 217 | @mock.patch.object(redis, 'ConnectionPool', return_value='conn_pool') 218 | @mock.patch.object(utils, 'find_entrypoint') 219 | def test_cpool_host(self, mock_find_entrypoint, mock_ConnectionPool, 220 | mock_StrictRedis): 221 | result = database.initialize({ 222 | 'host': '10.0.0.1', 223 | 'port': '1234', 224 | 'connection_pool.other': 'value', 225 | }) 226 | 227 | self.assertEqual(result, 'db_handle') 228 | mock_StrictRedis.assert_called_once_with(connection_pool='conn_pool') 229 | mock_ConnectionPool.assert_called_once_with( 230 | host='10.0.0.1', 231 | port=1234, 232 | other='value', 233 | connection_class=redis.Connection, 234 | ) 235 | 236 | 237 | class TestLimitsHydrate(unittest2.TestCase): 238 | @mock.patch.object(limits.Limit, 'hydrate', 239 | side_effect=lambda x, y: "limit:%s" % y) 240 | def test_hydrate(self, mock_hydrate): 241 | result = database.limits_hydrate('db', ['lim1', 'lim2', 'lim3']) 242 | 243 | self.assertEqual(result, ['limit:lim1', 'limit:lim2', 'limit:lim3']) 244 | mock_hydrate.assert_has_calls([ 245 | mock.call('db', 'lim1'), 246 | mock.call('db', 'lim2'), 247 | mock.call('db', 'lim3'), 248 | ]) 249 | 250 | 251 | class TestLimitUpdate(unittest2.TestCase): 252 | @mock.patch('msgpack.dumps', side_effect=lambda x: x) 253 | def test_limit_update(self, mock_dumps): 254 | limits = [ 255 | mock.Mock(**{'dehydrate.return_value': 'limit1'}), 256 | mock.Mock(**{'dehydrate.return_value': 'limit2'}), 257 | mock.Mock(**{'dehydrate.return_value': 'limit3'}), 258 | mock.Mock(**{'dehydrate.return_value': 'limit4'}), 259 | mock.Mock(**{'dehydrate.return_value': 'limit5'}), 260 | mock.Mock(**{'dehydrate.return_value': 'limit6'}), 261 | ] 262 | pipe = mock.MagicMock(**{ 263 | 'zrange.return_value': [ 264 | 'limit2', 265 | 'limit4', 266 | 'limit6', 267 | 'limit8', 268 | ], 269 | }) 270 | pipe.__enter__.return_value = pipe 271 | pipe.__exit__.return_value = False 272 | db = mock.Mock(**{'pipeline.return_value': pipe}) 273 | 274 | database.limit_update(db, 'limit_key', limits) 275 | for lim in limits: 276 | lim.dehydrate.assert_called_once_with() 277 | mock_dumps.assert_has_calls([ 278 | mock.call('limit1'), 279 | mock.call('limit2'), 280 | mock.call('limit3'), 281 | mock.call('limit4'), 282 | mock.call('limit5'), 283 | mock.call('limit6'), 284 | ]) 285 | db.pipeline.assert_called_once_with() 286 | pipe.assert_has_calls([ 287 | mock.call.__enter__(), 288 | mock.call.watch('limit_key'), 289 | mock.call.zrange('limit_key', 0, -1), 290 | mock.call.multi(), 291 | mock.call.zrem('limit_key', 'limit8'), 292 | mock.call.zadd('limit_key', 10, 'limit1'), 293 | mock.call.zadd('limit_key', 20, 'limit2'), 294 | mock.call.zadd('limit_key', 30, 'limit3'), 295 | mock.call.zadd('limit_key', 40, 'limit4'), 296 | mock.call.zadd('limit_key', 50, 'limit5'), 297 | mock.call.zadd('limit_key', 60, 'limit6'), 298 | mock.call.execute(), 299 | mock.call.__exit__(None, None, None), 300 | ]) 301 | 302 | @mock.patch('msgpack.dumps', side_effect=lambda x: x) 303 | def test_limit_update_retry(self, mock_dumps): 304 | limits = [ 305 | mock.Mock(**{'dehydrate.return_value': 'limit1'}), 306 | mock.Mock(**{'dehydrate.return_value': 'limit2'}), 307 | mock.Mock(**{'dehydrate.return_value': 'limit3'}), 308 | mock.Mock(**{'dehydrate.return_value': 'limit4'}), 309 | mock.Mock(**{'dehydrate.return_value': 'limit5'}), 310 | mock.Mock(**{'dehydrate.return_value': 'limit6'}), 311 | ] 312 | pipe = mock.MagicMock(**{ 313 | 'zrange.return_value': [ 314 | 'limit2', 315 | 'limit4', 316 | 'limit6', 317 | 'limit8', 318 | ], 319 | 'execute.side_effect': [redis.WatchError, None], 320 | }) 321 | pipe.__enter__.return_value = pipe 322 | pipe.__exit__.return_value = False 323 | db = mock.Mock(**{'pipeline.return_value': pipe}) 324 | 325 | database.limit_update(db, 'limit_key', limits) 326 | for lim in limits: 327 | lim.dehydrate.assert_called_once_with() 328 | mock_dumps.assert_has_calls([ 329 | mock.call('limit1'), 330 | mock.call('limit2'), 331 | mock.call('limit3'), 332 | mock.call('limit4'), 333 | mock.call('limit5'), 334 | mock.call('limit6'), 335 | ]) 336 | db.pipeline.assert_called_once_with() 337 | pipe.assert_has_calls([ 338 | mock.call.__enter__(), 339 | mock.call.watch('limit_key'), 340 | mock.call.zrange('limit_key', 0, -1), 341 | mock.call.multi(), 342 | mock.call.zrem('limit_key', 'limit8'), 343 | mock.call.zadd('limit_key', 10, 'limit1'), 344 | mock.call.zadd('limit_key', 20, 'limit2'), 345 | mock.call.zadd('limit_key', 30, 'limit3'), 346 | mock.call.zadd('limit_key', 40, 'limit4'), 347 | mock.call.zadd('limit_key', 50, 'limit5'), 348 | mock.call.zadd('limit_key', 60, 'limit6'), 349 | mock.call.execute(), 350 | mock.call.watch('limit_key'), 351 | mock.call.zrange('limit_key', 0, -1), 352 | mock.call.multi(), 353 | mock.call.zrem('limit_key', 'limit8'), 354 | mock.call.zadd('limit_key', 10, 'limit1'), 355 | mock.call.zadd('limit_key', 20, 'limit2'), 356 | mock.call.zadd('limit_key', 30, 'limit3'), 357 | mock.call.zadd('limit_key', 40, 'limit4'), 358 | mock.call.zadd('limit_key', 50, 'limit5'), 359 | mock.call.zadd('limit_key', 60, 'limit6'), 360 | mock.call.execute(), 361 | mock.call.__exit__(None, None, None), 362 | ]) 363 | 364 | 365 | class TestCommand(unittest2.TestCase): 366 | def test_command(self): 367 | db = mock.Mock() 368 | 369 | database.command(db, 'channel', 'command', 'one', 2, 3.14) 370 | 371 | db.publish.assert_called_once_with('channel', 'command:one:2:3.14') 372 | -------------------------------------------------------------------------------- /tests/unit/test_middleware.py: -------------------------------------------------------------------------------- 1 | # Copyright 2013 Rackspace 2 | # All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | # not use this file except in compliance with the License. You may obtain 6 | # a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13 | # License for the specific language governing permissions and limitations 14 | # under the License. 15 | 16 | import eventlet.semaphore 17 | import mock 18 | import unittest2 19 | 20 | from turnstile import config 21 | from turnstile import control 22 | from turnstile import database 23 | from turnstile import middleware 24 | from turnstile import remote 25 | from turnstile import utils 26 | 27 | from tests.unit import utils as test_utils 28 | 29 | 30 | class TestHeadersDict(unittest2.TestCase): 31 | def test_init_sequence(self): 32 | hd = middleware.HeadersDict([('Foo', 'value'), ('bAR', 'VALUE')]) 33 | 34 | self.assertEqual(hd.headers, dict(foo='value', bar='VALUE')) 35 | 36 | def test_init_dict(self): 37 | hd = middleware.HeadersDict(dict(Foo='value', bAR='VALUE')) 38 | 39 | self.assertEqual(hd.headers, dict(foo='value', bar='VALUE')) 40 | 41 | def test_init_kwargs(self): 42 | hd = middleware.HeadersDict(Foo='value', bAR='VALUE') 43 | 44 | self.assertEqual(hd.headers, dict(foo='value', bar='VALUE')) 45 | 46 | def test_get_item(self): 47 | hd = middleware.HeadersDict(Foo='value') 48 | 49 | self.assertEqual(hd['foo'], 'value') 50 | self.assertEqual(hd['Foo'], 'value') 51 | with self.assertRaises(KeyError): 52 | foo = hd['bar'] 53 | 54 | def test_set_item(self): 55 | hd = middleware.HeadersDict(Foo='value') 56 | 57 | hd['fOO'] = 'bar' 58 | self.assertEqual(hd.headers, dict(foo='bar')) 59 | hd['bAr'] = 'blah' 60 | self.assertEqual(hd.headers, dict(foo='bar', bar='blah')) 61 | 62 | def test_del_item(self): 63 | hd = middleware.HeadersDict(Foo='value', bAR='VALUE') 64 | 65 | del hd['fOO'] 66 | self.assertEqual(hd.headers, dict(bar='VALUE')) 67 | del hd['bar'] 68 | self.assertEqual(hd.headers, {}) 69 | with self.assertRaises(KeyError): 70 | del hd['baz'] 71 | 72 | def test_contains(self): 73 | hd = middleware.HeadersDict(Foo='value') 74 | 75 | self.assertTrue('foo' in hd) 76 | self.assertTrue('fOO' in hd) 77 | self.assertFalse('bAR' in hd) 78 | 79 | def test_iter(self): 80 | hd = middleware.HeadersDict(Foo='value', bAR='VALUE') 81 | 82 | result = sorted(list(iter(hd))) 83 | self.assertEqual(result, ['bar', 'foo']) 84 | 85 | def test_len(self): 86 | hd = middleware.HeadersDict(Foo='value') 87 | 88 | self.assertEqual(len(hd), 1) 89 | hd['bAR'] = 'VALUE' 90 | self.assertEqual(len(hd), 2) 91 | 92 | def test_iterkeys(self): 93 | hd = middleware.HeadersDict(Foo='value', bAR='VALUE') 94 | 95 | result = sorted(list(hd.iterkeys())) 96 | self.assertEqual(result, ['bar', 'foo']) 97 | 98 | def test_iteritems(self): 99 | hd = middleware.HeadersDict(Foo='value', bAR='VALUE') 100 | 101 | result = sorted(list(hd.iteritems())) 102 | self.assertEqual(result, [('bar', 'VALUE'), ('foo', 'value')]) 103 | 104 | def test_itervalues(self): 105 | hd = middleware.HeadersDict(Foo='value', bAR='VALUE') 106 | 107 | result = sorted(list(hd.itervalues())) 108 | self.assertEqual(result, ['VALUE', 'value']) 109 | 110 | def test_keys(self): 111 | hd = middleware.HeadersDict(Foo='value', bAR='VALUE') 112 | 113 | result = sorted(list(hd.keys())) 114 | self.assertEqual(result, ['bar', 'foo']) 115 | 116 | def test_items(self): 117 | hd = middleware.HeadersDict(Foo='value', bAR='VALUE') 118 | 119 | result = sorted(list(hd.items())) 120 | self.assertEqual(result, [('bar', 'VALUE'), ('foo', 'value')]) 121 | 122 | def test_values(self): 123 | hd = middleware.HeadersDict(Foo='value', bAR='VALUE') 124 | 125 | result = sorted(list(hd.values())) 126 | self.assertEqual(result, ['VALUE', 'value']) 127 | 128 | 129 | class TestTurnstileFilter(unittest2.TestCase): 130 | @mock.patch.object(middleware, 'TurnstileMiddleware', 131 | return_value='middleware') 132 | @mock.patch.object(utils, 'find_entrypoint') 133 | def test_filter_basic(self, mock_find_entrypoint, 134 | mock_TurnstileMiddleware): 135 | midware_class = middleware.turnstile_filter({}) 136 | 137 | self.assertFalse(mock_find_entrypoint.called) 138 | self.assertFalse(mock_TurnstileMiddleware.called) 139 | 140 | midware = midware_class('app') 141 | 142 | mock_TurnstileMiddleware.assert_called_once_with('app', {}) 143 | self.assertEqual(midware, 'middleware') 144 | 145 | @mock.patch.object(middleware, 'TurnstileMiddleware') 146 | @mock.patch.object(utils, 'find_entrypoint', 147 | return_value=mock.Mock(return_value='middleware')) 148 | def test_filter_alt_middleware(self, mock_find_entrypoint, 149 | mock_TurnstileMiddleware): 150 | midware_class = middleware.turnstile_filter({}, turnstile='spam') 151 | 152 | mock_find_entrypoint.assert_called_once_with( 153 | 'turnstile.middleware', 'spam', required=True) 154 | self.assertFalse(mock_find_entrypoint.return_value.called) 155 | self.assertFalse(mock_TurnstileMiddleware.called) 156 | 157 | midware = midware_class('app') 158 | 159 | mock_find_entrypoint.return_value.assert_called_once_with( 160 | 'app', dict(turnstile='spam')) 161 | self.assertFalse(mock_TurnstileMiddleware.called) 162 | self.assertEqual(midware, 'middleware') 163 | 164 | 165 | class TestTurnstileMiddleware(unittest2.TestCase): 166 | @mock.patch.object(utils, 'find_entrypoint') 167 | @mock.patch.object(control, 'ControlDaemon') 168 | @mock.patch.object(remote, 'RemoteControlDaemon') 169 | @mock.patch.object(middleware.LOG, 'info') 170 | def test_init_basic(self, mock_info, mock_RemoteControlDaemon, 171 | mock_ControlDaemon, mock_find_entrypoint): 172 | midware = middleware.TurnstileMiddleware('app', {}) 173 | 174 | self.assertEqual(midware.app, 'app') 175 | self.assertEqual(midware.limits, []) 176 | self.assertEqual(midware.limit_sum, None) 177 | self.assertEqual(midware.mapper, None) 178 | self.assertIsInstance(midware.mapper_lock, 179 | eventlet.semaphore.Semaphore) 180 | self.assertEqual(midware.conf._config, { 181 | None: dict(status='413 Request Entity Too Large'), 182 | }) 183 | self.assertEqual(midware._db, None) 184 | self.assertEqual(midware.preprocessors, []) 185 | self.assertEqual(midware.postprocessors, []) 186 | self.assertEqual(midware.formatter, midware.format_delay) 187 | self.assertFalse(mock_RemoteControlDaemon.called) 188 | mock_ControlDaemon.assert_has_calls([ 189 | mock.call(midware, midware.conf), 190 | mock.call().start(), 191 | ]) 192 | mock_info.assert_called_once_with("Turnstile middleware initialized") 193 | 194 | @mock.patch.object(utils, 'find_entrypoint', return_value=mock.Mock()) 195 | @mock.patch.object(control, 'ControlDaemon') 196 | @mock.patch.object(remote, 'RemoteControlDaemon') 197 | @mock.patch.object(middleware.LOG, 'info') 198 | def test_init_formatter(self, mock_info, mock_RemoteControlDaemon, 199 | mock_ControlDaemon, mock_find_entrypoint): 200 | fake_formatter = mock_find_entrypoint.return_value 201 | midware = middleware.TurnstileMiddleware('app', 202 | dict(formatter='formatter')) 203 | 204 | self.assertEqual(midware.app, 'app') 205 | self.assertEqual(midware.limits, []) 206 | self.assertEqual(midware.limit_sum, None) 207 | self.assertEqual(midware.mapper, None) 208 | self.assertIsInstance(midware.mapper_lock, 209 | eventlet.semaphore.Semaphore) 210 | self.assertEqual(midware.conf._config, { 211 | None: { 212 | 'status': '413 Request Entity Too Large', 213 | 'formatter': 'formatter', 214 | }, 215 | }) 216 | self.assertEqual(midware._db, None) 217 | self.assertEqual(midware.preprocessors, []) 218 | self.assertEqual(midware.postprocessors, []) 219 | mock_find_entrypoint.assert_called_once_with( 220 | 'turnstile.formatter', 'formatter', required=True) 221 | self.assertFalse(mock_RemoteControlDaemon.called) 222 | mock_ControlDaemon.assert_has_calls([ 223 | mock.call(midware, midware.conf), 224 | mock.call().start(), 225 | ]) 226 | mock_info.assert_called_once_with("Turnstile middleware initialized") 227 | midware.formatter('delay', 'limit', 'bucket', 'environ', 'start') 228 | fake_formatter.assert_called_once_with( 229 | '413 Request Entity Too Large', 'delay', 'limit', 'bucket', 230 | 'environ', 'start') 231 | 232 | @mock.patch.object(utils, 'find_entrypoint') 233 | @mock.patch.object(control, 'ControlDaemon') 234 | @mock.patch.object(remote, 'RemoteControlDaemon') 235 | @mock.patch.object(middleware.LOG, 'info') 236 | def test_init_remote(self, mock_info, mock_RemoteControlDaemon, 237 | mock_ControlDaemon, mock_find_entrypoint): 238 | midware = middleware.TurnstileMiddleware('app', { 239 | 'control.remote': 'yes', 240 | }) 241 | 242 | self.assertEqual(midware.app, 'app') 243 | self.assertEqual(midware.limits, []) 244 | self.assertEqual(midware.limit_sum, None) 245 | self.assertEqual(midware.mapper, None) 246 | self.assertIsInstance(midware.mapper_lock, 247 | eventlet.semaphore.Semaphore) 248 | self.assertEqual(midware.conf._config, { 249 | None: dict(status='413 Request Entity Too Large'), 250 | 'control': dict(remote='yes'), 251 | }) 252 | self.assertEqual(midware._db, None) 253 | self.assertEqual(midware.preprocessors, []) 254 | self.assertEqual(midware.postprocessors, []) 255 | self.assertEqual(midware.formatter, midware.format_delay) 256 | self.assertFalse(mock_ControlDaemon.called) 257 | mock_RemoteControlDaemon.assert_has_calls([ 258 | mock.call(midware, midware.conf), 259 | mock.call().start(), 260 | ]) 261 | mock_info.assert_called_once_with("Turnstile middleware initialized") 262 | 263 | @mock.patch.object(utils, 'find_entrypoint') 264 | @mock.patch.object(control, 'ControlDaemon') 265 | @mock.patch.object(remote, 'RemoteControlDaemon') 266 | @mock.patch.object(middleware.LOG, 'info') 267 | def test_init_enable(self, mock_info, mock_RemoteControlDaemon, 268 | mock_ControlDaemon, mock_find_entrypoint): 269 | entrypoints = { 270 | 'turnstile.preprocessor': { 271 | 'ep1': 'preproc1', 272 | 'ep3': 'preproc3', 273 | 'ep4': 'preproc4', 274 | 'ep6': 'preproc6', 275 | }, 276 | 'turnstile.postprocessor': { 277 | 'ep2': 'postproc2', 278 | 'ep4': 'postproc4', 279 | 'ep6': 'postproc6', 280 | }, 281 | } 282 | 283 | mock_find_entrypoint.side_effect = \ 284 | lambda x, y, compat=True: entrypoints[x].get(y) 285 | 286 | midware = middleware.TurnstileMiddleware('app', { 287 | 'enable': 'ep1 ep2 ep3 ep4 ep5 ep6', 288 | }) 289 | 290 | self.assertEqual(midware.app, 'app') 291 | self.assertEqual(midware.limits, []) 292 | self.assertEqual(midware.limit_sum, None) 293 | self.assertEqual(midware.mapper, None) 294 | self.assertIsInstance(midware.mapper_lock, 295 | eventlet.semaphore.Semaphore) 296 | self.assertEqual(midware.conf._config, { 297 | None: dict(status='413 Request Entity Too Large', 298 | enable='ep1 ep2 ep3 ep4 ep5 ep6'), 299 | }) 300 | self.assertEqual(midware._db, None) 301 | self.assertEqual(midware.preprocessors, [ 302 | 'preproc1', 303 | 'preproc3', 304 | 'preproc4', 305 | 'preproc6', 306 | ]) 307 | self.assertEqual(midware.postprocessors, [ 308 | 'postproc6', 309 | 'postproc4', 310 | 'postproc2', 311 | ]) 312 | self.assertEqual(midware.formatter, midware.format_delay) 313 | self.assertFalse(mock_RemoteControlDaemon.called) 314 | mock_ControlDaemon.assert_has_calls([ 315 | mock.call(midware, midware.conf), 316 | mock.call().start(), 317 | ]) 318 | mock_info.assert_called_once_with("Turnstile middleware initialized") 319 | 320 | @mock.patch.object(utils, 'find_entrypoint') 321 | @mock.patch.object(control, 'ControlDaemon') 322 | @mock.patch.object(remote, 'RemoteControlDaemon') 323 | @mock.patch.object(middleware.LOG, 'info') 324 | def test_init_processors(self, mock_info, mock_RemoteControlDaemon, 325 | mock_ControlDaemon, mock_find_entrypoint): 326 | entrypoints = { 327 | 'turnstile.preprocessor': { 328 | 'ep1': 'preproc1', 329 | 'ep3': 'preproc3', 330 | 'ep4': 'preproc4', 331 | 'ep6': 'preproc6', 332 | 'preproc:ep5': 'preproc5', 333 | }, 334 | 'turnstile.postprocessor': { 335 | 'ep2': 'postproc2', 336 | 'ep4': 'postproc4', 337 | 'ep6': 'postproc6', 338 | 'postproc:ep5': 'postproc5', 339 | }, 340 | } 341 | 342 | mock_find_entrypoint.side_effect = \ 343 | lambda x, y, required=False: entrypoints[x].get(y) 344 | 345 | midware = middleware.TurnstileMiddleware('app', { 346 | 'preprocess': 'ep1 ep3 ep4 preproc:ep5 ep6', 347 | 'postprocess': 'ep6 postproc:ep5 ep4 ep2', 348 | }) 349 | 350 | self.assertEqual(midware.app, 'app') 351 | self.assertEqual(midware.limits, []) 352 | self.assertEqual(midware.limit_sum, None) 353 | self.assertEqual(midware.mapper, None) 354 | self.assertIsInstance(midware.mapper_lock, 355 | eventlet.semaphore.Semaphore) 356 | self.assertEqual(midware.conf._config, { 357 | None: dict(status='413 Request Entity Too Large', 358 | preprocess='ep1 ep3 ep4 preproc:ep5 ep6', 359 | postprocess='ep6 postproc:ep5 ep4 ep2'), 360 | }) 361 | self.assertEqual(midware._db, None) 362 | self.assertEqual(midware.preprocessors, [ 363 | 'preproc1', 364 | 'preproc3', 365 | 'preproc4', 366 | 'preproc5', 367 | 'preproc6', 368 | ]) 369 | self.assertEqual(midware.postprocessors, [ 370 | 'postproc6', 371 | 'postproc5', 372 | 'postproc4', 373 | 'postproc2', 374 | ]) 375 | self.assertEqual(midware.formatter, midware.format_delay) 376 | self.assertFalse(mock_RemoteControlDaemon.called) 377 | mock_ControlDaemon.assert_has_calls([ 378 | mock.call(midware, midware.conf), 379 | mock.call().start(), 380 | ]) 381 | mock_info.assert_called_once_with("Turnstile middleware initialized") 382 | 383 | @mock.patch('traceback.format_exc', return_value='') 384 | @mock.patch.object(control, 'ControlDaemon') 385 | @mock.patch.object(middleware.LOG, 'info') 386 | @mock.patch.object(middleware.LOG, 'exception') 387 | @mock.patch.object(database, 'limits_hydrate', return_value=[ 388 | mock.Mock(), 389 | mock.Mock(), 390 | ]) 391 | @mock.patch('routes.Mapper', return_value='mapper') 392 | def test_recheck_limits_basic(self, mock_Mapper, mock_limits_hydrate, 393 | mock_exception, mock_info, 394 | mock_ControlDaemon, mock_format_exc): 395 | limit_data = mock.Mock(**{ 396 | 'get_limits.return_value': ('new_sum', ['limit1', 'limit2']), 397 | }) 398 | mock_ControlDaemon.return_value = mock.Mock(**{ 399 | 'get_limits.return_value': limit_data, 400 | }) 401 | midware = middleware.TurnstileMiddleware('app', {}) 402 | midware.limits = ['old_limit1', 'old_limit2'] 403 | midware.limit_sum = 'old_sum' 404 | midware.mapper = 'old_mapper' 405 | midware._db = mock.Mock() 406 | 407 | midware.recheck_limits() 408 | 409 | mock_ControlDaemon.return_value.get_limits.assert_called_once_with() 410 | limit_data.get_limits.assert_called_once_with('old_sum') 411 | mock_limits_hydrate.assert_called_once_with(midware._db, 412 | ['limit1', 'limit2']) 413 | mock_Mapper.assert_called_once_with(register=False) 414 | for lim in mock_limits_hydrate.return_value: 415 | lim._route.assert_called_once_with('mapper') 416 | self.assertEqual(midware.limits, mock_limits_hydrate.return_value) 417 | self.assertEqual(midware.limit_sum, 'new_sum') 418 | self.assertEqual(midware.mapper, 'mapper') 419 | self.assertFalse(mock_exception.called) 420 | self.assertFalse(mock_format_exc.called) 421 | self.assertEqual(len(midware._db.method_calls), 0) 422 | 423 | @mock.patch('traceback.format_exc', return_value='') 424 | @mock.patch.object(control, 'ControlDaemon') 425 | @mock.patch.object(middleware.LOG, 'info') 426 | @mock.patch.object(middleware.LOG, 'exception') 427 | @mock.patch.object(database, 'limits_hydrate', return_value=[ 428 | mock.Mock(), 429 | mock.Mock(), 430 | ]) 431 | @mock.patch('routes.Mapper', return_value='mapper') 432 | def test_recheck_limits_unchanged(self, mock_Mapper, mock_limits_hydrate, 433 | mock_exception, mock_info, 434 | mock_ControlDaemon, mock_format_exc): 435 | limit_data = mock.Mock(**{ 436 | 'get_limits.side_effect': control.NoChangeException, 437 | }) 438 | mock_ControlDaemon.return_value = mock.Mock(**{ 439 | 'get_limits.return_value': limit_data, 440 | }) 441 | midware = middleware.TurnstileMiddleware('app', {}) 442 | midware.limits = ['old_limit1', 'old_limit2'] 443 | midware.limit_sum = 'old_sum' 444 | midware.mapper = 'old_mapper' 445 | midware._db = mock.Mock() 446 | 447 | midware.recheck_limits() 448 | 449 | mock_ControlDaemon.return_value.get_limits.assert_called_once_with() 450 | limit_data.get_limits.assert_called_once_with('old_sum') 451 | self.assertFalse(mock_limits_hydrate.called) 452 | self.assertFalse(mock_Mapper.called) 453 | for lim in mock_limits_hydrate.return_value: 454 | self.assertFalse(lim._route.called) 455 | self.assertEqual(midware.limits, ['old_limit1', 'old_limit2']) 456 | self.assertEqual(midware.limit_sum, 'old_sum') 457 | self.assertEqual(midware.mapper, 'old_mapper') 458 | self.assertFalse(mock_exception.called) 459 | self.assertFalse(mock_format_exc.called) 460 | self.assertEqual(len(midware._db.method_calls), 0) 461 | 462 | @mock.patch('traceback.format_exc', return_value='') 463 | @mock.patch.object(control, 'ControlDaemon') 464 | @mock.patch.object(middleware.LOG, 'info') 465 | @mock.patch.object(middleware.LOG, 'exception') 466 | @mock.patch.object(database, 'limits_hydrate', return_value=[ 467 | mock.Mock(), 468 | mock.Mock(), 469 | ]) 470 | @mock.patch('routes.Mapper', return_value='mapper') 471 | def test_recheck_limits_exception(self, mock_Mapper, mock_limits_hydrate, 472 | mock_exception, mock_info, 473 | mock_ControlDaemon, mock_format_exc): 474 | limit_data = mock.Mock(**{ 475 | 'get_limits.side_effect': test_utils.TestException, 476 | }) 477 | mock_ControlDaemon.return_value = mock.Mock(**{ 478 | 'get_limits.return_value': limit_data, 479 | }) 480 | midware = middleware.TurnstileMiddleware('app', {}) 481 | midware.limits = ['old_limit1', 'old_limit2'] 482 | midware.limit_sum = 'old_sum' 483 | midware.mapper = 'old_mapper' 484 | midware._db = mock.Mock() 485 | 486 | midware.recheck_limits() 487 | 488 | mock_ControlDaemon.return_value.get_limits.assert_called_once_with() 489 | limit_data.get_limits.assert_called_once_with('old_sum') 490 | self.assertFalse(mock_limits_hydrate.called) 491 | self.assertFalse(mock_Mapper.called) 492 | for lim in mock_limits_hydrate.return_value: 493 | self.assertFalse(lim._route.called) 494 | self.assertEqual(midware.limits, ['old_limit1', 'old_limit2']) 495 | self.assertEqual(midware.limit_sum, 'old_sum') 496 | self.assertEqual(midware.mapper, 'old_mapper') 497 | mock_exception.assert_called_once_with("Could not load limits") 498 | mock_format_exc.assert_called_once_with() 499 | midware._db.assert_has_calls([ 500 | mock.call.sadd('errors', 'Failed to load limits: '), 501 | mock.call.publish('errors', 'Failed to load limits: '), 502 | ]) 503 | 504 | @mock.patch('traceback.format_exc', return_value='') 505 | @mock.patch.object(control, 'ControlDaemon') 506 | @mock.patch.object(middleware.LOG, 'info') 507 | @mock.patch.object(middleware.LOG, 'exception') 508 | @mock.patch.object(database, 'limits_hydrate', return_value=[ 509 | mock.Mock(), 510 | mock.Mock(), 511 | ]) 512 | @mock.patch('routes.Mapper', return_value='mapper') 513 | def test_recheck_limits_exception_altkeys(self, mock_Mapper, 514 | mock_limits_hydrate, 515 | mock_exception, mock_info, 516 | mock_ControlDaemon, 517 | mock_format_exc): 518 | limit_data = mock.Mock(**{ 519 | 'get_limits.side_effect': test_utils.TestException, 520 | }) 521 | mock_ControlDaemon.return_value = mock.Mock(**{ 522 | 'get_limits.return_value': limit_data, 523 | }) 524 | midware = middleware.TurnstileMiddleware('app', { 525 | 'control.errors_key': 'eset', 526 | 'control.errors_channel': 'epub', 527 | }) 528 | midware.limits = ['old_limit1', 'old_limit2'] 529 | midware.limit_sum = 'old_sum' 530 | midware.mapper = 'old_mapper' 531 | midware._db = mock.Mock() 532 | 533 | midware.recheck_limits() 534 | 535 | mock_ControlDaemon.return_value.get_limits.assert_called_once_with() 536 | limit_data.get_limits.assert_called_once_with('old_sum') 537 | self.assertFalse(mock_limits_hydrate.called) 538 | self.assertFalse(mock_Mapper.called) 539 | for lim in mock_limits_hydrate.return_value: 540 | self.assertFalse(lim._route.called) 541 | self.assertEqual(midware.limits, ['old_limit1', 'old_limit2']) 542 | self.assertEqual(midware.limit_sum, 'old_sum') 543 | self.assertEqual(midware.mapper, 'old_mapper') 544 | mock_exception.assert_called_once_with("Could not load limits") 545 | mock_format_exc.assert_called_once_with() 546 | midware._db.assert_has_calls([ 547 | mock.call.sadd('eset', 'Failed to load limits: '), 548 | mock.call.publish('epub', 'Failed to load limits: '), 549 | ]) 550 | 551 | @mock.patch.object(control, 'ControlDaemon') 552 | @mock.patch.object(middleware.LOG, 'info') 553 | @mock.patch.object(middleware.TurnstileMiddleware, 'recheck_limits') 554 | @mock.patch.object(middleware.TurnstileMiddleware, 'format_delay', 555 | return_value='formatted delay') 556 | def test_call_basic(self, mock_format_delay, mock_recheck_limits, 557 | mock_info, mock_ControlDaemon): 558 | app = mock.Mock(return_value='app response') 559 | midware = middleware.TurnstileMiddleware(app, {}) 560 | midware.mapper = mock.Mock() 561 | environ = {} 562 | 563 | result = midware(environ, 'start_response') 564 | 565 | self.assertEqual(result, 'app response') 566 | mock_recheck_limits.assert_called_once_with() 567 | midware.mapper.routematch.assert_called_once_with(environ=environ) 568 | self.assertFalse(mock_format_delay.called) 569 | app.assert_called_once_with(environ, 'start_response') 570 | self.assertEqual(environ, { 571 | 'turnstile.conf': midware.conf, 572 | }) 573 | 574 | @mock.patch.object(control, 'ControlDaemon') 575 | @mock.patch.object(middleware.LOG, 'info') 576 | @mock.patch.object(middleware.TurnstileMiddleware, 'recheck_limits') 577 | @mock.patch.object(middleware.TurnstileMiddleware, 'format_delay', 578 | return_value='formatted delay') 579 | def test_call_processors(self, mock_format_delay, mock_recheck_limits, 580 | mock_info, mock_ControlDaemon): 581 | app = mock.Mock(return_value='app response') 582 | midware = middleware.TurnstileMiddleware(app, {}) 583 | midware.mapper = mock.Mock() 584 | midware.preprocessors = [mock.Mock(), mock.Mock()] 585 | midware.postprocessors = [mock.Mock(), mock.Mock()] 586 | environ = {} 587 | 588 | result = midware(environ, 'start_response') 589 | 590 | self.assertEqual(result, 'app response') 591 | mock_recheck_limits.assert_called_once_with() 592 | for proc in midware.preprocessors: 593 | proc.assert_called_once_with(midware, environ) 594 | midware.mapper.routematch.assert_called_once_with(environ=environ) 595 | self.assertFalse(mock_format_delay.called) 596 | for proc in midware.postprocessors: 597 | proc.assert_called_once_with(midware, environ) 598 | app.assert_called_once_with(environ, 'start_response') 599 | self.assertEqual(environ, { 600 | 'turnstile.conf': midware.conf, 601 | }) 602 | 603 | @mock.patch.object(control, 'ControlDaemon') 604 | @mock.patch.object(middleware.LOG, 'info') 605 | @mock.patch.object(middleware.TurnstileMiddleware, 'recheck_limits') 606 | @mock.patch.object(middleware.TurnstileMiddleware, 'format_delay', 607 | return_value='formatted delay') 608 | def test_call_delay(self, mock_format_delay, mock_recheck_limits, 609 | mock_info, mock_ControlDaemon): 610 | app = mock.Mock(return_value='app response') 611 | midware = middleware.TurnstileMiddleware(app, {}) 612 | midware.mapper = mock.Mock() 613 | midware.preprocessors = [mock.Mock(), mock.Mock()] 614 | midware.postprocessors = [mock.Mock(), mock.Mock()] 615 | environ = { 616 | 'turnstile.delay': [ 617 | (30, 'limit1', 'bucket1'), 618 | (20, 'limit2', 'bucket2'), 619 | (60, 'limit3', 'bucket3'), 620 | (10, 'limit4', 'bucket4'), 621 | ], 622 | } 623 | 624 | result = midware(environ, 'start_response') 625 | 626 | self.assertEqual(result, 'formatted delay') 627 | mock_recheck_limits.assert_called_once_with() 628 | for proc in midware.preprocessors: 629 | proc.assert_called_once_with(midware, environ) 630 | midware.mapper.routematch.assert_called_once_with(environ=environ) 631 | mock_format_delay.assert_called_once_with(60, 'limit3', 'bucket3', 632 | environ, 'start_response') 633 | for proc in midware.postprocessors: 634 | self.assertFalse(proc.called) 635 | self.assertFalse(app.called) 636 | self.assertEqual(environ, { 637 | 'turnstile.delay': [ 638 | (30, 'limit1', 'bucket1'), 639 | (20, 'limit2', 'bucket2'), 640 | (60, 'limit3', 'bucket3'), 641 | (10, 'limit4', 'bucket4'), 642 | ], 643 | 'turnstile.conf': midware.conf, 644 | }) 645 | 646 | @mock.patch.object(control, 'ControlDaemon') 647 | @mock.patch.object(middleware.LOG, 'info') 648 | @mock.patch.object(middleware, 'HeadersDict', return_value=mock.Mock(**{ 649 | 'items.return_value': 'header items', 650 | })) 651 | def test_format_delay(self, mock_HeadersDict, mock_info, 652 | mock_ControlDaemon): 653 | midware = middleware.TurnstileMiddleware('app', {}) 654 | limit = mock.Mock(**{ 655 | 'format.return_value': ('limit status', 'limit entity'), 656 | }) 657 | start_response = mock.Mock() 658 | 659 | result = midware.format_delay(10.1, limit, 'bucket', 'environ', 660 | start_response) 661 | 662 | self.assertEqual(result, 'limit entity') 663 | mock_HeadersDict.assert_called_once_with([('Retry-After', '11')]) 664 | limit.format.assert_called_once_with( 665 | '413 Request Entity Too Large', mock_HeadersDict.return_value, 666 | 'environ', 'bucket', 10.1) 667 | start_response.assert_called_once_with( 668 | 'limit status', 'header items') 669 | 670 | @mock.patch.object(control, 'ControlDaemon') 671 | @mock.patch.object(middleware.LOG, 'info') 672 | @mock.patch.object(middleware, 'HeadersDict', return_value=mock.Mock(**{ 673 | 'items.return_value': 'header items', 674 | })) 675 | def test_format_delay_altstatus(self, mock_HeadersDict, mock_info, 676 | mock_ControlDaemon): 677 | midware = middleware.TurnstileMiddleware('app', { 678 | 'status': 'some other status', 679 | }) 680 | limit = mock.Mock(**{ 681 | 'format.return_value': ('limit status', 'limit entity'), 682 | }) 683 | start_response = mock.Mock() 684 | 685 | result = midware.format_delay(10.1, limit, 'bucket', 'environ', 686 | start_response) 687 | 688 | self.assertEqual(result, 'limit entity') 689 | mock_HeadersDict.assert_called_once_with([('Retry-After', '11')]) 690 | limit.format.assert_called_once_with( 691 | 'some other status', mock_HeadersDict.return_value, 692 | 'environ', 'bucket', 10.1) 693 | start_response.assert_called_once_with( 694 | 'limit status', 'header items') 695 | 696 | @mock.patch.object(control, 'ControlDaemon') 697 | @mock.patch.object(middleware.LOG, 'info') 698 | @mock.patch.object(config.Config, 'get_database', return_value='database') 699 | def test_db(self, mock_get_database, mock_info, mock_ControlDaemon): 700 | midware = middleware.TurnstileMiddleware('app', {}) 701 | 702 | db = midware.db 703 | 704 | self.assertEqual(db, 'database') 705 | mock_get_database.assert_called_once_with() 706 | 707 | @mock.patch.object(control, 'ControlDaemon') 708 | @mock.patch.object(middleware.LOG, 'info') 709 | @mock.patch.object(config.Config, 'get_database', return_value='database') 710 | def test_db_cached(self, mock_get_database, mock_info, mock_ControlDaemon): 711 | midware = middleware.TurnstileMiddleware('app', {}) 712 | midware._db = 'cached' 713 | 714 | db = midware.db 715 | 716 | self.assertEqual(db, 'cached') 717 | self.assertFalse(mock_get_database.called) 718 | -------------------------------------------------------------------------------- /tests/unit/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2013 Rackspace 2 | # All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | # not use this file except in compliance with the License. You may obtain 6 | # a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13 | # License for the specific language governing permissions and limitations 14 | # under the License. 15 | 16 | import mock 17 | import pkg_resources 18 | import unittest2 19 | 20 | from turnstile import utils 21 | 22 | from tests.unit import utils as test_utils 23 | 24 | 25 | class TestFindEntryPoint(unittest2.TestCase): 26 | @mock.patch.object(pkg_resources, 'iter_entry_points', 27 | return_value=[mock.Mock(**{ 28 | 'load.return_value': 'ep1', 29 | })]) 30 | def test_straight_load(self, mock_iter_entry_points): 31 | result = utils.find_entrypoint('test.group', 'endpoint') 32 | 33 | self.assertEqual(result, 'ep1') 34 | mock_iter_entry_points.assert_called_once_with( 35 | 'test.group', 'endpoint') 36 | mock_iter_entry_points.return_value[0].load.assert_called_once_with() 37 | 38 | @mock.patch.object(pkg_resources, 'iter_entry_points', 39 | return_value=[ 40 | mock.Mock(**{ 41 | 'load.side_effect': ImportError, 42 | }), 43 | mock.Mock(**{ 44 | 'load.side_effect': pkg_resources.UnknownExtra, 45 | }), 46 | mock.Mock(**{ 47 | 'load.return_value': 'ep3', 48 | }), 49 | mock.Mock(**{ 50 | 'load.return_value': 'ep4', 51 | }), 52 | ]) 53 | def test_skip_errors(self, mock_iter_entry_points): 54 | result = utils.find_entrypoint('test.group', 'endpoint') 55 | 56 | self.assertEqual(result, 'ep3') 57 | mock_iter_entry_points.assert_called_once_with( 58 | 'test.group', 'endpoint') 59 | mock_iter_entry_points.return_value[0].load.assert_called_once_with() 60 | mock_iter_entry_points.return_value[1].load.assert_called_once_with() 61 | mock_iter_entry_points.return_value[2].load.assert_called_once_with() 62 | self.assertFalse(mock_iter_entry_points.return_value[3].load.called) 63 | 64 | @mock.patch.object(pkg_resources, 'iter_entry_points', return_value=[]) 65 | def test_no_endpoints(self, mock_iter_entry_points): 66 | result = utils.find_entrypoint('test.group', 'endpoint') 67 | 68 | self.assertEqual(result, None) 69 | mock_iter_entry_points.assert_called_once_with( 70 | 'test.group', 'endpoint') 71 | 72 | @mock.patch.object(pkg_resources, 'iter_entry_points', return_value=[]) 73 | def test_no_endpoints_required(self, mock_iter_entry_points): 74 | self.assertRaises(ImportError, utils.find_entrypoint, 75 | 'test.group', 'endpoint', required=True) 76 | 77 | mock_iter_entry_points.assert_called_once_with( 78 | 'test.group', 'endpoint') 79 | 80 | @mock.patch.object(pkg_resources.EntryPoint, 'parse', 81 | return_value=mock.Mock(**{ 82 | 'load.return_value': 'class', 83 | })) 84 | @mock.patch.object(pkg_resources, 'iter_entry_points', return_value=[]) 85 | def test_no_compat(self, mock_iter_entry_points, mock_parse): 86 | result = utils.find_entrypoint('test.group', 'spam:ni', compat=False) 87 | 88 | self.assertEqual(result, None) 89 | mock_iter_entry_points.assert_called_once_with( 90 | 'test.group', 'spam:ni') 91 | self.assertFalse(mock_parse.called) 92 | self.assertFalse(mock_parse.return_value.load.called) 93 | 94 | @mock.patch.object(pkg_resources.EntryPoint, 'parse', 95 | return_value=mock.Mock(**{ 96 | 'load.return_value': 'class', 97 | })) 98 | @mock.patch.object(pkg_resources, 'iter_entry_points', return_value=[]) 99 | def test_no_group(self, mock_iter_entry_points, mock_parse): 100 | result = utils.find_entrypoint(None, 'spam:ni', compat=False) 101 | 102 | self.assertEqual(result, 'class') 103 | self.assertFalse(mock_iter_entry_points.called) 104 | mock_parse.assert_called_once_with('x=spam:ni') 105 | mock_parse.return_value.load.assert_called_once_with(False) 106 | 107 | @mock.patch.object(pkg_resources.EntryPoint, 'parse', 108 | return_value=mock.Mock(**{ 109 | 'load.return_value': 'class', 110 | })) 111 | @mock.patch.object(pkg_resources, 'iter_entry_points', return_value=[]) 112 | def test_with_compat(self, mock_iter_entry_points, mock_parse): 113 | result = utils.find_entrypoint('test.group', 'spam:ni') 114 | 115 | self.assertEqual(result, 'class') 116 | self.assertFalse(mock_iter_entry_points.called) 117 | mock_parse.assert_called_once_with('x=spam:ni') 118 | mock_parse.return_value.load.assert_called_once_with(False) 119 | 120 | @mock.patch.object(pkg_resources.EntryPoint, 'parse', 121 | return_value=mock.Mock(**{ 122 | 'load.side_effect': ImportError, 123 | })) 124 | @mock.patch.object(pkg_resources, 'iter_entry_points', return_value=[]) 125 | def test_with_compat_importerror(self, mock_iter_entry_points, mock_parse): 126 | result = utils.find_entrypoint('test.group', 'spam:ni') 127 | 128 | self.assertEqual(result, None) 129 | self.assertFalse(mock_iter_entry_points.called) 130 | mock_parse.assert_called_once_with('x=spam:ni') 131 | mock_parse.return_value.load.assert_called_once_with(False) 132 | 133 | @mock.patch.object(pkg_resources.EntryPoint, 'parse', 134 | return_value=mock.Mock(**{ 135 | 'load.side_effect': pkg_resources.UnknownExtra, 136 | })) 137 | @mock.patch.object(pkg_resources, 'iter_entry_points', return_value=[]) 138 | def test_with_compat_unknownextra(self, mock_iter_entry_points, 139 | mock_parse): 140 | result = utils.find_entrypoint('test.group', 'spam:ni') 141 | 142 | self.assertEqual(result, None) 143 | self.assertFalse(mock_iter_entry_points.called) 144 | mock_parse.assert_called_once_with('x=spam:ni') 145 | mock_parse.return_value.load.assert_called_once_with(False) 146 | 147 | 148 | class TestIgnoreExcept(unittest2.TestCase): 149 | def test_ignore_except(self): 150 | step = 0 151 | with utils.ignore_except(): 152 | step += 1 153 | raise test_utils.TestException() 154 | step += 2 155 | 156 | self.assertEqual(step, 1) 157 | -------------------------------------------------------------------------------- /tests/unit/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2013 Rackspace 2 | # All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | # not use this file except in compliance with the License. You may obtain 6 | # a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13 | # License for the specific language governing permissions and limitations 14 | # under the License. 15 | 16 | from lxml import etree 17 | 18 | 19 | class TestException(Exception): 20 | pass 21 | 22 | 23 | class Halt(BaseException): 24 | pass 25 | 26 | 27 | class XMLMatchState(object): 28 | """ 29 | Maintain some state for matching. 30 | 31 | Tracks the XML node path and saves the expected and actual full 32 | XML text, for use by the XMLMismatch subclasses. 33 | """ 34 | 35 | def __init__(self): 36 | self.path = [] 37 | 38 | def __enter__(self): 39 | pass 40 | 41 | def __exit__(self, exc_type, exc_value, exc_tb): 42 | self.path.pop() 43 | return False 44 | 45 | def __str__(self): 46 | return '/' + '/'.join(self.path) 47 | 48 | def node(self, tag, idx): 49 | """ 50 | Adds tag and index to the path; they will be popped off when 51 | the corresponding 'with' statement exits. 52 | 53 | :param tag: The element tag 54 | :param idx: If not None, the integer index of the element 55 | within its parent. Not included in the path 56 | element if None. 57 | """ 58 | 59 | if idx is not None: 60 | self.path.append("%s[%d]" % (tag, idx)) 61 | else: 62 | self.path.append(tag) 63 | return self 64 | 65 | 66 | def _compare_node(expected, actual, state, idx): 67 | """Recursively compares nodes within the XML tree.""" 68 | 69 | # Start by comparing the tags 70 | if expected.tag != actual.tag: 71 | raise AssertionError("s: XML tag mismatch at index %d: " 72 | "expected tag <%s>; actual tag <%s>" % 73 | (state, idx, expected.tag, actual.tag)) 74 | 75 | with state.node(expected.tag, idx): 76 | # Compare the attribute keys 77 | expected_attrs = set(expected.attrib.keys()) 78 | actual_attrs = set(actual.attrib.keys()) 79 | if expected_attrs != actual_attrs: 80 | expected = ', '.join(sorted(expected_attrs - actual_attrs)) 81 | actual = ', '.join(sorted(actual_attrs - expected_attrs)) 82 | raise AssertionError("%s: XML attributes mismatch: " 83 | "keys only in expected: %s; " 84 | "keys only in actual: %s" % 85 | (state, expected, actual)) 86 | 87 | # Compare the attribute values 88 | for key in expected_attrs: 89 | expected_value = expected.attrib[key] 90 | actual_value = actual.attrib[key] 91 | 92 | if 'DONTCARE' in (expected_value, actual_value): 93 | continue 94 | elif expected_value != actual_value: 95 | raise AssertionError("%s: XML attribute value mismatch: " 96 | "expected value of attribute %s: %r; " 97 | "actual value: %r" % 98 | (state, key, expected_value, 99 | actual_value)) 100 | 101 | # Compare the contents of the node 102 | if len(expected) == 0 and len(actual) == 0: 103 | # No children, compare text values 104 | if ('DONTCARE' not in (expected.text, actual.text) and 105 | expected.text != actual.text): 106 | raise AssertionError("%s: XML text value mismatch: " 107 | "expected text value: %r; " 108 | "actual value: %r" % 109 | (state, expected.text, actual.text)) 110 | else: 111 | expected_idx = 0 112 | actual_idx = 0 113 | while (expected_idx < len(expected) and 114 | actual_idx < len(actual)): 115 | # Ignore comments and processing instructions 116 | # TODO(Vek): may interpret PIs in the future, to 117 | # allow for, say, arbitrary ordering of some 118 | # elements 119 | if (expected[expected_idx].tag in 120 | (etree.Comment, etree.ProcessingInstruction)): 121 | expected_idx += 1 122 | continue 123 | 124 | # Compare the nodes 125 | result = _compare_node(expected[expected_idx], 126 | actual[actual_idx], state, 127 | actual_idx) 128 | if result is not True: 129 | return result 130 | 131 | # Step on to comparing the next nodes... 132 | expected_idx += 1 133 | actual_idx += 1 134 | 135 | # Make sure we consumed all nodes in actual 136 | if actual_idx < len(actual): 137 | raise AssertionError("%s: XML unexpected child element " 138 | "<%s> present at index %d" % 139 | (state, actual[actual_idx].tag, 140 | actual_idx)) 141 | 142 | # Make sure we consumed all nodes in expected 143 | if expected_idx < len(expected): 144 | for node in expected[expected_idx:]: 145 | if (node.tag in 146 | (etree.Comment, etree.ProcessingInstruction)): 147 | continue 148 | 149 | raise AssertionError("%s: XML expected child element " 150 | "<%s> not present at index %d" % 151 | (state, node.tag, actual_idx)) 152 | 153 | # The nodes match 154 | return True 155 | 156 | 157 | def compare_xml(expected, actual): 158 | """Compare two XML strings.""" 159 | 160 | expected = etree.fromstring(expected) 161 | if isinstance(actual, basestring): 162 | actual = etree.fromstring(actual) 163 | 164 | state = XMLMatchState() 165 | result = _compare_node(expected, actual, state, None) 166 | 167 | if result is False: 168 | raise AssertionError("%s: XML does not match" % state) 169 | elif result is not True: 170 | return result 171 | 172 | 173 | class TimeIncrementor(object): 174 | def __init__(self, interval, start=1000000.0): 175 | self.time = start - interval 176 | self.interval = interval 177 | 178 | def __call__(self): 179 | self.time += self.interval 180 | return self.time 181 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py26,py27,pep8 3 | 4 | [testenv] 5 | setenv = LANG=en_US.UTF-8 6 | LANGUAGE=en_US:en 7 | LC_ALL=C 8 | 9 | deps = -r{toxinidir}/.requires 10 | -r{toxinidir}/.test-requires 11 | commands = nosetests -v {posargs} 12 | 13 | [testenv:pep8] 14 | deps = pep8 15 | commands = pep8 --repeat --show-source turnstile tests 16 | 17 | [testenv:cover] 18 | deps = -r{toxinidir}/.requires 19 | -r{toxinidir}/.test-requires 20 | coverage 21 | commands = nosetests -v --with-coverage --cover-package=turnstile \ 22 | --cover-html --cover-html-dir=cov_html 23 | -------------------------------------------------------------------------------- /turnstile/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2012 Rackspace 2 | # All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | # not use this file except in compliance with the License. You may obtain 6 | # a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13 | # License for the specific language governing permissions and limitations 14 | # under the License. 15 | -------------------------------------------------------------------------------- /turnstile/compactor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2013 Rackspace 2 | # All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | # not use this file except in compliance with the License. You may obtain 6 | # a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13 | # License for the specific language governing permissions and limitations 14 | # under the License. 15 | 16 | import logging 17 | import time 18 | import traceback 19 | import uuid 20 | 21 | import msgpack 22 | 23 | from turnstile import control 24 | from turnstile import database 25 | from turnstile import limits 26 | from turnstile import remote 27 | from turnstile import utils 28 | 29 | 30 | LOG = logging.getLogger('turnstile') 31 | 32 | 33 | def version_greater(minimum, version): 34 | """ 35 | Compare two version strings. 36 | 37 | :param minimum: The minimum valid version. 38 | :param version: The version to compare to. 39 | 40 | :returns: True if version is greater than minimum, False 41 | otherwise. 42 | """ 43 | 44 | # Chop up the version strings 45 | minimum = [int(i) for i in minimum.split('.')] 46 | version = [int(i) for i in version.split('.')] 47 | 48 | # Compare the versions element by element 49 | for mini, vers in zip(minimum, version): 50 | if vers < mini: 51 | # If it's less than, we definitely don't match 52 | return False 53 | elif vers > mini: 54 | # If it's greater than, we definitely match 55 | return True 56 | 57 | # OK, the elements are equal; loop around and check out the 58 | # next element 59 | 60 | # All elements are equal 61 | return True 62 | 63 | 64 | def get_int(config, key, default): 65 | """ 66 | A helper to retrieve an integer value from a given dictionary 67 | containing string values. If the requested value is not present 68 | in the dictionary, or if it cannot be converted to an integer, a 69 | default value will be returned instead. 70 | 71 | :param config: The dictionary containing the desired value. 72 | :param key: The dictionary key for the desired value. 73 | :param default: The default value to return, if the key isn't set 74 | in the dictionary, or if the value set isn't a 75 | legal integer value. 76 | 77 | :returns: The desired integer value. 78 | """ 79 | 80 | try: 81 | return int(config[key]) 82 | except (KeyError, ValueError): 83 | return default 84 | 85 | 86 | class GetBucketKey(object): 87 | """ 88 | Bucket keys to be compacted are placed on a sorted set. The 89 | compactor needs to atomically pop one bucket key off the set. 90 | This can be done with a lock--entailing the use of a lock key and 91 | various timeout mechanisms--or by evaluating a Lua script, which 92 | is probably the best way. Unfortunately, Lua scripts are not 93 | supported prior to client version 2.7.0 or server version 2.6.0. 94 | This class provides an abstraction around these two methods, 95 | simplifying compactor(). 96 | """ 97 | 98 | @classmethod 99 | def factory(cls, config, db): 100 | """ 101 | Given a configuration and database, select and return an 102 | appropriate instance of a subclass of GetBucketKey. This will 103 | ensure that both client and server support are available for 104 | the Lua script feature of Redis, and if not, a lock will be 105 | used. 106 | 107 | :param config: A dictionary of compactor options. 108 | :param db: A database handle for the Redis database. 109 | 110 | :returns: An instance of a subclass of GetBucketKey, dependent 111 | on the support for the Lua script feature of Redis. 112 | """ 113 | 114 | # Make sure that the client supports register_script() 115 | if not hasattr(db, 'register_script'): 116 | LOG.debug("Redis client does not support register_script()") 117 | return GetBucketKeyByLock(config, db) 118 | 119 | # OK, the client supports register_script(); what about the 120 | # server? 121 | info = db.info() 122 | if version_greater('2.6', info['redis_version']): 123 | LOG.debug("Redis server supports register_script()") 124 | return GetBucketKeyByScript(config, db) 125 | 126 | # OK, use our fallback... 127 | LOG.debug("Redis server does not support register_script()") 128 | return GetBucketKeyByLock(config, db) 129 | 130 | def __init__(self, config, db): 131 | """ 132 | Initialize a GetBucketKey instance. 133 | 134 | :param config: A dictionary of compactor options. 135 | :param db: A database handle for the Redis database. 136 | """ 137 | 138 | self.db = db 139 | self.key = config.get('compactor_key', 'compactor') 140 | self.max_age = get_int(config, 'max_age', 600) 141 | self.min_age = get_int(config, 'min_age', 30) 142 | self.idle_sleep = get_int(config, 'sleep', 5) 143 | 144 | def __call__(self): 145 | """ 146 | Retrieve the next bucket key to compact. If no buckets are 147 | available for compacting, sleeps for a given period of time 148 | and tries again. 149 | 150 | :returns: The bucket key to compact. 151 | """ 152 | 153 | while True: 154 | now = time.time() 155 | 156 | # Drop all items older than max_age; they're no longer 157 | # quiesced, since the compactor logic will cause new 158 | # summarize records to be generated. No lock is needed... 159 | self.db.zremrangebyscore(self.key, 0, now - self.max_age) 160 | 161 | # Get an item and return it 162 | item = self.get(now) 163 | if item: 164 | LOG.debug("Next bucket to compact: %s" % item) 165 | return item 166 | 167 | # If we didn't get one, idle 168 | LOG.debug("No buckets to compact; sleeping for %s seconds" % 169 | self.idle_sleep) 170 | time.sleep(self.idle_sleep) 171 | 172 | def get(self, now): 173 | """ 174 | Get a bucket key to compact. If none are available, returns 175 | None. 176 | 177 | :param now: The current time, as a float. Used to ensure the 178 | bucket key has been aged sufficiently to be 179 | quiescent. 180 | 181 | :returns: A bucket key ready for compaction, or None if no 182 | bucket keys are available or none have aged 183 | sufficiently. 184 | """ 185 | 186 | raise NotImplementedError() # Pragma: nocover 187 | 188 | 189 | class GetBucketKeyByLock(GetBucketKey): 190 | """ 191 | Retrieve a bucket key to compact using a lock. 192 | """ 193 | 194 | def __init__(self, config, db): 195 | """ 196 | Initialize a GetBucketKeyByLock instance. 197 | 198 | :param config: A dictionary of compactor options. 199 | :param db: A database handle for the Redis database. 200 | """ 201 | 202 | super(GetBucketKeyByLock, self).__init__(config, db) 203 | 204 | lock_key = config.get('compactor_lock', 'compactor_lock') 205 | timeout = get_int(config, 'compactor_timeout', 30) 206 | self.lock = db.lock(lock_key, timeout=timeout) 207 | 208 | LOG.debug("Using GetBucketKeyByLock as bucket key getter") 209 | 210 | def get(self, now): 211 | """ 212 | Get a bucket key to compact. If none are available, returns 213 | None. This uses a configured lock to ensure that the bucket 214 | key is popped off the sorted set in an atomic fashion. 215 | 216 | :param now: The current time, as a float. Used to ensure the 217 | bucket key has been aged sufficiently to be 218 | quiescent. 219 | 220 | :returns: A bucket key ready for compaction, or None if no 221 | bucket keys are available or none have aged 222 | sufficiently. 223 | """ 224 | 225 | with self.lock: 226 | items = self.db.zrangebyscore(self.key, 0, now - self.min_age, 227 | start=0, num=1) 228 | # Did we get any items? 229 | if not items: 230 | return None 231 | 232 | # Drop the item we got 233 | item = items[0] 234 | self.db.zrem(item) 235 | 236 | return item 237 | 238 | 239 | class GetBucketKeyByScript(GetBucketKey): 240 | """ 241 | Retrieve a bucket key to compact using a Lua script. 242 | """ 243 | 244 | def __init__(self, config, db): 245 | """ 246 | Initialize a GetBucketKeyByScript instance. 247 | 248 | :param config: A dictionary of compactor options. 249 | :param db: A database handle for the Redis database. 250 | """ 251 | 252 | super(GetBucketKeyByScript, self).__init__(config, db) 253 | 254 | self.script = db.register_script(""" 255 | local res 256 | res = redis.call('zrangebyscore', KEYS[1], 0, ARGV[1], 'limit', 0, 1) 257 | if #res > 0 then 258 | redis.call('zrem', res[1]) 259 | end 260 | return res 261 | """) 262 | 263 | LOG.debug("Using GetBucketKeyByScript as bucket key getter") 264 | 265 | def get(self, now): 266 | """ 267 | Get a bucket key to compact. If none are available, returns 268 | None. This uses a Lua script to ensure that the bucket key is 269 | popped off the sorted set in an atomic fashion. 270 | 271 | :param now: The current time, as a float. Used to ensure the 272 | bucket key has been aged sufficiently to be 273 | quiescent. 274 | 275 | :returns: A bucket key ready for compaction, or None if no 276 | bucket keys are available or none have aged 277 | sufficiently. 278 | """ 279 | 280 | items = self.script(keys=[self.key], args=[now - self.min_age]) 281 | return items[0] if items else None 282 | 283 | 284 | class LimitContainer(object): 285 | """ 286 | Contains a mapping of available limits. To compact a bucket, the 287 | bucket needs to be loaded; this needs to be done by reference to 288 | the limit class, as the limit class specifies the bucket class to 289 | use and performs the appropriate processing of update records in 290 | the bucket list. 291 | 292 | Much of the code here is actually copied from the 293 | TurnstileMiddleware, suggesting that further abstraction is 294 | necessary. 295 | """ 296 | 297 | def __init__(self, conf, db): 298 | """ 299 | Initialize a LimitContainer. This sets up an appropriate 300 | control daemon, as well as providing a container for the 301 | limits themselves. 302 | 303 | :param conf: A turnstile.config.Config instance containing the 304 | configuration for the ControlDaemon. 305 | :param db: A database handle for the Redis database. 306 | """ 307 | 308 | self.conf = conf 309 | self.db = db 310 | self.limits = [] 311 | self.limit_map = {} 312 | self.limit_sum = None 313 | 314 | # Initialize the control daemon 315 | if conf.to_bool(conf['control'].get('remote', 'no'), False): 316 | self.control_daemon = remote.RemoteControlDaemon(self, conf) 317 | else: 318 | self.control_daemon = control.ControlDaemon(self, conf) 319 | 320 | # Now start the control daemon 321 | self.control_daemon.start() 322 | 323 | def __getitem__(self, key): 324 | """ 325 | Obtain the limit with the given UUID. Ensures that the 326 | current limit list is loaded. 327 | 328 | :param key: The UUID of the desired limit. 329 | """ 330 | 331 | self.recheck_limits() 332 | return self.limit_map[key] 333 | 334 | def recheck_limits(self): 335 | """ 336 | Re-check that the cached limits are the current limits. 337 | """ 338 | 339 | limit_data = self.control_daemon.get_limits() 340 | 341 | try: 342 | # Get the new checksum and list of limits 343 | new_sum, new_limits = limit_data.get_limits(self.limit_sum) 344 | 345 | # Convert the limits list into a list of objects 346 | lims = database.limits_hydrate(self.db, new_limits) 347 | 348 | # Save the new data 349 | self.limits = lims 350 | self.limit_map = dict((lim.uuid, lim) for lim in lims) 351 | self.limit_sum = new_sum 352 | except control.NoChangeException: 353 | # No changes to process; just keep going... 354 | return 355 | except Exception: 356 | # Log an error 357 | LOG.exception("Could not load limits") 358 | 359 | # Get our error set and publish channel 360 | control_args = self.conf['control'] 361 | error_key = control_args.get('errors_key', 'errors') 362 | error_channel = control_args.get('errors_channel', 'errors') 363 | 364 | # Get an informative message 365 | msg = "Failed to load limits: " + traceback.format_exc() 366 | 367 | # Store the message into the error set. We use a set 368 | # here because it's likely that more than one node 369 | # will generate the same message if there is an error, 370 | # and this avoids an explosion in the size of the set. 371 | with utils.ignore_except(): 372 | self.db.sadd(error_key, msg) 373 | 374 | # Publish the message to a channel 375 | with utils.ignore_except(): 376 | self.db.publish(error_channel, msg) 377 | 378 | 379 | def compact_bucket(db, buck_key, limit): 380 | """ 381 | Perform the compaction operation. This reads in the bucket 382 | information from the database, builds a compacted bucket record, 383 | inserts that record in the appropriate place in the database, then 384 | removes outdated updates. 385 | 386 | :param db: A database handle for the Redis database. 387 | :param buck_key: A turnstile.limits.BucketKey instance containing 388 | the bucket key. 389 | :param limit: The turnstile.limits.Limit object corresponding to 390 | the bucket. 391 | """ 392 | 393 | # Suck in the bucket records and generate our bucket 394 | records = db.lrange(str(buck_key), 0, -1) 395 | loader = limits.BucketLoader(limit.bucket_class, db, limit, 396 | str(buck_key), records, stop_summarize=True) 397 | 398 | # We now have the bucket loaded in; generate a 'bucket' record 399 | buck_record = msgpack.dumps(dict(bucket=loader.bucket.dehydrate(), 400 | uuid=str(uuid.uuid4()))) 401 | 402 | # Now we need to insert it into the record list 403 | result = db.linsert(str(buck_key), 'after', loader.last_summarize_rec, 404 | buck_record) 405 | 406 | # Were we successful? 407 | if result < 0: 408 | # Insert failed; we'll try again when max_age is hit 409 | LOG.warning("Bucket compaction on %s failed; will retry" % buck_key) 410 | return 411 | 412 | # OK, we have confirmed that the compacted bucket record has been 413 | # inserted correctly; now all we need to do is trim off the 414 | # outdated update records 415 | db.ltrim(str(buck_key), loader.last_summarize_idx + 1, -1) 416 | 417 | 418 | def compactor(conf): 419 | """ 420 | The compactor daemon. This fuction watches the sorted set 421 | containing bucket keys that need to be compacted, performing the 422 | necessary compaction. 423 | 424 | :param conf: A turnstile.config.Config instance containing the 425 | configuration for the compactor daemon. Note that a 426 | ControlDaemon is also started, so appropriate 427 | configuration for that must also be present, as must 428 | appropriate Redis connection information. 429 | """ 430 | 431 | # Get the database handle 432 | db = conf.get_database('compactor') 433 | 434 | # Get the limits container 435 | limit_map = LimitContainer(conf, db) 436 | 437 | # Get the compactor configuration 438 | config = conf['compactor'] 439 | 440 | # Make sure compaction is enabled 441 | if get_int(config, 'max_updates', 0) <= 0: 442 | # We'll just warn about it, since they could be running 443 | # the compactor with a different configuration file 444 | LOG.warning("Compaction is not enabled. Enable it by " 445 | "setting a positive integer value for " 446 | "'compactor.max_updates' in the configuration.") 447 | 448 | # Select the bucket key getter 449 | key_getter = GetBucketKey.factory(config, db) 450 | 451 | LOG.info("Compactor initialized") 452 | 453 | # Now enter our loop 454 | while True: 455 | # Get a bucket key to compact 456 | try: 457 | buck_key = limits.BucketKey.decode(key_getter()) 458 | except ValueError as exc: 459 | # Warn about invalid bucket keys 460 | LOG.warning("Error interpreting bucket key: %s" % exc) 461 | continue 462 | 463 | # Ignore version 1 keys--they can't be compacted 464 | if buck_key.version < 2: 465 | continue 466 | 467 | # Get the corresponding limit class 468 | try: 469 | limit = limit_map[buck_key.uuid] 470 | except KeyError: 471 | # Warn about missing limits 472 | LOG.warning("Unable to compact bucket for limit %s" % 473 | buck_key.uuid) 474 | continue 475 | 476 | LOG.debug("Compacting bucket %s" % buck_key) 477 | 478 | # OK, we now have the limit (which we really only need for 479 | # the bucket class); let's compact the bucket 480 | try: 481 | compact_bucket(db, buck_key, limit) 482 | except Exception: 483 | LOG.exception("Failed to compact bucket %s" % buck_key) 484 | else: 485 | LOG.debug("Finished compacting bucket %s" % buck_key) 486 | -------------------------------------------------------------------------------- /turnstile/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2012 Rackspace 2 | # All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | # not use this file except in compliance with the License. You may obtain 6 | # a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13 | # License for the specific language governing permissions and limitations 14 | # under the License. 15 | 16 | import ConfigParser 17 | 18 | from turnstile import database 19 | 20 | 21 | _str_true = set(['t', 'true', 'on', 'y', 'yes']) 22 | _str_false = set(['f', 'false', 'off', 'n', 'no']) 23 | 24 | 25 | class Config(object): 26 | """ 27 | Stores configuration data. Configuration can be loaded from the 28 | paste file (as ". = ") or from a standard INI 29 | file. For paste files, keys with no section prefix are made 30 | accessible as attributes; for standard INI files, the keys present 31 | in the "[turnstile]" section are made accessible as attributes. A 32 | configuration file may be specified in the paste file with the 33 | special "config" key. All keys associated with other sections are 34 | stored in dictionaries; these dictionaries are made accessible 35 | using subscripting. 36 | 37 | As an example, consider the following paste.ini file: 38 | 39 | [filter:turnstile] 40 | paste.filter_factory = turnstile.middleware:turnstile_filter 41 | preprocess = my_preproc:preproc 42 | redis.host = 10.0.0.1 43 | config = /etc/my_turnstile.conf 44 | 45 | Further assume that the /etc/my_turnstile.conf contains the 46 | following: 47 | 48 | [turnstile] 49 | status = 500 Internal Error 50 | 51 | [redis] 52 | password = s3cureM3! 53 | 54 | [control] 55 | node_name = mynode 56 | 57 | With this configuration, the config object acts like so: 58 | 59 | >>> config.preprocess 60 | 'my_preproc:preproc' 61 | >>> config.status 62 | '500 Internal Error' 63 | >>> config.config 64 | '/etc/my_turnstile.conf' 65 | >>> config['redis'] 66 | {'host': '10.0.0.1', 'password': 's3cureM3!'} 67 | >>> config['control'] 68 | {'node_name': 'mynode'} 69 | """ 70 | 71 | def __init__(self, conf_dict=None, conf_file=None): 72 | """ 73 | Initializes a Config object. A default is provided for the 74 | "status" configuration. 75 | 76 | :param conf_dict: Optional. Should specify a dictionary 77 | containing the configuration drawn from the 78 | paste.ini file. If a 'config' key is 79 | present in the dict, configuration will 80 | additionally be drawn from the specified INI 81 | file; configuration from the INI file will 82 | override configuration drawn from this dict. 83 | :param conf_file: Optional. Should specify the name of a file 84 | containing further configuration. If a 85 | conf_dict is also provided, values drawn 86 | from this file will override values from the 87 | conf_dict, as well as any additional file 88 | specified by the 'config' key. 89 | 90 | For configuration files, values in the '[turnstile]' section 91 | correspond to prefix-less values in the dictionary, with the 92 | exception that the 'config' value is ignored. 93 | """ 94 | 95 | self._config = { 96 | None: { 97 | 'status': '413 Request Entity Too Large', 98 | }, 99 | } 100 | 101 | # Handle passed-in dict (middleware) 102 | if conf_dict: 103 | for key, value in conf_dict.items(): 104 | outer, _sep, inner = key.partition('.') 105 | 106 | # Deal with prefix-less keys 107 | if not inner: 108 | outer, inner = None, outer 109 | 110 | # Make sure we have a place to put them 111 | self._config.setdefault(outer, {}) 112 | self._config[outer][inner] = value 113 | 114 | conf_files = [] 115 | 116 | # Were we to look aside to a configuration file? 117 | if 'config' in self._config[None]: 118 | conf_files.append(self._config[None]['config']) 119 | 120 | # Were we asked to load a specific file in addition? 121 | if conf_file: 122 | conf_files.append(conf_file) 123 | 124 | # Parse configuration files 125 | if conf_files: 126 | cp = ConfigParser.SafeConfigParser() 127 | cp.read(conf_files) 128 | 129 | # Each section corresponds to a top-level in the config 130 | for sect in cp.sections(): 131 | # Transform [turnstile] section 132 | outer = None if sect == 'turnstile' else sect 133 | 134 | self._config.setdefault(outer, {}) 135 | 136 | # Merge in the options from the section 137 | self._config[outer].update(dict(cp.items(sect))) 138 | 139 | def __getitem__(self, key): 140 | """ 141 | Retrieve the configuration dictionary for the given section. 142 | If the section does not exist in the configuration, an empty 143 | dictionary is returned, for convenience. 144 | """ 145 | 146 | return self._config.get(key, {}) 147 | 148 | def __contains__(self, key): 149 | """ 150 | Test if the given section exists in the configuration. 151 | Returns True if it does, False if it does not. Note that 152 | __getitem__() returns an empty dictionary if __contains__() 153 | would return False. 154 | """ 155 | 156 | return key in self._config 157 | 158 | def __getattr__(self, key): 159 | """ 160 | Retrieve the given configuration option. Configuration 161 | options that can be queried this way are those that are 162 | specified without prefix in the paste.ini file, or which are 163 | specified in the '[turnstile]' section of the configuration 164 | file. Raises an AttributeError if the given option does not 165 | exist. 166 | """ 167 | 168 | try: 169 | return self._config.get(None, {})[key] 170 | except KeyError: 171 | raise AttributeError('%r object has no attribute %r' % 172 | (self.__class__.__name__, key)) 173 | 174 | def get(self, key, default=None): 175 | """ 176 | Retrieve the given configuration option. Configuration 177 | options that can be queried this way are those that are 178 | specified without prefix in the paste.ini file, or which are 179 | specified in the '[turnstile]' section of the configuration 180 | file. Returns the default value (None if not specified) if 181 | the given option does not exist. 182 | """ 183 | 184 | return self._config.get(None, {}).get(key, default) 185 | 186 | def get_database(self, override=None): 187 | """ 188 | Convenience function for obtaining a handle to the Redis 189 | database. By default, uses the connection options from the 190 | '[redis]' section. However, if the override parameter is 191 | given, it specifies a section containing overrides for the 192 | Redis connection info; the keys will all be prefixed with 193 | 'redis.'. For example, in the following configuration file: 194 | 195 | [redis] 196 | host = 10.0.0.1 197 | password = s3cureM3! 198 | 199 | [control] 200 | redis.host = 127.0.0.1 201 | 202 | A call to get_database() would return a handle for the redis 203 | database on 10.0.0.1, while a call to get_database('control') 204 | would return a handle for the redis database on 127.0.0.1; in 205 | both cases, the database password would be 's3cureM3!'. 206 | """ 207 | 208 | # Grab the database connection arguments 209 | redis_args = self['redis'] 210 | 211 | # If we have an override, read some overrides from that 212 | # section 213 | if override: 214 | redis_args = redis_args.copy() 215 | for key, value in self[override].items(): 216 | if not key.startswith('redis.'): 217 | continue 218 | key = key[len('redis.'):] 219 | if value: 220 | redis_args[key] = value 221 | else: 222 | redis_args.pop(key, None) 223 | 224 | # Return the redis database connection 225 | return database.initialize(redis_args) 226 | 227 | @staticmethod 228 | def to_bool(value, do_raise=True): 229 | """Convert a string to a boolean value. 230 | 231 | If the string consists of digits, the integer value of the string 232 | is coerced to a boolean value. Otherwise, any of the strings "t", 233 | "true", "on", "y", and "yes" are considered True and any of the 234 | strings "f", "false", "off", "n", and "no" are considered False. 235 | A ValueError will be raised for any other value. 236 | """ 237 | 238 | value = value.lower() 239 | 240 | # Try it as an integer 241 | if value.isdigit(): 242 | return bool(int(value)) 243 | 244 | # OK, check it against the true/false values... 245 | if value in _str_true: 246 | return True 247 | elif value in _str_false: 248 | return False 249 | 250 | # Not recognized 251 | if do_raise: 252 | raise ValueError("invalid literal for to_bool(): %r" % value) 253 | 254 | return False 255 | -------------------------------------------------------------------------------- /turnstile/control.py: -------------------------------------------------------------------------------- 1 | # Copyright 2012 Rackspace 2 | # All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | # not use this file except in compliance with the License. You may obtain 6 | # a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13 | # License for the specific language governing permissions and limitations 14 | # under the License. 15 | 16 | import hashlib 17 | import logging 18 | import random 19 | import traceback 20 | 21 | import eventlet 22 | import msgpack 23 | 24 | from turnstile import utils 25 | 26 | 27 | LOG = logging.getLogger('turnstile') 28 | 29 | 30 | class NoChangeException(Exception): 31 | """ 32 | Indicates that there are no limit data changes to be applied. 33 | Raised by LimitData.get_limits(). 34 | """ 35 | 36 | pass 37 | 38 | 39 | class LimitData(object): 40 | """ 41 | Stores limit data. Provides a common depot between the 42 | ControlDaemon and the middleware which contains the raw limit data 43 | (as msgpack'd strings). 44 | """ 45 | 46 | def __init__(self): 47 | """ 48 | Initialize the LimitData. The limit data is initialized to 49 | the empty list. 50 | """ 51 | 52 | # Build up a sum for the empty list 53 | chksum = hashlib.md5() 54 | chksum.update('') 55 | 56 | self.limit_data = [] 57 | self.limit_sum = chksum.hexdigest() 58 | self.limit_lock = eventlet.semaphore.Semaphore() 59 | 60 | def set_limits(self, limits): 61 | """ 62 | Set the limit data to the given list of limits. Limits are 63 | specified as the raw msgpack string representing the limit. 64 | Computes the checksum of the limits; if the checksum is 65 | identical to the current one, no action is taken. 66 | """ 67 | 68 | # First task, build the checksum of the new limits 69 | chksum = hashlib.md5() # sufficient for our purposes 70 | for lim in limits: 71 | chksum.update(lim) 72 | new_sum = chksum.hexdigest() 73 | 74 | # Now install it 75 | with self.limit_lock: 76 | if self.limit_sum == new_sum: 77 | # No changes 78 | return 79 | self.limit_data = [msgpack.loads(lim) for lim in limits] 80 | self.limit_sum = new_sum 81 | 82 | def get_limits(self, limit_sum=None): 83 | """ 84 | Gets the current limit data if it is different from the data 85 | indicated by limit_sum. The db argument is used for hydrating 86 | the limit objects. Raises a NoChangeException if the 87 | limit_sum represents no change, otherwise returns a tuple 88 | consisting of the current limit_sum and a list of Limit 89 | objects. 90 | """ 91 | 92 | with self.limit_lock: 93 | # Any changes? 94 | if limit_sum and self.limit_sum == limit_sum: 95 | raise NoChangeException() 96 | 97 | # Return a tuple of the limits and limit sum 98 | return (self.limit_sum, self.limit_data) 99 | 100 | 101 | class ControlDaemon(object): 102 | """ 103 | A daemon thread which listens for control messages and can reload 104 | the limit configuration from the database. 105 | """ 106 | 107 | _commands = {} 108 | 109 | @classmethod 110 | def _register(cls, name, func): 111 | """ 112 | Register func as a recognized control command with the given 113 | name. 114 | """ 115 | 116 | cls._commands[name] = func 117 | 118 | def __init__(self, middleware, conf): 119 | """ 120 | Initialize the ControlDaemon. Starts the listening thread and 121 | triggers an immediate reload. 122 | """ 123 | 124 | # Save some relevant information 125 | self._db = None 126 | self.middleware = middleware 127 | self.config = conf 128 | self.limits = LimitData() 129 | 130 | # Need a semaphore to cover reloads in action 131 | self.pending = eventlet.semaphore.Semaphore() 132 | 133 | # Initialize the listening thread 134 | self.listen_thread = None 135 | 136 | def start(self): 137 | """ 138 | Starts the ControlDaemon by launching the listening thread and 139 | triggering the initial limits load. 140 | """ 141 | 142 | # Spawn the listening thread 143 | self.listen_thread = eventlet.spawn_n(self.listen) 144 | 145 | # Now do the initial load 146 | self.reload() 147 | 148 | def listen(self): 149 | """ 150 | Listen for incoming control messages. 151 | 152 | If the 'redis.shard_hint' configuration is set, its value will 153 | be passed to the pubsub() method when setting up the 154 | subscription. The control channel to subscribe to is 155 | specified by the 'redis.control_channel' configuration 156 | ('control' by default). 157 | """ 158 | 159 | # Use a specific database handle, with override. This allows 160 | # the long-lived listen thread to be configured to use a 161 | # different database or different database options. 162 | db = self.config.get_database('control') 163 | 164 | # Need a pub-sub object 165 | kwargs = {} 166 | if 'shard_hint' in self.config['control']: 167 | kwargs['shard_hint'] = self.config['control']['shard_hint'] 168 | pubsub = db.pubsub(**kwargs) 169 | 170 | # Subscribe to the right channel(s)... 171 | channel = self.config['control'].get('channel', 'control') 172 | pubsub.subscribe(channel) 173 | 174 | # Now we listen... 175 | for msg in pubsub.listen(): 176 | # Only interested in messages to our reload channel 177 | if (msg['type'] in ('pmessage', 'message') and 178 | msg['channel'] == channel): 179 | # Figure out what kind of message this is 180 | command, _sep, args = msg['data'].partition(':') 181 | 182 | # We must have some command... 183 | if not command: 184 | continue 185 | 186 | # Don't do anything with internal commands 187 | if command[0] == '_': 188 | LOG.error("Cannot call internal command %r" % command) 189 | continue 190 | 191 | # Look up the command 192 | if command in self._commands: 193 | func = self._commands[command] 194 | else: 195 | # Try an entrypoint 196 | func = utils.find_entrypoint('turnstile.command', command, 197 | compat=False) 198 | self._commands[command] = func 199 | 200 | # Don't do anything with missing commands 201 | if not func: 202 | LOG.error("No such command %r" % command) 203 | continue 204 | 205 | # Execute the desired command 206 | arglist = args.split(':') if args else [] 207 | try: 208 | func(self, *arglist) 209 | except Exception: 210 | LOG.exception("Failed to execute command %r arguments %r" % 211 | (command, arglist)) 212 | continue 213 | 214 | def get_limits(self): 215 | """ 216 | Retrieve the LimitData object the middleware will use for 217 | getting the limits. This is broken out into a function so 218 | that it can be overridden in multi-process configurations to 219 | return a LimitData subclass which will query the master 220 | LimitData in the ControlDaemon process. 221 | """ 222 | 223 | return self.limits 224 | 225 | def reload(self): 226 | """ 227 | Reloads the limits configuration from the database. 228 | 229 | If an error occurs loading the configuration, an error-level 230 | log message will be emitted. Additionally, the error message 231 | will be added to the set specified by the 'redis.errors_key' 232 | configuration ('errors' by default) and sent to the publishing 233 | channel specified by the 'redis.errors_channel' configuration 234 | ('errors' by default). 235 | """ 236 | 237 | # Acquire the pending semaphore. If we fail, exit--someone 238 | # else is already doing the reload 239 | if not self.pending.acquire(False): 240 | return 241 | 242 | # Do the remaining steps in a try/finally block so we make 243 | # sure to release the semaphore 244 | control_args = self.config['control'] 245 | try: 246 | # Load all the limits 247 | key = control_args.get('limits_key', 'limits') 248 | self.limits.set_limits(self.db.zrange(key, 0, -1)) 249 | except Exception: 250 | # Log an error 251 | LOG.exception("Could not load limits") 252 | 253 | # Get our error set and publish channel 254 | error_key = control_args.get('errors_key', 'errors') 255 | error_channel = control_args.get('errors_channel', 'errors') 256 | 257 | # Get an informative message 258 | msg = "Failed to load limits: " + traceback.format_exc() 259 | 260 | # Store the message into the error set. We use a set here 261 | # because it's likely that more than one node will 262 | # generate the same message if there is an error, and this 263 | # avoids an explosion in the size of the set. 264 | with utils.ignore_except(): 265 | self.db.sadd(error_key, msg) 266 | 267 | # Publish the message to a channel 268 | with utils.ignore_except(): 269 | self.db.publish(error_channel, msg) 270 | finally: 271 | self.pending.release() 272 | 273 | @property 274 | def db(self): 275 | """ 276 | Obtain a handle for the database. This allows lazy 277 | initialization of the database handle. 278 | """ 279 | 280 | # Initialize the database handle from the middleware's copy of 281 | # it 282 | if not self._db: 283 | self._db = self.middleware.db 284 | 285 | return self._db 286 | 287 | 288 | def register(name, func=None): 289 | """ 290 | Function or decorator which registers a given function as a 291 | recognized control command. 292 | """ 293 | 294 | def decorator(func): 295 | # Perform the registration 296 | ControlDaemon._register(name, func) 297 | return func 298 | 299 | # If func was given, call the decorator, otherwise, return the 300 | # decorator 301 | if func: 302 | return decorator(func) 303 | else: 304 | return decorator 305 | 306 | 307 | @register('ping') 308 | def ping(daemon, channel, data=None): 309 | """ 310 | Process the 'ping' control message. 311 | 312 | :param daemon: The control daemon; used to get at the 313 | configuration and the database. 314 | :param channel: The publish channel to which to send the 315 | response. 316 | :param data: Optional extra data. Will be returned as the 317 | second argument of the response. 318 | 319 | Responds to the named channel with a command of 'pong' and 320 | with the node_name (if configured) and provided data as 321 | arguments. 322 | """ 323 | 324 | if not channel: 325 | # No place to reply to 326 | return 327 | 328 | # Get our configured node name 329 | node_name = daemon.config['control'].get('node_name') 330 | 331 | # Format the response 332 | reply = ['pong'] 333 | if node_name or data: 334 | reply.append(node_name or '') 335 | if data: 336 | reply.append(data) 337 | 338 | # And send it 339 | with utils.ignore_except(): 340 | daemon.db.publish(channel, ':'.join(reply)) 341 | 342 | 343 | @register('reload') 344 | def reload(daemon, load_type=None, spread=None): 345 | """ 346 | Process the 'reload' control message. 347 | 348 | :param daemon: The control daemon; used to get at the 349 | configuration and call the actual reload. 350 | :param load_type: Optional type of reload. If given as 351 | 'immediate', reload is triggered 352 | immediately. If given as 'spread', reload 353 | is triggered after a random period of time 354 | in the interval (0.0, spread). Otherwise, 355 | reload will be as configured. 356 | :param spread: Optional argument for 'spread' load_type. Must 357 | be a float giving the maximum length of the 358 | interval, in seconds, over which the reload 359 | should be scheduled. If not provided, falls 360 | back to configuration. 361 | 362 | If a recognized load_type is not given, or is given as 363 | 'spread' but the spread parameter is not a valid float, the 364 | configuration will be checked for the 'redis.reload_spread' 365 | value. If that is a valid value, the reload will be randomly 366 | scheduled for some time within the interval (0.0, 367 | redis.reload_spread). 368 | """ 369 | 370 | # Figure out what type of reload this needs to be 371 | if load_type == 'immediate': 372 | spread = None 373 | elif load_type == 'spread': 374 | try: 375 | spread = float(spread) 376 | except (TypeError, ValueError): 377 | # Not a valid float; use the configured spread value 378 | load_type = None 379 | else: 380 | load_type = None 381 | 382 | if load_type is None: 383 | # Use configured set-up; see if we have a spread 384 | # configured 385 | try: 386 | spread = float(daemon.config['control']['reload_spread']) 387 | except (TypeError, ValueError, KeyError): 388 | # No valid configuration 389 | spread = None 390 | 391 | if spread: 392 | # Apply a randomization to spread the load around 393 | eventlet.spawn_after(random.random() * spread, daemon.reload) 394 | else: 395 | # Spawn in immediate mode 396 | eventlet.spawn_n(daemon.reload) 397 | -------------------------------------------------------------------------------- /turnstile/database.py: -------------------------------------------------------------------------------- 1 | # Copyright 2012 Rackspace 2 | # All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | # not use this file except in compliance with the License. You may obtain 6 | # a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13 | # License for the specific language governing permissions and limitations 14 | # under the License. 15 | 16 | import msgpack 17 | import redis 18 | 19 | from turnstile import limits 20 | from turnstile import utils 21 | 22 | 23 | REDIS_CONFIGS = { 24 | 'host': str, 25 | 'port': int, 26 | 'db': int, 27 | 'password': str, 28 | 'socket_timeout': int, 29 | 'unix_socket_path': str, 30 | } 31 | 32 | REDIS_EXCLUDES = set(['connection_pool', 'redis_client']) 33 | 34 | 35 | def initialize(config): 36 | """ 37 | Initialize a connection to the Redis database. 38 | """ 39 | 40 | # Determine the client class to use 41 | if 'redis_client' in config: 42 | client = utils.find_entrypoint('turnstile.redis_client', 43 | config['redis_client'], required=True) 44 | else: 45 | client = redis.StrictRedis 46 | 47 | # Extract relevant connection information from the configuration 48 | kwargs = {} 49 | for cfg_var, type_ in REDIS_CONFIGS.items(): 50 | if cfg_var in config: 51 | kwargs[cfg_var] = type_(config[cfg_var]) 52 | 53 | # Make sure we have at a minimum the hostname 54 | if 'host' not in kwargs and 'unix_socket_path' not in kwargs: 55 | raise redis.ConnectionError("No host specified for redis database") 56 | 57 | # Look up the connection pool configuration 58 | cpool_class = None 59 | cpool = {} 60 | extra_kwargs = {} 61 | for key, value in config.items(): 62 | if key.startswith('connection_pool.'): 63 | _dummy, _sep, varname = key.partition('.') 64 | if varname == 'connection_class': 65 | cpool[varname] = utils.find_entrypoint( 66 | 'turnstile.connection_class', value, required=True) 67 | elif varname == 'max_connections': 68 | cpool[varname] = int(value) 69 | elif varname == 'parser_class': 70 | cpool[varname] = utils.find_entrypoint( 71 | 'turnstile.parser_class', value, required=True) 72 | else: 73 | cpool[varname] = value 74 | elif key not in REDIS_CONFIGS and key not in REDIS_EXCLUDES: 75 | extra_kwargs[key] = value 76 | if cpool: 77 | cpool_class = redis.ConnectionPool 78 | 79 | # Use custom connection pool class if requested... 80 | if 'connection_pool' in config: 81 | cpool_class = utils.find_entrypoint('turnstile.connection_pool', 82 | config['connection_pool'], 83 | required=True) 84 | 85 | # If we're using a connection pool, we'll need to pass the keyword 86 | # arguments to that instead of to redis 87 | if cpool_class: 88 | cpool.update(kwargs) 89 | 90 | # Use a custom connection class? 91 | if 'connection_class' not in cpool: 92 | if 'unix_socket_path' in cpool: 93 | if 'host' in cpool: 94 | del cpool['host'] 95 | if 'port' in cpool: 96 | del cpool['port'] 97 | 98 | cpool['path'] = cpool['unix_socket_path'] 99 | del cpool['unix_socket_path'] 100 | 101 | cpool['connection_class'] = redis.UnixDomainSocketConnection 102 | else: 103 | cpool['connection_class'] = redis.Connection 104 | 105 | # Build the connection pool to use and set up to pass it into 106 | # the redis constructor... 107 | kwargs = dict(connection_pool=cpool_class(**cpool)) 108 | 109 | # Build and return the database 110 | kwargs.update(extra_kwargs) 111 | return client(**kwargs) 112 | 113 | 114 | def limits_hydrate(db, lims): 115 | """ 116 | Helper function to hydrate a list of limits. 117 | 118 | :param db: A database handle. 119 | :param lims: A list of limit strings, as retrieved from the 120 | database. 121 | """ 122 | 123 | return [limits.Limit.hydrate(db, lim) for lim in lims] 124 | 125 | 126 | def limit_update(db, key, limits): 127 | """ 128 | Safely updates the list of limits in the database. 129 | 130 | :param db: The database handle. 131 | :param key: The key the limits are stored under. 132 | :param limits: A list or sequence of limit objects, each 133 | understanding the dehydrate() method. 134 | 135 | The limits list currently in the database will be atomically 136 | changed to match the new list. This is done using the pipeline() 137 | method. 138 | """ 139 | 140 | # Start by dehydrating all the limits 141 | desired = [msgpack.dumps(l.dehydrate()) for l in limits] 142 | desired_set = set(desired) 143 | 144 | # Now, let's update the limits 145 | with db.pipeline() as pipe: 146 | while True: 147 | try: 148 | # Watch for changes to the key 149 | pipe.watch(key) 150 | 151 | # Look up the existing limits 152 | existing = set(pipe.zrange(key, 0, -1)) 153 | 154 | # Start the transaction... 155 | pipe.multi() 156 | 157 | # Remove limits we no longer have 158 | for lim in existing - desired_set: 159 | pipe.zrem(key, lim) 160 | 161 | # Update or add all our desired limits 162 | for idx, lim in enumerate(desired): 163 | pipe.zadd(key, (idx + 1) * 10, lim) 164 | 165 | # Execute the transaction 166 | pipe.execute() 167 | except redis.WatchError: 168 | # Try again... 169 | continue 170 | else: 171 | # We're all done! 172 | break 173 | 174 | 175 | def command(db, channel, command, *args): 176 | """ 177 | Utility function to issue a command to all Turnstile instances. 178 | 179 | :param db: The database handle. 180 | :param channel: The control channel all Turnstile instances are 181 | listening on. 182 | :param command: The command, as plain text. Currently, only 183 | 'reload' and 'ping' are recognized. 184 | 185 | All remaining arguments are treated as arguments for the command; 186 | they will be stringified and sent along with the command to the 187 | control channel. Note that ':' is an illegal character in 188 | arguments, but no warnings will be issued if it is used. 189 | """ 190 | 191 | # Build the command we're sending 192 | cmd = [command] 193 | cmd.extend(str(a) for a in args) 194 | 195 | # Send it out 196 | db.publish(channel, ':'.join(cmd)) 197 | -------------------------------------------------------------------------------- /turnstile/middleware.py: -------------------------------------------------------------------------------- 1 | # Copyright 2012 Rackspace 2 | # All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | # not use this file except in compliance with the License. You may obtain 6 | # a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13 | # License for the specific language governing permissions and limitations 14 | # under the License. 15 | 16 | import collections 17 | import logging 18 | import math 19 | import traceback 20 | 21 | import eventlet 22 | import routes 23 | 24 | from turnstile import config 25 | from turnstile import control 26 | from turnstile import database 27 | from turnstile import remote 28 | from turnstile import utils 29 | 30 | 31 | LOG = logging.getLogger('turnstile') 32 | 33 | 34 | class HeadersDict(collections.MutableMapping): 35 | """ 36 | A dictionary class for headers. All keys are mapped to lowercase. 37 | """ 38 | 39 | def __init__(self, *args, **kwargs): 40 | """ 41 | Initialize HeadersDict. Uses update() to process additional 42 | arguments. 43 | """ 44 | 45 | self.headers = {} 46 | self.update(*args, **kwargs) 47 | 48 | def __getitem__(self, key): 49 | """ 50 | Retrieve an item. 51 | """ 52 | 53 | return self.headers[key.lower()] 54 | 55 | def __setitem__(self, key, value): 56 | """ 57 | Set an item. 58 | """ 59 | 60 | self.headers[key.lower()] = value 61 | 62 | def __delitem__(self, key): 63 | """ 64 | Delete an item. 65 | """ 66 | 67 | del self.headers[key.lower()] 68 | 69 | def __contains__(self, key): 70 | """ 71 | Test if the headers dictionary contains a given header. 72 | """ 73 | 74 | return key.lower() in self.headers 75 | 76 | def __iter__(self): 77 | """ 78 | Iterate through the headers dictionary. 79 | """ 80 | 81 | return iter(self.headers) 82 | 83 | def __len__(self): 84 | """ 85 | Retrieve the length of the headers dictionary. 86 | """ 87 | 88 | return len(self.headers) 89 | 90 | def iterkeys(self): 91 | """ 92 | Iterate through header names. 93 | """ 94 | 95 | return self.headers.iterkeys() 96 | 97 | def iteritems(self): 98 | """ 99 | Iterate through header items. 100 | """ 101 | 102 | return self.headers.iteritems() 103 | 104 | def itervalues(self): 105 | """ 106 | Iterate through header values. 107 | """ 108 | 109 | return self.headers.itervalues() 110 | 111 | def keys(self): 112 | """ 113 | Return a list of header names. 114 | """ 115 | 116 | return self.headers.keys() 117 | 118 | def items(self): 119 | """ 120 | Return a list of header items. 121 | """ 122 | 123 | return self.headers.items() 124 | 125 | def values(self): 126 | """ 127 | Return a list of header values. 128 | """ 129 | 130 | return self.headers.values() 131 | 132 | 133 | def turnstile_filter(global_conf, **local_conf): 134 | """ 135 | Factory function for turnstile. 136 | 137 | Returns a function which, when passed the application, returns an 138 | instance of the TurnstileMiddleware. 139 | """ 140 | 141 | # Select the appropriate middleware class to return 142 | klass = TurnstileMiddleware 143 | if 'turnstile' in local_conf: 144 | klass = utils.find_entrypoint('turnstile.middleware', 145 | local_conf['turnstile'], required=True) 146 | 147 | def wrapper(app): 148 | return klass(app, local_conf) 149 | 150 | return wrapper 151 | 152 | 153 | class TurnstileMiddleware(object): 154 | """ 155 | Turnstile Middleware. 156 | 157 | Instances of this class are WSGI middleware which perform the 158 | desired rate limiting. 159 | """ 160 | 161 | def __init__(self, app, local_conf): 162 | """ 163 | Initialize the turnstile middleware. Saves the configuration 164 | and sets up the list of preprocessors, connects to the 165 | database, and initiates the control daemon thread. 166 | """ 167 | 168 | # Save the application 169 | self.app = app 170 | self.limits = [] 171 | self.limit_sum = None 172 | self.mapper = None 173 | self.mapper_lock = eventlet.semaphore.Semaphore() 174 | 175 | # Save the configuration 176 | self.conf = config.Config(conf_dict=local_conf) 177 | 178 | # We will lazy-load the database 179 | self._db = None 180 | 181 | # Set up request pre- and post-processors 182 | self.preprocessors = [] 183 | self.postprocessors = [] 184 | enable = self.conf.get('enable') 185 | if enable is not None: 186 | # Use the enabler syntax 187 | for proc in enable.split(): 188 | # Try the preprocessor 189 | preproc = utils.find_entrypoint('turnstile.preprocessor', 190 | proc, compat=False) 191 | if preproc: 192 | self.preprocessors.append(preproc) 193 | 194 | # Now the postprocessor 195 | postproc = utils.find_entrypoint('turnstile.postprocessor', 196 | proc, compat=False) 197 | if postproc: 198 | # Note the reversed order 199 | self.postprocessors.insert(0, postproc) 200 | else: 201 | # Using the classic syntax; grab preprocessors... 202 | for preproc in self.conf.get('preprocess', '').split(): 203 | klass = utils.find_entrypoint('turnstile.preprocessor', 204 | preproc, required=True) 205 | self.preprocessors.append(klass) 206 | 207 | # And now the postprocessors... 208 | for postproc in self.conf.get('postprocess', '').split(): 209 | klass = utils.find_entrypoint('turnstile.postprocessor', 210 | postproc, required=True) 211 | self.postprocessors.append(klass) 212 | 213 | # Set up the alternative formatter 214 | formatter = self.conf.get('formatter') 215 | if formatter: 216 | formatter = utils.find_entrypoint('turnstile.formatter', 217 | formatter, required=True) 218 | self.formatter = lambda a, b, c, d, e: formatter( 219 | self.conf.status, a, b, c, d, e) 220 | else: 221 | self.formatter = self.format_delay 222 | 223 | # Initialize the control daemon 224 | if self.conf.to_bool(self.conf['control'].get('remote', 'no'), False): 225 | self.control_daemon = remote.RemoteControlDaemon(self, self.conf) 226 | else: 227 | self.control_daemon = control.ControlDaemon(self, self.conf) 228 | 229 | # Now start the control daemon 230 | self.control_daemon.start() 231 | 232 | # Emit a log message to indicate that we're running 233 | LOG.info("Turnstile middleware initialized") 234 | 235 | def recheck_limits(self): 236 | """ 237 | Re-check that the cached limits are the current limits. 238 | """ 239 | 240 | limit_data = self.control_daemon.get_limits() 241 | 242 | try: 243 | # Get the new checksum and list of limits 244 | new_sum, new_limits = limit_data.get_limits(self.limit_sum) 245 | 246 | # Convert the limits list into a list of objects 247 | lims = database.limits_hydrate(self.db, new_limits) 248 | 249 | # Build a new mapper 250 | mapper = routes.Mapper(register=False) 251 | for lim in lims: 252 | lim._route(mapper) 253 | 254 | # Save the new data 255 | self.limits = lims 256 | self.limit_sum = new_sum 257 | self.mapper = mapper 258 | except control.NoChangeException: 259 | # No changes to process; just keep going... 260 | return 261 | except Exception: 262 | # Log an error 263 | LOG.exception("Could not load limits") 264 | 265 | # Get our error set and publish channel 266 | control_args = self.conf['control'] 267 | error_key = control_args.get('errors_key', 'errors') 268 | error_channel = control_args.get('errors_channel', 'errors') 269 | 270 | # Get an informative message 271 | msg = "Failed to load limits: " + traceback.format_exc() 272 | 273 | # Store the message into the error set. We use a set 274 | # here because it's likely that more than one node 275 | # will generate the same message if there is an error, 276 | # and this avoids an explosion in the size of the set. 277 | with utils.ignore_except(): 278 | self.db.sadd(error_key, msg) 279 | 280 | # Publish the message to a channel 281 | with utils.ignore_except(): 282 | self.db.publish(error_channel, msg) 283 | 284 | def __call__(self, environ, start_response): 285 | """ 286 | Implements the processing of the turnstile middleware. Walks 287 | the list of limit filters, invoking their filters, then 288 | returns an appropriate response for the limit filter returning 289 | the longest delay. If no limit filter indicates that a delay 290 | is needed, the request is passed on to the application. 291 | """ 292 | 293 | with self.mapper_lock: 294 | # Check for updates to the limits 295 | self.recheck_limits() 296 | 297 | # Grab the current mapper 298 | mapper = self.mapper 299 | 300 | # Run the request preprocessors; some may want to refer to 301 | # the limit data, so protect this in the mapper_lock 302 | for preproc in self.preprocessors: 303 | # Preprocessors are expected to modify the environment; 304 | # they are helpers to set up variables expected by the 305 | # limit classes. 306 | preproc(self, environ) 307 | 308 | # Make configuration available to the limit classes as well 309 | environ['turnstile.conf'] = self.conf 310 | 311 | # Now, if we have a mapper, run through it 312 | if mapper: 313 | mapper.routematch(environ=environ) 314 | 315 | # If there were any delays, deal with them 316 | if 'turnstile.delay' in environ and environ['turnstile.delay']: 317 | # Find the longest delay 318 | delay, limit, bucket = sorted(environ['turnstile.delay'], 319 | key=lambda x: x[0])[-1] 320 | 321 | return self.formatter(delay, limit, bucket, 322 | environ, start_response) 323 | 324 | with self.mapper_lock: 325 | # Run the request postprocessors; some may want to refer 326 | # to the limit data, so protect this in the mapper_lock 327 | for postproc in self.postprocessors: 328 | # Postprocessors are expected to modify the 329 | # environment; they are helpers to set up variables 330 | # expected by the limit classes. They run after the 331 | # limits are evaluated, to support reporting the 332 | # limits to the caller. 333 | postproc(self, environ) 334 | 335 | return self.app(environ, start_response) 336 | 337 | def format_delay(self, delay, limit, bucket, environ, start_response): 338 | """ 339 | Formats the over-limit response for the request. May be 340 | overridden in subclasses to allow alternate responses. 341 | """ 342 | 343 | # Set up the default status 344 | status = self.conf.status 345 | 346 | # Set up the retry-after header... 347 | headers = HeadersDict([('Retry-After', "%d" % math.ceil(delay))]) 348 | 349 | # Let format fiddle with the headers 350 | status, entity = limit.format(status, headers, environ, bucket, 351 | delay) 352 | 353 | # Return the response 354 | start_response(status, headers.items()) 355 | return entity 356 | 357 | @property 358 | def db(self): 359 | """ 360 | Obtain a handle for the database. This allows lazy 361 | initialization of the database handle. 362 | """ 363 | 364 | # Initialize the database handle 365 | if not self._db: 366 | self._db = self.conf.get_database() 367 | 368 | return self._db 369 | -------------------------------------------------------------------------------- /turnstile/remote.py: -------------------------------------------------------------------------------- 1 | # Copyright 2012 Rackspace 2 | # All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | # not use this file except in compliance with the License. You may obtain 6 | # a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13 | # License for the specific language governing permissions and limitations 14 | # under the License. 15 | 16 | import functools 17 | import json 18 | import logging 19 | import socket 20 | import sys 21 | import time 22 | import warnings 23 | 24 | import eventlet 25 | 26 | from turnstile import control 27 | from turnstile import utils 28 | 29 | 30 | LOG = logging.getLogger('turnstile') 31 | 32 | 33 | class ConnectionClosed(Exception): 34 | """Raised to indicate a connection has been closed.""" 35 | 36 | pass 37 | 38 | 39 | class Connection(object): 40 | """Buffered network connection.""" 41 | 42 | def __init__(self, sock): 43 | """Initialize a Connection object.""" 44 | 45 | self._sock = sock 46 | self._recvbuf = [] 47 | self._recvbuf_partial = '' 48 | 49 | def close(self): 50 | """ 51 | Close the connection. 52 | 53 | :param purge: If True (the default), the receive buffer will 54 | be purged. 55 | """ 56 | 57 | # Close the underlying socket 58 | if self._sock: 59 | with utils.ignore_except(): 60 | self._sock.close() 61 | self._sock = None 62 | 63 | # Purge the message buffers 64 | self._recvbuf = [] 65 | self._recvbuf_partial = '' 66 | 67 | def send(self, cmd, *payload): 68 | """ 69 | Send a command message to the other end. 70 | 71 | :param cmd: The command to send to the other end. 72 | :param payload: The command payload. Note that all elements 73 | of the payload must be serializable to JSON. 74 | """ 75 | 76 | # If it's closed, raise an error up front 77 | if not self._sock: 78 | raise ConnectionClosed("Connection closed") 79 | 80 | # Construct the outgoing message 81 | msg = json.dumps(dict(cmd=cmd, payload=payload)) + '\n' 82 | 83 | # Send it 84 | try: 85 | self._sock.sendall(msg) 86 | except socket.error: 87 | # We'll need to re-raise 88 | e_type, e_value, e_tb = sys.exc_info() 89 | 90 | # Make sure the socket is closed 91 | self.close() 92 | 93 | # Re-raise 94 | raise e_type, e_value, e_tb 95 | 96 | def _recvbuf_pop(self): 97 | """ 98 | Internal helper to pop a message off the receive buffer. If 99 | the message is an Exception, that exception will be raised; 100 | otherwise, a tuple of command and payload will be returned. 101 | """ 102 | 103 | # Pop a message off the recv buffer and return (or raise) it 104 | msg = self._recvbuf.pop(0) 105 | if isinstance(msg, Exception): 106 | raise msg 107 | return msg['cmd'], msg['payload'] 108 | 109 | def recv(self): 110 | """ 111 | Receive a message from the other end. Returns a tuple of the 112 | command (a string) and payload (a list). 113 | """ 114 | 115 | # See if we have a message to process... 116 | if self._recvbuf: 117 | return self._recvbuf_pop() 118 | 119 | # If it's closed, don't try to read more data 120 | if not self._sock: 121 | raise ConnectionClosed("Connection closed") 122 | 123 | # OK, get some data from the socket 124 | while True: 125 | try: 126 | data = self._sock.recv(4096) 127 | except socket.error: 128 | # We'll need to re-raise 129 | e_type, e_value, e_tb = sys.exc_info() 130 | 131 | # Make sure the socket is closed 132 | self.close() 133 | 134 | # Re-raise 135 | raise e_type, e_value, e_tb 136 | 137 | # Did the connection get closed? 138 | if not data: 139 | # There can never be anything in the buffer here 140 | self.close() 141 | raise ConnectionClosed("Connection closed") 142 | 143 | # Begin parsing the read-in data 144 | partial = self._recvbuf_partial + data 145 | self._recvbuf_partial = '' 146 | while partial: 147 | msg, sep, partial = partial.partition('\n') 148 | 149 | # If we have no sep, then it's not a complete message, 150 | # and the remainder is in msg 151 | if not sep: 152 | self._recvbuf_partial = msg 153 | break 154 | 155 | # Parse the message 156 | try: 157 | self._recvbuf.append(json.loads(msg)) 158 | except ValueError as exc: 159 | # Error parsing the message; save the exception, 160 | # which we will re-raise 161 | self._recvbuf.append(exc) 162 | 163 | # Make sure we have a message to return 164 | if self._recvbuf: 165 | return self._recvbuf_pop() 166 | 167 | # We have no complete messages; loop around and try to 168 | # read more data 169 | 170 | 171 | def remote(func): 172 | """ 173 | Decorator to mark a function as invoking a remote procedure call. 174 | When invoked in server mode, the function will be called; when 175 | invoked in client mode, an RPC will be initiated. 176 | """ 177 | 178 | @functools.wraps(func) 179 | def wrapper(self, *args, **kwargs): 180 | if self.mode == 'server': 181 | # In server mode, call the function 182 | return func(self, *args, **kwargs) 183 | 184 | # Make sure we're connected 185 | if not self.conn: 186 | self.connect() 187 | 188 | # Call the remote function 189 | self.conn.send('CALL', func.__name__, args, kwargs) 190 | 191 | # Receive the response 192 | cmd, payload = self.conn.recv() 193 | if cmd == 'ERR': 194 | self.close() 195 | raise Exception("Catastrophic error from server: %s" % 196 | payload[0]) 197 | elif cmd == 'EXC': 198 | exc_type = utils.find_entrypoint(None, payload[0]) 199 | raise exc_type(payload[1]) 200 | elif cmd != 'RES': 201 | self.close() 202 | raise Exception("Invalid command response from server: %s" % cmd) 203 | 204 | return payload[0] 205 | 206 | # Mark it a callable 207 | wrapper._remote = True 208 | 209 | # Return the wrapped function 210 | return wrapper 211 | 212 | 213 | def _create_server(host, port): 214 | """ 215 | Helper function. Creates a listening socket on the designated 216 | host and port. Modeled on the socket.create_connection() 217 | function. 218 | """ 219 | 220 | exc = socket.error("getaddrinfo returns an empty list") 221 | for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM): 222 | af, socktype, proto, canonname, sa = res 223 | sock = None 224 | try: 225 | # Create the listening socket 226 | sock = socket.socket(af, socktype, proto) 227 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 228 | sock.bind(sa) 229 | sock.listen(1024) 230 | return sock 231 | 232 | except socket.error as exc: 233 | # Clean up after ourselves 234 | if sock is not None: 235 | sock.close() 236 | 237 | # Couldn't create a listening socket 238 | raise exc 239 | 240 | 241 | class SimpleRPC(object): 242 | """ 243 | Implements simple remote procedure call. When run in client mode 244 | (by calling the connect() method), designated remote calls will be 245 | submitted to the server (specified by the arguments to the 246 | constructor). When run in server mode (by calling the listen() 247 | method), client connections are accepted, and client requests 248 | handled by calling the requested functions. 249 | 250 | Note: The connection is not secured through cryptographic means. 251 | Clients authenticate by sending an authkey to the server, which 252 | must match the authkey used by the server. It is strongly 253 | recommended that this class only be used on the local host. 254 | """ 255 | 256 | connection_class = Connection 257 | max_err_thresh = 10 258 | 259 | def __init__(self, host, port, authkey): 260 | """ 261 | Initialize a SimpleRPC object. 262 | 263 | :param host: The host the server will listen on. It is 264 | strongly recommended that this always be 265 | "127.0.0.1", since no cryptography is used. 266 | :param port: The TCP port the server will listen on. 267 | :param authkey: An authentication key. The server and all 268 | clients must use the same authentication key. 269 | """ 270 | 271 | self.host = host 272 | self.port = port 273 | self.authkey = authkey 274 | 275 | self.mode = None 276 | self.conn = None 277 | 278 | def close(self): 279 | """ 280 | Close the connection to the server. 281 | """ 282 | 283 | # Close the connection 284 | if self.conn: 285 | self.conn.close() 286 | self.conn = None 287 | 288 | def ping(self): 289 | """ 290 | Ping the server. Returns the time interval, in seconds, 291 | required for the server to respond to the PING message. 292 | """ 293 | 294 | # Make sure we're connected 295 | if not self.conn: 296 | self.connect() 297 | 298 | # Send the ping and wait for the response 299 | self.conn.send('PING', time.time()) 300 | cmd, payload = self.conn.recv() 301 | recv_ts = time.time() 302 | 303 | # Make sure the response was a PONG 304 | if cmd != 'PONG': 305 | raise Exception("Invalid response from server") 306 | 307 | # Return the RTT 308 | return recv_ts - payload[0] 309 | 310 | def connect(self): 311 | """ 312 | Connect to the server. This method causes the SimpleRPC 313 | object to switch to client mode. Note some methods, such as 314 | the ping() method, implicitly call this method. 315 | """ 316 | 317 | # Make sure we're in client mode 318 | if self.mode and self.mode != 'client': 319 | raise ValueError("%s is not in client mode" % 320 | self.__class__.__name__) 321 | self.mode = 'client' 322 | 323 | # If we're connected, nothing to do 324 | if self.conn: 325 | return 326 | 327 | # OK, attempt the connection 328 | fd = socket.create_connection((self.host, self.port)) 329 | 330 | # Initialize the connection object 331 | self.conn = self.connection_class(fd) 332 | 333 | # Authenticate 334 | try: 335 | self.conn.send('AUTH', self.authkey) 336 | cmd, payload = self.conn.recv() 337 | if cmd != 'OK': 338 | LOG.error("Failed to authenticate to %s port %s: %s" % 339 | (self.host, self.port, payload[0])) 340 | self.close() 341 | except Exception: 342 | exc_type, exc_value, exc_tb = sys.exc_info() 343 | 344 | # Log the error 345 | if exc_type == ValueError: 346 | LOG.error("Received bogus response from server: %s" % 347 | str(exc_value)) 348 | elif exc_type == ConnectionClosed: 349 | LOG.error("%s while authenticating to server" % 350 | str(exc_value)) 351 | else: 352 | LOG.exception("Failed to authenticate to server") 353 | 354 | # Close the connection 355 | self.close() 356 | 357 | # Re-raise the exception 358 | raise exc_type, exc_value, exc_tb 359 | 360 | def listen(self): 361 | """ 362 | Listen for clients. This method causes the SimpleRPC object 363 | to switch to server mode. One thread will be created for each 364 | client. 365 | """ 366 | 367 | # Make sure we're in server mode 368 | if self.mode and self.mode != 'server': 369 | raise ValueError("%s is not in server mode" % 370 | self.__class__.__name__) 371 | self.mode = 'server' 372 | 373 | # Obtain a listening socket 374 | serv = _create_server(self.host, self.port) 375 | 376 | # If we have too many errors, we want to bail out 377 | err_thresh = 0 378 | while True: 379 | # Accept a connection 380 | try: 381 | sock, addr = serv.accept() 382 | except Exception as exc: 383 | err_thresh += 1 384 | if err_thresh >= self.max_err_thresh: 385 | LOG.exception("Too many errors accepting " 386 | "connections: %s" % str(exc)) 387 | break 388 | continue # Pragma: nocover 389 | 390 | # Decrement error count on successful connections 391 | err_thresh = max(err_thresh - 1, 0) 392 | 393 | # Log the connection attempt 394 | LOG.info("Accepted connection from %s port %s" % 395 | (addr[0], addr[1])) 396 | 397 | # And handle the connection 398 | eventlet.spawn_n(self.serve, self.connection_class(sock), addr) 399 | 400 | # Close the listening socket 401 | with utils.ignore_except(): 402 | serv.close() 403 | 404 | def _get_remote_method(self, funcname): 405 | """ 406 | Look up the named remote method. Broken out from serve() for 407 | testing purposes. 408 | 409 | :param funcname: The name of the function to look up. 410 | """ 411 | 412 | func = getattr(self, funcname) 413 | if not callable(func) or not getattr(func, '_remote', False): 414 | raise AttributeError("%r object has no attribute %r" % 415 | (self.__class__.__name__, funcname)) 416 | 417 | return func 418 | 419 | def serve(self, conn, addr, auth=False): 420 | """ 421 | Handle a single client. 422 | 423 | :param conn: The Connection instance. 424 | :param addr: The address of the client, for logging purposes. 425 | :param auth: A boolean specifying whether the connection 426 | should be considered authenticated or not. 427 | Provided for debugging. 428 | """ 429 | 430 | try: 431 | # Handle data from the client 432 | while True: 433 | # Get the command 434 | try: 435 | cmd, payload = conn.recv() 436 | except ValueError as exc: 437 | # Tell the client about the error 438 | conn.send('ERR', "Failed to parse command: %s" % str(exc)) 439 | 440 | # If they haven't successfully authenticated yet, 441 | # disconnect them 442 | if not auth: 443 | return 444 | continue # Pragma: nocover 445 | 446 | # Log the command and payload, for debugging purposes 447 | LOG.debug("Received command %r from %s port %s; payload: %r" % 448 | (cmd, addr[0], addr[1], payload)) 449 | 450 | # Handle authentication 451 | if cmd == 'AUTH': 452 | if auth: 453 | conn.send('ERR', "Already authenticated") 454 | elif payload[0] != self.authkey: 455 | # Don't give them a second chance 456 | conn.send('ERR', "Invalid authentication key") 457 | return 458 | else: 459 | # Authentication successful 460 | conn.send('OK') 461 | auth = True 462 | 463 | # Handle unauthenticated connections 464 | elif not auth: 465 | # No second chances 466 | conn.send('ERR', "Not authenticated") 467 | return 468 | 469 | # Handle aliveness test 470 | elif cmd == 'PING': 471 | conn.send('PONG', *payload) 472 | 473 | # Handle a function call command 474 | elif cmd == 'CALL': 475 | try: 476 | # Get the call parameters 477 | try: 478 | funcname, args, kwargs = payload 479 | except ValueError as exc: 480 | conn.send('ERR', "Invalid payload for 'CALL' " 481 | "command: %s" % str(exc)) 482 | continue 483 | 484 | # Look up the function 485 | func = self._get_remote_method(funcname) 486 | 487 | # Call the function 488 | result = func(*args, **kwargs) 489 | except Exception as exc: 490 | exc_name = '%s:%s' % (exc.__class__.__module__, 491 | exc.__class__.__name__) 492 | conn.send('EXC', exc_name, str(exc)) 493 | else: 494 | # Return the result 495 | conn.send('RES', result) 496 | 497 | # Handle all other commands by returning an ERR 498 | else: 499 | conn.send('ERR', "Unrecognized command %r" % cmd) 500 | 501 | except ConnectionClosed: 502 | # Ignore the connection closed error 503 | pass 504 | except Exception as exc: 505 | # Log other exceptions 506 | LOG.exception("Error serving client at %s port %s: %s" % 507 | (addr[0], addr[1], str(exc))) 508 | 509 | finally: 510 | LOG.info("Closing connection from %s port %s" % 511 | (addr[0], addr[1])) 512 | 513 | # Make sure the socket gets closed 514 | conn.close() 515 | 516 | 517 | class ControlDaemonRPC(SimpleRPC): 518 | """ 519 | A SimpleRPC subclass for use by the Turnstile control daemon. 520 | """ 521 | 522 | def __init__(self, host, port, authkey, daemon): 523 | """ 524 | Initialize a ControlDaemonRPC object. 525 | 526 | :param host: The host the server will listen on. It is 527 | strongly recommended that this always be 528 | "127.0.0.1", since no cryptography is used. 529 | :param port: The TCP port the server will listen on. 530 | :param authkey: An authentication key. The server and all 531 | clients must use the same authentication key. 532 | :param daemon: The control daemon instance. 533 | """ 534 | 535 | super(ControlDaemonRPC, self).__init__(host, port, authkey) 536 | self.daemon = daemon 537 | 538 | @remote 539 | def get_limits(self, limit_sum): 540 | """ 541 | Retrieve a list of msgpack'd limit strings if the checksum is 542 | not the one given. Raises turnstile.control.NoChangeException 543 | if the checksums match. 544 | """ 545 | 546 | return self.daemon.limits.get_limits(limit_sum) 547 | 548 | 549 | class RemoteLimitData(object): 550 | """ 551 | Provides remote access to limit data stored in another process. 552 | This uses an RPC to obtain limit data maintained by the 553 | RemoteControlDaemon process. 554 | """ 555 | 556 | def __init__(self, rpc): 557 | """ 558 | Initialize RemoteLimitData. Stores a reference to the RPC 559 | client object. 560 | """ 561 | 562 | self.limit_rpc = rpc 563 | self.limit_lock = eventlet.semaphore.Semaphore() 564 | 565 | def set_limits(self, limits): 566 | """ 567 | Remote limit data is treated as read-only (with external 568 | update). 569 | """ 570 | 571 | raise ValueError("Cannot set remote limit data") 572 | 573 | def get_limits(self, limit_sum=None): 574 | """ 575 | Gets the current limit data if it is different from the data 576 | indicated by limit_sum. The db argument is used for hydrating 577 | the limit objects. Raises a NoChangeException if the 578 | limit_sum represents no change, otherwise returns a tuple 579 | consisting of the current limit_sum and a list of Limit 580 | objects. 581 | """ 582 | 583 | with self.limit_lock: 584 | # Grab the checksum and limit list 585 | try: 586 | return self.limit_rpc.get_limits(limit_sum) 587 | except control.NoChangeException: 588 | # Expected possibility 589 | raise 590 | except Exception: 591 | # Something happened; maybe the server isn't running. 592 | # Pretend that there's no change... 593 | raise control.NoChangeException() 594 | 595 | 596 | class RemoteControlDaemon(control.ControlDaemon): 597 | """ 598 | A daemon process which listens for control messages and can reload 599 | the limit configuration from the database. Based on the 600 | ControlDaemon, but starts an RPC server to enable access to the 601 | limit data from multiple processes. 602 | """ 603 | 604 | def __init__(self, middleware, conf): 605 | """ 606 | Initialize the RemoteControlDaemon. 607 | """ 608 | 609 | # Grab required configuration values 610 | required = { 611 | 'remote.host': lambda x: x, 612 | 'remote.port': int, 613 | 'remote.authkey': lambda x: x, 614 | } 615 | values = {} 616 | for conf_key, xform in required.items(): 617 | try: 618 | values[conf_key[len('remote.'):]] = \ 619 | xform(conf['control'][conf_key]) 620 | except KeyError: 621 | warnings.warn("Missing value for configuration key " 622 | "'control.%s'" % conf_key) 623 | except ValueError: 624 | warnings.warn("Invalid value for configuration key " 625 | "'control.%s'" % conf_key) 626 | else: 627 | del required[conf_key] 628 | 629 | # Error out if we're missing something critical 630 | if required: 631 | raise ValueError("Missing required configuration for " 632 | "RemoteControlDaemon. Missing or invalid " 633 | "configuration keys: %s" % 634 | ', '.join(['control.%s' % k 635 | for k in sorted(required.keys())])) 636 | 637 | super(RemoteControlDaemon, self).__init__(middleware, conf) 638 | 639 | # Set up the RPC object 640 | self.remote = ControlDaemonRPC(daemon=self, **values) 641 | self.remote_limits = None 642 | 643 | def get_limits(self): 644 | """ 645 | Retrieve the LimitData object the middleware will use for 646 | getting the limits. This implementation returns a 647 | RemoteLimitData instance that can access the LimitData stored 648 | in the RemoteControlDaemon process. 649 | """ 650 | 651 | # Set one up if we don't already have it 652 | if not self.remote_limits: 653 | self.remote_limits = RemoteLimitData(self.remote) 654 | return self.remote_limits 655 | 656 | def start(self): 657 | """ 658 | Starts the RemoteControlDaemon. 659 | """ 660 | 661 | # Don't connect the client yet, to avoid problems if we fork 662 | pass # Pragma: nocover 663 | 664 | def serve(self): 665 | """ 666 | Starts the RemoteControlDaemon process. Forks a thread for 667 | listening to the Redis database, then initializes and starts 668 | the RPC server. 669 | """ 670 | 671 | # Start the listening thread and load the limits 672 | super(RemoteControlDaemon, self).start() 673 | 674 | # Start the RPC server in this thread 675 | self.remote.listen() 676 | 677 | @property 678 | def db(self): 679 | """ 680 | Obtain a handle for the database. This allows lazy 681 | initialization of the database handle. 682 | """ 683 | 684 | # Initialize the database handle; we're running in a separate 685 | # process, so we need to get_database() ourself 686 | if not self._db: 687 | self._db = self.config.get_database() 688 | 689 | return self._db 690 | -------------------------------------------------------------------------------- /turnstile/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2012 Rackspace 2 | # All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); you may 5 | # not use this file except in compliance with the License. You may obtain 6 | # a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12 | # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13 | # License for the specific language governing permissions and limitations 14 | # under the License. 15 | 16 | import sys 17 | 18 | import pkg_resources 19 | 20 | 21 | def find_entrypoint(group, name, compat=True, required=False): 22 | """ 23 | Finds the first available entrypoint with the given name in the 24 | given group. 25 | 26 | :param group: The entrypoint group the name can be found in. If 27 | None, the name is not presumed to be an entrypoint. 28 | :param name: The name of the entrypoint. 29 | :param compat: If True, and if the name parameter contains a ':', 30 | the name will be interpreted as a module name and 31 | an object name, separated by a colon. This is 32 | provided for compatibility. 33 | :param required: If True, and no corresponding entrypoint can be 34 | found, an ImportError will be raised. If False 35 | (the default), None will be returned instead. 36 | 37 | :returns: The entrypoint object, or None if one could not be 38 | loaded. 39 | """ 40 | 41 | if group is None or (compat and ':' in name): 42 | try: 43 | return pkg_resources.EntryPoint.parse("x=" + name).load(False) 44 | except (ImportError, pkg_resources.UnknownExtra) as exc: 45 | pass 46 | else: 47 | for ep in pkg_resources.iter_entry_points(group, name): 48 | try: 49 | # Load and return the object 50 | return ep.load() 51 | except (ImportError, pkg_resources.UnknownExtra): 52 | # Couldn't load it; try the next one 53 | continue 54 | 55 | # Raise an ImportError if requested 56 | if required: 57 | raise ImportError("Cannot import %r entrypoint %r" % (group, name)) 58 | 59 | # Couldn't find one... 60 | return None 61 | 62 | 63 | class ignore_except(object): 64 | """Context manager to ignore all exceptions.""" 65 | 66 | def __enter__(self): 67 | """Entry does nothing.""" 68 | 69 | pass 70 | 71 | def __exit__(self, exc_type, exc_value, exc_traceback): 72 | """Return True to mark the exception as handled.""" 73 | 74 | return True 75 | --------------------------------------------------------------------------------