├── .gitignore ├── pkg ├── pki │ ├── test │ │ ├── .gitignore │ │ └── dir_keystorage │ │ │ ├── bad_path │ │ │ ├── exist.pem │ │ │ ├── wrong_serial │ │ │ ├── good_cert │ │ │ ├── 42.key │ │ │ └── .gitignore │ │ │ ├── bad_key │ │ │ ├── .gitignore │ │ │ └── 42.key │ │ │ │ └── .gitignore │ │ │ ├── crl.dir │ │ │ └── .gitignore │ │ │ ├── bad_cert │ │ │ └── 42.crt │ │ │ │ └── .gitignore │ │ │ └── good_crl.pem │ ├── struct.go │ ├── options.go │ ├── options_test.go │ ├── pki_test.go │ └── pki.go └── pair │ └── pair.go ├── internal └── fsStorage │ ├── test │ ├── .gitignore │ └── dir_keystorage │ │ ├── bad_path │ │ ├── exist.pem │ │ ├── wrong_serial │ │ ├── good_cert │ │ ├── 42.key │ │ └── .gitignore │ │ ├── bad_key │ │ ├── .gitignore │ │ └── 42.key │ │ │ └── .gitignore │ │ ├── crl.dir │ │ └── .gitignore │ │ ├── bad_cert │ │ └── 42.crt │ │ │ └── .gitignore │ │ └── good_crl.pem │ ├── storage.go │ └── storage_test.go ├── cmd └── easyrsa │ ├── main.go │ └── cmd.go ├── .coveralls.yml ├── .travis.yml ├── go.mod ├── readme.md ├── LICENSE ├── .github └── workflows │ └── test.yml └── go.sum /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | -------------------------------------------------------------------------------- /pkg/pki/test/.gitignore: -------------------------------------------------------------------------------- 1 | *.lock 2 | -------------------------------------------------------------------------------- /pkg/pki/test/dir_keystorage/bad_path: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pkg/pki/test/dir_keystorage/exist.pem: -------------------------------------------------------------------------------- 1 | asd -------------------------------------------------------------------------------- /internal/fsStorage/test/.gitignore: -------------------------------------------------------------------------------- 1 | *.lock 2 | -------------------------------------------------------------------------------- /internal/fsStorage/test/dir_keystorage/bad_path: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pkg/pki/test/dir_keystorage/wrong_serial: -------------------------------------------------------------------------------- 1 | gggg -------------------------------------------------------------------------------- /internal/fsStorage/test/dir_keystorage/exist.pem: -------------------------------------------------------------------------------- 1 | asd -------------------------------------------------------------------------------- /pkg/pki/test/dir_keystorage/good_cert/42.key: -------------------------------------------------------------------------------- 1 | keybytes -------------------------------------------------------------------------------- /internal/fsStorage/test/dir_keystorage/wrong_serial: -------------------------------------------------------------------------------- 1 | gggg -------------------------------------------------------------------------------- /internal/fsStorage/test/dir_keystorage/good_cert/42.key: -------------------------------------------------------------------------------- 1 | keybytes -------------------------------------------------------------------------------- /cmd/easyrsa/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | func main() { 4 | Execute() 5 | } 6 | -------------------------------------------------------------------------------- /.coveralls.yml: -------------------------------------------------------------------------------- 1 | service_name: travis-ci 2 | repo_token: 1jWnoKMVSPyJPPIQ5rEZLhYtJBA4GJ5Xc -------------------------------------------------------------------------------- /pkg/pki/test/dir_keystorage/bad_key/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | 42.crt 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /pkg/pki/test/dir_keystorage/crl.dir/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | .gitignore 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /pkg/pki/test/dir_keystorage/good_cert/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | 42.crt 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /internal/fsStorage/test/dir_keystorage/bad_key/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | 42.crt 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /internal/fsStorage/test/dir_keystorage/crl.dir/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | .gitignore 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /internal/fsStorage/test/dir_keystorage/good_cert/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | 42.crt 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /pkg/pki/test/dir_keystorage/bad_cert/42.crt/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | .gitignore 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /pkg/pki/test/dir_keystorage/bad_key/42.key/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | .gitignore 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /internal/fsStorage/test/dir_keystorage/bad_cert/42.crt/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | .gitignore 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /internal/fsStorage/test/dir_keystorage/bad_key/42.key/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | .gitignore 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | dist: focal 4 | 5 | go: 6 | - "1.18" 7 | 8 | before_script: 9 | - go get github.com/golangci/golangci-lint/cmd/golangci-lint 10 | - go get golang.org/x/tools/cmd/cover 11 | - go get github.com/mattn/goveralls 12 | 13 | script: 14 | - golangci-lint run 15 | - go test -v -covermode=count -coverprofile=coverage.out ./... 16 | - goveralls -coverprofile=coverage.out 17 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/kemsta/go-easyrsa 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/gofrs/flock v0.8.1 7 | github.com/stretchr/testify v1.7.1 8 | ) 9 | 10 | require ( 11 | github.com/inconshreveable/mousetrap v1.0.1 // indirect 12 | github.com/spf13/pflag v1.0.5 // indirect 13 | ) 14 | 15 | require ( 16 | github.com/davecgh/go-spew v1.1.1 // indirect 17 | github.com/kr/text v0.2.0 // indirect 18 | github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect 19 | github.com/pmezard/go-difflib v1.0.0 // indirect 20 | github.com/spf13/cobra v1.5.0 21 | golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab // indirect 22 | gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect 23 | gopkg.in/yaml.v3 v3.0.1 // indirect 24 | ) 25 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # EasyRsa 2 | [![Build Status](https://github.com/kemsta/go-easyrsa/actions/workflows/test.yml/badge.svg)](https://github.com/kemsta/go-easyrsa/actions/workflows/test.yml) 3 | [![Coverage Status](https://coveralls.io/repos/github/kemsta/go-easyrsa/badge.svg?branch=master)](https://coveralls.io/github/kemsta/go-easyrsa?branch=master) 4 | [![GoDoc](https://godoc.org/github.com/kemsta/go-easyrsa?status.svg)](https://godoc.org/github.com/kemsta/go-easyrsa) 5 | 6 | Simple golang implementation some [easy-rsa](https://github.com/OpenVPN/easy-rsa) functions 7 | 8 | ## cli usage examples 9 | 10 | go install github.com/kemsta/go-easyrsa/cmd/easyrsa@latest 11 | 12 | ### build ca pair 13 | easyrsa -k keys build-ca 14 | 15 | ### build server pair 16 | easyrsa -k keys build-server-key some-server-name 17 | 18 | ### build client pair 19 | easyrsa -k keys build-key some-client-name 20 | 21 | ### revoke cert 22 | easyrsa -k keys revoke-full some-client-name 23 | -------------------------------------------------------------------------------- /pkg/pki/test/dir_keystorage/good_crl.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN X509 CRL----- 2 | MIICtDCCAZwCAQEwDQYJKoZIhvcNAQELBQAwejELMAkGA1UEBhMCVVMxCzAJBgNV 3 | BAgTAkNBMRUwEwYDVQQHEwxTYW5GcmFuY2lzY28xFTATBgNVBAoTDEZvcnQtRnVu 4 | c3RvbjEdMBsGA1UECxMUTXlPcmdhbml6YXRpb25hbFVuaXQxETAPBgNVBAMTCGNp 5 | cy10ZWNoFw0xOTAzMDUyMDE3NDVaGA8yMTE4MDIwOTIwMTc0NVowgcYwEwICAMAX 6 | DTE5MDMwNTIwMTc0NFowIQIQNLLBhmHbiMpk16YGY5/WURcNMTkwMzA1MjAxNzQ0 7 | WjAiAhEAmaZnjV9u+iMdSHzCvKLSexcNMTkwMzA1MjAxNzQ0WjAhAhBQ4RT6IMlR 8 | 3Z5Wt6GzxP0cFw0xOTAzMDUyMDE3NDRaMCICEQDro9t1jxTzq69UkEGYREGNFw0x 9 | OTAzMDUyMDE3NDRaMCECEDIV2PgTW626L+aXdzDACH8XDTE5MDMwNTIwMTc0NVqg 10 | IzAhMB8GA1UdIwQYMBaAFN8cv5d1Pe+a5UGZ3gk8IRDdBsn6MA0GCSqGSIb3DQEB 11 | CwUAA4IBAQAlDfv8MJZqGVshVxSjVHED6sR0NVvsEIX20A3qyjlBZLsOg5HOYir5 12 | Ki2xP8vH06LPgc/8fRiJ4GQNOswgvCIu0MBJcQNYSoloPSFeULFb9xBLH2b2uxRW 13 | gz/giSE2k6JE12DRhkisxzpmEyBlWKglVNG1CgH71ya0+bZEGZ7ExYA58Lxemx6u 14 | WGWElvzxmJZN7xECKY+cq1V5H/Wd1BjsnDDhYqkJW3chhKy+pTnbRt/f39ff3rij 15 | rjYmPuJTYb5JEkNwhUQxxbZIgvs2tFRAbkjrmENvnDT9MLJmMa1Axl/KJTR+RdZv 16 | Mx62V3dP+mYV7PqFbF96S9VDGxzWBETb 17 | -----END X509 CRL----- 18 | -------------------------------------------------------------------------------- /internal/fsStorage/test/dir_keystorage/good_crl.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN X509 CRL----- 2 | MIICtDCCAZwCAQEwDQYJKoZIhvcNAQELBQAwejELMAkGA1UEBhMCVVMxCzAJBgNV 3 | BAgTAkNBMRUwEwYDVQQHEwxTYW5GcmFuY2lzY28xFTATBgNVBAoTDEZvcnQtRnVu 4 | c3RvbjEdMBsGA1UECxMUTXlPcmdhbml6YXRpb25hbFVuaXQxETAPBgNVBAMTCGNp 5 | cy10ZWNoFw0xOTAzMDUyMDE3NDVaGA8yMTE4MDIwOTIwMTc0NVowgcYwEwICAMAX 6 | DTE5MDMwNTIwMTc0NFowIQIQNLLBhmHbiMpk16YGY5/WURcNMTkwMzA1MjAxNzQ0 7 | WjAiAhEAmaZnjV9u+iMdSHzCvKLSexcNMTkwMzA1MjAxNzQ0WjAhAhBQ4RT6IMlR 8 | 3Z5Wt6GzxP0cFw0xOTAzMDUyMDE3NDRaMCICEQDro9t1jxTzq69UkEGYREGNFw0x 9 | OTAzMDUyMDE3NDRaMCECEDIV2PgTW626L+aXdzDACH8XDTE5MDMwNTIwMTc0NVqg 10 | IzAhMB8GA1UdIwQYMBaAFN8cv5d1Pe+a5UGZ3gk8IRDdBsn6MA0GCSqGSIb3DQEB 11 | CwUAA4IBAQAlDfv8MJZqGVshVxSjVHED6sR0NVvsEIX20A3qyjlBZLsOg5HOYir5 12 | Ki2xP8vH06LPgc/8fRiJ4GQNOswgvCIu0MBJcQNYSoloPSFeULFb9xBLH2b2uxRW 13 | gz/giSE2k6JE12DRhkisxzpmEyBlWKglVNG1CgH71ya0+bZEGZ7ExYA58Lxemx6u 14 | WGWElvzxmJZN7xECKY+cq1V5H/Wd1BjsnDDhYqkJW3chhKy+pTnbRt/f39ff3rij 15 | rjYmPuJTYb5JEkNwhUQxxbZIgvs2tFRAbkjrmENvnDT9MLJmMa1Axl/KJTR+RdZv 16 | Mx62V3dP+mYV7PqFbF96S9VDGxzWBETb 17 | -----END X509 CRL----- 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Stanislav Kem 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pkg/pki/struct.go: -------------------------------------------------------------------------------- 1 | package pki 2 | 3 | import ( 4 | "crypto/x509/pkix" 5 | "github.com/kemsta/go-easyrsa/pkg/pair" 6 | "math/big" 7 | ) 8 | 9 | // Key storage interface 10 | type KeyStorage interface { 11 | Put(pair *pair.X509Pair) error // Put new pair to KeyStorage. Overwrite if already exist. 12 | GetByCN(cn string) ([]*pair.X509Pair, error) // Get all keypairs by CN. 13 | GetLastByCn(cn string) (*pair.X509Pair, error) // Get last pair by CN. 14 | GetBySerial(serial *big.Int) (*pair.X509Pair, error) // Get one keypair by serial. 15 | DeleteByCn(cn string) error // Delete all keypairs by CN. 16 | DeleteBySerial(serial *big.Int) error // Delete one keypair by serial. 17 | GetAll() ([]*pair.X509Pair, error) // Get all keypair 18 | } 19 | 20 | // Serial provider interface 21 | type SerialProvider interface { 22 | Next() (*big.Int, error) // Next return next uniq serial 23 | } 24 | 25 | // Certificate revocation list holder interface 26 | type CRLHolder interface { 27 | Put([]byte) error // Put file content for crl 28 | Get() (*pkix.CertificateList, error) // Get current revoked cert list 29 | } 30 | -------------------------------------------------------------------------------- /pkg/pair/pair.go: -------------------------------------------------------------------------------- 1 | package pair 2 | 3 | import ( 4 | "crypto/rsa" 5 | "crypto/x509" 6 | "encoding/pem" 7 | "fmt" 8 | "math/big" 9 | ) 10 | 11 | // X509Pair represent pair cert and key 12 | type X509Pair struct { 13 | KeyPemBytes []byte // pem encoded rsa.PrivateKey bytes 14 | CertPemBytes []byte // pem encoded x509.Certificate bytes 15 | CN string // common name 16 | Serial *big.Int // serial number 17 | } 18 | 19 | // Decode pem bytes to rsa.PrivateKey and x509.Certificate 20 | func (pair *X509Pair) Decode() (key *rsa.PrivateKey, cert *x509.Certificate, err error) { 21 | block, _ := pem.Decode(pair.KeyPemBytes) 22 | if block == nil { 23 | return nil, nil, fmt.Errorf("can`t parse key: %v", string(pair.KeyPemBytes)) 24 | } 25 | 26 | key, err = x509.ParsePKCS1PrivateKey(block.Bytes) 27 | if err != nil { 28 | return nil, nil, fmt.Errorf("can`t parse key %v: %w", string(block.Bytes), err) 29 | } 30 | 31 | block, _ = pem.Decode(pair.CertPemBytes) 32 | if block == nil { 33 | return nil, nil, fmt.Errorf("can`t parse cert: %v", string(pair.CertPemBytes)) 34 | } 35 | cert, err = x509.ParseCertificate(block.Bytes) 36 | if err != nil { 37 | return nil, nil, fmt.Errorf("can`t parse cert %v: %w", string(block.Bytes), err) 38 | } 39 | return 40 | } 41 | 42 | // NewX509Pair create new X509Pair object 43 | func NewX509Pair(keyPemBytes []byte, certPemBytes []byte, CN string, serial *big.Int) *X509Pair { 44 | return &X509Pair{KeyPemBytes: keyPemBytes, CertPemBytes: certPemBytes, CN: CN, Serial: serial} 45 | } 46 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | on: [push, pull_request] 2 | name: Test 3 | jobs: 4 | test: 5 | strategy: 6 | matrix: 7 | go-version: [1.17.x, 1.18.x] 8 | os: [ubuntu-latest, macos-latest, windows-latest] 9 | runs-on: ${{ matrix.os }} 10 | steps: 11 | - uses: actions/setup-go@v3 12 | with: 13 | go-version: ${{ matrix.go-version }} 14 | - uses: actions/checkout@v3 15 | - uses: actions/cache@v2 16 | with: 17 | path: ~/go/pkg/mod 18 | key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} 19 | restore-keys: | 20 | ${{ runner.os }}-go- 21 | - run: go test ./... 22 | golangci: 23 | name: lint 24 | runs-on: ubuntu-latest 25 | steps: 26 | - uses: actions/setup-go@v3 27 | with: 28 | go-version: 1.18 29 | - uses: actions/checkout@v3 30 | - name: golangci-lint 31 | uses: golangci/golangci-lint-action@v3 32 | with: 33 | version: latest 34 | coverage: 35 | runs-on: ubuntu-latest 36 | steps: 37 | - name: Install Go 38 | if: success() 39 | uses: actions/setup-go@v2 40 | with: 41 | go-version: 1.17.x 42 | - name: Checkout code 43 | uses: actions/checkout@v2 44 | - name: Calc coverage 45 | run: | 46 | go test -v ./... -covermode=count -coverprofile=coverage.out 47 | - name: Convert coverage.out to coverage.lcov 48 | uses: jandelgado/gcov2lcov-action@v1.0.6 49 | - name: Coveralls 50 | uses: coverallsapp/github-action@v1.1.2 51 | with: 52 | github-token: ${{ secrets.github_token }} 53 | path-to-lcov: coverage.lcov 54 | -------------------------------------------------------------------------------- /pkg/pki/options.go: -------------------------------------------------------------------------------- 1 | package pki 2 | 3 | import ( 4 | "crypto/x509" 5 | "crypto/x509/pkix" 6 | "encoding/asn1" 7 | "net" 8 | "time" 9 | ) 10 | 11 | type Option func(*x509.Certificate) 12 | 13 | func Apply(options []Option, cert *x509.Certificate) { 14 | for _, option := range options { 15 | option(cert) 16 | } 17 | } 18 | 19 | func CN(cn string) Option { 20 | return func(certificate *x509.Certificate) { 21 | certificate.Subject.CommonName = cn 22 | } 23 | } 24 | 25 | func Server() Option { 26 | return func(certificate *x509.Certificate) { 27 | certificate.KeyUsage = x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment 28 | certificate.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} 29 | if certificate.ExtraExtensions == nil { 30 | certificate.ExtraExtensions = []pkix.Extension{} 31 | } 32 | val, _ := asn1.Marshal(asn1.BitString{Bytes: []byte{0x40}, BitLength: 2}) // setting nsCertType to Server Type 33 | certificate.ExtraExtensions = append(certificate.ExtraExtensions, pkix.Extension{Id: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 1, 1}, Value: val}) 34 | } 35 | } 36 | 37 | func Client() Option { 38 | return func(certificate *x509.Certificate) { 39 | certificate.KeyUsage = x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement 40 | certificate.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth} 41 | if certificate.ExtraExtensions == nil { 42 | certificate.ExtraExtensions = []pkix.Extension{} 43 | } 44 | val, _ := asn1.Marshal(asn1.BitString{Bytes: []byte{0x80}, BitLength: 2}) // setting nsCertType to Client Type 45 | certificate.ExtraExtensions = append(certificate.ExtraExtensions, pkix.Extension{Id: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 1, 1}, Value: val}) 46 | } 47 | } 48 | 49 | func DNSNames(names []string) Option { 50 | return func(certificate *x509.Certificate) { 51 | certificate.DNSNames = names 52 | } 53 | } 54 | 55 | func IPAddresses(ips []net.IP) Option { 56 | return func(certificate *x509.Certificate) { 57 | certificate.IPAddresses = ips 58 | } 59 | } 60 | 61 | func ExcludedDNSDomains(names []string) Option { 62 | return func(certificate *x509.Certificate) { 63 | certificate.ExcludedDNSDomains = names 64 | } 65 | } 66 | 67 | func NotAfter(time time.Time) Option { 68 | return func(certificate *x509.Certificate) { 69 | certificate.NotAfter = time 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /cmd/easyrsa/cmd.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "github.com/kemsta/go-easyrsa/pkg/pki" 6 | "github.com/spf13/cobra" 7 | "log" 8 | "net" 9 | "os" 10 | ) 11 | 12 | var keyDir string 13 | var pkiI *pki.PKI 14 | var serverDnsNames []string 15 | var serverIPs []net.IP 16 | 17 | var rootCmd = &cobra.Command{ 18 | Use: "easyrsa", 19 | PersistentPreRun: func(cmd *cobra.Command, args []string) { 20 | var err error 21 | pkiI, err = getPki() 22 | if err != nil { 23 | log.Fatal(err) 24 | } 25 | }, 26 | } 27 | 28 | func Execute() { 29 | if err := rootCmd.Execute(); err != nil { 30 | fmt.Println(err) 31 | os.Exit(1) 32 | } 33 | } 34 | 35 | var buildCa = &cobra.Command{ 36 | Use: "build-ca [CN]", 37 | Short: "build ca cert/key with optional CN", 38 | Run: func(cmd *cobra.Command, args []string) { 39 | var options []pki.Option 40 | if len(args) > 0 { 41 | options = append(options, pki.CN(args[0])) 42 | } 43 | _, err := pkiI.NewCa(options...) 44 | if err != nil { 45 | fmt.Println(fmt.Errorf("can`t build ca pair: %s", err)) 46 | } 47 | }, 48 | } 49 | 50 | var buildServerKey = &cobra.Command{ 51 | Use: "build-server-key CN", 52 | Short: "build server cert/key with CN", 53 | Args: cobra.MinimumNArgs(1), 54 | Run: func(cmd *cobra.Command, args []string) { 55 | options := []pki.Option{pki.Server()} 56 | if serverDnsNames != nil { 57 | options = append(options, pki.DNSNames(serverDnsNames)) 58 | } 59 | if serverIPs != nil { 60 | options = append(options, pki.IPAddresses(serverIPs)) 61 | } 62 | if _, err := pkiI.NewCert(args[0], options...); err != nil { 63 | fmt.Println(fmt.Errorf("can`t build server pair: %s", err)) 64 | } 65 | }, 66 | } 67 | 68 | var buildKey = &cobra.Command{ 69 | Use: "build-key CN", 70 | Short: "build client cert/key with CN", 71 | Args: cobra.MinimumNArgs(1), 72 | Run: func(cmd *cobra.Command, args []string) { 73 | _, err := pkiI.NewCert(args[0], pki.Client()) 74 | if err != nil { 75 | fmt.Println(fmt.Errorf("can`t build client pair: %s", err)) 76 | } 77 | }, 78 | } 79 | 80 | var revokeFull = &cobra.Command{ 81 | Use: "revoke-full CN", 82 | Short: "revoke cert with CN", 83 | Args: cobra.MinimumNArgs(1), 84 | Run: func(cmd *cobra.Command, args []string) { 85 | err := pkiI.RevokeAllByCN(args[0]) 86 | if err != nil { 87 | fmt.Println(fmt.Errorf("can`t revoke cert: %s", err)) 88 | } 89 | }, 90 | } 91 | 92 | func init() { 93 | rootCmd.PersistentFlags().StringVarP(&keyDir, "key-dir", "k", "keys", "") 94 | buildServerKey.Flags().StringArrayVarP(&serverDnsNames, "dns", "n", nil, "server dns names") 95 | buildServerKey.Flags().IPSliceVarP(&serverIPs, "ip", "i", nil, "server ip addresses") 96 | rootCmd.AddCommand(buildCa) 97 | rootCmd.AddCommand(buildServerKey) 98 | rootCmd.AddCommand(buildKey) 99 | rootCmd.AddCommand(revokeFull) 100 | } 101 | 102 | func getPki() (*pki.PKI, error) { 103 | return pki.InitPKI(keyDir, nil) 104 | } 105 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= 2 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 3 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= 7 | github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= 8 | github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= 9 | github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= 10 | github.com/inconshreveable/mousetrap v1.0.1 h1:U3uMjPSQEBMNp1lFxmllqCPM6P5u/Xq7Pgzkat/bFNc= 11 | github.com/inconshreveable/mousetrap v1.0.1/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= 12 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 13 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 14 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 15 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 16 | github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= 17 | github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= 18 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 19 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 20 | github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= 21 | github.com/spf13/cobra v1.5.0 h1:X+jTBEBqF0bHN+9cSMgmfuvv2VHJ9ezmFNf9Y/XstYU= 22 | github.com/spf13/cobra v1.5.0/go.mod h1:dWXEIy2H428czQCjInthrTRUg7yKbok+2Qi/yBIJoUM= 23 | github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= 24 | github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= 25 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 26 | github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= 27 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 28 | golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a h1:dGzPydgVsqGcTRVwiLJ1jVbufYwmzD3LfVPLKsKg+0k= 29 | golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 30 | golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab h1:2QkjZIsXupsJbJIdSjjUOgWK3aEtzyuh2mPt3l/CkeU= 31 | golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 32 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 33 | gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= 34 | gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 35 | gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= 36 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 37 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 38 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 39 | -------------------------------------------------------------------------------- /pkg/pki/options_test.go: -------------------------------------------------------------------------------- 1 | package pki 2 | 3 | import ( 4 | "crypto/x509" 5 | "crypto/x509/pkix" 6 | "encoding/asn1" 7 | "net" 8 | "reflect" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | func TestCN(t *testing.T) { 14 | type args struct { 15 | cn string 16 | } 17 | tests := []struct { 18 | name string 19 | args args 20 | want *x509.Certificate 21 | }{ 22 | { 23 | name: "change cn", 24 | args: args{ 25 | cn: "changed", 26 | }, 27 | want: &x509.Certificate{Subject: pkix.Name{CommonName: "changed"}}, 28 | }, 29 | } 30 | for _, tt := range tests { 31 | t.Run(tt.name, func(t *testing.T) { 32 | cert := &x509.Certificate{} 33 | if CN(tt.args.cn)(cert); !reflect.DeepEqual(cert, tt.want) { 34 | t.Errorf("CN() = %v, want %v", cert, tt.want) 35 | } 36 | }) 37 | } 38 | } 39 | 40 | func TestDNSNames(t *testing.T) { 41 | type args struct { 42 | names []string 43 | } 44 | tests := []struct { 45 | name string 46 | args args 47 | want *x509.Certificate 48 | }{ 49 | { 50 | name: "changed", 51 | args: args{ 52 | names: []string{"first", "second"}, 53 | }, 54 | want: &x509.Certificate{DNSNames: []string{"first", "second"}}, 55 | }, 56 | } 57 | for _, tt := range tests { 58 | t.Run(tt.name, func(t *testing.T) { 59 | cert := &x509.Certificate{} 60 | if DNSNames(tt.args.names)(cert); !reflect.DeepEqual(cert, tt.want) { 61 | t.Errorf("DNSNames() = %v, want %v", cert, tt.want) 62 | } 63 | }) 64 | } 65 | } 66 | 67 | func TestExcludedDNSDomains(t *testing.T) { 68 | type args struct { 69 | names []string 70 | } 71 | tests := []struct { 72 | name string 73 | args args 74 | want *x509.Certificate 75 | }{ 76 | { 77 | name: "changed", 78 | args: args{ 79 | names: []string{"first", "second"}, 80 | }, 81 | want: &x509.Certificate{ExcludedDNSDomains: []string{"first", "second"}}, 82 | }, 83 | } 84 | for _, tt := range tests { 85 | t.Run(tt.name, func(t *testing.T) { 86 | cert := &x509.Certificate{} 87 | if ExcludedDNSDomains(tt.args.names)(cert); !reflect.DeepEqual(cert, tt.want) { 88 | t.Errorf("ExcludedDNSDomains() = %v, want %v", cert, tt.want) 89 | } 90 | }) 91 | } 92 | } 93 | 94 | func TestIPAddresses(t *testing.T) { 95 | type args struct { 96 | ips []net.IP 97 | } 98 | tests := []struct { 99 | name string 100 | args args 101 | want *x509.Certificate 102 | }{ 103 | { 104 | name: "changed", 105 | args: args{ 106 | ips: []net.IP{{127, 0, 0, 1}}, 107 | }, 108 | want: &x509.Certificate{IPAddresses: []net.IP{{127, 0, 0, 1}}}, 109 | }, 110 | } 111 | for _, tt := range tests { 112 | t.Run(tt.name, func(t *testing.T) { 113 | cert := &x509.Certificate{} 114 | if IPAddresses(tt.args.ips)(cert); !reflect.DeepEqual(cert, tt.want) { 115 | t.Errorf("IPAddresses() = %v, want %v", cert, tt.want) 116 | } 117 | }) 118 | } 119 | } 120 | 121 | func TestServer(t *testing.T) { 122 | want := &x509.Certificate{} 123 | want.KeyUsage = x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement | x509.KeyUsageKeyEncipherment 124 | want.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} 125 | val, _ := asn1.Marshal(asn1.BitString{Bytes: []byte{0x40}, BitLength: 2}) // setting nsCertType to Server Type 126 | want.ExtraExtensions = []pkix.Extension{} 127 | want.ExtraExtensions = append(want.ExtraExtensions, pkix.Extension{Id: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 1, 1}, Value: val}) 128 | tests := []struct { 129 | name string 130 | want *x509.Certificate 131 | }{ 132 | { 133 | name: "changed", 134 | want: want, 135 | }, 136 | } 137 | for _, tt := range tests { 138 | t.Run(tt.name, func(t *testing.T) { 139 | cert := &x509.Certificate{} 140 | if Server()(cert); !reflect.DeepEqual(cert, tt.want) { 141 | t.Errorf("Server() = %v, want %v", cert, tt.want) 142 | } 143 | }) 144 | } 145 | } 146 | 147 | func TestClient(t *testing.T) { 148 | want := &x509.Certificate{} 149 | want.KeyUsage = x509.KeyUsageDigitalSignature | x509.KeyUsageKeyAgreement 150 | want.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth} 151 | val, _ := asn1.Marshal(asn1.BitString{Bytes: []byte{0x80}, BitLength: 2}) // setting nsCertType to Client Type 152 | want.ExtraExtensions = []pkix.Extension{} 153 | want.ExtraExtensions = append(want.ExtraExtensions, pkix.Extension{Id: asn1.ObjectIdentifier{2, 16, 840, 1, 113730, 1, 1}, Value: val}) 154 | tests := []struct { 155 | name string 156 | want *x509.Certificate 157 | }{ 158 | { 159 | name: "changed", 160 | want: want, 161 | }, 162 | } 163 | for _, tt := range tests { 164 | t.Run(tt.name, func(t *testing.T) { 165 | cert := &x509.Certificate{} 166 | if Client()(cert); !reflect.DeepEqual(cert, tt.want) { 167 | t.Errorf("Client() = %v, want %v", cert, tt.want) 168 | } 169 | }) 170 | } 171 | } 172 | 173 | func TestNotAfter(t *testing.T) { 174 | type args struct { 175 | time time.Time 176 | } 177 | tests := []struct { 178 | name string 179 | args args 180 | want *x509.Certificate 181 | }{ 182 | { 183 | name: "changed", 184 | args: args{ 185 | time: time.Unix(100000, 0), 186 | }, 187 | want: &x509.Certificate{NotAfter: time.Unix(100000, 0)}, 188 | }, 189 | } 190 | for _, tt := range tests { 191 | t.Run(tt.name, func(t *testing.T) { 192 | cert := &x509.Certificate{} 193 | if NotAfter(tt.args.time)(cert); !reflect.DeepEqual(cert, tt.want) { 194 | t.Errorf("NotAfter() = %v, want %v", cert, tt.want) 195 | } 196 | }) 197 | } 198 | } 199 | -------------------------------------------------------------------------------- /pkg/pki/pki_test.go: -------------------------------------------------------------------------------- 1 | package pki 2 | 3 | import ( 4 | "crypto/x509/pkix" 5 | "github.com/kemsta/go-easyrsa/internal/fsStorage" 6 | "log" 7 | "math/big" 8 | "os" 9 | "path" 10 | "path/filepath" 11 | "reflect" 12 | "testing" 13 | 14 | "github.com/stretchr/testify/assert" 15 | ) 16 | 17 | var testData = "test/pki/" 18 | 19 | func TestPki_NewCa(t *testing.T) { 20 | pki, cleanup := getTmpPki() 21 | defer cleanup() 22 | t.Run("create ca and write", func(t *testing.T) { 23 | got, err := pki.NewCa() 24 | assert.NoError(t, err) 25 | assert.NotNil(t, got) 26 | assert.NotEmpty(t, got.CertPemBytes) 27 | assert.NotEmpty(t, got.KeyPemBytes) 28 | }) 29 | t.Run("get ca by cn", func(t *testing.T) { 30 | got, err := pki.Storage.GetByCN("ca") 31 | assert.NoError(t, err) 32 | assert.NotNil(t, got) 33 | assert.Len(t, got, 1) 34 | }) 35 | t.Run("get ca by serial", func(t *testing.T) { 36 | got, err := pki.Storage.GetBySerial(big.NewInt(1)) 37 | assert.NoError(t, err) 38 | assert.NotNil(t, got) 39 | assert.NotEmpty(t, got.CertPemBytes) 40 | assert.NotEmpty(t, got.KeyPemBytes) 41 | }) 42 | t.Run("decode ca", func(t *testing.T) { 43 | ca, _ := pki.Storage.GetByCN("ca") 44 | key, cert, err := ca[0].Decode() 45 | assert.NoError(t, err) 46 | assert.NotNil(t, key) 47 | assert.NotNil(t, cert) 48 | assert.Equal(t, cert.SerialNumber, big.NewInt(1)) 49 | assert.True(t, cert.IsCA) 50 | assert.Equal(t, cert.Subject.CommonName, "ca") 51 | }) 52 | } 53 | 54 | func TestPKI_newCert(t *testing.T) { 55 | pki, cleanup := getTmpPki() 56 | defer cleanup() 57 | _, _ = pki.NewCa() 58 | t.Run("create server cert and write", func(t *testing.T) { 59 | got, err := pki.NewCert("server", Server()) 60 | assert.NoError(t, err) 61 | assert.NotNil(t, got) 62 | assert.NotEmpty(t, got.CertPemBytes) 63 | assert.NotEmpty(t, got.KeyPemBytes) 64 | }) 65 | t.Run("get cert by cn", func(t *testing.T) { 66 | got, err := pki.Storage.GetByCN("server") 67 | assert.NoError(t, err) 68 | assert.NotNil(t, got) 69 | assert.Len(t, got, 1) 70 | }) 71 | t.Run("get cert by serial", func(t *testing.T) { 72 | got, err := pki.Storage.GetBySerial(big.NewInt(2)) 73 | assert.NoError(t, err) 74 | assert.NotNil(t, got) 75 | assert.NotEmpty(t, got.CertPemBytes) 76 | assert.NotEmpty(t, got.KeyPemBytes) 77 | }) 78 | t.Run("decode cert", func(t *testing.T) { 79 | ca, _ := pki.Storage.GetByCN("server") 80 | key, cert, err := ca[0].Decode() 81 | assert.NoError(t, err) 82 | assert.NotNil(t, key) 83 | assert.NotNil(t, cert) 84 | assert.Equal(t, cert.SerialNumber, big.NewInt(2)) 85 | assert.Equal(t, cert.Subject.CommonName, "server") 86 | }) 87 | } 88 | 89 | func getTmpPki() (*PKI, func()) { 90 | _ = os.MkdirAll(testData, 0777) 91 | storDir, err := filepath.Abs(testData) 92 | _ = os.MkdirAll(storDir, 0777) 93 | stor := fsStorage.NewDirKeyStorage(storDir) 94 | serialProvider := fsStorage.NewFileSerialProvider(filepath.Join(storDir, "serial")) 95 | crlHolder := fsStorage.NewFileCRLHolder(filepath.Join(storDir, "crl.pem")) 96 | pki := NewPKI(stor, serialProvider, crlHolder, pkix.Name{}) 97 | if err != nil { 98 | log.Fatalln("can`t create pki") 99 | } 100 | 101 | return pki, func() { 102 | _ = os.RemoveAll(storDir) 103 | } 104 | } 105 | 106 | func TestPKI_getCRL(t *testing.T) { 107 | pki, cleanup := getTmpPki() 108 | defer cleanup() 109 | t.Run("get crl", func(t *testing.T) { 110 | list, err := pki.GetCRL() 111 | assert.NoError(t, err) 112 | assert.NotNil(t, list) 113 | }) 114 | } 115 | 116 | func TestPKI_RevokeOne(t *testing.T) { 117 | pki, cleanup := getTmpPki() 118 | defer cleanup() 119 | _, _ = pki.NewCa() 120 | _, _ = pki.NewCert("server", Server()) 121 | _, _ = pki.NewCert("server", Server()) 122 | _, _ = pki.NewCert("cert") 123 | t.Run("revoke", func(t *testing.T) { 124 | err := pki.RevokeOne(big.NewInt(300)) 125 | assert.NoError(t, err) 126 | list, _ := pki.GetCRL() 127 | assert.Equal(t, list.TBSCertList.RevokedCertificates[0].SerialNumber, big.NewInt(300)) 128 | }) 129 | } 130 | 131 | func TestPKI_IsRevoked(t *testing.T) { 132 | pki, cleanup := getTmpPki() 133 | defer cleanup() 134 | _, _ = pki.NewCa() 135 | _, _ = pki.NewCert("server", Server()) 136 | _, _ = pki.NewCert("server", Server()) 137 | _, _ = pki.NewCert("cert") 138 | t.Run("revoke", func(t *testing.T) { 139 | err := pki.RevokeOne(big.NewInt(4)) 140 | assert.NoError(t, err) 141 | assert.True(t, pki.IsRevoked(big.NewInt(4))) 142 | assert.False(t, pki.IsRevoked(big.NewInt(1))) 143 | assert.False(t, pki.IsRevoked(big.NewInt(42))) 144 | }) 145 | } 146 | 147 | func TestPKI_RevokeAllByCN(t *testing.T) { 148 | pki, cleanup := getTmpPki() 149 | defer cleanup() 150 | _, _ = pki.NewCa() 151 | _, _ = pki.NewCert("server", Server()) 152 | _, _ = pki.NewCert("server", Server()) 153 | _, _ = pki.NewCert("cert") 154 | t.Run("revoke", func(t *testing.T) { 155 | err := pki.RevokeAllByCN("server") 156 | assert.NoError(t, err) 157 | list, _ := pki.GetCRL() 158 | assert.Len(t, list.TBSCertList.RevokedCertificates, 2) 159 | assert.Equal(t, list.TBSCertList.RevokedCertificates[0].SerialNumber, big.NewInt(2)) 160 | }) 161 | } 162 | 163 | func TestPKI_GetLastCA(t *testing.T) { 164 | pki, cleanup := getTmpPki() 165 | defer cleanup() 166 | t.Run("empty ca", func(t *testing.T) { 167 | pair, err := pki.GetLastCA() 168 | assert.Error(t, err) 169 | assert.Nil(t, pair) 170 | }) 171 | t.Run("one ca", func(t *testing.T) { 172 | _, _ = pki.NewCa() 173 | pair, err := pki.GetLastCA() 174 | assert.NoError(t, err) 175 | assert.NotNil(t, pair) 176 | assert.Equal(t, pair.CN, "ca") 177 | assert.Equal(t, pair.Serial, big.NewInt(1)) 178 | }) 179 | t.Run("5 ca", func(t *testing.T) { 180 | _, _ = pki.NewCa() 181 | _, _ = pki.NewCa() 182 | _, _ = pki.NewCa() 183 | _, _ = pki.NewCa() 184 | pair, err := pki.GetLastCA() 185 | assert.NoError(t, err) 186 | assert.NotNil(t, pair) 187 | assert.Equal(t, pair.CN, "ca") 188 | assert.Equal(t, pair.Serial, big.NewInt(5)) 189 | }) 190 | } 191 | 192 | func TestInitPKI(t *testing.T) { 193 | pkiDir := "test/def_pki" 194 | defer func() { 195 | _ = os.RemoveAll(pkiDir) 196 | }() 197 | type args struct { 198 | pkiDir string 199 | } 200 | tests := []struct { 201 | name string 202 | args args 203 | want *PKI 204 | wantErr bool 205 | }{ 206 | { 207 | name: "default pki", 208 | args: args{ 209 | pkiDir: "test/def_pki", 210 | }, 211 | want: &PKI{ 212 | Storage: fsStorage.NewDirKeyStorage(pkiDir), 213 | serialProvider: fsStorage.NewFileSerialProvider(path.Join(pkiDir, "serial")), 214 | crlHolder: fsStorage.NewFileCRLHolder(path.Join(pkiDir, "crl.pem")), 215 | subjTemplate: pkix.Name{}, 216 | }, 217 | wantErr: false, 218 | }, 219 | } 220 | for _, tt := range tests { 221 | t.Run(tt.name, func(t *testing.T) { 222 | got, err := InitPKI(tt.args.pkiDir, nil) 223 | if (err != nil) != tt.wantErr { 224 | t.Errorf("InitPKI() error = %v, wantErr %v", err, tt.wantErr) 225 | return 226 | } 227 | if !reflect.DeepEqual(got, tt.want) { 228 | t.Errorf("InitPKI() got = %v, want %v", got, tt.want) 229 | } 230 | }) 231 | } 232 | } 233 | -------------------------------------------------------------------------------- /pkg/pki/pki.go: -------------------------------------------------------------------------------- 1 | package pki 2 | 3 | import ( 4 | "crypto/rand" 5 | "crypto/rsa" 6 | "crypto/x509" 7 | "crypto/x509/pkix" 8 | "encoding/pem" 9 | "fmt" 10 | "math/big" 11 | "os" 12 | "path" 13 | "sort" 14 | "time" 15 | 16 | "github.com/kemsta/go-easyrsa/internal/fsStorage" 17 | "github.com/kemsta/go-easyrsa/pkg/pair" 18 | ) 19 | 20 | const ( 21 | PEMCertificateBlock string = "CERTIFICATE" // pem block header for x509.Certificate 22 | PEMRSAPrivateKeyBlock = "RSA PRIVATE KEY" // pem block header for rsa.PrivateKey 23 | PEMx509CRLBlock = "X509 CRL" // pem block header for CRL 24 | DefaultKeySizeBytes int = 2048 // default key size in bytes 25 | DefaultExpireYears = 99 // default expire time for certs 26 | ) 27 | 28 | // PKI struct holder 29 | type PKI struct { 30 | Storage KeyStorage 31 | serialProvider SerialProvider 32 | crlHolder CRLHolder 33 | subjTemplate pkix.Name 34 | } 35 | 36 | // NewPKI PKI struct "constructor" 37 | func NewPKI(storage KeyStorage, sp SerialProvider, crlHolder CRLHolder, subjTemplate pkix.Name) *PKI { 38 | return &PKI{Storage: storage, serialProvider: sp, crlHolder: crlHolder, subjTemplate: subjTemplate} 39 | } 40 | 41 | // Init default pki with file storages 42 | func InitPKI(pkiDir string, subjTemplate *pkix.Name) (*PKI, error) { 43 | if subjTemplate == nil { 44 | subjTemplate = &pkix.Name{} 45 | } 46 | pki := NewPKI(fsStorage.NewDirKeyStorage(pkiDir), 47 | fsStorage.NewFileSerialProvider(path.Join(pkiDir, "serial")), 48 | fsStorage.NewFileCRLHolder(path.Join(pkiDir, "crl.pem")), 49 | *subjTemplate) 50 | 51 | if _, err := os.Stat(pkiDir); os.IsNotExist(err) { 52 | if err := os.MkdirAll(pkiDir, 0750); err != nil { 53 | return nil, fmt.Errorf("can't create %v: %w", pkiDir, err) 54 | } 55 | } 56 | return pki, nil 57 | } 58 | 59 | // NewCa creating new version self signed CA pair 60 | func (p *PKI) NewCa(opts ...Option) (*pair.X509Pair, error) { 61 | key, err := rsa.GenerateKey(rand.Reader, DefaultKeySizeBytes) 62 | if err != nil { 63 | return nil, fmt.Errorf("can`t generate key: %w", err) 64 | } 65 | 66 | subj := p.subjTemplate 67 | subj.CommonName = "ca" 68 | 69 | serial, err := p.serialProvider.Next() 70 | if err != nil { 71 | return nil, fmt.Errorf("can`t get next serial: %w", err) 72 | } 73 | 74 | now := time.Now() 75 | 76 | template := x509.Certificate{ 77 | SerialNumber: serial, 78 | Subject: subj, 79 | NotBefore: now.Add(-10 * time.Minute).UTC(), 80 | NotAfter: now.Add(time.Duration(24*365*DefaultExpireYears) * time.Hour).UTC(), 81 | BasicConstraintsValid: true, 82 | IsCA: true, 83 | KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign | x509.KeyUsageCRLSign, 84 | } 85 | 86 | Apply(opts, &template) 87 | 88 | certificate, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) 89 | if err != nil { 90 | return nil, fmt.Errorf("can`t create cert: %w", err) 91 | } 92 | 93 | res := pair.NewX509Pair( 94 | pem.EncodeToMemory(&pem.Block{ 95 | Type: PEMRSAPrivateKeyBlock, 96 | Bytes: x509.MarshalPKCS1PrivateKey(key), 97 | }), 98 | pem.EncodeToMemory(&pem.Block{ 99 | Type: PEMCertificateBlock, 100 | Bytes: certificate, 101 | }), 102 | "ca", 103 | serial) 104 | err = p.Storage.Put(res) 105 | if err != nil { 106 | return nil, fmt.Errorf("can't put generated cert into storage: %w", err) 107 | } 108 | return res, nil 109 | } 110 | 111 | // NewCert generate new pair signed by last CA key 112 | func (p *PKI) NewCert(cn string, opts ...Option) (*pair.X509Pair, error) { 113 | caPair, err := p.GetLastCA() 114 | if err != nil { 115 | return nil, fmt.Errorf("can`t get ca pair: %w", err) 116 | } 117 | caKey, caCert, err := caPair.Decode() 118 | if err != nil { 119 | return nil, fmt.Errorf("can`t parse ca pair: %w", err) 120 | } 121 | 122 | key, err := rsa.GenerateKey(rand.Reader, 2048) 123 | if err != nil { 124 | return nil, fmt.Errorf("can`t create private key: %w", err) 125 | } 126 | 127 | serial, err := p.serialProvider.Next() 128 | if err != nil { 129 | return nil, err 130 | } 131 | 132 | now := time.Now() 133 | subj := p.subjTemplate 134 | subj.CommonName = cn 135 | tmpl := x509.Certificate{ 136 | NotBefore: now.Add(-10 * time.Minute).UTC(), 137 | NotAfter: now.Add(time.Duration(24*365*DefaultExpireYears) * time.Hour).UTC(), 138 | SerialNumber: serial, 139 | Subject: subj, 140 | BasicConstraintsValid: true, 141 | } 142 | 143 | Apply(opts, &tmpl) 144 | 145 | // Sign with CA's private key 146 | cert, err := x509.CreateCertificate(rand.Reader, &tmpl, caCert, &key.PublicKey, caKey) 147 | if err != nil { 148 | return nil, fmt.Errorf("certificate cannot be created: %w", err) 149 | } 150 | 151 | priKeyPem := pem.EncodeToMemory(&pem.Block{ 152 | Type: PEMRSAPrivateKeyBlock, 153 | Bytes: x509.MarshalPKCS1PrivateKey(key), 154 | }) 155 | 156 | certPem := pem.EncodeToMemory(&pem.Block{ 157 | Type: PEMCertificateBlock, 158 | Bytes: cert, 159 | }) 160 | 161 | res := pair.NewX509Pair(priKeyPem, certPem, cn, serial) 162 | 163 | err = p.Storage.Put(res) 164 | if err != nil { 165 | return nil, err 166 | } 167 | return res, nil 168 | } 169 | 170 | // GetCRL return current revoke list 171 | func (p *PKI) GetCRL() (*pkix.CertificateList, error) { 172 | return p.crlHolder.Get() 173 | } 174 | 175 | // GetLastCA return last CA pair 176 | func (p *PKI) GetLastCA() (*pair.X509Pair, error) { 177 | return p.Storage.GetLastByCn("ca") 178 | } 179 | 180 | // RevokeOne revoke one pair with serial 181 | func (p *PKI) RevokeOne(serial *big.Int) error { 182 | list := make([]pkix.RevokedCertificate, 0) 183 | if oldList, err := p.GetCRL(); err == nil { 184 | list = oldList.TBSCertList.RevokedCertificates 185 | } 186 | caPairs, err := p.Storage.GetByCN("ca") 187 | if err != nil { 188 | return fmt.Errorf("can`t get ca certs for signing crl: %w", err) 189 | } 190 | sort.Slice(caPairs, func(i, j int) bool { 191 | return caPairs[i].Serial.Cmp(caPairs[j].Serial) == 1 192 | }) 193 | caKey, caCert, err := caPairs[0].Decode() 194 | if err != nil { 195 | return fmt.Errorf("can`t decode ca certs for signing crl: %w", err) 196 | } 197 | list = append(list, pkix.RevokedCertificate{ 198 | SerialNumber: serial, 199 | RevocationTime: time.Now(), 200 | }) 201 | crlBytes, err := caCert.CreateCRL( 202 | rand.Reader, caKey, removeDups(list), time.Now(), time.Now().Add(DefaultExpireYears*365*24*time.Hour)) 203 | if err != nil { 204 | return fmt.Errorf("can`t create crl: %w", err) 205 | } 206 | crlPem := pem.EncodeToMemory(&pem.Block{ 207 | Type: PEMx509CRLBlock, 208 | Bytes: crlBytes, 209 | }) 210 | err = p.crlHolder.Put(crlPem) 211 | if err != nil { 212 | return fmt.Errorf("can`t put new crl: %w", err) 213 | } 214 | return nil 215 | } 216 | 217 | // RevokeAllByCN revoke all pairs with common name 218 | func (p *PKI) RevokeAllByCN(cn string) error { 219 | pairs, err := p.Storage.GetByCN(cn) 220 | if err != nil { 221 | return fmt.Errorf("can`t get pairs for revoke: %w", err) 222 | } 223 | for _, certPair := range pairs { 224 | err := p.RevokeOne(certPair.Serial) 225 | if err != nil { 226 | return fmt.Errorf("can`t revoke: %w", err) 227 | } 228 | } 229 | return nil 230 | } 231 | 232 | // IsRevoked return true if it`s revoked serial 233 | func (p *PKI) IsRevoked(serial *big.Int) bool { 234 | revokedCerts, err := p.GetCRL() 235 | if err != nil { 236 | revokedCerts = &pkix.CertificateList{} 237 | } 238 | for _, cert := range revokedCerts.TBSCertList.RevokedCertificates { 239 | if cert.SerialNumber.Cmp(serial) == 0 { 240 | return true 241 | } 242 | } 243 | return false 244 | } 245 | 246 | func removeDups(list []pkix.RevokedCertificate) []pkix.RevokedCertificate { 247 | encountered := map[int64]bool{} 248 | result := make([]pkix.RevokedCertificate, 0) 249 | for _, cert := range list { 250 | if !encountered[cert.SerialNumber.Int64()] { 251 | result = append(result, cert) 252 | encountered[cert.SerialNumber.Int64()] = true 253 | } 254 | } 255 | return result 256 | } 257 | -------------------------------------------------------------------------------- /internal/fsStorage/storage.go: -------------------------------------------------------------------------------- 1 | package fsStorage 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "crypto/x509" 7 | "crypto/x509/pkix" 8 | "errors" 9 | "fmt" 10 | "github.com/gofrs/flock" 11 | "github.com/kemsta/go-easyrsa/pkg/pair" 12 | "io" 13 | "io/ioutil" 14 | "math/big" 15 | "os" 16 | "path/filepath" 17 | "sort" 18 | "strconv" 19 | "strings" 20 | "time" 21 | ) 22 | 23 | const ( 24 | LockPeriod = time.Millisecond * 100 25 | LockTimeout = time.Second * 10 26 | CertFileExtension = ".crt" // certificate file extension 27 | ) 28 | 29 | // Common CRLHolder implementation. It's saving file on fs 30 | type FileCRLHolder struct { 31 | locker *flock.Flock 32 | path string 33 | } 34 | 35 | func NewFileCRLHolder(path string) *FileCRLHolder { 36 | return &FileCRLHolder{locker: flock.New(fmt.Sprintf("%v.lock", path)), path: path} 37 | } 38 | 39 | // Save new crl content to storage 40 | func (h *FileCRLHolder) Put(content []byte) error { 41 | ctx, cancel := context.WithTimeout(context.Background(), LockTimeout) 42 | defer cancel() 43 | locked, err := h.locker.TryLockContext(ctx, LockPeriod) 44 | if err != nil { 45 | return fmt.Errorf("there's error with saving crl to storage: %w", err) 46 | } 47 | if !locked { 48 | return fmt.Errorf("can`t lock serial file %v", h.path) 49 | } 50 | defer func() { 51 | _ = h.locker.Unlock() 52 | }() 53 | if err = writeFileAtomic(h.path, bytes.NewReader(content), 0644); err != nil { 54 | return fmt.Errorf("can't overwrite crl file %s with new content: %w", h.path, err) 55 | } 56 | 57 | return nil 58 | } 59 | 60 | // Get crl content from storage 61 | func (h *FileCRLHolder) Get() (*pkix.CertificateList, error) { 62 | err := h.locker.RLock() 63 | if err != nil { 64 | return nil, err 65 | } 66 | defer func() { 67 | _ = h.locker.Unlock() 68 | }() 69 | if stat, err := os.Stat(h.path); err != nil || stat.Size() == 0 { 70 | return &pkix.CertificateList{}, nil 71 | } 72 | fBytes, err := ioutil.ReadFile(h.path) 73 | if err != nil { 74 | return nil, fmt.Errorf("can`t read crl %v: %w", h.path, err) 75 | } 76 | list, err := x509.ParseCRL(fBytes) 77 | if err != nil { 78 | return nil, fmt.Errorf("can`t parse crl \n %v: %w", string(fBytes), err) 79 | } 80 | return list, nil 81 | } 82 | 83 | // FileSerialProvider implement SerialProvider interface with storing serial in file on fs 84 | type FileSerialProvider struct { 85 | locker *flock.Flock 86 | path string 87 | } 88 | 89 | // Get next serial and increment counter in storage 90 | func (p *FileSerialProvider) Next() (*big.Int, error) { 91 | ctx, cancel := context.WithTimeout(context.Background(), LockTimeout) 92 | defer cancel() 93 | locked, err := p.locker.TryLockContext(ctx, LockPeriod) 94 | if err != nil { 95 | return nil, fmt.Errorf("can`t lock serial file %v: %w", p.path, err) 96 | } 97 | if !locked { 98 | return nil, fmt.Errorf("can`t lock serial file %v", p.path) 99 | } 100 | defer func() { 101 | _ = p.locker.Unlock() 102 | }() 103 | res := big.NewInt(0) 104 | sBytes, err := ioutil.ReadFile(p.path) 105 | if os.IsNotExist(err) { 106 | // nothing to do. New serial 107 | } else if err != nil { 108 | return nil, fmt.Errorf("can`t read serial file %v: %w", p.path, err) 109 | } 110 | 111 | if len(sBytes) != 0 { 112 | res.SetString(string(sBytes), 16) 113 | } 114 | res.Add(big.NewInt(1), res) 115 | 116 | if err := writeFileAtomic(p.path, strings.NewReader(res.Text(16)), 0644); err != nil { 117 | return res, fmt.Errorf("can`t write cert %v: %w", p.path, err) 118 | } 119 | 120 | return res, nil 121 | } 122 | 123 | func NewFileSerialProvider(path string) *FileSerialProvider { 124 | return &FileSerialProvider{ 125 | locker: flock.New(fmt.Sprintf("%v.lock", path)), 126 | path: path, 127 | } 128 | } 129 | 130 | // DirKeyStorage is a Storage interface implementation with storing pairs on fs 131 | type DirKeyStorage struct { 132 | keydir string 133 | } 134 | 135 | func NewDirKeyStorage(keydir string) *DirKeyStorage { 136 | return &DirKeyStorage{keydir: keydir} 137 | } 138 | 139 | // Put keypair in dir as /keydir/cn/serial.[crt,key] 140 | func (s *DirKeyStorage) Put(pair *pair.X509Pair) error { 141 | certPath, keyPath, err := s.makePath(pair) 142 | if err != nil { 143 | return fmt.Errorf("can`t make path %v: %w", pair, err) 144 | } 145 | if err := writeFileAtomic(certPath, bytes.NewReader(pair.CertPemBytes), 0644); err != nil { 146 | return fmt.Errorf("can`t write cert %v: %w", certPath, err) 147 | } 148 | 149 | if err := writeFileAtomic(keyPath, bytes.NewReader(pair.KeyPemBytes), 0644); err != nil { 150 | return fmt.Errorf("can`t write cert %v: %w", certPath, err) 151 | } 152 | return nil 153 | } 154 | 155 | // DeleteByCn delete all pair with cn 156 | func (s *DirKeyStorage) DeleteByCn(cn string) error { 157 | err := os.Remove(filepath.Join(s.keydir, cn)) 158 | if err != nil { 159 | return fmt.Errorf("can`t delete by cn %v in %v: %w", cn, s.keydir, err) 160 | } 161 | return nil 162 | } 163 | 164 | // Delete only one pair with serial 165 | func (s *DirKeyStorage) DeleteBySerial(serial *big.Int) error { 166 | p, err := s.GetBySerial(serial) 167 | if err != nil { 168 | return fmt.Errorf("can`t find pair by serial %v: %w", serial, err) 169 | } 170 | certPath := filepath.Join(s.keydir, p.CN, fmt.Sprintf("%s.crt", p.Serial.Text(16))) 171 | keyPath := filepath.Join(s.keydir, p.CN, fmt.Sprintf("%s.key", p.Serial.Text(16))) 172 | err = os.Remove(certPath) 173 | if err != nil { 174 | return fmt.Errorf("can`t delete cert %v: %w", certPath, err) 175 | } 176 | err = os.Remove(keyPath) 177 | if err != nil { 178 | return fmt.Errorf("can`t delete key %v: %w", keyPath, err) 179 | } 180 | return nil 181 | } 182 | 183 | // GetByCN return all pairs with cn 184 | func (s *DirKeyStorage) GetByCN(cn string) ([]*pair.X509Pair, error) { 185 | res := make([]*pair.X509Pair, 0) 186 | err := filepath.Walk(filepath.Join(s.keydir, cn), func(path string, info os.FileInfo, err error) error { 187 | if err != nil { 188 | return err 189 | } 190 | if filepath.Ext(path) == CertFileExtension { 191 | fileName := filepath.Base(path) 192 | serial, err := strconv.ParseInt(fileName[0:len(fileName)-len(filepath.Ext(fileName))], 16, 64) 193 | if err != nil { 194 | return nil 195 | } 196 | certBytes, err := ioutil.ReadFile(path) 197 | if err != nil { 198 | return nil 199 | } 200 | keyBytes, err := ioutil.ReadFile(fmt.Sprintf("%s.key", path[0:len(path)-len(filepath.Ext(path))])) 201 | if err != nil { 202 | return nil 203 | } 204 | res = append(res, pair.NewX509Pair(keyBytes, certBytes, cn, big.NewInt(serial))) 205 | } 206 | return nil 207 | }) 208 | if len(res) == 0 { 209 | return nil, fmt.Errorf("%v not found", cn) 210 | } 211 | return res, err 212 | } 213 | 214 | // GetLastByCn return only last pair with cn 215 | func (s *DirKeyStorage) GetLastByCn(cn string) (*pair.X509Pair, error) { 216 | pairs, err := s.GetByCN(cn) 217 | if err != nil || len(pairs) == 0 { 218 | return nil, fmt.Errorf("can`t get cert %v: %w", cn, err) 219 | } 220 | sort.Slice(pairs, func(i, j int) bool { 221 | return pairs[i].Serial.Cmp(pairs[j].Serial) == 1 222 | }) 223 | return pairs[0], nil 224 | } 225 | 226 | // GetBySerial return only one pair with serial 227 | func (s *DirKeyStorage) GetBySerial(serial *big.Int) (*pair.X509Pair, error) { 228 | var res *pair.X509Pair 229 | err := filepath.Walk(s.keydir, func(path string, info os.FileInfo, err error) error { 230 | if err != nil { 231 | return nil 232 | } 233 | if filepath.Ext(path) == CertFileExtension { 234 | fileName := filepath.Base(path) 235 | ser, err := strconv.ParseInt(fileName[0:len(fileName)-len(filepath.Ext(fileName))], 16, 64) 236 | if err != nil { 237 | return nil 238 | } 239 | cn := filepath.Base(filepath.Dir(path)) 240 | if serial.Text(16) == big.NewInt(ser).Text(16) { 241 | certBytes, err := ioutil.ReadFile(path) 242 | if err != nil { 243 | return nil 244 | } 245 | keyBytes, err := ioutil.ReadFile(fmt.Sprintf("%s.key", path[0:len(path)-len(filepath.Ext(path))])) 246 | if err != nil { 247 | return nil 248 | } 249 | res = pair.NewX509Pair(keyBytes, certBytes, cn, big.NewInt(ser)) 250 | return nil 251 | } 252 | } 253 | return nil 254 | }) 255 | if res == nil { 256 | return nil, fmt.Errorf("%v not found", serial) 257 | } 258 | return res, err 259 | } 260 | 261 | // GetAll return all pairs 262 | func (s *DirKeyStorage) GetAll() ([]*pair.X509Pair, error) { 263 | res := make([]*pair.X509Pair, 0) 264 | err := filepath.Walk(s.keydir, func(path string, info os.FileInfo, err error) error { 265 | if err != nil { 266 | return nil 267 | } 268 | if filepath.Ext(path) == CertFileExtension { 269 | fileName := filepath.Base(path) 270 | ser, err := strconv.ParseInt(fileName[0:len(fileName)-len(filepath.Ext(fileName))], 16, 64) 271 | if err != nil { 272 | return nil 273 | } 274 | cn := filepath.Base(filepath.Dir(path)) 275 | certBytes, err := ioutil.ReadFile(path) 276 | if err != nil { 277 | return nil 278 | } 279 | keyBytes, err := ioutil.ReadFile(fmt.Sprintf("%s.key", path[0:len(path)-len(filepath.Ext(path))])) 280 | if err != nil { 281 | return nil 282 | } 283 | res = append(res, pair.NewX509Pair(keyBytes, certBytes, cn, big.NewInt(ser))) 284 | } 285 | return nil 286 | }) 287 | if err != nil { 288 | return nil, fmt.Errorf("can`t get all pairs: %w", err) 289 | } 290 | return res, nil 291 | } 292 | 293 | func (s *DirKeyStorage) makePath(pair *pair.X509Pair) (certPath, keyPath string, err error) { 294 | if pair.CN == "" || pair.Serial == nil { 295 | return "", "", errors.New("empty cn or serial") 296 | } 297 | basePath := filepath.Join(s.keydir, pair.CN) 298 | err = os.MkdirAll(basePath, 0755) 299 | if err != nil { 300 | return "", "", fmt.Errorf("can`t create dir for key pair %v: %w", pair, err) 301 | } 302 | return filepath.Join(basePath, fmt.Sprintf("%s.crt", pair.Serial.Text(16))), 303 | filepath.Join(basePath, fmt.Sprintf("%s.key", pair.Serial.Text(16))), nil 304 | } 305 | 306 | func writeFileAtomic(path string, r io.Reader, mode os.FileMode) error { 307 | dir, file := filepath.Split(path) 308 | if dir == "" { 309 | dir = "." 310 | } 311 | fd, err := ioutil.TempFile(dir, file) 312 | if err != nil { 313 | return fmt.Errorf("cannot create temp file: %w", err) 314 | } 315 | defer func() { 316 | _ = os.Remove(fd.Name()) 317 | }() 318 | defer func(fd *os.File) { 319 | _ = fd.Close() 320 | }(fd) 321 | if _, err := io.Copy(fd, r); err != nil { 322 | return fmt.Errorf("cannot write data to tempfile %q: %w", fd.Name(), err) 323 | } 324 | if err := fd.Sync(); err != nil { 325 | return fmt.Errorf("can't flush tempfile %q: %v", fd.Name(), err) 326 | } 327 | if err := fd.Close(); err != nil { 328 | return fmt.Errorf("can't close tempfile %q: %v", fd.Name(), err) 329 | } 330 | if err := os.Chmod(fd.Name(), mode); err != nil { 331 | return fmt.Errorf("can't set filemode on tempfile %q: %w", fd.Name(), err) 332 | } 333 | if err := os.Rename(fd.Name(), path); err != nil { 334 | return fmt.Errorf("cannot replace %q with tempfile %q: %w", path, fd.Name(), err) 335 | } 336 | return nil 337 | } 338 | -------------------------------------------------------------------------------- /internal/fsStorage/storage_test.go: -------------------------------------------------------------------------------- 1 | package fsStorage 2 | 3 | import ( 4 | "bytes" 5 | "crypto/x509/pkix" 6 | "fmt" 7 | "github.com/kemsta/go-easyrsa/pkg/pair" 8 | "io" 9 | "io/ioutil" 10 | "math/big" 11 | "os" 12 | "path/filepath" 13 | "reflect" 14 | "strings" 15 | "testing" 16 | 17 | "github.com/stretchr/testify/assert" 18 | ) 19 | 20 | func getTestDir() string { 21 | res, _ := filepath.Abs("test") 22 | return res 23 | } 24 | 25 | func TestDirKeyStorage_makePath(t *testing.T) { 26 | type fields struct { 27 | keydir string 28 | } 29 | type args struct { 30 | pair *pair.X509Pair 31 | } 32 | tests := []struct { 33 | name string 34 | fields fields 35 | args args 36 | wantCertPath string 37 | wantKeyPath string 38 | wantErr bool 39 | }{ 40 | { 41 | name: "empty cn", 42 | fields: fields{ 43 | keydir: filepath.Join(getTestDir(), "dir_keystorage"), 44 | }, 45 | args: args{ 46 | pair: &pair.X509Pair{ 47 | KeyPemBytes: nil, 48 | CertPemBytes: nil, 49 | CN: "", 50 | Serial: big.NewInt(66), 51 | }, 52 | }, 53 | wantCertPath: "", 54 | wantKeyPath: "", 55 | wantErr: true, 56 | }, 57 | { 58 | name: "empty serial", 59 | fields: fields{ 60 | keydir: filepath.Join(getTestDir(), "dir_keystorage"), 61 | }, 62 | args: args{ 63 | pair: &pair.X509Pair{ 64 | KeyPemBytes: nil, 65 | CertPemBytes: nil, 66 | CN: "good_cert", 67 | Serial: nil, 68 | }, 69 | }, 70 | wantCertPath: "", 71 | wantKeyPath: "", 72 | wantErr: true, 73 | }, 74 | { 75 | name: "can`t create dir", 76 | fields: fields{ 77 | keydir: filepath.Join(getTestDir(), "dir_keystorage"), 78 | }, 79 | args: args{ 80 | pair: &pair.X509Pair{ 81 | KeyPemBytes: nil, 82 | CertPemBytes: nil, 83 | CN: "bad_path", 84 | Serial: big.NewInt(66), 85 | }, 86 | }, 87 | wantCertPath: "", 88 | wantKeyPath: "", 89 | wantErr: true, 90 | }, 91 | { 92 | name: "good", 93 | fields: fields{ 94 | keydir: filepath.Join(getTestDir(), "dir_keystorage"), 95 | }, 96 | args: args{ 97 | pair: &pair.X509Pair{ 98 | KeyPemBytes: nil, 99 | CertPemBytes: nil, 100 | CN: "good_cert", 101 | Serial: big.NewInt(66), 102 | }, 103 | }, 104 | wantCertPath: filepath.Join(getTestDir(), "dir_keystorage", "good_cert/42.crt"), 105 | wantKeyPath: filepath.Join(getTestDir(), "dir_keystorage", "good_cert/42.key"), 106 | wantErr: false, 107 | }, 108 | } 109 | for _, tt := range tests { 110 | t.Run(tt.name, func(t *testing.T) { 111 | s := &DirKeyStorage{ 112 | keydir: tt.fields.keydir, 113 | } 114 | gotCertPath, gotKeyPath, err := s.makePath(tt.args.pair) 115 | if (err != nil) != tt.wantErr { 116 | t.Errorf("DirKeyStorage.makePath() error = %v, wantErr %v", err, tt.wantErr) 117 | return 118 | } 119 | if gotCertPath != tt.wantCertPath { 120 | t.Errorf("DirKeyStorage.makePath() gotCertPath = %v, want %v", gotCertPath, tt.wantCertPath) 121 | } 122 | if gotKeyPath != tt.wantKeyPath { 123 | t.Errorf("DirKeyStorage.makePath() gotKeyPath = %v, want %v", gotKeyPath, tt.wantKeyPath) 124 | } 125 | }) 126 | } 127 | } 128 | 129 | func TestDirKeyStorage_Put(t *testing.T) { 130 | type fields struct { 131 | keydir string 132 | } 133 | type args struct { 134 | pair *pair.X509Pair 135 | } 136 | tests := []struct { 137 | name string 138 | fields fields 139 | args args 140 | wantErr bool 141 | }{ 142 | { 143 | name: "can`t make path", 144 | fields: fields{ 145 | keydir: filepath.Join(getTestDir(), "dir_keystorage"), 146 | }, 147 | args: args{ 148 | pair: &pair.X509Pair{ 149 | KeyPemBytes: nil, 150 | CertPemBytes: nil, 151 | CN: "bad_path", 152 | Serial: big.NewInt(66), 153 | }, 154 | }, 155 | wantErr: true, 156 | }, 157 | { 158 | name: "good", 159 | fields: fields{ 160 | keydir: filepath.Join(getTestDir(), "dir_keystorage"), 161 | }, 162 | args: args{ 163 | pair: &pair.X509Pair{ 164 | KeyPemBytes: []byte("keybytes"), 165 | CertPemBytes: []byte("certbytes"), 166 | CN: "good_cert", 167 | Serial: big.NewInt(66), 168 | }, 169 | }, 170 | wantErr: false, 171 | }, 172 | { 173 | name: "bad_cert", 174 | fields: fields{ 175 | keydir: filepath.Join(getTestDir(), "dir_keystorage"), 176 | }, 177 | args: args{ 178 | pair: &pair.X509Pair{ 179 | KeyPemBytes: nil, 180 | CertPemBytes: nil, 181 | CN: "bad_cert", 182 | Serial: big.NewInt(66), 183 | }, 184 | }, 185 | wantErr: true, 186 | }, 187 | { 188 | name: "bad_key", 189 | fields: fields{ 190 | keydir: filepath.Join(getTestDir(), "dir_keystorage"), 191 | }, 192 | args: args{ 193 | pair: &pair.X509Pair{ 194 | KeyPemBytes: nil, 195 | CertPemBytes: nil, 196 | CN: "bad_key", 197 | Serial: big.NewInt(66), 198 | }, 199 | }, 200 | wantErr: true, 201 | }, 202 | } 203 | for _, tt := range tests { 204 | t.Run(tt.name, func(t *testing.T) { 205 | s := &DirKeyStorage{ 206 | keydir: tt.fields.keydir, 207 | } 208 | if err := s.Put(tt.args.pair); (err != nil) != tt.wantErr { 209 | t.Errorf("DirKeyStorage.Put() error = %v, wantErr %v", err, tt.wantErr) 210 | } 211 | }) 212 | } 213 | certBytes, _ := ioutil.ReadFile(filepath.Join(getTestDir(), "dir_keystorage", "good_cert/42.crt")) 214 | if !bytes.Equal(certBytes, []byte("certbytes")) { 215 | t.Errorf("DirKeyStorage.Put() wrong cert bytes in result file") 216 | } 217 | keyBytes, _ := ioutil.ReadFile(filepath.Join(getTestDir(), "dir_keystorage", "good_cert/42.key")) 218 | if !bytes.Equal(keyBytes, []byte("keybytes")) { 219 | t.Errorf("DirKeyStorage.Put() wrong key bytes in result file") 220 | } 221 | } 222 | 223 | func TestDirKeyStorage_DeleteByCn(t *testing.T) { 224 | _ = os.MkdirAll(filepath.Join(getTestDir(), "dir_keystorage", "for_delete"), 0755) 225 | type fields struct { 226 | keydir string 227 | } 228 | type args struct { 229 | cn string 230 | } 231 | tests := []struct { 232 | name string 233 | fields fields 234 | args args 235 | wantErr bool 236 | }{ 237 | { 238 | name: "recurse delete", 239 | fields: fields{ 240 | keydir: filepath.Join(getTestDir(), "dir_keystorage"), 241 | }, 242 | args: args{ 243 | cn: "for_delete", 244 | }, 245 | wantErr: false, 246 | }, 247 | } 248 | for _, tt := range tests { 249 | t.Run(tt.name, func(t *testing.T) { 250 | s := &DirKeyStorage{ 251 | keydir: tt.fields.keydir, 252 | } 253 | if err := s.DeleteByCn(tt.args.cn); (err != nil) != tt.wantErr { 254 | t.Errorf("DirKeyStorage.DeleteByCn() error = %v, wantErr %v", err, tt.wantErr) 255 | } 256 | }) 257 | } 258 | } 259 | 260 | func TestDirKeyStorage_GetByCN(t *testing.T) { 261 | type fields struct { 262 | keydir string 263 | } 264 | type args struct { 265 | cn string 266 | } 267 | tests := []struct { 268 | name string 269 | fields fields 270 | args args 271 | want []*pair.X509Pair 272 | wantErr bool 273 | }{ 274 | { 275 | name: "not exist", 276 | fields: fields{ 277 | keydir: filepath.Join(getTestDir(), "dir_keystorage"), 278 | }, 279 | args: args{ 280 | cn: "not_exist", 281 | }, 282 | want: nil, 283 | wantErr: true, 284 | }, 285 | { 286 | name: "bad cert", 287 | fields: fields{ 288 | keydir: filepath.Join(getTestDir(), "dir_keystorage"), 289 | }, 290 | args: args{ 291 | cn: "bad_cert", 292 | }, 293 | want: nil, 294 | wantErr: true, 295 | }, 296 | { 297 | name: "bad key", 298 | fields: fields{ 299 | keydir: filepath.Join(getTestDir(), "dir_keystorage"), 300 | }, 301 | args: args{ 302 | cn: "bad_key", 303 | }, 304 | want: nil, 305 | wantErr: true, 306 | }, 307 | { 308 | name: "good cert", 309 | fields: fields{ 310 | keydir: filepath.Join(getTestDir(), "dir_keystorage"), 311 | }, 312 | args: args{ 313 | cn: "good_cert", 314 | }, 315 | want: []*pair.X509Pair{pair.NewX509Pair([]byte("keybytes"), []byte("certbytes"), "good_cert", big.NewInt(66))}, 316 | wantErr: false, 317 | }, 318 | } 319 | for _, tt := range tests { 320 | t.Run(tt.name, func(t *testing.T) { 321 | s := &DirKeyStorage{ 322 | keydir: tt.fields.keydir, 323 | } 324 | got, err := s.GetByCN(tt.args.cn) 325 | if (err != nil) != tt.wantErr { 326 | t.Errorf("DirKeyStorage.GetByCN() error = %v, wantErr %v", err, tt.wantErr) 327 | return 328 | } 329 | if !reflect.DeepEqual(got, tt.want) { 330 | t.Errorf("DirKeyStorage.GetByCN() = %v, want %v", got, tt.want) 331 | } 332 | }) 333 | } 334 | } 335 | 336 | func TestDirKeyStorage_GetBySerial(t *testing.T) { 337 | type fields struct { 338 | keydir string 339 | } 340 | type args struct { 341 | serial *big.Int 342 | } 343 | tests := []struct { 344 | name string 345 | fields fields 346 | args args 347 | want *pair.X509Pair 348 | wantErr bool 349 | }{ 350 | { 351 | name: "42", 352 | fields: fields{ 353 | keydir: filepath.Join(getTestDir(), "dir_keystorage"), 354 | }, 355 | args: args{ 356 | serial: big.NewInt(66), 357 | }, 358 | want: pair.NewX509Pair([]byte("keybytes"), []byte("certbytes"), "good_cert", big.NewInt(66)), 359 | wantErr: false, 360 | }, 361 | } 362 | for _, tt := range tests { 363 | t.Run(tt.name, func(t *testing.T) { 364 | s := &DirKeyStorage{ 365 | keydir: tt.fields.keydir, 366 | } 367 | got, err := s.GetBySerial(tt.args.serial) 368 | if (err != nil) != tt.wantErr { 369 | t.Errorf("DirKeyStorage.GetBySerial() error = %v, wantErr %v", err, tt.wantErr) 370 | return 371 | } 372 | if !reflect.DeepEqual(got, tt.want) { 373 | t.Errorf("DirKeyStorage.GetBySerial() = %v, want %v", got, tt.want) 374 | } 375 | }) 376 | } 377 | } 378 | 379 | func TestDirKeyStorage_DeleteBySerial(t *testing.T) { 380 | 381 | _ = os.MkdirAll(filepath.Join(getTestDir(), "dir_keystorage", "for_delete"), 0755) 382 | _ = ioutil.WriteFile(filepath.Join(getTestDir(), "dir_keystorage", "for_delete", "a.crt"), []byte(""), 0600) 383 | _ = ioutil.WriteFile(filepath.Join(getTestDir(), "dir_keystorage", "for_delete", "a.key"), []byte(""), 0600) 384 | 385 | type fields struct { 386 | keydir string 387 | } 388 | type args struct { 389 | serial *big.Int 390 | } 391 | tests := []struct { 392 | name string 393 | fields fields 394 | args args 395 | wantErr bool 396 | }{ 397 | { 398 | name: "not exist", 399 | fields: fields{ 400 | keydir: filepath.Join(getTestDir(), "dir_keystorage"), 401 | }, 402 | args: args{ 403 | serial: big.NewInt(67), 404 | }, 405 | wantErr: true, 406 | }, 407 | { 408 | name: "exist", 409 | fields: fields{ 410 | keydir: filepath.Join(getTestDir(), "dir_keystorage"), 411 | }, 412 | args: args{ 413 | serial: big.NewInt(10), 414 | }, 415 | wantErr: false, 416 | }, 417 | } 418 | 419 | for _, tt := range tests { 420 | t.Run(tt.name, func(t *testing.T) { 421 | s := &DirKeyStorage{ 422 | keydir: tt.fields.keydir, 423 | } 424 | if err := s.DeleteBySerial(tt.args.serial); (err != nil) != tt.wantErr { 425 | t.Errorf("DirKeyStorage.DeleteBySerial() error = %v, wantErr %v", err, tt.wantErr) 426 | } 427 | }) 428 | } 429 | } 430 | 431 | func TestFileSerialProvider_Next(t *testing.T) { 432 | defer func() { 433 | _ = os.RemoveAll(filepath.Join(getTestDir(), "dir_keystorage", "new_serial")) 434 | _ = ioutil.WriteFile(filepath.Join(getTestDir(), "dir_keystorage", "wrong_serial"), []byte("gggg"), 0666) 435 | }() 436 | type fields struct { 437 | path string 438 | } 439 | tests := []struct { 440 | name string 441 | fields fields 442 | want *big.Int 443 | wantErr bool 444 | }{ 445 | { 446 | name: "not exist dir", 447 | fields: fields{ 448 | path: filepath.Join(getTestDir(), "dir_keystorage", "not_exist/serial"), 449 | }, 450 | want: nil, 451 | wantErr: true, 452 | }, 453 | { 454 | name: "not exist file", 455 | fields: fields{ 456 | path: filepath.Join(getTestDir(), "dir_keystorage", "new_serial"), 457 | }, 458 | want: big.NewInt(1), 459 | wantErr: false, 460 | }, 461 | { 462 | name: "broken file", 463 | fields: fields{ 464 | path: filepath.Join(getTestDir(), "dir_keystorage", "wrong_serial"), 465 | }, 466 | want: big.NewInt(1), 467 | wantErr: false, 468 | }, 469 | { 470 | name: "dir", 471 | fields: fields{ 472 | path: filepath.Join(getTestDir(), "dir_keystorage"), 473 | }, 474 | want: nil, 475 | wantErr: true, 476 | }, 477 | } 478 | for _, tt := range tests { 479 | t.Run(tt.name, func(t *testing.T) { 480 | p := NewFileSerialProvider(tt.fields.path) 481 | got, err := p.Next() 482 | if (err != nil) != tt.wantErr { 483 | t.Errorf("FileSerialProvider.Next() error = %v, wantErr %v", err, tt.wantErr) 484 | return 485 | } 486 | if !reflect.DeepEqual(got, tt.want) { 487 | t.Errorf("FileSerialProvider.Next() = %v, want %v", got, tt.want) 488 | } 489 | }) 490 | } 491 | } 492 | 493 | func TestFileCRLHolder_Put(t *testing.T) { 494 | t.Run("not exist", func(t *testing.T) { 495 | fileName := filepath.Join(getTestDir(), "dir_keystorage", "not_exist_crl.pem") 496 | content := []byte("content") 497 | defer func() { 498 | _ = os.RemoveAll(fileName) 499 | }() 500 | h := NewFileCRLHolder(fileName) 501 | err := h.Put(content) 502 | if err != nil { 503 | t.Errorf("FileCRLHolder.Put() error = %v", err) 504 | } 505 | got, _ := ioutil.ReadFile(fileName) 506 | if !bytes.Equal(got, content) { 507 | t.Errorf("FileCRLHolder.Put() got = %v, want %v", got, content) 508 | } 509 | }) 510 | t.Run("exist", func(t *testing.T) { 511 | fileName := filepath.Join(getTestDir(), "dir_keystorage", "exist.pem") 512 | content := []byte("content") 513 | defer func() { 514 | _ = ioutil.WriteFile(fileName, []byte("asd"), 0644) 515 | }() 516 | h := NewFileCRLHolder(fileName) 517 | err := h.Put(content) 518 | if err != nil { 519 | t.Errorf("FileCRLHolder.Put() error = %v", err) 520 | } 521 | got, _ := ioutil.ReadFile(fileName) 522 | if !bytes.Equal(got, content) { 523 | t.Errorf("FileCRLHolder.Put() got = %v, want %v", got, content) 524 | } 525 | }) 526 | t.Run("dir", func(t *testing.T) { 527 | fileName := filepath.Join(getTestDir(), "dir_keystorage", "crl.dir") 528 | content := []byte("content") 529 | defer func() { 530 | _ = ioutil.WriteFile(fileName, []byte("asd"), 0666) 531 | }() 532 | h := NewFileCRLHolder(fileName) 533 | err := h.Put(content) 534 | if err == nil { 535 | t.Errorf("FileCRLHolder.Put() error = %v", err) 536 | } 537 | }) 538 | } 539 | 540 | func TestFileCRLHolder_Get(t *testing.T) { 541 | type fields struct { 542 | path string 543 | } 544 | tests := []struct { 545 | name string 546 | fields fields 547 | want *pkix.CertificateList 548 | wantErr bool 549 | }{ 550 | { 551 | name: "not exist", 552 | fields: fields{ 553 | path: filepath.Join(getTestDir(), "dir_keystorage", "not_exist"), 554 | }, 555 | want: nil, 556 | wantErr: false, 557 | }, 558 | { 559 | name: "broken", 560 | fields: fields{ 561 | path: filepath.Join(getTestDir(), "dir_keystorage", "exist.pem"), 562 | }, 563 | want: nil, 564 | wantErr: true, 565 | }, 566 | { 567 | name: "good", 568 | fields: fields{ 569 | path: filepath.Join(getTestDir(), "dir_keystorage", "good_crl.pem"), 570 | }, 571 | want: nil, 572 | wantErr: false, 573 | }, 574 | } 575 | for _, tt := range tests { 576 | t.Run(tt.name, func(t *testing.T) { 577 | h := NewFileCRLHolder(tt.fields.path) 578 | _, err := h.Get() 579 | if (err != nil) != tt.wantErr { 580 | t.Errorf("FileCRLHolder.Get() error = %v, wantErr %v", err, tt.wantErr) 581 | return 582 | } 583 | }) 584 | } 585 | } 586 | 587 | func TestDirKeyStorage_GetAll(t *testing.T) { 588 | storPath := filepath.Join(getTestDir(), "empty_stor") 589 | stor := NewDirKeyStorage(storPath) 590 | _ = os.MkdirAll(storPath, 0755) 591 | defer func() { 592 | _ = os.RemoveAll(storPath) 593 | }() 594 | t.Run("empty stor", func(t *testing.T) { 595 | all, err := stor.GetAll() 596 | assert.NoError(t, err) 597 | assert.NotNil(t, all) 598 | assert.Empty(t, all) 599 | }) 600 | t.Run("good stor", func(t *testing.T) { 601 | _ = stor.Put(pair.NewX509Pair([]byte("keybytes"), []byte("certbytes"), "good_cert", big.NewInt(66))) 602 | _ = stor.Put(pair.NewX509Pair([]byte("keybytes"), []byte("certbytes"), "good_cert", big.NewInt(65))) 603 | _ = stor.Put(pair.NewX509Pair([]byte("keybytes"), []byte("certbytes"), "another_cert", big.NewInt(64))) 604 | all, err := stor.GetAll() 605 | assert.NoError(t, err) 606 | assert.NotNil(t, all) 607 | assert.NotEmpty(t, all) 608 | assert.Len(t, all, 3) 609 | }) 610 | } 611 | 612 | func TestDirKeyStorage_GetLastByCn(t *testing.T) { 613 | storPath := filepath.Join(getTestDir(), "empty_stor") 614 | stor := NewDirKeyStorage(storPath) 615 | _ = os.MkdirAll(filepath.Join(storPath, "any"), 0755) 616 | defer func() { 617 | _ = os.RemoveAll(storPath) 618 | }() 619 | t.Run("empty stor", func(t *testing.T) { 620 | all, err := stor.GetLastByCn("any") 621 | assert.Error(t, err) 622 | assert.Nil(t, all) 623 | }) 624 | } 625 | 626 | func Test_writeFileAtomic(t *testing.T) { 627 | path := filepath.Join(getTestDir(), "dir_keystorage") 628 | type args struct { 629 | path string 630 | r io.Reader 631 | mode os.FileMode 632 | } 633 | tests := []struct { 634 | name string 635 | args args 636 | wantErr assert.ErrorAssertionFunc 637 | }{ 638 | { 639 | name: "not_exist", 640 | args: args{ 641 | path: filepath.Join(path, "bad_key/not_exist"), 642 | r: strings.NewReader("test"), 643 | mode: 0644, 644 | }, 645 | wantErr: assert.NoError, 646 | }, 647 | { 648 | name: "exist", 649 | args: args{ 650 | path: filepath.Join(path, "bad_key/42.crt"), 651 | r: strings.NewReader("test"), 652 | mode: 0644, 653 | }, 654 | wantErr: assert.NoError, 655 | }, 656 | { 657 | name: "dir", 658 | args: args{ 659 | path: filepath.Join(path, "bad_key/42.key"), 660 | r: strings.NewReader("test"), 661 | mode: 0644, 662 | }, 663 | wantErr: assert.Error, 664 | }, 665 | } 666 | defer func(name string) { 667 | _ = os.Remove(name) 668 | }(filepath.Join(path, "bad_key/not_exist")) 669 | for _, tt := range tests { 670 | t.Run(tt.name, func(t *testing.T) { 671 | tt.wantErr(t, writeFileAtomic(tt.args.path, tt.args.r, tt.args.mode), fmt.Sprintf("writeFileAtomic(%v, %v, %v)", tt.args.path, tt.args.r, tt.args.mode)) 672 | }) 673 | } 674 | } 675 | --------------------------------------------------------------------------------