├── piston ├── __init__.py ├── templates │ ├── piston │ │ └── authorize_token.html │ └── documentation.html ├── middleware.py ├── managers.py ├── forms.py ├── store.py ├── handler.py ├── models.py ├── doc.py ├── decorator.py ├── resource.py ├── utils.py ├── authentication.py ├── emitters.py └── oauth.py ├── examples └── blogserver │ ├── __init__.py │ ├── api │ ├── __init__.py │ ├── urls.py │ └── handlers.py │ ├── blog │ ├── __init__.py │ ├── urls.py │ ├── models.py │ └── views.py │ ├── urls.py │ ├── manage.py │ ├── templates │ ├── posts.html │ └── test_js.html │ ├── fixtures │ └── initial_data.xml │ ├── README.txt │ └── settings.py ├── tests ├── test_project │ ├── __init__.py │ ├── apps │ │ ├── __init__.py │ │ └── testapp │ │ │ ├── __init__.py │ │ │ ├── signals.py │ │ │ ├── forms.py │ │ │ ├── models.py │ │ │ ├── urls.py │ │ │ ├── handlers.py │ │ │ └── tests.py │ ├── templates │ │ ├── 404.html │ │ ├── 500.html │ │ └── admin │ │ │ └── login.html │ ├── urls.py │ └── settings.py ├── buildout.cfg └── bootstrap.py ├── AUTHORS.txt ├── setup.py └── ez_setup.py /piston/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/blogserver/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_project/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/blogserver/api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/blogserver/blog/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_project/apps/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_project/templates/404.html: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_project/templates/500.html: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_project/apps/testapp/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_project/templates/admin/login.html: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/test_project/apps/testapp/signals.py: -------------------------------------------------------------------------------- 1 | import django.dispatch 2 | 3 | entry_request_started = django.dispatch.Signal(providing_args=['request']) 4 | -------------------------------------------------------------------------------- /tests/test_project/apps/testapp/forms.py: -------------------------------------------------------------------------------- 1 | from django import forms 2 | 3 | 4 | class EchoForm(forms.Form): 5 | msg = forms.CharField(max_length=128) 6 | 7 | -------------------------------------------------------------------------------- /tests/test_project/urls.py: -------------------------------------------------------------------------------- 1 | from django.conf.urls.defaults import * 2 | 3 | 4 | urlpatterns = patterns('', 5 | url(r'api/', include('test_project.apps.testapp.urls')) 6 | ) 7 | -------------------------------------------------------------------------------- /examples/blogserver/blog/urls.py: -------------------------------------------------------------------------------- 1 | from django.conf.urls.defaults import * 2 | 3 | urlpatterns = patterns('blogserver.blog.views', 4 | url(r'^$', 'posts', name='posts'), 5 | url(r'^js$', 'test_js'), 6 | ) -------------------------------------------------------------------------------- /examples/blogserver/urls.py: -------------------------------------------------------------------------------- 1 | from django.conf.urls.defaults import * 2 | from django.contrib import admin 3 | 4 | admin.autodiscover() 5 | 6 | urlpatterns = patterns('', 7 | (r'^', include('blogserver.blog.urls')), 8 | (r'^api/', include('blogserver.api.urls')), 9 | (r'^admin/(.*)', admin.site.root), 10 | ) 11 | -------------------------------------------------------------------------------- /piston/templates/piston/authorize_token.html: -------------------------------------------------------------------------------- 1 | 3 | 4 | 5 | Authorize Token 6 | 7 | 8 |

Authorize Token

9 | 10 |
11 | {{ form.as_table }} 12 |
13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /tests/buildout.cfg: -------------------------------------------------------------------------------- 1 | [buildout] 2 | parts = django-1.1 django-1.0 3 | develop = .. 4 | eggs = django-piston 5 | 6 | [django-1.1] 7 | recipe = djangorecipe 8 | version = trunk 9 | project = test_project 10 | settings = settings 11 | test = testapp 12 | eggs = ${buildout:eggs} 13 | testrunner = test-1.1 14 | 15 | [django-1.0] 16 | recipe = djangorecipe 17 | version = 1.0.2 18 | project = test_project 19 | settings = settings 20 | test = testapp 21 | eggs = ${buildout:eggs} 22 | testrunner = test-1.0 23 | -------------------------------------------------------------------------------- /examples/blogserver/blog/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | from django.contrib.auth.models import User 3 | from django.contrib import admin 4 | 5 | class Blogpost(models.Model): 6 | title = models.CharField(max_length=255) 7 | content = models.TextField() 8 | author = models.ForeignKey(User, related_name='posts') 9 | created_on = models.DateTimeField(auto_now_add=True) 10 | 11 | def __unicode__(self): 12 | return self.title 13 | 14 | admin.site.register(Blogpost) -------------------------------------------------------------------------------- /examples/blogserver/manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from django.core.management import execute_manager 3 | try: 4 | import settings # Assumed to be in the same directory. 5 | except ImportError: 6 | import sys 7 | sys.stderr.write("Error: Can't find the file 'settings.py' in the directory containing %r. It appears you've customized things.\nYou'll have to run django-admin.py, passing it your settings module.\n(If the file settings.py does indeed exist, it's causing an ImportError somehow.)\n" % __file__) 8 | sys.exit(1) 9 | 10 | if __name__ == "__main__": 11 | execute_manager(settings) 12 | -------------------------------------------------------------------------------- /examples/blogserver/blog/views.py: -------------------------------------------------------------------------------- 1 | from django.http import HttpResponse, HttpResponseRedirect 2 | from django.contrib.auth.decorators import login_required 3 | from django.shortcuts import render_to_response, get_object_or_404 4 | from django.template import RequestContext 5 | 6 | from blogserver.blog.models import Blogpost 7 | 8 | def posts(request): 9 | posts = Blogpost.objects.all() 10 | 11 | return render_to_response("posts.html", { 12 | 'posts': posts }, 13 | RequestContext(request)) 14 | 15 | def test_js(request): 16 | return render_to_response('test_js.html', {}, RequestContext(request)) 17 | -------------------------------------------------------------------------------- /examples/blogserver/templates/posts.html: -------------------------------------------------------------------------------- 1 | 3 | 4 | 5 | 6 | 7 | 8 | Blogposts 9 | 10 | 11 | 12 | 13 | 14 |

Posts

15 | 16 | {% for post in posts %} 17 | 18 |

{{ post.title }}

