├── .gitignore ├── .goreleaser.yml ├── fixtures ├── certs │ ├── ca-key.pem │ ├── client-key.pem │ ├── server-key.pem │ ├── ca-cert.pem │ ├── client-cert.pem │ └── server-cert.pem └── gen.sh ├── authorizer ├── mtls │ ├── fixtures │ │ ├── ca-ext-key.pem │ │ ├── ca-root-key.pem │ │ ├── server-a-key.pem │ │ ├── user-a-key.pem │ │ ├── user-b-key.pem │ │ ├── user-ext-key.pem │ │ ├── ca-signer-a-key.pem │ │ ├── ca-signer-b-key.pem │ │ ├── claim-test-key.pem │ │ ├── ca-intermediate-key.pem │ │ ├── ca-ext-cert.pem │ │ ├── ca-root-cert.pem │ │ ├── ca-intermediate-cert.pem │ │ ├── ca-signer-a-cert.pem │ │ ├── ca-signer-b-cert.pem │ │ ├── user-a-cert.pem │ │ ├── user-b-cert.pem │ │ ├── user-ext-cert.pem │ │ ├── server-a-cert.pem │ │ ├── claim-test-cert.pem │ │ ├── ca-chain-a.pem │ │ └── ca-chain-b.pem │ ├── doc.go │ ├── utils.go │ └── utils_test.go └── simple │ ├── doc.go │ ├── authorizer.go │ ├── pushpublisher.go │ ├── publishhandler_test.go │ ├── authenticator.go │ ├── authorizer_test.go │ └── authenticator_test.go ├── gateway ├── buffers_test.go ├── buffers.go ├── upstreamer │ └── push │ │ ├── randomizer.go │ │ ├── movingaverage_test.go │ │ ├── ping.go │ │ ├── notifier_options_test.go │ │ ├── movingaverage.go │ │ ├── services.go │ │ ├── notifier_options.go │ │ ├── upstreamer_options_test.go │ │ ├── utils.go │ │ ├── services_test.go │ │ ├── notifier.go │ │ ├── upstreamer_options.go │ │ └── upstreamer_distribution_test.go ├── listener.go ├── extractors.go ├── utils.go ├── interfaces.go ├── limiter.go ├── rewriters.go ├── listener_test.go ├── extractors_test.go └── errors.go ├── Makefile ├── .github └── workflows │ └── build-go.yaml ├── doc.go ├── job.go ├── pubsub_test.go ├── metrics.go ├── pubsub.go ├── netlimiter.go ├── README.md ├── websocket_push_session_mock_test.go ├── pinger.go ├── pinger_test.go ├── profiling_server.go ├── processor_helpers.go ├── go.mod ├── websocket_push_session_mock.go ├── job_test.go ├── health_server.go ├── utils.go ├── netlimiter_test.go ├── rest_server_helpers.go ├── cors.go ├── pubsub_local.go ├── meta.go ├── pubsub_nats_options_test.go ├── pubsub_local_test.go ├── config.go ├── opentracing.go ├── pubsub_nats_mocks_test.go ├── meta_test.go ├── pubsub_nats_options.go ├── context.go ├── metrics_prometheus.go └── context_mock_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | testresults.xml 2 | .DS_Store 3 | .tags 4 | 5 | *.cover 6 | *.lock 7 | vendor 8 | 9 | unit_coverage.out 10 | cov.report 11 | .idea/* 12 | artifacts 13 | remod.dev 14 | .remod 15 | dist 16 | .vsession 17 | -------------------------------------------------------------------------------- /.goreleaser.yml: -------------------------------------------------------------------------------- 1 | snapshot: 2 | name_template: "{{ .Tag }}-next" 3 | changelog: 4 | sort: asc 5 | filters: 6 | exclude: 7 | - '^docs:' 8 | - '^test:' 9 | - '^examples:' 10 | builds: 11 | - skip: true 12 | -------------------------------------------------------------------------------- /fixtures/certs/ca-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN EC PRIVATE KEY----- 2 | MHcCAQEEIODhz57/DL62sSVeh8l/Rt1EfLeFM6ZWZXUv+JlAys76oAoGCCqGSM49 3 | AwEHoUQDQgAEN6zYuw8I0ZBaKe4gIUMw4bAA0DIqnkglVlsw1QDtKDKcyLzZvECh 4 | FDQK8MlXQZMBY40DRvNfDwPAXsf1RVf3Qw== 5 | -----END EC PRIVATE KEY----- 6 | -------------------------------------------------------------------------------- /fixtures/certs/client-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN EC PRIVATE KEY----- 2 | MHcCAQEEICUV8iFvxhDn8kB6rU3gIizjFm9unrVZ74lLbCwKg3f3oAoGCCqGSM49 3 | AwEHoUQDQgAEj/SijSBMrKIY7k7YDI1eWRmA/Ppl5dn2+wnA4niEOehfjGhtbrVs 4 | Jt1C1jcYJmIXCljnoeiWRP78lSs5VCQjqA== 5 | -----END EC PRIVATE KEY----- 6 | -------------------------------------------------------------------------------- /fixtures/certs/server-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN EC PRIVATE KEY----- 2 | MHcCAQEEIOf7Myt3HnDYaeO4VoUhmVQd+mIMRNyNuW2whsIsvc33oAoGCCqGSM49 3 | AwEHoUQDQgAE4aG1pQMnzPgKHcTu0J3QXnn/ZempnQjn7rIGTZsKYUng6PKqYTbS 4 | aZUqdrIx7eakczLXiPAzA2H3QRSmNVk+Bw== 5 | -----END EC PRIVATE KEY----- 6 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/ca-ext-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN EC PRIVATE KEY----- 2 | MHcCAQEEINkx1/iOmWRlfaY90gRBaQ5alvRUmM43CpFnLrrRWNN9oAoGCCqGSM49 3 | AwEHoUQDQgAEWUmynz+3f9PEnPWY7lBQXo31HODeCt/sBXMQB5p9hDTrjKJVDFv/ 4 | ry/+0aiz0VnE7IEjlJncWwAWJkP+QyEGGw== 5 | -----END EC PRIVATE KEY----- 6 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/ca-root-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN EC PRIVATE KEY----- 2 | MHcCAQEEIKNt7pNp0ljqdwz7QjxAEZ9ZtYX75YwpoFtSGgQXsUQmoAoGCCqGSM49 3 | AwEHoUQDQgAEZrASa6uHjeyjJAeeYBIo0sF1DB8hX1Fv1TIjSNLDEyhDA+bl+LPM 4 | 5WKYEZsagD5RcGa33s4bYR2fgqBtx1eqIw== 5 | -----END EC PRIVATE KEY----- 6 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/server-a-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN EC PRIVATE KEY----- 2 | MHcCAQEEINq3jRuV2SWk2qHT+F5i1nLItE3tC/hiqa2Tt1Q/sMYBoAoGCCqGSM49 3 | AwEHoUQDQgAE3tW9M/HuQAcGdMtWiZlrMfQ49ESXNWjx8M8XD521PDTKw4x57L/h 4 | lP0tjxyhoPGPmG1daELHGcVQ2ThE7t2/nA== 5 | -----END EC PRIVATE KEY----- 6 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/user-a-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN EC PRIVATE KEY----- 2 | MHcCAQEEIDloR2qBxSNiL5i5VXknXK+rOpVCpdj7K3+qlq+VIkTPoAoGCCqGSM49 3 | AwEHoUQDQgAE6Un907SdKaQS+DYeLSQXEHe9TqXZFKMUvxoT7DNFTVAMKD4znqc0 4 | 7A0NnUyya05pRcAWYup+wvlTyEBFA7FTPw== 5 | -----END EC PRIVATE KEY----- 6 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/user-b-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN EC PRIVATE KEY----- 2 | MHcCAQEEIMcaImMaB0EJceBfLrE+AKwLXDbTwaQqf7eAFxP0TGGvoAoGCCqGSM49 3 | AwEHoUQDQgAENBfD+rpkwmI3ju8zoVG6h1IGYxklR9YbLcZG1puep/rfEhct0jsL 4 | YLpp3x8OF+YDHDLlavwkdxi9zJHv6NYSLA== 5 | -----END EC PRIVATE KEY----- 6 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/user-ext-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN EC PRIVATE KEY----- 2 | MHcCAQEEIKKWNvVPl8dEq3e+CwypKuGWcLbDPrI2vyJ+yaf/W6J/oAoGCCqGSM49 3 | AwEHoUQDQgAEJLJ0ddePepewGMaO+RsPE63Er/kT86qk2l7NJFFIVDvPXlsazO21 4 | 6QTh/Jggqaglq52JNYSAflmJ05ho41Sblg== 5 | -----END EC PRIVATE KEY----- 6 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/ca-signer-a-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN EC PRIVATE KEY----- 2 | MHcCAQEEIBCk1/UhxiOO+9tG9P8jmkJktig2Rbcjhg64JL/GENyEoAoGCCqGSM49 3 | AwEHoUQDQgAEko9zulgiDQi80GPGV9VEsWzuUfP0DsCIb84XsRpUcYl7n7Njt+YH 4 | iYbmG/mIg2zWzAiTTNmetxBtS6hxqUa6pA== 5 | -----END EC PRIVATE KEY----- 6 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/ca-signer-b-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN EC PRIVATE KEY----- 2 | MHcCAQEEILmJdMbVc+BsnvPXCeyjE8zPBAht7xDD4Taa5B9xWmWUoAoGCCqGSM49 3 | AwEHoUQDQgAEJzTiVakpH3JD3z5Ar8tYX65Sc18nqJ/jlR8hfsshCptDILVv/pX5 4 | NGY3C8+94PKacrP0QNyWD2SGTW0R6DbFmA== 5 | -----END EC PRIVATE KEY----- 6 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/claim-test-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN EC PRIVATE KEY----- 2 | MHcCAQEEIDt08Hag1GkVZVRaMS6pOW9FbvxEQBnK+pT+B8a+j9LOoAoGCCqGSM49 3 | AwEHoUQDQgAECAE1GwRLPm+xpAtsa0l6FLEzdpqAjIsv5eElPIXNBpa0yp5D/KUU 4 | eTjq+tB/Lee+9mZZdLnD7eLDO1Ks0HYV8Q== 5 | -----END EC PRIVATE KEY----- 6 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/ca-intermediate-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN EC PRIVATE KEY----- 2 | MHcCAQEEIMH6LhkQ57D5Xl4CP+mUDQWW/GoYWGgXc8tXH+uG4AgUoAoGCCqGSM49 3 | AwEHoUQDQgAEBo7+zb5KiYoyLxOxnYYbGLP+YWDROxuEHUWf47hxQnGJWhR8V14m 4 | Nw/GCYysopbgTDEZVDEK4uRbDgShj0lBJA== 5 | -----END EC PRIVATE KEY----- 6 | -------------------------------------------------------------------------------- /fixtures/gen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" || exit 1 4 | mkdir -p ./certs 5 | tg cert --out ./certs --name "ca" --is-ca --org "Aporeto" --common-name "localhost" --force 6 | tg cert --out ./certs --name "client" --auth-client --common-name "localhost" --signing-cert ./certs/ca-cert.pem --signing-cert-key ./certs/ca-key.pem --force 7 | tg cert --out ./certs --name "server" --auth-server --common-name "localhost" --dns "localhost" --signing-cert ./certs/ca-cert.pem --signing-cert-key ./certs/ca-key.pem --force 8 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/ca-ext-cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBPTCB5KADAgECAhEAivrvpZkRC1337X+i8EyOkTAKBggqhkjOPQQDAjAOMQww 3 | CgYDVQQDEwNleHQwHhcNMTcxMDEzMjMxMjM3WhcNMjcwODIyMjMxMjM3WjAOMQww 4 | CgYDVQQDEwNleHQwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARZSbKfP7d/08Sc 5 | 9ZjuUFBejfUc4N4K3+wFcxAHmn2ENOuMolUMW/+vL/7RqLPRWcTsgSOUmdxbABYm 6 | Q/5DIQYboyMwITAOBgNVHQ8BAf8EBAMCAQYwDwYDVR0TAQH/BAUwAwEB/zAKBggq 7 | hkjOPQQDAgNIADBFAiEAmoYnvoG5TE15qKAWbAaOxN5c2MCtjSrP0CnwtTijOfQC 8 | IB25mGCqbVOSP0A7cs2y5uvIEPe9ntrDxlK5emaF9Ef7 9 | -----END CERTIFICATE----- 10 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/ca-root-cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBPTCB5aADAgECAhACHS8TfLjFRxcCfk7S5ko1MAoGCCqGSM49BAMCMA8xDTAL 3 | BgNVBAMTBHJvb3QwHhcNMTcxMDEzMjMwOTQ5WhcNMjcwODIyMjMwOTQ5WjAPMQ0w 4 | CwYDVQQDEwRyb290MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEZrASa6uHjeyj 5 | JAeeYBIo0sF1DB8hX1Fv1TIjSNLDEyhDA+bl+LPM5WKYEZsagD5RcGa33s4bYR2f 6 | gqBtx1eqI6MjMCEwDgYDVR0PAQH/BAQDAgEGMA8GA1UdEwEB/wQFMAMBAf8wCgYI 7 | KoZIzj0EAwIDRwAwRAIgO2YRDPkM7fa9Z6Ld77d/59EpPGjKzQKiT0n4OXo7zz0C 8 | IHzFRdemtNkpM/JXVj8IVhyY7T/h6ShYn9CtZ3kD92Hu 9 | -----END CERTIFICATE----- 10 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/ca-intermediate-cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBRjCB7aADAgECAhAnzb2sr6mCl+KiVHdScn39MAoGCCqGSM49BAMCMA8xDTAL 3 | BgNVBAMTBHJvb3QwHhcNMTcxMDEzMjMxMDU3WhcNMjcwODIyMjMxMDU3WjAXMRUw 4 | EwYDVQQDEwxpbnRlcm1lZGlhdGUwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQG 5 | jv7NvkqJijIvE7GdhhsYs/5hYNE7G4QdRZ/juHFCcYlaFHxXXiY3D8YJjKyiluBM 6 | MRlUMQri5FsOBKGPSUEkoyMwITAOBgNVHQ8BAf8EBAMCAQYwDwYDVR0TAQH/BAUw 7 | AwEB/zAKBggqhkjOPQQDAgNIADBFAiAme2oWY9u7g6kROCwrK+u/sf9RMtQOKMVu 8 | ws/4FrWMPQIhANCgWbFbWHckIgrFA/YheRy25B1/irCGW01ziyfbFnTY 9 | -----END CERTIFICATE----- 10 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/ca-signer-a-cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBSTCB8aADAgECAhAiq53Rx75bI9gq0+FkXMSYMAoGCCqGSM49BAMCMBcxFTAT 3 | BgNVBAMTDGludGVybWVkaWF0ZTAeFw0xNzEwMTMyMzExNTBaFw0yNzA4MjIyMzEx 4 | NTBaMBMxETAPBgNVBAMTCHNpZ25lci1hMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcD 5 | QgAEko9zulgiDQi80GPGV9VEsWzuUfP0DsCIb84XsRpUcYl7n7Njt+YHiYbmG/mI 6 | g2zWzAiTTNmetxBtS6hxqUa6pKMjMCEwDgYDVR0PAQH/BAQDAgEGMA8GA1UdEwEB 7 | /wQFMAMBAf8wCgYIKoZIzj0EAwIDRwAwRAIgNOlZXpSQsLrWyyKUfr3BbYoFmXYl 8 | qNF0f0zpeQUGzc8CIFmKIYYNAS041NzU7M8qJvio7FqZtSk1LRZC0QR1l8vB 9 | -----END CERTIFICATE----- 10 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/ca-signer-b-cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBSTCB8aADAgECAhA8BebnE3zdVOII5Zh6/UEaMAoGCCqGSM49BAMCMBcxFTAT 3 | BgNVBAMTDGludGVybWVkaWF0ZTAeFw0xNzEwMTMyMzEyMTBaFw0yNzA4MjIyMzEy 4 | MTBaMBMxETAPBgNVBAMTCHNpZ25lci1iMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcD 5 | QgAEJzTiVakpH3JD3z5Ar8tYX65Sc18nqJ/jlR8hfsshCptDILVv/pX5NGY3C8+9 6 | 4PKacrP0QNyWD2SGTW0R6DbFmKMjMCEwDgYDVR0PAQH/BAQDAgEGMA8GA1UdEwEB 7 | /wQFMAMBAf8wCgYIKoZIzj0EAwIDRwAwRAIgBmum95kc40bgQ8YpgyVEhYDUd4e5 8 | I+w3mYNk7CGEAPkCIF5esWzNKE10TJU7mPmV7vVX463dpIvSRKWxguhiFX3R 9 | -----END CERTIFICATE----- 10 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/user-a-cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBVTCB/aADAgECAhARq0YiIWt0OVVp+RqLmtgmMAoGCCqGSM49BAMCMBMxETAP 3 | BgNVBAMTCHNpZ25lci1hMB4XDTE3MTAxMzIzMTUzMVoXDTI3MDgyMjIzMTUzMVow 4 | ETEPMA0GA1UEAxMGdXNlci1hMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE6Un9 5 | 07SdKaQS+DYeLSQXEHe9TqXZFKMUvxoT7DNFTVAMKD4znqc07A0NnUyya05pRcAW 6 | Yup+wvlTyEBFA7FTP6M1MDMwDgYDVR0PAQH/BAQDAgeAMBMGA1UdJQQMMAoGCCsG 7 | AQUFBwMCMAwGA1UdEwEB/wQCMAAwCgYIKoZIzj0EAwIDRwAwRAIgEOun3or4nuub 8 | 1i2QgNkOOSfxAbEG/stM2nEjTemXtpECIH3KX72mnKbd8eLSYFIsbAz6B55GBeF8 9 | Tuzw3YBRYF5F 10 | -----END CERTIFICATE----- 11 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/user-b-cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBVzCB/qADAgECAhEAxEXp+z1wWArT0+U85V5BhjAKBggqhkjOPQQDAjATMREw 3 | DwYDVQQDEwhzaWduZXItYjAeFw0xNzEwMTMyMzE1NDFaFw0yNzA4MjIyMzE1NDFa 4 | MBExDzANBgNVBAMTBnVzZXItYjBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABDQX 5 | w/q6ZMJiN47vM6FRuodSBmMZJUfWGy3GRtabnqf63xIXLdI7C2C6ad8fDhfmAxwy 6 | 5Wr8JHcYvcyR7+jWEiyjNTAzMA4GA1UdDwEB/wQEAwIHgDATBgNVHSUEDDAKBggr 7 | BgEFBQcDAjAMBgNVHRMBAf8EAjAAMAoGCCqGSM49BAMCA0gAMEUCIQDtf3/E/SdO 8 | r0tQSIEAZEAXsoprmc1G3GpZLxyr56fYMQIgFTZ3y3kBi27ec8Kq25RZuVOm8fW5 9 | es4EmhhC8ezi2Jo= 10 | -----END CERTIFICATE----- 11 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/user-ext-cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBVDCB+6ADAgECAhEA3zkId/yJk4HeXlMLDXWjUDAKBggqhkjOPQQDAjAOMQww 3 | CgYDVQQDEwNleHQwHhcNMTcxMDEzMjMxNTUyWhcNMjcwODIyMjMxNTUyWjATMREw 4 | DwYDVQQDEwh1c2VyLWV4dDBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABCSydHXX 5 | j3qXsBjGjvkbDxOtxK/5E/OqpNpezSRRSFQ7z15bGszttekE4fyYIKmoJaudiTWE 6 | gH5ZidOYaONUm5ajNTAzMA4GA1UdDwEB/wQEAwIHgDATBgNVHSUEDDAKBggrBgEF 7 | BQcDAjAMBgNVHRMBAf8EAjAAMAoGCCqGSM49BAMCA0gAMEUCIFheh6Cd/cMyGZc8 8 | U3E8UljBYXwWmvEpOun9hSPEKgJwAiEAsDmmNJwUqfsOpPMbAa4u2jCEmof2nwAl 9 | behXecThWbs= 10 | -----END CERTIFICATE----- 11 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/server-a-cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBWjCCAQCgAwIBAgIRAMNoVvbTH0w4PFr3LcVvcFIwCgYIKoZIzj0EAwIwEzER 3 | MA8GA1UEAxMIc2lnbmVyLWEwHhcNMTcxMDEzMjMxNjA0WhcNMjcwODIyMjMxNjA0 4 | WjATMREwDwYDVQQDEwhzZXJ2ZXItYTBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IA 5 | BN7VvTPx7kAHBnTLVomZazH0OPRElzVo8fDPFw+dtTw0ysOMeey/4ZT9LY8coaDx 6 | j5htXWhCxxnFUNk4RO7dv5yjNTAzMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUEDDAK 7 | BggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMAoGCCqGSM49BAMCA0gAMEUCIQCL6O76 8 | hNVHe/QkonYwvOZ3CGtpolNdugNkP0ZryyTNbgIgOkG+spKTzNhWr0CADuojGjKt 9 | qHyEtDrEAIgjgpGl/FA= 10 | -----END CERTIFICATE----- 11 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/claim-test-cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBbjCCAROgAwIBAgIRALVJ9SRXYpYyBI9MgSiCgmowCgYIKoZIzj0EAwIwJzEK 3 | MAgGA1UEChMBQTEKMAgGA1UECxMBQjENMAsGA1UEAxMEdGVzdDAeFw0xODEyMTMw 4 | MjU3MDJaFw0yODEwMjEwMjU3MDJaMCcxCjAIBgNVBAoTAUExCjAIBgNVBAsTAUIx 5 | DTALBgNVBAMTBHRlc3QwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQIATUbBEs+ 6 | b7GkC2xrSXoUsTN2moCMiy/l4SU8hc0GlrTKnkP8pRR5OOr60H8t5772Zll0ucPt 7 | 4sM7UqzQdhXxoyAwHjAOBgNVHQ8BAf8EBAMCB4AwDAYDVR0TAQH/BAIwADAKBggq 8 | hkjOPQQDAgNJADBGAiEAk00oNksw6KtS8r6JPBSclT/tD9SfMifSXupKqSBuZpcC 9 | IQCuo+G9F6BxKGZt+hC0fw4SHeXx19EaRHyYpF1vDq8xTA== 10 | -----END CERTIFICATE----- 11 | -------------------------------------------------------------------------------- /gateway/buffers_test.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "testing" 5 | 6 | // nolint:revive // Allow dot imports for readability in tests 7 | . "github.com/smartystreets/goconvey/convey" 8 | ) 9 | 10 | func TestBufferPool(t *testing.T) { 11 | 12 | Convey("Given I create a new buffer pool", t, func() { 13 | 14 | bp := newPool(42) 15 | Convey("Then bp should be correct", func() { 16 | So(bp, ShouldNotBeNil) 17 | }) 18 | 19 | buff := bp.Get() 20 | 21 | Convey("Then the size of a buffer should be correct", func() { 22 | So(cap(buff), ShouldEqual, 42) 23 | }) 24 | 25 | bp.Put(buff) 26 | }) 27 | } 28 | -------------------------------------------------------------------------------- /fixtures/certs/ca-cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBjTCCATOgAwIBAgIRAPNBiFhHdg/OWfkNNR7E4lkwCgYIKoZIzj0EAwIwJjEQ 3 | MA4GA1UEChMHQXBvcmV0bzESMBAGA1UEAxMJbG9jYWxob3N0MB4XDTIwMDkwOTIz 4 | MjM0MloXDTMwMDcxOTIzMjM0MlowJjEQMA4GA1UEChMHQXBvcmV0bzESMBAGA1UE 5 | AxMJbG9jYWxob3N0MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEN6zYuw8I0ZBa 6 | Ke4gIUMw4bAA0DIqnkglVlsw1QDtKDKcyLzZvEChFDQK8MlXQZMBY40DRvNfDwPA 7 | Xsf1RVf3Q6NCMEAwDgYDVR0PAQH/BAQDAgEGMA8GA1UdEwEB/wQFMAMBAf8wHQYD 8 | VR0OBBYEFEs1PL+tyeV8q+MTpBo5SLqXv3vpMAoGCCqGSM49BAMCA0gAMEUCIHah 9 | unFclBT/PwFpzK3Yhb1NJjcHmY/mKfMUMvTUF61fAiEAojGxSRlLdstBXWPzwWXZ 10 | hbmm3vaRGBdCOUQHQt7i4Ng= 11 | -----END CERTIFICATE----- 12 | -------------------------------------------------------------------------------- /fixtures/certs/client-cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBjjCCATSgAwIBAgIQPSli5DlPidzWqCSwooYj0jAKBggqhkjOPQQDAjAmMRAw 3 | DgYDVQQKEwdBcG9yZXRvMRIwEAYDVQQDEwlsb2NhbGhvc3QwHhcNMjAwOTA5MjMy 4 | MzQyWhcNMzAwNzE5MjMyMzQyWjAUMRIwEAYDVQQDEwlsb2NhbGhvc3QwWTATBgcq 5 | hkjOPQIBBggqhkjOPQMBBwNCAASP9KKNIEysohjuTtgMjV5ZGYD8+mXl2fb7CcDi 6 | eIQ56F+MaG1utWwm3ULWNxgmYhcKWOeh6JZE/vyVKzlUJCOoo1YwVDAOBgNVHQ8B 7 | Af8EBAMCBaAwEwYDVR0lBAwwCgYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAfBgNV 8 | HSMEGDAWgBRLNTy/rcnlfKvjE6QaOUi6l7976TAKBggqhkjOPQQDAgNIADBFAiBx 9 | fHVx/NcHVkKvc3pJ8QbpqqBRdByYtEhnXvY6j2R9iQIhAL/OD4K4/6mV8Z28OCcc 10 | RCADJGGFRKPj4HQ287ZqWQBa 11 | -----END CERTIFICATE----- 12 | -------------------------------------------------------------------------------- /fixtures/certs/server-cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBozCCAUqgAwIBAgIQPtbn3UPWJt5Y2kt/I1ddKzAKBggqhkjOPQQDAjAmMRAw 3 | DgYDVQQKEwdBcG9yZXRvMRIwEAYDVQQDEwlsb2NhbGhvc3QwHhcNMjAwOTA5MjMy 4 | MzQyWhcNMzAwNzE5MjMyMzQyWjAUMRIwEAYDVQQDEwlsb2NhbGhvc3QwWTATBgcq 5 | hkjOPQIBBggqhkjOPQMBBwNCAAThobWlAyfM+AodxO7QndBeef9l6amdCOfusgZN 6 | mwphSeDo8qphNtJplSp2sjHt5qRzMteI8DMDYfdBFKY1WT4Ho2wwajAOBgNVHQ8B 7 | Af8EBAMCBaAwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDAYDVR0TAQH/BAIwADAfBgNV 8 | HSMEGDAWgBRLNTy/rcnlfKvjE6QaOUi6l7976TAUBgNVHREEDTALgglsb2NhbGhv 9 | c3QwCgYIKoZIzj0EAwIDRwAwRAIgX/ZW9OcqslTo6GzYNc5qgZEJCMMgpOmzRhll 10 | nXN59RsCICyIFX6CqgnTj4I9BjJcLaXdu1bG9Pri9GVlKEHIdaXt 11 | -----END CERTIFICATE----- 12 | -------------------------------------------------------------------------------- /gateway/buffers.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import "sync" 4 | 5 | // bufferPool implements the interface of httputil.BufferPool in order 6 | // to improve memory utilization in the reverse proxy. 7 | type bufferPool struct { 8 | s sync.Pool 9 | } 10 | 11 | func newPool(size int) *bufferPool { 12 | return &bufferPool{ 13 | s: sync.Pool{ 14 | New: func() any { 15 | return make([]byte, size) 16 | }, 17 | }, 18 | } 19 | } 20 | 21 | // Get gets a buffer from the pool. 22 | func (b *bufferPool) Get() []byte { 23 | return b.s.Get().([]byte) 24 | } 25 | 26 | // Put returns the buffer to the pool. 27 | func (b *bufferPool) Put(buf []byte) { 28 | b.s.Put(buf) // nolint 29 | } 30 | -------------------------------------------------------------------------------- /authorizer/mtls/doc.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package mtls // import "go.aporeto.io/bahamut/authorizer/mtls" 13 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | MAKEFLAGS += --warn-undefined-variables 2 | SHELL := /bin/bash -o pipefail 3 | 4 | export GO111MODULE = on 5 | 6 | default: lint test 7 | 8 | lint: 9 | # --enable=unparam 10 | golangci-lint run \ 11 | --disable-all \ 12 | --exclude-use-default=false \ 13 | --exclude=package-comments \ 14 | --exclude=unused-parameter \ 15 | --enable=errcheck \ 16 | --enable=goimports \ 17 | --enable=ineffassign \ 18 | --enable=revive \ 19 | --enable=unused \ 20 | --enable=staticcheck \ 21 | --enable=unconvert \ 22 | --enable=misspell \ 23 | --enable=prealloc \ 24 | --enable=nakedret \ 25 | --enable=typecheck \ 26 | ./... 27 | test: 28 | go test ./... -race -cover -covermode=atomic -coverprofile=unit_coverage.out 29 | 30 | sec: 31 | gosec -quiet ./... 32 | -------------------------------------------------------------------------------- /.github/workflows/build-go.yaml: -------------------------------------------------------------------------------- 1 | name: build-go 2 | on: 3 | push: 4 | branches: 5 | - master 6 | pull_request: 7 | 8 | defaults: 9 | run: 10 | shell: bash 11 | 12 | env: 13 | GO111MODULE: on 14 | 15 | jobs: 16 | build: 17 | runs-on: ubuntu-latest 18 | strategy: 19 | fail-fast: false 20 | matrix: 21 | go: 22 | - "1.21" 23 | - "1.22" 24 | steps: 25 | - uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744 # v3 26 | 27 | - uses: actions/setup-go@93397bea11091df50f3d7e59dc26a7711a8bcfbe # v4 28 | with: 29 | go-version: ${{ matrix.go }} 30 | cache: true 31 | 32 | - name: setup 33 | run: | 34 | go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest 35 | 36 | - name: build 37 | run: | 38 | make 39 | -------------------------------------------------------------------------------- /authorizer/simple/doc.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | // Package simple provides implementations of bahamut.SessionAuthenticator 13 | // bahamut.RequestAuthenticator and a bahamut.Authorizer using 14 | // a given function to decide if a request should be authenticated/authorized. 15 | package simple // import "go.aporeto.io/bahamut/authorizer/simple" 16 | -------------------------------------------------------------------------------- /gateway/upstreamer/push/randomizer.go: -------------------------------------------------------------------------------- 1 | package push 2 | 3 | import ( 4 | "math/rand" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | // A Randomizer reprensents an interface to randomize 10 | type Randomizer interface { 11 | Intn(int) int 12 | Shuffle(n int, swap func(i, j int)) 13 | } 14 | 15 | // newRandomizer return a new Randomizer 16 | func newRandomizer() Randomizer { 17 | return &defaultRandomizer{random: rand.New(rand.NewSource(time.Now().UnixNano()))} 18 | } 19 | 20 | // defaultRandomizer is the default Randomizer 21 | type defaultRandomizer struct { 22 | sync.Mutex 23 | random *rand.Rand 24 | } 25 | 26 | // Intn implement Randomizer interface 27 | func (r *defaultRandomizer) Intn(n int) int { 28 | r.Lock() 29 | defer r.Unlock() 30 | return r.random.Intn(n) 31 | } 32 | 33 | // Shuffle implement Randomizer interface 34 | func (r *defaultRandomizer) Shuffle(n int, swap func(i, j int)) { 35 | r.Lock() 36 | r.random.Shuffle(n, swap) 37 | r.Unlock() 38 | } 39 | -------------------------------------------------------------------------------- /gateway/listener.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "net" 5 | 6 | "golang.org/x/time/rate" 7 | ) 8 | 9 | type limitListener struct { 10 | net.Listener 11 | limiter *rate.Limiter 12 | metricManager LimiterMetricManager 13 | } 14 | 15 | func newLimitedListener(l net.Listener, cps rate.Limit, burst int, metricManager LimiterMetricManager) net.Listener { 16 | 17 | return &limitListener{ 18 | Listener: l, 19 | limiter: rate.NewLimiter(cps, burst), 20 | metricManager: metricManager, 21 | } 22 | 23 | } 24 | 25 | func (l *limitListener) Accept() (net.Conn, error) { 26 | 27 | c, err := l.Listener.Accept() 28 | if err != nil { 29 | return nil, err 30 | } 31 | 32 | if !l.limiter.Allow() { 33 | c.Close() // nolint 34 | if l.metricManager != nil { 35 | l.metricManager.RegisterLimitedConnection() 36 | } 37 | } else { 38 | if l.metricManager != nil { 39 | l.metricManager.RegisterAcceptedConnection() 40 | } 41 | } 42 | 43 | return c, nil 44 | } 45 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | // Package bahamut contains everything needed to build a fast and secure API server 13 | // based on a set of Regolithe Specifications. 14 | // 15 | // Bahamut uses an Elemental model generated from a set of Regolithe Specifications 16 | // You will just need to write various processors to handle the business logic and storage. 17 | package bahamut // import "go.aporeto.io/bahamut" 18 | -------------------------------------------------------------------------------- /job.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import "context" 15 | 16 | // Job is the type of function that can be run as a Job. 17 | type Job func() error 18 | 19 | // RunJob runs a Job can than be canceled at any time according to the context. 20 | func RunJob(ctx context.Context, job Job) (bool, error) { 21 | 22 | out := make(chan error) 23 | 24 | go func() { out <- job() }() 25 | 26 | select { 27 | case <-ctx.Done(): 28 | return true, nil 29 | case err := <-out: 30 | return false, err 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /pubsub_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "testing" 16 | 17 | // nolint:revive // Allow dot imports for readability in tests 18 | . "github.com/smartystreets/goconvey/convey" 19 | ) 20 | 21 | func TestPubsub_NewServer(t *testing.T) { 22 | 23 | Convey("Given I create a new localPubSubServer", t, func() { 24 | 25 | ps := NewLocalPubSubClient() 26 | 27 | Convey("Then the PubSubServer should be correctly initialized", func() { 28 | So(ps, ShouldImplement, (*PubSubClient)(nil)) 29 | }) 30 | }) 31 | } 32 | -------------------------------------------------------------------------------- /authorizer/mtls/utils.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package mtls 13 | 14 | import "crypto/x509" 15 | 16 | func makeClaims(cert *x509.Certificate) []string { 17 | 18 | claims := []string{ 19 | "@auth:realm=certificate", 20 | "@auth:mode=internal", 21 | "@auth:serialnumber=" + cert.SerialNumber.String(), 22 | "@auth:commonname=" + cert.Subject.CommonName, 23 | } 24 | 25 | for _, o := range cert.Subject.Organization { 26 | claims = append(claims, "@auth:organization="+o) 27 | } 28 | 29 | for _, ou := range cert.Subject.OrganizationalUnit { 30 | claims = append(claims, "@auth:organizationalunit="+ou) 31 | } 32 | 33 | return claims 34 | } 35 | -------------------------------------------------------------------------------- /metrics.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "net/http" 16 | "time" 17 | 18 | opentracing "github.com/opentracing/opentracing-go" 19 | ) 20 | 21 | // FinishMeasurementFunc is the kind of functinon returned by MetricsManager.MeasureRequest(). 22 | type FinishMeasurementFunc func(code int, span opentracing.Span) time.Duration 23 | 24 | // A MetricsManager handles Prometheus Metrics Management 25 | type MetricsManager interface { 26 | MeasureRequest(method string, path string) FinishMeasurementFunc 27 | RegisterWSConnection() 28 | UnregisterWSConnection() 29 | RegisterTCPConnection() 30 | UnregisterTCPConnection() 31 | Write(w http.ResponseWriter, r *http.Request) 32 | } 33 | -------------------------------------------------------------------------------- /pubsub.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "context" 16 | ) 17 | 18 | // PubSubOptPublish is the type of option that can use in PubSubClient.Publish. 19 | type PubSubOptPublish func(any) 20 | 21 | // PubSubOptSubscribe is the type of option that can use in PubSubClient.Subscribe. 22 | type PubSubOptSubscribe func(any) 23 | 24 | // A PubSubClient is a structure that provides a publish/subscribe mechanism. 25 | type PubSubClient interface { 26 | Publish(publication *Publication, opts ...PubSubOptPublish) error 27 | Subscribe(pubs chan *Publication, errors chan error, topic string, opts ...PubSubOptSubscribe) func() 28 | Connect(ctx context.Context) error 29 | Disconnect() error 30 | } 31 | -------------------------------------------------------------------------------- /gateway/upstreamer/push/movingaverage_test.go: -------------------------------------------------------------------------------- 1 | package push 2 | 3 | import ( 4 | "testing" 5 | 6 | // nolint:revive // Allow dot imports for readability in tests 7 | . "github.com/smartystreets/goconvey/convey" 8 | ) 9 | 10 | func TestMovingAverage(t *testing.T) { 11 | 12 | Convey("Given I have a moving average of size 3", t, func() { 13 | ma := newMovingAverage(3) 14 | 15 | Convey("When I push no values the average is not available", func() { 16 | v, err := ma.average() 17 | So(v, ShouldEqual, 0) 18 | So(err, ShouldNotBeNil) 19 | }) 20 | 21 | Convey("When I push two values the average is not available", func() { 22 | ma = ma.append(1) 23 | ma = ma.append(1) 24 | v, err := ma.average() 25 | So(v, ShouldEqual, 0) 26 | So(err, ShouldNotBeNil) 27 | }) 28 | 29 | Convey("When I push a tree values the average calculated", func() { 30 | ma = ma.append(1) 31 | ma = ma.append(1) 32 | ma = ma.append(1) 33 | v, err := ma.average() 34 | So(v, ShouldEqual, 1) 35 | So(err, ShouldBeNil) 36 | }) 37 | 38 | Convey("When I push a four values the average calculated", func() { 39 | ma = ma.append(1) 40 | ma = ma.append(1) 41 | ma = ma.append(1) 42 | ma = ma.append(1) 43 | v, err := ma.average() 44 | So(v, ShouldEqual, 1) 45 | So(err, ShouldBeNil) 46 | }) 47 | 48 | }) 49 | } 50 | -------------------------------------------------------------------------------- /gateway/upstreamer/push/ping.go: -------------------------------------------------------------------------------- 1 | package push 2 | 3 | import ( 4 | "go.aporeto.io/bahamut" 5 | "golang.org/x/time/rate" 6 | ) 7 | 8 | type entityStatus int 9 | 10 | const ( 11 | entityStatusGoodbye entityStatus = 0 12 | entityStatusHello entityStatus = 1 13 | ) 14 | 15 | // An APILimiter holds the parameters of a *rate.Limiter. 16 | // It is used to announce a desired rate limit for 17 | // inconming requests. 18 | type APILimiter struct { 19 | limiter *rate.Limiter 20 | Limit rate.Limit 21 | Burst int 22 | } 23 | 24 | // IdentityToAPILimitersRegistry is a map of elemental.Identity Name 25 | // to an AnnouncedRateLimits. 26 | type IdentityToAPILimitersRegistry map[string]*APILimiter 27 | 28 | type servicePing struct { 29 | Routes map[int][]bahamut.RouteInfo 30 | Versions map[string]any 31 | APILimiters IdentityToAPILimitersRegistry 32 | Name string 33 | Endpoint string 34 | PushEndpoint string 35 | Prefix string 36 | Status entityStatus 37 | Load float64 38 | } 39 | 40 | // Key returns the key for the service. 41 | // This is either the name or prefix/name, if any. 42 | func (s *servicePing) Key() string { 43 | if s.Prefix != "" { 44 | return s.Prefix + "/" + s.Name 45 | } 46 | 47 | return s.Name 48 | } 49 | 50 | type peerPing struct { 51 | RuntimeID string 52 | Status entityStatus 53 | } 54 | -------------------------------------------------------------------------------- /gateway/upstreamer/push/notifier_options_test.go: -------------------------------------------------------------------------------- 1 | package push 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | // nolint:revive // Allow dot imports for readability in tests 8 | . "github.com/smartystreets/goconvey/convey" 9 | "go.aporeto.io/elemental" 10 | "golang.org/x/time/rate" 11 | ) 12 | 13 | func Test_NotiferOptions(t *testing.T) { 14 | 15 | c := newNotifierConfig() 16 | 17 | Convey("Calling OptionNotifierAnnounceRateLimits should work", t, func() { 18 | rls := IdentityToAPILimitersRegistry{ 19 | "a": {Limit: rate.Limit(1), Burst: 2}, 20 | } 21 | OptionNotifierAnnounceRateLimits(rls)(&c) 22 | So(c.rateLimits, ShouldResemble, rls) 23 | So(c.rateLimits, ShouldNotEqual, rls) 24 | }) 25 | 26 | Convey("Calling OptionNotifierPingInterval should work", t, func() { 27 | OptionNotifierPingInterval(3 * time.Hour)(&c) 28 | So(c.pingInterval, ShouldEqual, 3*time.Hour) 29 | }) 30 | 31 | Convey("Calling OptionNotifierPrefix should work", t, func() { 32 | OptionNotifierPrefix("prefix")(&c) 33 | So(c.prefix, ShouldEqual, "prefix") 34 | }) 35 | 36 | Convey("Calling OptionNotifierAPIPrivateOverrides should work", t, func() { 37 | ov := map[elemental.Identity]bool{ 38 | elemental.MakeIdentity("thing", "things"): true, 39 | } 40 | OptionNotifierPrivateAPIOverrides(ov)(&c) 41 | So(c.privateOverrides, ShouldResemble, map[string]bool{"things": true}) 42 | }) 43 | } 44 | -------------------------------------------------------------------------------- /gateway/upstreamer/push/movingaverage.go: -------------------------------------------------------------------------------- 1 | package push 2 | 3 | import "fmt" 4 | 5 | // MovingAverage represent a moving average 6 | // give a sample size. 7 | type movingAverage struct { 8 | ring []float64 9 | sampleSize int 10 | nextIdx int 11 | samplingComplete bool 12 | } 13 | 14 | // newMovingAverage return a new movingAverage 15 | func newMovingAverage(sampleSize int) movingAverage { 16 | 17 | if sampleSize <= 0 { 18 | panic("sampleSize must be greather than 0.") 19 | } 20 | 21 | return movingAverage{ 22 | sampleSize: sampleSize, 23 | ring: make([]float64, sampleSize), 24 | } 25 | } 26 | 27 | // average return the average of the sampleSize 28 | // If sampleSize are not compplete it returns 0 29 | func (m movingAverage) average() (float64, error) { 30 | 31 | sum := .0 32 | 33 | if !m.samplingComplete { 34 | return sum, fmt.Errorf("cannot compute average without a full sampling") 35 | } 36 | 37 | for _, value := range m.ring { 38 | sum += value 39 | } 40 | 41 | return sum / float64(m.sampleSize), nil 42 | } 43 | 44 | // append will insert a new value to the ring and return a copy 45 | // of itself 46 | func (m movingAverage) append(value float64) movingAverage { 47 | 48 | nm := newMovingAverage(m.sampleSize) 49 | nm.samplingComplete = m.samplingComplete 50 | for i := range m.ring { 51 | nm.ring[i] = m.ring[i] 52 | } 53 | 54 | nm.ring[m.nextIdx] = value 55 | nm.nextIdx = (m.nextIdx + 1) % nm.sampleSize 56 | if nm.nextIdx == 0 { 57 | nm.samplingComplete = true 58 | } 59 | 60 | return nm 61 | } 62 | -------------------------------------------------------------------------------- /netlimiter.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package bahamut 6 | 7 | import ( 8 | "net" 9 | "sync/atomic" 10 | ) 11 | 12 | type limitListener struct { 13 | net.Listener 14 | nConn int64 15 | maxConn int64 16 | } 17 | 18 | // newListener returns a Listener that uses the given semaphore to accept at most 19 | // n simultaneous connections from the provided Listener where n is the size of 20 | // the given channel. 21 | func newListener(l net.Listener, n int) *limitListener { 22 | 23 | return &limitListener{ 24 | Listener: l, 25 | maxConn: int64(n), 26 | } 27 | } 28 | 29 | func (l *limitListener) release() { 30 | 31 | atomic.AddInt64(&l.nConn, -1) 32 | } 33 | 34 | func (l *limitListener) Accept() (net.Conn, error) { 35 | 36 | for { 37 | 38 | c, err := l.Listener.Accept() 39 | if err != nil { 40 | return nil, err 41 | } 42 | 43 | var currentConn int64 44 | if l.maxConn > 0 { 45 | currentConn = atomic.AddInt64(&l.nConn, 1) 46 | } 47 | 48 | if currentConn > l.maxConn { 49 | c.Close() // nolint: errcheck 50 | l.release() 51 | continue 52 | } 53 | 54 | return &limitListenerConn{Conn: c, release: l.release}, nil 55 | } 56 | } 57 | 58 | func (l *limitListener) Close() error { 59 | return l.Listener.Close() 60 | } 61 | 62 | type limitListenerConn struct { 63 | net.Conn 64 | release func() 65 | } 66 | 67 | func (c *limitListenerConn) Close() error { 68 | c.release() 69 | return c.Conn.Close() 70 | } 71 | -------------------------------------------------------------------------------- /gateway/extractors.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | 7 | "github.com/cespare/xxhash" 8 | ) 9 | 10 | type defaultSourceExtractor struct { 11 | authCookieName string 12 | } 13 | 14 | // NewDefaultSourceExtractor returns a default SourceExtractor. 15 | // A source extractor will discriminate the source of a request 16 | // based on a hash of its authentication string. 17 | // It will first use an eventual cookie with the given name, 18 | // then use then use the Authorization header. 19 | // If both are empty, the bucket key will be 'default'. 20 | // If authCookieName is empty, only the value of the Authorization 21 | // header will be taken into account. 22 | func NewDefaultSourceExtractor(authCookieName string) SourceExtractor { 23 | return defaultSourceExtractor{ 24 | authCookieName: authCookieName, 25 | } 26 | } 27 | 28 | func (f defaultSourceExtractor) ExtractSource(r *http.Request) (string, error) { 29 | 30 | var v string 31 | authHeader := r.Header.Get("Authorization") 32 | 33 | var authCookie *http.Cookie 34 | if f.authCookieName != "" { 35 | authCookie, _ = r.Cookie(f.authCookieName) 36 | } 37 | 38 | switch { 39 | case authCookie != nil && authCookie.Value != "": 40 | v = authCookie.Value 41 | case authHeader != "": 42 | v = authHeader 43 | default: 44 | return "default", nil 45 | } 46 | 47 | return fmt.Sprintf("%d", xxhash.Sum64([]byte(v))), nil 48 | } 49 | 50 | type defaultTCPSourceExtractor struct{} 51 | 52 | func (f defaultTCPSourceExtractor) ExtractSource(r *http.Request) (string, error) { 53 | 54 | return r.RemoteAddr, nil 55 | } 56 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/ca-chain-a.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBSTCB8aADAgECAhAiq53Rx75bI9gq0+FkXMSYMAoGCCqGSM49BAMCMBcxFTAT 3 | BgNVBAMTDGludGVybWVkaWF0ZTAeFw0xNzEwMTMyMzExNTBaFw0yNzA4MjIyMzEx 4 | NTBaMBMxETAPBgNVBAMTCHNpZ25lci1hMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcD 5 | QgAEko9zulgiDQi80GPGV9VEsWzuUfP0DsCIb84XsRpUcYl7n7Njt+YHiYbmG/mI 6 | g2zWzAiTTNmetxBtS6hxqUa6pKMjMCEwDgYDVR0PAQH/BAQDAgEGMA8GA1UdEwEB 7 | /wQFMAMBAf8wCgYIKoZIzj0EAwIDRwAwRAIgNOlZXpSQsLrWyyKUfr3BbYoFmXYl 8 | qNF0f0zpeQUGzc8CIFmKIYYNAS041NzU7M8qJvio7FqZtSk1LRZC0QR1l8vB 9 | -----END CERTIFICATE----- 10 | -----BEGIN CERTIFICATE----- 11 | MIIBRjCB7aADAgECAhAnzb2sr6mCl+KiVHdScn39MAoGCCqGSM49BAMCMA8xDTAL 12 | BgNVBAMTBHJvb3QwHhcNMTcxMDEzMjMxMDU3WhcNMjcwODIyMjMxMDU3WjAXMRUw 13 | EwYDVQQDEwxpbnRlcm1lZGlhdGUwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQG 14 | jv7NvkqJijIvE7GdhhsYs/5hYNE7G4QdRZ/juHFCcYlaFHxXXiY3D8YJjKyiluBM 15 | MRlUMQri5FsOBKGPSUEkoyMwITAOBgNVHQ8BAf8EBAMCAQYwDwYDVR0TAQH/BAUw 16 | AwEB/zAKBggqhkjOPQQDAgNIADBFAiAme2oWY9u7g6kROCwrK+u/sf9RMtQOKMVu 17 | ws/4FrWMPQIhANCgWbFbWHckIgrFA/YheRy25B1/irCGW01ziyfbFnTY 18 | -----END CERTIFICATE----- 19 | -----BEGIN CERTIFICATE----- 20 | MIIBPTCB5aADAgECAhACHS8TfLjFRxcCfk7S5ko1MAoGCCqGSM49BAMCMA8xDTAL 21 | BgNVBAMTBHJvb3QwHhcNMTcxMDEzMjMwOTQ5WhcNMjcwODIyMjMwOTQ5WjAPMQ0w 22 | CwYDVQQDEwRyb290MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEZrASa6uHjeyj 23 | JAeeYBIo0sF1DB8hX1Fv1TIjSNLDEyhDA+bl+LPM5WKYEZsagD5RcGa33s4bYR2f 24 | gqBtx1eqI6MjMCEwDgYDVR0PAQH/BAQDAgEGMA8GA1UdEwEB/wQFMAMBAf8wCgYI 25 | KoZIzj0EAwIDRwAwRAIgO2YRDPkM7fa9Z6Ld77d/59EpPGjKzQKiT0n4OXo7zz0C 26 | IHzFRdemtNkpM/JXVj8IVhyY7T/h6ShYn9CtZ3kD92Hu 27 | -----END CERTIFICATE----- 28 | -------------------------------------------------------------------------------- /authorizer/mtls/fixtures/ca-chain-b.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIBSTCB8aADAgECAhA8BebnE3zdVOII5Zh6/UEaMAoGCCqGSM49BAMCMBcxFTAT 3 | BgNVBAMTDGludGVybWVkaWF0ZTAeFw0xNzEwMTMyMzEyMTBaFw0yNzA4MjIyMzEy 4 | MTBaMBMxETAPBgNVBAMTCHNpZ25lci1iMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcD 5 | QgAEJzTiVakpH3JD3z5Ar8tYX65Sc18nqJ/jlR8hfsshCptDILVv/pX5NGY3C8+9 6 | 4PKacrP0QNyWD2SGTW0R6DbFmKMjMCEwDgYDVR0PAQH/BAQDAgEGMA8GA1UdEwEB 7 | /wQFMAMBAf8wCgYIKoZIzj0EAwIDRwAwRAIgBmum95kc40bgQ8YpgyVEhYDUd4e5 8 | I+w3mYNk7CGEAPkCIF5esWzNKE10TJU7mPmV7vVX463dpIvSRKWxguhiFX3R 9 | -----END CERTIFICATE----- 10 | -----BEGIN CERTIFICATE----- 11 | MIIBRjCB7aADAgECAhAnzb2sr6mCl+KiVHdScn39MAoGCCqGSM49BAMCMA8xDTAL 12 | BgNVBAMTBHJvb3QwHhcNMTcxMDEzMjMxMDU3WhcNMjcwODIyMjMxMDU3WjAXMRUw 13 | EwYDVQQDEwxpbnRlcm1lZGlhdGUwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAQG 14 | jv7NvkqJijIvE7GdhhsYs/5hYNE7G4QdRZ/juHFCcYlaFHxXXiY3D8YJjKyiluBM 15 | MRlUMQri5FsOBKGPSUEkoyMwITAOBgNVHQ8BAf8EBAMCAQYwDwYDVR0TAQH/BAUw 16 | AwEB/zAKBggqhkjOPQQDAgNIADBFAiAme2oWY9u7g6kROCwrK+u/sf9RMtQOKMVu 17 | ws/4FrWMPQIhANCgWbFbWHckIgrFA/YheRy25B1/irCGW01ziyfbFnTY 18 | -----END CERTIFICATE----- 19 | -----BEGIN CERTIFICATE----- 20 | MIIBPTCB5aADAgECAhACHS8TfLjFRxcCfk7S5ko1MAoGCCqGSM49BAMCMA8xDTAL 21 | BgNVBAMTBHJvb3QwHhcNMTcxMDEzMjMwOTQ5WhcNMjcwODIyMjMwOTQ5WjAPMQ0w 22 | CwYDVQQDEwRyb290MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEZrASa6uHjeyj 23 | JAeeYBIo0sF1DB8hX1Fv1TIjSNLDEyhDA+bl+LPM5WKYEZsagD5RcGa33s4bYR2f 24 | gqBtx1eqI6MjMCEwDgYDVR0PAQH/BAQDAgEGMA8GA1UdEwEB/wQFMAMBAf8wCgYI 25 | KoZIzj0EAwIDRwAwRAIgO2YRDPkM7fa9Z6Ld77d/59EpPGjKzQKiT0n4OXo7zz0C 26 | IHzFRdemtNkpM/JXVj8IVhyY7T/h6ShYn9CtZ3kD92Hu 27 | -----END CERTIFICATE----- 28 | -------------------------------------------------------------------------------- /authorizer/simple/authorizer.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package simple 13 | 14 | import "go.aporeto.io/bahamut" 15 | 16 | // A Authorizer is a bahamut.Authorizer compliant structure to authorize 17 | // requests using a given functions. 18 | type Authorizer struct { 19 | customAuthFunc CustomAuthRequestFunc 20 | } 21 | 22 | // NewAuthorizer returns a new *Authorizer. 23 | func NewAuthorizer(customAuthFunc CustomAuthRequestFunc) *Authorizer { 24 | 25 | return &Authorizer{ 26 | customAuthFunc: customAuthFunc, 27 | } 28 | } 29 | 30 | // IsAuthorized authorizer the given context. 31 | // It will return true if the authentication is a success, false in case of failure 32 | // and an eventual error in case of error. 33 | func (a *Authorizer) IsAuthorized(ctx bahamut.Context) (bahamut.AuthAction, error) { 34 | 35 | if a.customAuthFunc == nil { 36 | return bahamut.AuthActionContinue, nil 37 | } 38 | 39 | action, err := a.customAuthFunc(ctx) 40 | if err != nil { 41 | return bahamut.AuthActionKO, err 42 | } 43 | 44 | return action, nil 45 | } 46 | -------------------------------------------------------------------------------- /authorizer/simple/pushpublisher.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package simple 13 | 14 | import ( 15 | "go.aporeto.io/elemental" 16 | ) 17 | 18 | // CustomShouldPublishFunc is the type of function that can be used 19 | // to decide is an event should be published. 20 | type CustomShouldPublishFunc func(*elemental.Event) (bool, error) 21 | 22 | // A PublishHandler handles publish decisions. 23 | type PublishHandler struct { 24 | shouldPublishFunc CustomShouldPublishFunc 25 | } 26 | 27 | // NewPublishHandler returns a new PushSessionsHandler. If shouldPublishFunc is nil 28 | // the publisher will dispatch all events. 29 | func NewPublishHandler(shouldPublishFunc CustomShouldPublishFunc) *PublishHandler { 30 | 31 | return &PublishHandler{ 32 | shouldPublishFunc: shouldPublishFunc, 33 | } 34 | } 35 | 36 | // ShouldPublish is part of the bahamut.PushPublishHandler interface 37 | func (g *PublishHandler) ShouldPublish(event *elemental.Event) (bool, error) { 38 | 39 | if g.shouldPublishFunc == nil { 40 | return true, nil 41 | } 42 | 43 | return g.shouldPublishFunc(event) 44 | } 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bahamut 2 | 3 | [![Codacy Badge](https://app.codacy.com/project/badge/Grade/f8d3dbbc552b4c8abf8985425d25c338)](https://www.codacy.com/gh/PaloAltoNetworks/bahamut/dashboard?utm_source=github.com&utm_medium=referral&utm_content=PaloAltoNetworks/bahamut&utm_campaign=Badge_Grade) [![Codacy Badge](https://app.codacy.com/project/badge/Coverage/f8d3dbbc552b4c8abf8985425d25c338)](https://www.codacy.com/gh/PaloAltoNetworks/bahamut/dashboard?utm_source=github.com&utm_medium=referral&utm_content=PaloAltoNetworks/bahamut&utm_campaign=Badge_Coverage) 4 | 5 | > Note: this is a work in progress. 6 | 7 | Bahamut is a Go library that provides everything you need to set up a full blown 8 | API server based on an [Elemental](https://go.aporeto.io/elemental) model 9 | generated from a [Regolithe Specification](https://go.aporeto.io/regolithe). 10 | 11 | The main concept of Bahamut is to only write core business logic, and letting it 12 | handle all the boring bookkeeping. You can implement various Processors 13 | interfaces, and register them when you start a Bahamut Server. 14 | 15 | A Bahamut Server is not directly responsible for storing an retrieving data from 16 | a database. To do so, you can use any backend library you like in your 17 | processors, but we recommend using 18 | [Manipulate](https://go.aporeto.io/manipulate), which provides a common 19 | interface for manipulating an Elemental model and multiple implementations for 20 | MongoDB (manipmongo), MemDB (manipmemory) or can be used to issue ReST calls using 21 | maniphttp. 22 | 23 | It is usually used by clients to interact with the API of a Bahamut service, but 24 | also used for Bahamut Services to talk together. 25 | -------------------------------------------------------------------------------- /authorizer/mtls/utils_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package mtls 13 | 14 | import ( 15 | "crypto/x509" 16 | "encoding/pem" 17 | "os" 18 | "reflect" 19 | "testing" 20 | ) 21 | 22 | func Test_makeClaims(t *testing.T) { 23 | 24 | cdata, _ := os.ReadFile("./fixtures/claim-test-cert.pem") 25 | cblock, _ := pem.Decode(cdata) 26 | cert, _ := x509.ParseCertificate(cblock.Bytes) 27 | 28 | type args struct { 29 | cert *x509.Certificate 30 | } 31 | tests := []struct { 32 | name string 33 | args args 34 | want []string 35 | }{ 36 | { 37 | "simple", 38 | args{ 39 | cert, 40 | }, 41 | []string{ 42 | "@auth:realm=certificate", 43 | "@auth:mode=internal", 44 | "@auth:serialnumber=240974276977353940447659278772794983018", 45 | "@auth:commonname=test", 46 | "@auth:organization=A", 47 | "@auth:organizationalunit=B", 48 | }, 49 | }, 50 | } 51 | for _, tt := range tests { 52 | t.Run(tt.name, func(t *testing.T) { 53 | if got := makeClaims(tt.args.cert); !reflect.DeepEqual(got, tt.want) { 54 | t.Errorf("makeClaims() = %v, want %v", got, tt.want) 55 | } 56 | }) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /gateway/utils.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "crypto/tls" 5 | "fmt" 6 | "net" 7 | "net/http" 8 | 9 | "github.com/go-zoo/bone" 10 | "go.aporeto.io/bahamut" 11 | ) 12 | 13 | func injectGeneralHeader(h http.Header) http.Header { 14 | 15 | h.Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains; preload") 16 | h.Set("X-Frame-Options", "DENY") 17 | h.Set("X-Content-Type-Options", "nosniff") 18 | h.Set("X-Xss-Protection", "1; mode=block") 19 | h.Set("Cache-Control", "private, no-transform") 20 | 21 | return h 22 | } 23 | 24 | func injectCORSHeader(h http.Header, corsOrigin string, additionalCorsOrigin []string, allowCredentials bool, origin string, method string) http.Header { 25 | 26 | a := bahamut.NewDefaultCORSController(corsOrigin, additionalCorsOrigin) 27 | ac := a.PolicyForRequest(nil) 28 | ac.AllowCredentials = allowCredentials 29 | ac.Inject(h, origin, method == http.MethodOptions) 30 | return h 31 | } 32 | 33 | func makeProxyProtocolSourceChecker(allowed string) (func(net.Addr) (bool, error), error) { 34 | 35 | _, allowedSubnet, err := net.ParseCIDR(allowed) 36 | if err != nil { 37 | return nil, fmt.Errorf("unable to parse CIDR: %s", err) 38 | } 39 | 40 | return func(addr net.Addr) (bool, error) { 41 | 42 | ipstr, _, err := net.SplitHostPort(addr.String()) 43 | if err != nil { 44 | return false, fmt.Errorf("unable to parse net.Addr: %s", err) 45 | } 46 | 47 | return allowedSubnet.Contains(net.ParseIP(ipstr)), nil 48 | }, nil 49 | } 50 | 51 | func makeGoodbyeServer(listen string, serverTLSConfig *tls.Config) *http.Server { 52 | 53 | mux := bone.New() 54 | mux.NotFound( 55 | http.HandlerFunc( 56 | func(w http.ResponseWriter, req *http.Request) { 57 | w.WriteHeader(503) 58 | _, _ = w.Write([]byte(`[{"code":503,"title":"Service Not Available","description":"Shutting down. Please retry your request","subject":"gateway"}]`)) 59 | }, 60 | ), 61 | ) 62 | 63 | return &http.Server{ 64 | TLSConfig: serverTLSConfig, 65 | Addr: listen, 66 | Handler: mux, 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /websocket_push_session_mock_test.go: -------------------------------------------------------------------------------- 1 | package bahamut 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "net/http" 7 | "testing" 8 | 9 | // nolint:revive // Allow dot imports for readability in tests 10 | . "github.com/smartystreets/goconvey/convey" 11 | "go.aporeto.io/elemental" 12 | ) 13 | 14 | func TestMockSession(t *testing.T) { 15 | 16 | Convey("MockSession should work", t, func() { 17 | 18 | s := NewMockSession() 19 | So(s.MockClaimsMap, ShouldNotBeNil) 20 | So(s.MockCookies, ShouldNotBeNil) 21 | So(s.MockHeaders, ShouldNotBeNil) 22 | So(s.MockParameters, ShouldNotBeNil) 23 | 24 | s.MockClaimsMap = map[string]string{"k": "v"} 25 | s.MockClientIP = "1.1.1.1" 26 | s.MockCookies = map[string]*http.Cookie{"c": {}} 27 | s.MockHeaders = map[string]string{"k": "v"} 28 | s.MockIdentifier = "id" 29 | s.MockParameters = map[string]string{"k": "v"} 30 | s.MockPushConfig = &elemental.PushConfig{} 31 | s.MockTLSConnectionState = &tls.ConnectionState{} 32 | s.MockToken = "token" 33 | 34 | var calledDirectPush bool 35 | s.MockDirectPush = func(evts ...*elemental.Event) { calledDirectPush = true } 36 | 37 | s.SetClaims([]string{"k=v"}) 38 | s.SetMetadata("mischief") // A beer to the one who gets the reference. 39 | 40 | So(s.Identifier(), ShouldEqual, "id") 41 | So(s.Parameter("k"), ShouldEqual, "v") 42 | So(s.Header("k"), ShouldEqual, "v") 43 | So(s.PushConfig(), ShouldNotBeNil) 44 | So(s.Claims(), ShouldResemble, []string{"k=v"}) 45 | So(s.ClaimsMap(), ShouldResemble, map[string]string{"k": "v"}) 46 | So(s.Token(), ShouldEqual, "token") 47 | So(s.TLSConnectionState(), ShouldNotBeNil) 48 | So(s.Metadata(), ShouldEqual, "mischief") 49 | So(s.Context(), ShouldResemble, context.Background()) 50 | So(s.ClientIP(), ShouldEqual, "1.1.1.1") 51 | 52 | cc, err := s.Cookie("c") 53 | So(cc, ShouldNotBeNil) 54 | So(err, ShouldBeNil) 55 | cc, err = s.Cookie("d") 56 | So(cc, ShouldBeNil) 57 | So(err, ShouldEqual, http.ErrNoCookie) 58 | 59 | s.DirectPush() 60 | So(calledDirectPush, ShouldBeTrue) 61 | }) 62 | } 63 | -------------------------------------------------------------------------------- /pinger.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "sync" 16 | "time" 17 | 18 | "go.uber.org/zap" 19 | ) 20 | 21 | const ( 22 | // PingStatusOK represents the status "ok" 23 | PingStatusOK = "ok" 24 | // PingStatusTimeout represents the status "timeout" 25 | PingStatusTimeout = "timeout" 26 | // PingStatusError represents the status "error" 27 | PingStatusError = "error" 28 | ) 29 | 30 | // A Pinger is an interface for objects that implements a Ping method 31 | type Pinger interface { 32 | Ping(timeout time.Duration) error 33 | } 34 | 35 | // RetrieveHealthStatus returns the status for each Pinger. 36 | func RetrieveHealthStatus(timeout time.Duration, pingers map[string]Pinger) error { 37 | 38 | var firstError error 39 | 40 | var wg sync.WaitGroup 41 | wg.Add(len(pingers)) 42 | m := &sync.Mutex{} 43 | for name, pinger := range pingers { 44 | go func(name string, pinger Pinger) { 45 | defer wg.Done() 46 | 47 | start := time.Now() 48 | err := pinger.Ping(timeout) 49 | status := stringifyStatus(err) 50 | duration := time.Since(start) 51 | 52 | zap.L().Info("Ping", 53 | zap.String("service", name), 54 | zap.String("status", status), 55 | zap.String("duration", duration.String()), 56 | zap.Error(err), 57 | ) 58 | 59 | m.Lock() 60 | if err != nil && firstError == nil { 61 | firstError = err 62 | } 63 | m.Unlock() 64 | }(name, pinger) 65 | } 66 | 67 | wg.Wait() 68 | 69 | return firstError 70 | } 71 | 72 | // stringify status output 73 | func stringifyStatus(err error) string { 74 | if err == nil { 75 | return PingStatusOK 76 | } 77 | 78 | errMsg := err.Error() 79 | if errMsg == PingStatusTimeout { 80 | return PingStatusTimeout 81 | } 82 | 83 | return PingStatusError 84 | } 85 | -------------------------------------------------------------------------------- /pinger_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "fmt" 16 | "testing" 17 | "time" 18 | 19 | // nolint:revive // Allow dot imports for readability in tests 20 | . "github.com/smartystreets/goconvey/convey" 21 | ) 22 | 23 | type MockPinger struct { 24 | PingStatus error 25 | } 26 | 27 | func (m MockPinger) Ping(timeout time.Duration) error { 28 | return m.PingStatus 29 | } 30 | 31 | func Test_RetrieveHealthStatus(t *testing.T) { 32 | 33 | Convey("Given the following pingers", t, func() { 34 | pingers := map[string]Pinger{ 35 | "p1": MockPinger{PingStatus: nil}, 36 | "p2": MockPinger{PingStatus: fmt.Errorf(PingStatusTimeout)}, //nolint:staticcheck 37 | "p3": MockPinger{PingStatus: fmt.Errorf("Another status")}, //nolint:staticcheck 38 | } 39 | results := RetrieveHealthStatus(time.Second, pingers) 40 | 41 | Convey("Then I should have the following status results", func() { 42 | So(results, ShouldNotBeNil) 43 | }) 44 | }) 45 | 46 | Convey("Given the following pingers", t, func() { 47 | pingers := map[string]Pinger{ 48 | "p1": MockPinger{PingStatus: nil}, 49 | "p2": MockPinger{PingStatus: nil}, 50 | "p3": MockPinger{PingStatus: nil}, 51 | } 52 | results := RetrieveHealthStatus(time.Second, pingers) 53 | 54 | Convey("Then I should have the following status results", func() { 55 | So(results, ShouldBeNil) 56 | }) 57 | }) 58 | } 59 | 60 | func Test_stringifyStatus(t *testing.T) { 61 | 62 | Convey("Given the stringifyStatus method", t, func() { 63 | So(stringifyStatus(nil), ShouldEqual, PingStatusOK) 64 | So(stringifyStatus(fmt.Errorf(PingStatusTimeout)), ShouldEqual, PingStatusTimeout) //nolint:staticcheck 65 | So(stringifyStatus(fmt.Errorf("Another status")), ShouldEqual, PingStatusError) //nolint:staticcheck 66 | }) 67 | } 68 | -------------------------------------------------------------------------------- /gateway/upstreamer/push/services.go: -------------------------------------------------------------------------------- 1 | package push 2 | 3 | import ( 4 | "sync" 5 | "time" 6 | 7 | "go.aporeto.io/bahamut" 8 | "golang.org/x/time/rate" 9 | ) 10 | 11 | type endpointInfo struct { 12 | lastSeen time.Time 13 | lastLimiterAdjust time.Time 14 | sync.RWMutex 15 | limiters IdentityToAPILimitersRegistry 16 | address string 17 | lastLoad float64 18 | } 19 | 20 | type servicesConfig map[string]*service 21 | 22 | type service struct { 23 | routes map[int][]bahamut.RouteInfo 24 | versions map[string]any 25 | endpoints map[string]*endpointInfo 26 | name string 27 | } 28 | 29 | // newService returns a new proxy info from the given string. 30 | func newService(name string) *service { 31 | return &service{ 32 | name: name, 33 | endpoints: map[string]*endpointInfo{}, 34 | } 35 | } 36 | 37 | func (b *service) getEndpoints() []*endpointInfo { 38 | 39 | out := make([]*endpointInfo, len(b.endpoints)) 40 | var i int 41 | for _, v := range b.endpoints { 42 | out[i] = v 43 | i++ 44 | } 45 | 46 | return out 47 | } 48 | 49 | func (b *service) hasEndpoint(ep string) bool { 50 | 51 | _, ok := b.endpoints[ep] 52 | 53 | return ok 54 | } 55 | 56 | func (b *service) registerEndpoint(address string, load float64, apilimiters IdentityToAPILimitersRegistry) { 57 | 58 | if apilimiters == nil { 59 | apilimiters = IdentityToAPILimitersRegistry{} 60 | } 61 | 62 | // Instantiate all the actual rate limiters using the values 63 | // announced by the service. 64 | for _, l := range apilimiters { 65 | l.limiter = rate.NewLimiter(l.Limit, l.Burst) 66 | } 67 | 68 | b.endpoints[address] = &endpointInfo{ 69 | lastSeen: time.Now(), 70 | lastLoad: load, 71 | address: address, 72 | limiters: apilimiters, 73 | } 74 | } 75 | 76 | func (b *service) pokeEndpoint(ep string, load float64) { 77 | 78 | if epi, ok := b.endpoints[ep]; ok { 79 | epi.Lock() 80 | epi.lastSeen = time.Now() 81 | epi.lastLoad = load 82 | epi.Unlock() 83 | } 84 | } 85 | 86 | func (b *service) outdatedEndpoints(since time.Time) []string { 87 | 88 | var out []string 89 | 90 | for ep, epi := range b.endpoints { 91 | epi.RLock() 92 | if epi.lastSeen.Before(since) { 93 | out = append(out, ep) 94 | } 95 | epi.RUnlock() 96 | } 97 | 98 | return out 99 | } 100 | 101 | func (b *service) unregisterEndpoint(ep string) { 102 | 103 | delete(b.endpoints, ep) 104 | } 105 | -------------------------------------------------------------------------------- /gateway/upstreamer/push/notifier_options.go: -------------------------------------------------------------------------------- 1 | package push 2 | 3 | import ( 4 | "time" 5 | 6 | "go.aporeto.io/elemental" 7 | ) 8 | 9 | type notifierConfig struct { 10 | rateLimits IdentityToAPILimitersRegistry 11 | privateOverrides map[string]bool 12 | prefix string 13 | pingInterval time.Duration 14 | } 15 | 16 | func newNotifierConfig() notifierConfig { 17 | return notifierConfig{ 18 | rateLimits: IdentityToAPILimitersRegistry{}, 19 | pingInterval: 5 * time.Second, 20 | privateOverrides: map[string]bool{}, 21 | } 22 | } 23 | 24 | // A NotifierOption is the kind of option that can be passed 25 | // to the notifier. 26 | type NotifierOption func(*notifierConfig) 27 | 28 | // OptionNotifierPingInterval sets the interval between sending 29 | // 2 pings. The default is 5s. 30 | func OptionNotifierPingInterval(interval time.Duration) NotifierOption { 31 | return func(c *notifierConfig) { 32 | c.pingInterval = interval 33 | } 34 | } 35 | 36 | // OptionNotifierAnnounceRateLimits can be used to set a IdentityToAPILimitersRegistry 37 | // to tell the gateways to instantiate some rate limiters for the current 38 | // instance of the service. 39 | // 40 | // It is not guaranteed that the gateway will honor the request. 41 | func OptionNotifierAnnounceRateLimits(rls IdentityToAPILimitersRegistry) NotifierOption { 42 | return func(c *notifierConfig) { 43 | c.rateLimits = make(IdentityToAPILimitersRegistry, len(rls)) 44 | for k, v := range rls { 45 | c.rateLimits[k] = v 46 | } 47 | } 48 | } 49 | 50 | // OptionNotifierPrefix sets the API prefix that the gateway should 51 | // add to the API routes for that service. 52 | func OptionNotifierPrefix(prefix string) NotifierOption { 53 | return func(c *notifierConfig) { 54 | c.prefix = prefix 55 | } 56 | } 57 | 58 | // OptionNotifierPrivateAPIOverrides allows to pass a map of identity to boolean 59 | // that will be used to override the specificaton's "private" flag. This allows 60 | // the service to force a public API to be private (or vice versa). 61 | // 62 | // NOTE: this does not change the internal data in bahamut's server RouteInfo. 63 | // As far as bahamut server is concerned, the route Private flag did not 64 | // change. This is only affecting the gateway. 65 | func OptionNotifierPrivateAPIOverrides(overrides map[elemental.Identity]bool) NotifierOption { 66 | return func(c *notifierConfig) { 67 | for k, v := range overrides { 68 | c.privateOverrides[k.Category] = v 69 | } 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /profiling_server.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "context" 16 | "net/http" 17 | "net/http/pprof" 18 | "time" 19 | 20 | "go.uber.org/zap" 21 | ) 22 | 23 | // an profilingServer is the structure serving the profiling. 24 | type profilingServer struct { 25 | server *http.Server 26 | cfg config 27 | } 28 | 29 | // newProfilingServer returns a new profilingServer. 30 | func newProfilingServer(cfg config) *profilingServer { 31 | 32 | return &profilingServer{ 33 | cfg: cfg, 34 | } 35 | } 36 | 37 | // start starts the profilingServer. 38 | func (s *profilingServer) start(ctx context.Context) { 39 | 40 | mux := http.NewServeMux() 41 | mux.HandleFunc("/debug/pprof/", pprof.Index) 42 | mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) 43 | mux.HandleFunc("/debug/pprof/profile", pprof.Profile) 44 | mux.HandleFunc("/debug/pprof/trace", pprof.Trace) 45 | 46 | s.server = &http.Server{ 47 | Addr: s.cfg.profilingServer.listenAddress, 48 | Handler: mux, 49 | } 50 | 51 | go func() { 52 | if err := s.server.ListenAndServe(); err != nil { 53 | if err == http.ErrServerClosed { 54 | return 55 | } 56 | zap.L().Fatal("Unable to start profiling server", zap.Error(err)) 57 | } 58 | }() 59 | 60 | zap.L().Info("Profiler server started", zap.String("address", s.cfg.profilingServer.listenAddress)) 61 | 62 | <-ctx.Done() 63 | } 64 | 65 | // stop stops the profilingServer. 66 | func (s *profilingServer) stop() { 67 | 68 | if s.server == nil { 69 | return 70 | } 71 | 72 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) 73 | 74 | go func() { 75 | defer cancel() 76 | if err := s.server.Shutdown(ctx); err != nil { 77 | zap.L().Error("Could not gracefully stop profiling server", zap.Error(err)) 78 | } else { 79 | zap.L().Debug("Profiling server stopped") 80 | } 81 | }() 82 | 83 | zap.L().Debug("Profile server stopped") 84 | } 85 | -------------------------------------------------------------------------------- /gateway/interfaces.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "time" 7 | 8 | "golang.org/x/time/rate" 9 | ) 10 | 11 | // ErrUpstreamerTooManyRequests can be returned to 12 | // instruct the bahamut.Gateway to return to stop 13 | // routing and return a a 429 Too Many Request error to 14 | // the client. 15 | var ErrUpstreamerTooManyRequests = errors.New("Please retry in a moment") 16 | 17 | // An Upstreamer is the interface that can compute upstreams. 18 | type Upstreamer interface { 19 | 20 | // Upstream is called by the bahamut.Gateway for each incoming request 21 | // in order to find which upstream to forward the request to, based 22 | // on the incoming http.Request and any other details the implementation 23 | // whishes to. Needless to say, it must be fast or it would severely degrade 24 | // the performances of the bahamut.Gateway. 25 | // 26 | // The request state must not be changed from this function. 27 | // 28 | // The returned upstream is a string in the form "https://10.3.19.4". 29 | // If it is empty, the bahamut.Gayeway will return a 30 | // 503 Service Unavailable error. 31 | // 32 | // If Upstream returns an error, the bahamut.Gayeway will check for a 33 | // known ErrUpstreamerX and will act accordingly. Otherwise it will 34 | // return the error as a 500 Internal Server Error. 35 | Upstream(req *http.Request) (upstream string, err error) 36 | } 37 | 38 | // A SourceExtractor is used to extract a token (or key) used 39 | // to keep track of a single source. 40 | type SourceExtractor interface { 41 | 42 | // ExtractSource will be called to decide what would be the rate to 43 | // given a request. 44 | ExtractSource(req *http.Request) (token string, err error) 45 | } 46 | 47 | // A RateExtractor is used to decide rates per token. 48 | // This allows to perform advanced computation to determine how 49 | // to rate limit one unique client. 50 | type RateExtractor interface { 51 | 52 | // ExtractRates will be called to decide what would be the rate to 53 | // given a request. 54 | ExtractRates(r *http.Request) (rate.Limit, int, error) 55 | } 56 | 57 | // A LatencyBasedUpstreamer is the interface that can circle back 58 | // response time as an input for Upstreamer decision. 59 | type LatencyBasedUpstreamer interface { 60 | CollectLatency(address string, responseTime time.Duration) 61 | Upstreamer 62 | } 63 | 64 | // A Gateway can be used as an api gateway. 65 | type Gateway interface { 66 | Start() 67 | Stop() 68 | } 69 | 70 | // A LimiterMetricManager is used to compute 71 | // metrics for the various limiters that support it. 72 | type LimiterMetricManager interface { 73 | RegisterLimitedConnection() 74 | RegisterAcceptedConnection() 75 | } 76 | -------------------------------------------------------------------------------- /authorizer/simple/publishhandler_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package simple 13 | 14 | import ( 15 | "fmt" 16 | "testing" 17 | 18 | // nolint:revive // Allow dot imports for readability in tests 19 | . "github.com/smartystreets/goconvey/convey" 20 | "go.aporeto.io/elemental" 21 | ) 22 | 23 | func TestPublishHandler_NewPublishHandler(t *testing.T) { 24 | 25 | Convey("Given I call NewPublishHandler with one funcs", t, func() { 26 | 27 | f1 := func(*elemental.Event) (bool, error) { return true, nil } 28 | 29 | pub := NewPublishHandler(f1) 30 | 31 | Convey("Then it should be correctly initialized", func() { 32 | So(pub.shouldPublishFunc, ShouldEqual, f1) 33 | 34 | }) 35 | }) 36 | } 37 | 38 | func TestPublishHandler_ShouldPublish(t *testing.T) { 39 | 40 | Convey("Given I call NewPublishHandler and a func that says ok", t, func() { 41 | 42 | f1 := func(*elemental.Event) (bool, error) { return true, nil } 43 | 44 | pub := NewPublishHandler(f1) 45 | 46 | Convey("When I call ShouldPublish", func() { 47 | 48 | action, err := pub.ShouldPublish(nil) 49 | 50 | Convey("Then err should be nil", func() { 51 | So(err, ShouldBeNil) 52 | }) 53 | 54 | Convey("Then action should be OK", func() { 55 | So(action, ShouldEqual, true) 56 | }) 57 | }) 58 | }) 59 | 60 | Convey("Given I call NewPublishHandler and no func", t, func() { 61 | 62 | pub := NewPublishHandler(nil) 63 | 64 | Convey("When I call ShouldPublish", func() { 65 | 66 | action, err := pub.ShouldPublish(nil) 67 | 68 | Convey("Then err should be nil", func() { 69 | So(err, ShouldBeNil) 70 | }) 71 | 72 | Convey("Then action should be Continue", func() { 73 | So(action, ShouldEqual, true) 74 | }) 75 | }) 76 | }) 77 | 78 | Convey("Given I call NewPublishHandler and a func that returns an error", t, func() { 79 | 80 | f1 := func(*elemental.Event) (bool, error) { return false, fmt.Errorf("paf") } 81 | 82 | pub := NewPublishHandler(f1) 83 | 84 | Convey("When I call ShouldPublish", func() { 85 | 86 | action, err := pub.ShouldPublish(nil) 87 | 88 | Convey("Then err should not be nil", func() { 89 | So(err.Error(), ShouldEqual, "paf") 90 | }) 91 | 92 | Convey("Then action should be KO", func() { 93 | So(action, ShouldEqual, false) 94 | }) 95 | }) 96 | }) 97 | } 98 | -------------------------------------------------------------------------------- /processor_helpers.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "net/http" 16 | 17 | "go.aporeto.io/elemental" 18 | ) 19 | 20 | // CheckAuthentication checks if the current context has been authenticated if there is any authenticator registered. 21 | // 22 | // If it is not authenticated it stops the normal processing execution flow, and will write the Unauthorized response to the given writer. 23 | // If not Authenticator is set, then it will always return true. 24 | // 25 | // This is mostly used by autogenerated code, and you should not need to use it manually. 26 | func CheckAuthentication(authenticators []RequestAuthenticator, ctx Context) (err error) { 27 | 28 | if len(authenticators) == 0 { 29 | return nil 30 | } 31 | 32 | var action AuthAction 33 | for _, authenticator := range authenticators { 34 | 35 | action, err = authenticator.AuthenticateRequest(ctx) 36 | if err != nil { 37 | return err 38 | } 39 | 40 | switch action { 41 | case AuthActionOK: 42 | return nil 43 | case AuthActionKO: 44 | return elemental.NewError("Unauthorized", "You are not authorized to access this resource.", "bahamut", http.StatusUnauthorized) 45 | case AuthActionContinue: 46 | continue 47 | } 48 | } 49 | 50 | return nil 51 | } 52 | 53 | // CheckAuthorization checks if the current context has been authorized if there is any authorizer registered. 54 | // 55 | // If it is not authorized it stops the normal processing execution flow, and will write the Unauthorized response to the given writer. 56 | // If not Authorizer is set, then it will always return true. 57 | // 58 | // This is mostly used by autogenerated code, and you should not need to use it manually. 59 | func CheckAuthorization(authorizers []Authorizer, ctx Context) (err error) { 60 | 61 | if len(authorizers) == 0 { 62 | return nil 63 | } 64 | 65 | var action AuthAction 66 | for _, authorizer := range authorizers { 67 | 68 | action, err = authorizer.IsAuthorized(ctx) 69 | if err != nil { 70 | return err 71 | } 72 | 73 | switch action { 74 | case AuthActionOK: 75 | return nil 76 | case AuthActionKO: 77 | return elemental.NewError("Forbidden", "You are not allowed to access this resource.", "bahamut", http.StatusForbidden) 78 | case AuthActionContinue: 79 | continue 80 | } 81 | } 82 | 83 | return nil 84 | } 85 | -------------------------------------------------------------------------------- /authorizer/simple/authenticator.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package simple 13 | 14 | import ( 15 | "go.aporeto.io/bahamut" 16 | ) 17 | 18 | // CustomAuthRequestFunc is the type of functions that can be used to 19 | // decide custom authentication operations for requests. It returns a bahamut.AuthAction. 20 | type CustomAuthRequestFunc func(bahamut.Context) (bahamut.AuthAction, error) 21 | 22 | // CustomAuthSessionFunc is the type of functions that can be used to 23 | // decide custom authentication operations sessions. It returns a bahamut.AuthAction. 24 | type CustomAuthSessionFunc func(bahamut.Session) (bahamut.AuthAction, error) 25 | 26 | // A Authenticator is a bahamut.Authenticator compliant structure to authentify 27 | // requests using a given functions. 28 | type Authenticator struct { 29 | customAuthRequestFunc CustomAuthRequestFunc 30 | customAuthSessionFunc CustomAuthSessionFunc 31 | } 32 | 33 | // NewAuthenticator returns a new *Authenticator. 34 | func NewAuthenticator(customAuthRequestFunc CustomAuthRequestFunc, customAuthSessionFunc CustomAuthSessionFunc) *Authenticator { 35 | 36 | return &Authenticator{ 37 | customAuthSessionFunc: customAuthSessionFunc, 38 | customAuthRequestFunc: customAuthRequestFunc, 39 | } 40 | } 41 | 42 | // AuthenticateSession authenticates the given session. 43 | // It will return true if the authentication is a success, false in case of failure 44 | // and an eventual error in case of error. 45 | func (a *Authenticator) AuthenticateSession(session bahamut.Session) (bahamut.AuthAction, error) { 46 | 47 | if a.customAuthSessionFunc == nil { 48 | return bahamut.AuthActionContinue, nil 49 | } 50 | 51 | action, err := a.customAuthSessionFunc(session) 52 | if err != nil { 53 | return bahamut.AuthActionKO, err 54 | } 55 | 56 | return action, nil 57 | } 58 | 59 | // AuthenticateRequest authenticates the request from the given bahamut.Context. 60 | // It will return true if the authentication is a success, false in case of failure 61 | // and an eventual error in case of error. 62 | func (a *Authenticator) AuthenticateRequest(ctx bahamut.Context) (bahamut.AuthAction, error) { 63 | 64 | if a.customAuthRequestFunc == nil { 65 | return bahamut.AuthActionContinue, nil 66 | } 67 | 68 | action, err := a.customAuthRequestFunc(ctx) 69 | if err != nil { 70 | return bahamut.AuthActionKO, err 71 | } 72 | 73 | return action, nil 74 | } 75 | -------------------------------------------------------------------------------- /authorizer/simple/authorizer_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package simple 13 | 14 | import ( 15 | "fmt" 16 | "testing" 17 | 18 | // nolint:revive // Allow dot imports for readability in tests 19 | . "github.com/smartystreets/goconvey/convey" 20 | "go.aporeto.io/bahamut" 21 | ) 22 | 23 | func TestAuthorizer_NewAuthorizer(t *testing.T) { 24 | 25 | Convey("Given I call NewAuthorizer with one funcs", t, func() { 26 | 27 | f1 := func(bahamut.Context) (bahamut.AuthAction, error) { return bahamut.AuthActionOK, nil } 28 | 29 | auth := NewAuthorizer(f1) 30 | 31 | Convey("Then it should be correctly initialized", func() { 32 | So(auth.customAuthFunc, ShouldEqual, f1) 33 | 34 | }) 35 | }) 36 | } 37 | 38 | func TestAuthorizer_IsAuthorized(t *testing.T) { 39 | 40 | Convey("Given I call NewAuthorizer and a func that says ok", t, func() { 41 | 42 | f1 := func(bahamut.Context) (bahamut.AuthAction, error) { return bahamut.AuthActionOK, nil } 43 | 44 | auth := NewAuthorizer(f1) 45 | 46 | Convey("When I call IsAuthorized", func() { 47 | 48 | action, err := auth.IsAuthorized(nil) 49 | 50 | Convey("Then err should be nil", func() { 51 | So(err, ShouldBeNil) 52 | }) 53 | 54 | Convey("Then action should be OK", func() { 55 | So(action, ShouldEqual, bahamut.AuthActionOK) 56 | }) 57 | }) 58 | }) 59 | 60 | Convey("Given I call NewAuthorizer and no func", t, func() { 61 | 62 | auth := NewAuthorizer(nil) 63 | 64 | Convey("When I call IsAuthorized", func() { 65 | 66 | action, err := auth.IsAuthorized(nil) 67 | 68 | Convey("Then err should be nil", func() { 69 | So(err, ShouldBeNil) 70 | }) 71 | 72 | Convey("Then action should be Continue", func() { 73 | So(action, ShouldEqual, bahamut.AuthActionContinue) 74 | }) 75 | }) 76 | }) 77 | 78 | Convey("Given I call NewAuthorizer and a func that returns an error", t, func() { 79 | 80 | f1 := func(bahamut.Context) (bahamut.AuthAction, error) { return bahamut.AuthActionOK, fmt.Errorf("paf") } 81 | 82 | auth := NewAuthorizer(f1) 83 | 84 | Convey("When I call IsAuthorized", func() { 85 | 86 | action, err := auth.IsAuthorized(nil) 87 | 88 | Convey("Then err should not be nil", func() { 89 | So(err.Error(), ShouldEqual, "paf") 90 | }) 91 | 92 | Convey("Then action should be KO", func() { 93 | So(action, ShouldEqual, bahamut.AuthActionKO) 94 | }) 95 | }) 96 | }) 97 | } 98 | -------------------------------------------------------------------------------- /gateway/upstreamer/push/upstreamer_options_test.go: -------------------------------------------------------------------------------- 1 | package push 2 | 3 | import ( 4 | "math/rand" 5 | "testing" 6 | "time" 7 | 8 | // nolint:revive // Allow dot imports for readability in tests 9 | . "github.com/smartystreets/goconvey/convey" 10 | ) 11 | 12 | func Test_Options(t *testing.T) { 13 | 14 | c := newUpstreamConfig() 15 | 16 | Convey("Calling OptionExposePrivateAPIs should work", t, func() { 17 | OptionUpstreamerExposePrivateAPIs(true)(&c) 18 | So(c.exposePrivateAPIs, ShouldEqual, true) 19 | }) 20 | 21 | Convey("Calling OptionOverrideEndpointsAddresses should work", t, func() { 22 | OptionUpstreamerOverrideEndpointsAddresses("127.0.0.1:443")(&c) 23 | So(c.overrideEndpointAddress, ShouldEqual, "127.0.0.1:443") 24 | }) 25 | 26 | Convey("Calling OptionRegisterEventAPI should work", t, func() { 27 | OptionUpstreamerRegisterEventAPI("srva", "events")(&c) 28 | OptionUpstreamerRegisterEventAPI("srvb", "hello")(&c) 29 | So(len(c.eventsAPIs), ShouldEqual, 2) 30 | So(c.eventsAPIs["srva"], ShouldEqual, "events") 31 | So(c.eventsAPIs["srvb"], ShouldEqual, "hello") 32 | }) 33 | 34 | Convey("Calling OptionRequiredServices should work", t, func() { 35 | OptionRequiredServices([]string{"srv1"})(&c) 36 | So(c.requiredServices, ShouldResemble, []string{"srv1"}) 37 | }) 38 | 39 | Convey("Calling OptionServiceTimeout should work", t, func() { 40 | OptionUpstreamerServiceTimeout(time.Hour, time.Minute)(&c) 41 | So(c.serviceTimeout, ShouldEqual, time.Hour) 42 | So(c.serviceTimeoutCheckInterval, ShouldEqual, time.Minute) 43 | }) 44 | 45 | Convey("Calling OptionRandomizer should work", t, func() { 46 | rn := rand.New(rand.NewSource(time.Now().UnixNano())) 47 | OptionUpstreamerRandomizer(rn)(&c) 48 | So(c.randomizer, ShouldResemble, rn) 49 | }) 50 | 51 | Convey("Calling OptionUpstreamerPeersTimeout should work", t, func() { 52 | OptionUpstreamerPeersTimeout(time.Hour)(&c) 53 | So(c.peerTimeout, ShouldResemble, time.Hour) 54 | }) 55 | 56 | Convey("Calling OptionUpstreamerPeersCheckInterval should work", t, func() { 57 | OptionUpstreamerPeersCheckInterval(time.Hour)(&c) 58 | So(c.peerTimeoutCheckInterval, ShouldResemble, time.Hour) 59 | }) 60 | 61 | Convey("Calling OptionUpstreamerPeersPingInterval should work", t, func() { 62 | OptionUpstreamerPeersPingInterval(time.Hour)(&c) 63 | So(c.peerPingInterval, ShouldResemble, time.Hour) 64 | }) 65 | 66 | Convey("Calling OptionUpstreamerTokenRateLimiting should work", t, func() { 67 | OptionUpstreamerTokenRateLimiting(1, 2)(&c) 68 | So(c.tokenLimitingRPS, ShouldEqual, 1) 69 | So(c.tokenLimitingBurst, ShouldEqual, 2) 70 | 71 | So(func() { OptionUpstreamerTokenRateLimiting(0, 2)(&c) }, ShouldPanicWith, `rps cannot be <= 0`) 72 | So(func() { OptionUpstreamerTokenRateLimiting(1, 0)(&c) }, ShouldPanicWith, `burst cannot be <= 0`) 73 | }) 74 | 75 | Convey("Calling OptionUpstreamerGlobalServiceTopic should work", t, func() { 76 | OptionUpstreamerGlobalServiceTopic("global")(&c) 77 | So(c.globalServiceTopic, ShouldEqual, "global") 78 | }) 79 | } 80 | -------------------------------------------------------------------------------- /gateway/limiter.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "errors" 5 | "net/http" 6 | "time" 7 | 8 | "github.com/karlseguin/ccache/v2" 9 | "golang.org/x/time/rate" 10 | ) 11 | 12 | const maxCacheSize = 65536 13 | 14 | var errTooManyRequest = errors.New("Please retry in a moment") 15 | 16 | type sourceLimiter struct { 17 | nextHTTP http.Handler 18 | nextWS http.Handler 19 | sourceExtractor SourceExtractor 20 | rateExtractor RateExtractor 21 | metricManager LimiterMetricManager 22 | rls *ccache.Cache 23 | errorHandler *errorHandler 24 | defaultLimit rate.Limit 25 | defaultBurst int 26 | } 27 | 28 | func newSourceLimiter( 29 | nextHTTP http.Handler, 30 | nextWS http.Handler, 31 | defaultLimit rate.Limit, 32 | defaultBurst int, 33 | sourceExtractor SourceExtractor, 34 | rateExtractor RateExtractor, 35 | errorHandler *errorHandler, 36 | metricManager LimiterMetricManager, 37 | ) *sourceLimiter { 38 | 39 | if errorHandler == nil { 40 | panic("errorHandler must not be nil") 41 | } 42 | 43 | if sourceExtractor == nil { 44 | panic("sourceExtractor must not be nil") 45 | } 46 | 47 | return &sourceLimiter{ 48 | nextHTTP: nextHTTP, 49 | nextWS: nextWS, 50 | defaultLimit: defaultLimit, 51 | defaultBurst: defaultBurst, 52 | sourceExtractor: sourceExtractor, 53 | rateExtractor: rateExtractor, 54 | errorHandler: errorHandler, 55 | rls: ccache.New(ccache.Configure().MaxSize(maxCacheSize)), 56 | metricManager: metricManager, 57 | } 58 | } 59 | 60 | func (l *sourceLimiter) ServeHTTP(w http.ResponseWriter, req *http.Request) { 61 | 62 | key, err := l.sourceExtractor.ExtractSource(req) 63 | if err != nil { 64 | l.errorHandler.ServeHTTP(w, req, errTooManyRequest) 65 | return 66 | } 67 | 68 | var rl *rate.Limiter 69 | 70 | var limit rate.Limit 71 | var burst int 72 | 73 | if l.rateExtractor != nil { 74 | limit, burst, err = l.rateExtractor.ExtractRates(req) 75 | if err != nil { 76 | l.errorHandler.ServeHTTP(w, req, errTooManyRequest) 77 | return 78 | } 79 | } else { 80 | limit = l.defaultLimit 81 | burst = l.defaultBurst 82 | } 83 | 84 | if item := l.rls.Get(key); item == nil || item.Value() == nil || item.Expired() { 85 | rl = rate.NewLimiter(limit, burst) 86 | l.rls.Set(key, rl, time.Hour) 87 | } else { 88 | rl = item.Value().(*rate.Limiter) 89 | } 90 | 91 | if rl.Limit() != limit { 92 | rl.SetLimit(limit) 93 | } 94 | if rl.Burst() != burst { 95 | rl.SetBurst(burst) 96 | } 97 | 98 | if !rl.Allow() { 99 | l.errorHandler.ServeHTTP(w, req, errTooManyRequest) 100 | if l.metricManager != nil { 101 | l.metricManager.RegisterLimitedConnection() 102 | } 103 | return 104 | } 105 | 106 | if l.metricManager != nil { 107 | l.metricManager.RegisterAcceptedConnection() 108 | } 109 | 110 | if req.Header.Get(internalWSMarkingHeader) != "" { 111 | l.nextWS.ServeHTTP(w, req) 112 | } else { 113 | l.nextHTTP.ServeHTTP(w, req) 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /gateway/rewriters.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "encoding/pem" 5 | "fmt" 6 | "net" 7 | "net/http" 8 | "net/http/httputil" 9 | "strings" 10 | 11 | "go.aporeto.io/tg/tglib" 12 | "go.uber.org/zap" 13 | ) 14 | 15 | const internalWSMarkingHeader = "__internal_ws__" 16 | 17 | type requestRewriter struct { 18 | customRewriter RequestRewriter 19 | blockOpenTracing bool 20 | private bool 21 | trustForwardHeader bool 22 | } 23 | 24 | func (s *requestRewriter) Rewrite(r *httputil.ProxyRequest) { 25 | 26 | if s.customRewriter != nil { 27 | if err := s.customRewriter(r, s.private); err != nil { 28 | zap.L().Error("Unable rewrite request with custom rewriter", zap.Error(err)) 29 | panic(fmt.Sprintf("unable to rewrite request with custom rewriter: %s", err)) // panic are recovered from oxy 30 | } 31 | } 32 | 33 | if s.blockOpenTracing { 34 | r.Out.Header.Del("X-B3-TraceID") 35 | r.Out.Header.Del("X-B3-SpanID") 36 | r.Out.Header.Del("X-B3-ParentSpanID") 37 | r.Out.Header.Del("X-B3-Sampled") 38 | r.Out.Header.Del("Uber-Trace-ID") 39 | r.Out.Header.Del("Jaeger-Baggage") 40 | r.Out.Header.Del("TraceParent") 41 | r.Out.Header.Del("TraceState") 42 | } 43 | 44 | // If we trust the forward headers, we backport the ones from 45 | // the inbound request to the outbound request. 46 | // Otherwise, per documentation, they have already been removed 47 | // from the outbound request. 48 | if s.trustForwardHeader { 49 | r.Out.Header["X-Forwarded-For"] = r.In.Header["X-Forwarded-For"] 50 | r.Out.Header["X-Forwarded-Proto"] = r.In.Header["X-Forwarded-Proto"] 51 | r.Out.Header["X-Forwarded-Host"] = r.In.Header["X-Forwarded-Host"] 52 | } 53 | 54 | // Now, if we reach here, and still have no X-Forwarded-For, we set them 55 | // using the inbound request client IP. 56 | if r.Out.Header.Get("X-Forwarded-For") == "" { 57 | if clientIP, _, err := net.SplitHostPort(r.In.RemoteAddr); err == nil { 58 | r.Out.Header.Set("X-Forwarded-For", clientIP) 59 | r.Out.Header.Set("X-Forwarded-Host", r.In.Host) 60 | r.Out.Header.Set("X-Forwarded-Proto", r.In.Proto) 61 | } 62 | } 63 | 64 | // Here we delete the internalWSMarkingHeader if it has 65 | // been set. 66 | if r.In.Header.Get(internalWSMarkingHeader) != "" { 67 | r.Out.Header.Del(internalWSMarkingHeader) 68 | } 69 | 70 | if r.In.TLS != nil && len(r.In.TLS.PeerCertificates) == 1 { 71 | 72 | block, err := tglib.CertToPEM(r.In.TLS.PeerCertificates[0]) 73 | if err != nil { 74 | zap.L().Error("Unable to handle client TLS certificate", zap.Error(err)) 75 | panic(fmt.Sprintf("unable to handle client TLS certificate: %s", err)) // panic are recovered from oxy 76 | } 77 | 78 | r.Out.Header.Add("X-TLS-Client-Certificate", strings.ReplaceAll(string(pem.EncodeToMemory(block)), "\n", " ")) 79 | } 80 | } 81 | 82 | type circuitBreakerHandler struct{} 83 | 84 | func (h *circuitBreakerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 85 | writeError(w, r, makeError(http.StatusServiceUnavailable, "Service Unavailable", "The service is busy handling requests. Please retry in a moment")) 86 | } 87 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module go.aporeto.io/bahamut 2 | 3 | go 1.22 4 | 5 | toolchain go1.22.2 6 | 7 | require ( 8 | go.aporeto.io/elemental v1.123.1-0.20240822212917-6f8c7be6698c 9 | go.aporeto.io/tg v1.50.0 10 | go.aporeto.io/wsc v1.51.0 11 | ) 12 | 13 | require ( 14 | github.com/NYTimes/gziphandler v1.1.1 15 | github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a 16 | github.com/cespare/xxhash v1.1.0 17 | github.com/go-zoo/bone v1.3.0 18 | github.com/gofrs/uuid v4.4.0+incompatible 19 | github.com/golang/mock v1.6.0 20 | github.com/gorilla/websocket v1.5.0 21 | github.com/karlseguin/ccache/v2 v2.0.8 22 | github.com/mailgun/multibuf v0.1.2 23 | github.com/nats-io/nats-server/v2 v2.9.11 24 | github.com/nats-io/nats.go v1.23.0 25 | github.com/opentracing/opentracing-go v1.2.0 26 | github.com/prometheus/client_golang v1.14.0 27 | github.com/shirou/gopsutil/v3 v3.23.1 28 | github.com/smartystreets/goconvey v1.7.2 29 | github.com/valyala/tcplisten v1.0.0 30 | github.com/vulcand/oxy/v2 v2.0.0-20221121151423-d5cb734e4467 31 | go.uber.org/zap v1.24.0 32 | golang.org/x/time v0.3.0 33 | ) 34 | 35 | require ( 36 | github.com/HdrHistogram/hdrhistogram-go v1.1.2 // indirect 37 | github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de // indirect 38 | github.com/beorn7/perks v1.0.1 // indirect 39 | github.com/cespare/xxhash/v2 v2.2.0 // indirect 40 | github.com/go-ole/go-ole v1.2.6 // indirect 41 | github.com/golang/protobuf v1.5.2 // indirect 42 | github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 // indirect 43 | github.com/gravitational/trace v1.2.1 // indirect 44 | github.com/jonboulle/clockwork v0.3.0 // indirect 45 | github.com/jtolds/gls v4.20.0+incompatible // indirect 46 | github.com/klauspost/compress v1.15.11 // indirect 47 | github.com/lufia/plan9stats v0.0.0-20230110061619-bbe2e5e100de // indirect 48 | github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect 49 | github.com/minio/highwayhash v1.0.2 // indirect 50 | github.com/mitchellh/copystructure v1.2.0 // indirect 51 | github.com/mitchellh/reflectwalk v1.0.2 // indirect 52 | github.com/nats-io/jwt/v2 v2.3.0 // indirect 53 | github.com/nats-io/nkeys v0.3.0 // indirect 54 | github.com/nats-io/nuid v1.0.1 // indirect 55 | github.com/power-devops/perfstat v0.0.0-20221212215047-62379fc7944b // indirect 56 | github.com/prometheus/client_model v0.3.0 // indirect 57 | github.com/prometheus/common v0.39.0 // indirect 58 | github.com/prometheus/procfs v0.9.0 // indirect 59 | github.com/sirupsen/logrus v1.9.3 // indirect 60 | github.com/smartystreets/assertions v1.2.0 // indirect 61 | github.com/tklauser/go-sysconf v0.3.11 // indirect 62 | github.com/tklauser/numcpus v0.6.0 // indirect 63 | github.com/ugorji/go/codec v1.2.9 // indirect 64 | github.com/vulcand/predicate v1.2.0 // indirect 65 | github.com/yusufpapurcu/wmi v1.2.2 // indirect 66 | go.mongodb.org/mongo-driver v1.16.0 // indirect 67 | go.uber.org/atomic v1.10.0 // indirect 68 | go.uber.org/multierr v1.9.0 // indirect 69 | golang.org/x/crypto v0.22.0 // indirect 70 | golang.org/x/net v0.21.0 // indirect 71 | golang.org/x/sys v0.19.0 // indirect 72 | golang.org/x/term v0.19.0 // indirect 73 | google.golang.org/protobuf v1.28.1 // indirect 74 | ) 75 | -------------------------------------------------------------------------------- /websocket_push_session_mock.go: -------------------------------------------------------------------------------- 1 | package bahamut 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "net/http" 7 | 8 | "go.aporeto.io/elemental" 9 | ) 10 | 11 | var _ PushSession = &MockSession{} 12 | 13 | // A MockSession can be used to mock a bahamut.Session. 14 | type MockSession struct { 15 | MockMetadata any 16 | MockClaimsMap map[string]string 17 | MockCookies map[string]*http.Cookie 18 | MockHeaders map[string]string 19 | MockParameters map[string]string 20 | MockPushConfig *elemental.PushConfig 21 | MockTLSConnectionState *tls.ConnectionState 22 | MockDirectPush func(...*elemental.Event) 23 | MockClientIP string 24 | MockIdentifier string 25 | MockToken string 26 | MockClaims []string 27 | } 28 | 29 | // NewMockSession returns a new MockSession. 30 | func NewMockSession() *MockSession { 31 | return &MockSession{ 32 | MockClaimsMap: map[string]string{}, 33 | MockCookies: map[string]*http.Cookie{}, 34 | MockHeaders: map[string]string{}, 35 | MockParameters: map[string]string{}, 36 | } 37 | } 38 | 39 | // Cookie is part of the Session interface. 40 | func (s *MockSession) Cookie(c string) (*http.Cookie, error) { 41 | 42 | v, ok := s.MockCookies[c] 43 | if !ok { 44 | return nil, http.ErrNoCookie 45 | } 46 | 47 | return v, nil 48 | } 49 | 50 | // DirectPush is part of the PushSession interface 51 | func (s *MockSession) DirectPush(evts ...*elemental.Event) { 52 | if s.MockDirectPush != nil { 53 | s.MockDirectPush(evts...) 54 | } 55 | } 56 | 57 | // Identifier is part of the PushSession interface. 58 | func (s *MockSession) Identifier() string { return s.MockIdentifier } 59 | 60 | // Parameter is part of the PushSession interface. 61 | func (s *MockSession) Parameter(k string) string { return s.MockParameters[k] } 62 | 63 | // Header is part of the PushSession interface. 64 | func (s *MockSession) Header(k string) string { return s.MockHeaders[k] } 65 | 66 | // PushConfig is part of the PushSession interface. 67 | func (s *MockSession) PushConfig() *elemental.PushConfig { return s.MockPushConfig } 68 | 69 | // SetClaims is part of the PushSession interface. 70 | func (s *MockSession) SetClaims(claims []string) { s.MockClaims = claims } 71 | 72 | // Claims is part of the PushSession interface. 73 | func (s *MockSession) Claims() []string { return s.MockClaims } 74 | 75 | // ClaimsMap is part of the PushSession interface. 76 | func (s *MockSession) ClaimsMap() map[string]string { return s.MockClaimsMap } 77 | 78 | // Token is part of the PushSession interface. 79 | func (s *MockSession) Token() string { return s.MockToken } 80 | 81 | // TLSConnectionState is part of the PushSession interface. 82 | func (s *MockSession) TLSConnectionState() *tls.ConnectionState { return s.MockTLSConnectionState } 83 | 84 | // Metadata is part of the PushSession interface. 85 | func (s *MockSession) Metadata() any { return s.MockMetadata } 86 | 87 | // SetMetadata is part of the PushSession interface. 88 | func (s *MockSession) SetMetadata(m any) { s.MockMetadata = m } 89 | 90 | // Context is part of the PushSession interface. 91 | func (s *MockSession) Context() context.Context { return context.Background() } 92 | 93 | // ClientIP is part of the PushSession interface. 94 | func (s *MockSession) ClientIP() string { return s.MockClientIP } 95 | -------------------------------------------------------------------------------- /job_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "context" 16 | "errors" 17 | "sync" 18 | "testing" 19 | "time" 20 | 21 | // nolint:revive // Allow dot imports for readability in tests 22 | . "github.com/smartystreets/goconvey/convey" 23 | ) 24 | 25 | func TestJob_RunJob(t *testing.T) { 26 | 27 | Convey("Given I have a context and a job func to run", t, func() { 28 | 29 | var called int 30 | l := &sync.Mutex{} 31 | 32 | ctx, cancel := context.WithCancel(context.Background()) 33 | defer cancel() 34 | 35 | j := func() error { 36 | l.Lock() 37 | called++ 38 | l.Unlock() 39 | return nil 40 | } 41 | 42 | Convey("When I call RunJob", func() { 43 | 44 | interrupted, err := RunJob(ctx, j) 45 | 46 | Convey("Then interrupted should be false", func() { 47 | l.Lock() 48 | defer l.Unlock() 49 | So(interrupted, ShouldBeFalse) 50 | So(err, ShouldBeNil) 51 | So(called, ShouldEqual, 1) 52 | 53 | }) 54 | }) 55 | }) 56 | 57 | Convey("Given I have a context and a job func to run that returns an error", t, func() { 58 | 59 | var called int 60 | 61 | l := &sync.Mutex{} 62 | ctx, cancel := context.WithCancel(context.Background()) 63 | defer cancel() 64 | 65 | j := func() error { 66 | l.Lock() 67 | called++ 68 | l.Unlock() 69 | return errors.New("oops") 70 | } 71 | 72 | Convey("When I call RunJob", func() { 73 | 74 | interrupted, err := RunJob(ctx, j) 75 | 76 | Convey("Then interrupted should be false", func() { 77 | l.Lock() 78 | defer l.Unlock() 79 | So(interrupted, ShouldBeFalse) 80 | So(err, ShouldNotBeNil) 81 | So(called, ShouldEqual, 1) 82 | }) 83 | }) 84 | }) 85 | 86 | Convey("Given I have a context and a job func to run that I cancel", t, func() { 87 | 88 | var called int 89 | l := &sync.Mutex{} 90 | l2 := &sync.Mutex{} 91 | 92 | ctx, cancel := context.WithCancel(context.Background()) 93 | 94 | j := func() error { 95 | time.Sleep(300 * time.Millisecond) 96 | l.Lock() 97 | called++ 98 | l.Unlock() 99 | return errors.New("oops") 100 | } 101 | 102 | Convey("When I call RunJob", func() { 103 | 104 | var interrupted bool 105 | var err error 106 | 107 | go func() { 108 | l2.Lock() 109 | interrupted, err = RunJob(ctx, j) 110 | l2.Unlock() 111 | }() 112 | time.Sleep(30 * time.Millisecond) 113 | cancel() 114 | 115 | Convey("Then interrupted should be false", func() { 116 | l.Lock() 117 | l2.Lock() 118 | defer l.Unlock() 119 | defer l2.Unlock() 120 | So(interrupted, ShouldBeTrue) 121 | So(err, ShouldBeNil) 122 | So(called, ShouldEqual, 0) 123 | }) 124 | }) 125 | }) 126 | } 127 | -------------------------------------------------------------------------------- /health_server.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "context" 16 | "net/http" 17 | "strings" 18 | "time" 19 | 20 | "go.uber.org/zap" 21 | ) 22 | 23 | // an healthServer is the structure serving the health check endpoint. 24 | type healthServer struct { 25 | server *http.Server 26 | cfg config 27 | } 28 | 29 | // newHealthServer returns a new healthServer. 30 | func newHealthServer(cfg config) *healthServer { 31 | 32 | s := &healthServer{ 33 | cfg: cfg, 34 | server: &http.Server{Addr: cfg.healthServer.listenAddress}, 35 | } 36 | 37 | s.server.Handler = s 38 | 39 | return s 40 | } 41 | 42 | func (s *healthServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { 43 | 44 | if r.Method != http.MethodGet { 45 | http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) 46 | return 47 | } 48 | 49 | switch r.URL.Path { 50 | 51 | case "/": 52 | 53 | if s.cfg.healthServer.healthHandler == nil { 54 | w.WriteHeader(http.StatusNoContent) 55 | return 56 | } 57 | 58 | if err := s.cfg.healthServer.healthHandler(); err != nil { 59 | w.WriteHeader(http.StatusInternalServerError) 60 | return 61 | } 62 | 63 | w.WriteHeader(http.StatusNoContent) 64 | 65 | case "/metrics": 66 | if s.cfg.healthServer.metricsManager == nil { 67 | w.WriteHeader(http.StatusNotImplemented) 68 | return 69 | } 70 | 71 | s.cfg.healthServer.metricsManager.Write(w, r) 72 | 73 | default: 74 | 75 | if s.cfg.healthServer.customStats == nil { 76 | http.Error(w, "Not Found", http.StatusNotFound) 77 | return 78 | } 79 | 80 | f := s.cfg.healthServer.customStats[strings.TrimPrefix(r.URL.Path, "/")] 81 | if f == nil { 82 | http.Error(w, "Not Found", http.StatusNotFound) 83 | return 84 | } 85 | 86 | f(w, r) 87 | } 88 | } 89 | 90 | func (s *healthServer) start(ctx context.Context) { 91 | 92 | zap.L().Debug("Health server enabled", zap.String("listen", s.cfg.healthServer.listenAddress)) 93 | 94 | go func() { 95 | if err := s.server.ListenAndServe(); err != nil { 96 | if err == http.ErrServerClosed { 97 | return 98 | } 99 | zap.L().Fatal("Unable to start health server", zap.Error(err)) 100 | } 101 | }() 102 | 103 | zap.L().Info("Health server started", zap.String("address", s.cfg.healthServer.listenAddress)) 104 | 105 | <-ctx.Done() 106 | } 107 | 108 | func (s *healthServer) stop() context.Context { 109 | 110 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) 111 | 112 | go func() { 113 | defer cancel() 114 | if err := s.server.Shutdown(ctx); err != nil { 115 | zap.L().Error("Could not gracefully stop health server", zap.Error(err)) 116 | } else { 117 | zap.L().Debug("Health server stopped") 118 | } 119 | }() 120 | 121 | return ctx 122 | } 123 | -------------------------------------------------------------------------------- /utils.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "context" 16 | "fmt" 17 | "net/http" 18 | "os" 19 | "runtime/debug" 20 | "strings" 21 | 22 | opentracing "github.com/opentracing/opentracing-go" 23 | "github.com/opentracing/opentracing-go/log" 24 | "go.aporeto.io/elemental" 25 | ) 26 | 27 | func handleRecoveredPanic(ctx context.Context, r any, disablePanicRecovery bool) error { 28 | 29 | if r == nil { 30 | return nil 31 | } 32 | 33 | err := elemental.NewError("Internal Server Error", fmt.Sprintf("panic: %v", r), "bahamut", http.StatusInternalServerError) 34 | 35 | st := string(debug.Stack()) 36 | 37 | // Print the panic as it would have happened 38 | fmt.Fprintf(os.Stderr, "panic: %s\n\n%s", err, st) // nolint: errcheck 39 | 40 | sp := opentracing.SpanFromContext(ctx) 41 | if sp != nil { 42 | sp.SetTag("error", true) 43 | sp.SetTag("panic", true) 44 | sp.LogFields( 45 | log.String("panic", fmt.Sprintf("%v", r)), 46 | log.String("stack", st), 47 | ) 48 | } 49 | 50 | if disablePanicRecovery { 51 | if sp != nil { 52 | sp.Finish() 53 | } 54 | panic(err) 55 | } 56 | 57 | return err 58 | } 59 | 60 | func extractSpanID(span opentracing.Span) string { 61 | 62 | spanID := "unknown" 63 | if stringer, ok := span.(fmt.Stringer); ok { 64 | spanID = strings.SplitN(stringer.String(), ":", 2)[0] 65 | } 66 | 67 | return spanID 68 | } 69 | 70 | func processError(ctx context.Context, err error) (outError elemental.Errors) { 71 | 72 | span := opentracing.SpanFromContext(ctx) 73 | 74 | outError = elemental.NewErrors(err).Trace(extractSpanID(span)) 75 | 76 | if span != nil { 77 | span.SetTag("error", true) 78 | span.SetTag("status.code", outError.Code()) 79 | span.LogFields(log.Object("elemental.error", outError)) 80 | } 81 | 82 | return outError 83 | } 84 | 85 | func claimsToMap(claims []string) map[string]string { 86 | 87 | claimsMap := map[string]string{} 88 | 89 | var k, v string 90 | 91 | for _, claim := range claims { 92 | if err := splitPtr(claim, &k, &v); err != nil { 93 | panic(err) 94 | } 95 | claimsMap[k] = v 96 | } 97 | 98 | return claimsMap 99 | } 100 | 101 | func splitPtr(tag string, key *string, value *string) (err error) { 102 | 103 | l := len(tag) 104 | if l < 3 { 105 | err = fmt.Errorf("invalid tag: invalid length '%s'", tag) 106 | return 107 | } 108 | 109 | if tag[0] == '=' { 110 | err = fmt.Errorf("invalid tag: missing key '%s'", tag) 111 | return 112 | } 113 | 114 | for i := 0; i < l; i++ { 115 | if tag[i] == '=' { 116 | if i+1 >= l { 117 | return fmt.Errorf("invalid tag: missing value '%s'", tag) 118 | } 119 | *key = tag[:i] 120 | *value = tag[i+1:] 121 | return 122 | } 123 | } 124 | 125 | return fmt.Errorf("invalid tag: missing equal symbol '%s'", tag) 126 | } 127 | -------------------------------------------------------------------------------- /gateway/upstreamer/push/utils.go: -------------------------------------------------------------------------------- 1 | package push 2 | 3 | import ( 4 | "regexp" 5 | "strings" 6 | ) 7 | 8 | var vregexp = regexp.MustCompile(`/v/\d+`) 9 | 10 | func getTargetIdentity(path string) (string, string) { 11 | 12 | parts := strings.Split( 13 | strings.TrimPrefix( 14 | vregexp.ReplaceAllString(path, ""), 15 | "/", 16 | ), 17 | "/", 18 | ) 19 | 20 | prefix := "" 21 | if len(parts) > 1 && parts[0][0] == '_' { 22 | prefix = parts[0][1:] 23 | parts = append([]string{}, parts[1:]...) 24 | } 25 | 26 | switch len(parts) { 27 | 28 | case 1: 29 | return parts[0], prefix 30 | case 2: 31 | return parts[0], prefix 32 | default: 33 | return parts[2], prefix 34 | } 35 | } 36 | 37 | func pick(randomizer Randomizer, length int) (int, int) { 38 | 39 | if length < 2 { 40 | panic("pick: len must be greater than 2") 41 | } 42 | 43 | idxs := make([]int, length) 44 | for i := 0; i < length; i++ { 45 | idxs[i] = i 46 | } 47 | 48 | randomizer.Shuffle(length, func(i, j int) { idxs[i], idxs[j] = idxs[j], idxs[i] }) 49 | 50 | return idxs[0], idxs[1] 51 | } 52 | 53 | func handleAddServicePing(services servicesConfig, sp servicePing) bool { 54 | 55 | if sp.Status == entityStatusGoodbye { 56 | panic("handleAddServicePing received a goodbye service ping") 57 | } 58 | 59 | srv, ok := services[sp.Key()] 60 | if !ok { 61 | srv = newService(sp.Key()) 62 | services[sp.Key()] = srv 63 | } 64 | 65 | // In any case we poke the endpoint. This will 66 | // only do something if the endpoint is already 67 | // registered. 68 | defer srv.pokeEndpoint(sp.Endpoint, sp.Load) 69 | 70 | if srv.hasEndpoint(sp.Endpoint) { 71 | return false 72 | } 73 | 74 | // We update the info to the latest. 75 | srv.routes = sp.Routes 76 | srv.versions = sp.Versions 77 | 78 | // We register the new endpoint. 79 | srv.registerEndpoint(sp.Endpoint, sp.Load, sp.APILimiters) 80 | 81 | return true 82 | } 83 | 84 | func handleRemoveServicePing(services servicesConfig, sp servicePing) bool { 85 | 86 | if sp.Status == entityStatusHello { 87 | panic("handleRemoveServicePing received a hello service ping") 88 | } 89 | 90 | srv, ok := services[sp.Key()] 91 | if !ok { 92 | return false 93 | } 94 | 95 | if !srv.hasEndpoint(sp.Endpoint) { 96 | return false 97 | } 98 | 99 | srv.unregisterEndpoint(sp.Endpoint) 100 | 101 | if len(srv.getEndpoints()) > 0 { 102 | return true 103 | } 104 | 105 | delete(services, sp.Key()) 106 | 107 | return true 108 | } 109 | 110 | func resyncRoutes(services servicesConfig, includePrivate bool, events map[string]string) map[string][]*endpointInfo { 111 | 112 | apis := map[string][]*endpointInfo{} 113 | 114 | for serviceName, config := range services { 115 | 116 | name, prefix := extractPrefix(serviceName) 117 | 118 | for _, routes := range config.routes { 119 | for _, route := range routes { 120 | if !route.Private || includePrivate { 121 | apis[prefix+"/"+route.Identity] = append([]*endpointInfo{}, config.getEndpoints()...) 122 | } 123 | } 124 | } 125 | 126 | if api, ok := events[name]; ok { 127 | apis[prefix+"/"+api] = append([]*endpointInfo{}, config.getEndpoints()...) 128 | } 129 | } 130 | 131 | return apis 132 | } 133 | 134 | func extractPrefix(key string) (name string, prefix string) { 135 | 136 | name = key 137 | 138 | if parts := strings.SplitN(key, "/", 2); len(parts) == 2 { 139 | prefix = parts[0] 140 | name = parts[1] 141 | } 142 | 143 | return name, prefix 144 | } 145 | -------------------------------------------------------------------------------- /netlimiter_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2013 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package bahamut 6 | 7 | import ( 8 | "errors" 9 | "fmt" 10 | "io" 11 | "net" 12 | "net/http" 13 | "sync" 14 | "sync/atomic" 15 | "testing" 16 | "time" 17 | ) 18 | 19 | const defaultMaxOpenFiles = 256 20 | 21 | func TestLimitListener(t *testing.T) { 22 | const maximum = 5 23 | 24 | attempts := (defaultMaxOpenFiles - maximum) / 2 25 | if attempts > 256 { // maximum length of accept queue is 128 by default 26 | attempts = 256 27 | } 28 | 29 | l, err := net.Listen("tcp", "127.0.0.1:0") 30 | if err != nil { 31 | t.Fatal(err) 32 | } 33 | defer l.Close() // nolint 34 | l = newListener(l, maximum) 35 | 36 | var open int32 37 | // nolint 38 | go http.Serve(l, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 39 | if n := atomic.AddInt32(&open, 1); n > maximum { 40 | t.Errorf("%d open connections, want <= %d", n, maximum) 41 | } 42 | defer atomic.AddInt32(&open, -1) 43 | time.Sleep(10 * time.Millisecond) 44 | fmt.Fprint(w, "some body") 45 | })) 46 | 47 | var wg sync.WaitGroup 48 | var failed int32 49 | for i := 0; i < attempts; i++ { 50 | wg.Add(1) 51 | go func() { 52 | defer wg.Done() 53 | c := http.Client{Timeout: 3 * time.Second} 54 | r, err := c.Get("http://" + l.Addr().String()) 55 | if err != nil { 56 | if err == io.EOF { 57 | t.Log(err) 58 | atomic.AddInt32(&failed, 1) 59 | } 60 | return 61 | } 62 | defer r.Body.Close() // nolint 63 | io.Copy(io.Discard, r.Body) // nolint 64 | }() 65 | } 66 | wg.Wait() 67 | 68 | // We expect some Gets to fail as the kernel's accept queue is filled, 69 | // but most should succeed. 70 | if int(failed) >= attempts/2 { 71 | t.Errorf("%d requests failed within %d attempts", failed, attempts) 72 | } 73 | } 74 | 75 | type errorListener struct { 76 | net.Listener 77 | } 78 | 79 | func (errorListener) Accept() (net.Conn, error) { 80 | return nil, errFake 81 | } 82 | 83 | var errFake = errors.New("fake error from errorListener") 84 | 85 | // This used to hang. 86 | func TestLimitListenerError(t *testing.T) { 87 | donec := make(chan bool, 1) 88 | 89 | go func() { 90 | const n = 2 91 | ll := newListener(errorListener{}, 2) 92 | for i := 0; i < n+1; i++ { 93 | _, err := ll.Accept() 94 | if err != errFake { 95 | panic(fmt.Sprintf("Accept error = %v; want errFake", err)) 96 | } 97 | } 98 | donec <- true 99 | }() 100 | select { 101 | case <-donec: 102 | case <-time.After(5 * time.Second): 103 | t.Fatal("timeout. deadlock?") 104 | } 105 | } 106 | 107 | func TestLimitListenerClose(t *testing.T) { 108 | ln, err := net.Listen("tcp", "127.0.0.1:0") 109 | if err != nil { 110 | t.Fatal(err) 111 | } 112 | defer ln.Close() // nolint 113 | ln = newListener(ln, 1) 114 | 115 | doneCh := make(chan struct{}) 116 | defer close(doneCh) 117 | go func() { 118 | c, err := net.Dial("tcp", ln.Addr().String()) 119 | if err != nil { 120 | panic(err) 121 | } 122 | defer c.Close() // nolint 123 | <-doneCh 124 | }() 125 | 126 | c, err := ln.Accept() 127 | if err != nil { 128 | t.Fatal(err) 129 | } 130 | defer c.Close() // nolint 131 | 132 | acceptDone := make(chan struct{}) 133 | go func() { 134 | c, err := ln.Accept() 135 | if err == nil { 136 | c.Close() // nolint 137 | t.Errorf("Unexpected successful Accept()") 138 | } 139 | close(acceptDone) 140 | }() 141 | 142 | // Wait a tiny bit to ensure the Accept() is blocking. 143 | time.Sleep(10 * time.Millisecond) 144 | ln.Close() // nolint 145 | 146 | select { 147 | case <-acceptDone: 148 | case <-time.After(5 * time.Second): 149 | t.Fatalf("Accept() still blocking") 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /rest_server_helpers.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "fmt" 16 | "net/http" 17 | "strconv" 18 | "strings" 19 | 20 | "go.aporeto.io/elemental" 21 | "go.uber.org/zap" 22 | ) 23 | 24 | // Various common errors 25 | var ( 26 | ErrNotFound = elemental.NewError("Not Found", "Unable to find the requested resource", "bahamut", http.StatusNotFound) 27 | ErrRateLimit = elemental.NewError("Rate Limit", "You have exceeded your rate limit", "bahamut", http.StatusTooManyRequests) 28 | ) 29 | 30 | func setCommonHeader(w http.ResponseWriter, encoding elemental.EncodingType) { 31 | 32 | w.Header().Set("Accept", "application/msgpack,application/json") 33 | if encoding == elemental.EncodingTypeJSON { 34 | w.Header().Set("Content-Type", string(encoding)+"; charset=UTF-8") 35 | } else { 36 | w.Header().Set("Content-Type", string(encoding)) 37 | } 38 | } 39 | 40 | func makeNotFoundHandler(controller CORSPolicyController) func(w http.ResponseWriter, r *http.Request) { 41 | return func(w http.ResponseWriter, r *http.Request) { 42 | 43 | var corsPolicy *CORSPolicy 44 | if controller != nil { 45 | corsPolicy = controller.PolicyForRequest(r) 46 | } 47 | 48 | writeHTTPResponse( 49 | w, 50 | makeErrorResponse( 51 | r.Context(), 52 | elemental.NewResponse(elemental.NewRequest()), 53 | ErrNotFound, 54 | nil, 55 | nil, 56 | ), 57 | r.Header.Get("origin"), 58 | corsPolicy, 59 | ) 60 | } 61 | } 62 | 63 | // writeHTTPResponse writes the response into the given http.ResponseWriter. 64 | func writeHTTPResponse(w http.ResponseWriter, r *elemental.Response, origin string, accessControl *CORSPolicy) int { 65 | 66 | // If r is nil, we simply stop. 67 | // It mostly means the client closed the connection and 68 | // no response is needed. 69 | if r == nil { 70 | return 0 71 | } 72 | 73 | for _, cookie := range r.Cookies { 74 | http.SetCookie(w, cookie) 75 | } 76 | 77 | setCommonHeader(w, r.Request.Accept) 78 | 79 | if accessControl != nil { 80 | accessControl.Inject(w.Header(), origin, false) 81 | } 82 | 83 | if r.Redirect != "" { 84 | w.Header().Set("Location", r.Redirect) 85 | w.WriteHeader(http.StatusFound) 86 | return http.StatusFound 87 | } 88 | 89 | w.Header().Set("X-Count-Total", strconv.Itoa(r.Total)) 90 | 91 | if r.Next != "" { 92 | w.Header().Set("X-Next", r.Next) 93 | } 94 | 95 | if len(r.Messages) > 0 { 96 | w.Header().Set("X-Messages", strings.Join(r.Messages, ";")) 97 | } 98 | 99 | w.WriteHeader(r.StatusCode) 100 | 101 | if r.Data != nil { 102 | 103 | if _, err := w.Write(r.Data); err != nil { 104 | zap.L().Debug("Unable to send http response to client", zap.Error(err)) 105 | } 106 | } 107 | 108 | return r.StatusCode 109 | } 110 | 111 | // If the first one is "v" it means the next one has to be a int for the version number. 112 | func extractAPIVersion(path string) (version int, err error) { 113 | 114 | components := strings.SplitN(strings.TrimPrefix(path, "/"), "/", 3) 115 | if components[0] == "v" { 116 | version, err = strconv.Atoi(components[1]) 117 | if err != nil { 118 | return 0, fmt.Errorf("Invalid api version number '%s'", components[1]) 119 | } 120 | } 121 | 122 | return version, nil 123 | } 124 | -------------------------------------------------------------------------------- /cors.go: -------------------------------------------------------------------------------- 1 | package bahamut 2 | 3 | import ( 4 | "net/http" 5 | "strconv" 6 | "strings" 7 | ) 8 | 9 | // CORSOriginMirror instruts to mirror any incoming origin. 10 | // This should not be used in production as this is a development 11 | // feature that is not secure. 12 | const CORSOriginMirror = "_mirror_" 13 | 14 | // CORSPolicy allows to configure 15 | // CORS Access Control header of a response. 16 | type CORSPolicy struct { 17 | additionalOrigins map[string]struct{} 18 | AllowOrigin string 19 | AllowHeaders []string 20 | AllowMethods []string 21 | ExposeHeaders []string 22 | MaxAge int 23 | AllowCredentials bool 24 | } 25 | 26 | type corsPolicyController struct { 27 | policy *CORSPolicy 28 | } 29 | 30 | // NewDefaultCORSController returns a CORSPolicyController that always returns a CORSAccessControlPolicy 31 | // with sensible defaults. 32 | func NewDefaultCORSController(origin string, additionalOrigins []string) CORSPolicyController { 33 | 34 | additionalOriginsMap := make(map[string]struct{}, len(additionalOrigins)) 35 | if len(additionalOrigins) > 0 { 36 | for _, o := range additionalOrigins { 37 | additionalOriginsMap[o] = struct{}{} 38 | } 39 | } 40 | 41 | return &corsPolicyController{ 42 | policy: &CORSPolicy{ 43 | AllowOrigin: origin, 44 | additionalOrigins: additionalOriginsMap, 45 | AllowCredentials: true, 46 | MaxAge: 1500, 47 | AllowHeaders: []string{ 48 | "Authorization", 49 | "Accept", 50 | "Content-Type", 51 | "Cache-Control", 52 | "Cookie", 53 | "If-Modified-Since", 54 | "X-Requested-With", 55 | "X-Count-Total", 56 | "X-Namespace", 57 | "X-External-Tracking-Type", 58 | "X-External-Tracking-ID", 59 | "X-TLS-Client-Certificate", 60 | "Accept-Encoding", 61 | "X-Fields", 62 | "X-Read-Consistency", 63 | "X-Write-Consistency", 64 | "Idempotency-Key", 65 | }, 66 | AllowMethods: []string{ 67 | "GET", 68 | "POST", 69 | "PUT", 70 | "DELETE", 71 | "PATCH", 72 | "HEAD", 73 | "OPTIONS", 74 | }, 75 | ExposeHeaders: []string{ 76 | "X-Requested-With", 77 | "X-Count-Total", 78 | "X-Namespace", 79 | "X-Messages", 80 | "X-Fields", 81 | "X-Next", 82 | }, 83 | }, 84 | } 85 | } 86 | 87 | func (c *corsPolicyController) PolicyForRequest(*http.Request) *CORSPolicy { 88 | return c.policy 89 | } 90 | 91 | // Inject injects the CORS header on the given http.Header. It will use 92 | // the given request origin to determine the allow origin policy and the method 93 | // to determine if it should inject pre-flight OPTIONS header. 94 | // If the given http.Header is nil, this function is a no op. 95 | func (a *CORSPolicy) Inject(h http.Header, origin string, preflight bool) { 96 | 97 | if h == nil { 98 | return 99 | } 100 | 101 | corsOrigin := a.AllowOrigin 102 | 103 | switch { 104 | case a.AllowOrigin == "*": 105 | corsOrigin = "*" 106 | 107 | case a.AllowOrigin == CORSOriginMirror && origin != "": 108 | corsOrigin = origin 109 | 110 | case a.AllowOrigin == CORSOriginMirror && origin == "": 111 | corsOrigin = "" 112 | 113 | case func() bool { _, ok := a.additionalOrigins[origin]; return ok }(): 114 | corsOrigin = origin 115 | } 116 | 117 | if preflight { 118 | h.Set("Access-Control-Allow-Headers", strings.Join(a.AllowHeaders, ", ")) 119 | h.Set("Access-Control-Allow-Methods", strings.Join(a.AllowMethods, ", ")) 120 | h.Set("Access-Control-Max-Age", strconv.Itoa(a.MaxAge)) 121 | } 122 | 123 | if corsOrigin != "" { 124 | h.Set("Access-Control-Allow-Origin", corsOrigin) 125 | } 126 | 127 | h.Set("Access-Control-Expose-Headers", strings.Join(a.ExposeHeaders, ", ")) 128 | 129 | if a.AllowCredentials && corsOrigin != "*" && corsOrigin != "" { 130 | h.Set("Access-Control-Allow-Credentials", "true") 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /gateway/listener_test.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | // nolint:revive // Allow dot imports for readability in tests 11 | . "github.com/smartystreets/goconvey/convey" 12 | ) 13 | 14 | type fakeConn struct { 15 | closed bool 16 | } 17 | 18 | func (c *fakeConn) Read(b []byte) (n int, err error) { return 0, nil } 19 | 20 | func (c *fakeConn) Write(b []byte) (n int, err error) { return 0, nil } 21 | 22 | func (c *fakeConn) Close() error { c.closed = true; return nil } 23 | 24 | func (c *fakeConn) LocalAddr() net.Addr { return nil } 25 | 26 | func (c *fakeConn) RemoteAddr() net.Addr { return nil } 27 | 28 | func (c *fakeConn) SetDeadline(t time.Time) error { return nil } 29 | 30 | func (c *fakeConn) SetReadDeadline(t time.Time) error { return nil } 31 | 32 | func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil } 33 | 34 | type fakeListener struct { 35 | conn func() net.Conn 36 | acceptError error 37 | } 38 | 39 | func (l *fakeListener) Accept() (net.Conn, error) { 40 | 41 | if l.acceptError != nil { 42 | return nil, l.acceptError 43 | } 44 | 45 | return l.conn(), nil 46 | } 47 | 48 | func (l *fakeListener) Addr() net.Addr { 49 | return nil 50 | } 51 | 52 | func (l *fakeListener) Close() error { 53 | return nil 54 | } 55 | 56 | type fakeListenerLimiterMetricManager struct { 57 | sync.Mutex 58 | accepted int 59 | rejected int 60 | total int 61 | } 62 | 63 | func (m *fakeListenerLimiterMetricManager) RegisterAcceptedConnection() { 64 | m.Lock() 65 | m.total = m.total + 1 66 | m.accepted = m.accepted + 1 67 | m.Unlock() 68 | } 69 | 70 | func (m *fakeListenerLimiterMetricManager) RegisterLimitedConnection() { 71 | m.Lock() 72 | m.total = m.total + 1 73 | m.rejected = m.rejected + 1 74 | m.Unlock() 75 | } 76 | 77 | func TestLimitLimiter(t *testing.T) { 78 | 79 | Convey("Given I call newLimitedListener", t, func() { 80 | 81 | l := &fakeListener{ 82 | conn: func() net.Conn { return &fakeConn{} }, 83 | } 84 | 85 | mm := &fakeListenerLimiterMetricManager{} 86 | 87 | ll := newLimitedListener(l, 2.0, 1, mm) 88 | 89 | Convey("When I call Accept and it works", func() { 90 | 91 | c, err := ll.Accept() 92 | 93 | Convey("Then err should be nil", func() { 94 | So(err, ShouldBeNil) 95 | }) 96 | 97 | Convey("Then c should be correct", func() { 98 | So(c, ShouldNotBeNil) 99 | So(mm.total, ShouldBeGreaterThanOrEqualTo, 1) 100 | So(mm.accepted, ShouldBeGreaterThanOrEqualTo, 1) 101 | So(mm.rejected, ShouldEqual, 0) 102 | }) 103 | }) 104 | 105 | Convey("When I call Accept but underlying listener is returning an error", func() { 106 | 107 | l.acceptError = fmt.Errorf("boom") 108 | 109 | c, err := ll.Accept() 110 | 111 | Convey("Then err should be nil", func() { 112 | So(err, ShouldNotBeNil) 113 | So(err.Error(), ShouldEqual, "boom") 114 | }) 115 | 116 | Convey("Then c should be correct", func() { 117 | So(c, ShouldBeNil) 118 | }) 119 | }) 120 | 121 | Convey("When I spam Accept I should get rate limited", func() { 122 | 123 | // send a bunch of Accept to excite the rate limiter 124 | go func() { _, _ = ll.Accept() }() 125 | go func() { _, _ = ll.Accept() }() 126 | go func() { _, _ = ll.Accept() }() 127 | go func() { _, _ = ll.Accept() }() 128 | go func() { _, _ = ll.Accept() }() 129 | go func() { _, _ = ll.Accept() }() 130 | go func() { _, _ = ll.Accept() }() 131 | go func() { _, _ = ll.Accept() }() 132 | go func() { _, _ = ll.Accept() }() 133 | 134 | time.Sleep(300 * time.Millisecond) 135 | 136 | // this one should be closed because rate limited 137 | conn, _ := ll.Accept() 138 | 139 | Convey("Then err should be nil", func() { 140 | So(conn.(*fakeConn).closed, ShouldBeTrue) 141 | So(mm.total, ShouldBeGreaterThanOrEqualTo, 1) 142 | So(mm.accepted, ShouldBeGreaterThanOrEqualTo, 1) 143 | So(mm.rejected, ShouldBeGreaterThanOrEqualTo, 1) 144 | }) 145 | }) 146 | }) 147 | } 148 | -------------------------------------------------------------------------------- /gateway/upstreamer/push/services_test.go: -------------------------------------------------------------------------------- 1 | package push 2 | 3 | import ( 4 | "sort" 5 | "strings" 6 | "testing" 7 | "time" 8 | 9 | // nolint:revive // Allow dot imports for readability in tests 10 | . "github.com/smartystreets/goconvey/convey" 11 | "golang.org/x/time/rate" 12 | ) 13 | 14 | func Test_Services(t *testing.T) { 15 | 16 | Convey("Given I create a new service", t, func() { 17 | 18 | srv := newService("mysrv") 19 | 20 | Convey("Then srv should be correctly initialized", func() { 21 | So(srv.name, ShouldEqual, "mysrv") 22 | So(srv.endpoints, ShouldResemble, map[string]*endpointInfo{}) 23 | }) 24 | 25 | Convey("When I register two endpoints", func() { 26 | 27 | rls1 := IdentityToAPILimitersRegistry{ 28 | "identity-a": {Limit: 10, Burst: 20}, 29 | "identity-b": {Limit: 11, Burst: 21}, 30 | } 31 | rls2 := IdentityToAPILimitersRegistry{ 32 | "identity-c": {Limit: 100, Burst: 200}, 33 | } 34 | 35 | srv.registerEndpoint("1.1.1.1:4443", 0.3, rls1) 36 | srv.registerEndpoint("2.2.2.2:4443", 0.4, rls2) 37 | 38 | Convey("Then they should be registered", func() { 39 | 40 | eps := srv.getEndpoints() 41 | 42 | // sort since it is stored as a map 43 | sort.Slice(eps, func(i, j int) bool { 44 | return strings.Compare(eps[i].address, eps[j].address) == -1 45 | }) 46 | 47 | So(srv.hasEndpoint("1.1.1.1:4443"), ShouldBeTrue) 48 | So(srv.hasEndpoint("2.2.2.2:4443"), ShouldBeTrue) 49 | 50 | So(len(eps), ShouldEqual, 2) 51 | So(eps[0].address, ShouldEqual, "1.1.1.1:4443") 52 | So(eps[0].lastLoad, ShouldEqual, 0.3) 53 | So(eps[0].limiters, ShouldEqual, rls1) 54 | So(eps[0].lastSeen.Round(time.Second), ShouldEqual, time.Now().Round(time.Second)) 55 | So(eps[0].limiters["identity-a"].limiter, ShouldHaveSameTypeAs, &rate.Limiter{}) 56 | So(eps[0].limiters["identity-a"].limiter.Limit(), ShouldEqual, rate.Limit(10)) 57 | So(eps[0].limiters["identity-a"].limiter.Burst(), ShouldEqual, rate.Limit(20)) 58 | So(eps[0].limiters["identity-b"].limiter, ShouldHaveSameTypeAs, &rate.Limiter{}) 59 | So(eps[0].limiters["identity-b"].limiter.Limit(), ShouldEqual, rate.Limit(11)) 60 | So(eps[0].limiters["identity-b"].limiter.Burst(), ShouldEqual, rate.Limit(21)) 61 | 62 | So(eps[1].address, ShouldEqual, "2.2.2.2:4443") 63 | So(eps[1].lastLoad, ShouldEqual, 0.4) 64 | So(eps[1].lastSeen.Round(time.Second), ShouldEqual, time.Now().Round(time.Second)) 65 | So(eps[1].limiters, ShouldEqual, rls2) 66 | So(eps[1].limiters["identity-c"].limiter, ShouldHaveSameTypeAs, &rate.Limiter{}) 67 | So(eps[1].limiters["identity-c"].limiter.Limit(), ShouldEqual, rate.Limit(100)) 68 | So(eps[1].limiters["identity-c"].limiter.Burst(), ShouldEqual, rate.Limit(200)) 69 | 70 | }) 71 | 72 | Convey("When I poke one endpoint", func() { 73 | 74 | time.Sleep(1500 * time.Millisecond) 75 | srv.pokeEndpoint("2.2.2.2:4443", 0.6) 76 | 77 | eps := srv.getEndpoints() 78 | 79 | // sort since it is stored as a map 80 | sort.Slice(eps, func(i, j int) bool { 81 | return strings.Compare(eps[i].address, eps[j].address) == -1 82 | }) 83 | 84 | Convey("Then the endpoint should be poked", func() { 85 | So(eps[1].address, ShouldEqual, "2.2.2.2:4443") 86 | So(eps[1].lastLoad, ShouldEqual, 0.6) 87 | So(eps[1].lastSeen.Round(time.Second), ShouldEqual, time.Now().Round(time.Second)) 88 | }) 89 | 90 | Convey("Then the second one should be outdated", func() { 91 | 92 | outdated := srv.outdatedEndpoints(time.Now().Add(-1 * time.Second)) 93 | 94 | So(len(outdated), ShouldEqual, 1) 95 | So(outdated, ShouldContain, "1.1.1.1:4443") 96 | }) 97 | }) 98 | 99 | Convey("When I unregister the 2 endpoints", func() { 100 | 101 | srv.unregisterEndpoint("1.1.1.1:4443") 102 | srv.unregisterEndpoint("2.2.2.2:4443") 103 | 104 | Convey("Then all endpoints should be removed", func() { 105 | So(len(srv.getEndpoints()), ShouldEqual, 0) 106 | }) 107 | }) 108 | }) 109 | }) 110 | } 111 | -------------------------------------------------------------------------------- /gateway/extractors_test.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | ) 7 | 8 | func Test_defaultSourceExtractor_ExtractSource(t *testing.T) { 9 | type args struct { 10 | r *http.Request 11 | } 12 | tests := []struct { 13 | name string 14 | f defaultSourceExtractor 15 | args args 16 | want string 17 | wantErr bool 18 | }{ 19 | { 20 | "no header", 21 | NewDefaultSourceExtractor("").(defaultSourceExtractor), 22 | args{ 23 | &http.Request{Header: http.Header{}}, 24 | }, 25 | "default", 26 | false, 27 | }, 28 | { 29 | "header", 30 | NewDefaultSourceExtractor("").(defaultSourceExtractor), 31 | args{ 32 | &http.Request{Header: http.Header{"Authorization": {"Bearer X"}}}, 33 | }, 34 | "763578776710062675", 35 | false, 36 | }, 37 | { 38 | "cookie", 39 | NewDefaultSourceExtractor("X-Token").(defaultSourceExtractor), 40 | args{ 41 | func() *http.Request { 42 | r := &http.Request{Header: http.Header{}} 43 | r.AddCookie(&http.Cookie{Name: "X-Token", Value: "X"}) 44 | return r 45 | }(), 46 | }, 47 | "15784302077936868069", 48 | false, 49 | }, 50 | { 51 | "cookie with no configured authCookieName", 52 | NewDefaultSourceExtractor("").(defaultSourceExtractor), 53 | args{ 54 | func() *http.Request { 55 | r := &http.Request{Header: http.Header{}} 56 | r.AddCookie(&http.Cookie{Name: "X-Token", Value: "X"}) 57 | return r 58 | }(), 59 | }, 60 | "default", 61 | false, 62 | }, 63 | { 64 | "empty cookie", 65 | NewDefaultSourceExtractor("X-Token").(defaultSourceExtractor), 66 | args{ 67 | func() *http.Request { 68 | r := &http.Request{Header: http.Header{}} 69 | r.AddCookie(&http.Cookie{}) 70 | return r 71 | }(), 72 | }, 73 | "default", 74 | false, 75 | }, 76 | { 77 | "cookie and header (cookie should have prio)", 78 | NewDefaultSourceExtractor("X-Token").(defaultSourceExtractor), 79 | args{ 80 | func() *http.Request { 81 | r := &http.Request{Header: http.Header{"Authorization": {"Bearer X"}}} 82 | r.AddCookie(&http.Cookie{Name: "X-Token", Value: "X"}) 83 | return r 84 | }(), 85 | }, 86 | "15784302077936868069", 87 | false, 88 | }, 89 | { 90 | "empty cookie and header (header should fall back)", 91 | NewDefaultSourceExtractor("X-Token").(defaultSourceExtractor), 92 | args{ 93 | func() *http.Request { 94 | r := &http.Request{Header: http.Header{"Authorization": {"Bearer X"}}} 95 | r.AddCookie(&http.Cookie{}) 96 | return r 97 | }(), 98 | }, 99 | "763578776710062675", 100 | false, 101 | }, 102 | } 103 | for _, tt := range tests { 104 | t.Run(tt.name, func(t *testing.T) { 105 | f := tt.f 106 | got, err := f.ExtractSource(tt.args.r) 107 | if (err != nil) != tt.wantErr { 108 | t.Errorf("defaultSourceExtractor.ExtractSource() error = %v, wantErr %v", err, tt.wantErr) 109 | return 110 | } 111 | if got != tt.want { 112 | t.Errorf("defaultSourceExtractor.ExtractSource() = %v, want %v", got, tt.want) 113 | } 114 | }) 115 | } 116 | } 117 | 118 | func Test_defaultTCPSourceExtractor_ExtractSource(t *testing.T) { 119 | type args struct { 120 | r *http.Request 121 | } 122 | tests := []struct { 123 | name string 124 | f defaultTCPSourceExtractor 125 | args args 126 | want string 127 | wantErr bool 128 | }{ 129 | { 130 | "source", 131 | defaultTCPSourceExtractor{}, 132 | args{ 133 | &http.Request{RemoteAddr: "1.1.1.1"}, 134 | }, 135 | "1.1.1.1", 136 | false, 137 | }, 138 | } 139 | for _, tt := range tests { 140 | t.Run(tt.name, func(t *testing.T) { 141 | f := defaultTCPSourceExtractor{} 142 | got, err := f.ExtractSource(tt.args.r) 143 | if (err != nil) != tt.wantErr { 144 | t.Errorf("defaultTCPSourceExtractor.ExtractSource() error = %v, wantErr %v", err, tt.wantErr) 145 | return 146 | } 147 | if got != tt.want { 148 | t.Errorf("defaultTCPSourceExtractor.ExtractSource() = %v, want %v", got, tt.want) 149 | } 150 | }) 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /pubsub_local.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "context" 16 | "sync" 17 | ) 18 | 19 | type registration struct { 20 | ch chan *Publication 21 | topic string 22 | } 23 | 24 | // localPubSub implements a PubSubClient using local channels 25 | type localPubSub struct { 26 | subscribers map[string][]chan *Publication 27 | register chan *registration 28 | unregister chan *registration 29 | publications chan *Publication 30 | stop chan struct{} 31 | 32 | lock *sync.Mutex 33 | } 34 | 35 | // NewLocalPubSubClient returns a PubSubClient backed by local channels. 36 | func NewLocalPubSubClient() PubSubClient { 37 | 38 | return newlocalPubSub() 39 | } 40 | 41 | // newlocalPubSub returns a new localPubSub. 42 | func newlocalPubSub() *localPubSub { 43 | 44 | return &localPubSub{ 45 | subscribers: map[string][]chan *Publication{}, 46 | register: make(chan *registration), 47 | unregister: make(chan *registration), 48 | stop: make(chan struct{}), 49 | publications: make(chan *Publication, 1024), 50 | lock: &sync.Mutex{}, 51 | } 52 | } 53 | 54 | // Publish publishes a publication. 55 | func (p *localPubSub) Publish(publication *Publication, opts ...PubSubOptPublish) error { 56 | 57 | p.publications <- publication 58 | 59 | return nil 60 | } 61 | 62 | // Subscribe will subscribe the given channel to the given topic 63 | func (p *localPubSub) Subscribe(c chan *Publication, errors chan error, topic string, opts ...PubSubOptSubscribe) func() { 64 | 65 | unsubscribe := make(chan struct{}) 66 | 67 | p.registerSubscriberChannel(c, topic) 68 | 69 | go func() { 70 | <-unsubscribe 71 | p.unregisterSubscriberChannel(c, topic) 72 | }() 73 | 74 | return func() { close(unsubscribe) } 75 | } 76 | 77 | // Connect connects the PubSubClient to the remote service. 78 | func (p *localPubSub) Connect(ctx context.Context) error { 79 | 80 | go p.listen() 81 | 82 | return nil 83 | } 84 | 85 | // Disconnect disconnects the PubSubClient from the remote service.. 86 | func (p *localPubSub) Disconnect() error { 87 | 88 | close(p.stop) 89 | 90 | return nil 91 | } 92 | 93 | func (p *localPubSub) registerSubscriberChannel(c chan *Publication, topic string) { 94 | 95 | p.register <- ®istration{ch: c, topic: topic} 96 | } 97 | 98 | func (p *localPubSub) unregisterSubscriberChannel(c chan *Publication, topic string) { 99 | 100 | p.unregister <- ®istration{ch: c, topic: topic} 101 | } 102 | 103 | func (p *localPubSub) listen() { 104 | 105 | for { 106 | select { 107 | case reg := <-p.register: 108 | p.lock.Lock() 109 | if _, ok := p.subscribers[reg.topic]; !ok { 110 | p.subscribers[reg.topic] = []chan *Publication{} 111 | } 112 | 113 | p.subscribers[reg.topic] = append(p.subscribers[reg.topic], reg.ch) 114 | p.lock.Unlock() 115 | 116 | case reg := <-p.unregister: 117 | p.lock.Lock() 118 | for i, sub := range p.subscribers[reg.topic] { 119 | if sub == reg.ch { 120 | p.subscribers[reg.topic] = append(p.subscribers[reg.topic][:i], p.subscribers[reg.topic][i+1:]...) 121 | close(sub) 122 | break 123 | } 124 | } 125 | p.lock.Unlock() 126 | 127 | case publication := <-p.publications: 128 | 129 | p.lock.Lock() 130 | var wg sync.WaitGroup 131 | for _, sub := range p.subscribers[publication.Topic] { 132 | wg.Add(1) 133 | go func(s chan *Publication, p *Publication) { 134 | defer wg.Done() 135 | s <- p.Duplicate() 136 | }(sub, publication) 137 | } 138 | wg.Wait() 139 | p.lock.Unlock() 140 | 141 | case <-p.stop: 142 | p.lock.Lock() 143 | p.subscribers = map[string][]chan *Publication{} 144 | p.lock.Unlock() 145 | return 146 | } 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /gateway/upstreamer/push/notifier.go: -------------------------------------------------------------------------------- 1 | package push 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "runtime" 7 | "time" 8 | 9 | "github.com/shirou/gopsutil/v3/process" 10 | "go.aporeto.io/bahamut" 11 | "go.uber.org/zap" 12 | ) 13 | 14 | // A Notifier sends ServicePing to the Wutai gateways. 15 | type Notifier struct { 16 | pubsub bahamut.PubSubClient 17 | limiters IdentityToAPILimitersRegistry 18 | privateOverrides map[string]bool 19 | serviceName string 20 | endpoint string 21 | serviceStatusTopic string 22 | prefix string 23 | frequency time.Duration 24 | } 25 | 26 | // NewNotifier returns a new Wutai notifier. 27 | func NewNotifier( 28 | pubsub bahamut.PubSubClient, 29 | serviceStatusTopic string, 30 | serviceName string, 31 | endpoint string, 32 | opts ...NotifierOption, 33 | ) *Notifier { 34 | 35 | cfg := newNotifierConfig() 36 | for _, o := range opts { 37 | o(&cfg) 38 | } 39 | 40 | return &Notifier{ 41 | pubsub: pubsub, 42 | serviceName: serviceName, 43 | endpoint: endpoint, 44 | serviceStatusTopic: serviceStatusTopic, 45 | limiters: cfg.rateLimits, 46 | frequency: cfg.pingInterval, 47 | prefix: cfg.prefix, 48 | privateOverrides: cfg.privateOverrides, 49 | } 50 | } 51 | 52 | // MakeStartHook returns a bahamut start hook that sends the hello message to the Upstreamer periodically. 53 | func (w *Notifier) MakeStartHook(ctx context.Context) func(server bahamut.Server) error { 54 | 55 | return func(server bahamut.Server) error { 56 | 57 | p, err := process.NewProcess(int32(os.Getpid())) 58 | if err != nil { 59 | return err 60 | } 61 | 62 | routes := server.RoutesInfo() 63 | for _, versionedRoutes := range routes { 64 | for i, r := range versionedRoutes { 65 | priv, ok := w.privateOverrides[r.Identity] 66 | if ok { 67 | r.Private = priv 68 | versionedRoutes[i] = r 69 | } 70 | } 71 | } 72 | 73 | sp := servicePing{ 74 | Name: w.serviceName, 75 | Prefix: w.prefix, 76 | Status: entityStatusHello, 77 | Endpoint: w.endpoint, 78 | Routes: routes, 79 | Versions: server.VersionsInfo(), 80 | PushEndpoint: server.PushEndpoint(), 81 | APILimiters: w.limiters, 82 | } 83 | 84 | pct, err := p.CPUPercent() 85 | if err != nil { 86 | return err 87 | } 88 | 89 | // Use the maxproc to get a percentage between 0 and 100 90 | cores := float64(runtime.GOMAXPROCS(0)) 91 | 92 | sp.Load = pct / cores 93 | 94 | pub := bahamut.NewPublication(w.serviceStatusTopic) 95 | if err := pub.Encode(sp); err != nil { 96 | return err 97 | } 98 | 99 | if err := w.pubsub.Publish(pub); err != nil { 100 | return err 101 | } 102 | 103 | go func() { 104 | for { 105 | select { 106 | case <-time.After(w.frequency): 107 | 108 | if pct, err = p.Percent(0); err != nil { 109 | zap.L().Error("Unable to retrieve cpu usage", zap.Error(err)) 110 | continue 111 | } 112 | 113 | sp.Load = pct / cores 114 | 115 | if err := pub.Encode(sp); err != nil { 116 | zap.L().Error("Unable to encode service ping", zap.Error(err)) 117 | continue 118 | } 119 | 120 | if err := w.pubsub.Publish(pub); err != nil { 121 | zap.L().Error("Unable to send wutai up ping", zap.Error(err)) 122 | } 123 | case <-ctx.Done(): 124 | return 125 | } 126 | } 127 | }() 128 | 129 | return nil 130 | } 131 | } 132 | 133 | // MakeStopHook returns a bahamut stop hook that sends the goodbye message to the Upstreamer. 134 | func (w *Notifier) MakeStopHook() func(server bahamut.Server) error { 135 | 136 | return func(server bahamut.Server) error { 137 | 138 | pub := bahamut.NewPublication(w.serviceStatusTopic) 139 | if err := pub.Encode(servicePing{ 140 | Name: w.serviceName, 141 | Prefix: w.prefix, 142 | Status: entityStatusGoodbye, 143 | Endpoint: w.endpoint, 144 | }); err != nil { 145 | return err 146 | } 147 | 148 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 149 | defer cancel() 150 | 151 | if err := w.pubsub.Publish(pub, bahamut.NATSOptPublishRequireAck(ctx)); err != nil { 152 | return err 153 | } 154 | 155 | <-time.After(time.Second) 156 | 157 | return w.pubsub.Disconnect() 158 | } 159 | } 160 | -------------------------------------------------------------------------------- /gateway/errors.go: -------------------------------------------------------------------------------- 1 | package gateway 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "crypto/x509" 7 | "fmt" 8 | "io" 9 | "net" 10 | "net/http" 11 | 12 | "github.com/mailgun/multibuf" 13 | "github.com/vulcand/oxy/v2/connlimit" 14 | "go.aporeto.io/elemental" 15 | ) 16 | 17 | var ( 18 | errLocked = elemental.NewError( 19 | "Service Locked", 20 | "The requested service is in maintenance. Please try again in a moment.", 21 | "gateway", 22 | http.StatusLocked, 23 | ) 24 | 25 | errServiceUnavailable = elemental.NewError( 26 | "Service Temporarily Unavailable", 27 | "The requested service is not available. Please try again in a moment.", 28 | "gateway", 29 | http.StatusServiceUnavailable, 30 | ) 31 | 32 | errGatewayTimeout = elemental.NewError( 33 | "Gateway Timeout", 34 | "The requested service took too long to respond. Please try again in a moment.", 35 | "gateway", 36 | http.StatusGatewayTimeout, 37 | ) 38 | 39 | errBadGateway = elemental.NewError( 40 | "Bad Gateway", 41 | "The requested service is not available. Please try again in a moment.", 42 | "gateway", 43 | http.StatusBadGateway, 44 | ) 45 | 46 | errClientClosedConnection = elemental.NewError( 47 | "Client Closed Connection", 48 | "The client closed the connection before it could complete.", 49 | "gateway", 50 | 499, 51 | ) 52 | 53 | errRateLimit = elemental.NewError( 54 | "Too Many Requests", 55 | "Please retry in a moment.", 56 | "gateway", 57 | http.StatusTooManyRequests, 58 | ) 59 | 60 | errConnLimit = elemental.NewError( 61 | "Too Many Connections", 62 | "Please retry in a moment.", 63 | "gateway", 64 | http.StatusTooManyRequests, 65 | ) 66 | ) 67 | 68 | func makeError(code int, title string, description string) elemental.Error { 69 | return elemental.NewError( 70 | title, 71 | description, 72 | "gateway", 73 | code, 74 | ) 75 | } 76 | 77 | type errorHeaderInjector func(w http.ResponseWriter, r *http.Request) http.Header 78 | 79 | type errorHandler struct { 80 | corsOriginInjector errorHeaderInjector 81 | } 82 | 83 | func (s *errorHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, err error) { 84 | 85 | if err == nil { 86 | return 87 | } 88 | 89 | if s.corsOriginInjector != nil { 90 | s.corsOriginInjector(w, r) 91 | } 92 | 93 | switch e := err.(type) { 94 | 95 | case net.Error: 96 | if e.Timeout() { 97 | writeError(w, r, errGatewayTimeout) 98 | return 99 | } 100 | writeError(w, r, errBadGateway) 101 | return 102 | 103 | case *connlimit.MaxConnError: 104 | writeError(w, r, errConnLimit) 105 | return 106 | 107 | case *multibuf.MaxSizeReachedError: 108 | writeError(w, r, makeError(http.StatusRequestEntityTooLarge, "Entity Too Large", fmt.Sprintf("Payload size exceeds the maximum allowed size (%d bytes)", e.MaxSize))) 109 | return 110 | 111 | case x509.UnknownAuthorityError, x509.HostnameError, x509.CertificateInvalidError, x509.ConstraintViolationError, *tls.CertificateVerificationError: 112 | writeError(w, r, makeError(495, "TLS Error", err.Error())) 113 | return 114 | } 115 | 116 | switch err { 117 | case io.EOF: 118 | writeError(w, r, errBadGateway) 119 | case context.Canceled: 120 | writeError(w, r, errClientClosedConnection) 121 | case errTooManyRequest: 122 | writeError(w, r, errRateLimit) 123 | default: 124 | // the http package function MaxBytesReader is returning an error.erroString 125 | // so we need to check its string value. 126 | if err.Error() == "http: request body too large" { 127 | writeError(w, r, makeError(http.StatusRequestEntityTooLarge, "Entity Too Large", err.Error())) 128 | return 129 | } 130 | writeError(w, r, makeError(http.StatusInternalServerError, "Internal Server Error", err.Error())) 131 | } 132 | } 133 | 134 | func writeError(w http.ResponseWriter, r *http.Request, eerr elemental.Error) { 135 | 136 | _, encoding, err := elemental.EncodingFromHeaders(r.Header) 137 | if err != nil { 138 | encoding = elemental.EncodingTypeJSON 139 | } 140 | 141 | data, err := elemental.Encode(encoding, elemental.NewErrors(eerr)) 142 | if err != nil { 143 | http.Error(w, "Error while encoding the error", eerr.Code) 144 | } 145 | 146 | if encoding == elemental.EncodingTypeJSON { 147 | w.Header().Set("Content-Type", string(encoding)+"; charset=UTF-8") 148 | } else { 149 | w.Header().Set("Content-Type", string(encoding)) 150 | } 151 | 152 | w.WriteHeader(eerr.Code) 153 | w.Write(data) // nolint 154 | } 155 | -------------------------------------------------------------------------------- /meta.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "fmt" 16 | "sort" 17 | "strings" 18 | 19 | "go.aporeto.io/elemental" 20 | ) 21 | 22 | // A RouteInfo contains basic information about an api route. 23 | type RouteInfo struct { 24 | Identity string `msgpack:"identity" json:"identity"` 25 | URL string `msgpack:"url" json:"url"` 26 | Verbs []string `msgpack:"verbs,omitempty" json:"verbs,omitempty"` 27 | Private bool `msgpack:"private,omitempty" json:"private,omitempty"` 28 | } 29 | 30 | func (r RouteInfo) String() string { 31 | return fmt.Sprintf("%s -> %s", r.URL, strings.Join(r.Verbs, ", ")) 32 | } 33 | 34 | type routeBuilder struct { 35 | verbs map[string]struct{} 36 | identity elemental.Identity 37 | private bool 38 | } 39 | 40 | func buildVersionedRoutes(modelManagers map[int]elemental.ModelManager, processorFinder processorFinderFunc) map[int][]RouteInfo { 41 | 42 | addRoute := func(routes map[string]routeBuilder, identity elemental.Identity, url string, verb string, private bool) { 43 | 44 | rb, ok := routes[url] 45 | if !ok { 46 | rb = routeBuilder{ 47 | verbs: map[string]struct{}{}, 48 | private: private, 49 | identity: identity, 50 | } 51 | routes[url] = rb 52 | } 53 | rb.verbs[verb] = struct{}{} 54 | } 55 | 56 | versionedRoutes := map[int][]RouteInfo{} 57 | 58 | for version, modelManager := range modelManagers { 59 | 60 | versionedRoutes[version] = []RouteInfo{} 61 | 62 | routes := map[string]routeBuilder{} 63 | 64 | for identity, relationship := range modelManager.Relationships() { 65 | 66 | // If we don't have a processor registered for the given model, we skip. 67 | if _, err := processorFinder(identity); err != nil { 68 | continue 69 | } 70 | 71 | if len(relationship.Create) > 0 { 72 | addRoute(routes, identity, fmt.Sprintf("/%s", identity.Category), "POST", identity.Private) 73 | } 74 | 75 | if len(relationship.Retrieve) > 0 { 76 | addRoute(routes, identity, fmt.Sprintf("/%s/:id", identity.Category), "GET", identity.Private) 77 | } 78 | 79 | if len(relationship.Delete) > 0 { 80 | addRoute(routes, identity, fmt.Sprintf("/%s/:id", identity.Category), "DELETE", identity.Private) 81 | } 82 | 83 | if len(relationship.Update) > 0 { 84 | addRoute(routes, identity, fmt.Sprintf("/%s/:id", identity.Category), "PUT", identity.Private) 85 | } 86 | 87 | for parent := range relationship.RetrieveMany { 88 | 89 | if parent == "root" { 90 | addRoute(routes, identity, fmt.Sprintf("/%s", identity.Category), "GET", identity.Private) 91 | } else { 92 | addRoute(routes, identity, fmt.Sprintf("/%s/:id/%s", modelManager.IdentityFromName(parent).Category, identity.Category), "GET", identity.Private) 93 | } 94 | } 95 | 96 | for parent := range relationship.Create { 97 | 98 | if parent == "root" { 99 | addRoute(routes, identity, fmt.Sprintf("/%s", identity.Category), "POST", identity.Private) 100 | } else { 101 | addRoute(routes, identity, fmt.Sprintf("/%s/:id/%s", modelManager.IdentityFromName(parent).Category, identity.Category), "POST", identity.Private) 102 | } 103 | } 104 | } 105 | 106 | for url, rb := range routes { 107 | var flatVerbs []string 108 | 109 | for v := range rb.verbs { 110 | flatVerbs = append(flatVerbs, v) 111 | } 112 | sort.Strings(flatVerbs) 113 | 114 | versionedRoutes[version] = append( 115 | versionedRoutes[version], 116 | RouteInfo{ 117 | URL: url, 118 | Verbs: flatVerbs, 119 | Private: rb.private, 120 | Identity: rb.identity.Category, 121 | }, 122 | ) 123 | } 124 | } 125 | 126 | for _, ri := range versionedRoutes { 127 | sort.Slice(ri, func(i int, j int) bool { 128 | return strings.Compare(ri[i].URL, ri[j].URL) == -1 129 | }) 130 | } 131 | 132 | return versionedRoutes 133 | } 134 | -------------------------------------------------------------------------------- /pubsub_nats_options_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "context" 16 | "crypto/tls" 17 | "testing" 18 | "time" 19 | 20 | nats "github.com/nats-io/nats.go" 21 | // nolint:revive // Allow dot imports for readability in tests 22 | . "github.com/smartystreets/goconvey/convey" 23 | ) 24 | 25 | func TestBahamut_NATSOption(t *testing.T) { 26 | 27 | n := &natsPubSub{} 28 | 29 | Convey("Calling NATSOptCredentials should work", t, func() { 30 | NATSOptCredentials("user", "pass")(n) 31 | So(n.username, ShouldEqual, "user") 32 | So(n.password, ShouldEqual, "pass") 33 | }) 34 | 35 | Convey("Calling NATSOptClusterID should work", t, func() { 36 | NATSOptClusterID("cid")(n) 37 | So(n.clusterID, ShouldEqual, "cid") 38 | }) 39 | 40 | Convey("Calling NATSOptClientID should work", t, func() { 41 | NATSOptClientID("cid")(n) 42 | So(n.clientID, ShouldEqual, "cid") 43 | }) 44 | 45 | Convey("Calling NATSOptTLS should work", t, func() { 46 | tlscfg := &tls.Config{} 47 | NATSOptTLS(tlscfg)(n) 48 | So(n.tlsConfig, ShouldEqual, tlscfg) 49 | }) 50 | 51 | Convey("Calling NATSErrorHandler should work", t, func() { 52 | f := func(*nats.Conn, *nats.Subscription, error) {} 53 | NATSErrorHandler(f)(n) 54 | So(n.errorHandleFunc, ShouldEqual, f) 55 | }) 56 | } 57 | 58 | func TestBahamut_PubSubNatsOptionsSubscribe(t *testing.T) { 59 | 60 | c := natsSubscribeConfig{} 61 | 62 | Convey("Calling NATSOptSubscribeQueue should work", t, func() { 63 | NATSOptSubscribeQueue("queueGroup")(&c) 64 | So(c.queueGroup, ShouldEqual, "queueGroup") 65 | }) 66 | 67 | Convey("Calling NATSOptSubscribeReplyTimeout should set the timeout", t, func() { 68 | duration := 15 * time.Second 69 | NATSOptSubscribeReplyTimeout(duration)(&c) 70 | So(c.replyTimeout, ShouldEqual, duration) 71 | }) 72 | 73 | } 74 | 75 | func TestBahamut_PubSubNatsOptionsPublish(t *testing.T) { 76 | 77 | Convey("Setup", t, func() { 78 | 79 | c := natsPublishConfig{} 80 | 81 | Convey("Calling NATSOptPublishRequireAck should work", func() { 82 | NATSOptPublishRequireAck(context.TODO())(&c) 83 | So(c.ctx, ShouldResemble, context.TODO()) 84 | So(c.desiredResponse, ShouldEqual, ResponseModeACK) 85 | }) 86 | 87 | Convey("Calling NATSOptPublishRequireAck should panic if requestMode has already been set", func() { 88 | c.desiredResponse = ResponseModePublication 89 | So(func() { 90 | NATSOptPublishRequireAck(context.TODO())(&c) 91 | }, ShouldPanic) 92 | }) 93 | 94 | Convey("Calling NATSOptPublishRequireAck should panic if supplied context is nil", func() { 95 | So(func() { 96 | // nolint - note: ignoring linter feedback as we are trying to cause a panic intentionally by passing in a `nil` context 97 | NATSOptPublishRequireAck(nil)(&c) 98 | }, ShouldPanic) 99 | }) 100 | 101 | Convey("Calling NATSOptRespondToChannel should work", func() { 102 | respCh := make(chan *Publication) 103 | NATSOptRespondToChannel(context.TODO(), respCh)(&c) 104 | So(c.ctx, ShouldResemble, context.TODO()) 105 | So(c.responseCh, ShouldEqual, respCh) 106 | So(c.desiredResponse, ShouldEqual, ResponseModePublication) 107 | }) 108 | 109 | Convey("Calling NATSOptRespondToChannel should panic if requestMode has already been set", func() { 110 | c.desiredResponse = ResponseModeACK 111 | So(func() { 112 | NATSOptRespondToChannel(context.TODO(), make(chan *Publication))(&c) 113 | }, ShouldPanic) 114 | }) 115 | 116 | Convey("Calling NATSOptRespondToChannel should panic if supplied response channel is nil", func() { 117 | So(func() { 118 | NATSOptRespondToChannel(context.TODO(), nil)(&c) 119 | }, ShouldPanic) 120 | }) 121 | 122 | Convey("Calling NATSOptRespondToChannel should panic if supplied context is nil", func() { 123 | c.desiredResponse = ResponseModeACK 124 | So(func() { 125 | // nolint - note: ignoring linter feedback as we are trying to cause a panic intentionally by passing in a `nil` context 126 | NATSOptRespondToChannel(nil, make(chan *Publication))(&c) 127 | }, ShouldPanic) 128 | }) 129 | }) 130 | } 131 | -------------------------------------------------------------------------------- /pubsub_local_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "context" 16 | "testing" 17 | "time" 18 | 19 | // nolint:revive // Allow dot imports for readability in tests 20 | . "github.com/smartystreets/goconvey/convey" 21 | ) 22 | 23 | func TestLocalPubSub_NewPubSubServer(t *testing.T) { 24 | 25 | Convey("Given I create a new PubSubServer", t, func() { 26 | 27 | ps := newlocalPubSub() 28 | 29 | Convey("Then the PubSubServer should be correctly initialized", func() { 30 | So(ps.subscribers, ShouldHaveSameTypeAs, map[string][]chan *Publication{}) 31 | }) 32 | }) 33 | } 34 | 35 | func TestLocalPubSub_ConnectDisconnect(t *testing.T) { 36 | 37 | Convey("Given I create a new PubSubServer", t, func() { 38 | 39 | ps := newlocalPubSub() 40 | 41 | Convey("When I connect", func() { 42 | 43 | err := ps.Connect(context.Background()) 44 | 45 | Convey("then err should be nil", func() { 46 | So(err, ShouldBeNil) 47 | }) 48 | }) 49 | 50 | Convey("Whan I call Disconnect nothing should happen", func() { 51 | _ = ps.Disconnect() 52 | }) 53 | }) 54 | } 55 | 56 | func TestLocalPubSub_RegisterUnregister(t *testing.T) { 57 | 58 | Convey("Given I create a new PubSubServer", t, func() { 59 | 60 | ps := newlocalPubSub() 61 | if err := ps.Connect(context.Background()); err != nil { 62 | panic(err) 63 | } 64 | defer func() { _ = ps.Disconnect() }() 65 | 66 | Convey("When I register a channel to a topic", func() { 67 | 68 | c := make(chan *Publication) 69 | 70 | ps.registerSubscriberChannel(c, "topic") 71 | time.Sleep(30 * time.Millisecond) 72 | 73 | Convey("Then the channel should be correctly registered", func() { 74 | ps.lock.Lock() 75 | defer ps.lock.Unlock() 76 | So(ps.subscribers["topic"][0], ShouldEqual, c) 77 | }) 78 | 79 | Convey("When I unregister it", func() { 80 | 81 | ps.unregisterSubscriberChannel(c, "topic") 82 | time.Sleep(30 * time.Millisecond) 83 | 84 | Convey("Then the channel should be correctly unregistered", func() { 85 | ps.lock.Lock() 86 | defer ps.lock.Unlock() 87 | So(len(ps.subscribers["topic"]), ShouldEqual, 0) 88 | }) 89 | 90 | Convey("Then the channel should be closed", func() { 91 | _, ok := <-c 92 | So(ok, ShouldBeFalse) 93 | }) 94 | }) 95 | }) 96 | }) 97 | } 98 | 99 | func TestLocalPubSub_PublishSubscribe(t *testing.T) { 100 | 101 | Convey("Given I create a new PubSubServer", t, func() { 102 | 103 | ps := newlocalPubSub() 104 | if err := ps.Connect(context.Background()); err != nil { 105 | panic(err) 106 | } 107 | defer func() { _ = ps.Disconnect() }() 108 | 109 | Convey("When I register a 2 channels to a topic 'topic' and a another one to 'nottopic'", func() { 110 | 111 | c1 := make(chan *Publication) 112 | c2 := make(chan *Publication) 113 | c3 := make(chan *Publication) 114 | 115 | u1 := ps.Subscribe(c1, nil, "topic") 116 | u2 := ps.Subscribe(c2, nil, "topic") 117 | u3 := ps.Subscribe(c3, nil, "nottopic") 118 | time.Sleep(30 * time.Millisecond) 119 | 120 | Convey("When Publish somthing", func() { 121 | 122 | publ := NewPublication("topic") 123 | go func() { _ = ps.Publish(publ) }() 124 | 125 | time.Sleep(30 * time.Millisecond) 126 | 127 | var ok1, ok2, ok3 bool 128 | LOOP: 129 | for { 130 | select { 131 | case <-c1: 132 | ok1 = true 133 | case <-c2: 134 | ok2 = true 135 | case <-c3: 136 | ok3 = true 137 | case <-time.After(30 * time.Millisecond): 138 | break LOOP 139 | } 140 | } 141 | Convey("Then the first two channels should receive the publication", func() { 142 | So(ok1, ShouldBeTrue) 143 | So(ok2, ShouldBeTrue) 144 | }) 145 | 146 | Convey("Then the third channel should not receive anything", func() { 147 | So(ok3, ShouldBeFalse) 148 | }) 149 | 150 | Convey("When I ubsubscribe", func() { 151 | u1() 152 | u2() 153 | u3() 154 | 155 | Convey("Then all channels should be closed", func() { 156 | _, ok1 := <-c1 157 | _, ok2 := <-c2 158 | _, ok3 := <-c3 159 | 160 | So(ok1, ShouldBeFalse) 161 | So(ok2, ShouldBeFalse) 162 | So(ok3, ShouldBeFalse) 163 | }) 164 | 165 | }) 166 | }) 167 | }) 168 | }) 169 | } 170 | -------------------------------------------------------------------------------- /authorizer/simple/authenticator_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package simple 13 | 14 | import ( 15 | "fmt" 16 | "testing" 17 | 18 | // nolint:revive // Allow dot imports for readability in tests 19 | . "github.com/smartystreets/goconvey/convey" 20 | "go.aporeto.io/bahamut" 21 | ) 22 | 23 | func TestAuththenticator_NewAuthenticator(t *testing.T) { 24 | 25 | Convey("Given I call NewAuthenticator with two funcs", t, func() { 26 | 27 | f1 := func(bahamut.Context) (bahamut.AuthAction, error) { return bahamut.AuthActionOK, nil } 28 | f2 := func(bahamut.Session) (bahamut.AuthAction, error) { return bahamut.AuthActionOK, nil } 29 | 30 | auth := NewAuthenticator(f1, f2) 31 | 32 | Convey("Then it should be correctly initialized", func() { 33 | So(auth.customAuthRequestFunc, ShouldEqual, f1) 34 | So(auth.customAuthSessionFunc, ShouldEqual, f2) 35 | }) 36 | }) 37 | } 38 | 39 | func TestAuththenticator_AuthenticateRequest(t *testing.T) { 40 | 41 | Convey("Given I call NewAuthenticator and a func that says ok", t, func() { 42 | 43 | f1 := func(bahamut.Context) (bahamut.AuthAction, error) { return bahamut.AuthActionOK, nil } 44 | 45 | auth := NewAuthenticator(f1, nil) 46 | 47 | Convey("When I call AuthenticateRequest", func() { 48 | 49 | action, err := auth.AuthenticateRequest(nil) 50 | 51 | Convey("Then err should be nil", func() { 52 | So(err, ShouldBeNil) 53 | }) 54 | 55 | Convey("Then action should be OK", func() { 56 | So(action, ShouldEqual, bahamut.AuthActionOK) 57 | }) 58 | }) 59 | }) 60 | 61 | Convey("Given I call NewAuthenticator and no func", t, func() { 62 | 63 | auth := NewAuthenticator(nil, nil) 64 | 65 | Convey("When I call AuthenticateRequest", func() { 66 | 67 | action, err := auth.AuthenticateRequest(nil) 68 | 69 | Convey("Then err should be nil", func() { 70 | So(err, ShouldBeNil) 71 | }) 72 | 73 | Convey("Then action should be Continue", func() { 74 | So(action, ShouldEqual, bahamut.AuthActionContinue) 75 | }) 76 | }) 77 | }) 78 | 79 | Convey("Given I call NewAuthenticator and a func that returns an error", t, func() { 80 | 81 | f1 := func(bahamut.Context) (bahamut.AuthAction, error) { return bahamut.AuthActionOK, fmt.Errorf("paf") } 82 | 83 | auth := NewAuthenticator(f1, nil) 84 | 85 | Convey("When I call AuthenticateRequest", func() { 86 | 87 | action, err := auth.AuthenticateRequest(nil) 88 | 89 | Convey("Then err should not be nil", func() { 90 | So(err.Error(), ShouldEqual, "paf") 91 | }) 92 | 93 | Convey("Then action should be KO", func() { 94 | So(action, ShouldEqual, bahamut.AuthActionKO) 95 | }) 96 | }) 97 | }) 98 | } 99 | 100 | func TestAuththenticator_AuthenticateSession(t *testing.T) { 101 | 102 | Convey("Given I call NewAuthenticator and a func that says ok", t, func() { 103 | 104 | f1 := func(bahamut.Session) (bahamut.AuthAction, error) { return bahamut.AuthActionOK, nil } 105 | 106 | auth := NewAuthenticator(nil, f1) 107 | 108 | Convey("When I call AuthenticateSession", func() { 109 | 110 | action, err := auth.AuthenticateSession(nil) 111 | 112 | Convey("Then err should be nil", func() { 113 | So(err, ShouldBeNil) 114 | }) 115 | 116 | Convey("Then action should be OK", func() { 117 | So(action, ShouldEqual, bahamut.AuthActionOK) 118 | }) 119 | }) 120 | }) 121 | 122 | Convey("Given I call NewAuthenticator and no func", t, func() { 123 | 124 | auth := NewAuthenticator(nil, nil) 125 | 126 | Convey("When I call AuthenticateSession", func() { 127 | 128 | action, err := auth.AuthenticateSession(nil) 129 | 130 | Convey("Then err should be nil", func() { 131 | So(err, ShouldBeNil) 132 | }) 133 | 134 | Convey("Then action should be Continue", func() { 135 | So(action, ShouldEqual, bahamut.AuthActionContinue) 136 | }) 137 | }) 138 | }) 139 | 140 | Convey("Given I call NewAuthenticator and a func that returns an error", t, func() { 141 | 142 | f1 := func(bahamut.Session) (bahamut.AuthAction, error) { return bahamut.AuthActionOK, fmt.Errorf("paf") } 143 | 144 | auth := NewAuthenticator(nil, f1) 145 | 146 | Convey("When I call AuthenticateSession", func() { 147 | 148 | action, err := auth.AuthenticateSession(nil) 149 | 150 | Convey("Then err should not be nil", func() { 151 | So(err.Error(), ShouldEqual, "paf") 152 | }) 153 | 154 | Convey("Then action should be KO", func() { 155 | So(action, ShouldEqual, bahamut.AuthActionKO) 156 | }) 157 | }) 158 | }) 159 | } 160 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "crypto/tls" 16 | "crypto/x509" 17 | "log" 18 | "net" 19 | "net/http" 20 | "time" 21 | 22 | opentracing "github.com/opentracing/opentracing-go" 23 | "go.aporeto.io/elemental" 24 | "golang.org/x/time/rate" 25 | ) 26 | 27 | // HealthServerFunc is the type used by the Health Server to check the health of the server. 28 | type HealthServerFunc func() error 29 | 30 | // HealthStatFunc is the type used by the Health Server to return additional custom health info. 31 | type HealthStatFunc func(http.ResponseWriter, *http.Request) 32 | 33 | // TraceCleaner is the type of function that can be used to clean a trace data 34 | // before it is sent to OpenTracing server. You can use this to strip passwords 35 | // or other sensitive data. 36 | type TraceCleaner func(elemental.Identity, []byte) []byte 37 | 38 | // An IdentifiableRetriever is the type of function you can use to perform transparent 39 | // patch support using elemental.SparseIdentifiable. 40 | // If this is set in the configuration, the handler for PATCH method will use 41 | // this function to retrieve the target identifiable, will apply the patch and 42 | // treat the request as a standard update. 43 | type IdentifiableRetriever func(*elemental.Request) (elemental.Identifiable, error) 44 | 45 | type apiRateLimit struct { 46 | limiter *rate.Limiter 47 | condition func(*elemental.Request) bool 48 | } 49 | 50 | // A config represents the configuration of Bahamut. 51 | type config struct { 52 | opentracing struct { 53 | tracer opentracing.Tracer 54 | excludedIdentities map[string]struct{} 55 | traceCleaner TraceCleaner 56 | } 57 | hooks struct { 58 | postStart func(Server) error 59 | preStop func(Server) error 60 | errorTransformer func(error) error 61 | } 62 | rateLimiting struct { 63 | rateLimiter *rate.Limiter 64 | apiRateLimiters map[elemental.Identity]apiRateLimit 65 | } 66 | security struct { 67 | auditer Auditer 68 | corsController CORSPolicyController 69 | requestAuthenticators []RequestAuthenticator 70 | sessionAuthenticators []SessionAuthenticator 71 | authorizers []Authorizer 72 | } 73 | pushServer struct { 74 | service PubSubClient 75 | dispatchHandler PushDispatchHandler 76 | publishHandler PushPublishHandler 77 | topic string 78 | endpoint string 79 | enabled bool 80 | subjectHierarchiesEnabled bool 81 | publishEnabled bool 82 | dispatchEnabled bool 83 | } 84 | meta struct { 85 | version map[string]any 86 | serviceName string 87 | serviceVersion string 88 | disableMetaRoute bool 89 | } 90 | profilingServer struct { 91 | listenAddress string 92 | enabled bool 93 | } 94 | model struct { 95 | modelManagers map[int]elemental.ModelManager 96 | unmarshallers map[elemental.Identity]CustomUmarshaller 97 | marshallers map[elemental.Identity]CustomMarshaller 98 | retriever IdentifiableRetriever 99 | readOnlyExcludedIdentities []elemental.Identity 100 | readOnly bool 101 | } 102 | tls struct { 103 | clientCAPool *x509.CertPool 104 | serverCertificatesRetrieverFunc func(*tls.ClientHelloInfo) (*tls.Certificate, error) 105 | peerCertificateVerifyFunc func([][]byte, [][]*x509.Certificate) error 106 | serverCertificates []tls.Certificate 107 | nextProtos []string 108 | authType tls.ClientAuthType 109 | disableSessionTicket bool 110 | } 111 | healthServer struct { 112 | metricsManager MetricsManager 113 | healthHandler HealthServerFunc 114 | customStats map[string]HealthStatFunc 115 | listenAddress string 116 | readTimeout time.Duration 117 | writeTimeout time.Duration 118 | idleTimeout time.Duration 119 | enabled bool 120 | } 121 | restServer struct { 122 | customListener net.Listener 123 | customRootHandlerFunc http.HandlerFunc 124 | httpLogger *log.Logger 125 | apiPrefix string 126 | customRoutePrefix string 127 | listenAddress string 128 | maxConnection int 129 | idleTimeout time.Duration 130 | writeTimeout time.Duration 131 | readTimeout time.Duration 132 | enabled bool 133 | disableKeepalive bool 134 | disableCompression bool 135 | } 136 | general struct{ panicRecoveryDisabled bool } 137 | } 138 | -------------------------------------------------------------------------------- /opentracing.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "context" 16 | "encoding/base64" 17 | "fmt" 18 | "net/http" 19 | "net/url" 20 | "strings" 21 | 22 | opentracing "github.com/opentracing/opentracing-go" 23 | "github.com/opentracing/opentracing-go/ext" 24 | "github.com/opentracing/opentracing-go/log" 25 | "go.aporeto.io/elemental" 26 | ) 27 | 28 | var snipSlice = []string{"[snip]"} 29 | 30 | func extractClaims(r *elemental.Request) string { 31 | 32 | if r.Password == "" { 33 | return "{}" 34 | } 35 | 36 | tokenParts := strings.SplitN(r.Password, ".", 3) 37 | if len(tokenParts) != 3 { 38 | return fmt.Sprintf("invalid token format: %s", r.Password) 39 | } 40 | 41 | identity, err := base64.RawStdEncoding.DecodeString(tokenParts[1]) 42 | if err != nil { 43 | return fmt.Sprintf("invalid token encoding: %s: %s", r.Password, err) 44 | } 45 | 46 | return string(identity) 47 | } 48 | 49 | func tracingName(r *elemental.Request) string { 50 | 51 | switch r.Operation { 52 | 53 | case elemental.OperationCreate: 54 | return fmt.Sprintf("bahamut.handle.create.%s", r.Identity.Category) 55 | 56 | case elemental.OperationRetrieveMany: 57 | return fmt.Sprintf("bahamut.handle.retrieve_many.%s", r.Identity.Category) 58 | 59 | case elemental.OperationInfo: 60 | return fmt.Sprintf("bahamut.handle.info.%s", r.Identity.Category) 61 | 62 | case elemental.OperationUpdate: 63 | return fmt.Sprintf("bahamut.handle.update.%s", r.Identity.Category) 64 | 65 | case elemental.OperationDelete: 66 | return fmt.Sprintf("bahamut.handle.delete.%s", r.Identity.Category) 67 | 68 | case elemental.OperationRetrieve: 69 | return fmt.Sprintf("bahamut.handle.retrieve.%s", r.Identity.Category) 70 | 71 | case elemental.OperationPatch: 72 | return fmt.Sprintf("bahamut.handle.patch.%s", r.Identity.Category) 73 | } 74 | 75 | return fmt.Sprintf("Unknown operation: %s", r.Operation) 76 | } 77 | 78 | // StartTracing starts tracing the request. 79 | func traceRequest(ctx context.Context, r *elemental.Request, tracer opentracing.Tracer, exludedIdentities map[string]struct{}, cleaner TraceCleaner) context.Context { 80 | 81 | if tracer == nil { 82 | return ctx 83 | } 84 | 85 | if _, ok := exludedIdentities[r.Identity.Name]; ok { 86 | return ctx 87 | } 88 | 89 | spanContext, _ := tracer.Extract(opentracing.TextMap, opentracing.HTTPHeadersCarrier(r.Headers)) 90 | span := tracer.StartSpan(tracingName(r), ext.RPCServerOption(spanContext)) 91 | trackingCtx := opentracing.ContextWithSpan(ctx, span) 92 | 93 | // Remove sensitive information from parameters. 94 | safeParameters := url.Values{} 95 | for k, p := range r.Parameters { 96 | lk := strings.ToLower(k) 97 | if lk == "token" || lk == "password" { 98 | safeParameters[k] = snipSlice 99 | continue 100 | } 101 | safeParameters[k] = []string{fmt.Sprintf("%v", p.Values())} 102 | } 103 | 104 | // Remove sensitive information from headers. 105 | safeHeaders := http.Header{} 106 | for k, v := range r.Headers { 107 | lk := strings.ToLower(k) 108 | if lk == "authorization" || lk == "cookie" { 109 | safeHeaders[k] = snipSlice 110 | continue 111 | } 112 | safeHeaders[k] = v 113 | } 114 | 115 | span.SetTag("req.api_version", r.Version) 116 | span.SetTag("req.id", r.RequestID) 117 | span.SetTag("req.identity", r.Identity.Name) 118 | span.SetTag("req.recursive", r.Recursive) 119 | span.SetTag("req.operation", r.Operation) 120 | span.SetTag("req.override_protection", r.OverrideProtection) 121 | 122 | if r.ExternalTrackingID != "" { 123 | span.SetTag("req.external_tracking_id", r.ExternalTrackingID) 124 | } 125 | 126 | if r.ExternalTrackingType != "" { 127 | span.SetTag("req.external_tracking_type", r.ExternalTrackingType) 128 | } 129 | 130 | if r.Namespace != "" { 131 | span.SetTag("req.namespace", r.Namespace) 132 | } 133 | 134 | if r.ObjectID != "" { 135 | span.SetTag("req.object.id", r.ObjectID) 136 | } 137 | 138 | if r.ParentID != "" { 139 | span.SetTag("req.parent.id", r.ParentID) 140 | } 141 | 142 | if !r.ParentIdentity.IsEmpty() { 143 | span.SetTag("req.parent.identity", r.ParentIdentity.Name) 144 | } 145 | 146 | data := append([]byte{}, r.Data...) 147 | if cleaner != nil { 148 | data = cleaner(r.Identity, data) 149 | } 150 | 151 | span.LogFields( 152 | log.Int("req.page.number", r.Page), 153 | log.Int("req.page.size", r.PageSize), 154 | log.Object("req.headers", safeHeaders), 155 | log.Object("req.claims", extractClaims(r)), 156 | log.Object("req.client_ip", r.ClientIP), 157 | log.Object("req.parameters", safeParameters), 158 | log.Object("req.order_by", r.Order), 159 | log.String("req.payload", string(data)), 160 | ) 161 | 162 | return trackingCtx 163 | } 164 | 165 | func finishTracing(ctx context.Context) { 166 | 167 | span := opentracing.SpanFromContext(ctx) 168 | if span == nil { 169 | return 170 | } 171 | 172 | span.Finish() 173 | } 174 | -------------------------------------------------------------------------------- /gateway/upstreamer/push/upstreamer_options.go: -------------------------------------------------------------------------------- 1 | package push 2 | 3 | import ( 4 | "time" 5 | 6 | "golang.org/x/time/rate" 7 | ) 8 | 9 | // An UpstreamerOption represents a configuration option 10 | // for the Upstreamer. 11 | type UpstreamerOption func(*upstreamConfig) 12 | 13 | type upstreamConfig struct { 14 | randomizer Randomizer 15 | eventsAPIs map[string]string 16 | overrideEndpointAddress string 17 | globalServiceTopic string 18 | requiredServices []string 19 | serviceTimeoutCheckInterval time.Duration 20 | serviceTimeout time.Duration 21 | peerTimeout time.Duration 22 | peerTimeoutCheckInterval time.Duration 23 | peerPingInterval time.Duration 24 | latencySampleSize int 25 | tokenLimitingBurst int 26 | tokenLimitingRPS rate.Limit 27 | exposePrivateAPIs bool 28 | } 29 | 30 | func newUpstreamConfig() upstreamConfig { 31 | return upstreamConfig{ 32 | eventsAPIs: map[string]string{}, 33 | latencySampleSize: 20, 34 | serviceTimeout: 30 * time.Second, 35 | serviceTimeoutCheckInterval: 5 * time.Second, 36 | peerTimeout: 30 * time.Second, 37 | peerTimeoutCheckInterval: 5 * time.Second, 38 | peerPingInterval: 10 * time.Second, 39 | randomizer: newRandomizer(), 40 | tokenLimitingBurst: 2000, 41 | tokenLimitingRPS: 500, 42 | } 43 | } 44 | 45 | // OptionUpstreamerExposePrivateAPIs configures the Upstreamer to expose 46 | // the private APIs. 47 | func OptionUpstreamerExposePrivateAPIs(enabled bool) UpstreamerOption { 48 | return func(cfg *upstreamConfig) { 49 | cfg.exposePrivateAPIs = enabled 50 | } 51 | } 52 | 53 | // OptionUpstreamerOverrideEndpointsAddresses configures the Upstreamer 54 | // to always ignore what IP address the services are reporting 55 | // and always use the provided address. 56 | func OptionUpstreamerOverrideEndpointsAddresses(override string) UpstreamerOption { 57 | return func(cfg *upstreamConfig) { 58 | cfg.overrideEndpointAddress = override 59 | } 60 | } 61 | 62 | // OptionUpstreamerRegisterEventAPI registers an event API for the given serviceName 63 | // on the given endpoint. 64 | // For instance is serviceA exposes an event API on /events, you can use 65 | // OptionUpstreamerRegisterEventAPI("serviceA", "events") 66 | func OptionUpstreamerRegisterEventAPI(serviceName string, eventEndpoint string) UpstreamerOption { 67 | return func(cfg *upstreamConfig) { 68 | cfg.eventsAPIs[serviceName] = eventEndpoint 69 | } 70 | } 71 | 72 | // OptionRequiredServices sets the list of services 73 | // that must be ready before starting the upstreamer. 74 | func OptionRequiredServices(required []string) UpstreamerOption { 75 | return func(cfg *upstreamConfig) { 76 | cfg.requiredServices = required 77 | } 78 | } 79 | 80 | // OptionUpstreamerServiceTimeout sets the time to wait for the upstream 81 | // to consider a service that did not ping to be outdated and removed 82 | // in the case no goodbye was sent. Default is 30s. 83 | // The check interval parameters defines how often the upstream 84 | // will check for outdated services. The default is 5s. 85 | func OptionUpstreamerServiceTimeout(timeout time.Duration, checkInterval time.Duration) UpstreamerOption { 86 | return func(cfg *upstreamConfig) { 87 | cfg.serviceTimeout = timeout 88 | cfg.serviceTimeoutCheckInterval = checkInterval 89 | } 90 | } 91 | 92 | // OptionUpstreamerRandomizer set a custom Randomizer 93 | // that must implement the Randomizer interface 94 | // and be safe for concurrent use by multiple goroutines. 95 | func OptionUpstreamerRandomizer(randomizer Randomizer) UpstreamerOption { 96 | return func(cfg *upstreamConfig) { 97 | cfg.randomizer = randomizer 98 | } 99 | } 100 | 101 | // OptionUpstreamerPeersTimeout sets for how long a peer ping 102 | // should stay valid after receiving it. 103 | // The default is 30s. 104 | func OptionUpstreamerPeersTimeout(t time.Duration) UpstreamerOption { 105 | return func(cfg *upstreamConfig) { 106 | cfg.peerTimeout = t 107 | } 108 | } 109 | 110 | // OptionUpstreamerPeersCheckInterval sets the frequency at which the upstreamer 111 | // will check for outdated peers. 112 | // The default is 5s. 113 | func OptionUpstreamerPeersCheckInterval(t time.Duration) UpstreamerOption { 114 | return func(cfg *upstreamConfig) { 115 | cfg.peerTimeoutCheckInterval = t 116 | } 117 | } 118 | 119 | // OptionUpstreamerPeersPingInterval sets how often the upstreamer will 120 | // ping its peers. 121 | // The default is 10s. 122 | func OptionUpstreamerPeersPingInterval(t time.Duration) UpstreamerOption { 123 | return func(cfg *upstreamConfig) { 124 | cfg.peerPingInterval = t 125 | } 126 | } 127 | 128 | // OptionUpstreamerTokenRateLimiting configures the per source rate limiting. 129 | // The default is cps:500/burst:2000 130 | func OptionUpstreamerTokenRateLimiting(rps rate.Limit, burst int) UpstreamerOption { 131 | return func(cfg *upstreamConfig) { 132 | cfg.tokenLimitingRPS = rps 133 | cfg.tokenLimitingBurst = burst 134 | if cfg.tokenLimitingRPS <= 0 { 135 | panic("rps cannot be <= 0") 136 | } 137 | if cfg.tokenLimitingBurst <= 0 { 138 | panic("burst cannot be <= 0") 139 | } 140 | } 141 | } 142 | 143 | // OptionUpstreamerGlobalServiceTopic sets the global topic that the gateway 144 | // will use to listen for service pings coming from global services. 145 | func OptionUpstreamerGlobalServiceTopic(topic string) UpstreamerOption { 146 | return func(cfg *upstreamConfig) { 147 | cfg.globalServiceTopic = topic 148 | } 149 | } 150 | -------------------------------------------------------------------------------- /pubsub_nats_mocks_test.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: pubsub_nats.go 3 | 4 | // Package bahamut is a generated GoMock package. 5 | package bahamut 6 | 7 | import ( 8 | context "context" 9 | reflect "reflect" 10 | 11 | gomock "github.com/golang/mock/gomock" 12 | nats "github.com/nats-io/nats.go" 13 | ) 14 | 15 | // MockNATSClient is a mock of natsClient interface. 16 | type MockNATSClient struct { 17 | ctrl *gomock.Controller 18 | recorder *MockNATSClientMockRecorder 19 | } 20 | 21 | // MockNATSClientMockRecorder is the mock recorder for MockNATSClient. 22 | type MockNATSClientMockRecorder struct { 23 | mock *MockNATSClient 24 | } 25 | 26 | // NewMockNATSClient creates a new mock instance. 27 | func NewMockNATSClient(ctrl *gomock.Controller) *MockNATSClient { 28 | mock := &MockNATSClient{ctrl: ctrl} 29 | mock.recorder = &MockNATSClientMockRecorder{mock} 30 | return mock 31 | } 32 | 33 | // EXPECT returns an object that allows the caller to indicate expected use. 34 | func (m *MockNATSClient) EXPECT() *MockNATSClientMockRecorder { 35 | return m.recorder 36 | } 37 | 38 | // Close mocks base method. 39 | func (m *MockNATSClient) Close() { 40 | m.ctrl.T.Helper() 41 | m.ctrl.Call(m, "Close") 42 | } 43 | 44 | // Close indicates an expected call of Close. 45 | func (mr *MockNATSClientMockRecorder) Close() *gomock.Call { 46 | mr.mock.ctrl.T.Helper() 47 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockNATSClient)(nil).Close)) 48 | } 49 | 50 | // Flush mocks base method. 51 | func (m *MockNATSClient) Flush() error { 52 | m.ctrl.T.Helper() 53 | ret := m.ctrl.Call(m, "Flush") 54 | ret0, _ := ret[0].(error) 55 | return ret0 56 | } 57 | 58 | // Flush indicates an expected call of Flush. 59 | func (mr *MockNATSClientMockRecorder) Flush() *gomock.Call { 60 | mr.mock.ctrl.T.Helper() 61 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Flush", reflect.TypeOf((*MockNATSClient)(nil).Flush)) 62 | } 63 | 64 | // IsConnected mocks base method. 65 | func (m *MockNATSClient) IsConnected() bool { 66 | m.ctrl.T.Helper() 67 | ret := m.ctrl.Call(m, "IsConnected") 68 | ret0, _ := ret[0].(bool) 69 | return ret0 70 | } 71 | 72 | // IsConnected indicates an expected call of IsConnected. 73 | func (mr *MockNATSClientMockRecorder) IsConnected() *gomock.Call { 74 | mr.mock.ctrl.T.Helper() 75 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockNATSClient)(nil).IsConnected)) 76 | } 77 | 78 | // IsReconnecting mocks base method. 79 | func (m *MockNATSClient) IsReconnecting() bool { 80 | m.ctrl.T.Helper() 81 | ret := m.ctrl.Call(m, "IsReconnecting") 82 | ret0, _ := ret[0].(bool) 83 | return ret0 84 | } 85 | 86 | // IsReconnecting indicates an expected call of IsReconnecting. 87 | func (mr *MockNATSClientMockRecorder) IsReconnecting() *gomock.Call { 88 | mr.mock.ctrl.T.Helper() 89 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsReconnecting", reflect.TypeOf((*MockNATSClient)(nil).IsReconnecting)) 90 | } 91 | 92 | // Publish mocks base method. 93 | func (m *MockNATSClient) Publish(subj string, data []byte) error { 94 | m.ctrl.T.Helper() 95 | ret := m.ctrl.Call(m, "Publish", subj, data) 96 | ret0, _ := ret[0].(error) 97 | return ret0 98 | } 99 | 100 | // Publish indicates an expected call of Publish. 101 | func (mr *MockNATSClientMockRecorder) Publish(subj, data any) *gomock.Call { 102 | mr.mock.ctrl.T.Helper() 103 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Publish", reflect.TypeOf((*MockNATSClient)(nil).Publish), subj, data) 104 | } 105 | 106 | // QueueSubscribe mocks base method. 107 | func (m *MockNATSClient) QueueSubscribe(subj, queue string, cb nats.MsgHandler) (*nats.Subscription, error) { 108 | m.ctrl.T.Helper() 109 | ret := m.ctrl.Call(m, "QueueSubscribe", subj, queue, cb) 110 | ret0, _ := ret[0].(*nats.Subscription) 111 | ret1, _ := ret[1].(error) 112 | return ret0, ret1 113 | } 114 | 115 | // QueueSubscribe indicates an expected call of QueueSubscribe. 116 | func (mr *MockNATSClientMockRecorder) QueueSubscribe(subj, queue, cb any) *gomock.Call { 117 | mr.mock.ctrl.T.Helper() 118 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueSubscribe", reflect.TypeOf((*MockNATSClient)(nil).QueueSubscribe), subj, queue, cb) 119 | } 120 | 121 | // RequestWithContext mocks base method. 122 | func (m *MockNATSClient) RequestWithContext(ctx context.Context, subj string, data []byte) (*nats.Msg, error) { 123 | m.ctrl.T.Helper() 124 | ret := m.ctrl.Call(m, "RequestWithContext", ctx, subj, data) 125 | ret0, _ := ret[0].(*nats.Msg) 126 | ret1, _ := ret[1].(error) 127 | return ret0, ret1 128 | } 129 | 130 | // RequestWithContext indicates an expected call of RequestWithContext. 131 | func (mr *MockNATSClientMockRecorder) RequestWithContext(ctx, subj, data any) *gomock.Call { 132 | mr.mock.ctrl.T.Helper() 133 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestWithContext", reflect.TypeOf((*MockNATSClient)(nil).RequestWithContext), ctx, subj, data) 134 | } 135 | 136 | // Subscribe mocks base method. 137 | func (m *MockNATSClient) Subscribe(subj string, cb nats.MsgHandler) (*nats.Subscription, error) { 138 | m.ctrl.T.Helper() 139 | ret := m.ctrl.Call(m, "Subscribe", subj, cb) 140 | ret0, _ := ret[0].(*nats.Subscription) 141 | ret1, _ := ret[1].(error) 142 | return ret0, ret1 143 | } 144 | 145 | // Subscribe indicates an expected call of Subscribe. 146 | func (mr *MockNATSClientMockRecorder) Subscribe(subj, cb any) *gomock.Call { 147 | mr.mock.ctrl.T.Helper() 148 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subscribe", reflect.TypeOf((*MockNATSClient)(nil).Subscribe), subj, cb) 149 | } 150 | -------------------------------------------------------------------------------- /meta_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "fmt" 16 | "reflect" 17 | "testing" 18 | 19 | "go.aporeto.io/elemental" 20 | testmodel "go.aporeto.io/elemental/test/model" 21 | ) 22 | 23 | func Test_buildVersionedRoutes(t *testing.T) { 24 | type args struct { 25 | modelManagers map[int]elemental.ModelManager 26 | processorFinder processorFinderFunc 27 | } 28 | tests := []struct { 29 | args args 30 | want map[int][]RouteInfo 31 | name string 32 | }{ 33 | { 34 | name: "simple", 35 | args: args{ 36 | map[int]elemental.ModelManager{0: testmodel.Manager(), 1: testmodel.Manager()}, 37 | func(identity elemental.Identity) (Processor, error) { 38 | return mockProcessor{}, nil 39 | }, 40 | }, 41 | want: map[int][]RouteInfo{ 42 | 0: { 43 | { 44 | URL: "/lists", 45 | Identity: "lists", 46 | Verbs: []string{ 47 | "GET", 48 | "POST", 49 | }, 50 | }, 51 | { 52 | URL: "/lists/:id", 53 | Identity: "lists", 54 | Verbs: []string{ 55 | "DELETE", 56 | "GET", 57 | "PUT", 58 | }, 59 | }, 60 | { 61 | URL: "/lists/:id/tasks", 62 | Identity: "tasks", 63 | Verbs: []string{ 64 | "GET", 65 | "POST", 66 | }, 67 | }, 68 | { 69 | URL: "/lists/:id/users", 70 | Identity: "users", 71 | Verbs: []string{ 72 | "GET", 73 | }, 74 | }, 75 | { 76 | URL: "/tasks", 77 | Identity: "tasks", 78 | Verbs: []string{ 79 | "POST", 80 | }, 81 | }, 82 | { 83 | URL: "/tasks/:id", 84 | Identity: "tasks", 85 | Verbs: []string{ 86 | "DELETE", 87 | "GET", 88 | "PUT", 89 | }, 90 | }, 91 | { 92 | URL: "/users", 93 | Identity: "users", 94 | Verbs: []string{ 95 | "GET", 96 | "POST", 97 | }, 98 | }, 99 | { 100 | URL: "/users/:id", 101 | Identity: "users", 102 | Verbs: []string{ 103 | "DELETE", 104 | "GET", 105 | "PUT", 106 | }, 107 | }, 108 | }, 109 | 1: { 110 | { 111 | URL: "/lists", 112 | Identity: "lists", 113 | Verbs: []string{ 114 | "GET", 115 | "POST", 116 | }, 117 | }, 118 | { 119 | URL: "/lists/:id", 120 | Identity: "lists", 121 | Verbs: []string{ 122 | "DELETE", 123 | "GET", 124 | "PUT", 125 | }, 126 | }, 127 | { 128 | URL: "/lists/:id/tasks", 129 | Identity: "tasks", 130 | Verbs: []string{ 131 | "GET", 132 | "POST", 133 | }, 134 | }, 135 | { 136 | URL: "/lists/:id/users", 137 | Identity: "users", 138 | Verbs: []string{ 139 | "GET", 140 | }, 141 | }, 142 | { 143 | URL: "/tasks", 144 | Identity: "tasks", 145 | Verbs: []string{ 146 | "POST", 147 | }, 148 | }, 149 | { 150 | URL: "/tasks/:id", 151 | Identity: "tasks", 152 | Verbs: []string{ 153 | "DELETE", 154 | "GET", 155 | "PUT", 156 | }, 157 | }, 158 | { 159 | URL: "/users", 160 | Identity: "users", 161 | Verbs: []string{ 162 | "GET", 163 | "POST", 164 | }, 165 | }, 166 | { 167 | URL: "/users/:id", 168 | Identity: "users", 169 | Verbs: []string{ 170 | "DELETE", 171 | "GET", 172 | "PUT", 173 | }, 174 | }, 175 | }, 176 | }, 177 | }, 178 | { 179 | name: "error retrieving processor", 180 | args: args{ 181 | map[int]elemental.ModelManager{0: testmodel.Manager(), 1: testmodel.Manager()}, 182 | func(identity elemental.Identity) (Processor, error) { 183 | return nil, fmt.Errorf("boom") 184 | }, 185 | }, 186 | want: map[int][]RouteInfo{0: {}, 1: {}}, 187 | }, 188 | } 189 | 190 | for _, tt := range tests { 191 | t.Run(tt.name, func(t *testing.T) { 192 | if got := buildVersionedRoutes(tt.args.modelManagers, tt.args.processorFinder); !reflect.DeepEqual(got, tt.want) { 193 | t.Errorf("buildVersionedRoutes() = %v, want %v", got, tt.want) 194 | } 195 | }) 196 | } 197 | } 198 | 199 | func TestRouteInfo_String(t *testing.T) { 200 | type fields struct { 201 | URL string 202 | Verbs []string 203 | Private bool 204 | } 205 | tests := []struct { 206 | name string 207 | want string 208 | fields fields 209 | }{ 210 | { 211 | name: "simple", 212 | fields: fields{ 213 | URL: "http.com", 214 | Verbs: []string{"POST", "GET"}, 215 | }, 216 | want: "http.com -> POST, GET", 217 | }, 218 | } 219 | for _, tt := range tests { 220 | t.Run(tt.name, func(t *testing.T) { 221 | r := RouteInfo{ 222 | URL: tt.fields.URL, 223 | Verbs: tt.fields.Verbs, 224 | Private: tt.fields.Private, 225 | } 226 | if got := r.String(); got != tt.want { 227 | t.Errorf("RouteInfo.String() = %v, want %v", got, tt.want) 228 | } 229 | }) 230 | } 231 | } 232 | -------------------------------------------------------------------------------- /pubsub_nats_options.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "context" 16 | "crypto/tls" 17 | "fmt" 18 | "time" 19 | 20 | nats "github.com/nats-io/nats.go" 21 | ) 22 | 23 | // A NATSOption represents an option to the pubsub backed by nats 24 | type NATSOption func(*natsPubSub) 25 | 26 | // NATSOptConnectRetryInterval sets the connection retry interval 27 | func NATSOptConnectRetryInterval(interval time.Duration) NATSOption { 28 | return func(n *natsPubSub) { 29 | n.retryInterval = interval 30 | } 31 | } 32 | 33 | // NATSOptCredentials sets the username and password to use to connect to nats. 34 | func NATSOptCredentials(username string, password string) NATSOption { 35 | return func(n *natsPubSub) { 36 | n.username = username 37 | n.password = password 38 | } 39 | } 40 | 41 | // NATSOptClusterID sets the clusterID to use to connect to nats. 42 | func NATSOptClusterID(clusterID string) NATSOption { 43 | return func(n *natsPubSub) { 44 | n.clusterID = clusterID 45 | } 46 | } 47 | 48 | // NATSOptClientID sets the client ID to use to connect to nats. 49 | func NATSOptClientID(clientID string) NATSOption { 50 | return func(n *natsPubSub) { 51 | n.clientID = clientID 52 | } 53 | } 54 | 55 | // NATSOptTLS sets the tls config to use to connect nats. 56 | func NATSOptTLS(tlsConfig *tls.Config) NATSOption { 57 | return func(n *natsPubSub) { 58 | n.tlsConfig = tlsConfig 59 | } 60 | } 61 | 62 | // NATSErrorHandler sets the error handler to install in nats client. 63 | func NATSErrorHandler(handler func(*nats.Conn, *nats.Subscription, error)) NATSOption { 64 | return func(n *natsPubSub) { 65 | n.errorHandleFunc = handler 66 | } 67 | } 68 | 69 | // natsOptClient sets the NATS client that will be used 70 | // This is useful for unit testing as you can pass in a mocked NATS client 71 | func natsOptClient(client natsClient) NATSOption { 72 | return func(n *natsPubSub) { 73 | n.client = client 74 | } 75 | } 76 | 77 | var ackMessage = []byte("ack") 78 | 79 | type natsSubscribeConfig struct { 80 | queueGroup string 81 | replyTimeout time.Duration 82 | } 83 | 84 | func defaultSubscribeConfig() natsSubscribeConfig { 85 | return natsSubscribeConfig{ 86 | replyTimeout: 60 * time.Second, 87 | } 88 | } 89 | 90 | type natsPublishConfig struct { 91 | ctx context.Context 92 | responseCh chan *Publication 93 | desiredResponse ResponseMode 94 | } 95 | 96 | // NATSOptSubscribeQueue sets the NATS subscriber queue group. 97 | // In short, this allows to ensure only one subscriber in the 98 | // queue group with the same name will receive the publication. 99 | // 100 | // See: https://nats.io/documentation/concepts/nats-queueing/ 101 | func NATSOptSubscribeQueue(queueGroup string) PubSubOptSubscribe { 102 | return func(c any) { 103 | c.(*natsSubscribeConfig).queueGroup = queueGroup 104 | } 105 | } 106 | 107 | // NATSOptSubscribeReplyTimeout sets the duration of time to wait before giving up 108 | // waiting for a response to publish back to the client that is expecting a response 109 | func NATSOptSubscribeReplyTimeout(t time.Duration) PubSubOptSubscribe { 110 | return func(c any) { 111 | c.(*natsSubscribeConfig).replyTimeout = t 112 | } 113 | } 114 | 115 | // NATSOptRespondToChannel will send the *Publication received to the provided channel. 116 | // 117 | // This is an advanced option which is useful in situations where you want to block until 118 | // you receive a response. The context parameter allows you to provide a deadline on how long 119 | // you should wait before considering the request as a failure: 120 | // 121 | // myCtx, _ := context.WithTimeout(context.Background(), 5*time.Second) 122 | // respCh := make(chan *Publication) 123 | // publishOption := NATSOptRespondToChannel(myCtx, respCh) 124 | // 125 | // This option CANNOT be combined with NATSOptPublishRequireAck 126 | func NATSOptRespondToChannel(ctx context.Context, resp chan *Publication) PubSubOptPublish { 127 | return func(c any) { 128 | config := c.(*natsPublishConfig) 129 | 130 | switch { 131 | case config.desiredResponse != ResponseModeNone: 132 | panic(fmt.Sprintf("illegal option: request mode has already been set to %s", config.desiredResponse)) 133 | case resp == nil: 134 | panic("illegal argument: response channel cannot be nil") 135 | case ctx == nil: 136 | panic("illegal argument: context cannot be nil") 137 | } 138 | 139 | config.ctx = ctx 140 | config.responseCh = resp 141 | config.desiredResponse = ResponseModePublication 142 | } 143 | } 144 | 145 | // NATSOptPublishRequireAck is a helper to require a ack in the limit 146 | // of the given context.Context. If the other side is bahamut.PubSubClient 147 | // using the Subscribe method, then it will automatically send back the expected 148 | // ack. 149 | // 150 | // This option CANNOT be combined with NATSOptRespondToChannel 151 | func NATSOptPublishRequireAck(ctx context.Context) PubSubOptPublish { 152 | return func(c any) { 153 | config := c.(*natsPublishConfig) 154 | 155 | switch { 156 | case config.desiredResponse != ResponseModeNone: 157 | panic(fmt.Sprintf("illegal option: request mode has already been set to %s", config.desiredResponse)) 158 | case ctx == nil: 159 | panic("illegal argument: context cannot be nil") 160 | } 161 | 162 | config.ctx = ctx 163 | config.desiredResponse = ResponseModeACK 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /context.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "context" 16 | "net/http" 17 | "sync" 18 | 19 | "github.com/gofrs/uuid" 20 | "go.aporeto.io/elemental" 21 | ) 22 | 23 | type bcontext struct { 24 | outputData any 25 | ctx context.Context 26 | inputData any 27 | claimsMap map[string]string 28 | responseWriter ResponseWriter 29 | request *elemental.Request 30 | eventsLock *sync.Mutex 31 | messagesLock *sync.Mutex 32 | metadata map[any]any 33 | id string 34 | redirect string 35 | next string 36 | messages []string 37 | outputCookies []*http.Cookie 38 | claims []string 39 | events elemental.Events 40 | count int 41 | statusCode int 42 | disableOutputDataPush bool 43 | } 44 | 45 | // NewContext creates a new *Context. 46 | func NewContext(ctx context.Context, request *elemental.Request) Context { 47 | return newContext(ctx, request) 48 | } 49 | 50 | func newContext(ctx context.Context, request *elemental.Request) *bcontext { 51 | 52 | if ctx == nil { 53 | panic("nil context") 54 | } 55 | 56 | return &bcontext{ 57 | claims: nil, 58 | claimsMap: map[string]string{}, 59 | ctx: ctx, 60 | eventsLock: &sync.Mutex{}, 61 | id: uuid.Must(uuid.NewV4()).String(), 62 | messagesLock: &sync.Mutex{}, 63 | request: request, 64 | } 65 | } 66 | 67 | func (c *bcontext) Identifier() string { 68 | return c.id 69 | } 70 | 71 | func (c *bcontext) Context() context.Context { 72 | return c.ctx 73 | } 74 | 75 | func (c *bcontext) Request() *elemental.Request { 76 | return c.request 77 | } 78 | 79 | func (c *bcontext) Count() int { 80 | return c.count 81 | } 82 | 83 | func (c *bcontext) SetCount(count int) { 84 | c.count = count 85 | } 86 | 87 | func (c *bcontext) InputData() any { 88 | return c.inputData 89 | } 90 | 91 | func (c *bcontext) SetInputData(data any) { 92 | c.inputData = data 93 | } 94 | 95 | func (c *bcontext) OutputData() any { 96 | return c.outputData 97 | } 98 | 99 | func (c *bcontext) SetDisableOutputDataPush(disabled bool) { 100 | c.disableOutputDataPush = disabled 101 | } 102 | 103 | func (c *bcontext) SetOutputData(data any) { 104 | 105 | if c.responseWriter != nil { 106 | panic("you cannot use SetOutputData after using SetResponseWriter") 107 | } 108 | 109 | c.outputData = data 110 | } 111 | 112 | func (c *bcontext) SetResponseWriter(writer ResponseWriter) { 113 | 114 | if c.outputData != nil { 115 | panic("you cannot use SetResponseWriter after using SetOutputData") 116 | } 117 | 118 | c.responseWriter = writer 119 | } 120 | 121 | func (c *bcontext) StatusCode() int { 122 | return c.statusCode 123 | } 124 | 125 | func (c *bcontext) SetStatusCode(code int) { 126 | c.statusCode = code 127 | } 128 | 129 | func (c *bcontext) Redirect() string { 130 | return c.redirect 131 | } 132 | 133 | func (c *bcontext) SetRedirect(url string) { 134 | c.redirect = url 135 | } 136 | 137 | func (c *bcontext) Metadata(key any) any { 138 | 139 | if c.metadata == nil { 140 | return nil 141 | } 142 | 143 | return c.metadata[key] 144 | } 145 | 146 | func (c *bcontext) SetMetadata(key, value any) { 147 | 148 | if c.metadata == nil { 149 | c.metadata = map[any]any{} 150 | } 151 | 152 | c.metadata[key] = value 153 | } 154 | 155 | func (c *bcontext) SetClaims(claims []string) { 156 | 157 | if claims == nil { 158 | return 159 | } 160 | 161 | c.claims = append([]string{}, claims...) 162 | c.claimsMap = claimsToMap(c.claims) 163 | } 164 | 165 | func (c *bcontext) Claims() []string { 166 | 167 | return append([]string{}, c.claims...) 168 | } 169 | 170 | func (c *bcontext) ClaimsMap() map[string]string { 171 | 172 | o := make(map[string]string, len(c.claimsMap)) 173 | 174 | for k, v := range c.claimsMap { 175 | o[k] = v 176 | } 177 | 178 | return o 179 | } 180 | 181 | func (c *bcontext) EnqueueEvents(events ...*elemental.Event) { 182 | 183 | c.eventsLock.Lock() 184 | defer c.eventsLock.Unlock() 185 | 186 | c.events = append(c.events, events...) 187 | } 188 | 189 | func (c *bcontext) SetNext(next string) { 190 | c.next = next 191 | } 192 | 193 | func (c *bcontext) AddMessage(msg string) { 194 | c.messagesLock.Lock() 195 | c.messages = append(c.messages, msg) 196 | c.messagesLock.Unlock() 197 | } 198 | 199 | func (c *bcontext) AddOutputCookies(cookies ...*http.Cookie) { 200 | c.outputCookies = append(c.outputCookies, cookies...) 201 | } 202 | 203 | func (c *bcontext) Duplicate() Context { 204 | 205 | c2 := newContext(c.ctx, c.request.Duplicate()) 206 | 207 | c2.inputData = c.inputData 208 | c2.count = c.count 209 | c2.statusCode = c.statusCode 210 | c2.outputData = c.outputData 211 | c2.claims = append(c2.claims, c.claims...) 212 | c2.redirect = c.redirect 213 | c2.messages = append(c2.messages, c.messages...) 214 | c2.next = c.next 215 | c2.outputCookies = append(c2.outputCookies, c.outputCookies...) 216 | c2.responseWriter = c.responseWriter 217 | c2.disableOutputDataPush = c.disableOutputDataPush 218 | 219 | for k, v := range c.claimsMap { 220 | c2.claimsMap[k] = v 221 | } 222 | 223 | if c.metadata != nil { 224 | c2.metadata = map[any]any{} 225 | for k, v := range c.metadata { 226 | c2.metadata[k] = v 227 | } 228 | } 229 | 230 | return c2 231 | } 232 | -------------------------------------------------------------------------------- /gateway/upstreamer/push/upstreamer_distribution_test.go: -------------------------------------------------------------------------------- 1 | package push 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | "sync" 7 | "testing" 8 | "time" 9 | 10 | // nolint:revive // Allow dot imports for readability in tests 11 | . "github.com/smartystreets/goconvey/convey" 12 | "golang.org/x/time/rate" 13 | ) 14 | 15 | func TestUpstreamUpstreamerDistribution(t *testing.T) { 16 | 17 | Convey("Given I have an upstreamer with 3 registered apis with different loads", t, func() { 18 | 19 | u := NewUpstreamer(nil, "topic", "topic2") 20 | u.apis = map[string][]*endpointInfo{ 21 | "/cats": { 22 | { 23 | address: "1.1.1.1:1", 24 | lastLoad: 10.0, 25 | }, 26 | { 27 | address: "2.2.2.2:1", 28 | lastLoad: 10.0, 29 | }, 30 | { 31 | address: "3.3.3.3:1", 32 | lastLoad: 81.0, 33 | }, 34 | }, 35 | } 36 | 37 | Convey("When I call upstream on /cats 2k times", func() { 38 | 39 | counts := make(map[string]int) 40 | 41 | for i := 0; i <= 2000; i++ { 42 | upstream, _ := u.Upstream(&http.Request{ 43 | URL: &url.URL{Path: "/cats"}, 44 | }) 45 | counts[upstream]++ 46 | } 47 | 48 | Convey("Then the repoartition should be correct", func() { 49 | So(counts["1.1.1.1:1"], ShouldAlmostEqual, counts["2.2.2.2:1"], 200) 50 | So(counts["3.3.3.3:1"], ShouldBeLessThan, counts["1.1.1.1:1"]/2) 51 | }) 52 | }) 53 | }) 54 | 55 | Convey("Given I have an upstreamer with 1 not loaded/ratelimited and one loaded/not ratelimited", t, func() { 56 | 57 | u := NewUpstreamer(nil, "topic", "topic2") 58 | u.apis = map[string][]*endpointInfo{ 59 | "/cats": { 60 | { 61 | address: "1.1.1.1:1", 62 | lastLoad: 10.0, 63 | limiters: IdentityToAPILimitersRegistry{ 64 | "cats": {limiter: rate.NewLimiter(rate.Limit(1), 1)}, 65 | }, 66 | }, 67 | { 68 | address: "3.3.3.3:1", 69 | lastLoad: 81.0, 70 | }, 71 | }, 72 | } 73 | 74 | Convey("When I call upstream on /cats 2k times", func() { 75 | 76 | counts := make(map[string]int) 77 | 78 | for i := 0; i <= 2000; i++ { 79 | upstream, _ := u.Upstream(&http.Request{ 80 | URL: &url.URL{Path: "/cats"}, 81 | }) 82 | counts[upstream]++ 83 | } 84 | 85 | Convey("Then the repoartition should be correct", func() { 86 | So(counts["1.1.1.1:1"], ShouldAlmostEqual, 0, 10) 87 | So(counts["3.3.3.3:1"], ShouldAlmostEqual, 2000, 10) 88 | }) 89 | }) 90 | }) 91 | 92 | Convey("Given I have an upstreamer with 1 not loaded/not ratelimited and one loaded/ratelimited", t, func() { 93 | 94 | u := NewUpstreamer(nil, "topic", "topic2") 95 | u.apis = map[string][]*endpointInfo{ 96 | "/cats": { 97 | { 98 | address: "1.1.1.1:1", 99 | lastLoad: 10.0, 100 | }, 101 | { 102 | address: "3.3.3.3:1", 103 | lastLoad: 81.0, 104 | limiters: IdentityToAPILimitersRegistry{ 105 | "cats": {limiter: rate.NewLimiter(rate.Limit(1), 1)}, 106 | }, 107 | }, 108 | }, 109 | } 110 | 111 | Convey("When I call upstream on /cats 2k times", func() { 112 | 113 | counts := make(map[string]int) 114 | 115 | for i := 0; i <= 2000; i++ { 116 | upstream, _ := u.Upstream(&http.Request{ 117 | URL: &url.URL{Path: "/cats"}, 118 | }) 119 | counts[upstream]++ 120 | } 121 | 122 | Convey("Then the repoartition should be correct", func() { 123 | So(counts["1.1.1.1:1"], ShouldAlmostEqual, 2000, 10) 124 | So(counts["3.3.3.3:1"], ShouldAlmostEqual, 0, 10) 125 | }) 126 | }) 127 | }) 128 | } 129 | 130 | func TestLatencyBasedUpstreamer(t *testing.T) { 131 | 132 | Convey("Given I have a new latency based upstreamer", t, func() { 133 | u := NewUpstreamer(nil, "topic", "topic2") 134 | u.config.latencySampleSize = 2 135 | 136 | Convey("When I there is no entries the average is not available", func() { 137 | 138 | var v float64 139 | var err error 140 | 141 | if ma, ok := u.latencies.Load("foo"); ok { 142 | v, err = ma.(movingAverage).average() 143 | } 144 | 145 | So(v, ShouldEqual, 0) 146 | So(err, ShouldBeNil) 147 | }) 148 | 149 | Convey("When I add one entry the average is not yet available", func() { 150 | u.CollectLatency("bar", 1*time.Microsecond) 151 | 152 | var v float64 153 | var err error 154 | 155 | if ma, ok := u.latencies.Load("bar"); ok { 156 | v, err = ma.(movingAverage).average() 157 | } 158 | 159 | So(v, ShouldEqual, 0) 160 | So(err, ShouldNotBeNil) 161 | }) 162 | 163 | Convey("When I add two entries the average is available", func() { 164 | u.CollectLatency("bar", 1*time.Microsecond) 165 | u.CollectLatency("bar", 1*time.Microsecond) 166 | 167 | var v float64 168 | var err error 169 | 170 | if ma, ok := u.latencies.Load("bar"); ok { 171 | v, err = ma.(movingAverage).average() 172 | } 173 | 174 | So(v, ShouldEqual, 1) 175 | So(err, ShouldBeNil) 176 | }) 177 | 178 | Convey("When I add entries concurently there is no race", func() { 179 | 180 | u := NewUpstreamer(nil, "topic", "topic2") 181 | u.config.latencySampleSize = 100 182 | 183 | var wg sync.WaitGroup 184 | 185 | inc := func() { 186 | defer wg.Done() 187 | u.CollectLatency("bar", 1*time.Microsecond) 188 | } 189 | 190 | for i := 0; i < 100; i++ { 191 | wg.Add(1) 192 | go inc() 193 | } 194 | 195 | wg.Wait() 196 | 197 | if ma, ok := u.latencies.Load("bar"); ok { 198 | // As there is no garrantee of the result as the operation can overlap 199 | // we are not checking the result here. This is just to track races 200 | _, _ = ma.(movingAverage).average() 201 | } 202 | 203 | }) 204 | 205 | Convey("When I delete an entry a values the average is not available", func() { 206 | u.latencies.Delete("bar") 207 | var v float64 208 | var err error 209 | 210 | if ma, ok := u.latencies.Load("bar"); ok { 211 | v, err = ma.(movingAverage).average() 212 | } 213 | 214 | So(v, ShouldEqual, 0) 215 | So(err, ShouldBeNil) 216 | }) 217 | 218 | }) 219 | } 220 | -------------------------------------------------------------------------------- /metrics_prometheus.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "net/http" 16 | "regexp" 17 | "strconv" 18 | "strings" 19 | "time" 20 | 21 | opentracing "github.com/opentracing/opentracing-go" 22 | "github.com/prometheus/client_golang/prometheus" 23 | "github.com/prometheus/client_golang/prometheus/promhttp" 24 | ) 25 | 26 | var ( 27 | pregexp = regexp.MustCompile(`^/_[a-zA-Z0-9-_+]+`) 28 | vregexp = regexp.MustCompile(`^/v/\d+`) 29 | ) 30 | 31 | func sanitizePath(url string) string { 32 | 33 | prefix := "/" 34 | matches := pregexp.FindAllString(url, 1) 35 | if len(matches) == 1 { 36 | prefix = matches[0] + "/" 37 | url = strings.TrimPrefix(url, matches[0]) 38 | } 39 | url = vregexp.ReplaceAllString(url, "") 40 | url = strings.TrimPrefix(url, "/") 41 | 42 | parts := strings.Split(url, "/") 43 | 44 | if len(parts) <= 1 { 45 | return prefix + url 46 | } 47 | 48 | parts[1] = ":id" 49 | 50 | return prefix + strings.Join(parts, "/") 51 | } 52 | 53 | type prometheusMetricsManager struct { 54 | reqDurationMetric *prometheus.SummaryVec 55 | reqTotalMetric *prometheus.CounterVec 56 | errorMetric *prometheus.CounterVec 57 | tcpConnTotalMetric prometheus.Counter 58 | tcpConnCurrentMetric prometheus.Gauge 59 | wsConnTotalMetric prometheus.Counter 60 | wsConnCurrentMetric prometheus.Gauge 61 | 62 | handler http.Handler 63 | } 64 | 65 | // NewPrometheusMetricsManager returns a new MetricManager using the prometheus format. 66 | func NewPrometheusMetricsManager() MetricsManager { 67 | 68 | return newPrometheusMetricsManager(prometheus.DefaultRegisterer) 69 | } 70 | 71 | func newPrometheusMetricsManager(registerer prometheus.Registerer) MetricsManager { 72 | mc := &prometheusMetricsManager{ 73 | handler: promhttp.Handler(), 74 | reqTotalMetric: prometheus.NewCounterVec( 75 | prometheus.CounterOpts{ 76 | Name: "http_requests_total", 77 | Help: "The total number of requests.", 78 | }, 79 | []string{"method", "url", "code"}, 80 | ), 81 | reqDurationMetric: prometheus.NewSummaryVec( 82 | prometheus.SummaryOpts{ 83 | Name: "http_requests_duration_seconds", 84 | Help: "The average duration of the requests", 85 | Objectives: map[float64]float64{0.5: 0.05, 0.9: 0.01, 0.99: 0.001}, 86 | }, 87 | []string{"method", "url"}, 88 | ), 89 | tcpConnTotalMetric: prometheus.NewCounter( 90 | prometheus.CounterOpts{ 91 | Name: "tcp_connections_total", 92 | Help: "The total number of TCP connection.", 93 | }, 94 | ), 95 | tcpConnCurrentMetric: prometheus.NewGauge( 96 | prometheus.GaugeOpts{ 97 | Name: "tcp_connections_current", 98 | Help: "The current number of TCP connection.", 99 | }, 100 | ), 101 | wsConnTotalMetric: prometheus.NewCounter( 102 | prometheus.CounterOpts{ 103 | Name: "http_ws_connections_total", 104 | Help: "The total number of ws connection.", 105 | }, 106 | ), 107 | wsConnCurrentMetric: prometheus.NewGauge( 108 | prometheus.GaugeOpts{ 109 | Name: "http_ws_connections_current", 110 | Help: "The current number of ws connection.", 111 | }, 112 | ), 113 | errorMetric: prometheus.NewCounterVec( 114 | prometheus.CounterOpts{ 115 | Name: "http_errors_5xx_total", 116 | Help: "The total number of 5xx errors.", 117 | }, 118 | []string{"trace", "method", "url", "code"}, 119 | ), 120 | } 121 | 122 | registerer.MustRegister(mc.tcpConnCurrentMetric) 123 | registerer.MustRegister(mc.tcpConnTotalMetric) 124 | registerer.MustRegister(mc.reqTotalMetric) 125 | registerer.MustRegister(mc.reqDurationMetric) 126 | registerer.MustRegister(mc.wsConnTotalMetric) 127 | registerer.MustRegister(mc.wsConnCurrentMetric) 128 | registerer.MustRegister(mc.errorMetric) 129 | 130 | return mc 131 | } 132 | 133 | func (c *prometheusMetricsManager) MeasureRequest(method string, path string) FinishMeasurementFunc { 134 | 135 | surl := sanitizePath(path) 136 | 137 | timer := prometheus.NewTimer( 138 | prometheus.ObserverFunc( 139 | func(v float64) { 140 | c.reqDurationMetric.With( 141 | prometheus.Labels{ 142 | "method": method, 143 | "url": surl, 144 | }, 145 | ).Observe(v) 146 | }, 147 | ), 148 | ) 149 | 150 | return func(code int, span opentracing.Span) time.Duration { 151 | 152 | c.reqTotalMetric.With(prometheus.Labels{ 153 | "method": method, 154 | "url": surl, 155 | "code": strconv.Itoa(code), 156 | }).Inc() 157 | 158 | if code >= http.StatusInternalServerError { 159 | 160 | c.errorMetric.With(prometheus.Labels{ 161 | "trace": extractSpanID(span), 162 | "method": method, 163 | "url": surl, 164 | "code": strconv.Itoa(code), 165 | }).Inc() 166 | } 167 | 168 | return timer.ObserveDuration() 169 | } 170 | } 171 | 172 | func (c *prometheusMetricsManager) RegisterWSConnection() { 173 | c.wsConnTotalMetric.Inc() 174 | c.wsConnCurrentMetric.Inc() 175 | } 176 | 177 | func (c *prometheusMetricsManager) UnregisterWSConnection() { 178 | c.wsConnCurrentMetric.Dec() 179 | } 180 | 181 | func (c *prometheusMetricsManager) RegisterTCPConnection() { 182 | c.tcpConnTotalMetric.Inc() 183 | c.tcpConnCurrentMetric.Inc() 184 | } 185 | 186 | func (c *prometheusMetricsManager) UnregisterTCPConnection() { 187 | c.tcpConnCurrentMetric.Dec() 188 | } 189 | 190 | func (c *prometheusMetricsManager) Write(w http.ResponseWriter, r *http.Request) { 191 | c.handler.ServeHTTP(w, r) 192 | } 193 | -------------------------------------------------------------------------------- /context_mock_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Aporeto Inc. 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // http://www.apache.org/licenses/LICENSE-2.0 6 | // Unless required by applicable law or agreed to in writing, software 7 | // distributed under the License is distributed on an "AS IS" BASIS, 8 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | // See the License for the specific language governing permissions and 10 | // limitations under the License. 11 | 12 | package bahamut 13 | 14 | import ( 15 | "context" 16 | "net/http" 17 | "testing" 18 | 19 | // nolint:revive // Allow dot imports for readability in tests 20 | . "github.com/smartystreets/goconvey/convey" 21 | "go.aporeto.io/elemental" 22 | testmodel "go.aporeto.io/elemental/test/model" 23 | ) 24 | 25 | func TestMockContext_NewMockContext(t *testing.T) { 26 | 27 | Convey("Given I call NewMockContext", t, func() { 28 | 29 | c := NewMockContext(context.Background()) 30 | 31 | Convey("Then it should be correctly initialized", func() { 32 | So(c.MockCtx, ShouldResemble, context.Background()) 33 | So(c.Metadata("hello"), ShouldBeNil) 34 | So(c, ShouldImplement, (*Context)(nil)) 35 | }) 36 | }) 37 | } 38 | 39 | func TestMockContext_Identifier(t *testing.T) { 40 | 41 | Convey("Identifier should work", t, func() { 42 | ctx := NewMockContext(context.Background()) 43 | So(ctx.Identifier(), ShouldNotBeEmpty) 44 | }) 45 | } 46 | 47 | func TestMockContext_Events(t *testing.T) { 48 | 49 | Convey("Given I create a Context", t, func() { 50 | 51 | c := NewMockContext(context.Background()) 52 | 53 | Convey("When I enqueue 2 events", func() { 54 | 55 | c.EnqueueEvents( 56 | elemental.NewEvent(elemental.EventCreate, testmodel.NewList()), 57 | elemental.NewEvent(elemental.EventCreate, testmodel.NewList()), 58 | ) 59 | 60 | Convey("Then I should have 2 events in the queue", func() { 61 | So(len(c.MockEvents), ShouldEqual, 2) 62 | }) 63 | }) 64 | }) 65 | } 66 | 67 | func TestMockContext_Duplicate(t *testing.T) { 68 | 69 | Convey("Given I have a Context, Info, Count, and Page", t, func() { 70 | 71 | req := &elemental.Request{ 72 | Namespace: "/thens", 73 | Headers: http.Header{"header": []string{"h1"}}, 74 | Identity: elemental.EmptyIdentity, 75 | ParentID: "xxxx", 76 | ParentIdentity: elemental.EmptyIdentity, 77 | Operation: elemental.OperationCreate, 78 | } 79 | 80 | cookies := []*http.Cookie{{}, {}} 81 | rwriter := func(http.ResponseWriter) int { return 0 } 82 | 83 | ctx := NewMockContext(context.Background()) 84 | ctx.MockRequest = req 85 | ctx.SetCount(10) 86 | ctx.SetInputData("input") 87 | ctx.SetOutputData("output") 88 | ctx.SetStatusCode(42) 89 | ctx.AddMessage("a") 90 | ctx.SetRedirect("laba") 91 | ctx.AddMessage("b") 92 | ctx.SetMetadata("hello", "world") 93 | ctx.SetClaims([]string{"ouais=yes"}) 94 | ctx.SetNext("next") 95 | ctx.AddOutputCookies(cookies[0], cookies[1]) 96 | ctx.SetResponseWriter(rwriter) 97 | ctx.SetDisableOutputDataPush(true) 98 | 99 | Convey("When I call the Duplicate method", func() { 100 | 101 | ctx2 := ctx.Duplicate() 102 | 103 | Convey("Then the duplicated context should be correct", func() { 104 | So(ctx.MockCtx, ShouldResemble, ctx2.Context()) 105 | So(ctx.MockCount, ShouldEqual, ctx2.Count()) 106 | So(ctx.Metadata("hello").(string), ShouldEqual, "world") 107 | So(ctx.MockInputData, ShouldEqual, ctx2.InputData()) 108 | So(ctx.MockOutputData, ShouldEqual, ctx2.OutputData()) 109 | So(ctx.MockRequest.Namespace, ShouldEqual, ctx2.Request().Namespace) 110 | So(ctx.MockRequest.ParentID, ShouldEqual, ctx2.Request().ParentID) 111 | So(ctx.MockStatusCode, ShouldEqual, ctx2.StatusCode()) 112 | So(ctx.MockClaims, ShouldResemble, ctx2.Claims()) 113 | So(ctx.MockClaimsMap, ShouldResemble, ctx2.ClaimsMap()) 114 | So(ctx.MockRedirect, ShouldEqual, ctx2.Redirect()) 115 | So(ctx.MockNext, ShouldEqual, ctx2.(*MockContext).MockNext) 116 | So(ctx.MockMessages, ShouldResemble, ctx2.(*MockContext).MockMessages) 117 | So(ctx.MockOutputCookies, ShouldResemble, ctx2.(*MockContext).MockOutputCookies) 118 | So(ctx.MockOutputCookies, ShouldResemble, cookies) 119 | So(ctx.MockResponseWriter, ShouldEqual, rwriter) 120 | So(ctx.MockDisableOutputDataPush, ShouldEqual, ctx.MockDisableOutputDataPush) 121 | }) 122 | }) 123 | }) 124 | } 125 | 126 | func TestMockContext_GetClaims(t *testing.T) { 127 | 128 | Convey("Given I have a Context with claims", t, func() { 129 | 130 | oc := []string{"ouais=yes"} 131 | 132 | ctx := NewMockContext(context.Background()) 133 | ctx.SetClaims(oc) 134 | 135 | Convey("When I call GetClaims", func() { 136 | 137 | claims := ctx.Claims() 138 | 139 | Convey("Then claims should be correct", func() { 140 | So(claims, ShouldResemble, oc) 141 | So(claims, ShouldNotEqual, oc) 142 | }) 143 | }) 144 | 145 | Convey("When I call GetClaimsMap", func() { 146 | 147 | claimsMap := ctx.ClaimsMap() 148 | 149 | Convey("Then claims should be correct", func() { 150 | So(claimsMap, ShouldResemble, map[string]string{"ouais": "yes"}) 151 | So(claimsMap, ShouldNotEqual, ctx.MockClaimsMap) 152 | }) 153 | }) 154 | }) 155 | 156 | Convey("Given I have a Context nil claims", t, func() { 157 | 158 | ctx := NewMockContext(context.Background()) 159 | ctx.SetClaims(nil) 160 | 161 | Convey("When I call GetClaims", func() { 162 | 163 | claims := ctx.Claims() 164 | 165 | Convey("Then claims should be correct", func() { 166 | So(len(claims), ShouldEqual, 0) 167 | }) 168 | }) 169 | 170 | Convey("When I call GetClaimsMap", func() { 171 | 172 | claimsMap := ctx.ClaimsMap() 173 | 174 | Convey("Then claims should be correct", func() { 175 | So(claimsMap, ShouldResemble, map[string]string{}) 176 | }) 177 | }) 178 | }) 179 | } 180 | --------------------------------------------------------------------------------