├── config
├── __init__.py
├── settings
│ ├── __init__.py
│ ├── local.py
│ ├── production.py
│ └── common.py
├── wsgi.py
└── urls.py
├── posts
├── __init__.py
├── tests
│ ├── __init__.py
│ ├── test_models.py
│ ├── test_serializers.py
│ └── test_views.py
├── migrations
│ ├── __init__.py
│ ├── 0002_auto_20180424_2035.py
│ └── 0001_initial.py
├── apps.py
├── admin.py
├── permissions.py
├── factories.py
├── models.py
├── serializers.py
└── views.py
├── tags
├── __init__.py
├── tests
│ ├── __init__.py
│ └── test_views.py
├── migrations
│ ├── __init__.py
│ ├── 0002_tag_creator.py
│ └── 0001_initial.py
├── admin.py
├── apps.py
├── serializers.py
├── permissions.py
├── factories.py
├── models.py
└── views.py
├── users
├── __init__.py
├── tests
│ ├── __init__.py
│ └── unit_tests
│ │ ├── __init__.py
│ │ ├── test_utils.py
│ │ ├── test_models.py
│ │ ├── test_serializers.py
│ │ └── test_views.py
├── migrations
│ ├── __init__.py
│ └── 0001_initial.py
├── apps.py
├── admin.py
├── disable_csrf_middleware.py
├── utils.py
├── permissions.py
├── factories.py
├── jwt_middleware.py
├── models.py
├── serializers.py
├── validators.py
├── mugshot.py
└── views.py
├── utils
├── __init__.py
├── mixins.py
└── rest_tools.py
├── balance
├── __init__.py
├── migrations
│ ├── __init__.py
│ ├── 0002_record_user.py
│ └── 0001_initial.py
├── views.py
├── tests.py
├── apps.py
├── admin.py
├── serializers.py
├── factories.py
├── models.py
└── permissions.py
├── db_tools
├── __init__.py
├── fake_db_fast.py
└── fake_db.py
├── replies
├── tests
│ ├── __init__.py
│ ├── test_models.py
│ ├── test_serializers.py
│ ├── test_moderation.py
│ └── test_views.py
├── migrations
│ ├── __init__.py
│ ├── 0002_reply_user.py
│ └── 0001_initial.py
├── admin.py
├── __init__.py
├── permissions.py
├── apps.py
├── urls.py
├── factories.py
├── models.py
├── moderation.py
├── views.py
└── serializers.py
├── notifications_extension
├── urls.py
├── __init__.py
├── tests
│ ├── __init__.py
│ └── test_views.py
├── migrations
│ └── __init__.py
├── models.py
├── admin.py
├── apps.py
├── filters.py
├── views.py
└── serializers.py
├── requirements
├── production.txt
├── local.txt
└── base.txt
├── .coveragerc
├── .travis.yml
├── runtests.py
├── README.md
├── templates
└── users
│ ├── account
│ ├── password_set.html
│ ├── password_change.html
│ ├── signup.html
│ ├── password_reset.html
│ ├── email_confirm.html
│ ├── login.html
│ ├── base.html
│ └── email.html
│ └── socialaccount
│ └── connections.html
├── manage.py
├── LICENSE
└── .gitignore
/config/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/posts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tags/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/users/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/balance/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/db_tools/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/posts/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tags/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/users/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/config/settings/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/posts/migrations/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/replies/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tags/migrations/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/users/migrations/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/balance/migrations/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/notifications_extension/urls.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/replies/migrations/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/notifications_extension/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirements/production.txt:
--------------------------------------------------------------------------------
1 | -r base.txt
--------------------------------------------------------------------------------
/users/tests/unit_tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/notifications_extension/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/notifications_extension/migrations/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/balance/views.py:
--------------------------------------------------------------------------------
1 | from rest_framework import viewsets
2 |
--------------------------------------------------------------------------------
/requirements/local.txt:
--------------------------------------------------------------------------------
1 | -r base.txt
2 |
3 | isort==4.3.4
--------------------------------------------------------------------------------
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | omit =
3 | db_tools/*
4 | manage.py
5 | runtests.py
--------------------------------------------------------------------------------
/balance/tests.py:
--------------------------------------------------------------------------------
1 | from django.test import TestCase
2 |
3 | # Create your tests here.
4 |
--------------------------------------------------------------------------------
/notifications_extension/models.py:
--------------------------------------------------------------------------------
1 | from django.db import models
2 |
3 | # Create your models here.
4 |
--------------------------------------------------------------------------------
/posts/apps.py:
--------------------------------------------------------------------------------
1 | from django.apps import AppConfig
2 |
3 |
4 | class PostsConfig(AppConfig):
5 | name = 'posts'
6 |
--------------------------------------------------------------------------------
/tags/admin.py:
--------------------------------------------------------------------------------
1 | from django.contrib import admin
2 |
3 | from .models import Tag
4 |
5 | admin.site.register(Tag)
6 |
--------------------------------------------------------------------------------
/tags/apps.py:
--------------------------------------------------------------------------------
1 | from django.apps import AppConfig
2 |
3 |
4 | class TagsConfig(AppConfig):
5 | name = 'tags'
6 |
--------------------------------------------------------------------------------
/users/apps.py:
--------------------------------------------------------------------------------
1 | from django.apps import AppConfig
2 |
3 |
4 | class UsersConfig(AppConfig):
5 | name = 'users'
6 |
--------------------------------------------------------------------------------
/posts/admin.py:
--------------------------------------------------------------------------------
1 | from django.contrib import admin
2 |
3 | from .models import Post
4 |
5 | admin.site.register(Post)
6 |
--------------------------------------------------------------------------------
/users/admin.py:
--------------------------------------------------------------------------------
1 | from django.contrib import admin
2 |
3 | from .models import User
4 |
5 | admin.site.register(User)
6 |
--------------------------------------------------------------------------------
/balance/apps.py:
--------------------------------------------------------------------------------
1 | from django.apps import AppConfig
2 |
3 |
4 | class BalanceConfig(AppConfig):
5 | name = 'balance'
6 |
--------------------------------------------------------------------------------
/replies/admin.py:
--------------------------------------------------------------------------------
1 | from django.contrib import admin
2 |
3 | from .models import Reply
4 |
5 | admin.site.register(Reply)
6 |
--------------------------------------------------------------------------------
/replies/__init__.py:
--------------------------------------------------------------------------------
1 | default_app_config = 'replies.apps.RepliesConfig'
2 |
3 |
4 | def get_model():
5 | from .models import Reply
6 | return Reply
7 |
--------------------------------------------------------------------------------
/notifications_extension/admin.py:
--------------------------------------------------------------------------------
1 | from django.contrib import admin
2 |
3 | # from notifications.models import Notification
4 | #
5 | # admin.site.register(Notification)
6 |
--------------------------------------------------------------------------------
/notifications_extension/apps.py:
--------------------------------------------------------------------------------
1 | from django.apps import AppConfig
2 |
3 |
4 | class NotificationExtensionConfig(AppConfig):
5 | name = 'notifications_extension'
6 |
--------------------------------------------------------------------------------
/users/disable_csrf_middleware.py:
--------------------------------------------------------------------------------
1 | # Todo: 临时关闭 csrf 校验,需仔细评估安全问题
2 |
3 | from django.utils.deprecation import MiddlewareMixin
4 |
5 |
6 | class DisableCSRFCheck(MiddlewareMixin):
7 | def process_request(self, request):
8 | setattr(request, '_dont_enforce_csrf_checks', True)
9 |
--------------------------------------------------------------------------------
/replies/permissions.py:
--------------------------------------------------------------------------------
1 | from rest_framework import permissions
2 |
3 |
4 | class NotSelf(permissions.BasePermission):
5 | def has_object_permission(self, request, view, obj):
6 | if request.method in permissions.SAFE_METHODS:
7 | return True
8 |
9 | return obj.user != request.user
10 |
--------------------------------------------------------------------------------
/tags/serializers.py:
--------------------------------------------------------------------------------
1 | from rest_framework import serializers
2 |
3 | from .models import Tag
4 |
5 |
6 | class TagSerializer(serializers.HyperlinkedModelSerializer):
7 |
8 | class Meta:
9 | model = Tag
10 | fields = (
11 | 'id',
12 | 'url',
13 | 'name',
14 | )
15 |
--------------------------------------------------------------------------------
/balance/admin.py:
--------------------------------------------------------------------------------
1 | from django.contrib import admin
2 |
3 | from .models import Record
4 |
5 |
6 | class RecordAdmin(admin.ModelAdmin):
7 | list_display = (
8 | 'created_time',
9 | 'reward_type',
10 | 'coin_type',
11 | 'amount',
12 | 'description',
13 | 'user',
14 | )
15 |
16 |
17 | admin.site.register(Record, RecordAdmin)
18 |
--------------------------------------------------------------------------------
/tags/permissions.py:
--------------------------------------------------------------------------------
1 | from rest_framework import permissions
2 |
3 |
4 | class TagPermissionOrReadOnly(permissions.IsAdminUser):
5 | """
6 | 仅允许管理员添加标签,(以后可添加普通用户的权限管理),其他用户只读
7 | """
8 | def has_permission(self, request, view):
9 | return (super(TagPermissionOrReadOnly, self).has_permission(request, view) or
10 | request.method in permissions.SAFE_METHODS)
11 |
--------------------------------------------------------------------------------
/tags/factories.py:
--------------------------------------------------------------------------------
1 | # factories that automatically create user data
2 | import factory
3 |
4 | from users.factories import UserFactory
5 |
6 | from .models import Tag
7 |
8 |
9 | class TagFactory(factory.DjangoModelFactory):
10 | class Meta:
11 | model = Tag
12 |
13 | name = factory.Sequence(lambda n: 'tag%s' % n)
14 | creator = factory.SubFactory(UserFactory, username='admin')
15 |
--------------------------------------------------------------------------------
/posts/permissions.py:
--------------------------------------------------------------------------------
1 | from rest_framework import permissions
2 |
3 |
4 | class IsAdminAuthorOrReadOnly(permissions.BasePermission):
5 | """
6 | 允许普通用户编辑自己的帖子, 管理员可以编辑所有帖子
7 | """
8 | def has_object_permission(self, request, view, obj):
9 | return (request.method in permissions.SAFE_METHODS or
10 | request.user.is_staff or
11 | request.user == obj.author)
12 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 |
3 | python:
4 | - "3.5"
5 |
6 | install:
7 | - pip install -r requirements/local.txt
8 | - pip install coveralls
9 |
10 | env:
11 | - DJANGO_SETTINGS_MODULE=config.settings.local
12 |
13 | services:
14 | - mysql
15 |
16 | script:
17 | - python runtests.py
18 | - coverage run --source=. manage.py test
19 |
20 | after_success:
21 | - coveralls
22 |
23 | branches:
24 | only:
25 | - master
26 | - dev
--------------------------------------------------------------------------------
/replies/apps.py:
--------------------------------------------------------------------------------
1 | from django.apps import AppConfig
2 |
3 |
4 | class RepliesConfig(AppConfig):
5 | name = 'replies'
6 |
7 | def ready(self):
8 | from .moderation import moderator
9 | from .moderation import ReplyModerator
10 | from posts.models import Post
11 | from actstream import registry
12 | registry.register(self.get_model('Reply'))
13 | registry.register(Post)
14 | moderator.register(Post, ReplyModerator)
15 |
--------------------------------------------------------------------------------
/config/wsgi.py:
--------------------------------------------------------------------------------
1 | """
2 | WSGI config for DjangoChina project.
3 |
4 | It exposes the WSGI callable as a module-level variable named ``application``.
5 |
6 | For more information on this file, see
7 | https://docs.djangoproject.com/en/1.11/howto/deployment/wsgi/
8 | """
9 |
10 | import os
11 |
12 | from django.core.wsgi import get_wsgi_application
13 |
14 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings.production")
15 |
16 | application = get_wsgi_application()
17 |
--------------------------------------------------------------------------------
/users/utils.py:
--------------------------------------------------------------------------------
1 | from django.conf import settings
2 | from ipware import get_client_ip
3 |
4 |
5 | def get_ip_address_from_request(request):
6 | """
7 | 返回request里的IP地址
8 | 提示:
9 | 为了开发方便,这个函数会返回类似127.0.0.1之类无法在公网被路由的地址,
10 | 在生产环境中,类似地址不会被返回
11 | """
12 | ip, is_routable = get_client_ip(request)
13 | if settings.DEBUG:
14 | return ip
15 | else:
16 | if ip is not None and is_routable:
17 | return ip
18 | return None
19 |
--------------------------------------------------------------------------------
/replies/urls.py:
--------------------------------------------------------------------------------
1 | from django.conf.urls import url
2 |
3 | from . import views
4 |
5 | app_name = 'replies'
6 | urlpatterns = [
7 | # url(r'^$', views.ReplyCreateView.as_view(), name='create_reply'),
8 | # url(r'^like/$', views.ReplyLikeCreateView.as_view(), name='like_reply'),
9 | #
10 | # # 以下两条 API 仅用于测试
11 | # url(r'^flat-list/$', views.FlatReplyListView.as_view(), name='flat_reply_list'),
12 | # url(r'^tree-list/$', views.TreeReplyListView.as_view(), name='tree_reply_list'),
13 | ]
14 |
--------------------------------------------------------------------------------
/runtests.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import django
5 | from django.conf import settings
6 | from django.test.utils import get_runner
7 |
8 | if __name__ == "__main__":
9 | os.environ['DJANGO_SETTINGS_MODULE'] = 'config.settings.local'
10 | django.setup()
11 | TestRunner = get_runner(settings)
12 | test_runner = TestRunner()
13 | failures = test_runner.run_tests([
14 | 'replies', 'posts', 'tags', 'users', 'balance', 'notifications_extension'
15 | ])
16 | sys.exit(bool(failures))
17 |
--------------------------------------------------------------------------------
/tags/models.py:
--------------------------------------------------------------------------------
1 | from django.conf import settings
2 | from django.db import models
3 | from model_utils.models import TimeStampedModel
4 |
5 |
6 | class Tag(TimeStampedModel):
7 | name = models.CharField("标签名", max_length=100, unique=True)
8 | creator = models.ForeignKey(
9 | settings.AUTH_USER_MODEL,
10 | verbose_name="创建者",
11 | on_delete=models.CASCADE
12 | )
13 |
14 | class Meta:
15 | verbose_name = "标签"
16 | verbose_name_plural = "标签"
17 |
18 | def __str__(self):
19 | return self.name
20 |
--------------------------------------------------------------------------------
/utils/mixins.py:
--------------------------------------------------------------------------------
1 | class EagerLoaderMixin(object):
2 | """
3 | 这个mixin包含一个通用的方法,可以通过select_related和prefetch_related提前告知Django
4 | 需要加载的外键信息,任何需要跨表查询的serializer都应该绑定这个mixin,并且在使用的场景里
5 | (通常都是view)使用这个方法加入查询的外键名
6 | """
7 | @staticmethod
8 | def setup_eager_loading(queryset, select_related=None, prefetch_related=None):
9 | if select_related:
10 | queryset = queryset.select_related(*select_related)
11 | if prefetch_related:
12 | queryset = queryset.prefetch_related(*prefetch_related)
13 | return queryset
14 |
--------------------------------------------------------------------------------
/users/permissions.py:
--------------------------------------------------------------------------------
1 | from rest_framework import permissions
2 |
3 |
4 | class IsVerified(permissions.BasePermission):
5 | """
6 | 设置主 email 时必须为已验证的 email
7 | """
8 |
9 | def has_object_permission(self, request, view, obj):
10 | if obj.verified:
11 | return True
12 | return False
13 |
14 |
15 | class NotPrimary(permissions.BasePermission):
16 | """
17 | 不能删除主 email
18 | """
19 |
20 | def has_object_permission(self, request, view, obj):
21 | if obj.primary:
22 | return False
23 | return True
24 |
--------------------------------------------------------------------------------
/balance/serializers.py:
--------------------------------------------------------------------------------
1 | from rest_framework import serializers
2 |
3 | from balance.models import Record
4 |
5 |
6 | class BalanceSerializer(serializers.ModelSerializer):
7 | class Meta:
8 | model = Record
9 | fields = (
10 | 'reward_type',
11 | 'coin_type',
12 | 'amount',
13 | 'description',
14 | 'user',
15 | )
16 | read_only_fields = (
17 | 'reward_type',
18 | 'coin_type',
19 | 'amount',
20 | 'description',
21 | 'user',
22 | )
23 |
--------------------------------------------------------------------------------
/users/factories.py:
--------------------------------------------------------------------------------
1 | # factories that automatically create user data
2 | import factory
3 |
4 | from users.models import User
5 |
6 |
7 | class UserFactory(factory.DjangoModelFactory):
8 | class Meta:
9 | model = User
10 |
11 | username = factory.Sequence(lambda n: 'user%s' % n)
12 | email = factory.LazyAttribute(lambda o: '%s@example.com' % o.username)
13 | password = 'password'
14 | mugshot = factory.django.ImageField()
15 |
16 | @classmethod
17 | def _create(cls, model_class, *args, **kwargs):
18 | manager = cls._get_manager(model_class)
19 | return manager.create_user(*args, **kwargs)
20 |
--------------------------------------------------------------------------------
/balance/factories.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | import factory
5 |
6 | from users.factories import UserFactory
7 |
8 | from .models import Record
9 |
10 |
11 | class RecordFactory(factory.DjangoModelFactory):
12 | class Meta:
13 | model = Record
14 |
15 | reward_type = 0
16 | coin_type = 2
17 | user = factory.SubFactory(UserFactory)
18 |
19 | @factory.lazy_attribute
20 | def amount(self):
21 | random_amount = abs(random.gauss(10, 5))
22 | random_amount = math.ceil(random_amount)
23 |
24 | if random_amount == 0:
25 | random_amount += 1
26 |
27 | return random_amount
28 |
--------------------------------------------------------------------------------
/requirements/base.txt:
--------------------------------------------------------------------------------
1 | django==1.11
2 | django-bootstrap-form==3.4
3 | django-contrib-comments==1.8.0
4 | django-notifications-hq==1.3
5 | django-rest-auth==0.9.3
6 | django-allauth==0.35.0
7 | django-ipware==2.0.1
8 | djangorestframework-jwt==1.11.0
9 | django-mptt==0.9.0
10 | pycodestyle==2.3.1
11 | djangorestframework==3.8.0
12 | pillow==5.0.0
13 | django-filter==1.1.0
14 | markdown==2.6.11
15 | coreapi==2.3.3
16 | factory-boy==2.10.0
17 | django-model-utils==3.1.1
18 | django-cors-headers==2.2.0
19 | mysqlclient==1.3.12
20 | django-imagekit==4.0.2
21 | django-environ==0.4.5
22 | raven
23 | git+https://github.com/zmrenwu/django-activity-stream.git@master#egg=django-activity-stream
24 |
--------------------------------------------------------------------------------
/posts/factories.py:
--------------------------------------------------------------------------------
1 | import factory
2 |
3 | from users.factories import UserFactory
4 |
5 | from .models import Post
6 |
7 |
8 | class PostFactory(factory.DjangoModelFactory):
9 | title = factory.Faker('sentence')
10 | body = factory.Faker('text')
11 | author = factory.SubFactory(UserFactory)
12 |
13 | class Meta:
14 | model = Post
15 |
16 | @factory.post_generation
17 | def tags(self, create, extracted, **kwargs):
18 | if not create:
19 | # Simple build, do nothing.
20 | return
21 |
22 | if extracted:
23 | # A list of groups were passed in, use them
24 | for tag in extracted:
25 | self.tags.add(tag)
26 |
--------------------------------------------------------------------------------
/notifications_extension/filters.py:
--------------------------------------------------------------------------------
1 | import django_filters
2 |
3 | from notifications.models import Notification
4 |
5 |
6 | class NotificationFilter(django_filters.rest_framework.FilterSet):
7 | unread = django_filters.filters.CharFilter(method='unread_filter')
8 |
9 | def unread_filter(self, queryset, name, value):
10 | if value == 'true':
11 | return queryset.filter(unread='True')
12 | elif value == 'false':
13 | return queryset.filter(unread='False')
14 | elif value == 'all':
15 | return queryset
16 | else:
17 | return queryset.none()
18 |
19 | class Meta:
20 | model = Notification
21 | fields = ['unread', 'verb']
22 |
--------------------------------------------------------------------------------
/tags/migrations/0002_tag_creator.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by Django 1.11 on 2018-04-24 12:35
3 | from __future__ import unicode_literals
4 |
5 | from django.conf import settings
6 | from django.db import migrations, models
7 | import django.db.models.deletion
8 |
9 |
10 | class Migration(migrations.Migration):
11 |
12 | initial = True
13 |
14 | dependencies = [
15 | migrations.swappable_dependency(settings.AUTH_USER_MODEL),
16 | ('tags', '0001_initial'),
17 | ]
18 |
19 | operations = [
20 | migrations.AddField(
21 | model_name='tag',
22 | name='creator',
23 | field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='创建者'),
24 | ),
25 | ]
26 |
--------------------------------------------------------------------------------
/balance/migrations/0002_record_user.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by Django 1.11 on 2018-04-24 12:35
3 | from __future__ import unicode_literals
4 |
5 | from django.conf import settings
6 | from django.db import migrations, models
7 | import django.db.models.deletion
8 |
9 |
10 | class Migration(migrations.Migration):
11 |
12 | initial = True
13 |
14 | dependencies = [
15 | migrations.swappable_dependency(settings.AUTH_USER_MODEL),
16 | ('balance', '0001_initial'),
17 | ]
18 |
19 | operations = [
20 | migrations.AddField(
21 | model_name='record',
22 | name='user',
23 | field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='用户'),
24 | ),
25 | ]
26 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Django中文社区
2 |
3 | [](https://travis-ci.org/DjangoChinaOrg/Django-China-API) [](https://coveralls.io/github/DjangoChinaOrg/Django-China-API?branch=dev)
4 |
5 | ### 开发团队
6 |
7 | 见:[组织成员列表](https://github.com/orgs/DjangoChinaOrg/people)
8 |
9 | 团队目前配置:
10 |
11 | - 后端开发 3 人
12 | - 前端开发 1 人
13 | - 产品 1 人
14 |
15 | 如果有兴趣加入,欢迎随时与我们取得联系:djangostudyteam@163.com
16 |
17 | ### 技术栈
18 |
19 | 后端使用 Python 3.5,、Django 1.11、Django REST framework 开发 API
20 |
21 | 前端使用 Vue 2.0
22 |
23 | ### 第一版的样子
24 |
25 | 可能会是这样:www.pythonzh.cn
26 |
27 | ### 如何贡献?
28 |
29 | 我们将开发需求以 issue 的形式发布在项目的 issue 列表并指定了团队中的开发人员。如果您有兴趣参与,可以通过 issue 的讨论功能告知你的参与计划,开发人员就可以着手去实现其它需求,这将大大加快项目的上线速度。
--------------------------------------------------------------------------------
/replies/migrations/0002_reply_user.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by Django 1.11 on 2018-04-24 12:35
3 | from __future__ import unicode_literals
4 |
5 | from django.conf import settings
6 | from django.db import migrations, models
7 | import django.db.models.deletion
8 |
9 |
10 | class Migration(migrations.Migration):
11 |
12 | initial = True
13 |
14 | dependencies = [
15 | migrations.swappable_dependency(settings.AUTH_USER_MODEL),
16 | ('replies', '0001_initial'),
17 | ]
18 |
19 | operations = [
20 | migrations.AddField(
21 | model_name='reply',
22 | name='user',
23 | field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='reply_comments', to=settings.AUTH_USER_MODEL, verbose_name='user'),
24 | ),
25 | ]
26 |
--------------------------------------------------------------------------------
/templates/users/account/password_set.html:
--------------------------------------------------------------------------------
1 | {% extends "account/base.html" %}
2 |
3 | {% load bootstrap %}
4 | {% load i18n %}
5 |
6 | {% block head_title %}{% trans "Set Password" %}{% endblock %}
7 |
8 | {% block content %}
9 |
27 | {% endblock %}
28 |
--------------------------------------------------------------------------------
/config/settings/local.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .common import *
4 |
5 | DEBUG = True
6 | ALLOWED_HOSTS = ['*']
7 |
8 | SECRET_KEY = 't3l=1)%^^ftao(2_@p^j_$ordrl4rg4-0z1w@^gvvi64balvbx'
9 |
10 | EMAIL_BACKEND = 'django.core.mail.backends.console.EmailBackend'
11 |
12 | # envs
13 | # MYSQL_HOST = os.getenv('MYSQL_HOST', '127.0.0.1')
14 | # MYSQL_DB_NAME = os.getenv('MYSQL_MYSQL_DB_NAME', 'django')
15 | # MYSQL_DB_USER = os.getenv('MYSQL_MYSQL_DB_USER', 'root')
16 | # MYSQL_PASSWORD = os.getenv('MYSQL_PASSWORD', '')
17 |
18 | # # database
19 | # DATABASES['default'].update(
20 | # {'HOST': MYSQL_HOST,
21 | # 'NAME': MYSQL_DB_NAME,
22 | # 'USER': MYSQL_DB_USER,
23 | # 'PASSWORD': MYSQL_PASSWORD,
24 | # })
25 |
26 | # sqlite3
27 | DATABASES = {
28 | 'default': {
29 | 'ENGINE': 'django.db.backends.sqlite3',
30 | 'NAME': os.path.join(BASE_DIR, 'dev.sqlite3'),
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/templates/users/account/password_change.html:
--------------------------------------------------------------------------------
1 | {% extends "account/base.html" %}
2 |
3 | {% load bootstrap %}
4 | {% load i18n %}
5 |
6 | {% block head_title %}{% trans "Change Password" %}{% endblock %}
7 |
8 | {% block content %}
9 |
27 | {% endblock %}
28 |
--------------------------------------------------------------------------------
/manage.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import os
3 | import sys
4 |
5 | if __name__ == "__main__":
6 | env = os.environ.get('DJANGO_ENV', 'local')
7 | os.environ.setdefault("DJANGO_SETTINGS_MODULE",
8 | "config.settings.{}".format(env))
9 | try:
10 | from django.core.management import execute_from_command_line
11 | except ImportError:
12 | # The above import may fail for some other reason. Ensure that the
13 | # issue is really that Django is missing to avoid masking other
14 | # exceptions on Python 2.
15 | try:
16 | import django
17 | except ImportError:
18 | raise ImportError(
19 | "Couldn't import Django. Are you sure it's installed and "
20 | "available on your PYTHONPATH environment variable? Did you "
21 | "forget to activate a virtual environment?"
22 | )
23 | raise
24 | execute_from_command_line(sys.argv)
25 |
--------------------------------------------------------------------------------
/posts/migrations/0002_auto_20180424_2035.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by Django 1.11 on 2018-04-24 12:35
3 | from __future__ import unicode_literals
4 |
5 | from django.conf import settings
6 | from django.db import migrations, models
7 | import django.db.models.deletion
8 |
9 |
10 | class Migration(migrations.Migration):
11 |
12 | initial = True
13 |
14 | dependencies = [
15 | migrations.swappable_dependency(settings.AUTH_USER_MODEL),
16 | ('posts', '0001_initial'),
17 | ('tags', '0001_initial'),
18 | ]
19 |
20 | operations = [
21 | migrations.AddField(
22 | model_name='post',
23 | name='author',
24 | field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='作者'),
25 | ),
26 | migrations.AddField(
27 | model_name='post',
28 | name='tags',
29 | field=models.ManyToManyField(to='tags.Tag', verbose_name='标签'),
30 | ),
31 | ]
32 |
--------------------------------------------------------------------------------
/balance/models.py:
--------------------------------------------------------------------------------
1 | from django.conf import settings
2 | from django.db import models
3 | from model_utils.fields import AutoCreatedField
4 |
5 |
6 | class Record(models.Model):
7 | REWARD_TYPE = (
8 | (0, '每日签到奖励'),
9 | )
10 |
11 | COIN_TYPE = (
12 | (0, '金币'),
13 | (1, '银币'),
14 | (2, '铜币'),
15 | )
16 |
17 | created_time = AutoCreatedField(verbose_name="创建时间")
18 | reward_type = models.IntegerField(verbose_name="奖励类型", choices=REWARD_TYPE)
19 | coin_type = models.IntegerField(verbose_name="钱币类型", choices=COIN_TYPE)
20 | amount = models.PositiveIntegerField(verbose_name="数额")
21 | description = models.CharField(verbose_name="描述", max_length=300, blank=True)
22 | user = models.ForeignKey(settings.AUTH_USER_MODEL, verbose_name="用户", on_delete=models.CASCADE)
23 |
24 | class Meta:
25 | verbose_name = "奖励记录"
26 | verbose_name_plural = "奖励记录"
27 |
28 | def __str__(self):
29 | return '%s:%s:%s -> %s' % (self.reward_type, self.coin_type, self.amount, self.user)
30 |
--------------------------------------------------------------------------------
/replies/factories.py:
--------------------------------------------------------------------------------
1 | import factory
2 | from django.contrib.contenttypes.models import ContentType
3 | from django.contrib.sites.models import Site
4 |
5 | from posts.factories import PostFactory
6 | from replies.models import Reply
7 | from users.factories import UserFactory
8 |
9 |
10 | class SiteFactory(factory.DjangoModelFactory):
11 | name = factory.Sequence(lambda n: 'example_%s' % n)
12 | domain = factory.LazyAttribute(lambda o: '%s.com' % o.name)
13 |
14 | class Meta:
15 | model = Site
16 |
17 |
18 | class BaseReplyFactory(factory.DjangoModelFactory):
19 | content_type = factory.LazyAttribute(
20 | lambda o: ContentType.objects.get_for_model(o.content_object))
21 | object_pk = factory.SelfAttribute('content_object.id')
22 | user = factory.SubFactory(UserFactory)
23 | site = factory.SubFactory(SiteFactory)
24 | comment = 'test comment'
25 |
26 | class Meta:
27 | model = Reply
28 |
29 |
30 | class PostReplyFactory(BaseReplyFactory):
31 | content_object = factory.SubFactory(PostFactory)
32 |
--------------------------------------------------------------------------------
/users/tests/unit_tests/test_utils.py:
--------------------------------------------------------------------------------
1 | from django.test import TestCase
2 | from django.test.client import RequestFactory
3 |
4 | from users.utils import get_ip_address_from_request
5 |
6 |
7 | class GetIpAddressTests(TestCase):
8 | """
9 | get_ip_address_from_request函数的测试
10 | """
11 | def setUp(self):
12 | """
13 | 创建request工厂,创建测试用户
14 | """
15 | super(GetIpAddressTests, self).setUp()
16 | self.rf = RequestFactory()
17 |
18 | def test_get_ip_address_from_request(self):
19 | """
20 | 测试当request里包含IP信息的时候,会正确返回IP地址
21 | """
22 | test_ip = '210.1.1.1'
23 | request = self.rf.get('/', REMOTE_ADDR=test_ip)
24 | ip_address = get_ip_address_from_request(request)
25 | self.assertEqual(ip_address, test_ip)
26 |
27 | def test_get_ip_address_from_request__without_ip(self):
28 | """
29 | 测试当request里不包含IP信息的时候,函数仍然会正常返回
30 | """
31 | request = self.rf.get('/')
32 | ip_address = get_ip_address_from_request(request)
33 | self.assertEqual(ip_address, None)
34 |
--------------------------------------------------------------------------------
/balance/permissions.py:
--------------------------------------------------------------------------------
1 | from django.utils import timezone
2 | from rest_framework import permissions
3 |
4 | from balance.models import Record
5 |
6 |
7 | class OncePerDay(permissions.BasePermission):
8 | """
9 | 一天内(0:00:00-23:59:59)用户只能签到一次
10 | """
11 |
12 | def has_object_permission(self, request, view, obj):
13 | if request.method in permissions.SAFE_METHODS:
14 | return True
15 |
16 | try:
17 | # 获取用户最近一次签到记录
18 | # 不存在则说明用户从未签到
19 | latest_record = request.user.record_set.latest('created_time')
20 | except Record.DoesNotExist:
21 | return True
22 |
23 | # 获取当天的开始和结束时间
24 | today_start = timezone.now().replace(hour=0, minute=0, second=0)
25 | today_end = timezone.now().replace(hour=23, minute=59, second=59)
26 |
27 | if today_start <= latest_record.created_time <= today_end:
28 | return False
29 |
30 | return True
31 |
32 |
33 | class IsCurrentUser(permissions.BasePermission):
34 | def has_object_permission(self, request, view, obj):
35 | return obj == request.user
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) [2018] [DjangoChinaOrg]
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/users/tests/unit_tests/test_models.py:
--------------------------------------------------------------------------------
1 | from django.test import TestCase
2 | from django.test.client import RequestFactory
3 |
4 | from users.models import User, update_last_login_ip
5 |
6 |
7 | class UserSignalTests(TestCase):
8 | """
9 | 用户信号函数的测试
10 | """
11 | def setUp(self):
12 | """
13 | 创建request工厂,创建测试用户
14 | """
15 | super(UserSignalTests, self).setUp()
16 | self.rf = RequestFactory()
17 | self.user = User.objects.create(nickname='test-user', username='test-user')
18 |
19 | def test_update_last_login_ip(self):
20 | """
21 | 测试当request里包含IP信息的时候,会被成功存入用户的last_login_ip域
22 | """
23 | test_ip = '210.1.1.1'
24 | request = self.rf.get('/', REMOTE_ADDR=test_ip)
25 | update_last_login_ip(None, self.user, request)
26 | self.assertEqual(self.user.last_login_ip, test_ip)
27 |
28 | def test_update_last_login_ip__without_ip(self):
29 | """
30 | 测试当request里不包含IP信息的时候,这个函数仍然会正常返回
31 | """
32 | request = self.rf.get('/')
33 | update_last_login_ip(None, self.user, request)
34 | self.assertEqual(self.user.last_login_ip, None)
35 |
--------------------------------------------------------------------------------
/tags/migrations/0001_initial.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by Django 1.11 on 2018-04-24 12:35
3 | from __future__ import unicode_literals
4 |
5 | from django.db import migrations, models
6 | import django.utils.timezone
7 | import model_utils.fields
8 |
9 |
10 | class Migration(migrations.Migration):
11 |
12 | initial = True
13 |
14 | dependencies = [
15 | ]
16 |
17 | operations = [
18 | migrations.CreateModel(
19 | name='Tag',
20 | fields=[
21 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
22 | ('created', model_utils.fields.AutoCreatedField(default=django.utils.timezone.now, editable=False, verbose_name='created')),
23 | ('modified', model_utils.fields.AutoLastModifiedField(default=django.utils.timezone.now, editable=False, verbose_name='modified')),
24 | ('name', models.CharField(max_length=100, unique=True, verbose_name='标签名')),
25 | ],
26 | options={
27 | 'verbose_name': '标签',
28 | 'verbose_name_plural': '标签',
29 | },
30 | ),
31 | ]
32 |
--------------------------------------------------------------------------------
/templates/users/account/signup.html:
--------------------------------------------------------------------------------
1 | {% extends "account/base.html" %}
2 |
3 | {% load bootstrap %}
4 | {% load i18n %}
5 |
6 | {% block head_title %}{% trans "Signup" %}{% endblock %}
7 |
8 | {% block content %}
9 |
10 |
11 |
12 |
13 |
16 |
17 |
{% blocktrans %}Already have an account? Then please sign in .{% endblocktrans %}
18 |
19 |
27 |
28 |
29 |
30 |
31 |
32 |
33 | {% endblock %}
34 |
--------------------------------------------------------------------------------
/replies/models.py:
--------------------------------------------------------------------------------
1 | from actstream.models import Follow
2 | from django.contrib.contenttypes.models import ContentType
3 | from django.db import models
4 | from django_comments.abstracts import CommentAbstractModel
5 | from mptt.models import MPTTModel, TreeForeignKey
6 |
7 |
8 | class Reply(MPTTModel, CommentAbstractModel):
9 | parent = TreeForeignKey(
10 | 'self',
11 | verbose_name="上级回复",
12 | related_name='children',
13 | blank=True,
14 | null=True,
15 | on_delete=models.SET_NULL,
16 | )
17 |
18 | class Meta(CommentAbstractModel.Meta):
19 | verbose_name = "回复"
20 | verbose_name_plural = "回复"
21 |
22 | def descendants(self):
23 | """
24 | 获取回复的全部子孙回复,按回复时间正序排序
25 | """
26 | return self.get_descendants().order_by('submit_date')
27 |
28 | def descendants_count(self):
29 | return self.get_descendant_count()
30 |
31 | @property
32 | def ctype(self):
33 | return ContentType.objects.get_for_model(self)
34 |
35 | @property
36 | def ctype_id(self):
37 | return self.ctype.id
38 |
39 | @property
40 | def like_count(self):
41 | return Follow.objects.for_object(self, flag='like').count()
42 |
--------------------------------------------------------------------------------
/tags/views.py:
--------------------------------------------------------------------------------
1 | from django.db.models import Count
2 | from rest_framework import viewsets
3 | from rest_framework.decorators import action
4 | from rest_framework.response import Response
5 |
6 | from utils.rest_tools import CustomPageNumberPagination
7 |
8 | from .models import Tag
9 | from .permissions import TagPermissionOrReadOnly
10 | from .serializers import TagSerializer
11 |
12 |
13 | class TagPageNumberPagination(CustomPageNumberPagination):
14 | page_size = 120
15 | max_page_size = 120
16 |
17 |
18 | class TagViewSet(viewsets.ModelViewSet):
19 | queryset = Tag.objects.all()
20 | serializer_class = TagSerializer
21 | permission_classes = (TagPermissionOrReadOnly,)
22 | http_method_names = ['get', 'post']
23 | pagination_class = TagPageNumberPagination
24 |
25 | def perform_create(self, serializer):
26 | """
27 | 因为author字段在PostSerializer里是ReadOnly,所以这里需要手动保存
28 | """
29 | serializer.save(creator=self.request.user)
30 |
31 | @action(detail=False)
32 | def popular(self, request):
33 | """
34 | 返回帖子数量最多的前10标签
35 | """
36 | tags = self.get_queryset().annotate(
37 | num_posts=Count('post')).order_by('-num_posts')[:10]
38 | serializer = self.get_serializer(tags, many=True)
39 | return Response(serializer.data)
40 |
--------------------------------------------------------------------------------
/balance/migrations/0001_initial.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by Django 1.11 on 2018-04-24 12:35
3 | from __future__ import unicode_literals
4 |
5 | from django.db import migrations, models
6 | import django.utils.timezone
7 | import model_utils.fields
8 |
9 |
10 | class Migration(migrations.Migration):
11 |
12 | initial = True
13 |
14 | dependencies = [
15 | ]
16 |
17 | operations = [
18 | migrations.CreateModel(
19 | name='Record',
20 | fields=[
21 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
22 | ('created_time', model_utils.fields.AutoCreatedField(default=django.utils.timezone.now, editable=False, verbose_name='创建时间')),
23 | ('reward_type', models.IntegerField(choices=[(0, '每日签到奖励')], verbose_name='奖励类型')),
24 | ('coin_type', models.IntegerField(choices=[(0, '金币'), (1, '银币'), (2, '铜币')], verbose_name='钱币类型')),
25 | ('amount', models.PositiveIntegerField(verbose_name='数额')),
26 | ('description', models.CharField(blank=True, max_length=300, verbose_name='描述')),
27 | ],
28 | options={
29 | 'verbose_name': '奖励记录',
30 | 'verbose_name_plural': '奖励记录',
31 | },
32 | ),
33 | ]
34 |
--------------------------------------------------------------------------------
/templates/users/account/password_reset.html:
--------------------------------------------------------------------------------
1 | {% extends "account/base.html" %}
2 |
3 | {% load bootstrap %}
4 | {% load i18n %}
5 | {% load account %}
6 |
7 | {% block head_title %}{% trans "Password Reset" %}{% endblock %}
8 |
9 | {% block content %}
10 |
11 |
12 |
13 |
14 |
17 |
18 | {% if user.is_authenticated %}
19 | {% include "account/snippets/already_logged_in.html" %}
20 | {% endif %}
21 |
{% trans "Forgotten your password? Enter your e-mail address below, and we'll send you an e-mail allowing you to reset it." %}
22 |
23 |
28 |
{% blocktrans %}Please contact us if you have any trouble resetting your password.{% endblocktrans %}
29 |
30 |
31 |
32 |
33 |
34 | {% endblock %}
35 |
--------------------------------------------------------------------------------
/users/tests/unit_tests/test_serializers.py:
--------------------------------------------------------------------------------
1 | from django.test import TestCase
2 |
3 | from users.models import User
4 | from users.serializers import UserDetailsSerializer
5 |
6 |
7 | class UserDetailsSerializerTests(TestCase):
8 | """
9 | 用户详细信息序列器的测试
10 | """
11 | def setUp(self):
12 | """
13 | 创建测试用户
14 | """
15 | super(UserDetailsSerializerTests, self).setUp()
16 | self.user = User.objects.create(nickname='test-user', username='test-user')
17 |
18 | def test_serialize_user_details(self):
19 | """
20 | 测试序列化用户实例
21 | """
22 | serializer = UserDetailsSerializer(self.user)
23 | serialized_data = serializer.data
24 | self.assertEqual(serialized_data['id'], self.user.id)
25 | self.assertEqual(serialized_data['username'], self.user.username)
26 | self.assertEqual(serialized_data['nickname'], self.user.nickname)
27 | self.assertEqual(serialized_data['ip_joined'], self.user.ip_joined)
28 | self.assertEqual(serialized_data['last_login_ip'], self.user.last_login_ip)
29 | self.assertEqual(serialized_data['is_superuser'], self.user.is_superuser)
30 | self.assertEqual(serialized_data['is_staff'], self.user.is_staff)
31 | self.assertEqual(serialized_data['post_count'], self.user.post_set.count())
32 | self.assertEqual(serialized_data['reply_count'], self.user.reply_comments.count())
33 |
--------------------------------------------------------------------------------
/templates/users/account/email_confirm.html:
--------------------------------------------------------------------------------
1 | {% extends "account/base.html" %}
2 |
3 | {% load bootstrap %}
4 | {% load i18n %}
5 | {% load account %}
6 |
7 | {% block head_title %}{% trans "Confirm E-mail Address" %}{% endblock %}
8 |
9 |
10 | {% block content %}
11 |
12 |
13 |
14 |
15 |
18 |
19 | {% if confirmation %}
20 | {% user_display confirmation.email_address.user as user_display %}
21 |
{% blocktrans with confirmation.email_address.email as email %}Please confirm that {{ email }} is an e-mail address for user {{ user_display }}.{% endblocktrans %}
22 |
26 | {% else %}
27 | {% url 'account_email' as email_url %}
28 |
{% blocktrans %}This e-mail confirmation link expired or is invalid. Please issue a new e-mail confirmation request .{% endblocktrans %}
29 | {% endif %}
30 |
31 |
32 |
33 |
34 |
35 |
36 | {% endblock %}
37 |
--------------------------------------------------------------------------------
/posts/tests/test_models.py:
--------------------------------------------------------------------------------
1 | from django.test import TestCase
2 |
3 | from ..models import Post
4 | from users.models import User
5 | from tags.models import Tag
6 |
7 |
8 | class PostModelTests(TestCase):
9 | """
10 | 测试Post的objects和public manager
11 | """
12 | def setUp(self):
13 | self.user = User.objects.create_user(username='test',
14 | email='test@test.com',
15 | password='test',
16 | nickname='test')
17 | self.tag = Tag.objects.create(name='test_tag',
18 | creator=self.user)
19 |
20 | def test_managers(self):
21 | self.post1 = Post.objects.create(title='test title first',
22 | body='first test body',
23 | author=self.user
24 | )
25 | self.post1.tags.add(self.tag)
26 | self.post2 = Post.objects.create(title='test title second',
27 | body='second test body',
28 | author=self.user,
29 | hidden=True
30 | )
31 | self.post2.tags.add(self.tag)
32 | self.assertEqual(Post.objects.all().count(), 2)
33 | self.assertEqual(Post.public.all().count(), 1)
34 | self.assertEqual(Post.public.get().hidden, False)
35 |
--------------------------------------------------------------------------------
/posts/migrations/0001_initial.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by Django 1.11 on 2018-04-24 12:35
3 | from __future__ import unicode_literals
4 |
5 | from django.db import migrations, models
6 | import django.utils.timezone
7 | import model_utils.fields
8 |
9 |
10 | class Migration(migrations.Migration):
11 |
12 | initial = True
13 |
14 | dependencies = [
15 | ]
16 |
17 | operations = [
18 | migrations.CreateModel(
19 | name='Post',
20 | fields=[
21 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
22 | ('created', model_utils.fields.AutoCreatedField(default=django.utils.timezone.now, editable=False, verbose_name='created')),
23 | ('modified', model_utils.fields.AutoLastModifiedField(default=django.utils.timezone.now, editable=False, verbose_name='modified')),
24 | ('title', models.CharField(max_length=255, verbose_name='标题')),
25 | ('body', models.TextField(blank=True, verbose_name='正文')),
26 | ('views', models.PositiveIntegerField(default=0, editable=False, verbose_name='浏览量')),
27 | ('pinned', models.BooleanField(default=False, verbose_name='置顶')),
28 | ('highlighted', models.BooleanField(default=False, verbose_name='加精')),
29 | ('hidden', models.BooleanField(default=False, verbose_name='隐藏')),
30 | ],
31 | options={
32 | 'verbose_name': '帖子',
33 | 'verbose_name_plural': '帖子',
34 | },
35 | ),
36 | ]
37 |
--------------------------------------------------------------------------------
/posts/models.py:
--------------------------------------------------------------------------------
1 | from django.conf import settings
2 | from django.contrib.contenttypes.fields import GenericRelation
3 | from django.db import models
4 | from model_utils.models import TimeStampedModel
5 |
6 | from replies.models import Reply
7 |
8 |
9 | class PublicManager(models.Manager):
10 | def get_queryset(self):
11 | return super(PublicManager, self) \
12 | .get_queryset().filter(hidden=False).order_by('-created_time')
13 |
14 |
15 | class Post(TimeStampedModel):
16 | title = models.CharField("标题", max_length=255)
17 | body = models.TextField("正文", blank=True)
18 | views = models.PositiveIntegerField("浏览量", default=0, editable=False)
19 | pinned = models.BooleanField("置顶", default=False)
20 | highlighted = models.BooleanField("加精", default=False)
21 | hidden = models.BooleanField("隐藏", default=False)
22 | tags = models.ManyToManyField('tags.Tag', verbose_name="标签")
23 | author = models.ForeignKey(
24 | settings.AUTH_USER_MODEL,
25 | verbose_name="作者",
26 | on_delete=models.CASCADE
27 | )
28 | replies = GenericRelation(
29 | Reply,
30 | object_id_field='object_pk',
31 | content_type_field='content_type',
32 | verbose_name="回复"
33 | )
34 |
35 | objects = models.Manager()
36 | # 未隐藏的帖子
37 | public = PublicManager()
38 |
39 | class Meta:
40 | verbose_name = "帖子"
41 | verbose_name_plural = "帖子"
42 |
43 | def __str__(self):
44 | return self.title
45 |
46 | def increase_views(self):
47 | self.views += 1
48 | self.save(update_fields=['views'])
49 |
--------------------------------------------------------------------------------
/posts/tests/test_serializers.py:
--------------------------------------------------------------------------------
1 | from django.test import TestCase
2 | from django.contrib.contenttypes.models import ContentType
3 |
4 | from ..models import Post
5 | from ..serializers import IndexPostListSerializer
6 | from users.models import User
7 | from tags.models import Tag
8 |
9 |
10 | class PostSerializerTests(TestCase):
11 | """
12 | 测试帖子序列化器
13 | """
14 | def setUp(self):
15 | self.user = User.objects.create_user(username='test',
16 | email='test@test.com',
17 | password='test',
18 | nickname='test')
19 | self.tag1 = Tag.objects.create(name='test_tag',
20 | creator=self.user)
21 | self.tag2 = Tag.objects.create(name='another test_tag',
22 | creator=self.user)
23 | self.post = Post.objects.create(title='test title first',
24 | body='first test body',
25 | author=self.user)
26 | self.post.tags.add(self.tag1)
27 | self.post.tags.add(self.tag2)
28 |
29 | def test_post_serializer_detail(self):
30 | serializer = IndexPostListSerializer(self.post, context={'request': None})
31 | data = serializer.data
32 | self.assertEqual(data['id'], self.post.id)
33 | self.assertEqual(data['title'], self.post.title)
34 | self.assertEqual(data['views'], self.post.views)
35 | self.assertEqual(data['pinned'], self.post.pinned)
36 | self.assertEqual(data['highlighted'], self.post.highlighted)
37 | self.assertEqual(data['reply_count'], self.post.replies.count())
38 |
--------------------------------------------------------------------------------
/config/settings/production.py:
--------------------------------------------------------------------------------
1 | from .common import *
2 | import environ
3 |
4 | ALLOWED_HOSTS = ['.dj-china.org', 'localhost', '127.0.0.1', '0.0.0.0']
5 |
6 | env = environ.Env(
7 | # set casting, default value
8 | DEBUG=(bool, True),
9 | SOCIAL_LOGIN_GITHUB_CALLBACK_URL=(str, 'http://127.0.0.1:8000/social-auth/github/loginsuccess')
10 | )
11 |
12 | environ.Env.read_env()
13 | DEBUG = env('DEBUG') # default False
14 | SECRET_KEY = env('SECRET_KEY')
15 |
16 | # sentry dsn
17 | RAVEN_CONFIG = {
18 | 'dsn': env('SENTRY_DSN'),
19 | }
20 |
21 | # import sentry_sdk
22 | # from sentry_sdk.integrations.django import DjangoIntegration
23 | #
24 | # sentry_sdk.init(
25 | # dsn=env('SENTRY_DSN'),
26 | # integrations=[DjangoIntegration()]
27 | # )
28 |
29 | # GitHub 登录
30 | SOCIAL_LOGIN_GITHUB_CALLBACK_URL = env('SOCIAL_LOGIN_GITHUB_CALLBACK_URL')
31 |
32 | # mysql
33 | DATABASES = {
34 | 'default': {
35 | 'ENGINE': 'django.db.backends.mysql',
36 | 'NAME': env('MYSQL_NAME'),
37 | 'USER': env('MYSQL_USER'),
38 | 'PASSWORD': env('MYSQL_PASSWORD'),
39 | 'HOST': 'localhost',
40 | 'PORT': '3306',
41 | 'OPTIONS': {
42 | 'autocommit': True,
43 | 'init_command': "SET sql_mode='STRICT_TRANS_TABLES'",
44 | 'charset': 'utf8mb4',
45 | },
46 | 'TEST': {
47 | 'NAME': 'django_test',
48 | 'CHARSET': 'utf8',
49 | 'COLLATION': 'utf8_general_ci',
50 | }
51 | }
52 | }
53 |
54 | # 邮件配置,使用腾讯云企业邮箱
55 | EMAIL_BACKEND = 'django.core.mail.backends.smtp.EmailBackend'
56 | EMAIL_HOST = 'smtp.exmail.qq.com'
57 | EMAIL_PORT = 465
58 | EMAIL_USE_SSL = True
59 | EMAIL_USE_LOCALTIME = True
60 | EMAIL_HOST_USER = env('EMAIL_HOST_USER')
61 | EMAIL_HOST_PASSWORD = env('EMAIL_HOST_PASSWORD')
62 |
63 | DEFAULT_FROM_EMAIL = 'Django中文社区 <%s>' % EMAIL_HOST_USER
64 | SERVER_EMAIL = EMAIL_HOST_USER
65 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by .ignore support plugin (hsz.mobi)
2 | ### Python template
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | .hypothesis/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | .static_storage/
58 | .media/
59 | local_settings.py
60 |
61 | # Flask stuff:
62 | instance/
63 | .webassets-cache
64 |
65 | # Scrapy stuff:
66 | .scrapy
67 |
68 | # Sphinx documentation
69 | docs/_build/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # Jupyter Notebook
75 | .ipynb_checkpoints
76 |
77 | # pyenv
78 | .python-version
79 |
80 | # celery beat schedule file
81 | celerybeat-schedule
82 |
83 | # SageMath parsed files
84 | *.sage.py
85 |
86 | # Environments
87 | config/settings/.env
88 | .venv
89 | env/
90 | venv/
91 | ENV/
92 | env.bak/
93 | venv.bak/
94 |
95 | # Spyder project settings
96 | .spyderproject
97 | .spyproject
98 |
99 | # Rope project settings
100 | .ropeproject
101 |
102 | # mkdocs documentation
103 | /site
104 |
105 | # mypy
106 | .mypy_cache/
107 | *.sqlite3
108 | .idea/
109 | media/
110 |
111 | # vscode
112 | .vscode/
113 |
--------------------------------------------------------------------------------
/users/jwt_middleware.py:
--------------------------------------------------------------------------------
1 | from django.conf import settings
2 | from rest_framework_jwt.serializers import RefreshJSONWebTokenSerializer
3 | from rest_framework_jwt.settings import api_settings
4 |
5 | jwt_payload_handler = api_settings.JWT_PAYLOAD_HANDLER
6 | jwt_encode_handler = api_settings.JWT_ENCODE_HANDLER
7 |
8 |
9 | class JWTMiddleware(object):
10 | """
11 | 将用户的Session状态和JWT同步的中间件
12 | """
13 | def __init__(self, get_response):
14 | self.get_response = get_response
15 |
16 | def __call__(self, request):
17 | response = self.get_response(request)
18 | if not hasattr(request, 'user'):
19 | return response
20 | if request.user.is_authenticated():
21 | if response.status_code != 200:
22 | return response
23 |
24 | if 'JWT' in request.COOKIES:
25 | # 刷新JWT
26 | serializer = RefreshJSONWebTokenSerializer(
27 | data={'token': request.COOKIES['JWT']})
28 | if serializer.is_valid():
29 | jwt_and_user = serializer.object
30 | if jwt_and_user['user'] == request.user:
31 | jwt = jwt_and_user['token']
32 | else:
33 | jwt = jwt_encode_handler(jwt_payload_handler(request.user))
34 | else:
35 | # 旧JWT无法解析的话,创建新的JWT
36 | jwt = jwt_encode_handler(jwt_payload_handler(request.user))
37 | else:
38 | # JWT还不在cookie里的话,创建新的JWT
39 | jwt = jwt_encode_handler(jwt_payload_handler(request.user))
40 |
41 | response.set_cookie(
42 | 'JWT',
43 | value=jwt,
44 | max_age=24 * 60 * 60
45 | )
46 | else:
47 | # 用户已经登出,清理掉JWT
48 | if 'JWT' in request.COOKIES:
49 | response.delete_cookie('JWT')
50 | return response
51 |
--------------------------------------------------------------------------------
/notifications_extension/views.py:
--------------------------------------------------------------------------------
1 | from rest_framework import permissions, viewsets, mixins
2 | from notifications.models import Notification
3 | from rest_framework import filters
4 | from django_filters.rest_framework import DjangoFilterBackend
5 | from rest_framework.response import Response
6 | from rest_framework import status
7 | from rest_framework.decorators import action
8 |
9 | from .serializers import NotificationSerializer
10 | from .filters import NotificationFilter
11 |
12 |
13 | class NotificationViewSet(mixins.RetrieveModelMixin,
14 | mixins.UpdateModelMixin,
15 | mixins.DestroyModelMixin,
16 | mixins.ListModelMixin,
17 | viewsets.GenericViewSet):
18 | serializer_class = NotificationSerializer
19 | permission_classes = [permissions.IsAuthenticated, ]
20 | filter_backends = (DjangoFilterBackend, filters.OrderingFilter)
21 | ordering_fields = ('timestamp',)
22 | filter_class = NotificationFilter # 过滤器
23 |
24 | def get_queryset(self):
25 | return Notification.objects.filter(recipient=self.request.user).active()
26 |
27 | def perform_destroy(self, instance):
28 | pk = self.kwargs['pk']
29 | instance = Notification.objects.get(id=pk)
30 | instance.deleted = True
31 | instance.save()
32 |
33 | def update(self, request, *args, **kwargs):
34 | pk = self.kwargs['pk']
35 | instance = Notification.objects.get(id=pk)
36 | if instance.recipient != request.user:
37 | return Response(status=status.HTTP_403_FORBIDDEN)
38 | instance.unread = False
39 | instance.save()
40 | return Response(status=status.HTTP_200_OK)
41 |
42 | @action(methods=['post'], detail=False)
43 | def mark_all_as_read(self, request):
44 | Notification.objects.filter(recipient=request.user).mark_all_as_read(recipient=request.user)
45 | return Response(status=status.HTTP_200_OK)
46 |
--------------------------------------------------------------------------------
/templates/users/account/login.html:
--------------------------------------------------------------------------------
1 | {% extends "account/base.html" %}
2 |
3 | {% load bootstrap %}
4 | {% load i18n %}
5 | {% load account socialaccount %}
6 |
7 | {% block head_title %}{% trans "Sign In" %}{% endblock %}
8 |
9 | {% block content %}
10 |
11 | {% get_providers as socialaccount_providers %}
12 |
13 |
14 |
15 |
33 |
34 | {% if socialaccount_providers %}
35 |
36 |
39 |
40 |
41 |
42 | {% include "socialaccount/snippets/provider_list.html" with process="login" %}
43 |
44 |
45 |
46 |
47 | {% include "socialaccount/snippets/login_extra.html" %}
48 | {% else %}
49 |
{% blocktrans %}If you have not created an account yet, then please
50 | sign up first.{% endblocktrans %}
51 | {% endif %}
52 |
53 |
54 |
55 |
56 | {% endblock %}
--------------------------------------------------------------------------------
/notifications_extension/serializers.py:
--------------------------------------------------------------------------------
1 | from rest_framework import serializers
2 | from notifications.models import Notification
3 |
4 | from users.serializers import UserSimpleDetailsSerializer
5 |
6 |
7 | class NotificationSerializer(serializers.ModelSerializer):
8 | """
9 | 目前通知共有 3 种:
10 | 1. 帖子被评论,帖子作者收到通知,通知模型各字段含义为:
11 | recipient:帖子作者
12 | actor:回复者
13 | target:帖子
14 | action_object:新评论
15 | verb: 'reply'
16 |
17 | 2. 帖子的回复被其他人回复,即回复别人的回复,被回复者收到通知:
18 | recipient:被回复者
19 | actor:回复者
20 | target:帖子
21 | action_object:新回复
22 | verb: 'respond'
23 |
24 | 3. 回复被点赞:
25 | recipient:被赞者
26 | actor:回复者
27 | target:被赞的回复
28 | action_object:被赞的回复所属的帖子
29 | verb: 'like'
30 | """
31 | actor = UserSimpleDetailsSerializer()
32 | post = serializers.SerializerMethodField()
33 | reply = serializers.SerializerMethodField()
34 |
35 | def get_reply(self, obj):
36 | if obj.verb == 'like':
37 | reply = obj.target
38 | return {
39 | 'comment': reply.comment
40 | }
41 | elif obj.verb == 'reply':
42 | reply = obj.action_object
43 | return {
44 | 'comment': reply.comment
45 | }
46 | elif obj.verb == 'respond':
47 | reply = obj.action_object
48 | return {
49 | 'comment': reply.comment
50 | }
51 |
52 | def get_post(self, obj):
53 | if obj.verb == 'like':
54 | post = obj.action_object
55 | return {
56 | 'post_id': post.id,
57 | 'post_title': post.title
58 | }
59 | else:
60 | post = obj.target
61 | return {
62 | 'post_id': post.id,
63 | 'post_title': post.title
64 | }
65 |
66 | class Meta:
67 | model = Notification
68 | fields = ('id', 'unread', 'actor', 'verb', 'timestamp', 'deleted', 'recipient', 'post', 'reply')
69 |
--------------------------------------------------------------------------------
/replies/moderation.py:
--------------------------------------------------------------------------------
1 | from django_comments.moderation import CommentModerator
2 | from django_comments.moderation import Moderator as DjangoCommentModerator
3 | from notifications.signals import notify
4 |
5 |
6 | class Moderator(DjangoCommentModerator):
7 | def post_save_moderation(self, sender, comment, request, **kwargs):
8 | model = comment.content_type.model_class()
9 | if model not in self._registry:
10 | return
11 | self._registry[model].notify(comment, comment.content_object, request)
12 |
13 |
14 | class ReplyModerator(CommentModerator):
15 | def notify(self, reply, content_object, request):
16 | post_author = content_object.author
17 |
18 | if reply.parent: # 回复的回复
19 | parent_user = reply.parent.user
20 | # 通知被回复的人,自己回复自己无需通知
21 | if parent_user != reply.user:
22 | reply_data = {
23 | 'recipient': parent_user,
24 | 'verb': 'respond',
25 | 'action_object': reply,
26 | 'target': content_object,
27 | }
28 | notify.send(sender=reply.user, **reply_data)
29 |
30 | if parent_user != content_object.author and post_author != reply.user:
31 | # 如果被回复的人不是帖子作者,且不是帖子作者自己的回复,帖子作者应该收到通知
32 | comment_data = {
33 | 'recipient': post_author,
34 | 'verb': 'reply',
35 | 'action_object': reply,
36 | 'target': content_object,
37 | }
38 | notify.send(sender=reply.user, **comment_data)
39 | else:
40 | # 如果是直接回复,且不是帖子作者自己回复,则通知帖子作者
41 | if post_author != reply.user:
42 | comment_data = {
43 | 'recipient': post_author,
44 | 'verb': 'reply',
45 | 'action_object': reply,
46 | 'target': content_object,
47 | }
48 | notify.send(sender=reply.user, **comment_data)
49 |
50 |
51 | moderator = Moderator()
52 |
--------------------------------------------------------------------------------
/users/models.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from django.core.files.base import ContentFile
4 | from django.contrib.auth.models import AbstractUser
5 | from django.contrib.auth.signals import user_logged_in
6 | from django.db import models
7 |
8 | from .mugshot import Avatar
9 | from .utils import get_ip_address_from_request
10 |
11 |
12 | def user_mugshot_path(instance, filename):
13 | return os.path.join('mugshots', instance.username, filename)
14 |
15 | from django.db import models
16 | from imagekit.models import ImageSpecField
17 | from imagekit.processors import ResizeToFill
18 | class User(AbstractUser):
19 | """
20 | 用户模型定义
21 | """
22 | last_login_ip = models.GenericIPAddressField(
23 | "最近一次登陆IP",
24 | unpack_ipv4=True,
25 | blank=True,
26 | null=True
27 | )
28 | ip_joined = models.GenericIPAddressField("注册IP", unpack_ipv4=True, blank=True, null=True)
29 |
30 | nickname = models.CharField("昵称", max_length=50, unique=True)
31 | mugshot = models.ImageField("头像", upload_to=user_mugshot_path)
32 | mugshot_thumbnail = ImageSpecField(source='mugshot',
33 | processors=[ResizeToFill(100, 100)],
34 | format='JPEG',
35 | options={'quality': 60})
36 |
37 | def __str__(self):
38 | return self.username
39 |
40 | def save(self, *args, **kwargs):
41 | if not self.mugshot:
42 | avatar = Avatar(rows=10, columns=10)
43 | image_byte_array = avatar.get_image(
44 | string=self.username,
45 | width=480,
46 | height=480,
47 | pad=10
48 | )
49 | self.mugshot.save('default_mugshot.png', ContentFile(image_byte_array), save=False)
50 | if not self.pk and not self.nickname:
51 | # 自动将username存入到nickname域内
52 | self.nickname = self.username
53 | super(User, self).save(*args, **kwargs)
54 |
55 |
56 | def update_last_login_ip(sender, user, request, **kwargs):
57 | """
58 | 更新用户最后一次登陆的IP地址
59 | """
60 | ip = get_ip_address_from_request(request)
61 | if ip:
62 | user.last_login_ip = ip
63 | user.save()
64 |
65 |
66 | user_logged_in.connect(update_last_login_ip)
67 |
--------------------------------------------------------------------------------
/utils/rest_tools.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | from rest_framework.pagination import PageNumberPagination
4 | from rest_framework.response import Response
5 | from rest_framework.exceptions import NotFound
6 | from django.core.paginator import InvalidPage
7 | from django.utils import six
8 |
9 |
10 | class CustomPageNumberPagination(PageNumberPagination):
11 | page_size = 20
12 | page_size_query_param = 'page_size'
13 | page_query_param = 'page'
14 | max_page_size = 100
15 |
16 | def paginate_queryset(self, queryset, request, view=None):
17 | """
18 | Update self.page_size, if page_size_query_param appears in the request url
19 |
20 | Paginate a queryset if required, either returning a
21 | page object, or `None` if pagination is not configured for this view.
22 | """
23 | page_size = self.get_page_size(request)
24 |
25 | # this is the only difference from the PageNumberPagination's paginate_queryset function
26 | if not page_size:
27 | return None
28 | else:
29 | self.page_size = page_size
30 |
31 | paginator = self.django_paginator_class(queryset, page_size)
32 | page_number = request.query_params.get(self.page_query_param, 1)
33 | if page_number in self.last_page_strings:
34 | page_number = paginator.num_pages
35 |
36 | try:
37 | self.page = paginator.page(page_number)
38 | except InvalidPage as exc:
39 | msg = self.invalid_page_message.format(
40 | page_number=page_number, message=six.text_type(exc)
41 | )
42 | raise NotFound(msg)
43 |
44 | if paginator.num_pages > 1 and self.template is not None:
45 | # The browsable API should display pagination controls.
46 | self.display_page_controls = True
47 |
48 | self.request = request
49 | return list(self.page)
50 |
51 | def get_paginated_response(self, data):
52 | return Response(OrderedDict([
53 | ('page_size', self.page_size),
54 | ('current_page', self.page.number),
55 | ('last_page', self.page.paginator.num_pages),
56 | ('next', self.get_next_link()),
57 | ('previous', self.get_previous_link()),
58 | ('count', self.page.paginator.count),
59 | ('data', data)
60 | ]))
61 |
--------------------------------------------------------------------------------
/templates/users/socialaccount/connections.html:
--------------------------------------------------------------------------------
1 | {% extends "socialaccount/base.html" %}
2 |
3 | {% load i18n %}
4 |
5 | {% block head_title %}{% trans "Account Connections" %}{% endblock %}
6 |
7 | {% block content %}
8 |
9 |
10 |
11 |
12 |
15 |
16 | {% if form.accounts %}
17 |
{% blocktrans %}You can sign in to your account using any of the following third party accounts:{% endblocktrans %}
18 |
43 |
44 | {% else %}
45 |
{% trans 'You currently have no social network accounts connected to this account.' %}
46 | {% endif %}
47 |
48 |
{% trans 'Add a 3rd Party Account' %}
49 |
50 |
51 | {% include "socialaccount/snippets/provider_list.html" with process="connect" %}
52 |
53 |
54 | {% include "socialaccount/snippets/login_extra.html" %}
55 |
56 |
57 |
58 |
59 |
60 |
61 | {% endblock %}
62 |
--------------------------------------------------------------------------------
/replies/tests/test_models.py:
--------------------------------------------------------------------------------
1 | from django.contrib.contenttypes.models import ContentType
2 | from django.contrib.sites.models import Site
3 | from django.test import TestCase
4 | from django.utils.timezone import now, timedelta
5 |
6 | from posts.models import Post
7 | from users.models import User
8 |
9 | from ..models import Reply
10 |
11 |
12 | class RelyModelTestCase(TestCase):
13 | def setUp(self):
14 | self.user = User.objects.create_user(username='test',
15 | email='test@test.com',
16 | password='test',
17 | nickname='test')
18 | self.post = Post.objects.create(title='test title', author=self.user)
19 | self.site = Site.objects.create(name='test', domain='test.com')
20 | self.post_ct = ContentType.objects.get_for_model(self.post)
21 | self.post_id = self.post.id
22 |
23 | def test_descendants(self):
24 | """测试以正确的顺序返回了全部回复"""
25 | self.root_reply = Reply.objects.create(
26 | content_type=self.post_ct,
27 | object_pk=self.post_id,
28 | site=self.site,
29 | user=self.user,
30 | comment='root reply',
31 | submit_date=now()
32 | )
33 | self.child_reply = Reply.objects.create(
34 | content_type=self.post_ct,
35 | object_pk=self.post_id,
36 | site=self.site,
37 | user=self.user,
38 | comment='child reply',
39 | parent=self.root_reply,
40 | submit_date=now() + timedelta(minutes=1)
41 | )
42 |
43 | self.another_child_reply = Reply.objects.create(
44 | content_type=self.post_ct,
45 | object_pk=self.post_id,
46 | site=self.site,
47 | user=self.user,
48 | comment='another child reply',
49 | parent=self.root_reply,
50 | submit_date=now() + timedelta(minutes=2)
51 | )
52 |
53 | self.grandchild_reply = Reply.objects.create(
54 | content_type=self.post_ct,
55 | object_pk=self.post_id,
56 | site=self.site,
57 | user=self.user,
58 | comment='grandchild reply',
59 | parent=self.child_reply,
60 | submit_date=now() + timedelta(minutes=3)
61 | )
62 |
63 | self.assertQuerysetEqual(
64 | self.root_reply.descendants(),
65 | [repr(o) for o in [self.child_reply, self.another_child_reply, self.grandchild_reply]]
66 | )
67 |
--------------------------------------------------------------------------------
/replies/views.py:
--------------------------------------------------------------------------------
1 | from actstream.models import Follow
2 | from django.db.utils import IntegrityError
3 | from django_comments import signals
4 | from notifications.signals import notify
5 | from rest_framework import mixins
6 | from rest_framework import permissions
7 | from rest_framework import status
8 | from rest_framework import viewsets
9 | from rest_framework.decorators import action
10 | from rest_framework.response import Response
11 |
12 | from replies.permissions import NotSelf
13 | from replies.serializers import (FollowSerializer,
14 | ReplyCreationSerializer)
15 | from replies.models import Reply
16 |
17 |
18 | class ReplyViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet):
19 | serializer_class = ReplyCreationSerializer
20 | permission_classes = [permissions.IsAuthenticated, ]
21 | queryset = Reply.objects.filter(is_public=True, is_removed=False)
22 |
23 | def perform_create(self, serializer):
24 | parent_reply = serializer.validated_data.get('parent')
25 | reply = serializer.save(user=self.request.user, parent=parent_reply)
26 |
27 | # 创建相应的 notification
28 | signals.comment_was_posted.send(
29 | sender=reply.__class__,
30 | comment=reply,
31 | request=self.request
32 | )
33 |
34 | @action(
35 | methods=['post', 'delete'],
36 | detail=True,
37 | permission_classes=[NotSelf, permissions.IsAuthenticated],
38 | serializer_class=FollowSerializer,
39 | )
40 | def like(self, request, pk=None):
41 | reply = self.get_object()
42 |
43 | if self.request.method == 'POST':
44 | try:
45 | follow = Follow.objects.create(
46 | user=self.request.user,
47 | content_type=reply.ctype,
48 | object_id=reply.id,
49 | flag='like',
50 | )
51 | except IntegrityError:
52 | return Response(
53 | {'detail': '已经赞过'},
54 | status=status.HTTP_400_BAD_REQUEST
55 | )
56 | serializer = FollowSerializer(follow, context={'request': request})
57 | # 创建相应的 notification
58 | data = {
59 | 'recipient': reply.user,
60 | 'verb': 'like',
61 | 'action_object': reply.content_object,
62 | 'target': reply
63 | }
64 | notify.send(sender=self.request.user, **data)
65 | return Response(serializer.data, status=status.HTTP_201_CREATED)
66 | elif self.request.method == 'DELETE':
67 | Follow.objects.filter(
68 | user=self.request.user,
69 | content_type=reply.ctype_id,
70 | object_id=reply.id,
71 | flag='like'
72 | ).delete()
73 | return Response(status=status.HTTP_204_NO_CONTENT)
74 |
--------------------------------------------------------------------------------
/replies/migrations/0001_initial.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by Django 1.11 on 2018-04-24 12:35
3 | from __future__ import unicode_literals
4 |
5 | from django.db import migrations, models
6 | import django.db.models.deletion
7 | import mptt.fields
8 |
9 |
10 | class Migration(migrations.Migration):
11 |
12 | initial = True
13 |
14 | dependencies = [
15 | ('sites', '0002_alter_domain_unique'),
16 | ('contenttypes', '0002_remove_content_type_name'),
17 | ]
18 |
19 | operations = [
20 | migrations.CreateModel(
21 | name='Reply',
22 | fields=[
23 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
24 | ('object_pk', models.TextField(verbose_name='object ID')),
25 | ('user_name', models.CharField(blank=True, max_length=50, verbose_name="user's name")),
26 | ('user_email', models.EmailField(blank=True, max_length=254, verbose_name="user's email address")),
27 | ('user_url', models.URLField(blank=True, verbose_name="user's URL")),
28 | ('comment', models.TextField(max_length=3000, verbose_name='comment')),
29 | ('submit_date', models.DateTimeField(db_index=True, default=None, verbose_name='date/time submitted')),
30 | ('ip_address', models.GenericIPAddressField(blank=True, null=True, unpack_ipv4=True, verbose_name='IP address')),
31 | ('is_public', models.BooleanField(default=True, help_text='Uncheck this box to make the comment effectively disappear from the site.', verbose_name='is public')),
32 | ('is_removed', models.BooleanField(default=False, help_text='Check this box if the comment is inappropriate. A "This comment has been removed" message will be displayed instead.', verbose_name='is removed')),
33 | ('lft', models.PositiveIntegerField(db_index=True, editable=False)),
34 | ('rght', models.PositiveIntegerField(db_index=True, editable=False)),
35 | ('tree_id', models.PositiveIntegerField(db_index=True, editable=False)),
36 | ('level', models.PositiveIntegerField(db_index=True, editable=False)),
37 | ('content_type', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='content_type_set_for_reply', to='contenttypes.ContentType', verbose_name='content type')),
38 | ('parent', mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='children', to='replies.Reply', verbose_name='上级回复')),
39 | ('site', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='sites.Site')),
40 | ],
41 | options={
42 | 'verbose_name': '回复',
43 | 'verbose_name_plural': '回复',
44 | 'ordering': ('submit_date',),
45 | 'permissions': [('can_moderate', 'Can moderate comments')],
46 | 'abstract': False,
47 | },
48 | ),
49 | ]
50 |
--------------------------------------------------------------------------------
/templates/users/account/base.html:
--------------------------------------------------------------------------------
1 |
2 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | {% block head_title %}{% endblock %}
17 | {% block extra_head %}
18 | {% endblock %}
19 |
20 |
21 | {% block body %}
22 |
23 | {% if messages %}
24 |
25 | {% for message in messages %}
26 |
{{message}}
27 | {% endfor %}
28 |
29 | {% endif %}
30 |
31 |
32 | Django中国
33 |
34 |
35 |
36 |
37 |
38 |
39 | {% if user.is_authenticated %}
40 |
41 | 修改Email
42 |
43 |
44 | 登出
45 |
46 | {% else %}
47 |
48 | 登陆
49 |
50 |
51 | 注册
52 |
53 | {% endif %}
54 |
55 |
56 |
57 |
58 | {% block content %}
59 | {% endblock %}
60 |
61 | {% endblock %}
62 | {% block extra_body %}
63 | {% endblock %}
64 |
65 |
66 |
--------------------------------------------------------------------------------
/config/urls.py:
--------------------------------------------------------------------------------
1 | """DjangoChina URL Configuration
2 |
3 | The `urlpatterns` list routes URLs to views. For more information please see:
4 | https://docs.djangoproject.com/en/1.11/topics/http/urls/
5 | Examples:
6 | Function views
7 | 1. Add an import: from my_app import views
8 | 2. Add a URL to urlpatterns: url(r'^$', views.home, name='home')
9 | Class-based views
10 | 1. Add an import: from other_app.views import Home
11 | 2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home')
12 | Including another URLconf
13 | 1. Import the include() function: from django.conf.urls import url, include
14 | 2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls'))
15 | """
16 | from django.conf import settings
17 | from django.conf.urls import include, url
18 | from django.conf.urls.static import static
19 | from django.contrib import admin
20 | from rest_framework.documentation import include_docs_urls
21 | from rest_framework.routers import DefaultRouter
22 | from rest_framework_jwt.views import refresh_jwt_token
23 | from rest_auth.registration.views import (
24 | SocialAccountListView, SocialAccountDisconnectView
25 | )
26 |
27 | from notifications_extension.views import NotificationViewSet
28 | from posts.views import PostViewSet
29 | from replies.views import ReplyViewSet
30 | from tags.views import TagViewSet
31 | from users.views import (
32 | EmailAddressViewSet,
33 | LoginViewCustom,
34 | RegisterViewCustom,
35 | ConfirmEmailView,
36 | MugshotUploadView,
37 | UserViewSets,
38 | GitHubLogin,
39 | GitHubConnect
40 | )
41 |
42 | router = DefaultRouter()
43 | router.register(r'posts', PostViewSet)
44 | router.register(r'tags', TagViewSet)
45 | router.register(r'replies', ReplyViewSet)
46 | router.register(r'users', UserViewSets)
47 | router.register(r'users/email', EmailAddressViewSet, base_name='email')
48 | router.register(r'notifications', NotificationViewSet, base_name='notifications')
49 |
50 | urlpatterns = [
51 | url(r'^admin/', admin.site.urls),
52 | url(r'^accounts/', include('allauth.urls')),
53 | # url(r'^users/mugshot/(?P[^/]+)$', MugshotUploadView.as_view()),
54 | url(r'^rest-auth/login/$', LoginViewCustom.as_view(), name='rest_login'),
55 | url(r'^rest-auth/registration/$', RegisterViewCustom.as_view(), name='rest_register'),
56 | url(r'^rest-auth/registration/account-confirm-email/(?P[-:\w]+)/$',
57 | ConfirmEmailView.as_view(),
58 | name='account_confirm_email'),
59 | url(r'^rest-auth/github/login/$', GitHubLogin.as_view(), name='github_login'),
60 | url(r'^rest-auth/github/connect/$', GitHubConnect.as_view(), name='github_connect'),
61 | url(r'^rest-auth/socialaccounts/$',
62 | SocialAccountListView.as_view(),
63 | name='social_account_list'),
64 | url(r'^rest-auth/socialaccounts/(?P\d+)/disconnect/$',
65 | SocialAccountDisconnectView.as_view(),
66 | name='social_account_disconnect'),
67 | url(r'^rest-auth/jwt-refresh/', refresh_jwt_token),
68 | url(r'^rest-auth/', include('rest_auth.urls')),
69 | url(r'^rest-auth/registration/', include('rest_auth.registration.urls')),
70 | url(r'^api-auth/', include('rest_framework.urls')), # 仅仅用于测试
71 | url(r'^', include(router.urls)),
72 | url(r'^docs/', include_docs_urls(title='Django中文社区 API'))
73 | ] + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT)
74 |
--------------------------------------------------------------------------------
/users/serializers.py:
--------------------------------------------------------------------------------
1 | from allauth.account.adapter import get_adapter
2 | from allauth.account.models import EmailAddress
3 | from allauth.account.utils import setup_user_email
4 | from rest_auth.registration.serializers import RegisterSerializer
5 | from rest_framework import serializers
6 | from rest_framework.fields import CurrentUserDefault
7 |
8 | from .models import User
9 | from .utils import get_ip_address_from_request
10 |
11 | from .validators import FileValidator
12 |
13 |
14 | class UserDetailsSerializer(serializers.ModelSerializer):
15 | """
16 | 用户详细信息的序列器
17 | """
18 | mugshot_url = serializers.SerializerMethodField(source='mugshot_thumbnail.url')
19 | post_count = serializers.SerializerMethodField()
20 | reply_count = serializers.SerializerMethodField()
21 | mugshot = serializers.ImageField(
22 | validators=[FileValidator(max_size=2 * 1024 * 1024, allowed_extensions=('png', 'jpg', 'jpeg'))])
23 |
24 | class Meta:
25 | model = User
26 | read_only_fields = (
27 | 'id',
28 | 'username',
29 | 'date_joined',
30 | 'mugshot_url',
31 | 'ip_joined',
32 | 'last_login_ip',
33 | 'is_superuser',
34 | 'is_staff',
35 | )
36 | fields = (
37 | 'id',
38 | 'username',
39 | 'nickname',
40 | 'email',
41 | 'mugshot',
42 | 'date_joined',
43 | 'mugshot_url',
44 | 'ip_joined',
45 | 'last_login_ip',
46 | 'is_superuser',
47 | 'is_staff',
48 | 'post_count',
49 | 'reply_count',
50 | )
51 |
52 | def get_post_count(self, obj):
53 | """
54 | 返回用户提交的帖子数量
55 | """
56 | return obj.post_set.count()
57 |
58 | def get_mugshot_url(self, obj):
59 | return obj.mugshot_thumbnail.url
60 |
61 | def get_reply_count(self, obj):
62 | """
63 | 返回用户提交的回复数量
64 | """
65 | return obj.reply_comments.count()
66 |
67 |
68 | class UserRegistrationSerializer(RegisterSerializer):
69 | """
70 | 继承至rest_auth的默认序列器,增加了昵称
71 | """
72 |
73 | def save(self, request):
74 | """
75 | 改写父类的save方法,检测并存入用户的注册IP地址
76 | """
77 | adapter = get_adapter()
78 | user = adapter.new_user(request)
79 | self.cleaned_data = self.get_cleaned_data()
80 | ip = get_ip_address_from_request(request)
81 | if ip:
82 | user.ip_joined = ip
83 | adapter.save_user(request, user, self)
84 | self.custom_signup(request, user)
85 | setup_user_email(request, user, [])
86 | return user
87 |
88 |
89 | class EmailAddressSerializer(serializers.ModelSerializer):
90 | class Meta:
91 | model = EmailAddress
92 | fields = (
93 | 'id',
94 | 'user',
95 | 'email',
96 | 'verified',
97 | 'primary',
98 | )
99 | read_only_fields = ('id', 'user', 'verified', 'primary')
100 |
101 |
102 | class UserSimpleDetailsSerializer(serializers.ModelSerializer):
103 | class Meta:
104 | model = User
105 | read_only_fields = (
106 | 'id',
107 | 'username',
108 | 'nickname',
109 | )
110 | fields = (
111 | 'id',
112 | 'username',
113 | 'nickname',
114 | )
115 |
--------------------------------------------------------------------------------
/templates/users/account/email.html:
--------------------------------------------------------------------------------
1 | {% extends "account/base.html" %}
2 |
3 | {% load bootstrap %}
4 | {% load i18n %}
5 |
6 | {% block head_title %}{% trans "Account" %}{% endblock %}
7 |
8 | {% block content %}
9 |
10 |
11 |
12 |
13 |
16 |
17 | {% if user.emailaddress_set.all %}
18 |
{% trans 'The following e-mail addresses are associated with your account:' %}
19 |
43 | {% else %}
44 |
{% trans 'Warning:'%} {% trans "You currently do not have any e-mail address set up. You should really add an e-mail address so you can receive notifications, reset your password, etc." %}
45 | {% endif %}
46 |
{% trans "Add E-mail Address" %}
47 |
52 |
53 |
54 |
55 |
56 |
57 | {% endblock %}
58 |
59 |
60 | {% block extra_body %}
61 |
74 | {% endblock %}
--------------------------------------------------------------------------------
/users/migrations/0001_initial.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Generated by Django 1.11 on 2018-04-24 12:35
3 | from __future__ import unicode_literals
4 |
5 | import django.contrib.auth.models
6 | import django.contrib.auth.validators
7 | from django.db import migrations, models
8 | import django.utils.timezone
9 | import users.models
10 |
11 |
12 | class Migration(migrations.Migration):
13 |
14 | initial = True
15 |
16 | dependencies = [
17 | ('auth', '0008_alter_user_username_max_length'),
18 | ]
19 |
20 | operations = [
21 | migrations.CreateModel(
22 | name='User',
23 | fields=[
24 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
25 | ('password', models.CharField(max_length=128, verbose_name='password')),
26 | ('last_login', models.DateTimeField(blank=True, null=True, verbose_name='last login')),
27 | ('is_superuser', models.BooleanField(default=False, help_text='Designates that this user has all permissions without explicitly assigning them.', verbose_name='superuser status')),
28 | ('username', models.CharField(error_messages={'unique': 'A user with that username already exists.'}, help_text='Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.', max_length=150, unique=True, validators=[django.contrib.auth.validators.UnicodeUsernameValidator()], verbose_name='username')),
29 | ('first_name', models.CharField(blank=True, max_length=30, verbose_name='first name')),
30 | ('last_name', models.CharField(blank=True, max_length=30, verbose_name='last name')),
31 | ('email', models.EmailField(blank=True, max_length=254, verbose_name='email address')),
32 | ('is_staff', models.BooleanField(default=False, help_text='Designates whether the user can log into this admin site.', verbose_name='staff status')),
33 | ('is_active', models.BooleanField(default=True, help_text='Designates whether this user should be treated as active. Unselect this instead of deleting accounts.', verbose_name='active')),
34 | ('date_joined', models.DateTimeField(default=django.utils.timezone.now, verbose_name='date joined')),
35 | ('last_login_ip', models.GenericIPAddressField(blank=True, null=True, unpack_ipv4=True, verbose_name='最近一次登陆IP')),
36 | ('ip_joined', models.GenericIPAddressField(blank=True, null=True, unpack_ipv4=True, verbose_name='注册IP')),
37 | ('nickname', models.CharField(max_length=50, unique=True, verbose_name='昵称')),
38 | ('mugshot', models.ImageField(upload_to=users.models.user_mugshot_path, verbose_name='头像')),
39 | ('groups', models.ManyToManyField(blank=True, help_text='The groups this user belongs to. A user will get all permissions granted to each of their groups.', related_name='user_set', related_query_name='user', to='auth.Group', verbose_name='groups')),
40 | ('user_permissions', models.ManyToManyField(blank=True, help_text='Specific permissions for this user.', related_name='user_set', related_query_name='user', to='auth.Permission', verbose_name='user permissions')),
41 | ],
42 | options={
43 | 'verbose_name': 'user',
44 | 'verbose_name_plural': 'users',
45 | 'abstract': False,
46 | },
47 | managers=[
48 | ('objects', django.contrib.auth.models.UserManager()),
49 | ],
50 | ),
51 | ]
52 |
--------------------------------------------------------------------------------
/replies/tests/test_serializers.py:
--------------------------------------------------------------------------------
1 | from pprint import pprint
2 |
3 | from django.contrib.contenttypes.models import ContentType
4 | from django.contrib.sites.models import Site
5 | from django.test import RequestFactory, TestCase
6 | from django.utils.timezone import now, timedelta
7 |
8 | from posts.models import Post
9 | from users.models import User
10 |
11 | from ..models import Reply
12 | from ..serializers import TreeRepliesSerializer
13 |
14 |
15 | class ReplySerializerTestCase(TestCase):
16 | def setUp(self):
17 | self.user = User.objects.create_user(username='test',
18 | email='test@test.com',
19 | password='test',
20 | nickname='test')
21 | self.post = Post.objects.create(title='test title', author=self.user)
22 | self.site = Site.objects.create(name='test', domain='test.com')
23 | self.post_ct = ContentType.objects.get_for_model(self.post)
24 | self.post_id = self.post.id
25 |
26 | def tearDown(self):
27 | pass
28 |
29 | def test_tree_reply_serializer(self):
30 | # - root
31 | # - - child
32 | # - - - grand child
33 | # - - another_child
34 | # - another root
35 | # - - another root child
36 | self.root_reply = Reply.objects.create(
37 | content_type=self.post_ct,
38 | object_pk=self.post_id,
39 | site=self.site,
40 | user=self.user,
41 | comment='root reply',
42 | submit_date=now()
43 | )
44 | self.child_reply = Reply.objects.create(
45 | content_type=self.post_ct,
46 | object_pk=self.post_id,
47 | site=self.site,
48 | user=self.user,
49 | comment='child reply',
50 | parent=self.root_reply,
51 | submit_date=now() + timedelta(minutes=1)
52 | )
53 |
54 | self.another_child_reply = Reply.objects.create(
55 | content_type=self.post_ct,
56 | object_pk=self.post_id,
57 | site=self.site,
58 | user=self.user,
59 | comment='another child reply',
60 | parent=self.root_reply,
61 | submit_date=now() + timedelta(minutes=2)
62 | )
63 |
64 | self.grandchild_reply = Reply.objects.create(
65 | content_type=self.post_ct,
66 | object_pk=self.post_id,
67 | site=self.site,
68 | user=self.user,
69 | comment='grandchild reply',
70 | parent=self.child_reply,
71 | submit_date=now() + timedelta(minutes=3)
72 | )
73 |
74 | self.another_root_reply = Reply.objects.create(
75 | content_type=self.post_ct,
76 | object_pk=self.post_id,
77 | site=self.site,
78 | user=self.user,
79 | comment='another root reply',
80 | submit_date=now()
81 | )
82 |
83 | self.another_root_child_reply = Reply.objects.create(
84 | content_type=self.post_ct,
85 | object_pk=self.post_id,
86 | site=self.site,
87 | user=self.user,
88 | comment='another root child reply',
89 | parent=self.another_root_reply,
90 | submit_date=now() + timedelta(minutes=1)
91 | )
92 | request = RequestFactory().get('/')
93 | request.user = self.user
94 | replies = self.post.replies.filter(is_public=True, is_removed=False, parent__isnull=True)
95 | serializer = TreeRepliesSerializer(replies, many=True, context={'request': request})
96 |
97 | pprint(serializer.data)
98 |
--------------------------------------------------------------------------------
/users/validators.py:
--------------------------------------------------------------------------------
1 | # take from https://gist.github.com/jrosebr1/2140738
2 |
3 | # @brief
4 | # Performs file upload validation for django. The original version implemented
5 | # by dokterbob had some problems with determining the correct mimetype and
6 | # determining the size of the file uploaded (at least within my Django application
7 | # that is).
8 |
9 | # @author dokterbob
10 | # @author jrosebr1
11 |
12 | import mimetypes
13 | from os.path import splitext
14 |
15 | from django.core.exceptions import ValidationError
16 | from django.template.defaultfilters import filesizeformat
17 | from django.utils.deconstruct import deconstructible
18 | from django.utils.translation import ugettext_lazy as _
19 |
20 |
21 | @deconstructible
22 | class FileValidator(object):
23 | """
24 | Validator for files, checking the size, extension and mimetype.
25 |
26 | Initialization parameters:
27 | allowed_extensions: iterable with allowed file extensions
28 | ie. ('txt', 'doc')
29 | allowd_mimetypes: iterable with allowed mimetypes
30 | ie. ('image/png', )
31 | min_size: minimum number of bytes allowed
32 | ie. 100
33 | max_size: maximum number of bytes allowed
34 | ie. 24*1024*1024 for 24 MB
35 |
36 | Usage example::
37 |
38 | MyModel(models.Model):
39 | myfile = FileField(validators=FileValidator(max_size=24*1024*1024), ...)
40 |
41 | """
42 |
43 | extension_message = _("Extension '%(extension)s' not allowed. Allowed extensions are: '%(allowed_extensions)s.'")
44 | mime_message = _("MIME type '%(mimetype)s' is not valid. Allowed types are: %(allowed_mimetypes)s.")
45 | min_size_message = _('The current file %(size)s, which is too small. The minumum file size is %(allowed_size)s.')
46 | max_size_message = _('The current file %(size)s, which is too large. The maximum file size is %(allowed_size)s.')
47 |
48 | def __init__(self, *args, **kwargs):
49 | self.allowed_extensions = kwargs.pop('allowed_extensions', None)
50 | self.allowed_mimetypes = kwargs.pop('allowed_mimetypes', None)
51 | self.min_size = kwargs.pop('min_size', 0)
52 | self.max_size = kwargs.pop('max_size', None)
53 |
54 | def __call__(self, value):
55 | """
56 | Check the extension, content type and file size.
57 | """
58 |
59 | # Check the extension
60 | ext = splitext(value.name)[1][1:].lower()
61 | if self.allowed_extensions and not ext in self.allowed_extensions:
62 | message = self.extension_message % {
63 | 'extension': ext,
64 | 'allowed_extensions': ', '.join(self.allowed_extensions)
65 | }
66 |
67 | raise ValidationError(message)
68 |
69 | # Check the content type
70 | mimetype = mimetypes.guess_type(value.name)[0]
71 | if self.allowed_mimetypes and not mimetype in self.allowed_mimetypes:
72 | message = self.mime_message % {
73 | 'mimetype': mimetype,
74 | 'allowed_mimetypes': ', '.join(self.allowed_mimetypes)
75 | }
76 |
77 | raise ValidationError(message)
78 |
79 | # Check the file size
80 | filesize = len(value)
81 | if self.max_size and filesize > self.max_size:
82 | message = self.max_size_message % {
83 | 'size': filesizeformat(filesize),
84 | 'allowed_size': filesizeformat(self.max_size)
85 | }
86 |
87 | raise ValidationError(message)
88 |
89 | elif filesize < self.min_size:
90 | message = self.min_size_message % {
91 | 'size': filesizeformat(filesize),
92 | 'allowed_size': filesizeformat(self.min_size)
93 | }
94 |
95 | raise ValidationError(message)
96 |
--------------------------------------------------------------------------------
/posts/serializers.py:
--------------------------------------------------------------------------------
1 | from django.contrib.contenttypes.models import ContentType
2 | from django.db.models import Prefetch
3 | from rest_framework import serializers
4 |
5 | from tags.serializers import TagSerializer
6 | from utils.mixins import EagerLoaderMixin
7 | from .models import Post, Reply
8 |
9 |
10 | class IndexPostListSerializer(serializers.HyperlinkedModelSerializer, EagerLoaderMixin):
11 | """
12 | 首页帖子列表序列化器
13 | """
14 | author = serializers.SerializerMethodField()
15 | reply_count = serializers.SerializerMethodField()
16 | tags = TagSerializer(many=True, read_only=True)
17 | latest_reply_time = serializers.SerializerMethodField()
18 |
19 | SELECT_RELATED_FIELDS = ['author']
20 | PREFETCH_RELATED_FIELDS = [
21 | 'tags',
22 | Prefetch('replies', queryset=Reply.objects.order_by('-submit_date'))
23 | ]
24 |
25 | class Meta:
26 | model = Post
27 | fields = (
28 | 'id',
29 | 'url',
30 | 'title',
31 | 'views',
32 | 'created',
33 | 'modified',
34 | 'latest_reply_time',
35 | 'pinned',
36 | 'highlighted',
37 | 'tags',
38 | 'author',
39 | 'reply_count',
40 | )
41 |
42 | def get_author(self, obj):
43 | author = obj.author
44 | request = self.context.get('request')
45 | url = author.mugshot.url
46 | thumbnail_url = author.mugshot_thumbnail.url
47 | return {
48 | 'id': author.id,
49 | 'mugshot': request.build_absolute_uri(url) if request else url,
50 | 'mugshot_url': request.build_absolute_uri(thumbnail_url) if request else url,
51 | 'nickname': author.nickname,
52 | }
53 |
54 | def get_reply_count(self, obj):
55 | """
56 | 返回帖子的回复数量
57 | """
58 | return obj.replies.count()
59 |
60 | def get_latest_reply_time(self, obj):
61 | """
62 | 返回最后一次评论的时间,
63 | 如果没有评论,返回null
64 | """
65 | replies = obj.replies.all()
66 | if replies:
67 | return replies[0].submit_date
68 | else:
69 | return None
70 |
71 |
72 | class PopularPostSerializer(serializers.HyperlinkedModelSerializer, EagerLoaderMixin):
73 | """
74 | 热门帖子序列化器
75 | """
76 | author = serializers.SerializerMethodField()
77 |
78 | SELECT_RELATED_FIELDS = ['author']
79 |
80 | class Meta:
81 | model = Post
82 | fields = (
83 | 'id',
84 | 'url',
85 | 'title',
86 | 'author',
87 | )
88 |
89 | def get_author(self, obj):
90 | author = obj.author
91 | request = self.context.get('request')
92 | url = author.mugshot.url
93 | thumbnail_url = author.mugshot_thumbnail.url
94 | return {
95 | 'id': author.id,
96 | 'mugshot': request.build_absolute_uri(url) if request else url,
97 | 'mugshot_url': request.build_absolute_uri(thumbnail_url) if request else url,
98 | 'nickname': author.nickname,
99 | }
100 |
101 |
102 | class PostDetailSerializer(IndexPostListSerializer):
103 | """
104 | 用来显示帖子详情,已经用来创建、修改帖子的序列化器
105 | """
106 | author = serializers.SerializerMethodField()
107 | participants_count = serializers.SerializerMethodField()
108 | content_type = serializers.SerializerMethodField()
109 |
110 | class Meta:
111 | model = Post
112 | fields = (
113 | 'id',
114 | 'content_type',
115 | 'title',
116 | 'author',
117 | 'views',
118 | 'created',
119 | 'modified',
120 | 'body',
121 | 'tags',
122 | 'reply_count',
123 | 'participants_count',
124 | )
125 |
126 | def get_content_type(self, obj):
127 | """
128 | 帖子的content_type
129 | """
130 | content_type = ContentType.objects.get_for_model(obj)
131 | return content_type.id
132 |
133 | def get_participants_count(self, obj):
134 | """
135 | 返回评论参与者数量
136 | """
137 | return obj.replies.values_list('user', flat=True).distinct().count()
138 |
--------------------------------------------------------------------------------
/tags/tests/test_views.py:
--------------------------------------------------------------------------------
1 | from django.urls import reverse
2 | from rest_framework import status
3 | from rest_framework.test import APITestCase
4 |
5 | from tags.models import Tag
6 | from users.models import User
7 | from posts.models import Post
8 |
9 |
10 | class TagTests(APITestCase):
11 | def setUp(self):
12 | self.user = User.objects.create_user(username='test',
13 | email='test@test.com',
14 | password='test',
15 | nickname='test')
16 | self.admin = User.objects.create_superuser(username='admin',
17 | email='admin@test.com',
18 | password='admin',
19 | nickname='admin')
20 |
21 | def test_admin_can_create_tag(self):
22 | """
23 | 测试管理员可以添加标签
24 | """
25 | url = reverse('tag-list')
26 | data = {
27 | 'name': 'test tag'
28 | }
29 | self.client.login(username='admin', password='admin')
30 | response = self.client.post(url, data, format='json')
31 | self.assertEqual(response.status_code, status.HTTP_201_CREATED)
32 | self.assertEqual(Tag.objects.count(), 1)
33 | self.assertEqual(Tag.objects.get().creator, self.admin)
34 | self.assertEqual(Tag.objects.get().name, 'test tag')
35 |
36 | def test_unauthorized_user_cannot_create_tag(self):
37 | """
38 | 测试没有权限用户无法添加标签
39 | """
40 | url = reverse('tag-list')
41 | data = {
42 | 'name': 'test tag'
43 | }
44 | self.client.login(username='test', password='test')
45 | response = self.client.post(url, data, format='json')
46 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
47 |
48 | def test_anonymous_user_cannot_create_tag(self):
49 | """
50 | 测试未登录用户无法添加标签
51 | """
52 | url = reverse('tag-list')
53 | data = {
54 | 'name': 'test tag'
55 | }
56 | response = self.client.post(url, data, format='json')
57 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
58 |
59 | def test_no_duplicate_tags(self):
60 | """
61 | 测试没有重复的标签
62 | """
63 | """
64 | 测试管理员可以添加标签
65 | """
66 | self.tag = Tag.objects.create(name='test tag',
67 | creator=self.admin
68 | )
69 | url = reverse('tag-list')
70 | data = {
71 | 'name': 'test tag'
72 | }
73 | self.client.login(username='admin', password='admin')
74 | response = self.client.post(url, data, format='json')
75 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
76 |
77 | def test_popular_tags(self):
78 | """
79 | 测试热门标签
80 | """
81 | self.post1 = Post.objects.create(title='this is a test1',
82 | body='this is a test',
83 | author=self.admin)
84 | self.post2 = Post.objects.create(title='this is a test2',
85 | body='this is a test',
86 | author=self.admin)
87 | self.post3 = Post.objects.create(title='this is a test3',
88 | body='this is a test',
89 | author=self.admin)
90 | self.tag1 = Tag.objects.create(name='test tag1',
91 | creator=self.admin)
92 | self.tag2 = Tag.objects.create(name='test tag2',
93 | creator=self.admin)
94 | self.tag3 = Tag.objects.create(name='test tag3',
95 | creator=self.admin)
96 | self.post1.tags.add(self.tag1, self.tag2)
97 | self.post2.tags.add(self.tag2, self.tag3, self.tag1)
98 | self.post3.tags.add(self.tag2)
99 | url = reverse('tag-popular')
100 | response = self.client.get(url, format='json')
101 | self.assertEqual(response.data[0]['name'], 'test tag2')
102 | self.assertEqual(response.data[1]['name'], 'test tag1')
103 | self.assertEqual(response.data[2]['name'], 'test tag3')
104 |
--------------------------------------------------------------------------------
/db_tools/fake_db_fast.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import random
4 |
5 | BASE_DIR = os.path.dirname((os.path.dirname(os.path.abspath(__file__))))
6 | sys.path.append(BASE_DIR)
7 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings.local")
8 |
9 | import django
10 |
11 | django.setup()
12 |
13 | from users.factories import UserFactory
14 | from users.models import User
15 | from tags.factories import TagFactory
16 | from tags.models import Tag
17 | from replies.factories import PostReplyFactory, SiteFactory
18 | from replies.models import Reply
19 | from posts.factories import PostFactory
20 | from posts.models import Post
21 | from balance.factories import RecordFactory
22 |
23 | from allauth.account.models import EmailAddress
24 | from rest_framework.test import APIClient
25 | from django.urls import reverse
26 | from django.contrib.contenttypes.models import ContentType
27 | from faker import Faker
28 |
29 | if __name__ == '__main__':
30 | site = SiteFactory()
31 | c = APIClient()
32 | fake = Faker()
33 | # 生成 5 个用户
34 | UserFactory.create_batch(5)
35 | print('users created...')
36 |
37 | # 生成 5 个标签
38 | user = UserFactory(username='admin')
39 | TagFactory.create_batch(5, creator=user)
40 | print('tags created...')
41 |
42 | tags = list(Tag.objects.all())
43 | users = list(User.objects.all())
44 |
45 | # 每个用户发布一定量的帖子
46 | for user in users:
47 | post_count = random.randint(5, 10)
48 |
49 | for i in range(post_count):
50 | tag_count = random.randint(1, 3)
51 | tag_sample = random.sample(tags, tag_count)
52 | PostFactory(author=user, tags=tag_sample)
53 |
54 | # 绑定一个 email
55 | EmailAddress.objects.create(
56 | user=user,
57 | email=user.email,
58 | verified=True,
59 | primary=True
60 | )
61 | # 获得一定的财富
62 | RecordFactory.create_batch(random.randint(0, 3), user=user)
63 |
64 | print('posts posted...')
65 |
66 | posts = list(Post.objects.all())
67 |
68 | # 随机选择一些用户对每个帖子进行回复,使用 client 发送请求,
69 | # 这样回复时会生成相应通知
70 | for post in posts:
71 | post_ct = ContentType.objects.get_for_model(post)
72 | post_id = post.id
73 | url = reverse('reply-list')
74 |
75 | user_sample = random.sample(users, random.randint(0, 5))
76 | for user in user_sample:
77 | data = {
78 | "content_type": post_ct.id,
79 | "object_pk": post_id,
80 | "site": 1,
81 | "comment": fake.text(max_nb_chars=200, ext_word_list=None)
82 | }
83 | c.force_login(user)
84 | c.post(url, data, format='json')
85 |
86 | print('post replies created...')
87 |
88 | # 再随机选择一些用户对某些回复进行回复
89 | for i in range(2):
90 | replies = list(Reply.objects.all())
91 | for reply in replies:
92 | indicator = random.randint(0, 2)
93 | if indicator: # 以 1/3 的概率回复
94 | continue
95 | post = reply.content_object
96 | post_ct = ContentType.objects.get_for_model(post)
97 | post_id = post.id
98 | url = reverse('reply-list')
99 | user_sample = random.sample(users, random.randint(1, 3))
100 |
101 | for user in user_sample:
102 | data = {
103 | "content_type": post_ct.id,
104 | "object_pk": post_id,
105 | "site": 1,
106 | "comment": fake.text(max_nb_chars=200, ext_word_list=None),
107 | "parent": reply.id
108 | }
109 | c.force_login(user)
110 | c.post(url, data, format='json')
111 | print('reply replies created...')
112 |
113 | # 再随机选择一些用户对回复进行点赞
114 | replies = list(Reply.objects.all())
115 | for i in range(2):
116 | for reply in replies:
117 | indicator = random.randint(0, 2)
118 | if indicator: # 以 1/3 的概率回复
119 | continue
120 | url = reverse('reply-like', kwargs={'pk': reply.id})
121 | user_sample = random.sample(users, random.randint(1, 3))
122 |
123 | for user in user_sample:
124 | data = {
125 | "content_type": reply.ctype_id,
126 | "object_id": reply.id,
127 | "flag": "like",
128 | }
129 | c.force_login(user)
130 | c.post(url, data, format='json')
131 | print('liked!')
132 | print('done...!')
--------------------------------------------------------------------------------
/db_tools/fake_db.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import random
4 |
5 | BASE_DIR = os.path.dirname((os.path.dirname(os.path.abspath(__file__))))
6 | sys.path.append(BASE_DIR)
7 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings.local")
8 |
9 | import django
10 |
11 | django.setup()
12 |
13 | from users.factories import UserFactory
14 | from users.models import User
15 | from tags.factories import TagFactory
16 | from tags.models import Tag
17 | from replies.factories import PostReplyFactory, SiteFactory
18 | from replies.models import Reply
19 | from posts.factories import PostFactory
20 | from posts.models import Post
21 | from balance.factories import RecordFactory
22 |
23 | from allauth.account.models import EmailAddress
24 | from rest_framework.test import APIClient
25 | from django.urls import reverse
26 | from django.contrib.contenttypes.models import ContentType
27 | from faker import Faker
28 |
29 | if __name__ == '__main__':
30 | site = SiteFactory()
31 | c = APIClient()
32 | fake = Faker()
33 | # 生成 10 个用户
34 | UserFactory.create_batch(10)
35 | print('users created...')
36 |
37 | # 生成 10 个标签
38 | user = UserFactory(username='admin')
39 | TagFactory.create_batch(10, creator=user)
40 | print('tags created...')
41 |
42 | tags = list(Tag.objects.all())
43 | users = list(User.objects.all())
44 |
45 | # 每个用户发布一定量的帖子
46 | for user in users:
47 | post_count = random.randint(10, 15)
48 |
49 | for i in range(post_count):
50 | tag_count = random.randint(1, 3)
51 | tag_sample = random.sample(tags, tag_count)
52 | PostFactory(author=user, tags=tag_sample)
53 |
54 | # 绑定一个 email
55 | EmailAddress.objects.create(
56 | user=user,
57 | email=user.email,
58 | verified=True,
59 | primary=True
60 | )
61 | # 获得一定的财富
62 | RecordFactory.create_batch(random.randint(0, 10), user=user)
63 |
64 | print('posts posted...')
65 |
66 | posts = list(Post.objects.all())
67 |
68 | # 随机选择一些用户对每个帖子进行回复,使用 client 发送请求,
69 | # 这样回复时会生成相应通知
70 | for post in posts:
71 | post_ct = ContentType.objects.get_for_model(post)
72 | post_id = post.id
73 | url = reverse('reply-list')
74 |
75 | user_sample = random.sample(users, random.randint(0, 10))
76 | for user in user_sample:
77 | data = {
78 | "content_type": post_ct.id,
79 | "object_pk": post_id,
80 | "site": 1,
81 | "comment": fake.text(max_nb_chars=200, ext_word_list=None)
82 | }
83 | c.force_login(user)
84 | c.post(url, data, format='json')
85 |
86 | print('post replies created...')
87 |
88 | # 再随机选择一些用户对某些回复进行回复
89 | for i in range(3):
90 | replies = list(Reply.objects.all())
91 | for reply in replies:
92 | indicator = random.randint(0, 2)
93 | if indicator: # 以 1/3 的概率回复
94 | continue
95 | post = reply.content_object
96 | post_ct = ContentType.objects.get_for_model(post)
97 | post_id = post.id
98 | url = reverse('reply-list')
99 | user_sample = random.sample(users, random.randint(1, 3))
100 |
101 | for user in user_sample:
102 | data = {
103 | "content_type": post_ct.id,
104 | "object_pk": post_id,
105 | "site": 1,
106 | "comment": fake.text(max_nb_chars=200, ext_word_list=None),
107 | "parent": reply.id
108 | }
109 | c.force_login(user)
110 | c.post(url, data, format='json')
111 | print('reply replies created...')
112 |
113 | # 再随机选择一些用户对回复进行点赞
114 | replies = list(Reply.objects.all())
115 | for i in range(5):
116 | for reply in replies:
117 | indicator = random.randint(0, 2)
118 | if indicator: # 以 1/3 的概率回复
119 | continue
120 | url = reverse('reply-like', kwargs={'pk': reply.id})
121 | user_sample = random.sample(users, random.randint(1, 3))
122 |
123 | for user in user_sample:
124 | data = {
125 | "content_type": reply.ctype_id,
126 | "object_id": reply.id,
127 | "flag": "like",
128 | }
129 | c.force_login(user)
130 | c.post(url, data, format='json')
131 | print('liked!')
132 | print('done...!')
133 |
--------------------------------------------------------------------------------
/replies/tests/test_moderation.py:
--------------------------------------------------------------------------------
1 | from django.contrib.contenttypes.models import ContentType
2 | from django.contrib.sites.models import Site
3 | from django.test import TestCase
4 | from notifications.models import Notification
5 |
6 | from posts.models import Post
7 | from users.models import User
8 |
9 | from ..models import Reply
10 | from ..moderation import ReplyModerator
11 |
12 | reply_moderator = ReplyModerator(ReplyModerator)
13 |
14 |
15 | class ModerationTestCase(TestCase):
16 | def setUp(self):
17 | self.user = User.objects.create_user(
18 | username='test',
19 | email='test@test.com',
20 | password='test',
21 | nickname='test'
22 | )
23 | self.another_user = User.objects.create_user(
24 | username='another',
25 | email='another@test.com',
26 | password='another',
27 | nickname='another'
28 | )
29 | self.post = Post.objects.create(
30 | title='test title',
31 | author=self.user
32 | )
33 | self.site = Site.objects.create(name='test', domain='test.com')
34 | self.post_ct = ContentType.objects.get_for_model(self.post)
35 | self.post_id = self.post.id
36 |
37 | def test_post_author_received_reply_notification(self):
38 | """帖子被回复后,且回复者不是作者,则作者收到通知"""
39 | reply = Reply.objects.create(
40 | content_type=self.post_ct,
41 | object_pk=self.post_id,
42 | site=self.site,
43 | user=self.another_user, # 回复者不是帖子作者
44 | comment='reply',
45 | )
46 |
47 | reply_moderator.notify(reply=reply, content_object=self.post, request=None)
48 | self.assertEqual(Notification.objects.count(), 1)
49 |
50 | def test_reply_self_do_not_received_notification(self):
51 | reply = Reply.objects.create(
52 | content_type=self.post_ct,
53 | object_pk=self.post_id,
54 | site=self.site,
55 | user=self.user, # 回复者是帖子作者
56 | comment='reply',
57 | )
58 |
59 | reply_moderator.notify(reply=reply, content_object=self.post, request=None)
60 | self.assertEqual(Notification.objects.count(), 0)
61 |
62 | def test_reply_others_reply_as_well_as_others_post(self):
63 | """回复他人回复且不是自己的帖子,他人和帖子作者收到通知"""
64 | user = User.objects.create_user(
65 | username='user',
66 | email='user@test.com',
67 | password='user',
68 | nickname='user'
69 | )
70 |
71 | reply = Reply.objects.create(
72 | content_type=self.post_ct,
73 | object_pk=self.post_id,
74 | site=self.site,
75 | user=self.another_user,
76 | comment='reply',
77 | )
78 |
79 | new_reply = Reply.objects.create(
80 | content_type=self.post_ct,
81 | object_pk=self.post_id,
82 | site=self.site,
83 | user=user,
84 | comment='new reply',
85 | parent=reply,
86 | )
87 |
88 | reply_moderator.notify(reply=new_reply, content_object=self.post, request=None)
89 | self.assertEqual(Notification.objects.count(), 2)
90 |
91 | def test_reply_others_reply_but_self_post(self):
92 | """回复他人回复但是自己的帖子,他人收到通知"""
93 | reply = Reply.objects.create(
94 | content_type=self.post_ct,
95 | object_pk=self.post_id,
96 | site=self.site,
97 | user=self.another_user,
98 | comment='reply',
99 | )
100 |
101 | new_reply = Reply.objects.create(
102 | content_type=self.post_ct,
103 | object_pk=self.post_id,
104 | site=self.site,
105 | user=self.user,
106 | comment='new reply',
107 | parent=reply,
108 | )
109 |
110 | reply_moderator.notify(reply=new_reply, content_object=self.post, request=None)
111 | self.assertEqual(Notification.objects.count(), 1)
112 | self.assertEqual(
113 | Notification.objects.get().recipient,
114 | reply.user
115 | )
116 |
117 | def test_reply_self_reply_as_well_as_self_post(self):
118 | """回复自己的回复和自己的帖子,没有任何通知"""
119 | reply = Reply.objects.create(
120 | content_type=self.post_ct,
121 | object_pk=self.post_id,
122 | site=self.site,
123 | user=self.user, # 回复者是帖子作者
124 | comment='reply',
125 | )
126 |
127 | new_reply = Reply.objects.create(
128 | content_type=self.post_ct,
129 | object_pk=self.post_id,
130 | site=self.site,
131 | user=self.user, # 回复者是帖子作者
132 | comment='new reply',
133 | parent=reply,
134 | )
135 |
136 | reply_moderator.notify(reply=new_reply, content_object=self.post, request=None)
137 | self.assertEqual(Notification.objects.count(), 0)
138 |
139 | def test_reply_self_reply_but_others_post(self):
140 | """回复自己的回复和让人的帖子,帖子作者收到一条通知"""
141 | reply = Reply.objects.create(
142 | content_type=self.post_ct,
143 | object_pk=self.post_id,
144 | site=self.site,
145 | user=self.another_user, # 回复者是帖子作者
146 | comment='reply',
147 | )
148 |
149 | new_reply = Reply.objects.create(
150 | content_type=self.post_ct,
151 | object_pk=self.post_id,
152 | site=self.site,
153 | user=self.another_user, # 回复者是帖子作者
154 | comment='new reply',
155 | parent=reply,
156 | )
157 |
158 | reply_moderator.notify(reply=new_reply, content_object=self.post, request=None)
159 | self.assertEqual(Notification.objects.count(), 1)
160 | self.assertEqual(
161 | Notification.objects.get().recipient,
162 | self.post.author
163 | )
164 |
--------------------------------------------------------------------------------
/config/settings/common.py:
--------------------------------------------------------------------------------
1 | """
2 | Django settings for DjangoChina project.
3 | Generated by 'django-admin startproject' using Django 1.11.
4 | For more information on this file, see
5 | https://docs.djangoproject.com/en/1.11/topics/settings/
6 | For the full list of settings and their values, see
7 | https://docs.djangoproject.com/en/1.11/ref/settings/
8 | """
9 |
10 | import datetime
11 | import os
12 |
13 | # Build paths inside the project like this: os.path.join(BASE_DIR, ...)
14 | BASE_DIR = os.path.dirname(os.path.dirname(
15 | os.path.dirname(os.path.abspath(__file__))))
16 |
17 | # Quick-start development settings - unsuitable for production
18 | # See https://docs.djangoproject.com/en/1.11/howto/deployment/checklist/
19 |
20 | # SECURITY WARNING: don't run with debug turned on in production!
21 | DEBUG = True
22 |
23 | ALLOWED_HOSTS = []
24 |
25 | SITE_ID = 1
26 |
27 | LOGIN_URL = '/accounts/login/'
28 | LOGIN_REDIRECT_URL = '/'
29 |
30 | # Application definition
31 |
32 | INSTALLED_APPS = [
33 | 'django.contrib.admin',
34 | 'django.contrib.auth',
35 | 'django.contrib.contenttypes',
36 | 'django.contrib.sessions',
37 | 'django.contrib.messages',
38 | 'django.contrib.staticfiles',
39 | 'django.contrib.sites',
40 |
41 | # third-party apps
42 | 'bootstrapform',
43 | 'notifications',
44 | 'rest_framework',
45 | 'rest_framework.authtoken',
46 | 'rest_auth',
47 | 'allauth',
48 | 'allauth.account',
49 | 'allauth.socialaccount',
50 | 'allauth.socialaccount.providers.github',
51 | 'rest_auth.registration',
52 | 'django_comments',
53 | 'actstream',
54 | 'django_filters',
55 | 'corsheaders',
56 | 'raven.contrib.django.raven_compat', # sentry support
57 |
58 | # local apps
59 | 'users',
60 | 'posts',
61 | 'replies',
62 | 'tags',
63 | 'balance',
64 | 'notifications_extension',
65 | ]
66 |
67 | COMMENTS_APP = 'replies'
68 |
69 | MIDDLEWARE = [
70 | 'corsheaders.middleware.CorsMiddleware',
71 | 'django.middleware.security.SecurityMiddleware',
72 | 'django.contrib.sessions.middleware.SessionMiddleware',
73 | 'django.middleware.common.CommonMiddleware',
74 | # 'django.middleware.csrf.CsrfViewMiddleware',
75 | 'django.contrib.auth.middleware.AuthenticationMiddleware',
76 | 'users.disable_csrf_middleware.DisableCSRFCheck',
77 | 'users.jwt_middleware.JWTMiddleware',
78 | 'django.contrib.messages.middleware.MessageMiddleware',
79 | 'django.middleware.clickjacking.XFrameOptionsMiddleware',
80 | ]
81 |
82 | CORS_ORIGIN_ALLOW_ALL = True # 跨域
83 |
84 | ROOT_URLCONF = 'config.urls'
85 |
86 | TEMPLATES = [
87 | {
88 | 'BACKEND': 'django.template.backends.django.DjangoTemplates',
89 | 'DIRS': [
90 | os.path.join(BASE_DIR, 'templates'),
91 | os.path.join(BASE_DIR, 'templates', 'users')
92 | ],
93 | 'APP_DIRS': True,
94 | 'OPTIONS': {
95 | 'context_processors': [
96 | 'django.template.context_processors.debug',
97 | 'django.template.context_processors.request',
98 | 'django.contrib.auth.context_processors.auth',
99 | 'django.contrib.messages.context_processors.messages',
100 | ],
101 | },
102 | },
103 | ]
104 |
105 | WSGI_APPLICATION = 'config.wsgi.application'
106 |
107 | # Password validation
108 | # https://docs.djangoproject.com/en/1.11/ref/settings/#auth-password-validators
109 |
110 | AUTH_PASSWORD_VALIDATORS = [
111 | {
112 | 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
113 | },
114 | {
115 | 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
116 | },
117 | {
118 | 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
119 | },
120 | {
121 | 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
122 | },
123 | ]
124 |
125 | # Internationalization
126 | # https://docs.djangoproject.com/en/1.11/topics/i18n/
127 |
128 | LANGUAGE_CODE = 'zh-hans'
129 |
130 | TIME_ZONE = 'Asia/Shanghai'
131 |
132 | USE_I18N = True
133 |
134 | USE_L10N = True
135 |
136 | USE_TZ = True
137 |
138 | # Static files (CSS, JavaScript, Images)
139 | # https://docs.djangoproject.com/en/1.11/howto/static-files/
140 |
141 | STATIC_URL = '/static/'
142 | STATIC_ROOT = os.path.join(BASE_DIR, 'static')
143 |
144 | MEDIA_ROOT = os.path.join(BASE_DIR, 'media')
145 | MEDIA_URL = '/media/'
146 |
147 | AUTH_USER_MODEL = 'users.User'
148 |
149 | # django-rest-framework settings
150 | REST_FRAMEWORK = {
151 | 'DEFAULT_PERMISSION_CLASSES': (
152 | 'rest_framework.permissions.IsAuthenticated',
153 | ),
154 | 'DEFAULT_AUTHENTICATION_CLASSES': (
155 | 'rest_framework_jwt.authentication.JSONWebTokenAuthentication',
156 | 'rest_framework.authentication.SessionAuthentication',
157 | 'rest_framework.authentication.BasicAuthentication'
158 | ),
159 | 'DEFAULT_PAGINATION_CLASS': 'utils.rest_tools.CustomPageNumberPagination',
160 | }
161 |
162 | # django-allauth settings
163 | AUTHENTICATION_BACKENDS = (
164 | # Needed to login by username in Django admin, regardless of `allauth`
165 | 'django.contrib.auth.backends.ModelBackend',
166 |
167 | # `allauth` specific authentication methods, such as login by e-mail
168 | 'allauth.account.auth_backends.AuthenticationBackend',
169 | )
170 |
171 | # django-rest-auth settings
172 | REST_AUTH_SERIALIZERS = {
173 | 'USER_DETAILS_SERIALIZER': 'users.serializers.UserDetailsSerializer',
174 | }
175 |
176 | REST_AUTH_REGISTER_SERIALIZERS = {
177 | 'REGISTER_SERIALIZER': 'users.serializers.UserRegistrationSerializer',
178 | }
179 |
180 | # djangorestframework-jwt settings
181 | JWT_AUTH = {
182 | 'JWT_EXPIRATION_DELTA': datetime.timedelta(seconds=60 * 30),
183 | 'JWT_AUTH_HEADER_PREFIX': 'Bearer',
184 | 'JWT_ALLOW_REFRESH': True,
185 | }
186 |
187 | if DEBUG:
188 | JWT_AUTH['JWT_EXPIRATION_DELTA'] = datetime.timedelta(days=1)
189 |
190 | # django-all-auth setting
191 | REST_USE_JWT = True
192 | ACCOUNT_AUTHENTICATION_METHOD = 'username_email'
193 | ACCOUNT_EMAIL_REQUIRED = True
194 | ACCOUNT_EMAIL_VERIFICATION = True
195 | LOGIN_ON_EMAIL_CONFIRMATION = True
196 | SOCIALACCOUNT_EMAIL_VERIFICATION = False
197 | OLD_PASSWORD_FIELD_ENABLED = True
198 |
199 | # 软删除
200 | NOTIFICATIONS_SOFT_DELETE = True
201 |
202 | SOCIAL_LOGIN_GITHUB_CALLBACK_URL = 'dummy'
203 |
--------------------------------------------------------------------------------
/users/mugshot.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/richardasaurus/randomavatar
3 | """
4 |
5 | import random
6 | import math
7 | import hashlib
8 | from io import BytesIO
9 | from PIL import Image, ImageDraw
10 |
11 |
12 | class Avatar(object):
13 | def __init__(self, rows, columns):
14 | self.rows = rows
15 | self.cols = columns
16 | self._generate_colours()
17 |
18 | m = hashlib.md5()
19 | m.update(b"hello world")
20 | entropy = len(m.hexdigest()) / 2 * 8
21 | if self.rows > 15 or self.cols > 15:
22 | raise ValueError("Rows and columns must be valued 15 or under")
23 |
24 | self.digest = hashlib.md5
25 | self.digest_entropy = entropy
26 |
27 | def _generate_colours(self):
28 | colours_ok = False
29 |
30 | while colours_ok is False:
31 | self.fg_colour = self._get_pastel_colour()
32 | self.bg_colour = self._get_pastel_colour(lighten=80)
33 |
34 | # Get the luminance for each colour
35 | fg_lum = self._luminance(self.fg_colour) + 0.05
36 | bg_lum = self._luminance(self.bg_colour) + 0.05
37 |
38 | # Check the difference in luminance
39 | # meets the 1.25 threshold
40 | result = (fg_lum / bg_lum) \
41 | if (fg_lum / bg_lum) else (bg_lum / fg_lum)
42 | if result > 1.20:
43 | colours_ok = True
44 |
45 | def get_image(self, string, width, height, pad=0):
46 | """
47 | Byte representation of a PNG image
48 | """
49 | hex_digest_byte_list = self._string_to_byte_list(string)
50 | matrix = self._create_matrix(hex_digest_byte_list)
51 | return self._create_image(matrix, width, height, pad)
52 |
53 | def save(self, image_byte_array=None, save_location=None):
54 | if image_byte_array and save_location:
55 | with open(save_location, 'wb') as f:
56 | return f.write(image_byte_array)
57 | else:
58 | raise ValueError('image_byte_array and path must be provided')
59 |
60 | def _get_pastel_colour(self, lighten=127):
61 | """
62 | Create a pastel colour hex colour string
63 | """
64 | def r():
65 | return random.randint(0, 128) + 127
66 | return r(), r(), r() # return rgb values as a tuple
67 |
68 | def _luminance(self, rgb):
69 | """
70 | Determine the liminanace of an RGB colour
71 | """
72 | a = []
73 | for v in rgb:
74 | v = v / float(255)
75 | if v < 0.03928:
76 | result = v / 12.92
77 | else:
78 | result = math.pow(((v + 0.055) / 1.055), 2.4)
79 |
80 | a.append(result)
81 | return a[0] * 0.2126 + a[1] * 0.7152 + a[2] * 0.0722
82 |
83 | def _string_to_byte_list(self, data):
84 | """
85 | Creates a hex digest of the input string given to create the image,
86 | if it's not already hexadecimal
87 | Returns:
88 | Length 16 list of rgb value range integers
89 | (each representing a byte of the hex digest)
90 | """
91 | bytes_length = 16
92 |
93 | m = self.digest()
94 | m.update(str.encode(data))
95 | hex_digest = m.hexdigest()
96 |
97 | return list(int(hex_digest[num * 2:num * 2 + 2], bytes_length)
98 | for num in range(bytes_length))
99 |
100 | def _bit_is_one(self, n, hash_bytes):
101 | """
102 | Check if the n (index) of hash_bytes is 1 or 0.
103 | """
104 |
105 | scale = 16 # hexadecimal
106 |
107 | if not hash_bytes[int(n / (scale / 2))] >> int(
108 | (scale / 2) - ((n % (scale / 2)) + 1)) & 1 == 1:
109 | return False
110 | return True
111 |
112 | def _create_image(self, matrix, width, height, pad):
113 | """
114 | Generates a PNG byte list
115 | """
116 |
117 | image = Image.new("RGB", (width + (pad * 2),
118 | height + (pad * 2)), self.bg_colour)
119 | image_draw = ImageDraw.Draw(image)
120 |
121 | # Calculate the block widht and height.
122 | block_width = width / self.cols
123 | block_height = height / self.rows
124 |
125 | # Loop through blocks in matrix, draw rectangles.
126 | for row, cols in enumerate(matrix):
127 | for col, cell in enumerate(cols):
128 | if cell:
129 | image_draw.rectangle((
130 | pad + col * block_width, # x1
131 | pad + row * block_height, # y1
132 | pad + (col + 1) * block_width - 1, # x2
133 | pad + (row + 1) * block_height - 1 # y2
134 | ), fill=self.fg_colour)
135 |
136 | stream = BytesIO()
137 | image.save(stream, format="png", optimize=True)
138 | # return the image byte data
139 | return stream.getvalue()
140 |
141 | def _create_matrix(self, byte_list):
142 | """
143 | This matrix decides which blocks should be filled fg/bg colour
144 | True for fg_colour
145 | False for bg_colour
146 | hash_bytes - array of hash bytes values. RGB range values in each slot
147 | Returns:
148 | List representation of the matrix
149 | [[True, True, True, True],
150 | [False, True, True, False],
151 | [True, True, True, True],
152 | [False, False, False, False]]
153 | """
154 |
155 | # Number of rows * cols halfed and rounded
156 | # in order to fill opposite side
157 | cells = int(self.rows * self.cols / 2 + self.cols % 2)
158 |
159 | matrix = [[False] * self.cols for num in range(self.rows)]
160 |
161 | for cell_number in range(cells):
162 |
163 | # If the bit with index corresponding to this cell is 1
164 | # mark that cell as fg_colour
165 | # Skip byte 1, that's used in determining fg_colour
166 | if self._bit_is_one(cell_number, byte_list[1:]):
167 | # Find cell coordinates in matrix.
168 | x_row = cell_number % self.rows
169 | y_col = int(cell_number / self.cols)
170 | # Set coord True and its opposite side
171 | matrix[x_row][self.cols - y_col - 1] = True
172 | matrix[x_row][y_col] = True
173 | return matrix
174 |
--------------------------------------------------------------------------------
/replies/serializers.py:
--------------------------------------------------------------------------------
1 | from actstream.models import Follow
2 | from django.contrib.contenttypes.models import ContentType
3 | from django.contrib.sites.models import Site
4 | from rest_framework import serializers
5 |
6 | from posts.models import Post
7 | from replies.models import Reply
8 |
9 |
10 | class FlatReplySerializer(serializers.ModelSerializer):
11 | """
12 | 返回一个扁平化的按发表时间倒序排序的 reply 列表,无视其层级关系。
13 | 适合用在 user 详情页面的个人回复列表中。
14 | """
15 | post = serializers.SerializerMethodField()
16 | user = serializers.SerializerMethodField()
17 | parent_user = serializers.SerializerMethodField()
18 | is_liked = serializers.SerializerMethodField()
19 |
20 | class Meta:
21 | model = Reply
22 | fields = (
23 | 'id',
24 | 'user',
25 | 'parent_user',
26 | 'post',
27 | 'submit_date',
28 | 'comment',
29 | 'like_count',
30 | 'is_liked',
31 | )
32 |
33 | def get_post(self, obj):
34 | post = obj.content_object
35 | return {
36 | 'id': post.id,
37 | 'title': post.title,
38 | }
39 |
40 | def get_user(self, obj):
41 | user = obj.user
42 | request = self.context.get('request')
43 | url = user.mugshot.url
44 | thumbnail_url = user.mugshot_thumbnail.url
45 | return {
46 | 'id': user.id,
47 | 'mugshot': request.build_absolute_uri(url) if request else url,
48 | 'mugshot_url': request.build_absolute_uri(thumbnail_url) if request else url,
49 | 'nickname': user.nickname,
50 | }
51 |
52 | def get_parent_user(self, obj):
53 | parent = obj.parent
54 | if not parent:
55 | return None
56 | user = parent.user
57 | request = self.context.get('request')
58 | url = user.mugshot.url
59 | thumbnail_url = user.mugshot_thumbnail.url
60 | return {
61 | 'id': user.id,
62 | 'mugshot': request.build_absolute_uri(url) if request else url,
63 | 'mugshot_url': request.build_absolute_uri(thumbnail_url) if request else url,
64 | 'nickname': user.nickname,
65 | }
66 |
67 | def get_is_liked(self, obj):
68 | request = self.context.get('request')
69 | return Follow.objects.is_following(request.user, obj, flag='like')
70 |
71 |
72 | class ReplyCreationSerializer(serializers.ModelSerializer):
73 | """
74 | 仅用于 reply 的创建
75 | """
76 | parent_user = serializers.SerializerMethodField()
77 | user = serializers.SerializerMethodField()
78 |
79 | class Meta:
80 | model = Reply
81 | fields = (
82 | 'id',
83 | 'object_pk',
84 | 'comment',
85 | 'parent',
86 | 'submit_date',
87 | 'ip_address',
88 | 'is_public',
89 | 'is_removed',
90 | 'user',
91 | 'parent_user',
92 | )
93 | read_only_fields = (
94 | 'id',
95 | 'submit_date',
96 | 'ip_address',
97 | 'is_public',
98 | 'is_removed',
99 | )
100 |
101 | def get_parent_user(self, obj):
102 | parent = obj.parent
103 | if not parent:
104 | return None
105 | user = parent.user
106 | request = self.context.get('request')
107 | url = user.mugshot.url
108 | thumbnail_url = user.mugshot_thumbnail.url
109 | return {
110 | 'id': user.id,
111 | 'mugshot': request.build_absolute_uri(url) if request else url,
112 | 'mugshot_url': request.build_absolute_uri(thumbnail_url) if request else url,
113 | 'nickname': user.nickname,
114 | }
115 |
116 | def get_user(self, obj):
117 | user = obj.user
118 | request = self.context.get('request')
119 | url = user.mugshot.url
120 | thumbnail_url = user.mugshot_thumbnail.url
121 | return {
122 | 'id': user.id,
123 | 'mugshot': request.build_absolute_uri(url) if request else url,
124 | 'mugshot_url': request.build_absolute_uri(thumbnail_url) if request else url,
125 | 'nickname': user.nickname,
126 | }
127 |
128 | def create(self, validated_data):
129 | post_id = validated_data.get('object_pk')
130 | post_ctype = ContentType.objects.get_for_model(
131 | Post.objects.get(id=int(post_id))
132 | )
133 | site = Site.objects.get_current()
134 | validated_data['content_type'] = post_ctype
135 | validated_data['site'] = site
136 | return super(ReplyCreationSerializer, self).create(validated_data)
137 |
138 |
139 | class TreeRepliesSerializer(serializers.ModelSerializer):
140 | """
141 | 返回两层的 reply,第一层为根 reply,第二层为这个 reply 的所有子孙 reply。
142 | 这个 Serializer 适合用于帖子详情页的 reply 列表。
143 | """
144 | descendants = serializers.SerializerMethodField()
145 | user = serializers.SerializerMethodField()
146 | is_liked = serializers.SerializerMethodField()
147 |
148 | class Meta:
149 | model = Reply
150 | fields = (
151 | 'id',
152 | 'content_type',
153 | 'object_pk',
154 | 'comment',
155 | 'submit_date',
156 | 'like_count',
157 | 'user',
158 | 'descendants',
159 | 'descendants_count',
160 | 'is_liked'
161 | )
162 |
163 | def get_is_liked(self, obj):
164 | request = self.context.get('request')
165 | return Follow.objects.is_following(request.user, obj, flag='like')
166 |
167 | def get_user(self, obj):
168 | user = obj.user
169 | request = self.context.get('request')
170 | url = user.mugshot.url
171 | thumbnail_url = user.mugshot_thumbnail.url
172 | return {
173 | 'id': user.id,
174 | 'mugshot': request.build_absolute_uri(url) if request else url,
175 | 'mugshot_url': request.build_absolute_uri(thumbnail_url) if request else url,
176 | 'nickname': user.nickname,
177 | }
178 |
179 | def get_descendants(self, obj):
180 | qs = obj.descendants()
181 | request = self.context.get('request')
182 | return FlatReplySerializer(qs, many=True, context={'request': request}).data
183 |
184 |
185 | class FollowSerializer(serializers.ModelSerializer):
186 | """
187 | 用于记录回复的点赞信息
188 | """
189 |
190 | class Meta:
191 | model = Follow
192 | fields = (
193 | 'id',
194 | 'user',
195 | 'content_type',
196 | 'object_id',
197 | 'flag',
198 | 'started',
199 | )
200 | read_only_fields = (
201 | 'id',
202 | 'user',
203 | 'content_type',
204 | 'object_id',
205 | 'flag',
206 | 'started',
207 | )
208 |
209 | def create(self, validated_data):
210 | reply_id = validated_data.get('object_pk')
211 | reply_ctype = ContentType.objects.get_for_model(
212 | Reply.objects.get(id=int(reply_id))
213 | )
214 | validated_data['content_type'] = reply_ctype
215 | return super(FollowSerializer, self).create(validated_data)
216 |
--------------------------------------------------------------------------------
/posts/views.py:
--------------------------------------------------------------------------------
1 | import datetime
2 |
3 | from django.db.models import Count, Max
4 | from django.db.models.functions import Coalesce
5 | from django.utils.timezone import now
6 | from django_filters import rest_framework as filters
7 | from rest_framework import permissions, serializers, status, viewsets
8 | from rest_framework.decorators import action
9 | from rest_framework.response import Response
10 |
11 | from replies.serializers import TreeRepliesSerializer
12 | from tags.models import Tag
13 | from .models import Post
14 | from .permissions import IsAdminAuthorOrReadOnly
15 | from .serializers import (
16 | IndexPostListSerializer,
17 | PopularPostSerializer,
18 | PostDetailSerializer,
19 | )
20 |
21 |
22 | class PostViewSet(viewsets.ModelViewSet):
23 | queryset = IndexPostListSerializer.setup_eager_loading(
24 | Post.public.annotate(
25 | latest_post_time=Coalesce(Max('replies__submit_date'), 'created')
26 | ).order_by('-pinned', '-latest_post_time'),
27 | select_related=IndexPostListSerializer.SELECT_RELATED_FIELDS,
28 | prefetch_related=IndexPostListSerializer.PREFETCH_RELATED_FIELDS
29 | )
30 | serializer_class = IndexPostListSerializer
31 | permission_classes = (permissions.IsAuthenticatedOrReadOnly,
32 | IsAdminAuthorOrReadOnly)
33 | # 允许get post put方法
34 | http_method_names = ['get', 'post', 'put', 'patch']
35 | filter_backends = (filters.DjangoFilterBackend,)
36 | # 在post-list页面可以按标签字段过滤出特定标签下的帖子
37 | filter_fields = ('tags',)
38 |
39 | def retrieve(self, request, *args, **kwargs):
40 | """
41 | 重写帖子详情页,这里使用PostDetailSerializer,
42 | 而不是默认的IndexPostListSerializer
43 | """
44 | instance = self.get_object()
45 | instance.increase_views()
46 | instance.refresh_from_db()
47 | serializer = PostDetailSerializer(instance, context={'request': request})
48 | return Response(serializer.data)
49 |
50 | def create(self, request, *args, **kwargs):
51 | """
52 | 重写创建帖子方法,使用PostDetailSerializer
53 | """
54 | serializer = PostDetailSerializer(data=request.data, context={'request': request})
55 | serializer.is_valid(raise_exception=True)
56 | self.perform_create(serializer)
57 | headers = self.get_success_headers(serializer.data)
58 | return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
59 |
60 | def perform_create(self, serializer):
61 | """
62 | 保存tags和author,同时验证tag的数量
63 | tags和author在PostSerializer里是read_only
64 | """
65 | tags_data = self.request.data.get('tags')
66 | tags = []
67 | if not tags_data:
68 | raise serializers.ValidationError(detail={'标签': '请选择至少一个标签'})
69 | elif len(tags_data) > 3:
70 | raise serializers.ValidationError(detail={'标签': '最多可以选择三个标签'})
71 | for name in tags_data:
72 | try:
73 | tag = Tag.objects.get(name=name)
74 | tags.append(tag)
75 | except Exception:
76 | raise serializers.ValidationError(detail={'标签': '标签不存在'})
77 | serializer.save(author=self.request.user, tags=tags)
78 |
79 | def update(self, request, *args, **kwargs):
80 | """
81 | 更新帖子的方法,包括put和patch,
82 | """
83 | partial = kwargs.pop('partial', False)
84 | tags_data = request.data.get('tags')
85 | if partial and tags_data is None:
86 | pass
87 | elif not tags_data:
88 | raise serializers.ValidationError(detail={'标签': '请选择至少一个标签'})
89 | elif len(tags_data) > 3:
90 | raise serializers.ValidationError(detail={'标签': '最多可以选择三个标签'})
91 | instance = self.get_object()
92 | serializer = PostDetailSerializer(
93 | instance,
94 | data=request.data,
95 | partial=partial,
96 | context={'request': request})
97 | serializer.is_valid(raise_exception=True)
98 | self.perform_update(serializer)
99 |
100 | if getattr(instance, '_prefetched_objects_cache', None):
101 | # If 'prefetch_related' has been applied to a queryset, we need to
102 | # forcibly invalidate the prefetch cache on the instance.
103 | instance._prefetched_objects_cache = {}
104 |
105 | return Response(serializer.data)
106 |
107 | def perform_update(self, serializer):
108 | """
109 | 执行更新
110 | """
111 | data = {}
112 | tags = []
113 | # 标签,隐藏,置顶,加精这些字段都是read_only,
114 | # 因此这些字段在修改时需要手动来保存
115 | tags_data = self.request.data.get('tags')
116 | # 如果用户不是管理,则无需提取这些字段
117 | if self.request.user.is_staff:
118 | hidden = self.request.data.get('hidden')
119 | pinned = self.request.data.get('pinned')
120 | highlighted = self.request.data.get('highlighted')
121 | if hidden is not None:
122 | data['hidden'] = hidden
123 | if pinned is not None:
124 | data['pinned'] = pinned
125 | if highlighted is not None:
126 | data['highlighted'] = highlighted
127 | if tags_data:
128 | for name in tags_data:
129 | try:
130 | tag = Tag.objects.get(name=name)
131 | tags.append(tag)
132 | except Exception:
133 | raise serializers.ValidationError(detail={'标签': '标签不存在'})
134 | data['tags'] = tags
135 | serializer.save(**data)
136 |
137 | @action(detail=False, serializer_class=PopularPostSerializer)
138 | def popular(self, request):
139 | """
140 | 返回48小时内评论次数最多的帖子
141 | """
142 | popular_posts = PopularPostSerializer.setup_eager_loading(
143 | Post.public.annotate(
144 | num_replies=Count('replies'),
145 | latest_reply_time=Max('replies__submit_date')
146 | ).filter(
147 | num_replies__gt=0,
148 | latest_reply_time__gt=(now() - datetime.timedelta(days=2)),
149 | latest_reply_time__lt=now()
150 | ).order_by('-num_replies', '-latest_reply_time')[:10],
151 | select_related=PopularPostSerializer.SELECT_RELATED_FIELDS
152 | )
153 |
154 | # return paginated queryset as response data if paginator exists
155 | self.paginator.page_size = 10
156 | page = self.paginate_queryset(popular_posts)
157 | if page is not None:
158 | serializer = self.get_serializer(page, many=True)
159 | return self.get_paginated_response(serializer.data)
160 |
161 | serializer = self.get_serializer(popular_posts, many=True)
162 | return Response(serializer.data)
163 |
164 | @action(methods=['get'], detail=True, serializer_class=TreeRepliesSerializer)
165 | def replies(self, request, pk=None):
166 | post = self.get_object()
167 | replies = post.replies.filter(is_public=True, is_removed=False, parent__isnull=True)
168 | page = self.paginate_queryset(replies)
169 | if page is not None:
170 | serializer = self.get_serializer(page, many=True, context={'request': request})
171 | return self.get_paginated_response(serializer.data)
172 |
173 | serializer = self.get_serializer(replies, many=True, context={'request': request})
174 | return Response(serializer.data)
175 |
--------------------------------------------------------------------------------
/replies/tests/test_views.py:
--------------------------------------------------------------------------------
1 | from actstream.models import Follow
2 | from django.contrib.contenttypes.models import ContentType
3 | from django.contrib.sites.models import Site
4 | from django.urls import reverse
5 | from notifications.models import Notification
6 | from rest_framework import status
7 | from rest_framework import test
8 | from notifications.models import Notification
9 | from actstream.models import Follow
10 |
11 | from posts.models import Post
12 | from users.models import User
13 |
14 | from ..models import Reply
15 |
16 |
17 | class ReplyViewSetsTestCase(test.APITestCase):
18 | def setUp(self):
19 | self.user = User.objects.create_user(
20 | username='test',
21 | email='test@test.com',
22 | password='test',
23 | nickname='test'
24 | )
25 | self.another_user = User.objects.create_user(
26 | username='another',
27 | email='another@test.com',
28 | password='another',
29 | nickname='another'
30 | )
31 | self.post = Post.objects.create(
32 | title='test title',
33 | author=self.another_user
34 | )
35 | self.post_ct = ContentType.objects.get_for_model(self.post)
36 | self.post_id = self.post.id
37 |
38 | self.site = Site.objects.create(name='test', domain='test.com')
39 | self.reply = Reply.objects.create(
40 | content_type=self.post_ct,
41 | object_pk=self.post_id,
42 | site=self.site,
43 | user=self.another_user,
44 | comment='reply',
45 | )
46 |
47 | def test_anonymous_user_can_not_create_reply(self):
48 | url = reverse('reply-list')
49 | data = {
50 | "content_type": self.post_ct.id,
51 | "object_pk": self.post_id,
52 | "site": 1,
53 | "comment": "test comment",
54 | }
55 | response = self.client.post(url, data, format='json')
56 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
57 |
58 | def test_authenticated_user_can_create_reply(self):
59 | url = reverse('reply-list')
60 | data = {
61 | "object_pk": self.post_id,
62 | "comment": "test comment",
63 | }
64 | self.client.login(username='test', password='test')
65 | response = self.client.post(url, data, format='json')
66 | self.assertEqual(response.status_code, status.HTTP_201_CREATED)
67 | self.assertEqual(Reply.objects.count(), 2)
68 | self.assertEqual(Reply.objects.last().user, self.user)
69 | self.assertEqual(Reply.objects.last().comment, 'test comment')
70 |
71 | # 确定生成了通知
72 | self.assertEqual(Notification.objects.count(), 1)
73 |
74 | def test_can_create_child_reply(self):
75 | url = reverse('reply-list')
76 | data = {
77 | "object_pk": self.post_id,
78 | "comment": "test comment",
79 | "parent": self.reply.id,
80 | }
81 | self.client.login(username='test', password='test')
82 | response = self.client.post(url, data, format='json')
83 | self.assertEqual(response.status_code, status.HTTP_201_CREATED)
84 | self.assertEqual(Reply.objects.count(), 2)
85 | self.assertEqual(Reply.objects.last().user, self.user)
86 | self.assertEqual(Reply.objects.last().comment, 'test comment')
87 | self.assertEqual(Reply.objects.last().parent, self.reply)
88 |
89 | # 确定生成了通知
90 | self.assertEqual(Notification.objects.count(), 1)
91 |
92 | def test_reply_only_support_post_method(self):
93 | url = reverse('reply-list')
94 | data = {
95 | "object_pk": self.post_id,
96 | "comment": "test comment",
97 | }
98 | self.client.login(username='test', password='test')
99 |
100 | response = self.client.get(url, format='json')
101 | self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
102 |
103 | response = self.client.put(url, data, format='json')
104 | self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
105 |
106 | response = self.client.patch(url, data, format='json')
107 | self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
108 |
109 | response = self.client.delete(url, format='json')
110 | self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
111 |
112 | def test_anonymous_user_can_not_like_reply(self):
113 | url = reverse('reply-like', kwargs={'pk': self.reply.id})
114 | data = {
115 | "object_id": self.reply.id,
116 | "flag": "like",
117 | }
118 | response = self.client.post(url, data, format='json')
119 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
120 |
121 | def test_user_can_not_like_self_reply(self):
122 | url = reverse('reply-like', kwargs={'pk': self.reply.id})
123 | data = {
124 | "object_id": self.reply.id,
125 | "flag": "like",
126 | }
127 | self.client.login(username='another', password='another')
128 | response = self.client.post(url, data, format='json')
129 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
130 |
131 | def test_user_can_like_others_reply(self):
132 | url = reverse('reply-like', kwargs={'pk': self.reply.id})
133 | data = {
134 | "object_id": self.reply.id,
135 | "flag": "like",
136 | }
137 | self.client.login(username='test', password='test')
138 | response = self.client.post(url, data, format='json')
139 | self.assertEqual(response.status_code, status.HTTP_201_CREATED)
140 | self.assertEqual(Follow.objects.count(), 1)
141 | self.assertEqual(Follow.objects.get().user, self.user)
142 | self.assertEqual(Follow.objects.get().follow_object, self.reply)
143 | self.assertEqual(Follow.objects.get().flag, 'like')
144 |
145 | # 确认生成了通知
146 | self.assertEqual(Notification.objects.count(), 1)
147 | self.assertEqual(Notification.objects.get().recipient, self.reply.user)
148 |
149 | def test_user_can_not_like_same_reply_twice(self):
150 | url = reverse('reply-like', kwargs={'pk': self.reply.id})
151 | data = {
152 | "object_id": self.reply.id,
153 | "flag": "like",
154 | }
155 | self.client.login(username='test', password='test')
156 | response = self.client.post(url, data, format='json')
157 | self.assertEqual(response.status_code, status.HTTP_201_CREATED)
158 |
159 | self.assertEqual(Follow.objects.count(), 1)
160 |
161 | response = self.client.post(url, data, format='json')
162 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
163 | # An error occurred in the current transaction.
164 | # You can't execute queries until the end of the 'atomic' block.
165 | # 暂时不知道如何解决
166 | # self.assertEqual(Follow.objects.count(), 1)
167 |
168 | # 确认生成了通知
169 | # self.assertEqual(Notification.objects.count(), 1)
170 | # self.assertEqual(Notification.objects.get().recipient, self.reply.user)
171 |
172 | def test_user_can_delete_self_like(self):
173 | url = reverse('reply-like', kwargs={'pk': self.reply.id})
174 | data = {
175 | "object_id": self.reply.id,
176 | "flag": "like",
177 | }
178 | self.client.login(username='test', password='test')
179 | response = self.client.post(url, data, format='json')
180 | self.assertEqual(response.status_code, status.HTTP_201_CREATED)
181 | self.assertEqual(Follow.objects.count(), 1)
182 |
183 | response = self.client.delete(url)
184 | self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
185 | self.assertEqual(Follow.objects.count(), 0)
186 |
187 | def test_like_only_support_post_delete_methods(self):
188 | url = reverse('reply-like', kwargs={'pk': self.reply.id})
189 | data = {
190 | "object_id": self.reply.id,
191 | "flag": "like",
192 | }
193 | self.client.login(username='test', password='test')
194 |
195 | response = self.client.get(url, format='json')
196 | self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
197 |
198 | response = self.client.put(url, data, format='json')
199 | self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
200 |
201 | response = self.client.patch(url, data, format='json')
202 | self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
203 |
--------------------------------------------------------------------------------
/users/views.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | from allauth.account.views import ConfirmEmailView as AllAuthConfirmEmailView
5 | from allauth.socialaccount.providers.github.views import GitHubOAuth2Adapter
6 | from allauth.socialaccount.providers.oauth2.client import OAuth2Client
7 | from django.conf import settings
8 | from django.contrib.sites.shortcuts import get_current_site
9 | from django.db.models import Sum
10 | from django.utils import timezone
11 | from rest_auth.registration.views import (
12 | LoginView, RegisterView, SocialConnectView, SocialLoginView)
13 | from rest_framework import mixins, permissions, status, views, viewsets
14 | from rest_framework.decorators import action
15 | from rest_framework.parsers import FileUploadParser
16 | from rest_framework.permissions import AllowAny
17 | from rest_framework.response import Response
18 |
19 | from balance.models import Record
20 | from balance.permissions import IsCurrentUser, OncePerDay
21 | from balance.serializers import BalanceSerializer
22 | from posts.serializers import IndexPostListSerializer
23 | from replies.serializers import FlatReplySerializer
24 |
25 | from .models import User
26 | from .permissions import IsVerified, NotPrimary
27 | from .serializers import EmailAddressSerializer, UserDetailsSerializer
28 |
29 |
30 | class RegisterViewCustom(RegisterView):
31 | """
32 | 注册视图取消authentication_class一次避免CSRF校验
33 | """
34 | authentication_classes = ()
35 |
36 |
37 | class LoginViewCustom(LoginView):
38 | """
39 | 登陆视图取消authentication_class以此避免CSRF校验
40 | """
41 | authentication_classes = ()
42 |
43 |
44 | class ConfirmEmailView(AllAuthConfirmEmailView):
45 | template_name = 'account/email_confirm.html'
46 |
47 |
48 | class GitHubLogin(SocialLoginView):
49 | authentication_classes = ()
50 | adapter_class = GitHubOAuth2Adapter
51 | client_class = OAuth2Client
52 | callback_url = getattr(settings, 'SOCIAL_LOGIN_GITHUB_CALLBACK_URL')
53 |
54 |
55 | class GitHubConnect(SocialConnectView):
56 | adapter_class = GitHubOAuth2Adapter
57 | client_class = OAuth2Client
58 | callback_url = getattr(settings, 'SOCIAL_LOGIN_GITHUB_CALLBACK_URL')
59 |
60 |
61 | class UserViewSets(
62 | mixins.RetrieveModelMixin,
63 | mixins.UpdateModelMixin,
64 | viewsets.GenericViewSet
65 | ):
66 | queryset = User.objects.all()
67 | # TODO: 用户的email等隐私信息需要特殊处理
68 | permission_classes = [AllowAny, ]
69 | serializer_class = UserDetailsSerializer
70 | lookup_value_regex = '[0-9]+'
71 |
72 | def get_permissions(self):
73 | if self.action in ['update', 'partial_update']:
74 | return [permissions.IsAuthenticated(), IsCurrentUser()]
75 | return super().get_permissions()
76 |
77 | @action(methods=['get'], detail=True, serializer_class=FlatReplySerializer)
78 | def replies(self, request, pk=None):
79 | user = self.get_object()
80 | replies = user.reply_comments.filter(is_public=True, is_removed=False)
81 | page = self.paginate_queryset(replies)
82 | if page is not None:
83 | serializer = self.get_serializer(page, many=True, context={'request': request})
84 | return self.get_paginated_response(serializer.data)
85 |
86 | serializer = self.get_serializer(replies, many=True, context={'request': request})
87 | return Response(serializer.data)
88 |
89 | @action(methods=['get'], detail=True, serializer_class=IndexPostListSerializer)
90 | def posts(self, request, pk=None):
91 | user = self.get_object()
92 | posts = user.post_set.all()
93 | hidden = self.request.query_params.get('hidden')
94 | if hidden:
95 | if not self.request.user.is_authenticated:
96 | return Response(status=status.HTTP_401_UNAUTHORIZED)
97 | # 只有用户自己可以查看被隐藏的帖子
98 | if user != self.request.user:
99 | return Response(status=status.HTTP_403_FORBIDDEN)
100 | page = self.paginate_queryset(posts.filter(hidden=True))
101 | if page is not None:
102 | serializer = self.get_serializer(page, many=True, context={'request': request})
103 | return self.get_paginated_response(serializer.data)
104 |
105 | serializer = self.get_serializer(
106 | posts.filter(hidden=True),
107 | many=True,
108 | context={'request': request}
109 | )
110 | return Response(serializer.data)
111 |
112 | page = self.paginate_queryset(posts.filter(hidden=False))
113 | if page is not None:
114 | serializer = self.get_serializer(page, many=True, context={'request': request})
115 | return self.get_paginated_response(serializer.data)
116 | serializer = self.get_serializer(
117 | posts.filter(hidden=False),
118 | many=True,
119 | context={'request': request}
120 | )
121 | return Response(serializer.data)
122 |
123 | @action(
124 | methods=['post'],
125 | detail=True,
126 | permission_classes=[permissions.IsAuthenticated, OncePerDay, IsCurrentUser],
127 | )
128 | def checkin(self, request, pk=None):
129 | user = self.get_object()
130 |
131 | # 生成随机奖励
132 | random_amount = abs(random.gauss(10, 5))
133 | random_amount = math.ceil(random_amount)
134 |
135 | if random_amount == 0:
136 | random_amount += 1
137 |
138 | record = Record.objects.create(
139 | reward_type=0,
140 | coin_type=2,
141 | amount=random_amount,
142 | description='',
143 | user=user
144 | )
145 | serializer = BalanceSerializer(record)
146 | return Response(serializer.data, status=status.HTTP_201_CREATED)
147 |
148 | @action(methods=['get'], detail=True)
149 | def balance(self, request, pk=None):
150 | user = self.get_object()
151 | user_treasure = user.record_set.values('coin_type').annotate(Sum('amount'))
152 | return Response(user_treasure)
153 |
154 | @action(methods=['get'], detail=True,
155 | permission_classes=[permissions.IsAuthenticated, IsCurrentUser], )
156 | def checked(self, request, pk=None):
157 | user = self.get_object()
158 | today_start = timezone.now().replace(hour=0, minute=0, second=0)
159 | today_end = timezone.now().replace(hour=23, minute=59, second=59)
160 | checked = user.record_set.filter(created_time__gt=today_start, created_time__lt=today_end).exists()
161 | return Response({'checked': checked})
162 |
163 |
164 | class MugshotUploadView(views.APIView):
165 | permission_classes = [permissions.IsAuthenticated]
166 | parser_classes = (FileUploadParser,)
167 |
168 | def post(self, request, filename):
169 | if 'file' not in request.FILES:
170 | return Response({
171 | 'file': 'No avatar file selected.'
172 | }, status=status.HTTP_400_BAD_REQUEST)
173 | file_obj = request.FILES['file']
174 |
175 | limit_kb = 2048
176 | if file_obj.size > limit_kb * 1024:
177 | return Response({
178 | 'file': 'File size is too large.'
179 | }, status=status.HTTP_400_BAD_REQUEST)
180 | user = request.user
181 | user.mugshot.name
182 | user.mugshot.save(filename, file_obj)
183 | user.save()
184 | user.refresh_from_db()
185 | return Response({'mugshot_url': user.mugshot.url}, status=200)
186 |
187 |
188 | class EmailAddressViewSet(mixins.ListModelMixin,
189 | mixins.RetrieveModelMixin,
190 | mixins.CreateModelMixin,
191 | mixins.DestroyModelMixin,
192 | viewsets.GenericViewSet):
193 | serializer_class = EmailAddressSerializer
194 | permission_classes = [permissions.IsAuthenticated]
195 |
196 | def get_permissions(self):
197 | if self.action == 'destroy':
198 | return [permissions.IsAuthenticated(), NotPrimary()]
199 | else:
200 | return super(EmailAddressViewSet, self).get_permissions()
201 |
202 | def get_queryset(self):
203 | return self.request.user.emailaddress_set.all()
204 |
205 | def perform_create(self, serializer):
206 | email = serializer.save(user=self.request.user)
207 | email.send_confirmation(request=self.request)
208 |
209 | @action(methods=['post'], detail=True,
210 | permission_classes=[permissions.IsAuthenticated, IsVerified])
211 | def set_primary(self, request, pk=None):
212 | email = self.get_object()
213 | success = email.set_as_primary()
214 |
215 | if success:
216 | return Response(status=status.HTTP_201_CREATED)
217 | else:
218 | return Response(status=status.HTTP_400_BAD_REQUEST)
219 |
220 | @action(methods=['get'], detail=True,
221 | permission_classes=[permissions.IsAuthenticated])
222 | def reverify(self, request, pk=None):
223 | email = self.get_object()
224 | email.send_confirmation(request=self.request)
225 |
226 | return Response(status=status.HTTP_200_OK)
227 |
--------------------------------------------------------------------------------
/notifications_extension/tests/test_views.py:
--------------------------------------------------------------------------------
1 | from actstream.models import Follow
2 | from django.contrib.contenttypes.models import ContentType
3 | from django.contrib.sites.models import Site
4 | from django.urls import reverse
5 | from rest_framework import status
6 | from rest_framework import test
7 | from notifications.models import Notification
8 | from actstream.models import Follow
9 |
10 | from posts.models import Post
11 | from users.models import User
12 |
13 | from replies.models import Reply
14 | from replies.moderation import ReplyModerator
15 |
16 | reply_moderator = ReplyModerator(ReplyModerator)
17 |
18 |
19 | class NotificationViewSetsTestCase(test.APITestCase):
20 | def setUp(self):
21 | self.user = User.objects.create_user(
22 | username='test',
23 | email='test@test.com',
24 | password='test',
25 | nickname='test'
26 | )
27 | self.another_user = User.objects.create_user(
28 | username='another',
29 | email='another@test.com',
30 | password='another',
31 | nickname='another'
32 | )
33 | self.post = Post.objects.create(
34 | title='test title',
35 | author=self.user
36 | )
37 | self.post_ct = ContentType.objects.get_for_model(self.post)
38 | self.post_id = self.post.id
39 |
40 | self.another_post = Post.objects.create(
41 | title='another title',
42 | author=self.another_user
43 | )
44 | self.another_post_ct = ContentType.objects.get_for_model(self.another_post)
45 | self.another_post_id = self.another_post.id
46 |
47 | self.site = Site.objects.create(name='test', domain='test.com')
48 |
49 | # 其他用户评论测试用户的文章
50 | self.reply = Reply.objects.create(
51 | content_type=self.post_ct,
52 | object_pk=self.post_id,
53 | site=self.site,
54 | user=self.another_user,
55 | comment='reply',
56 | )
57 |
58 | # 测试用户评论其他用户的文章
59 | self.another_reply = Reply.objects.create(
60 | content_type=self.another_post_ct,
61 | object_pk=self.another_post_id,
62 | site=self.site,
63 | user=self.user,
64 | comment='reply',
65 | )
66 |
67 | def test_anonymous_user_can_not_get_notifications_method(self):
68 | # 未登录用户无法获取通知列表
69 |
70 | url = reverse('notifications-list')
71 | response = self.client.get(url, format='json')
72 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
73 |
74 | def test_authenticated_user_can_get_all_notifications_method(self):
75 | # 已登录用户可以获取到自己的通知列表
76 |
77 | url = reverse('notifications-list')
78 | data = {
79 | "unread": "all"
80 | }
81 | self.client.login(username='test', password='test')
82 | response = self.client.get(url, data, format='json')
83 | self.assertEqual(response.status_code, status.HTTP_200_OK)
84 |
85 | def test_user_get_notifications_not_others_method(self):
86 | # 已登录用户可以获取到自己的通知列表
87 | # 确认用户访问API获取的就是自己的通知列表,而不是别人的
88 |
89 | reply_moderator.notify(reply=self.reply, content_object=self.post, request=None)
90 | url = reverse('notifications-list')
91 | self.client.login(username='test', password='test')
92 | response = self.client.get(url, format='json')
93 | self.assertEqual(response.status_code, status.HTTP_200_OK)
94 | self.assertEqual(response.data['data'][0]['recipient'], self.user.id)
95 |
96 | def test_authenticated_user_can_get_single_notification_method(self):
97 | # 用户可以获取自己单条通知
98 |
99 | reply_moderator.notify(reply=self.reply, content_object=self.post, request=None)
100 | # reply = Reply.objects.first()
101 | notification = Notification.objects.first()
102 | url = reverse('notifications-detail', kwargs={'pk': notification.id})
103 | self.client.login(username='test', password='test')
104 | response = self.client.get(url, format='json')
105 | self.assertEqual(response.status_code, status.HTTP_200_OK)
106 |
107 | def test_authenticated_user_can_not_get_other_users_single_notification_method(self):
108 | # 用户无法通过API获取属于其它用户的单条通知
109 |
110 | reply_moderator.notify(
111 | reply=self.another_reply,
112 | content_object=self.another_post,
113 | request=None
114 | )
115 | notification = Notification.objects.first()
116 | url = reverse('notifications-detail', kwargs={'pk': notification.id})
117 | self.client.login(username='test', password='test')
118 | response = self.client.get(url, format='json')
119 | self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
120 |
121 | def test_authenticated_user_can_modify_single_notification_to_read_method(self):
122 | # 用户可以将自己的单条通知标为已读
123 |
124 | reply_moderator.notify(reply=self.reply, content_object=self.post, request=None)
125 | notification = Notification.objects.first()
126 | self.assertEqual(notification.unread, True)
127 | url = reverse('notifications-detail', kwargs={'pk': notification.id})
128 | self.client.login(username='test', password='test')
129 | response = self.client.put(url)
130 | self.assertEqual(response.status_code, status.HTTP_200_OK)
131 | notification = Notification.objects.first()
132 | self.assertEqual(notification.unread, False)
133 |
134 | def test_authenticated_user_can_not_modify_other_users_single_notification_to_read_method(self):
135 | # 用户无法通过 API 将其它用户的通知标为已读
136 |
137 | reply_moderator.notify(
138 | reply=self.another_reply,
139 | content_object=self.another_post,
140 | request=None
141 | )
142 | notification = Notification.objects.first()
143 | self.assertEqual(notification.unread, True)
144 | url = reverse('notifications-detail', kwargs={'pk': notification.id})
145 | self.client.login(username='test', password='test')
146 | response = self.client.put(url)
147 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
148 | notification = Notification.objects.first()
149 | self.assertEqual(notification.unread, True)
150 |
151 | def test_authenticated_user_can_delete_single_notification_method(self):
152 | # 用户可以将自己的单条通知删除
153 |
154 | reply_moderator.notify(reply=self.reply, content_object=self.post, request=None)
155 | notification = Notification.objects.first()
156 | self.assertEqual(notification.deleted, False)
157 | url = reverse('notifications-detail', kwargs={'pk': notification.id})
158 | self.client.login(username='test', password='test')
159 | response = self.client.delete(url)
160 | self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
161 | notification = Notification.objects.first()
162 | self.assertEqual(notification.deleted, True)
163 |
164 | def test_authenticated_user_can_delete_other_user_single_notification_method(self):
165 | # 用户无法通过 API 将其它用户的单条通知删除
166 |
167 | reply_moderator.notify(
168 | reply=self.another_reply,
169 | content_object=self.another_post,
170 | request=None
171 | )
172 | notification = Notification.objects.first()
173 | self.assertEqual(notification.deleted, False)
174 | url = reverse('notifications-detail', kwargs={'pk': notification.id})
175 | self.client.login(username='test', password='test')
176 | response = self.client.delete(url)
177 | self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
178 | notification = Notification.objects.first()
179 | self.assertEqual(notification.deleted, False)
180 |
181 | def test_authenticated_user_can_make_all_notification_to_read_method(self):
182 | # 用户可以将自己的全部通知标为已读
183 |
184 | reply_moderator.notify(reply=self.reply, content_object=self.post, request=None)
185 | notification = Notification.objects.first()
186 | self.assertEqual(notification.unread, True) # 验证未读
187 | url = reverse('notifications-mark-all-as-read')
188 | self.client.login(username='test', password='test')
189 | response = self.client.post(url)
190 | self.assertEqual(response.status_code, status.HTTP_200_OK)
191 | notification = Notification.objects.first()
192 | self.assertEqual(notification.unread, False) # 验证已读
193 |
194 | def test_authenticated_user_can_not_make_other_users_all_notification_to_read_method(self):
195 | # 用户可以将自己的全部通知标为已读
196 |
197 | reply_moderator.notify(
198 | reply=self.another_reply,
199 | content_object=self.another_post,
200 | request=None
201 | )
202 | notification = Notification.objects.first()
203 | self.assertEqual(notification.unread, True) # 其他用户的通知 验证未读
204 | url = reverse('notifications-mark-all-as-read')
205 | self.client.login(username='test', password='test')
206 | response = self.client.post(url)
207 | self.assertEqual(response.status_code, status.HTTP_200_OK)
208 | notification = Notification.objects.first()
209 | self.assertEqual(notification.unread, True) # 其他用户的通知 验证未读
210 |
--------------------------------------------------------------------------------
/posts/tests/test_views.py:
--------------------------------------------------------------------------------
1 | from django.urls import reverse
2 |
3 | from rest_framework import status
4 | from rest_framework.test import APITestCase
5 |
6 | from posts.models import Post
7 | from tags.models import Tag
8 | from users.models import User
9 |
10 |
11 | class PostTestCase(APITestCase):
12 | def setUp(self):
13 | self.user = User.objects.create_user(username='test',
14 | email='test@test.com',
15 | password='test',
16 | nickname='test')
17 | self.another_user = User.objects.create_user(username='test2',
18 | email='test2@test.com',
19 | password='test2',
20 | nickname='test2')
21 | self.admin = User.objects.create_superuser(username='admin',
22 | email='admin@admin.com',
23 | password='admin123',
24 | nickname='admin')
25 | self.tag1 = Tag.objects.create(name='test tag1', creator=self.user)
26 | self.tag2 = Tag.objects.create(name='test tag2', creator=self.user)
27 | self.tag3 = Tag.objects.create(name='test tag3', creator=self.user)
28 | self.tag4 = Tag.objects.create(name='test tag4', creator=self.user)
29 |
30 | def test_authenticated_user_can_create_post(self):
31 | """
32 | 测试登录用户可以发帖子,
33 | """
34 | url = reverse('post-list')
35 | data = {
36 | "title": "test title",
37 | "body": "test test test",
38 | "tags": ['test tag1', 'test tag2', 'test tag3']
39 | }
40 | data2 = {
41 | "title": "test title",
42 | "body": "test test test",
43 | }
44 | self.client.login(username='test', password='test')
45 | response = self.client.post(url, data, format='json')
46 | self.assertEqual(response.status_code, status.HTTP_201_CREATED)
47 | self.assertEqual(Post.objects.count(), 1)
48 | self.assertEqual(Post.objects.get().author, self.user)
49 | self.assertEqual(Post.objects.get().body, 'test test test')
50 | self.assertEqual(Post.objects.get().title, 'test title')
51 | self.assertEqual(Post.objects.get().tags.count(), 3)
52 | self.assertEqual(Post.objects.get().pinned, False)
53 | self.assertEqual(Post.objects.get().highlighted, False)
54 | self.assertEqual(Post.objects.get().hidden, False)
55 |
56 | # 当提交的数据不完整时
57 | response = self.client.post(url, data2, formant='json')
58 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
59 |
60 | def test_anonymous_user_cannot_create_post(self):
61 | """
62 | 测试未登录用户 不可以发帖子
63 | """
64 | url = reverse('post-list')
65 | data = {
66 | "title": "test title",
67 | "body": "test test test",
68 | "tags": ['test tag1', 'test tag2', 'test tag3']
69 | }
70 | response = self.client.post(url, data, format='json')
71 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
72 |
73 | def test_post_tags_quantity(self):
74 | """
75 | 测试帖子标签的数量 大于等于1, 小于等于3
76 | 以及标签不存在的情况
77 | 测试方法 包括POST, PUT, PATCH
78 | """
79 | self.post = Post.objects.create(title='this is a test',
80 | body='this is a test',
81 | author=self.user
82 | )
83 | self.post.tags.add(self.tag1)
84 | url1 = reverse('post-list')
85 | url2 = reverse('post-detail', kwargs={'pk': self.post.pk})
86 | data1 = {
87 | "title": "test title",
88 | "body": "test test test",
89 | "tags": []
90 | }
91 | data2 = {
92 | "title": "test title",
93 | "body": "test test test",
94 | "tags": ['test tag1', 'test tag2', 'test tag3', 'test tag4']
95 | }
96 | data3 = {
97 | "tags": []
98 | }
99 | data4 = {
100 | "title": "test title",
101 | "body": "test test test",
102 | "tags": ["tag does not exist"]
103 | }
104 | self.client.login(username='test', password='test')
105 | response = self.client.post(url1, data1, format='json')
106 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
107 |
108 | response = self.client.post(url1, data2, format='json')
109 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
110 |
111 | response = self.client.post(url1, data4, format='json')
112 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
113 |
114 | response = self.client.put(url2, data1, format='json')
115 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
116 |
117 | response = self.client.put(url2, data2, format='json')
118 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
119 |
120 | response = self.client.patch(url2, data3, format='json')
121 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
122 |
123 | response = self.client.patch(url2, data2, format='json')
124 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
125 |
126 | response = self.client.patch(url2, data4, format='json')
127 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
128 |
129 | def test_only_author_admin_can_edit_post(self):
130 | """
131 | 测试只有管理员和作者可以编辑帖子,
132 | 同时作者无法修改hidden, highlighted, pinned字段
133 | """
134 | self.post = Post.objects.create(title='this is a test',
135 | body='this is a test',
136 | author=self.user
137 | )
138 | self.post.tags.add(self.tag1)
139 | url = reverse('post-detail', kwargs={'pk': self.post.pk})
140 | data1 = {
141 | "title": "hello",
142 | "body": "hello world",
143 | "tags": ['test tag1', 'test tag2', 'test tag3']
144 | }
145 | data2 = {
146 | "title": "hello",
147 | "body": "hello world",
148 | "tags": ['test tag1', 'test tag2', 'test tag3'],
149 | "pinned": True,
150 | "highlighted": True
151 | }
152 | data3 = {
153 | "pinned": False,
154 | "highlighted": False
155 | }
156 | data4 = {
157 | "hidden": True,
158 | }
159 | # 管理员
160 | self.client.login(username='admin', password='admin123')
161 | response = self.client.put(url, data1, format='json')
162 | self.assertEqual(response.status_code, status.HTTP_200_OK)
163 | # 测试管理 置顶, 加精帖子
164 | response = self.client.put(url, data2, format='json')
165 | self.assertEqual(response.status_code, status.HTTP_200_OK)
166 | self.assertEqual(Post.objects.get().pinned, True)
167 | self.assertEqual(Post.objects.get().highlighted, True)
168 | response = self.client.patch(url, data3, format='json')
169 | self.assertEqual(response.status_code, status.HTTP_200_OK)
170 | self.assertEqual(Post.objects.get().pinned, False)
171 | self.client.logout()
172 |
173 | # 作者
174 | self.client.login(username='test', password='test')
175 | response = self.client.put(url, data1, format='json')
176 | self.assertEqual(response.status_code, status.HTTP_200_OK)
177 | # 作者加精 置顶 隐藏
178 | response = self.client.put(url, data2, format='json')
179 | self.assertEqual(response.status_code, status.HTTP_200_OK)
180 | self.assertEqual(Post.objects.get().pinned, False)
181 | self.assertEqual(Post.objects.get().highlighted, False)
182 | response = self.client.patch(url, data4, format='json')
183 | self.assertEqual(response.status_code, status.HTTP_200_OK)
184 | self.assertEqual(Post.objects.get().hidden, False)
185 | self.client.logout()
186 |
187 | # 其他用户
188 | self.client.login(username='test2', password='test2')
189 | response = self.client.put(url, data1, format='json')
190 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
191 | response = self.client.put(url, data2, format='json')
192 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
193 | response = self.client.patch(url, data3, format='json')
194 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
195 | self.client.logout()
196 |
197 | # 管理员隐藏帖子
198 | self.client.login(username='admin', password='admin123')
199 | response = self.client.patch(url, data4, format='json')
200 | self.assertEqual(response.status_code, status.HTTP_200_OK)
201 | response = self.client.get(reverse('post-list'), format='json')
202 | self.assertEqual(response.data['count'], 0)
203 |
204 | def test_index_post_list(self):
205 | """
206 | 测试首页列表数量,以及分页情况
207 | """
208 | for i in range(5):
209 | self.post = Post.objects.create(title='this is a test',
210 | body='this is a test',
211 | author=self.user
212 | )
213 | self.post.tags.add(self.tag1)
214 | url = reverse('post-list')
215 | self.client.login(username='test', password='test')
216 | response = self.client.get(url, format='json')
217 | self.assertEqual(response.status_code, status.HTTP_200_OK)
218 | self.assertEqual(response.data['count'], 5)
219 |
220 | url = reverse('post-list') + '?page=2'
221 | response = self.client.get(url, format='json')
222 | self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
223 |
224 | for i in range(16):
225 | self.post = Post.objects.create(title='this is a test',
226 | body='this is a test',
227 | author=self.user
228 | )
229 | self.post.tags.add(self.tag1)
230 |
231 | response = self.client.get(url, format='json')
232 | self.assertEqual(response.status_code, status.HTTP_200_OK)
233 | self.assertEqual(response.data['count'], 21)
234 |
235 | def test_popular_post_list(self):
236 | """
237 | 测试热门帖子
238 | """
239 | for i in range(5):
240 | self.post = Post.objects.create(title='this is a test',
241 | body='this is a test',
242 | author=self.user
243 | )
244 | self.post.tags.add(self.tag1)
245 | url = reverse('post-popular')
246 | response = self.client.get(url, format='json')
247 | self.assertEqual(response.status_code, status.HTTP_200_OK)
248 | self.assertEqual(response.data['count'], 0)
249 | self.assertEqual(response.data['page_size'], 10)
250 |
251 | url = reverse('reply-list')
252 | self.client.login(username='admin', password='admin123')
253 | data = {
254 | "content_type": 19,
255 | "object_pk": self.post.id,
256 | "site": 1,
257 | "comment": "回复测试",
258 | "parent": None
259 | }
260 | response = self.client.post(url, data, format='json')
261 | self.assertEqual(response.status_code, status.HTTP_201_CREATED)
262 |
263 | url = reverse('post-popular')
264 | response = self.client.get(url, format='json')
265 | request = response.wsgi_request
266 | author = Post.objects.get(id=self.post.id).author
267 | self.assertEqual(response.data['count'], 1)
268 | self.assertEqual(response.data['data'][0]['author']['id'], author.id)
269 | self.assertEqual(response.data['data'][0]['author']['mugshot'],
270 | request.build_absolute_uri(author.mugshot.url))
271 | self.assertEqual(response.data['data'][0]
272 | ['author']['nickname'], author.nickname)
273 |
274 | def test_post_detail(self):
275 | """
276 | 测试帖子详情
277 | """
278 | self.post = Post.objects.create(title='this is a test',
279 | body='this is a test',
280 | author=self.user
281 | )
282 | self.post.tags.add(self.tag1)
283 | url = reverse('post-detail', kwargs={'pk': self.post.pk})
284 | response = self.client.get(url, format='json')
285 | self.assertEqual(response.status_code, status.HTTP_200_OK)
286 | self.assertEqual(response.data['id'], self.post.id)
287 | self.assertEqual(response.data['title'], 'this is a test')
288 | self.assertEqual(response.data['body'], 'this is a test')
289 | self.assertEqual(response.data['author']['nickname'], 'test')
290 |
291 | url = reverse('post-detail', kwargs={'pk': self.post.pk + 1})
292 | response = self.client.get(url, format='json')
293 | self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
294 |
--------------------------------------------------------------------------------
/users/tests/unit_tests/test_views.py:
--------------------------------------------------------------------------------
1 | from allauth.account.models import EmailAddress
2 | from django.contrib.contenttypes.models import ContentType
3 | from django.contrib.sites.models import Site
4 | from django.utils.timezone import timedelta
5 | from rest_framework import status, test
6 | from rest_framework.reverse import reverse
7 |
8 | from balance.models import Record
9 | from posts.models import Post
10 | from replies.models import Reply
11 | from replies.serializers import FlatReplySerializer
12 | from ...models import User
13 |
14 |
15 | class UserViewSetTestCase(test.APITestCase):
16 | def setUp(self):
17 | self.user = User.objects.create_user(
18 | username='test',
19 | email='test@test.com',
20 | password='test',
21 | nickname='test'
22 | )
23 |
24 | self.another_user = User.objects.create_user(
25 | username='another',
26 | email='another@test.com',
27 | password='another',
28 | nickname='another'
29 | )
30 |
31 | self.site = Site.objects.create(name='test', domain='test.com')
32 | self.post = Post.objects.create(
33 | title='test title',
34 | author=self.another_user
35 | )
36 | self.post_ct = ContentType.objects.get_for_model(self.post)
37 | self.post_id = self.post.id
38 |
39 | def test_return_user_replies(self):
40 | reply1 = Reply.objects.create(
41 | content_type=self.post_ct,
42 | object_pk=self.post_id,
43 | site=self.site,
44 | user=self.user,
45 | comment='reply1',
46 | )
47 | reply2 = Reply.objects.create(
48 | content_type=self.post_ct,
49 | object_pk=self.post_id,
50 | site=self.site,
51 | user=self.user,
52 | comment='reply2',
53 | parent=reply1
54 | )
55 | reply3 = Reply.objects.create(
56 | content_type=self.post_ct,
57 | object_pk=self.post_id,
58 | site=self.site,
59 | user=self.another_user,
60 | comment='reply3',
61 | parent=reply1
62 | )
63 |
64 | url = reverse('user-replies', kwargs={'pk': self.user.id})
65 | response = self.client.get(url)
66 | self.assertEqual(response.status_code, status.HTTP_200_OK)
67 | self.assertEqual(
68 | response.data['data'],
69 | FlatReplySerializer(
70 | self.user.reply_comments.filter(
71 | is_public=True, is_removed=False),
72 | many=True, context={'request': response.wsgi_request}
73 | ).data
74 | )
75 |
76 | def test_anonymous_user_cannot_checkin(self):
77 | url = reverse('user-checkin', kwargs={'pk': self.user.id})
78 | response = self.client.post(url)
79 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
80 |
81 | def test_current_request_user_can_checkin(self):
82 | url = reverse('user-checkin', kwargs={'pk': self.user.id})
83 | self.client.login(username='test', password='test')
84 | response = self.client.post(url)
85 |
86 | self.assertEqual(response.status_code, status.HTTP_201_CREATED)
87 | self.assertEqual(Record.objects.count(), 1)
88 | self.assertEqual(Record.objects.get().user, self.user)
89 |
90 | def test_non_current_request_user_cannot_checkin(self):
91 | url = reverse('user-checkin', kwargs={'pk': self.another_user.id})
92 | self.client.login(username='test', password='test')
93 | response = self.client.post(url)
94 |
95 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
96 |
97 | def test_user_can_only_chekin_once_per_day(self):
98 | url = reverse('user-checkin', kwargs={'pk': self.user.id})
99 | self.client.login(username='test', password='test')
100 | response = self.client.post(url)
101 | self.assertEqual(response.status_code, status.HTTP_201_CREATED)
102 | self.assertEqual(Record.objects.count(), 1)
103 | self.assertEqual(Record.objects.get().user, self.user)
104 |
105 | # 再一次签到
106 | response = self.client.post(url)
107 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
108 |
109 | def test_user_can_checkin_after_a_day(self):
110 | record = Record.objects.create(
111 | reward_type=0,
112 | coin_type=2,
113 | amount=10,
114 | description='',
115 | user=self.user,
116 | )
117 | record.created_time = record.created_time - timedelta(days=1)
118 | record.save()
119 | record.refresh_from_db()
120 |
121 | url = reverse('user-checkin', kwargs={'pk': self.user.id})
122 | self.client.login(username='test', password='test')
123 | response = self.client.post(url)
124 | self.assertEqual(response.status_code, status.HTTP_201_CREATED)
125 | self.assertEqual(Record.objects.count(), 2)
126 |
127 | def test_user_can_get_non_hidden_posts(self):
128 | Post.objects.create(
129 | title='test',
130 | author=self.user,
131 | )
132 |
133 | Post.objects.create(
134 | title='test2',
135 | author=self.user,
136 | )
137 |
138 | Post.objects.create(
139 | title='test3',
140 | author=self.user,
141 | )
142 |
143 | Post.objects.create(
144 | title='test4',
145 | author=self.user,
146 | hidden=True
147 | )
148 |
149 | url = reverse('user-posts', kwargs={'pk': self.user.id})
150 | response = self.client.get(url)
151 | self.assertEqual(response.status_code, status.HTTP_200_OK)
152 | self.assertEqual(len(response.data['data']), 3)
153 |
154 | def test_user_can_only_get_self_hidden_posts(self):
155 | Post.objects.create(
156 | title='test',
157 | author=self.user,
158 | )
159 |
160 | Post.objects.create(
161 | title='test',
162 | author=self.user,
163 | hidden=True
164 | )
165 |
166 | Post.objects.create(
167 | title='test',
168 | author=self.another_user,
169 | hidden=True
170 | )
171 |
172 | url = reverse('user-posts', kwargs={'pk': self.user.id})
173 | response = self.client.get(url, {'hidden': 'true'})
174 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
175 |
176 | self.client.login(username='test', password='test')
177 | response = self.client.get(url, {'hidden': 'true'})
178 | self.assertEqual(len(response.data['data']), 1)
179 |
180 | other_url = reverse('user-posts', kwargs={'pk': self.another_user.id})
181 | response = self.client.get(other_url, {'hidden': 'true'})
182 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
183 |
184 | def test_can_get_user_treasure(self):
185 | Record.objects.create(
186 | reward_type=0,
187 | coin_type=2,
188 | amount=10,
189 | user=self.user,
190 | )
191 | Record.objects.create(
192 | reward_type=0,
193 | coin_type=2,
194 | amount=25,
195 | user=self.user,
196 | )
197 | Record.objects.create(
198 | reward_type=0,
199 | coin_type=1,
200 | amount=10,
201 | user=self.user,
202 | )
203 | Record.objects.create(
204 | reward_type=0,
205 | coin_type=1,
206 | amount=35,
207 | user=self.user,
208 | )
209 | Record.objects.create(
210 | reward_type=0,
211 | coin_type=0,
212 | amount=35,
213 | user=self.user,
214 | )
215 | Record.objects.create(
216 | reward_type=0,
217 | coin_type=0,
218 | amount=35,
219 | user=self.another_user,
220 | )
221 | url = reverse('user-balance', kwargs={'pk': self.user.id})
222 | response = self.client.get(url)
223 | self.assertEqual(response.status_code, status.HTTP_200_OK)
224 | balance_data = list(response.data)
225 | self.assertTrue({'coin_type': 0, 'amount__sum': 35} in balance_data)
226 | self.assertTrue({'coin_type': 1, 'amount__sum': 45} in balance_data)
227 | self.assertTrue({'coin_type': 2, 'amount__sum': 35} in balance_data)
228 |
229 |
230 | class EmailAddressViewSetTestCase(test.APITestCase):
231 | def setUp(self):
232 | self.user = User.objects.create_user(
233 | username='test',
234 | email='test@test.com',
235 | password='test',
236 | nickname='test'
237 | )
238 |
239 | self.another_user = User.objects.create_user(
240 | username='another',
241 | email='another@test.com',
242 | password='another',
243 | nickname='another'
244 | )
245 |
246 | self.email = EmailAddress.objects.create(
247 | user=self.user,
248 | email=self.user.email,
249 | verified=True,
250 | primary=True
251 | )
252 |
253 | self.unverified_email = EmailAddress.objects.create(
254 | user=self.user,
255 | email='unverified@test.com',
256 | verified=False,
257 | primary=False
258 | )
259 |
260 | self.another_user_email = EmailAddress.objects.create(
261 | user=self.another_user,
262 | email=self.another_user.email,
263 | verified=True,
264 | primary=True
265 | )
266 |
267 | def test_anonymous_user_cannot_operate_email(self):
268 | list_url = reverse('email-list')
269 | retrieve_url = reverse('email-detail', kwargs={'pk': self.email.id})
270 | set_primary_url = reverse(
271 | 'email-set-primary', kwargs={'pk': self.email.id})
272 | reverify_url = reverse('email-reverify', kwargs={'pk': self.email.id})
273 |
274 | response = self.client.get(list_url)
275 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
276 |
277 | response = self.client.get(retrieve_url)
278 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
279 |
280 | response = self.client.post(list_url, data={
281 | 'user': self.user,
282 | 'email': 'new@email.com',
283 | 'verified': True,
284 | 'primary': True
285 | })
286 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
287 |
288 | response = self.client.delete(
289 | list_url, data={'email': self.user.email})
290 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
291 |
292 | response = self.client.post(set_primary_url)
293 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
294 |
295 | response = self.client.get(reverify_url)
296 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
297 |
298 | def test_user_can_get_self_email(self):
299 | url = reverse('email-list')
300 | self.client.login(username='test', password='test')
301 | response = self.client.get(url)
302 |
303 | self.assertEqual(response.status_code, status.HTTP_200_OK)
304 | self.assertEqual(len(response.data), 7)
305 | self.assertTrue(
306 | all([ret['user'] == self.user.id for ret in response.data['data']])
307 | )
308 |
309 | def test_user_cannnot_get_others_email(self):
310 | url = reverse('email-detail',
311 | kwargs={'pk': self.another_user_email.id})
312 | self.client.login(username='test', password='test')
313 | response = self.client.get(url)
314 |
315 | self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
316 |
317 | def test_user_can_add_email(self):
318 | url = reverse('email-list')
319 | self.client.login(username='test', password='test')
320 | response = self.client.post(url, data={
321 | 'user': self.user,
322 | 'email': 'new@email.com',
323 | 'verified': True,
324 | 'primary': True
325 | })
326 | self.assertEqual(response.status_code, status.HTTP_201_CREATED)
327 | self.assertEqual(self.user.emailaddress_set.count(), 3)
328 |
329 | def test_user_cannot_set_unverified_email_to_primary(self):
330 | url = reverse('email-set-primary',
331 | kwargs={'pk': self.unverified_email.id})
332 | self.client.login(username='test', password='test')
333 | response = self.client.post(url)
334 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
335 |
336 | def test_user_can_set_verified_email_to_primary(self):
337 | verified_unprimary_email = EmailAddress.objects.create(
338 | user=self.user,
339 | email='verified_unprimary_email@test.com',
340 | verified=True,
341 | primary=False
342 | )
343 | url = reverse('email-set-primary',
344 | kwargs={'pk': verified_unprimary_email.id})
345 | self.client.login(username='test', password='test')
346 | response = self.client.post(url)
347 | self.assertEqual(response.status_code, status.HTTP_201_CREATED)
348 |
349 | # 新的 primary email 设置成功
350 | new_primary_email = EmailAddress.objects.get(
351 | pk=verified_unprimary_email.id)
352 | self.assertTrue(new_primary_email.primary)
353 |
354 | # 旧的 primary email 被设置为非 primary email
355 | old_primary_email = EmailAddress.objects.get(pk=self.email.id)
356 | self.assertFalse(old_primary_email.primary)
357 |
358 | def test_can_delete_non_primary_email(self):
359 | url = reverse('email-detail', kwargs={'pk': self.unverified_email.id})
360 | self.client.login(username='test', password='test')
361 | response = self.client.delete(url)
362 | self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
363 | self.assertEqual(self.user.emailaddress_set.count(), 1)
364 |
365 | def test_user_cannot_delete_primary_email(self):
366 | url = reverse('email-detail', kwargs={'pk': self.email.id})
367 | self.client.login(username='test', password='test')
368 | response = self.client.delete(url, data={'email': self.email.email})
369 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
370 | self.assertEqual(self.user.emailaddress_set.count(), 2)
371 |
372 | def test_user_cannot_delete_others_email(self):
373 | url = reverse('email-detail',
374 | kwargs={'pk': self.another_user_email.id})
375 | self.client.login(username='test', password='test')
376 | response = self.client.delete(url)
377 | self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
378 | self.assertEqual(self.another_user.emailaddress_set.count(), 1)
379 |
380 | def test_user_can_reverify_email(self):
381 | url = reverse('email-reverify', kwargs={'pk': self.email.id})
382 | self.client.login(username='test', password='test')
383 | response = self.client.get(url)
384 | self.assertEqual(response.status_code, status.HTTP_200_OK)
385 |
--------------------------------------------------------------------------------