19 | 20 |
21 | 22 | {{ post.content}} 23 | 24 |
25 | 26 |
27 | Written on {{ post.created_on }} by {{ post.author.username }}. 28 |
29 | 30 | {% endfor %} 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /examples/blogserver/api/urls.py: -------------------------------------------------------------------------------- 1 | from django.conf.urls.defaults import * 2 | from piston.resource import Resource 3 | from piston.authentication import HttpBasicAuthentication 4 | from piston.doc import documentation_view 5 | 6 | from blogserver.api.handlers import BlogpostHandler 7 | 8 | auth = HttpBasicAuthentication(realm='My sample API') 9 | 10 | blogposts = Resource(handler=BlogpostHandler, authentication=auth) 11 | 12 | urlpatterns = patterns('', 13 | url(r'^posts/$', blogposts), 14 | url(r'^posts/(?P.+)/$', blogposts), 15 | url(r'^posts\.(?P.+)', blogposts, name='blogposts'), 16 | 17 | # automated documentation 18 | url(r'^$', documentation_view), 19 | ) -------------------------------------------------------------------------------- /tests/test_project/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | DEBUG = True 3 | DATABASE_ENGINE = 'sqlite3' 4 | DATABASE_NAME = '/tmp/piston.db' 5 | INSTALLED_APPS = ( 6 | 'django.contrib.auth', 7 | 'django.contrib.contenttypes', 8 | 'django.contrib.sessions', 9 | 'piston', 10 | 'test_project.apps.testapp', 11 | ) 12 | TEMPLATE_DIRS = ( 13 | os.path.join(os.path.dirname(__file__), 'templates'), 14 | ) 15 | ROOT_URLCONF = 'test_project.urls' 16 | 17 | MIDDLEWARE_CLASSES = ( 18 | 'piston.middleware.ConditionalMiddlewareCompatProxy', 19 | 'django.contrib.sessions.middleware.SessionMiddleware', 20 | 'piston.middleware.CommonMiddlewareCompatProxy', 21 | 'django.contrib.auth.middleware.AuthenticationMiddleware', 22 | ) 23 | -------------------------------------------------------------------------------- /piston/middleware.py: -------------------------------------------------------------------------------- 1 | from django.middleware.http import ConditionalGetMiddleware 2 | from django.middleware.common import CommonMiddleware 3 | 4 | def compat_middleware_factory(klass): 5 | """ 6 | Class wrapper that only executes `process_response` 7 | if `streaming` is not set on the `HttpResponse` object. 8 | Django has a bad habbit of looking at the content, 9 | which will prematurely exhaust the data source if we're 10 | using generators or buffers. 11 | """ 12 | class compatwrapper(klass): 13 | def process_response(self, req, resp): 14 | if not hasattr(resp, 'streaming'): 15 | return klass.process_response(self, req, resp) 16 | return resp 17 | return compatwrapper 18 | 19 | ConditionalMiddlewareCompatProxy = compat_middleware_factory(ConditionalGetMiddleware) 20 | CommonMiddlewareCompatProxy = compat_middleware_factory(CommonMiddleware) 21 | -------------------------------------------------------------------------------- /AUTHORS.txt: -------------------------------------------------------------------------------- 1 | django-piston was originally written by Jesper Noehr, but since 2 | its release, many people have contributed invaluable feedback and patches: 3 | 4 | Alberto Donato provided a fix for #26, contributed #35, #37 and #38 5 | Matthew Marshall provided fixes for #25 and #24 6 | Pete Karl reported #20 7 | Travis Jensen provided a fix for #24 8 | Benoit Chesneau provided some mimetype helpers 9 | Adam Lowry fixed a bug with the callmap construct 10 | Benjamin Pollack provided an ez_setup fix and more 11 | Seph Soliman for fixing a bug with AnonymousBaseHandler 12 | David Larlet for suggesting several improvements 13 | Xavier Barbosa for improving QuerySet handling and providing Template URI code 14 | Michael Richardson for contributing an improvement to form validation handling 15 | Brian McMurray for contributing a patch for #41 16 | James Emerton for making the OAuth parts more usable/friendly 17 | Anton Tsigularov for providing a patch for incorrect multipart detection -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | try: 5 | from setuptools import setup, find_packages 6 | except ImportError: 7 | import ez_setup 8 | ez_setup.use_setuptools() 9 | from setuptools import setup, find_packages 10 | 11 | import os 12 | 13 | setup( 14 | name = "django-piston", 15 | version = "0.2.2", 16 | url = 'http://bitbucket.org/jespern/django-piston/wiki/Home', 17 | download_url = 'http://bitbucket.org/jespern/django-piston/downloads/', 18 | license = 'BSD', 19 | description = "Piston is a Django mini-framework creating APIs.", 20 | author = 'Jesper Noehr', 21 | author_email = 'jesper@noehr.org', 22 | packages = find_packages(), 23 | include_package_data = True, 24 | classifiers = [ 25 | 'Development Status :: 3 - Alpha', 26 | 'Framework :: Django', 27 | 'Intended Audience :: Developers', 28 | 'License :: OSI Approved :: BSD License', 29 | 'Operating System :: OS Independent', 30 | 'Programming Language :: Python', 31 | 'Topic :: Internet :: WWW/HTTP', 32 | ] 33 | ) 34 | -------------------------------------------------------------------------------- /tests/test_project/apps/testapp/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | class TestModel(models.Model): 4 | test1 = models.CharField(max_length=1, blank=True, null=True) 5 | test2 = models.CharField(max_length=1, blank=True, null=True) 6 | 7 | class ExpressiveTestModel(models.Model): 8 | title = models.CharField(max_length=255) 9 | content = models.TextField() 10 | never_shown = models.TextField() 11 | 12 | class Comment(models.Model): 13 | parent = models.ForeignKey(ExpressiveTestModel, related_name='comments') 14 | content = models.TextField() 15 | 16 | class AbstractModel(models.Model): 17 | some_field = models.CharField(max_length=32, default='something here') 18 | 19 | class Meta: 20 | abstract = True 21 | 22 | class InheritedModel(AbstractModel): 23 | some_other = models.CharField(max_length=32, default='something else') 24 | 25 | class Meta: 26 | db_table = 'testing_abstracts' 27 | 28 | class PlainOldObject(object): 29 | def __emittable__(self): 30 | return {'type': 'plain', 31 | 'field': 'a field'} -------------------------------------------------------------------------------- /piston/templates/documentation.html: -------------------------------------------------------------------------------- 1 | {% load markup %} 2 | 4 | 5 | 6 | 7 | Piston generated documentation 8 | 9 | 16 | 17 | 18 |

API Documentation

19 | 20 | {% for doc in docs %} 21 | 22 |

{{ doc.name|cut:"Handler" }}:

23 | 24 |

25 | {{ doc.get_doc|default:""|restructuredtext }} 26 |

27 | 28 |

29 | URL: {{ doc.get_resource_uri_template }} 30 |

31 | 32 |

33 | Accepted methods: {% for meth in doc.allowed_methods %}{{ meth }}{% if not forloop.last %}, {% endif %}{% endfor %} 34 |

35 | 36 |
37 | {% for method in doc.get_all_methods %} 38 | 39 |
40 | method {{ method.name }}({{ method.signature }}){% if method.stale %} - inherited{% else %}:{% endif %} 41 | 42 |
43 | 44 | {% if method.get_doc %} 45 |
46 | {{ method.get_doc|default:""|restructuredtext }} 47 |
48 | {% endif %} 49 | 50 | {% endfor %} 51 |
52 | 53 | {% endfor %} 54 | 55 | 56 | -------------------------------------------------------------------------------- /tests/test_project/apps/testapp/urls.py: -------------------------------------------------------------------------------- 1 | from django.conf.urls.defaults import * 2 | from piston.resource import Resource 3 | from piston.authentication import HttpBasicAuthentication 4 | 5 | from test_project.apps.testapp.handlers import EntryHandler, ExpressiveHandler, AbstractHandler, EchoHandler, PlainOldObjectHandler 6 | 7 | auth = HttpBasicAuthentication(realm='TestApplication') 8 | 9 | entries = Resource(handler=EntryHandler, authentication=auth) 10 | expressive = Resource(handler=ExpressiveHandler, authentication=auth) 11 | abstract = Resource(handler=AbstractHandler, authentication=auth) 12 | echo = Resource(handler=EchoHandler) 13 | popo = Resource(handler=PlainOldObjectHandler) 14 | 15 | 16 | urlpatterns = patterns('', 17 | url(r'^entries/$', entries), 18 | url(r'^entries/(?P.+)/$', entries), 19 | url(r'^entries\.(?P.+)', entries), 20 | url(r'^entry-(?P.+)\.(?P.+)', entries), 21 | 22 | url(r'^expressive\.(?P.+)$', expressive), 23 | 24 | url(r'^abstract\.(?P.+)$', abstract), 25 | url(r'^abstract/(?P\d+)\.(?P.+)$', abstract), 26 | 27 | url(r'^echo$', echo), 28 | 29 | # oauth entrypoints 30 | url(r'^oauth/request_token$', 'piston.authentication.oauth_request_token'), 31 | url(r'^oauth/authorize$', 'piston.authentication.oauth_user_auth'), 32 | url(r'^oauth/access_token$', 'piston.authentication.oauth_access_token'), 33 | 34 | url(r'^popo$', popo), 35 | ) 36 | 37 | 38 | -------------------------------------------------------------------------------- /piston/managers.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | from django.contrib.auth.models import User 3 | 4 | KEY_SIZE = 16 5 | SECRET_SIZE = 16 6 | 7 | class ConsumerManager(models.Manager): 8 | def create_consumer(self, name, description=None, user=None): 9 | """ 10 | Shortcut to create a consumer with random key/secret. 11 | """ 12 | consumer, created = self.get_or_create(name=name) 13 | 14 | if user: 15 | consumer.user = user 16 | 17 | if description: 18 | consumer.description = description 19 | 20 | if created: 21 | consumer.generate_random_codes() 22 | 23 | return consumer 24 | 25 | _default_consumer = None 26 | 27 | class ResourceManager(models.Manager): 28 | _default_resource = None 29 | 30 | def get_default_resource(self, name): 31 | """ 32 | Add cache if you use a default resource. 33 | """ 34 | if not self._default_resource: 35 | self._default_resource = self.get(name=name) 36 | 37 | return self._default_resource 38 | 39 | class TokenManager(models.Manager): 40 | def create_token(self, consumer, token_type, timestamp, user=None): 41 | """ 42 | Shortcut to create a token with random key/secret. 43 | """ 44 | token, created = self.get_or_create(consumer=consumer, 45 | token_type=token_type, 46 | timestamp=timestamp, 47 | user=user) 48 | 49 | if created: 50 | token.generate_random_codes() 51 | 52 | return token -------------------------------------------------------------------------------- /examples/blogserver/fixtures/initial_data.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | testuser 5 | 6 | 7 | 8 | sha1$97a04$0537e86889530ffd5c46a756fffbe0f745e4938e 9 | False 10 | True 11 | False 12 | 2009-04-27 04:55:09 13 | 2009-04-27 04:55:09 14 | 15 | 16 | 17 | 18 | example.com 19 | example.com 20 | 21 | 22 | Sample blogpost 1 23 | This is just a sample post. 24 | 1 25 | 2009-04-27 04:55:23 26 | 27 | 28 | Another sample post 29 | This is yet another sample post. 30 | 1 31 | 2009-04-27 04:55:33 32 | 33 | 34 | -------------------------------------------------------------------------------- /examples/blogserver/api/handlers.py: -------------------------------------------------------------------------------- 1 | from piston.handler import BaseHandler, AnonymousBaseHandler 2 | from piston.utils import rc, require_mime, require_extended 3 | 4 | from blogserver.blog.models import Blogpost 5 | 6 | class AnonymousBlogpostHandler(AnonymousBaseHandler): 7 | """ 8 | Anonymous entrypoint for blogposts. 9 | """ 10 | model = Blogpost 11 | fields = ('id', 'title', 'content', 'created_on') 12 | 13 | @classmethod 14 | def resource_uri(self): 15 | return ('blogposts', [ 'format', ]) 16 | 17 | class BlogpostHandler(BaseHandler): 18 | """ 19 | Authenticated entrypoint for blogposts. 20 | """ 21 | model = Blogpost 22 | anonymous = AnonymousBlogpostHandler 23 | fields = ('title', 'content', ('author', ('username',)), 24 | 'created_on', 'content_length') 25 | 26 | def read(self, title=None): 27 | """ 28 | Returns a blogpost, if `title` is given, 29 | otherwise all the posts. 30 | 31 | Parameters: 32 | - `title`: The title of the post to retrieve. 33 | """ 34 | base = Blogpost.objects 35 | 36 | if title: 37 | return base.get(title=title) 38 | else: 39 | return base.all() 40 | 41 | def content_length(self, blogpost): 42 | return len(blogpost.content) 43 | 44 | @require_extended 45 | def create(self, request): 46 | """ 47 | Creates a new blogpost. 48 | """ 49 | attrs = self.flatten_dict(request.POST) 50 | 51 | if self.exists(**attrs): 52 | return rc.DUPLICATE_ENTRY 53 | else: 54 | post = Blogpost(title=attrs['title'], 55 | content=attrs['content'], 56 | author=request.user) 57 | post.save() 58 | 59 | return post 60 | 61 | @classmethod 62 | def resource_uri(self): 63 | return ('blogposts', [ 'format', ]) -------------------------------------------------------------------------------- /examples/blogserver/templates/test_js.html: -------------------------------------------------------------------------------- 1 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | Test 10 | 11 | 56 | 57 | 58 | 59 |

JS Test

60 |
61 |
62 |
63 | 64 | 65 |
66 | 67 | 68 |
69 | 70 |
71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /piston/forms.py: -------------------------------------------------------------------------------- 1 | import hmac, base64 2 | 3 | from django import forms 4 | from django.conf import settings 5 | 6 | class Form(forms.Form): 7 | pass 8 | 9 | class ModelForm(forms.ModelForm): 10 | """ 11 | Subclass of `forms.ModelForm` which makes sure 12 | that the initial values are present in the form 13 | data, so you don't have to send all old values 14 | for the form to actually validate. Django does not 15 | do this on its own, which is really annoying. 16 | """ 17 | def merge_from_initial(self): 18 | self.data._mutable = True 19 | filt = lambda v: v not in self.data.keys() 20 | for field in filter(filt, getattr(self.Meta, 'fields', ())): 21 | self.data[field] = self.initial.get(field, None) 22 | 23 | 24 | class OAuthAuthenticationForm(forms.Form): 25 | oauth_token = forms.CharField(widget=forms.HiddenInput) 26 | oauth_callback = forms.URLField(widget=forms.HiddenInput) 27 | authorize_access = forms.BooleanField(required=True) 28 | csrf_signature = forms.CharField(widget=forms.HiddenInput) 29 | 30 | def __init__(self, *args, **kwargs): 31 | forms.Form.__init__(self, *args, **kwargs) 32 | 33 | self.fields['csrf_signature'].initial = self.initial_csrf_signature 34 | 35 | def clean_csrf_signature(self): 36 | sig = self.cleaned_data['csrf_signature'] 37 | token = self.cleaned_data['oauth_token'] 38 | 39 | sig1 = OAuthAuthenticationForm.get_csrf_signature(settings.SECRET_KEY, token) 40 | 41 | if sig != sig1: 42 | raise forms.ValidationError("CSRF signature is not valid") 43 | 44 | return sig 45 | 46 | def initial_csrf_signature(self): 47 | token = self.initial['oauth_token'] 48 | return OAuthAuthenticationForm.get_csrf_signature(settings.SECRET_KEY, token) 49 | 50 | @staticmethod 51 | def get_csrf_signature(key, token): 52 | # Check signature... 53 | try: 54 | import hashlib # 2.5 55 | hashed = hmac.new(key, token, hashlib.sha1) 56 | except: 57 | import sha # deprecated 58 | hashed = hmac.new(key, token, sha) 59 | 60 | # calculate the digest base 64 61 | return base64.b64encode(hashed.digest()) 62 | 63 | -------------------------------------------------------------------------------- /examples/blogserver/README.txt: -------------------------------------------------------------------------------- 1 | This is a bare-skeleton Django application which demonstrates how you can 2 | add an API to your own applications. 3 | 4 | It's a simple blog application, with a "Blogpost" model, with an API on top 5 | of it. It has a fixture which contains a sample user (used as author and 6 | for auth) and a couple of posts. 7 | 8 | You can get started like so: 9 | 10 | $ python manage.py syncdb (answer "no" when it asks for superuser creation) 11 | $ python manage.py runserver 12 | 13 | Now, the test user has authentication info: 14 | 15 | Username: testuser 16 | Password: foobar 17 | 18 | The API is accessible via '/api/posts'. You can try it with curl: 19 | 20 | $ curl -u testuser:foobar "http://127.0.0.1:8000/api/posts/?format=yaml" 21 | - author: {absolute_uri: /users/testuser/, username: testuser} 22 | content: This is just a sample post. 23 | content_length: 27 24 | created_on: 2009-04-27 04:55:23 25 | title: Sample blogpost 1 26 | - author: {absolute_uri: /users/testuser/, username: testuser} 27 | content: This is yet another sample post. 28 | content_length: 32 29 | created_on: 2009-04-27 04:55:33 30 | title: Another sample post 31 | 32 | That's an authorized request, and the user gets back privileged information. 33 | 34 | Anonymously: 35 | 36 | $ curl "http://127.0.0.1:8000/api/posts/?format=yaml" 37 | - {content: This is just a sample post., created_on: !!timestamp '2009-04-27 04:55:23', 38 | title: Sample blogpost 1} 39 | - {content: This is yet another sample post., created_on: !!timestamp '2009-04-27 40 | 04:55:33', title: Another sample post} 41 | 42 | Creating blog posts is also easy: 43 | 44 | $ curl -u testuser:foobar "http://127.0.0.1:8000/api/posts/?format=yaml" -F "title=Testing again" -F "content=Foobar" 45 | author: {absolute_uri: /users/testuser/, username: testuser} 46 | content: Foobar 47 | content_length: 6 48 | created_on: 2009-04-27 05:53:38.138215 49 | title: Testing again 50 | 51 | (The data returned is the blog post it created.) 52 | 53 | Anonymously that's not allowed: 54 | 55 | $ curl -v "http://127.0.0.1:8000/api/posts/?format=yaml" -F "title=Testing again" -F "content=Foobar" 56 | * About to connect() to 127.0.0.1 port 8000 (#0) 57 | * Trying 127.0.0.1... connected 58 | * Connected to 127.0.0.1 (127.0.0.1) port 8000 (#0) 59 | > POST /api/posts/?format=yaml HTTP/1.1 60 | [snip] 61 | > 62 | < HTTP/1.0 405 METHOD NOT ALLOWED 63 | 64 | This is because by default, AnonymousBaseHandler has 'allow_methods' only set to 'GET'. 65 | 66 | You can check out how this is done in the 'api' directory. 67 | 68 | Also, there's plenty of documentation on http://bitbucket.org/jespern/django-piston/ 69 | 70 | Have fun! -------------------------------------------------------------------------------- /tests/test_project/apps/testapp/handlers.py: -------------------------------------------------------------------------------- 1 | from django.core.paginator import Paginator 2 | 3 | from piston.handler import BaseHandler 4 | from piston.utils import rc, validate 5 | 6 | from models import TestModel, ExpressiveTestModel, Comment, InheritedModel, PlainOldObject 7 | from forms import EchoForm 8 | from test_project.apps.testapp import signals 9 | 10 | 11 | class EntryHandler(BaseHandler): 12 | model = TestModel 13 | allowed_methods = ['GET', 'PUT', 'POST'] 14 | 15 | def read(self, request, pk=None): 16 | signals.entry_request_started.send(sender=self, request=request) 17 | if pk is not None: 18 | return TestModel.objects.get(pk=int(pk)) 19 | paginator = Paginator(TestModel.objects.all(), 25) 20 | return paginator.page(int(request.GET.get('page', 1))).object_list 21 | 22 | def update(self, request, pk): 23 | signals.entry_request_started.send(sender=self, request=request) 24 | 25 | def create(self, request): 26 | signals.entry_request_started.send(sender=self, request=request) 27 | 28 | class ExpressiveHandler(BaseHandler): 29 | model = ExpressiveTestModel 30 | fields = ('title', 'content', ('comments', ('content',))) 31 | 32 | @classmethod 33 | def comments(cls, em): 34 | return em.comments.all() 35 | 36 | def read(self, request): 37 | inst = ExpressiveTestModel.objects.all() 38 | 39 | return inst 40 | 41 | def create(self, request): 42 | if request.content_type: 43 | data = request.data 44 | 45 | em = self.model(title=data['title'], content=data['content']) 46 | em.save() 47 | 48 | for comment in data['comments']: 49 | Comment(parent=em, content=comment['content']).save() 50 | 51 | return rc.CREATED 52 | else: 53 | super(ExpressiveTestModel, self).create(request) 54 | 55 | class AbstractHandler(BaseHandler): 56 | fields = ('id', 'some_other', 'some_field') 57 | model = InheritedModel 58 | 59 | def read(self, request, id_=None): 60 | if id_: 61 | return self.model.objects.get(pk=id_) 62 | else: 63 | return super(AbstractHandler, self).read(request) 64 | 65 | class PlainOldObjectHandler(BaseHandler): 66 | allowed_methods = ('GET',) 67 | fields = ('type', 'field') 68 | model = PlainOldObject 69 | 70 | def read(self, request): 71 | return self.model() 72 | 73 | class EchoHandler(BaseHandler): 74 | allowed_methods = ('GET', ) 75 | 76 | @validate(EchoForm, 'GET') 77 | def read(self, request): 78 | return {'msg': request.GET['msg']} 79 | -------------------------------------------------------------------------------- /tests/bootstrap.py: -------------------------------------------------------------------------------- 1 | ############################################################################## 2 | # 3 | # Copyright (c) 2006 Zope Corporation and Contributors. 4 | # All Rights Reserved. 5 | # 6 | # This software is subject to the provisions of the Zope Public License, 7 | # Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. 8 | # THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED 9 | # WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 10 | # WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS 11 | # FOR A PARTICULAR PURPOSE. 12 | # 13 | ############################################################################## 14 | """Bootstrap a buildout-based project 15 | 16 | Simply run this script in a directory containing a buildout.cfg. 17 | The script accepts buildout command-line options, so you can 18 | use the -c option to specify an alternate configuration file. 19 | 20 | $Id$ 21 | """ 22 | 23 | import os, shutil, sys, tempfile, urllib2 24 | 25 | tmpeggs = tempfile.mkdtemp() 26 | 27 | is_jython = sys.platform.startswith('java') 28 | 29 | try: 30 | import pkg_resources 31 | except ImportError: 32 | ez = {} 33 | exec urllib2.urlopen('http://peak.telecommunity.com/dist/ez_setup.py' 34 | ).read() in ez 35 | ez['use_setuptools'](to_dir=tmpeggs, download_delay=0) 36 | 37 | import pkg_resources 38 | 39 | if sys.platform == 'win32': 40 | def quote(c): 41 | if ' ' in c: 42 | return '"%s"' % c # work around spawn lamosity on windows 43 | else: 44 | return c 45 | else: 46 | def quote (c): 47 | return c 48 | 49 | cmd = 'from setuptools.command.easy_install import main; main()' 50 | ws = pkg_resources.working_set 51 | 52 | if len(sys.argv) > 2 and sys.argv[1] == '--version': 53 | VERSION = ' == %s' % sys.argv[2] 54 | args = sys.argv[3:] + ['bootstrap'] 55 | else: 56 | VERSION = '' 57 | args = sys.argv[1:] + ['bootstrap'] 58 | 59 | if is_jython: 60 | import subprocess 61 | 62 | assert subprocess.Popen([sys.executable] + ['-c', quote(cmd), '-mqNxd', 63 | quote(tmpeggs), 'zc.buildout' + VERSION], 64 | env=dict(os.environ, 65 | PYTHONPATH= 66 | ws.find(pkg_resources.Requirement.parse('setuptools')).location 67 | ), 68 | ).wait() == 0 69 | 70 | else: 71 | assert os.spawnle( 72 | os.P_WAIT, sys.executable, quote (sys.executable), 73 | '-c', quote (cmd), '-mqNxd', quote (tmpeggs), 'zc.buildout' + VERSION, 74 | dict(os.environ, 75 | PYTHONPATH= 76 | ws.find(pkg_resources.Requirement.parse('setuptools')).location 77 | ), 78 | ) == 0 79 | 80 | ws.add_entry(tmpeggs) 81 | ws.require('zc.buildout' + VERSION) 82 | import zc.buildout.buildout 83 | zc.buildout.buildout.main(args) 84 | shutil.rmtree(tmpeggs) 85 | -------------------------------------------------------------------------------- /piston/store.py: -------------------------------------------------------------------------------- 1 | import oauth 2 | 3 | from models import Nonce, Token, Consumer 4 | 5 | class DataStore(oauth.OAuthDataStore): 6 | """Layer between Python OAuth and Django database.""" 7 | def __init__(self, oauth_request): 8 | self.signature = oauth_request.parameters.get('oauth_signature', None) 9 | self.timestamp = oauth_request.parameters.get('oauth_timestamp', None) 10 | self.scope = oauth_request.parameters.get('scope', 'all') 11 | 12 | def lookup_consumer(self, key): 13 | try: 14 | self.consumer = Consumer.objects.get(key=key) 15 | return self.consumer 16 | except Consumer.DoesNotExist: 17 | return None 18 | 19 | def lookup_token(self, token_type, token): 20 | if token_type == 'request': 21 | token_type = Token.REQUEST 22 | elif token_type == 'access': 23 | token_type = Token.ACCESS 24 | try: 25 | self.request_token = Token.objects.get(key=token, 26 | token_type=token_type) 27 | return self.request_token 28 | except Token.DoesNotExist: 29 | return None 30 | 31 | def lookup_nonce(self, oauth_consumer, oauth_token, nonce): 32 | if oauth_token is None: 33 | return None 34 | nonce, created = Nonce.objects.get_or_create(consumer_key=oauth_consumer.key, 35 | token_key=oauth_token.key, 36 | key=nonce) 37 | if created: 38 | return None 39 | else: 40 | return nonce.key 41 | 42 | def fetch_request_token(self, oauth_consumer): 43 | if oauth_consumer.key == self.consumer.key: 44 | self.request_token = Token.objects.create_token(consumer=self.consumer, 45 | token_type=Token.REQUEST, 46 | timestamp=self.timestamp) 47 | return self.request_token 48 | return None 49 | 50 | def fetch_access_token(self, oauth_consumer, oauth_token): 51 | if oauth_consumer.key == self.consumer.key \ 52 | and oauth_token.key == self.request_token.key \ 53 | and self.request_token.is_approved: 54 | self.access_token = Token.objects.create_token(consumer=self.consumer, 55 | token_type=Token.ACCESS, 56 | timestamp=self.timestamp, 57 | user=self.request_token.user) 58 | return self.access_token 59 | return None 60 | 61 | def authorize_request_token(self, oauth_token, user): 62 | if oauth_token.key == self.request_token.key: 63 | # authorize the request token in the store 64 | self.request_token.is_approved = True 65 | self.request_token.user = user 66 | self.request_token.save() 67 | return self.request_token 68 | return None -------------------------------------------------------------------------------- /examples/blogserver/settings.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | DEBUG = True 4 | TEMPLATE_DEBUG = DEBUG 5 | 6 | ADMINS = ( 7 | # ('Your Name', 'your_email@domain.com'), 8 | ) 9 | 10 | MANAGERS = ADMINS 11 | 12 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 13 | 14 | # Fix up piston imports here. We would normally place piston in 15 | # a directory accessible via the Django app, but this is an 16 | # example and we ship it a couple of directories up. 17 | sys.path.insert(0, os.path.join(BASE_DIR, '../../')) 18 | 19 | DATABASE_ENGINE = 'sqlite3' # 'postgresql_psycopg2', 'postgresql', 'mysql', 'sqlite3' or 'oracle'. 20 | DATABASE_NAME = os.path.join(BASE_DIR, 'db') # Or path to database file if using sqlite3. 21 | #DATABASE_USER = '' # Not used with sqlite3. 22 | #DATABASE_PASSWORD = '' # Not used with sqlite3. 23 | #DATABASE_HOST = '' # Set to empty string for localhost. Not used with sqlite3. 24 | #DATABASE_PORT = '' # Set to empty string for default. Not used with sqlite3. 25 | 26 | # Local time zone for this installation. Choices can be found here: 27 | # http://en.wikipedia.org/wiki/List_of_tz_zones_by_name 28 | # although not all choices may be available on all operating systems. 29 | # If running in a Windows environment this must be set to the same as your 30 | # system time zone. 31 | TIME_ZONE = 'America/Chicago' 32 | 33 | # Language code for this installation. All choices can be found here: 34 | # http://www.i18nguy.com/unicode/language-identifiers.html 35 | LANGUAGE_CODE = 'en-us' 36 | 37 | SITE_ID = 1 38 | 39 | # If you set this to False, Django will make some optimizations so as not 40 | # to load the internationalization machinery. 41 | USE_I18N = True 42 | 43 | # Absolute path to the directory that holds media. 44 | # Example: "/home/media/media.lawrence.com/" 45 | MEDIA_ROOT = '' 46 | 47 | # URL that handles the media served from MEDIA_ROOT. Make sure to use a 48 | # trailing slash if there is a path component (optional in other cases). 49 | # Examples: "http://media.lawrence.com", "http://example.com/media/" 50 | MEDIA_URL = '' 51 | 52 | # URL prefix for admin media -- CSS, JavaScript and images. Make sure to use a 53 | # trailing slash. 54 | # Examples: "http://foo.com/media/", "/media/". 55 | ADMIN_MEDIA_PREFIX = '/media/' 56 | 57 | # Make this unique, and don't share it with anybody. 58 | SECRET_KEY = 'f@vhy8vuq7w70v=cnynm(am1__*zt##i2--i2p-021@-qgws%g' 59 | 60 | # List of callables that know how to import templates from various sources. 61 | TEMPLATE_LOADERS = ( 62 | 'django.template.loaders.filesystem.load_template_source', 63 | 'django.template.loaders.app_directories.load_template_source', 64 | # 'django.template.loaders.eggs.load_template_source', 65 | ) 66 | 67 | MIDDLEWARE_CLASSES = ( 68 | 'django.middleware.common.CommonMiddleware', 69 | 'django.contrib.sessions.middleware.SessionMiddleware', 70 | 'django.contrib.auth.middleware.AuthenticationMiddleware', 71 | ) 72 | 73 | ROOT_URLCONF = 'blogserver.urls' 74 | 75 | TEMPLATE_DIRS = ( 76 | os.path.join(BASE_DIR, 'templates'), 77 | os.path.join(BASE_DIR, '../../piston/templates'), 78 | ) 79 | 80 | INSTALLED_APPS = ( 81 | 'django.contrib.auth', 82 | 'django.contrib.contenttypes', 83 | 'django.contrib.sessions', 84 | 'django.contrib.sites', 85 | 'django.contrib.admin', 86 | 'django.contrib.markup', 87 | 'blogserver.blog', 88 | 'blogserver.api', 89 | ) 90 | 91 | FIXTURE_DIRS = ( 92 | os.path.join(BASE_DIR, 'fixtures'), 93 | ) 94 | 95 | APPEND_SLASH = False 96 | -------------------------------------------------------------------------------- /piston/handler.py: -------------------------------------------------------------------------------- 1 | from utils import rc 2 | from django.core.exceptions import ObjectDoesNotExist, MultipleObjectsReturned 3 | 4 | typemapper = { } 5 | 6 | class HandlerMetaClass(type): 7 | """ 8 | Metaclass that keeps a registry of class -> handler 9 | mappings. 10 | """ 11 | def __new__(cls, name, bases, attrs): 12 | new_cls = type.__new__(cls, name, bases, attrs) 13 | 14 | if hasattr(new_cls, 'model'): 15 | typemapper[new_cls] = (new_cls.model, new_cls.is_anonymous) 16 | 17 | return new_cls 18 | 19 | class BaseHandler(object): 20 | """ 21 | Basehandler that gives you CRUD for free. 22 | You are supposed to subclass this for specific 23 | functionality. 24 | 25 | All CRUD methods (`read`/`update`/`create`/`delete`) 26 | receive a request as the first argument from the 27 | resource. Use this for checking `request.user`, etc. 28 | """ 29 | __metaclass__ = HandlerMetaClass 30 | 31 | allowed_methods = ('GET', 'POST', 'PUT', 'DELETE') 32 | anonymous = is_anonymous = False 33 | exclude = ( 'id', ) 34 | fields = ( ) 35 | 36 | def flatten_dict(self, dct): 37 | return dict([ (str(k), dct.get(k)) for k in dct.keys() ]) 38 | 39 | def has_model(self): 40 | return hasattr(self, 'model') 41 | 42 | def value_from_tuple(tu, name): 43 | for int_, n in tu: 44 | if n == name: 45 | return int_ 46 | return None 47 | 48 | def exists(self, **kwargs): 49 | if not self.has_model(): 50 | raise NotImplementedError 51 | 52 | try: 53 | self.model.objects.get(**kwargs) 54 | return True 55 | except self.model.DoesNotExist: 56 | return False 57 | 58 | def read(self, request, *args, **kwargs): 59 | if not self.has_model(): 60 | return rc.NOT_IMPLEMENTED 61 | 62 | pkfield = self.model._meta.pk.name 63 | 64 | if pkfield in kwargs: 65 | try: 66 | return self.model.objects.get(pk=kwargs.get(pkfield)) 67 | except ObjectDoesNotExist: 68 | return rc.NOT_FOUND 69 | except MultipleObjectsReturned: # should never happen, since we're using a PK 70 | return rc.BAD_REQUEST 71 | else: 72 | return self.model.objects.filter(*args, **kwargs) 73 | 74 | def create(self, request, *args, **kwargs): 75 | if not self.has_model(): 76 | return rc.NOT_IMPLEMENTED 77 | 78 | attrs = self.flatten_dict(request.POST) 79 | 80 | try: 81 | inst = self.model.objects.get(**attrs) 82 | return rc.DUPLICATE_ENTRY 83 | except self.model.DoesNotExist: 84 | inst = self.model(**attrs) 85 | inst.save() 86 | return inst 87 | except self.model.MultipleObjectsReturned: 88 | return rc.DUPLICATE_ENTRY 89 | 90 | def update(self, request, *args, **kwargs): 91 | # TODO: This doesn't work automatically yet. 92 | return rc.NOT_IMPLEMENTED 93 | 94 | def delete(self, request, *args, **kwargs): 95 | if not self.has_model(): 96 | raise NotImplementedError 97 | 98 | try: 99 | inst = self.model.objects.get(*args, **kwargs) 100 | 101 | inst.delete() 102 | 103 | return rc.DELETED 104 | except self.model.MultipleObjectsReturned: 105 | return rc.DUPLICATE_ENTRY 106 | except self.model.DoesNotExist: 107 | return rc.NOT_HERE 108 | 109 | class AnonymousBaseHandler(BaseHandler): 110 | """ 111 | Anonymous handler. 112 | """ 113 | is_anonymous = True 114 | allowed_methods = ('GET',) 115 | -------------------------------------------------------------------------------- /piston/models.py: -------------------------------------------------------------------------------- 1 | import urllib 2 | from django.db import models 3 | from django.contrib.auth.models import User 4 | from django.contrib import admin 5 | from django.conf import settings 6 | from django.core.mail import send_mail, mail_admins 7 | from django.template import loader 8 | 9 | from managers import TokenManager, ConsumerManager, ResourceManager 10 | 11 | KEY_SIZE = 18 12 | SECRET_SIZE = 32 13 | 14 | CONSUMER_STATES = ( 15 | ('pending', 'Pending approval'), 16 | ('accepted', 'Accepted'), 17 | ('canceled', 'Canceled'), 18 | ) 19 | 20 | class Nonce(models.Model): 21 | token_key = models.CharField(max_length=KEY_SIZE) 22 | consumer_key = models.CharField(max_length=KEY_SIZE) 23 | key = models.CharField(max_length=255) 24 | 25 | def __unicode__(self): 26 | return u"Nonce %s for %s" % (self.key, self.consumer_key) 27 | 28 | admin.site.register(Nonce) 29 | 30 | class Resource(models.Model): 31 | name = models.CharField(max_length=255) 32 | url = models.TextField(max_length=2047) 33 | is_readonly = models.BooleanField(default=True) 34 | 35 | objects = ResourceManager() 36 | 37 | def __unicode__(self): 38 | return u"Resource %s with url %s" % (self.name, self.url) 39 | 40 | admin.site.register(Resource) 41 | 42 | class Consumer(models.Model): 43 | name = models.CharField(max_length=255) 44 | description = models.TextField() 45 | 46 | key = models.CharField(max_length=KEY_SIZE) 47 | secret = models.CharField(max_length=SECRET_SIZE) 48 | 49 | status = models.CharField(max_length=16, choices=CONSUMER_STATES, default='pending') 50 | user = models.ForeignKey(User, null=True, blank=True, related_name='consumers') 51 | 52 | objects = ConsumerManager() 53 | 54 | def __unicode__(self): 55 | return u"Consumer %s with key %s" % (self.name, self.key) 56 | 57 | def generate_random_codes(self): 58 | key = User.objects.make_random_password(length=KEY_SIZE) 59 | 60 | secret = User.objects.make_random_password(length=SECRET_SIZE) 61 | 62 | while Consumer.objects.filter(key__exact=key, secret__exact=secret).count(): 63 | secret = User.objects.make_random_password(length=SECRET_SIZE) 64 | 65 | self.key = key 66 | self.secret = secret 67 | self.save() 68 | 69 | # -- 70 | 71 | def save(self, **kwargs): 72 | super(Consumer, self).save(**kwargs) 73 | 74 | if self.id and self.user: 75 | subject = "API Consumer" 76 | rcpt = [ self.user.email, ] 77 | 78 | if self.status == "accepted": 79 | template = "api/mails/consumer_accepted.txt" 80 | subject += " was accepted!" 81 | elif self.status == "canceled": 82 | template = "api/mails/consumer_canceled.txt" 83 | subject += " has been canceled" 84 | else: 85 | template = "api/mails/consumer_pending.txt" 86 | subject += " application received" 87 | 88 | for admin in settings.ADMINS: 89 | bcc.append(admin[1]) 90 | 91 | body = loader.render_to_string(template, 92 | { 'consumer': self, 'user': self.user }) 93 | 94 | send_mail(subject, body, settings.DEFAULT_FROM_EMAIL, 95 | rcpt, fail_silently=True) 96 | 97 | if self.status == 'pending': 98 | mail_admins(subject, body, fail_silently=True) 99 | 100 | if settings.DEBUG: 101 | print "Mail being sent, to=%s" % rcpt 102 | print "Subject: %s" % subject 103 | print body 104 | 105 | admin.site.register(Consumer) 106 | 107 | class Token(models.Model): 108 | REQUEST = 1 109 | ACCESS = 2 110 | TOKEN_TYPES = ((REQUEST, u'Request'), (ACCESS, u'Access')) 111 | 112 | key = models.CharField(max_length=KEY_SIZE) 113 | secret = models.CharField(max_length=SECRET_SIZE) 114 | token_type = models.IntegerField(choices=TOKEN_TYPES) 115 | timestamp = models.IntegerField() 116 | is_approved = models.BooleanField(default=False) 117 | 118 | user = models.ForeignKey(User, null=True, blank=True, related_name='tokens') 119 | consumer = models.ForeignKey(Consumer) 120 | 121 | objects = TokenManager() 122 | 123 | def __unicode__(self): 124 | return u"%s Token %s for %s" % (self.get_token_type_display(), self.key, self.consumer) 125 | 126 | def to_string(self, only_key=False): 127 | token_dict = { 128 | 'oauth_token': self.key, 129 | 'oauth_token_secret': self.secret 130 | } 131 | if only_key: 132 | del token_dict['oauth_token_secret'] 133 | return urllib.urlencode(token_dict) 134 | 135 | def generate_random_codes(self): 136 | key = User.objects.make_random_password(length=KEY_SIZE) 137 | secret = User.objects.make_random_password(length=SECRET_SIZE) 138 | 139 | while Token.objects.filter(key__exact=key, secret__exact=secret).count(): 140 | secret = User.objects.make_random_password(length=SECRET_SIZE) 141 | 142 | self.key = key 143 | self.secret = secret 144 | self.save() 145 | 146 | admin.site.register(Token) -------------------------------------------------------------------------------- /piston/doc.py: -------------------------------------------------------------------------------- 1 | import inspect, handler 2 | 3 | from piston.handler import typemapper 4 | 5 | from django.core.urlresolvers import get_resolver, get_callable, get_script_prefix 6 | from django.shortcuts import render_to_response 7 | from django.template import RequestContext 8 | 9 | def generate_doc(handler_cls): 10 | """ 11 | Returns a `HandlerDocumentation` object 12 | for the given handler. Use this to generate 13 | documentation for your API. 14 | """ 15 | if not type(handler_cls) is handler.HandlerMetaClass: 16 | raise ValueError("Give me handler, not %s" % type(handler_cls)) 17 | 18 | return HandlerDocumentation(handler_cls) 19 | 20 | class HandlerMethod(object): 21 | def __init__(self, method, stale=False): 22 | self.method = method 23 | self.stale = stale 24 | 25 | def iter_args(self): 26 | args, _, _, defaults = inspect.getargspec(self.method) 27 | 28 | for idx, arg in enumerate(args): 29 | if arg in ('self', 'request', 'form'): 30 | continue 31 | 32 | didx = len(args)-idx 33 | 34 | if defaults and len(defaults) >= didx: 35 | yield (arg, str(defaults[-didx])) 36 | else: 37 | yield (arg, None) 38 | 39 | def get_signature(self, parse_optional=True): 40 | spec = "" 41 | 42 | for argn, argdef in self.iter_args(): 43 | spec += argn 44 | 45 | if argdef: 46 | spec += '=%s' % argdef 47 | 48 | spec += ', ' 49 | 50 | spec = spec.rstrip(", ") 51 | 52 | if parse_optional: 53 | return spec.replace("=None", "=") 54 | 55 | return spec 56 | 57 | signature = property(get_signature) 58 | 59 | def get_doc(self): 60 | return inspect.getdoc(self.method) 61 | 62 | doc = property(get_doc) 63 | 64 | def get_name(self): 65 | return self.method.__name__ 66 | 67 | name = property(get_name) 68 | 69 | def __repr__(self): 70 | return "" % self.name 71 | 72 | class HandlerDocumentation(object): 73 | def __init__(self, handler): 74 | self.handler = handler 75 | 76 | def get_methods(self, include_default=False): 77 | for method in "read create update delete".split(): 78 | met = getattr(self.handler, method) 79 | stale = inspect.getmodule(met) is handler 80 | 81 | if not self.handler.is_anonymous: 82 | if met and (not stale or include_default): 83 | yield HandlerMethod(met, stale) 84 | else: 85 | if not stale or met.__name__ == "read" \ 86 | and 'GET' in self.allowed_methods: 87 | 88 | yield HandlerMethod(met, stale) 89 | 90 | def get_all_methods(self): 91 | return self.get_methods(include_default=True) 92 | 93 | @property 94 | def is_anonymous(self): 95 | return handler.is_anonymous 96 | 97 | def get_model(self): 98 | return getattr(self, 'model', None) 99 | 100 | def get_doc(self): 101 | return self.handler.__doc__ 102 | 103 | doc = property(get_doc) 104 | 105 | @property 106 | def name(self): 107 | return self.handler.__name__ 108 | 109 | @property 110 | def allowed_methods(self): 111 | return self.handler.allowed_methods 112 | 113 | def get_resource_uri_template(self): 114 | """ 115 | URI template processor. 116 | 117 | See http://bitworking.org/projects/URI-Templates/ 118 | """ 119 | def _convert(template, params=[]): 120 | """URI template converter""" 121 | paths = template % dict([p, "{%s}" % p] for p in params) 122 | return u'%s%s' % (get_script_prefix(), paths) 123 | 124 | try: 125 | resource_uri = self.handler.resource_uri() 126 | 127 | components = [None, [], {}] 128 | 129 | for i, value in enumerate(resource_uri): 130 | components[i] = value 131 | 132 | lookup_view, args, kwargs = components 133 | lookup_view = get_callable(lookup_view, True) 134 | 135 | possibilities = get_resolver(None).reverse_dict.getlist(lookup_view) 136 | 137 | for possibility, pattern in possibilities: 138 | for result, params in possibility: 139 | if args: 140 | if len(args) != len(params): 141 | continue 142 | return _convert(result, params) 143 | else: 144 | if set(kwargs.keys()) != set(params): 145 | continue 146 | return _convert(result, params) 147 | except: 148 | return None 149 | 150 | resource_uri_template = property(get_resource_uri_template) 151 | 152 | def __repr__(self): 153 | return u'' % self.name 154 | 155 | def documentation_view(request): 156 | """ 157 | Generic documentation view. Generates documentation 158 | from the handlers you've defined. 159 | """ 160 | docs = [ ] 161 | 162 | for handler, (model, anonymous) in typemapper.iteritems(): 163 | docs.append(generate_doc(handler)) 164 | 165 | return render_to_response('documentation.html', 166 | { 'docs': docs }, RequestContext(request)) 167 | -------------------------------------------------------------------------------- /piston/decorator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Decorator module, see 3 | http://www.phyast.pitt.edu/~micheles/python/documentation.html 4 | for the documentation and below for the licence. 5 | """ 6 | 7 | ## The basic trick is to generate the source code for the decorated function 8 | ## with the right signature and to evaluate it. 9 | ## Uncomment the statement 'print >> sys.stderr, func_src' in _decorator 10 | ## to understand what is going on. 11 | 12 | __all__ = ["decorator", "new_wrapper", "getinfo"] 13 | 14 | import inspect, sys 15 | 16 | try: 17 | set 18 | except NameError: 19 | from sets import Set as set 20 | 21 | def getinfo(func): 22 | """ 23 | Returns an info dictionary containing: 24 | - name (the name of the function : str) 25 | - argnames (the names of the arguments : list) 26 | - defaults (the values of the default arguments : tuple) 27 | - signature (the signature : str) 28 | - doc (the docstring : str) 29 | - module (the module name : str) 30 | - dict (the function __dict__ : str) 31 | 32 | >>> def f(self, x=1, y=2, *args, **kw): pass 33 | 34 | >>> info = getinfo(f) 35 | 36 | >>> info["name"] 37 | 'f' 38 | >>> info["argnames"] 39 | ['self', 'x', 'y', 'args', 'kw'] 40 | 41 | >>> info["defaults"] 42 | (1, 2) 43 | 44 | >>> info["signature"] 45 | 'self, x, y, *args, **kw' 46 | """ 47 | assert inspect.ismethod(func) or inspect.isfunction(func) 48 | regargs, varargs, varkwargs, defaults = inspect.getargspec(func) 49 | argnames = list(regargs) 50 | if varargs: 51 | argnames.append(varargs) 52 | if varkwargs: 53 | argnames.append(varkwargs) 54 | signature = inspect.formatargspec(regargs, varargs, varkwargs, defaults, 55 | formatvalue=lambda value: "")[1:-1] 56 | return dict(name=func.__name__, argnames=argnames, signature=signature, 57 | defaults = func.func_defaults, doc=func.__doc__, 58 | module=func.__module__, dict=func.__dict__, 59 | globals=func.func_globals, closure=func.func_closure) 60 | 61 | # akin to functools.update_wrapper 62 | def update_wrapper(wrapper, model, infodict=None): 63 | infodict = infodict or getinfo(model) 64 | try: 65 | wrapper.__name__ = infodict['name'] 66 | except: # Python version < 2.4 67 | pass 68 | wrapper.__doc__ = infodict['doc'] 69 | wrapper.__module__ = infodict['module'] 70 | wrapper.__dict__.update(infodict['dict']) 71 | wrapper.func_defaults = infodict['defaults'] 72 | wrapper.undecorated = model 73 | return wrapper 74 | 75 | def new_wrapper(wrapper, model): 76 | """ 77 | An improvement over functools.update_wrapper. The wrapper is a generic 78 | callable object. It works by generating a copy of the wrapper with the 79 | right signature and by updating the copy, not the original. 80 | Moreovoer, 'model' can be a dictionary with keys 'name', 'doc', 'module', 81 | 'dict', 'defaults'. 82 | """ 83 | if isinstance(model, dict): 84 | infodict = model 85 | else: # assume model is a function 86 | infodict = getinfo(model) 87 | assert not '_wrapper_' in infodict["argnames"], ( 88 | '"_wrapper_" is a reserved argument name!') 89 | src = "lambda %(signature)s: _wrapper_(%(signature)s)" % infodict 90 | funcopy = eval(src, dict(_wrapper_=wrapper)) 91 | return update_wrapper(funcopy, model, infodict) 92 | 93 | # helper used in decorator_factory 94 | def __call__(self, func): 95 | infodict = getinfo(func) 96 | for name in ('_func_', '_self_'): 97 | assert not name in infodict["argnames"], ( 98 | '%s is a reserved argument name!' % name) 99 | src = "lambda %(signature)s: _self_.call(_func_, %(signature)s)" 100 | new = eval(src % infodict, dict(_func_=func, _self_=self)) 101 | return update_wrapper(new, func, infodict) 102 | 103 | def decorator_factory(cls): 104 | """ 105 | Take a class with a ``.caller`` method and return a callable decorator 106 | object. It works by adding a suitable __call__ method to the class; 107 | it raises a TypeError if the class already has a nontrivial __call__ 108 | method. 109 | """ 110 | attrs = set(dir(cls)) 111 | if '__call__' in attrs: 112 | raise TypeError('You cannot decorate a class with a nontrivial ' 113 | '__call__ method') 114 | if 'call' not in attrs: 115 | raise TypeError('You cannot decorate a class without a ' 116 | '.call method') 117 | cls.__call__ = __call__ 118 | return cls 119 | 120 | def decorator(caller): 121 | """ 122 | General purpose decorator factory: takes a caller function as 123 | input and returns a decorator with the same attributes. 124 | A caller function is any function like this:: 125 | 126 | def caller(func, *args, **kw): 127 | # do something 128 | return func(*args, **kw) 129 | 130 | Here is an example of usage: 131 | 132 | >>> @decorator 133 | ... def chatty(f, *args, **kw): 134 | ... print "Calling %r" % f.__name__ 135 | ... return f(*args, **kw) 136 | 137 | >>> chatty.__name__ 138 | 'chatty' 139 | 140 | >>> @chatty 141 | ... def f(): pass 142 | ... 143 | >>> f() 144 | Calling 'f' 145 | 146 | decorator can also take in input a class with a .caller method; in this 147 | case it converts the class into a factory of callable decorator objects. 148 | See the documentation for an example. 149 | """ 150 | if inspect.isclass(caller): 151 | return decorator_factory(caller) 152 | def _decorator(func): # the real meat is here 153 | infodict = getinfo(func) 154 | argnames = infodict['argnames'] 155 | assert not ('_call_' in argnames or '_func_' in argnames), ( 156 | 'You cannot use _call_ or _func_ as argument names!') 157 | src = "lambda %(signature)s: _call_(_func_, %(signature)s)" % infodict 158 | # import sys; print >> sys.stderr, src # for debugging purposes 159 | dec_func = eval(src, dict(_func_=func, _call_=caller)) 160 | return update_wrapper(dec_func, func, infodict) 161 | return update_wrapper(_decorator, caller) 162 | 163 | if __name__ == "__main__": 164 | import doctest; doctest.testmod() 165 | 166 | ########################## LEGALESE ############################### 167 | 168 | ## Redistributions of source code must retain the above copyright 169 | ## notice, this list of conditions and the following disclaimer. 170 | ## Redistributions in bytecode form must reproduce the above copyright 171 | ## notice, this list of conditions and the following disclaimer in 172 | ## the documentation and/or other materials provided with the 173 | ## distribution. 174 | 175 | ## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 176 | ## "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 177 | ## LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 178 | ## A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 179 | ## HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 180 | ## INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 181 | ## BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS 182 | ## OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 183 | ## ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR 184 | ## TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE 185 | ## USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH 186 | ## DAMAGE. 187 | -------------------------------------------------------------------------------- /piston/resource.py: -------------------------------------------------------------------------------- 1 | import sys, inspect 2 | 3 | from django.http import (HttpResponse, Http404, HttpResponseNotAllowed, 4 | HttpResponseForbidden, HttpResponseServerError) 5 | from django.views.debug import ExceptionReporter 6 | from django.views.decorators.vary import vary_on_headers 7 | from django.conf import settings 8 | from django.core.mail import send_mail, EmailMessage 9 | 10 | from emitters import Emitter 11 | from handler import typemapper 12 | from doc import HandlerMethod 13 | from authentication import NoAuthentication 14 | from utils import coerce_put_post, FormValidationError, HttpStatusCode 15 | from utils import rc, format_error, translate_mime, MimerDataException 16 | 17 | class Resource(object): 18 | """ 19 | Resource. Create one for your URL mappings, just 20 | like you would with Django. Takes one argument, 21 | the handler. The second argument is optional, and 22 | is an authentication handler. If not specified, 23 | `NoAuthentication` will be used by default. 24 | """ 25 | callmap = { 'GET': 'read', 'POST': 'create', 26 | 'PUT': 'update', 'DELETE': 'delete' } 27 | 28 | def __init__(self, handler, authentication=None): 29 | if not callable(handler): 30 | raise AttributeError, "Handler not callable." 31 | 32 | self.handler = handler() 33 | 34 | if not authentication: 35 | self.authentication = NoAuthentication() 36 | else: 37 | self.authentication = authentication 38 | 39 | # Erroring 40 | self.email_errors = getattr(settings, 'PISTON_EMAIL_ERRORS', True) 41 | self.display_errors = getattr(settings, 'PISTON_DISPLAY_ERRORS', True) 42 | self.stream = getattr(settings, 'PISTON_STREAM_OUTPUT', False) 43 | 44 | def determine_emitter(self, request, *args, **kwargs): 45 | """ 46 | Function for determening which emitter to use 47 | for output. It lives here so you can easily subclass 48 | `Resource` in order to change how emission is detected. 49 | 50 | You could also check for the `Accept` HTTP header here, 51 | since that pretty much makes sense. Refer to `Mimer` for 52 | that as well. 53 | """ 54 | em = kwargs.pop('emitter_format', None) 55 | 56 | if not em: 57 | em = request.GET.get('format', 'json') 58 | 59 | return em 60 | 61 | @vary_on_headers('Authorization') 62 | def __call__(self, request, *args, **kwargs): 63 | """ 64 | NB: Sends a `Vary` header so we don't cache requests 65 | that are different (OAuth stuff in `Authorization` header.) 66 | """ 67 | rm = request.method.upper() 68 | 69 | # Django's internal mechanism doesn't pick up 70 | # PUT request, so we trick it a little here. 71 | if rm == "PUT": 72 | coerce_put_post(request) 73 | 74 | if not self.authentication.is_authenticated(request): 75 | if hasattr(self.handler, 'anonymous') and \ 76 | callable(self.handler.anonymous) and \ 77 | rm in self.handler.anonymous.allowed_methods: 78 | 79 | handler = self.handler.anonymous() 80 | anonymous = True 81 | else: 82 | return self.authentication.challenge() 83 | else: 84 | handler = self.handler 85 | anonymous = handler.is_anonymous 86 | 87 | # Translate nested datastructs into `request.data` here. 88 | if rm in ('POST', 'PUT'): 89 | try: 90 | translate_mime(request) 91 | except MimerDataException: 92 | return rc.BAD_REQUEST 93 | 94 | if not rm in handler.allowed_methods: 95 | return HttpResponseNotAllowed(handler.allowed_methods) 96 | 97 | meth = getattr(handler, self.callmap.get(rm), None) 98 | 99 | if not meth: 100 | raise Http404 101 | 102 | # Support emitter both through (?P) and ?format=emitter. 103 | em_format = self.determine_emitter(request, *args, **kwargs) 104 | 105 | kwargs.pop('emitter_format', None) 106 | 107 | # Clean up the request object a bit, since we might 108 | # very well have `oauth_`-headers in there, and we 109 | # don't want to pass these along to the handler. 110 | request = self.cleanup_request(request) 111 | 112 | try: 113 | result = meth(request, *args, **kwargs) 114 | except FormValidationError, e: 115 | # TODO: Use rc.BAD_REQUEST here 116 | return HttpResponse("Bad Request: %s" % e.form.errors, status=400) 117 | except TypeError, e: 118 | result = rc.BAD_REQUEST 119 | hm = HandlerMethod(meth) 120 | sig = hm.get_signature() 121 | 122 | msg = 'Method signature does not match.\n\n' 123 | 124 | if sig: 125 | msg += 'Signature should be: %s' % sig 126 | else: 127 | msg += 'Resource does not expect any parameters.' 128 | 129 | if self.display_errors: 130 | msg += '\n\nException was: %s' % str(e) 131 | 132 | result.content = format_error(msg) 133 | except HttpStatusCode, e: 134 | #result = e ## why is this being passed on and not just dealt with now? 135 | return e.response 136 | except Exception, e: 137 | """ 138 | On errors (like code errors), we'd like to be able to 139 | give crash reports to both admins and also the calling 140 | user. There's two setting parameters for this: 141 | 142 | Parameters:: 143 | - `PISTON_EMAIL_ERRORS`: Will send a Django formatted 144 | error email to people in `settings.ADMINS`. 145 | - `PISTON_DISPLAY_ERRORS`: Will return a simple traceback 146 | to the caller, so he can tell you what error they got. 147 | 148 | If `PISTON_DISPLAY_ERRORS` is not enabled, the caller will 149 | receive a basic "500 Internal Server Error" message. 150 | """ 151 | exc_type, exc_value, tb = sys.exc_info() 152 | rep = ExceptionReporter(request, exc_type, exc_value, tb.tb_next) 153 | if self.email_errors: 154 | self.email_exception(rep) 155 | if self.display_errors: 156 | return HttpResponseServerError( 157 | format_error('\n'.join(rep.format_exception()))) 158 | else: 159 | raise 160 | 161 | emitter, ct = Emitter.get(em_format) 162 | srl = emitter(result, typemapper, handler, handler.fields, anonymous) 163 | 164 | try: 165 | """ 166 | Decide whether or not we want a generator here, 167 | or we just want to buffer up the entire result 168 | before sending it to the client. Won't matter for 169 | smaller datasets, but larger will have an impact. 170 | """ 171 | if self.stream: stream = srl.stream_render(request) 172 | else: stream = srl.render(request) 173 | 174 | resp = HttpResponse(stream, mimetype=ct) 175 | 176 | resp.streaming = self.stream 177 | 178 | return resp 179 | except HttpStatusCode, e: 180 | return e.response 181 | 182 | @staticmethod 183 | def cleanup_request(request): 184 | """ 185 | Removes `oauth_` keys from various dicts on the 186 | request object, and returns the sanitized version. 187 | """ 188 | for method_type in ('GET', 'PUT', 'POST', 'DELETE'): 189 | block = getattr(request, method_type, { }) 190 | 191 | if True in [ k.startswith("oauth_") for k in block.keys() ]: 192 | sanitized = block.copy() 193 | 194 | for k in sanitized.keys(): 195 | if k.startswith("oauth_"): 196 | sanitized.pop(k) 197 | 198 | setattr(request, method_type, sanitized) 199 | 200 | return request 201 | 202 | # -- 203 | 204 | def email_exception(self, reporter): 205 | subject = "Piston crash report" 206 | html = reporter.get_traceback_html() 207 | 208 | message = EmailMessage(settings.EMAIL_SUBJECT_PREFIX+subject, 209 | html, settings.SERVER_EMAIL, 210 | [ admin[1] for admin in settings.ADMINS ]) 211 | 212 | message.content_subtype = 'html' 213 | message.send(fail_silently=True) 214 | -------------------------------------------------------------------------------- /piston/utils.py: -------------------------------------------------------------------------------- 1 | from django.http import HttpResponseNotAllowed, HttpResponseForbidden, HttpResponse, HttpResponseBadRequest 2 | from django.core.urlresolvers import reverse 3 | from django.core.cache import cache 4 | from django import get_version as django_version 5 | from decorator import decorator 6 | 7 | from datetime import datetime, timedelta 8 | 9 | __version__ = '0.2.2' 10 | 11 | def get_version(): 12 | return __version__ 13 | 14 | def format_error(error): 15 | return u"Piston/%s (Django %s) crash report:\n\n%s" % \ 16 | (get_version(), django_version(), error) 17 | 18 | class rc_factory(object): 19 | """ 20 | Status codes. 21 | """ 22 | CODES = dict(ALL_OK = ('OK', 200), 23 | CREATED = ('Created', 201), 24 | DELETED = ('', 204), # 204 says "Don't send a body!" 25 | BAD_REQUEST = ('Bad Request', 400), 26 | FORBIDDEN = ('Forbidden', 401), 27 | NOT_FOUND = ('Not Found', 404), 28 | DUPLICATE_ENTRY = ('Conflict/Duplicate', 409), 29 | NOT_HERE = ('Gone', 410), 30 | NOT_IMPLEMENTED = ('Not Implemented', 501), 31 | THROTTLED = ('Throttled', 503)) 32 | 33 | def __getattr__(self, attr): 34 | """ 35 | Returns a fresh `HttpResponse` when getting 36 | an "attribute". This is backwards compatible 37 | with 0.2, which is important. 38 | """ 39 | try: 40 | (r, c) = self.CODES.get(attr) 41 | except TypeError: 42 | raise AttributeError(attr) 43 | 44 | return HttpResponse(r, content_type='text/plain', status=c) 45 | 46 | rc = rc_factory() 47 | 48 | class FormValidationError(Exception): 49 | def __init__(self, form): 50 | self.form = form 51 | 52 | class HttpStatusCode(Exception): 53 | def __init__(self, response): 54 | self.response = response 55 | 56 | def validate(v_form, operation='POST'): 57 | @decorator 58 | def wrap(f, self, request, *a, **kwa): 59 | form = v_form(getattr(request, operation)) 60 | 61 | if form.is_valid(): 62 | return f(self, request, *a, **kwa) 63 | else: 64 | raise FormValidationError(form) 65 | return wrap 66 | 67 | def throttle(max_requests, timeout=60*60, extra=''): 68 | """ 69 | Simple throttling decorator, caches 70 | the amount of requests made in cache. 71 | 72 | If used on a view where users are required to 73 | log in, the username is used, otherwise the 74 | IP address of the originating request is used. 75 | 76 | Parameters:: 77 | - `max_requests`: The maximum number of requests 78 | - `timeout`: The timeout for the cache entry (default: 1 hour) 79 | """ 80 | @decorator 81 | def wrap(f, self, request, *args, **kwargs): 82 | if request.user.is_authenticated(): 83 | ident = request.user.username 84 | else: 85 | ident = request.META.get('REMOTE_ADDR', None) 86 | 87 | if hasattr(request, 'throttle_extra'): 88 | """ 89 | Since we want to be able to throttle on a per- 90 | application basis, it's important that we realize 91 | that `throttle_extra` might be set on the request 92 | object. If so, append the identifier name with it. 93 | """ 94 | ident += ':%s' % str(request.throttle_extra) 95 | 96 | if ident: 97 | """ 98 | Preferrably we'd use incr/decr here, since they're 99 | atomic in memcached, but it's in django-trunk so we 100 | can't use it yet. If someone sees this after it's in 101 | stable, you can change it here. 102 | """ 103 | ident += ':%s' % extra 104 | 105 | now = datetime.now() 106 | ts_key = 'throttle:ts:%s' % ident 107 | timestamp = cache.get(ts_key) 108 | offset = now + timedelta(seconds=timeout) 109 | 110 | if timestamp and timestamp < offset: 111 | t = rc.THROTTLED 112 | wait = timeout - (offset-timestamp).seconds 113 | t.content = 'Throttled, wait %d seconds.' % wait 114 | 115 | return t 116 | 117 | count = cache.get(ident, 1) 118 | cache.set(ident, count+1) 119 | 120 | if count >= max_requests: 121 | cache.set(ts_key, offset, timeout) 122 | cache.set(ident, 1) 123 | 124 | return f(self, request, *args, **kwargs) 125 | return wrap 126 | 127 | def coerce_put_post(request): 128 | """ 129 | Django doesn't particularly understand REST. 130 | In case we send data over PUT, Django won't 131 | actually look at the data and load it. We need 132 | to twist its arm here. 133 | 134 | The try/except abominiation here is due to a bug 135 | in mod_python. This should fix it. 136 | """ 137 | if request.method == "PUT": 138 | try: 139 | request.method = "POST" 140 | request._load_post_and_files() 141 | request.method = "PUT" 142 | except AttributeError: 143 | request.META['REQUEST_METHOD'] = 'POST' 144 | request._load_post_and_files() 145 | request.META['REQUEST_METHOD'] = 'PUT' 146 | 147 | request.PUT = request.POST 148 | 149 | 150 | class MimerDataException(Exception): 151 | """ 152 | Raised if the content_type and data don't match 153 | """ 154 | pass 155 | 156 | class Mimer(object): 157 | TYPES = dict() 158 | 159 | def __init__(self, request): 160 | self.request = request 161 | 162 | def is_multipart(self): 163 | content_type = self.content_type() 164 | if content_type is not None: 165 | return content_type.lstrip().startswith('multipart') 166 | return False 167 | 168 | def loader_for_type(self, ctype): 169 | """ 170 | Gets a function ref to deserialize content 171 | for a certain mimetype. 172 | """ 173 | for loadee, mimes in Mimer.TYPES.iteritems(): 174 | if ctype in mimes: 175 | return loadee 176 | 177 | def content_type(self): 178 | """ 179 | Returns the content type of the request in all cases where it is 180 | different than a submitted form - application/x-www-form-urlencoded 181 | """ 182 | type_formencoded = "application/x-www-form-urlencoded" 183 | 184 | ctype = self.request.META.get('CONTENT_TYPE', type_formencoded) 185 | 186 | if ctype == type_formencoded: 187 | return None 188 | 189 | return ctype 190 | 191 | 192 | def translate(self): 193 | """ 194 | Will look at the `Content-type` sent by the client, and maybe 195 | deserialize the contents into the format they sent. This will 196 | work for JSON, YAML, XML and Pickle. Since the data is not just 197 | key-value (and maybe just a list), the data will be placed on 198 | `request.data` instead, and the handler will have to read from 199 | there. 200 | 201 | It will also set `request.content_type` so the handler has an easy 202 | way to tell what's going on. `request.content_type` will always be 203 | None for form-encoded and/or multipart form data (what your browser sends.) 204 | """ 205 | ctype = self.content_type() 206 | self.request.content_type = ctype 207 | 208 | if not self.is_multipart() and ctype: 209 | loadee = self.loader_for_type(ctype) 210 | 211 | try: 212 | self.request.data = loadee(self.request.raw_post_data) 213 | 214 | # Reset both POST and PUT from request, as its 215 | # misleading having their presence around. 216 | self.request.POST = self.request.PUT = dict() 217 | except (TypeError, ValueError): 218 | raise MimerDataException 219 | 220 | return self.request 221 | 222 | @classmethod 223 | def register(cls, loadee, types): 224 | cls.TYPES[loadee] = types 225 | 226 | @classmethod 227 | def unregister(cls, loadee): 228 | return cls.TYPES.pop(loadee) 229 | 230 | def translate_mime(request): 231 | request = Mimer(request).translate() 232 | 233 | def require_mime(*mimes): 234 | """ 235 | Decorator requiring a certain mimetype. There's a nifty 236 | helper called `require_extended` below which requires everything 237 | we support except for post-data via form. 238 | """ 239 | @decorator 240 | def wrap(f, self, request, *args, **kwargs): 241 | m = Mimer(request) 242 | realmimes = set() 243 | 244 | rewrite = { 'json': 'application/json', 245 | 'yaml': 'application/x-yaml', 246 | 'xml': 'text/xml', 247 | 'pickle': 'application/python-pickle' } 248 | 249 | for idx, mime in enumerate(mimes): 250 | realmimes.add(rewrite.get(mime, mime)) 251 | 252 | if not m.content_type() in realmimes: 253 | return rc.BAD_REQUEST 254 | 255 | return f(self, request, *args, **kwargs) 256 | return wrap 257 | 258 | require_extended = require_mime('json', 'yaml', 'xml', 'pickle') 259 | 260 | -------------------------------------------------------------------------------- /piston/authentication.py: -------------------------------------------------------------------------------- 1 | import oauth 2 | from django.http import HttpResponse, HttpResponseRedirect 3 | from django.contrib.auth.models import User, AnonymousUser 4 | from django.contrib.auth.decorators import login_required 5 | from django.template import loader 6 | from django.contrib.auth import authenticate 7 | from django.conf import settings 8 | from django.core.urlresolvers import get_callable 9 | from django.core.exceptions import ImproperlyConfigured 10 | from django.shortcuts import render_to_response 11 | from django.template import RequestContext 12 | 13 | from piston import forms 14 | 15 | class NoAuthentication(object): 16 | """ 17 | Authentication handler that always returns 18 | True, so no authentication is needed, nor 19 | initiated (`challenge` is missing.) 20 | """ 21 | def is_authenticated(self, request): 22 | return True 23 | 24 | class HttpBasicAuthentication(object): 25 | """ 26 | Basic HTTP authenticater. Synopsis: 27 | 28 | Authentication handlers must implement two methods: 29 | - `is_authenticated`: Will be called when checking for 30 | authentication. Receives a `request` object, please 31 | set your `User` object on `request.user`, otherwise 32 | return False (or something that evaluates to False.) 33 | - `challenge`: In cases where `is_authenticated` returns 34 | False, the result of this method will be returned. 35 | This will usually be a `HttpResponse` object with 36 | some kind of challenge headers and 401 code on it. 37 | """ 38 | def __init__(self, auth_func=authenticate, realm='API'): 39 | self.auth_func = auth_func 40 | self.realm = realm 41 | 42 | def is_authenticated(self, request): 43 | auth_string = request.META.get('HTTP_AUTHORIZATION', None) 44 | 45 | if not auth_string: 46 | return False 47 | 48 | (authmeth, auth) = auth_string.split(" ", 1) 49 | 50 | if not authmeth.lower() == 'basic': 51 | return False 52 | 53 | auth = auth.strip().decode('base64') 54 | (username, password) = auth.split(':', 1) 55 | 56 | request.user = self.auth_func(username=username, password=password) \ 57 | or AnonymousUser() 58 | 59 | return not request.user in (False, None, AnonymousUser()) 60 | 61 | def challenge(self): 62 | resp = HttpResponse("Authorization Required") 63 | resp['WWW-Authenticate'] = 'Basic realm="%s"' % self.realm 64 | resp.status_code = 401 65 | return resp 66 | 67 | def load_data_store(): 68 | '''Load data store for OAuth Consumers, Tokens, Nonces and Resources 69 | ''' 70 | path = getattr(settings, 'OAUTH_DATA_STORE', 'piston.store.DataStore') 71 | 72 | # stolen from django.contrib.auth.load_backend 73 | i = path.rfind('.') 74 | module, attr = path[:i], path[i+1:] 75 | 76 | try: 77 | mod = __import__(module, {}, {}, attr) 78 | except ImportError, e: 79 | raise ImproperlyConfigured, 'Error importing OAuth data store %s: "%s"' % (module, e) 80 | 81 | try: 82 | cls = getattr(mod, attr) 83 | except AttributeError: 84 | raise ImproperlyConfigured, 'Module %s does not define a "%s" OAuth data store' % (module, attr) 85 | 86 | return cls 87 | 88 | # Set the datastore here. 89 | oauth_datastore = load_data_store() 90 | 91 | def initialize_server_request(request): 92 | """ 93 | Shortcut for initialization. 94 | """ 95 | oauth_request = oauth.OAuthRequest.from_request( 96 | request.method, request.build_absolute_uri(), 97 | headers=request.META, parameters=dict(request.REQUEST.items()), 98 | query_string=request.environ.get('QUERY_STRING', '')) 99 | 100 | if oauth_request: 101 | oauth_server = oauth.OAuthServer(oauth_datastore(oauth_request)) 102 | oauth_server.add_signature_method(oauth.OAuthSignatureMethod_PLAINTEXT()) 103 | oauth_server.add_signature_method(oauth.OAuthSignatureMethod_HMAC_SHA1()) 104 | else: 105 | oauth_server = None 106 | 107 | return oauth_server, oauth_request 108 | 109 | def send_oauth_error(err=None): 110 | """ 111 | Shortcut for sending an error. 112 | """ 113 | response = HttpResponse(err.message.encode('utf-8')) 114 | response.status_code = 401 115 | 116 | realm = 'OAuth' 117 | header = oauth.build_authenticate_header(realm=realm) 118 | 119 | for k, v in header.iteritems(): 120 | response[k] = v 121 | 122 | return response 123 | 124 | def oauth_request_token(request): 125 | oauth_server, oauth_request = initialize_server_request(request) 126 | 127 | if oauth_server is None: 128 | return INVALID_PARAMS_RESPONSE 129 | try: 130 | token = oauth_server.fetch_request_token(oauth_request) 131 | 132 | response = HttpResponse(token.to_string()) 133 | except oauth.OAuthError, err: 134 | response = send_oauth_error(err) 135 | 136 | return response 137 | 138 | def oauth_auth_view(request, token, callback, params): 139 | form = forms.OAuthAuthenticationForm(initial={ 140 | 'oauth_token': token.key, 141 | 'oauth_callback': callback, 142 | }) 143 | 144 | return render_to_response('piston/authorize_token.html', 145 | { 'form': form }, RequestContext(request)) 146 | 147 | @login_required 148 | def oauth_user_auth(request): 149 | oauth_server, oauth_request = initialize_server_request(request) 150 | 151 | if oauth_request is None: 152 | return INVALID_PARAMS_RESPONSE 153 | 154 | try: 155 | token = oauth_server.fetch_request_token(oauth_request) 156 | except oauth.OAuthError, err: 157 | return send_oauth_error(err) 158 | 159 | try: 160 | callback = oauth_server.get_callback(oauth_request) 161 | except: 162 | callback = None 163 | 164 | if request.method == "GET": 165 | params = oauth_request.get_normalized_parameters() 166 | 167 | oauth_view = getattr(settings, 'OAUTH_AUTH_VIEW', None) 168 | if oauth_view is None: 169 | return oauth_auth_view(request, token, callback, params) 170 | else: 171 | return get_callable(oauth_view)(request, token, callback, params) 172 | elif request.method == "POST": 173 | try: 174 | form = forms.OAuthAuthenticationForm(request.POST) 175 | if form.is_valid(): 176 | token = oauth_server.authorize_token(token, request.user) 177 | args = '?'+token.to_string(only_key=True) 178 | else: 179 | args = '?error=%s' % 'Access not granted by user.' 180 | 181 | if not callback: 182 | callback = getattr(settings, 'OAUTH_CALLBACK_VIEW') 183 | return get_callable(callback)(request, token) 184 | 185 | response = HttpResponseRedirect(callback+args) 186 | 187 | except oauth.OAuthError, err: 188 | response = send_oauth_error(err) 189 | else: 190 | response = HttpResponse('Action not allowed.') 191 | 192 | return response 193 | 194 | def oauth_access_token(request): 195 | oauth_server, oauth_request = initialize_server_request(request) 196 | 197 | if oauth_request is None: 198 | return INVALID_PARAMS_RESPONSE 199 | 200 | try: 201 | token = oauth_server.fetch_access_token(oauth_request) 202 | return HttpResponse(token.to_string()) 203 | except oauth.OAuthError, err: 204 | return send_oauth_error(err) 205 | 206 | INVALID_PARAMS_RESPONSE = send_oauth_error(oauth.OAuthError('Invalid request parameters.')) 207 | 208 | class OAuthAuthentication(object): 209 | """ 210 | OAuth authentication. Based on work by Leah Culver. 211 | """ 212 | def __init__(self, realm='API'): 213 | self.realm = realm 214 | self.builder = oauth.build_authenticate_header 215 | 216 | def is_authenticated(self, request): 217 | """ 218 | Checks whether a means of specifying authentication 219 | is provided, and if so, if it is a valid token. 220 | 221 | Read the documentation on `HttpBasicAuthentication` 222 | for more information about what goes on here. 223 | """ 224 | if self.is_valid_request(request): 225 | try: 226 | consumer, token, parameters = self.validate_token(request) 227 | except oauth.OAuthError, err: 228 | print send_oauth_error(err) 229 | return False 230 | 231 | if consumer and token: 232 | request.user = token.user 233 | request.throttle_extra = token.consumer.id 234 | return True 235 | 236 | return False 237 | 238 | def challenge(self): 239 | """ 240 | Returns a 401 response with a small bit on 241 | what OAuth is, and where to learn more about it. 242 | 243 | When this was written, browsers did not understand 244 | OAuth authentication on the browser side, and hence 245 | the helpful template we render. Maybe some day in the 246 | future, browsers will take care of this stuff for us 247 | and understand the 401 with the realm we give it. 248 | """ 249 | response = HttpResponse() 250 | response.status_code = 401 251 | realm = 'API' 252 | 253 | for k, v in self.builder(realm=realm).iteritems(): 254 | response[k] = v 255 | 256 | tmpl = loader.render_to_string('oauth/challenge.html', 257 | { 'MEDIA_URL': settings.MEDIA_URL }) 258 | 259 | response.content = tmpl 260 | 261 | return response 262 | 263 | @staticmethod 264 | def is_valid_request(request): 265 | """ 266 | Checks whether the required parameters are either in 267 | the http-authorization header sent by some clients, 268 | which is by the way the preferred method according to 269 | OAuth spec, but otherwise fall back to `GET` and `POST`. 270 | """ 271 | must_have = [ 'oauth_'+s for s in [ 272 | 'consumer_key', 'token', 'signature', 273 | 'signature_method', 'timestamp', 'nonce' ] ] 274 | 275 | is_in = lambda l: all([ (p in l) for p in must_have ]) 276 | 277 | auth_params = request.META.get("HTTP_AUTHORIZATION", "") 278 | req_params = request.REQUEST 279 | 280 | return is_in(auth_params) or is_in(req_params) 281 | 282 | @staticmethod 283 | def validate_token(request, check_timestamp=True, check_nonce=True): 284 | oauth_server, oauth_request = initialize_server_request(request) 285 | return oauth_server.verify_request(oauth_request) 286 | 287 | -------------------------------------------------------------------------------- /ez_setup.py: -------------------------------------------------------------------------------- 1 | #!python 2 | """Bootstrap setuptools installation 3 | 4 | If you want to use setuptools in your package's setup.py, just include this 5 | file in the same directory with it, and add this to the top of your setup.py:: 6 | 7 | from ez_setup import use_setuptools 8 | use_setuptools() 9 | 10 | If you want to require a specific version of setuptools, set a download 11 | mirror, or use an alternate download directory, you can do so by supplying 12 | the appropriate options to ``use_setuptools()``. 13 | 14 | This file can also be run as a script to install or upgrade setuptools. 15 | """ 16 | import sys 17 | DEFAULT_VERSION = "0.6c9" 18 | DEFAULT_URL = "http://pypi.python.org/packages/%s/s/setuptools/" % sys.version[:3] 19 | 20 | md5_data = { 21 | 'setuptools-0.6b1-py2.3.egg': '8822caf901250d848b996b7f25c6e6ca', 22 | 'setuptools-0.6b1-py2.4.egg': 'b79a8a403e4502fbb85ee3f1941735cb', 23 | 'setuptools-0.6b2-py2.3.egg': '5657759d8a6d8fc44070a9d07272d99b', 24 | 'setuptools-0.6b2-py2.4.egg': '4996a8d169d2be661fa32a6e52e4f82a', 25 | 'setuptools-0.6b3-py2.3.egg': 'bb31c0fc7399a63579975cad9f5a0618', 26 | 'setuptools-0.6b3-py2.4.egg': '38a8c6b3d6ecd22247f179f7da669fac', 27 | 'setuptools-0.6b4-py2.3.egg': '62045a24ed4e1ebc77fe039aa4e6f7e5', 28 | 'setuptools-0.6b4-py2.4.egg': '4cb2a185d228dacffb2d17f103b3b1c4', 29 | 'setuptools-0.6c1-py2.3.egg': 'b3f2b5539d65cb7f74ad79127f1a908c', 30 | 'setuptools-0.6c1-py2.4.egg': 'b45adeda0667d2d2ffe14009364f2a4b', 31 | 'setuptools-0.6c2-py2.3.egg': 'f0064bf6aa2b7d0f3ba0b43f20817c27', 32 | 'setuptools-0.6c2-py2.4.egg': '616192eec35f47e8ea16cd6a122b7277', 33 | 'setuptools-0.6c3-py2.3.egg': 'f181fa125dfe85a259c9cd6f1d7b78fa', 34 | 'setuptools-0.6c3-py2.4.egg': 'e0ed74682c998bfb73bf803a50e7b71e', 35 | 'setuptools-0.6c3-py2.5.egg': 'abef16fdd61955514841c7c6bd98965e', 36 | 'setuptools-0.6c4-py2.3.egg': 'b0b9131acab32022bfac7f44c5d7971f', 37 | 'setuptools-0.6c4-py2.4.egg': '2a1f9656d4fbf3c97bf946c0a124e6e2', 38 | 'setuptools-0.6c4-py2.5.egg': '8f5a052e32cdb9c72bcf4b5526f28afc', 39 | 'setuptools-0.6c5-py2.3.egg': 'ee9fd80965da04f2f3e6b3576e9d8167', 40 | 'setuptools-0.6c5-py2.4.egg': 'afe2adf1c01701ee841761f5bcd8aa64', 41 | 'setuptools-0.6c5-py2.5.egg': 'a8d3f61494ccaa8714dfed37bccd3d5d', 42 | 'setuptools-0.6c6-py2.3.egg': '35686b78116a668847237b69d549ec20', 43 | 'setuptools-0.6c6-py2.4.egg': '3c56af57be3225019260a644430065ab', 44 | 'setuptools-0.6c6-py2.5.egg': 'b2f8a7520709a5b34f80946de5f02f53', 45 | 'setuptools-0.6c7-py2.3.egg': '209fdf9adc3a615e5115b725658e13e2', 46 | 'setuptools-0.6c7-py2.4.egg': '5a8f954807d46a0fb67cf1f26c55a82e', 47 | 'setuptools-0.6c7-py2.5.egg': '45d2ad28f9750e7434111fde831e8372', 48 | 'setuptools-0.6c8-py2.3.egg': '50759d29b349db8cfd807ba8303f1902', 49 | 'setuptools-0.6c8-py2.4.egg': 'cba38d74f7d483c06e9daa6070cce6de', 50 | 'setuptools-0.6c8-py2.5.egg': '1721747ee329dc150590a58b3e1ac95b', 51 | 'setuptools-0.6c9-py2.3.egg': 'a83c4020414807b496e4cfbe08507c03', 52 | 'setuptools-0.6c9-py2.4.egg': '260a2be2e5388d66bdaee06abec6342a', 53 | 'setuptools-0.6c9-py2.5.egg': 'fe67c3e5a17b12c0e7c541b7ea43a8e6', 54 | 'setuptools-0.6c9-py2.6.egg': 'ca37b1ff16fa2ede6e19383e7b59245a', 55 | } 56 | 57 | import sys, os 58 | try: from hashlib import md5 59 | except ImportError: from md5 import md5 60 | 61 | def _validate_md5(egg_name, data): 62 | if egg_name in md5_data: 63 | digest = md5(data).hexdigest() 64 | if digest != md5_data[egg_name]: 65 | print >>sys.stderr, ( 66 | "md5 validation of %s failed! (Possible download problem?)" 67 | % egg_name 68 | ) 69 | sys.exit(2) 70 | return data 71 | 72 | def use_setuptools( 73 | version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir, 74 | download_delay=15 75 | ): 76 | """Automatically find/download setuptools and make it available on sys.path 77 | 78 | `version` should be a valid setuptools version number that is available 79 | as an egg for download under the `download_base` URL (which should end with 80 | a '/'). `to_dir` is the directory where setuptools will be downloaded, if 81 | it is not already available. If `download_delay` is specified, it should 82 | be the number of seconds that will be paused before initiating a download, 83 | should one be required. If an older version of setuptools is installed, 84 | this routine will print a message to ``sys.stderr`` and raise SystemExit in 85 | an attempt to abort the calling script. 86 | """ 87 | was_imported = 'pkg_resources' in sys.modules or 'setuptools' in sys.modules 88 | def do_download(): 89 | egg = download_setuptools(version, download_base, to_dir, download_delay) 90 | sys.path.insert(0, egg) 91 | import setuptools; setuptools.bootstrap_install_from = egg 92 | try: 93 | import pkg_resources 94 | except ImportError: 95 | return do_download() 96 | try: 97 | pkg_resources.require("setuptools>="+version); return 98 | except pkg_resources.VersionConflict, e: 99 | if was_imported: 100 | print >>sys.stderr, ( 101 | "The required version of setuptools (>=%s) is not available, and\n" 102 | "can't be installed while this script is running. Please install\n" 103 | " a more recent version first, using 'easy_install -U setuptools'." 104 | "\n\n(Currently using %r)" 105 | ) % (version, e.args[0]) 106 | sys.exit(2) 107 | else: 108 | del pkg_resources, sys.modules['pkg_resources'] # reload ok 109 | return do_download() 110 | except pkg_resources.DistributionNotFound: 111 | return do_download() 112 | 113 | def download_setuptools( 114 | version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir, 115 | delay = 15 116 | ): 117 | """Download setuptools from a specified location and return its filename 118 | 119 | `version` should be a valid setuptools version number that is available 120 | as an egg for download under the `download_base` URL (which should end 121 | with a '/'). `to_dir` is the directory where the egg will be downloaded. 122 | `delay` is the number of seconds to pause before an actual download attempt. 123 | """ 124 | import urllib2, shutil 125 | egg_name = "setuptools-%s-py%s.egg" % (version,sys.version[:3]) 126 | url = download_base + egg_name 127 | saveto = os.path.join(to_dir, egg_name) 128 | src = dst = None 129 | if not os.path.exists(saveto): # Avoid repeated downloads 130 | try: 131 | from distutils import log 132 | if delay: 133 | log.warn(""" 134 | --------------------------------------------------------------------------- 135 | This script requires setuptools version %s to run (even to display 136 | help). I will attempt to download it for you (from 137 | %s), but 138 | you may need to enable firewall access for this script first. 139 | I will start the download in %d seconds. 140 | 141 | (Note: if this machine does not have network access, please obtain the file 142 | 143 | %s 144 | 145 | and place it in this directory before rerunning this script.) 146 | ---------------------------------------------------------------------------""", 147 | version, download_base, delay, url 148 | ); from time import sleep; sleep(delay) 149 | log.warn("Downloading %s", url) 150 | src = urllib2.urlopen(url) 151 | # Read/write all in one block, so we don't create a corrupt file 152 | # if the download is interrupted. 153 | data = _validate_md5(egg_name, src.read()) 154 | dst = open(saveto,"wb"); dst.write(data) 155 | finally: 156 | if src: src.close() 157 | if dst: dst.close() 158 | return os.path.realpath(saveto) 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | def main(argv, version=DEFAULT_VERSION): 196 | """Install or upgrade setuptools and EasyInstall""" 197 | try: 198 | import setuptools 199 | except ImportError: 200 | egg = None 201 | try: 202 | egg = download_setuptools(version, delay=0) 203 | sys.path.insert(0,egg) 204 | from setuptools.command.easy_install import main 205 | return main(list(argv)+[egg]) # we're done here 206 | finally: 207 | if egg and os.path.exists(egg): 208 | os.unlink(egg) 209 | else: 210 | if setuptools.__version__ == '0.0.1': 211 | print >>sys.stderr, ( 212 | "You have an obsolete version of setuptools installed. Please\n" 213 | "remove it from your system entirely before rerunning this script." 214 | ) 215 | sys.exit(2) 216 | 217 | req = "setuptools>="+version 218 | import pkg_resources 219 | try: 220 | pkg_resources.require(req) 221 | except pkg_resources.VersionConflict: 222 | try: 223 | from setuptools.command.easy_install import main 224 | except ImportError: 225 | from easy_install import main 226 | main(list(argv)+[download_setuptools(delay=0)]) 227 | sys.exit(0) # try to force an exit 228 | else: 229 | if argv: 230 | from setuptools.command.easy_install import main 231 | main(argv) 232 | else: 233 | print "Setuptools version",version,"or greater has been installed." 234 | print '(Run "ez_setup.py -U setuptools" to reinstall or upgrade.)' 235 | 236 | def update_md5(filenames): 237 | """Update our built-in md5 registry""" 238 | 239 | import re 240 | 241 | for name in filenames: 242 | base = os.path.basename(name) 243 | f = open(name,'rb') 244 | md5_data[base] = md5(f.read()).hexdigest() 245 | f.close() 246 | 247 | data = [" %r: %r,\n" % it for it in md5_data.items()] 248 | data.sort() 249 | repl = "".join(data) 250 | 251 | import inspect 252 | srcfile = inspect.getsourcefile(sys.modules[__name__]) 253 | f = open(srcfile, 'rb'); src = f.read(); f.close() 254 | 255 | match = re.search("\nmd5_data = {\n([^}]+)}", src) 256 | if not match: 257 | print >>sys.stderr, "Internal error!" 258 | sys.exit(2) 259 | 260 | src = src[:match.start(1)] + repl + src[match.end(1):] 261 | f = open(srcfile,'w') 262 | f.write(src) 263 | f.close() 264 | 265 | 266 | if __name__=='__main__': 267 | if len(sys.argv)>2 and sys.argv[1]=='--md5update': 268 | update_md5(sys.argv[2:]) 269 | else: 270 | main(sys.argv[1:]) 271 | 272 | 273 | 274 | 275 | 276 | 277 | -------------------------------------------------------------------------------- /tests/test_project/apps/testapp/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | from django.contrib.auth.models import User 3 | from django.utils import simplejson 4 | from django.conf import settings 5 | 6 | from piston import oauth 7 | from piston.models import Consumer, Token 8 | from piston.forms import OAuthAuthenticationForm 9 | 10 | try: 11 | import yaml 12 | except ImportError: 13 | print "Can't run YAML testsuite" 14 | yaml = None 15 | 16 | import urllib, base64 17 | 18 | from test_project.apps.testapp.models import TestModel, ExpressiveTestModel, Comment, InheritedModel 19 | from test_project.apps.testapp import signals 20 | 21 | class MainTests(TestCase): 22 | def setUp(self): 23 | self.user = User.objects.create_user('admin', 'admin@world.com', 'admin') 24 | self.user.is_staff = True 25 | self.user.is_superuser = True 26 | self.user.is_active = True 27 | self.user.save() 28 | self.auth_string = 'Basic %s' % base64.encodestring('admin:admin').rstrip() 29 | 30 | if hasattr(self, 'init_delegate'): 31 | self.init_delegate() 32 | 33 | def tearDown(self): 34 | self.user.delete() 35 | 36 | 37 | 38 | class OAuthTests(MainTests): 39 | signature_method = oauth.OAuthSignatureMethod_HMAC_SHA1() 40 | 41 | def setUp(self): 42 | super(OAuthTests, self).setUp() 43 | 44 | self.consumer = Consumer(name='Test Consumer', description='Test', status='accepted') 45 | self.consumer.generate_random_codes() 46 | self.consumer.save() 47 | 48 | def tearDown(self): 49 | super(OAuthTests, self).tearDown() 50 | self.consumer.delete() 51 | 52 | def test_handshake(self): 53 | '''Test the OAuth handshake procedure 54 | ''' 55 | oaconsumer = oauth.OAuthConsumer(self.consumer.key, self.consumer.secret) 56 | 57 | # Get a request key... 58 | request = oauth.OAuthRequest.from_consumer_and_token(oaconsumer, 59 | http_url='http://testserver/api/oauth/request_token') 60 | request.sign_request(self.signature_method, oaconsumer, None) 61 | 62 | response = self.client.get('/api/oauth/request_token', request.parameters) 63 | oatoken = oauth.OAuthToken.from_string(response.content) 64 | 65 | token = Token.objects.get(key=oatoken.key, token_type=Token.REQUEST) 66 | self.assertEqual(token.secret, oatoken.secret) 67 | 68 | # Simulate user authentication... 69 | self.failUnless(self.client.login(username='admin', password='admin')) 70 | request = oauth.OAuthRequest.from_token_and_callback(token=oatoken, 71 | callback='http://printer.example.com/request_token_ready', 72 | http_url='http://testserver/api/oauth/authorize') 73 | request.sign_request(self.signature_method, oaconsumer, oatoken) 74 | 75 | # Request the login page 76 | # TODO: Parse the response to make sure all the fields exist 77 | # response = self.client.get('/api/oauth/authorize', { 78 | # 'oauth_token': oatoken.key, 79 | # 'oauth_callback': 'http://printer.example.com/request_token_ready', 80 | # }) 81 | 82 | response = self.client.post('/api/oauth/authorize', { 83 | 'oauth_token': oatoken.key, 84 | 'oauth_callback': 'http://printer.example.com/request_token_ready', 85 | 'csrf_signature': OAuthAuthenticationForm.get_csrf_signature(settings.SECRET_KEY, oatoken.key), 86 | 'authorize_access': 1, 87 | }) 88 | 89 | # Response should be a redirect... 90 | self.assertEqual(302, response.status_code) 91 | self.assertEqual('http://printer.example.com/request_token_ready?oauth_token='+oatoken.key, response['Location']) 92 | 93 | # Obtain access token... 94 | request = oauth.OAuthRequest.from_consumer_and_token(oaconsumer, token=oatoken, 95 | http_url='http://testserver/api/oauth/access_token') 96 | request.sign_request(self.signature_method, oaconsumer, oatoken) 97 | response = self.client.get('/api/oauth/access_token', request.parameters) 98 | 99 | oa_atoken = oauth.OAuthToken.from_string(response.content) 100 | atoken = Token.objects.get(key=oa_atoken.key, token_type=Token.ACCESS) 101 | self.assertEqual(atoken.secret, oa_atoken.secret) 102 | 103 | 104 | class MultiXMLTests(MainTests): 105 | def init_delegate(self): 106 | self.t1_data = TestModel() 107 | self.t1_data.save() 108 | self.t2_data = TestModel() 109 | self.t2_data.save() 110 | 111 | def test_multixml(self): 112 | expected = '\nNoneNoneNoneNone' 113 | result = self.client.get('/api/entries.xml', 114 | HTTP_AUTHORIZATION=self.auth_string).content 115 | self.assertEquals(expected, result) 116 | 117 | def test_singlexml(self): 118 | obj = TestModel.objects.all()[0] 119 | expected = '\nNoneNone' 120 | result = self.client.get('/api/entry-%d.xml' % (obj.pk,), 121 | HTTP_AUTHORIZATION=self.auth_string).content 122 | self.assertEquals(expected, result) 123 | 124 | class AbstractBaseClassTests(MainTests): 125 | def init_delegate(self): 126 | self.ab1 = InheritedModel() 127 | self.ab1.save() 128 | self.ab2 = InheritedModel() 129 | self.ab2.save() 130 | 131 | def test_field_presence(self): 132 | result = self.client.get('/api/abstract.json', 133 | HTTP_AUTHORIZATION=self.auth_string).content 134 | 135 | expected = """[ 136 | { 137 | "id": 1, 138 | "some_other": "something else", 139 | "some_field": "something here" 140 | }, 141 | { 142 | "id": 2, 143 | "some_other": "something else", 144 | "some_field": "something here" 145 | } 146 | ]""" 147 | 148 | self.assertEquals(result, expected) 149 | 150 | def test_specific_id(self): 151 | ids = (1, 2) 152 | be = """{ 153 | "id": %d, 154 | "some_other": "something else", 155 | "some_field": "something here" 156 | }""" 157 | 158 | for id_ in ids: 159 | result = self.client.get('/api/abstract/%d.json' % id_, 160 | HTTP_AUTHORIZATION=self.auth_string).content 161 | 162 | expected = be % id_ 163 | 164 | self.assertEquals(result, expected) 165 | 166 | class IncomingExpressiveTests(MainTests): 167 | def init_delegate(self): 168 | e1 = ExpressiveTestModel(title="foo", content="bar") 169 | e1.save() 170 | e2 = ExpressiveTestModel(title="foo2", content="bar2") 171 | e2.save() 172 | 173 | def test_incoming_json(self): 174 | outgoing = simplejson.dumps({ 'title': 'test', 'content': 'test', 175 | 'comments': [ { 'content': 'test1' }, 176 | { 'content': 'test2' } ] }) 177 | 178 | expected = """[ 179 | { 180 | "content": "bar", 181 | "comments": [], 182 | "title": "foo" 183 | }, 184 | { 185 | "content": "bar2", 186 | "comments": [], 187 | "title": "foo2" 188 | } 189 | ]""" 190 | 191 | result = self.client.get('/api/expressive.json', 192 | HTTP_AUTHORIZATION=self.auth_string).content 193 | 194 | self.assertEquals(result, expected) 195 | 196 | resp = self.client.post('/api/expressive.json', outgoing, content_type='application/json', 197 | HTTP_AUTHORIZATION=self.auth_string) 198 | 199 | self.assertEquals(resp.status_code, 201) 200 | 201 | expected = """[ 202 | { 203 | "content": "bar", 204 | "comments": [], 205 | "title": "foo" 206 | }, 207 | { 208 | "content": "bar2", 209 | "comments": [], 210 | "title": "foo2" 211 | }, 212 | { 213 | "content": "test", 214 | "comments": [ 215 | { 216 | "content": "test1" 217 | }, 218 | { 219 | "content": "test2" 220 | } 221 | ], 222 | "title": "test" 223 | } 224 | ]""" 225 | 226 | result = self.client.get('/api/expressive.json', 227 | HTTP_AUTHORIZATION=self.auth_string).content 228 | 229 | self.assertEquals(result, expected) 230 | 231 | def test_incoming_invalid_json(self): 232 | resp = self.client.post('/api/expressive.json', 233 | 'foo', 234 | HTTP_AUTHORIZATION=self.auth_string, 235 | content_type='application/json') 236 | self.assertEquals(resp.status_code, 400) 237 | 238 | def test_incoming_yaml(self): 239 | if not yaml: 240 | return 241 | 242 | expected = """- comments: [] 243 | content: bar 244 | title: foo 245 | - comments: [] 246 | content: bar2 247 | title: foo2 248 | """ 249 | 250 | self.assertEquals(self.client.get('/api/expressive.yaml', 251 | HTTP_AUTHORIZATION=self.auth_string).content, expected) 252 | 253 | outgoing = yaml.dump({ 'title': 'test', 'content': 'test', 254 | 'comments': [ { 'content': 'test1' }, 255 | { 'content': 'test2' } ] }) 256 | 257 | resp = self.client.post('/api/expressive.json', outgoing, content_type='application/x-yaml', 258 | HTTP_AUTHORIZATION=self.auth_string) 259 | 260 | self.assertEquals(resp.status_code, 201) 261 | 262 | expected = """- comments: [] 263 | content: bar 264 | title: foo 265 | - comments: [] 266 | content: bar2 267 | title: foo2 268 | - comments: 269 | - {content: test1} 270 | - {content: test2} 271 | content: test 272 | title: test 273 | """ 274 | self.assertEquals(self.client.get('/api/expressive.yaml', 275 | HTTP_AUTHORIZATION=self.auth_string).content, expected) 276 | 277 | def test_incoming_invalid_yaml(self): 278 | resp = self.client.post('/api/expressive.yaml', 279 | ' 8**sad asj lja foo', 280 | HTTP_AUTHORIZATION=self.auth_string, 281 | content_type='application/yaml') 282 | self.assertEquals(resp.status_code, 400) 283 | 284 | class Issue36RegressionTests(MainTests): 285 | """ 286 | This testcase addresses #36 in django-piston where request.FILES is passed 287 | empty to the handler if the request.method is PUT. 288 | """ 289 | def fetch_request(self, sender, request, *args, **kwargs): 290 | self.request = request 291 | 292 | def setUp(self): 293 | super(self.__class__, self).setUp() 294 | self.data = TestModel() 295 | self.data.save() 296 | # Register to the WSGIRequest signals to get the latest generated 297 | # request object. 298 | signals.entry_request_started.connect(self.fetch_request) 299 | 300 | def tearDown(self): 301 | super(self.__class__, self).tearDown() 302 | self.data.delete() 303 | signals.entry_request_started.disconnect(self.fetch_request) 304 | 305 | def test_simple(self): 306 | # First try it with POST to see if it works there 307 | if True: 308 | fp = open(__file__, 'r') 309 | try: 310 | response = self.client.post('/api/entries.xml', 311 | {'file':fp}, HTTP_AUTHORIZATION=self.auth_string) 312 | self.assertEquals(1, len(self.request.FILES), 'request.FILES on POST is empty when it should contain 1 file') 313 | finally: 314 | fp.close() 315 | 316 | if not hasattr(self.client, 'put'): 317 | import warnings 318 | warnings.warn('Issue36RegressionTest partially requires Django 1.1 or newer. Skipped.') 319 | return 320 | 321 | # ... and then with PUT 322 | fp = open(__file__, 'r') 323 | try: 324 | response = self.client.put('/api/entry-%d.xml' % self.data.pk, 325 | {'file': fp}, HTTP_AUTHORIZATION=self.auth_string) 326 | self.assertEquals(1, len(self.request.FILES), 'request.FILES on PUT is empty when it should contain 1 file') 327 | finally: 328 | fp.close() 329 | 330 | class ValidationTest(MainTests): 331 | def test_basic_validation_fails(self): 332 | resp = self.client.get('/api/echo') 333 | self.assertEquals(resp.status_code, 400) 334 | self.assertEquals(resp.content, 'Bad Request:
    ' 335 | '
  • msg
    • This field is required.
    • ' 336 | '
') 337 | 338 | def test_basic_validation_succeeds(self): 339 | data = {'msg': 'donuts!'} 340 | resp = self.client.get('/api/echo', data) 341 | self.assertEquals(resp.status_code, 200) 342 | self.assertEquals(data, simplejson.loads(resp.content)) 343 | 344 | class PlainOldObject(MainTests): 345 | def test_plain_object_serialization(self): 346 | resp = self.client.get('/api/popo') 347 | self.assertEquals(resp.status_code, 200) 348 | self.assertEquals({'type': 'plain', 'field': 'a field'}, simplejson.loads(resp.content)) 349 | -------------------------------------------------------------------------------- /piston/emitters.py: -------------------------------------------------------------------------------- 1 | from __future__ import generators 2 | 3 | import decimal, re, inspect 4 | 5 | try: 6 | # yaml isn't standard with python. It shouldn't be required if it 7 | # isn't used. 8 | import yaml 9 | except ImportError: 10 | yaml = None 11 | 12 | # Fallback since `any` isn't in Python <2.5 13 | try: 14 | any 15 | except NameError: 16 | def any(iterable): 17 | for element in iterable: 18 | if element: 19 | return True 20 | return False 21 | 22 | from django.db.models.query import QuerySet 23 | from django.db.models import Model, permalink 24 | from django.utils import simplejson 25 | from django.utils.xmlutils import SimplerXMLGenerator 26 | from django.utils.encoding import smart_unicode 27 | from django.core.serializers.json import DateTimeAwareJSONEncoder 28 | from django.http import HttpResponse 29 | from django.core import serializers 30 | 31 | from utils import HttpStatusCode, Mimer 32 | 33 | try: 34 | import cStringIO as StringIO 35 | except ImportError: 36 | import StringIO 37 | 38 | try: 39 | import cPickle as pickle 40 | except ImportError: 41 | import pickle 42 | 43 | re_choicedisp = re.compile('^get_(\w+)_display$') 44 | 45 | class Emitter(object): 46 | """ 47 | Super emitter. All other emitters should subclass 48 | this one. It has the `construct` method which 49 | conveniently returns a serialized `dict`. This is 50 | usually the only method you want to use in your 51 | emitter. See below for examples. 52 | """ 53 | EMITTERS = { } 54 | 55 | def __init__(self, payload, typemapper, handler, fields=(), anonymous=True): 56 | self.typemapper = typemapper 57 | self.data = payload 58 | self.handler = handler 59 | self.fields = fields 60 | self.anonymous = anonymous 61 | 62 | if isinstance(self.data, Exception): 63 | raise 64 | 65 | def method_fields(self, data, fields): 66 | if not data: 67 | return { } 68 | 69 | has = dir(data) 70 | ret = dict() 71 | 72 | for field in fields: 73 | if field in has: 74 | ret[field] = getattr(data, field) 75 | 76 | return ret 77 | 78 | def construct(self): 79 | """ 80 | Recursively serialize a lot of types, and 81 | in cases where it doesn't recognize the type, 82 | it will fall back to Django's `smart_unicode`. 83 | 84 | Returns `dict`. 85 | """ 86 | def _any(thing, fields=()): 87 | """ 88 | Dispatch, all types are routed through here. 89 | """ 90 | ret = None 91 | 92 | if isinstance(thing, QuerySet): 93 | ret = _qs(thing, fields=fields) 94 | elif isinstance(thing, (tuple, list)): 95 | ret = _list(thing) 96 | elif isinstance(thing, dict): 97 | ret = _dict(thing) 98 | elif isinstance(thing, decimal.Decimal): 99 | ret = str(thing) 100 | elif isinstance(thing, Model): 101 | ret = _model(thing, fields=fields) 102 | elif isinstance(thing, HttpResponse): 103 | raise HttpStatusCode(thing) 104 | elif inspect.isfunction(thing): 105 | if not inspect.getargspec(thing)[0]: 106 | ret = _any(thing()) 107 | elif hasattr(thing, '__emittable__'): 108 | f = thing.__emittable__ 109 | if inspect.ismethod(f) and len(inspect.getargspec(f)[0]) == 1: 110 | ret = _any(f()) 111 | else: 112 | ret = smart_unicode(thing, strings_only=True) 113 | 114 | return ret 115 | 116 | def _fk(data, field): 117 | """ 118 | Foreign keys. 119 | """ 120 | return _any(getattr(data, field.name)) 121 | 122 | def _related(data, fields=()): 123 | """ 124 | Foreign keys. 125 | """ 126 | return [ _model(m, fields) for m in data.iterator() ] 127 | 128 | def _m2m(data, field, fields=()): 129 | """ 130 | Many to many (re-route to `_model`.) 131 | """ 132 | return [ _model(m, fields) for m in getattr(data, field.name).iterator() ] 133 | 134 | def _model(data, fields=()): 135 | """ 136 | Models. Will respect the `fields` and/or 137 | `exclude` on the handler (see `typemapper`.) 138 | """ 139 | ret = { } 140 | handler = self.in_typemapper(type(data), self.anonymous) 141 | get_absolute_uri = False 142 | 143 | if handler or fields: 144 | v = lambda f: getattr(data, f.attname) 145 | 146 | if not fields: 147 | """ 148 | Fields was not specified, try to find teh correct 149 | version in the typemapper we were sent. 150 | """ 151 | mapped = self.in_typemapper(type(data), self.anonymous) 152 | get_fields = set(mapped.fields) 153 | exclude_fields = set(mapped.exclude).difference(get_fields) 154 | 155 | if 'absolute_uri' in get_fields: 156 | get_absolute_uri = True 157 | 158 | if not get_fields: 159 | get_fields = set([ f.attname.replace("_id", "", 1) 160 | for f in data._meta.fields ]) 161 | 162 | # sets can be negated. 163 | for exclude in exclude_fields: 164 | if isinstance(exclude, basestring): 165 | get_fields.discard(exclude) 166 | 167 | elif isinstance(exclude, re._pattern_type): 168 | for field in get_fields.copy(): 169 | if exclude.match(field): 170 | get_fields.discard(field) 171 | 172 | else: 173 | get_fields = set(fields) 174 | 175 | met_fields = self.method_fields(handler, get_fields) 176 | 177 | for f in data._meta.local_fields: 178 | if f.serialize and not any([ p in met_fields for p in [ f.attname, f.name ]]): 179 | if not f.rel: 180 | if f.attname in get_fields: 181 | ret[f.attname] = _any(v(f)) 182 | get_fields.remove(f.attname) 183 | else: 184 | if f.attname[:-3] in get_fields: 185 | ret[f.name] = _fk(data, f) 186 | get_fields.remove(f.name) 187 | 188 | for mf in data._meta.many_to_many: 189 | if mf.serialize and mf.attname not in met_fields: 190 | if mf.attname in get_fields: 191 | ret[mf.name] = _m2m(data, mf) 192 | get_fields.remove(mf.name) 193 | 194 | # try to get the remainder of fields 195 | for maybe_field in get_fields: 196 | 197 | if isinstance(maybe_field, (list, tuple)): 198 | model, fields = maybe_field 199 | inst = getattr(data, model, None) 200 | 201 | if inst: 202 | if hasattr(inst, 'all'): 203 | ret[model] = _related(inst, fields) 204 | elif callable(inst): 205 | if len(inspect.getargspec(inst)[0]) == 1: 206 | ret[model] = _any(inst(), fields) 207 | else: 208 | ret[model] = _model(inst, fields) 209 | 210 | elif maybe_field in met_fields: 211 | # Overriding normal field which has a "resource method" 212 | # so you can alter the contents of certain fields without 213 | # using different names. 214 | ret[maybe_field] = _any(met_fields[maybe_field](data)) 215 | 216 | else: 217 | maybe = getattr(data, maybe_field, None) 218 | if maybe: 219 | if callable(maybe): 220 | if re_choicedisp.search(maybe_field) or len(inspect.getargspec(maybe)[0]) == 1: 221 | ret[maybe_field] = _any(maybe()) 222 | else: 223 | ret[maybe_field] = _any(maybe) 224 | else: 225 | handler_f = getattr(handler or self.handler, maybe_field, None) 226 | 227 | if handler_f: 228 | ret[maybe_field] = _any(handler_f(data)) 229 | 230 | else: 231 | for f in data._meta.fields: 232 | ret[f.attname] = _any(getattr(data, f.attname)) 233 | 234 | fields = dir(data.__class__) + ret.keys() 235 | add_ons = [k for k in dir(data) if k not in fields] 236 | 237 | for k in add_ons: 238 | ret[k] = _any(getattr(data, k)) 239 | 240 | # resouce uri 241 | if self.in_typemapper(type(data), self.anonymous): 242 | handler = self.in_typemapper(type(data), self.anonymous) 243 | if hasattr(handler, 'resource_uri'): 244 | url_id, fields = handler.resource_uri() 245 | ret['resource_uri'] = permalink( lambda: (url_id, 246 | (getattr(data, f) for f in fields) ) )() 247 | 248 | if hasattr(data, 'get_api_url') and 'resource_uri' not in ret: 249 | try: ret['resource_uri'] = data.get_api_url() 250 | except: pass 251 | 252 | # absolute uri 253 | if hasattr(data, 'get_absolute_url') and get_absolute_uri: 254 | try: ret['absolute_uri'] = data.get_absolute_url() 255 | except: pass 256 | 257 | return ret 258 | 259 | def _qs(data, fields=()): 260 | """ 261 | Querysets. 262 | """ 263 | return [ _any(v, fields) for v in data ] 264 | 265 | def _list(data): 266 | """ 267 | Lists. 268 | """ 269 | return [ _any(v) for v in data ] 270 | 271 | def _dict(data): 272 | """ 273 | Dictionaries. 274 | """ 275 | return dict([ (k, _any(v)) for k, v in data.iteritems() ]) 276 | 277 | # Kickstart the seralizin'. 278 | return _any(self.data, self.fields) 279 | 280 | def in_typemapper(self, model, anonymous): 281 | for klass, (km, is_anon) in self.typemapper.iteritems(): 282 | if model is km and is_anon is anonymous: 283 | return klass 284 | 285 | def render(self): 286 | """ 287 | This super emitter does not implement `render`, 288 | this is a job for the specific emitter below. 289 | """ 290 | raise NotImplementedError("Please implement render.") 291 | 292 | def stream_render(self, request, stream=True): 293 | """ 294 | Tells our patched middleware not to look 295 | at the contents, and returns a generator 296 | rather than the buffered string. Should be 297 | more memory friendly for large datasets. 298 | """ 299 | yield self.render(request) 300 | 301 | @classmethod 302 | def get(cls, format): 303 | """ 304 | Gets an emitter, returns the class and a content-type. 305 | """ 306 | if cls.EMITTERS.has_key(format): 307 | return cls.EMITTERS.get(format) 308 | 309 | raise ValueError("No emitters found for type %s" % format) 310 | 311 | @classmethod 312 | def register(cls, name, klass, content_type='text/plain'): 313 | """ 314 | Register an emitter. 315 | 316 | Parameters:: 317 | - `name`: The name of the emitter ('json', 'xml', 'yaml', ...) 318 | - `klass`: The emitter class. 319 | - `content_type`: The content type to serve response as. 320 | """ 321 | cls.EMITTERS[name] = (klass, content_type) 322 | 323 | @classmethod 324 | def unregister(cls, name): 325 | """ 326 | Remove an emitter from the registry. Useful if you don't 327 | want to provide output in one of the built-in emitters. 328 | """ 329 | return cls.EMITTERS.pop(name, None) 330 | 331 | class XMLEmitter(Emitter): 332 | def _to_xml(self, xml, data): 333 | if isinstance(data, (list, tuple)): 334 | for item in data: 335 | xml.startElement("resource", {}) 336 | self._to_xml(xml, item) 337 | xml.endElement("resource") 338 | elif isinstance(data, dict): 339 | for key, value in data.iteritems(): 340 | xml.startElement(key, {}) 341 | self._to_xml(xml, value) 342 | xml.endElement(key) 343 | else: 344 | xml.characters(smart_unicode(data)) 345 | 346 | def render(self, request): 347 | stream = StringIO.StringIO() 348 | 349 | xml = SimplerXMLGenerator(stream, "utf-8") 350 | xml.startDocument() 351 | xml.startElement("response", {}) 352 | 353 | self._to_xml(xml, self.construct()) 354 | 355 | xml.endElement("response") 356 | xml.endDocument() 357 | 358 | return stream.getvalue() 359 | 360 | Emitter.register('xml', XMLEmitter, 'text/xml; charset=utf-8') 361 | Mimer.register(lambda *a: None, ('text/xml',)) 362 | 363 | class JSONEmitter(Emitter): 364 | """ 365 | JSON emitter, understands timestamps. 366 | """ 367 | def render(self, request): 368 | cb = request.GET.get('callback') 369 | seria = simplejson.dumps(self.construct(), cls=DateTimeAwareJSONEncoder, ensure_ascii=False, indent=4) 370 | 371 | # Callback 372 | if cb: 373 | return '%s(%s)' % (cb, seria) 374 | 375 | return seria 376 | 377 | Emitter.register('json', JSONEmitter, 'application/json; charset=utf-8') 378 | Mimer.register(simplejson.loads, ('application/json',)) 379 | 380 | class YAMLEmitter(Emitter): 381 | """ 382 | YAML emitter, uses `safe_dump` to omit the 383 | specific types when outputting to non-Python. 384 | """ 385 | def render(self, request): 386 | return yaml.safe_dump(self.construct()) 387 | 388 | if yaml: # Only register yaml if it was import successfully. 389 | Emitter.register('yaml', YAMLEmitter, 'application/x-yaml; charset=utf-8') 390 | Mimer.register(yaml.load, ('application/x-yaml',)) 391 | 392 | class PickleEmitter(Emitter): 393 | """ 394 | Emitter that returns Python pickled. 395 | """ 396 | def render(self, request): 397 | return pickle.dumps(self.construct()) 398 | 399 | Emitter.register('pickle', PickleEmitter, 'application/python-pickle') 400 | Mimer.register(pickle.loads, ('application/python-pickle',)) 401 | 402 | class DjangoEmitter(Emitter): 403 | """ 404 | Emitter for the Django serialized format. 405 | """ 406 | def render(self, request, format='xml'): 407 | if isinstance(self.data, HttpResponse): 408 | return self.data 409 | elif isinstance(self.data, (int, str)): 410 | response = self.data 411 | else: 412 | response = serializers.serialize(format, self.data, indent=True) 413 | 414 | return response 415 | 416 | Emitter.register('django', DjangoEmitter, 'text/xml; charset=utf-8') 417 | -------------------------------------------------------------------------------- /piston/oauth.py: -------------------------------------------------------------------------------- 1 | import cgi 2 | import urllib 3 | import time 4 | import random 5 | import urlparse 6 | import hmac 7 | import base64 8 | 9 | VERSION = '1.0' # Hi Blaine! 10 | HTTP_METHOD = 'GET' 11 | SIGNATURE_METHOD = 'PLAINTEXT' 12 | 13 | # Generic exception class 14 | class OAuthError(RuntimeError): 15 | def get_message(self): 16 | return self._message 17 | 18 | def set_message(self, message): 19 | self._message = message 20 | 21 | message = property(get_message, set_message) 22 | 23 | def __init__(self, message='OAuth error occured.'): 24 | self.message = message 25 | 26 | # optional WWW-Authenticate header (401 error) 27 | def build_authenticate_header(realm=''): 28 | return { 'WWW-Authenticate': 'OAuth realm="%s"' % realm } 29 | 30 | # url escape 31 | def escape(s): 32 | # escape '/' too 33 | return urllib.quote(s, safe='~') 34 | 35 | # util function: current timestamp 36 | # seconds since epoch (UTC) 37 | def generate_timestamp(): 38 | return int(time.time()) 39 | 40 | # util function: nonce 41 | # pseudorandom number 42 | def generate_nonce(length=8): 43 | return ''.join(str(random.randint(0, 9)) for i in range(length)) 44 | 45 | # OAuthConsumer is a data type that represents the identity of the Consumer 46 | # via its shared secret with the Service Provider. 47 | class OAuthConsumer(object): 48 | key = None 49 | secret = None 50 | 51 | def __init__(self, key, secret): 52 | self.key = key 53 | self.secret = secret 54 | 55 | # OAuthToken is a data type that represents an End User via either an access 56 | # or request token. 57 | class OAuthToken(object): 58 | # access tokens and request tokens 59 | key = None 60 | secret = None 61 | 62 | ''' 63 | key = the token 64 | secret = the token secret 65 | ''' 66 | def __init__(self, key, secret): 67 | self.key = key 68 | self.secret = secret 69 | 70 | def to_string(self): 71 | return urllib.urlencode({'oauth_token': self.key, 'oauth_token_secret': self.secret}) 72 | 73 | # return a token from something like: 74 | # oauth_token_secret=digg&oauth_token=digg 75 | @staticmethod 76 | def from_string(s): 77 | params = cgi.parse_qs(s, keep_blank_values=False) 78 | key = params['oauth_token'][0] 79 | secret = params['oauth_token_secret'][0] 80 | return OAuthToken(key, secret) 81 | 82 | def __str__(self): 83 | return self.to_string() 84 | 85 | # OAuthRequest represents the request and can be serialized 86 | class OAuthRequest(object): 87 | ''' 88 | OAuth parameters: 89 | - oauth_consumer_key 90 | - oauth_token 91 | - oauth_signature_method 92 | - oauth_signature 93 | - oauth_timestamp 94 | - oauth_nonce 95 | - oauth_version 96 | ... any additional parameters, as defined by the Service Provider. 97 | ''' 98 | parameters = None # oauth parameters 99 | http_method = HTTP_METHOD 100 | http_url = None 101 | version = VERSION 102 | 103 | def __init__(self, http_method=HTTP_METHOD, http_url=None, parameters=None): 104 | self.http_method = http_method 105 | self.http_url = http_url 106 | self.parameters = parameters or {} 107 | 108 | def set_parameter(self, parameter, value): 109 | self.parameters[parameter] = value 110 | 111 | def get_parameter(self, parameter): 112 | try: 113 | return self.parameters[parameter] 114 | except: 115 | raise OAuthError('Parameter not found: %s' % parameter) 116 | 117 | def _get_timestamp_nonce(self): 118 | return self.get_parameter('oauth_timestamp'), self.get_parameter('oauth_nonce') 119 | 120 | # get any non-oauth parameters 121 | def get_nonoauth_parameters(self): 122 | parameters = {} 123 | for k, v in self.parameters.iteritems(): 124 | # ignore oauth parameters 125 | if k.find('oauth_') < 0: 126 | parameters[k] = v 127 | return parameters 128 | 129 | # serialize as a header for an HTTPAuth request 130 | def to_header(self, realm=''): 131 | auth_header = 'OAuth realm="%s"' % realm 132 | # add the oauth parameters 133 | if self.parameters: 134 | for k, v in self.parameters.iteritems(): 135 | auth_header += ', %s="%s"' % (k, escape(str(v))) 136 | return {'Authorization': auth_header} 137 | 138 | # serialize as post data for a POST request 139 | def to_postdata(self): 140 | return '&'.join('%s=%s' % (escape(str(k)), escape(str(v))) for k, v in self.parameters.iteritems()) 141 | 142 | # serialize as a url for a GET request 143 | def to_url(self): 144 | return '%s?%s' % (self.get_normalized_http_url(), self.to_postdata()) 145 | 146 | # return a string that consists of all the parameters that need to be signed 147 | def get_normalized_parameters(self): 148 | params = self.parameters 149 | try: 150 | # exclude the signature if it exists 151 | del params['oauth_signature'] 152 | except: 153 | pass 154 | key_values = params.items() 155 | # sort lexicographically, first after key, then after value 156 | key_values.sort() 157 | # combine key value pairs in string and escape 158 | return '&'.join('%s=%s' % (escape(str(k)), escape(str(v))) for k, v in key_values) 159 | 160 | # just uppercases the http method 161 | def get_normalized_http_method(self): 162 | return self.http_method.upper() 163 | 164 | # parses the url and rebuilds it to be scheme://host/path 165 | def get_normalized_http_url(self): 166 | parts = urlparse.urlparse(self.http_url) 167 | url_string = '%s://%s%s' % (parts[0], parts[1], parts[2]) # scheme, netloc, path 168 | return url_string 169 | 170 | # set the signature parameter to the result of build_signature 171 | def sign_request(self, signature_method, consumer, token): 172 | # set the signature method 173 | self.set_parameter('oauth_signature_method', signature_method.get_name()) 174 | # set the signature 175 | self.set_parameter('oauth_signature', self.build_signature(signature_method, consumer, token)) 176 | 177 | def build_signature(self, signature_method, consumer, token): 178 | # call the build signature method within the signature method 179 | return signature_method.build_signature(self, consumer, token) 180 | 181 | @staticmethod 182 | def from_request(http_method, http_url, headers=None, parameters=None, query_string=None): 183 | # combine multiple parameter sources 184 | if parameters is None: 185 | parameters = {} 186 | 187 | # headers 188 | if headers and 'HTTP_AUTHORIZATION' in headers: 189 | auth_header = headers['HTTP_AUTHORIZATION'] 190 | # check that the authorization header is OAuth 191 | if auth_header.index('OAuth') > -1: 192 | try: 193 | # get the parameters from the header 194 | header_params = OAuthRequest._split_header(auth_header) 195 | parameters.update(header_params) 196 | except: 197 | raise OAuthError('Unable to parse OAuth parameters from Authorization header.') 198 | 199 | # GET or POST query string 200 | if query_string: 201 | query_params = OAuthRequest._split_url_string(query_string) 202 | parameters.update(query_params) 203 | 204 | # URL parameters 205 | param_str = urlparse.urlparse(http_url)[4] # query 206 | url_params = OAuthRequest._split_url_string(param_str) 207 | parameters.update(url_params) 208 | 209 | if parameters: 210 | return OAuthRequest(http_method, http_url, parameters) 211 | 212 | return None 213 | 214 | @staticmethod 215 | def from_consumer_and_token(oauth_consumer, token=None, http_method=HTTP_METHOD, http_url=None, parameters=None): 216 | if not parameters: 217 | parameters = {} 218 | 219 | defaults = { 220 | 'oauth_consumer_key': oauth_consumer.key, 221 | 'oauth_timestamp': generate_timestamp(), 222 | 'oauth_nonce': generate_nonce(), 223 | 'oauth_version': OAuthRequest.version, 224 | } 225 | 226 | defaults.update(parameters) 227 | parameters = defaults 228 | 229 | if token: 230 | parameters['oauth_token'] = token.key 231 | 232 | return OAuthRequest(http_method, http_url, parameters) 233 | 234 | @staticmethod 235 | def from_token_and_callback(token, callback=None, http_method=HTTP_METHOD, http_url=None, parameters=None): 236 | if not parameters: 237 | parameters = {} 238 | 239 | parameters['oauth_token'] = token.key 240 | 241 | if callback: 242 | parameters['oauth_callback'] = escape(callback) 243 | 244 | return OAuthRequest(http_method, http_url, parameters) 245 | 246 | # util function: turn Authorization: header into parameters, has to do some unescaping 247 | @staticmethod 248 | def _split_header(header): 249 | params = {} 250 | parts = header.split(',') 251 | for param in parts: 252 | # ignore realm parameter 253 | if param.find('OAuth realm') > -1: 254 | continue 255 | # remove whitespace 256 | param = param.strip() 257 | # split key-value 258 | param_parts = param.split('=', 1) 259 | # remove quotes and unescape the value 260 | params[param_parts[0]] = urllib.unquote(param_parts[1].strip('\"')) 261 | return params 262 | 263 | # util function: turn url string into parameters, has to do some unescaping 264 | @staticmethod 265 | def _split_url_string(param_str): 266 | parameters = cgi.parse_qs(param_str, keep_blank_values=False) 267 | for k, v in parameters.iteritems(): 268 | parameters[k] = urllib.unquote(v[0]) 269 | return parameters 270 | 271 | # OAuthServer is a worker to check a requests validity against a data store 272 | class OAuthServer(object): 273 | timestamp_threshold = 300 # in seconds, five minutes 274 | version = VERSION 275 | signature_methods = None 276 | data_store = None 277 | 278 | def __init__(self, data_store=None, signature_methods=None): 279 | self.data_store = data_store 280 | self.signature_methods = signature_methods or {} 281 | 282 | def set_data_store(self, oauth_data_store): 283 | self.data_store = data_store 284 | 285 | def get_data_store(self): 286 | return self.data_store 287 | 288 | def add_signature_method(self, signature_method): 289 | self.signature_methods[signature_method.get_name()] = signature_method 290 | return self.signature_methods 291 | 292 | # process a request_token request 293 | # returns the request token on success 294 | def fetch_request_token(self, oauth_request): 295 | try: 296 | # get the request token for authorization 297 | token = self._get_token(oauth_request, 'request') 298 | except OAuthError: 299 | # no token required for the initial token request 300 | version = self._get_version(oauth_request) 301 | consumer = self._get_consumer(oauth_request) 302 | self._check_signature(oauth_request, consumer, None) 303 | # fetch a new token 304 | token = self.data_store.fetch_request_token(consumer) 305 | return token 306 | 307 | # process an access_token request 308 | # returns the access token on success 309 | def fetch_access_token(self, oauth_request): 310 | version = self._get_version(oauth_request) 311 | consumer = self._get_consumer(oauth_request) 312 | # get the request token 313 | token = self._get_token(oauth_request, 'request') 314 | self._check_signature(oauth_request, consumer, token) 315 | new_token = self.data_store.fetch_access_token(consumer, token) 316 | return new_token 317 | 318 | # verify an api call, checks all the parameters 319 | def verify_request(self, oauth_request): 320 | # -> consumer and token 321 | version = self._get_version(oauth_request) 322 | consumer = self._get_consumer(oauth_request) 323 | # get the access token 324 | token = self._get_token(oauth_request, 'access') 325 | self._check_signature(oauth_request, consumer, token) 326 | parameters = oauth_request.get_nonoauth_parameters() 327 | return consumer, token, parameters 328 | 329 | # authorize a request token 330 | def authorize_token(self, token, user): 331 | return self.data_store.authorize_request_token(token, user) 332 | 333 | # get the callback url 334 | def get_callback(self, oauth_request): 335 | return oauth_request.get_parameter('oauth_callback') 336 | 337 | # optional support for the authenticate header 338 | def build_authenticate_header(self, realm=''): 339 | return {'WWW-Authenticate': 'OAuth realm="%s"' % realm} 340 | 341 | # verify the correct version request for this server 342 | def _get_version(self, oauth_request): 343 | try: 344 | version = oauth_request.get_parameter('oauth_version') 345 | except: 346 | version = VERSION 347 | if version and version != self.version: 348 | raise OAuthError('OAuth version %s not supported.' % str(version)) 349 | return version 350 | 351 | # figure out the signature with some defaults 352 | def _get_signature_method(self, oauth_request): 353 | try: 354 | signature_method = oauth_request.get_parameter('oauth_signature_method') 355 | except: 356 | signature_method = SIGNATURE_METHOD 357 | try: 358 | # get the signature method object 359 | signature_method = self.signature_methods[signature_method] 360 | except: 361 | signature_method_names = ', '.join(self.signature_methods.keys()) 362 | raise OAuthError('Signature method %s not supported try one of the following: %s' % (signature_method, signature_method_names)) 363 | 364 | return signature_method 365 | 366 | def _get_consumer(self, oauth_request): 367 | consumer_key = oauth_request.get_parameter('oauth_consumer_key') 368 | if not consumer_key: 369 | raise OAuthError('Invalid consumer key.') 370 | consumer = self.data_store.lookup_consumer(consumer_key) 371 | if not consumer: 372 | raise OAuthError('Invalid consumer.') 373 | return consumer 374 | 375 | # try to find the token for the provided request token key 376 | def _get_token(self, oauth_request, token_type='access'): 377 | token_field = oauth_request.get_parameter('oauth_token') 378 | token = self.data_store.lookup_token(token_type, token_field) 379 | if not token: 380 | raise OAuthError('Invalid %s token: %s' % (token_type, token_field)) 381 | return token 382 | 383 | def _check_signature(self, oauth_request, consumer, token): 384 | timestamp, nonce = oauth_request._get_timestamp_nonce() 385 | self._check_timestamp(timestamp) 386 | self._check_nonce(consumer, token, nonce) 387 | signature_method = self._get_signature_method(oauth_request) 388 | try: 389 | signature = oauth_request.get_parameter('oauth_signature') 390 | except: 391 | raise OAuthError('Missing signature.') 392 | # validate the signature 393 | valid_sig = signature_method.check_signature(oauth_request, consumer, token, signature) 394 | if not valid_sig: 395 | key, base = signature_method.build_signature_base_string(oauth_request, consumer, token) 396 | raise OAuthError('Invalid signature. Expected signature base string: %s' % base) 397 | built = signature_method.build_signature(oauth_request, consumer, token) 398 | 399 | def _check_timestamp(self, timestamp): 400 | # verify that timestamp is recentish 401 | timestamp = int(timestamp) 402 | now = int(time.time()) 403 | lapsed = now - timestamp 404 | if lapsed > self.timestamp_threshold: 405 | raise OAuthError('Expired timestamp: given %d and now %s has a greater difference than threshold %d' % (timestamp, now, self.timestamp_threshold)) 406 | 407 | def _check_nonce(self, consumer, token, nonce): 408 | # verify that the nonce is uniqueish 409 | nonce = self.data_store.lookup_nonce(consumer, token, nonce) 410 | if nonce: 411 | raise OAuthError('Nonce already used: %s' % str(nonce)) 412 | 413 | # OAuthClient is a worker to attempt to execute a request 414 | class OAuthClient(object): 415 | consumer = None 416 | token = None 417 | 418 | def __init__(self, oauth_consumer, oauth_token): 419 | self.consumer = oauth_consumer 420 | self.token = oauth_token 421 | 422 | def get_consumer(self): 423 | return self.consumer 424 | 425 | def get_token(self): 426 | return self.token 427 | 428 | def fetch_request_token(self, oauth_request): 429 | # -> OAuthToken 430 | raise NotImplementedError 431 | 432 | def fetch_access_token(self, oauth_request): 433 | # -> OAuthToken 434 | raise NotImplementedError 435 | 436 | def access_resource(self, oauth_request): 437 | # -> some protected resource 438 | raise NotImplementedError 439 | 440 | # OAuthDataStore is a database abstraction used to lookup consumers and tokens 441 | class OAuthDataStore(object): 442 | 443 | def lookup_consumer(self, key): 444 | # -> OAuthConsumer 445 | raise NotImplementedError 446 | 447 | def lookup_token(self, oauth_consumer, token_type, token_token): 448 | # -> OAuthToken 449 | raise NotImplementedError 450 | 451 | def lookup_nonce(self, oauth_consumer, oauth_token, nonce, timestamp): 452 | # -> OAuthToken 453 | raise NotImplementedError 454 | 455 | def fetch_request_token(self, oauth_consumer): 456 | # -> OAuthToken 457 | raise NotImplementedError 458 | 459 | def fetch_access_token(self, oauth_consumer, oauth_token): 460 | # -> OAuthToken 461 | raise NotImplementedError 462 | 463 | def authorize_request_token(self, oauth_token, user): 464 | # -> OAuthToken 465 | raise NotImplementedError 466 | 467 | # OAuthSignatureMethod is a strategy class that implements a signature method 468 | class OAuthSignatureMethod(object): 469 | def get_name(self): 470 | # -> str 471 | raise NotImplementedError 472 | 473 | def build_signature_base_string(self, oauth_request, oauth_consumer, oauth_token): 474 | # -> str key, str raw 475 | raise NotImplementedError 476 | 477 | def build_signature(self, oauth_request, oauth_consumer, oauth_token): 478 | # -> str 479 | raise NotImplementedError 480 | 481 | def check_signature(self, oauth_request, consumer, token, signature): 482 | built = self.build_signature(oauth_request, consumer, token) 483 | return built == signature 484 | 485 | class OAuthSignatureMethod_HMAC_SHA1(OAuthSignatureMethod): 486 | 487 | def get_name(self): 488 | return 'HMAC-SHA1' 489 | 490 | def build_signature_base_string(self, oauth_request, consumer, token): 491 | sig = ( 492 | escape(oauth_request.get_normalized_http_method()), 493 | escape(oauth_request.get_normalized_http_url()), 494 | escape(oauth_request.get_normalized_parameters()), 495 | ) 496 | 497 | key = '%s&' % escape(consumer.secret) 498 | if token: 499 | key += escape(token.secret) 500 | raw = '&'.join(sig) 501 | return key, raw 502 | 503 | def build_signature(self, oauth_request, consumer, token): 504 | # build the base signature string 505 | key, raw = self.build_signature_base_string(oauth_request, consumer, token) 506 | 507 | # hmac object 508 | try: 509 | import hashlib # 2.5 510 | hashed = hmac.new(key, raw, hashlib.sha1) 511 | except: 512 | import sha # deprecated 513 | hashed = hmac.new(key, raw, sha) 514 | 515 | # calculate the digest base 64 516 | return base64.b64encode(hashed.digest()) 517 | 518 | class OAuthSignatureMethod_PLAINTEXT(OAuthSignatureMethod): 519 | 520 | def get_name(self): 521 | return 'PLAINTEXT' 522 | 523 | def build_signature_base_string(self, oauth_request, consumer, token): 524 | # concatenate the consumer key and secret 525 | sig = escape(consumer.secret) + '&' 526 | if token: 527 | sig = sig + escape(token.secret) 528 | return sig 529 | 530 | def build_signature(self, oauth_request, consumer, token): 531 | return self.build_signature_base_string(oauth_request, consumer, token) 532 | --------------------------------------------------------------------------------