├── .coveragerc ├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── balance ├── __init__.py ├── admin.py ├── apps.py ├── factories.py ├── migrations │ ├── 0001_initial.py │ ├── 0002_record_user.py │ └── __init__.py ├── models.py ├── permissions.py ├── serializers.py ├── tests.py └── views.py ├── config ├── __init__.py ├── settings │ ├── __init__.py │ ├── common.py │ ├── local.py │ └── production.py ├── urls.py └── wsgi.py ├── db_tools ├── __init__.py ├── fake_db.py └── fake_db_fast.py ├── manage.py ├── notifications_extension ├── __init__.py ├── admin.py ├── apps.py ├── filters.py ├── migrations │ └── __init__.py ├── models.py ├── serializers.py ├── tests │ ├── __init__.py │ └── test_views.py ├── urls.py └── views.py ├── posts ├── __init__.py ├── admin.py ├── apps.py ├── factories.py ├── migrations │ ├── 0001_initial.py │ ├── 0002_auto_20180424_2035.py │ └── __init__.py ├── models.py ├── permissions.py ├── serializers.py ├── tests │ ├── __init__.py │ ├── test_models.py │ ├── test_serializers.py │ └── test_views.py └── views.py ├── replies ├── __init__.py ├── admin.py ├── apps.py ├── factories.py ├── migrations │ ├── 0001_initial.py │ ├── 0002_reply_user.py │ └── __init__.py ├── models.py ├── moderation.py ├── permissions.py ├── serializers.py ├── tests │ ├── __init__.py │ ├── test_models.py │ ├── test_moderation.py │ ├── test_serializers.py │ └── test_views.py ├── urls.py └── views.py ├── requirements ├── base.txt ├── local.txt └── production.txt ├── runtests.py ├── tags ├── __init__.py ├── admin.py ├── apps.py ├── factories.py ├── migrations │ ├── 0001_initial.py │ ├── 0002_tag_creator.py │ └── __init__.py ├── models.py ├── permissions.py ├── serializers.py ├── tests │ ├── __init__.py │ └── test_views.py └── views.py ├── templates └── users │ ├── account │ ├── base.html │ ├── email.html │ ├── email_confirm.html │ ├── login.html │ ├── password_change.html │ ├── password_reset.html │ ├── password_set.html │ └── signup.html │ └── socialaccount │ └── connections.html ├── users ├── __init__.py ├── admin.py ├── apps.py ├── disable_csrf_middleware.py ├── factories.py ├── jwt_middleware.py ├── migrations │ ├── 0001_initial.py │ └── __init__.py ├── models.py ├── mugshot.py ├── permissions.py ├── serializers.py ├── tests │ ├── __init__.py │ └── unit_tests │ │ ├── __init__.py │ │ ├── test_models.py │ │ ├── test_serializers.py │ │ ├── test_utils.py │ │ └── test_views.py ├── utils.py ├── validators.py └── views.py └── utils ├── __init__.py ├── mixins.py └── rest_tools.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | db_tools/* 4 | manage.py 5 | runtests.py -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | .static_storage/ 58 | .media/ 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | config/settings/.env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | *.sqlite3 108 | .idea/ 109 | media/ 110 | 111 | # vscode 112 | .vscode/ 113 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | python: 4 | - "3.5" 5 | 6 | install: 7 | - pip install -r requirements/local.txt 8 | - pip install coveralls 9 | 10 | env: 11 | - DJANGO_SETTINGS_MODULE=config.settings.local 12 | 13 | services: 14 | - mysql 15 | 16 | script: 17 | - python runtests.py 18 | - coverage run --source=. manage.py test 19 | 20 | after_success: 21 | - coveralls 22 | 23 | branches: 24 | only: 25 | - master 26 | - dev -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2018] [DjangoChinaOrg] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Django中文社区 2 | 3 | [![Build Status](https://travis-ci.org/DjangoChinaOrg/Django-China-API.svg?branch=dev)](https://travis-ci.org/DjangoChinaOrg/Django-China-API) [![Coverage Status](https://coveralls.io/repos/github/DjangoChinaOrg/Django-China-API/badge.svg?branch=dev)](https://coveralls.io/github/DjangoChinaOrg/Django-China-API?branch=dev) 4 | 5 | ### 开发团队 6 | 7 | 见:[组织成员列表](https://github.com/orgs/DjangoChinaOrg/people) 8 | 9 | 团队目前配置: 10 | 11 | - 后端开发 3 人 12 | - 前端开发 1 人 13 | - 产品 1 人 14 | 15 | 如果有兴趣加入,欢迎随时与我们取得联系:djangostudyteam@163.com 16 | 17 | ### 技术栈 18 | 19 | 后端使用 Python 3.5,、Django 1.11、Django REST framework 开发 API 20 | 21 | 前端使用 Vue 2.0 22 | 23 | ### 第一版的样子 24 | 25 | 可能会是这样:www.pythonzh.cn 26 | 27 | ### 如何贡献? 28 | 29 | 我们将开发需求以 issue 的形式发布在项目的 issue 列表并指定了团队中的开发人员。如果您有兴趣参与,可以通过 issue 的讨论功能告知你的参与计划,开发人员就可以着手去实现其它需求,这将大大加快项目的上线速度。 -------------------------------------------------------------------------------- /balance/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/balance/__init__.py -------------------------------------------------------------------------------- /balance/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | 3 | from .models import Record 4 | 5 | 6 | class RecordAdmin(admin.ModelAdmin): 7 | list_display = ( 8 | 'created_time', 9 | 'reward_type', 10 | 'coin_type', 11 | 'amount', 12 | 'description', 13 | 'user', 14 | ) 15 | 16 | 17 | admin.site.register(Record, RecordAdmin) 18 | -------------------------------------------------------------------------------- /balance/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class BalanceConfig(AppConfig): 5 | name = 'balance' 6 | -------------------------------------------------------------------------------- /balance/factories.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import factory 5 | 6 | from users.factories import UserFactory 7 | 8 | from .models import Record 9 | 10 | 11 | class RecordFactory(factory.DjangoModelFactory): 12 | class Meta: 13 | model = Record 14 | 15 | reward_type = 0 16 | coin_type = 2 17 | user = factory.SubFactory(UserFactory) 18 | 19 | @factory.lazy_attribute 20 | def amount(self): 21 | random_amount = abs(random.gauss(10, 5)) 22 | random_amount = math.ceil(random_amount) 23 | 24 | if random_amount == 0: 25 | random_amount += 1 26 | 27 | return random_amount 28 | -------------------------------------------------------------------------------- /balance/migrations/0001_initial.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by Django 1.11 on 2018-04-24 12:35 3 | from __future__ import unicode_literals 4 | 5 | from django.db import migrations, models 6 | import django.utils.timezone 7 | import model_utils.fields 8 | 9 | 10 | class Migration(migrations.Migration): 11 | 12 | initial = True 13 | 14 | dependencies = [ 15 | ] 16 | 17 | operations = [ 18 | migrations.CreateModel( 19 | name='Record', 20 | fields=[ 21 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 22 | ('created_time', model_utils.fields.AutoCreatedField(default=django.utils.timezone.now, editable=False, verbose_name='创建时间')), 23 | ('reward_type', models.IntegerField(choices=[(0, '每日签到奖励')], verbose_name='奖励类型')), 24 | ('coin_type', models.IntegerField(choices=[(0, '金币'), (1, '银币'), (2, '铜币')], verbose_name='钱币类型')), 25 | ('amount', models.PositiveIntegerField(verbose_name='数额')), 26 | ('description', models.CharField(blank=True, max_length=300, verbose_name='描述')), 27 | ], 28 | options={ 29 | 'verbose_name': '奖励记录', 30 | 'verbose_name_plural': '奖励记录', 31 | }, 32 | ), 33 | ] 34 | -------------------------------------------------------------------------------- /balance/migrations/0002_record_user.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by Django 1.11 on 2018-04-24 12:35 3 | from __future__ import unicode_literals 4 | 5 | from django.conf import settings 6 | from django.db import migrations, models 7 | import django.db.models.deletion 8 | 9 | 10 | class Migration(migrations.Migration): 11 | 12 | initial = True 13 | 14 | dependencies = [ 15 | migrations.swappable_dependency(settings.AUTH_USER_MODEL), 16 | ('balance', '0001_initial'), 17 | ] 18 | 19 | operations = [ 20 | migrations.AddField( 21 | model_name='record', 22 | name='user', 23 | field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='用户'), 24 | ), 25 | ] 26 | -------------------------------------------------------------------------------- /balance/migrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/balance/migrations/__init__.py -------------------------------------------------------------------------------- /balance/models.py: -------------------------------------------------------------------------------- 1 | from django.conf import settings 2 | from django.db import models 3 | from model_utils.fields import AutoCreatedField 4 | 5 | 6 | class Record(models.Model): 7 | REWARD_TYPE = ( 8 | (0, '每日签到奖励'), 9 | ) 10 | 11 | COIN_TYPE = ( 12 | (0, '金币'), 13 | (1, '银币'), 14 | (2, '铜币'), 15 | ) 16 | 17 | created_time = AutoCreatedField(verbose_name="创建时间") 18 | reward_type = models.IntegerField(verbose_name="奖励类型", choices=REWARD_TYPE) 19 | coin_type = models.IntegerField(verbose_name="钱币类型", choices=COIN_TYPE) 20 | amount = models.PositiveIntegerField(verbose_name="数额") 21 | description = models.CharField(verbose_name="描述", max_length=300, blank=True) 22 | user = models.ForeignKey(settings.AUTH_USER_MODEL, verbose_name="用户", on_delete=models.CASCADE) 23 | 24 | class Meta: 25 | verbose_name = "奖励记录" 26 | verbose_name_plural = "奖励记录" 27 | 28 | def __str__(self): 29 | return '%s:%s:%s -> %s' % (self.reward_type, self.coin_type, self.amount, self.user) 30 | -------------------------------------------------------------------------------- /balance/permissions.py: -------------------------------------------------------------------------------- 1 | from django.utils import timezone 2 | from rest_framework import permissions 3 | 4 | from balance.models import Record 5 | 6 | 7 | class OncePerDay(permissions.BasePermission): 8 | """ 9 | 一天内(0:00:00-23:59:59)用户只能签到一次 10 | """ 11 | 12 | def has_object_permission(self, request, view, obj): 13 | if request.method in permissions.SAFE_METHODS: 14 | return True 15 | 16 | try: 17 | # 获取用户最近一次签到记录 18 | # 不存在则说明用户从未签到 19 | latest_record = request.user.record_set.latest('created_time') 20 | except Record.DoesNotExist: 21 | return True 22 | 23 | # 获取当天的开始和结束时间 24 | today_start = timezone.now().replace(hour=0, minute=0, second=0) 25 | today_end = timezone.now().replace(hour=23, minute=59, second=59) 26 | 27 | if today_start <= latest_record.created_time <= today_end: 28 | return False 29 | 30 | return True 31 | 32 | 33 | class IsCurrentUser(permissions.BasePermission): 34 | def has_object_permission(self, request, view, obj): 35 | return obj == request.user 36 | -------------------------------------------------------------------------------- /balance/serializers.py: -------------------------------------------------------------------------------- 1 | from rest_framework import serializers 2 | 3 | from balance.models import Record 4 | 5 | 6 | class BalanceSerializer(serializers.ModelSerializer): 7 | class Meta: 8 | model = Record 9 | fields = ( 10 | 'reward_type', 11 | 'coin_type', 12 | 'amount', 13 | 'description', 14 | 'user', 15 | ) 16 | read_only_fields = ( 17 | 'reward_type', 18 | 'coin_type', 19 | 'amount', 20 | 'description', 21 | 'user', 22 | ) 23 | -------------------------------------------------------------------------------- /balance/tests.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | 3 | # Create your tests here. 4 | -------------------------------------------------------------------------------- /balance/views.py: -------------------------------------------------------------------------------- 1 | from rest_framework import viewsets 2 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/config/__init__.py -------------------------------------------------------------------------------- /config/settings/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/config/settings/__init__.py -------------------------------------------------------------------------------- /config/settings/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Django settings for DjangoChina project. 3 | Generated by 'django-admin startproject' using Django 1.11. 4 | For more information on this file, see 5 | https://docs.djangoproject.com/en/1.11/topics/settings/ 6 | For the full list of settings and their values, see 7 | https://docs.djangoproject.com/en/1.11/ref/settings/ 8 | """ 9 | 10 | import datetime 11 | import os 12 | 13 | # Build paths inside the project like this: os.path.join(BASE_DIR, ...) 14 | BASE_DIR = os.path.dirname(os.path.dirname( 15 | os.path.dirname(os.path.abspath(__file__)))) 16 | 17 | # Quick-start development settings - unsuitable for production 18 | # See https://docs.djangoproject.com/en/1.11/howto/deployment/checklist/ 19 | 20 | # SECURITY WARNING: don't run with debug turned on in production! 21 | DEBUG = True 22 | 23 | ALLOWED_HOSTS = [] 24 | 25 | SITE_ID = 1 26 | 27 | LOGIN_URL = '/accounts/login/' 28 | LOGIN_REDIRECT_URL = '/' 29 | 30 | # Application definition 31 | 32 | INSTALLED_APPS = [ 33 | 'django.contrib.admin', 34 | 'django.contrib.auth', 35 | 'django.contrib.contenttypes', 36 | 'django.contrib.sessions', 37 | 'django.contrib.messages', 38 | 'django.contrib.staticfiles', 39 | 'django.contrib.sites', 40 | 41 | # third-party apps 42 | 'bootstrapform', 43 | 'notifications', 44 | 'rest_framework', 45 | 'rest_framework.authtoken', 46 | 'rest_auth', 47 | 'allauth', 48 | 'allauth.account', 49 | 'allauth.socialaccount', 50 | 'allauth.socialaccount.providers.github', 51 | 'rest_auth.registration', 52 | 'django_comments', 53 | 'actstream', 54 | 'django_filters', 55 | 'corsheaders', 56 | 'raven.contrib.django.raven_compat', # sentry support 57 | 58 | # local apps 59 | 'users', 60 | 'posts', 61 | 'replies', 62 | 'tags', 63 | 'balance', 64 | 'notifications_extension', 65 | ] 66 | 67 | COMMENTS_APP = 'replies' 68 | 69 | MIDDLEWARE = [ 70 | 'corsheaders.middleware.CorsMiddleware', 71 | 'django.middleware.security.SecurityMiddleware', 72 | 'django.contrib.sessions.middleware.SessionMiddleware', 73 | 'django.middleware.common.CommonMiddleware', 74 | # 'django.middleware.csrf.CsrfViewMiddleware', 75 | 'django.contrib.auth.middleware.AuthenticationMiddleware', 76 | 'users.disable_csrf_middleware.DisableCSRFCheck', 77 | 'users.jwt_middleware.JWTMiddleware', 78 | 'django.contrib.messages.middleware.MessageMiddleware', 79 | 'django.middleware.clickjacking.XFrameOptionsMiddleware', 80 | ] 81 | 82 | CORS_ORIGIN_ALLOW_ALL = True # 跨域 83 | 84 | ROOT_URLCONF = 'config.urls' 85 | 86 | TEMPLATES = [ 87 | { 88 | 'BACKEND': 'django.template.backends.django.DjangoTemplates', 89 | 'DIRS': [ 90 | os.path.join(BASE_DIR, 'templates'), 91 | os.path.join(BASE_DIR, 'templates', 'users') 92 | ], 93 | 'APP_DIRS': True, 94 | 'OPTIONS': { 95 | 'context_processors': [ 96 | 'django.template.context_processors.debug', 97 | 'django.template.context_processors.request', 98 | 'django.contrib.auth.context_processors.auth', 99 | 'django.contrib.messages.context_processors.messages', 100 | ], 101 | }, 102 | }, 103 | ] 104 | 105 | WSGI_APPLICATION = 'config.wsgi.application' 106 | 107 | # Password validation 108 | # https://docs.djangoproject.com/en/1.11/ref/settings/#auth-password-validators 109 | 110 | AUTH_PASSWORD_VALIDATORS = [ 111 | { 112 | 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', 113 | }, 114 | { 115 | 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', 116 | }, 117 | { 118 | 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', 119 | }, 120 | { 121 | 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', 122 | }, 123 | ] 124 | 125 | # Internationalization 126 | # https://docs.djangoproject.com/en/1.11/topics/i18n/ 127 | 128 | LANGUAGE_CODE = 'zh-hans' 129 | 130 | TIME_ZONE = 'Asia/Shanghai' 131 | 132 | USE_I18N = True 133 | 134 | USE_L10N = True 135 | 136 | USE_TZ = True 137 | 138 | # Static files (CSS, JavaScript, Images) 139 | # https://docs.djangoproject.com/en/1.11/howto/static-files/ 140 | 141 | STATIC_URL = '/static/' 142 | STATIC_ROOT = os.path.join(BASE_DIR, 'static') 143 | 144 | MEDIA_ROOT = os.path.join(BASE_DIR, 'media') 145 | MEDIA_URL = '/media/' 146 | 147 | AUTH_USER_MODEL = 'users.User' 148 | 149 | # django-rest-framework settings 150 | REST_FRAMEWORK = { 151 | 'DEFAULT_PERMISSION_CLASSES': ( 152 | 'rest_framework.permissions.IsAuthenticated', 153 | ), 154 | 'DEFAULT_AUTHENTICATION_CLASSES': ( 155 | 'rest_framework_jwt.authentication.JSONWebTokenAuthentication', 156 | 'rest_framework.authentication.SessionAuthentication', 157 | 'rest_framework.authentication.BasicAuthentication' 158 | ), 159 | 'DEFAULT_PAGINATION_CLASS': 'utils.rest_tools.CustomPageNumberPagination', 160 | } 161 | 162 | # django-allauth settings 163 | AUTHENTICATION_BACKENDS = ( 164 | # Needed to login by username in Django admin, regardless of `allauth` 165 | 'django.contrib.auth.backends.ModelBackend', 166 | 167 | # `allauth` specific authentication methods, such as login by e-mail 168 | 'allauth.account.auth_backends.AuthenticationBackend', 169 | ) 170 | 171 | # django-rest-auth settings 172 | REST_AUTH_SERIALIZERS = { 173 | 'USER_DETAILS_SERIALIZER': 'users.serializers.UserDetailsSerializer', 174 | } 175 | 176 | REST_AUTH_REGISTER_SERIALIZERS = { 177 | 'REGISTER_SERIALIZER': 'users.serializers.UserRegistrationSerializer', 178 | } 179 | 180 | # djangorestframework-jwt settings 181 | JWT_AUTH = { 182 | 'JWT_EXPIRATION_DELTA': datetime.timedelta(seconds=60 * 30), 183 | 'JWT_AUTH_HEADER_PREFIX': 'Bearer', 184 | 'JWT_ALLOW_REFRESH': True, 185 | } 186 | 187 | if DEBUG: 188 | JWT_AUTH['JWT_EXPIRATION_DELTA'] = datetime.timedelta(days=1) 189 | 190 | # django-all-auth setting 191 | REST_USE_JWT = True 192 | ACCOUNT_AUTHENTICATION_METHOD = 'username_email' 193 | ACCOUNT_EMAIL_REQUIRED = True 194 | ACCOUNT_EMAIL_VERIFICATION = True 195 | LOGIN_ON_EMAIL_CONFIRMATION = True 196 | SOCIALACCOUNT_EMAIL_VERIFICATION = False 197 | OLD_PASSWORD_FIELD_ENABLED = True 198 | 199 | # 软删除 200 | NOTIFICATIONS_SOFT_DELETE = True 201 | 202 | SOCIAL_LOGIN_GITHUB_CALLBACK_URL = 'dummy' 203 | -------------------------------------------------------------------------------- /config/settings/local.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .common import * 4 | 5 | DEBUG = True 6 | ALLOWED_HOSTS = ['*'] 7 | 8 | SECRET_KEY = 't3l=1)%^^ftao(2_@p^j_$ordrl4rg4-0z1w@^gvvi64balvbx' 9 | 10 | EMAIL_BACKEND = 'django.core.mail.backends.console.EmailBackend' 11 | 12 | # envs 13 | # MYSQL_HOST = os.getenv('MYSQL_HOST', '127.0.0.1') 14 | # MYSQL_DB_NAME = os.getenv('MYSQL_MYSQL_DB_NAME', 'django') 15 | # MYSQL_DB_USER = os.getenv('MYSQL_MYSQL_DB_USER', 'root') 16 | # MYSQL_PASSWORD = os.getenv('MYSQL_PASSWORD', '') 17 | 18 | # # database 19 | # DATABASES['default'].update( 20 | # {'HOST': MYSQL_HOST, 21 | # 'NAME': MYSQL_DB_NAME, 22 | # 'USER': MYSQL_DB_USER, 23 | # 'PASSWORD': MYSQL_PASSWORD, 24 | # }) 25 | 26 | # sqlite3 27 | DATABASES = { 28 | 'default': { 29 | 'ENGINE': 'django.db.backends.sqlite3', 30 | 'NAME': os.path.join(BASE_DIR, 'dev.sqlite3'), 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /config/settings/production.py: -------------------------------------------------------------------------------- 1 | from .common import * 2 | import environ 3 | 4 | ALLOWED_HOSTS = ['.dj-china.org', 'localhost', '127.0.0.1', '0.0.0.0'] 5 | 6 | env = environ.Env( 7 | # set casting, default value 8 | DEBUG=(bool, True), 9 | SOCIAL_LOGIN_GITHUB_CALLBACK_URL=(str, 'http://127.0.0.1:8000/social-auth/github/loginsuccess') 10 | ) 11 | 12 | environ.Env.read_env() 13 | DEBUG = env('DEBUG') # default False 14 | SECRET_KEY = env('SECRET_KEY') 15 | 16 | # sentry dsn 17 | RAVEN_CONFIG = { 18 | 'dsn': env('SENTRY_DSN'), 19 | } 20 | 21 | # import sentry_sdk 22 | # from sentry_sdk.integrations.django import DjangoIntegration 23 | # 24 | # sentry_sdk.init( 25 | # dsn=env('SENTRY_DSN'), 26 | # integrations=[DjangoIntegration()] 27 | # ) 28 | 29 | # GitHub 登录 30 | SOCIAL_LOGIN_GITHUB_CALLBACK_URL = env('SOCIAL_LOGIN_GITHUB_CALLBACK_URL') 31 | 32 | # mysql 33 | DATABASES = { 34 | 'default': { 35 | 'ENGINE': 'django.db.backends.mysql', 36 | 'NAME': env('MYSQL_NAME'), 37 | 'USER': env('MYSQL_USER'), 38 | 'PASSWORD': env('MYSQL_PASSWORD'), 39 | 'HOST': 'localhost', 40 | 'PORT': '3306', 41 | 'OPTIONS': { 42 | 'autocommit': True, 43 | 'init_command': "SET sql_mode='STRICT_TRANS_TABLES'", 44 | 'charset': 'utf8mb4', 45 | }, 46 | 'TEST': { 47 | 'NAME': 'django_test', 48 | 'CHARSET': 'utf8', 49 | 'COLLATION': 'utf8_general_ci', 50 | } 51 | } 52 | } 53 | 54 | # 邮件配置,使用腾讯云企业邮箱 55 | EMAIL_BACKEND = 'django.core.mail.backends.smtp.EmailBackend' 56 | EMAIL_HOST = 'smtp.exmail.qq.com' 57 | EMAIL_PORT = 465 58 | EMAIL_USE_SSL = True 59 | EMAIL_USE_LOCALTIME = True 60 | EMAIL_HOST_USER = env('EMAIL_HOST_USER') 61 | EMAIL_HOST_PASSWORD = env('EMAIL_HOST_PASSWORD') 62 | 63 | DEFAULT_FROM_EMAIL = 'Django中文社区 <%s>' % EMAIL_HOST_USER 64 | SERVER_EMAIL = EMAIL_HOST_USER 65 | -------------------------------------------------------------------------------- /config/urls.py: -------------------------------------------------------------------------------- 1 | """DjangoChina URL Configuration 2 | 3 | The `urlpatterns` list routes URLs to views. For more information please see: 4 | https://docs.djangoproject.com/en/1.11/topics/http/urls/ 5 | Examples: 6 | Function views 7 | 1. Add an import: from my_app import views 8 | 2. Add a URL to urlpatterns: url(r'^$', views.home, name='home') 9 | Class-based views 10 | 1. Add an import: from other_app.views import Home 11 | 2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home') 12 | Including another URLconf 13 | 1. Import the include() function: from django.conf.urls import url, include 14 | 2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls')) 15 | """ 16 | from django.conf import settings 17 | from django.conf.urls import include, url 18 | from django.conf.urls.static import static 19 | from django.contrib import admin 20 | from rest_framework.documentation import include_docs_urls 21 | from rest_framework.routers import DefaultRouter 22 | from rest_framework_jwt.views import refresh_jwt_token 23 | from rest_auth.registration.views import ( 24 | SocialAccountListView, SocialAccountDisconnectView 25 | ) 26 | 27 | from notifications_extension.views import NotificationViewSet 28 | from posts.views import PostViewSet 29 | from replies.views import ReplyViewSet 30 | from tags.views import TagViewSet 31 | from users.views import ( 32 | EmailAddressViewSet, 33 | LoginViewCustom, 34 | RegisterViewCustom, 35 | ConfirmEmailView, 36 | MugshotUploadView, 37 | UserViewSets, 38 | GitHubLogin, 39 | GitHubConnect 40 | ) 41 | 42 | router = DefaultRouter() 43 | router.register(r'posts', PostViewSet) 44 | router.register(r'tags', TagViewSet) 45 | router.register(r'replies', ReplyViewSet) 46 | router.register(r'users', UserViewSets) 47 | router.register(r'users/email', EmailAddressViewSet, base_name='email') 48 | router.register(r'notifications', NotificationViewSet, base_name='notifications') 49 | 50 | urlpatterns = [ 51 | url(r'^admin/', admin.site.urls), 52 | url(r'^accounts/', include('allauth.urls')), 53 | # url(r'^users/mugshot/(?P[^/]+)$', MugshotUploadView.as_view()), 54 | url(r'^rest-auth/login/$', LoginViewCustom.as_view(), name='rest_login'), 55 | url(r'^rest-auth/registration/$', RegisterViewCustom.as_view(), name='rest_register'), 56 | url(r'^rest-auth/registration/account-confirm-email/(?P[-:\w]+)/$', 57 | ConfirmEmailView.as_view(), 58 | name='account_confirm_email'), 59 | url(r'^rest-auth/github/login/$', GitHubLogin.as_view(), name='github_login'), 60 | url(r'^rest-auth/github/connect/$', GitHubConnect.as_view(), name='github_connect'), 61 | url(r'^rest-auth/socialaccounts/$', 62 | SocialAccountListView.as_view(), 63 | name='social_account_list'), 64 | url(r'^rest-auth/socialaccounts/(?P\d+)/disconnect/$', 65 | SocialAccountDisconnectView.as_view(), 66 | name='social_account_disconnect'), 67 | url(r'^rest-auth/jwt-refresh/', refresh_jwt_token), 68 | url(r'^rest-auth/', include('rest_auth.urls')), 69 | url(r'^rest-auth/registration/', include('rest_auth.registration.urls')), 70 | url(r'^api-auth/', include('rest_framework.urls')), # 仅仅用于测试 71 | url(r'^', include(router.urls)), 72 | url(r'^docs/', include_docs_urls(title='Django中文社区 API')) 73 | ] + static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) 74 | -------------------------------------------------------------------------------- /config/wsgi.py: -------------------------------------------------------------------------------- 1 | """ 2 | WSGI config for DjangoChina project. 3 | 4 | It exposes the WSGI callable as a module-level variable named ``application``. 5 | 6 | For more information on this file, see 7 | https://docs.djangoproject.com/en/1.11/howto/deployment/wsgi/ 8 | """ 9 | 10 | import os 11 | 12 | from django.core.wsgi import get_wsgi_application 13 | 14 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings.production") 15 | 16 | application = get_wsgi_application() 17 | -------------------------------------------------------------------------------- /db_tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/db_tools/__init__.py -------------------------------------------------------------------------------- /db_tools/fake_db.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | 5 | BASE_DIR = os.path.dirname((os.path.dirname(os.path.abspath(__file__)))) 6 | sys.path.append(BASE_DIR) 7 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings.local") 8 | 9 | import django 10 | 11 | django.setup() 12 | 13 | from users.factories import UserFactory 14 | from users.models import User 15 | from tags.factories import TagFactory 16 | from tags.models import Tag 17 | from replies.factories import PostReplyFactory, SiteFactory 18 | from replies.models import Reply 19 | from posts.factories import PostFactory 20 | from posts.models import Post 21 | from balance.factories import RecordFactory 22 | 23 | from allauth.account.models import EmailAddress 24 | from rest_framework.test import APIClient 25 | from django.urls import reverse 26 | from django.contrib.contenttypes.models import ContentType 27 | from faker import Faker 28 | 29 | if __name__ == '__main__': 30 | site = SiteFactory() 31 | c = APIClient() 32 | fake = Faker() 33 | # 生成 10 个用户 34 | UserFactory.create_batch(10) 35 | print('users created...') 36 | 37 | # 生成 10 个标签 38 | user = UserFactory(username='admin') 39 | TagFactory.create_batch(10, creator=user) 40 | print('tags created...') 41 | 42 | tags = list(Tag.objects.all()) 43 | users = list(User.objects.all()) 44 | 45 | # 每个用户发布一定量的帖子 46 | for user in users: 47 | post_count = random.randint(10, 15) 48 | 49 | for i in range(post_count): 50 | tag_count = random.randint(1, 3) 51 | tag_sample = random.sample(tags, tag_count) 52 | PostFactory(author=user, tags=tag_sample) 53 | 54 | # 绑定一个 email 55 | EmailAddress.objects.create( 56 | user=user, 57 | email=user.email, 58 | verified=True, 59 | primary=True 60 | ) 61 | # 获得一定的财富 62 | RecordFactory.create_batch(random.randint(0, 10), user=user) 63 | 64 | print('posts posted...') 65 | 66 | posts = list(Post.objects.all()) 67 | 68 | # 随机选择一些用户对每个帖子进行回复,使用 client 发送请求, 69 | # 这样回复时会生成相应通知 70 | for post in posts: 71 | post_ct = ContentType.objects.get_for_model(post) 72 | post_id = post.id 73 | url = reverse('reply-list') 74 | 75 | user_sample = random.sample(users, random.randint(0, 10)) 76 | for user in user_sample: 77 | data = { 78 | "content_type": post_ct.id, 79 | "object_pk": post_id, 80 | "site": 1, 81 | "comment": fake.text(max_nb_chars=200, ext_word_list=None) 82 | } 83 | c.force_login(user) 84 | c.post(url, data, format='json') 85 | 86 | print('post replies created...') 87 | 88 | # 再随机选择一些用户对某些回复进行回复 89 | for i in range(3): 90 | replies = list(Reply.objects.all()) 91 | for reply in replies: 92 | indicator = random.randint(0, 2) 93 | if indicator: # 以 1/3 的概率回复 94 | continue 95 | post = reply.content_object 96 | post_ct = ContentType.objects.get_for_model(post) 97 | post_id = post.id 98 | url = reverse('reply-list') 99 | user_sample = random.sample(users, random.randint(1, 3)) 100 | 101 | for user in user_sample: 102 | data = { 103 | "content_type": post_ct.id, 104 | "object_pk": post_id, 105 | "site": 1, 106 | "comment": fake.text(max_nb_chars=200, ext_word_list=None), 107 | "parent": reply.id 108 | } 109 | c.force_login(user) 110 | c.post(url, data, format='json') 111 | print('reply replies created...') 112 | 113 | # 再随机选择一些用户对回复进行点赞 114 | replies = list(Reply.objects.all()) 115 | for i in range(5): 116 | for reply in replies: 117 | indicator = random.randint(0, 2) 118 | if indicator: # 以 1/3 的概率回复 119 | continue 120 | url = reverse('reply-like', kwargs={'pk': reply.id}) 121 | user_sample = random.sample(users, random.randint(1, 3)) 122 | 123 | for user in user_sample: 124 | data = { 125 | "content_type": reply.ctype_id, 126 | "object_id": reply.id, 127 | "flag": "like", 128 | } 129 | c.force_login(user) 130 | c.post(url, data, format='json') 131 | print('liked!') 132 | print('done...!') 133 | -------------------------------------------------------------------------------- /db_tools/fake_db_fast.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | 5 | BASE_DIR = os.path.dirname((os.path.dirname(os.path.abspath(__file__)))) 6 | sys.path.append(BASE_DIR) 7 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings.local") 8 | 9 | import django 10 | 11 | django.setup() 12 | 13 | from users.factories import UserFactory 14 | from users.models import User 15 | from tags.factories import TagFactory 16 | from tags.models import Tag 17 | from replies.factories import PostReplyFactory, SiteFactory 18 | from replies.models import Reply 19 | from posts.factories import PostFactory 20 | from posts.models import Post 21 | from balance.factories import RecordFactory 22 | 23 | from allauth.account.models import EmailAddress 24 | from rest_framework.test import APIClient 25 | from django.urls import reverse 26 | from django.contrib.contenttypes.models import ContentType 27 | from faker import Faker 28 | 29 | if __name__ == '__main__': 30 | site = SiteFactory() 31 | c = APIClient() 32 | fake = Faker() 33 | # 生成 5 个用户 34 | UserFactory.create_batch(5) 35 | print('users created...') 36 | 37 | # 生成 5 个标签 38 | user = UserFactory(username='admin') 39 | TagFactory.create_batch(5, creator=user) 40 | print('tags created...') 41 | 42 | tags = list(Tag.objects.all()) 43 | users = list(User.objects.all()) 44 | 45 | # 每个用户发布一定量的帖子 46 | for user in users: 47 | post_count = random.randint(5, 10) 48 | 49 | for i in range(post_count): 50 | tag_count = random.randint(1, 3) 51 | tag_sample = random.sample(tags, tag_count) 52 | PostFactory(author=user, tags=tag_sample) 53 | 54 | # 绑定一个 email 55 | EmailAddress.objects.create( 56 | user=user, 57 | email=user.email, 58 | verified=True, 59 | primary=True 60 | ) 61 | # 获得一定的财富 62 | RecordFactory.create_batch(random.randint(0, 3), user=user) 63 | 64 | print('posts posted...') 65 | 66 | posts = list(Post.objects.all()) 67 | 68 | # 随机选择一些用户对每个帖子进行回复,使用 client 发送请求, 69 | # 这样回复时会生成相应通知 70 | for post in posts: 71 | post_ct = ContentType.objects.get_for_model(post) 72 | post_id = post.id 73 | url = reverse('reply-list') 74 | 75 | user_sample = random.sample(users, random.randint(0, 5)) 76 | for user in user_sample: 77 | data = { 78 | "content_type": post_ct.id, 79 | "object_pk": post_id, 80 | "site": 1, 81 | "comment": fake.text(max_nb_chars=200, ext_word_list=None) 82 | } 83 | c.force_login(user) 84 | c.post(url, data, format='json') 85 | 86 | print('post replies created...') 87 | 88 | # 再随机选择一些用户对某些回复进行回复 89 | for i in range(2): 90 | replies = list(Reply.objects.all()) 91 | for reply in replies: 92 | indicator = random.randint(0, 2) 93 | if indicator: # 以 1/3 的概率回复 94 | continue 95 | post = reply.content_object 96 | post_ct = ContentType.objects.get_for_model(post) 97 | post_id = post.id 98 | url = reverse('reply-list') 99 | user_sample = random.sample(users, random.randint(1, 3)) 100 | 101 | for user in user_sample: 102 | data = { 103 | "content_type": post_ct.id, 104 | "object_pk": post_id, 105 | "site": 1, 106 | "comment": fake.text(max_nb_chars=200, ext_word_list=None), 107 | "parent": reply.id 108 | } 109 | c.force_login(user) 110 | c.post(url, data, format='json') 111 | print('reply replies created...') 112 | 113 | # 再随机选择一些用户对回复进行点赞 114 | replies = list(Reply.objects.all()) 115 | for i in range(2): 116 | for reply in replies: 117 | indicator = random.randint(0, 2) 118 | if indicator: # 以 1/3 的概率回复 119 | continue 120 | url = reverse('reply-like', kwargs={'pk': reply.id}) 121 | user_sample = random.sample(users, random.randint(1, 3)) 122 | 123 | for user in user_sample: 124 | data = { 125 | "content_type": reply.ctype_id, 126 | "object_id": reply.id, 127 | "flag": "like", 128 | } 129 | c.force_login(user) 130 | c.post(url, data, format='json') 131 | print('liked!') 132 | print('done...!') -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | 5 | if __name__ == "__main__": 6 | env = os.environ.get('DJANGO_ENV', 'local') 7 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", 8 | "config.settings.{}".format(env)) 9 | try: 10 | from django.core.management import execute_from_command_line 11 | except ImportError: 12 | # The above import may fail for some other reason. Ensure that the 13 | # issue is really that Django is missing to avoid masking other 14 | # exceptions on Python 2. 15 | try: 16 | import django 17 | except ImportError: 18 | raise ImportError( 19 | "Couldn't import Django. Are you sure it's installed and " 20 | "available on your PYTHONPATH environment variable? Did you " 21 | "forget to activate a virtual environment?" 22 | ) 23 | raise 24 | execute_from_command_line(sys.argv) 25 | -------------------------------------------------------------------------------- /notifications_extension/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/notifications_extension/__init__.py -------------------------------------------------------------------------------- /notifications_extension/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | 3 | # from notifications.models import Notification 4 | # 5 | # admin.site.register(Notification) 6 | -------------------------------------------------------------------------------- /notifications_extension/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class NotificationExtensionConfig(AppConfig): 5 | name = 'notifications_extension' 6 | -------------------------------------------------------------------------------- /notifications_extension/filters.py: -------------------------------------------------------------------------------- 1 | import django_filters 2 | 3 | from notifications.models import Notification 4 | 5 | 6 | class NotificationFilter(django_filters.rest_framework.FilterSet): 7 | unread = django_filters.filters.CharFilter(method='unread_filter') 8 | 9 | def unread_filter(self, queryset, name, value): 10 | if value == 'true': 11 | return queryset.filter(unread='True') 12 | elif value == 'false': 13 | return queryset.filter(unread='False') 14 | elif value == 'all': 15 | return queryset 16 | else: 17 | return queryset.none() 18 | 19 | class Meta: 20 | model = Notification 21 | fields = ['unread', 'verb'] 22 | -------------------------------------------------------------------------------- /notifications_extension/migrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/notifications_extension/migrations/__init__.py -------------------------------------------------------------------------------- /notifications_extension/models.py: -------------------------------------------------------------------------------- 1 | from django.db import models 2 | 3 | # Create your models here. 4 | -------------------------------------------------------------------------------- /notifications_extension/serializers.py: -------------------------------------------------------------------------------- 1 | from rest_framework import serializers 2 | from notifications.models import Notification 3 | 4 | from users.serializers import UserSimpleDetailsSerializer 5 | 6 | 7 | class NotificationSerializer(serializers.ModelSerializer): 8 | """ 9 | 目前通知共有 3 种: 10 | 1. 帖子被评论,帖子作者收到通知,通知模型各字段含义为: 11 | recipient:帖子作者 12 | actor:回复者 13 | target:帖子 14 | action_object:新评论 15 | verb: 'reply' 16 | 17 | 2. 帖子的回复被其他人回复,即回复别人的回复,被回复者收到通知: 18 | recipient:被回复者 19 | actor:回复者 20 | target:帖子 21 | action_object:新回复 22 | verb: 'respond' 23 | 24 | 3. 回复被点赞: 25 | recipient:被赞者 26 | actor:回复者 27 | target:被赞的回复 28 | action_object:被赞的回复所属的帖子 29 | verb: 'like' 30 | """ 31 | actor = UserSimpleDetailsSerializer() 32 | post = serializers.SerializerMethodField() 33 | reply = serializers.SerializerMethodField() 34 | 35 | def get_reply(self, obj): 36 | if obj.verb == 'like': 37 | reply = obj.target 38 | return { 39 | 'comment': reply.comment 40 | } 41 | elif obj.verb == 'reply': 42 | reply = obj.action_object 43 | return { 44 | 'comment': reply.comment 45 | } 46 | elif obj.verb == 'respond': 47 | reply = obj.action_object 48 | return { 49 | 'comment': reply.comment 50 | } 51 | 52 | def get_post(self, obj): 53 | if obj.verb == 'like': 54 | post = obj.action_object 55 | return { 56 | 'post_id': post.id, 57 | 'post_title': post.title 58 | } 59 | else: 60 | post = obj.target 61 | return { 62 | 'post_id': post.id, 63 | 'post_title': post.title 64 | } 65 | 66 | class Meta: 67 | model = Notification 68 | fields = ('id', 'unread', 'actor', 'verb', 'timestamp', 'deleted', 'recipient', 'post', 'reply') 69 | -------------------------------------------------------------------------------- /notifications_extension/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/notifications_extension/tests/__init__.py -------------------------------------------------------------------------------- /notifications_extension/tests/test_views.py: -------------------------------------------------------------------------------- 1 | from actstream.models import Follow 2 | from django.contrib.contenttypes.models import ContentType 3 | from django.contrib.sites.models import Site 4 | from django.urls import reverse 5 | from rest_framework import status 6 | from rest_framework import test 7 | from notifications.models import Notification 8 | from actstream.models import Follow 9 | 10 | from posts.models import Post 11 | from users.models import User 12 | 13 | from replies.models import Reply 14 | from replies.moderation import ReplyModerator 15 | 16 | reply_moderator = ReplyModerator(ReplyModerator) 17 | 18 | 19 | class NotificationViewSetsTestCase(test.APITestCase): 20 | def setUp(self): 21 | self.user = User.objects.create_user( 22 | username='test', 23 | email='test@test.com', 24 | password='test', 25 | nickname='test' 26 | ) 27 | self.another_user = User.objects.create_user( 28 | username='another', 29 | email='another@test.com', 30 | password='another', 31 | nickname='another' 32 | ) 33 | self.post = Post.objects.create( 34 | title='test title', 35 | author=self.user 36 | ) 37 | self.post_ct = ContentType.objects.get_for_model(self.post) 38 | self.post_id = self.post.id 39 | 40 | self.another_post = Post.objects.create( 41 | title='another title', 42 | author=self.another_user 43 | ) 44 | self.another_post_ct = ContentType.objects.get_for_model(self.another_post) 45 | self.another_post_id = self.another_post.id 46 | 47 | self.site = Site.objects.create(name='test', domain='test.com') 48 | 49 | # 其他用户评论测试用户的文章 50 | self.reply = Reply.objects.create( 51 | content_type=self.post_ct, 52 | object_pk=self.post_id, 53 | site=self.site, 54 | user=self.another_user, 55 | comment='reply', 56 | ) 57 | 58 | # 测试用户评论其他用户的文章 59 | self.another_reply = Reply.objects.create( 60 | content_type=self.another_post_ct, 61 | object_pk=self.another_post_id, 62 | site=self.site, 63 | user=self.user, 64 | comment='reply', 65 | ) 66 | 67 | def test_anonymous_user_can_not_get_notifications_method(self): 68 | # 未登录用户无法获取通知列表 69 | 70 | url = reverse('notifications-list') 71 | response = self.client.get(url, format='json') 72 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) 73 | 74 | def test_authenticated_user_can_get_all_notifications_method(self): 75 | # 已登录用户可以获取到自己的通知列表 76 | 77 | url = reverse('notifications-list') 78 | data = { 79 | "unread": "all" 80 | } 81 | self.client.login(username='test', password='test') 82 | response = self.client.get(url, data, format='json') 83 | self.assertEqual(response.status_code, status.HTTP_200_OK) 84 | 85 | def test_user_get_notifications_not_others_method(self): 86 | # 已登录用户可以获取到自己的通知列表 87 | # 确认用户访问API获取的就是自己的通知列表,而不是别人的 88 | 89 | reply_moderator.notify(reply=self.reply, content_object=self.post, request=None) 90 | url = reverse('notifications-list') 91 | self.client.login(username='test', password='test') 92 | response = self.client.get(url, format='json') 93 | self.assertEqual(response.status_code, status.HTTP_200_OK) 94 | self.assertEqual(response.data['data'][0]['recipient'], self.user.id) 95 | 96 | def test_authenticated_user_can_get_single_notification_method(self): 97 | # 用户可以获取自己单条通知 98 | 99 | reply_moderator.notify(reply=self.reply, content_object=self.post, request=None) 100 | # reply = Reply.objects.first() 101 | notification = Notification.objects.first() 102 | url = reverse('notifications-detail', kwargs={'pk': notification.id}) 103 | self.client.login(username='test', password='test') 104 | response = self.client.get(url, format='json') 105 | self.assertEqual(response.status_code, status.HTTP_200_OK) 106 | 107 | def test_authenticated_user_can_not_get_other_users_single_notification_method(self): 108 | # 用户无法通过API获取属于其它用户的单条通知 109 | 110 | reply_moderator.notify( 111 | reply=self.another_reply, 112 | content_object=self.another_post, 113 | request=None 114 | ) 115 | notification = Notification.objects.first() 116 | url = reverse('notifications-detail', kwargs={'pk': notification.id}) 117 | self.client.login(username='test', password='test') 118 | response = self.client.get(url, format='json') 119 | self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) 120 | 121 | def test_authenticated_user_can_modify_single_notification_to_read_method(self): 122 | # 用户可以将自己的单条通知标为已读 123 | 124 | reply_moderator.notify(reply=self.reply, content_object=self.post, request=None) 125 | notification = Notification.objects.first() 126 | self.assertEqual(notification.unread, True) 127 | url = reverse('notifications-detail', kwargs={'pk': notification.id}) 128 | self.client.login(username='test', password='test') 129 | response = self.client.put(url) 130 | self.assertEqual(response.status_code, status.HTTP_200_OK) 131 | notification = Notification.objects.first() 132 | self.assertEqual(notification.unread, False) 133 | 134 | def test_authenticated_user_can_not_modify_other_users_single_notification_to_read_method(self): 135 | # 用户无法通过 API 将其它用户的通知标为已读 136 | 137 | reply_moderator.notify( 138 | reply=self.another_reply, 139 | content_object=self.another_post, 140 | request=None 141 | ) 142 | notification = Notification.objects.first() 143 | self.assertEqual(notification.unread, True) 144 | url = reverse('notifications-detail', kwargs={'pk': notification.id}) 145 | self.client.login(username='test', password='test') 146 | response = self.client.put(url) 147 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) 148 | notification = Notification.objects.first() 149 | self.assertEqual(notification.unread, True) 150 | 151 | def test_authenticated_user_can_delete_single_notification_method(self): 152 | # 用户可以将自己的单条通知删除 153 | 154 | reply_moderator.notify(reply=self.reply, content_object=self.post, request=None) 155 | notification = Notification.objects.first() 156 | self.assertEqual(notification.deleted, False) 157 | url = reverse('notifications-detail', kwargs={'pk': notification.id}) 158 | self.client.login(username='test', password='test') 159 | response = self.client.delete(url) 160 | self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) 161 | notification = Notification.objects.first() 162 | self.assertEqual(notification.deleted, True) 163 | 164 | def test_authenticated_user_can_delete_other_user_single_notification_method(self): 165 | # 用户无法通过 API 将其它用户的单条通知删除 166 | 167 | reply_moderator.notify( 168 | reply=self.another_reply, 169 | content_object=self.another_post, 170 | request=None 171 | ) 172 | notification = Notification.objects.first() 173 | self.assertEqual(notification.deleted, False) 174 | url = reverse('notifications-detail', kwargs={'pk': notification.id}) 175 | self.client.login(username='test', password='test') 176 | response = self.client.delete(url) 177 | self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) 178 | notification = Notification.objects.first() 179 | self.assertEqual(notification.deleted, False) 180 | 181 | def test_authenticated_user_can_make_all_notification_to_read_method(self): 182 | # 用户可以将自己的全部通知标为已读 183 | 184 | reply_moderator.notify(reply=self.reply, content_object=self.post, request=None) 185 | notification = Notification.objects.first() 186 | self.assertEqual(notification.unread, True) # 验证未读 187 | url = reverse('notifications-mark-all-as-read') 188 | self.client.login(username='test', password='test') 189 | response = self.client.post(url) 190 | self.assertEqual(response.status_code, status.HTTP_200_OK) 191 | notification = Notification.objects.first() 192 | self.assertEqual(notification.unread, False) # 验证已读 193 | 194 | def test_authenticated_user_can_not_make_other_users_all_notification_to_read_method(self): 195 | # 用户可以将自己的全部通知标为已读 196 | 197 | reply_moderator.notify( 198 | reply=self.another_reply, 199 | content_object=self.another_post, 200 | request=None 201 | ) 202 | notification = Notification.objects.first() 203 | self.assertEqual(notification.unread, True) # 其他用户的通知 验证未读 204 | url = reverse('notifications-mark-all-as-read') 205 | self.client.login(username='test', password='test') 206 | response = self.client.post(url) 207 | self.assertEqual(response.status_code, status.HTTP_200_OK) 208 | notification = Notification.objects.first() 209 | self.assertEqual(notification.unread, True) # 其他用户的通知 验证未读 210 | -------------------------------------------------------------------------------- /notifications_extension/urls.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/notifications_extension/urls.py -------------------------------------------------------------------------------- /notifications_extension/views.py: -------------------------------------------------------------------------------- 1 | from rest_framework import permissions, viewsets, mixins 2 | from notifications.models import Notification 3 | from rest_framework import filters 4 | from django_filters.rest_framework import DjangoFilterBackend 5 | from rest_framework.response import Response 6 | from rest_framework import status 7 | from rest_framework.decorators import action 8 | 9 | from .serializers import NotificationSerializer 10 | from .filters import NotificationFilter 11 | 12 | 13 | class NotificationViewSet(mixins.RetrieveModelMixin, 14 | mixins.UpdateModelMixin, 15 | mixins.DestroyModelMixin, 16 | mixins.ListModelMixin, 17 | viewsets.GenericViewSet): 18 | serializer_class = NotificationSerializer 19 | permission_classes = [permissions.IsAuthenticated, ] 20 | filter_backends = (DjangoFilterBackend, filters.OrderingFilter) 21 | ordering_fields = ('timestamp',) 22 | filter_class = NotificationFilter # 过滤器 23 | 24 | def get_queryset(self): 25 | return Notification.objects.filter(recipient=self.request.user).active() 26 | 27 | def perform_destroy(self, instance): 28 | pk = self.kwargs['pk'] 29 | instance = Notification.objects.get(id=pk) 30 | instance.deleted = True 31 | instance.save() 32 | 33 | def update(self, request, *args, **kwargs): 34 | pk = self.kwargs['pk'] 35 | instance = Notification.objects.get(id=pk) 36 | if instance.recipient != request.user: 37 | return Response(status=status.HTTP_403_FORBIDDEN) 38 | instance.unread = False 39 | instance.save() 40 | return Response(status=status.HTTP_200_OK) 41 | 42 | @action(methods=['post'], detail=False) 43 | def mark_all_as_read(self, request): 44 | Notification.objects.filter(recipient=request.user).mark_all_as_read(recipient=request.user) 45 | return Response(status=status.HTTP_200_OK) 46 | -------------------------------------------------------------------------------- /posts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/posts/__init__.py -------------------------------------------------------------------------------- /posts/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | 3 | from .models import Post 4 | 5 | admin.site.register(Post) 6 | -------------------------------------------------------------------------------- /posts/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class PostsConfig(AppConfig): 5 | name = 'posts' 6 | -------------------------------------------------------------------------------- /posts/factories.py: -------------------------------------------------------------------------------- 1 | import factory 2 | 3 | from users.factories import UserFactory 4 | 5 | from .models import Post 6 | 7 | 8 | class PostFactory(factory.DjangoModelFactory): 9 | title = factory.Faker('sentence') 10 | body = factory.Faker('text') 11 | author = factory.SubFactory(UserFactory) 12 | 13 | class Meta: 14 | model = Post 15 | 16 | @factory.post_generation 17 | def tags(self, create, extracted, **kwargs): 18 | if not create: 19 | # Simple build, do nothing. 20 | return 21 | 22 | if extracted: 23 | # A list of groups were passed in, use them 24 | for tag in extracted: 25 | self.tags.add(tag) 26 | -------------------------------------------------------------------------------- /posts/migrations/0001_initial.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by Django 1.11 on 2018-04-24 12:35 3 | from __future__ import unicode_literals 4 | 5 | from django.db import migrations, models 6 | import django.utils.timezone 7 | import model_utils.fields 8 | 9 | 10 | class Migration(migrations.Migration): 11 | 12 | initial = True 13 | 14 | dependencies = [ 15 | ] 16 | 17 | operations = [ 18 | migrations.CreateModel( 19 | name='Post', 20 | fields=[ 21 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 22 | ('created', model_utils.fields.AutoCreatedField(default=django.utils.timezone.now, editable=False, verbose_name='created')), 23 | ('modified', model_utils.fields.AutoLastModifiedField(default=django.utils.timezone.now, editable=False, verbose_name='modified')), 24 | ('title', models.CharField(max_length=255, verbose_name='标题')), 25 | ('body', models.TextField(blank=True, verbose_name='正文')), 26 | ('views', models.PositiveIntegerField(default=0, editable=False, verbose_name='浏览量')), 27 | ('pinned', models.BooleanField(default=False, verbose_name='置顶')), 28 | ('highlighted', models.BooleanField(default=False, verbose_name='加精')), 29 | ('hidden', models.BooleanField(default=False, verbose_name='隐藏')), 30 | ], 31 | options={ 32 | 'verbose_name': '帖子', 33 | 'verbose_name_plural': '帖子', 34 | }, 35 | ), 36 | ] 37 | -------------------------------------------------------------------------------- /posts/migrations/0002_auto_20180424_2035.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by Django 1.11 on 2018-04-24 12:35 3 | from __future__ import unicode_literals 4 | 5 | from django.conf import settings 6 | from django.db import migrations, models 7 | import django.db.models.deletion 8 | 9 | 10 | class Migration(migrations.Migration): 11 | 12 | initial = True 13 | 14 | dependencies = [ 15 | migrations.swappable_dependency(settings.AUTH_USER_MODEL), 16 | ('posts', '0001_initial'), 17 | ('tags', '0001_initial'), 18 | ] 19 | 20 | operations = [ 21 | migrations.AddField( 22 | model_name='post', 23 | name='author', 24 | field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='作者'), 25 | ), 26 | migrations.AddField( 27 | model_name='post', 28 | name='tags', 29 | field=models.ManyToManyField(to='tags.Tag', verbose_name='标签'), 30 | ), 31 | ] 32 | -------------------------------------------------------------------------------- /posts/migrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/posts/migrations/__init__.py -------------------------------------------------------------------------------- /posts/models.py: -------------------------------------------------------------------------------- 1 | from django.conf import settings 2 | from django.contrib.contenttypes.fields import GenericRelation 3 | from django.db import models 4 | from model_utils.models import TimeStampedModel 5 | 6 | from replies.models import Reply 7 | 8 | 9 | class PublicManager(models.Manager): 10 | def get_queryset(self): 11 | return super(PublicManager, self) \ 12 | .get_queryset().filter(hidden=False).order_by('-created_time') 13 | 14 | 15 | class Post(TimeStampedModel): 16 | title = models.CharField("标题", max_length=255) 17 | body = models.TextField("正文", blank=True) 18 | views = models.PositiveIntegerField("浏览量", default=0, editable=False) 19 | pinned = models.BooleanField("置顶", default=False) 20 | highlighted = models.BooleanField("加精", default=False) 21 | hidden = models.BooleanField("隐藏", default=False) 22 | tags = models.ManyToManyField('tags.Tag', verbose_name="标签") 23 | author = models.ForeignKey( 24 | settings.AUTH_USER_MODEL, 25 | verbose_name="作者", 26 | on_delete=models.CASCADE 27 | ) 28 | replies = GenericRelation( 29 | Reply, 30 | object_id_field='object_pk', 31 | content_type_field='content_type', 32 | verbose_name="回复" 33 | ) 34 | 35 | objects = models.Manager() 36 | # 未隐藏的帖子 37 | public = PublicManager() 38 | 39 | class Meta: 40 | verbose_name = "帖子" 41 | verbose_name_plural = "帖子" 42 | 43 | def __str__(self): 44 | return self.title 45 | 46 | def increase_views(self): 47 | self.views += 1 48 | self.save(update_fields=['views']) 49 | -------------------------------------------------------------------------------- /posts/permissions.py: -------------------------------------------------------------------------------- 1 | from rest_framework import permissions 2 | 3 | 4 | class IsAdminAuthorOrReadOnly(permissions.BasePermission): 5 | """ 6 | 允许普通用户编辑自己的帖子, 管理员可以编辑所有帖子 7 | """ 8 | def has_object_permission(self, request, view, obj): 9 | return (request.method in permissions.SAFE_METHODS or 10 | request.user.is_staff or 11 | request.user == obj.author) 12 | -------------------------------------------------------------------------------- /posts/serializers.py: -------------------------------------------------------------------------------- 1 | from django.contrib.contenttypes.models import ContentType 2 | from django.db.models import Prefetch 3 | from rest_framework import serializers 4 | 5 | from tags.serializers import TagSerializer 6 | from utils.mixins import EagerLoaderMixin 7 | from .models import Post, Reply 8 | 9 | 10 | class IndexPostListSerializer(serializers.HyperlinkedModelSerializer, EagerLoaderMixin): 11 | """ 12 | 首页帖子列表序列化器 13 | """ 14 | author = serializers.SerializerMethodField() 15 | reply_count = serializers.SerializerMethodField() 16 | tags = TagSerializer(many=True, read_only=True) 17 | latest_reply_time = serializers.SerializerMethodField() 18 | 19 | SELECT_RELATED_FIELDS = ['author'] 20 | PREFETCH_RELATED_FIELDS = [ 21 | 'tags', 22 | Prefetch('replies', queryset=Reply.objects.order_by('-submit_date')) 23 | ] 24 | 25 | class Meta: 26 | model = Post 27 | fields = ( 28 | 'id', 29 | 'url', 30 | 'title', 31 | 'views', 32 | 'created', 33 | 'modified', 34 | 'latest_reply_time', 35 | 'pinned', 36 | 'highlighted', 37 | 'tags', 38 | 'author', 39 | 'reply_count', 40 | ) 41 | 42 | def get_author(self, obj): 43 | author = obj.author 44 | request = self.context.get('request') 45 | url = author.mugshot.url 46 | thumbnail_url = author.mugshot_thumbnail.url 47 | return { 48 | 'id': author.id, 49 | 'mugshot': request.build_absolute_uri(url) if request else url, 50 | 'mugshot_url': request.build_absolute_uri(thumbnail_url) if request else url, 51 | 'nickname': author.nickname, 52 | } 53 | 54 | def get_reply_count(self, obj): 55 | """ 56 | 返回帖子的回复数量 57 | """ 58 | return obj.replies.count() 59 | 60 | def get_latest_reply_time(self, obj): 61 | """ 62 | 返回最后一次评论的时间, 63 | 如果没有评论,返回null 64 | """ 65 | replies = obj.replies.all() 66 | if replies: 67 | return replies[0].submit_date 68 | else: 69 | return None 70 | 71 | 72 | class PopularPostSerializer(serializers.HyperlinkedModelSerializer, EagerLoaderMixin): 73 | """ 74 | 热门帖子序列化器 75 | """ 76 | author = serializers.SerializerMethodField() 77 | 78 | SELECT_RELATED_FIELDS = ['author'] 79 | 80 | class Meta: 81 | model = Post 82 | fields = ( 83 | 'id', 84 | 'url', 85 | 'title', 86 | 'author', 87 | ) 88 | 89 | def get_author(self, obj): 90 | author = obj.author 91 | request = self.context.get('request') 92 | url = author.mugshot.url 93 | thumbnail_url = author.mugshot_thumbnail.url 94 | return { 95 | 'id': author.id, 96 | 'mugshot': request.build_absolute_uri(url) if request else url, 97 | 'mugshot_url': request.build_absolute_uri(thumbnail_url) if request else url, 98 | 'nickname': author.nickname, 99 | } 100 | 101 | 102 | class PostDetailSerializer(IndexPostListSerializer): 103 | """ 104 | 用来显示帖子详情,已经用来创建、修改帖子的序列化器 105 | """ 106 | author = serializers.SerializerMethodField() 107 | participants_count = serializers.SerializerMethodField() 108 | content_type = serializers.SerializerMethodField() 109 | 110 | class Meta: 111 | model = Post 112 | fields = ( 113 | 'id', 114 | 'content_type', 115 | 'title', 116 | 'author', 117 | 'views', 118 | 'created', 119 | 'modified', 120 | 'body', 121 | 'tags', 122 | 'reply_count', 123 | 'participants_count', 124 | ) 125 | 126 | def get_content_type(self, obj): 127 | """ 128 | 帖子的content_type 129 | """ 130 | content_type = ContentType.objects.get_for_model(obj) 131 | return content_type.id 132 | 133 | def get_participants_count(self, obj): 134 | """ 135 | 返回评论参与者数量 136 | """ 137 | return obj.replies.values_list('user', flat=True).distinct().count() 138 | -------------------------------------------------------------------------------- /posts/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/posts/tests/__init__.py -------------------------------------------------------------------------------- /posts/tests/test_models.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | 3 | from ..models import Post 4 | from users.models import User 5 | from tags.models import Tag 6 | 7 | 8 | class PostModelTests(TestCase): 9 | """ 10 | 测试Post的objects和public manager 11 | """ 12 | def setUp(self): 13 | self.user = User.objects.create_user(username='test', 14 | email='test@test.com', 15 | password='test', 16 | nickname='test') 17 | self.tag = Tag.objects.create(name='test_tag', 18 | creator=self.user) 19 | 20 | def test_managers(self): 21 | self.post1 = Post.objects.create(title='test title first', 22 | body='first test body', 23 | author=self.user 24 | ) 25 | self.post1.tags.add(self.tag) 26 | self.post2 = Post.objects.create(title='test title second', 27 | body='second test body', 28 | author=self.user, 29 | hidden=True 30 | ) 31 | self.post2.tags.add(self.tag) 32 | self.assertEqual(Post.objects.all().count(), 2) 33 | self.assertEqual(Post.public.all().count(), 1) 34 | self.assertEqual(Post.public.get().hidden, False) 35 | -------------------------------------------------------------------------------- /posts/tests/test_serializers.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | from django.contrib.contenttypes.models import ContentType 3 | 4 | from ..models import Post 5 | from ..serializers import IndexPostListSerializer 6 | from users.models import User 7 | from tags.models import Tag 8 | 9 | 10 | class PostSerializerTests(TestCase): 11 | """ 12 | 测试帖子序列化器 13 | """ 14 | def setUp(self): 15 | self.user = User.objects.create_user(username='test', 16 | email='test@test.com', 17 | password='test', 18 | nickname='test') 19 | self.tag1 = Tag.objects.create(name='test_tag', 20 | creator=self.user) 21 | self.tag2 = Tag.objects.create(name='another test_tag', 22 | creator=self.user) 23 | self.post = Post.objects.create(title='test title first', 24 | body='first test body', 25 | author=self.user) 26 | self.post.tags.add(self.tag1) 27 | self.post.tags.add(self.tag2) 28 | 29 | def test_post_serializer_detail(self): 30 | serializer = IndexPostListSerializer(self.post, context={'request': None}) 31 | data = serializer.data 32 | self.assertEqual(data['id'], self.post.id) 33 | self.assertEqual(data['title'], self.post.title) 34 | self.assertEqual(data['views'], self.post.views) 35 | self.assertEqual(data['pinned'], self.post.pinned) 36 | self.assertEqual(data['highlighted'], self.post.highlighted) 37 | self.assertEqual(data['reply_count'], self.post.replies.count()) 38 | -------------------------------------------------------------------------------- /posts/tests/test_views.py: -------------------------------------------------------------------------------- 1 | from django.urls import reverse 2 | 3 | from rest_framework import status 4 | from rest_framework.test import APITestCase 5 | 6 | from posts.models import Post 7 | from tags.models import Tag 8 | from users.models import User 9 | 10 | 11 | class PostTestCase(APITestCase): 12 | def setUp(self): 13 | self.user = User.objects.create_user(username='test', 14 | email='test@test.com', 15 | password='test', 16 | nickname='test') 17 | self.another_user = User.objects.create_user(username='test2', 18 | email='test2@test.com', 19 | password='test2', 20 | nickname='test2') 21 | self.admin = User.objects.create_superuser(username='admin', 22 | email='admin@admin.com', 23 | password='admin123', 24 | nickname='admin') 25 | self.tag1 = Tag.objects.create(name='test tag1', creator=self.user) 26 | self.tag2 = Tag.objects.create(name='test tag2', creator=self.user) 27 | self.tag3 = Tag.objects.create(name='test tag3', creator=self.user) 28 | self.tag4 = Tag.objects.create(name='test tag4', creator=self.user) 29 | 30 | def test_authenticated_user_can_create_post(self): 31 | """ 32 | 测试登录用户可以发帖子, 33 | """ 34 | url = reverse('post-list') 35 | data = { 36 | "title": "test title", 37 | "body": "test test test", 38 | "tags": ['test tag1', 'test tag2', 'test tag3'] 39 | } 40 | data2 = { 41 | "title": "test title", 42 | "body": "test test test", 43 | } 44 | self.client.login(username='test', password='test') 45 | response = self.client.post(url, data, format='json') 46 | self.assertEqual(response.status_code, status.HTTP_201_CREATED) 47 | self.assertEqual(Post.objects.count(), 1) 48 | self.assertEqual(Post.objects.get().author, self.user) 49 | self.assertEqual(Post.objects.get().body, 'test test test') 50 | self.assertEqual(Post.objects.get().title, 'test title') 51 | self.assertEqual(Post.objects.get().tags.count(), 3) 52 | self.assertEqual(Post.objects.get().pinned, False) 53 | self.assertEqual(Post.objects.get().highlighted, False) 54 | self.assertEqual(Post.objects.get().hidden, False) 55 | 56 | # 当提交的数据不完整时 57 | response = self.client.post(url, data2, formant='json') 58 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) 59 | 60 | def test_anonymous_user_cannot_create_post(self): 61 | """ 62 | 测试未登录用户 不可以发帖子 63 | """ 64 | url = reverse('post-list') 65 | data = { 66 | "title": "test title", 67 | "body": "test test test", 68 | "tags": ['test tag1', 'test tag2', 'test tag3'] 69 | } 70 | response = self.client.post(url, data, format='json') 71 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) 72 | 73 | def test_post_tags_quantity(self): 74 | """ 75 | 测试帖子标签的数量 大于等于1, 小于等于3 76 | 以及标签不存在的情况 77 | 测试方法 包括POST, PUT, PATCH 78 | """ 79 | self.post = Post.objects.create(title='this is a test', 80 | body='this is a test', 81 | author=self.user 82 | ) 83 | self.post.tags.add(self.tag1) 84 | url1 = reverse('post-list') 85 | url2 = reverse('post-detail', kwargs={'pk': self.post.pk}) 86 | data1 = { 87 | "title": "test title", 88 | "body": "test test test", 89 | "tags": [] 90 | } 91 | data2 = { 92 | "title": "test title", 93 | "body": "test test test", 94 | "tags": ['test tag1', 'test tag2', 'test tag3', 'test tag4'] 95 | } 96 | data3 = { 97 | "tags": [] 98 | } 99 | data4 = { 100 | "title": "test title", 101 | "body": "test test test", 102 | "tags": ["tag does not exist"] 103 | } 104 | self.client.login(username='test', password='test') 105 | response = self.client.post(url1, data1, format='json') 106 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) 107 | 108 | response = self.client.post(url1, data2, format='json') 109 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) 110 | 111 | response = self.client.post(url1, data4, format='json') 112 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) 113 | 114 | response = self.client.put(url2, data1, format='json') 115 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) 116 | 117 | response = self.client.put(url2, data2, format='json') 118 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) 119 | 120 | response = self.client.patch(url2, data3, format='json') 121 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) 122 | 123 | response = self.client.patch(url2, data2, format='json') 124 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) 125 | 126 | response = self.client.patch(url2, data4, format='json') 127 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) 128 | 129 | def test_only_author_admin_can_edit_post(self): 130 | """ 131 | 测试只有管理员和作者可以编辑帖子, 132 | 同时作者无法修改hidden, highlighted, pinned字段 133 | """ 134 | self.post = Post.objects.create(title='this is a test', 135 | body='this is a test', 136 | author=self.user 137 | ) 138 | self.post.tags.add(self.tag1) 139 | url = reverse('post-detail', kwargs={'pk': self.post.pk}) 140 | data1 = { 141 | "title": "hello", 142 | "body": "hello world", 143 | "tags": ['test tag1', 'test tag2', 'test tag3'] 144 | } 145 | data2 = { 146 | "title": "hello", 147 | "body": "hello world", 148 | "tags": ['test tag1', 'test tag2', 'test tag3'], 149 | "pinned": True, 150 | "highlighted": True 151 | } 152 | data3 = { 153 | "pinned": False, 154 | "highlighted": False 155 | } 156 | data4 = { 157 | "hidden": True, 158 | } 159 | # 管理员 160 | self.client.login(username='admin', password='admin123') 161 | response = self.client.put(url, data1, format='json') 162 | self.assertEqual(response.status_code, status.HTTP_200_OK) 163 | # 测试管理 置顶, 加精帖子 164 | response = self.client.put(url, data2, format='json') 165 | self.assertEqual(response.status_code, status.HTTP_200_OK) 166 | self.assertEqual(Post.objects.get().pinned, True) 167 | self.assertEqual(Post.objects.get().highlighted, True) 168 | response = self.client.patch(url, data3, format='json') 169 | self.assertEqual(response.status_code, status.HTTP_200_OK) 170 | self.assertEqual(Post.objects.get().pinned, False) 171 | self.client.logout() 172 | 173 | # 作者 174 | self.client.login(username='test', password='test') 175 | response = self.client.put(url, data1, format='json') 176 | self.assertEqual(response.status_code, status.HTTP_200_OK) 177 | # 作者加精 置顶 隐藏 178 | response = self.client.put(url, data2, format='json') 179 | self.assertEqual(response.status_code, status.HTTP_200_OK) 180 | self.assertEqual(Post.objects.get().pinned, False) 181 | self.assertEqual(Post.objects.get().highlighted, False) 182 | response = self.client.patch(url, data4, format='json') 183 | self.assertEqual(response.status_code, status.HTTP_200_OK) 184 | self.assertEqual(Post.objects.get().hidden, False) 185 | self.client.logout() 186 | 187 | # 其他用户 188 | self.client.login(username='test2', password='test2') 189 | response = self.client.put(url, data1, format='json') 190 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) 191 | response = self.client.put(url, data2, format='json') 192 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) 193 | response = self.client.patch(url, data3, format='json') 194 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) 195 | self.client.logout() 196 | 197 | # 管理员隐藏帖子 198 | self.client.login(username='admin', password='admin123') 199 | response = self.client.patch(url, data4, format='json') 200 | self.assertEqual(response.status_code, status.HTTP_200_OK) 201 | response = self.client.get(reverse('post-list'), format='json') 202 | self.assertEqual(response.data['count'], 0) 203 | 204 | def test_index_post_list(self): 205 | """ 206 | 测试首页列表数量,以及分页情况 207 | """ 208 | for i in range(5): 209 | self.post = Post.objects.create(title='this is a test', 210 | body='this is a test', 211 | author=self.user 212 | ) 213 | self.post.tags.add(self.tag1) 214 | url = reverse('post-list') 215 | self.client.login(username='test', password='test') 216 | response = self.client.get(url, format='json') 217 | self.assertEqual(response.status_code, status.HTTP_200_OK) 218 | self.assertEqual(response.data['count'], 5) 219 | 220 | url = reverse('post-list') + '?page=2' 221 | response = self.client.get(url, format='json') 222 | self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) 223 | 224 | for i in range(16): 225 | self.post = Post.objects.create(title='this is a test', 226 | body='this is a test', 227 | author=self.user 228 | ) 229 | self.post.tags.add(self.tag1) 230 | 231 | response = self.client.get(url, format='json') 232 | self.assertEqual(response.status_code, status.HTTP_200_OK) 233 | self.assertEqual(response.data['count'], 21) 234 | 235 | def test_popular_post_list(self): 236 | """ 237 | 测试热门帖子 238 | """ 239 | for i in range(5): 240 | self.post = Post.objects.create(title='this is a test', 241 | body='this is a test', 242 | author=self.user 243 | ) 244 | self.post.tags.add(self.tag1) 245 | url = reverse('post-popular') 246 | response = self.client.get(url, format='json') 247 | self.assertEqual(response.status_code, status.HTTP_200_OK) 248 | self.assertEqual(response.data['count'], 0) 249 | self.assertEqual(response.data['page_size'], 10) 250 | 251 | url = reverse('reply-list') 252 | self.client.login(username='admin', password='admin123') 253 | data = { 254 | "content_type": 19, 255 | "object_pk": self.post.id, 256 | "site": 1, 257 | "comment": "回复测试", 258 | "parent": None 259 | } 260 | response = self.client.post(url, data, format='json') 261 | self.assertEqual(response.status_code, status.HTTP_201_CREATED) 262 | 263 | url = reverse('post-popular') 264 | response = self.client.get(url, format='json') 265 | request = response.wsgi_request 266 | author = Post.objects.get(id=self.post.id).author 267 | self.assertEqual(response.data['count'], 1) 268 | self.assertEqual(response.data['data'][0]['author']['id'], author.id) 269 | self.assertEqual(response.data['data'][0]['author']['mugshot'], 270 | request.build_absolute_uri(author.mugshot.url)) 271 | self.assertEqual(response.data['data'][0] 272 | ['author']['nickname'], author.nickname) 273 | 274 | def test_post_detail(self): 275 | """ 276 | 测试帖子详情 277 | """ 278 | self.post = Post.objects.create(title='this is a test', 279 | body='this is a test', 280 | author=self.user 281 | ) 282 | self.post.tags.add(self.tag1) 283 | url = reverse('post-detail', kwargs={'pk': self.post.pk}) 284 | response = self.client.get(url, format='json') 285 | self.assertEqual(response.status_code, status.HTTP_200_OK) 286 | self.assertEqual(response.data['id'], self.post.id) 287 | self.assertEqual(response.data['title'], 'this is a test') 288 | self.assertEqual(response.data['body'], 'this is a test') 289 | self.assertEqual(response.data['author']['nickname'], 'test') 290 | 291 | url = reverse('post-detail', kwargs={'pk': self.post.pk + 1}) 292 | response = self.client.get(url, format='json') 293 | self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) 294 | -------------------------------------------------------------------------------- /posts/views.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from django.db.models import Count, Max 4 | from django.db.models.functions import Coalesce 5 | from django.utils.timezone import now 6 | from django_filters import rest_framework as filters 7 | from rest_framework import permissions, serializers, status, viewsets 8 | from rest_framework.decorators import action 9 | from rest_framework.response import Response 10 | 11 | from replies.serializers import TreeRepliesSerializer 12 | from tags.models import Tag 13 | from .models import Post 14 | from .permissions import IsAdminAuthorOrReadOnly 15 | from .serializers import ( 16 | IndexPostListSerializer, 17 | PopularPostSerializer, 18 | PostDetailSerializer, 19 | ) 20 | 21 | 22 | class PostViewSet(viewsets.ModelViewSet): 23 | queryset = IndexPostListSerializer.setup_eager_loading( 24 | Post.public.annotate( 25 | latest_post_time=Coalesce(Max('replies__submit_date'), 'created') 26 | ).order_by('-pinned', '-latest_post_time'), 27 | select_related=IndexPostListSerializer.SELECT_RELATED_FIELDS, 28 | prefetch_related=IndexPostListSerializer.PREFETCH_RELATED_FIELDS 29 | ) 30 | serializer_class = IndexPostListSerializer 31 | permission_classes = (permissions.IsAuthenticatedOrReadOnly, 32 | IsAdminAuthorOrReadOnly) 33 | # 允许get post put方法 34 | http_method_names = ['get', 'post', 'put', 'patch'] 35 | filter_backends = (filters.DjangoFilterBackend,) 36 | # 在post-list页面可以按标签字段过滤出特定标签下的帖子 37 | filter_fields = ('tags',) 38 | 39 | def retrieve(self, request, *args, **kwargs): 40 | """ 41 | 重写帖子详情页,这里使用PostDetailSerializer, 42 | 而不是默认的IndexPostListSerializer 43 | """ 44 | instance = self.get_object() 45 | instance.increase_views() 46 | instance.refresh_from_db() 47 | serializer = PostDetailSerializer(instance, context={'request': request}) 48 | return Response(serializer.data) 49 | 50 | def create(self, request, *args, **kwargs): 51 | """ 52 | 重写创建帖子方法,使用PostDetailSerializer 53 | """ 54 | serializer = PostDetailSerializer(data=request.data, context={'request': request}) 55 | serializer.is_valid(raise_exception=True) 56 | self.perform_create(serializer) 57 | headers = self.get_success_headers(serializer.data) 58 | return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) 59 | 60 | def perform_create(self, serializer): 61 | """ 62 | 保存tags和author,同时验证tag的数量 63 | tags和author在PostSerializer里是read_only 64 | """ 65 | tags_data = self.request.data.get('tags') 66 | tags = [] 67 | if not tags_data: 68 | raise serializers.ValidationError(detail={'标签': '请选择至少一个标签'}) 69 | elif len(tags_data) > 3: 70 | raise serializers.ValidationError(detail={'标签': '最多可以选择三个标签'}) 71 | for name in tags_data: 72 | try: 73 | tag = Tag.objects.get(name=name) 74 | tags.append(tag) 75 | except Exception: 76 | raise serializers.ValidationError(detail={'标签': '标签不存在'}) 77 | serializer.save(author=self.request.user, tags=tags) 78 | 79 | def update(self, request, *args, **kwargs): 80 | """ 81 | 更新帖子的方法,包括put和patch, 82 | """ 83 | partial = kwargs.pop('partial', False) 84 | tags_data = request.data.get('tags') 85 | if partial and tags_data is None: 86 | pass 87 | elif not tags_data: 88 | raise serializers.ValidationError(detail={'标签': '请选择至少一个标签'}) 89 | elif len(tags_data) > 3: 90 | raise serializers.ValidationError(detail={'标签': '最多可以选择三个标签'}) 91 | instance = self.get_object() 92 | serializer = PostDetailSerializer( 93 | instance, 94 | data=request.data, 95 | partial=partial, 96 | context={'request': request}) 97 | serializer.is_valid(raise_exception=True) 98 | self.perform_update(serializer) 99 | 100 | if getattr(instance, '_prefetched_objects_cache', None): 101 | # If 'prefetch_related' has been applied to a queryset, we need to 102 | # forcibly invalidate the prefetch cache on the instance. 103 | instance._prefetched_objects_cache = {} 104 | 105 | return Response(serializer.data) 106 | 107 | def perform_update(self, serializer): 108 | """ 109 | 执行更新 110 | """ 111 | data = {} 112 | tags = [] 113 | # 标签,隐藏,置顶,加精这些字段都是read_only, 114 | # 因此这些字段在修改时需要手动来保存 115 | tags_data = self.request.data.get('tags') 116 | # 如果用户不是管理,则无需提取这些字段 117 | if self.request.user.is_staff: 118 | hidden = self.request.data.get('hidden') 119 | pinned = self.request.data.get('pinned') 120 | highlighted = self.request.data.get('highlighted') 121 | if hidden is not None: 122 | data['hidden'] = hidden 123 | if pinned is not None: 124 | data['pinned'] = pinned 125 | if highlighted is not None: 126 | data['highlighted'] = highlighted 127 | if tags_data: 128 | for name in tags_data: 129 | try: 130 | tag = Tag.objects.get(name=name) 131 | tags.append(tag) 132 | except Exception: 133 | raise serializers.ValidationError(detail={'标签': '标签不存在'}) 134 | data['tags'] = tags 135 | serializer.save(**data) 136 | 137 | @action(detail=False, serializer_class=PopularPostSerializer) 138 | def popular(self, request): 139 | """ 140 | 返回48小时内评论次数最多的帖子 141 | """ 142 | popular_posts = PopularPostSerializer.setup_eager_loading( 143 | Post.public.annotate( 144 | num_replies=Count('replies'), 145 | latest_reply_time=Max('replies__submit_date') 146 | ).filter( 147 | num_replies__gt=0, 148 | latest_reply_time__gt=(now() - datetime.timedelta(days=2)), 149 | latest_reply_time__lt=now() 150 | ).order_by('-num_replies', '-latest_reply_time')[:10], 151 | select_related=PopularPostSerializer.SELECT_RELATED_FIELDS 152 | ) 153 | 154 | # return paginated queryset as response data if paginator exists 155 | self.paginator.page_size = 10 156 | page = self.paginate_queryset(popular_posts) 157 | if page is not None: 158 | serializer = self.get_serializer(page, many=True) 159 | return self.get_paginated_response(serializer.data) 160 | 161 | serializer = self.get_serializer(popular_posts, many=True) 162 | return Response(serializer.data) 163 | 164 | @action(methods=['get'], detail=True, serializer_class=TreeRepliesSerializer) 165 | def replies(self, request, pk=None): 166 | post = self.get_object() 167 | replies = post.replies.filter(is_public=True, is_removed=False, parent__isnull=True) 168 | page = self.paginate_queryset(replies) 169 | if page is not None: 170 | serializer = self.get_serializer(page, many=True, context={'request': request}) 171 | return self.get_paginated_response(serializer.data) 172 | 173 | serializer = self.get_serializer(replies, many=True, context={'request': request}) 174 | return Response(serializer.data) 175 | -------------------------------------------------------------------------------- /replies/__init__.py: -------------------------------------------------------------------------------- 1 | default_app_config = 'replies.apps.RepliesConfig' 2 | 3 | 4 | def get_model(): 5 | from .models import Reply 6 | return Reply 7 | -------------------------------------------------------------------------------- /replies/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | 3 | from .models import Reply 4 | 5 | admin.site.register(Reply) 6 | -------------------------------------------------------------------------------- /replies/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class RepliesConfig(AppConfig): 5 | name = 'replies' 6 | 7 | def ready(self): 8 | from .moderation import moderator 9 | from .moderation import ReplyModerator 10 | from posts.models import Post 11 | from actstream import registry 12 | registry.register(self.get_model('Reply')) 13 | registry.register(Post) 14 | moderator.register(Post, ReplyModerator) 15 | -------------------------------------------------------------------------------- /replies/factories.py: -------------------------------------------------------------------------------- 1 | import factory 2 | from django.contrib.contenttypes.models import ContentType 3 | from django.contrib.sites.models import Site 4 | 5 | from posts.factories import PostFactory 6 | from replies.models import Reply 7 | from users.factories import UserFactory 8 | 9 | 10 | class SiteFactory(factory.DjangoModelFactory): 11 | name = factory.Sequence(lambda n: 'example_%s' % n) 12 | domain = factory.LazyAttribute(lambda o: '%s.com' % o.name) 13 | 14 | class Meta: 15 | model = Site 16 | 17 | 18 | class BaseReplyFactory(factory.DjangoModelFactory): 19 | content_type = factory.LazyAttribute( 20 | lambda o: ContentType.objects.get_for_model(o.content_object)) 21 | object_pk = factory.SelfAttribute('content_object.id') 22 | user = factory.SubFactory(UserFactory) 23 | site = factory.SubFactory(SiteFactory) 24 | comment = 'test comment' 25 | 26 | class Meta: 27 | model = Reply 28 | 29 | 30 | class PostReplyFactory(BaseReplyFactory): 31 | content_object = factory.SubFactory(PostFactory) 32 | -------------------------------------------------------------------------------- /replies/migrations/0001_initial.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by Django 1.11 on 2018-04-24 12:35 3 | from __future__ import unicode_literals 4 | 5 | from django.db import migrations, models 6 | import django.db.models.deletion 7 | import mptt.fields 8 | 9 | 10 | class Migration(migrations.Migration): 11 | 12 | initial = True 13 | 14 | dependencies = [ 15 | ('sites', '0002_alter_domain_unique'), 16 | ('contenttypes', '0002_remove_content_type_name'), 17 | ] 18 | 19 | operations = [ 20 | migrations.CreateModel( 21 | name='Reply', 22 | fields=[ 23 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 24 | ('object_pk', models.TextField(verbose_name='object ID')), 25 | ('user_name', models.CharField(blank=True, max_length=50, verbose_name="user's name")), 26 | ('user_email', models.EmailField(blank=True, max_length=254, verbose_name="user's email address")), 27 | ('user_url', models.URLField(blank=True, verbose_name="user's URL")), 28 | ('comment', models.TextField(max_length=3000, verbose_name='comment')), 29 | ('submit_date', models.DateTimeField(db_index=True, default=None, verbose_name='date/time submitted')), 30 | ('ip_address', models.GenericIPAddressField(blank=True, null=True, unpack_ipv4=True, verbose_name='IP address')), 31 | ('is_public', models.BooleanField(default=True, help_text='Uncheck this box to make the comment effectively disappear from the site.', verbose_name='is public')), 32 | ('is_removed', models.BooleanField(default=False, help_text='Check this box if the comment is inappropriate. A "This comment has been removed" message will be displayed instead.', verbose_name='is removed')), 33 | ('lft', models.PositiveIntegerField(db_index=True, editable=False)), 34 | ('rght', models.PositiveIntegerField(db_index=True, editable=False)), 35 | ('tree_id', models.PositiveIntegerField(db_index=True, editable=False)), 36 | ('level', models.PositiveIntegerField(db_index=True, editable=False)), 37 | ('content_type', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='content_type_set_for_reply', to='contenttypes.ContentType', verbose_name='content type')), 38 | ('parent', mptt.fields.TreeForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='children', to='replies.Reply', verbose_name='上级回复')), 39 | ('site', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='sites.Site')), 40 | ], 41 | options={ 42 | 'verbose_name': '回复', 43 | 'verbose_name_plural': '回复', 44 | 'ordering': ('submit_date',), 45 | 'permissions': [('can_moderate', 'Can moderate comments')], 46 | 'abstract': False, 47 | }, 48 | ), 49 | ] 50 | -------------------------------------------------------------------------------- /replies/migrations/0002_reply_user.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by Django 1.11 on 2018-04-24 12:35 3 | from __future__ import unicode_literals 4 | 5 | from django.conf import settings 6 | from django.db import migrations, models 7 | import django.db.models.deletion 8 | 9 | 10 | class Migration(migrations.Migration): 11 | 12 | initial = True 13 | 14 | dependencies = [ 15 | migrations.swappable_dependency(settings.AUTH_USER_MODEL), 16 | ('replies', '0001_initial'), 17 | ] 18 | 19 | operations = [ 20 | migrations.AddField( 21 | model_name='reply', 22 | name='user', 23 | field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='reply_comments', to=settings.AUTH_USER_MODEL, verbose_name='user'), 24 | ), 25 | ] 26 | -------------------------------------------------------------------------------- /replies/migrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/replies/migrations/__init__.py -------------------------------------------------------------------------------- /replies/models.py: -------------------------------------------------------------------------------- 1 | from actstream.models import Follow 2 | from django.contrib.contenttypes.models import ContentType 3 | from django.db import models 4 | from django_comments.abstracts import CommentAbstractModel 5 | from mptt.models import MPTTModel, TreeForeignKey 6 | 7 | 8 | class Reply(MPTTModel, CommentAbstractModel): 9 | parent = TreeForeignKey( 10 | 'self', 11 | verbose_name="上级回复", 12 | related_name='children', 13 | blank=True, 14 | null=True, 15 | on_delete=models.SET_NULL, 16 | ) 17 | 18 | class Meta(CommentAbstractModel.Meta): 19 | verbose_name = "回复" 20 | verbose_name_plural = "回复" 21 | 22 | def descendants(self): 23 | """ 24 | 获取回复的全部子孙回复,按回复时间正序排序 25 | """ 26 | return self.get_descendants().order_by('submit_date') 27 | 28 | def descendants_count(self): 29 | return self.get_descendant_count() 30 | 31 | @property 32 | def ctype(self): 33 | return ContentType.objects.get_for_model(self) 34 | 35 | @property 36 | def ctype_id(self): 37 | return self.ctype.id 38 | 39 | @property 40 | def like_count(self): 41 | return Follow.objects.for_object(self, flag='like').count() 42 | -------------------------------------------------------------------------------- /replies/moderation.py: -------------------------------------------------------------------------------- 1 | from django_comments.moderation import CommentModerator 2 | from django_comments.moderation import Moderator as DjangoCommentModerator 3 | from notifications.signals import notify 4 | 5 | 6 | class Moderator(DjangoCommentModerator): 7 | def post_save_moderation(self, sender, comment, request, **kwargs): 8 | model = comment.content_type.model_class() 9 | if model not in self._registry: 10 | return 11 | self._registry[model].notify(comment, comment.content_object, request) 12 | 13 | 14 | class ReplyModerator(CommentModerator): 15 | def notify(self, reply, content_object, request): 16 | post_author = content_object.author 17 | 18 | if reply.parent: # 回复的回复 19 | parent_user = reply.parent.user 20 | # 通知被回复的人,自己回复自己无需通知 21 | if parent_user != reply.user: 22 | reply_data = { 23 | 'recipient': parent_user, 24 | 'verb': 'respond', 25 | 'action_object': reply, 26 | 'target': content_object, 27 | } 28 | notify.send(sender=reply.user, **reply_data) 29 | 30 | if parent_user != content_object.author and post_author != reply.user: 31 | # 如果被回复的人不是帖子作者,且不是帖子作者自己的回复,帖子作者应该收到通知 32 | comment_data = { 33 | 'recipient': post_author, 34 | 'verb': 'reply', 35 | 'action_object': reply, 36 | 'target': content_object, 37 | } 38 | notify.send(sender=reply.user, **comment_data) 39 | else: 40 | # 如果是直接回复,且不是帖子作者自己回复,则通知帖子作者 41 | if post_author != reply.user: 42 | comment_data = { 43 | 'recipient': post_author, 44 | 'verb': 'reply', 45 | 'action_object': reply, 46 | 'target': content_object, 47 | } 48 | notify.send(sender=reply.user, **comment_data) 49 | 50 | 51 | moderator = Moderator() 52 | -------------------------------------------------------------------------------- /replies/permissions.py: -------------------------------------------------------------------------------- 1 | from rest_framework import permissions 2 | 3 | 4 | class NotSelf(permissions.BasePermission): 5 | def has_object_permission(self, request, view, obj): 6 | if request.method in permissions.SAFE_METHODS: 7 | return True 8 | 9 | return obj.user != request.user 10 | -------------------------------------------------------------------------------- /replies/serializers.py: -------------------------------------------------------------------------------- 1 | from actstream.models import Follow 2 | from django.contrib.contenttypes.models import ContentType 3 | from django.contrib.sites.models import Site 4 | from rest_framework import serializers 5 | 6 | from posts.models import Post 7 | from replies.models import Reply 8 | 9 | 10 | class FlatReplySerializer(serializers.ModelSerializer): 11 | """ 12 | 返回一个扁平化的按发表时间倒序排序的 reply 列表,无视其层级关系。 13 | 适合用在 user 详情页面的个人回复列表中。 14 | """ 15 | post = serializers.SerializerMethodField() 16 | user = serializers.SerializerMethodField() 17 | parent_user = serializers.SerializerMethodField() 18 | is_liked = serializers.SerializerMethodField() 19 | 20 | class Meta: 21 | model = Reply 22 | fields = ( 23 | 'id', 24 | 'user', 25 | 'parent_user', 26 | 'post', 27 | 'submit_date', 28 | 'comment', 29 | 'like_count', 30 | 'is_liked', 31 | ) 32 | 33 | def get_post(self, obj): 34 | post = obj.content_object 35 | return { 36 | 'id': post.id, 37 | 'title': post.title, 38 | } 39 | 40 | def get_user(self, obj): 41 | user = obj.user 42 | request = self.context.get('request') 43 | url = user.mugshot.url 44 | thumbnail_url = user.mugshot_thumbnail.url 45 | return { 46 | 'id': user.id, 47 | 'mugshot': request.build_absolute_uri(url) if request else url, 48 | 'mugshot_url': request.build_absolute_uri(thumbnail_url) if request else url, 49 | 'nickname': user.nickname, 50 | } 51 | 52 | def get_parent_user(self, obj): 53 | parent = obj.parent 54 | if not parent: 55 | return None 56 | user = parent.user 57 | request = self.context.get('request') 58 | url = user.mugshot.url 59 | thumbnail_url = user.mugshot_thumbnail.url 60 | return { 61 | 'id': user.id, 62 | 'mugshot': request.build_absolute_uri(url) if request else url, 63 | 'mugshot_url': request.build_absolute_uri(thumbnail_url) if request else url, 64 | 'nickname': user.nickname, 65 | } 66 | 67 | def get_is_liked(self, obj): 68 | request = self.context.get('request') 69 | return Follow.objects.is_following(request.user, obj, flag='like') 70 | 71 | 72 | class ReplyCreationSerializer(serializers.ModelSerializer): 73 | """ 74 | 仅用于 reply 的创建 75 | """ 76 | parent_user = serializers.SerializerMethodField() 77 | user = serializers.SerializerMethodField() 78 | 79 | class Meta: 80 | model = Reply 81 | fields = ( 82 | 'id', 83 | 'object_pk', 84 | 'comment', 85 | 'parent', 86 | 'submit_date', 87 | 'ip_address', 88 | 'is_public', 89 | 'is_removed', 90 | 'user', 91 | 'parent_user', 92 | ) 93 | read_only_fields = ( 94 | 'id', 95 | 'submit_date', 96 | 'ip_address', 97 | 'is_public', 98 | 'is_removed', 99 | ) 100 | 101 | def get_parent_user(self, obj): 102 | parent = obj.parent 103 | if not parent: 104 | return None 105 | user = parent.user 106 | request = self.context.get('request') 107 | url = user.mugshot.url 108 | thumbnail_url = user.mugshot_thumbnail.url 109 | return { 110 | 'id': user.id, 111 | 'mugshot': request.build_absolute_uri(url) if request else url, 112 | 'mugshot_url': request.build_absolute_uri(thumbnail_url) if request else url, 113 | 'nickname': user.nickname, 114 | } 115 | 116 | def get_user(self, obj): 117 | user = obj.user 118 | request = self.context.get('request') 119 | url = user.mugshot.url 120 | thumbnail_url = user.mugshot_thumbnail.url 121 | return { 122 | 'id': user.id, 123 | 'mugshot': request.build_absolute_uri(url) if request else url, 124 | 'mugshot_url': request.build_absolute_uri(thumbnail_url) if request else url, 125 | 'nickname': user.nickname, 126 | } 127 | 128 | def create(self, validated_data): 129 | post_id = validated_data.get('object_pk') 130 | post_ctype = ContentType.objects.get_for_model( 131 | Post.objects.get(id=int(post_id)) 132 | ) 133 | site = Site.objects.get_current() 134 | validated_data['content_type'] = post_ctype 135 | validated_data['site'] = site 136 | return super(ReplyCreationSerializer, self).create(validated_data) 137 | 138 | 139 | class TreeRepliesSerializer(serializers.ModelSerializer): 140 | """ 141 | 返回两层的 reply,第一层为根 reply,第二层为这个 reply 的所有子孙 reply。 142 | 这个 Serializer 适合用于帖子详情页的 reply 列表。 143 | """ 144 | descendants = serializers.SerializerMethodField() 145 | user = serializers.SerializerMethodField() 146 | is_liked = serializers.SerializerMethodField() 147 | 148 | class Meta: 149 | model = Reply 150 | fields = ( 151 | 'id', 152 | 'content_type', 153 | 'object_pk', 154 | 'comment', 155 | 'submit_date', 156 | 'like_count', 157 | 'user', 158 | 'descendants', 159 | 'descendants_count', 160 | 'is_liked' 161 | ) 162 | 163 | def get_is_liked(self, obj): 164 | request = self.context.get('request') 165 | return Follow.objects.is_following(request.user, obj, flag='like') 166 | 167 | def get_user(self, obj): 168 | user = obj.user 169 | request = self.context.get('request') 170 | url = user.mugshot.url 171 | thumbnail_url = user.mugshot_thumbnail.url 172 | return { 173 | 'id': user.id, 174 | 'mugshot': request.build_absolute_uri(url) if request else url, 175 | 'mugshot_url': request.build_absolute_uri(thumbnail_url) if request else url, 176 | 'nickname': user.nickname, 177 | } 178 | 179 | def get_descendants(self, obj): 180 | qs = obj.descendants() 181 | request = self.context.get('request') 182 | return FlatReplySerializer(qs, many=True, context={'request': request}).data 183 | 184 | 185 | class FollowSerializer(serializers.ModelSerializer): 186 | """ 187 | 用于记录回复的点赞信息 188 | """ 189 | 190 | class Meta: 191 | model = Follow 192 | fields = ( 193 | 'id', 194 | 'user', 195 | 'content_type', 196 | 'object_id', 197 | 'flag', 198 | 'started', 199 | ) 200 | read_only_fields = ( 201 | 'id', 202 | 'user', 203 | 'content_type', 204 | 'object_id', 205 | 'flag', 206 | 'started', 207 | ) 208 | 209 | def create(self, validated_data): 210 | reply_id = validated_data.get('object_pk') 211 | reply_ctype = ContentType.objects.get_for_model( 212 | Reply.objects.get(id=int(reply_id)) 213 | ) 214 | validated_data['content_type'] = reply_ctype 215 | return super(FollowSerializer, self).create(validated_data) 216 | -------------------------------------------------------------------------------- /replies/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/replies/tests/__init__.py -------------------------------------------------------------------------------- /replies/tests/test_models.py: -------------------------------------------------------------------------------- 1 | from django.contrib.contenttypes.models import ContentType 2 | from django.contrib.sites.models import Site 3 | from django.test import TestCase 4 | from django.utils.timezone import now, timedelta 5 | 6 | from posts.models import Post 7 | from users.models import User 8 | 9 | from ..models import Reply 10 | 11 | 12 | class RelyModelTestCase(TestCase): 13 | def setUp(self): 14 | self.user = User.objects.create_user(username='test', 15 | email='test@test.com', 16 | password='test', 17 | nickname='test') 18 | self.post = Post.objects.create(title='test title', author=self.user) 19 | self.site = Site.objects.create(name='test', domain='test.com') 20 | self.post_ct = ContentType.objects.get_for_model(self.post) 21 | self.post_id = self.post.id 22 | 23 | def test_descendants(self): 24 | """测试以正确的顺序返回了全部回复""" 25 | self.root_reply = Reply.objects.create( 26 | content_type=self.post_ct, 27 | object_pk=self.post_id, 28 | site=self.site, 29 | user=self.user, 30 | comment='root reply', 31 | submit_date=now() 32 | ) 33 | self.child_reply = Reply.objects.create( 34 | content_type=self.post_ct, 35 | object_pk=self.post_id, 36 | site=self.site, 37 | user=self.user, 38 | comment='child reply', 39 | parent=self.root_reply, 40 | submit_date=now() + timedelta(minutes=1) 41 | ) 42 | 43 | self.another_child_reply = Reply.objects.create( 44 | content_type=self.post_ct, 45 | object_pk=self.post_id, 46 | site=self.site, 47 | user=self.user, 48 | comment='another child reply', 49 | parent=self.root_reply, 50 | submit_date=now() + timedelta(minutes=2) 51 | ) 52 | 53 | self.grandchild_reply = Reply.objects.create( 54 | content_type=self.post_ct, 55 | object_pk=self.post_id, 56 | site=self.site, 57 | user=self.user, 58 | comment='grandchild reply', 59 | parent=self.child_reply, 60 | submit_date=now() + timedelta(minutes=3) 61 | ) 62 | 63 | self.assertQuerysetEqual( 64 | self.root_reply.descendants(), 65 | [repr(o) for o in [self.child_reply, self.another_child_reply, self.grandchild_reply]] 66 | ) 67 | -------------------------------------------------------------------------------- /replies/tests/test_moderation.py: -------------------------------------------------------------------------------- 1 | from django.contrib.contenttypes.models import ContentType 2 | from django.contrib.sites.models import Site 3 | from django.test import TestCase 4 | from notifications.models import Notification 5 | 6 | from posts.models import Post 7 | from users.models import User 8 | 9 | from ..models import Reply 10 | from ..moderation import ReplyModerator 11 | 12 | reply_moderator = ReplyModerator(ReplyModerator) 13 | 14 | 15 | class ModerationTestCase(TestCase): 16 | def setUp(self): 17 | self.user = User.objects.create_user( 18 | username='test', 19 | email='test@test.com', 20 | password='test', 21 | nickname='test' 22 | ) 23 | self.another_user = User.objects.create_user( 24 | username='another', 25 | email='another@test.com', 26 | password='another', 27 | nickname='another' 28 | ) 29 | self.post = Post.objects.create( 30 | title='test title', 31 | author=self.user 32 | ) 33 | self.site = Site.objects.create(name='test', domain='test.com') 34 | self.post_ct = ContentType.objects.get_for_model(self.post) 35 | self.post_id = self.post.id 36 | 37 | def test_post_author_received_reply_notification(self): 38 | """帖子被回复后,且回复者不是作者,则作者收到通知""" 39 | reply = Reply.objects.create( 40 | content_type=self.post_ct, 41 | object_pk=self.post_id, 42 | site=self.site, 43 | user=self.another_user, # 回复者不是帖子作者 44 | comment='reply', 45 | ) 46 | 47 | reply_moderator.notify(reply=reply, content_object=self.post, request=None) 48 | self.assertEqual(Notification.objects.count(), 1) 49 | 50 | def test_reply_self_do_not_received_notification(self): 51 | reply = Reply.objects.create( 52 | content_type=self.post_ct, 53 | object_pk=self.post_id, 54 | site=self.site, 55 | user=self.user, # 回复者是帖子作者 56 | comment='reply', 57 | ) 58 | 59 | reply_moderator.notify(reply=reply, content_object=self.post, request=None) 60 | self.assertEqual(Notification.objects.count(), 0) 61 | 62 | def test_reply_others_reply_as_well_as_others_post(self): 63 | """回复他人回复且不是自己的帖子,他人和帖子作者收到通知""" 64 | user = User.objects.create_user( 65 | username='user', 66 | email='user@test.com', 67 | password='user', 68 | nickname='user' 69 | ) 70 | 71 | reply = Reply.objects.create( 72 | content_type=self.post_ct, 73 | object_pk=self.post_id, 74 | site=self.site, 75 | user=self.another_user, 76 | comment='reply', 77 | ) 78 | 79 | new_reply = Reply.objects.create( 80 | content_type=self.post_ct, 81 | object_pk=self.post_id, 82 | site=self.site, 83 | user=user, 84 | comment='new reply', 85 | parent=reply, 86 | ) 87 | 88 | reply_moderator.notify(reply=new_reply, content_object=self.post, request=None) 89 | self.assertEqual(Notification.objects.count(), 2) 90 | 91 | def test_reply_others_reply_but_self_post(self): 92 | """回复他人回复但是自己的帖子,他人收到通知""" 93 | reply = Reply.objects.create( 94 | content_type=self.post_ct, 95 | object_pk=self.post_id, 96 | site=self.site, 97 | user=self.another_user, 98 | comment='reply', 99 | ) 100 | 101 | new_reply = Reply.objects.create( 102 | content_type=self.post_ct, 103 | object_pk=self.post_id, 104 | site=self.site, 105 | user=self.user, 106 | comment='new reply', 107 | parent=reply, 108 | ) 109 | 110 | reply_moderator.notify(reply=new_reply, content_object=self.post, request=None) 111 | self.assertEqual(Notification.objects.count(), 1) 112 | self.assertEqual( 113 | Notification.objects.get().recipient, 114 | reply.user 115 | ) 116 | 117 | def test_reply_self_reply_as_well_as_self_post(self): 118 | """回复自己的回复和自己的帖子,没有任何通知""" 119 | reply = Reply.objects.create( 120 | content_type=self.post_ct, 121 | object_pk=self.post_id, 122 | site=self.site, 123 | user=self.user, # 回复者是帖子作者 124 | comment='reply', 125 | ) 126 | 127 | new_reply = Reply.objects.create( 128 | content_type=self.post_ct, 129 | object_pk=self.post_id, 130 | site=self.site, 131 | user=self.user, # 回复者是帖子作者 132 | comment='new reply', 133 | parent=reply, 134 | ) 135 | 136 | reply_moderator.notify(reply=new_reply, content_object=self.post, request=None) 137 | self.assertEqual(Notification.objects.count(), 0) 138 | 139 | def test_reply_self_reply_but_others_post(self): 140 | """回复自己的回复和让人的帖子,帖子作者收到一条通知""" 141 | reply = Reply.objects.create( 142 | content_type=self.post_ct, 143 | object_pk=self.post_id, 144 | site=self.site, 145 | user=self.another_user, # 回复者是帖子作者 146 | comment='reply', 147 | ) 148 | 149 | new_reply = Reply.objects.create( 150 | content_type=self.post_ct, 151 | object_pk=self.post_id, 152 | site=self.site, 153 | user=self.another_user, # 回复者是帖子作者 154 | comment='new reply', 155 | parent=reply, 156 | ) 157 | 158 | reply_moderator.notify(reply=new_reply, content_object=self.post, request=None) 159 | self.assertEqual(Notification.objects.count(), 1) 160 | self.assertEqual( 161 | Notification.objects.get().recipient, 162 | self.post.author 163 | ) 164 | -------------------------------------------------------------------------------- /replies/tests/test_serializers.py: -------------------------------------------------------------------------------- 1 | from pprint import pprint 2 | 3 | from django.contrib.contenttypes.models import ContentType 4 | from django.contrib.sites.models import Site 5 | from django.test import RequestFactory, TestCase 6 | from django.utils.timezone import now, timedelta 7 | 8 | from posts.models import Post 9 | from users.models import User 10 | 11 | from ..models import Reply 12 | from ..serializers import TreeRepliesSerializer 13 | 14 | 15 | class ReplySerializerTestCase(TestCase): 16 | def setUp(self): 17 | self.user = User.objects.create_user(username='test', 18 | email='test@test.com', 19 | password='test', 20 | nickname='test') 21 | self.post = Post.objects.create(title='test title', author=self.user) 22 | self.site = Site.objects.create(name='test', domain='test.com') 23 | self.post_ct = ContentType.objects.get_for_model(self.post) 24 | self.post_id = self.post.id 25 | 26 | def tearDown(self): 27 | pass 28 | 29 | def test_tree_reply_serializer(self): 30 | # - root 31 | # - - child 32 | # - - - grand child 33 | # - - another_child 34 | # - another root 35 | # - - another root child 36 | self.root_reply = Reply.objects.create( 37 | content_type=self.post_ct, 38 | object_pk=self.post_id, 39 | site=self.site, 40 | user=self.user, 41 | comment='root reply', 42 | submit_date=now() 43 | ) 44 | self.child_reply = Reply.objects.create( 45 | content_type=self.post_ct, 46 | object_pk=self.post_id, 47 | site=self.site, 48 | user=self.user, 49 | comment='child reply', 50 | parent=self.root_reply, 51 | submit_date=now() + timedelta(minutes=1) 52 | ) 53 | 54 | self.another_child_reply = Reply.objects.create( 55 | content_type=self.post_ct, 56 | object_pk=self.post_id, 57 | site=self.site, 58 | user=self.user, 59 | comment='another child reply', 60 | parent=self.root_reply, 61 | submit_date=now() + timedelta(minutes=2) 62 | ) 63 | 64 | self.grandchild_reply = Reply.objects.create( 65 | content_type=self.post_ct, 66 | object_pk=self.post_id, 67 | site=self.site, 68 | user=self.user, 69 | comment='grandchild reply', 70 | parent=self.child_reply, 71 | submit_date=now() + timedelta(minutes=3) 72 | ) 73 | 74 | self.another_root_reply = Reply.objects.create( 75 | content_type=self.post_ct, 76 | object_pk=self.post_id, 77 | site=self.site, 78 | user=self.user, 79 | comment='another root reply', 80 | submit_date=now() 81 | ) 82 | 83 | self.another_root_child_reply = Reply.objects.create( 84 | content_type=self.post_ct, 85 | object_pk=self.post_id, 86 | site=self.site, 87 | user=self.user, 88 | comment='another root child reply', 89 | parent=self.another_root_reply, 90 | submit_date=now() + timedelta(minutes=1) 91 | ) 92 | request = RequestFactory().get('/') 93 | request.user = self.user 94 | replies = self.post.replies.filter(is_public=True, is_removed=False, parent__isnull=True) 95 | serializer = TreeRepliesSerializer(replies, many=True, context={'request': request}) 96 | 97 | pprint(serializer.data) 98 | -------------------------------------------------------------------------------- /replies/tests/test_views.py: -------------------------------------------------------------------------------- 1 | from actstream.models import Follow 2 | from django.contrib.contenttypes.models import ContentType 3 | from django.contrib.sites.models import Site 4 | from django.urls import reverse 5 | from notifications.models import Notification 6 | from rest_framework import status 7 | from rest_framework import test 8 | from notifications.models import Notification 9 | from actstream.models import Follow 10 | 11 | from posts.models import Post 12 | from users.models import User 13 | 14 | from ..models import Reply 15 | 16 | 17 | class ReplyViewSetsTestCase(test.APITestCase): 18 | def setUp(self): 19 | self.user = User.objects.create_user( 20 | username='test', 21 | email='test@test.com', 22 | password='test', 23 | nickname='test' 24 | ) 25 | self.another_user = User.objects.create_user( 26 | username='another', 27 | email='another@test.com', 28 | password='another', 29 | nickname='another' 30 | ) 31 | self.post = Post.objects.create( 32 | title='test title', 33 | author=self.another_user 34 | ) 35 | self.post_ct = ContentType.objects.get_for_model(self.post) 36 | self.post_id = self.post.id 37 | 38 | self.site = Site.objects.create(name='test', domain='test.com') 39 | self.reply = Reply.objects.create( 40 | content_type=self.post_ct, 41 | object_pk=self.post_id, 42 | site=self.site, 43 | user=self.another_user, 44 | comment='reply', 45 | ) 46 | 47 | def test_anonymous_user_can_not_create_reply(self): 48 | url = reverse('reply-list') 49 | data = { 50 | "content_type": self.post_ct.id, 51 | "object_pk": self.post_id, 52 | "site": 1, 53 | "comment": "test comment", 54 | } 55 | response = self.client.post(url, data, format='json') 56 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) 57 | 58 | def test_authenticated_user_can_create_reply(self): 59 | url = reverse('reply-list') 60 | data = { 61 | "object_pk": self.post_id, 62 | "comment": "test comment", 63 | } 64 | self.client.login(username='test', password='test') 65 | response = self.client.post(url, data, format='json') 66 | self.assertEqual(response.status_code, status.HTTP_201_CREATED) 67 | self.assertEqual(Reply.objects.count(), 2) 68 | self.assertEqual(Reply.objects.last().user, self.user) 69 | self.assertEqual(Reply.objects.last().comment, 'test comment') 70 | 71 | # 确定生成了通知 72 | self.assertEqual(Notification.objects.count(), 1) 73 | 74 | def test_can_create_child_reply(self): 75 | url = reverse('reply-list') 76 | data = { 77 | "object_pk": self.post_id, 78 | "comment": "test comment", 79 | "parent": self.reply.id, 80 | } 81 | self.client.login(username='test', password='test') 82 | response = self.client.post(url, data, format='json') 83 | self.assertEqual(response.status_code, status.HTTP_201_CREATED) 84 | self.assertEqual(Reply.objects.count(), 2) 85 | self.assertEqual(Reply.objects.last().user, self.user) 86 | self.assertEqual(Reply.objects.last().comment, 'test comment') 87 | self.assertEqual(Reply.objects.last().parent, self.reply) 88 | 89 | # 确定生成了通知 90 | self.assertEqual(Notification.objects.count(), 1) 91 | 92 | def test_reply_only_support_post_method(self): 93 | url = reverse('reply-list') 94 | data = { 95 | "object_pk": self.post_id, 96 | "comment": "test comment", 97 | } 98 | self.client.login(username='test', password='test') 99 | 100 | response = self.client.get(url, format='json') 101 | self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) 102 | 103 | response = self.client.put(url, data, format='json') 104 | self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) 105 | 106 | response = self.client.patch(url, data, format='json') 107 | self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) 108 | 109 | response = self.client.delete(url, format='json') 110 | self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) 111 | 112 | def test_anonymous_user_can_not_like_reply(self): 113 | url = reverse('reply-like', kwargs={'pk': self.reply.id}) 114 | data = { 115 | "object_id": self.reply.id, 116 | "flag": "like", 117 | } 118 | response = self.client.post(url, data, format='json') 119 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) 120 | 121 | def test_user_can_not_like_self_reply(self): 122 | url = reverse('reply-like', kwargs={'pk': self.reply.id}) 123 | data = { 124 | "object_id": self.reply.id, 125 | "flag": "like", 126 | } 127 | self.client.login(username='another', password='another') 128 | response = self.client.post(url, data, format='json') 129 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) 130 | 131 | def test_user_can_like_others_reply(self): 132 | url = reverse('reply-like', kwargs={'pk': self.reply.id}) 133 | data = { 134 | "object_id": self.reply.id, 135 | "flag": "like", 136 | } 137 | self.client.login(username='test', password='test') 138 | response = self.client.post(url, data, format='json') 139 | self.assertEqual(response.status_code, status.HTTP_201_CREATED) 140 | self.assertEqual(Follow.objects.count(), 1) 141 | self.assertEqual(Follow.objects.get().user, self.user) 142 | self.assertEqual(Follow.objects.get().follow_object, self.reply) 143 | self.assertEqual(Follow.objects.get().flag, 'like') 144 | 145 | # 确认生成了通知 146 | self.assertEqual(Notification.objects.count(), 1) 147 | self.assertEqual(Notification.objects.get().recipient, self.reply.user) 148 | 149 | def test_user_can_not_like_same_reply_twice(self): 150 | url = reverse('reply-like', kwargs={'pk': self.reply.id}) 151 | data = { 152 | "object_id": self.reply.id, 153 | "flag": "like", 154 | } 155 | self.client.login(username='test', password='test') 156 | response = self.client.post(url, data, format='json') 157 | self.assertEqual(response.status_code, status.HTTP_201_CREATED) 158 | 159 | self.assertEqual(Follow.objects.count(), 1) 160 | 161 | response = self.client.post(url, data, format='json') 162 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) 163 | # An error occurred in the current transaction. 164 | # You can't execute queries until the end of the 'atomic' block. 165 | # 暂时不知道如何解决 166 | # self.assertEqual(Follow.objects.count(), 1) 167 | 168 | # 确认生成了通知 169 | # self.assertEqual(Notification.objects.count(), 1) 170 | # self.assertEqual(Notification.objects.get().recipient, self.reply.user) 171 | 172 | def test_user_can_delete_self_like(self): 173 | url = reverse('reply-like', kwargs={'pk': self.reply.id}) 174 | data = { 175 | "object_id": self.reply.id, 176 | "flag": "like", 177 | } 178 | self.client.login(username='test', password='test') 179 | response = self.client.post(url, data, format='json') 180 | self.assertEqual(response.status_code, status.HTTP_201_CREATED) 181 | self.assertEqual(Follow.objects.count(), 1) 182 | 183 | response = self.client.delete(url) 184 | self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) 185 | self.assertEqual(Follow.objects.count(), 0) 186 | 187 | def test_like_only_support_post_delete_methods(self): 188 | url = reverse('reply-like', kwargs={'pk': self.reply.id}) 189 | data = { 190 | "object_id": self.reply.id, 191 | "flag": "like", 192 | } 193 | self.client.login(username='test', password='test') 194 | 195 | response = self.client.get(url, format='json') 196 | self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) 197 | 198 | response = self.client.put(url, data, format='json') 199 | self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) 200 | 201 | response = self.client.patch(url, data, format='json') 202 | self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) 203 | -------------------------------------------------------------------------------- /replies/urls.py: -------------------------------------------------------------------------------- 1 | from django.conf.urls import url 2 | 3 | from . import views 4 | 5 | app_name = 'replies' 6 | urlpatterns = [ 7 | # url(r'^$', views.ReplyCreateView.as_view(), name='create_reply'), 8 | # url(r'^like/$', views.ReplyLikeCreateView.as_view(), name='like_reply'), 9 | # 10 | # # 以下两条 API 仅用于测试 11 | # url(r'^flat-list/$', views.FlatReplyListView.as_view(), name='flat_reply_list'), 12 | # url(r'^tree-list/$', views.TreeReplyListView.as_view(), name='tree_reply_list'), 13 | ] 14 | -------------------------------------------------------------------------------- /replies/views.py: -------------------------------------------------------------------------------- 1 | from actstream.models import Follow 2 | from django.db.utils import IntegrityError 3 | from django_comments import signals 4 | from notifications.signals import notify 5 | from rest_framework import mixins 6 | from rest_framework import permissions 7 | from rest_framework import status 8 | from rest_framework import viewsets 9 | from rest_framework.decorators import action 10 | from rest_framework.response import Response 11 | 12 | from replies.permissions import NotSelf 13 | from replies.serializers import (FollowSerializer, 14 | ReplyCreationSerializer) 15 | from replies.models import Reply 16 | 17 | 18 | class ReplyViewSet(mixins.CreateModelMixin, viewsets.GenericViewSet): 19 | serializer_class = ReplyCreationSerializer 20 | permission_classes = [permissions.IsAuthenticated, ] 21 | queryset = Reply.objects.filter(is_public=True, is_removed=False) 22 | 23 | def perform_create(self, serializer): 24 | parent_reply = serializer.validated_data.get('parent') 25 | reply = serializer.save(user=self.request.user, parent=parent_reply) 26 | 27 | # 创建相应的 notification 28 | signals.comment_was_posted.send( 29 | sender=reply.__class__, 30 | comment=reply, 31 | request=self.request 32 | ) 33 | 34 | @action( 35 | methods=['post', 'delete'], 36 | detail=True, 37 | permission_classes=[NotSelf, permissions.IsAuthenticated], 38 | serializer_class=FollowSerializer, 39 | ) 40 | def like(self, request, pk=None): 41 | reply = self.get_object() 42 | 43 | if self.request.method == 'POST': 44 | try: 45 | follow = Follow.objects.create( 46 | user=self.request.user, 47 | content_type=reply.ctype, 48 | object_id=reply.id, 49 | flag='like', 50 | ) 51 | except IntegrityError: 52 | return Response( 53 | {'detail': '已经赞过'}, 54 | status=status.HTTP_400_BAD_REQUEST 55 | ) 56 | serializer = FollowSerializer(follow, context={'request': request}) 57 | # 创建相应的 notification 58 | data = { 59 | 'recipient': reply.user, 60 | 'verb': 'like', 61 | 'action_object': reply.content_object, 62 | 'target': reply 63 | } 64 | notify.send(sender=self.request.user, **data) 65 | return Response(serializer.data, status=status.HTTP_201_CREATED) 66 | elif self.request.method == 'DELETE': 67 | Follow.objects.filter( 68 | user=self.request.user, 69 | content_type=reply.ctype_id, 70 | object_id=reply.id, 71 | flag='like' 72 | ).delete() 73 | return Response(status=status.HTTP_204_NO_CONTENT) 74 | -------------------------------------------------------------------------------- /requirements/base.txt: -------------------------------------------------------------------------------- 1 | django==1.11 2 | django-bootstrap-form==3.4 3 | django-contrib-comments==1.8.0 4 | django-notifications-hq==1.3 5 | django-rest-auth==0.9.3 6 | django-allauth==0.35.0 7 | django-ipware==2.0.1 8 | djangorestframework-jwt==1.11.0 9 | django-mptt==0.9.0 10 | pycodestyle==2.3.1 11 | djangorestframework==3.8.0 12 | pillow==5.0.0 13 | django-filter==1.1.0 14 | markdown==2.6.11 15 | coreapi==2.3.3 16 | factory-boy==2.10.0 17 | django-model-utils==3.1.1 18 | django-cors-headers==2.2.0 19 | mysqlclient==1.3.12 20 | django-imagekit==4.0.2 21 | django-environ==0.4.5 22 | raven 23 | git+https://github.com/zmrenwu/django-activity-stream.git@master#egg=django-activity-stream 24 | -------------------------------------------------------------------------------- /requirements/local.txt: -------------------------------------------------------------------------------- 1 | -r base.txt 2 | 3 | isort==4.3.4 -------------------------------------------------------------------------------- /requirements/production.txt: -------------------------------------------------------------------------------- 1 | -r base.txt -------------------------------------------------------------------------------- /runtests.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import django 5 | from django.conf import settings 6 | from django.test.utils import get_runner 7 | 8 | if __name__ == "__main__": 9 | os.environ['DJANGO_SETTINGS_MODULE'] = 'config.settings.local' 10 | django.setup() 11 | TestRunner = get_runner(settings) 12 | test_runner = TestRunner() 13 | failures = test_runner.run_tests([ 14 | 'replies', 'posts', 'tags', 'users', 'balance', 'notifications_extension' 15 | ]) 16 | sys.exit(bool(failures)) 17 | -------------------------------------------------------------------------------- /tags/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/tags/__init__.py -------------------------------------------------------------------------------- /tags/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | 3 | from .models import Tag 4 | 5 | admin.site.register(Tag) 6 | -------------------------------------------------------------------------------- /tags/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class TagsConfig(AppConfig): 5 | name = 'tags' 6 | -------------------------------------------------------------------------------- /tags/factories.py: -------------------------------------------------------------------------------- 1 | # factories that automatically create user data 2 | import factory 3 | 4 | from users.factories import UserFactory 5 | 6 | from .models import Tag 7 | 8 | 9 | class TagFactory(factory.DjangoModelFactory): 10 | class Meta: 11 | model = Tag 12 | 13 | name = factory.Sequence(lambda n: 'tag%s' % n) 14 | creator = factory.SubFactory(UserFactory, username='admin') 15 | -------------------------------------------------------------------------------- /tags/migrations/0001_initial.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by Django 1.11 on 2018-04-24 12:35 3 | from __future__ import unicode_literals 4 | 5 | from django.db import migrations, models 6 | import django.utils.timezone 7 | import model_utils.fields 8 | 9 | 10 | class Migration(migrations.Migration): 11 | 12 | initial = True 13 | 14 | dependencies = [ 15 | ] 16 | 17 | operations = [ 18 | migrations.CreateModel( 19 | name='Tag', 20 | fields=[ 21 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 22 | ('created', model_utils.fields.AutoCreatedField(default=django.utils.timezone.now, editable=False, verbose_name='created')), 23 | ('modified', model_utils.fields.AutoLastModifiedField(default=django.utils.timezone.now, editable=False, verbose_name='modified')), 24 | ('name', models.CharField(max_length=100, unique=True, verbose_name='标签名')), 25 | ], 26 | options={ 27 | 'verbose_name': '标签', 28 | 'verbose_name_plural': '标签', 29 | }, 30 | ), 31 | ] 32 | -------------------------------------------------------------------------------- /tags/migrations/0002_tag_creator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by Django 1.11 on 2018-04-24 12:35 3 | from __future__ import unicode_literals 4 | 5 | from django.conf import settings 6 | from django.db import migrations, models 7 | import django.db.models.deletion 8 | 9 | 10 | class Migration(migrations.Migration): 11 | 12 | initial = True 13 | 14 | dependencies = [ 15 | migrations.swappable_dependency(settings.AUTH_USER_MODEL), 16 | ('tags', '0001_initial'), 17 | ] 18 | 19 | operations = [ 20 | migrations.AddField( 21 | model_name='tag', 22 | name='creator', 23 | field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='创建者'), 24 | ), 25 | ] 26 | -------------------------------------------------------------------------------- /tags/migrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/tags/migrations/__init__.py -------------------------------------------------------------------------------- /tags/models.py: -------------------------------------------------------------------------------- 1 | from django.conf import settings 2 | from django.db import models 3 | from model_utils.models import TimeStampedModel 4 | 5 | 6 | class Tag(TimeStampedModel): 7 | name = models.CharField("标签名", max_length=100, unique=True) 8 | creator = models.ForeignKey( 9 | settings.AUTH_USER_MODEL, 10 | verbose_name="创建者", 11 | on_delete=models.CASCADE 12 | ) 13 | 14 | class Meta: 15 | verbose_name = "标签" 16 | verbose_name_plural = "标签" 17 | 18 | def __str__(self): 19 | return self.name 20 | -------------------------------------------------------------------------------- /tags/permissions.py: -------------------------------------------------------------------------------- 1 | from rest_framework import permissions 2 | 3 | 4 | class TagPermissionOrReadOnly(permissions.IsAdminUser): 5 | """ 6 | 仅允许管理员添加标签,(以后可添加普通用户的权限管理),其他用户只读 7 | """ 8 | def has_permission(self, request, view): 9 | return (super(TagPermissionOrReadOnly, self).has_permission(request, view) or 10 | request.method in permissions.SAFE_METHODS) 11 | -------------------------------------------------------------------------------- /tags/serializers.py: -------------------------------------------------------------------------------- 1 | from rest_framework import serializers 2 | 3 | from .models import Tag 4 | 5 | 6 | class TagSerializer(serializers.HyperlinkedModelSerializer): 7 | 8 | class Meta: 9 | model = Tag 10 | fields = ( 11 | 'id', 12 | 'url', 13 | 'name', 14 | ) 15 | -------------------------------------------------------------------------------- /tags/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/tags/tests/__init__.py -------------------------------------------------------------------------------- /tags/tests/test_views.py: -------------------------------------------------------------------------------- 1 | from django.urls import reverse 2 | from rest_framework import status 3 | from rest_framework.test import APITestCase 4 | 5 | from tags.models import Tag 6 | from users.models import User 7 | from posts.models import Post 8 | 9 | 10 | class TagTests(APITestCase): 11 | def setUp(self): 12 | self.user = User.objects.create_user(username='test', 13 | email='test@test.com', 14 | password='test', 15 | nickname='test') 16 | self.admin = User.objects.create_superuser(username='admin', 17 | email='admin@test.com', 18 | password='admin', 19 | nickname='admin') 20 | 21 | def test_admin_can_create_tag(self): 22 | """ 23 | 测试管理员可以添加标签 24 | """ 25 | url = reverse('tag-list') 26 | data = { 27 | 'name': 'test tag' 28 | } 29 | self.client.login(username='admin', password='admin') 30 | response = self.client.post(url, data, format='json') 31 | self.assertEqual(response.status_code, status.HTTP_201_CREATED) 32 | self.assertEqual(Tag.objects.count(), 1) 33 | self.assertEqual(Tag.objects.get().creator, self.admin) 34 | self.assertEqual(Tag.objects.get().name, 'test tag') 35 | 36 | def test_unauthorized_user_cannot_create_tag(self): 37 | """ 38 | 测试没有权限用户无法添加标签 39 | """ 40 | url = reverse('tag-list') 41 | data = { 42 | 'name': 'test tag' 43 | } 44 | self.client.login(username='test', password='test') 45 | response = self.client.post(url, data, format='json') 46 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) 47 | 48 | def test_anonymous_user_cannot_create_tag(self): 49 | """ 50 | 测试未登录用户无法添加标签 51 | """ 52 | url = reverse('tag-list') 53 | data = { 54 | 'name': 'test tag' 55 | } 56 | response = self.client.post(url, data, format='json') 57 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) 58 | 59 | def test_no_duplicate_tags(self): 60 | """ 61 | 测试没有重复的标签 62 | """ 63 | """ 64 | 测试管理员可以添加标签 65 | """ 66 | self.tag = Tag.objects.create(name='test tag', 67 | creator=self.admin 68 | ) 69 | url = reverse('tag-list') 70 | data = { 71 | 'name': 'test tag' 72 | } 73 | self.client.login(username='admin', password='admin') 74 | response = self.client.post(url, data, format='json') 75 | self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) 76 | 77 | def test_popular_tags(self): 78 | """ 79 | 测试热门标签 80 | """ 81 | self.post1 = Post.objects.create(title='this is a test1', 82 | body='this is a test', 83 | author=self.admin) 84 | self.post2 = Post.objects.create(title='this is a test2', 85 | body='this is a test', 86 | author=self.admin) 87 | self.post3 = Post.objects.create(title='this is a test3', 88 | body='this is a test', 89 | author=self.admin) 90 | self.tag1 = Tag.objects.create(name='test tag1', 91 | creator=self.admin) 92 | self.tag2 = Tag.objects.create(name='test tag2', 93 | creator=self.admin) 94 | self.tag3 = Tag.objects.create(name='test tag3', 95 | creator=self.admin) 96 | self.post1.tags.add(self.tag1, self.tag2) 97 | self.post2.tags.add(self.tag2, self.tag3, self.tag1) 98 | self.post3.tags.add(self.tag2) 99 | url = reverse('tag-popular') 100 | response = self.client.get(url, format='json') 101 | self.assertEqual(response.data[0]['name'], 'test tag2') 102 | self.assertEqual(response.data[1]['name'], 'test tag1') 103 | self.assertEqual(response.data[2]['name'], 'test tag3') 104 | -------------------------------------------------------------------------------- /tags/views.py: -------------------------------------------------------------------------------- 1 | from django.db.models import Count 2 | from rest_framework import viewsets 3 | from rest_framework.decorators import action 4 | from rest_framework.response import Response 5 | 6 | from utils.rest_tools import CustomPageNumberPagination 7 | 8 | from .models import Tag 9 | from .permissions import TagPermissionOrReadOnly 10 | from .serializers import TagSerializer 11 | 12 | 13 | class TagPageNumberPagination(CustomPageNumberPagination): 14 | page_size = 120 15 | max_page_size = 120 16 | 17 | 18 | class TagViewSet(viewsets.ModelViewSet): 19 | queryset = Tag.objects.all() 20 | serializer_class = TagSerializer 21 | permission_classes = (TagPermissionOrReadOnly,) 22 | http_method_names = ['get', 'post'] 23 | pagination_class = TagPageNumberPagination 24 | 25 | def perform_create(self, serializer): 26 | """ 27 | 因为author字段在PostSerializer里是ReadOnly,所以这里需要手动保存 28 | """ 29 | serializer.save(creator=self.request.user) 30 | 31 | @action(detail=False) 32 | def popular(self, request): 33 | """ 34 | 返回帖子数量最多的前10标签 35 | """ 36 | tags = self.get_queryset().annotate( 37 | num_posts=Count('post')).order_by('-num_posts')[:10] 38 | serializer = self.get_serializer(tags, many=True) 39 | return Response(serializer.data) 40 | -------------------------------------------------------------------------------- /templates/users/account/base.html: -------------------------------------------------------------------------------- 1 | 2 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | {% block head_title %}{% endblock %} 17 | {% block extra_head %} 18 | {% endblock %} 19 | 20 | 21 | {% block body %} 22 | 23 | {% if messages %} 24 |
25 | {% for message in messages %} 26 |
{{message}}
27 | {% endfor %} 28 |
29 | {% endif %} 30 | 31 | 57 |
58 | {% block content %} 59 | {% endblock %} 60 |
61 | {% endblock %} 62 | {% block extra_body %} 63 | {% endblock %} 64 | 65 | 66 | -------------------------------------------------------------------------------- /templates/users/account/email.html: -------------------------------------------------------------------------------- 1 | {% extends "account/base.html" %} 2 | 3 | {% load bootstrap %} 4 | {% load i18n %} 5 | 6 | {% block head_title %}{% trans "Account" %}{% endblock %} 7 | 8 | {% block content %} 9 |
10 |
11 |
12 |
13 |
14 | {% trans "E-mail Addresses" %} 15 |
16 |
17 | {% if user.emailaddress_set.all %} 18 |

{% trans 'The following e-mail addresses are associated with your account:' %}

19 | 43 | {% else %} 44 |

{% trans 'Warning:'%} {% trans "You currently do not have any e-mail address set up. You should really add an e-mail address so you can receive notifications, reset your password, etc." %}

45 | {% endif %} 46 |

{% trans "Add E-mail Address" %}

47 |
48 | {% csrf_token %} 49 | {{ form|bootstrap_horizontal }} 50 | 51 |
52 |
53 |
54 |
55 |
56 |
57 | {% endblock %} 58 | 59 | 60 | {% block extra_body %} 61 | 74 | {% endblock %} -------------------------------------------------------------------------------- /templates/users/account/email_confirm.html: -------------------------------------------------------------------------------- 1 | {% extends "account/base.html" %} 2 | 3 | {% load bootstrap %} 4 | {% load i18n %} 5 | {% load account %} 6 | 7 | {% block head_title %}{% trans "Confirm E-mail Address" %}{% endblock %} 8 | 9 | 10 | {% block content %} 11 |
12 |
13 |
14 |
15 |
16 | {% trans "Confirm E-mail Address" %} 17 |
18 |
19 | {% if confirmation %} 20 | {% user_display confirmation.email_address.user as user_display %} 21 |

{% blocktrans with confirmation.email_address.email as email %}Please confirm that {{ email }} is an e-mail address for user {{ user_display }}.{% endblocktrans %}

22 |
23 | {% csrf_token %} 24 | 25 |
26 | {% else %} 27 | {% url 'account_email' as email_url %} 28 |

{% blocktrans %}This e-mail confirmation link expired or is invalid. Please issue a new e-mail confirmation request.{% endblocktrans %}

29 | {% endif %} 30 |
31 |
32 |
33 |
34 |
35 | 36 | {% endblock %} 37 | -------------------------------------------------------------------------------- /templates/users/account/login.html: -------------------------------------------------------------------------------- 1 | {% extends "account/base.html" %} 2 | 3 | {% load bootstrap %} 4 | {% load i18n %} 5 | {% load account socialaccount %} 6 | 7 | {% block head_title %}{% trans "Sign In" %}{% endblock %} 8 | 9 | {% block content %} 10 | 11 | {% get_providers as socialaccount_providers %} 12 | 13 |
14 |
15 |
16 |
17 |
18 | {% trans "Sign In" %} 19 |
20 |
21 | 30 |
31 |
32 |
33 |
34 | {% if socialaccount_providers %} 35 |
36 |
37 | 使用第三方账号登陆 38 |
39 |
40 |
41 |
    42 | {% include "socialaccount/snippets/provider_list.html" with process="login" %} 43 |
44 |
45 |
46 |
47 | {% include "socialaccount/snippets/login_extra.html" %} 48 | {% else %} 49 |

{% blocktrans %}If you have not created an account yet, then please 50 | sign up first.{% endblocktrans %}

51 | {% endif %} 52 |
53 |
54 |
55 | 56 | {% endblock %} -------------------------------------------------------------------------------- /templates/users/account/password_change.html: -------------------------------------------------------------------------------- 1 | {% extends "account/base.html" %} 2 | 3 | {% load bootstrap %} 4 | {% load i18n %} 5 | 6 | {% block head_title %}{% trans "Change Password" %}{% endblock %} 7 | 8 | {% block content %} 9 |
10 |
11 |
12 |
13 |
14 | {% trans "Change Password" %} 15 |
16 |
17 |
18 | {% csrf_token %} 19 | {{ form|bootstrap_horizontal }} 20 | 21 |
22 |
23 |
24 |
25 |
26 |
27 | {% endblock %} 28 | -------------------------------------------------------------------------------- /templates/users/account/password_reset.html: -------------------------------------------------------------------------------- 1 | {% extends "account/base.html" %} 2 | 3 | {% load bootstrap %} 4 | {% load i18n %} 5 | {% load account %} 6 | 7 | {% block head_title %}{% trans "Password Reset" %}{% endblock %} 8 | 9 | {% block content %} 10 |
11 |
12 |
13 |
14 |
15 | {% trans "Password Reset" %} 16 |
17 |
18 | {% if user.is_authenticated %} 19 | {% include "account/snippets/already_logged_in.html" %} 20 | {% endif %} 21 |

{% trans "Forgotten your password? Enter your e-mail address below, and we'll send you an e-mail allowing you to reset it." %}

22 | 23 |
24 | {% csrf_token %} 25 | {{ form|bootstrap_horizontal }} 26 | 27 |
28 |

{% blocktrans %}Please contact us if you have any trouble resetting your password.{% endblocktrans %}

29 |
30 |
31 |
32 |
33 |
34 | {% endblock %} 35 | -------------------------------------------------------------------------------- /templates/users/account/password_set.html: -------------------------------------------------------------------------------- 1 | {% extends "account/base.html" %} 2 | 3 | {% load bootstrap %} 4 | {% load i18n %} 5 | 6 | {% block head_title %}{% trans "Set Password" %}{% endblock %} 7 | 8 | {% block content %} 9 |
10 |
11 |
12 |
13 |
14 | {% trans "Set Password" %} 15 |
16 |
17 |
18 | {% csrf_token %} 19 | {{ form|bootstrap_horizontal }} 20 | 21 |
22 |
23 |
24 |
25 |
26 |
27 | {% endblock %} 28 | -------------------------------------------------------------------------------- /templates/users/account/signup.html: -------------------------------------------------------------------------------- 1 | {% extends "account/base.html" %} 2 | 3 | {% load bootstrap %} 4 | {% load i18n %} 5 | 6 | {% block head_title %}{% trans "Signup" %}{% endblock %} 7 | 8 | {% block content %} 9 |
10 |
11 |
12 |
13 |
14 | {% trans "Sign Up" %} 15 |
16 |
17 |

{% blocktrans %}Already have an account? Then please sign in.{% endblocktrans %}

18 | 19 | 27 |
28 |
29 |
30 |
31 |
32 | 33 | {% endblock %} 34 | -------------------------------------------------------------------------------- /templates/users/socialaccount/connections.html: -------------------------------------------------------------------------------- 1 | {% extends "socialaccount/base.html" %} 2 | 3 | {% load i18n %} 4 | 5 | {% block head_title %}{% trans "Account Connections" %}{% endblock %} 6 | 7 | {% block content %} 8 |
9 |
10 |
11 |
12 |
13 | {% trans "Account Connections" %} 14 |
15 |
16 | {% if form.accounts %} 17 |

{% blocktrans %}You can sign in to your account using any of the following third party accounts:{% endblocktrans %}

18 |
19 | {% csrf_token %} 20 |
21 | {% if form.non_field_errors %} 22 |
{{ form.non_field_errors }}
23 | {% endif %} 24 | 25 | {% for base_account in form.accounts %} 26 | {% with base_account.get_provider_account as account %} 27 |
28 | 33 |
34 | {% endwith %} 35 | {% endfor %} 36 | 37 |
38 | 39 |
40 |
41 | 42 |
43 | 44 | {% else %} 45 |

{% trans 'You currently have no social network accounts connected to this account.' %}

46 | {% endif %} 47 | 48 |
{% trans 'Add a 3rd Party Account' %}
49 | 50 |
    51 | {% include "socialaccount/snippets/provider_list.html" with process="connect" %} 52 |
53 | 54 | {% include "socialaccount/snippets/login_extra.html" %} 55 |
56 |
57 |
58 |
59 |
60 | 61 | {% endblock %} 62 | -------------------------------------------------------------------------------- /users/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/users/__init__.py -------------------------------------------------------------------------------- /users/admin.py: -------------------------------------------------------------------------------- 1 | from django.contrib import admin 2 | 3 | from .models import User 4 | 5 | admin.site.register(User) 6 | -------------------------------------------------------------------------------- /users/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class UsersConfig(AppConfig): 5 | name = 'users' 6 | -------------------------------------------------------------------------------- /users/disable_csrf_middleware.py: -------------------------------------------------------------------------------- 1 | # Todo: 临时关闭 csrf 校验,需仔细评估安全问题 2 | 3 | from django.utils.deprecation import MiddlewareMixin 4 | 5 | 6 | class DisableCSRFCheck(MiddlewareMixin): 7 | def process_request(self, request): 8 | setattr(request, '_dont_enforce_csrf_checks', True) 9 | -------------------------------------------------------------------------------- /users/factories.py: -------------------------------------------------------------------------------- 1 | # factories that automatically create user data 2 | import factory 3 | 4 | from users.models import User 5 | 6 | 7 | class UserFactory(factory.DjangoModelFactory): 8 | class Meta: 9 | model = User 10 | 11 | username = factory.Sequence(lambda n: 'user%s' % n) 12 | email = factory.LazyAttribute(lambda o: '%s@example.com' % o.username) 13 | password = 'password' 14 | mugshot = factory.django.ImageField() 15 | 16 | @classmethod 17 | def _create(cls, model_class, *args, **kwargs): 18 | manager = cls._get_manager(model_class) 19 | return manager.create_user(*args, **kwargs) 20 | -------------------------------------------------------------------------------- /users/jwt_middleware.py: -------------------------------------------------------------------------------- 1 | from django.conf import settings 2 | from rest_framework_jwt.serializers import RefreshJSONWebTokenSerializer 3 | from rest_framework_jwt.settings import api_settings 4 | 5 | jwt_payload_handler = api_settings.JWT_PAYLOAD_HANDLER 6 | jwt_encode_handler = api_settings.JWT_ENCODE_HANDLER 7 | 8 | 9 | class JWTMiddleware(object): 10 | """ 11 | 将用户的Session状态和JWT同步的中间件 12 | """ 13 | def __init__(self, get_response): 14 | self.get_response = get_response 15 | 16 | def __call__(self, request): 17 | response = self.get_response(request) 18 | if not hasattr(request, 'user'): 19 | return response 20 | if request.user.is_authenticated(): 21 | if response.status_code != 200: 22 | return response 23 | 24 | if 'JWT' in request.COOKIES: 25 | # 刷新JWT 26 | serializer = RefreshJSONWebTokenSerializer( 27 | data={'token': request.COOKIES['JWT']}) 28 | if serializer.is_valid(): 29 | jwt_and_user = serializer.object 30 | if jwt_and_user['user'] == request.user: 31 | jwt = jwt_and_user['token'] 32 | else: 33 | jwt = jwt_encode_handler(jwt_payload_handler(request.user)) 34 | else: 35 | # 旧JWT无法解析的话,创建新的JWT 36 | jwt = jwt_encode_handler(jwt_payload_handler(request.user)) 37 | else: 38 | # JWT还不在cookie里的话,创建新的JWT 39 | jwt = jwt_encode_handler(jwt_payload_handler(request.user)) 40 | 41 | response.set_cookie( 42 | 'JWT', 43 | value=jwt, 44 | max_age=24 * 60 * 60 45 | ) 46 | else: 47 | # 用户已经登出,清理掉JWT 48 | if 'JWT' in request.COOKIES: 49 | response.delete_cookie('JWT') 50 | return response 51 | -------------------------------------------------------------------------------- /users/migrations/0001_initial.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by Django 1.11 on 2018-04-24 12:35 3 | from __future__ import unicode_literals 4 | 5 | import django.contrib.auth.models 6 | import django.contrib.auth.validators 7 | from django.db import migrations, models 8 | import django.utils.timezone 9 | import users.models 10 | 11 | 12 | class Migration(migrations.Migration): 13 | 14 | initial = True 15 | 16 | dependencies = [ 17 | ('auth', '0008_alter_user_username_max_length'), 18 | ] 19 | 20 | operations = [ 21 | migrations.CreateModel( 22 | name='User', 23 | fields=[ 24 | ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), 25 | ('password', models.CharField(max_length=128, verbose_name='password')), 26 | ('last_login', models.DateTimeField(blank=True, null=True, verbose_name='last login')), 27 | ('is_superuser', models.BooleanField(default=False, help_text='Designates that this user has all permissions without explicitly assigning them.', verbose_name='superuser status')), 28 | ('username', models.CharField(error_messages={'unique': 'A user with that username already exists.'}, help_text='Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.', max_length=150, unique=True, validators=[django.contrib.auth.validators.UnicodeUsernameValidator()], verbose_name='username')), 29 | ('first_name', models.CharField(blank=True, max_length=30, verbose_name='first name')), 30 | ('last_name', models.CharField(blank=True, max_length=30, verbose_name='last name')), 31 | ('email', models.EmailField(blank=True, max_length=254, verbose_name='email address')), 32 | ('is_staff', models.BooleanField(default=False, help_text='Designates whether the user can log into this admin site.', verbose_name='staff status')), 33 | ('is_active', models.BooleanField(default=True, help_text='Designates whether this user should be treated as active. Unselect this instead of deleting accounts.', verbose_name='active')), 34 | ('date_joined', models.DateTimeField(default=django.utils.timezone.now, verbose_name='date joined')), 35 | ('last_login_ip', models.GenericIPAddressField(blank=True, null=True, unpack_ipv4=True, verbose_name='最近一次登陆IP')), 36 | ('ip_joined', models.GenericIPAddressField(blank=True, null=True, unpack_ipv4=True, verbose_name='注册IP')), 37 | ('nickname', models.CharField(max_length=50, unique=True, verbose_name='昵称')), 38 | ('mugshot', models.ImageField(upload_to=users.models.user_mugshot_path, verbose_name='头像')), 39 | ('groups', models.ManyToManyField(blank=True, help_text='The groups this user belongs to. A user will get all permissions granted to each of their groups.', related_name='user_set', related_query_name='user', to='auth.Group', verbose_name='groups')), 40 | ('user_permissions', models.ManyToManyField(blank=True, help_text='Specific permissions for this user.', related_name='user_set', related_query_name='user', to='auth.Permission', verbose_name='user permissions')), 41 | ], 42 | options={ 43 | 'verbose_name': 'user', 44 | 'verbose_name_plural': 'users', 45 | 'abstract': False, 46 | }, 47 | managers=[ 48 | ('objects', django.contrib.auth.models.UserManager()), 49 | ], 50 | ), 51 | ] 52 | -------------------------------------------------------------------------------- /users/migrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/users/migrations/__init__.py -------------------------------------------------------------------------------- /users/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from django.core.files.base import ContentFile 4 | from django.contrib.auth.models import AbstractUser 5 | from django.contrib.auth.signals import user_logged_in 6 | from django.db import models 7 | 8 | from .mugshot import Avatar 9 | from .utils import get_ip_address_from_request 10 | 11 | 12 | def user_mugshot_path(instance, filename): 13 | return os.path.join('mugshots', instance.username, filename) 14 | 15 | from django.db import models 16 | from imagekit.models import ImageSpecField 17 | from imagekit.processors import ResizeToFill 18 | class User(AbstractUser): 19 | """ 20 | 用户模型定义 21 | """ 22 | last_login_ip = models.GenericIPAddressField( 23 | "最近一次登陆IP", 24 | unpack_ipv4=True, 25 | blank=True, 26 | null=True 27 | ) 28 | ip_joined = models.GenericIPAddressField("注册IP", unpack_ipv4=True, blank=True, null=True) 29 | 30 | nickname = models.CharField("昵称", max_length=50, unique=True) 31 | mugshot = models.ImageField("头像", upload_to=user_mugshot_path) 32 | mugshot_thumbnail = ImageSpecField(source='mugshot', 33 | processors=[ResizeToFill(100, 100)], 34 | format='JPEG', 35 | options={'quality': 60}) 36 | 37 | def __str__(self): 38 | return self.username 39 | 40 | def save(self, *args, **kwargs): 41 | if not self.mugshot: 42 | avatar = Avatar(rows=10, columns=10) 43 | image_byte_array = avatar.get_image( 44 | string=self.username, 45 | width=480, 46 | height=480, 47 | pad=10 48 | ) 49 | self.mugshot.save('default_mugshot.png', ContentFile(image_byte_array), save=False) 50 | if not self.pk and not self.nickname: 51 | # 自动将username存入到nickname域内 52 | self.nickname = self.username 53 | super(User, self).save(*args, **kwargs) 54 | 55 | 56 | def update_last_login_ip(sender, user, request, **kwargs): 57 | """ 58 | 更新用户最后一次登陆的IP地址 59 | """ 60 | ip = get_ip_address_from_request(request) 61 | if ip: 62 | user.last_login_ip = ip 63 | user.save() 64 | 65 | 66 | user_logged_in.connect(update_last_login_ip) 67 | -------------------------------------------------------------------------------- /users/mugshot.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/richardasaurus/randomavatar 3 | """ 4 | 5 | import random 6 | import math 7 | import hashlib 8 | from io import BytesIO 9 | from PIL import Image, ImageDraw 10 | 11 | 12 | class Avatar(object): 13 | def __init__(self, rows, columns): 14 | self.rows = rows 15 | self.cols = columns 16 | self._generate_colours() 17 | 18 | m = hashlib.md5() 19 | m.update(b"hello world") 20 | entropy = len(m.hexdigest()) / 2 * 8 21 | if self.rows > 15 or self.cols > 15: 22 | raise ValueError("Rows and columns must be valued 15 or under") 23 | 24 | self.digest = hashlib.md5 25 | self.digest_entropy = entropy 26 | 27 | def _generate_colours(self): 28 | colours_ok = False 29 | 30 | while colours_ok is False: 31 | self.fg_colour = self._get_pastel_colour() 32 | self.bg_colour = self._get_pastel_colour(lighten=80) 33 | 34 | # Get the luminance for each colour 35 | fg_lum = self._luminance(self.fg_colour) + 0.05 36 | bg_lum = self._luminance(self.bg_colour) + 0.05 37 | 38 | # Check the difference in luminance 39 | # meets the 1.25 threshold 40 | result = (fg_lum / bg_lum) \ 41 | if (fg_lum / bg_lum) else (bg_lum / fg_lum) 42 | if result > 1.20: 43 | colours_ok = True 44 | 45 | def get_image(self, string, width, height, pad=0): 46 | """ 47 | Byte representation of a PNG image 48 | """ 49 | hex_digest_byte_list = self._string_to_byte_list(string) 50 | matrix = self._create_matrix(hex_digest_byte_list) 51 | return self._create_image(matrix, width, height, pad) 52 | 53 | def save(self, image_byte_array=None, save_location=None): 54 | if image_byte_array and save_location: 55 | with open(save_location, 'wb') as f: 56 | return f.write(image_byte_array) 57 | else: 58 | raise ValueError('image_byte_array and path must be provided') 59 | 60 | def _get_pastel_colour(self, lighten=127): 61 | """ 62 | Create a pastel colour hex colour string 63 | """ 64 | def r(): 65 | return random.randint(0, 128) + 127 66 | return r(), r(), r() # return rgb values as a tuple 67 | 68 | def _luminance(self, rgb): 69 | """ 70 | Determine the liminanace of an RGB colour 71 | """ 72 | a = [] 73 | for v in rgb: 74 | v = v / float(255) 75 | if v < 0.03928: 76 | result = v / 12.92 77 | else: 78 | result = math.pow(((v + 0.055) / 1.055), 2.4) 79 | 80 | a.append(result) 81 | return a[0] * 0.2126 + a[1] * 0.7152 + a[2] * 0.0722 82 | 83 | def _string_to_byte_list(self, data): 84 | """ 85 | Creates a hex digest of the input string given to create the image, 86 | if it's not already hexadecimal 87 | Returns: 88 | Length 16 list of rgb value range integers 89 | (each representing a byte of the hex digest) 90 | """ 91 | bytes_length = 16 92 | 93 | m = self.digest() 94 | m.update(str.encode(data)) 95 | hex_digest = m.hexdigest() 96 | 97 | return list(int(hex_digest[num * 2:num * 2 + 2], bytes_length) 98 | for num in range(bytes_length)) 99 | 100 | def _bit_is_one(self, n, hash_bytes): 101 | """ 102 | Check if the n (index) of hash_bytes is 1 or 0. 103 | """ 104 | 105 | scale = 16 # hexadecimal 106 | 107 | if not hash_bytes[int(n / (scale / 2))] >> int( 108 | (scale / 2) - ((n % (scale / 2)) + 1)) & 1 == 1: 109 | return False 110 | return True 111 | 112 | def _create_image(self, matrix, width, height, pad): 113 | """ 114 | Generates a PNG byte list 115 | """ 116 | 117 | image = Image.new("RGB", (width + (pad * 2), 118 | height + (pad * 2)), self.bg_colour) 119 | image_draw = ImageDraw.Draw(image) 120 | 121 | # Calculate the block widht and height. 122 | block_width = width / self.cols 123 | block_height = height / self.rows 124 | 125 | # Loop through blocks in matrix, draw rectangles. 126 | for row, cols in enumerate(matrix): 127 | for col, cell in enumerate(cols): 128 | if cell: 129 | image_draw.rectangle(( 130 | pad + col * block_width, # x1 131 | pad + row * block_height, # y1 132 | pad + (col + 1) * block_width - 1, # x2 133 | pad + (row + 1) * block_height - 1 # y2 134 | ), fill=self.fg_colour) 135 | 136 | stream = BytesIO() 137 | image.save(stream, format="png", optimize=True) 138 | # return the image byte data 139 | return stream.getvalue() 140 | 141 | def _create_matrix(self, byte_list): 142 | """ 143 | This matrix decides which blocks should be filled fg/bg colour 144 | True for fg_colour 145 | False for bg_colour 146 | hash_bytes - array of hash bytes values. RGB range values in each slot 147 | Returns: 148 | List representation of the matrix 149 | [[True, True, True, True], 150 | [False, True, True, False], 151 | [True, True, True, True], 152 | [False, False, False, False]] 153 | """ 154 | 155 | # Number of rows * cols halfed and rounded 156 | # in order to fill opposite side 157 | cells = int(self.rows * self.cols / 2 + self.cols % 2) 158 | 159 | matrix = [[False] * self.cols for num in range(self.rows)] 160 | 161 | for cell_number in range(cells): 162 | 163 | # If the bit with index corresponding to this cell is 1 164 | # mark that cell as fg_colour 165 | # Skip byte 1, that's used in determining fg_colour 166 | if self._bit_is_one(cell_number, byte_list[1:]): 167 | # Find cell coordinates in matrix. 168 | x_row = cell_number % self.rows 169 | y_col = int(cell_number / self.cols) 170 | # Set coord True and its opposite side 171 | matrix[x_row][self.cols - y_col - 1] = True 172 | matrix[x_row][y_col] = True 173 | return matrix 174 | -------------------------------------------------------------------------------- /users/permissions.py: -------------------------------------------------------------------------------- 1 | from rest_framework import permissions 2 | 3 | 4 | class IsVerified(permissions.BasePermission): 5 | """ 6 | 设置主 email 时必须为已验证的 email 7 | """ 8 | 9 | def has_object_permission(self, request, view, obj): 10 | if obj.verified: 11 | return True 12 | return False 13 | 14 | 15 | class NotPrimary(permissions.BasePermission): 16 | """ 17 | 不能删除主 email 18 | """ 19 | 20 | def has_object_permission(self, request, view, obj): 21 | if obj.primary: 22 | return False 23 | return True 24 | -------------------------------------------------------------------------------- /users/serializers.py: -------------------------------------------------------------------------------- 1 | from allauth.account.adapter import get_adapter 2 | from allauth.account.models import EmailAddress 3 | from allauth.account.utils import setup_user_email 4 | from rest_auth.registration.serializers import RegisterSerializer 5 | from rest_framework import serializers 6 | from rest_framework.fields import CurrentUserDefault 7 | 8 | from .models import User 9 | from .utils import get_ip_address_from_request 10 | 11 | from .validators import FileValidator 12 | 13 | 14 | class UserDetailsSerializer(serializers.ModelSerializer): 15 | """ 16 | 用户详细信息的序列器 17 | """ 18 | mugshot_url = serializers.SerializerMethodField(source='mugshot_thumbnail.url') 19 | post_count = serializers.SerializerMethodField() 20 | reply_count = serializers.SerializerMethodField() 21 | mugshot = serializers.ImageField( 22 | validators=[FileValidator(max_size=2 * 1024 * 1024, allowed_extensions=('png', 'jpg', 'jpeg'))]) 23 | 24 | class Meta: 25 | model = User 26 | read_only_fields = ( 27 | 'id', 28 | 'username', 29 | 'date_joined', 30 | 'mugshot_url', 31 | 'ip_joined', 32 | 'last_login_ip', 33 | 'is_superuser', 34 | 'is_staff', 35 | ) 36 | fields = ( 37 | 'id', 38 | 'username', 39 | 'nickname', 40 | 'email', 41 | 'mugshot', 42 | 'date_joined', 43 | 'mugshot_url', 44 | 'ip_joined', 45 | 'last_login_ip', 46 | 'is_superuser', 47 | 'is_staff', 48 | 'post_count', 49 | 'reply_count', 50 | ) 51 | 52 | def get_post_count(self, obj): 53 | """ 54 | 返回用户提交的帖子数量 55 | """ 56 | return obj.post_set.count() 57 | 58 | def get_mugshot_url(self, obj): 59 | return obj.mugshot_thumbnail.url 60 | 61 | def get_reply_count(self, obj): 62 | """ 63 | 返回用户提交的回复数量 64 | """ 65 | return obj.reply_comments.count() 66 | 67 | 68 | class UserRegistrationSerializer(RegisterSerializer): 69 | """ 70 | 继承至rest_auth的默认序列器,增加了昵称 71 | """ 72 | 73 | def save(self, request): 74 | """ 75 | 改写父类的save方法,检测并存入用户的注册IP地址 76 | """ 77 | adapter = get_adapter() 78 | user = adapter.new_user(request) 79 | self.cleaned_data = self.get_cleaned_data() 80 | ip = get_ip_address_from_request(request) 81 | if ip: 82 | user.ip_joined = ip 83 | adapter.save_user(request, user, self) 84 | self.custom_signup(request, user) 85 | setup_user_email(request, user, []) 86 | return user 87 | 88 | 89 | class EmailAddressSerializer(serializers.ModelSerializer): 90 | class Meta: 91 | model = EmailAddress 92 | fields = ( 93 | 'id', 94 | 'user', 95 | 'email', 96 | 'verified', 97 | 'primary', 98 | ) 99 | read_only_fields = ('id', 'user', 'verified', 'primary') 100 | 101 | 102 | class UserSimpleDetailsSerializer(serializers.ModelSerializer): 103 | class Meta: 104 | model = User 105 | read_only_fields = ( 106 | 'id', 107 | 'username', 108 | 'nickname', 109 | ) 110 | fields = ( 111 | 'id', 112 | 'username', 113 | 'nickname', 114 | ) 115 | -------------------------------------------------------------------------------- /users/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/users/tests/__init__.py -------------------------------------------------------------------------------- /users/tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/users/tests/unit_tests/__init__.py -------------------------------------------------------------------------------- /users/tests/unit_tests/test_models.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | from django.test.client import RequestFactory 3 | 4 | from users.models import User, update_last_login_ip 5 | 6 | 7 | class UserSignalTests(TestCase): 8 | """ 9 | 用户信号函数的测试 10 | """ 11 | def setUp(self): 12 | """ 13 | 创建request工厂,创建测试用户 14 | """ 15 | super(UserSignalTests, self).setUp() 16 | self.rf = RequestFactory() 17 | self.user = User.objects.create(nickname='test-user', username='test-user') 18 | 19 | def test_update_last_login_ip(self): 20 | """ 21 | 测试当request里包含IP信息的时候,会被成功存入用户的last_login_ip域 22 | """ 23 | test_ip = '210.1.1.1' 24 | request = self.rf.get('/', REMOTE_ADDR=test_ip) 25 | update_last_login_ip(None, self.user, request) 26 | self.assertEqual(self.user.last_login_ip, test_ip) 27 | 28 | def test_update_last_login_ip__without_ip(self): 29 | """ 30 | 测试当request里不包含IP信息的时候,这个函数仍然会正常返回 31 | """ 32 | request = self.rf.get('/') 33 | update_last_login_ip(None, self.user, request) 34 | self.assertEqual(self.user.last_login_ip, None) 35 | -------------------------------------------------------------------------------- /users/tests/unit_tests/test_serializers.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | 3 | from users.models import User 4 | from users.serializers import UserDetailsSerializer 5 | 6 | 7 | class UserDetailsSerializerTests(TestCase): 8 | """ 9 | 用户详细信息序列器的测试 10 | """ 11 | def setUp(self): 12 | """ 13 | 创建测试用户 14 | """ 15 | super(UserDetailsSerializerTests, self).setUp() 16 | self.user = User.objects.create(nickname='test-user', username='test-user') 17 | 18 | def test_serialize_user_details(self): 19 | """ 20 | 测试序列化用户实例 21 | """ 22 | serializer = UserDetailsSerializer(self.user) 23 | serialized_data = serializer.data 24 | self.assertEqual(serialized_data['id'], self.user.id) 25 | self.assertEqual(serialized_data['username'], self.user.username) 26 | self.assertEqual(serialized_data['nickname'], self.user.nickname) 27 | self.assertEqual(serialized_data['ip_joined'], self.user.ip_joined) 28 | self.assertEqual(serialized_data['last_login_ip'], self.user.last_login_ip) 29 | self.assertEqual(serialized_data['is_superuser'], self.user.is_superuser) 30 | self.assertEqual(serialized_data['is_staff'], self.user.is_staff) 31 | self.assertEqual(serialized_data['post_count'], self.user.post_set.count()) 32 | self.assertEqual(serialized_data['reply_count'], self.user.reply_comments.count()) 33 | -------------------------------------------------------------------------------- /users/tests/unit_tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from django.test import TestCase 2 | from django.test.client import RequestFactory 3 | 4 | from users.utils import get_ip_address_from_request 5 | 6 | 7 | class GetIpAddressTests(TestCase): 8 | """ 9 | get_ip_address_from_request函数的测试 10 | """ 11 | def setUp(self): 12 | """ 13 | 创建request工厂,创建测试用户 14 | """ 15 | super(GetIpAddressTests, self).setUp() 16 | self.rf = RequestFactory() 17 | 18 | def test_get_ip_address_from_request(self): 19 | """ 20 | 测试当request里包含IP信息的时候,会正确返回IP地址 21 | """ 22 | test_ip = '210.1.1.1' 23 | request = self.rf.get('/', REMOTE_ADDR=test_ip) 24 | ip_address = get_ip_address_from_request(request) 25 | self.assertEqual(ip_address, test_ip) 26 | 27 | def test_get_ip_address_from_request__without_ip(self): 28 | """ 29 | 测试当request里不包含IP信息的时候,函数仍然会正常返回 30 | """ 31 | request = self.rf.get('/') 32 | ip_address = get_ip_address_from_request(request) 33 | self.assertEqual(ip_address, None) 34 | -------------------------------------------------------------------------------- /users/tests/unit_tests/test_views.py: -------------------------------------------------------------------------------- 1 | from allauth.account.models import EmailAddress 2 | from django.contrib.contenttypes.models import ContentType 3 | from django.contrib.sites.models import Site 4 | from django.utils.timezone import timedelta 5 | from rest_framework import status, test 6 | from rest_framework.reverse import reverse 7 | 8 | from balance.models import Record 9 | from posts.models import Post 10 | from replies.models import Reply 11 | from replies.serializers import FlatReplySerializer 12 | from ...models import User 13 | 14 | 15 | class UserViewSetTestCase(test.APITestCase): 16 | def setUp(self): 17 | self.user = User.objects.create_user( 18 | username='test', 19 | email='test@test.com', 20 | password='test', 21 | nickname='test' 22 | ) 23 | 24 | self.another_user = User.objects.create_user( 25 | username='another', 26 | email='another@test.com', 27 | password='another', 28 | nickname='another' 29 | ) 30 | 31 | self.site = Site.objects.create(name='test', domain='test.com') 32 | self.post = Post.objects.create( 33 | title='test title', 34 | author=self.another_user 35 | ) 36 | self.post_ct = ContentType.objects.get_for_model(self.post) 37 | self.post_id = self.post.id 38 | 39 | def test_return_user_replies(self): 40 | reply1 = Reply.objects.create( 41 | content_type=self.post_ct, 42 | object_pk=self.post_id, 43 | site=self.site, 44 | user=self.user, 45 | comment='reply1', 46 | ) 47 | reply2 = Reply.objects.create( 48 | content_type=self.post_ct, 49 | object_pk=self.post_id, 50 | site=self.site, 51 | user=self.user, 52 | comment='reply2', 53 | parent=reply1 54 | ) 55 | reply3 = Reply.objects.create( 56 | content_type=self.post_ct, 57 | object_pk=self.post_id, 58 | site=self.site, 59 | user=self.another_user, 60 | comment='reply3', 61 | parent=reply1 62 | ) 63 | 64 | url = reverse('user-replies', kwargs={'pk': self.user.id}) 65 | response = self.client.get(url) 66 | self.assertEqual(response.status_code, status.HTTP_200_OK) 67 | self.assertEqual( 68 | response.data['data'], 69 | FlatReplySerializer( 70 | self.user.reply_comments.filter( 71 | is_public=True, is_removed=False), 72 | many=True, context={'request': response.wsgi_request} 73 | ).data 74 | ) 75 | 76 | def test_anonymous_user_cannot_checkin(self): 77 | url = reverse('user-checkin', kwargs={'pk': self.user.id}) 78 | response = self.client.post(url) 79 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) 80 | 81 | def test_current_request_user_can_checkin(self): 82 | url = reverse('user-checkin', kwargs={'pk': self.user.id}) 83 | self.client.login(username='test', password='test') 84 | response = self.client.post(url) 85 | 86 | self.assertEqual(response.status_code, status.HTTP_201_CREATED) 87 | self.assertEqual(Record.objects.count(), 1) 88 | self.assertEqual(Record.objects.get().user, self.user) 89 | 90 | def test_non_current_request_user_cannot_checkin(self): 91 | url = reverse('user-checkin', kwargs={'pk': self.another_user.id}) 92 | self.client.login(username='test', password='test') 93 | response = self.client.post(url) 94 | 95 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) 96 | 97 | def test_user_can_only_chekin_once_per_day(self): 98 | url = reverse('user-checkin', kwargs={'pk': self.user.id}) 99 | self.client.login(username='test', password='test') 100 | response = self.client.post(url) 101 | self.assertEqual(response.status_code, status.HTTP_201_CREATED) 102 | self.assertEqual(Record.objects.count(), 1) 103 | self.assertEqual(Record.objects.get().user, self.user) 104 | 105 | # 再一次签到 106 | response = self.client.post(url) 107 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) 108 | 109 | def test_user_can_checkin_after_a_day(self): 110 | record = Record.objects.create( 111 | reward_type=0, 112 | coin_type=2, 113 | amount=10, 114 | description='', 115 | user=self.user, 116 | ) 117 | record.created_time = record.created_time - timedelta(days=1) 118 | record.save() 119 | record.refresh_from_db() 120 | 121 | url = reverse('user-checkin', kwargs={'pk': self.user.id}) 122 | self.client.login(username='test', password='test') 123 | response = self.client.post(url) 124 | self.assertEqual(response.status_code, status.HTTP_201_CREATED) 125 | self.assertEqual(Record.objects.count(), 2) 126 | 127 | def test_user_can_get_non_hidden_posts(self): 128 | Post.objects.create( 129 | title='test', 130 | author=self.user, 131 | ) 132 | 133 | Post.objects.create( 134 | title='test2', 135 | author=self.user, 136 | ) 137 | 138 | Post.objects.create( 139 | title='test3', 140 | author=self.user, 141 | ) 142 | 143 | Post.objects.create( 144 | title='test4', 145 | author=self.user, 146 | hidden=True 147 | ) 148 | 149 | url = reverse('user-posts', kwargs={'pk': self.user.id}) 150 | response = self.client.get(url) 151 | self.assertEqual(response.status_code, status.HTTP_200_OK) 152 | self.assertEqual(len(response.data['data']), 3) 153 | 154 | def test_user_can_only_get_self_hidden_posts(self): 155 | Post.objects.create( 156 | title='test', 157 | author=self.user, 158 | ) 159 | 160 | Post.objects.create( 161 | title='test', 162 | author=self.user, 163 | hidden=True 164 | ) 165 | 166 | Post.objects.create( 167 | title='test', 168 | author=self.another_user, 169 | hidden=True 170 | ) 171 | 172 | url = reverse('user-posts', kwargs={'pk': self.user.id}) 173 | response = self.client.get(url, {'hidden': 'true'}) 174 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) 175 | 176 | self.client.login(username='test', password='test') 177 | response = self.client.get(url, {'hidden': 'true'}) 178 | self.assertEqual(len(response.data['data']), 1) 179 | 180 | other_url = reverse('user-posts', kwargs={'pk': self.another_user.id}) 181 | response = self.client.get(other_url, {'hidden': 'true'}) 182 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) 183 | 184 | def test_can_get_user_treasure(self): 185 | Record.objects.create( 186 | reward_type=0, 187 | coin_type=2, 188 | amount=10, 189 | user=self.user, 190 | ) 191 | Record.objects.create( 192 | reward_type=0, 193 | coin_type=2, 194 | amount=25, 195 | user=self.user, 196 | ) 197 | Record.objects.create( 198 | reward_type=0, 199 | coin_type=1, 200 | amount=10, 201 | user=self.user, 202 | ) 203 | Record.objects.create( 204 | reward_type=0, 205 | coin_type=1, 206 | amount=35, 207 | user=self.user, 208 | ) 209 | Record.objects.create( 210 | reward_type=0, 211 | coin_type=0, 212 | amount=35, 213 | user=self.user, 214 | ) 215 | Record.objects.create( 216 | reward_type=0, 217 | coin_type=0, 218 | amount=35, 219 | user=self.another_user, 220 | ) 221 | url = reverse('user-balance', kwargs={'pk': self.user.id}) 222 | response = self.client.get(url) 223 | self.assertEqual(response.status_code, status.HTTP_200_OK) 224 | balance_data = list(response.data) 225 | self.assertTrue({'coin_type': 0, 'amount__sum': 35} in balance_data) 226 | self.assertTrue({'coin_type': 1, 'amount__sum': 45} in balance_data) 227 | self.assertTrue({'coin_type': 2, 'amount__sum': 35} in balance_data) 228 | 229 | 230 | class EmailAddressViewSetTestCase(test.APITestCase): 231 | def setUp(self): 232 | self.user = User.objects.create_user( 233 | username='test', 234 | email='test@test.com', 235 | password='test', 236 | nickname='test' 237 | ) 238 | 239 | self.another_user = User.objects.create_user( 240 | username='another', 241 | email='another@test.com', 242 | password='another', 243 | nickname='another' 244 | ) 245 | 246 | self.email = EmailAddress.objects.create( 247 | user=self.user, 248 | email=self.user.email, 249 | verified=True, 250 | primary=True 251 | ) 252 | 253 | self.unverified_email = EmailAddress.objects.create( 254 | user=self.user, 255 | email='unverified@test.com', 256 | verified=False, 257 | primary=False 258 | ) 259 | 260 | self.another_user_email = EmailAddress.objects.create( 261 | user=self.another_user, 262 | email=self.another_user.email, 263 | verified=True, 264 | primary=True 265 | ) 266 | 267 | def test_anonymous_user_cannot_operate_email(self): 268 | list_url = reverse('email-list') 269 | retrieve_url = reverse('email-detail', kwargs={'pk': self.email.id}) 270 | set_primary_url = reverse( 271 | 'email-set-primary', kwargs={'pk': self.email.id}) 272 | reverify_url = reverse('email-reverify', kwargs={'pk': self.email.id}) 273 | 274 | response = self.client.get(list_url) 275 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) 276 | 277 | response = self.client.get(retrieve_url) 278 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) 279 | 280 | response = self.client.post(list_url, data={ 281 | 'user': self.user, 282 | 'email': 'new@email.com', 283 | 'verified': True, 284 | 'primary': True 285 | }) 286 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) 287 | 288 | response = self.client.delete( 289 | list_url, data={'email': self.user.email}) 290 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) 291 | 292 | response = self.client.post(set_primary_url) 293 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) 294 | 295 | response = self.client.get(reverify_url) 296 | self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) 297 | 298 | def test_user_can_get_self_email(self): 299 | url = reverse('email-list') 300 | self.client.login(username='test', password='test') 301 | response = self.client.get(url) 302 | 303 | self.assertEqual(response.status_code, status.HTTP_200_OK) 304 | self.assertEqual(len(response.data), 7) 305 | self.assertTrue( 306 | all([ret['user'] == self.user.id for ret in response.data['data']]) 307 | ) 308 | 309 | def test_user_cannnot_get_others_email(self): 310 | url = reverse('email-detail', 311 | kwargs={'pk': self.another_user_email.id}) 312 | self.client.login(username='test', password='test') 313 | response = self.client.get(url) 314 | 315 | self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) 316 | 317 | def test_user_can_add_email(self): 318 | url = reverse('email-list') 319 | self.client.login(username='test', password='test') 320 | response = self.client.post(url, data={ 321 | 'user': self.user, 322 | 'email': 'new@email.com', 323 | 'verified': True, 324 | 'primary': True 325 | }) 326 | self.assertEqual(response.status_code, status.HTTP_201_CREATED) 327 | self.assertEqual(self.user.emailaddress_set.count(), 3) 328 | 329 | def test_user_cannot_set_unverified_email_to_primary(self): 330 | url = reverse('email-set-primary', 331 | kwargs={'pk': self.unverified_email.id}) 332 | self.client.login(username='test', password='test') 333 | response = self.client.post(url) 334 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) 335 | 336 | def test_user_can_set_verified_email_to_primary(self): 337 | verified_unprimary_email = EmailAddress.objects.create( 338 | user=self.user, 339 | email='verified_unprimary_email@test.com', 340 | verified=True, 341 | primary=False 342 | ) 343 | url = reverse('email-set-primary', 344 | kwargs={'pk': verified_unprimary_email.id}) 345 | self.client.login(username='test', password='test') 346 | response = self.client.post(url) 347 | self.assertEqual(response.status_code, status.HTTP_201_CREATED) 348 | 349 | # 新的 primary email 设置成功 350 | new_primary_email = EmailAddress.objects.get( 351 | pk=verified_unprimary_email.id) 352 | self.assertTrue(new_primary_email.primary) 353 | 354 | # 旧的 primary email 被设置为非 primary email 355 | old_primary_email = EmailAddress.objects.get(pk=self.email.id) 356 | self.assertFalse(old_primary_email.primary) 357 | 358 | def test_can_delete_non_primary_email(self): 359 | url = reverse('email-detail', kwargs={'pk': self.unverified_email.id}) 360 | self.client.login(username='test', password='test') 361 | response = self.client.delete(url) 362 | self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) 363 | self.assertEqual(self.user.emailaddress_set.count(), 1) 364 | 365 | def test_user_cannot_delete_primary_email(self): 366 | url = reverse('email-detail', kwargs={'pk': self.email.id}) 367 | self.client.login(username='test', password='test') 368 | response = self.client.delete(url, data={'email': self.email.email}) 369 | self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) 370 | self.assertEqual(self.user.emailaddress_set.count(), 2) 371 | 372 | def test_user_cannot_delete_others_email(self): 373 | url = reverse('email-detail', 374 | kwargs={'pk': self.another_user_email.id}) 375 | self.client.login(username='test', password='test') 376 | response = self.client.delete(url) 377 | self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) 378 | self.assertEqual(self.another_user.emailaddress_set.count(), 1) 379 | 380 | def test_user_can_reverify_email(self): 381 | url = reverse('email-reverify', kwargs={'pk': self.email.id}) 382 | self.client.login(username='test', password='test') 383 | response = self.client.get(url) 384 | self.assertEqual(response.status_code, status.HTTP_200_OK) 385 | -------------------------------------------------------------------------------- /users/utils.py: -------------------------------------------------------------------------------- 1 | from django.conf import settings 2 | from ipware import get_client_ip 3 | 4 | 5 | def get_ip_address_from_request(request): 6 | """ 7 | 返回request里的IP地址 8 | 提示: 9 | 为了开发方便,这个函数会返回类似127.0.0.1之类无法在公网被路由的地址, 10 | 在生产环境中,类似地址不会被返回 11 | """ 12 | ip, is_routable = get_client_ip(request) 13 | if settings.DEBUG: 14 | return ip 15 | else: 16 | if ip is not None and is_routable: 17 | return ip 18 | return None 19 | -------------------------------------------------------------------------------- /users/validators.py: -------------------------------------------------------------------------------- 1 | # take from https://gist.github.com/jrosebr1/2140738 2 | 3 | # @brief 4 | # Performs file upload validation for django. The original version implemented 5 | # by dokterbob had some problems with determining the correct mimetype and 6 | # determining the size of the file uploaded (at least within my Django application 7 | # that is). 8 | 9 | # @author dokterbob 10 | # @author jrosebr1 11 | 12 | import mimetypes 13 | from os.path import splitext 14 | 15 | from django.core.exceptions import ValidationError 16 | from django.template.defaultfilters import filesizeformat 17 | from django.utils.deconstruct import deconstructible 18 | from django.utils.translation import ugettext_lazy as _ 19 | 20 | 21 | @deconstructible 22 | class FileValidator(object): 23 | """ 24 | Validator for files, checking the size, extension and mimetype. 25 | 26 | Initialization parameters: 27 | allowed_extensions: iterable with allowed file extensions 28 | ie. ('txt', 'doc') 29 | allowd_mimetypes: iterable with allowed mimetypes 30 | ie. ('image/png', ) 31 | min_size: minimum number of bytes allowed 32 | ie. 100 33 | max_size: maximum number of bytes allowed 34 | ie. 24*1024*1024 for 24 MB 35 | 36 | Usage example:: 37 | 38 | MyModel(models.Model): 39 | myfile = FileField(validators=FileValidator(max_size=24*1024*1024), ...) 40 | 41 | """ 42 | 43 | extension_message = _("Extension '%(extension)s' not allowed. Allowed extensions are: '%(allowed_extensions)s.'") 44 | mime_message = _("MIME type '%(mimetype)s' is not valid. Allowed types are: %(allowed_mimetypes)s.") 45 | min_size_message = _('The current file %(size)s, which is too small. The minumum file size is %(allowed_size)s.') 46 | max_size_message = _('The current file %(size)s, which is too large. The maximum file size is %(allowed_size)s.') 47 | 48 | def __init__(self, *args, **kwargs): 49 | self.allowed_extensions = kwargs.pop('allowed_extensions', None) 50 | self.allowed_mimetypes = kwargs.pop('allowed_mimetypes', None) 51 | self.min_size = kwargs.pop('min_size', 0) 52 | self.max_size = kwargs.pop('max_size', None) 53 | 54 | def __call__(self, value): 55 | """ 56 | Check the extension, content type and file size. 57 | """ 58 | 59 | # Check the extension 60 | ext = splitext(value.name)[1][1:].lower() 61 | if self.allowed_extensions and not ext in self.allowed_extensions: 62 | message = self.extension_message % { 63 | 'extension': ext, 64 | 'allowed_extensions': ', '.join(self.allowed_extensions) 65 | } 66 | 67 | raise ValidationError(message) 68 | 69 | # Check the content type 70 | mimetype = mimetypes.guess_type(value.name)[0] 71 | if self.allowed_mimetypes and not mimetype in self.allowed_mimetypes: 72 | message = self.mime_message % { 73 | 'mimetype': mimetype, 74 | 'allowed_mimetypes': ', '.join(self.allowed_mimetypes) 75 | } 76 | 77 | raise ValidationError(message) 78 | 79 | # Check the file size 80 | filesize = len(value) 81 | if self.max_size and filesize > self.max_size: 82 | message = self.max_size_message % { 83 | 'size': filesizeformat(filesize), 84 | 'allowed_size': filesizeformat(self.max_size) 85 | } 86 | 87 | raise ValidationError(message) 88 | 89 | elif filesize < self.min_size: 90 | message = self.min_size_message % { 91 | 'size': filesizeformat(filesize), 92 | 'allowed_size': filesizeformat(self.min_size) 93 | } 94 | 95 | raise ValidationError(message) 96 | -------------------------------------------------------------------------------- /users/views.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | from allauth.account.views import ConfirmEmailView as AllAuthConfirmEmailView 5 | from allauth.socialaccount.providers.github.views import GitHubOAuth2Adapter 6 | from allauth.socialaccount.providers.oauth2.client import OAuth2Client 7 | from django.conf import settings 8 | from django.contrib.sites.shortcuts import get_current_site 9 | from django.db.models import Sum 10 | from django.utils import timezone 11 | from rest_auth.registration.views import ( 12 | LoginView, RegisterView, SocialConnectView, SocialLoginView) 13 | from rest_framework import mixins, permissions, status, views, viewsets 14 | from rest_framework.decorators import action 15 | from rest_framework.parsers import FileUploadParser 16 | from rest_framework.permissions import AllowAny 17 | from rest_framework.response import Response 18 | 19 | from balance.models import Record 20 | from balance.permissions import IsCurrentUser, OncePerDay 21 | from balance.serializers import BalanceSerializer 22 | from posts.serializers import IndexPostListSerializer 23 | from replies.serializers import FlatReplySerializer 24 | 25 | from .models import User 26 | from .permissions import IsVerified, NotPrimary 27 | from .serializers import EmailAddressSerializer, UserDetailsSerializer 28 | 29 | 30 | class RegisterViewCustom(RegisterView): 31 | """ 32 | 注册视图取消authentication_class一次避免CSRF校验 33 | """ 34 | authentication_classes = () 35 | 36 | 37 | class LoginViewCustom(LoginView): 38 | """ 39 | 登陆视图取消authentication_class以此避免CSRF校验 40 | """ 41 | authentication_classes = () 42 | 43 | 44 | class ConfirmEmailView(AllAuthConfirmEmailView): 45 | template_name = 'account/email_confirm.html' 46 | 47 | 48 | class GitHubLogin(SocialLoginView): 49 | authentication_classes = () 50 | adapter_class = GitHubOAuth2Adapter 51 | client_class = OAuth2Client 52 | callback_url = getattr(settings, 'SOCIAL_LOGIN_GITHUB_CALLBACK_URL') 53 | 54 | 55 | class GitHubConnect(SocialConnectView): 56 | adapter_class = GitHubOAuth2Adapter 57 | client_class = OAuth2Client 58 | callback_url = getattr(settings, 'SOCIAL_LOGIN_GITHUB_CALLBACK_URL') 59 | 60 | 61 | class UserViewSets( 62 | mixins.RetrieveModelMixin, 63 | mixins.UpdateModelMixin, 64 | viewsets.GenericViewSet 65 | ): 66 | queryset = User.objects.all() 67 | # TODO: 用户的email等隐私信息需要特殊处理 68 | permission_classes = [AllowAny, ] 69 | serializer_class = UserDetailsSerializer 70 | lookup_value_regex = '[0-9]+' 71 | 72 | def get_permissions(self): 73 | if self.action in ['update', 'partial_update']: 74 | return [permissions.IsAuthenticated(), IsCurrentUser()] 75 | return super().get_permissions() 76 | 77 | @action(methods=['get'], detail=True, serializer_class=FlatReplySerializer) 78 | def replies(self, request, pk=None): 79 | user = self.get_object() 80 | replies = user.reply_comments.filter(is_public=True, is_removed=False) 81 | page = self.paginate_queryset(replies) 82 | if page is not None: 83 | serializer = self.get_serializer(page, many=True, context={'request': request}) 84 | return self.get_paginated_response(serializer.data) 85 | 86 | serializer = self.get_serializer(replies, many=True, context={'request': request}) 87 | return Response(serializer.data) 88 | 89 | @action(methods=['get'], detail=True, serializer_class=IndexPostListSerializer) 90 | def posts(self, request, pk=None): 91 | user = self.get_object() 92 | posts = user.post_set.all() 93 | hidden = self.request.query_params.get('hidden') 94 | if hidden: 95 | if not self.request.user.is_authenticated: 96 | return Response(status=status.HTTP_401_UNAUTHORIZED) 97 | # 只有用户自己可以查看被隐藏的帖子 98 | if user != self.request.user: 99 | return Response(status=status.HTTP_403_FORBIDDEN) 100 | page = self.paginate_queryset(posts.filter(hidden=True)) 101 | if page is not None: 102 | serializer = self.get_serializer(page, many=True, context={'request': request}) 103 | return self.get_paginated_response(serializer.data) 104 | 105 | serializer = self.get_serializer( 106 | posts.filter(hidden=True), 107 | many=True, 108 | context={'request': request} 109 | ) 110 | return Response(serializer.data) 111 | 112 | page = self.paginate_queryset(posts.filter(hidden=False)) 113 | if page is not None: 114 | serializer = self.get_serializer(page, many=True, context={'request': request}) 115 | return self.get_paginated_response(serializer.data) 116 | serializer = self.get_serializer( 117 | posts.filter(hidden=False), 118 | many=True, 119 | context={'request': request} 120 | ) 121 | return Response(serializer.data) 122 | 123 | @action( 124 | methods=['post'], 125 | detail=True, 126 | permission_classes=[permissions.IsAuthenticated, OncePerDay, IsCurrentUser], 127 | ) 128 | def checkin(self, request, pk=None): 129 | user = self.get_object() 130 | 131 | # 生成随机奖励 132 | random_amount = abs(random.gauss(10, 5)) 133 | random_amount = math.ceil(random_amount) 134 | 135 | if random_amount == 0: 136 | random_amount += 1 137 | 138 | record = Record.objects.create( 139 | reward_type=0, 140 | coin_type=2, 141 | amount=random_amount, 142 | description='', 143 | user=user 144 | ) 145 | serializer = BalanceSerializer(record) 146 | return Response(serializer.data, status=status.HTTP_201_CREATED) 147 | 148 | @action(methods=['get'], detail=True) 149 | def balance(self, request, pk=None): 150 | user = self.get_object() 151 | user_treasure = user.record_set.values('coin_type').annotate(Sum('amount')) 152 | return Response(user_treasure) 153 | 154 | @action(methods=['get'], detail=True, 155 | permission_classes=[permissions.IsAuthenticated, IsCurrentUser], ) 156 | def checked(self, request, pk=None): 157 | user = self.get_object() 158 | today_start = timezone.now().replace(hour=0, minute=0, second=0) 159 | today_end = timezone.now().replace(hour=23, minute=59, second=59) 160 | checked = user.record_set.filter(created_time__gt=today_start, created_time__lt=today_end).exists() 161 | return Response({'checked': checked}) 162 | 163 | 164 | class MugshotUploadView(views.APIView): 165 | permission_classes = [permissions.IsAuthenticated] 166 | parser_classes = (FileUploadParser,) 167 | 168 | def post(self, request, filename): 169 | if 'file' not in request.FILES: 170 | return Response({ 171 | 'file': 'No avatar file selected.' 172 | }, status=status.HTTP_400_BAD_REQUEST) 173 | file_obj = request.FILES['file'] 174 | 175 | limit_kb = 2048 176 | if file_obj.size > limit_kb * 1024: 177 | return Response({ 178 | 'file': 'File size is too large.' 179 | }, status=status.HTTP_400_BAD_REQUEST) 180 | user = request.user 181 | user.mugshot.name 182 | user.mugshot.save(filename, file_obj) 183 | user.save() 184 | user.refresh_from_db() 185 | return Response({'mugshot_url': user.mugshot.url}, status=200) 186 | 187 | 188 | class EmailAddressViewSet(mixins.ListModelMixin, 189 | mixins.RetrieveModelMixin, 190 | mixins.CreateModelMixin, 191 | mixins.DestroyModelMixin, 192 | viewsets.GenericViewSet): 193 | serializer_class = EmailAddressSerializer 194 | permission_classes = [permissions.IsAuthenticated] 195 | 196 | def get_permissions(self): 197 | if self.action == 'destroy': 198 | return [permissions.IsAuthenticated(), NotPrimary()] 199 | else: 200 | return super(EmailAddressViewSet, self).get_permissions() 201 | 202 | def get_queryset(self): 203 | return self.request.user.emailaddress_set.all() 204 | 205 | def perform_create(self, serializer): 206 | email = serializer.save(user=self.request.user) 207 | email.send_confirmation(request=self.request) 208 | 209 | @action(methods=['post'], detail=True, 210 | permission_classes=[permissions.IsAuthenticated, IsVerified]) 211 | def set_primary(self, request, pk=None): 212 | email = self.get_object() 213 | success = email.set_as_primary() 214 | 215 | if success: 216 | return Response(status=status.HTTP_201_CREATED) 217 | else: 218 | return Response(status=status.HTTP_400_BAD_REQUEST) 219 | 220 | @action(methods=['get'], detail=True, 221 | permission_classes=[permissions.IsAuthenticated]) 222 | def reverify(self, request, pk=None): 223 | email = self.get_object() 224 | email.send_confirmation(request=self.request) 225 | 226 | return Response(status=status.HTTP_200_OK) 227 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DjangoChinaOrg/Django-China-API/79a5d85fe88ba7784d08d370b8e7519f7274f208/utils/__init__.py -------------------------------------------------------------------------------- /utils/mixins.py: -------------------------------------------------------------------------------- 1 | class EagerLoaderMixin(object): 2 | """ 3 | 这个mixin包含一个通用的方法,可以通过select_related和prefetch_related提前告知Django 4 | 需要加载的外键信息,任何需要跨表查询的serializer都应该绑定这个mixin,并且在使用的场景里 5 | (通常都是view)使用这个方法加入查询的外键名 6 | """ 7 | @staticmethod 8 | def setup_eager_loading(queryset, select_related=None, prefetch_related=None): 9 | if select_related: 10 | queryset = queryset.select_related(*select_related) 11 | if prefetch_related: 12 | queryset = queryset.prefetch_related(*prefetch_related) 13 | return queryset 14 | -------------------------------------------------------------------------------- /utils/rest_tools.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from rest_framework.pagination import PageNumberPagination 4 | from rest_framework.response import Response 5 | from rest_framework.exceptions import NotFound 6 | from django.core.paginator import InvalidPage 7 | from django.utils import six 8 | 9 | 10 | class CustomPageNumberPagination(PageNumberPagination): 11 | page_size = 20 12 | page_size_query_param = 'page_size' 13 | page_query_param = 'page' 14 | max_page_size = 100 15 | 16 | def paginate_queryset(self, queryset, request, view=None): 17 | """ 18 | Update self.page_size, if page_size_query_param appears in the request url 19 | 20 | Paginate a queryset if required, either returning a 21 | page object, or `None` if pagination is not configured for this view. 22 | """ 23 | page_size = self.get_page_size(request) 24 | 25 | # this is the only difference from the PageNumberPagination's paginate_queryset function 26 | if not page_size: 27 | return None 28 | else: 29 | self.page_size = page_size 30 | 31 | paginator = self.django_paginator_class(queryset, page_size) 32 | page_number = request.query_params.get(self.page_query_param, 1) 33 | if page_number in self.last_page_strings: 34 | page_number = paginator.num_pages 35 | 36 | try: 37 | self.page = paginator.page(page_number) 38 | except InvalidPage as exc: 39 | msg = self.invalid_page_message.format( 40 | page_number=page_number, message=six.text_type(exc) 41 | ) 42 | raise NotFound(msg) 43 | 44 | if paginator.num_pages > 1 and self.template is not None: 45 | # The browsable API should display pagination controls. 46 | self.display_page_controls = True 47 | 48 | self.request = request 49 | return list(self.page) 50 | 51 | def get_paginated_response(self, data): 52 | return Response(OrderedDict([ 53 | ('page_size', self.page_size), 54 | ('current_page', self.page.number), 55 | ('last_page', self.page.paginator.num_pages), 56 | ('next', self.get_next_link()), 57 | ('previous', self.get_previous_link()), 58 | ('count', self.page.paginator.count), 59 | ('data', data) 60 | ])) 61 | --------------------------------------------------------------------------------