├── .github ├── FUNDING.yml └── workflows │ └── build.yml ├── .gitignore ├── .gitmodules ├── .travis.yml ├── Dockerfile ├── LICENSE ├── README.md ├── bootstrap ├── app.go ├── init.go ├── script.go └── static.go ├── build.sh ├── go.mod ├── go.sum ├── main.go ├── middleware ├── auth.go ├── auth_test.go ├── frontend.go ├── frontend_test.go ├── mock.go ├── mock_test.go ├── option.go ├── option_test.go ├── session.go ├── session_test.go ├── share.go └── share_test.go ├── models ├── download.go ├── download_test.go ├── file.go ├── file_test.go ├── folder.go ├── folder_test.go ├── group.go ├── group_test.go ├── init.go ├── migration.go ├── migration_test.go ├── policy.go ├── policy_test.go ├── scripts │ ├── invoker.go │ ├── invoker_test.go │ ├── storage.go │ └── storage_test.go ├── setting.go ├── setting_test.go ├── share.go ├── share_test.go ├── tag.go ├── tag_test.go ├── task.go ├── task_test.go ├── user.go ├── user_authn.go ├── user_authn_test.go ├── user_test.go ├── webdav.go └── webdav_test.go ├── pkg ├── aria2 │ ├── aria2.go │ ├── aria2_test.go │ ├── caller.go │ ├── caller_test.go │ ├── monitor.go │ ├── monitor_test.go │ ├── notification.go │ ├── notification_test.go │ └── rpc │ │ ├── README.md │ │ ├── call.go │ │ ├── client.go │ │ ├── const.go │ │ ├── json2.go │ │ ├── notification.go │ │ ├── proc.go │ │ ├── proto.go │ │ └── resp.go ├── auth │ ├── auth.go │ ├── auth_test.go │ ├── hmac.go │ └── hmac_test.go ├── authn │ ├── auth.go │ └── auth_test.go ├── cache │ ├── driver.go │ ├── driver_test.go │ ├── memo.go │ ├── memo_test.go │ ├── redis.go │ └── redis_test.go ├── conf │ ├── conf.go │ ├── conf_test.go │ ├── defaults.go │ └── version.go ├── crontab │ ├── collect.go │ └── init.go ├── email │ ├── init.go │ ├── mail.go │ ├── smtp.go │ └── template.go ├── filesystem │ ├── archive.go │ ├── archive_test.go │ ├── driver │ │ ├── cos │ │ │ ├── handler.go │ │ │ └── scf.go │ │ ├── local │ │ │ ├── file.go │ │ │ ├── file_test.go │ │ │ ├── handler.go │ │ │ └── handler_test.go │ │ ├── onedrive │ │ │ ├── api.go │ │ │ ├── api_test.go │ │ │ ├── client.go │ │ │ ├── client_test.go │ │ │ ├── handler.go │ │ │ ├── handler_test.go │ │ │ ├── handller_test.go │ │ │ ├── oauth.go │ │ │ ├── oauth_test.go │ │ │ ├── options.go │ │ │ └── types.go │ │ ├── oss │ │ │ ├── callback.go │ │ │ ├── callback_test.go │ │ │ ├── handler.go │ │ │ └── handler_test.go │ │ ├── qiniu │ │ │ └── handler.go │ │ ├── remote │ │ │ ├── handler.go │ │ │ └── handler_test.go │ │ ├── s3 │ │ │ └── handler.go │ │ ├── template │ │ │ └── handler.go │ │ └── upyun │ │ │ └── handler.go │ ├── errors.go │ ├── file.go │ ├── file_test.go │ ├── filesystem.go │ ├── filesystem_test.go │ ├── fsctx │ │ └── context.go │ ├── hooks.go │ ├── hooks_test.go │ ├── image.go │ ├── image_test.go │ ├── manage.go │ ├── manage_test.go │ ├── path.go │ ├── path_test.go │ ├── response │ │ └── common.go │ ├── tests │ │ ├── file1.txt │ │ ├── file2.txt │ │ └── test.zip │ ├── upload.go │ ├── upload_test.go │ ├── validator.go │ └── validator_test.go ├── hashid │ ├── hash.go │ └── hash_test.go ├── recaptcha │ └── recaptcha.go ├── request │ ├── request.go │ ├── request_test.go │ ├── slave.go │ └── slave_test.go ├── serializer │ ├── aria2.go │ ├── aria2_test.go │ ├── auth.go │ ├── auth_test.go │ ├── error.go │ ├── setting.go │ ├── setting_test.go │ ├── share.go │ ├── share_test.go │ ├── slave.go │ ├── upload.go │ ├── upload_test.go │ ├── user.go │ └── user_test.go ├── task │ ├── compress.go │ ├── compress_test.go │ ├── decompress.go │ ├── decompress_test.go │ ├── errors.go │ ├── import.go │ ├── import_test.go │ ├── job.go │ ├── job_test.go │ ├── pool.go │ ├── pool_test.go │ ├── tranfer.go │ ├── transfer_test.go │ ├── worker.go │ └── worker_test.go ├── thumb │ ├── image.go │ └── image_test.go ├── util │ ├── common.go │ ├── common_test.go │ ├── io.go │ ├── io_test.go │ ├── logger.go │ ├── logger_test.go │ ├── path.go │ ├── path_test.go │ └── session.go └── webdav │ ├── file.go │ ├── if.go │ ├── internal │ └── xml │ │ ├── README │ │ ├── marshal.go │ │ ├── read.go │ │ ├── typeinfo.go │ │ └── xml.go │ ├── lock.go │ ├── prop.go │ ├── webdav.go │ └── xml.go ├── routers ├── controllers │ ├── admin.go │ ├── aria2.go │ ├── callback.go │ ├── directory.go │ ├── file.go │ ├── main.go │ ├── objects.go │ ├── share.go │ ├── site.go │ ├── slave.go │ ├── tag.go │ ├── user.go │ └── webdav.go ├── file_router_test.go ├── main_test.go ├── router.go └── router_test.go └── service ├── admin ├── aria2.go ├── file.go ├── group.go ├── list.go ├── policy.go ├── share.go ├── site.go ├── task.go └── user.go ├── aria2 ├── add.go └── manage.go ├── callback ├── oauth.go └── upload.go ├── explorer ├── directory.go ├── file.go ├── objects.go ├── search.go ├── tag.go └── upload.go ├── setting └── webdav.go ├── share ├── manage.go └── visit.go └── user ├── login.go ├── register.go └── setting.go /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | custom: ["https://cloudreve.org/buy.php"] 2 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | 7 | jobs: 8 | 9 | test: 10 | name: Test 11 | runs-on: ubuntu-latest 12 | steps: 13 | 14 | - name: Set up Go 1.13 15 | uses: actions/setup-go@v1 16 | with: 17 | go-version: 1.13 18 | id: go 19 | 20 | - name: Check out code into the Go module directory 21 | uses: actions/checkout@v2 22 | with: 23 | submodules: 'recursive' 24 | 25 | - name: Get dependencies 26 | run: | 27 | go get github.com/rakyll/statik 28 | export PATH=$PATH:~/go/bin/ 29 | statik -src=models -f 30 | 31 | - name: Test 32 | run: go test -coverprofile=coverage.txt -covermode=atomic ./... 33 | 34 | build: 35 | name: Build 36 | needs: test 37 | runs-on: ubuntu-latest 38 | steps: 39 | 40 | - name: Set up Go 1.13 41 | uses: actions/setup-go@v1 42 | with: 43 | go-version: 1.13 44 | id: go 45 | 46 | - name: Check out code into the Go module directory 47 | uses: actions/checkout@v2 48 | with: 49 | clean: false 50 | submodules: 'recursive' 51 | - run: | 52 | git fetch --prune --unshallow --tags 53 | 54 | - name: Get dependencies and build 55 | run: | 56 | go get github.com/rakyll/statik 57 | export PATH=$PATH:~/go/bin/ 58 | statik -src=models -f 59 | sudo apt-get update 60 | sudo apt-get -y install gcc-mingw-w64-x86-64 61 | sudo apt-get -y install gcc-arm-linux-gnueabihf libc6-dev-armhf-cross 62 | sudo apt-get -y install gcc-aarch64-linux-gnu libc6-dev-arm64-cross 63 | chmod +x ./build.sh 64 | ./build.sh -r b 65 | 66 | - name: Upload binary files (windows_amd64) 67 | uses: actions/upload-artifact@v2 68 | with: 69 | name: cloudreve_windows_amd64 70 | path: release/cloudreve*windows_amd64.* 71 | 72 | - name: Upload binary files (linux_amd64) 73 | uses: actions/upload-artifact@v2 74 | with: 75 | name: cloudreve_linux_amd64 76 | path: release/cloudreve*linux_amd64.* 77 | 78 | - name: Upload binary files (linux_arm) 79 | uses: actions/upload-artifact@v2 80 | with: 81 | name: cloudreve_linux_arm 82 | path: release/cloudreve*linux_arm.* 83 | 84 | - name: Upload binary files (linux_arm64) 85 | uses: actions/upload-artifact@v2 86 | with: 87 | name: cloudreve_linux_arm64 88 | path: release/cloudreve*linux_arm64.* 89 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | cloudreve 3 | *.exe 4 | *.exe~ 5 | *.dll 6 | *.so 7 | *.dylib 8 | *.db 9 | *.bin 10 | /release/ 11 | 12 | # Test binary, build with `go test -c` 13 | *.test 14 | 15 | # Output of the go coverage tool, specifically when used with LiteIDE 16 | *.out 17 | 18 | # Development enviroment 19 | .idea/* 20 | uploads/* 21 | temp 22 | 23 | # Version control 24 | version.lock 25 | 26 | # Config file 27 | *.ini 28 | conf/conf.ini 29 | /statik/ 30 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "assets"] 2 | path = assets 3 | url = https://github.com/cloudreve/frontend.git 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 3 | - 1.13.x 4 | node_js: "12.16.3" 5 | git: 6 | depth: 1 7 | install: 8 | - go get github.com/rakyll/statik 9 | before_script: 10 | - statik -src=models -f 11 | script: 12 | - go test -coverprofile=coverage.txt -covermode=atomic ./... 13 | after_success: 14 | - bash <(curl -s https://codecov.io/bash) 15 | before_deploy: 16 | - sudo apt-get update 17 | - sudo apt-get -y install gcc-mingw-w64-x86-64 18 | - sudo apt-get -y install gcc-arm-linux-gnueabihf libc6-dev-armhf-cross 19 | - sudo apt-get -y install gcc-aarch64-linux-gnu libc6-dev-arm64-cross 20 | - chmod +x ./build.sh 21 | - ./build.sh -r b 22 | deploy: 23 | provider: releases 24 | api_key: $GITHUB_TOKEN 25 | file_glob: true 26 | file: release/* 27 | draft: true 28 | skip_cleanup: true 29 | on: 30 | tags: true -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # build frontend 2 | FROM node:lts-buster AS fe-builder 3 | 4 | COPY ./assets /assets 5 | 6 | WORKDIR /assets 7 | 8 | # yarn repo connection is unstable, adjust the network timeout to 10 min. 9 | RUN set -ex \ 10 | && yarn install --network-timeout 600000 \ 11 | && yarn run build 12 | 13 | # build backend 14 | FROM golang:1.15.1-alpine3.12 AS be-builder 15 | 16 | ENV GO111MODULE on 17 | 18 | COPY . /go/src/github.com/cloudreve/Cloudreve/v3 19 | COPY --from=fe-builder /assets/build/ /go/src/github.com/cloudreve/Cloudreve/v3/assets/build/ 20 | 21 | WORKDIR /go/src/github.com/cloudreve/Cloudreve/v3 22 | 23 | RUN set -ex \ 24 | && apk upgrade \ 25 | && apk add gcc libc-dev git \ 26 | && export COMMIT_SHA=$(git rev-parse --short HEAD) \ 27 | && export VERSION=$(git describe --tags) \ 28 | && (cd && go get github.com/rakyll/statik) \ 29 | && statik -src=assets/build/ -include=*.html,*.js,*.json,*.css,*.png,*.svg,*.ico -f \ 30 | && go install -ldflags "-X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.BackendVersion=${VERSION}' \ 31 | -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.LastCommit=${COMMIT_SHA}'\ 32 | -w -s" 33 | 34 | # build final image 35 | FROM alpine:3.12 AS dist 36 | 37 | LABEL maintainer="mritd " 38 | 39 | # we use the Asia/Shanghai timezone by default, you can be modified 40 | # by `docker build --build-arg=TZ=Other_Timezone ...` 41 | ARG TZ="Asia/Shanghai" 42 | 43 | ENV TZ ${TZ} 44 | 45 | COPY --from=be-builder /go/bin/cloudreve /cloudreve/cloudreve 46 | 47 | RUN apk upgrade \ 48 | && apk add bash tzdata \ 49 | && ln -s /cloudreve/cloudreve /usr/bin/cloudreve \ 50 | && ln -sf /usr/share/zoneinfo/${TZ} /etc/localtime \ 51 | && echo ${TZ} > /etc/timezone \ 52 | && rm -rf /var/cache/apk/* 53 | 54 | # cloudreve use tcp 5212 port by default 55 | EXPOSE 5212/tcp 56 | 57 | # cloudreve stores all files(including executable file) in the `/cloudreve` 58 | # directory by default; users should mount the configfile to the `/etc/cloudreve` 59 | # directory by themselves for persistence considerations, and the data storage 60 | # directory recommends using `/data` directory. 61 | VOLUME /etc/cloudreve 62 | 63 | VOLUME /data 64 | 65 | ENTRYPOINT ["cloudreve"] 66 | -------------------------------------------------------------------------------- /bootstrap/app.go: -------------------------------------------------------------------------------- 1 | package bootstrap 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | 7 | "github.com/cloudreve/Cloudreve/v3/pkg/conf" 8 | "github.com/cloudreve/Cloudreve/v3/pkg/request" 9 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 10 | "github.com/hashicorp/go-version" 11 | ) 12 | 13 | // InitApplication 初始化应用常量 14 | func InitApplication() { 15 | fmt.Print(` 16 | ___ _ _ 17 | / __\ | ___ _ _ __| |_ __ _____ _____ 18 | / / | |/ _ \| | | |/ _ | '__/ _ \ \ / / _ \ 19 | / /___| | (_) | |_| | (_| | | | __/\ V / __/ 20 | \____/|_|\___/ \__,_|\__,_|_| \___| \_/ \___| 21 | 22 | V` + conf.BackendVersion + ` Commit #` + conf.LastCommit + ` Pro=` + conf.IsPro + ` 23 | ================================================ 24 | 25 | `) 26 | go CheckUpdate() 27 | } 28 | 29 | type GitHubRelease struct { 30 | URL string `json:"html_url"` 31 | Name string `json:"name"` 32 | Tag string `json:"tag_name"` 33 | } 34 | 35 | // CheckUpdate 检查更新 36 | func CheckUpdate() { 37 | client := request.HTTPClient{} 38 | res, err := client.Request("GET", "https://api.github.com/repos/cloudreve/cloudreve/releases", nil).GetResponse() 39 | if err != nil { 40 | util.Log().Warning("更新检查失败, %s", err) 41 | return 42 | } 43 | 44 | var list []GitHubRelease 45 | if err := json.Unmarshal([]byte(res), &list); err != nil { 46 | util.Log().Warning("更新检查失败, %s", err) 47 | return 48 | } 49 | 50 | if len(list) > 0 { 51 | present, err1 := version.NewVersion(conf.BackendVersion) 52 | latest, err2 := version.NewVersion(list[0].Tag) 53 | if err1 == nil && err2 == nil && latest.GreaterThan(present) { 54 | util.Log().Info("有新的版本 [%s] 可用,下载:%s", list[0].Name, list[0].URL) 55 | } 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /bootstrap/init.go: -------------------------------------------------------------------------------- 1 | package bootstrap 2 | 3 | import ( 4 | model "github.com/cloudreve/Cloudreve/v3/models" 5 | "github.com/cloudreve/Cloudreve/v3/pkg/aria2" 6 | "github.com/cloudreve/Cloudreve/v3/pkg/auth" 7 | "github.com/cloudreve/Cloudreve/v3/pkg/cache" 8 | "github.com/cloudreve/Cloudreve/v3/pkg/conf" 9 | "github.com/cloudreve/Cloudreve/v3/pkg/crontab" 10 | "github.com/cloudreve/Cloudreve/v3/pkg/email" 11 | "github.com/cloudreve/Cloudreve/v3/pkg/task" 12 | "github.com/gin-gonic/gin" 13 | ) 14 | 15 | // Init 初始化启动 16 | func Init(path string) { 17 | InitApplication() 18 | conf.Init(path) 19 | // Debug 关闭时,切换为生产模式 20 | if !conf.SystemConfig.Debug { 21 | gin.SetMode(gin.ReleaseMode) 22 | } 23 | cache.Init() 24 | if conf.SystemConfig.Mode == "master" { 25 | model.Init() 26 | task.Init() 27 | aria2.Init(false) 28 | email.Init() 29 | crontab.Init() 30 | InitStatic() 31 | } 32 | auth.Init() 33 | } 34 | -------------------------------------------------------------------------------- /bootstrap/script.go: -------------------------------------------------------------------------------- 1 | package bootstrap 2 | 3 | import ( 4 | "context" 5 | "github.com/cloudreve/Cloudreve/v3/models/scripts" 6 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 7 | ) 8 | 9 | func RunScript(name string) { 10 | ctx, cancel := context.WithCancel(context.Background()) 11 | defer cancel() 12 | if err := scripts.RunDBScript(name, ctx); err != nil { 13 | util.Log().Error("数据库脚本执行失败: %s", err) 14 | return 15 | } 16 | 17 | util.Log().Info("数据库脚本 [%s] 执行完毕", name) 18 | } 19 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | REPO=$(cd $(dirname $0); pwd) 4 | COMMIT_SHA=$(git rev-parse --short HEAD) 5 | VERSION=$(git describe --tags) 6 | ASSETS="false" 7 | BINARY="false" 8 | RELEASE="false" 9 | 10 | debugInfo () { 11 | echo "Repo: $REPO" 12 | echo "Build assets: $ASSETS" 13 | echo "Build binary: $BINARY" 14 | echo "Release: $RELEASE" 15 | echo "Version: $VERSION" 16 | echo "Commit: $COMMIT_SHA" 17 | } 18 | 19 | buildAssets () { 20 | cd $REPO 21 | rm -rf assets/build 22 | rm -f statik/statik.go 23 | 24 | export CI=false 25 | 26 | cd $REPO/assets 27 | 28 | yarn install 29 | yarn run build 30 | 31 | if ! [ -x "$(command -v statik)" ]; then 32 | export CGO_ENABLED=0 33 | go get github.com/rakyll/statik 34 | fi 35 | 36 | cd $REPO 37 | statik -src=assets/build/ -include=*.html,*.js,*.json,*.css,*.png,*.svg,*.ico,*.ttf -f 38 | } 39 | 40 | buildBinary () { 41 | cd $REPO 42 | go build -a -o cloudreve -ldflags " -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.BackendVersion=$VERSION' -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.LastCommit=$COMMIT_SHA'" 43 | } 44 | 45 | _build() { 46 | local osarch=$1 47 | IFS=/ read -r -a arr <<<"$osarch" 48 | os="${arr[0]}" 49 | arch="${arr[1]}" 50 | gcc="${arr[2]}" 51 | 52 | # Go build to build the binary. 53 | export GOOS=$os 54 | export GOARCH=$arch 55 | export CC=$gcc 56 | export CGO_ENABLED=1 57 | 58 | if [ -n "$VERSION" ]; then 59 | out="release/cloudreve_${VERSION}_${os}_${arch}" 60 | else 61 | out="release/cloudreve_${COMMIT_SHA}_${os}_${arch}" 62 | fi 63 | 64 | go build -a -o "${out}" -ldflags " -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.BackendVersion=$VERSION' -X 'github.com/cloudreve/Cloudreve/v3/pkg/conf.LastCommit=$COMMIT_SHA'" 65 | 66 | if [ "$os" = "windows" ]; then 67 | mv $out release/cloudreve.exe 68 | zip -j -q "${out}.zip" release/cloudreve.exe 69 | rm -f "release/cloudreve.exe" 70 | else 71 | mv $out release/cloudreve 72 | tar -zcvf "${out}.tar.gz" -C release cloudreve 73 | rm -f "release/cloudreve" 74 | fi 75 | } 76 | 77 | release(){ 78 | cd $REPO 79 | ## List of architectures and OS to test coss compilation. 80 | SUPPORTED_OSARCH="linux/amd64/gcc linux/arm/arm-linux-gnueabihf-gcc windows/amd64/x86_64-w64-mingw32-gcc linux/arm64/aarch64-linux-gnu-gcc" 81 | 82 | echo "Release builds for OS/Arch/CC: ${SUPPORTED_OSARCH}" 83 | for each_osarch in ${SUPPORTED_OSARCH}; do 84 | _build "${each_osarch}" 85 | done 86 | } 87 | 88 | usage() { 89 | echo "Usage: $0 [-a] [-c] [-b] [-r]" 1>&2; 90 | exit 1; 91 | } 92 | 93 | while getopts "bacr:d" o; do 94 | case "${o}" in 95 | b) 96 | ASSETS="true" 97 | BINARY="true" 98 | ;; 99 | a) 100 | ASSETS="true" 101 | ;; 102 | c) 103 | BINARY="true" 104 | ;; 105 | r) 106 | ASSETS="true" 107 | RELEASE="true" 108 | ;; 109 | d) 110 | DEBUG="true" 111 | ;; 112 | *) 113 | usage 114 | ;; 115 | esac 116 | done 117 | shift $((OPTIND-1)) 118 | 119 | if [ "$DEBUG" = "true" ]; then 120 | debugInfo 121 | fi 122 | 123 | if [ "$ASSETS" = "true" ]; then 124 | buildAssets 125 | fi 126 | 127 | if [ "$BINARY" = "true" ]; then 128 | buildBinary 129 | fi 130 | 131 | if [ "$RELEASE" = "true" ]; then 132 | release 133 | fi 134 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/cloudreve/Cloudreve/v3 2 | 3 | go 1.13 4 | 5 | require ( 6 | github.com/DATA-DOG/go-sqlmock v1.3.3 7 | github.com/aliyun/aliyun-oss-go-sdk v2.0.5+incompatible 8 | github.com/aws/aws-sdk-go v1.31.5 9 | github.com/baiyubin/aliyun-sts-go-sdk v0.0.0-20180326062324-cfa1a18b161f // indirect 10 | github.com/duo-labs/webauthn v0.0.0-20191119193225-4bf9a0f776d4 11 | github.com/fatih/color v1.7.0 12 | github.com/gin-contrib/cors v1.3.0 13 | github.com/gin-contrib/gzip v0.0.2-0.20200226035851-25bef2ef21e8 14 | github.com/gin-contrib/sessions v0.0.1 15 | github.com/gin-contrib/static v0.0.0-20191128031702-f81c604d8ac2 16 | github.com/gin-gonic/gin v1.5.0 17 | github.com/go-ini/ini v1.50.0 18 | github.com/go-mail/mail v2.3.1+incompatible 19 | github.com/gomodule/redigo v2.0.0+incompatible 20 | github.com/google/go-querystring v1.0.0 21 | github.com/gorilla/websocket v1.4.1 22 | github.com/hashicorp/go-version v1.2.0 23 | github.com/jinzhu/gorm v1.9.11 24 | github.com/juju/ratelimit v1.0.1 25 | github.com/mattn/go-colorable v0.1.4 // indirect 26 | github.com/mojocn/base64Captcha v0.0.0-20190801020520-752b1cd608b2 27 | github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 28 | github.com/pkg/errors v0.9.1 29 | github.com/pquerna/otp v1.2.0 30 | github.com/qiniu/api.v7/v7 v7.4.0 31 | github.com/rafaeljusto/redigomock v0.0.0-20191117212112-00b2509252a1 32 | github.com/rakyll/statik v0.1.7 33 | github.com/robfig/cron/v3 v3.0.1 34 | github.com/smartystreets/goconvey v1.6.4 // indirect 35 | github.com/speps/go-hashids v2.0.0+incompatible 36 | github.com/stretchr/testify v1.5.1 37 | github.com/tencentcloud/tencentcloud-sdk-go v3.0.125+incompatible 38 | github.com/tencentyun/cos-go-sdk-v5 v0.0.0-20200120023323-87ff3bc489ac 39 | github.com/upyun/go-sdk v2.1.0+incompatible 40 | golang.org/x/text v0.3.2 41 | gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect 42 | gopkg.in/go-playground/validator.v9 v9.29.1 43 | gopkg.in/ini.v1 v1.51.0 // indirect 44 | gopkg.in/mail.v2 v2.3.1 // indirect 45 | ) 46 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | 6 | "github.com/cloudreve/Cloudreve/v3/bootstrap" 7 | "github.com/cloudreve/Cloudreve/v3/pkg/conf" 8 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 9 | "github.com/cloudreve/Cloudreve/v3/routers" 10 | ) 11 | 12 | var ( 13 | isEject bool 14 | confPath string 15 | scriptName string 16 | ) 17 | 18 | func init() { 19 | flag.StringVar(&confPath, "c", util.RelativePath("conf.ini"), "配置文件路径") 20 | flag.BoolVar(&isEject, "eject", false, "导出内置静态资源") 21 | flag.StringVar(&scriptName, "database-script", "", "运行内置数据库助手脚本") 22 | flag.Parse() 23 | bootstrap.Init(confPath) 24 | } 25 | 26 | func main() { 27 | if isEject { 28 | // 开始导出内置静态资源文件 29 | bootstrap.Eject() 30 | return 31 | } 32 | 33 | if scriptName != "" { 34 | // 开始运行助手数据库脚本 35 | bootstrap.RunScript(scriptName) 36 | return 37 | } 38 | 39 | api := routers.InitRouter() 40 | 41 | // 如果启用了SSL 42 | if conf.SSLConfig.CertPath != "" { 43 | go func() { 44 | util.Log().Info("开始监听 %s", conf.SSLConfig.Listen) 45 | if err := api.RunTLS(conf.SSLConfig.Listen, 46 | conf.SSLConfig.CertPath, conf.SSLConfig.KeyPath); err != nil { 47 | util.Log().Error("无法监听[%s],%s", conf.SSLConfig.Listen, err) 48 | } 49 | }() 50 | } 51 | 52 | // 如果启用了Unix 53 | if conf.UnixConfig.Listen != "" { 54 | go func() { 55 | util.Log().Info("开始监听 %s", conf.UnixConfig.Listen) 56 | if err := api.RunUnix(conf.UnixConfig.Listen); err != nil { 57 | util.Log().Error("无法监听[%s],%s", conf.UnixConfig.Listen, err) 58 | } 59 | }() 60 | } 61 | 62 | util.Log().Info("开始监听 %s", conf.SystemConfig.Listen) 63 | if err := api.Run(conf.SystemConfig.Listen); err != nil { 64 | util.Log().Error("无法监听[%s],%s", conf.SystemConfig.Listen, err) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /middleware/frontend.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "github.com/cloudreve/Cloudreve/v3/bootstrap" 5 | model "github.com/cloudreve/Cloudreve/v3/models" 6 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 7 | "github.com/gin-gonic/gin" 8 | "io/ioutil" 9 | "net/http" 10 | "strings" 11 | ) 12 | 13 | // FrontendFileHandler 前端静态文件处理 14 | func FrontendFileHandler() gin.HandlerFunc { 15 | ignoreFunc := func(c *gin.Context) { 16 | c.Next() 17 | } 18 | 19 | if bootstrap.StaticFS == nil { 20 | return ignoreFunc 21 | } 22 | 23 | // 读取index.html 24 | file, err := bootstrap.StaticFS.Open("/index.html") 25 | if err != nil { 26 | util.Log().Warning("静态文件[index.html]不存在,可能会影响首页展示") 27 | return ignoreFunc 28 | } 29 | 30 | fileContentBytes, err := ioutil.ReadAll(file) 31 | if err != nil { 32 | util.Log().Warning("静态文件[index.html]读取失败,可能会影响首页展示") 33 | return ignoreFunc 34 | } 35 | fileContent := string(fileContentBytes) 36 | 37 | fileServer := http.FileServer(bootstrap.StaticFS) 38 | return func(c *gin.Context) { 39 | path := c.Request.URL.Path 40 | 41 | // API 跳过 42 | if strings.HasPrefix(path, "/api") || strings.HasPrefix(path, "/custom") || strings.HasPrefix(path, "/dav") || path == "/manifest.json" { 43 | c.Next() 44 | return 45 | } 46 | 47 | // 不存在的路径和index.html均返回index.html 48 | if (path == "/index.html") || (path == "/") || !bootstrap.StaticFS.Exists("/", path) { 49 | // 读取、替换站点设置 50 | options := model.GetSettingByNames("siteName", "siteKeywords", "siteScript", 51 | "pwa_small_icon") 52 | finalHTML := util.Replace(map[string]string{ 53 | "{siteName}": options["siteName"], 54 | "{siteDes}": options["siteDes"], 55 | "{siteScript}": options["siteScript"], 56 | "{pwa_small_icon}": options["pwa_small_icon"], 57 | }, fileContent) 58 | 59 | c.Header("Content-Type", "text/html") 60 | c.String(200, finalHTML) 61 | c.Abort() 62 | return 63 | } 64 | 65 | // 存在的静态文件 66 | fileServer.ServeHTTP(c.Writer, c.Request) 67 | c.Abort() 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /middleware/mock.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 5 | "github.com/gin-gonic/gin" 6 | ) 7 | 8 | // SessionMock 测试时模拟Session 9 | var SessionMock = make(map[string]interface{}) 10 | 11 | // ContextMock 测试时模拟Context 12 | var ContextMock = make(map[string]interface{}) 13 | 14 | // MockHelper 单元测试助手中间件 15 | func MockHelper() gin.HandlerFunc { 16 | return func(c *gin.Context) { 17 | // 将SessionMock写入会话 18 | util.SetSession(c, SessionMock) 19 | for key, value := range ContextMock { 20 | c.Set(key, value) 21 | } 22 | c.Next() 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /middleware/mock_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 9 | "github.com/gin-gonic/gin" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestMockHelper(t *testing.T) { 14 | asserts := assert.New(t) 15 | MockHelperFunc := MockHelper() 16 | rec := httptest.NewRecorder() 17 | c, _ := gin.CreateTestContext(rec) 18 | c.Request, _ = http.NewRequest("GET", "/test", nil) 19 | 20 | // 写入session 21 | { 22 | SessionMock["test"] = "pass" 23 | Session("test")(c) 24 | MockHelperFunc(c) 25 | asserts.Equal("pass", util.GetSession(c, "test").(string)) 26 | } 27 | 28 | // 写入context 29 | { 30 | ContextMock["test"] = "pass" 31 | MockHelperFunc(c) 32 | test, exist := c.Get("test") 33 | asserts.True(exist) 34 | asserts.Equal("pass", test.(string)) 35 | 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /middleware/option.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | model "github.com/cloudreve/Cloudreve/v3/models" 5 | "github.com/cloudreve/Cloudreve/v3/pkg/hashid" 6 | "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 7 | "github.com/gin-gonic/gin" 8 | ) 9 | 10 | // HashID 将给定对象的HashID转换为真实ID 11 | func HashID(IDType int) gin.HandlerFunc { 12 | return func(c *gin.Context) { 13 | if c.Param("id") != "" { 14 | id, err := hashid.DecodeHashID(c.Param("id"), IDType) 15 | if err == nil { 16 | c.Set("object_id", id) 17 | c.Next() 18 | return 19 | } 20 | c.JSON(200, serializer.ParamErr("无法解析对象ID", nil)) 21 | c.Abort() 22 | return 23 | 24 | } 25 | c.Next() 26 | } 27 | } 28 | 29 | // IsFunctionEnabled 当功能未开启时阻止访问 30 | func IsFunctionEnabled(key string) gin.HandlerFunc { 31 | return func(c *gin.Context) { 32 | if !model.IsTrueVal(model.GetSettingByName(key)) { 33 | c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, "未开启此功能", nil)) 34 | c.Abort() 35 | return 36 | } 37 | 38 | c.Next() 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /middleware/option_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/cloudreve/Cloudreve/v3/pkg/cache" 9 | "github.com/cloudreve/Cloudreve/v3/pkg/hashid" 10 | "github.com/gin-gonic/gin" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestHashID(t *testing.T) { 15 | asserts := assert.New(t) 16 | rec := httptest.NewRecorder() 17 | TestFunc := HashID(hashid.FolderID) 18 | 19 | // 未给定ID对象,跳过 20 | { 21 | c, _ := gin.CreateTestContext(rec) 22 | c.Params = []gin.Param{} 23 | c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil) 24 | TestFunc(c) 25 | asserts.NoError(mock.ExpectationsWereMet()) 26 | asserts.False(c.IsAborted()) 27 | } 28 | 29 | // 给定ID,解析失败 30 | { 31 | c, _ := gin.CreateTestContext(rec) 32 | c.Params = []gin.Param{ 33 | {"id", "2333"}, 34 | } 35 | c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil) 36 | TestFunc(c) 37 | asserts.NoError(mock.ExpectationsWereMet()) 38 | asserts.True(c.IsAborted()) 39 | } 40 | 41 | // 给定ID,解析成功 42 | { 43 | c, _ := gin.CreateTestContext(rec) 44 | c.Params = []gin.Param{ 45 | {"id", hashid.HashID(1, hashid.FolderID)}, 46 | } 47 | c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil) 48 | TestFunc(c) 49 | asserts.NoError(mock.ExpectationsWereMet()) 50 | asserts.False(c.IsAborted()) 51 | } 52 | } 53 | 54 | func TestIsFunctionEnabled(t *testing.T) { 55 | asserts := assert.New(t) 56 | rec := httptest.NewRecorder() 57 | TestFunc := IsFunctionEnabled("TestIsFunctionEnabled") 58 | 59 | // 未开启 60 | { 61 | cache.Set("setting_TestIsFunctionEnabled", "0", 0) 62 | c, _ := gin.CreateTestContext(rec) 63 | c.Params = []gin.Param{} 64 | c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil) 65 | TestFunc(c) 66 | asserts.True(c.IsAborted()) 67 | } 68 | // 开启 69 | { 70 | cache.Set("setting_TestIsFunctionEnabled", "1", 0) 71 | c, _ := gin.CreateTestContext(rec) 72 | c.Params = []gin.Param{} 73 | c.Request, _ = http.NewRequest("POST", "/api/v3/file/dellete/1", nil) 74 | TestFunc(c) 75 | asserts.False(c.IsAborted()) 76 | } 77 | 78 | } 79 | -------------------------------------------------------------------------------- /middleware/session.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "github.com/cloudreve/Cloudreve/v3/pkg/conf" 5 | "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 6 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 7 | "github.com/gin-contrib/sessions" 8 | "github.com/gin-contrib/sessions/memstore" 9 | "github.com/gin-contrib/sessions/redis" 10 | "github.com/gin-gonic/gin" 11 | ) 12 | 13 | // Store session存储 14 | var Store memstore.Store 15 | 16 | // Session 初始化session 17 | func Session(secret string) gin.HandlerFunc { 18 | // Redis设置不为空,且非测试模式时使用Redis 19 | if conf.RedisConfig.Server != "" && gin.Mode() != gin.TestMode { 20 | var err error 21 | Store, err = redis.NewStoreWithDB(10, conf.RedisConfig.Network, conf.RedisConfig.Server, conf.RedisConfig.Password, conf.RedisConfig.DB, []byte(secret)) 22 | if err != nil { 23 | util.Log().Panic("无法连接到 Redis:%s", err) 24 | } 25 | 26 | util.Log().Info("已连接到 Redis 服务器:%s", conf.RedisConfig.Server) 27 | } else { 28 | Store = memstore.NewStore([]byte(secret)) 29 | } 30 | 31 | // Also set Secure: true if using SSL, you should though 32 | // TODO:same-site policy 33 | Store.Options(sessions.Options{HttpOnly: true, MaxAge: 7 * 86400, Path: "/"}) 34 | return sessions.Sessions("cloudreve-session", Store) 35 | } 36 | 37 | // CSRFInit 初始化CSRF标记 38 | func CSRFInit() gin.HandlerFunc { 39 | return func(c *gin.Context) { 40 | util.SetSession(c, map[string]interface{}{"CSRF": true}) 41 | c.Next() 42 | } 43 | } 44 | 45 | // CSRFCheck 检查CSRF标记 46 | func CSRFCheck() gin.HandlerFunc { 47 | return func(c *gin.Context) { 48 | if check, ok := util.GetSession(c, "CSRF").(bool); ok && check { 49 | c.Next() 50 | return 51 | } 52 | 53 | c.JSON(200, serializer.Err(serializer.CodeNoPermissionErr, "来源非法", nil)) 54 | c.Abort() 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /middleware/session_test.go: -------------------------------------------------------------------------------- 1 | package middleware 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "testing" 7 | 8 | "github.com/cloudreve/Cloudreve/v3/pkg/conf" 9 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 10 | "github.com/gin-gonic/gin" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestSession(t *testing.T) { 15 | asserts := assert.New(t) 16 | 17 | { 18 | handler := Session("2333") 19 | asserts.NotNil(handler) 20 | asserts.NotNil(Store) 21 | asserts.IsType(emptyFunc(), handler) 22 | } 23 | { 24 | conf.RedisConfig.Server = "123" 25 | asserts.Panics(func() { 26 | Session("2333") 27 | }) 28 | conf.RedisConfig.Server = "" 29 | } 30 | 31 | } 32 | 33 | func emptyFunc() gin.HandlerFunc { 34 | return func(c *gin.Context) {} 35 | } 36 | 37 | func TestCSRFInit(t *testing.T) { 38 | asserts := assert.New(t) 39 | rec := httptest.NewRecorder() 40 | sessionFunc := Session("233") 41 | { 42 | c, _ := gin.CreateTestContext(rec) 43 | c.Request, _ = http.NewRequest("GET", "/test", nil) 44 | sessionFunc(c) 45 | CSRFInit()(c) 46 | asserts.True(util.GetSession(c, "CSRF").(bool)) 47 | } 48 | } 49 | 50 | func TestCSRFCheck(t *testing.T) { 51 | asserts := assert.New(t) 52 | rec := httptest.NewRecorder() 53 | sessionFunc := Session("233") 54 | 55 | // 通过检查 56 | { 57 | c, _ := gin.CreateTestContext(rec) 58 | c.Request, _ = http.NewRequest("GET", "/test", nil) 59 | sessionFunc(c) 60 | CSRFInit()(c) 61 | CSRFCheck()(c) 62 | asserts.False(c.IsAborted()) 63 | } 64 | 65 | // 未通过检查 66 | { 67 | c, _ := gin.CreateTestContext(rec) 68 | c.Request, _ = http.NewRequest("GET", "/test", nil) 69 | sessionFunc(c) 70 | CSRFCheck()(c) 71 | asserts.True(c.IsAborted()) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /models/download.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" 7 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 8 | "github.com/jinzhu/gorm" 9 | ) 10 | 11 | // Download 离线下载队列模型 12 | type Download struct { 13 | gorm.Model 14 | Status int // 任务状态 15 | Type int // 任务类型 16 | Source string `gorm:"type:text"` // 文件下载地址 17 | TotalSize uint64 // 文件大小 18 | DownloadedSize uint64 // 文件大小 19 | GID string `gorm:"size:32,index:gid"` // 任务ID 20 | Speed int // 下载速度 21 | Parent string `gorm:"type:text"` // 存储目录 22 | Attrs string `gorm:"size:65535"` // 任务状态属性 23 | Error string `gorm:"type:text"` // 错误描述 24 | Dst string `gorm:"type:text"` // 用户文件系统存储父目录路径 25 | UserID uint // 发起者UID 26 | TaskID uint // 对应的转存任务ID 27 | 28 | // 关联模型 29 | User *User `gorm:"PRELOAD:false,association_autoupdate:false"` 30 | 31 | // 数据库忽略字段 32 | StatusInfo rpc.StatusInfo `gorm:"-"` 33 | Task *Task `gorm:"-"` 34 | } 35 | 36 | // AfterFind 找到下载任务后的钩子,处理Status结构 37 | func (task *Download) AfterFind() (err error) { 38 | // 解析状态 39 | if task.Attrs != "" { 40 | err = json.Unmarshal([]byte(task.Attrs), &task.StatusInfo) 41 | } 42 | 43 | if task.TaskID != 0 { 44 | task.Task, _ = GetTasksByID(task.TaskID) 45 | } 46 | 47 | return err 48 | } 49 | 50 | // BeforeSave Save下载任务前的钩子 51 | func (task *Download) BeforeSave() (err error) { 52 | // 解析状态 53 | if task.Attrs != "" { 54 | err = json.Unmarshal([]byte(task.Attrs), &task.StatusInfo) 55 | } 56 | return err 57 | } 58 | 59 | // Create 创建离线下载记录 60 | func (task *Download) Create() (uint, error) { 61 | if err := DB.Create(task).Error; err != nil { 62 | util.Log().Warning("无法插入离线下载记录, %s", err) 63 | return 0, err 64 | } 65 | return task.ID, nil 66 | } 67 | 68 | // Save 更新 69 | func (task *Download) Save() error { 70 | if err := DB.Save(task).Error; err != nil { 71 | util.Log().Warning("无法更新离线下载记录, %s", err) 72 | return err 73 | } 74 | return nil 75 | } 76 | 77 | // GetDownloadsByStatus 根据状态检索下载 78 | func GetDownloadsByStatus(status ...int) []Download { 79 | var tasks []Download 80 | DB.Where("status in (?)", status).Find(&tasks) 81 | return tasks 82 | } 83 | 84 | // GetDownloadsByStatusAndUser 根据状态检索和用户ID下载 85 | // page 为 0 表示列出所有,非零时分页 86 | func GetDownloadsByStatusAndUser(page, uid uint, status ...int) []Download { 87 | var tasks []Download 88 | dbChain := DB 89 | if page > 0 { 90 | dbChain = dbChain.Limit(10).Offset((page - 1) * 10).Order("updated_at DESC") 91 | } 92 | dbChain.Where("user_id = ? and status in (?)", uid, status).Find(&tasks) 93 | return tasks 94 | } 95 | 96 | // GetDownloadByGid 根据GID和用户ID查找下载 97 | func GetDownloadByGid(gid string, uid uint) (*Download, error) { 98 | download := &Download{} 99 | result := DB.Where("user_id = ? and g_id = ?", uid, gid).First(download) 100 | return download, result.Error 101 | } 102 | 103 | // GetOwner 获取下载任务所属用户 104 | func (task *Download) GetOwner() *User { 105 | if task.User == nil { 106 | if user, err := GetUserByID(task.UserID); err == nil { 107 | return &user 108 | } 109 | } 110 | return task.User 111 | } 112 | 113 | // Delete 删除离线下载记录 114 | func (download *Download) Delete() error { 115 | return DB.Model(download).Delete(download).Error 116 | } 117 | -------------------------------------------------------------------------------- /models/group.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "encoding/json" 5 | "github.com/jinzhu/gorm" 6 | ) 7 | 8 | // Group 用户组模型 9 | type Group struct { 10 | gorm.Model 11 | Name string 12 | Policies string 13 | MaxStorage uint64 14 | ShareEnabled bool 15 | WebDAVEnabled bool 16 | SpeedLimit int 17 | Options string `json:"-",gorm:"type:text"` 18 | 19 | // 数据库忽略字段 20 | PolicyList []uint `gorm:"-"` 21 | OptionsSerialized GroupOption `gorm:"-"` 22 | } 23 | 24 | // GroupOption 用户组其他配置 25 | type GroupOption struct { 26 | ArchiveDownload bool `json:"archive_download,omitempty"` // 打包下载 27 | ArchiveTask bool `json:"archive_task,omitempty"` // 在线压缩 28 | CompressSize uint64 `json:"compress_size,omitempty"` // 可压缩大小 29 | DecompressSize uint64 `json:"decompress_size,omitempty"` 30 | OneTimeDownload bool `json:"one_time_download,omitempty"` 31 | ShareDownload bool `json:"share_download,omitempty"` 32 | Aria2 bool `json:"aria2,omitempty"` // 离线下载 33 | Aria2Options map[string]interface{} `json:"aria2_options,omitempty"` // 离线下载用户组配置 34 | } 35 | 36 | // GetGroupByID 用ID获取用户组 37 | func GetGroupByID(ID interface{}) (Group, error) { 38 | var group Group 39 | result := DB.First(&group, ID) 40 | return group, result.Error 41 | } 42 | 43 | // AfterFind 找到用户组后的钩子,处理Policy列表 44 | func (group *Group) AfterFind() (err error) { 45 | // 解析用户组策略列表 46 | if group.Policies != "" { 47 | err = json.Unmarshal([]byte(group.Policies), &group.PolicyList) 48 | } 49 | if err != nil { 50 | return err 51 | } 52 | 53 | // 解析用户组设置 54 | if group.Options != "" { 55 | err = json.Unmarshal([]byte(group.Options), &group.OptionsSerialized) 56 | } 57 | 58 | return err 59 | } 60 | 61 | // BeforeSave Save用户前的钩子 62 | func (group *Group) BeforeSave() (err error) { 63 | err = group.SerializePolicyList() 64 | return err 65 | } 66 | 67 | //SerializePolicyList 将序列后的可选策略列表、配置写入数据库字段 68 | // TODO 完善测试 69 | func (group *Group) SerializePolicyList() (err error) { 70 | policies, err := json.Marshal(&group.PolicyList) 71 | group.Policies = string(policies) 72 | if err != nil { 73 | return err 74 | } 75 | 76 | optionsValue, err := json.Marshal(&group.OptionsSerialized) 77 | group.Options = string(optionsValue) 78 | return err 79 | } 80 | -------------------------------------------------------------------------------- /models/group_test.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "github.com/DATA-DOG/go-sqlmock" 5 | "github.com/jinzhu/gorm" 6 | "github.com/pkg/errors" 7 | "github.com/stretchr/testify/assert" 8 | "testing" 9 | ) 10 | 11 | func TestGetGroupByID(t *testing.T) { 12 | asserts := assert.New(t) 13 | 14 | //找到用户组时 15 | groupRows := sqlmock.NewRows([]string{"id", "name", "policies"}). 16 | AddRow(1, "管理员", "[1]") 17 | mock.ExpectQuery("^SELECT (.+)").WillReturnRows(groupRows) 18 | 19 | group, err := GetGroupByID(1) 20 | asserts.NoError(err) 21 | asserts.Equal(Group{ 22 | Model: gorm.Model{ 23 | ID: 1, 24 | }, 25 | Name: "管理员", 26 | Policies: "[1]", 27 | PolicyList: []uint{1}, 28 | }, group) 29 | 30 | //未找到用户时 31 | mock.ExpectQuery("^SELECT (.+)").WillReturnError(errors.New("not found")) 32 | group, err = GetGroupByID(1) 33 | asserts.Error(err) 34 | asserts.Equal(Group{}, group) 35 | } 36 | 37 | func TestGroup_AfterFind(t *testing.T) { 38 | asserts := assert.New(t) 39 | 40 | testCase := Group{ 41 | Model: gorm.Model{ 42 | ID: 1, 43 | }, 44 | Name: "管理员", 45 | Policies: "[1]", 46 | } 47 | err := testCase.AfterFind() 48 | asserts.NoError(err) 49 | asserts.Equal(testCase.PolicyList, []uint{1}) 50 | 51 | testCase.Policies = "[1,2,3,4,5]" 52 | err = testCase.AfterFind() 53 | asserts.NoError(err) 54 | asserts.Equal(testCase.PolicyList, []uint{1, 2, 3, 4, 5}) 55 | 56 | testCase.Policies = "[1,2,3,4,5" 57 | err = testCase.AfterFind() 58 | asserts.Error(err) 59 | 60 | testCase.Policies = "[]" 61 | err = testCase.AfterFind() 62 | asserts.NoError(err) 63 | asserts.Equal(testCase.PolicyList, []uint{}) 64 | } 65 | 66 | func TestGroup_BeforeSave(t *testing.T) { 67 | asserts := assert.New(t) 68 | group := Group{ 69 | PolicyList: []uint{1, 2, 3}, 70 | } 71 | { 72 | err := group.BeforeSave() 73 | asserts.NoError(err) 74 | asserts.Equal("[1,2,3]", group.Policies) 75 | } 76 | 77 | } 78 | -------------------------------------------------------------------------------- /models/init.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/cloudreve/Cloudreve/v3/pkg/conf" 8 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 9 | "github.com/gin-gonic/gin" 10 | "github.com/jinzhu/gorm" 11 | 12 | _ "github.com/jinzhu/gorm/dialects/mysql" 13 | _ "github.com/jinzhu/gorm/dialects/sqlite" 14 | ) 15 | 16 | // DB 数据库链接单例 17 | var DB *gorm.DB 18 | 19 | // Init 初始化 MySQL 链接 20 | func Init() { 21 | util.Log().Info("初始化数据库连接") 22 | 23 | var ( 24 | db *gorm.DB 25 | err error 26 | ) 27 | 28 | if gin.Mode() == gin.TestMode { 29 | // 测试模式下,使用内存数据库 30 | db, err = gorm.Open("sqlite3", ":memory:") 31 | } else { 32 | switch conf.DatabaseConfig.Type { 33 | case "UNSET", "sqlite", "sqlite3": 34 | // 未指定数据库或者明确指定为 sqlite 时,使用 SQLite3 数据库 35 | db, err = gorm.Open("sqlite3", util.RelativePath(conf.DatabaseConfig.DBFile)) 36 | case "mysql": 37 | // 当前只支持 sqlite3 与 mysql 数据库 38 | // TODO: import 其他 gorm 支持的主流数据库?否则直接 Open 没有任何意义。 39 | // TODO: 数据库连接其他参数允许用户自定义?譬如编码更换为 utf8mb4 以支持表情。 40 | db, err = gorm.Open("mysql", fmt.Sprintf("%s:%s@(%s:%d)/%s?charset=utf8&parseTime=True&loc=Local", 41 | conf.DatabaseConfig.User, 42 | conf.DatabaseConfig.Password, 43 | conf.DatabaseConfig.Host, 44 | conf.DatabaseConfig.Port, 45 | conf.DatabaseConfig.Name)) 46 | default: 47 | util.Log().Panic("不支持数据库类型: %s", conf.DatabaseConfig.Type) 48 | } 49 | } 50 | 51 | //db.SetLogger(util.Log()) 52 | if err != nil { 53 | util.Log().Panic("连接数据库不成功, %s", err) 54 | } 55 | 56 | // 处理表前缀 57 | gorm.DefaultTableNameHandler = func(db *gorm.DB, defaultTableName string) string { 58 | return conf.DatabaseConfig.TablePrefix + defaultTableName 59 | } 60 | 61 | // Debug模式下,输出所有 SQL 日志 62 | if conf.SystemConfig.Debug { 63 | db.LogMode(true) 64 | } else { 65 | db.LogMode(false) 66 | } 67 | 68 | //设置连接池 69 | //空闲 70 | db.DB().SetMaxIdleConns(50) 71 | //打开 72 | db.DB().SetMaxOpenConns(100) 73 | //超时 74 | db.DB().SetConnMaxLifetime(time.Second * 30) 75 | 76 | DB = db 77 | 78 | //执行迁移 79 | migration() 80 | } 81 | -------------------------------------------------------------------------------- /models/migration_test.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/cloudreve/Cloudreve/v3/pkg/conf" 7 | "github.com/jinzhu/gorm" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestMigration(t *testing.T) { 12 | asserts := assert.New(t) 13 | conf.DatabaseConfig.Type = "sqlite3" 14 | DB, _ = gorm.Open("sqlite3", ":memory:") 15 | 16 | asserts.NotPanics(func() { 17 | migration() 18 | }) 19 | conf.DatabaseConfig.Type = "mysql" 20 | DB = mockDB 21 | } 22 | -------------------------------------------------------------------------------- /models/scripts/invoker.go: -------------------------------------------------------------------------------- 1 | package scripts 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | ) 7 | 8 | type DBScript interface { 9 | Run(ctx context.Context) 10 | } 11 | 12 | var availableScripts = make(map[string]DBScript) 13 | 14 | func RunDBScript(name string, ctx context.Context) error { 15 | if script, ok := availableScripts[name]; ok { 16 | script.Run(ctx) 17 | return nil 18 | } 19 | 20 | return fmt.Errorf("数据库脚本 [%s] 不存在", name) 21 | } 22 | 23 | func register(name string, script DBScript) { 24 | availableScripts[name] = script 25 | } 26 | -------------------------------------------------------------------------------- /models/scripts/invoker_test.go: -------------------------------------------------------------------------------- 1 | package scripts 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "github.com/DATA-DOG/go-sqlmock" 7 | model "github.com/cloudreve/Cloudreve/v3/models" 8 | "github.com/jinzhu/gorm" 9 | "github.com/stretchr/testify/assert" 10 | "testing" 11 | ) 12 | 13 | var mock sqlmock.Sqlmock 14 | var mockDB *gorm.DB 15 | 16 | type TestScript int 17 | 18 | func (script TestScript) Run(ctx context.Context) { 19 | 20 | } 21 | 22 | // TestMain 初始化数据库Mock 23 | func TestMain(m *testing.M) { 24 | var db *sql.DB 25 | var err error 26 | db, mock, err = sqlmock.New() 27 | if err != nil { 28 | panic("An error was not expected when opening a stub database connection") 29 | } 30 | model.DB, _ = gorm.Open("mysql", db) 31 | mockDB = model.DB 32 | defer db.Close() 33 | m.Run() 34 | } 35 | 36 | func TestRunDBScript(t *testing.T) { 37 | asserts := assert.New(t) 38 | register("test", TestScript(0)) 39 | 40 | // 不存在 41 | { 42 | asserts.Error(RunDBScript("else", context.Background())) 43 | } 44 | 45 | // 存在 46 | { 47 | asserts.NoError(RunDBScript("test", context.Background())) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /models/scripts/storage.go: -------------------------------------------------------------------------------- 1 | package scripts 2 | 3 | import ( 4 | "context" 5 | model "github.com/cloudreve/Cloudreve/v3/models" 6 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 7 | ) 8 | 9 | type UserStorageCalibration int 10 | 11 | func init() { 12 | register("CalibrateUserStorage", UserStorageCalibration(0)) 13 | } 14 | 15 | type storageResult struct { 16 | Total uint64 17 | } 18 | 19 | // Run 运行脚本校准所有用户容量 20 | func (script UserStorageCalibration) Run(ctx context.Context) { 21 | // 列出所有用户 22 | var res []model.User 23 | model.DB.Model(&model.User{}).Find(&res) 24 | 25 | // 逐个检查容量 26 | for _, user := range res { 27 | // 计算正确的容量 28 | var total storageResult 29 | model.DB.Model(&model.File{}).Where("user_id = ?", user.ID).Select("sum(size) as total").Scan(&total) 30 | // 更新用户的容量 31 | if user.Storage != total.Total { 32 | util.Log().Info("将用户 [%s] 的容量由 %d 校准为 %d", user.Email, 33 | user.Storage, total.Total) 34 | model.DB.Model(&user).Update("storage", total.Total) 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /models/scripts/storage_test.go: -------------------------------------------------------------------------------- 1 | package scripts 2 | 3 | import ( 4 | "context" 5 | "github.com/DATA-DOG/go-sqlmock" 6 | "github.com/stretchr/testify/assert" 7 | "testing" 8 | ) 9 | 10 | func TestUserStorageCalibration_Run(t *testing.T) { 11 | asserts := assert.New(t) 12 | script := UserStorageCalibration(0) 13 | 14 | // 容量异常 15 | { 16 | mock.ExpectQuery("SELECT(.+)users(.+)"). 17 | WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10)) 18 | mock.ExpectQuery("SELECT(.+)files(.+)"). 19 | WithArgs(1). 20 | WillReturnRows(sqlmock.NewRows([]string{"total"}).AddRow(11)) 21 | mock.ExpectBegin() 22 | mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) 23 | mock.ExpectCommit() 24 | script.Run(context.Background()) 25 | asserts.NoError(mock.ExpectationsWereMet()) 26 | } 27 | 28 | // 容量正常 29 | { 30 | mock.ExpectQuery("SELECT(.+)users(.+)"). 31 | WillReturnRows(sqlmock.NewRows([]string{"id", "email", "storage"}).AddRow(1, "a@a.com", 10)) 32 | mock.ExpectQuery("SELECT(.+)files(.+)"). 33 | WithArgs(1). 34 | WillReturnRows(sqlmock.NewRows([]string{"total"}).AddRow(10)) 35 | script.Run(context.Background()) 36 | asserts.NoError(mock.ExpectationsWereMet()) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /models/setting.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "net/url" 5 | "strconv" 6 | 7 | "github.com/cloudreve/Cloudreve/v3/pkg/cache" 8 | "github.com/jinzhu/gorm" 9 | ) 10 | 11 | // Setting 系统设置模型 12 | type Setting struct { 13 | gorm.Model 14 | Type string `gorm:"not null"` 15 | Name string `gorm:"unique;not null;index:setting_key"` 16 | Value string `gorm:"size:‎65535"` 17 | } 18 | 19 | // IsTrueVal 返回设置的值是否为真 20 | func IsTrueVal(val string) bool { 21 | return val == "1" || val == "true" 22 | } 23 | 24 | // GetSettingByName 用 Name 获取设置值 25 | func GetSettingByName(name string) string { 26 | var setting Setting 27 | 28 | // 优先从缓存中查找 29 | cacheKey := "setting_" + name 30 | if optionValue, ok := cache.Get(cacheKey); ok { 31 | return optionValue.(string) 32 | } 33 | // 尝试数据库中查找 34 | result := DB.Where("name = ?", name).First(&setting) 35 | if result.Error == nil { 36 | _ = cache.Set(cacheKey, setting.Value, -1) 37 | return setting.Value 38 | } 39 | return "" 40 | } 41 | 42 | // GetSettingByNames 用多个 Name 获取设置值 43 | func GetSettingByNames(names ...string) map[string]string { 44 | var queryRes []Setting 45 | res, miss := cache.GetSettings(names, "setting_") 46 | 47 | if len(miss) > 0 { 48 | DB.Where("name IN (?)", miss).Find(&queryRes) 49 | for _, setting := range queryRes { 50 | res[setting.Name] = setting.Value 51 | } 52 | } 53 | 54 | _ = cache.SetSettings(res, "setting_") 55 | return res 56 | } 57 | 58 | // GetSettingByType 获取一个或多个分组的所有设置值 59 | func GetSettingByType(types []string) map[string]string { 60 | var queryRes []Setting 61 | res := make(map[string]string) 62 | 63 | DB.Where("type IN (?)", types).Find(&queryRes) 64 | for _, setting := range queryRes { 65 | res[setting.Name] = setting.Value 66 | } 67 | 68 | return res 69 | } 70 | 71 | // GetSiteURL 获取站点地址 72 | func GetSiteURL() *url.URL { 73 | base, err := url.Parse(GetSettingByName("siteURL")) 74 | if err != nil { 75 | base, _ = url.Parse("https://cloudreve.org") 76 | } 77 | return base 78 | } 79 | 80 | // GetIntSetting 获取整形设置值,如果转换失败则返回默认值defaultVal 81 | func GetIntSetting(key string, defaultVal int) int { 82 | res, err := strconv.Atoi(GetSettingByName(key)) 83 | if err != nil { 84 | return defaultVal 85 | } 86 | return res 87 | } 88 | -------------------------------------------------------------------------------- /models/tag.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 5 | "github.com/jinzhu/gorm" 6 | ) 7 | 8 | // Tag 用户自定义标签 9 | type Tag struct { 10 | gorm.Model 11 | Name string // 标签名 12 | Icon string // 图标标识 13 | Color string // 图标颜色 14 | Type int // 标签类型(文件分类/目录直达) 15 | Expression string `gorm:"type:text"` // 搜索表表达式/直达路径 16 | UserID uint // 创建者ID 17 | } 18 | 19 | const ( 20 | // FileTagType 文件分类标签 21 | FileTagType = iota 22 | // DirectoryLinkType 目录快捷方式标签 23 | DirectoryLinkType 24 | ) 25 | 26 | // Create 创建标签记录 27 | func (tag *Tag) Create() (uint, error) { 28 | if err := DB.Create(tag).Error; err != nil { 29 | util.Log().Warning("无法插入离线下载记录, %s", err) 30 | return 0, err 31 | } 32 | return tag.ID, nil 33 | } 34 | 35 | // DeleteTagByID 根据给定ID和用户ID删除标签 36 | func DeleteTagByID(id, uid uint) error { 37 | result := DB.Where("id = ? and user_id = ?", id, uid).Delete(&Tag{}) 38 | return result.Error 39 | } 40 | 41 | // GetTagsByUID 根据用户ID查找标签 42 | func GetTagsByUID(uid uint) ([]Tag, error) { 43 | var tag []Tag 44 | result := DB.Where("user_id = ?", uid).Find(&tag) 45 | return tag, result.Error 46 | } 47 | 48 | // GetTagsByID 根据ID查找标签 49 | func GetTagsByID(id, uid uint) (*Tag, error) { 50 | var tag Tag 51 | result := DB.Where("user_id = ? and id = ?", uid, id).First(&tag) 52 | return &tag, result.Error 53 | } 54 | -------------------------------------------------------------------------------- /models/tag_test.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "errors" 5 | "github.com/DATA-DOG/go-sqlmock" 6 | "github.com/stretchr/testify/assert" 7 | "testing" 8 | ) 9 | 10 | func TestTag_Create(t *testing.T) { 11 | asserts := assert.New(t) 12 | tag := Tag{} 13 | 14 | // 成功 15 | { 16 | mock.ExpectBegin() 17 | mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) 18 | mock.ExpectCommit() 19 | id, err := tag.Create() 20 | asserts.NoError(mock.ExpectationsWereMet()) 21 | asserts.NoError(err) 22 | asserts.EqualValues(1, id) 23 | } 24 | 25 | // 失败 26 | { 27 | mock.ExpectBegin() 28 | mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) 29 | mock.ExpectRollback() 30 | id, err := tag.Create() 31 | asserts.NoError(mock.ExpectationsWereMet()) 32 | asserts.Error(err) 33 | asserts.EqualValues(0, id) 34 | } 35 | } 36 | 37 | func TestDeleteTagByID(t *testing.T) { 38 | asserts := assert.New(t) 39 | mock.ExpectBegin() 40 | mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) 41 | mock.ExpectCommit() 42 | err := DeleteTagByID(1, 2) 43 | asserts.NoError(mock.ExpectationsWereMet()) 44 | asserts.NoError(err) 45 | } 46 | 47 | func TestGetTagsByUID(t *testing.T) { 48 | asserts := assert.New(t) 49 | mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) 50 | res, err := GetTagsByUID(1) 51 | asserts.NoError(mock.ExpectationsWereMet()) 52 | asserts.NoError(err) 53 | asserts.Len(res, 1) 54 | } 55 | 56 | func TestGetTagsByID(t *testing.T) { 57 | asserts := assert.New(t) 58 | mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) 59 | res, err := GetTasksByID(1) 60 | asserts.NoError(mock.ExpectationsWereMet()) 61 | asserts.NoError(err) 62 | asserts.EqualValues(1, res.ID) 63 | } 64 | -------------------------------------------------------------------------------- /models/task.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 5 | "github.com/jinzhu/gorm" 6 | ) 7 | 8 | // Task 任务模型 9 | type Task struct { 10 | gorm.Model 11 | Status int // 任务状态 12 | Type int // 任务类型 13 | UserID uint // 发起者UID,0表示为系统发起 14 | Progress int // 进度 15 | Error string `gorm:"type:text"` // 错误信息 16 | Props string `gorm:"type:text"` // 任务属性 17 | } 18 | 19 | // Create 创建任务记录 20 | func (task *Task) Create() (uint, error) { 21 | if err := DB.Create(task).Error; err != nil { 22 | util.Log().Warning("无法插入任务记录, %s", err) 23 | return 0, err 24 | } 25 | return task.ID, nil 26 | } 27 | 28 | // SetStatus 设定任务状态 29 | func (task *Task) SetStatus(status int) error { 30 | return DB.Model(task).Select("status").Updates(map[string]interface{}{"status": status}).Error 31 | } 32 | 33 | // SetProgress 设定任务进度 34 | func (task *Task) SetProgress(progress int) error { 35 | return DB.Model(task).Select("progress").Updates(map[string]interface{}{"progress": progress}).Error 36 | } 37 | 38 | // SetError 设定错误信息 39 | func (task *Task) SetError(err string) error { 40 | return DB.Model(task).Select("error").Updates(map[string]interface{}{"error": err}).Error 41 | } 42 | 43 | // GetTasksByStatus 根据状态检索任务 44 | func GetTasksByStatus(status ...int) []Task { 45 | var tasks []Task 46 | DB.Where("status in (?)", status).Find(&tasks) 47 | return tasks 48 | } 49 | 50 | // GetTasksByID 根据ID检索任务 51 | func GetTasksByID(id interface{}) (*Task, error) { 52 | task := &Task{} 53 | result := DB.Where("id = ?", id).First(task) 54 | return task, result.Error 55 | } 56 | 57 | // ListTasks 列出用户所属的任务 58 | func ListTasks(uid uint, page, pageSize int, order string) ([]Task, int) { 59 | var ( 60 | tasks []Task 61 | total int 62 | ) 63 | dbChain := DB 64 | dbChain = dbChain.Where("user_id = ?", uid) 65 | 66 | // 计算总数用于分页 67 | dbChain.Model(&Share{}).Count(&total) 68 | 69 | // 查询记录 70 | dbChain.Limit(pageSize).Offset((page - 1) * pageSize).Order(order).Find(&tasks) 71 | 72 | return tasks, total 73 | } 74 | -------------------------------------------------------------------------------- /models/task_test.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "errors" 5 | "github.com/DATA-DOG/go-sqlmock" 6 | "github.com/jinzhu/gorm" 7 | "github.com/stretchr/testify/assert" 8 | "testing" 9 | ) 10 | 11 | func TestTask_Create(t *testing.T) { 12 | asserts := assert.New(t) 13 | // 成功 14 | { 15 | mock.ExpectBegin() 16 | mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) 17 | mock.ExpectCommit() 18 | task := Task{Props: "1"} 19 | id, err := task.Create() 20 | asserts.NoError(mock.ExpectationsWereMet()) 21 | asserts.NoError(err) 22 | asserts.EqualValues(1, id) 23 | } 24 | 25 | // 失败 26 | { 27 | mock.ExpectBegin() 28 | mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) 29 | mock.ExpectRollback() 30 | task := Task{Props: "1"} 31 | id, err := task.Create() 32 | asserts.NoError(mock.ExpectationsWereMet()) 33 | asserts.Error(err) 34 | asserts.EqualValues(0, id) 35 | } 36 | } 37 | 38 | func TestTask_SetError(t *testing.T) { 39 | asserts := assert.New(t) 40 | task := Task{ 41 | Model: gorm.Model{ID: 1}, 42 | } 43 | mock.ExpectBegin() 44 | mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) 45 | mock.ExpectCommit() 46 | asserts.NoError(task.SetError("error")) 47 | asserts.NoError(mock.ExpectationsWereMet()) 48 | } 49 | 50 | func TestTask_SetStatus(t *testing.T) { 51 | asserts := assert.New(t) 52 | task := Task{ 53 | Model: gorm.Model{ID: 1}, 54 | } 55 | mock.ExpectBegin() 56 | mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) 57 | mock.ExpectCommit() 58 | asserts.NoError(task.SetStatus(1)) 59 | asserts.NoError(mock.ExpectationsWereMet()) 60 | } 61 | 62 | func TestTask_SetProgress(t *testing.T) { 63 | asserts := assert.New(t) 64 | task := Task{ 65 | Model: gorm.Model{ID: 1}, 66 | } 67 | mock.ExpectBegin() 68 | mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) 69 | mock.ExpectCommit() 70 | asserts.NoError(task.SetProgress(1)) 71 | asserts.NoError(mock.ExpectationsWereMet()) 72 | } 73 | 74 | func TestGetTasksByID(t *testing.T) { 75 | asserts := assert.New(t) 76 | mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) 77 | res, err := GetTasksByID(1) 78 | asserts.NoError(mock.ExpectationsWereMet()) 79 | asserts.NoError(err) 80 | asserts.EqualValues(1, res.ID) 81 | } 82 | 83 | func TestListTasks(t *testing.T) { 84 | asserts := assert.New(t) 85 | 86 | mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(5)) 87 | mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(5)) 88 | 89 | res, total := ListTasks(1, 1, 10, "") 90 | asserts.NoError(mock.ExpectationsWereMet()) 91 | asserts.EqualValues(5, total) 92 | asserts.Len(res, 1) 93 | } 94 | -------------------------------------------------------------------------------- /models/user_authn.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "encoding/base64" 5 | "encoding/binary" 6 | "encoding/json" 7 | "fmt" 8 | "net/url" 9 | 10 | "github.com/cloudreve/Cloudreve/v3/pkg/hashid" 11 | "github.com/duo-labs/webauthn/webauthn" 12 | ) 13 | 14 | /* 15 | `webauthn.User` 接口的实现 16 | */ 17 | 18 | // WebAuthnID 返回用户ID 19 | func (user User) WebAuthnID() []byte { 20 | bs := make([]byte, 8) 21 | binary.LittleEndian.PutUint64(bs, uint64(user.ID)) 22 | return bs 23 | } 24 | 25 | // WebAuthnName 返回用户名 26 | func (user User) WebAuthnName() string { 27 | return user.Email 28 | } 29 | 30 | // WebAuthnDisplayName 获得用于展示的用户名 31 | func (user User) WebAuthnDisplayName() string { 32 | return user.Nick 33 | } 34 | 35 | // WebAuthnIcon 获得用户头像 36 | func (user User) WebAuthnIcon() string { 37 | avatar, _ := url.Parse("/api/v3/user/avatar/" + hashid.HashID(user.ID, hashid.UserID) + "/l") 38 | base := GetSiteURL() 39 | base.Scheme = "https" 40 | return base.ResolveReference(avatar).String() 41 | } 42 | 43 | // WebAuthnCredentials 获得已注册的验证器凭证 44 | func (user User) WebAuthnCredentials() []webauthn.Credential { 45 | var res []webauthn.Credential 46 | err := json.Unmarshal([]byte(user.Authn), &res) 47 | if err != nil { 48 | fmt.Println(err) 49 | } 50 | return res 51 | } 52 | 53 | // RegisterAuthn 添加新的验证器 54 | func (user *User) RegisterAuthn(credential *webauthn.Credential) error { 55 | exists := user.WebAuthnCredentials() 56 | exists = append(exists, *credential) 57 | res, err := json.Marshal(exists) 58 | if err != nil { 59 | return err 60 | } 61 | 62 | return DB.Model(user).Update("authn", string(res)).Error 63 | } 64 | 65 | // RemoveAuthn 删除验证器 66 | func (user *User) RemoveAuthn(id string) { 67 | exists := user.WebAuthnCredentials() 68 | for i := 0; i < len(exists); i++ { 69 | idEncoded := base64.StdEncoding.EncodeToString(exists[i].ID) 70 | if idEncoded == id { 71 | exists[len(exists)-1], exists[i] = exists[i], exists[len(exists)-1] 72 | exists = exists[:len(exists)-1] 73 | break 74 | } 75 | } 76 | 77 | res, _ := json.Marshal(exists) 78 | DB.Model(user).Update("authn", string(res)) 79 | } 80 | -------------------------------------------------------------------------------- /models/user_authn_test.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "github.com/DATA-DOG/go-sqlmock" 5 | "github.com/duo-labs/webauthn/webauthn" 6 | "github.com/jinzhu/gorm" 7 | "github.com/stretchr/testify/assert" 8 | "testing" 9 | ) 10 | 11 | func TestUser_RegisterAuthn(t *testing.T) { 12 | asserts := assert.New(t) 13 | credential := webauthn.Credential{} 14 | user := User{ 15 | Model: gorm.Model{ID: 1}, 16 | } 17 | 18 | { 19 | mock.ExpectBegin() 20 | mock.ExpectExec("UPDATE(.+)"). 21 | WillReturnResult(sqlmock.NewResult(1, 1)) 22 | mock.ExpectCommit() 23 | user.RegisterAuthn(&credential) 24 | asserts.NoError(mock.ExpectationsWereMet()) 25 | } 26 | } 27 | 28 | func TestUser_WebAuthnCredentials(t *testing.T) { 29 | asserts := assert.New(t) 30 | user := User{ 31 | Model: gorm.Model{ID: 1}, 32 | Authn: `[{"ID":"123","PublicKey":"+4sg1vYcjg/+=","AttestationType":"packed","Authenticator":{"AAGUID":"+lg==","SignCount":0,"CloneWarning":false}}]`, 33 | } 34 | { 35 | credentials := user.WebAuthnCredentials() 36 | asserts.Len(credentials, 1) 37 | } 38 | } 39 | 40 | func TestUser_WebAuthnDisplayName(t *testing.T) { 41 | asserts := assert.New(t) 42 | user := User{ 43 | Model: gorm.Model{ID: 1}, 44 | Nick: "123", 45 | } 46 | { 47 | nick := user.WebAuthnDisplayName() 48 | asserts.Equal("123", nick) 49 | } 50 | } 51 | 52 | func TestUser_WebAuthnIcon(t *testing.T) { 53 | asserts := assert.New(t) 54 | user := User{ 55 | Model: gorm.Model{ID: 1}, 56 | } 57 | { 58 | icon := user.WebAuthnIcon() 59 | asserts.NotEmpty(icon) 60 | } 61 | } 62 | 63 | func TestUser_WebAuthnID(t *testing.T) { 64 | asserts := assert.New(t) 65 | user := User{ 66 | Model: gorm.Model{ID: 1}, 67 | } 68 | { 69 | id := user.WebAuthnID() 70 | asserts.Len(id, 8) 71 | } 72 | } 73 | 74 | func TestUser_WebAuthnName(t *testing.T) { 75 | asserts := assert.New(t) 76 | user := User{ 77 | Model: gorm.Model{ID: 1}, 78 | Email: "abslant@foxmail.com", 79 | } 80 | { 81 | name := user.WebAuthnName() 82 | asserts.Equal("abslant@foxmail.com", name) 83 | } 84 | } 85 | 86 | func TestUser_RemoveAuthn(t *testing.T) { 87 | asserts := assert.New(t) 88 | user := User{ 89 | Model: gorm.Model{ID: 1}, 90 | Authn: `[{"ID":"123","PublicKey":"+4sg1vYcjg/+=","AttestationType":"packed","Authenticator":{"AAGUID":"+lg==","SignCount":0,"CloneWarning":false}}]`, 91 | } 92 | { 93 | mock.ExpectBegin() 94 | mock.ExpectExec("UPDATE(.+)"). 95 | WillReturnResult(sqlmock.NewResult(1, 1)) 96 | mock.ExpectCommit() 97 | user.RemoveAuthn("123") 98 | asserts.NoError(mock.ExpectationsWereMet()) 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /models/webdav.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "github.com/jinzhu/gorm" 5 | ) 6 | 7 | // Webdav 应用账户 8 | type Webdav struct { 9 | gorm.Model 10 | Name string // 应用名称 11 | Password string `gorm:"unique_index:password_only_on"` // 应用密码 12 | UserID uint `gorm:"unique_index:password_only_on"` // 用户ID 13 | Root string `gorm:"type:text"` // 根目录 14 | } 15 | 16 | // Create 创建账户 17 | func (webdav *Webdav) Create() (uint, error) { 18 | if err := DB.Create(webdav).Error; err != nil { 19 | return 0, err 20 | } 21 | return webdav.ID, nil 22 | } 23 | 24 | // GetWebdavByPassword 根据密码和用户查找Webdav应用 25 | func GetWebdavByPassword(password string, uid uint) (*Webdav, error) { 26 | webdav := &Webdav{} 27 | res := DB.Where("user_id = ? and password = ?", uid, password).First(webdav) 28 | return webdav, res.Error 29 | } 30 | 31 | // ListWebDAVAccounts 列出用户的所有账号 32 | func ListWebDAVAccounts(uid uint) []Webdav { 33 | var accounts []Webdav 34 | DB.Where("user_id = ?", uid).Order("created_at desc").Find(&accounts) 35 | return accounts 36 | } 37 | 38 | // DeleteWebDAVAccountByID 根据账户ID和UID删除账户 39 | func DeleteWebDAVAccountByID(id, uid uint) { 40 | DB.Where("user_id = ? and id = ?", uid, id).Delete(&Webdav{}) 41 | } 42 | -------------------------------------------------------------------------------- /models/webdav_test.go: -------------------------------------------------------------------------------- 1 | package model 2 | 3 | import ( 4 | "errors" 5 | "github.com/DATA-DOG/go-sqlmock" 6 | "github.com/stretchr/testify/assert" 7 | "testing" 8 | ) 9 | 10 | func TestWebdav_Create(t *testing.T) { 11 | asserts := assert.New(t) 12 | // 成功 13 | { 14 | mock.ExpectBegin() 15 | mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) 16 | mock.ExpectCommit() 17 | task := Webdav{} 18 | id, err := task.Create() 19 | asserts.NoError(mock.ExpectationsWereMet()) 20 | asserts.NoError(err) 21 | asserts.EqualValues(1, id) 22 | } 23 | 24 | // 失败 25 | { 26 | mock.ExpectBegin() 27 | mock.ExpectExec("INSERT(.+)").WillReturnError(errors.New("error")) 28 | mock.ExpectRollback() 29 | task := Webdav{} 30 | id, err := task.Create() 31 | asserts.NoError(mock.ExpectationsWereMet()) 32 | asserts.Error(err) 33 | asserts.EqualValues(0, id) 34 | } 35 | } 36 | 37 | func TestGetWebdavByPassword(t *testing.T) { 38 | asserts := assert.New(t) 39 | mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"})) 40 | _, err := GetWebdavByPassword("e", 1) 41 | asserts.NoError(mock.ExpectationsWereMet()) 42 | asserts.Error(err) 43 | } 44 | 45 | func TestListWebDAVAccounts(t *testing.T) { 46 | asserts := assert.New(t) 47 | mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id"})) 48 | res := ListWebDAVAccounts(1) 49 | asserts.NoError(mock.ExpectationsWereMet()) 50 | asserts.Len(res, 0) 51 | } 52 | 53 | func TestDeleteWebDAVAccountByID(t *testing.T) { 54 | asserts := assert.New(t) 55 | mock.ExpectBegin() 56 | mock.ExpectExec("UPDATE(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) 57 | mock.ExpectCommit() 58 | asserts.NoError(DeleteTagByID(1, 1)) 59 | asserts.NoError(mock.ExpectationsWereMet()) 60 | } 61 | -------------------------------------------------------------------------------- /pkg/aria2/aria2_test.go: -------------------------------------------------------------------------------- 1 | package aria2 2 | 3 | import ( 4 | "database/sql" 5 | "testing" 6 | 7 | "github.com/DATA-DOG/go-sqlmock" 8 | model "github.com/cloudreve/Cloudreve/v3/models" 9 | "github.com/cloudreve/Cloudreve/v3/pkg/cache" 10 | "github.com/jinzhu/gorm" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | var mock sqlmock.Sqlmock 15 | 16 | // TestMain 初始化数据库Mock 17 | func TestMain(m *testing.M) { 18 | var db *sql.DB 19 | var err error 20 | db, mock, err = sqlmock.New() 21 | if err != nil { 22 | panic("An error was not expected when opening a stub database connection") 23 | } 24 | model.DB, _ = gorm.Open("mysql", db) 25 | defer db.Close() 26 | m.Run() 27 | } 28 | 29 | func TestDummyAria2(t *testing.T) { 30 | asserts := assert.New(t) 31 | instance := DummyAria2{} 32 | asserts.Error(instance.CreateTask(nil, nil)) 33 | _, err := instance.Status(nil) 34 | asserts.Error(err) 35 | asserts.Error(instance.Cancel(nil)) 36 | asserts.Error(instance.Select(nil, nil)) 37 | } 38 | 39 | func TestInit(t *testing.T) { 40 | MAX_RETRY = 0 41 | asserts := assert.New(t) 42 | cache.Set("setting_aria2_token", "1", 0) 43 | cache.Set("setting_aria2_call_timeout", "5", 0) 44 | cache.Set("setting_aria2_options", `[]`, 0) 45 | 46 | // 未指定RPC地址,跳过 47 | { 48 | cache.Set("setting_aria2_rpcurl", "", 0) 49 | Init(false) 50 | asserts.IsType(&DummyAria2{}, Instance) 51 | } 52 | 53 | // 无法解析服务器地址 54 | { 55 | cache.Set("setting_aria2_rpcurl", string(byte(0x7f)), 0) 56 | Init(false) 57 | asserts.IsType(&DummyAria2{}, Instance) 58 | } 59 | 60 | // 无法解析全局配置 61 | { 62 | Instance = &RPCService{} 63 | cache.Set("setting_aria2_options", "?", 0) 64 | cache.Set("setting_aria2_rpcurl", "ws://127.0.0.1:1234", 0) 65 | Init(false) 66 | asserts.IsType(&DummyAria2{}, Instance) 67 | } 68 | 69 | // 连接失败 70 | { 71 | cache.Set("setting_aria2_options", "{}", 0) 72 | cache.Set("setting_aria2_rpcurl", "http://127.0.0.1:1234", 0) 73 | cache.Set("setting_aria2_call_timeout", "1", 0) 74 | cache.Set("setting_aria2_interval", "100", 0) 75 | mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"g_id"}).AddRow("1")) 76 | Init(false) 77 | asserts.NoError(mock.ExpectationsWereMet()) 78 | asserts.IsType(&RPCService{}, Instance) 79 | } 80 | } 81 | 82 | func TestGetStatus(t *testing.T) { 83 | asserts := assert.New(t) 84 | asserts.Equal(4, getStatus("complete")) 85 | asserts.Equal(1, getStatus("active")) 86 | asserts.Equal(0, getStatus("waiting")) 87 | asserts.Equal(2, getStatus("paused")) 88 | asserts.Equal(3, getStatus("error")) 89 | asserts.Equal(5, getStatus("removed")) 90 | asserts.Equal(6, getStatus("?")) 91 | } 92 | -------------------------------------------------------------------------------- /pkg/aria2/caller.go: -------------------------------------------------------------------------------- 1 | package aria2 2 | 3 | import ( 4 | "context" 5 | "path/filepath" 6 | "strconv" 7 | "strings" 8 | "time" 9 | 10 | model "github.com/cloudreve/Cloudreve/v3/models" 11 | "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" 12 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 13 | ) 14 | 15 | // RPCService 通过RPC服务的Aria2任务管理器 16 | type RPCService struct { 17 | options *clientOptions 18 | Caller rpc.Client 19 | } 20 | 21 | type clientOptions struct { 22 | Options map[string]interface{} // 创建下载时额外添加的设置 23 | } 24 | 25 | // Init 初始化 26 | func (client *RPCService) Init(server, secret string, timeout int, options map[string]interface{}) error { 27 | // 客户端已存在,则关闭先前连接 28 | if client.Caller != nil { 29 | client.Caller.Close() 30 | } 31 | 32 | client.options = &clientOptions{ 33 | Options: options, 34 | } 35 | caller, err := rpc.New(context.Background(), server, secret, time.Duration(timeout)*time.Second, 36 | EventNotifier) 37 | client.Caller = caller 38 | return err 39 | } 40 | 41 | // Status 查询下载状态 42 | func (client *RPCService) Status(task *model.Download) (rpc.StatusInfo, error) { 43 | res, err := client.Caller.TellStatus(task.GID) 44 | if err != nil { 45 | // 失败后重试 46 | util.Log().Debug("无法获取离线下载状态,%s,10秒钟后重试", err) 47 | time.Sleep(time.Duration(10) * time.Second) 48 | res, err = client.Caller.TellStatus(task.GID) 49 | } 50 | 51 | return res, err 52 | } 53 | 54 | // Cancel 取消下载 55 | func (client *RPCService) Cancel(task *model.Download) error { 56 | // 取消下载任务 57 | _, err := client.Caller.Remove(task.GID) 58 | if err != nil { 59 | util.Log().Warning("无法取消离线下载任务[%s], %s", task.GID, err) 60 | } 61 | 62 | //// 删除临时文件 63 | //util.Log().Debug("离线下载任务[%s]已取消,1 分钟后删除临时文件", task.GID) 64 | //go func(task *model.Download) { 65 | // select { 66 | // case <-time.After(time.Duration(60) * time.Second): 67 | // err := os.RemoveAll(task.Parent) 68 | // if err != nil { 69 | // util.Log().Warning("无法删除离线下载临时目录[%s], %s", task.Parent, err) 70 | // } 71 | // } 72 | //}(task) 73 | 74 | return err 75 | } 76 | 77 | // Select 选取要下载的文件 78 | func (client *RPCService) Select(task *model.Download, files []int) error { 79 | var selected = make([]string, len(files)) 80 | for i := 0; i < len(files); i++ { 81 | selected[i] = strconv.Itoa(files[i]) 82 | } 83 | _, err := client.Caller.ChangeOption(task.GID, map[string]interface{}{"select-file": strings.Join(selected, ",")}) 84 | return err 85 | } 86 | 87 | // CreateTask 创建新任务 88 | func (client *RPCService) CreateTask(task *model.Download, groupOptions map[string]interface{}) error { 89 | // 生成存储路径 90 | path := filepath.Join( 91 | model.GetSettingByName("aria2_temp_path"), 92 | "aria2", 93 | strconv.FormatInt(time.Now().UnixNano(), 10), 94 | ) 95 | 96 | // 创建下载任务 97 | options := map[string]interface{}{ 98 | "dir": path, 99 | } 100 | for k, v := range client.options.Options { 101 | options[k] = v 102 | } 103 | for k, v := range groupOptions { 104 | options[k] = v 105 | } 106 | 107 | gid, err := client.Caller.AddURI(task.Source, options) 108 | if err != nil || gid == "" { 109 | return err 110 | } 111 | 112 | // 保存到数据库 113 | task.GID = gid 114 | _, err = task.Create() 115 | if err != nil { 116 | return err 117 | } 118 | 119 | // 创建任务监控 120 | NewMonitor(task) 121 | 122 | return nil 123 | } 124 | -------------------------------------------------------------------------------- /pkg/aria2/caller_test.go: -------------------------------------------------------------------------------- 1 | package aria2 2 | 3 | import ( 4 | "testing" 5 | 6 | model "github.com/cloudreve/Cloudreve/v3/models" 7 | "github.com/cloudreve/Cloudreve/v3/pkg/cache" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestRPCService_Init(t *testing.T) { 12 | asserts := assert.New(t) 13 | caller := &RPCService{} 14 | asserts.Error(caller.Init("ws://", "", 1, nil)) 15 | asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil)) 16 | } 17 | 18 | func TestRPCService_Status(t *testing.T) { 19 | asserts := assert.New(t) 20 | caller := &RPCService{} 21 | asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil)) 22 | 23 | _, err := caller.Status(&model.Download{}) 24 | asserts.Error(err) 25 | } 26 | 27 | func TestRPCService_Cancel(t *testing.T) { 28 | asserts := assert.New(t) 29 | caller := &RPCService{} 30 | asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil)) 31 | 32 | err := caller.Cancel(&model.Download{Parent: "test"}) 33 | asserts.Error(err) 34 | } 35 | 36 | func TestRPCService_Select(t *testing.T) { 37 | asserts := assert.New(t) 38 | caller := &RPCService{} 39 | asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil)) 40 | 41 | err := caller.Select(&model.Download{Parent: "test"}, []int{1, 2, 3}) 42 | asserts.Error(err) 43 | } 44 | 45 | func TestRPCService_CreateTask(t *testing.T) { 46 | asserts := assert.New(t) 47 | caller := &RPCService{} 48 | asserts.NoError(caller.Init("http://127.0.0.1", "", 1, nil)) 49 | cache.Set("setting_aria2_temp_path", "test", 0) 50 | err := caller.CreateTask(&model.Download{Parent: "test"}, map[string]interface{}{"1": "1"}) 51 | asserts.Error(err) 52 | } 53 | -------------------------------------------------------------------------------- /pkg/aria2/notification.go: -------------------------------------------------------------------------------- 1 | package aria2 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" 7 | ) 8 | 9 | // Notifier aria2实践通知处理 10 | type Notifier struct { 11 | Subscribes sync.Map 12 | } 13 | 14 | // Subscribe 订阅事件通知 15 | func (notifier *Notifier) Subscribe(target chan StatusEvent, gid string) { 16 | notifier.Subscribes.Store(gid, target) 17 | } 18 | 19 | // Unsubscribe 取消订阅事件通知 20 | func (notifier *Notifier) Unsubscribe(gid string) { 21 | notifier.Subscribes.Delete(gid) 22 | } 23 | 24 | // Notify 发送通知 25 | func (notifier *Notifier) Notify(events []rpc.Event, status int) { 26 | for _, event := range events { 27 | if target, ok := notifier.Subscribes.Load(event.Gid); ok { 28 | target.(chan StatusEvent) <- StatusEvent{ 29 | GID: event.Gid, 30 | Status: status, 31 | } 32 | } 33 | } 34 | } 35 | 36 | // OnDownloadStart 下载开始 37 | func (notifier *Notifier) OnDownloadStart(events []rpc.Event) { 38 | notifier.Notify(events, Downloading) 39 | } 40 | 41 | // OnDownloadPause 下载暂停 42 | func (notifier *Notifier) OnDownloadPause(events []rpc.Event) { 43 | notifier.Notify(events, Paused) 44 | } 45 | 46 | // OnDownloadStop 下载停止 47 | func (notifier *Notifier) OnDownloadStop(events []rpc.Event) { 48 | notifier.Notify(events, Canceled) 49 | } 50 | 51 | // OnDownloadComplete 下载完成 52 | func (notifier *Notifier) OnDownloadComplete(events []rpc.Event) { 53 | notifier.Notify(events, Complete) 54 | } 55 | 56 | // OnDownloadError 下载出错 57 | func (notifier *Notifier) OnDownloadError(events []rpc.Event) { 58 | notifier.Notify(events, Error) 59 | } 60 | 61 | // OnBtDownloadComplete BT下载完成 62 | func (notifier *Notifier) OnBtDownloadComplete(events []rpc.Event) { 63 | notifier.Notify(events, Complete) 64 | } 65 | -------------------------------------------------------------------------------- /pkg/aria2/notification_test.go: -------------------------------------------------------------------------------- 1 | package aria2 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestNotifier_Notify(t *testing.T) { 11 | asserts := assert.New(t) 12 | notifier2 := &Notifier{} 13 | notifyChan := make(chan StatusEvent, 10) 14 | notifier2.Subscribe(notifyChan, "1") 15 | 16 | // 未订阅 17 | { 18 | notifier2.Notify([]rpc.Event{rpc.Event{Gid: ""}}, 1) 19 | asserts.Len(notifyChan, 0) 20 | } 21 | 22 | // 订阅 23 | { 24 | notifier2.Notify([]rpc.Event{{Gid: "1"}}, 1) 25 | asserts.Len(notifyChan, 1) 26 | <-notifyChan 27 | 28 | notifier2.OnBtDownloadComplete([]rpc.Event{{Gid: "1"}}) 29 | asserts.Len(notifyChan, 1) 30 | <-notifyChan 31 | 32 | notifier2.OnDownloadStart([]rpc.Event{{Gid: "1"}}) 33 | asserts.Len(notifyChan, 1) 34 | <-notifyChan 35 | 36 | notifier2.OnDownloadPause([]rpc.Event{{Gid: "1"}}) 37 | asserts.Len(notifyChan, 1) 38 | <-notifyChan 39 | 40 | notifier2.OnDownloadStop([]rpc.Event{{Gid: "1"}}) 41 | asserts.Len(notifyChan, 1) 42 | <-notifyChan 43 | 44 | notifier2.OnDownloadComplete([]rpc.Event{{Gid: "1"}}) 45 | asserts.Len(notifyChan, 1) 46 | <-notifyChan 47 | 48 | notifier2.OnDownloadError([]rpc.Event{{Gid: "1"}}) 49 | asserts.Len(notifyChan, 1) 50 | <-notifyChan 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /pkg/aria2/rpc/const.go: -------------------------------------------------------------------------------- 1 | package rpc 2 | 3 | const ( 4 | aria2AddURI = "aria2.addUri" 5 | aria2AddTorrent = "aria2.addTorrent" 6 | aria2AddMetalink = "aria2.addMetalink" 7 | aria2Remove = "aria2.remove" 8 | aria2ForceRemove = "aria2.forceRemove" 9 | aria2Pause = "aria2.pause" 10 | aria2PauseAll = "aria2.pauseAll" 11 | aria2ForcePause = "aria2.forcePause" 12 | aria2ForcePauseAll = "aria2.forcePauseAll" 13 | aria2Unpause = "aria2.unpause" 14 | aria2UnpauseAll = "aria2.unpauseAll" 15 | aria2TellStatus = "aria2.tellStatus" 16 | aria2GetURIs = "aria2.getUris" 17 | aria2GetFiles = "aria2.getFiles" 18 | aria2GetPeers = "aria2.getPeers" 19 | aria2GetServers = "aria2.getServers" 20 | aria2TellActive = "aria2.tellActive" 21 | aria2TellWaiting = "aria2.tellWaiting" 22 | aria2TellStopped = "aria2.tellStopped" 23 | aria2ChangePosition = "aria2.changePosition" 24 | aria2ChangeURI = "aria2.changeUri" 25 | aria2GetOption = "aria2.getOption" 26 | aria2ChangeOption = "aria2.changeOption" 27 | aria2GetGlobalOption = "aria2.getGlobalOption" 28 | aria2ChangeGlobalOption = "aria2.changeGlobalOption" 29 | aria2GetGlobalStat = "aria2.getGlobalStat" 30 | aria2PurgeDownloadResult = "aria2.purgeDownloadResult" 31 | aria2RemoveDownloadResult = "aria2.removeDownloadResult" 32 | aria2GetVersion = "aria2.getVersion" 33 | aria2GetSessionInfo = "aria2.getSessionInfo" 34 | aria2Shutdown = "aria2.shutdown" 35 | aria2ForceShutdown = "aria2.forceShutdown" 36 | aria2SaveSession = "aria2.saveSession" 37 | aria2Multicall = "system.multicall" 38 | aria2ListMethods = "system.listMethods" 39 | ) 40 | -------------------------------------------------------------------------------- /pkg/aria2/rpc/notification.go: -------------------------------------------------------------------------------- 1 | package rpc 2 | 3 | import ( 4 | "log" 5 | ) 6 | 7 | type Event struct { 8 | Gid string `json:"gid"` // GID of the download 9 | } 10 | 11 | // The RPC server might send notifications to the client. 12 | // Notifications is unidirectional, therefore the client which receives the notification must not respond to it. 13 | // The method signature of a notification is much like a normal method request but lacks the id key 14 | 15 | type websocketResponse struct { 16 | clientResponse 17 | Method string `json:"method"` 18 | Params []Event `json:"params"` 19 | } 20 | 21 | // Notifier handles rpc notification from aria2 server 22 | type Notifier interface { 23 | // OnDownloadStart will be sent when a download is started. 24 | OnDownloadStart([]Event) 25 | // OnDownloadPause will be sent when a download is paused. 26 | OnDownloadPause([]Event) 27 | // OnDownloadStop will be sent when a download is stopped by the user. 28 | OnDownloadStop([]Event) 29 | // OnDownloadComplete will be sent when a download is complete. For BitTorrent downloads, this notification is sent when the download is complete and seeding is over. 30 | OnDownloadComplete([]Event) 31 | // OnDownloadError will be sent when a download is stopped due to an error. 32 | OnDownloadError([]Event) 33 | // OnBtDownloadComplete will be sent when a torrent download is complete but seeding is still going on. 34 | OnBtDownloadComplete([]Event) 35 | } 36 | 37 | type DummyNotifier struct{} 38 | 39 | func (DummyNotifier) OnDownloadStart(events []Event) { log.Printf("%s started.", events) } 40 | func (DummyNotifier) OnDownloadPause(events []Event) { log.Printf("%s paused.", events) } 41 | func (DummyNotifier) OnDownloadStop(events []Event) { log.Printf("%s stopped.", events) } 42 | func (DummyNotifier) OnDownloadComplete(events []Event) { log.Printf("%s completed.", events) } 43 | func (DummyNotifier) OnDownloadError(events []Event) { log.Printf("%s error.", events) } 44 | func (DummyNotifier) OnBtDownloadComplete(events []Event) { log.Printf("bt %s completed.", events) } 45 | -------------------------------------------------------------------------------- /pkg/aria2/rpc/proc.go: -------------------------------------------------------------------------------- 1 | package rpc 2 | 3 | import "sync" 4 | 5 | type ResponseProcFn func(resp clientResponse) error 6 | 7 | type ResponseProcessor struct { 8 | cbs map[uint64]ResponseProcFn 9 | mu *sync.RWMutex 10 | } 11 | 12 | func NewResponseProcessor() *ResponseProcessor { 13 | return &ResponseProcessor{ 14 | make(map[uint64]ResponseProcFn), 15 | &sync.RWMutex{}, 16 | } 17 | } 18 | 19 | func (r *ResponseProcessor) Add(id uint64, fn ResponseProcFn) { 20 | r.mu.Lock() 21 | r.cbs[id] = fn 22 | r.mu.Unlock() 23 | } 24 | 25 | func (r *ResponseProcessor) remove(id uint64) { 26 | r.mu.Lock() 27 | delete(r.cbs, id) 28 | r.mu.Unlock() 29 | } 30 | 31 | // Process called by recv routine 32 | func (r *ResponseProcessor) Process(resp clientResponse) error { 33 | id := *resp.Id 34 | r.mu.RLock() 35 | fn, ok := r.cbs[id] 36 | r.mu.RUnlock() 37 | if ok && fn != nil { 38 | defer r.remove(id) 39 | return fn(resp) 40 | } 41 | return nil 42 | } 43 | -------------------------------------------------------------------------------- /pkg/aria2/rpc/proto.go: -------------------------------------------------------------------------------- 1 | package rpc 2 | 3 | // Protocol is a set of rpc methods that aria2 daemon supports 4 | type Protocol interface { 5 | AddURI(uri string, options ...interface{}) (gid string, err error) 6 | AddTorrent(filename string, options ...interface{}) (gid string, err error) 7 | AddMetalink(filename string, options ...interface{}) (gid []string, err error) 8 | Remove(gid string) (g string, err error) 9 | ForceRemove(gid string) (g string, err error) 10 | Pause(gid string) (g string, err error) 11 | PauseAll() (ok string, err error) 12 | ForcePause(gid string) (g string, err error) 13 | ForcePauseAll() (ok string, err error) 14 | Unpause(gid string) (g string, err error) 15 | UnpauseAll() (ok string, err error) 16 | TellStatus(gid string, keys ...string) (info StatusInfo, err error) 17 | GetURIs(gid string) (infos []URIInfo, err error) 18 | GetFiles(gid string) (infos []FileInfo, err error) 19 | GetPeers(gid string) (infos []PeerInfo, err error) 20 | GetServers(gid string) (infos []ServerInfo, err error) 21 | TellActive(keys ...string) (infos []StatusInfo, err error) 22 | TellWaiting(offset, num int, keys ...string) (infos []StatusInfo, err error) 23 | TellStopped(offset, num int, keys ...string) (infos []StatusInfo, err error) 24 | ChangePosition(gid string, pos int, how string) (p int, err error) 25 | ChangeURI(gid string, fileindex int, delUris []string, addUris []string, position ...int) (p []int, err error) 26 | GetOption(gid string) (m Option, err error) 27 | ChangeOption(gid string, option Option) (ok string, err error) 28 | GetGlobalOption() (m Option, err error) 29 | ChangeGlobalOption(options Option) (ok string, err error) 30 | GetGlobalStat() (info GlobalStatInfo, err error) 31 | PurgeDownloadResult() (ok string, err error) 32 | RemoveDownloadResult(gid string) (ok string, err error) 33 | GetVersion() (info VersionInfo, err error) 34 | GetSessionInfo() (info SessionInfo, err error) 35 | Shutdown() (ok string, err error) 36 | ForceShutdown() (ok string, err error) 37 | SaveSession() (ok string, err error) 38 | Multicall(methods []Method) (r []interface{}, err error) 39 | ListMethods() (methods []string, err error) 40 | } 41 | -------------------------------------------------------------------------------- /pkg/auth/auth_test.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "io/ioutil" 5 | "net/http" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestSignURI(t *testing.T) { 14 | asserts := assert.New(t) 15 | General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} 16 | 17 | // 成功 18 | { 19 | sign, err := SignURI(General, "/api/v3/something?id=1", 0) 20 | asserts.NoError(err) 21 | queries := sign.Query() 22 | asserts.Equal("1", queries.Get("id")) 23 | asserts.NotEmpty(queries.Get("sign")) 24 | } 25 | 26 | // URI解码失败 27 | { 28 | sign, err := SignURI(General, "://dg.;'f]gh./'", 0) 29 | asserts.Error(err) 30 | asserts.Nil(sign) 31 | } 32 | } 33 | 34 | func TestCheckURI(t *testing.T) { 35 | asserts := assert.New(t) 36 | General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} 37 | 38 | // 成功 39 | { 40 | sign, err := SignURI(General, "/api/ok?if=sdf&fd=go", 10) 41 | asserts.NoError(err) 42 | asserts.NoError(CheckURI(General, sign)) 43 | } 44 | 45 | // 过期 46 | { 47 | sign, err := SignURI(General, "/api/ok?if=sdf&fd=go", -1) 48 | asserts.NoError(err) 49 | asserts.Error(CheckURI(General, sign)) 50 | } 51 | } 52 | 53 | func TestSignRequest(t *testing.T) { 54 | asserts := assert.New(t) 55 | General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} 56 | 57 | // 非上传请求 58 | { 59 | req, err := http.NewRequest("POST", "http://127.0.0.1/api/v3/slave/upload", strings.NewReader("I am body.")) 60 | asserts.NoError(err) 61 | req = SignRequest(General, req, 0) 62 | asserts.NotEmpty(req.Header["Authorization"]) 63 | } 64 | 65 | // 上传请求 66 | { 67 | req, err := http.NewRequest( 68 | "POST", 69 | "http://127.0.0.1/api/v3/slave/upload", 70 | strings.NewReader("I am body."), 71 | ) 72 | asserts.NoError(err) 73 | req.Header["X-Policy"] = []string{"I am Policy"} 74 | req = SignRequest(General, req, 10) 75 | asserts.NotEmpty(req.Header["Authorization"]) 76 | } 77 | } 78 | 79 | func TestCheckRequest(t *testing.T) { 80 | asserts := assert.New(t) 81 | General = HMACAuth{SecretKey: []byte(util.RandStringRunes(256))} 82 | 83 | // 非上传请求 验证成功 84 | { 85 | req, err := http.NewRequest( 86 | "POST", 87 | "http://127.0.0.1/api/v3/upload", 88 | strings.NewReader("I am body."), 89 | ) 90 | asserts.NoError(err) 91 | req = SignRequest(General, req, 0) 92 | err = CheckRequest(General, req) 93 | asserts.NoError(err) 94 | } 95 | 96 | // 上传请求 验证成功 97 | { 98 | req, err := http.NewRequest( 99 | "POST", 100 | "http://127.0.0.1/api/v3/upload", 101 | strings.NewReader("I am body."), 102 | ) 103 | asserts.NoError(err) 104 | req.Header["X-Policy"] = []string{"I am Policy"} 105 | req = SignRequest(General, req, 0) 106 | err = CheckRequest(General, req) 107 | asserts.NoError(err) 108 | } 109 | 110 | // 非上传请求 失败 111 | { 112 | req, err := http.NewRequest( 113 | "POST", 114 | "http://127.0.0.1/api/v3/upload", 115 | strings.NewReader("I am body."), 116 | ) 117 | asserts.NoError(err) 118 | req = SignRequest(General, req, 0) 119 | req.Body = ioutil.NopCloser(strings.NewReader("2333")) 120 | err = CheckRequest(General, req) 121 | asserts.Error(err) 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /pkg/auth/hmac.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "crypto/hmac" 5 | "crypto/sha256" 6 | "encoding/base64" 7 | "io" 8 | "strconv" 9 | "strings" 10 | "time" 11 | ) 12 | 13 | // HMACAuth HMAC算法鉴权 14 | type HMACAuth struct { 15 | SecretKey []byte 16 | } 17 | 18 | // Sign 对给定Body生成expires后失效的签名,expires为过期时间戳, 19 | // 填写为0表示不限制有效期 20 | func (auth HMACAuth) Sign(body string, expires int64) string { 21 | h := hmac.New(sha256.New, auth.SecretKey) 22 | expireTimeStamp := strconv.FormatInt(expires, 10) 23 | _, err := io.WriteString(h, body+":"+expireTimeStamp) 24 | if err != nil { 25 | return "" 26 | } 27 | 28 | return base64.URLEncoding.EncodeToString(h.Sum(nil)) + ":" + expireTimeStamp 29 | } 30 | 31 | // Check 对给定Body和Sign进行鉴权,包括对expires的检查 32 | func (auth HMACAuth) Check(body string, sign string) error { 33 | signSlice := strings.Split(sign, ":") 34 | // 如果未携带expires字段 35 | if signSlice[len(signSlice)-1] == "" { 36 | return ErrAuthFailed 37 | } 38 | 39 | // 验证是否过期 40 | expires, err := strconv.ParseInt(signSlice[len(signSlice)-1], 10, 64) 41 | if err != nil { 42 | return ErrAuthFailed.WithError(err) 43 | } 44 | // 如果签名过期 45 | if expires < time.Now().Unix() && expires != 0 { 46 | return ErrExpired 47 | } 48 | 49 | // 验证签名 50 | if auth.Sign(body, expires) != sign { 51 | return ErrAuthFailed 52 | } 53 | return nil 54 | } 55 | -------------------------------------------------------------------------------- /pkg/auth/hmac_test.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "testing" 7 | "time" 8 | 9 | "github.com/DATA-DOG/go-sqlmock" 10 | model "github.com/cloudreve/Cloudreve/v3/models" 11 | "github.com/cloudreve/Cloudreve/v3/pkg/conf" 12 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 13 | "github.com/gin-gonic/gin" 14 | "github.com/jinzhu/gorm" 15 | "github.com/stretchr/testify/assert" 16 | ) 17 | 18 | var mock sqlmock.Sqlmock 19 | 20 | func TestMain(m *testing.M) { 21 | // 设置gin为测试模式 22 | gin.SetMode(gin.TestMode) 23 | 24 | // 初始化sqlmock 25 | var db *sql.DB 26 | var err error 27 | db, mock, err = sqlmock.New() 28 | if err != nil { 29 | panic("An error was not expected when opening a stub database connection") 30 | } 31 | 32 | mockDB, _ := gorm.Open("mysql", db) 33 | model.DB = mockDB 34 | defer db.Close() 35 | 36 | m.Run() 37 | } 38 | 39 | func TestHMACAuth_Sign(t *testing.T) { 40 | asserts := assert.New(t) 41 | auth := HMACAuth{ 42 | SecretKey: []byte(util.RandStringRunes(256)), 43 | } 44 | 45 | asserts.NotEmpty(auth.Sign("content", 0)) 46 | } 47 | 48 | func TestHMACAuth_Check(t *testing.T) { 49 | asserts := assert.New(t) 50 | auth := HMACAuth{ 51 | SecretKey: []byte(util.RandStringRunes(256)), 52 | } 53 | 54 | // 正常,永不过期 55 | { 56 | sign := auth.Sign("content", 0) 57 | asserts.NoError(auth.Check("content", sign)) 58 | } 59 | 60 | // 过期 61 | { 62 | sign := auth.Sign("content", 1) 63 | asserts.Error(auth.Check("content", sign)) 64 | } 65 | 66 | // 签名格式错误 67 | { 68 | sign := auth.Sign("content", 1) 69 | asserts.Error(auth.Check("content", sign+":")) 70 | } 71 | 72 | // 过期日期格式错误 73 | { 74 | asserts.Error(auth.Check("content", "ErrAuthFailed:ErrAuthFailed")) 75 | } 76 | 77 | // 签名有误 78 | { 79 | asserts.Error(auth.Check("content", fmt.Sprintf("sign:%d", time.Now().Unix()+10))) 80 | } 81 | } 82 | 83 | func TestInit(t *testing.T) { 84 | asserts := assert.New(t) 85 | mock.ExpectQuery("SELECT(.+)").WillReturnRows(sqlmock.NewRows([]string{"id", "value"}).AddRow(1, "12312312312312")) 86 | Init() 87 | asserts.NoError(mock.ExpectationsWereMet()) 88 | 89 | // slave模式 90 | conf.SystemConfig.Mode = "slave" 91 | asserts.Panics(func() { 92 | Init() 93 | }) 94 | } 95 | -------------------------------------------------------------------------------- /pkg/authn/auth.go: -------------------------------------------------------------------------------- 1 | package authn 2 | 3 | import ( 4 | model "github.com/cloudreve/Cloudreve/v3/models" 5 | "github.com/duo-labs/webauthn/webauthn" 6 | ) 7 | 8 | // NewAuthnInstance 新建Authn实例 9 | func NewAuthnInstance() (*webauthn.WebAuthn, error) { 10 | base := model.GetSiteURL() 11 | return webauthn.New(&webauthn.Config{ 12 | RPDisplayName: model.GetSettingByName("siteName"), // Display Name for your site 13 | RPID: base.Hostname(), // Generally the FQDN for your site 14 | RPOrigin: base.String(), // The origin URL for WebAuthn requests 15 | }) 16 | } 17 | -------------------------------------------------------------------------------- /pkg/authn/auth_test.go: -------------------------------------------------------------------------------- 1 | package authn 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/cloudreve/Cloudreve/v3/pkg/cache" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestInit(t *testing.T) { 11 | asserts := assert.New(t) 12 | cache.Set("setting_siteURL", "http://cloudreve.org", 0) 13 | cache.Set("setting_siteName", "Cloudreve", 0) 14 | res, err := NewAuthnInstance() 15 | asserts.NotNil(res) 16 | asserts.NoError(err) 17 | } 18 | -------------------------------------------------------------------------------- /pkg/cache/driver.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "github.com/cloudreve/Cloudreve/v3/pkg/conf" 5 | "github.com/gin-gonic/gin" 6 | ) 7 | 8 | // Store 缓存存储器 9 | var Store Driver = NewMemoStore() 10 | 11 | // Init 初始化缓存 12 | func Init() { 13 | //Store = NewRedisStore(10, "tcp", "127.0.0.1:6379", "", "0") 14 | //return 15 | if conf.RedisConfig.Server != "" && gin.Mode() != gin.TestMode { 16 | Store = NewRedisStore( 17 | 10, 18 | conf.RedisConfig.Network, 19 | conf.RedisConfig.Server, 20 | conf.RedisConfig.Password, 21 | conf.RedisConfig.DB, 22 | ) 23 | } 24 | } 25 | 26 | // Driver 键值缓存存储容器 27 | type Driver interface { 28 | // 设置值,ttl为过期时间,单位为秒 29 | Set(key string, value interface{}, ttl int) error 30 | 31 | // 取值,并返回是否成功 32 | Get(key string) (interface{}, bool) 33 | 34 | // 批量取值,返回成功取值的map即不存在的值 35 | Gets(keys []string, prefix string) (map[string]interface{}, []string) 36 | 37 | // 批量设置值,所有的key都会加上prefix前缀 38 | Sets(values map[string]interface{}, prefix string) error 39 | 40 | // 删除值 41 | Delete(keys []string, prefix string) error 42 | } 43 | 44 | // Set 设置缓存值 45 | func Set(key string, value interface{}, ttl int) error { 46 | return Store.Set(key, value, ttl) 47 | } 48 | 49 | // Get 获取缓存值 50 | func Get(key string) (interface{}, bool) { 51 | return Store.Get(key) 52 | } 53 | 54 | // Deletes 删除值 55 | func Deletes(keys []string, prefix string) error { 56 | return Store.Delete(keys, prefix) 57 | } 58 | 59 | // GetSettings 根据名称批量获取设置项缓存 60 | func GetSettings(keys []string, prefix string) (map[string]string, []string) { 61 | raw, miss := Store.Gets(keys, prefix) 62 | 63 | res := make(map[string]string, len(raw)) 64 | for k, v := range raw { 65 | res[k] = v.(string) 66 | } 67 | 68 | return res, miss 69 | } 70 | 71 | // SetSettings 批量设置站点设置缓存 72 | func SetSettings(values map[string]string, prefix string) error { 73 | var toBeSet = make(map[string]interface{}, len(values)) 74 | for key, value := range values { 75 | toBeSet[key] = interface{}(value) 76 | } 77 | return Store.Sets(toBeSet, prefix) 78 | } 79 | -------------------------------------------------------------------------------- /pkg/cache/driver_test.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | func TestSet(t *testing.T) { 9 | asserts := assert.New(t) 10 | 11 | asserts.NoError(Set("123", "321", -1)) 12 | } 13 | 14 | func TestGet(t *testing.T) { 15 | asserts := assert.New(t) 16 | asserts.NoError(Set("123", "321", -1)) 17 | 18 | value, ok := Get("123") 19 | asserts.True(ok) 20 | asserts.Equal("321", value) 21 | 22 | value, ok = Get("not_exist") 23 | asserts.False(ok) 24 | } 25 | 26 | func TestDeletes(t *testing.T) { 27 | asserts := assert.New(t) 28 | asserts.NoError(Set("123", "321", -1)) 29 | err := Deletes([]string{"123"}, "") 30 | asserts.NoError(err) 31 | _, exist := Get("123") 32 | asserts.False(exist) 33 | } 34 | 35 | func TestGetSettings(t *testing.T) { 36 | asserts := assert.New(t) 37 | asserts.NoError(Set("test_1", "1", -1)) 38 | 39 | values, missed := GetSettings([]string{"1", "2"}, "test_") 40 | asserts.Equal(map[string]string{"1": "1"}, values) 41 | asserts.Equal([]string{"2"}, missed) 42 | } 43 | 44 | func TestSetSettings(t *testing.T) { 45 | asserts := assert.New(t) 46 | 47 | err := SetSettings(map[string]string{"3": "3", "4": "4"}, "test_") 48 | asserts.NoError(err) 49 | value1, _ := Get("test_3") 50 | value2, _ := Get("test_4") 51 | asserts.Equal("3", value1) 52 | asserts.Equal("4", value2) 53 | } 54 | 55 | func TestInit(t *testing.T) { 56 | asserts := assert.New(t) 57 | 58 | asserts.NotPanics(func() { 59 | Init() 60 | }) 61 | } 62 | -------------------------------------------------------------------------------- /pkg/cache/memo.go: -------------------------------------------------------------------------------- 1 | package cache 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | 7 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 8 | ) 9 | 10 | // MemoStore 内存存储驱动 11 | type MemoStore struct { 12 | Store *sync.Map 13 | } 14 | 15 | // item 存储的对象 16 | type itemWithTTL struct { 17 | expires int64 18 | value interface{} 19 | } 20 | 21 | func newItem(value interface{}, expires int) itemWithTTL { 22 | expires64 := int64(expires) 23 | if expires > 0 { 24 | expires64 = time.Now().Unix() + expires64 25 | } 26 | return itemWithTTL{ 27 | value: value, 28 | expires: expires64, 29 | } 30 | } 31 | 32 | // getValue 从itemWithTTL中取值 33 | func getValue(item interface{}, ok bool) (interface{}, bool) { 34 | if !ok { 35 | return nil, ok 36 | } 37 | 38 | var itemObj itemWithTTL 39 | if itemObj, ok = item.(itemWithTTL); !ok { 40 | return item, true 41 | } 42 | 43 | if itemObj.expires > 0 && itemObj.expires < time.Now().Unix() { 44 | return nil, false 45 | } 46 | 47 | return itemObj.value, ok 48 | 49 | } 50 | 51 | // GarbageCollect 回收已过期的缓存 52 | func (store *MemoStore) GarbageCollect() { 53 | store.Store.Range(func(key, value interface{}) bool { 54 | if item, ok := value.(itemWithTTL); ok { 55 | if item.expires > 0 && item.expires < time.Now().Unix() { 56 | util.Log().Debug("回收垃圾[%s]", key.(string)) 57 | store.Store.Delete(key) 58 | } 59 | } 60 | return true 61 | }) 62 | } 63 | 64 | // NewMemoStore 新建内存存储 65 | func NewMemoStore() *MemoStore { 66 | return &MemoStore{ 67 | Store: &sync.Map{}, 68 | } 69 | } 70 | 71 | // Set 存储值 72 | func (store *MemoStore) Set(key string, value interface{}, ttl int) error { 73 | store.Store.Store(key, newItem(value, ttl)) 74 | return nil 75 | } 76 | 77 | // Get 取值 78 | func (store *MemoStore) Get(key string) (interface{}, bool) { 79 | return getValue(store.Store.Load(key)) 80 | } 81 | 82 | // Gets 批量取值 83 | func (store *MemoStore) Gets(keys []string, prefix string) (map[string]interface{}, []string) { 84 | var res = make(map[string]interface{}) 85 | var notFound = make([]string, 0, len(keys)) 86 | 87 | for _, key := range keys { 88 | if value, ok := getValue(store.Store.Load(prefix + key)); ok { 89 | res[key] = value 90 | } else { 91 | notFound = append(notFound, key) 92 | } 93 | } 94 | 95 | return res, notFound 96 | } 97 | 98 | // Sets 批量设置值 99 | func (store *MemoStore) Sets(values map[string]interface{}, prefix string) error { 100 | for key, value := range values { 101 | store.Store.Store(prefix+key, value) 102 | } 103 | return nil 104 | } 105 | 106 | // Delete 批量删除值 107 | func (store *MemoStore) Delete(keys []string, prefix string) error { 108 | for _, key := range keys { 109 | store.Store.Delete(prefix + key) 110 | } 111 | return nil 112 | } 113 | -------------------------------------------------------------------------------- /pkg/conf/conf_test.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import ( 4 | "io/ioutil" 5 | "os" 6 | "testing" 7 | 8 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | // 测试Init日志路径错误 13 | func TestInitPanic(t *testing.T) { 14 | asserts := assert.New(t) 15 | 16 | // 日志路径不存在时 17 | asserts.NotPanics(func() { 18 | Init("not/exist/path/conf.ini") 19 | }) 20 | 21 | asserts.True(util.Exists("not/exist/path/conf.ini")) 22 | 23 | } 24 | 25 | // TestInitDelimiterNotFound 日志路径存在但 Key 格式错误时 26 | func TestInitDelimiterNotFound(t *testing.T) { 27 | asserts := assert.New(t) 28 | testCase := `[Database] 29 | Type = mysql 30 | User = root 31 | Password233root 32 | Host = 127.0.0.1:3306 33 | Name = v3 34 | TablePrefix = v3_` 35 | err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644) 36 | defer func() { err = os.Remove("testConf.ini") }() 37 | if err != nil { 38 | panic(err) 39 | } 40 | asserts.Panics(func() { 41 | Init("testConf.ini") 42 | }) 43 | } 44 | 45 | // TestInitNoPanic 日志路径存在且合法时 46 | func TestInitNoPanic(t *testing.T) { 47 | asserts := assert.New(t) 48 | testCase := ` 49 | [System] 50 | Listen = 3000 51 | HashIDSalt = 1 52 | 53 | [Database] 54 | Type = mysql 55 | User = root 56 | Password = root 57 | Host = 127.0.0.1:3306 58 | Name = v3 59 | TablePrefix = v3_` 60 | err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644) 61 | defer func() { err = os.Remove("testConf.ini") }() 62 | if err != nil { 63 | panic(err) 64 | } 65 | asserts.NotPanics(func() { 66 | Init("testConf.ini") 67 | }) 68 | } 69 | 70 | func TestMapSection(t *testing.T) { 71 | asserts := assert.New(t) 72 | 73 | //正常情况 74 | testCase := ` 75 | [System] 76 | Listen = 3000 77 | HashIDSalt = 1 78 | 79 | [Database] 80 | Type = mysql 81 | User = root 82 | Password:root 83 | Host = 127.0.0.1:3306 84 | Name = v3 85 | TablePrefix = v3_` 86 | err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644) 87 | defer func() { err = os.Remove("testConf.ini") }() 88 | if err != nil { 89 | panic(err) 90 | } 91 | Init("testConf.ini") 92 | err = mapSection("Database", DatabaseConfig) 93 | asserts.NoError(err) 94 | 95 | } 96 | -------------------------------------------------------------------------------- /pkg/conf/defaults.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | import "github.com/mojocn/base64Captcha" 4 | 5 | // RedisConfig Redis服务器配置 6 | var RedisConfig = &redis{ 7 | Network: "tcp", 8 | Server: "", 9 | Password: "", 10 | DB: "0", 11 | } 12 | 13 | // DatabaseConfig 数据库配置 14 | var DatabaseConfig = &database{ 15 | Type: "UNSET", 16 | DBFile: "cloudreve.db", 17 | Port: 3306, 18 | } 19 | 20 | // SystemConfig 系统公用配置 21 | var SystemConfig = &system{ 22 | Debug: false, 23 | Mode: "master", 24 | Listen: ":5212", 25 | } 26 | 27 | // CaptchaConfig 验证码配置 28 | var CaptchaConfig = &captcha{ 29 | Height: 60, 30 | Width: 240, 31 | Mode: 3, 32 | ComplexOfNoiseText: base64Captcha.CaptchaComplexLower, 33 | ComplexOfNoiseDot: base64Captcha.CaptchaComplexLower, 34 | IsShowHollowLine: false, 35 | IsShowNoiseDot: false, 36 | IsShowNoiseText: false, 37 | IsShowSlimeLine: false, 38 | IsShowSineLine: false, 39 | CaptchaLen: 6, 40 | } 41 | 42 | // CORSConfig 跨域配置 43 | var CORSConfig = &cors{ 44 | AllowOrigins: []string{"UNSET"}, 45 | AllowMethods: []string{"PUT", "POST", "GET", "OPTIONS"}, 46 | AllowHeaders: []string{"Cookie", "X-Policy", "Authorization", "Content-Length", "Content-Type", "X-Path", "X-FileName"}, 47 | AllowCredentials: false, 48 | ExposeHeaders: nil, 49 | } 50 | 51 | // ThumbConfig 缩略图配置 52 | var ThumbConfig = &thumb{ 53 | MaxWidth: 400, 54 | MaxHeight: 300, 55 | FileSuffix: "._thumb", 56 | } 57 | 58 | // SlaveConfig 从机配置 59 | var SlaveConfig = &slave{ 60 | CallbackTimeout: 20, 61 | SignatureTTL: 60, 62 | } 63 | 64 | var SSLConfig = &ssl{ 65 | Listen: ":443", 66 | CertPath: "", 67 | KeyPath: "", 68 | } 69 | 70 | var UnixConfig = &unix{ 71 | Listen: "", 72 | } 73 | -------------------------------------------------------------------------------- /pkg/conf/version.go: -------------------------------------------------------------------------------- 1 | package conf 2 | 3 | // BackendVersion 当前后端版本号 4 | var BackendVersion = "3.2.0" 5 | 6 | // RequiredDBVersion 与当前版本匹配的数据库版本 7 | var RequiredDBVersion = "3.2.0" 8 | 9 | // RequiredStaticVersion 与当前版本匹配的静态资源版本 10 | var RequiredStaticVersion = "3.2.0" 11 | 12 | // IsPro 是否为Pro版本 13 | var IsPro = "false" 14 | 15 | // LastCommit 最后commit id 16 | var LastCommit = "a11f819" 17 | -------------------------------------------------------------------------------- /pkg/crontab/collect.go: -------------------------------------------------------------------------------- 1 | package crontab 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "strings" 7 | "time" 8 | 9 | model "github.com/cloudreve/Cloudreve/v3/models" 10 | "github.com/cloudreve/Cloudreve/v3/pkg/cache" 11 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 12 | ) 13 | 14 | func garbageCollect() { 15 | // 清理打包下载产生的临时文件 16 | collectArchiveFile() 17 | 18 | // 清理过期的内置内存缓存 19 | if store, ok := cache.Store.(*cache.MemoStore); ok { 20 | collectCache(store) 21 | } 22 | 23 | util.Log().Info("定时任务 [cron_garbage_collect] 执行完毕") 24 | } 25 | 26 | func collectArchiveFile() { 27 | // 读取有效期、目录设置 28 | tempPath := util.RelativePath(model.GetSettingByName("temp_path")) 29 | expires := model.GetIntSetting("download_timeout", 30) 30 | 31 | // 列出文件 32 | root := filepath.Join(tempPath, "archive") 33 | err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { 34 | if err == nil && !info.IsDir() && 35 | strings.HasPrefix(filepath.Base(path), "archive_") && 36 | time.Now().Sub(info.ModTime()).Seconds() > float64(expires) { 37 | util.Log().Debug("删除过期打包下载临时文件 [%s]", path) 38 | // 删除符合条件的文件 39 | if err := os.Remove(path); err != nil { 40 | util.Log().Debug("临时文件 [%s] 删除失败 , %s", path, err) 41 | } 42 | } 43 | return nil 44 | }) 45 | 46 | if err != nil { 47 | util.Log().Debug("[定时任务] 无法列取临时打包目录") 48 | } 49 | 50 | } 51 | 52 | func collectCache(store *cache.MemoStore) { 53 | util.Log().Debug("清理内存缓存") 54 | store.GarbageCollect() 55 | } 56 | -------------------------------------------------------------------------------- /pkg/crontab/init.go: -------------------------------------------------------------------------------- 1 | package crontab 2 | 3 | import ( 4 | model "github.com/cloudreve/Cloudreve/v3/models" 5 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 6 | "github.com/robfig/cron/v3" 7 | ) 8 | 9 | // Cron 定时任务 10 | var Cron *cron.Cron 11 | 12 | // Reload 重新启动定时任务 13 | func Reload() { 14 | if Cron != nil { 15 | Cron.Stop() 16 | } 17 | Init() 18 | } 19 | 20 | // Init 初始化定时任务 21 | func Init() { 22 | util.Log().Info("初始化定时任务...") 23 | // 读取cron日程设置 24 | options := model.GetSettingByNames("cron_garbage_collect") 25 | Cron := cron.New() 26 | for k, v := range options { 27 | var handler func() 28 | switch k { 29 | case "cron_garbage_collect": 30 | handler = garbageCollect 31 | default: 32 | util.Log().Warning("未知定时任务类型 [%s],跳过", k) 33 | continue 34 | } 35 | 36 | if _, err := Cron.AddFunc(v, handler); err != nil { 37 | util.Log().Warning("无法启动定时任务 [%s] , %s", k, err) 38 | } 39 | 40 | } 41 | Cron.Start() 42 | } 43 | -------------------------------------------------------------------------------- /pkg/email/init.go: -------------------------------------------------------------------------------- 1 | package email 2 | 3 | import ( 4 | "sync" 5 | 6 | model "github.com/cloudreve/Cloudreve/v3/models" 7 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 8 | ) 9 | 10 | // Client 默认的邮件发送客户端 11 | var Client Driver 12 | 13 | // Lock 读写锁 14 | var Lock sync.RWMutex 15 | 16 | // Init 初始化 17 | func Init() { 18 | util.Log().Debug("邮件队列初始化") 19 | Lock.Lock() 20 | defer Lock.Unlock() 21 | 22 | if Client != nil { 23 | Client.Close() 24 | } 25 | 26 | // 读取SMTP设置 27 | options := model.GetSettingByNames( 28 | "fromName", 29 | "fromAdress", 30 | "smtpHost", 31 | "replyTo", 32 | "smtpUser", 33 | "smtpPass", 34 | "smtpEncryption", 35 | ) 36 | port := model.GetIntSetting("smtpPort", 25) 37 | keepAlive := model.GetIntSetting("mail_keepalive", 30) 38 | 39 | client := NewSMTPClient(SMTPConfig{ 40 | Name: options["fromName"], 41 | Address: options["fromAdress"], 42 | ReplyTo: options["replyTo"], 43 | Host: options["smtpHost"], 44 | Port: port, 45 | User: options["smtpUser"], 46 | Password: options["smtpPass"], 47 | Keepalive: keepAlive, 48 | Encryption: model.IsTrueVal(options["smtpEncryption"]), 49 | }) 50 | 51 | Client = client 52 | } 53 | -------------------------------------------------------------------------------- /pkg/email/mail.go: -------------------------------------------------------------------------------- 1 | package email 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | ) 7 | 8 | // Driver 邮件发送驱动 9 | type Driver interface { 10 | // Close 关闭驱动 11 | Close() 12 | // Send 发送邮件 13 | Send(to, title, body string) error 14 | } 15 | 16 | var ( 17 | // ErrChanNotOpen 邮件队列未开启 18 | ErrChanNotOpen = errors.New("邮件队列未开启") 19 | // ErrNoActiveDriver 无可用邮件发送服务 20 | ErrNoActiveDriver = errors.New("无可用邮件发送服务") 21 | ) 22 | 23 | // Send 发送邮件 24 | func Send(to, title, body string) error { 25 | // 忽略通过QQ登录的邮箱 26 | if strings.HasSuffix(to, "@login.qq.com") { 27 | return nil 28 | } 29 | 30 | Lock.RLock() 31 | defer Lock.RUnlock() 32 | 33 | if Client == nil { 34 | return ErrNoActiveDriver 35 | } 36 | 37 | return Client.Send(to, title, body) 38 | } 39 | -------------------------------------------------------------------------------- /pkg/email/smtp.go: -------------------------------------------------------------------------------- 1 | package email 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 7 | "github.com/go-mail/mail" 8 | ) 9 | 10 | // SMTP SMTP协议发送邮件 11 | type SMTP struct { 12 | Config SMTPConfig 13 | ch chan *mail.Message 14 | chOpen bool 15 | } 16 | 17 | // SMTPConfig SMTP发送配置 18 | type SMTPConfig struct { 19 | Name string // 发送者名 20 | Address string // 发送者地址 21 | ReplyTo string // 回复地址 22 | Host string // 服务器主机名 23 | Port int // 服务器端口 24 | User string // 用户名 25 | Password string // 密码 26 | Encryption bool // 是否启用加密 27 | Keepalive int // SMTP 连接保留时长 28 | } 29 | 30 | // NewSMTPClient 新建SMTP发送队列 31 | func NewSMTPClient(config SMTPConfig) *SMTP { 32 | client := &SMTP{ 33 | Config: config, 34 | ch: make(chan *mail.Message, 30), 35 | chOpen: false, 36 | } 37 | 38 | client.Init() 39 | 40 | return client 41 | } 42 | 43 | // Send 发送邮件 44 | func (client *SMTP) Send(to, title, body string) error { 45 | if !client.chOpen { 46 | return ErrChanNotOpen 47 | } 48 | m := mail.NewMessage() 49 | m.SetAddressHeader("From", client.Config.Address, client.Config.Name) 50 | m.SetAddressHeader("Reply-To", client.Config.ReplyTo, client.Config.Name) 51 | m.SetHeader("To", to) 52 | m.SetHeader("Subject", title) 53 | m.SetBody("text/html", body) 54 | client.ch <- m 55 | return nil 56 | } 57 | 58 | // Close 关闭发送队列 59 | func (client *SMTP) Close() { 60 | if client.ch != nil { 61 | close(client.ch) 62 | } 63 | } 64 | 65 | // Init 初始化发送队列 66 | func (client *SMTP) Init() { 67 | go func() { 68 | defer func() { 69 | if err := recover(); err != nil { 70 | client.chOpen = false 71 | util.Log().Error("邮件发送队列出现异常, %s ,10 秒后重置", err) 72 | time.Sleep(time.Duration(10) * time.Second) 73 | client.Init() 74 | } 75 | }() 76 | 77 | d := mail.NewDialer(client.Config.Host, client.Config.Port, client.Config.User, client.Config.Password) 78 | d.Timeout = time.Duration(client.Config.Keepalive+5) * time.Second 79 | client.chOpen = true 80 | 81 | // 是否启用 SSL 82 | if client.Config.Encryption { 83 | d.SSL = true 84 | } 85 | 86 | var s mail.SendCloser 87 | var err error 88 | open := false 89 | for { 90 | select { 91 | case m, ok := <-client.ch: 92 | if !ok { 93 | util.Log().Debug("邮件队列关闭") 94 | client.chOpen = false 95 | return 96 | } 97 | if !open { 98 | if s, err = d.Dial(); err != nil { 99 | panic(err) 100 | } 101 | open = true 102 | } 103 | if err := mail.Send(s, m); err != nil { 104 | util.Log().Warning("邮件发送失败, %s", err) 105 | } else { 106 | util.Log().Debug("邮件已发送") 107 | } 108 | // 长时间没有新邮件,则关闭SMTP连接 109 | case <-time.After(time.Duration(client.Config.Keepalive) * time.Second): 110 | if open { 111 | if err := s.Close(); err != nil { 112 | util.Log().Warning("无法关闭 SMTP 连接 %s", err) 113 | } 114 | open = false 115 | } 116 | } 117 | } 118 | }() 119 | } 120 | -------------------------------------------------------------------------------- /pkg/email/template.go: -------------------------------------------------------------------------------- 1 | package email 2 | 3 | import ( 4 | "fmt" 5 | 6 | model "github.com/cloudreve/Cloudreve/v3/models" 7 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 8 | ) 9 | 10 | // NewActivationEmail 新建激活邮件 11 | func NewActivationEmail(userName, activateURL string) (string, string) { 12 | options := model.GetSettingByNames("siteName", "siteURL", "siteTitle", "mail_activation_template") 13 | replace := map[string]string{ 14 | "{siteTitle}": options["siteName"], 15 | "{userName}": userName, 16 | "{activationUrl}": activateURL, 17 | "{siteUrl}": options["siteURL"], 18 | "{siteSecTitle}": options["siteTitle"], 19 | } 20 | return fmt.Sprintf("【%s】注册激活", options["siteName"]), 21 | util.Replace(replace, options["mail_activation_template"]) 22 | } 23 | 24 | // NewResetEmail 新建重设密码邮件 25 | func NewResetEmail(userName, resetURL string) (string, string) { 26 | options := model.GetSettingByNames("siteName", "siteURL", "siteTitle", "mail_reset_pwd_template") 27 | replace := map[string]string{ 28 | "{siteTitle}": options["siteName"], 29 | "{userName}": userName, 30 | "{resetUrl}": resetURL, 31 | "{siteUrl}": options["siteURL"], 32 | "{siteSecTitle}": options["siteTitle"], 33 | } 34 | return fmt.Sprintf("【%s】密码重置", options["siteName"]), 35 | util.Replace(replace, options["mail_reset_pwd_template"]) 36 | } 37 | -------------------------------------------------------------------------------- /pkg/filesystem/driver/local/file.go: -------------------------------------------------------------------------------- 1 | package local 2 | 3 | import ( 4 | "io" 5 | ) 6 | 7 | // FileStream 用户传来的文件 8 | type FileStream struct { 9 | File io.ReadCloser 10 | Size uint64 11 | VirtualPath string 12 | Name string 13 | MIMEType string 14 | } 15 | 16 | func (file FileStream) Read(p []byte) (n int, err error) { 17 | return file.File.Read(p) 18 | } 19 | 20 | func (file FileStream) GetMIMEType() string { 21 | return file.MIMEType 22 | } 23 | 24 | func (file FileStream) GetSize() uint64 { 25 | return file.Size 26 | } 27 | 28 | func (file FileStream) Close() error { 29 | return file.File.Close() 30 | } 31 | 32 | func (file FileStream) GetFileName() string { 33 | return file.Name 34 | } 35 | 36 | func (file FileStream) GetVirtualPath() string { 37 | return file.VirtualPath 38 | } 39 | -------------------------------------------------------------------------------- /pkg/filesystem/driver/local/file_test.go: -------------------------------------------------------------------------------- 1 | package local 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "io/ioutil" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | func TestFileStream_GetFileName(t *testing.T) { 11 | asserts := assert.New(t) 12 | file := FileStream{Name: "123"} 13 | asserts.Equal("123", file.GetFileName()) 14 | } 15 | 16 | func TestFileStream_GetMIMEType(t *testing.T) { 17 | asserts := assert.New(t) 18 | file := FileStream{MIMEType: "123"} 19 | asserts.Equal("123", file.GetMIMEType()) 20 | } 21 | 22 | func TestFileStream_GetSize(t *testing.T) { 23 | asserts := assert.New(t) 24 | file := FileStream{Size: 123} 25 | asserts.Equal(uint64(123), file.GetSize()) 26 | } 27 | 28 | func TestFileStream_Read(t *testing.T) { 29 | asserts := assert.New(t) 30 | file := FileStream{ 31 | File: ioutil.NopCloser(strings.NewReader("123")), 32 | } 33 | var p = make([]byte, 3) 34 | { 35 | n, err := file.Read(p) 36 | asserts.Equal(3, n) 37 | asserts.NoError(err) 38 | } 39 | } 40 | 41 | func TestFileStream_Close(t *testing.T) { 42 | asserts := assert.New(t) 43 | file := FileStream{ 44 | File: ioutil.NopCloser(strings.NewReader("123")), 45 | } 46 | err := file.Close() 47 | asserts.NoError(err) 48 | } 49 | -------------------------------------------------------------------------------- /pkg/filesystem/driver/onedrive/client.go: -------------------------------------------------------------------------------- 1 | package onedrive 2 | 3 | import ( 4 | "errors" 5 | 6 | model "github.com/cloudreve/Cloudreve/v3/models" 7 | "github.com/cloudreve/Cloudreve/v3/pkg/request" 8 | ) 9 | 10 | var ( 11 | // ErrAuthEndpoint 无法解析授权端点地址 12 | ErrAuthEndpoint = errors.New("无法解析授权端点地址") 13 | // ErrInvalidRefreshToken 上传策略无有效的RefreshToken 14 | ErrInvalidRefreshToken = errors.New("上传策略无有效的RefreshToken") 15 | // ErrDeleteFile 无法删除文件 16 | ErrDeleteFile = errors.New("无法删除文件") 17 | // ErrClientCanceled 客户端取消操作 18 | ErrClientCanceled = errors.New("客户端取消操作") 19 | ) 20 | 21 | // Client OneDrive客户端 22 | type Client struct { 23 | Endpoints *Endpoints 24 | Policy *model.Policy 25 | Credential *Credential 26 | 27 | ClientID string 28 | ClientSecret string 29 | Redirect string 30 | 31 | Request request.Client 32 | } 33 | 34 | // Endpoints OneDrive客户端相关设置 35 | type Endpoints struct { 36 | OAuthURL string // OAuth认证的基URL 37 | OAuthEndpoints *oauthEndpoint 38 | EndpointURL string // 接口请求的基URL 39 | isInChina bool // 是否为世纪互联 40 | } 41 | 42 | // NewClient 根据存储策略获取新的client 43 | func NewClient(policy *model.Policy) (*Client, error) { 44 | client := &Client{ 45 | Endpoints: &Endpoints{ 46 | OAuthURL: policy.BaseURL, 47 | EndpointURL: policy.Server, 48 | }, 49 | Credential: &Credential{ 50 | RefreshToken: policy.AccessKey, 51 | }, 52 | Policy: policy, 53 | ClientID: policy.BucketName, 54 | ClientSecret: policy.SecretKey, 55 | Redirect: policy.OptionsSerialized.OdRedirect, 56 | Request: request.HTTPClient{}, 57 | } 58 | 59 | oauthBase := client.getOAuthEndpoint() 60 | if oauthBase == nil { 61 | return nil, ErrAuthEndpoint 62 | } 63 | client.Endpoints.OAuthEndpoints = oauthBase 64 | 65 | return client, nil 66 | } 67 | -------------------------------------------------------------------------------- /pkg/filesystem/driver/onedrive/client_test.go: -------------------------------------------------------------------------------- 1 | package onedrive 2 | 3 | import ( 4 | "testing" 5 | 6 | model "github.com/cloudreve/Cloudreve/v3/models" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestNewClient(t *testing.T) { 11 | asserts := assert.New(t) 12 | // getOAuthEndpoint失败 13 | { 14 | policy := model.Policy{ 15 | BaseURL: string([]byte{0x7f}), 16 | } 17 | res, err := NewClient(&policy) 18 | asserts.Error(err) 19 | asserts.Nil(res) 20 | } 21 | 22 | // 成功 23 | { 24 | policy := model.Policy{} 25 | res, err := NewClient(&policy) 26 | asserts.NoError(err) 27 | asserts.NotNil(res) 28 | asserts.NotNil(res.Credential) 29 | asserts.NotNil(res.Endpoints) 30 | asserts.NotNil(res.Endpoints.OAuthEndpoints) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /pkg/filesystem/driver/onedrive/handller_test.go: -------------------------------------------------------------------------------- 1 | package onedrive 2 | 3 | import ( 4 | model "github.com/cloudreve/Cloudreve/v3/models" 5 | "testing" 6 | ) 7 | 8 | func TestDriver_replaceSourceHost(t *testing.T) { 9 | tests := []struct { 10 | name string 11 | origin string 12 | cdn string 13 | want string 14 | wantErr bool 15 | }{ 16 | {"TestNoReplace", "http://1dr.ms/download.aspx?123456", "", "http://1dr.ms/download.aspx?123456", false}, 17 | {"TestReplaceCorrect", "http://1dr.ms/download.aspx?123456", "https://test.com:8080", "https://test.com:8080/download.aspx?123456", false}, 18 | {"TestCdnFormatError", "http://1dr.ms/download.aspx?123456", string([]byte{0x7f}), "", true}, 19 | {"TestSrcFormatError", string([]byte{0x7f}), "https://test.com:8080", "", true}, 20 | } 21 | for _, tt := range tests { 22 | t.Run(tt.name, func(t *testing.T) { 23 | policy := &model.Policy{} 24 | policy.OptionsSerialized.OdProxy = tt.cdn 25 | handler := Driver{ 26 | Policy: policy, 27 | } 28 | got, err := handler.replaceSourceHost(tt.origin) 29 | if (err != nil) != tt.wantErr { 30 | t.Errorf("replaceSourceHost() error = %v, wantErr %v", err, tt.wantErr) 31 | return 32 | } 33 | if got != tt.want { 34 | t.Errorf("replaceSourceHost() got = %v, want %v", got, tt.want) 35 | } 36 | }) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /pkg/filesystem/driver/onedrive/options.go: -------------------------------------------------------------------------------- 1 | package onedrive 2 | 3 | import "time" 4 | 5 | // Option 发送请求的额外设置 6 | type Option interface { 7 | apply(*options) 8 | } 9 | 10 | type options struct { 11 | redirect string 12 | code string 13 | refreshToken string 14 | conflictBehavior string 15 | expires time.Time 16 | } 17 | 18 | type optionFunc func(*options) 19 | 20 | // WithCode 设置接口Code 21 | func WithCode(t string) Option { 22 | return optionFunc(func(o *options) { 23 | o.code = t 24 | }) 25 | } 26 | 27 | // WithRefreshToken 设置接口RefreshToken 28 | func WithRefreshToken(t string) Option { 29 | return optionFunc(func(o *options) { 30 | o.refreshToken = t 31 | }) 32 | } 33 | 34 | // WithConflictBehavior 设置文件重名后的处理方式 35 | func WithConflictBehavior(t string) Option { 36 | return optionFunc(func(o *options) { 37 | o.conflictBehavior = t 38 | }) 39 | } 40 | 41 | func (f optionFunc) apply(o *options) { 42 | f(o) 43 | } 44 | 45 | func newDefaultOption() *options { 46 | return &options{ 47 | conflictBehavior: "fail", 48 | expires: time.Now().UTC().Add(time.Duration(1) * time.Hour), 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /pkg/filesystem/driver/oss/callback.go: -------------------------------------------------------------------------------- 1 | package oss 2 | 3 | import ( 4 | "bytes" 5 | "crypto" 6 | "crypto/md5" 7 | "crypto/rsa" 8 | "crypto/x509" 9 | "encoding/base64" 10 | "encoding/pem" 11 | "errors" 12 | "fmt" 13 | "io/ioutil" 14 | "net/http" 15 | "net/url" 16 | "strings" 17 | 18 | "github.com/cloudreve/Cloudreve/v3/pkg/cache" 19 | "github.com/cloudreve/Cloudreve/v3/pkg/request" 20 | ) 21 | 22 | // GetPublicKey 从回调请求或缓存中获取OSS的回调签名公钥 23 | func GetPublicKey(r *http.Request) ([]byte, error) { 24 | var pubKey []byte 25 | 26 | // 尝试从缓存中获取 27 | pub, exist := cache.Get("oss_public_key") 28 | if exist { 29 | return pub.([]byte), nil 30 | } 31 | 32 | // 从请求中获取 33 | pubURL, err := base64.StdEncoding.DecodeString(r.Header.Get("x-oss-pub-key-url")) 34 | if err != nil { 35 | return pubKey, err 36 | } 37 | 38 | // 确保这个 public key 是由 OSS 颁发的 39 | if !strings.HasPrefix(string(pubURL), "http://gosspublic.alicdn.com/") && 40 | !strings.HasPrefix(string(pubURL), "https://gosspublic.alicdn.com/") { 41 | return pubKey, errors.New("公钥URL无效") 42 | } 43 | 44 | // 获取公钥 45 | client := request.HTTPClient{} 46 | body, err := client.Request("GET", string(pubURL), nil). 47 | CheckHTTPResponse(200). 48 | GetResponse() 49 | if err != nil { 50 | return pubKey, err 51 | } 52 | 53 | // 写入缓存 54 | _ = cache.Set("oss_public_key", []byte(body), 86400*7) 55 | 56 | return []byte(body), nil 57 | } 58 | 59 | func getRequestMD5(r *http.Request) ([]byte, error) { 60 | var byteMD5 []byte 61 | 62 | // 获取请求正文 63 | body, err := ioutil.ReadAll(r.Body) 64 | r.Body.Close() 65 | if err != nil { 66 | return byteMD5, err 67 | } 68 | r.Body = ioutil.NopCloser(bytes.NewReader(body)) 69 | 70 | strURLPathDecode, err := url.PathUnescape(r.URL.Path) 71 | if err != nil { 72 | return byteMD5, err 73 | } 74 | 75 | strAuth := fmt.Sprintf("%s\n%s", strURLPathDecode, string(body)) 76 | md5Ctx := md5.New() 77 | md5Ctx.Write([]byte(strAuth)) 78 | byteMD5 = md5Ctx.Sum(nil) 79 | 80 | return byteMD5, nil 81 | } 82 | 83 | // VerifyCallbackSignature 验证OSS回调请求 84 | func VerifyCallbackSignature(r *http.Request) error { 85 | bytePublicKey, err := GetPublicKey(r) 86 | if err != nil { 87 | return err 88 | } 89 | 90 | byteMD5, err := getRequestMD5(r) 91 | if err != nil { 92 | return err 93 | } 94 | 95 | strAuthorizationBase64 := r.Header.Get("authorization") 96 | if strAuthorizationBase64 == "" { 97 | return errors.New("no authorization field in Request header") 98 | } 99 | authorization, _ := base64.StdEncoding.DecodeString(strAuthorizationBase64) 100 | 101 | pubBlock, _ := pem.Decode(bytePublicKey) 102 | if pubBlock == nil { 103 | return errors.New("pubBlock not exist") 104 | } 105 | pubInterface, err := x509.ParsePKIXPublicKey(pubBlock.Bytes) 106 | if (pubInterface == nil) || (err != nil) { 107 | return err 108 | } 109 | pub := pubInterface.(*rsa.PublicKey) 110 | 111 | errorVerifyPKCS1v15 := rsa.VerifyPKCS1v15(pub, crypto.MD5, byteMD5, authorization) 112 | if errorVerifyPKCS1v15 != nil { 113 | return errorVerifyPKCS1v15 114 | } 115 | 116 | return nil 117 | } 118 | -------------------------------------------------------------------------------- /pkg/filesystem/driver/template/handler.go: -------------------------------------------------------------------------------- 1 | package template 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "io" 7 | "net/url" 8 | 9 | model "github.com/cloudreve/Cloudreve/v3/models" 10 | "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" 11 | "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 12 | ) 13 | 14 | // Driver 适配器模板 15 | type Driver struct { 16 | Policy *model.Policy 17 | } 18 | 19 | // Get 获取文件 20 | func (handler Driver) Get(ctx context.Context, path string) (response.RSCloser, error) { 21 | return nil, errors.New("未实现") 22 | } 23 | 24 | // Put 将文件流保存到指定目录 25 | func (handler Driver) Put(ctx context.Context, file io.ReadCloser, dst string, size uint64) error { 26 | return errors.New("未实现") 27 | } 28 | 29 | // Delete 删除一个或多个文件, 30 | // 返回未删除的文件,及遇到的最后一个错误 31 | func (handler Driver) Delete(ctx context.Context, files []string) ([]string, error) { 32 | return []string{}, errors.New("未实现") 33 | } 34 | 35 | // Thumb 获取文件缩略图 36 | func (handler Driver) Thumb(ctx context.Context, path string) (*response.ContentResponse, error) { 37 | return nil, errors.New("未实现") 38 | } 39 | 40 | // Source 获取外链URL 41 | func (handler Driver) Source( 42 | ctx context.Context, 43 | path string, 44 | baseURL url.URL, 45 | ttl int64, 46 | isDownload bool, 47 | speed int, 48 | ) (string, error) { 49 | return "", errors.New("未实现") 50 | } 51 | 52 | // Token 获取上传策略和认证Token 53 | func (handler Driver) Token(ctx context.Context, TTL int64, key string) (serializer.UploadCredential, error) { 54 | return serializer.UploadCredential{}, errors.New("未实现") 55 | } 56 | -------------------------------------------------------------------------------- /pkg/filesystem/errors.go: -------------------------------------------------------------------------------- 1 | package filesystem 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 7 | ) 8 | 9 | var ( 10 | ErrUnknownPolicyType = errors.New("未知存储策略类型") 11 | ErrFileSizeTooBig = errors.New("单个文件尺寸太大") 12 | ErrFileExtensionNotAllowed = errors.New("不允许上传此类型的文件") 13 | ErrInsufficientCapacity = errors.New("容量空间不足") 14 | ErrIllegalObjectName = errors.New("目标名称非法") 15 | ErrClientCanceled = errors.New("客户端取消操作") 16 | ErrRootProtected = errors.New("无法对根目录进行操作") 17 | ErrInsertFileRecord = serializer.NewError(serializer.CodeDBError, "无法插入文件记录", nil) 18 | ErrFileExisted = serializer.NewError(serializer.CodeObjectExist, "同名文件或目录已存在", nil) 19 | ErrFolderExisted = serializer.NewError(serializer.CodeObjectExist, "同名目录已存在", nil) 20 | ErrPathNotExist = serializer.NewError(404, "路径不存在", nil) 21 | ErrObjectNotExist = serializer.NewError(404, "文件不存在", nil) 22 | ErrIO = serializer.NewError(serializer.CodeIOFailed, "无法读取文件数据", nil) 23 | ErrDBListObjects = serializer.NewError(serializer.CodeDBError, "无法列取对象记录", nil) 24 | ErrDBDeleteObjects = serializer.NewError(serializer.CodeDBError, "无法删除对象记录", nil) 25 | ) 26 | -------------------------------------------------------------------------------- /pkg/filesystem/fsctx/context.go: -------------------------------------------------------------------------------- 1 | package fsctx 2 | 3 | type key int 4 | 5 | const ( 6 | // GinCtx Gin的上下文 7 | GinCtx key = iota 8 | // SavePathCtx 文件物理路径 9 | SavePathCtx 10 | // FileHeaderCtx 上传的文件 11 | FileHeaderCtx 12 | // PathCtx 文件或目录的虚拟路径 13 | PathCtx 14 | // FileModelCtx 文件数据库模型 15 | FileModelCtx 16 | // FolderModelCtx 目录数据库模型 17 | FolderModelCtx 18 | // HTTPCtx HTTP请求的上下文 19 | HTTPCtx 20 | // UploadPolicyCtx 上传策略,一般为slave模式下使用 21 | UploadPolicyCtx 22 | // UserCtx 用户 23 | UserCtx 24 | // ThumbSizeCtx 缩略图尺寸 25 | ThumbSizeCtx 26 | // FileSizeCtx 文件大小 27 | FileSizeCtx 28 | // ShareKeyCtx 分享文件的 HashID 29 | ShareKeyCtx 30 | // LimitParentCtx 限制父目录 31 | LimitParentCtx 32 | // IgnoreConflictCtx 忽略重名冲突 33 | IgnoreConflictCtx 34 | // RetryCtx 失败重试次数 35 | RetryCtx 36 | // ForceUsePublicEndpointCtx 强制使用公网 Endpoint 37 | ForceUsePublicEndpointCtx 38 | // CancelFuncCtx Context 取消函數 39 | CancelFuncCtx 40 | // ValidateCapacityOnceCtx 限定归还容量的操作只執行一次 41 | ValidateCapacityOnceCtx 42 | ) 43 | -------------------------------------------------------------------------------- /pkg/filesystem/image.go: -------------------------------------------------------------------------------- 1 | package filesystem 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "strconv" 7 | 8 | model "github.com/cloudreve/Cloudreve/v3/models" 9 | "github.com/cloudreve/Cloudreve/v3/pkg/conf" 10 | "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" 11 | "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" 12 | "github.com/cloudreve/Cloudreve/v3/pkg/thumb" 13 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 14 | ) 15 | 16 | /* ================ 17 | 图像处理相关 18 | ================ 19 | */ 20 | 21 | // HandledExtension 可以生成缩略图的文件扩展名 22 | var HandledExtension = []string{"jpg", "jpeg", "png", "gif"} 23 | 24 | // GetThumb 获取文件的缩略图 25 | func (fs *FileSystem) GetThumb(ctx context.Context, id uint) (*response.ContentResponse, error) { 26 | // 根据 ID 查找文件 27 | err := fs.resetFileIDIfNotExist(ctx, id) 28 | if err != nil || fs.FileTarget[0].PicInfo == "" { 29 | return &response.ContentResponse{ 30 | Redirect: false, 31 | }, ErrObjectNotExist 32 | } 33 | 34 | w, h := fs.GenerateThumbnailSize(0, 0) 35 | ctx = context.WithValue(ctx, fsctx.ThumbSizeCtx, [2]uint{w, h}) 36 | ctx = context.WithValue(ctx, fsctx.FileModelCtx, fs.FileTarget[0]) 37 | res, err := fs.Handler.Thumb(ctx, fs.FileTarget[0].SourceName) 38 | if err == nil && conf.SystemConfig.Mode == "master" { 39 | res.MaxAge = model.GetIntSetting("preview_timeout", 60) 40 | } 41 | 42 | // 出错时重新生成缩略图 43 | if err != nil { 44 | fs.GenerateThumbnail(ctx, &fs.FileTarget[0]) 45 | } 46 | 47 | return res, err 48 | } 49 | 50 | // GenerateThumbnail 尝试为本地策略文件生成缩略图并获取图像原始大小 51 | // TODO 失败时,如果之前还有图像信息,则清除 52 | func (fs *FileSystem) GenerateThumbnail(ctx context.Context, file *model.File) { 53 | // 判断是否可以生成缩略图 54 | if !IsInExtensionList(HandledExtension, file.Name) { 55 | return 56 | } 57 | 58 | // 新建上下文 59 | newCtx, cancel := context.WithCancel(context.Background()) 60 | defer cancel() 61 | 62 | // 获取文件数据 63 | source, err := fs.Handler.Get(newCtx, file.SourceName) 64 | if err != nil { 65 | return 66 | } 67 | defer source.Close() 68 | 69 | image, err := thumb.NewThumbFromFile(source, file.Name) 70 | if err != nil { 71 | util.Log().Warning("生成缩略图时无法解析 [%s] 图像数据:%s", file.SourceName, err) 72 | return 73 | } 74 | 75 | // 获取原始图像尺寸 76 | w, h := image.GetSize() 77 | 78 | // 生成缩略图 79 | image.GetThumb(fs.GenerateThumbnailSize(w, h)) 80 | // 保存到文件 81 | err = image.Save(util.RelativePath(file.SourceName + conf.ThumbConfig.FileSuffix)) 82 | if err != nil { 83 | util.Log().Warning("无法保存缩略图:%s", err) 84 | return 85 | } 86 | 87 | // 更新文件的图像信息 88 | if file.Model.ID > 0 { 89 | err = file.UpdatePicInfo(fmt.Sprintf("%d,%d", w, h)) 90 | } else { 91 | file.PicInfo = fmt.Sprintf("%d,%d", w, h) 92 | } 93 | 94 | // 失败时删除缩略图文件 95 | if err != nil { 96 | _, _ = fs.Handler.Delete(newCtx, []string{file.SourceName + conf.ThumbConfig.FileSuffix}) 97 | } 98 | } 99 | 100 | // GenerateThumbnailSize 获取要生成的缩略图的尺寸 101 | func (fs *FileSystem) GenerateThumbnailSize(w, h int) (uint, uint) { 102 | if conf.SystemConfig.Mode == "master" { 103 | options := model.GetSettingByNames("thumb_width", "thumb_height") 104 | w, _ := strconv.ParseUint(options["thumb_width"], 10, 32) 105 | h, _ := strconv.ParseUint(options["thumb_height"], 10, 32) 106 | return uint(w), uint(h) 107 | } 108 | return conf.ThumbConfig.MaxWidth, conf.ThumbConfig.MaxHeight 109 | } 110 | -------------------------------------------------------------------------------- /pkg/filesystem/image_test.go: -------------------------------------------------------------------------------- 1 | package filesystem 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | model "github.com/cloudreve/Cloudreve/v3/models" 8 | "github.com/cloudreve/Cloudreve/v3/pkg/cache" 9 | "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/response" 10 | "github.com/stretchr/testify/assert" 11 | testMock "github.com/stretchr/testify/mock" 12 | ) 13 | 14 | func TestFileSystem_GetThumb(t *testing.T) { 15 | asserts := assert.New(t) 16 | fs := &FileSystem{User: &model.User{}} 17 | 18 | // 非图像文件 19 | { 20 | fs.SetTargetFile(&[]model.File{{}}) 21 | _, err := fs.GetThumb(context.Background(), 1) 22 | asserts.Equal(err, ErrObjectNotExist) 23 | } 24 | 25 | // 成功 26 | { 27 | cache.Set("setting_thumb_width", "10", 0) 28 | cache.Set("setting_thumb_height", "10", 0) 29 | cache.Set("setting_preview_timeout", "50", 0) 30 | testHandller2 := new(FileHeaderMock) 31 | testHandller2.On("Thumb", testMock.Anything, "").Return(&response.ContentResponse{}, nil) 32 | fs.CleanTargets() 33 | fs.SetTargetFile(&[]model.File{{PicInfo: "1,1", Policy: model.Policy{Type: "mock"}}}) 34 | fs.FileTarget[0].Policy.ID = 1 35 | fs.Handler = testHandller2 36 | res, err := fs.GetThumb(context.Background(), 1) 37 | asserts.NoError(err) 38 | asserts.EqualValues(50, res.MaxAge) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /pkg/filesystem/path.go: -------------------------------------------------------------------------------- 1 | package filesystem 2 | 3 | import ( 4 | "path" 5 | 6 | model "github.com/cloudreve/Cloudreve/v3/models" 7 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 8 | ) 9 | 10 | /* ================= 11 | 路径/目录相关 12 | ================= 13 | */ 14 | 15 | // IsPathExist 返回给定目录是否存在 16 | // 如果存在就返回目录 17 | func (fs *FileSystem) IsPathExist(path string) (bool, *model.Folder) { 18 | pathList := util.SplitPath(path) 19 | if len(pathList) == 0 { 20 | return false, nil 21 | } 22 | 23 | // 递归步入目录 24 | // TODO:测试新增 25 | var currentFolder *model.Folder 26 | 27 | // 如果已设定跟目录对象,则从给定目录向下遍历 28 | if fs.Root != nil { 29 | currentFolder = fs.Root 30 | } 31 | 32 | for _, folderName := range pathList { 33 | var err error 34 | 35 | // 根目录 36 | if folderName == "/" { 37 | if currentFolder != nil { 38 | continue 39 | } 40 | currentFolder, err = fs.User.Root() 41 | if err != nil { 42 | return false, nil 43 | } 44 | } else { 45 | currentFolder, err = currentFolder.GetChild(folderName) 46 | if err != nil { 47 | return false, nil 48 | } 49 | } 50 | } 51 | 52 | return true, currentFolder 53 | } 54 | 55 | // IsFileExist 返回给定路径的文件是否存在 56 | func (fs *FileSystem) IsFileExist(fullPath string) (bool, *model.File) { 57 | basePath := path.Dir(fullPath) 58 | fileName := path.Base(fullPath) 59 | 60 | // 获得父目录 61 | exist, parent := fs.IsPathExist(basePath) 62 | if !exist { 63 | return false, nil 64 | } 65 | 66 | file, err := parent.GetChildFile(fileName) 67 | 68 | return err == nil, file 69 | } 70 | 71 | // IsChildFileExist 确定folder目录下是否有名为name的文件 72 | func (fs *FileSystem) IsChildFileExist(folder *model.Folder, name string) (bool, *model.File) { 73 | file, err := folder.GetChildFile(name) 74 | return err == nil, file 75 | } 76 | -------------------------------------------------------------------------------- /pkg/filesystem/response/common.go: -------------------------------------------------------------------------------- 1 | package response 2 | 3 | import ( 4 | "io" 5 | "time" 6 | ) 7 | 8 | // ContentResponse 获取文件内容类方法的通用返回值。 9 | // 有些上传策略需要重定向, 10 | // 有些直接写文件数据到浏览器 11 | type ContentResponse struct { 12 | Redirect bool 13 | Content RSCloser 14 | URL string 15 | MaxAge int 16 | } 17 | 18 | // RSCloser 存储策略适配器返回的文件流,有些策略需要带有Closer 19 | type RSCloser interface { 20 | io.ReadSeeker 21 | io.Closer 22 | } 23 | 24 | // Object 列出文件、目录时返回的对象 25 | type Object struct { 26 | Name string `json:"name"` 27 | RelativePath string `json:"relative_path"` 28 | Source string `json:"source"` 29 | Size uint64 `json:"size"` 30 | IsDir bool `json:"is_dir"` 31 | LastModify time.Time `json:"last_modify"` 32 | } 33 | -------------------------------------------------------------------------------- /pkg/filesystem/tests/file1.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moeYuiYui/Cloudreve/1fe212135bd576cee956132aaf70cebb879eadd9/pkg/filesystem/tests/file1.txt -------------------------------------------------------------------------------- /pkg/filesystem/tests/file2.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moeYuiYui/Cloudreve/1fe212135bd576cee956132aaf70cebb879eadd9/pkg/filesystem/tests/file2.txt -------------------------------------------------------------------------------- /pkg/filesystem/tests/test.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moeYuiYui/Cloudreve/1fe212135bd576cee956132aaf70cebb879eadd9/pkg/filesystem/tests/test.zip -------------------------------------------------------------------------------- /pkg/filesystem/validator.go: -------------------------------------------------------------------------------- 1 | package filesystem 2 | 3 | import ( 4 | "context" 5 | "path/filepath" 6 | "strings" 7 | 8 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 9 | ) 10 | 11 | /* ========== 12 | 验证器 13 | ========== 14 | */ 15 | 16 | // 文件/路径名保留字符 17 | var reservedCharacter = []string{"\\", "?", "*", "<", "\"", ":", ">", "/", "|"} 18 | 19 | // ValidateLegalName 验证文件名/文件夹名是否合法 20 | func (fs *FileSystem) ValidateLegalName(ctx context.Context, name string) bool { 21 | // 是否包含保留字符 22 | for _, value := range reservedCharacter { 23 | if strings.Contains(name, value) { 24 | return false 25 | } 26 | } 27 | 28 | // 是否超出长度限制 29 | if len(name) >= 256 { 30 | return false 31 | } 32 | 33 | // 是否为空限制 34 | if len(name) == 0 { 35 | return false 36 | } 37 | 38 | // 结尾不能是空格 39 | if strings.HasSuffix(name, " ") { 40 | return false 41 | } 42 | 43 | return true 44 | } 45 | 46 | // ValidateFileSize 验证上传的文件大小是否超出限制 47 | func (fs *FileSystem) ValidateFileSize(ctx context.Context, size uint64) bool { 48 | if fs.User.Policy.MaxSize == 0 { 49 | return true 50 | } 51 | return size <= fs.User.Policy.MaxSize 52 | } 53 | 54 | // ValidateCapacity 验证并扣除用户容量 55 | func (fs *FileSystem) ValidateCapacity(ctx context.Context, size uint64) bool { 56 | return fs.User.IncreaseStorage(size) 57 | } 58 | 59 | // ValidateExtension 验证文件扩展名 60 | func (fs *FileSystem) ValidateExtension(ctx context.Context, fileName string) bool { 61 | // 不需要验证 62 | if len(fs.User.Policy.OptionsSerialized.FileType) == 0 { 63 | return true 64 | } 65 | 66 | return IsInExtensionList(fs.User.Policy.OptionsSerialized.FileType, fileName) 67 | } 68 | 69 | // IsInExtensionList 返回文件的扩展名是否在给定的列表范围内 70 | func IsInExtensionList(extList []string, fileName string) bool { 71 | ext := strings.ToLower(filepath.Ext(fileName)) 72 | // 无扩展名时 73 | if len(ext) == 0 { 74 | return false 75 | } 76 | 77 | if util.ContainsString(extList, ext[1:]) { 78 | return true 79 | } 80 | 81 | return false 82 | } 83 | -------------------------------------------------------------------------------- /pkg/hashid/hash.go: -------------------------------------------------------------------------------- 1 | package hashid 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/cloudreve/Cloudreve/v3/pkg/conf" 7 | "github.com/speps/go-hashids" 8 | ) 9 | 10 | // ID类型 11 | const ( 12 | ShareID = iota // 分享 13 | UserID // 用户 14 | FileID // 文件ID 15 | FolderID // 目录ID 16 | TagID // 标签ID 17 | PolicyID // 存储策略ID 18 | ) 19 | 20 | var ( 21 | // ErrTypeNotMatch ID类型不匹配 22 | ErrTypeNotMatch = errors.New("ID类型不匹配") 23 | ) 24 | 25 | // HashEncode 对给定数据计算HashID 26 | func HashEncode(v []int) (string, error) { 27 | hd := hashids.NewData() 28 | hd.Salt = conf.SystemConfig.HashIDSalt 29 | 30 | h, err := hashids.NewWithData(hd) 31 | if err != nil { 32 | return "", err 33 | } 34 | 35 | id, err := h.Encode(v) 36 | if err != nil { 37 | return "", err 38 | } 39 | return id, nil 40 | } 41 | 42 | // HashDecode 对给定数据计算原始数据 43 | func HashDecode(raw string) ([]int, error) { 44 | hd := hashids.NewData() 45 | hd.Salt = conf.SystemConfig.HashIDSalt 46 | 47 | h, err := hashids.NewWithData(hd) 48 | if err != nil { 49 | return []int{}, err 50 | } 51 | 52 | return h.DecodeWithError(raw) 53 | 54 | } 55 | 56 | // HashID 计算数据库内主键对应的HashID 57 | func HashID(id uint, t int) string { 58 | v, _ := HashEncode([]int{int(id), t}) 59 | return v 60 | } 61 | 62 | // DecodeHashID 计算HashID对应的数据库ID 63 | func DecodeHashID(id string, t int) (uint, error) { 64 | v, _ := HashDecode(id) 65 | if len(v) != 2 || v[1] != t { 66 | return 0, ErrTypeNotMatch 67 | } 68 | return uint(v[0]), nil 69 | } 70 | -------------------------------------------------------------------------------- /pkg/hashid/hash_test.go: -------------------------------------------------------------------------------- 1 | package hashid 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | func TestHashEncode(t *testing.T) { 9 | asserts := assert.New(t) 10 | 11 | { 12 | res, err := HashEncode([]int{1, 2, 3}) 13 | asserts.NoError(err) 14 | asserts.NotEmpty(res) 15 | } 16 | 17 | { 18 | res, err := HashEncode([]int{}) 19 | asserts.Error(err) 20 | asserts.Empty(res) 21 | } 22 | 23 | } 24 | 25 | func TestHashID(t *testing.T) { 26 | asserts := assert.New(t) 27 | 28 | { 29 | res := HashID(1, ShareID) 30 | asserts.NotEmpty(res) 31 | } 32 | } 33 | 34 | func TestHashDecode(t *testing.T) { 35 | asserts := assert.New(t) 36 | 37 | // 正常 38 | { 39 | res, _ := HashEncode([]int{1, 2, 3}) 40 | decodeRes, err := HashDecode(res) 41 | asserts.NoError(err) 42 | asserts.Equal([]int{1, 2, 3}, decodeRes) 43 | } 44 | 45 | // 出错 46 | { 47 | decodeRes, err := HashDecode("233") 48 | asserts.Error(err) 49 | asserts.Len(decodeRes, 0) 50 | } 51 | } 52 | 53 | func TestDecodeHashID(t *testing.T) { 54 | asserts := assert.New(t) 55 | 56 | // 成功 57 | { 58 | uid, err := DecodeHashID(HashID(1, ShareID), ShareID) 59 | asserts.NoError(err) 60 | asserts.EqualValues(1, uid) 61 | } 62 | 63 | // 类型不匹配 64 | { 65 | uid, err := DecodeHashID(HashID(1, ShareID), UserID) 66 | asserts.Error(err) 67 | asserts.EqualValues(0, uid) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /pkg/request/slave.go: -------------------------------------------------------------------------------- 1 | package request 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "errors" 7 | "time" 8 | 9 | "github.com/cloudreve/Cloudreve/v3/pkg/auth" 10 | "github.com/cloudreve/Cloudreve/v3/pkg/conf" 11 | "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 12 | ) 13 | 14 | // RemoteCallback 发送远程存储策略上传回调请求 15 | func RemoteCallback(url string, body serializer.UploadCallback) error { 16 | callbackBody, err := json.Marshal(struct { 17 | Data serializer.UploadCallback `json:"data"` 18 | }{ 19 | Data: body, 20 | }) 21 | if err != nil { 22 | return serializer.NewError(serializer.CodeCallbackError, "无法编码回调正文", err) 23 | } 24 | 25 | resp := GeneralClient.Request( 26 | "POST", 27 | url, 28 | bytes.NewReader(callbackBody), 29 | WithTimeout(time.Duration(conf.SlaveConfig.CallbackTimeout)*time.Second), 30 | WithCredential(auth.General, int64(conf.SlaveConfig.SignatureTTL)), 31 | ) 32 | 33 | if resp.Err != nil { 34 | return serializer.NewError(serializer.CodeCallbackError, "无法发起回调请求", resp.Err) 35 | } 36 | 37 | // 解析回调服务端响应 38 | resp = resp.CheckHTTPResponse(200) 39 | if resp.Err != nil { 40 | return serializer.NewError(serializer.CodeCallbackError, "服务器返回异常响应", resp.Err) 41 | } 42 | response, err := resp.DecodeResponse() 43 | if err != nil { 44 | return serializer.NewError(serializer.CodeCallbackError, "无法解析服务端返回的响应", err) 45 | } 46 | if response.Code != 0 { 47 | return serializer.NewError(response.Code, response.Msg, errors.New(response.Error)) 48 | } 49 | 50 | return nil 51 | } 52 | -------------------------------------------------------------------------------- /pkg/serializer/aria2_test.go: -------------------------------------------------------------------------------- 1 | package serializer 2 | 3 | import ( 4 | "testing" 5 | 6 | model "github.com/cloudreve/Cloudreve/v3/models" 7 | "github.com/cloudreve/Cloudreve/v3/pkg/aria2/rpc" 8 | "github.com/cloudreve/Cloudreve/v3/pkg/cache" 9 | "github.com/jinzhu/gorm" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestBuildFinishedListResponse(t *testing.T) { 14 | asserts := assert.New(t) 15 | tasks := []model.Download{ 16 | { 17 | StatusInfo: rpc.StatusInfo{ 18 | Files: []rpc.FileInfo{ 19 | { 20 | Path: "/file/name.txt", 21 | }, 22 | }, 23 | }, 24 | Task: &model.Task{ 25 | Model: gorm.Model{}, 26 | Error: "error", 27 | }, 28 | }, 29 | { 30 | StatusInfo: rpc.StatusInfo{ 31 | Files: []rpc.FileInfo{ 32 | { 33 | Path: "/file/name1.txt", 34 | }, 35 | { 36 | Path: "/file/name2.txt", 37 | }, 38 | }, 39 | }, 40 | }, 41 | } 42 | tasks[1].StatusInfo.BitTorrent.Info.Name = "name.txt" 43 | res := BuildFinishedListResponse(tasks).Data.([]FinishedListResponse) 44 | asserts.Len(res, 2) 45 | asserts.Equal("name.txt", res[1].Name) 46 | asserts.Equal("name.txt", res[0].Name) 47 | asserts.Equal("name.txt", res[0].Files[0].Path) 48 | asserts.Equal("name1.txt", res[1].Files[0].Path) 49 | asserts.Equal("name2.txt", res[1].Files[1].Path) 50 | asserts.EqualValues(0, res[0].TaskStatus) 51 | asserts.Equal("error", res[0].TaskError) 52 | } 53 | 54 | func TestBuildDownloadingResponse(t *testing.T) { 55 | asserts := assert.New(t) 56 | cache.Set("setting_aria2_interval", "10", 0) 57 | tasks := []model.Download{ 58 | { 59 | StatusInfo: rpc.StatusInfo{ 60 | Files: []rpc.FileInfo{ 61 | { 62 | Path: "/file/name.txt", 63 | }, 64 | }, 65 | }, 66 | Task: &model.Task{ 67 | Model: gorm.Model{}, 68 | Error: "error", 69 | }, 70 | }, 71 | { 72 | StatusInfo: rpc.StatusInfo{ 73 | Files: []rpc.FileInfo{ 74 | { 75 | Path: "/file/name1.txt", 76 | }, 77 | { 78 | Path: "/file/name2.txt", 79 | }, 80 | }, 81 | }, 82 | }, 83 | } 84 | tasks[1].StatusInfo.BitTorrent.Info.Name = "name.txt" 85 | 86 | res := BuildDownloadingResponse(tasks).Data.([]DownloadListResponse) 87 | asserts.Len(res, 2) 88 | asserts.Equal("name1.txt", res[1].Name) 89 | asserts.Equal("name.txt", res[0].Name) 90 | asserts.Equal("name.txt", res[0].Info.Files[0].Path) 91 | asserts.Equal("name1.txt", res[1].Info.Files[0].Path) 92 | asserts.Equal("name2.txt", res[1].Info.Files[1].Path) 93 | } 94 | -------------------------------------------------------------------------------- /pkg/serializer/auth.go: -------------------------------------------------------------------------------- 1 | package serializer 2 | 3 | import "encoding/json" 4 | 5 | // RequestRawSign 待签名的HTTP请求 6 | type RequestRawSign struct { 7 | Path string 8 | Policy string 9 | Body string 10 | } 11 | 12 | // NewRequestSignString 返回JSON格式的待签名字符串 13 | // TODO 测试 14 | func NewRequestSignString(path, policy, body string) string { 15 | req := RequestRawSign{ 16 | Path: path, 17 | Policy: policy, 18 | Body: body, 19 | } 20 | res, _ := json.Marshal(req) 21 | return string(res) 22 | } 23 | -------------------------------------------------------------------------------- /pkg/serializer/auth_test.go: -------------------------------------------------------------------------------- 1 | package serializer 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | func TestNewRequestSignString(t *testing.T) { 9 | asserts := assert.New(t) 10 | 11 | sign := NewRequestSignString("1", "2", "3") 12 | asserts.NotEmpty(sign) 13 | } 14 | -------------------------------------------------------------------------------- /pkg/serializer/error.go: -------------------------------------------------------------------------------- 1 | package serializer 2 | 3 | import "github.com/gin-gonic/gin" 4 | 5 | // Response 基础序列化器 6 | type Response struct { 7 | Code int `json:"code"` 8 | Data interface{} `json:"data,omitempty"` 9 | Msg string `json:"msg"` 10 | Error string `json:"error,omitempty"` 11 | } 12 | 13 | // AppError 应用错误,实现了error接口 14 | type AppError struct { 15 | Code int 16 | Msg string 17 | RawError error 18 | } 19 | 20 | // NewError 返回新的错误对象 todo:测试 还有下面的 21 | func NewError(code int, msg string, err error) AppError { 22 | return AppError{ 23 | Code: code, 24 | Msg: msg, 25 | RawError: err, 26 | } 27 | } 28 | 29 | // WithError 将应用error携带标准库中的error 30 | func (err *AppError) WithError(raw error) AppError { 31 | err.RawError = raw 32 | return *err 33 | } 34 | 35 | // Error 返回业务代码确定的可读错误信息 36 | func (err AppError) Error() string { 37 | return err.Msg 38 | } 39 | 40 | // 三位数错误编码为复用http原本含义 41 | // 五位数错误编码为应用自定义错误 42 | // 五开头的五位数错误编码为服务器端错误,比如数据库操作失败 43 | // 四开头的五位数错误编码为客户端错误,有时候是客户端代码写错了,有时候是用户操作错误 44 | const ( 45 | // CodeNotFullySuccess 未完全成功 46 | CodeNotFullySuccess = 203 47 | // CodeCheckLogin 未登录 48 | CodeCheckLogin = 401 49 | // CodeNoPermissionErr 未授权访问 50 | CodeNoPermissionErr = 403 51 | // CodeNotFound 资源未找到 52 | CodeNotFound = 404 53 | // CodeUploadFailed 上传出错 54 | CodeUploadFailed = 40002 55 | // CodeCredentialInvalid 凭证无效 56 | CodeCredentialInvalid = 40001 57 | // CodeCreateFolderFailed 目录创建失败 58 | CodeCreateFolderFailed = 40003 59 | // CodeObjectExist 对象已存在 60 | CodeObjectExist = 40004 61 | // CodeSignExpired 签名过期 62 | CodeSignExpired = 40005 63 | // CodePolicyNotAllowed 当前存储策略不允许 64 | CodePolicyNotAllowed = 40006 65 | // CodeGroupNotAllowed 用户组无法进行此操作 66 | CodeGroupNotAllowed = 40007 67 | // CodeAdminRequired 非管理用户组 68 | CodeAdminRequired = 40008 69 | // CodeDBError 数据库操作失败 70 | CodeDBError = 50001 71 | // CodeEncryptError 加密失败 72 | CodeEncryptError = 50002 73 | // CodeIOFailed IO操作失败 74 | CodeIOFailed = 50004 75 | // CodeInternalSetting 内部设置参数错误 76 | CodeInternalSetting = 50005 77 | // CodeCacheOperation 缓存操作失败 78 | CodeCacheOperation = 50006 79 | // CodeCallbackError 回调失败 80 | CodeCallbackError = 50007 81 | //CodeParamErr 各种奇奇怪怪的参数错误 82 | CodeParamErr = 40001 83 | // CodeNotSet 未定错误,后续尝试从error中获取 84 | CodeNotSet = -1 85 | ) 86 | 87 | // DBErr 数据库操作失败 88 | func DBErr(msg string, err error) Response { 89 | if msg == "" { 90 | msg = "数据库操作失败" 91 | } 92 | return Err(CodeDBError, msg, err) 93 | } 94 | 95 | // ParamErr 各种参数错误 96 | func ParamErr(msg string, err error) Response { 97 | if msg == "" { 98 | msg = "参数错误" 99 | } 100 | return Err(CodeParamErr, msg, err) 101 | } 102 | 103 | // Err 通用错误处理 104 | func Err(errCode int, msg string, err error) Response { 105 | // 底层错误是AppError,则尝试从AppError中获取详细信息 106 | if appError, ok := err.(AppError); ok { 107 | errCode = appError.Code 108 | err = appError.RawError 109 | msg = appError.Msg 110 | } 111 | 112 | res := Response{ 113 | Code: errCode, 114 | Msg: msg, 115 | } 116 | // 生产环境隐藏底层报错 117 | if err != nil && gin.Mode() != gin.ReleaseMode { 118 | res.Error = err.Error() 119 | } 120 | return res 121 | } 122 | -------------------------------------------------------------------------------- /pkg/serializer/setting.go: -------------------------------------------------------------------------------- 1 | package serializer 2 | 3 | import model "github.com/cloudreve/Cloudreve/v3/models" 4 | 5 | // SiteConfig 站点全局设置序列 6 | type SiteConfig struct { 7 | SiteName string `json:"title"` 8 | SiteICPId string `json:"siteICPId"` 9 | LoginCaptcha bool `json:"loginCaptcha"` 10 | RegCaptcha bool `json:"regCaptcha"` 11 | ForgetCaptcha bool `json:"forgetCaptcha"` 12 | EmailActive bool `json:"emailActive"` 13 | Themes string `json:"themes"` 14 | DefaultTheme string `json:"defaultTheme"` 15 | HomepageViewMethod string `json:"home_view_method"` 16 | ShareViewMethod string `json:"share_view_method"` 17 | Authn bool `json:"authn"` 18 | User User `json:"user"` 19 | UseReCaptcha bool `json:"captcha_IsUseReCaptcha"` 20 | ReCaptchaKey string `json:"captcha_ReCaptchaKey"` 21 | } 22 | 23 | type task struct { 24 | Status int `json:"status"` 25 | Type int `json:"type"` 26 | CreateDate string `json:"create_date"` 27 | Progress int `json:"progress"` 28 | Error string `json:"error"` 29 | } 30 | 31 | // BuildTaskList 构建任务列表响应 32 | func BuildTaskList(tasks []model.Task, total int) Response { 33 | res := make([]task, 0, len(tasks)) 34 | for _, t := range tasks { 35 | res = append(res, task{ 36 | Status: t.Status, 37 | Type: t.Type, 38 | CreateDate: t.CreatedAt.Format("2006-01-02 15:04:05"), 39 | Progress: t.Progress, 40 | Error: t.Error, 41 | }) 42 | } 43 | 44 | return Response{Data: map[string]interface{}{ 45 | "total": total, 46 | "tasks": res, 47 | }} 48 | } 49 | 50 | func checkSettingValue(setting map[string]string, key string) string { 51 | if v, ok := setting[key]; ok { 52 | return v 53 | } 54 | return "" 55 | } 56 | 57 | // BuildSiteConfig 站点全局设置 58 | func BuildSiteConfig(settings map[string]string, user *model.User) Response { 59 | var userRes User 60 | if user != nil { 61 | userRes = BuildUser(*user) 62 | } else { 63 | userRes = BuildUser(*model.NewAnonymousUser()) 64 | } 65 | res := Response{ 66 | Data: SiteConfig{ 67 | SiteName: checkSettingValue(settings, "siteName"), 68 | SiteICPId: checkSettingValue(settings, "siteICPId"), 69 | LoginCaptcha: model.IsTrueVal(checkSettingValue(settings, "login_captcha")), 70 | RegCaptcha: model.IsTrueVal(checkSettingValue(settings, "reg_captcha")), 71 | ForgetCaptcha: model.IsTrueVal(checkSettingValue(settings, "forget_captcha")), 72 | EmailActive: model.IsTrueVal(checkSettingValue(settings, "email_active")), 73 | Themes: checkSettingValue(settings, "themes"), 74 | DefaultTheme: checkSettingValue(settings, "defaultTheme"), 75 | HomepageViewMethod: checkSettingValue(settings, "home_view_method"), 76 | ShareViewMethod: checkSettingValue(settings, "share_view_method"), 77 | Authn: model.IsTrueVal(checkSettingValue(settings, "authn_enabled")), 78 | User: userRes, 79 | UseReCaptcha: model.IsTrueVal(checkSettingValue(settings, "captcha_IsUseReCaptcha")), 80 | ReCaptchaKey: checkSettingValue(settings, "captcha_ReCaptchaKey"), 81 | }} 82 | return res 83 | } 84 | -------------------------------------------------------------------------------- /pkg/serializer/setting_test.go: -------------------------------------------------------------------------------- 1 | package serializer 2 | 3 | import ( 4 | "testing" 5 | 6 | model "github.com/cloudreve/Cloudreve/v3/models" 7 | "github.com/jinzhu/gorm" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestCheckSettingValue(t *testing.T) { 12 | asserts := assert.New(t) 13 | 14 | asserts.Equal("", checkSettingValue(map[string]string{}, "key")) 15 | asserts.Equal("123", checkSettingValue(map[string]string{"key": "123"}, "key")) 16 | } 17 | 18 | func TestBuildSiteConfig(t *testing.T) { 19 | asserts := assert.New(t) 20 | 21 | res := BuildSiteConfig(map[string]string{"not exist": ""}, &model.User{}) 22 | asserts.Equal("", res.Data.(SiteConfig).SiteName) 23 | 24 | res = BuildSiteConfig(map[string]string{"siteName": "123"}, &model.User{}) 25 | asserts.Equal("123", res.Data.(SiteConfig).SiteName) 26 | 27 | // 非空用户 28 | res = BuildSiteConfig(map[string]string{"qq_login": "1"}, &model.User{ 29 | Model: gorm.Model{ 30 | ID: 5, 31 | }, 32 | }) 33 | asserts.Len(res.Data.(SiteConfig).User.ID, 4) 34 | } 35 | 36 | func TestBuildTaskList(t *testing.T) { 37 | asserts := assert.New(t) 38 | tasks := []model.Task{{}} 39 | 40 | res := BuildTaskList(tasks, 1) 41 | asserts.NotNil(res) 42 | } 43 | -------------------------------------------------------------------------------- /pkg/serializer/share_test.go: -------------------------------------------------------------------------------- 1 | package serializer 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | model "github.com/cloudreve/Cloudreve/v3/models" 8 | "github.com/jinzhu/gorm" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestBuildShareList(t *testing.T) { 13 | asserts := assert.New(t) 14 | timeNow := time.Now() 15 | 16 | shares := []model.Share{ 17 | { 18 | Expires: &timeNow, 19 | File: model.File{ 20 | Model: gorm.Model{ID: 1}, 21 | }, 22 | }, 23 | { 24 | Folder: model.Folder{ 25 | Model: gorm.Model{ID: 1}, 26 | }, 27 | }, 28 | } 29 | 30 | res := BuildShareList(shares, 2) 31 | asserts.Equal(0, res.Code) 32 | } 33 | 34 | func TestBuildShareResponse(t *testing.T) { 35 | asserts := assert.New(t) 36 | 37 | // 未解锁 38 | { 39 | share := &model.Share{ 40 | User: model.User{Model: gorm.Model{ID: 1}}, 41 | Downloads: 1, 42 | } 43 | res := BuildShareResponse(share, false) 44 | asserts.EqualValues(0, res.Downloads) 45 | asserts.True(res.Locked) 46 | asserts.NotNil(res.Creator) 47 | } 48 | 49 | // 已解锁,非目录 50 | { 51 | expires := time.Now().Add(time.Duration(10) * time.Second) 52 | share := &model.Share{ 53 | User: model.User{Model: gorm.Model{ID: 1}}, 54 | Downloads: 1, 55 | Expires: &expires, 56 | File: model.File{ 57 | Model: gorm.Model{ID: 1}, 58 | }, 59 | } 60 | res := BuildShareResponse(share, true) 61 | asserts.EqualValues(1, res.Downloads) 62 | asserts.False(res.Locked) 63 | asserts.NotEmpty(res.Expire) 64 | asserts.NotNil(res.Creator) 65 | } 66 | 67 | // 已解锁,是目录 68 | { 69 | expires := time.Now().Add(time.Duration(10) * time.Second) 70 | share := &model.Share{ 71 | User: model.User{Model: gorm.Model{ID: 1}}, 72 | Downloads: 1, 73 | Expires: &expires, 74 | Folder: model.Folder{ 75 | Model: gorm.Model{ID: 1}, 76 | }, 77 | IsDir: true, 78 | } 79 | res := BuildShareResponse(share, true) 80 | asserts.EqualValues(1, res.Downloads) 81 | asserts.False(res.Locked) 82 | asserts.NotEmpty(res.Expire) 83 | asserts.NotNil(res.Creator) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /pkg/serializer/slave.go: -------------------------------------------------------------------------------- 1 | package serializer 2 | 3 | // RemoteDeleteRequest 远程策略删除接口请求正文 4 | type RemoteDeleteRequest struct { 5 | Files []string `json:"files"` 6 | } 7 | 8 | // ListRequest 远程策略列文件请求正文 9 | type ListRequest struct { 10 | Path string `json:"path"` 11 | Recursive bool `json:"recursive"` 12 | } 13 | -------------------------------------------------------------------------------- /pkg/serializer/upload.go: -------------------------------------------------------------------------------- 1 | package serializer 2 | 3 | import ( 4 | "encoding/base64" 5 | "encoding/gob" 6 | "encoding/json" 7 | ) 8 | 9 | // UploadPolicy slave模式下传递的上传策略 10 | type UploadPolicy struct { 11 | SavePath string `json:"save_path"` 12 | FileName string `json:"file_name"` 13 | AutoRename bool `json:"auto_rename"` 14 | MaxSize uint64 `json:"max_size"` 15 | AllowedExtension []string `json:"allowed_extension"` 16 | CallbackURL string `json:"callback_url"` 17 | } 18 | 19 | // UploadCredential 返回给客户端的上传凭证 20 | type UploadCredential struct { 21 | Token string `json:"token"` 22 | Policy string `json:"policy"` 23 | Path string `json:"path"` // 存储路径 24 | AccessKey string `json:"ak"` 25 | KeyTime string `json:"key_time,omitempty"` // COS用有效期 26 | Callback string `json:"callback,omitempty"` // 回调地址 27 | Key string `json:"key,omitempty"` // 文件标识符,通常为回调key 28 | } 29 | 30 | // UploadSession 上传会话 31 | type UploadSession struct { 32 | Key string 33 | UID uint 34 | PolicyID uint 35 | VirtualPath string 36 | Name string 37 | Size uint64 38 | SavePath string 39 | } 40 | 41 | // UploadCallback 上传回调正文 42 | type UploadCallback struct { 43 | Name string `json:"name"` 44 | SourceName string `json:"source_name"` 45 | PicInfo string `json:"pic_info"` 46 | Size uint64 `json:"size"` 47 | } 48 | 49 | // GeneralUploadCallbackFailed 存储策略上传回调失败响应 50 | type GeneralUploadCallbackFailed struct { 51 | Error string `json:"error"` 52 | } 53 | 54 | func init() { 55 | gob.Register(UploadSession{}) 56 | } 57 | 58 | // DecodeUploadPolicy 反序列化Header中携带的上传策略 59 | func DecodeUploadPolicy(raw string) (*UploadPolicy, error) { 60 | var res UploadPolicy 61 | 62 | rawJSON, err := base64.StdEncoding.DecodeString(raw) 63 | if err != nil { 64 | return nil, err 65 | } 66 | 67 | err = json.Unmarshal(rawJSON, &res) 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | return &res, err 73 | } 74 | 75 | // EncodeUploadPolicy 序列化Header中携带的上传策略 76 | func (policy *UploadPolicy) EncodeUploadPolicy() (string, error) { 77 | jsonRes, err := json.Marshal(policy) 78 | if err != nil { 79 | return "", err 80 | } 81 | 82 | res := base64.StdEncoding.EncodeToString(jsonRes) 83 | return res, nil 84 | 85 | } 86 | -------------------------------------------------------------------------------- /pkg/serializer/upload_test.go: -------------------------------------------------------------------------------- 1 | package serializer 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | func TestDecodeUploadPolicy(t *testing.T) { 9 | asserts := assert.New(t) 10 | 11 | testCases := []struct { 12 | input string 13 | expectError bool 14 | expectNil bool 15 | expectRes *UploadPolicy 16 | }{ 17 | { 18 | "错误的base64字符", 19 | true, 20 | true, 21 | &UploadPolicy{}, 22 | }, 23 | { 24 | "6ZSZ6K+v55qESlNPTuWtl+espg==", 25 | true, 26 | true, 27 | &UploadPolicy{}, 28 | }, 29 | { 30 | "e30=", 31 | false, 32 | false, 33 | &UploadPolicy{}, 34 | }, 35 | { 36 | "eyJjYWxsYmFja191cmwiOiJ0ZXN0In0=", 37 | false, 38 | false, 39 | &UploadPolicy{CallbackURL: "test"}, 40 | }, 41 | } 42 | 43 | for _, testCase := range testCases { 44 | res, err := DecodeUploadPolicy(testCase.input) 45 | if testCase.expectError { 46 | asserts.Error(err) 47 | } 48 | if testCase.expectNil { 49 | asserts.Nil(res) 50 | } 51 | if !testCase.expectNil { 52 | asserts.Equal(testCase.expectRes, res) 53 | } 54 | } 55 | } 56 | 57 | func TestUploadPolicy_EncodeUploadPolicy(t *testing.T) { 58 | asserts := assert.New(t) 59 | testPolicy := UploadPolicy{} 60 | res, err := testPolicy.EncodeUploadPolicy() 61 | asserts.NoError(err) 62 | asserts.NotEmpty(res) 63 | } 64 | -------------------------------------------------------------------------------- /pkg/task/decompress.go: -------------------------------------------------------------------------------- 1 | package task 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | 7 | model "github.com/cloudreve/Cloudreve/v3/models" 8 | "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" 9 | ) 10 | 11 | // DecompressTask 文件压缩任务 12 | type DecompressTask struct { 13 | User *model.User 14 | TaskModel *model.Task 15 | TaskProps DecompressProps 16 | Err *JobError 17 | 18 | zipPath string 19 | } 20 | 21 | // DecompressProps 压缩任务属性 22 | type DecompressProps struct { 23 | Src string `json:"src"` 24 | Dst string `json:"dst"` 25 | } 26 | 27 | // Props 获取任务属性 28 | func (job *DecompressTask) Props() string { 29 | res, _ := json.Marshal(job.TaskProps) 30 | return string(res) 31 | } 32 | 33 | // Type 获取任务状态 34 | func (job *DecompressTask) Type() int { 35 | return DecompressTaskType 36 | } 37 | 38 | // Creator 获取创建者ID 39 | func (job *DecompressTask) Creator() uint { 40 | return job.User.ID 41 | } 42 | 43 | // Model 获取任务的数据库模型 44 | func (job *DecompressTask) Model() *model.Task { 45 | return job.TaskModel 46 | } 47 | 48 | // SetStatus 设定状态 49 | func (job *DecompressTask) SetStatus(status int) { 50 | job.TaskModel.SetStatus(status) 51 | } 52 | 53 | // SetError 设定任务失败信息 54 | func (job *DecompressTask) SetError(err *JobError) { 55 | job.Err = err 56 | res, _ := json.Marshal(job.Err) 57 | job.TaskModel.SetError(string(res)) 58 | } 59 | 60 | // SetErrorMsg 设定任务失败信息 61 | func (job *DecompressTask) SetErrorMsg(msg string, err error) { 62 | jobErr := &JobError{Msg: msg} 63 | if err != nil { 64 | jobErr.Error = err.Error() 65 | } 66 | job.SetError(jobErr) 67 | } 68 | 69 | // GetError 返回任务失败信息 70 | func (job *DecompressTask) GetError() *JobError { 71 | return job.Err 72 | } 73 | 74 | // Do 开始执行任务 75 | func (job *DecompressTask) Do() { 76 | // 创建文件系统 77 | fs, err := filesystem.NewFileSystem(job.User) 78 | if err != nil { 79 | job.SetErrorMsg("无法创建文件系统", err) 80 | return 81 | } 82 | 83 | job.TaskModel.SetProgress(DecompressingProgress) 84 | err = fs.Decompress(context.Background(), job.TaskProps.Src, job.TaskProps.Dst) 85 | if err != nil { 86 | job.SetErrorMsg("解压缩失败", err) 87 | return 88 | } 89 | 90 | } 91 | 92 | // NewDecompressTask 新建压缩任务 93 | func NewDecompressTask(user *model.User, src, dst string) (Job, error) { 94 | newTask := &DecompressTask{ 95 | User: user, 96 | TaskProps: DecompressProps{ 97 | Src: src, 98 | Dst: dst, 99 | }, 100 | } 101 | 102 | record, err := Record(newTask) 103 | if err != nil { 104 | return nil, err 105 | } 106 | newTask.TaskModel = record 107 | 108 | return newTask, nil 109 | } 110 | 111 | // NewDecompressTaskFromModel 从数据库记录中恢复压缩任务 112 | func NewDecompressTaskFromModel(task *model.Task) (Job, error) { 113 | user, err := model.GetActiveUserByID(task.UserID) 114 | if err != nil { 115 | return nil, err 116 | } 117 | newTask := &DecompressTask{ 118 | User: &user, 119 | TaskModel: task, 120 | } 121 | 122 | err = json.Unmarshal([]byte(task.Props), &newTask.TaskProps) 123 | if err != nil { 124 | return nil, err 125 | } 126 | 127 | return newTask, nil 128 | } 129 | -------------------------------------------------------------------------------- /pkg/task/errors.go: -------------------------------------------------------------------------------- 1 | package task 2 | 3 | import "errors" 4 | 5 | var ( 6 | // ErrUnknownTaskType 未知任务类型 7 | ErrUnknownTaskType = errors.New("未知任务类型") 8 | ) 9 | -------------------------------------------------------------------------------- /pkg/task/job.go: -------------------------------------------------------------------------------- 1 | package task 2 | 3 | import ( 4 | model "github.com/cloudreve/Cloudreve/v3/models" 5 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 6 | ) 7 | 8 | // 任务类型 9 | const ( 10 | // CompressTaskType 压缩任务 11 | CompressTaskType = iota 12 | // DecompressTaskType 解压缩任务 13 | DecompressTaskType 14 | // TransferTaskType 中转任务 15 | TransferTaskType 16 | // ImportTaskType 导入任务 17 | ImportTaskType 18 | ) 19 | 20 | // 任务状态 21 | const ( 22 | // Queued 排队中 23 | Queued = iota 24 | // Processing 处理中 25 | Processing 26 | // Error 失败 27 | Error 28 | // Canceled 取消 29 | Canceled 30 | // Complete 完成 31 | Complete 32 | ) 33 | 34 | // 任务进度 35 | const ( 36 | // PendingProgress 等待中 37 | PendingProgress = iota 38 | // Compressing 压缩中 39 | CompressingProgress 40 | // Decompressing 解压缩中 41 | DecompressingProgress 42 | // Downloading 下载中 43 | DownloadingProgress 44 | // Transferring 转存中 45 | TransferringProgress 46 | // ListingProgress 索引中 47 | ListingProgress 48 | // InsertingProgress 插入中 49 | InsertingProgress 50 | ) 51 | 52 | // Job 任务接口 53 | type Job interface { 54 | Type() int // 返回任务类型 55 | Creator() uint // 返回创建者ID 56 | Props() string // 返回序列化后的任务属性 57 | Model() *model.Task // 返回对应的数据库模型 58 | SetStatus(int) // 设定任务状态 59 | Do() // 开始执行任务 60 | SetError(*JobError) // 设定任务失败信息 61 | GetError() *JobError // 获取任务执行结果,返回nil表示成功完成执行 62 | } 63 | 64 | // JobError 任务失败信息 65 | type JobError struct { 66 | Msg string `json:"msg,omitempty"` 67 | Error string `json:"error,omitempty"` 68 | } 69 | 70 | // Record 将任务记录到数据库中 71 | func Record(job Job) (*model.Task, error) { 72 | record := model.Task{ 73 | Status: Queued, 74 | Type: job.Type(), 75 | UserID: job.Creator(), 76 | Progress: 0, 77 | Error: "", 78 | Props: job.Props(), 79 | } 80 | _, err := record.Create() 81 | return &record, err 82 | } 83 | 84 | // Resume 从数据库中恢复未完成任务 85 | func Resume() { 86 | tasks := model.GetTasksByStatus(Queued, Processing) 87 | if len(tasks) == 0 { 88 | return 89 | } 90 | util.Log().Info("从数据库中恢复 %d 个未完成任务", len(tasks)) 91 | 92 | for i := 0; i < len(tasks); i++ { 93 | job, err := GetJobFromModel(&tasks[i]) 94 | if err != nil { 95 | util.Log().Warning("无法恢复任务,%s", err) 96 | continue 97 | } 98 | 99 | TaskPoll.Submit(job) 100 | } 101 | } 102 | 103 | // GetJobFromModel 从数据库给定模型获取任务 104 | func GetJobFromModel(task *model.Task) (Job, error) { 105 | switch task.Type { 106 | case CompressTaskType: 107 | return NewCompressTaskFromModel(task) 108 | case DecompressTaskType: 109 | return NewDecompressTaskFromModel(task) 110 | case TransferTaskType: 111 | return NewTransferTaskFromModel(task) 112 | case ImportTaskType: 113 | return NewImportTaskFromModel(task) 114 | default: 115 | return nil, ErrUnknownTaskType 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /pkg/task/job_test.go: -------------------------------------------------------------------------------- 1 | package task 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | 7 | "github.com/DATA-DOG/go-sqlmock" 8 | model "github.com/cloudreve/Cloudreve/v3/models" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestRecord(t *testing.T) { 13 | asserts := assert.New(t) 14 | job := &TransferTask{ 15 | User: &model.User{Policy: model.Policy{Type: "unknown"}}, 16 | } 17 | mock.ExpectBegin() 18 | mock.ExpectExec("INSERT(.+)").WillReturnResult(sqlmock.NewResult(1, 1)) 19 | mock.ExpectCommit() 20 | _, err := Record(job) 21 | asserts.NoError(err) 22 | } 23 | 24 | func TestResume(t *testing.T) { 25 | asserts := assert.New(t) 26 | 27 | // 没有任务 28 | { 29 | mock.ExpectQuery("SELECT(.+)").WithArgs(Queued).WillReturnRows(sqlmock.NewRows([]string{"type"})) 30 | Resume() 31 | asserts.NoError(mock.ExpectationsWereMet()) 32 | } 33 | } 34 | 35 | func TestGetJobFromModel(t *testing.T) { 36 | asserts := assert.New(t) 37 | 38 | // CompressTaskType 39 | { 40 | task := &model.Task{ 41 | Status: 0, 42 | Type: CompressTaskType, 43 | } 44 | mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error")) 45 | job, err := GetJobFromModel(task) 46 | asserts.NoError(mock.ExpectationsWereMet()) 47 | asserts.Nil(job) 48 | asserts.Error(err) 49 | } 50 | // DecompressTaskType 51 | { 52 | task := &model.Task{ 53 | Status: 0, 54 | Type: DecompressTaskType, 55 | } 56 | mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error")) 57 | job, err := GetJobFromModel(task) 58 | asserts.NoError(mock.ExpectationsWereMet()) 59 | asserts.Nil(job) 60 | asserts.Error(err) 61 | } 62 | // TransferTaskType 63 | { 64 | task := &model.Task{ 65 | Status: 0, 66 | Type: TransferTaskType, 67 | } 68 | mock.ExpectQuery("SELECT(.+)users(.+)").WillReturnError(errors.New("error")) 69 | job, err := GetJobFromModel(task) 70 | asserts.NoError(mock.ExpectationsWereMet()) 71 | asserts.Nil(job) 72 | asserts.Error(err) 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /pkg/task/pool.go: -------------------------------------------------------------------------------- 1 | package task 2 | 3 | import ( 4 | model "github.com/cloudreve/Cloudreve/v3/models" 5 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 6 | ) 7 | 8 | // TaskPoll 要使用的任务池 9 | var TaskPoll *Pool 10 | 11 | // Pool 带有最大配额的任务池 12 | type Pool struct { 13 | // 容量 14 | idleWorker chan int 15 | } 16 | 17 | // Add 增加可用Worker数量 18 | func (pool *Pool) Add(num int) { 19 | for i := 0; i < num; i++ { 20 | pool.idleWorker <- 1 21 | } 22 | } 23 | 24 | // ObtainWorker 阻塞直到获取新的Worker 25 | func (pool *Pool) ObtainWorker() Worker { 26 | select { 27 | case <-pool.idleWorker: 28 | // 有空闲Worker名额时,返回新Worker 29 | return &GeneralWorker{} 30 | } 31 | } 32 | 33 | // FreeWorker 添加空闲Worker 34 | func (pool *Pool) FreeWorker() { 35 | pool.Add(1) 36 | } 37 | 38 | // Submit 开始提交任务 39 | func (pool *Pool) Submit(job Job) { 40 | go func() { 41 | util.Log().Debug("等待获取Worker") 42 | worker := pool.ObtainWorker() 43 | util.Log().Debug("获取到Worker") 44 | worker.Do(job) 45 | util.Log().Debug("释放Worker") 46 | pool.FreeWorker() 47 | }() 48 | } 49 | 50 | // Init 初始化任务池 51 | func Init() { 52 | maxWorker := model.GetIntSetting("max_worker_num", 10) 53 | TaskPoll = &Pool{ 54 | idleWorker: make(chan int, maxWorker), 55 | } 56 | TaskPoll.Add(maxWorker) 57 | util.Log().Info("初始化任务队列,WorkerNum = %d", maxWorker) 58 | 59 | Resume() 60 | } 61 | -------------------------------------------------------------------------------- /pkg/task/pool_test.go: -------------------------------------------------------------------------------- 1 | package task 2 | 3 | import ( 4 | "database/sql" 5 | "testing" 6 | 7 | "github.com/DATA-DOG/go-sqlmock" 8 | model "github.com/cloudreve/Cloudreve/v3/models" 9 | "github.com/cloudreve/Cloudreve/v3/pkg/cache" 10 | "github.com/jinzhu/gorm" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | var mock sqlmock.Sqlmock 15 | 16 | // TestMain 初始化数据库Mock 17 | func TestMain(m *testing.M) { 18 | var db *sql.DB 19 | var err error 20 | db, mock, err = sqlmock.New() 21 | if err != nil { 22 | panic("An error was not expected when opening a stub database connection") 23 | } 24 | model.DB, _ = gorm.Open("mysql", db) 25 | defer db.Close() 26 | m.Run() 27 | } 28 | 29 | func TestInit(t *testing.T) { 30 | asserts := assert.New(t) 31 | cache.Set("setting_max_worker_num", "10", 0) 32 | mock.ExpectQuery("SELECT(.+)").WithArgs(Queued).WillReturnRows(sqlmock.NewRows([]string{"type"}).AddRow(-1)) 33 | Init() 34 | asserts.NoError(mock.ExpectationsWereMet()) 35 | asserts.Len(TaskPoll.idleWorker, 10) 36 | } 37 | 38 | func TestPool_Submit(t *testing.T) { 39 | asserts := assert.New(t) 40 | pool := &Pool{ 41 | idleWorker: make(chan int, 1), 42 | } 43 | pool.Add(1) 44 | job := &MockJob{ 45 | DoFunc: func() { 46 | 47 | }, 48 | } 49 | asserts.NotPanics(func() { 50 | pool.Submit(job) 51 | }) 52 | } 53 | -------------------------------------------------------------------------------- /pkg/task/worker.go: -------------------------------------------------------------------------------- 1 | package task 2 | 3 | import "github.com/cloudreve/Cloudreve/v3/pkg/util" 4 | 5 | // Worker 处理任务的对象 6 | type Worker interface { 7 | Do(Job) // 执行任务 8 | } 9 | 10 | // GeneralWorker 通用Worker 11 | type GeneralWorker struct { 12 | } 13 | 14 | // Do 执行任务 15 | func (worker *GeneralWorker) Do(job Job) { 16 | util.Log().Debug("开始执行任务") 17 | job.SetStatus(Processing) 18 | 19 | defer func() { 20 | // 致命错误捕获 21 | if err := recover(); err != nil { 22 | util.Log().Debug("任务执行出错,%s", err) 23 | job.SetError(&JobError{Msg: "致命错误"}) 24 | job.SetStatus(Error) 25 | } 26 | }() 27 | 28 | // 开始执行任务 29 | job.Do() 30 | 31 | // 任务执行失败 32 | if err := job.GetError(); err != nil { 33 | util.Log().Debug("任务执行出错") 34 | job.SetStatus(Error) 35 | return 36 | } 37 | 38 | util.Log().Debug("任务执行完成") 39 | // 执行完成 40 | job.SetStatus(Complete) 41 | } 42 | -------------------------------------------------------------------------------- /pkg/task/worker_test.go: -------------------------------------------------------------------------------- 1 | package task 2 | 3 | import ( 4 | "testing" 5 | 6 | model "github.com/cloudreve/Cloudreve/v3/models" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | type MockJob struct { 11 | Err *JobError 12 | Status int 13 | DoFunc func() 14 | } 15 | 16 | func (job *MockJob) Type() int { 17 | panic("implement me") 18 | } 19 | 20 | func (job *MockJob) Creator() uint { 21 | panic("implement me") 22 | } 23 | 24 | func (job *MockJob) Props() string { 25 | panic("implement me") 26 | } 27 | 28 | func (job *MockJob) Model() *model.Task { 29 | panic("implement me") 30 | } 31 | 32 | func (job *MockJob) SetStatus(status int) { 33 | job.Status = status 34 | } 35 | 36 | func (job *MockJob) Do() { 37 | job.DoFunc() 38 | } 39 | 40 | func (job *MockJob) SetError(*JobError) { 41 | } 42 | 43 | func (job *MockJob) GetError() *JobError { 44 | return job.Err 45 | } 46 | 47 | func TestGeneralWorker_Do(t *testing.T) { 48 | asserts := assert.New(t) 49 | worker := &GeneralWorker{} 50 | job := &MockJob{} 51 | 52 | // 正常 53 | { 54 | job.DoFunc = func() { 55 | } 56 | worker.Do(job) 57 | asserts.Equal(Complete, job.Status) 58 | } 59 | 60 | // 有错误 61 | { 62 | job.DoFunc = func() { 63 | } 64 | job.Status = Queued 65 | job.Err = &JobError{Msg: "error"} 66 | worker.Do(job) 67 | asserts.Equal(Error, job.Status) 68 | } 69 | 70 | // 有致命错误 71 | { 72 | job.DoFunc = func() { 73 | panic("mock fatal error") 74 | } 75 | job.Status = Queued 76 | job.Err = nil 77 | worker.Do(job) 78 | asserts.Equal(Error, job.Status) 79 | } 80 | 81 | } 82 | -------------------------------------------------------------------------------- /pkg/thumb/image.go: -------------------------------------------------------------------------------- 1 | package thumb 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "image" 7 | "image/gif" 8 | "image/jpeg" 9 | "image/png" 10 | "io" 11 | "path/filepath" 12 | "strings" 13 | 14 | model "github.com/cloudreve/Cloudreve/v3/models" 15 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 16 | 17 | "github.com/nfnt/resize" 18 | ) 19 | 20 | // Thumb 缩略图 21 | type Thumb struct { 22 | src image.Image 23 | ext string 24 | } 25 | 26 | // NewThumbFromFile 从文件数据获取新的Thumb对象, 27 | // 尝试通过文件名name解码图像 28 | func NewThumbFromFile(file io.Reader, name string) (*Thumb, error) { 29 | ext := strings.ToLower(filepath.Ext(name)) 30 | // 无扩展名时 31 | if len(ext) == 0 { 32 | return nil, errors.New("未知的图像类型") 33 | } 34 | 35 | var err error 36 | var img image.Image 37 | switch ext[1:] { 38 | case "jpg": 39 | img, err = jpeg.Decode(file) 40 | case "jpeg": 41 | img, err = jpeg.Decode(file) 42 | case "gif": 43 | img, err = gif.Decode(file) 44 | case "png": 45 | img, err = png.Decode(file) 46 | default: 47 | return nil, errors.New("未知的图像类型") 48 | } 49 | if err != nil { 50 | return nil, err 51 | } 52 | 53 | return &Thumb{ 54 | src: img, 55 | ext: ext[1:], 56 | }, nil 57 | } 58 | 59 | // GetThumb 生成给定最大尺寸的缩略图 60 | func (image *Thumb) GetThumb(width, height uint) { 61 | image.src = resize.Thumbnail(width, height, image.src, resize.Lanczos3) 62 | } 63 | 64 | // GetSize 获取图像尺寸 65 | func (image *Thumb) GetSize() (int, int) { 66 | b := image.src.Bounds() 67 | return b.Max.X, b.Max.Y 68 | } 69 | 70 | // Save 保存图像到给定路径 71 | func (image *Thumb) Save(path string) (err error) { 72 | out, err := util.CreatNestedFile(path) 73 | 74 | if err != nil { 75 | return err 76 | } 77 | defer out.Close() 78 | 79 | err = png.Encode(out, image.src) 80 | return err 81 | 82 | } 83 | 84 | // CreateAvatar 创建头像 85 | func (image *Thumb) CreateAvatar(uid uint) error { 86 | // 读取头像相关设定 87 | savePath := util.RelativePath(model.GetSettingByName("avatar_path")) 88 | s := model.GetIntSetting("avatar_size_s", 50) 89 | m := model.GetIntSetting("avatar_size_m", 130) 90 | l := model.GetIntSetting("avatar_size_l", 200) 91 | 92 | // 生成头像缩略图 93 | src := image.src 94 | for k, size := range []int{s, m, l} { 95 | image.src = resize.Resize(uint(size), uint(size), src, resize.Lanczos3) 96 | err := image.Save(filepath.Join(savePath, fmt.Sprintf("avatar_%d_%d.png", uid, k))) 97 | if err != nil { 98 | return err 99 | } 100 | } 101 | 102 | return nil 103 | 104 | } 105 | -------------------------------------------------------------------------------- /pkg/thumb/image_test.go: -------------------------------------------------------------------------------- 1 | package thumb 2 | 3 | import ( 4 | "fmt" 5 | "image" 6 | "image/jpeg" 7 | "os" 8 | "testing" 9 | 10 | "github.com/cloudreve/Cloudreve/v3/pkg/cache" 11 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func CreateTestImage() *os.File { 16 | file, err := os.Create("TestNewThumbFromFile.jpeg") 17 | alpha := image.NewAlpha(image.Rect(0, 0, 500, 200)) 18 | jpeg.Encode(file, alpha, nil) 19 | if err != nil { 20 | fmt.Println(err) 21 | } 22 | _, _ = file.Seek(0, 0) 23 | return file 24 | } 25 | 26 | func TestNewThumbFromFile(t *testing.T) { 27 | asserts := assert.New(t) 28 | file := CreateTestImage() 29 | defer file.Close() 30 | 31 | // 无扩展名时 32 | { 33 | thumb, err := NewThumbFromFile(file, "123") 34 | asserts.Error(err) 35 | asserts.Nil(thumb) 36 | } 37 | 38 | { 39 | thumb, err := NewThumbFromFile(file, "123.jpg") 40 | asserts.NoError(err) 41 | asserts.NotNil(thumb) 42 | } 43 | { 44 | thumb, err := NewThumbFromFile(file, "123.jpeg") 45 | asserts.Error(err) 46 | asserts.Nil(thumb) 47 | } 48 | { 49 | thumb, err := NewThumbFromFile(file, "123.png") 50 | asserts.Error(err) 51 | asserts.Nil(thumb) 52 | } 53 | { 54 | thumb, err := NewThumbFromFile(file, "123.gif") 55 | asserts.Error(err) 56 | asserts.Nil(thumb) 57 | } 58 | { 59 | thumb, err := NewThumbFromFile(file, "123.3211") 60 | asserts.Error(err) 61 | asserts.Nil(thumb) 62 | } 63 | } 64 | 65 | func TestThumb_GetSize(t *testing.T) { 66 | asserts := assert.New(t) 67 | file := CreateTestImage() 68 | defer file.Close() 69 | thumb, err := NewThumbFromFile(file, "123.jpg") 70 | asserts.NoError(err) 71 | 72 | w, h := thumb.GetSize() 73 | asserts.Equal(500, w) 74 | asserts.Equal(200, h) 75 | } 76 | 77 | func TestThumb_GetThumb(t *testing.T) { 78 | asserts := assert.New(t) 79 | file := CreateTestImage() 80 | defer file.Close() 81 | thumb, err := NewThumbFromFile(file, "123.jpg") 82 | asserts.NoError(err) 83 | 84 | asserts.NotPanics(func() { 85 | thumb.GetThumb(10, 10) 86 | }) 87 | } 88 | 89 | func TestThumb_Save(t *testing.T) { 90 | asserts := assert.New(t) 91 | file := CreateTestImage() 92 | defer file.Close() 93 | thumb, err := NewThumbFromFile(file, "123.jpg") 94 | asserts.NoError(err) 95 | 96 | err = thumb.Save("/:noteexist/") 97 | asserts.Error(err) 98 | 99 | err = thumb.Save("TestThumb_Save.png") 100 | asserts.NoError(err) 101 | asserts.True(util.Exists("TestThumb_Save.png")) 102 | 103 | } 104 | 105 | func TestThumb_CreateAvatar(t *testing.T) { 106 | asserts := assert.New(t) 107 | file := CreateTestImage() 108 | defer file.Close() 109 | 110 | thumb, err := NewThumbFromFile(file, "123.jpg") 111 | asserts.NoError(err) 112 | 113 | cache.Set("setting_avatar_path", "tests", 0) 114 | cache.Set("setting_avatar_size_s", "50", 0) 115 | cache.Set("setting_avatar_size_m", "130", 0) 116 | cache.Set("setting_avatar_size_l", "200", 0) 117 | 118 | asserts.NoError(thumb.CreateAvatar(1)) 119 | asserts.True(util.Exists(util.RelativePath("tests/avatar_1_1.png"))) 120 | asserts.True(util.Exists(util.RelativePath("tests/avatar_1_2.png"))) 121 | asserts.True(util.Exists(util.RelativePath("tests/avatar_1_0.png"))) 122 | } 123 | -------------------------------------------------------------------------------- /pkg/util/common.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "math/rand" 5 | "regexp" 6 | "strings" 7 | "time" 8 | ) 9 | 10 | func init() { 11 | rand.Seed(time.Now().UnixNano()) 12 | } 13 | 14 | // RandStringRunes 返回随机字符串 15 | func RandStringRunes(n int) string { 16 | var letterRunes = []rune("1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") 17 | 18 | b := make([]rune, n) 19 | for i := range b { 20 | b[i] = letterRunes[rand.Intn(len(letterRunes))] 21 | } 22 | return string(b) 23 | } 24 | 25 | // ContainsUint 返回list中是否包含 26 | func ContainsUint(s []uint, e uint) bool { 27 | for _, a := range s { 28 | if a == e { 29 | return true 30 | } 31 | } 32 | return false 33 | } 34 | 35 | // ContainsString 返回list中是否包含 36 | func ContainsString(s []string, e string) bool { 37 | for _, a := range s { 38 | if a == e { 39 | return true 40 | } 41 | } 42 | return false 43 | } 44 | 45 | // Replace 根据替换表执行批量替换 46 | func Replace(table map[string]string, s string) string { 47 | for key, value := range table { 48 | s = strings.Replace(s, key, value, -1) 49 | } 50 | return s 51 | } 52 | 53 | // BuildRegexp 构建用于SQL查询用的多条件正则 54 | func BuildRegexp(search []string, prefix, suffix, condition string) string { 55 | var res string 56 | for key, value := range search { 57 | res += prefix + regexp.QuoteMeta(value) + suffix 58 | if key < len(search)-1 { 59 | res += condition 60 | } 61 | } 62 | return res 63 | } 64 | 65 | // BuildConcat 根据数据库类型构建字符串连接表达式 66 | func BuildConcat(str1, str2 string, DBType string) string { 67 | switch DBType { 68 | case "mysql": 69 | return "CONCAT(" + str1 + "," + str2 + ")" 70 | default: 71 | return str1 + "||" + str2 72 | } 73 | } 74 | 75 | // SliceIntersect 求两个切片交集 76 | func SliceIntersect(slice1, slice2 []string) []string { 77 | m := make(map[string]int) 78 | nn := make([]string, 0) 79 | for _, v := range slice1 { 80 | m[v]++ 81 | } 82 | 83 | for _, v := range slice2 { 84 | times, _ := m[v] 85 | if times == 1 { 86 | nn = append(nn, v) 87 | } 88 | } 89 | return nn 90 | } 91 | 92 | // SliceDifference 求两个切片差集 93 | func SliceDifference(slice1, slice2 []string) []string { 94 | m := make(map[string]int) 95 | nn := make([]string, 0) 96 | inter := SliceIntersect(slice1, slice2) 97 | for _, v := range inter { 98 | m[v]++ 99 | } 100 | 101 | for _, value := range slice1 { 102 | times, _ := m[value] 103 | if times == 0 { 104 | nn = append(nn, value) 105 | } 106 | } 107 | return nn 108 | } 109 | -------------------------------------------------------------------------------- /pkg/util/common_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | func TestRandStringRunes(t *testing.T) { 9 | asserts := assert.New(t) 10 | 11 | // 0 长度字符 12 | randStr := RandStringRunes(0) 13 | asserts.Len(randStr, 0) 14 | 15 | // 16 长度字符 16 | randStr = RandStringRunes(16) 17 | asserts.Len(randStr, 16) 18 | 19 | // 32 长度字符 20 | randStr = RandStringRunes(32) 21 | asserts.Len(randStr, 32) 22 | 23 | //相同长度字符 24 | sameLenStr1 := RandStringRunes(32) 25 | sameLenStr2 := RandStringRunes(32) 26 | asserts.NotEqual(sameLenStr1, sameLenStr2) 27 | } 28 | 29 | func TestContainsUint(t *testing.T) { 30 | asserts := assert.New(t) 31 | asserts.True(ContainsUint([]uint{0, 2, 3, 65, 4}, 65)) 32 | asserts.True(ContainsUint([]uint{65}, 65)) 33 | asserts.False(ContainsUint([]uint{65}, 6)) 34 | } 35 | 36 | func TestContainsString(t *testing.T) { 37 | asserts := assert.New(t) 38 | asserts.True(ContainsString([]string{"", "1"}, "")) 39 | asserts.True(ContainsString([]string{"", "1"}, "1")) 40 | asserts.False(ContainsString([]string{"", "1"}, " ")) 41 | } 42 | 43 | func TestReplace(t *testing.T) { 44 | asserts := assert.New(t) 45 | 46 | asserts.Equal("origin", Replace(map[string]string{ 47 | "123": "321", 48 | }, "origin")) 49 | 50 | asserts.Equal("321origin321", Replace(map[string]string{ 51 | "123": "321", 52 | }, "123origin123")) 53 | asserts.Equal("321new321", Replace(map[string]string{ 54 | "123": "321", 55 | "origin": "new", 56 | }, "123origin123")) 57 | } 58 | 59 | func TestBuildRegexp(t *testing.T) { 60 | asserts := assert.New(t) 61 | 62 | asserts.Equal("^/dir/", BuildRegexp([]string{"/dir"}, "^", "/", "|")) 63 | asserts.Equal("^/dir/|^/dir/di\\*r/", BuildRegexp([]string{"/dir", "/dir/di*r"}, "^", "/", "|")) 64 | } 65 | 66 | func TestBuildConcat(t *testing.T) { 67 | asserts := assert.New(t) 68 | asserts.Equal("CONCAT(1,2)", BuildConcat("1", "2", "mysql")) 69 | asserts.Equal("1||2", BuildConcat("1", "2", "sqlite3")) 70 | } 71 | 72 | func TestSliceDifference(t *testing.T) { 73 | asserts := assert.New(t) 74 | 75 | { 76 | s1 := []string{"1", "2", "3", "4"} 77 | s2 := []string{"2", "4"} 78 | asserts.Equal([]string{"1", "3"}, SliceDifference(s1, s2)) 79 | } 80 | 81 | { 82 | s2 := []string{"1", "2", "3", "4"} 83 | s1 := []string{"2", "4"} 84 | asserts.Equal([]string{}, SliceDifference(s1, s2)) 85 | } 86 | 87 | { 88 | s1 := []string{"1", "2", "3", "4"} 89 | s2 := []string{"1", "2", "3", "4"} 90 | asserts.Equal([]string{}, SliceDifference(s1, s2)) 91 | } 92 | 93 | { 94 | s1 := []string{"1", "2", "3", "4"} 95 | s2 := []string{} 96 | asserts.Equal([]string{"1", "2", "3", "4"}, SliceDifference(s1, s2)) 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /pkg/util/io.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "io" 5 | "os" 6 | "path/filepath" 7 | ) 8 | 9 | // Exists reports whether the named file or directory exists. 10 | func Exists(name string) bool { 11 | if _, err := os.Stat(name); err != nil { 12 | if os.IsNotExist(err) { 13 | return false 14 | } 15 | } 16 | return true 17 | } 18 | 19 | // CreatNestedFile 给定path创建文件,如果目录不存在就递归创建 20 | func CreatNestedFile(path string) (*os.File, error) { 21 | basePath := filepath.Dir(path) 22 | if !Exists(basePath) { 23 | err := os.MkdirAll(basePath, 0700) 24 | if err != nil { 25 | Log().Warning("无法创建目录,%s", err) 26 | return nil, err 27 | } 28 | } 29 | 30 | return os.Create(path) 31 | } 32 | 33 | // IsEmpty 返回给定目录是否为空目录 34 | func IsEmpty(name string) (bool, error) { 35 | f, err := os.Open(name) 36 | if err != nil { 37 | return false, err 38 | } 39 | defer f.Close() 40 | 41 | _, err = f.Readdirnames(1) // Or f.Readdir(1) 42 | if err == io.EOF { 43 | return true, nil 44 | } 45 | return false, err // Either not empty or error, suits both cases 46 | } 47 | -------------------------------------------------------------------------------- /pkg/util/io_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | func TestExists(t *testing.T) { 9 | asserts := assert.New(t) 10 | asserts.True(Exists("io_test.go")) 11 | asserts.False(Exists("io_test.js")) 12 | } 13 | 14 | func TestCreatNestedFile(t *testing.T) { 15 | asserts := assert.New(t) 16 | 17 | // 父目录不存在 18 | { 19 | file, err := CreatNestedFile("test/nest.txt") 20 | asserts.NoError(err) 21 | asserts.NoError(file.Close()) 22 | asserts.FileExists("test/nest.txt") 23 | } 24 | 25 | // 父目录存在 26 | { 27 | file, err := CreatNestedFile("test/direct.txt") 28 | asserts.NoError(err) 29 | asserts.NoError(file.Close()) 30 | asserts.FileExists("test/direct.txt") 31 | } 32 | } 33 | 34 | func TestIsEmpty(t *testing.T) { 35 | asserts := assert.New(t) 36 | 37 | asserts.False(IsEmpty("")) 38 | asserts.False(IsEmpty("not_exist")) 39 | } 40 | -------------------------------------------------------------------------------- /pkg/util/logger_test.go: -------------------------------------------------------------------------------- 1 | // +build !race 2 | 3 | package util 4 | 5 | import ( 6 | "github.com/stretchr/testify/assert" 7 | "testing" 8 | ) 9 | 10 | func TestBuildLogger(t *testing.T) { 11 | asserts := assert.New(t) 12 | asserts.NotPanics(func() { 13 | BuildLogger("error") 14 | }) 15 | asserts.NotPanics(func() { 16 | BuildLogger("warning") 17 | }) 18 | asserts.NotPanics(func() { 19 | BuildLogger("info") 20 | }) 21 | asserts.NotPanics(func() { 22 | BuildLogger("?") 23 | }) 24 | asserts.NotPanics(func() { 25 | BuildLogger("debug") 26 | }) 27 | } 28 | 29 | func TestLog(t *testing.T) { 30 | asserts := assert.New(t) 31 | asserts.NotNil(Log()) 32 | GloablLogger = nil 33 | asserts.NotNil(Log()) 34 | } 35 | 36 | func TestLogger_Debug(t *testing.T) { 37 | asserts := assert.New(t) 38 | l := Logger{ 39 | level: LevelDebug, 40 | } 41 | asserts.NotPanics(func() { 42 | l.Debug("123") 43 | }) 44 | l.level = LevelError 45 | asserts.NotPanics(func() { 46 | l.Debug("123") 47 | }) 48 | } 49 | 50 | func TestLogger_Info(t *testing.T) { 51 | asserts := assert.New(t) 52 | l := Logger{ 53 | level: LevelDebug, 54 | } 55 | asserts.NotPanics(func() { 56 | l.Info("123") 57 | }) 58 | l.level = LevelError 59 | asserts.NotPanics(func() { 60 | l.Info("123") 61 | }) 62 | } 63 | func TestLogger_Warning(t *testing.T) { 64 | asserts := assert.New(t) 65 | l := Logger{ 66 | level: LevelDebug, 67 | } 68 | asserts.NotPanics(func() { 69 | l.Warning("123") 70 | }) 71 | l.level = LevelError 72 | asserts.NotPanics(func() { 73 | l.Warning("123") 74 | }) 75 | } 76 | 77 | func TestLogger_Error(t *testing.T) { 78 | asserts := assert.New(t) 79 | l := Logger{ 80 | level: LevelDebug, 81 | } 82 | asserts.NotPanics(func() { 83 | l.Error("123") 84 | }) 85 | l.level = -1 86 | asserts.NotPanics(func() { 87 | l.Error("123") 88 | }) 89 | } 90 | 91 | func TestLogger_Panic(t *testing.T) { 92 | asserts := assert.New(t) 93 | l := Logger{ 94 | level: LevelDebug, 95 | } 96 | asserts.Panics(func() { 97 | l.Panic("123") 98 | }) 99 | l.level = -1 100 | asserts.NotPanics(func() { 101 | l.Error("123") 102 | }) 103 | } 104 | -------------------------------------------------------------------------------- /pkg/util/path.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "os" 5 | "path" 6 | "path/filepath" 7 | "strings" 8 | ) 9 | 10 | // DotPathToStandardPath 将","分割的路径转换为标准路径 11 | func DotPathToStandardPath(path string) string { 12 | return "/" + strings.Replace(path, ",", "/", -1) 13 | } 14 | 15 | // FillSlash 给路径补全`/` 16 | func FillSlash(path string) string { 17 | if path == "/" { 18 | return path 19 | } 20 | return path + "/" 21 | } 22 | 23 | // RemoveSlash 移除路径最后的`/` 24 | func RemoveSlash(path string) string { 25 | if len(path) > 1 { 26 | return strings.TrimSuffix(path, "/") 27 | } 28 | return path 29 | } 30 | 31 | // SplitPath 分割路径为列表 32 | func SplitPath(path string) []string { 33 | if len(path) == 0 || path[0] != '/' { 34 | return []string{} 35 | } 36 | 37 | if path == "/" { 38 | return []string{"/"} 39 | } 40 | 41 | pathSplit := strings.Split(path, "/") 42 | pathSplit[0] = "/" 43 | return pathSplit 44 | } 45 | 46 | // FormSlash 将path中的反斜杠'\'替换为'/' 47 | func FormSlash(old string) string { 48 | return path.Clean(strings.ReplaceAll(old, "\\", "/")) 49 | } 50 | 51 | // RelativePath 获取相对可执行文件的路径 52 | func RelativePath(name string) string { 53 | if filepath.IsAbs(name) { 54 | return name 55 | } 56 | e, _ := os.Executable() 57 | return filepath.Join(filepath.Dir(e), name) 58 | } 59 | -------------------------------------------------------------------------------- /pkg/util/path_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | func TestDotPathToStandardPath(t *testing.T) { 9 | asserts := assert.New(t) 10 | 11 | asserts.Equal("/", DotPathToStandardPath("")) 12 | asserts.Equal("/目录", DotPathToStandardPath("目录")) 13 | asserts.Equal("/目录/目录2", DotPathToStandardPath("目录,目录2")) 14 | } 15 | 16 | func TestFillSlash(t *testing.T) { 17 | asserts := assert.New(t) 18 | asserts.Equal("/", FillSlash("/")) 19 | asserts.Equal("/", FillSlash("")) 20 | asserts.Equal("/123/", FillSlash("/123")) 21 | } 22 | 23 | func TestRemoveSlash(t *testing.T) { 24 | asserts := assert.New(t) 25 | asserts.Equal("/", RemoveSlash("/")) 26 | asserts.Equal("/123/1236", RemoveSlash("/123/1236")) 27 | asserts.Equal("/123/1236", RemoveSlash("/123/1236/")) 28 | } 29 | 30 | func TestSplitPath(t *testing.T) { 31 | asserts := assert.New(t) 32 | asserts.Equal([]string{}, SplitPath("")) 33 | asserts.Equal([]string{}, SplitPath("1")) 34 | asserts.Equal([]string{"/"}, SplitPath("/")) 35 | asserts.Equal([]string{"/", "123", "321"}, SplitPath("/123/321")) 36 | } 37 | -------------------------------------------------------------------------------- /pkg/util/session.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "github.com/gin-contrib/sessions" 5 | "github.com/gin-gonic/gin" 6 | ) 7 | 8 | // SetSession 设置session 9 | func SetSession(c *gin.Context, list map[string]interface{}) { 10 | s := sessions.Default(c) 11 | for key, value := range list { 12 | s.Set(key, value) 13 | } 14 | 15 | err := s.Save() 16 | if err != nil { 17 | Log().Warning("无法设置 Session 值:%s", err) 18 | } 19 | } 20 | 21 | // GetSession 获取session 22 | func GetSession(c *gin.Context, key string) interface{} { 23 | s := sessions.Default(c) 24 | return s.Get(key) 25 | } 26 | 27 | // DeleteSession 删除session 28 | func DeleteSession(c *gin.Context, key string) { 29 | s := sessions.Default(c) 30 | s.Delete(key) 31 | s.Save() 32 | } 33 | 34 | // ClearSession 清空session 35 | func ClearSession(c *gin.Context) { 36 | s := sessions.Default(c) 37 | s.Clear() 38 | s.Save() 39 | } 40 | -------------------------------------------------------------------------------- /pkg/webdav/internal/xml/README: -------------------------------------------------------------------------------- 1 | This is a fork of the encoding/xml package at ca1d6c4, the last commit before 2 | https://go.googlesource.com/go/+/c0d6d33 "encoding/xml: restore Go 1.4 name 3 | space behavior" made late in the lead-up to the Go 1.5 release. 4 | 5 | The list of encoding/xml changes is at 6 | https://go.googlesource.com/go/+log/master/src/encoding/xml 7 | 8 | This fork is temporary, and I (nigeltao) expect to revert it after Go 1.6 is 9 | released. 10 | 11 | See http://golang.org/issue/11841 12 | -------------------------------------------------------------------------------- /routers/controllers/aria2.go: -------------------------------------------------------------------------------- 1 | package controllers 2 | 3 | import ( 4 | "context" 5 | 6 | ariaCall "github.com/cloudreve/Cloudreve/v3/pkg/aria2" 7 | "github.com/cloudreve/Cloudreve/v3/service/aria2" 8 | "github.com/cloudreve/Cloudreve/v3/service/explorer" 9 | "github.com/gin-gonic/gin" 10 | ) 11 | 12 | // AddAria2URL 添加离线下载URL 13 | func AddAria2URL(c *gin.Context) { 14 | var addService aria2.AddURLService 15 | if err := c.ShouldBindJSON(&addService); err == nil { 16 | res := addService.Add(c, ariaCall.URLTask) 17 | c.JSON(200, res) 18 | } else { 19 | c.JSON(200, ErrorResponse(err)) 20 | } 21 | } 22 | 23 | // SelectAria2File 选择多文件离线下载中要下载的文件 24 | func SelectAria2File(c *gin.Context) { 25 | var selectService aria2.SelectFileService 26 | if err := c.ShouldBindJSON(&selectService); err == nil { 27 | res := selectService.Select(c) 28 | c.JSON(200, res) 29 | } else { 30 | c.JSON(200, ErrorResponse(err)) 31 | } 32 | } 33 | 34 | // AddAria2Torrent 添加离线下载种子 35 | func AddAria2Torrent(c *gin.Context) { 36 | // 创建上下文 37 | ctx, cancel := context.WithCancel(context.Background()) 38 | defer cancel() 39 | 40 | var service explorer.FileIDService 41 | if err := c.ShouldBindUri(&service); err == nil { 42 | // 获取种子内容的下载地址 43 | res := service.CreateDownloadSession(ctx, c) 44 | if res.Code != 0 { 45 | c.JSON(200, res) 46 | return 47 | } 48 | 49 | // 创建下载任务 50 | var addService aria2.AddURLService 51 | addService.URL = res.Data.(string) 52 | 53 | if err := c.ShouldBindJSON(&addService); err == nil { 54 | addService.URL = res.Data.(string) 55 | res := addService.Add(c, ariaCall.URLTask) 56 | c.JSON(200, res) 57 | } else { 58 | c.JSON(200, ErrorResponse(err)) 59 | } 60 | 61 | } else { 62 | c.JSON(200, ErrorResponse(err)) 63 | } 64 | } 65 | 66 | // CancelAria2Download 取消或删除aria2离线下载任务 67 | func CancelAria2Download(c *gin.Context) { 68 | var selectService aria2.DownloadTaskService 69 | if err := c.ShouldBindUri(&selectService); err == nil { 70 | res := selectService.Delete(c) 71 | c.JSON(200, res) 72 | } else { 73 | c.JSON(200, ErrorResponse(err)) 74 | } 75 | } 76 | 77 | // ListDownloading 获取正在下载中的任务 78 | func ListDownloading(c *gin.Context) { 79 | var service aria2.DownloadListService 80 | if err := c.ShouldBindQuery(&service); err == nil { 81 | res := service.Downloading(c, CurrentUser(c)) 82 | c.JSON(200, res) 83 | } else { 84 | c.JSON(200, ErrorResponse(err)) 85 | } 86 | } 87 | 88 | // ListFinished 获取已完成的任务 89 | func ListFinished(c *gin.Context) { 90 | var service aria2.DownloadListService 91 | if err := c.ShouldBindQuery(&service); err == nil { 92 | res := service.Finished(c, CurrentUser(c)) 93 | c.JSON(200, res) 94 | } else { 95 | c.JSON(200, ErrorResponse(err)) 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /routers/controllers/directory.go: -------------------------------------------------------------------------------- 1 | package controllers 2 | 3 | import ( 4 | "github.com/cloudreve/Cloudreve/v3/service/explorer" 5 | "github.com/gin-gonic/gin" 6 | ) 7 | 8 | // CreateDirectory 创建目录 9 | func CreateDirectory(c *gin.Context) { 10 | var service explorer.DirectoryService 11 | if err := c.ShouldBindJSON(&service); err == nil { 12 | res := service.CreateDirectory(c) 13 | c.JSON(200, res) 14 | } else { 15 | c.JSON(200, ErrorResponse(err)) 16 | } 17 | } 18 | 19 | // ListDirectory 列出目录下内容 20 | func ListDirectory(c *gin.Context) { 21 | var service explorer.DirectoryService 22 | if err := c.ShouldBindUri(&service); err == nil { 23 | res := service.ListDirectory(c) 24 | c.JSON(200, res) 25 | } else { 26 | c.JSON(200, ErrorResponse(err)) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /routers/controllers/main.go: -------------------------------------------------------------------------------- 1 | package controllers 2 | 3 | import ( 4 | "encoding/json" 5 | 6 | model "github.com/cloudreve/Cloudreve/v3/models" 7 | "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 8 | "github.com/gin-gonic/gin" 9 | "gopkg.in/go-playground/validator.v9" 10 | ) 11 | 12 | // ParamErrorMsg 根据Validator返回的错误信息给出错误提示 13 | func ParamErrorMsg(filed string, tag string) string { 14 | // 未通过验证的表单域与中文对应 15 | fieldMap := map[string]string{ 16 | "UserName": "邮箱", 17 | "Password": "密码", 18 | "Path": "路径", 19 | "SourceID": "原始资源", 20 | "URL": "链接", 21 | "Nick": "昵称", 22 | } 23 | // 未通过的规则与中文对应 24 | tagMap := map[string]string{ 25 | "required": "不能为空", 26 | "min": "太短", 27 | "max": "太长", 28 | "email": "格式不正确", 29 | } 30 | fieldVal, findField := fieldMap[filed] 31 | tagVal, findTag := tagMap[tag] 32 | if findField && findTag { 33 | // 返回拼接出来的错误信息 34 | return fieldVal + tagVal 35 | } 36 | return "" 37 | } 38 | 39 | // ErrorResponse 返回错误消息 40 | func ErrorResponse(err error) serializer.Response { 41 | // 处理 Validator 产生的错误 42 | if ve, ok := err.(validator.ValidationErrors); ok { 43 | for _, e := range ve { 44 | return serializer.ParamErr( 45 | ParamErrorMsg(e.Field(), e.Tag()), 46 | err, 47 | ) 48 | } 49 | } 50 | 51 | if _, ok := err.(*json.UnmarshalTypeError); ok { 52 | return serializer.ParamErr("JSON类型不匹配", err) 53 | } 54 | 55 | return serializer.ParamErr("参数错误", err) 56 | } 57 | 58 | // CurrentUser 获取当前用户 59 | func CurrentUser(c *gin.Context) *model.User { 60 | if user, _ := c.Get("user"); user != nil { 61 | if u, ok := user.(*model.User); ok { 62 | return u 63 | } 64 | } 65 | return nil 66 | } 67 | -------------------------------------------------------------------------------- /routers/controllers/objects.go: -------------------------------------------------------------------------------- 1 | package controllers 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/cloudreve/Cloudreve/v3/service/explorer" 7 | "github.com/gin-gonic/gin" 8 | ) 9 | 10 | // Delete 删除文件或目录 11 | func Delete(c *gin.Context) { 12 | // 创建上下文 13 | ctx, cancel := context.WithCancel(context.Background()) 14 | defer cancel() 15 | 16 | var service explorer.ItemIDService 17 | if err := c.ShouldBindJSON(&service); err == nil { 18 | res := service.Delete(ctx, c) 19 | c.JSON(200, res) 20 | } else { 21 | c.JSON(200, ErrorResponse(err)) 22 | } 23 | } 24 | 25 | // Move 移动文件或目录 26 | func Move(c *gin.Context) { 27 | // 创建上下文 28 | ctx, cancel := context.WithCancel(context.Background()) 29 | defer cancel() 30 | 31 | var service explorer.ItemMoveService 32 | if err := c.ShouldBindJSON(&service); err == nil { 33 | res := service.Move(ctx, c) 34 | c.JSON(200, res) 35 | } else { 36 | c.JSON(200, ErrorResponse(err)) 37 | } 38 | } 39 | 40 | // Copy 复制文件或目录 41 | func Copy(c *gin.Context) { 42 | // 创建上下文 43 | ctx, cancel := context.WithCancel(context.Background()) 44 | defer cancel() 45 | 46 | var service explorer.ItemMoveService 47 | if err := c.ShouldBindJSON(&service); err == nil { 48 | res := service.Copy(ctx, c) 49 | c.JSON(200, res) 50 | } else { 51 | c.JSON(200, ErrorResponse(err)) 52 | } 53 | } 54 | 55 | // Rename 重命名文件或目录 56 | func Rename(c *gin.Context) { 57 | // 创建上下文 58 | ctx, cancel := context.WithCancel(context.Background()) 59 | defer cancel() 60 | 61 | var service explorer.ItemRenameService 62 | if err := c.ShouldBindJSON(&service); err == nil { 63 | res := service.Rename(ctx, c) 64 | c.JSON(200, res) 65 | } else { 66 | c.JSON(200, ErrorResponse(err)) 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /routers/controllers/tag.go: -------------------------------------------------------------------------------- 1 | package controllers 2 | 3 | import ( 4 | "github.com/cloudreve/Cloudreve/v3/service/explorer" 5 | "github.com/gin-gonic/gin" 6 | ) 7 | 8 | // CreateFilterTag 创建文件分类标签 9 | func CreateFilterTag(c *gin.Context) { 10 | var service explorer.FilterTagCreateService 11 | if err := c.ShouldBindJSON(&service); err == nil { 12 | res := service.Create(c, CurrentUser(c)) 13 | c.JSON(200, res) 14 | } else { 15 | c.JSON(200, ErrorResponse(err)) 16 | } 17 | } 18 | 19 | // CreateLinkTag 创建目录快捷方式标签 20 | func CreateLinkTag(c *gin.Context) { 21 | var service explorer.LinkTagCreateService 22 | if err := c.ShouldBindJSON(&service); err == nil { 23 | res := service.Create(c, CurrentUser(c)) 24 | c.JSON(200, res) 25 | } else { 26 | c.JSON(200, ErrorResponse(err)) 27 | } 28 | } 29 | 30 | // DeleteTag 删除标签 31 | func DeleteTag(c *gin.Context) { 32 | var service explorer.TagService 33 | if err := c.ShouldBindUri(&service); err == nil { 34 | res := service.Delete(c, CurrentUser(c)) 35 | c.JSON(200, res) 36 | } else { 37 | c.JSON(200, ErrorResponse(err)) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /routers/controllers/webdav.go: -------------------------------------------------------------------------------- 1 | package controllers 2 | 3 | import ( 4 | model "github.com/cloudreve/Cloudreve/v3/models" 5 | "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" 6 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 7 | "github.com/cloudreve/Cloudreve/v3/pkg/webdav" 8 | "github.com/cloudreve/Cloudreve/v3/service/setting" 9 | "github.com/gin-gonic/gin" 10 | ) 11 | 12 | var handler *webdav.Handler 13 | 14 | func init() { 15 | handler = &webdav.Handler{ 16 | Prefix: "/dav", 17 | LockSystem: make(map[uint]webdav.LockSystem), 18 | } 19 | } 20 | 21 | // ServeWebDAV 处理WebDAV相关请求 22 | func ServeWebDAV(c *gin.Context) { 23 | fs, err := filesystem.NewFileSystemFromContext(c) 24 | if err != nil { 25 | util.Log().Warning("无法为WebDAV初始化文件系统,%s", err) 26 | return 27 | } 28 | 29 | if webdavCtx, ok := c.Get("webdav"); ok { 30 | application := webdavCtx.(*model.Webdav) 31 | 32 | // 重定根目录 33 | if application.Root != "/" { 34 | if exist, root := fs.IsPathExist(application.Root); exist { 35 | root.Position = "" 36 | root.Name = "/" 37 | fs.Root = root 38 | } 39 | } 40 | } 41 | 42 | handler.ServeHTTP(c.Writer, c.Request, fs) 43 | } 44 | 45 | // GetWebDAVAccounts 获取webdav账号列表 46 | func GetWebDAVAccounts(c *gin.Context) { 47 | var service setting.WebDAVListService 48 | if err := c.ShouldBindUri(&service); err == nil { 49 | res := service.Accounts(c, CurrentUser(c)) 50 | c.JSON(200, res) 51 | } else { 52 | c.JSON(200, ErrorResponse(err)) 53 | } 54 | } 55 | 56 | // DeleteWebDAVAccounts 删除WebDAV账户 57 | func DeleteWebDAVAccounts(c *gin.Context) { 58 | var service setting.WebDAVAccountService 59 | if err := c.ShouldBindUri(&service); err == nil { 60 | res := service.Delete(c, CurrentUser(c)) 61 | c.JSON(200, res) 62 | } else { 63 | c.JSON(200, ErrorResponse(err)) 64 | } 65 | } 66 | 67 | // CreateWebDAVAccounts 创建WebDAV账户 68 | func CreateWebDAVAccounts(c *gin.Context) { 69 | var service setting.WebDAVAccountCreateService 70 | if err := c.ShouldBindJSON(&service); err == nil { 71 | res := service.Create(c, CurrentUser(c)) 72 | c.JSON(200, res) 73 | } else { 74 | c.JSON(200, ErrorResponse(err)) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /routers/main_test.go: -------------------------------------------------------------------------------- 1 | package routers 2 | 3 | import ( 4 | "database/sql" 5 | "testing" 6 | 7 | "github.com/DATA-DOG/go-sqlmock" 8 | model "github.com/cloudreve/Cloudreve/v3/models" 9 | "github.com/gin-gonic/gin" 10 | "github.com/jinzhu/gorm" 11 | ) 12 | 13 | var mock sqlmock.Sqlmock 14 | var memDB *gorm.DB 15 | var mockDB *gorm.DB 16 | 17 | // TestMain 初始化数据库Mock 18 | func TestMain(m *testing.M) { 19 | // 设置gin为测试模式 20 | gin.SetMode(gin.TestMode) 21 | 22 | // 初始化sqlmock 23 | var db *sql.DB 24 | var err error 25 | db, mock, err = sqlmock.New() 26 | if err != nil { 27 | panic("An error was not expected when opening a stub database connection") 28 | } 29 | 30 | // 初始话内存数据库 31 | model.Init() 32 | memDB = model.DB 33 | 34 | mockDB, _ = gorm.Open("mysql", db) 35 | model.DB = memDB 36 | defer db.Close() 37 | 38 | m.Run() 39 | } 40 | 41 | func switchToMemDB() { 42 | model.DB = memDB 43 | } 44 | 45 | func switchToMockDB() { 46 | model.DB = mockDB 47 | } 48 | -------------------------------------------------------------------------------- /service/admin/aria2.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "net/url" 5 | 6 | "github.com/cloudreve/Cloudreve/v3/pkg/aria2" 7 | "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 8 | ) 9 | 10 | // Aria2TestService aria2连接测试服务 11 | type Aria2TestService struct { 12 | Server string `json:"server" binding:"required"` 13 | Token string `json:"token"` 14 | } 15 | 16 | // Test 测试aria2连接 17 | func (service *Aria2TestService) Test() serializer.Response { 18 | testRPC := aria2.RPCService{} 19 | 20 | // 解析RPC服务地址 21 | server, err := url.Parse(service.Server) 22 | if err != nil { 23 | return serializer.ParamErr("无法解析 aria2 RPC 服务地址, "+err.Error(), nil) 24 | } 25 | server.Path = "/jsonrpc" 26 | 27 | if err := testRPC.Init(server.String(), service.Token, 5, map[string]interface{}{}); err != nil { 28 | return serializer.ParamErr("无法初始化连接, "+err.Error(), nil) 29 | } 30 | 31 | defer testRPC.Caller.Close() 32 | 33 | info, err := testRPC.Caller.GetVersion() 34 | if err != nil { 35 | return serializer.ParamErr("无法请求 RPC 服务, "+err.Error(), nil) 36 | } 37 | 38 | if info.Version == "" { 39 | return serializer.ParamErr("RPC 服务返回非预期响应", nil) 40 | } 41 | 42 | return serializer.Response{Data: info.Version} 43 | } 44 | -------------------------------------------------------------------------------- /service/admin/group.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "fmt" 5 | 6 | model "github.com/cloudreve/Cloudreve/v3/models" 7 | "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 8 | ) 9 | 10 | // AddGroupService 用户组添加服务 11 | type AddGroupService struct { 12 | Group model.Group `json:"group" binding:"required"` 13 | } 14 | 15 | // GroupService 用户组ID服务 16 | type GroupService struct { 17 | ID uint `uri:"id" json:"id" binding:"required"` 18 | } 19 | 20 | // Get 获取用户组详情 21 | func (service *GroupService) Get() serializer.Response { 22 | group, err := model.GetGroupByID(service.ID) 23 | if err != nil { 24 | return serializer.Err(serializer.CodeNotFound, "存储策略不存在", err) 25 | } 26 | 27 | return serializer.Response{Data: group} 28 | } 29 | 30 | // Delete 删除用户组 31 | func (service *GroupService) Delete() serializer.Response { 32 | // 查找用户组 33 | group, err := model.GetGroupByID(service.ID) 34 | if err != nil { 35 | return serializer.Err(serializer.CodeNotFound, "用户组不存在", err) 36 | } 37 | 38 | // 是否为系统用户组 39 | if group.ID <= 3 { 40 | return serializer.Err(serializer.CodeNoPermissionErr, "系统用户组无法删除", err) 41 | } 42 | 43 | // 检查是否有用户使用 44 | total := 0 45 | row := model.DB.Model(&model.User{}).Where("group_id = ?", service.ID). 46 | Select("count(id)").Row() 47 | row.Scan(&total) 48 | if total > 0 { 49 | return serializer.ParamErr(fmt.Sprintf("有 %d 位用户仍属于此用户组,请先删除这些用户或者更改用户组", total), nil) 50 | } 51 | 52 | model.DB.Delete(&group) 53 | 54 | return serializer.Response{} 55 | } 56 | 57 | // Add 添加用户组 58 | func (service *AddGroupService) Add() serializer.Response { 59 | if service.Group.ID > 0 { 60 | if err := model.DB.Save(&service.Group).Error; err != nil { 61 | return serializer.ParamErr("用户组保存失败", err) 62 | } 63 | } else { 64 | if err := model.DB.Create(&service.Group).Error; err != nil { 65 | return serializer.ParamErr("用户组添加失败", err) 66 | } 67 | } 68 | 69 | return serializer.Response{Data: service.Group.ID} 70 | } 71 | 72 | // Groups 列出用户组 73 | func (service *AdminListService) Groups() serializer.Response { 74 | var res []model.Group 75 | total := 0 76 | 77 | tx := model.DB.Model(&model.Group{}) 78 | if service.OrderBy != "" { 79 | tx = tx.Order(service.OrderBy) 80 | } 81 | 82 | for k, v := range service.Conditions { 83 | tx = tx.Where(k+" = ?", v) 84 | } 85 | 86 | // 计算总数用于分页 87 | tx.Count(&total) 88 | 89 | // 查询记录 90 | tx.Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res) 91 | 92 | // 统计每个用户组的用户总数 93 | statics := make(map[uint]int, len(res)) 94 | for i := 0; i < len(res); i++ { 95 | total := 0 96 | row := model.DB.Model(&model.User{}).Where("group_id = ?", res[i].ID). 97 | Select("count(id)").Row() 98 | row.Scan(&total) 99 | statics[res[i].ID] = total 100 | } 101 | 102 | // 汇总用户组存储策略 103 | policies := make(map[uint]model.Policy) 104 | for i := 0; i < len(res); i++ { 105 | for _, p := range res[i].PolicyList { 106 | if _, ok := policies[p]; !ok { 107 | policies[p], _ = model.GetPolicyByID(p) 108 | } 109 | } 110 | } 111 | 112 | return serializer.Response{Data: map[string]interface{}{ 113 | "total": total, 114 | "items": res, 115 | "statics": statics, 116 | "policies": policies, 117 | }} 118 | } 119 | -------------------------------------------------------------------------------- /service/admin/list.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | model "github.com/cloudreve/Cloudreve/v3/models" 5 | "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 6 | ) 7 | 8 | // AdminListService 仪表盘列条目服务 9 | type AdminListService struct { 10 | Page int `json:"page" binding:"min=1,required"` 11 | PageSize int `json:"page_size" binding:"min=1,required"` 12 | OrderBy string `json:"order_by"` 13 | Conditions map[string]string `form:"conditions"` 14 | Searches map[string]string `form:"searches"` 15 | } 16 | 17 | // GroupList 获取用户组列表 18 | func (service *NoParamService) GroupList() serializer.Response { 19 | var res []model.Group 20 | model.DB.Model(&model.Group{}).Find(&res) 21 | return serializer.Response{Data: res} 22 | } 23 | -------------------------------------------------------------------------------- /service/admin/share.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "strings" 5 | 6 | model "github.com/cloudreve/Cloudreve/v3/models" 7 | "github.com/cloudreve/Cloudreve/v3/pkg/hashid" 8 | "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 9 | "github.com/gin-gonic/gin" 10 | ) 11 | 12 | // ShareBatchService 分享批量操作服务 13 | type ShareBatchService struct { 14 | ID []uint `json:"id" binding:"min=1"` 15 | } 16 | 17 | // Delete 删除文件 18 | func (service *ShareBatchService) Delete(c *gin.Context) serializer.Response { 19 | if err := model.DB.Where("id in (?)", service.ID).Delete(&model.Share{}).Error; err != nil { 20 | return serializer.DBErr("无法删除分享", err) 21 | } 22 | return serializer.Response{} 23 | } 24 | 25 | // Shares 列出分享 26 | func (service *AdminListService) Shares() serializer.Response { 27 | var res []model.Share 28 | total := 0 29 | 30 | tx := model.DB.Model(&model.Share{}) 31 | if service.OrderBy != "" { 32 | tx = tx.Order(service.OrderBy) 33 | } 34 | 35 | for k, v := range service.Conditions { 36 | tx = tx.Where(k+" = ?", v) 37 | } 38 | 39 | if len(service.Searches) > 0 { 40 | search := "" 41 | for k, v := range service.Searches { 42 | search += k + " like '%" + v + "%' OR " 43 | } 44 | search = strings.TrimSuffix(search, " OR ") 45 | tx = tx.Where(search) 46 | } 47 | 48 | // 计算总数用于分页 49 | tx.Count(&total) 50 | 51 | // 查询记录 52 | tx.Limit(service.PageSize).Offset((service.Page - 1) * service.PageSize).Find(&res) 53 | 54 | // 查询对应用户,同时计算HashID 55 | users := make(map[uint]model.User) 56 | hashIDs := make(map[uint]string, len(res)) 57 | for _, file := range res { 58 | users[file.UserID] = model.User{} 59 | hashIDs[file.ID] = hashid.HashID(file.ID, hashid.ShareID) 60 | } 61 | 62 | userIDs := make([]uint, 0, len(users)) 63 | for k := range users { 64 | userIDs = append(userIDs, k) 65 | } 66 | 67 | var userList []model.User 68 | model.DB.Where("id in (?)", userIDs).Find(&userList) 69 | 70 | for _, v := range userList { 71 | users[v.ID] = v 72 | } 73 | 74 | return serializer.Response{Data: map[string]interface{}{ 75 | "total": total, 76 | "items": res, 77 | "users": users, 78 | "ids": hashIDs, 79 | }} 80 | } 81 | -------------------------------------------------------------------------------- /service/aria2/add.go: -------------------------------------------------------------------------------- 1 | package aria2 2 | 3 | import ( 4 | model "github.com/cloudreve/Cloudreve/v3/models" 5 | "github.com/cloudreve/Cloudreve/v3/pkg/aria2" 6 | "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" 7 | "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 8 | "github.com/gin-gonic/gin" 9 | ) 10 | 11 | // AddURLService 添加URL离线下载服务 12 | type AddURLService struct { 13 | URL string `json:"url" binding:"required"` 14 | Dst string `json:"dst" binding:"required,min=1"` 15 | } 16 | 17 | // Add 创建新的链接离线下载任务 18 | func (service *AddURLService) Add(c *gin.Context, taskType int) serializer.Response { 19 | // 创建文件系统 20 | fs, err := filesystem.NewFileSystemFromContext(c) 21 | if err != nil { 22 | return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) 23 | } 24 | defer fs.Recycle() 25 | 26 | // 检查用户组权限 27 | if !fs.User.Group.OptionsSerialized.Aria2 { 28 | return serializer.Err(serializer.CodeGroupNotAllowed, "当前用户组无法进行此操作", nil) 29 | } 30 | 31 | // 存放目录是否存在 32 | if exist, _ := fs.IsPathExist(service.Dst); !exist { 33 | return serializer.Err(serializer.CodeNotFound, "存放路径不存在", nil) 34 | } 35 | 36 | // 创建任务 37 | task := &model.Download{ 38 | Status: aria2.Ready, 39 | Type: taskType, 40 | Dst: service.Dst, 41 | UserID: fs.User.ID, 42 | Source: service.URL, 43 | } 44 | 45 | aria2.Lock.RLock() 46 | if err := aria2.Instance.CreateTask(task, fs.User.Group.OptionsSerialized.Aria2Options); err != nil { 47 | aria2.Lock.RUnlock() 48 | return serializer.Err(serializer.CodeNotSet, "任务创建失败", err) 49 | } 50 | aria2.Lock.RUnlock() 51 | 52 | return serializer.Response{} 53 | } 54 | -------------------------------------------------------------------------------- /service/aria2/manage.go: -------------------------------------------------------------------------------- 1 | package aria2 2 | 3 | import ( 4 | model "github.com/cloudreve/Cloudreve/v3/models" 5 | "github.com/cloudreve/Cloudreve/v3/pkg/aria2" 6 | "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 7 | "github.com/gin-gonic/gin" 8 | ) 9 | 10 | // SelectFileService 选择要下载的文件服务 11 | type SelectFileService struct { 12 | Indexes []int `json:"indexes" binding:"required"` 13 | } 14 | 15 | // DownloadTaskService 下载任务管理服务 16 | type DownloadTaskService struct { 17 | GID string `uri:"gid" binding:"required"` 18 | } 19 | 20 | // DownloadListService 下载列表服务 21 | type DownloadListService struct { 22 | Page uint `form:"page"` 23 | } 24 | 25 | // Finished 获取已完成的任务 26 | func (service *DownloadListService) Finished(c *gin.Context, user *model.User) serializer.Response { 27 | // 查找下载记录 28 | downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, aria2.Error, aria2.Complete, aria2.Canceled, aria2.Unknown) 29 | return serializer.BuildFinishedListResponse(downloads) 30 | } 31 | 32 | // Downloading 获取正在下载中的任务 33 | func (service *DownloadListService) Downloading(c *gin.Context, user *model.User) serializer.Response { 34 | // 查找下载记录 35 | downloads := model.GetDownloadsByStatusAndUser(service.Page, user.ID, aria2.Downloading, aria2.Paused, aria2.Ready) 36 | return serializer.BuildDownloadingResponse(downloads) 37 | } 38 | 39 | // Delete 取消或删除下载任务 40 | func (service *DownloadTaskService) Delete(c *gin.Context) serializer.Response { 41 | userCtx, _ := c.Get("user") 42 | user := userCtx.(*model.User) 43 | 44 | // 查找下载记录 45 | download, err := model.GetDownloadByGid(c.Param("gid"), user.ID) 46 | if err != nil { 47 | return serializer.Err(serializer.CodeNotFound, "下载记录不存在", err) 48 | } 49 | 50 | if download.Status >= aria2.Error { 51 | // 如果任务已完成,则删除任务记录 52 | if err := download.Delete(); err != nil { 53 | return serializer.Err(serializer.CodeDBError, "任务记录删除失败", err) 54 | } 55 | return serializer.Response{} 56 | } 57 | 58 | // 取消任务 59 | aria2.Lock.RLock() 60 | defer aria2.Lock.RUnlock() 61 | if err := aria2.Instance.Cancel(download); err != nil { 62 | return serializer.Err(serializer.CodeNotSet, "操作失败", err) 63 | } 64 | 65 | return serializer.Response{} 66 | } 67 | 68 | // Select 选取要下载的文件 69 | func (service *SelectFileService) Select(c *gin.Context) serializer.Response { 70 | userCtx, _ := c.Get("user") 71 | user := userCtx.(*model.User) 72 | 73 | // 查找下载记录 74 | download, err := model.GetDownloadByGid(c.Param("gid"), user.ID) 75 | if err != nil { 76 | return serializer.Err(serializer.CodeNotFound, "下载记录不存在", err) 77 | } 78 | 79 | if download.StatusInfo.BitTorrent.Mode != "multi" || (download.Status != aria2.Downloading && download.Status != aria2.Paused) { 80 | return serializer.Err(serializer.CodeNoPermissionErr, "此下载任务无法选取文件", err) 81 | } 82 | 83 | // 选取下载 84 | aria2.Lock.RLock() 85 | defer aria2.Lock.RUnlock() 86 | if err := aria2.Instance.Select(download, service.Indexes); err != nil { 87 | return serializer.Err(serializer.CodeNotSet, "操作失败", err) 88 | } 89 | 90 | return serializer.Response{} 91 | 92 | } 93 | -------------------------------------------------------------------------------- /service/callback/oauth.go: -------------------------------------------------------------------------------- 1 | package callback 2 | 3 | import ( 4 | "context" 5 | 6 | model "github.com/cloudreve/Cloudreve/v3/models" 7 | "github.com/cloudreve/Cloudreve/v3/pkg/cache" 8 | "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/driver/onedrive" 9 | "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 10 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 11 | "github.com/gin-gonic/gin" 12 | ) 13 | 14 | // OneDriveOauthService OneDrive 授权回调服务 15 | type OneDriveOauthService struct { 16 | Code string `form:"code"` 17 | Error string `form:"error"` 18 | ErrorMsg string `form:"error_description"` 19 | } 20 | 21 | // Auth 更新认证信息 22 | func (service *OneDriveOauthService) Auth(c *gin.Context) serializer.Response { 23 | if service.Error != "" { 24 | return serializer.ParamErr(service.ErrorMsg, nil) 25 | } 26 | 27 | policyID, ok := util.GetSession(c, "onedrive_oauth_policy").(uint) 28 | if !ok { 29 | return serializer.Err(serializer.CodeNotFound, "授权会话不存在,请重试", nil) 30 | } 31 | 32 | util.DeleteSession(c, "onedrive_oauth_policy") 33 | 34 | policy, err := model.GetPolicyByID(policyID) 35 | if err != nil { 36 | return serializer.Err(serializer.CodeNotFound, "存储策略不存在", nil) 37 | } 38 | 39 | client, err := onedrive.NewClient(&policy) 40 | if err != nil { 41 | return serializer.Err(serializer.CodeInternalSetting, "无法初始化 OneDrive 客户端", err) 42 | } 43 | 44 | credential, err := client.ObtainToken(context.Background(), onedrive.WithCode(service.Code)) 45 | if err != nil { 46 | return serializer.Err(serializer.CodeInternalSetting, "AccessToken 获取失败", err) 47 | } 48 | 49 | // 更新存储策略的 RefreshToken 50 | if err := client.Policy.UpdateAccessKey(credential.RefreshToken); err != nil { 51 | return serializer.DBErr("无法更新 RefreshToken", err) 52 | } 53 | 54 | cache.Deletes([]string{client.Policy.AccessKey}, "onedrive_") 55 | 56 | return serializer.Response{} 57 | } 58 | -------------------------------------------------------------------------------- /service/explorer/directory.go: -------------------------------------------------------------------------------- 1 | package explorer 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" 7 | "github.com/cloudreve/Cloudreve/v3/pkg/hashid" 8 | "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 9 | "github.com/gin-gonic/gin" 10 | ) 11 | 12 | // DirectoryService 创建新目录服务 13 | type DirectoryService struct { 14 | Path string `uri:"path" json:"path" binding:"required,min=1,max=65535"` 15 | } 16 | 17 | // ListDirectory 列出目录内容 18 | func (service *DirectoryService) ListDirectory(c *gin.Context) serializer.Response { 19 | // 创建文件系统 20 | fs, err := filesystem.NewFileSystemFromContext(c) 21 | if err != nil { 22 | return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) 23 | } 24 | defer fs.Recycle() 25 | 26 | // 上下文 27 | ctx, cancel := context.WithCancel(context.Background()) 28 | defer cancel() 29 | 30 | // 获取子项目 31 | objects, err := fs.List(ctx, service.Path, nil) 32 | if err != nil { 33 | return serializer.Err(serializer.CodeNotSet, err.Error(), err) 34 | } 35 | 36 | var parentID uint 37 | if len(fs.DirTarget) > 0 { 38 | parentID = fs.DirTarget[0].ID 39 | } 40 | 41 | return serializer.Response{ 42 | Code: 0, 43 | Data: map[string]interface{}{ 44 | "parent": hashid.HashID(parentID, hashid.FolderID), 45 | "objects": objects, 46 | }, 47 | } 48 | } 49 | 50 | // CreateDirectory 创建目录 51 | func (service *DirectoryService) CreateDirectory(c *gin.Context) serializer.Response { 52 | // 创建文件系统 53 | fs, err := filesystem.NewFileSystemFromContext(c) 54 | if err != nil { 55 | return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) 56 | } 57 | defer fs.Recycle() 58 | 59 | // 上下文 60 | ctx, cancel := context.WithCancel(context.Background()) 61 | defer cancel() 62 | 63 | // 创建目录 64 | _, err = fs.CreateDirectory(ctx, service.Path) 65 | if err != nil { 66 | return serializer.Err(serializer.CodeCreateFolderFailed, err.Error(), err) 67 | } 68 | return serializer.Response{ 69 | Code: 0, 70 | } 71 | 72 | } 73 | -------------------------------------------------------------------------------- /service/explorer/search.go: -------------------------------------------------------------------------------- 1 | package explorer 2 | 3 | import ( 4 | "context" 5 | "strings" 6 | 7 | model "github.com/cloudreve/Cloudreve/v3/models" 8 | "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" 9 | "github.com/cloudreve/Cloudreve/v3/pkg/hashid" 10 | "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 11 | "github.com/gin-gonic/gin" 12 | ) 13 | 14 | // ItemSearchService 文件搜索服务 15 | type ItemSearchService struct { 16 | Type string `uri:"type" binding:"required"` 17 | Keywords string `uri:"keywords" binding:"required"` 18 | } 19 | 20 | // Search 执行搜索 21 | func (service *ItemSearchService) Search(c *gin.Context) serializer.Response { 22 | // 创建文件系统 23 | fs, err := filesystem.NewFileSystemFromContext(c) 24 | if err != nil { 25 | return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) 26 | } 27 | defer fs.Recycle() 28 | 29 | switch service.Type { 30 | case "keywords": 31 | return service.SearchKeywords(c, fs, "%"+service.Keywords+"%") 32 | case "image": 33 | return service.SearchKeywords(c, fs, "%.bmp", "%.iff", "%.png", "%.gif", "%.jpg", "%.jpeg", "%.psd", "%.svg", "%.webp") 34 | case "video": 35 | return service.SearchKeywords(c, fs, "%.mp4", "%.flv", "%.avi", "%.wmv", "%.mkv", "%.rm", "%.rmvb", "%.mov", "%.ogv") 36 | case "audio": 37 | return service.SearchKeywords(c, fs, "%.mp3", "%.flac", "%.ape", "%.wav", "%.acc", "%.ogg", "%.midi", "%.mid") 38 | case "doc": 39 | return service.SearchKeywords(c, fs, "%.txt", "%.md", "%.pdf", "%.doc", "%.docx", "%.ppt", "%.pptx", "%.xls", "%.xlsx", "%.pub") 40 | case "tag": 41 | if tid, err := hashid.DecodeHashID(service.Keywords, hashid.TagID); err == nil { 42 | if tag, err := model.GetTagsByID(tid, fs.User.ID); err == nil { 43 | if tag.Type == model.FileTagType { 44 | exp := strings.Split(tag.Expression, "\n") 45 | expInput := make([]interface{}, len(exp)) 46 | for i := 0; i < len(exp); i++ { 47 | expInput[i] = exp[i] 48 | } 49 | return service.SearchKeywords(c, fs, expInput...) 50 | } 51 | } 52 | } 53 | return serializer.Err(serializer.CodeNotFound, "标签不存在", nil) 54 | default: 55 | return serializer.ParamErr("未知搜索类型", nil) 56 | } 57 | } 58 | 59 | // SearchKeywords 根据关键字搜索文件 60 | func (service *ItemSearchService) SearchKeywords(c *gin.Context, fs *filesystem.FileSystem, keywords ...interface{}) serializer.Response { 61 | // 上下文 62 | ctx, cancel := context.WithCancel(context.Background()) 63 | defer cancel() 64 | 65 | // 获取子项目 66 | objects, err := fs.Search(ctx, keywords...) 67 | if err != nil { 68 | return serializer.Err(serializer.CodeNotSet, err.Error(), err) 69 | } 70 | 71 | return serializer.Response{ 72 | Code: 0, 73 | Data: map[string]interface{}{ 74 | "parent": 0, 75 | "objects": objects, 76 | }, 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /service/explorer/tag.go: -------------------------------------------------------------------------------- 1 | package explorer 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | model "github.com/cloudreve/Cloudreve/v3/models" 8 | "github.com/cloudreve/Cloudreve/v3/pkg/hashid" 9 | "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 10 | "github.com/gin-gonic/gin" 11 | ) 12 | 13 | // FilterTagCreateService 文件分类标签创建服务 14 | type FilterTagCreateService struct { 15 | Expression string `json:"expression" binding:"required,min=1,max=65535"` 16 | Icon string `json:"icon" binding:"required,min=1,max=255"` 17 | Name string `json:"name" binding:"required,min=1,max=255"` 18 | Color string `json:"color" binding:"hexcolor|rgb|rgba|hsl"` 19 | } 20 | 21 | // LinkTagCreateService 目录快捷方式标签创建服务 22 | type LinkTagCreateService struct { 23 | Path string `json:"path" binding:"required,min=1,max=65535"` 24 | Name string `json:"name" binding:"required,min=1,max=255"` 25 | } 26 | 27 | // TagService 标签服务 28 | type TagService struct { 29 | } 30 | 31 | // Delete 删除标签 32 | func (service *TagService) Delete(c *gin.Context, user *model.User) serializer.Response { 33 | id, _ := c.Get("object_id") 34 | if err := model.DeleteTagByID(id.(uint), user.ID); err != nil { 35 | return serializer.Err(serializer.CodeDBError, "删除失败", err) 36 | } 37 | return serializer.Response{} 38 | } 39 | 40 | // Create 创建标签 41 | func (service *LinkTagCreateService) Create(c *gin.Context, user *model.User) serializer.Response { 42 | // 创建标签 43 | tag := model.Tag{ 44 | Name: service.Name, 45 | Icon: "FolderHeartOutline", 46 | Type: model.DirectoryLinkType, 47 | Expression: service.Path, 48 | UserID: user.ID, 49 | } 50 | id, err := tag.Create() 51 | if err != nil { 52 | return serializer.Err(serializer.CodeDBError, "标签创建失败", err) 53 | } 54 | 55 | return serializer.Response{ 56 | Data: hashid.HashID(id, hashid.TagID), 57 | } 58 | } 59 | 60 | // Create 创建标签 61 | func (service *FilterTagCreateService) Create(c *gin.Context, user *model.User) serializer.Response { 62 | // 分割表达式,将通配符转换为SQL内的% 63 | expressions := strings.Split(service.Expression, "\n") 64 | for i := 0; i < len(expressions); i++ { 65 | expressions[i] = strings.ReplaceAll(expressions[i], "*", "%") 66 | if expressions[i] == "" { 67 | return serializer.ParamErr(fmt.Sprintf("第 %d 行包含空的匹配表达式", i+1), nil) 68 | } 69 | } 70 | 71 | // 创建标签 72 | tag := model.Tag{ 73 | Name: service.Name, 74 | Icon: service.Icon, 75 | Color: service.Color, 76 | Type: model.FileTagType, 77 | Expression: strings.Join(expressions, "\n"), 78 | UserID: user.ID, 79 | } 80 | id, err := tag.Create() 81 | if err != nil { 82 | return serializer.Err(serializer.CodeDBError, "标签创建失败", err) 83 | } 84 | 85 | return serializer.Response{ 86 | Data: hashid.HashID(id, hashid.TagID), 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /service/explorer/upload.go: -------------------------------------------------------------------------------- 1 | package explorer 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/cloudreve/Cloudreve/v3/pkg/filesystem" 7 | "github.com/cloudreve/Cloudreve/v3/pkg/filesystem/fsctx" 8 | "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 9 | "github.com/gin-gonic/gin" 10 | ) 11 | 12 | // UploadCredentialService 获取上传凭证服务 13 | type UploadCredentialService struct { 14 | Path string `form:"path" binding:"required"` 15 | Size uint64 `form:"size" binding:"min=0"` 16 | Name string `form:"name"` 17 | Type string `form:"type"` 18 | } 19 | 20 | // Get 获取新的上传凭证 21 | func (service *UploadCredentialService) Get(ctx context.Context, c *gin.Context) serializer.Response { 22 | // 创建文件系统 23 | fs, err := filesystem.NewFileSystemFromContext(c) 24 | if err != nil { 25 | return serializer.Err(serializer.CodePolicyNotAllowed, err.Error(), err) 26 | } 27 | 28 | // 存储策略是否一致 29 | if service.Type != "" { 30 | if service.Type != fs.User.Policy.Type { 31 | return serializer.Err(serializer.CodePolicyNotAllowed, "存储策略已变更,请刷新页面", nil) 32 | } 33 | } 34 | 35 | ctx = context.WithValue(ctx, fsctx.GinCtx, c) 36 | credential, err := fs.GetUploadToken(ctx, service.Path, service.Size, service.Name) 37 | if err != nil { 38 | return serializer.Err(serializer.CodeNotSet, err.Error(), err) 39 | } 40 | 41 | return serializer.Response{ 42 | Code: 0, 43 | Data: credential, 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /service/setting/webdav.go: -------------------------------------------------------------------------------- 1 | package setting 2 | 3 | import ( 4 | model "github.com/cloudreve/Cloudreve/v3/models" 5 | "github.com/cloudreve/Cloudreve/v3/pkg/serializer" 6 | "github.com/cloudreve/Cloudreve/v3/pkg/util" 7 | "github.com/gin-gonic/gin" 8 | ) 9 | 10 | // WebDAVListService WebDAV 列表服务 11 | type WebDAVListService struct { 12 | } 13 | 14 | // WebDAVAccountService WebDAV 账号管理服务 15 | type WebDAVAccountService struct { 16 | ID uint `uri:"id" binding:"required,min=1"` 17 | } 18 | 19 | // WebDAVAccountCreateService WebDAV 账号创建服务 20 | type WebDAVAccountCreateService struct { 21 | Path string `json:"path" binding:"required,min=1,max=65535"` 22 | Name string `json:"name" binding:"required,min=1,max=255"` 23 | } 24 | 25 | // WebDAVMountCreateService WebDAV 挂载创建服务 26 | type WebDAVMountCreateService struct { 27 | Path string `json:"path" binding:"required,min=1,max=65535"` 28 | Policy string `json:"policy" binding:"required,min=1"` 29 | } 30 | 31 | // Create 创建WebDAV账户 32 | func (service *WebDAVAccountCreateService) Create(c *gin.Context, user *model.User) serializer.Response { 33 | account := model.Webdav{ 34 | Name: service.Name, 35 | Password: util.RandStringRunes(32), 36 | UserID: user.ID, 37 | Root: service.Path, 38 | } 39 | 40 | if _, err := account.Create(); err != nil { 41 | return serializer.Err(serializer.CodeDBError, "创建失败", err) 42 | } 43 | 44 | return serializer.Response{ 45 | Data: map[string]interface{}{ 46 | "id": account.ID, 47 | "password": account.Password, 48 | "created_at": account.CreatedAt, 49 | }, 50 | } 51 | } 52 | 53 | // Delete 删除WebDAV账户 54 | func (service *WebDAVAccountService) Delete(c *gin.Context, user *model.User) serializer.Response { 55 | model.DeleteWebDAVAccountByID(service.ID, user.ID) 56 | return serializer.Response{} 57 | } 58 | 59 | // Accounts 列出WebDAV账号 60 | func (service *WebDAVListService) Accounts(c *gin.Context, user *model.User) serializer.Response { 61 | accounts := model.ListWebDAVAccounts(user.ID) 62 | 63 | return serializer.Response{Data: map[string]interface{}{ 64 | "accounts": accounts, 65 | }} 66 | } 67 | --------------------------------------------------------------------------------