├── LICENSE ├── README.md ├── flask_session_plus ├── __init__.py ├── backends.py ├── core.py └── session.py ├── release.py ├── requirements-dev.txt ├── requirements-test.txt ├── setup.py └── test ├── __init__.py ├── auth.py ├── flask_app.py ├── models.py └── templates └── test_csrf.html /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Alejandro Casanovas 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Flask Multiple Sessions Interface 2 | 3 | #### Combine multiple sessions with different backends 4 | 5 | With Flask Session Plus you can use multiple different backends and choose what session variables are saved on what backend. 6 | 7 | 8 | ##### Python version: 9 | > It works on python >= 3.4 10 | > For the moment it should work on python 2.7 but it is not tested yet. If something does not work properly please open a bug report. 11 | > 12 | 13 | ##### Install it with: 14 | 15 | `pip install flask-session-plus` 16 | 17 | For Flask Multi Session to work, all you have to do is define all your sessions on a simple configuration variable called `SESSION_CONFIG`, and init the extension. 18 | 19 | 20 | ##### Session Configuration Example: 21 | 22 | ```python 23 | # example using the Google Firestore backend 24 | from google.cloud import firestore 25 | 26 | SESSION_CONFIG = [ 27 | # First session will store the csrf_token only on it's own cookie. 28 | { 29 | 'cookie_name': 'csrf', 30 | 'session_type': 'secure_cookie', 31 | 'session_fields': ['csrf_token'], 32 | }, 33 | # Second session will store the user logged in inside the firestore sessions collection. 34 | { 35 | 'cookie_name': 'session', 36 | 'session_type': 'firestore', 37 | 'session_fields': ['user_id', 'user_data'], 38 | 'client': firestore.Client(), 39 | 'collection': 'sessions', 40 | }, 41 | # Third session will store any other values set on the Flask session on it's own secure cookie 42 | { 43 | 'cookie_name': 'data', 44 | 'session_type': 'secure_cookie', 45 | 'session_fields': 'auto' 46 | }, 47 | # ... as many sessions as you want 48 | ] 49 | ``` 50 | 51 | > Caution: session_fields can collide if they have the same name and the same meaning. If they don't have the same meaning, you must use different field names. 52 | 53 | The above configuration will define three session interfaces: 54 | 55 | - The first one is a secure cookie with 'csrf' name that will store the 'csrf_token' field. 56 | - The second one is a FirestoreSessionInterface that will set a cookie named 'session' with a single session id. The 'user_id' and 'user_data' will be stored in the Google Cloud Firestore backend. 57 | - The third one will store any other varibles stored in the session on another secure cookie. 58 | 59 | After configuring just register it as an extension: 60 | 61 | ```python 62 | from flask_session_plus import Session 63 | 64 | app = Flask(__name__) 65 | 66 | Session(app) 67 | ``` 68 | 69 | or 70 | 71 | ```python 72 | from flask_session_plus import Session 73 | 74 | app = Flask(__name__) 75 | 76 | session = Session() 77 | 78 | session.init_app(app) 79 | ``` 80 | 81 | 82 | --- 83 | 84 | ### Current available backends: 85 | 86 | - Secure Cookies Sessions (session_type key: `'secure_cookie'`) 87 | - Google Firestore Sessions (session_type key: `'firestore'`) 88 | - Redis Sessions (session_type key: `'redis'`) 89 | - MongoDB Sessions (session_type key: `'mongodb'`) 90 | - Memcache Sessions (session_type key: `'memcache'`) 91 | 92 | 93 | More Backend Session Interfaces can be created by subclassing `BackendSessionInterface` and overwriting the following methods: 94 | 95 | 1. `__init__` 96 | 1. `open_session` 97 | 1. `save_session` 98 | 99 | ### All posible values for Session configuration: 100 | 101 | 102 | - Common properties for all backends: 103 | 104 | Property name | Required | Default | Description 105 | --- | :---: | --- | --- 106 | `cookie_name` | `True` | | The name of the cookie to use. It also serves as a key for different sessions. 107 | `session_type` | `False` | `'secure_cookie'` | The session backend to use. 108 | `session_fields` | `False` | `[]` | The fields that are owned by this session. An empty list means: 'include all fields'. It can be: 1) an array of fields to include, 2) a dict with the keys 'include' or 'exclude', to include or exclude a list of fields or 3) the string 'auto' to auto exclude all the other session fields. 109 | 110 | - Properties for SecureCookie (available for all backends): 111 | 112 | Property name | Required | Default | Description 113 | --- | :---: | --- | --- 114 | `cookie_domain` | `False` | | The domain for the session cookie. If this is not set, the cookie will be valid for all subdomains of SERVER_NAME.. 115 | `cookie_path` | `False` | | The path for the session cookie. If this is not set the cookie will be valid for all of APPLICATION_ROOT or if that is not set for '/'. 116 | `cookie_httponly` | `False` | `True` | Whether to allow access the cookie only over http or other ways (javascript). 117 | `cookie_secure` | `False` | `False` | Whether to serve this cookie over https only. 118 | `cookie_max_age` | `False` | `None` | The cookie expiration time in seconds. None means the cookie will expire at browser close. 119 | `cookie_samesite` | `False` | `'Lax'` | The cookie samesite configuration. 120 | 121 | - Properties available for any other backend rather than SecureCookie: 122 | 123 | Property name | Required | Default | Description 124 | --- | :---: | --- | --- 125 | `session_lifetime` | `False` | `timedelta(days=1)` | The duration for a valid session. Not used on SecureCookie backend. 126 | `session_permanent_lifetime` | `False` | `timedelta(days=31)` | The duration for a valid session when it's marked as permanent. Not used on SecureCookie backend. 127 | `key_prefix` | `False` | `'session'` | The prefix to use in the store_id. 128 | `use_signer` | `False` | `False` | Whether to sign the session id cookie or not. 129 | 130 | - Properties available for the Google Firestore backend: 131 | 132 | Property name | Required | Default | Description 133 | --- | :---: | --- | --- 134 | `client` | `True` | | The engine. An instance of firestore.Client. 135 | `collection` | `True` | | The firestore collection you want to use to store sessions. 136 | 137 | - Properties available for the Redis backend: 138 | 139 | Property name | Required | Default | Description 140 | --- | :---: | --- | --- 141 | `client` | `True` | | The engine. An instance of redis.Redis. 142 | 143 | - Properties available for the MongoDB backend: 144 | 145 | Property name | Required | Default | Description 146 | --- | :---: | --- | --- 147 | `client` | `True` | | The engine. An instance of redis.Redis. 148 | `db` | `True` | | The database you want to use. 149 | `collection` | `True` | | The mongodb collection you want to use to store sessions. 150 | 151 | - Properties available for the Memcache backend: 152 | 153 | Property name | Required | Default | Description 154 | --- | :---: | --- | --- 155 | `client` | `True` | | The engine. An instance of memcache.Client. 156 | -------------------------------------------------------------------------------- /flask_session_plus/__init__.py: -------------------------------------------------------------------------------- 1 | from flask_session_plus.session import Session 2 | -------------------------------------------------------------------------------- /flask_session_plus/backends.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import time 4 | from datetime import datetime, timedelta 5 | from uuid import uuid4 6 | import hashlib 7 | 8 | from flask.helpers import total_seconds 9 | from flask.json.tag import TaggedJSONSerializer 10 | from pytz import utc 11 | from flask.sessions import SessionInterface as FlaskSessionInterface 12 | from itsdangerous import BadSignature, want_bytes, Signer, URLSafeTimedSerializer 13 | from flask_session_plus.core import MultiSession 14 | 15 | log = logging.getLogger(__name__) 16 | 17 | PY2 = sys.version_info[0] == 2 18 | if not PY2: 19 | text_type = str 20 | else: 21 | text_type = unicode 22 | 23 | 24 | class BaseSessionInterface(FlaskSessionInterface): 25 | """ Base Session Interface """ 26 | 27 | def __init__(self, cookie_name, cookie_max_age=None, cookie_domain=None, 28 | cookie_path=None, cookie_httponly=True, cookie_secure=False, 29 | cookie_samesite=None, session_lifetime=None, 30 | session_permanent_lifetime=None, 31 | refresh_on_request=True, **kwargs): 32 | self.cookie_name = cookie_name 33 | self.cookie_max_age = cookie_max_age 34 | self.cookie_domain = cookie_domain 35 | self.cookie_path = cookie_path 36 | self.cookie_httponly = cookie_httponly 37 | self.cookie_secure = cookie_secure 38 | self.cookie_samesite = cookie_samesite 39 | self.session_lifetime = session_lifetime or timedelta(days=1) 40 | self.session_permanent_lifetime = session_permanent_lifetime or timedelta(days=31) 41 | self.refresh_on_request = refresh_on_request 42 | 43 | def get_expiration_time(self, app, session): 44 | """A helper method that returns an expiration date for the session 45 | or ``None`` if the session is linked to the browser session. The 46 | default implementation returns now + the permanent session 47 | lifetime configured on the application. 48 | """ 49 | if session.is_permanent(self.cookie_name): 50 | return datetime.utcnow() + self.session_permanent_lifetime 51 | else: 52 | return datetime.utcnow() + self.session_lifetime 53 | 54 | def get_cookie_expiration_time(self, app, session): 55 | """A helper method that returns an expiration date for the session 56 | or ``None`` if the session is linked to the browser session. The 57 | default implementation returns now + the permanent session 58 | lifetime configured on the application. 59 | """ 60 | if session.is_permanent(self.cookie_name): 61 | return datetime.utcnow() + self.session_permanent_lifetime 62 | else: 63 | if self.cookie_max_age is not None: 64 | return datetime.utcnow() + timedelta(seconds=self.cookie_max_age) 65 | else: 66 | return None 67 | 68 | def should_set_cookie(self, app, session): 69 | """Used by session backends to determine if a ``Set-Cookie`` header 70 | should be set for this session cookie for this response. If the session 71 | has been modified, the cookie is set. If the session is permanent and 72 | the ``SESSION_REFRESH_EACH_REQUEST`` config is true, the cookie is 73 | always set. 74 | 75 | This check is usually skipped if the session was deleted. 76 | 77 | .. versionadded:: 0.11 78 | """ 79 | return session.modified or self.refresh_on_request or app.config['SESSION_REFRESH_EACH_REQUEST'] 80 | 81 | def open_session(self, app, request): 82 | raise NotImplementedError 83 | 84 | def save_session(self, app, session, response): 85 | raise NotImplementedError 86 | 87 | 88 | session_json_serializer = TaggedJSONSerializer() 89 | 90 | 91 | class SecureCookieSessionInterface(BaseSessionInterface): 92 | """ A Secure Cookie Session Interface that works with Flask-Multi-Session """ 93 | 94 | #: the salt that should be applied on top of the secret key for the 95 | #: signing of cookie based sessions. 96 | salt = 'cookie-session' 97 | #: the hash function to use for the signature. The default is sha1 98 | digest_method = staticmethod(hashlib.sha3_256) 99 | #: the name of the itsdangerous supported key derivation. The default 100 | #: is hmac. 101 | key_derivation = 'hmac' 102 | #: A python serializer for the payload. The default is a compact 103 | #: JSON derived serializer with support for some extra Python types 104 | #: such as datetime objects or tuples. 105 | serializer = session_json_serializer 106 | session_class = MultiSession 107 | 108 | def get_signing_serializer(self, app): 109 | if not app.secret_key: 110 | return None 111 | signer_kwargs = dict( 112 | key_derivation=self.key_derivation, 113 | digest_method=self.digest_method 114 | ) 115 | return URLSafeTimedSerializer(app.secret_key, salt=self.salt, 116 | serializer=self.serializer, 117 | signer_kwargs=signer_kwargs) 118 | 119 | def open_session(self, app, request): 120 | s = self.get_signing_serializer(app) 121 | if s is None: 122 | return None 123 | val = request.cookies.get(self.cookie_name) 124 | if not val: 125 | return self.session_class() 126 | max_age = self.cookie_max_age or None 127 | 128 | try: 129 | data = s.loads(val, max_age=max_age) 130 | return self.session_class(data) 131 | except BadSignature: 132 | return self.session_class() 133 | 134 | def save_session(self, app, session, response): 135 | if self.cookie_domain is not None: 136 | domain = self.cookie_domain if self.cookie_domain else self.get_cookie_domain(app) 137 | else: 138 | domain = self.get_cookie_domain(app) 139 | 140 | path = self.cookie_path or self.get_cookie_path(app) 141 | 142 | # If the session is modified to be empty, remove the cookie. 143 | # If the session is empty, return without setting the cookie. 144 | if not session: 145 | if session.modified: 146 | response.delete_cookie( 147 | self.cookie_name, 148 | domain=domain, 149 | path=path 150 | ) 151 | 152 | return 153 | 154 | # Add a "Vary: Cookie" header if the session was accessed at all. 155 | if session.accessed: 156 | response.vary.add('Cookie') 157 | 158 | if not self.should_set_cookie(app, session): 159 | return 160 | 161 | httponly = self.cookie_httponly or self.get_cookie_httponly(app) 162 | secure = self.cookie_secure or self.get_cookie_secure(app) 163 | samesite = self.cookie_samesite or self.get_cookie_samesite(app) 164 | expires = self.get_cookie_expiration_time(app, session) 165 | val = self.get_signing_serializer(app).dumps(dict(session)) 166 | response.set_cookie( 167 | self.cookie_name, 168 | val, 169 | expires=expires, 170 | httponly=httponly, 171 | domain=domain, 172 | path=path, 173 | secure=secure, 174 | samesite=samesite 175 | ) 176 | 177 | 178 | class BackendSessionInterface(BaseSessionInterface): 179 | """ A common Session Interface for all backend Interfaces """ 180 | 181 | session_class = MultiSession 182 | 183 | def _generate_sid(self): 184 | return str(uuid4()) 185 | 186 | def _get_signer(self, app): 187 | if not app.secret_key: 188 | return None 189 | return Signer(app.secret_key, salt='flask-session', 190 | key_derivation='hmac') 191 | 192 | def open_session(self, app, request): 193 | raise NotImplementedError 194 | 195 | def save_session(self, app, session, response): 196 | raise NotImplementedError 197 | 198 | 199 | class FirestoreSessionInterface(BackendSessionInterface): 200 | """ A Session interface that uses Google Cloud Firestore as backend. """ 201 | 202 | def __init__(self, client, collection, key_prefix='session', use_signer=False, **kwargs): 203 | """ 204 | :param client: A 'firestore.Client' instance. 205 | :param collection: The collection you want to use. 206 | :param key_prefix: A prefix that is added to all session store keys. 207 | :param use_signer: Whether to sign the session id cookie or not. 208 | :param kwargs: extra params to the base class 209 | """ 210 | super(FirestoreSessionInterface, self).__init__(**kwargs) 211 | if client is None: 212 | from google.cloud import firestore 213 | client = firestore.Client() 214 | self.client = client 215 | self.store = client.collection(collection) 216 | self.key_prefix = key_prefix 217 | self.use_signer = use_signer 218 | 219 | def _delete_session_from_store(self, store_id): 220 | """ Deletes the session from the store """ 221 | try: 222 | self.store.document(store_id).delete() 223 | except Exception as e: 224 | log.error('Error while deleting expired session (session id: {}): {}'.format(store_id, str(e))) 225 | return False 226 | return True 227 | 228 | def open_session(self, app, request): 229 | sid = request.cookies.get(self.cookie_name) 230 | if not sid: 231 | sid = self._generate_sid() 232 | return self.session_class(sid={self.cookie_name: sid}) 233 | if self.use_signer: 234 | signer = self._get_signer(app) 235 | if signer is None: 236 | return None 237 | try: 238 | sid_as_bytes = signer.unsign(sid) 239 | sid = sid_as_bytes.decode() 240 | except BadSignature: 241 | sid = self._generate_sid() 242 | return self.session_class(sid={self.cookie_name: sid}) 243 | 244 | store_id = self.key_prefix + sid 245 | try: 246 | document = self.store.document(store_id).get() 247 | document = document.to_dict() if document.exists else None 248 | except Exception as e: 249 | log.error('Error while retrieving session from db (session id: {}): {}'.format(store_id, str(e))) 250 | # treat as session expired. 251 | document = None 252 | if document and document.pop('_expiration') <= datetime.utcnow().replace(tzinfo=utc): 253 | # Delete expired session 254 | self._delete_session_from_store(store_id) 255 | document = None 256 | if document is not None: 257 | try: 258 | data = document 259 | permanent = self.cookie_name if data.pop('_permanent', None) else None 260 | return self.session_class(data, sid={self.cookie_name: sid}, permanent=permanent) 261 | except: 262 | return self.session_class(sid={self.cookie_name: sid}) 263 | return self.session_class(sid={self.cookie_name: sid}) 264 | 265 | def save_session(self, app, session, response): 266 | if self.cookie_domain is not None: 267 | domain = self.cookie_domain if self.cookie_domain else self.get_cookie_domain(app) 268 | else: 269 | domain = self.get_cookie_domain(app) 270 | path = self.cookie_path or self.get_cookie_path(app) 271 | 272 | if not session: 273 | if session.modified: 274 | self._delete_session_from_store(self.key_prefix + session.get_sid(self.cookie_name)) 275 | response.delete_cookie(self.cookie_name, domain=domain, path=path) 276 | return 277 | 278 | httponly = self.cookie_httponly or self.get_cookie_httponly(app) 279 | secure = self.cookie_secure or self.get_cookie_secure(app) 280 | expires = self.get_expiration_time(app, session) 281 | 282 | if session.modified: 283 | # The session was modified 284 | store_id = self.key_prefix + session.get_sid(self.cookie_name) 285 | val = {'_expiration': expires, '_permanent': session.is_permanent(self.cookie_name)} 286 | val.update(dict(session)) 287 | try: 288 | self.store.document(store_id).set(val) 289 | except Exception as e: 290 | log.error('Error while updating session (session id: {}): {}'.format(store_id, str(e))) 291 | 292 | if self.use_signer: 293 | session_id = self._get_signer(app).sign(want_bytes(session.get_sid(self.cookie_name))) 294 | else: 295 | session_id = session.get_sid(self.cookie_name) 296 | cookie_expires = self.get_cookie_expiration_time(app, session) 297 | response.set_cookie(self.cookie_name, session_id, 298 | expires=cookie_expires, httponly=httponly, 299 | domain=domain, path=path, secure=secure) 300 | 301 | 302 | class RedisSessionInterface(BackendSessionInterface): 303 | """ A Session interface that uses Redis as backend. """ 304 | 305 | serializer = session_json_serializer 306 | 307 | def __init__(self, client, key_prefix='session', use_signer=False, **kwargs): 308 | """ 309 | :param redis: A 'redis.Redis' instance. 310 | :param key_prefix: A prefix that is added to all session store keys. 311 | :param use_signer: Whether to sign the session id cookie or not. 312 | :param kwargs: extra params to the base class 313 | """ 314 | super(RedisSessionInterface, self).__init__(**kwargs) 315 | if client is None: 316 | from redis import Redis 317 | client = Redis() 318 | self.client = client 319 | self.key_prefix = key_prefix 320 | self.use_signer = use_signer 321 | 322 | def _delete_session_from_store(self, store_id): 323 | """ Deletes the session from the store """ 324 | try: 325 | self.client.delete(store_id) 326 | except Exception as e: 327 | log.error('Error while deleting expired session (session id: {}): {}'.format(store_id, str(e))) 328 | return False 329 | return True 330 | 331 | def open_session(self, app, request): 332 | sid = request.cookies.get(self.cookie_name) 333 | if not sid: 334 | sid = self._generate_sid() 335 | return self.session_class(sid={self.cookie_name: sid}) 336 | if self.use_signer: 337 | signer = self._get_signer(app) 338 | if signer is None: 339 | return None 340 | try: 341 | sid_as_bytes = signer.unsign(sid) 342 | sid = sid_as_bytes.decode() 343 | except BadSignature: 344 | sid = self._generate_sid() 345 | return self.session_class(sid={self.cookie_name: sid}) 346 | 347 | if not PY2 and not isinstance(sid, text_type): 348 | sid = sid.decode('utf-8', 'strict') 349 | 350 | store_id = self.key_prefix + sid 351 | try: 352 | val = self.client.get(store_id) 353 | except Exception as e: 354 | log.error('Error while retrieving session from db (session id: {}): {}'.format(store_id, str(e))) 355 | # treat as session expired. 356 | val = None 357 | 358 | if val is not None: 359 | data = self.serializer.loads(val) 360 | else: 361 | data = None 362 | 363 | if data is not None: 364 | try: 365 | permanent = self.cookie_name if data.pop('_permanent', None) else None 366 | return self.session_class(data, sid={self.cookie_name: sid}, permanent=permanent) 367 | except: 368 | return self.session_class(sid={self.cookie_name: sid}) 369 | return self.session_class(sid={self.cookie_name: sid}) 370 | 371 | def save_session(self, app, session, response): 372 | if self.cookie_domain is not None: 373 | domain = self.cookie_domain if self.cookie_domain else self.get_cookie_domain(app) 374 | else: 375 | domain = self.get_cookie_domain(app) 376 | path = self.cookie_path or self.get_cookie_path(app) 377 | 378 | if not session: 379 | if session.modified: 380 | self._delete_session_from_store(self.key_prefix + session.get_sid(self.cookie_name)) 381 | response.delete_cookie(self.cookie_name, domain=domain, path=path) 382 | return 383 | 384 | httponly = self.cookie_httponly or self.get_cookie_httponly(app) 385 | secure = self.cookie_secure or self.get_cookie_secure(app) 386 | expires = self.get_expiration_time(app, session) 387 | 388 | if session.modified: 389 | # The session was modified 390 | store_id = self.key_prefix + session.get_sid(self.cookie_name) 391 | data = {'_permanent': session.is_permanent(self.cookie_name)} 392 | data.update(dict(session)) 393 | val = self.serializer.dumps(data) 394 | try: 395 | self.client.setex(name=store_id, value=val, time=total_seconds(expires)) 396 | except Exception as e: 397 | log.error('Error while updating session (session id: {}): {}'.format(store_id, str(e))) 398 | 399 | if self.use_signer: 400 | session_id = self._get_signer(app).sign(want_bytes(session.get_sid(self.cookie_name))) 401 | else: 402 | session_id = session.get_sid(self.cookie_name) 403 | cookie_expires = self.get_cookie_expiration_time(app, session) 404 | response.set_cookie(self.cookie_name, session_id, 405 | expires=cookie_expires, httponly=httponly, 406 | domain=domain, path=path, secure=secure) 407 | 408 | 409 | class MongoDBSessionInterface(BackendSessionInterface): 410 | """ A Session interface that uses MongoDB as backend. """ 411 | 412 | serializer = session_json_serializer 413 | 414 | def __init__(self, client, db, collection, key_prefix='session', use_signer=False, **kwargs): 415 | """ 416 | :param client: A 'pymongo.MongoClient' instance. 417 | :param db: The database you want to use. 418 | :param collection: The collection you want to use. 419 | :param key_prefix: A prefix that is added to all session store keys. 420 | :param use_signer: Whether to sign the session id cookie or not. 421 | :param kwargs: extra params to the base class 422 | """ 423 | super(MongoDBSessionInterface, self).__init__(**kwargs) 424 | if client is None: 425 | from pymongo import MongoClient 426 | client = MongoClient() 427 | self.client = client 428 | self.db = db 429 | self.store = client[db][collection] 430 | self.key_prefix = key_prefix 431 | self.use_signer = use_signer 432 | 433 | def _delete_session_from_store(self, store_id): 434 | """ Deletes the session from the store """ 435 | try: 436 | self.store.remove({'id': store_id}) 437 | except Exception as e: 438 | log.error('Error while deleting expired session (session id: {}): {}'.format(store_id, str(e))) 439 | return False 440 | return True 441 | 442 | def open_session(self, app, request): 443 | sid = request.cookies.get(self.cookie_name) 444 | if not sid: 445 | sid = self._generate_sid() 446 | return self.session_class(sid={self.cookie_name: sid}) 447 | if self.use_signer: 448 | signer = self._get_signer(app) 449 | if signer is None: 450 | return None 451 | try: 452 | sid_as_bytes = signer.unsign(sid) 453 | sid = sid_as_bytes.decode() 454 | except BadSignature: 455 | sid = self._generate_sid() 456 | return self.session_class(sid={self.cookie_name: sid}) 457 | 458 | store_id = self.key_prefix + sid 459 | try: 460 | document = self.store.find_one({'id': store_id}) 461 | except Exception as e: 462 | log.error('Error while retrieving session from db (session id: {}): {}'.format(store_id, str(e))) 463 | # treat as session expired. 464 | document = None 465 | if document and document.pop('_expiration') <= datetime.utcnow().replace(tzinfo=utc): 466 | # Delete expired session 467 | self._delete_session_from_store(store_id) 468 | document = None 469 | if document is not None: 470 | try: 471 | val = document['val'] 472 | data = self.serializer.loads(want_bytes(val)) 473 | permanent = self.cookie_name if document.pop('_permanent', None) else None 474 | return self.session_class(data, sid={self.cookie_name: sid}, permanent=permanent) 475 | except: 476 | return self.session_class(sid={self.cookie_name: sid}) 477 | return self.session_class(sid={self.cookie_name: sid}) 478 | 479 | def save_session(self, app, session, response): 480 | if self.cookie_domain is not None: 481 | domain = self.cookie_domain if self.cookie_domain else self.get_cookie_domain(app) 482 | else: 483 | domain = self.get_cookie_domain(app) 484 | path = self.cookie_path or self.get_cookie_path(app) 485 | 486 | if not session: 487 | if session.modified: 488 | self._delete_session_from_store(self.key_prefix + session.get_sid(self.cookie_name)) 489 | response.delete_cookie(self.cookie_name, domain=domain, path=path) 490 | return 491 | 492 | httponly = self.cookie_httponly or self.get_cookie_httponly(app) 493 | secure = self.cookie_secure or self.get_cookie_secure(app) 494 | expires = self.get_expiration_time(app, session) 495 | 496 | if session.modified: 497 | # The session was modified 498 | store_id = self.key_prefix + session.get_sid(self.cookie_name) 499 | val = self.serializer.dumps(dict(session)) 500 | try: 501 | self.store.update({'id': store_id}, 502 | {'id': store_id, 503 | 'val': val, 504 | '_expiration': expires, 505 | '_permanent': session.is_permanent(self.cookie_name)}, True) 506 | except Exception as e: 507 | log.error('Error while updating session (session id: {}): {}'.format(store_id, str(e))) 508 | 509 | if self.use_signer: 510 | session_id = self._get_signer(app).sign(want_bytes(session.get_sid(self.cookie_name))) 511 | else: 512 | session_id = session.get_sid(self.cookie_name) 513 | cookie_expires = self.get_cookie_expiration_time(app, session) 514 | response.set_cookie(self.cookie_name, session_id, 515 | expires=cookie_expires, httponly=httponly, 516 | domain=domain, path=path, secure=secure) 517 | 518 | 519 | class MemcachedSessionInterface(BackendSessionInterface): 520 | """ A Session interface that uses Memcached as backend. """ 521 | 522 | serializer = session_json_serializer 523 | 524 | def __init__(self, client, key_prefix='session', use_signer=False, **kwargs): 525 | """ 526 | :param client: A 'memcache.Client' instance. 527 | :param key_prefix: A prefix that is added to all session store keys. 528 | :param use_signer: Whether to sign the session id cookie or not. 529 | :param kwargs: extra params to the base class 530 | 531 | """ 532 | super(MemcachedSessionInterface, self).__init__(**kwargs) 533 | if client is None: 534 | raise ValueError('Must provide a valid memcache Client instance') 535 | self.client = client 536 | self.key_prefix = key_prefix 537 | self.use_signer = use_signer 538 | 539 | @staticmethod 540 | def _get_memcache_timeout(timeout): 541 | """ 542 | from Flask-session: 543 | Memcached deals with long (> 30 days) timeouts in a special 544 | way. Call this function to obtain a safe value for your timeout. 545 | """ 546 | if timeout > 2592000: # 60*60*24*30, 30 days 547 | # See http://code.google.com/p/memcached/wiki/FAQ 548 | # "You can set expire times up to 30 days in the future. After that 549 | # memcached interprets it as a date, and will expire the item after 550 | # said date. This is a simple (but obscure) mechanic." 551 | # 552 | # This means that we have to switch to absolute timestamps. 553 | timeout += int(time.time()) 554 | return timeout 555 | 556 | def open_session(self, app, request): 557 | sid = request.cookies.get(self.cookie_name) 558 | if not sid: 559 | sid = self._generate_sid() 560 | return self.session_class(sid={self.cookie_name: sid}) 561 | if self.use_signer: 562 | signer = self._get_signer(app) 563 | if signer is None: 564 | return None 565 | try: 566 | sid_as_bytes = signer.unsign(sid) 567 | sid = sid_as_bytes.decode() 568 | except BadSignature: 569 | sid = self._generate_sid() 570 | return self.session_class(sid={self.cookie_name: sid}) 571 | 572 | store_id = self.key_prefix + sid 573 | if PY2 and isinstance(store_id, unicode): 574 | store_id = store_id.encode('utf-8') 575 | try: 576 | val = self.client.get(store_id) 577 | except Exception as e: 578 | log.error('Error while retrieving session from db (session id: {}): {}'.format(store_id, str(e))) 579 | # treat as session expired. 580 | val = None 581 | if val is not None: 582 | try: 583 | if not PY2: 584 | val = want_bytes(val) 585 | data = self.serializer.loads(val) 586 | permanent = self.cookie_name if data.pop('_permanent', None) else None 587 | return self.session_class(data, sid={self.cookie_name: sid}, permanent=permanent) 588 | except: 589 | return self.session_class(sid={self.cookie_name: sid}) 590 | return self.session_class(sid={self.cookie_name: sid}) 591 | 592 | def save_session(self, app, session, response): 593 | if self.cookie_domain is not None: 594 | domain = self.cookie_domain if self.cookie_domain else self.get_cookie_domain(app) 595 | else: 596 | domain = self.get_cookie_domain(app) 597 | path = self.cookie_path or self.get_cookie_path(app) 598 | 599 | store_id = self.key_prefix + session.get_sid(self.cookie_name) 600 | if PY2 and isinstance(store_id, unicode): 601 | store_id = store_id.encode('utf-8') 602 | if not session: 603 | if session.modified: 604 | try: 605 | self.client.delete(store_id) 606 | except Exception as e: 607 | log.error('Error while deleting session (session id: {}): {}'.format(store_id, str(e))) 608 | response.delete_cookie(self.cookie_name, domain=domain, path=path) 609 | return 610 | 611 | httponly = self.cookie_httponly or self.get_cookie_httponly(app) 612 | secure = self.cookie_secure or self.get_cookie_secure(app) 613 | expires = self.get_expiration_time(app, session) 614 | 615 | if session.modified: 616 | # The session was modified 617 | data = {'_permanent': session.is_permanent(self.cookie_name)} 618 | data.update(dict(session)) 619 | val = self.serializer.dumps(data) 620 | 621 | try: 622 | self.client.set(store_id, val, self._get_memcache_timeout(total_seconds(expires))) 623 | except Exception as e: 624 | log.error('Error while updating session (session id: {}): {}'.format(store_id, str(e))) 625 | 626 | if self.use_signer: 627 | session_id = self._get_signer(app).sign(want_bytes(session.get_sid(self.cookie_name))) 628 | else: 629 | session_id = session.get_sid(self.cookie_name) 630 | cookie_expires = self.get_cookie_expiration_time(app, session) 631 | response.set_cookie(self.cookie_name, session_id, 632 | expires=cookie_expires, httponly=httponly, 633 | domain=domain, path=path, secure=secure) 634 | -------------------------------------------------------------------------------- /flask_session_plus/core.py: -------------------------------------------------------------------------------- 1 | from flask.sessions import SessionMixin 2 | from werkzeug._internal import _missing 3 | 4 | 5 | class UpdateDictMixin(object): 6 | """ Makes dicts call `self.on_update` on modifications. 7 | self.on_update receives the dict instance and the attribute changed 8 | """ 9 | 10 | on_update = None 11 | 12 | def calls_update(name): 13 | def oncall(self, *args, **kw): 14 | keys = list(super(UpdateDictMixin, self).keys()) if name == 'clear' else None 15 | rv = getattr(super(UpdateDictMixin, self), name)(*args, **kw) 16 | if self.on_update is not None: 17 | if name == 'clear': 18 | for key in keys: 19 | self.on_update(self, key) 20 | elif name == 'popitem': 21 | self.on_update(self, rv[0]) 22 | elif name == 'update': 23 | for key in args[0]: 24 | self.on_update(self, key) 25 | else: 26 | self.on_update(self, args[0]) 27 | return rv 28 | 29 | oncall.__name__ = name 30 | return oncall 31 | 32 | def setdefault(self, key, default=None): 33 | modified = key not in self 34 | rv = super(UpdateDictMixin, self).setdefault(key, default) 35 | if modified and self.on_update is not None: 36 | self.on_update(self, key) 37 | return rv 38 | 39 | def pop(self, key, default=_missing): 40 | modified = key in self 41 | if default is _missing: 42 | rv = super(UpdateDictMixin, self).pop(key) 43 | else: 44 | rv = super(UpdateDictMixin, self).pop(key, default) 45 | if modified and self.on_update is not None: 46 | self.on_update(self, key) 47 | return rv 48 | 49 | __setitem__ = calls_update('__setitem__') 50 | __delitem__ = calls_update('__delitem__') 51 | clear = calls_update('clear') 52 | popitem = calls_update('popitem') 53 | update = calls_update('update') 54 | del calls_update 55 | 56 | 57 | class CallbackDict(UpdateDictMixin, dict): 58 | """A dict that calls a function passed every time something is changed. 59 | The function is passed the dict instance. 60 | """ 61 | 62 | def __init__(self, initial=None, on_update=None): 63 | dict.__init__(self, initial or ()) 64 | self.on_update = on_update 65 | 66 | def __repr__(self): 67 | return '<%s %s>' % ( 68 | self.__class__.__name__, 69 | dict.__repr__(self) 70 | ) 71 | 72 | 73 | class MultiSession(CallbackDict, SessionMixin): 74 | """ Baseclass for Multi Sessions based sessions. 75 | Tracks the keys that were modified 76 | """ 77 | 78 | modified = False 79 | accessed = False 80 | 81 | def __init__(self, initial=None, sid=None, permanent=None): 82 | def on_update(self, updated_key): 83 | self.modified = True 84 | self.accessed = True 85 | self.tracked_status.add(updated_key) 86 | 87 | super(MultiSession, self).__init__(initial, on_update) 88 | sid = sid or {} # sid is a dict of {'cookie_name': sid} 89 | if not isinstance(sid, dict): 90 | raise ValueError("sid must be always a dict of {'cookie_name': sid}") 91 | self._sid = sid 92 | self._permanent = set() # store the cookie names that are permanent sessions 93 | if permanent is not None: 94 | self._permanent.add(permanent) 95 | self.modified = False 96 | self.tracked_status = set() 97 | 98 | @property 99 | def sid(self): 100 | return self._sid 101 | 102 | @sid.setter 103 | def sid(self, value): 104 | if not isinstance(value, dict): 105 | raise ValueError("sid must be always a dict of {'cookie_name': sid}") 106 | self._sid.update(value) 107 | 108 | def get_sid(self, cookie_name): 109 | return self._sid.get(cookie_name) 110 | 111 | def __getitem__(self, key): 112 | self.accessed = True 113 | return super(MultiSession, self).__getitem__(key) 114 | 115 | def get(self, key, default=None): 116 | self.accessed = True 117 | return super(MultiSession, self).get(key, default) 118 | 119 | def setdefault(self, key, default=None): 120 | self.accessed = True 121 | return super(MultiSession, self).setdefault(key, default) 122 | 123 | def is_permanent(self, cookie_name): 124 | return cookie_name in self._permanent 125 | 126 | def set_permanent(self, cookie_name, remove=False): 127 | if remove: 128 | if cookie_name in self._permanent: 129 | self._permanent.remove(cookie_name) 130 | else: 131 | self._permanent.add(cookie_name) 132 | -------------------------------------------------------------------------------- /flask_session_plus/session.py: -------------------------------------------------------------------------------- 1 | from flask.sessions import SessionInterface as FlaskSessionInterface 2 | from flask_session_plus.backends import SecureCookieSessionInterface, FirestoreSessionInterface 3 | from flask_session_plus.backends import RedisSessionInterface, MongoDBSessionInterface, MemcachedSessionInterface 4 | from flask_session_plus.core import MultiSession 5 | 6 | 7 | class MultiSessionInterface(FlaskSessionInterface): 8 | 9 | backends = { 10 | 'secure_cookie': SecureCookieSessionInterface, 11 | 'firestore': FirestoreSessionInterface, 12 | 'redis': RedisSessionInterface, 13 | 'mongodb': MongoDBSessionInterface, 14 | 'memcache': MemcachedSessionInterface, 15 | } 16 | 17 | def __init__(self, sessions_config): 18 | all_includes = [] # store all the defined includes 19 | check_auto = [] # store any 'session_fields' auto 20 | for i, session_conf in enumerate(sessions_config): 21 | session_fields = session_conf.get('session_fields') 22 | if session_fields is None: 23 | continue 24 | if isinstance(session_fields, dict): 25 | all_includes.extend(session_fields.get('include', [])) 26 | elif isinstance(session_fields, list): 27 | all_includes.extend(session_fields) 28 | elif isinstance(session_fields, str) and session_fields == 'auto': 29 | check_auto.append(i) 30 | else: 31 | raise ValueError('session_fields type is incorrect') 32 | 33 | for auto in check_auto: 34 | sessions_config[auto]['session_fields'] = {'exclude': all_includes} 35 | 36 | self.session_interfaces = [] 37 | for session_conf in sessions_config: 38 | session_fields = session_conf.get('session_fields') 39 | session_type = session_conf.pop('session_type') 40 | 41 | backend = self.backends.get(session_type) 42 | if backend: 43 | session_interface = (backend(**session_conf), session_fields) 44 | self.session_interfaces.append(session_interface) 45 | else: 46 | raise ValueError('Specified session_type not recognized as a valid one.') 47 | 48 | @staticmethod 49 | def get_session_for(session_interface, session, session_fields): 50 | """ Returns all the sessions configured """ 51 | 52 | if isinstance(session_fields, dict): 53 | include = session_fields.get('include', []) 54 | exclude = session_fields.get('exclude', []) 55 | else: 56 | include = session_fields 57 | exclude = [] 58 | 59 | new_dict = {} 60 | if len(include) == 0: 61 | new_dict = dict(session) 62 | else: 63 | for field in include: 64 | if field in session: 65 | new_dict[field] = session.get(field) 66 | 67 | for field in exclude: 68 | new_dict.pop(field, None) 69 | 70 | modified = False 71 | for field in include: 72 | modified = modified or field in session.tracked_status 73 | if modified: 74 | break 75 | 76 | new_session = session_interface.session_class(new_dict) # new session 77 | new_session.modified = modified # assign modified flag 78 | cookie_name = session_interface.cookie_name 79 | new_session.sid = {cookie_name: session.get_sid(cookie_name)} # assign session id 80 | if session.is_permanent(session_interface.cookie_name): 81 | new_session.set_permanent(session_interface.cookie_name) # assign permanent status 82 | return new_session 83 | 84 | def open_session(self, app, request): 85 | """ Opens all the inner session interfaces and integrates all the sessions into one """ 86 | common_dict = {} 87 | session_sids = {} 88 | permanents = set() 89 | for si, _ in self.session_interfaces: 90 | session = si.open_session(app, request) 91 | if session is not None: 92 | # 1st: update dict values 93 | common_dict.update(dict(session)) 94 | # 2nd: integrate session sid if available 95 | session_sids[si.cookie_name] = session.get_sid(si.cookie_name) 96 | # 3rd: add permanent status 97 | if session.is_permanent(si.cookie_name): 98 | permanents.add(si.cookie_name) 99 | multi_session = MultiSession(common_dict, sid=session_sids) 100 | for cookie_name in permanents: 101 | multi_session.set_permanent(cookie_name) 102 | return multi_session 103 | 104 | def save_session(self, app, session, response): 105 | """ Saves all session info into each of the session interfaces """ 106 | for si, session_fields in self.session_interfaces: 107 | interface_session = self.get_session_for(si, session, session_fields) 108 | si.save_session(app, interface_session, response) 109 | 110 | 111 | class Session(object): 112 | 113 | def __init__(self, app=None): 114 | self.app = app 115 | if app is not None: 116 | self.init_app(app) 117 | 118 | def init_app(self, app): 119 | app.session_interface = self.create_session_interface(app) 120 | 121 | @staticmethod 122 | def create_session_interface(app): 123 | sessions_config = app.config.get('SESSION_CONFIG', []) 124 | if not sessions_config: 125 | # add the default session 126 | sessions_config.append({ 127 | 'cookie_name': app.config.get('SESSION_COOKIE_NAME'), 128 | 'cookie_domain': app.config.get('SESSION_COOKIE_DOMAIN'), 129 | 'cookie_path': app.config.get('SESSION_COOKIE_PATH'), 130 | 'cookie_httponly': app.config.get('SESSION_COOKIE_HTTPONLY'), 131 | 'cookie_secure': app.config.get('SESSION_COOKIE_SECURE'), 132 | 'cookie_max_age': None, 133 | }) 134 | 135 | for session in sessions_config: 136 | if not session.get('cookie_name'): 137 | raise ValueError('Each session configuration must define a cookie name') 138 | session.setdefault('session_type', 'secure_cookie') # the session Interface to be used 139 | session.setdefault('session_fields', []) # the list of fields used for this session 140 | 141 | return MultiSessionInterface(sessions_config) 142 | -------------------------------------------------------------------------------- /release.py: -------------------------------------------------------------------------------- 1 | """ 2 | Release script 3 | """ 4 | 5 | import os 6 | import shutil 7 | import subprocess 8 | import sys 9 | import requests 10 | from pathlib import Path 11 | 12 | # noinspection PyPackageRequirements 13 | import click 14 | 15 | 16 | PYPI_PACKAGE_NAME = 'flask-session-plus' 17 | PYPI_URL = 'https://pypi.org/pypi/{package}/json' 18 | DIST_PATH = 'dist' 19 | DIST_PATH_DELETE = 'dist_delete' 20 | CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) 21 | 22 | 23 | @click.group(context_settings=CONTEXT_SETTINGS) 24 | def cli(): 25 | pass 26 | 27 | 28 | @cli.command() 29 | @click.option('--force/--no-force', default=False, help='Will force a new build removing the previous ones') 30 | def build(force): 31 | """ Builds the distribution files: wheels and source. """ 32 | dist_path = Path(DIST_PATH) 33 | if dist_path.exists() and list(dist_path.glob('*')): 34 | if force or click.confirm('{} is not empty - delete contents?'.format(dist_path)): 35 | dist_path.rename(DIST_PATH_DELETE) 36 | shutil.rmtree(Path(DIST_PATH_DELETE)) 37 | dist_path.mkdir() 38 | else: 39 | click.echo('Aborting') 40 | sys.exit(1) 41 | 42 | subprocess.check_call(['python', 'setup.py', 'bdist_wheel']) 43 | subprocess.check_call(['python', 'setup.py', 'sdist', 44 | '--formats=gztar']) 45 | 46 | 47 | @cli.command() 48 | @click.option('--release/--no-release', default=False, help='--release to upload to pypi otherwise upload to test.pypi') 49 | @click.option('--rebuild/--no-rebuild', default=True, help='Will force a rebuild of the build files (src and wheels)') 50 | @click.pass_context 51 | def upload(ctx, release, rebuild): 52 | """ Uploads distribuition files to pypi or pypitest. """ 53 | 54 | dist_path = Path(DIST_PATH) 55 | if rebuild is False: 56 | if not dist_path.exists() or not list(dist_path.glob('*')): 57 | print("No distribution files found. Please run 'build' command first") 58 | return 59 | else: 60 | ctx.invoke(build, force=True) 61 | 62 | if release: 63 | args = ['twine', 'upload', 'dist/*'] 64 | else: 65 | repository = 'https://test.pypi.org/legacy/' 66 | args = ['twine', 'upload', '--repository-url', repository, 'dist/*'] 67 | 68 | env = os.environ.copy() 69 | 70 | p = subprocess.Popen(args, env=env) 71 | p.wait() 72 | 73 | 74 | @cli.command() 75 | def check(): 76 | """ Checks the long description. """ 77 | dist_path = Path(DIST_PATH) 78 | if not dist_path.exists() or not list(dist_path.glob('*')): 79 | print("No distribution files found. Please run 'build' command first") 80 | return 81 | 82 | subprocess.check_call(['twine', 'check', 'dist/*']) 83 | 84 | 85 | # noinspection PyShadowingBuiltins 86 | @cli.command(name='list') 87 | def list_releases(): 88 | """ Lists all releases published on pypi""" 89 | response = requests.get(PYPI_URL.format(package=PYPI_PACKAGE_NAME)) 90 | if response: 91 | data = response.json() 92 | 93 | releases_dict = data.get('releases', {}) 94 | 95 | if releases_dict: 96 | for version, release in releases_dict.items(): 97 | release_formats = [] 98 | published_on_date = None 99 | for fmt in release: 100 | release_formats.append(fmt.get('packagetype')) 101 | published_on_date = fmt.get('upload_time') 102 | 103 | release_formats = ' | '.join(release_formats) 104 | print('{:<10}{:>15}{:>25}'.format(version, published_on_date, release_formats)) 105 | else: 106 | print('No releases found for {}'.format(PYPI_PACKAGE_NAME)) 107 | else: 108 | print('Package "{}" not found on Pypi.org'.format(PYPI_PACKAGE_NAME)) 109 | 110 | 111 | if __name__ == "__main__": 112 | cli() 113 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | Flask==2.3.2 2 | pytz==2018.7 3 | requests==2.32.4 4 | twine==1.12.1 5 | wheel==0.38.1 6 | Click==7.0 7 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | bcrypt==3.1.5 2 | bleach==3.3.0 3 | cachetools==3.0.0 4 | certifi==2024.7.4 5 | cffi==1.11.5 6 | chardet==3.0.4 7 | Click==7.0 8 | docutils==0.14 9 | Flask==2.3.2 10 | Flask-Bcrypt==0.7.1 11 | Flask-Login==0.4.1 12 | flask-session-plus==0.0.2 13 | Flask-WTF==0.14.2 14 | google-api-core==1.7.0 15 | google-auth==1.6.2 16 | google-cloud-core==0.29.1 17 | google-cloud-firestore==0.31.0 18 | googleapis-common-protos==1.5.5 19 | grpcio==1.53.2 20 | idna==3.7 21 | itsdangerous==1.1.0 22 | Jinja2>=2.10.1 23 | MarkupSafe==1.1.0 24 | pkginfo==1.4.2 25 | protobuf==3.18.3 26 | pyasn1==0.4.4 27 | pyasn1-modules==0.2.2 28 | pycparser==2.19 29 | Pygments==2.15.0 30 | pytz==2018.7 31 | readme-renderer==24.0 32 | requests==2.32.4 33 | requests-toolbelt==0.8.0 34 | rsa==4.7 35 | six==1.12.0 36 | tqdm==4.66.3 37 | twine==1.12.1 38 | urllib3>=1.24.2 39 | webencodings==0.5.1 40 | Werkzeug==3.0.6 41 | WTForms==2.2.1 42 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup, find_packages 3 | 4 | 5 | VERSION = '0.6.2' 6 | 7 | 8 | def read(fname): 9 | """ Returns the contents of the fname file """ 10 | with open(os.path.join(os.path.dirname(__file__), fname), 'r') as file: 11 | return file.read() 12 | 13 | 14 | # Available classifiers: https://pypi.org/pypi?%3Aaction=list_classifiers 15 | CLASSIFIERS = [ 16 | 'Development Status :: 4 - Beta', 17 | 'Intended Audience :: Developers', 18 | 'License :: OSI Approved :: Apache Software License', 19 | 'Topic :: Office/Business :: Office Suites', 20 | 'Topic :: Software Development :: Libraries', 21 | 'Programming Language :: Python', 22 | 'Programming Language :: Python :: 3 :: Only', 23 | 'Programming Language :: Python :: 3.4', 24 | 'Programming Language :: Python :: 3.5', 25 | 'Programming Language :: Python :: 3.6', 26 | 'Programming Language :: Python :: 3.7', 27 | 'Programming Language :: Python :: 3.8', 28 | 'Operating System :: OS Independent', 29 | ] 30 | 31 | 32 | requires = ['Flask'] 33 | 34 | setup( 35 | name='flask-session-plus', 36 | version=VERSION, 37 | packages=find_packages(), 38 | url='https://github.com/janscas/flask-session-plus', 39 | license='Mit License', 40 | author='Alejcas', 41 | author_email='alejcas@users.noreply.github.com', 42 | maintainer='Alejcas', 43 | maintainer_email='alejcas@users.noreply.github.com', 44 | description='Flask Multiple Sessions Interface (combine multiple sessions with different backends)', 45 | long_description=read('README.md'), 46 | long_description_content_type="text/markdown", 47 | classifiers=CLASSIFIERS, 48 | python_requires=">=3.4", 49 | install_requires=requires, 50 | ) 51 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alejcas/flask-session-plus/407021cb39016e82fbefea904a241c439ecbdff2/test/__init__.py -------------------------------------------------------------------------------- /test/auth.py: -------------------------------------------------------------------------------- 1 | # CUSTOM login_user and logout_user functions to overwrite the default ones from Flask-Login 2 | 3 | from flask import session, current_app, _request_ctx_stack, request 4 | from flask_login.config import COOKIE_NAME 5 | from flask_login.signals import user_logged_in, user_logged_out 6 | from flask_login.utils import _get_user 7 | 8 | 9 | def get_user_data(user): 10 | """ Extrats user data to be save in the session """ 11 | extract_attrs = current_app.config.get('SESION_USER_FIELDS', []) 12 | user_data = {} 13 | for attr in extract_attrs: 14 | user_data[attr] = getattr(user, attr, None) 15 | return user_data 16 | 17 | 18 | def login_user(user, remember=False, duration=None, force=False, fresh=True): 19 | ''' 20 | Logs a user in. You should pass the actual user object to this. If the 21 | user's `is_active` property is ``False``, they will not be logged in 22 | unless `force` is ``True``. 23 | 24 | This will return ``True`` if the log in attempt succeeds, and ``False`` if 25 | it fails (i.e. because the user is inactive). 26 | 27 | :param user: The user object to log in. 28 | :type user: object 29 | :param remember: Whether to remember the user after their session expires. 30 | Defaults to ``False``. 31 | :type remember: bool 32 | :param duration: The amount of time before the remember cookie expires. If 33 | ``None`` the value set in the settings is used. Defaults to ``None``. 34 | :type duration: :class:`datetime.timedelta` 35 | :param force: If the user is inactive, setting this to ``True`` will log 36 | them in regardless. Defaults to ``False``. 37 | :type force: bool 38 | :param fresh: setting this to ``False`` will log in the user with a session 39 | marked as not "fresh". Defaults to ``True``. 40 | :type fresh: bool 41 | ''' 42 | if not force and not user.is_active: 43 | return False 44 | 45 | user_id = getattr(user, current_app.login_manager.id_attribute)() 46 | session['user_id'] = user_id 47 | session['user_data'] = get_user_data(user) 48 | session['_fresh'] = fresh 49 | session['_id'] = current_app.login_manager._session_identifier_generator() 50 | 51 | if remember: 52 | session['remember'] = 'set' 53 | if duration is not None: 54 | try: 55 | # equal to timedelta.total_seconds() but works with Python 2.6 56 | session['remember_seconds'] = (duration.microseconds + 57 | (duration.seconds + 58 | duration.days * 24 * 3600) * 59 | 10**6) / 10.0**6 60 | except AttributeError: 61 | raise Exception('duration must be a datetime.timedelta, ' 62 | 'instead got: {0}'.format(duration)) 63 | 64 | _request_ctx_stack.top.user = user 65 | user_logged_in.send(current_app._get_current_object(), user=_get_user()) 66 | return True 67 | 68 | 69 | def logout_user(): 70 | ''' 71 | Logs a user out. (You do not need to pass the actual user.) This will 72 | also clean up the remember me cookie if it exists. 73 | ''' 74 | 75 | user = _get_user() 76 | 77 | session.pop('user_id', None) 78 | session.pop('user_data', None) 79 | session.pop('_fresh', None) 80 | 81 | cookie_name = current_app.config.get('REMEMBER_COOKIE_NAME', COOKIE_NAME) 82 | if cookie_name in request.cookies: 83 | session['remember'] = 'clear' 84 | if 'remember_seconds' in session: 85 | session.pop('remember_seconds') 86 | 87 | user_logged_out.send(current_app._get_current_object(), user=user) 88 | 89 | current_app.login_manager.reload_user() 90 | return True 91 | -------------------------------------------------------------------------------- /test/flask_app.py: -------------------------------------------------------------------------------- 1 | import os 2 | from flask import Flask, session, render_template 3 | from flask_session_plus import Session 4 | from flask_login import LoginManager, login_required, current_user 5 | from flask_wtf import FlaskForm 6 | from wtforms import StringField 7 | from test.auth import login_user, logout_user 8 | 9 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 10 | os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.join(os.path.dirname(BASE_DIR), 'firebase.json') 11 | 12 | from test.models import User, db 13 | 14 | app = Flask(__name__, template_folder='templates') 15 | app.config['SESSION_CONFIG'] = [ 16 | # First session will store the csrf_token only on it's own cookie. 17 | { 18 | 'cookie_name': 'csrf', 19 | 'session_type': 'secure_cookie', 20 | 'session_fields': ['csrf_token'] 21 | }, 22 | # Second session will store the user logged in inside the firestore sessions collection. 23 | { 24 | 'cookie_name': 'session', 25 | 'session_type': 'firestore', 26 | 'session_fields': ['user_id', 'user_data', '_fresh', '_id'], 27 | 'client': db, 28 | 'collection': 'sessions', 29 | }, 30 | # Third session will store any other values set on the Flask session on it's own secure cookie 31 | { 32 | 'cookie_name': 'data', 33 | 'session_type': 'secure_cookie', 34 | 'session_fields': 'auto' 35 | } 36 | ] 37 | app.config['SECRET_KEY'] = 'my_secret_key' 38 | app.config['SESION_USER_FIELDS'] = ['name', 'email', 'timezone', 'language', 'active'] 39 | 40 | mses = Session(app) 41 | login_manager = LoginManager(app) 42 | 43 | 44 | # Example Form to test the csrf token 45 | class LoginForm(FlaskForm): 46 | username = StringField('name') 47 | 48 | 49 | @login_manager.user_loader 50 | def load_user(id): 51 | # Flask-Login USER loader 52 | # can't use current_user here as this method is setting the current_user 53 | if 'user_id' in session: 54 | print('Got user from SESSION') 55 | return User.get_user_from_session(session) 56 | else: 57 | print('Got user from DATABASE') 58 | return User.get_user_by_id(id) 59 | 60 | 61 | @app.route('/') 62 | def index(): 63 | # testing setting session random values 64 | session['dog'] = 'cat' 65 | 66 | if current_user.is_authenticated: 67 | return f"Hi!: {current_user.to_dict()}" 68 | else: 69 | return 'Anon User' 70 | 71 | 72 | @app.route('/login') 73 | def login(): 74 | """ Testing a normal login """ 75 | user = User.get_user_by_id('1kuU9610nMtUlqLqdjxR') 76 | 77 | login_user(user) 78 | return 'User logged in!' 79 | 80 | 81 | @app.route('/loginpermanent') 82 | def loginp(): 83 | """ Testing a permanent login """ 84 | user = User.get_user_by_id('1kuU9610nMtUlqLqdjxR') 85 | 86 | login_user(user, remember=True) 87 | session.set_permanent('session') # setting the session as permanent 88 | 89 | return 'User logged in!' 90 | 91 | 92 | @app.route('/logout') 93 | def logout(): 94 | logout_user() 95 | session.set_permanent('session', remove=True) # unsetting the session as permanent 96 | 97 | return 'User logged out!' 98 | 99 | 100 | @app.route('/protected') 101 | @login_required 102 | def protected(): 103 | return f'you have access!: {current_user.to_dict()}' 104 | 105 | 106 | @app.route('/form', methods=['GET', 'POST']) 107 | @login_required 108 | def form(): 109 | """ Testing that the csrf_token from flask-wtf is well set""" 110 | frm = LoginForm() 111 | if frm.validate_on_submit(): 112 | return render_template('test_csrf.html', form=frm, success=True) 113 | return render_template('test_csrf.html', form=frm) 114 | 115 | 116 | if __name__ == '__main__': 117 | app.run() 118 | -------------------------------------------------------------------------------- /test/models.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from flask_login import UserMixin 3 | from google.cloud import firestore 4 | 5 | db = firestore.Client() 6 | 7 | from google.cloud.exceptions import NotFound 8 | from werkzeug.security import check_password_hash, generate_password_hash 9 | 10 | log = logging.getLogger(__name__) 11 | 12 | 13 | class User(UserMixin): 14 | 15 | def __init__(self, user_id, user_data): 16 | self.user_id = user_id 17 | self.user_ref = db.collection('users').document(user_id) 18 | self.email = user_data.pop('email', '') 19 | self.password = user_data.pop('password', '') 20 | self.active = user_data.pop('active', False) 21 | 22 | self.name = user_data.pop('name', '') 23 | 24 | # other attributes... 25 | self.country = user_data.pop('country', '') 26 | self.language = user_data.pop('language', '') 27 | self.timezone = user_data.pop('timezone', '') 28 | 29 | self.extra_data = user_data 30 | 31 | def __repr__(self): 32 | return f'name: {self.name} ({self.user_id})' 33 | 34 | def to_dict(self): 35 | return { 36 | 'user_id': self.user_id, 37 | 'email': self.email, 38 | 'password': self.password, 39 | 'active': self.active, 40 | 'name': self.name, 41 | 'country': self.country, 42 | 'language': self.language, 43 | 'timezone': self.timezone, 44 | } 45 | 46 | def get_id(self): 47 | return self.user_id 48 | 49 | @property 50 | def is_active(self): 51 | return self.active 52 | 53 | @classmethod 54 | def get_user_by_id(cls, user_id): 55 | user_ref = db.collection('users').document(user_id) 56 | try: 57 | user = user_ref.get() 58 | user = cls(user_id=user_id, user_data=user.to_dict()) if user.exists else None 59 | except NotFound: 60 | user = None 61 | return user 62 | 63 | @classmethod 64 | def get_user_by_email(cls, email): 65 | user_ref = db.collection('users').where('email', '==', email).limit(1) 66 | try: 67 | user_ref.get() 68 | user = list(user_ref.get()) # a query returns an iterator 69 | user = user[0] if user else None 70 | except Exception as e: 71 | log.error(f'Error while getting username by email ({email}): {e}') 72 | user = None 73 | 74 | if user: 75 | return cls(user.id, user.to_dict()) 76 | else: 77 | return None 78 | 79 | def set_password(self, password): 80 | new_password = generate_password_hash(password) 81 | try: 82 | self.user_ref.update({'password': new_password}) 83 | self.password = new_password 84 | except Exception as e: 85 | log.error(f'Error while setting password on User ({self.user_id}): {e}') 86 | return False 87 | return True 88 | 89 | def check_password(self, password): 90 | return check_password_hash(self.password, password) 91 | 92 | @classmethod 93 | def get_user_from_session(cls, session): 94 | # Avoids a database call 95 | if 'user_id' in session: 96 | return cls(session['user_id'], session.get('user_data', {})) 97 | else: 98 | return None 99 | -------------------------------------------------------------------------------- /test/templates/test_csrf.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Title 8 | 9 | 10 |
11 | {{ form.csrf_token }} 12 | {{ form.username }} 13 | 14 |
15 | 16 | {% if success %} 17 |

WORKING !!!

18 | {% endif %} 19 | 20 | 21 | 22 | --------------------------------------------------------------------------------