├── .editorconfig ├── .github └── workflows │ └── test.yml ├── .gitignore ├── .go-version ├── LICENSE ├── Makefile ├── README.md ├── agent ├── agent.go ├── agent_test.go ├── cache.go ├── cache_store.go ├── cache_test.go ├── decompress.go ├── decompress_test.go ├── html.go ├── html_test.go ├── option.go └── option_test.go ├── benchmark.go ├── benchmark_option.go ├── benchmark_result.go ├── benchmark_scenario.go ├── benchmark_scenario_test.go ├── benchmark_step.go ├── benchmark_test.go ├── demo ├── agent │ └── main.go ├── failure │ └── main.go ├── pubsub │ └── main.go └── worker │ └── main.go ├── failure ├── cleaner.go ├── cleaner_test.go ├── code.go ├── code_test.go ├── error.go ├── error_test.go ├── errors.go └── errors_test.go ├── go.mod ├── go.sum ├── parallel ├── parallel.go └── parallel_test.go ├── pubsub ├── pubsub.go └── pubsub_test.go ├── random └── useragent │ ├── browser.go │ ├── platform.go │ ├── platform_test.go │ ├── useragent.go │ └── useragent_test.go ├── score ├── score.go └── score_test.go ├── test ├── http.go └── http_test.go └── worker ├── worker.go └── worker_test.go /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | charset = utf-8 5 | indent_size = 2 6 | end_of_line = lf 7 | insert_final_newline = true 8 | trim_trailing_whitespace = true 9 | 10 | [*.go] 11 | indent_style = tab 12 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | on: 3 | push: 4 | branches: 5 | - master 6 | pull_request: 7 | branches: 8 | - master 9 | 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Checkout 15 | uses: actions/checkout@master 16 | with: 17 | fetch-depth: 1 18 | - name: Detect go version 19 | id: go-version 20 | run: echo "::set-output name=VERSION::$(cat .go-version)" 21 | - name: Setup go 22 | uses: actions/setup-go@master 23 | with: 24 | stable: 'false' 25 | go-version: ${{ steps.go-version.outputs.VERSION }} 26 | - name: Run test 27 | run: make test 28 | env: 29 | GOARGS: "-v -race" 30 | GOMAXPROCS: 8 31 | - name: Report coverage 32 | uses: codecov/codecov-action@v1 33 | with: 34 | token: ${{ secrets.CODECOV_TOKEN }} 35 | file: ./tmp/cover.out 36 | - name: Run demo 37 | run: make demo 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | /tmp 3 | *.test 4 | -------------------------------------------------------------------------------- /.go-version: -------------------------------------------------------------------------------- 1 | 1.18 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2020 Sho Kusano 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | GOTIMEOUT?=20s 2 | GOARGS?=-race 3 | GOMAXPROCS?=$(shell nproc) 4 | 5 | .PHONY: test 6 | test: 7 | @mkdir -p tmp 8 | @go test -cover -coverprofile=tmp/cover.out -covermode=atomic ./... 9 | @go tool cover -html=tmp/cover.out -o tmp/coverage.html 10 | 11 | .PHONY: bench 12 | bench: 13 | @for d in $(shell go list ./... | grep -v vendor | grep -v demo); do \ 14 | GOMAXPROCS=$(GOMAXPROCS) \ 15 | go test \ 16 | $(GOARGS) \ 17 | -bench=^Benchmark \ 18 | -benchmem \ 19 | "$$d" || exit 1; \ 20 | done 21 | 22 | .PHONY: demo 23 | demo: 24 | @for d in $(shell go list ./... | grep -v vendor | grep demo); do \ 25 | echo "===> Demo: $$d" && \ 26 | go run "$$d" || exit 1; \ 27 | done 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # isucandar 2 | 3 | [![test](https://github.com/isucon/isucandar/workflows/test/badge.svg)](https://github.com/isucon/isucandar/actions?query=workflow%3Atest) 4 | [![codecov](https://codecov.io/gh/isucon/isucandar/branch/master/graph/badge.svg?token=KO1N8H5S53)](https://codecov.io/gh/isucon/isucandar) 5 | 6 | isucandar は [ISUCON](http://isucon.net/) などの負荷試験で使える機能を集めたベンチマーカーフレームワークです。 7 | 8 | 主な機能として、ブラウザのように振る舞うエージェント、複数階層のスタックトレースを持ったエラー、スコア計算、並列数を制御しつつ外部から停止可能なワーカーなどがあります。 9 | 10 | ## 使い方 11 | 12 | ### agent 13 | 14 | `isucandar/agent` はブラウザに近い(似せた)挙動をすることを目的として作られたパッケージです。 15 | 16 | `net/http` を基礎にしつつ、いくつかの拡張が行われています。 17 | 18 | ```golang 19 | //// Agent 20 | // NewAgent の引数には可変長で func(*Agent) error な関数を渡せます。 21 | // その中で Agent の初期設定を完了させてください。 22 | // 簡易につかえるように、いくつかの AgentOption を返す関数が用意されています。 23 | agent, err := NewAgent(WithBaseURL("http://isucon.net"), WithDefaultTransport()) 24 | 25 | // 通常の http.NewRequest のように呼び出せます。 26 | req, err := agent.NewRequest(http.MethodGet, "/", nil) 27 | // あるいは、以下のような形でも作成できます。 28 | // req, _ := agent.GET("/") 29 | // req, _ := agent.POST("/", body) 30 | 31 | // タイムアウトの制御は主に Context.WithTimeout で行われることを想定しています。 32 | // 利用している Transport や Dialer は DefaultTransport や DefaultDialer を参照してください。 33 | res, err := agent.Do(context.TODO(), req) 34 | // Agent は自動的に以下のような挙動で振る舞います。 35 | // - CookieJar を持っているので Cookie を保存している 36 | // - CacheStore を利用して Conditinal GET あるいはキャッシュ残存期間次第では、リクエストを行わずキャッシュからレスポンスを復元する 37 | // - Content-Encoding で gzip, deflate, brotli のいずれかが指定されて居た場合、自動的に展開する(Accept-Encoding も付与します) 38 | // - 自身の Name に応じて User-Agent を設定する 39 | // - 特に Accept が指定されていない時、自動でブラウザの送るような Accept を送信する 40 | 41 | // 取得した HTTP レスポンスを使って、さらにブラウザのような挙動をさせることができます。 42 | resources, err := agent.ProcessHTML(context.TODO(), req, req.Body) 43 | // Agent は HTML を解析し、以下のようなルールに従って追加のリソースへリクエストを送信します。 44 | // - script, link 要素で収集対象となるもの(src が設定されている、 rel が stylesheet である、など)を取得します 45 | // - img 要素も収集しますが、ブラウザの挙動に従い、 loading="lazy" なものは無視します 46 | // - script 要素の async / defer は考慮しません 47 | // 挙動の参考としては『HTML をロードしてから onload が実行されるまでに発行されるリクエスト』を基準としています。 48 | // リクエストの順序などは考慮されていません。 49 | // 厳密な挙動が必要な場合は、外部で実装してください。 50 | 51 | //// CacheStore 52 | // Agent は CacheStore を持ち、それを利用してブラウザに似せた Conditinal GET や、 53 | // キャッシュを利用して、メモリからレスポンスを復元したりします。 54 | // もし Cache が必要ないようであれば、 WithNoCache() を NewAgent の引数へ渡してください。 55 | agent, _ := NewAgent(WithNoCache(), WithDefaultTransport()) 56 | 57 | // また、なんらかの理由でキャッシュをクリアしたくなった場合は agent.CacheStore.Clear() で削除できます。 58 | agent.CacheStore.Clear() 59 | 60 | //// Transport 61 | // Agent は HTTPClient とその Transport を持ちます。 62 | // TCP 接続単位で共有を拒否したい場合は WithCloneTransport(DefaultTransport) などを利用し、 63 | // 接続が共有されても構わない場合は WithDefatultTransport() を利用してください。 64 | agent, _ := NewAgent(WithDefaultTransport()) 65 | // or 66 | agent, _ := NewAgent(WithCloneTransport(DefaultTransport)); 67 | ``` 68 | 69 | #### 補足 70 | 71 | - `Agent` を複数のユーザー間で使い回さないでください。 `Agent` は1つの User−Agent として機能するように実装されています。 72 | - `ProcessHTML` は基本的に低速です。すべてのページでこれを利用しようとしてはいけません。チェックに必要な場合のみ利用してください。 73 | 74 | ### failure 75 | 76 | isucandar 独自のエラーや、それらのコレクションを扱うパッケージです。基本的には [xerrors](https://golang.org/x/xerrors) をベースに作成されていますが、以下のような点が異なります。 77 | 78 | - 取得数を指定したり、除外設定のできるコールスタック 79 | - xerrors 標準では1つしかコールスタックを保持できないためです 80 | - 複数個タグのようにつけられるエラーコード 81 | 82 | ```golang 83 | //// Code 84 | // Code はエラーコードそのものを指す interface です。 85 | // ErrorCode() string と Error() string が実装されていれば満たすことができます。 86 | // 基本的には StringCode を介して定義するのがかんたんです。 87 | var StandardErrorCode failure.StringCode = "standard" 88 | 89 | //// Error 90 | // Error はエラーコード、コールスタックの保持などを行う error 互換の構造体です。 91 | err := NewError(StandardErrorCode, fmt.Errorf("original error message")) 92 | // NewError は基本的に渡されたエラーを Code でラッピングしますが、一部のエラーは追加で Code を付与します。 93 | // - net.Error.Timeout() == true: TimeoutErrorCode 94 | // - net.Error.Temporary() == true: TemporaryErrorCode 95 | // - context.Canceled: CanceledErrorCode 96 | 97 | // コールスタックを出力したりできます 98 | fmt.Printf("%v", err) 99 | // standard: original error message 100 | fmt.Printf("%+v", err) 101 | // standard: 102 | // github.com/isucon/isucandar/failure.TestPrint: 103 | // ~/src/github.com/isucon/isucandar/failure/failure_test.go:10 104 | // - original error message 105 | 106 | // 最も最近つけられた ErrorCode は以下のように取得できます。 107 | // Error ではない場合、自動的に UnknownErrorCode の ErrorCode が返ります。 108 | code := GetCode(err) 109 | // => "standard" 110 | 111 | // そのエラーに紐付いている ErrorCode をすべて取得します。 112 | // なんの ErrorCode も紐付いていない時は UnknownErrorCode 単体が返ります。 113 | codes := GetCodes(err) 114 | // => []string{"standard"} 115 | 116 | // 元のエラーがなんであったかなどは xerrors や errors 同様に判別できます。 117 | Is(err, context.DeadlineExceeded) 118 | 119 | //// Backtrace & BacktraceCleaner 120 | // failure はコールスタックを Error に保存しますが、その深度は変数で変更できます。 121 | CaptureBacktraceSize = 1 // default: 5 122 | 123 | // BacktraceCleaner は保存される Backtrace から除外するものを指定できます。 124 | // 例えば組み込みパッケージの Backtrace を除外する指定は以下のようにできます。 125 | BacktraceCleaner.Add(SkipGOROOT()) 126 | 127 | // Backtrace matcher はかんたんに実装できます。 128 | BacktraceCleaner.Add(func(backtrace *Backtrace) bool { 129 | return strings.HasSuffix(backtrace.File, "_test.go") 130 | }) 131 | 132 | //// Errors 133 | // Errors は Error の収集と集計を高速かつかんたんに行うための構造体です。 134 | errors := NewErrors(context.TODO()) 135 | errors.Add(err) 136 | // NewErrors に渡した Context が終了すると、エラーの収集が終わったことを伝えられます。 137 | // Errors.Add(error) は内部的に別 goroutine で処理しているため、即座には Errors 内部に追加されません。 138 | // その代わり、ロックを気にせず追加し続けることができます。 139 | 140 | // 明示的に収集が完了したことを示すために Done してもよいです。 141 | errors.Done() 142 | 143 | // 最終的に以下の関数達で集計をすることができます。 144 | // ErrorCode ごとにエラーメッセージを集計します。 145 | // ErrorCode は GetCode(error) で得られたものが採用されます。 146 | errors.Messages() // => map[string][]string 147 | 148 | // ErrorCode ごとに数を集計します。 149 | // ErrorCode は GetCodes(err) で得られたすべてのコードに対して加算するため、 150 | // 総数は実際の error の数より大きくなることに注意してください。 151 | errros.Count() // => map[string]int64 152 | // 例えば unknown が 1 以上なら Critical Error とする、などが考えられます。 153 | 154 | // すべての error を返却します。 155 | errors.All() // => []error 156 | ``` 157 | 158 | #### 補足 159 | 160 | - 集計系の関数(`Messages` / `Count` / `All`)は集計完了前でもいつでも取り出せます。 161 | 162 | ### score 163 | 164 | スコア集計のためのパッケージです。 165 | 166 | ```golang 167 | //// Score 168 | // スコアを集計、点数の計算までを行います。 169 | score := NewScore(context.TODO()) 170 | // Errors 同様、収集の完了を Context 経由で伝えることができます。 171 | 172 | // スコアは文字列によってタグ付けされており、各タグに得点を設定できます。 173 | score.Set("success-get", 1) 174 | score.Set("success-post", 5) 175 | 176 | // タグを指定して1ずつスコアを加算します。 177 | score.Add("success-get") 178 | score.Add("success-post") 179 | 180 | // Done で明示的にスコア収集の完了を伝えられます。 181 | score.Done() 182 | 183 | // 以下の関数達で集計結果を得ることができます。 184 | // 各タグの個数を出力します。 185 | score.Breakdown() 186 | // => map[string]int64{ "success-get": 1, "success-post": 1 } 187 | 188 | // 合計得点を出力します。 189 | score.Sum() 190 | // => 6 191 | 192 | // Done しつつ Sum をします。 193 | score.Total() 194 | // => 6 195 | ``` 196 | 197 | #### 補足 198 | 199 | - `Breakdown()` や `Sum()` はいつでも実行できます。ロックなどを外部から考慮する必要はありません。 200 | 201 | ### worker 202 | 203 | 同じ処理を複数回実行したり、並列数を抑えながら無限に実行したりする処理の制御を提供します。 204 | 205 | ```golang 206 | //// Worker 207 | // Worker はオプションによって少し挙動が変わります。 208 | // ループ回数を指定するようなものは以下のように 209 | limitedWorker, err := NewWorker(f, WithLoopCount(5)) 210 | // ループ回数を指定しない場合は以下のように作成します。 211 | unlimitedWorker, err := NewWorker(f, WithInfinityLoop()) 212 | 213 | // 作成時の引数に渡す f が処理される内容です。 214 | f := func(ctx context.Context, i int) { 215 | // ctx : 渡されてきた Context 216 | // i : 何回目の実行か。ループ回数の指定がない場合は常に -1 になります。 217 | } 218 | 219 | // Worker は任意のタイミングで Context を介して停止が可能です。 220 | // 停止を通知された Worker は新たな実行をせず、なるべく素早く実行を終了します。 221 | ctx, cancel := context.WithTimeout(context.TODO(), 1 * time.Second) 222 | defer cancel() 223 | worker.Process(ctx) 224 | // Process は起動済みのジョブのすべての実行を待ちます。 225 | 226 | // 外部から Worker の終了を検出することもできます。 227 | worker.Wait() 228 | 229 | // Worker は作成時または後からループ回数を変更できます。 230 | worker, err := NewWorker(f, WithLoopCount(10)) 231 | worker.SetLoopCount(20) 232 | 233 | // Worker は作成時または後から並列数を変更できます。 234 | worker, err := NewWorker(f, WithMaxParallelism(10)/* あるいは WithUnlimitedParallelism() */) 235 | worker.SetParallelism(20) 236 | worker.AddParallelism(20) 237 | // 並列数の変更は実行中であっても反映されます。 238 | ``` 239 | 240 | #### 補足 241 | 242 | - ループ回数のない `Worker` を後から制限付きに変えたりすると思わぬエラーが発生する場合があります。 243 | 244 | ### parallel 245 | 246 | 同時実行数を制御しつつ、複数のジョブを実行させる処理を提供します。 247 | 248 | ```golang 249 | //// Parallel 250 | // 初期化時、あるいはあとから同時実行数を設定できます。 251 | parallel := NewParallel(10) 252 | parallel.SetParallelism(5) 253 | parallel.AddParallelism(5) 254 | // 制限値に 0 以下の値を与えると、並列数の上限を設けません。 255 | // 並列数の変更はジョブの起動中であっても構いません。 256 | 257 | // 実行可能になるまで待ってから(列に並ぶ)、ジョブを実行します。 258 | // Context を渡すことができますが、 Context が終了しても Parallel はジョブを自動停止はしません。 259 | // ジョブ側で Context の終了を検知して終了してください。 260 | // ただ、ジョブが列に並んでいる最中に Context が終了した場合、 261 | // Parallel は順番が来てもジョブを起動しません。 262 | parallel.Do(ctx, func(b context.Context) { 263 | // Do の ctx が func(b) に引き渡されます。 264 | }) 265 | 266 | // 順番待ちのジョブがいる時に、外部からすべての実行を取りやめたい場合は、 267 | // Close を用いてジョブの実行をキャンセルします。 268 | parallel.Close() 269 | 270 | // 実行中、あるいは未来実行するすべてのジョブの完了を待ちたい場合は、 271 | // Wait を利用してください。 272 | parallel.Wait() 273 | 274 | // 1度以上実行し、Close して停止した Parallel を再利用するには 275 | // Reset による再初期化が必要です。 276 | parallel.Reset() 277 | ``` 278 | 279 | #### 補足 280 | 281 | - 並列数に1を設定すると、 `Wait` 時に不安定な挙動を示す場合があります。 282 | 283 | ## Author 284 | 285 | Sho Kusano 286 | 287 | ## License 288 | 289 | See [LICENSE](https://github.com/isucon/isucandar/blob/master/LICENSE) 290 | -------------------------------------------------------------------------------- /agent/agent.go: -------------------------------------------------------------------------------- 1 | package agent 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "errors" 7 | "io" 8 | "net" 9 | "net/http" 10 | "net/http/cookiejar" 11 | "net/url" 12 | "strings" 13 | "time" 14 | ) 15 | 16 | var ( 17 | DefaultConnections = 10000 18 | DefaultName = "isucandar" 19 | DefaultAccept = "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8" 20 | DefaultRequestTimeout = 1 * time.Second 21 | 22 | DefaultTLSConfig = &tls.Config{ 23 | InsecureSkipVerify: true, 24 | } 25 | 26 | DefaultDialer *net.Dialer 27 | DefaultTransport *http.Transport 28 | ) 29 | 30 | var ( 31 | ErrTransportInvalid = errors.New("Specify transport option(WithCloneTransport or WithDefaultTransport)") 32 | ErrUnknownContentEncoding = errors.New("Unknown content encoding") 33 | ) 34 | 35 | func init() { 36 | DefaultDialer = &net.Dialer{ 37 | Timeout: 0, 38 | KeepAlive: 60 * time.Second, 39 | } 40 | 41 | transport := &http.Transport{ 42 | Proxy: http.ProxyFromEnvironment, 43 | Dial: DefaultDialer.Dial, 44 | DialContext: DefaultDialer.DialContext, 45 | TLSClientConfig: DefaultTLSConfig, 46 | DisableCompression: true, 47 | MaxIdleConns: 0, 48 | MaxIdleConnsPerHost: DefaultConnections, 49 | MaxConnsPerHost: 0, 50 | TLSHandshakeTimeout: 0, 51 | ResponseHeaderTimeout: 0, 52 | IdleConnTimeout: 0, 53 | ForceAttemptHTTP2: true, 54 | } 55 | 56 | DefaultTransport = transport 57 | } 58 | 59 | type AgentOption func(*Agent) error 60 | 61 | type Agent struct { 62 | Name string 63 | BaseURL *url.URL 64 | DefaultAccept string 65 | CacheStore CacheStore 66 | HttpClient *http.Client 67 | } 68 | 69 | func NewAgent(opts ...AgentOption) (*Agent, error) { 70 | jar, _ := cookiejar.New(&cookiejar.Options{}) 71 | 72 | agent := &Agent{ 73 | Name: DefaultName, 74 | BaseURL: nil, 75 | DefaultAccept: DefaultAccept, 76 | CacheStore: NewCacheStore(), 77 | HttpClient: &http.Client{ 78 | CheckRedirect: useLastResponse, 79 | Transport: nil, 80 | Jar: jar, 81 | Timeout: DefaultRequestTimeout, 82 | }, 83 | } 84 | 85 | for _, opt := range opts { 86 | if err := opt(agent); err != nil { 87 | return nil, err 88 | } 89 | } 90 | 91 | if agent.HttpClient.Transport == nil { 92 | return nil, ErrTransportInvalid 93 | } 94 | 95 | return agent, nil 96 | } 97 | 98 | func (a *Agent) ClearCookie() { 99 | if a.HttpClient.Jar != nil { 100 | jar, _ := cookiejar.New(&cookiejar.Options{}) 101 | a.HttpClient.Jar = jar 102 | } 103 | } 104 | 105 | func (a *Agent) Do(ctx context.Context, req *http.Request) (*http.Response, error) { 106 | req = req.WithContext(ctx) 107 | 108 | var cache *Cache 109 | if a.CacheStore != nil { 110 | cache = a.CacheStore.Get(req) 111 | } 112 | if cache != nil { 113 | cache.apply(req) 114 | } 115 | 116 | var res *http.Response 117 | var err error 118 | 119 | if cache != nil && !cache.requiresRevalidate(req) { 120 | res = cache.restoreResponse() 121 | } else { 122 | res, err = a.HttpClient.Do(req) 123 | if err != nil { 124 | if strings.Contains(err.Error(), "http2: server sent GOAWAY") && strings.Contains(err.Error(), "ErrCode=NO_ERROR") && req.Method == http.MethodGet { 125 | return a.Do(ctx, req) 126 | } 127 | return nil, err 128 | } 129 | 130 | res, err = decompress(res) 131 | if err != nil { 132 | return nil, err 133 | } 134 | } 135 | 136 | cache, err = newCache(res, cache.Body()) 137 | if err != nil { 138 | return nil, err 139 | } 140 | 141 | if cache != nil && a.CacheStore != nil { 142 | a.CacheStore.Put(req, cache) 143 | } 144 | 145 | return res, nil 146 | } 147 | 148 | func (a *Agent) NewRequest(method string, target string, body io.Reader) (*http.Request, error) { 149 | reqURL, err := url.Parse(target) 150 | if err != nil { 151 | return nil, err 152 | } 153 | 154 | if a.BaseURL != nil { 155 | reqURL = a.BaseURL.ResolveReference(reqURL) 156 | } 157 | 158 | req, err := http.NewRequest(method, reqURL.String(), body) 159 | if err != nil { 160 | return nil, err 161 | } 162 | 163 | req.Header.Set("User-Agent", a.Name) 164 | req.Header.Set("Accept-Encoding", "gzip, deflate, br") 165 | req.Header.Set("Connection", "keep-alive") 166 | if req.Header.Get("Accept") == "" { 167 | req.Header.Set("Accept", a.DefaultAccept) 168 | } 169 | 170 | return req, nil 171 | } 172 | 173 | func (a *Agent) GET(target string) (*http.Request, error) { 174 | return a.NewRequest(http.MethodGet, target, nil) 175 | } 176 | 177 | func (a *Agent) POST(target string, body io.Reader) (*http.Request, error) { 178 | return a.NewRequest(http.MethodPost, target, body) 179 | } 180 | 181 | func (a *Agent) PUT(target string, body io.Reader) (*http.Request, error) { 182 | return a.NewRequest(http.MethodPut, target, body) 183 | } 184 | 185 | func (a *Agent) PATCH(target string, body io.Reader) (*http.Request, error) { 186 | return a.NewRequest(http.MethodPatch, target, body) 187 | } 188 | 189 | func (a *Agent) DELETE(target string, body io.Reader) (*http.Request, error) { 190 | return a.NewRequest(http.MethodDelete, target, body) 191 | } 192 | 193 | func useLastResponse(req *http.Request, via []*http.Request) error { 194 | return http.ErrUseLastResponse 195 | } 196 | -------------------------------------------------------------------------------- /agent/agent_test.go: -------------------------------------------------------------------------------- 1 | package agent 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net/http" 8 | "net/http/httptest" 9 | "net/http/httputil" 10 | "testing" 11 | 12 | "github.com/julienschmidt/httprouter" 13 | ) 14 | 15 | func newHTTPServer() *httptest.Server { 16 | r := httprouter.New() 17 | r.GET("/dump", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { 18 | dump, err := httputil.DumpRequest(r, true) 19 | if err != nil { 20 | w.WriteHeader(http.StatusInternalServerError) 21 | return 22 | } 23 | fmt.Printf("%s", dump) 24 | w.WriteHeader(http.StatusNoContent) 25 | }) 26 | 27 | r.GET("/not_found", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { 28 | w.WriteHeader(404) 29 | }) 30 | r.GET("/301redirect", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { 31 | http.Redirect(w, r, "/301", http.StatusMovedPermanently) 32 | }) 33 | r.GET("/302redirect", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { 34 | http.Redirect(w, r, "/302", http.StatusFound) 35 | }) 36 | r.GET("/304redirect", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { 37 | http.Redirect(w, r, "/304", http.StatusNotModified) 38 | }) 39 | r.GET("/307redirect", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { 40 | http.Redirect(w, r, "/307", http.StatusTemporaryRedirect) 41 | }) 42 | r.GET("/308redirect", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { 43 | http.Redirect(w, r, "/308", http.StatusPermanentRedirect) 44 | }) 45 | 46 | return httptest.NewServer(r) 47 | } 48 | 49 | func TestAgent(t *testing.T) { 50 | errOpt := func(_ *Agent) error { 51 | return errors.New("invalid") 52 | } 53 | 54 | agent, err := NewAgent(errOpt) 55 | if err == nil || agent != nil { 56 | t.Fatal("error not occured") 57 | } 58 | } 59 | 60 | func TestAgentClearCookie(t *testing.T) { 61 | agent, err := NewAgent(WithBaseURL("http://example.com/"), WithDefaultTransport()) 62 | if err != nil { 63 | t.Fatal(err) 64 | } 65 | 66 | agent.HttpClient.Jar.SetCookies(agent.BaseURL, []*http.Cookie{ 67 | {}, 68 | }) 69 | if len(agent.HttpClient.Jar.Cookies(agent.BaseURL)) != 1 { 70 | t.Fatal("Set cookie failed") 71 | } 72 | agent.ClearCookie() 73 | if len(agent.HttpClient.Jar.Cookies(agent.BaseURL)) != 0 { 74 | t.Fatal("Clear cookie failed") 75 | } 76 | } 77 | 78 | func TestAgentNewRequest(t *testing.T) { 79 | agent, err := NewAgent(WithDefaultTransport()) 80 | if err != nil { 81 | t.Fatalf("%+v", err) 82 | } 83 | 84 | _, err = agent.NewRequest(http.MethodGet, "://invalid-uri", nil) 85 | if err == nil { 86 | t.Fatal("Not reached url parse error") 87 | } 88 | 89 | _, err = agent.NewRequest("bad method", "/", nil) 90 | if err == nil { 91 | t.Fatalf("Not reached method name error") 92 | } 93 | } 94 | 95 | func TestAgentRequest(t *testing.T) { 96 | srv := newHTTPServer() 97 | defer srv.Close() 98 | 99 | agent, err := NewAgent(WithBaseURL(srv.URL), WithDefaultTransport()) 100 | if err != nil { 101 | t.Fatalf("%+v", err) 102 | } 103 | 104 | req, err := agent.GET("/302redirect") 105 | if err != nil { 106 | t.Fatalf("%+v", err) 107 | } 108 | 109 | ctx, cancel := context.WithCancel(context.Background()) 110 | defer cancel() 111 | 112 | res, err := agent.Do(ctx, req) 113 | if err != nil { 114 | t.Fatalf("%+v", err) 115 | } 116 | 117 | if res.StatusCode != 302 { 118 | t.Fatalf("%#v", res) 119 | } 120 | } 121 | 122 | func TestAgentMethods(t *testing.T) { 123 | agent, err := NewAgent(WithDefaultTransport()) 124 | if err != nil { 125 | t.Fatalf("%+v", err) 126 | } 127 | 128 | r, _ := agent.GET("/") 129 | if r.Method != http.MethodGet { 130 | t.Fatalf("Method missmatch: %s", r.Method) 131 | } 132 | r, _ = agent.POST("/", nil) 133 | if r.Method != http.MethodPost { 134 | t.Fatalf("Method missmatch: %s", r.Method) 135 | } 136 | r, _ = agent.PUT("/", nil) 137 | if r.Method != http.MethodPut { 138 | t.Fatalf("Method missmatch: %s", r.Method) 139 | } 140 | r, _ = agent.PATCH("/", nil) 141 | if r.Method != http.MethodPatch { 142 | t.Fatalf("Method missmatch: %s", r.Method) 143 | } 144 | r, _ = agent.DELETE("/", nil) 145 | if r.Method != http.MethodDelete { 146 | t.Fatalf("Method missmatch: %s", r.Method) 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /agent/cache.go: -------------------------------------------------------------------------------- 1 | package agent 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "io/ioutil" 7 | "net/http" 8 | "sort" 9 | "strings" 10 | "time" 11 | 12 | "github.com/pquerna/cachecontrol/cacheobject" 13 | ) 14 | 15 | var ( 16 | cacheableStatusCodes = map[int]bool{ 17 | 200: true, 18 | 203: true, 19 | 204: true, 20 | 206: true, 21 | 300: true, 22 | 301: true, 23 | 404: true, 24 | 405: true, 25 | 410: true, 26 | 414: true, 27 | 501: true, 28 | } 29 | ) 30 | 31 | type Cache struct { 32 | now time.Time 33 | body []byte 34 | res *http.Response 35 | ReqDirectives *cacheobject.RequestCacheDirectives 36 | ResDirectives *cacheobject.ResponseCacheDirectives 37 | Expires *time.Time 38 | Date *time.Time 39 | LastModified *time.Time 40 | ETag *string 41 | Varies []string 42 | VariesKey string 43 | } 44 | 45 | func newCache(res *http.Response, cachedBody []byte) (*Cache, error) { 46 | if res.StatusCode == 304 { 47 | res.Body = ioutil.NopCloser(bytes.NewReader(cachedBody)) 48 | return nil, nil 49 | } 50 | 51 | // Do not cache request without get method 52 | if res.Request.Method != http.MethodGet { 53 | return nil, nil 54 | } 55 | 56 | // Do not cache request with authorization header 57 | if auth := res.Request.Header.Get("Authorization"); auth != "" { 58 | return nil, nil 59 | } 60 | 61 | // Do not cache specified status code 62 | if _, found := cacheableStatusCodes[res.StatusCode]; !found { 63 | return nil, nil 64 | } 65 | 66 | resDirs, err := cacheobject.ParseResponseCacheControl(res.Header.Get("Cache-Control")) 67 | if err != nil { 68 | return nil, err 69 | } 70 | 71 | if resDirs.NoStore { 72 | return nil, nil 73 | } 74 | 75 | reqDirs, err := cacheobject.ParseRequestCacheControl(res.Request.Header.Get("Cache-Control")) 76 | if err != nil { 77 | return nil, err 78 | } 79 | 80 | cache := &Cache{ 81 | now: time.Now(), 82 | ReqDirectives: reqDirs, 83 | ResDirectives: resDirs, 84 | } 85 | 86 | if t, err := http.ParseTime(res.Header.Get("Expires")); err == nil { 87 | cache.Expires = &t 88 | } 89 | 90 | if t, err := http.ParseTime(res.Header.Get("Date")); err == nil { 91 | cache.Date = &t 92 | } 93 | 94 | if t, err := http.ParseTime(res.Header.Get("Last-Modified")); err == nil { 95 | cache.LastModified = &t 96 | } 97 | 98 | if etag := res.Header.Get("ETag"); len(etag) > 0 { 99 | cache.ETag = &etag 100 | } 101 | 102 | if cache.Expires == nil && cache.LastModified == nil && cache.ETag == nil && cache.ResDirectives.MaxAge == -1 { 103 | return nil, nil 104 | } 105 | 106 | varies := make([]string, 0, 3) 107 | for _, v := range res.Header.Values("Vary") { 108 | for _, k := range strings.Split(v, ",") { 109 | varies = append(varies, strings.TrimSpace(k)) 110 | } 111 | } 112 | sort.Strings(varies) 113 | cache.Varies = varies 114 | 115 | key := "" 116 | for _, h := range varies { 117 | key += strings.Join(res.Request.Header.Values(h), ", ") 118 | } 119 | cache.VariesKey = key 120 | 121 | cache.res = res 122 | if res.StatusCode == 304 { 123 | cache.body = cachedBody 124 | } else { 125 | cache.body, err = ioutil.ReadAll(res.Body) 126 | if err != nil && err != io.EOF { 127 | return nil, err 128 | } 129 | res.Body.Close() 130 | cachedBody = cache.body 131 | } 132 | res.Body = ioutil.NopCloser(bytes.NewReader(cachedBody)) 133 | 134 | return cache, nil 135 | } 136 | 137 | func (c *Cache) Body() []byte { 138 | if c != nil { 139 | return c.body 140 | } 141 | return nil 142 | } 143 | 144 | func (c *Cache) apply(req *http.Request) { 145 | if c.LastModified != nil { 146 | req.Header.Set("If-Modified-Since", c.LastModified.Format(http.TimeFormat)) 147 | } 148 | if c.ETag != nil { 149 | req.Header.Set("If-None-Match", *c.ETag) 150 | } 151 | } 152 | 153 | func (c *Cache) isOutdated() bool { 154 | now := time.Now().UTC() 155 | 156 | if c.ResDirectives.MaxAge <= 0 && c.Expires == nil { 157 | return true 158 | } 159 | 160 | if c.ResDirectives.MaxAge > 0 && now.After(c.now.Add(time.Duration(c.ResDirectives.MaxAge)*time.Second)) { 161 | return true 162 | } 163 | return (c.Expires != nil && now.After(*c.Expires)) 164 | } 165 | 166 | func (c *Cache) matchVariesKey(req *http.Request) bool { 167 | key := "" 168 | for _, h := range c.Varies { 169 | key += strings.Join(req.Header.Values(h), ", ") 170 | } 171 | 172 | return key == c.VariesKey 173 | } 174 | 175 | func (c *Cache) requiresRevalidate(req *http.Request) bool { 176 | return c.ResDirectives.MustRevalidate || !c.matchVariesKey(req) || c.isOutdated() 177 | } 178 | 179 | func (c *Cache) restoreResponse() *http.Response { 180 | var res http.Response 181 | res = *c.res 182 | res.Body = ioutil.NopCloser(bytes.NewReader(c.body)) 183 | return &res 184 | } 185 | -------------------------------------------------------------------------------- /agent/cache_store.go: -------------------------------------------------------------------------------- 1 | package agent 2 | 3 | import ( 4 | "net/http" 5 | "sync" 6 | ) 7 | 8 | type CacheStore interface { 9 | Get(*http.Request) *Cache 10 | Put(*http.Request, *Cache) 11 | Clear() 12 | } 13 | 14 | type cacheStore struct { 15 | mu sync.RWMutex 16 | table map[string]*Cache 17 | } 18 | 19 | func NewCacheStore() CacheStore { 20 | return &cacheStore{ 21 | mu: sync.RWMutex{}, 22 | table: make(map[string]*Cache), 23 | } 24 | } 25 | 26 | func (c *cacheStore) Get(r *http.Request) *Cache { 27 | c.mu.RLock() 28 | defer c.mu.RUnlock() 29 | 30 | if c, ok := c.table[r.URL.String()]; ok && c != nil { 31 | return c 32 | } 33 | 34 | return nil 35 | } 36 | 37 | func (c *cacheStore) Put(r *http.Request, cache *Cache) { 38 | c.mu.Lock() 39 | defer c.mu.Unlock() 40 | 41 | c.table[r.URL.String()] = cache 42 | } 43 | 44 | func (c *cacheStore) Clear() { 45 | c.mu.Lock() 46 | defer c.mu.Unlock() 47 | 48 | c.table = make(map[string]*Cache) 49 | } 50 | -------------------------------------------------------------------------------- /agent/cache_test.go: -------------------------------------------------------------------------------- 1 | package agent 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "io/ioutil" 7 | "net/http" 8 | "net/http/httptest" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | func req(a *Agent, method string, path string) (*http.Request, *http.Response, error) { 14 | req, err := a.NewRequest(method, path, nil) 15 | if err != nil { 16 | return nil, nil, err 17 | } 18 | res, err := a.Do(context.Background(), req) 19 | if err != nil { 20 | return req, nil, err 21 | } 22 | 23 | return req, res, nil 24 | } 25 | 26 | func get(a *Agent, path string) (*http.Request, *http.Response, error) { 27 | return req(a, http.MethodGet, path) 28 | } 29 | 30 | func TestCacheCondition(t *testing.T) { 31 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 32 | switch r.URL.Path { 33 | case "/no-store": 34 | w.Header().Set("Cache-Control", "no-store, max-age=100") 35 | case "/invalid": 36 | w.Header().Set("Cache-Control", "private, max-age=-10") 37 | default: 38 | w.Header().Set("Cache-Control", "public, max-age=1000") 39 | } 40 | w.WriteHeader(200) 41 | io.WriteString(w, "OK") 42 | })) 43 | defer srv.Close() 44 | 45 | agent, err := NewAgent(WithBaseURL(srv.URL), WithDefaultTransport()) 46 | if err != nil { 47 | t.Fatal(err) 48 | } 49 | 50 | r, _, _ := req(agent, "POST", "/") 51 | if cache := agent.CacheStore.Get(r); cache != nil { 52 | t.Fatalf("Stored invalid cache: %v", cache) 53 | } 54 | 55 | r, _ = agent.GET("/") 56 | r.Header.Set("Authorization", "Bearer X-TOKEN") 57 | agent.Do(context.Background(), r) 58 | if cache := agent.CacheStore.Get(r); cache != nil { 59 | t.Fatalf("Stored invalid cache: %v", cache) 60 | } 61 | 62 | r, _, _ = get(agent, "/no-store") 63 | if cache := agent.CacheStore.Get(r); cache != nil { 64 | t.Fatalf("Stored invalid cache: %v", cache) 65 | } 66 | 67 | r, _, _ = get(agent, "/invalid") 68 | if cache := agent.CacheStore.Get(r); cache != nil { 69 | t.Fatalf("Stored invalid cache: %v", cache) 70 | } 71 | 72 | r, _ = agent.GET("/") 73 | r.Header.Set("Cache-Control", "max-age=-1") 74 | agent.Do(context.Background(), r) 75 | if cache := agent.CacheStore.Get(r); cache != nil { 76 | t.Fatalf("Stored invalid cache: %v", cache) 77 | } 78 | } 79 | 80 | func TestCacheWithLastModified(t *testing.T) { 81 | lm := time.Now().UTC() 82 | lm = lm.Truncate(time.Second) 83 | 84 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 85 | w.Header().Set("Content-Type", "text/plain") 86 | w.Header().Set("Last-Modified", lm.Format(http.TimeFormat)) 87 | ims, _ := http.ParseTime(r.Header.Get("If-Modified-Since")) 88 | 89 | if !lm.Equal(ims) { 90 | w.WriteHeader(http.StatusOK) 91 | 92 | io.WriteString(w, "Hello, World") 93 | } else { 94 | w.WriteHeader(http.StatusNotModified) 95 | } 96 | })) 97 | defer srv.Close() 98 | 99 | agent, err := NewAgent(WithBaseURL(srv.URL), WithDefaultTransport()) 100 | if err != nil { 101 | t.Fatal(err) 102 | } 103 | 104 | _, res, err := get(agent, "/") 105 | if err != nil { 106 | t.Fatal(err) 107 | } 108 | 109 | if res.StatusCode != 200 { 110 | t.Fatalf("status code missmatch: %d", res.StatusCode) 111 | } 112 | 113 | _, res, err = get(agent, "/") 114 | if err != nil { 115 | t.Fatal(err) 116 | } 117 | 118 | if res.StatusCode != 304 { 119 | t.Fatalf("status code missmatch: %d", res.StatusCode) 120 | } 121 | 122 | body, err := ioutil.ReadAll(res.Body) 123 | if err != nil { 124 | t.Fatalf("read body: %+v", err) 125 | } 126 | 127 | if string(body) != "Hello, World" { 128 | t.Fatalf("body missmatch: %x", body) 129 | } 130 | 131 | _, res, err = get(agent, "/") 132 | if err != nil { 133 | t.Fatal(err) 134 | } 135 | 136 | if res.StatusCode != 304 { 137 | t.Fatalf("status code missmatch: %d", res.StatusCode) 138 | } 139 | 140 | body, err = ioutil.ReadAll(res.Body) 141 | if err != nil { 142 | t.Fatalf("read body: %+v", err) 143 | } 144 | 145 | if string(body) != "Hello, World" { 146 | t.Fatalf("body missmatch: %x", body) 147 | } 148 | } 149 | 150 | func TestCacheWithETag(t *testing.T) { 151 | etag := "W/deadbeaf" 152 | 153 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 154 | w.Header().Set("Content-Type", "text/plain") 155 | w.Header().Set("ETag", etag) 156 | inm := r.Header.Get("If-None-Match") 157 | 158 | if etag != inm { 159 | w.WriteHeader(http.StatusOK) 160 | 161 | io.WriteString(w, "Hello, World") 162 | } else { 163 | w.WriteHeader(http.StatusNotModified) 164 | } 165 | })) 166 | defer srv.Close() 167 | 168 | agent, err := NewAgent(WithBaseURL(srv.URL), WithDefaultTransport()) 169 | if err != nil { 170 | t.Fatal(err) 171 | } 172 | 173 | _, res, err := get(agent, "/") 174 | if err != nil { 175 | t.Fatal(err) 176 | } 177 | 178 | if res.StatusCode != 200 { 179 | t.Fatalf("status code missmatch: %d", res.StatusCode) 180 | } 181 | 182 | _, res, err = get(agent, "/") 183 | if err != nil { 184 | t.Fatal(err) 185 | } 186 | 187 | if res.StatusCode != 304 { 188 | t.Fatalf("status code missmatch: %d", res.StatusCode) 189 | } 190 | 191 | _, res, err = get(agent, "/") 192 | if err != nil { 193 | t.Fatal(err) 194 | } 195 | 196 | if res.StatusCode != 304 { 197 | t.Fatalf("status code missmatch: %d", res.StatusCode) 198 | } 199 | 200 | body, err := ioutil.ReadAll(res.Body) 201 | if err != nil { 202 | t.Fatalf("read body: %+v", err) 203 | } 204 | 205 | if string(body) != "Hello, World" { 206 | t.Fatalf("body missmatch: %x", body) 207 | } 208 | 209 | _, res, err = get(agent, "/") 210 | if err != nil { 211 | t.Fatal(err) 212 | } 213 | 214 | if res.StatusCode != 304 { 215 | t.Fatalf("status code missmatch: %d", res.StatusCode) 216 | } 217 | 218 | body, err = ioutil.ReadAll(res.Body) 219 | if err != nil { 220 | t.Fatalf("read body: %+v", err) 221 | } 222 | 223 | if string(body) != "Hello, World" { 224 | t.Fatalf("body missmatch: %x", body) 225 | } 226 | 227 | } 228 | 229 | func TestCacheWithMaxAge(t *testing.T) { 230 | reqCount := 0 231 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 232 | w.Header().Set("Content-Type", "text/plain") 233 | w.Header().Set("Cache-Control", "max-age=2") 234 | w.WriteHeader(http.StatusOK) 235 | 236 | io.WriteString(w, "Hello, World") 237 | 238 | reqCount++ 239 | })) 240 | defer srv.Close() 241 | 242 | agent, err := NewAgent(WithBaseURL(srv.URL), WithDefaultTransport()) 243 | if err != nil { 244 | t.Fatal(err) 245 | } 246 | 247 | _, res, err := get(agent, "/") 248 | if err != nil { 249 | t.Fatal(err) 250 | } 251 | 252 | if res.StatusCode != 200 { 253 | t.Fatalf("status code missmatch: %d", res.StatusCode) 254 | } 255 | 256 | req, res, err := get(agent, "/") 257 | if err != nil { 258 | t.Fatal(err) 259 | } 260 | 261 | if res.StatusCode != 200 { 262 | t.Fatalf("status code missmatch: %d", res.StatusCode) 263 | } 264 | 265 | c := agent.CacheStore.Get(req) 266 | c.now = time.Now().Add(-3 * time.Second) 267 | 268 | get(agent, "/") 269 | 270 | if reqCount != 2 { 271 | t.Fatalf("missmatch req count: %d", reqCount) 272 | } 273 | } 274 | 275 | func TestCacheWithExpires(t *testing.T) { 276 | reqCount := 0 277 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 278 | w.Header().Set("Content-Type", "text/plain") 279 | w.Header().Set("Expires", time.Now().UTC().Add(1*time.Second).Format(http.TimeFormat)) 280 | w.WriteHeader(http.StatusOK) 281 | 282 | io.WriteString(w, "Hello, World") 283 | 284 | reqCount++ 285 | })) 286 | defer srv.Close() 287 | 288 | agent, err := NewAgent(WithBaseURL(srv.URL), WithDefaultTransport()) 289 | if err != nil { 290 | t.Fatal(err) 291 | } 292 | 293 | _, res, err := get(agent, "/") 294 | if err != nil { 295 | t.Fatal(err) 296 | } 297 | 298 | if res.StatusCode != 200 { 299 | t.Fatalf("status code missmatch: %d", res.StatusCode) 300 | } 301 | 302 | _, res, err = get(agent, "/") 303 | if err != nil { 304 | t.Fatal(err) 305 | } 306 | 307 | if res.StatusCode != 200 { 308 | t.Fatalf("status code missmatch: %d", res.StatusCode) 309 | } 310 | 311 | <-time.After(1 * time.Second) 312 | get(agent, "/") 313 | 314 | if reqCount != 2 { 315 | t.Fatalf("missmatch req count: %d", reqCount) 316 | } 317 | } 318 | 319 | func TestCacheWithVary(t *testing.T) { 320 | reqCount := 0 321 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 322 | w.Header().Set("Content-Type", "text/plain") 323 | w.Header().Set("Cache-Control", "max-age=200000, public") 324 | w.Header().Add("Vary", "User-Agent") 325 | w.Header().Add("Vary", "X-Cache-Count") 326 | w.WriteHeader(http.StatusOK) 327 | 328 | io.WriteString(w, "Hello, World") 329 | 330 | reqCount++ 331 | })) 332 | defer srv.Close() 333 | 334 | agent, err := NewAgent(WithBaseURL(srv.URL), WithDefaultTransport()) 335 | if err != nil { 336 | t.Fatal(err) 337 | } 338 | 339 | req, err := agent.GET("/") 340 | if err != nil { 341 | t.Fatal(err) 342 | } 343 | 344 | ctx := context.Background() 345 | a := req.Clone(ctx) 346 | agent.Do(ctx, a) 347 | a = req.Clone(ctx) 348 | a.Header.Set("User-Agent", "Hoge") 349 | agent.Do(ctx, a) 350 | a = req.Clone(ctx) 351 | a.Header.Set("X-Cache-Count", "3") 352 | agent.Do(ctx, a) 353 | 354 | if reqCount != 3 { 355 | t.Fatalf("missmatch req count: %d", reqCount) 356 | } 357 | } 358 | 359 | func TestCacheWithClear(t *testing.T) { 360 | reqCount := 0 361 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 362 | w.Header().Set("Content-Type", "text/plain") 363 | w.Header().Set("Cache-Control", "max-age=20000") 364 | w.WriteHeader(http.StatusOK) 365 | 366 | io.WriteString(w, "Hello, World") 367 | 368 | reqCount++ 369 | })) 370 | defer srv.Close() 371 | 372 | agent, err := NewAgent(WithBaseURL(srv.URL), WithDefaultTransport()) 373 | if err != nil { 374 | t.Fatal(err) 375 | } 376 | 377 | get(agent, "/") 378 | agent.CacheStore.Clear() 379 | get(agent, "/") 380 | 381 | if reqCount != 2 { 382 | t.Fatalf("missmatch req count: %d", reqCount) 383 | } 384 | } 385 | 386 | func BenchmarkCacheWithMaxAge(b *testing.B) { 387 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 388 | w.Header().Set("Content-Type", "text/plain") 389 | w.Header().Set("Cache-Control", "max-age=20000") 390 | w.WriteHeader(http.StatusOK) 391 | 392 | io.WriteString(w, "Hello, World") 393 | })) 394 | defer srv.Close() 395 | 396 | agent, err := NewAgent(WithBaseURL(srv.URL), WithDefaultTransport()) 397 | if err != nil { 398 | b.Fatal(err) 399 | } 400 | 401 | _, res, err := get(agent, "/") 402 | if err != nil { 403 | b.Fatal(err) 404 | } 405 | 406 | if res.StatusCode != 200 { 407 | b.Fatalf("status code missmatch: %d", res.StatusCode) 408 | } 409 | 410 | b.ResetTimer() 411 | for i := 0; i < b.N; i++ { 412 | _, res, err := get(agent, "/") 413 | if err != nil { 414 | b.Fatal(err) 415 | } 416 | 417 | if res.StatusCode != 200 && res.StatusCode != 304 { 418 | b.Fatalf("status code missmatch: %d", res.StatusCode) 419 | } 420 | 421 | body, err := ioutil.ReadAll(res.Body) 422 | if err != nil { 423 | b.Fatal(err) 424 | } 425 | 426 | if string(body) != "Hello, World" { 427 | b.Fatal("body missmatch") 428 | } 429 | } 430 | } 431 | -------------------------------------------------------------------------------- /agent/decompress.go: -------------------------------------------------------------------------------- 1 | package agent 2 | 3 | import ( 4 | "compress/flate" 5 | "compress/gzip" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "strings" 10 | 11 | "github.com/dsnet/compress/brotli" 12 | ) 13 | 14 | func decompress(res *http.Response) (*http.Response, error) { 15 | contentEncoding := res.Header.Get("Content-Encoding") 16 | if contentEncoding == "" { 17 | return res, nil 18 | } 19 | 20 | var err error 21 | var body io.ReadCloser = res.Body 22 | 23 | encodings := strings.Split(contentEncoding, ",") 24 | for i := len(encodings) - 1; i >= 0; i-- { 25 | encoding := encodings[i] 26 | switch strings.TrimSpace(encoding) { 27 | case "br": 28 | body, err = brotli.NewReader(body, &brotli.ReaderConfig{}) 29 | case "gzip": 30 | body = &gzipReader{body: body} 31 | case "deflate": 32 | body = flate.NewReader(body) 33 | case "identity", "": 34 | // nop 35 | default: 36 | err = fmt.Errorf("unknown content encoding: %s: %w", encoding, ErrUnknownContentEncoding) 37 | } 38 | 39 | if err != nil { 40 | return nil, err 41 | } 42 | } 43 | 44 | res.Header.Del("Content-Length") 45 | res.ContentLength = -1 46 | res.Uncompressed = true 47 | res.Body = body 48 | 49 | return res, nil 50 | } 51 | 52 | type gzipReader struct { 53 | body io.ReadCloser 54 | zr *gzip.Reader 55 | zerr error 56 | } 57 | 58 | func (gz *gzipReader) Read(p []byte) (int, error) { 59 | if gz.zr == nil { 60 | var err error 61 | gz.zr, err = gzip.NewReader(gz.body) 62 | if err != nil { 63 | return 0, err 64 | } 65 | } 66 | 67 | return gz.zr.Read(p) 68 | } 69 | 70 | func (gz *gzipReader) Close() error { 71 | return gz.body.Close() 72 | } 73 | -------------------------------------------------------------------------------- /agent/decompress_test.go: -------------------------------------------------------------------------------- 1 | package agent 2 | 3 | import ( 4 | "bytes" 5 | "compress/flate" 6 | "compress/gzip" 7 | "context" 8 | "errors" 9 | "io" 10 | "io/ioutil" 11 | "net/http" 12 | "net/http/httptest" 13 | "testing" 14 | "time" 15 | 16 | "github.com/andybalholm/brotli" 17 | "github.com/julienschmidt/httprouter" 18 | "github.com/labstack/echo/v4" 19 | "github.com/labstack/echo/v4/middleware" 20 | ) 21 | 22 | func newCompressHTTPServer() *httptest.Server { 23 | r := httprouter.New() 24 | 25 | r.GET("/br", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { 26 | w.Header().Set("Content-Type", "text/plain") 27 | w.Header().Set("Content-Encoding", "br") 28 | w.Header().Set("X-Content-Type-Options", "nosniff") 29 | w.Header().Set("Transfer-Encoding", "chunked") 30 | w.WriteHeader(200) 31 | bw := brotli.NewWriter(w) 32 | defer bw.Close() 33 | io.WriteString(bw, "test it") 34 | }) 35 | r.GET("/broken-br", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { 36 | w.Header().Set("Content-Type", "text/plain") 37 | w.Header().Set("Content-Encoding", "br") 38 | w.Header().Set("X-Content-Type-Options", "nosniff") 39 | w.Header().Set("Transfer-Encoding", "chunked") 40 | w.WriteHeader(200) 41 | io.WriteString(w, "test it") 42 | }) 43 | r.GET("/gzip", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { 44 | gw := gzip.NewWriter(w) 45 | defer gw.Close() 46 | 47 | w.Header().Set("Content-Type", "text/plain") 48 | w.Header().Set("Content-Encoding", "gzip") 49 | w.Header().Set("X-Content-Type-Options", "nosniff") 50 | w.Header().Set("Transfer-Encoding", "chunked") 51 | w.WriteHeader(200) 52 | io.WriteString(gw, "test it") 53 | }) 54 | r.GET("/broken-gzip", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { 55 | w.Header().Set("Content-Type", "text/plain") 56 | w.Header().Set("Content-Encoding", "gzip") 57 | w.Header().Set("X-Content-Type-Options", "nosniff") 58 | w.Header().Set("Transfer-Encoding", "chunked") 59 | w.WriteHeader(200) 60 | io.WriteString(w, "test it") 61 | }) 62 | r.GET("/deflate", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { 63 | fw, err := flate.NewWriter(w, 9) 64 | if err != nil { 65 | io.WriteString(w, err.Error()) 66 | w.WriteHeader(http.StatusInternalServerError) 67 | return 68 | } 69 | defer fw.Close() 70 | 71 | w.Header().Set("Content-Type", "text/plain") 72 | w.Header().Set("Content-Encoding", "deflate") 73 | w.Header().Set("X-Content-Type-Options", "nosniff") 74 | w.Header().Set("Transfer-Encoding", "chunked") 75 | w.WriteHeader(200) 76 | io.WriteString(fw, "test it") 77 | }) 78 | r.GET("/broken-deflate", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { 79 | w.Header().Set("Content-Type", "text/plain") 80 | w.Header().Set("Content-Encoding", "deflate") 81 | w.Header().Set("X-Content-Type-Options", "nosniff") 82 | w.Header().Set("Transfer-Encoding", "chunked") 83 | w.WriteHeader(200) 84 | io.WriteString(w, "test it") 85 | }) 86 | r.GET("/identity", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { 87 | w.Header().Set("Content-Type", "text/plain") 88 | w.Header().Set("Content-Encoding", "identity") 89 | w.Header().Set("X-Content-Type-Options", "nosniff") 90 | w.Header().Set("Transfer-Encoding", "chunked") 91 | w.WriteHeader(200) 92 | io.WriteString(w, "test it") 93 | }) 94 | r.GET("/multiple", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { 95 | gw := gzip.NewWriter(w) 96 | defer gw.Close() 97 | 98 | fw, err := flate.NewWriter(gw, 9) 99 | if err != nil { 100 | io.WriteString(w, err.Error()) 101 | w.WriteHeader(http.StatusInternalServerError) 102 | return 103 | } 104 | defer fw.Close() 105 | 106 | w.Header().Set("Content-Type", "text/plain") 107 | w.Header().Set("Content-Encoding", "deflate, gzip") 108 | w.Header().Set("X-Content-Type-Options", "nosniff") 109 | w.Header().Set("Transfer-Encoding", "chunked") 110 | w.WriteHeader(200) 111 | io.WriteString(fw, "test it") 112 | }) 113 | r.GET("/unknown", func(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { 114 | w.Header().Set("Content-Type", "text/plain") 115 | w.Header().Set("Content-Encoding", "unknown") 116 | w.Header().Set("X-Content-Type-Options", "nosniff") 117 | w.Header().Set("Transfer-Encoding", "chunked") 118 | w.WriteHeader(200) 119 | io.WriteString(w, "test it") 120 | }) 121 | 122 | return httptest.NewServer(r) 123 | } 124 | 125 | func TestBrotliResponse(t *testing.T) { 126 | srv := newCompressHTTPServer() 127 | defer srv.Close() 128 | 129 | agent, err := NewAgent(WithBaseURL(srv.URL), WithDefaultTransport()) 130 | if err != nil { 131 | t.Fatalf("%+v", err) 132 | } 133 | 134 | req, err := agent.GET("/br") 135 | if err != nil { 136 | t.Fatalf("%+v", err) 137 | } 138 | 139 | res, err := agent.Do(context.Background(), req) 140 | if err != nil { 141 | t.Fatalf("%+v", err) 142 | } 143 | 144 | if res.StatusCode != 200 { 145 | t.Fatalf("%#v", res) 146 | } 147 | defer res.Body.Close() 148 | 149 | body, err := ioutil.ReadAll(res.Body) 150 | if err != nil { 151 | t.Fatalf("%+v", err) 152 | } 153 | 154 | if bytes.Compare(body, []byte("test it")) != 0 { 155 | t.Fatalf("%s missmatch %s", body, "test it") 156 | } 157 | 158 | _, res, err = get(agent, "/broken-br") 159 | if err != nil { 160 | t.Fatal(err) 161 | } 162 | _, err = ioutil.ReadAll(res.Body) 163 | if err == nil { 164 | t.Fatalf("Not raised error with broken encoding") 165 | } 166 | } 167 | 168 | func TestGzipResponse(t *testing.T) { 169 | srv := newCompressHTTPServer() 170 | defer srv.Close() 171 | 172 | agent, err := NewAgent(WithBaseURL(srv.URL), WithDefaultTransport()) 173 | if err != nil { 174 | t.Fatalf("%+v", err) 175 | } 176 | 177 | req, err := agent.GET("/gzip") 178 | if err != nil { 179 | t.Fatalf("%+v", err) 180 | } 181 | 182 | res, err := agent.Do(context.Background(), req) 183 | if err != nil { 184 | t.Fatalf("%+v", err) 185 | } 186 | 187 | if res.StatusCode != 200 { 188 | t.Fatalf("%#v", res) 189 | } 190 | defer res.Body.Close() 191 | 192 | body, err := ioutil.ReadAll(res.Body) 193 | if err != nil { 194 | t.Fatalf("%+v", err) 195 | } 196 | 197 | if bytes.Compare(body, []byte("test it")) != 0 { 198 | t.Fatalf("%s missmatch %s", body, "test it") 199 | } 200 | 201 | _, res, err = get(agent, "/broken-gzip") 202 | if err != nil { 203 | t.Fatal(err) 204 | } 205 | 206 | body, err = ioutil.ReadAll(res.Body) 207 | if err == nil { 208 | t.Fatalf("Not raised error with broken encoding") 209 | } 210 | } 211 | 212 | func TestDeflateResponse(t *testing.T) { 213 | srv := newCompressHTTPServer() 214 | defer srv.Close() 215 | 216 | agent, err := NewAgent(WithBaseURL(srv.URL), WithDefaultTransport()) 217 | if err != nil { 218 | t.Fatalf("%+v", err) 219 | } 220 | 221 | req, err := agent.GET("/deflate") 222 | if err != nil { 223 | t.Fatalf("%+v", err) 224 | } 225 | 226 | res, err := agent.Do(context.Background(), req) 227 | if err != nil { 228 | t.Fatalf("%+v", err) 229 | } 230 | 231 | if res.StatusCode != 200 { 232 | t.Fatalf("%#v", res) 233 | } 234 | defer res.Body.Close() 235 | 236 | body, err := ioutil.ReadAll(res.Body) 237 | if err != nil { 238 | t.Fatalf("%+v", err) 239 | } 240 | 241 | if bytes.Compare(body, []byte("test it")) != 0 { 242 | t.Fatalf("%s missmatch %s", body, "test it") 243 | } 244 | 245 | _, res, err = get(agent, "/broken-deflate") 246 | if err != nil { 247 | t.Fatal(err) 248 | } 249 | 250 | body, err = ioutil.ReadAll(res.Body) 251 | if err == nil { 252 | t.Fatalf("Not raised error with broken encoding") 253 | } 254 | } 255 | 256 | func TestIdentityResponse(t *testing.T) { 257 | srv := newCompressHTTPServer() 258 | defer srv.Close() 259 | 260 | agent, err := NewAgent(WithBaseURL(srv.URL), WithDefaultTransport()) 261 | if err != nil { 262 | t.Fatalf("%+v", err) 263 | } 264 | 265 | req, err := agent.GET("/identity") 266 | if err != nil { 267 | t.Fatalf("%+v", err) 268 | } 269 | 270 | res, err := agent.Do(context.Background(), req) 271 | if err != nil { 272 | t.Fatalf("%+v", err) 273 | } 274 | 275 | if res.StatusCode != 200 { 276 | t.Fatalf("%#v", res) 277 | } 278 | defer res.Body.Close() 279 | 280 | body, err := ioutil.ReadAll(res.Body) 281 | if err != nil { 282 | t.Fatalf("%+v", err) 283 | } 284 | 285 | if bytes.Compare(body, []byte("test it")) != 0 { 286 | t.Fatalf("%s missmatch %s", body, "test it") 287 | } 288 | } 289 | 290 | func TestMultipleResponse(t *testing.T) { 291 | srv := newCompressHTTPServer() 292 | defer srv.Close() 293 | 294 | agent, err := NewAgent(WithBaseURL(srv.URL), WithDefaultTransport()) 295 | if err != nil { 296 | t.Fatalf("%+v", err) 297 | } 298 | 299 | req, err := agent.GET("/multiple") 300 | if err != nil { 301 | t.Fatalf("%+v", err) 302 | } 303 | 304 | res, err := agent.Do(context.Background(), req) 305 | if err != nil { 306 | t.Fatalf("%+v", err) 307 | } 308 | 309 | if res.StatusCode != 200 { 310 | t.Fatalf("%#v", res) 311 | } 312 | defer res.Body.Close() 313 | 314 | body, err := ioutil.ReadAll(res.Body) 315 | if err != nil { 316 | t.Fatalf("%+v", err) 317 | } 318 | 319 | if bytes.Compare(body, []byte("test it")) != 0 { 320 | t.Fatalf("%s missmatch %s", body, "test it") 321 | } 322 | } 323 | 324 | func TestUnknownResponse(t *testing.T) { 325 | srv := newCompressHTTPServer() 326 | defer srv.Close() 327 | 328 | agent, err := NewAgent(WithBaseURL(srv.URL), WithDefaultTransport()) 329 | if err != nil { 330 | t.Fatalf("%+v", err) 331 | } 332 | 333 | req, err := agent.GET("/unknown") 334 | if err != nil { 335 | t.Fatalf("%+v", err) 336 | } 337 | 338 | _, err = agent.Do(context.Background(), req) 339 | if err == nil { 340 | t.Fatalf("expected error but err is nil, %+v", err) 341 | } else { 342 | if !errors.Is(err, ErrUnknownContentEncoding) { 343 | t.Fatalf("%+v", err) 344 | } 345 | } 346 | } 347 | 348 | func TestWithEcho(t *testing.T) { 349 | e := echo.New() 350 | e.GET("/", func(c echo.Context) error { 351 | c.Response().Header().Set("Cache-Control", "public, max-age=10000") 352 | return c.String(200, "test it") 353 | }) 354 | e.Use(middleware.Gzip()) 355 | srv := httptest.NewServer(e) 356 | defer srv.Close() 357 | 358 | agent, err := NewAgent(WithBaseURL(srv.URL), WithDefaultTransport()) 359 | if err != nil { 360 | t.Fatalf("%+v", err) 361 | } 362 | 363 | for i := 0; i < 3; i++ { 364 | req, err := agent.GET("/") 365 | if err != nil { 366 | t.Fatalf("%+v", err) 367 | } 368 | 369 | res, err := agent.Do(context.Background(), req) 370 | if err != nil { 371 | t.Fatalf("%+v", err) 372 | } 373 | 374 | if res.StatusCode != 200 { 375 | t.Fatalf("%#v", res) 376 | } 377 | defer res.Body.Close() 378 | 379 | t.Logf("%+v", res) 380 | 381 | body, err := ioutil.ReadAll(res.Body) 382 | if err != nil { 383 | t.Fatalf("%+v", err) 384 | } 385 | 386 | if bytes.Compare(body, []byte("test it")) != 0 { 387 | t.Fatalf("%s missmatch %s", body, "test it") 388 | } 389 | <-time.After(1 * time.Second) 390 | } 391 | } 392 | -------------------------------------------------------------------------------- /agent/html.go: -------------------------------------------------------------------------------- 1 | package agent 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net/http" 7 | "net/url" 8 | "sync" 9 | "sync/atomic" 10 | 11 | "github.com/isucon/isucandar/failure" 12 | "golang.org/x/net/html" 13 | "golang.org/x/net/html/atom" 14 | ) 15 | 16 | type Resource struct { 17 | InitiatorType string 18 | Request *http.Request 19 | Response *http.Response 20 | Error error 21 | } 22 | 23 | type Resources map[string]*Resource 24 | 25 | func (a *Agent) ProcessHTML(ctx context.Context, r *http.Response, body io.ReadCloser) (Resources, error) { 26 | defer body.Close() 27 | 28 | wg := sync.WaitGroup{} 29 | mu := sync.Mutex{} 30 | resources := make(Resources) 31 | base := &*r.Request.URL 32 | baseChanged := false 33 | 34 | n := int32(0) 35 | favicon := &n 36 | 37 | resourceCollect := func(token html.Token) { 38 | defer wg.Done() 39 | var res *Resource 40 | switch token.DataAtom { 41 | case atom.Link: 42 | res = a.processHTMLLink(ctx, base, token) 43 | case atom.Script: 44 | res = a.processHTMLScript(ctx, base, token) 45 | case atom.Img: 46 | res = a.processHTMLImage(ctx, base, token) 47 | } 48 | 49 | if res != nil && res.Request != nil { 50 | if res.InitiatorType == "favicon" { 51 | atomic.StoreInt32(favicon, 1) 52 | } 53 | mu.Lock() 54 | resources[res.Request.URL.String()] = res 55 | mu.Unlock() 56 | } 57 | } 58 | 59 | doc := html.NewTokenizer(body) 60 | for tokenType := doc.Next(); tokenType != html.ErrorToken; tokenType = doc.Next() { 61 | token := doc.Token() 62 | if token.Type == html.StartTagToken || token.Type == html.SelfClosingTagToken { 63 | switch token.DataAtom { 64 | case atom.Base: 65 | if baseChanged { 66 | break 67 | } 68 | baseChanged = true 69 | href := "" 70 | for _, attr := range token.Attr { 71 | switch attr.Key { 72 | case "href": 73 | href = attr.Val 74 | } 75 | } 76 | if href != "" { 77 | newBaseURL, err := url.Parse(href) 78 | if err == nil { 79 | base = base.ResolveReference(newBaseURL) 80 | } 81 | } 82 | case atom.Link, atom.Script, atom.Img: 83 | wg.Add(1) 84 | go resourceCollect(token) 85 | } 86 | 87 | } 88 | } 89 | 90 | wg.Wait() 91 | 92 | // Automated favicon fetcher 93 | if atomic.LoadInt32(favicon) == 0 { 94 | if res := a.getResource(ctx, base, "/favicon.ico", "favicon"); res != nil && res.Request != nil { 95 | resources[res.Request.URL.String()] = res 96 | } 97 | } 98 | 99 | err := doc.Err() 100 | if failure.Is(err, io.EOF) { 101 | err = nil 102 | } 103 | return resources, err 104 | } 105 | 106 | func (a *Agent) processHTMLLink(ctx context.Context, base *url.URL, token html.Token) *Resource { 107 | rel := "" 108 | href := "" 109 | for _, attr := range token.Attr { 110 | switch attr.Key { 111 | case "rel": 112 | rel = attr.Val 113 | case "href": 114 | href = attr.Val 115 | } 116 | } 117 | 118 | switch rel { 119 | case "stylesheet": 120 | return a.getResource(ctx, base, href, "stylesheet") 121 | case "icon", "shortcut icon": 122 | return a.getResource(ctx, base, href, "favicon") 123 | case "apple-touch-icon", "apple-touch-icon-precomposed": 124 | return a.getResource(ctx, base, href, "apple-touch-icon") 125 | case "manifest": 126 | return a.getResource(ctx, base, href, "manifest") 127 | case "modulepreload": 128 | return a.getResource(ctx, base, href, "modulepreload") 129 | } 130 | 131 | return nil 132 | } 133 | 134 | func (a *Agent) processHTMLScript(ctx context.Context, base *url.URL, token html.Token) *Resource { 135 | src := "" 136 | for _, attr := range token.Attr { 137 | switch attr.Key { 138 | case "src": 139 | src = attr.Val 140 | } 141 | } 142 | 143 | if src == "" { 144 | return nil 145 | } 146 | 147 | return a.getResource(ctx, base, src, "script") 148 | } 149 | 150 | func (a *Agent) processHTMLImage(ctx context.Context, base *url.URL, token html.Token) *Resource { 151 | src := "" 152 | lazy := false // loading="lazy" 153 | for _, attr := range token.Attr { 154 | switch attr.Key { 155 | case "src": 156 | src = attr.Val 157 | case "loading": 158 | lazy = attr.Val == "lazy" 159 | } 160 | } 161 | 162 | if lazy || src == "" { 163 | return nil 164 | } 165 | 166 | return a.getResource(ctx, base, src, "img") 167 | } 168 | 169 | func (a *Agent) getResource(ctx context.Context, base *url.URL, ref string, initiatorType string) (res *Resource) { 170 | res = &Resource{ 171 | InitiatorType: initiatorType, 172 | } 173 | 174 | refURL, err := url.Parse(ref) 175 | if err != nil { 176 | res.Error = err 177 | return 178 | } 179 | refURL = base.ResolveReference(refURL) 180 | 181 | hreq, err := a.GET(refURL.String()) 182 | if err != nil { 183 | res.Error = err 184 | return 185 | } 186 | res.Request = hreq 187 | 188 | hres, err := a.Do(ctx, hreq) 189 | if err != nil && err != io.EOF { 190 | res.Error = err 191 | return 192 | } 193 | res.Response = hres 194 | 195 | return 196 | } 197 | -------------------------------------------------------------------------------- /agent/html_test.go: -------------------------------------------------------------------------------- 1 | package agent 2 | 3 | import ( 4 | "context" 5 | "io" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | ) 10 | 11 | const exampleHTMLDoc = ` 12 | 13 | 14 | 15 | This is agent test 16 | 17 | 18 | stylesheet じゃないならロードしない 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 |

Hello, World

29 | 30 | 31 | 32 | 33 | loading=lazy ならロードしない 34 | 35 | 36 | 37 | 38 | 39 | インラインスクリプトは無視する 40 | 41 | 42 | 43 | ` 44 | 45 | func TestHTMLParse(t *testing.T) { 46 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 47 | w.WriteHeader(200) 48 | io.WriteString(w, exampleHTMLDoc) 49 | })) 50 | defer srv.Close() 51 | 52 | agent, err := NewAgent(WithBaseURL(srv.URL), WithDefaultTransport()) 53 | if err != nil { 54 | t.Fatal(err) 55 | } 56 | 57 | _, res, err := get(agent, "/test.html") 58 | if err != nil { 59 | t.Fatal(err) 60 | } 61 | 62 | resources, err := agent.ProcessHTML(context.Background(), res, res.Body) 63 | if err != nil { 64 | t.Fatal(err) 65 | } 66 | 67 | if len(resources) != 12 { 68 | for k := range resources { 69 | t.Log(k) 70 | } 71 | t.Fatalf("resouces count missmatch: %d", len(resources)) 72 | } 73 | 74 | expects := []string{ 75 | srv.URL + "/root.css", 76 | srv.URL + "/sub/alt.css", 77 | srv.URL + "/cute.png", 78 | srv.URL + "/sub/dir/beautiful.png", 79 | srv.URL + "/need.js", 80 | srv.URL + "/defer.js", 81 | srv.URL + "/async.js", 82 | srv.URL + "/favicon.ico", 83 | srv.URL + "/apple-icon-precomposed.png", 84 | srv.URL + "/apple-icon.png", 85 | srv.URL + "/manifest.webmanifest", 86 | srv.URL + "/modulepreload.js", 87 | } 88 | 89 | for _, eURL := range expects { 90 | if _, ok := resources[eURL]; !ok { 91 | t.Fatalf("resouce not reached: %s", eURL) 92 | } 93 | } 94 | } 95 | 96 | func BenchmarkHTMLParse(b *testing.B) { 97 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 98 | 99 | w.WriteHeader(200) 100 | io.WriteString(w, exampleHTMLDoc) 101 | })) 102 | defer srv.Close() 103 | 104 | agent, err := NewAgent(WithBaseURL(srv.URL), WithDefaultTransport()) 105 | if err != nil { 106 | b.Fatal(err) 107 | } 108 | 109 | for i := 0; i < b.N; i++ { 110 | _, res, err := get(agent, "/test.html") 111 | if err != nil { 112 | b.Fatal(err) 113 | } 114 | 115 | resources, err := agent.ProcessHTML(context.Background(), res, res.Body) 116 | if err != nil { 117 | b.Fatal(err) 118 | } 119 | 120 | if len(resources) != 12 { 121 | for k := range resources { 122 | b.Log(k) 123 | } 124 | b.Fatalf("resouces count missmatch: %d", len(resources)) 125 | } 126 | 127 | expects := []string{ 128 | srv.URL + "/root.css", 129 | srv.URL + "/sub/alt.css", 130 | srv.URL + "/cute.png", 131 | srv.URL + "/sub/dir/beautiful.png", 132 | srv.URL + "/need.js", 133 | srv.URL + "/defer.js", 134 | srv.URL + "/async.js", 135 | srv.URL + "/favicon.ico", 136 | srv.URL + "/apple-icon-precomposed.png", 137 | srv.URL + "/apple-icon.png", 138 | srv.URL + "/modulepreload.js", 139 | } 140 | 141 | for _, eURL := range expects { 142 | if _, ok := resources[eURL]; !ok { 143 | b.Fatalf("resouce not reached: %s", eURL) 144 | } 145 | } 146 | } 147 | } 148 | 149 | const exampleFaviconDoc = ` 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | ` 158 | 159 | func TestFavicon(t *testing.T) { 160 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 161 | w.WriteHeader(200) 162 | io.WriteString(w, exampleFaviconDoc) 163 | })) 164 | defer srv.Close() 165 | 166 | agent, err := NewAgent(WithBaseURL(srv.URL), WithDefaultTransport()) 167 | if err != nil { 168 | t.Fatal(err) 169 | } 170 | 171 | _, res, err := get(agent, "/test.html") 172 | if err != nil { 173 | t.Fatal(err) 174 | } 175 | 176 | resources, err := agent.ProcessHTML(context.Background(), res, res.Body) 177 | if err != nil { 178 | t.Fatal(err) 179 | } 180 | 181 | if len(resources) != 2 { 182 | t.Fatalf("resouces count missmatch: %d", len(resources)) 183 | } 184 | 185 | expects := []string{ 186 | srv.URL + "/x-favicon.ico", 187 | srv.URL + "/x-short-cut-favicon.ico", 188 | } 189 | 190 | for _, eURL := range expects { 191 | if _, ok := resources[eURL]; !ok { 192 | t.Fatalf("resouce not reached: %s", eURL) 193 | } 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /agent/option.go: -------------------------------------------------------------------------------- 1 | package agent 2 | 3 | import ( 4 | "net/http" 5 | "net/url" 6 | "time" 7 | ) 8 | 9 | func WithNoCookie() AgentOption { 10 | return func(a *Agent) error { 11 | a.HttpClient.Jar = nil 12 | return nil 13 | } 14 | } 15 | 16 | func WithNoCache() AgentOption { 17 | return func(a *Agent) error { 18 | a.CacheStore = nil 19 | return nil 20 | } 21 | } 22 | 23 | func WithUserAgent(ua string) AgentOption { 24 | return func(a *Agent) error { 25 | a.Name = ua 26 | return nil 27 | } 28 | } 29 | 30 | func WithBaseURL(base string) AgentOption { 31 | return func(a *Agent) error { 32 | var err error 33 | a.BaseURL, err = url.Parse(base) 34 | return err 35 | } 36 | } 37 | 38 | func WithTimeout(d time.Duration) AgentOption { 39 | return func(a *Agent) error { 40 | a.HttpClient.Timeout = d 41 | return nil 42 | } 43 | } 44 | 45 | func WithDefaultTransport() AgentOption { 46 | return WithTransport(DefaultTransport) 47 | } 48 | 49 | func WithTransport(trs *http.Transport) AgentOption { 50 | return func(a *Agent) error { 51 | a.HttpClient.Transport = trs 52 | 53 | return nil 54 | } 55 | } 56 | 57 | func WithCloneTransport(trs *http.Transport) AgentOption { 58 | return func(a *Agent) error { 59 | a.HttpClient.Transport = trs.Clone() 60 | 61 | return nil 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /agent/option_test.go: -------------------------------------------------------------------------------- 1 | package agent 2 | 3 | import ( 4 | "io" 5 | "net" 6 | "net/http" 7 | "net/http/httptest" 8 | "testing" 9 | "time" 10 | 11 | "github.com/isucon/isucandar/failure" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestNoCookie(t *testing.T) { 16 | agent, err := NewAgent(WithNoCookie(), WithDefaultTransport()) 17 | if err != nil { 18 | t.Fatal(err) 19 | } 20 | 21 | if agent.HttpClient.Jar != nil { 22 | t.Fatal("Not removed cookie jar") 23 | } 24 | } 25 | 26 | func TestNoCache(t *testing.T) { 27 | agent, err := NewAgent(WithNoCache(), WithDefaultTransport()) 28 | if err != nil { 29 | t.Fatal(err) 30 | } 31 | 32 | if agent.CacheStore != nil { 33 | t.Fatal("Not removed cache store") 34 | } 35 | } 36 | 37 | func TestUserAgent(t *testing.T) { 38 | agent, err := NewAgent(WithUserAgent("Hello"), WithDefaultTransport()) 39 | if err != nil { 40 | t.Fatal(err) 41 | } 42 | 43 | if agent.Name != "Hello" { 44 | t.Fatalf("missmatch ua: %s", agent.Name) 45 | } 46 | } 47 | 48 | func TestBaseURL(t *testing.T) { 49 | agent, err := NewAgent(WithBaseURL("http://base.example.com"), WithDefaultTransport()) 50 | if err != nil { 51 | t.Fatal(err) 52 | } 53 | 54 | if agent.BaseURL.String() != "http://base.example.com" { 55 | t.Fatalf("missmatch base URL: %s", agent.BaseURL.String()) 56 | } 57 | } 58 | 59 | func TestTimeout(t *testing.T) { 60 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 61 | <-time.After(2 * time.Second) 62 | w.Header().Set("Content-Type", "text/plain") 63 | w.WriteHeader(http.StatusOK) 64 | 65 | io.WriteString(w, "Hello, World") 66 | })) 67 | defer func() { 68 | go srv.Close() 69 | }() 70 | 71 | agent, err := NewAgent(WithTimeout(1*time.Microsecond), WithBaseURL(srv.URL), WithDefaultTransport()) 72 | if err != nil { 73 | t.Fatal(err) 74 | } 75 | 76 | _, _, err = get(agent, "/") 77 | var nerr net.Error 78 | if ok := failure.As(err, &nerr); !ok || !nerr.Timeout() { 79 | t.Fatalf("expected timeout error: %+v", err) 80 | } 81 | } 82 | 83 | func TestWithoutTransport(t *testing.T) { 84 | _, err := NewAgent() 85 | assert.NotNil(t, err) 86 | if err != nil { 87 | assert.Same(t, ErrTransportInvalid, err) 88 | } 89 | } 90 | 91 | func TestDefaultTransport(t *testing.T) { 92 | agent1, err := NewAgent(WithDefaultTransport()) 93 | if err != nil { 94 | t.Fatal(err) 95 | } 96 | 97 | agent2, err := NewAgent(WithDefaultTransport()) 98 | if err != nil { 99 | t.Fatal(err) 100 | } 101 | 102 | assert.Same(t, agent1.HttpClient.Transport, agent2.HttpClient.Transport) 103 | } 104 | 105 | func TestTransport(t *testing.T) { 106 | trs := DefaultTransport.Clone() 107 | 108 | agent1, err := NewAgent(WithTransport(trs)) 109 | if err != nil { 110 | t.Fatal(err) 111 | } 112 | 113 | assert.Same(t, agent1.HttpClient.Transport, trs) 114 | } 115 | 116 | func TestCloneTransport(t *testing.T) { 117 | agent1, err := NewAgent(WithCloneTransport(DefaultTransport)) 118 | if err != nil { 119 | t.Fatal(err) 120 | } 121 | 122 | agent2, err := NewAgent(WithCloneTransport(DefaultTransport)) 123 | if err != nil { 124 | t.Fatal(err) 125 | } 126 | 127 | assert.NotSame(t, agent1.HttpClient.Transport, agent2.HttpClient.Transport) 128 | } 129 | -------------------------------------------------------------------------------- /benchmark.go: -------------------------------------------------------------------------------- 1 | package isucandar 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "time" 8 | 9 | "github.com/isucon/isucandar/failure" 10 | "github.com/isucon/isucandar/parallel" 11 | ) 12 | 13 | var ( 14 | ErrPanic failure.StringCode = "panic" 15 | ErrPrepare failure.StringCode = "prepare" 16 | ErrLoad failure.StringCode = "load" 17 | ErrValidation failure.StringCode = "validation" 18 | ) 19 | 20 | type BenchmarkStepFunc func(context.Context, *BenchmarkStep) error 21 | type BenchmarkErrorHook func(error, *BenchmarkStep) 22 | 23 | type Benchmark struct { 24 | mu sync.Mutex 25 | 26 | prepareSteps []BenchmarkStepFunc 27 | loadSteps []BenchmarkStepFunc 28 | validationSteps []BenchmarkStepFunc 29 | 30 | panicRecover bool 31 | prepareTimeout time.Duration 32 | loadTimeout time.Duration 33 | ignoreCodes []failure.Code 34 | errorHooks []BenchmarkErrorHook 35 | } 36 | 37 | func NewBenchmark(opts ...BenchmarkOption) (*Benchmark, error) { 38 | benchmark := &Benchmark{ 39 | mu: sync.Mutex{}, 40 | prepareSteps: []BenchmarkStepFunc{}, 41 | loadSteps: []BenchmarkStepFunc{}, 42 | validationSteps: []BenchmarkStepFunc{}, 43 | panicRecover: true, 44 | prepareTimeout: time.Duration(0), 45 | loadTimeout: time.Duration(0), 46 | ignoreCodes: []failure.Code{}, 47 | errorHooks: []BenchmarkErrorHook{}, 48 | } 49 | 50 | for _, opt := range opts { 51 | if err := opt(benchmark); err != nil { 52 | return nil, err 53 | } 54 | } 55 | 56 | return benchmark, nil 57 | } 58 | 59 | func (b *Benchmark) Start(parent context.Context) *BenchmarkResult { 60 | ctx, cancel := context.WithCancel(parent) 61 | result := newBenchmarkResult(ctx) 62 | defer cancel() 63 | 64 | step := &BenchmarkStep{ 65 | mu: sync.RWMutex{}, 66 | result: result, 67 | cancel: cancel, 68 | } 69 | 70 | for _, hook := range b.errorHooks { 71 | func(hook BenchmarkErrorHook) { 72 | result.Errors.Hook(func(err error) { 73 | hook(err, step) 74 | }) 75 | }(hook) 76 | } 77 | 78 | var ( 79 | loadParallel *parallel.Parallel 80 | loadCtx context.Context 81 | loadCancel context.CancelFunc 82 | ) 83 | 84 | step.setErrorCode(ErrPrepare) 85 | for _, prepare := range b.prepareSteps { 86 | var ( 87 | prepareCtx context.Context 88 | prepareCancel context.CancelFunc 89 | ) 90 | 91 | if b.prepareTimeout > 0 { 92 | prepareCtx, prepareCancel = context.WithTimeout(ctx, b.prepareTimeout) 93 | } else { 94 | prepareCtx, prepareCancel = context.WithCancel(ctx) 95 | } 96 | defer prepareCancel() 97 | 98 | if err := panicWrapper(b.panicRecover, func() error { return prepare(prepareCtx, step) }); err != nil { 99 | for _, ignore := range b.ignoreCodes { 100 | if failure.IsCode(err, ignore) { 101 | goto Result 102 | } 103 | } 104 | step.AddError(err) 105 | goto Result 106 | } 107 | } 108 | 109 | result.Errors.Wait() 110 | 111 | if ctx.Err() != nil { 112 | goto Result 113 | } 114 | 115 | step.setErrorCode(ErrLoad) 116 | if b.loadTimeout > 0 { 117 | loadCtx, loadCancel = context.WithTimeout(ctx, b.loadTimeout) 118 | } else { 119 | loadCtx, loadCancel = context.WithCancel(ctx) 120 | } 121 | loadParallel = parallel.NewParallel(loadCtx, -1) 122 | 123 | for _, load := range b.loadSteps { 124 | func(f BenchmarkStepFunc) { 125 | loadParallel.Do(func(c context.Context) { 126 | if err := panicWrapper(b.panicRecover, func() error { return f(c, step) }); err != nil { 127 | for _, ignore := range b.ignoreCodes { 128 | if failure.IsCode(err, ignore) { 129 | return 130 | } 131 | } 132 | step.AddError(err) 133 | } 134 | }) 135 | }(load) 136 | } 137 | loadParallel.Wait() 138 | loadCancel() 139 | 140 | result.Errors.Wait() 141 | 142 | if ctx.Err() != nil { 143 | goto Result 144 | } 145 | 146 | step.setErrorCode(ErrValidation) 147 | for _, validation := range b.validationSteps { 148 | if err := panicWrapper(b.panicRecover, func() error { return validation(ctx, step) }); err != nil { 149 | for _, ignore := range b.ignoreCodes { 150 | if failure.IsCode(err, ignore) { 151 | goto Result 152 | } 153 | } 154 | step.AddError(err) 155 | goto Result 156 | } 157 | } 158 | 159 | Result: 160 | cancel() 161 | step.wait() 162 | step.setErrorCode(nil) 163 | 164 | return result 165 | } 166 | 167 | func (b *Benchmark) OnError(f BenchmarkErrorHook) { 168 | b.mu.Lock() 169 | defer b.mu.Unlock() 170 | 171 | b.errorHooks = append(b.errorHooks, f) 172 | } 173 | 174 | func (b *Benchmark) Prepare(f BenchmarkStepFunc) { 175 | b.mu.Lock() 176 | defer b.mu.Unlock() 177 | 178 | b.prepareSteps = append(b.prepareSteps, f) 179 | } 180 | 181 | func (b *Benchmark) Load(f BenchmarkStepFunc) { 182 | b.mu.Lock() 183 | defer b.mu.Unlock() 184 | 185 | b.loadSteps = append(b.loadSteps, f) 186 | } 187 | 188 | func (b *Benchmark) Validation(f BenchmarkStepFunc) { 189 | b.mu.Lock() 190 | defer b.mu.Unlock() 191 | 192 | b.validationSteps = append(b.validationSteps, f) 193 | } 194 | 195 | func (b *Benchmark) IgnoreErrorCode(code failure.Code) { 196 | b.mu.Lock() 197 | defer b.mu.Unlock() 198 | 199 | b.ignoreCodes = append(b.ignoreCodes, code) 200 | } 201 | 202 | func panicWrapper(on bool, f func() error) (err error) { 203 | if !on { 204 | return f() 205 | } 206 | 207 | defer func() { 208 | re := recover() 209 | if re == nil { 210 | return 211 | } 212 | 213 | if rerr, ok := re.(error); !ok { 214 | err = failure.NewError(ErrPanic, fmt.Errorf("%v", re)) 215 | } else { 216 | err = failure.NewError(ErrPanic, rerr) 217 | } 218 | }() 219 | 220 | return f() 221 | } 222 | -------------------------------------------------------------------------------- /benchmark_option.go: -------------------------------------------------------------------------------- 1 | package isucandar 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | type BenchmarkOption func(*Benchmark) error 8 | 9 | func WithPrepareTimeout(d time.Duration) BenchmarkOption { 10 | return func(b *Benchmark) error { 11 | b.mu.Lock() 12 | defer b.mu.Unlock() 13 | b.prepareTimeout = d 14 | return nil 15 | } 16 | } 17 | 18 | func WithLoadTimeout(d time.Duration) BenchmarkOption { 19 | return func(b *Benchmark) error { 20 | b.mu.Lock() 21 | defer b.mu.Unlock() 22 | b.loadTimeout = d 23 | return nil 24 | } 25 | } 26 | 27 | func WithoutPanicRecover() BenchmarkOption { 28 | return func(b *Benchmark) error { 29 | b.panicRecover = false 30 | return nil 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /benchmark_result.go: -------------------------------------------------------------------------------- 1 | package isucandar 2 | 3 | import ( 4 | "context" 5 | "github.com/isucon/isucandar/failure" 6 | "github.com/isucon/isucandar/score" 7 | ) 8 | 9 | type BenchmarkResult struct { 10 | Score *score.Score 11 | Errors *failure.Errors 12 | } 13 | 14 | func newBenchmarkResult(ctx context.Context) *BenchmarkResult { 15 | return &BenchmarkResult{ 16 | Score: score.NewScore(ctx), 17 | Errors: failure.NewErrors(ctx), 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /benchmark_scenario.go: -------------------------------------------------------------------------------- 1 | package isucandar 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | ) 7 | 8 | var ( 9 | ErrInvalidScenario = errors.New("Invalid scenario interface") 10 | ) 11 | 12 | type PrepareScenario interface { 13 | Prepare(context.Context, *BenchmarkStep) error 14 | } 15 | 16 | type LoadScenario interface { 17 | Load(context.Context, *BenchmarkStep) error 18 | } 19 | 20 | type ValidationScenario interface { 21 | Validation(context.Context, *BenchmarkStep) error 22 | } 23 | 24 | func (b *Benchmark) AddScenario(scenario interface{}) { 25 | match := false 26 | if p, ok := scenario.(PrepareScenario); ok { 27 | b.Prepare(p.Prepare) 28 | match = true 29 | } 30 | 31 | if l, ok := scenario.(LoadScenario); ok { 32 | b.Load(l.Load) 33 | match = true 34 | } 35 | 36 | if v, ok := scenario.(ValidationScenario); ok { 37 | b.Validation(v.Validation) 38 | match = true 39 | } 40 | 41 | if !match { 42 | panic(ErrInvalidScenario) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /benchmark_scenario_test.go: -------------------------------------------------------------------------------- 1 | package isucandar 2 | 3 | import ( 4 | "context" 5 | "sync/atomic" 6 | "testing" 7 | ) 8 | 9 | type exampleScenario struct { 10 | prepare uint32 11 | load uint32 12 | validation uint32 13 | } 14 | 15 | func (e *exampleScenario) Prepare(_ context.Context, _ *BenchmarkStep) error { 16 | atomic.StoreUint32(&e.prepare, 1) 17 | return nil 18 | } 19 | 20 | func (e *exampleScenario) Load(_ context.Context, _ *BenchmarkStep) error { 21 | atomic.StoreUint32(&e.load, 1) 22 | return nil 23 | } 24 | 25 | func (e *exampleScenario) Validation(_ context.Context, _ *BenchmarkStep) error { 26 | atomic.StoreUint32(&e.validation, 1) 27 | return nil 28 | } 29 | 30 | func TestBenchmarkAddScenario(t *testing.T) { 31 | benchmark, err := NewBenchmark() 32 | if err != nil { 33 | t.Fatal(err) 34 | } 35 | 36 | e := &exampleScenario{ 37 | prepare: 0, 38 | load: 0, 39 | validation: 0, 40 | } 41 | 42 | benchmark.AddScenario(e) 43 | 44 | result := benchmark.Start(context.Background()) 45 | 46 | if len(result.Errors.All()) > 0 { 47 | t.Fatal(result.Errors.All()) 48 | } 49 | 50 | if e.prepare != 1 || e.load != 1 || e.validation != 1 { 51 | t.Fatal(e) 52 | } 53 | } 54 | 55 | func TestBenchmarkAddScenarioPanic(t *testing.T) { 56 | benchmark, err := NewBenchmark() 57 | if err != nil { 58 | t.Fatal(err) 59 | } 60 | 61 | var rerr interface{} 62 | func() { 63 | defer func() { 64 | rerr = recover() 65 | }() 66 | benchmark.AddScenario(nil) 67 | }() 68 | 69 | if rerr == nil { 70 | t.Fatal("Do not register invalid scenario") 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /benchmark_step.go: -------------------------------------------------------------------------------- 1 | package isucandar 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | 7 | "github.com/isucon/isucandar/failure" 8 | "github.com/isucon/isucandar/score" 9 | ) 10 | 11 | type BenchmarkStep struct { 12 | errorCode failure.Code 13 | mu sync.RWMutex 14 | result *BenchmarkResult 15 | cancel context.CancelFunc 16 | } 17 | 18 | func (b *BenchmarkStep) setErrorCode(code failure.Code) { 19 | b.mu.Lock() 20 | defer b.mu.Unlock() 21 | 22 | b.errorCode = code 23 | } 24 | 25 | func (b *BenchmarkStep) AddError(err error) { 26 | b.mu.RLock() 27 | defer b.mu.RUnlock() 28 | 29 | if b.errorCode != nil { 30 | b.result.Errors.Add(failure.NewError(b.errorCode, err)) 31 | } else { 32 | b.result.Errors.Add(err) 33 | } 34 | } 35 | 36 | func (b *BenchmarkStep) AddScore(tag score.ScoreTag) { 37 | b.result.Score.Add(tag) 38 | } 39 | 40 | func (b *BenchmarkStep) Cancel() { 41 | b.cancel() 42 | } 43 | 44 | func (b *BenchmarkStep) Result() *BenchmarkResult { 45 | return b.result 46 | } 47 | 48 | func (b *BenchmarkStep) wait() { 49 | wg := sync.WaitGroup{} 50 | wg.Add(1) 51 | go func() { 52 | b.result.Score.Wait() 53 | wg.Done() 54 | }() 55 | wg.Add(1) 56 | go func() { 57 | b.result.Errors.Wait() 58 | wg.Done() 59 | }() 60 | wg.Wait() 61 | } 62 | -------------------------------------------------------------------------------- /benchmark_test.go: -------------------------------------------------------------------------------- 1 | package isucandar 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "testing" 7 | "time" 8 | 9 | "github.com/isucon/isucandar/failure" 10 | ) 11 | 12 | var ( 13 | ErrIgnore failure.StringCode = "ignore" 14 | ErrBenchmarkCancel failure.StringCode = "banchmark-cancel" 15 | ) 16 | 17 | func newBenchmark(opts ...BenchmarkOption) *Benchmark { 18 | benchmark, err := NewBenchmark(opts...) 19 | if err != nil { 20 | panic(err) 21 | } 22 | 23 | benchmark.IgnoreErrorCode(ErrIgnore) 24 | 25 | benchmark.Prepare(func(ctx context.Context, s *BenchmarkStep) error { 26 | time.Sleep(1 * time.Microsecond) 27 | select { 28 | case <-ctx.Done(): 29 | return ctx.Err() 30 | default: 31 | return nil 32 | } 33 | }) 34 | 35 | benchmark.Load(func(ctx context.Context, s *BenchmarkStep) error { 36 | time.Sleep(1 * time.Microsecond) 37 | select { 38 | case <-ctx.Done(): 39 | return ctx.Err() 40 | default: 41 | return nil 42 | } 43 | }) 44 | 45 | benchmark.Validation(func(ctx context.Context, s *BenchmarkStep) error { 46 | time.Sleep(1 * time.Microsecond) 47 | select { 48 | case <-ctx.Done(): 49 | return ctx.Err() 50 | default: 51 | return nil 52 | } 53 | }) 54 | 55 | return benchmark 56 | } 57 | 58 | func TestBenchmark(t *testing.T) { 59 | ctx := context.TODO() 60 | b := newBenchmark() 61 | 62 | result := b.Start(ctx) 63 | 64 | if len(result.Errors.All()) != 0 { 65 | t.Fatal(result.Errors.All()) 66 | } 67 | } 68 | 69 | func TestBenchmarkScore(t *testing.T) { 70 | ctx := context.TODO() 71 | b := newBenchmark() 72 | 73 | b.Load(func(_ context.Context, s *BenchmarkStep) error { 74 | s.AddScore("dummy") 75 | return nil 76 | }) 77 | 78 | result := b.Start(ctx) 79 | 80 | if len(result.Errors.All()) != 0 { 81 | t.Fatal(result.Errors.All()) 82 | } 83 | 84 | result.Score.Set("dummy", 1) 85 | 86 | if result.Score.Sum() != 1 { 87 | t.Fatalf("%d", result.Score.Sum()) 88 | } 89 | } 90 | 91 | func TestBenchmarkCreation(t *testing.T) { 92 | raise := errors.New("error") 93 | _, err := NewBenchmark(func(b *Benchmark) error { 94 | return raise 95 | }) 96 | 97 | if err != raise { 98 | t.Fatal(err) 99 | } 100 | } 101 | 102 | func TestBenchmarkErrorHook(t *testing.T) { 103 | ctx := context.TODO() 104 | b := newBenchmark() 105 | b.OnError(func(err error, s *BenchmarkStep) { 106 | if failure.IsCode(err, ErrBenchmarkCancel) { 107 | s.Cancel() 108 | } 109 | }) 110 | 111 | b.Prepare(func(_ context.Context, s *BenchmarkStep) error { 112 | s.AddError(failure.NewError(ErrBenchmarkCancel, errors.New("cancel"))) 113 | return nil 114 | }) 115 | 116 | loaded := false 117 | b.Load(func(_ context.Context, _ *BenchmarkStep) error { 118 | loaded = true 119 | return nil 120 | }) 121 | 122 | b.Start(ctx) 123 | 124 | if loaded { 125 | t.Fatal("error hook error") 126 | } 127 | } 128 | 129 | func TestBenchmarkPrepareTimeout(t *testing.T) { 130 | ctx := context.TODO() 131 | b := newBenchmark(WithPrepareTimeout(1)) 132 | 133 | result := b.Start(ctx) 134 | 135 | if len(result.Errors.All()) != 1 || !failure.Is(result.Errors.All()[0], context.DeadlineExceeded) { 136 | t.Fatal(result.Errors.All()) 137 | } 138 | } 139 | 140 | func TestBenchmarkPreparePanic(t *testing.T) { 141 | ctx := context.TODO() 142 | b := newBenchmark() 143 | 144 | b.Prepare(func(_ context.Context, _ *BenchmarkStep) error { 145 | panic("Prepare panic") 146 | }) 147 | 148 | result := b.Start(ctx) 149 | 150 | if len(result.Errors.All()) != 1 || !failure.IsCode(result.Errors.All()[0], ErrPanic) { 151 | t.Fatal(result.Errors.All()) 152 | } 153 | } 154 | 155 | func TestBenchmarkPreparePanicError(t *testing.T) { 156 | ctx := context.TODO() 157 | b := newBenchmark() 158 | 159 | err := errors.New("Prepare panic") 160 | b.Prepare(func(_ context.Context, _ *BenchmarkStep) error { 161 | panic(err) 162 | }) 163 | 164 | result := b.Start(ctx) 165 | 166 | if len(result.Errors.All()) != 1 || !failure.Is(result.Errors.All()[0], err) { 167 | t.Fatal(result.Errors.All()) 168 | } 169 | } 170 | 171 | func TestBenchmarkPrepareIgnoredError(t *testing.T) { 172 | ctx := context.TODO() 173 | b := newBenchmark() 174 | 175 | err := failure.NewError(ErrIgnore, errors.New("Prepare panic")) 176 | b.Prepare(func(_ context.Context, _ *BenchmarkStep) error { 177 | return err 178 | }) 179 | 180 | loaded := false 181 | b.Load(func(_ context.Context, _ *BenchmarkStep) error { 182 | loaded = true 183 | return nil 184 | }) 185 | 186 | result := b.Start(ctx) 187 | 188 | if len(result.Errors.All()) != 0 { 189 | t.Fatal(result.Errors.All()) 190 | } 191 | 192 | if loaded { 193 | t.Fatal("ignore error") 194 | } 195 | } 196 | 197 | func TestBenchmarkPrepareCancel(t *testing.T) { 198 | ctx := context.TODO() 199 | b := newBenchmark() 200 | 201 | b.Prepare(func(_ context.Context, s *BenchmarkStep) error { 202 | s.Cancel() 203 | return nil 204 | }) 205 | 206 | loaded := false 207 | b.Load(func(_ context.Context, _ *BenchmarkStep) error { 208 | loaded = true 209 | return nil 210 | }) 211 | 212 | result := b.Start(ctx) 213 | 214 | if len(result.Errors.All()) > 1 { 215 | t.Fatal(result.Errors.All()) 216 | } 217 | 218 | if loaded { 219 | t.Fatal("cancel error") 220 | } 221 | } 222 | 223 | func TestBenchmarkLoadTimeout(t *testing.T) { 224 | ctx := context.TODO() 225 | b := newBenchmark(WithLoadTimeout(5 * time.Millisecond)) 226 | 227 | runAll := false 228 | b.Load(func(ctx context.Context, _ *BenchmarkStep) error { 229 | time.Sleep(100 * time.Millisecond) 230 | runAll = true 231 | return nil 232 | }) 233 | b.Start(ctx) 234 | 235 | if runAll { 236 | t.Fatal("Not timeout") 237 | } 238 | } 239 | 240 | func TestBenchmarkLoadPanic(t *testing.T) { 241 | ctx := context.TODO() 242 | b := newBenchmark() 243 | 244 | b.Load(func(_ context.Context, _ *BenchmarkStep) error { 245 | panic("Load panic") 246 | }) 247 | 248 | result := b.Start(ctx) 249 | 250 | if len(result.Errors.All()) != 1 || !failure.IsCode(result.Errors.All()[0], ErrPanic) { 251 | t.Fatal(result.Errors.All()) 252 | } 253 | t.Log(result.Errors.All()) 254 | } 255 | 256 | func TestBenchmarkLoadPanicError(t *testing.T) { 257 | ctx := context.TODO() 258 | b := newBenchmark() 259 | 260 | err := errors.New("Load panic") 261 | b.Load(func(_ context.Context, _ *BenchmarkStep) error { 262 | panic(err) 263 | }) 264 | 265 | result := b.Start(ctx) 266 | 267 | if len(result.Errors.All()) != 1 || !failure.Is(result.Errors.All()[0], err) { 268 | t.Fatal(result.Errors.All()) 269 | } 270 | } 271 | 272 | func TestBenchmarkLoadIgnoredError(t *testing.T) { 273 | ctx := context.TODO() 274 | b := newBenchmark() 275 | 276 | err := failure.NewError(ErrIgnore, errors.New("Prepare panic")) 277 | b.Load(func(_ context.Context, _ *BenchmarkStep) error { 278 | return err 279 | }) 280 | 281 | loaded := false 282 | b.Validation(func(_ context.Context, _ *BenchmarkStep) error { 283 | loaded = true 284 | return nil 285 | }) 286 | 287 | result := b.Start(ctx) 288 | 289 | if len(result.Errors.All()) != 0 { 290 | t.Fatal(result.Errors.All()) 291 | } 292 | 293 | if !loaded { 294 | t.Fatal("ignore error") 295 | } 296 | } 297 | 298 | func TestBenchmarkLoadCancel(t *testing.T) { 299 | ctx := context.TODO() 300 | b := newBenchmark() 301 | 302 | b.Load(func(_ context.Context, s *BenchmarkStep) error { 303 | s.Cancel() 304 | return nil 305 | }) 306 | 307 | loaded := false 308 | b.Validation(func(_ context.Context, _ *BenchmarkStep) error { 309 | loaded = true 310 | return nil 311 | }) 312 | 313 | result := b.Start(ctx) 314 | 315 | if len(result.Errors.All()) > 1 { 316 | t.Fatal(result.Errors.All()) 317 | } 318 | 319 | if loaded { 320 | t.Fatal("cancel error") 321 | } 322 | } 323 | 324 | func TestBenchmarkValidationPanic(t *testing.T) { 325 | ctx := context.TODO() 326 | b := newBenchmark() 327 | 328 | b.Validation(func(_ context.Context, _ *BenchmarkStep) error { 329 | panic("Validation panic") 330 | }) 331 | 332 | result := b.Start(ctx) 333 | 334 | if len(result.Errors.All()) != 1 || !failure.IsCode(result.Errors.All()[0], ErrPanic) { 335 | t.Fatal(result.Errors.All()) 336 | } 337 | } 338 | 339 | func TestBenchmarkValidationPanicError(t *testing.T) { 340 | ctx := context.TODO() 341 | b := newBenchmark() 342 | 343 | err := errors.New("Validation panic") 344 | b.Validation(func(_ context.Context, _ *BenchmarkStep) error { 345 | panic(err) 346 | }) 347 | 348 | result := b.Start(ctx) 349 | 350 | if len(result.Errors.All()) != 1 || !failure.Is(result.Errors.All()[0], err) { 351 | t.Fatal(result.Errors.All()) 352 | } 353 | } 354 | 355 | func TestBenchmarkValidationIgnoredError(t *testing.T) { 356 | ctx := context.TODO() 357 | b := newBenchmark() 358 | 359 | err := failure.NewError(ErrIgnore, errors.New("Prepare panic")) 360 | b.Validation(func(_ context.Context, _ *BenchmarkStep) error { 361 | return err 362 | }) 363 | 364 | result := b.Start(ctx) 365 | 366 | if len(result.Errors.All()) != 0 { 367 | t.Fatal(result.Errors.All()) 368 | } 369 | } 370 | 371 | func TestBenchmarWithoutPanicRecover(t *testing.T) { 372 | ctx := context.TODO() 373 | b := newBenchmark(WithoutPanicRecover()) 374 | 375 | panicErr := errors.New("panic") 376 | b.Validation(func(_ context.Context, _ *BenchmarkStep) error { 377 | panic(panicErr) 378 | }) 379 | 380 | func() { 381 | defer func() { 382 | err := recover() 383 | if err == nil { 384 | t.Fatalf("not thrown panic") 385 | } 386 | 387 | if err != panicErr { 388 | t.Fatalf("invalid panic: %+v", err) 389 | } 390 | }() 391 | 392 | b.Start(ctx) 393 | }() 394 | } 395 | -------------------------------------------------------------------------------- /demo/agent/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | 7 | "github.com/isucon/isucandar/agent" 8 | ) 9 | 10 | func main() { 11 | ctx, cancel := context.WithCancel(context.Background()) 12 | defer cancel() 13 | 14 | agent, err := agent.NewAgent(agent.WithDefaultTransport(), agent.WithBaseURL("https://github.com/")) 15 | if err != nil { 16 | panic(err) 17 | } 18 | 19 | req, err := agent.GET("/") 20 | if err != nil { 21 | panic(err) 22 | } 23 | 24 | res, err := agent.Do(ctx, req) 25 | if err != nil { 26 | panic(err) 27 | } 28 | 29 | resources, err := agent.ProcessHTML(ctx, res, res.Body) 30 | if err != nil { 31 | panic(err) 32 | } 33 | 34 | for url, resource := range resources { 35 | fmt.Printf("%s: %s: %s\n", resource.InitiatorType, resource.Response.Header.Get("Content-Type"), url) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /demo/failure/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "math/rand" 7 | 8 | "github.com/isucon/isucandar/failure" 9 | ) 10 | 11 | var ( 12 | ErrDeepCall failure.StringCode = "DEEP" 13 | ) 14 | 15 | func deepError(n int) error { 16 | if n > 0 { 17 | return deepError(n - 1) 18 | } else { 19 | return failure.NewError(ErrDeepCall, fmt.Errorf("error")) 20 | } 21 | } 22 | 23 | func main() { 24 | failure.BacktraceCleaner.Add(failure.SkipGOROOT) 25 | 26 | ctx, cancel := context.WithCancel(context.Background()) 27 | errors := failure.NewErrors(ctx) 28 | 29 | errors.Add(deepError(rand.Intn(5))) 30 | errors.Add(deepError(rand.Intn(5))) 31 | errors.Add(deepError(rand.Intn(5))) 32 | cancel() 33 | 34 | errors.Wait() 35 | 36 | for _, err := range errors.All() { 37 | fmt.Printf("%+v\n", err) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /demo/pubsub/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "time" 8 | 9 | "github.com/isucon/isucandar/pubsub" 10 | "github.com/isucon/isucandar/worker" 11 | ) 12 | 13 | func launchWorker(ctx context.Context, pubsub *pubsub.PubSub, format string) error { 14 | worker, err := worker.NewWorker(func(_ context.Context, _ int) { 15 | fmt.Println(time.Now().Format(format)) 16 | time.Sleep(time.Second) 17 | }, worker.WithMaxParallelism(1)) 18 | if err != nil { 19 | return err 20 | } 21 | 22 | go worker.Process(ctx) 23 | 24 | <-pubsub.Subscribe(ctx, func(limit interface{}) { 25 | l := limit.(int32) 26 | fmt.Printf("Worker increase: %d\n", l) 27 | worker.AddParallelism(l) 28 | }) 29 | 30 | return nil 31 | } 32 | 33 | func main() { 34 | p := pubsub.NewPubSub() 35 | 36 | wg := sync.WaitGroup{} 37 | wg.Add(3) 38 | 39 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 40 | defer cancel() 41 | 42 | go func() { 43 | for i := 1; i < 3; i++ { 44 | time.Sleep(1 * time.Second) 45 | p.Publish(int32(i)) 46 | } 47 | }() 48 | 49 | go func() { 50 | launchWorker(ctx, p, time.RFC822) 51 | wg.Done() 52 | }() 53 | go func() { 54 | launchWorker(ctx, p, time.RFC850) 55 | wg.Done() 56 | }() 57 | go func() { 58 | launchWorker(ctx, p, time.RFC3339) 59 | wg.Done() 60 | }() 61 | 62 | wg.Wait() 63 | } 64 | -------------------------------------------------------------------------------- /demo/worker/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/isucon/isucandar/worker" 9 | ) 10 | 11 | func main() { 12 | ctx, cancel := context.WithCancel(context.Background()) 13 | defer cancel() 14 | 15 | timeWorker, err := worker.NewWorker(func(ctx context.Context, _ int) { 16 | fmt.Println(time.Now().Format(time.RFC3339)) 17 | time.Sleep(1 * time.Second) 18 | }, worker.WithInfinityLoop(), worker.WithMaxParallelism(1)) 19 | if err != nil { 20 | panic(err) 21 | } 22 | 23 | increaseWorker, err := worker.NewWorker(func(ctx context.Context, _ int) { 24 | time.Sleep(3 * time.Second) 25 | fmt.Println("Increase time worker!") 26 | timeWorker.AddParallelism(1) 27 | }, worker.WithLoopCount(3), worker.WithMaxParallelism(1)) 28 | if err != nil { 29 | panic(err) 30 | } 31 | 32 | go func() { 33 | time.Sleep(10 * time.Second) 34 | cancel() 35 | }() 36 | 37 | go func() { 38 | increaseWorker.Process(ctx) 39 | fmt.Println("Increase worker executed") 40 | }() 41 | timeWorker.Process(ctx) 42 | fmt.Println("Time worker executed") 43 | } 44 | -------------------------------------------------------------------------------- /failure/cleaner.go: -------------------------------------------------------------------------------- 1 | package failure 2 | 3 | import ( 4 | "fmt" 5 | "runtime" 6 | "strconv" 7 | "strings" 8 | 9 | "golang.org/x/xerrors" 10 | ) 11 | 12 | var ( 13 | BacktraceCleaner = &backtraceCleaner{} 14 | ) 15 | 16 | type Backtrace struct { 17 | Function string 18 | File string 19 | LineNo int 20 | } 21 | 22 | func (b *Backtrace) String() string { 23 | return fmt.Sprintf("%s\n %s:%d", b.Function, b.File, b.LineNo) 24 | } 25 | 26 | type backtraceCleaner struct { 27 | matcher func(Backtrace) bool 28 | } 29 | 30 | func (bc *backtraceCleaner) match(frame xerrors.Frame) bool { 31 | if bc.matcher == nil { 32 | return false 33 | } 34 | 35 | c := &frameConvertor{ 36 | Function: "", 37 | File: "", 38 | LineNo: -1, 39 | } 40 | frame.Format(c) 41 | 42 | b := c.Backtrace() 43 | return bc.matcher(b) 44 | } 45 | 46 | func (bc *backtraceCleaner) Add(matcher func(Backtrace) bool) { 47 | oldMatcher := bc.matcher 48 | m := func(b Backtrace) bool { 49 | if oldMatcher == nil { 50 | return matcher(b) 51 | } 52 | return matcher(b) || oldMatcher(b) 53 | } 54 | bc.matcher = m 55 | } 56 | 57 | type frameConvertor struct { 58 | Function string 59 | File string 60 | LineNo int 61 | } 62 | 63 | func (f *frameConvertor) Detail() bool { 64 | // frame の内容を取りたいので常に true 65 | return true 66 | } 67 | 68 | func (f *frameConvertor) Print(args ...interface{}) {} 69 | 70 | func (f *frameConvertor) Printf(format string, args ...interface{}) { 71 | switch format { 72 | case "%s\n ": // function name formatter 73 | f.Function = fmt.Sprintf("%s", args[0]) 74 | case "%s:%d\n": // file name formatter 75 | f.File = fmt.Sprintf("%s", args[0]) 76 | f.LineNo, _ = strconv.Atoi(fmt.Sprintf("%d", args[1])) 77 | } 78 | } 79 | 80 | func (f *frameConvertor) Backtrace() Backtrace { 81 | return Backtrace{ 82 | Function: f.Function, 83 | File: f.File, 84 | LineNo: f.LineNo, 85 | } 86 | } 87 | 88 | func SkipGOROOT(b Backtrace) bool { 89 | return strings.HasPrefix(b.File, runtime.GOROOT()) 90 | } 91 | -------------------------------------------------------------------------------- /failure/cleaner_test.go: -------------------------------------------------------------------------------- 1 | package failure 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "testing" 7 | 8 | "golang.org/x/xerrors" 9 | ) 10 | 11 | func TestBacktraceCleaner(t *testing.T) { 12 | cleaner := &backtraceCleaner{} 13 | defaultCleaner := BacktraceCleaner 14 | defaultCaptureBacktraceSize := CaptureBacktraceSize 15 | BacktraceCleaner = cleaner 16 | CaptureBacktraceSize = 100 17 | defer func() { 18 | BacktraceCleaner = defaultCleaner 19 | CaptureBacktraceSize = defaultCaptureBacktraceSize 20 | }() 21 | 22 | cleaner.Add(SkipGOROOT) 23 | cleaner.Add(func(b Backtrace) bool { 24 | return strings.HasSuffix(b.Function, "TestBacktraceCleaner") 25 | }) 26 | 27 | var code StringCode = "cleaner" 28 | var f func(int) error 29 | f = func(n int) error { 30 | if n > 0 { 31 | return f(n - 1) 32 | } 33 | return NewError(code, fmt.Errorf("invalid")) 34 | } 35 | 36 | err := f(0) 37 | 38 | details := fmt.Sprintf("%+v", err) 39 | dLines := strings.Split(details, "\n") 40 | 41 | // TestBacktraceCleaner.func3: not match 42 | // TestBacktraceCleaner: match with Name 43 | // testing.tRunner: match with GOROOT 44 | expectLines := ((3 - 2) * 2) + 2 45 | if len(dLines) != expectLines { 46 | t.Logf("\n%+v", err) 47 | t.Fatalf("missmatch call stack size: %d / %d", len(dLines), expectLines) 48 | } 49 | } 50 | 51 | func TestFrameConvertor(t *testing.T) { 52 | convertor := &frameConvertor{} 53 | 54 | frame := xerrors.Caller(0) 55 | 56 | // No op 57 | convertor.Print(frame) 58 | 59 | frame.Format(convertor) 60 | backtrace := convertor.Backtrace() 61 | 62 | if !strings.HasSuffix(backtrace.Function, "failure.TestFrameConvertor") { 63 | t.Fatalf("Not match function: %s", backtrace.String()) 64 | } 65 | t.Logf("%s", backtrace.String()) 66 | } 67 | -------------------------------------------------------------------------------- /failure/code.go: -------------------------------------------------------------------------------- 1 | package failure 2 | 3 | import ( 4 | "golang.org/x/xerrors" 5 | ) 6 | 7 | type Code interface { 8 | Error() string 9 | ErrorCode() string 10 | } 11 | 12 | type StringCode string 13 | 14 | func (s StringCode) Error() string { 15 | return string(s) 16 | } 17 | 18 | func (s StringCode) ErrorCode() string { 19 | return string(s) 20 | } 21 | 22 | func GetErrorCode(err error) string { 23 | var code Code 24 | if ok := As(err, &code); ok { 25 | return code.ErrorCode() 26 | } else { 27 | return UnknownErrorCode.ErrorCode() 28 | } 29 | } 30 | 31 | func GetErrorCodes(err error) []string { 32 | var code Code 33 | var wrap xerrors.Wrapper 34 | 35 | unwrapped := false 36 | codes := []string{} 37 | 38 | for err != nil { 39 | if ok := As(err, &code); ok { 40 | codes = append(codes, code.ErrorCode()) 41 | } else if !unwrapped { 42 | codes = append(codes, UnknownErrorCode.ErrorCode()) 43 | } 44 | 45 | if ok := As(err, &wrap); ok { 46 | err = wrap.Unwrap() 47 | unwrapped = true 48 | } else { 49 | err = nil 50 | } 51 | } 52 | 53 | return codes 54 | } 55 | 56 | const ( 57 | UnknownErrorCode StringCode = "unknown" 58 | CanceledErrorCode StringCode = "canceled" 59 | TimeoutErrorCode StringCode = "timeout" 60 | TemporaryErrorCode StringCode = "temporary" 61 | ) 62 | -------------------------------------------------------------------------------- /failure/code_test.go: -------------------------------------------------------------------------------- 1 | package failure 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | ) 7 | 8 | type errorWithCode struct { 9 | code string 10 | message string 11 | } 12 | 13 | func (e *errorWithCode) Error() string { 14 | return e.message 15 | } 16 | 17 | func (e *errorWithCode) ErrorCode() string { 18 | return e.code 19 | } 20 | 21 | func TestErrorCode(t *testing.T) { 22 | err := errors.New("test") 23 | if code := GetErrorCode(err); code != "unknown" { 24 | t.Fatalf("expected unknown, got %s", code) 25 | } 26 | 27 | err = &errorWithCode{ 28 | code: "test", 29 | message: "Hello", 30 | } 31 | if code := GetErrorCode(err); code != "test" { 32 | t.Fatalf("expected test, got %s", code) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /failure/error.go: -------------------------------------------------------------------------------- 1 | package failure 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "net" 8 | 9 | "golang.org/x/xerrors" 10 | ) 11 | 12 | var ( 13 | CaptureBacktraceSize = 5 14 | ) 15 | 16 | type Error struct { 17 | Code 18 | err error 19 | // xerrors は1スタックしかとりあげてくれないので複数取るように 20 | frames []xerrors.Frame 21 | } 22 | 23 | func NewError(code Code, err error) error { 24 | // Skip already wrapped 25 | if IsCode(err, code) { 26 | return err 27 | } 28 | 29 | var nerr net.Error 30 | if As(err, &nerr) { 31 | switch true { 32 | case nerr.Timeout(): 33 | err = newError(TimeoutErrorCode, err) 34 | case nerr.Temporary(): 35 | err = newError(TemporaryErrorCode, err) 36 | } 37 | } else if Is(err, context.Canceled) { 38 | err = newError(CanceledErrorCode, err) 39 | } 40 | 41 | return newError(code, err) 42 | } 43 | 44 | func newError(code Code, err error) *Error { 45 | frames := make([]xerrors.Frame, 0, CaptureBacktraceSize) 46 | skip := 2 47 | for i := 0; i < CaptureBacktraceSize; i++ { 48 | frame := xerrors.Caller(i + skip) 49 | if BacktraceCleaner.match(frame) { 50 | i-- 51 | skip++ 52 | } else { 53 | frames = append(frames, frame) 54 | } 55 | } 56 | 57 | return &Error{ 58 | Code: code, 59 | err: err, 60 | frames: frames, 61 | } 62 | } 63 | 64 | func (e *Error) Unwrap() error { // implments xerrors.Wrapper 65 | return e.err 66 | } 67 | 68 | func (e *Error) Format(f fmt.State, c rune) { // implements fmt.Formatter 69 | xerrors.FormatError(e, f, c) 70 | } 71 | 72 | func (e *Error) FormatError(p xerrors.Printer) error { // implements xerrors.Formatter 73 | p.Print(e.Error()) 74 | if p.Detail() { 75 | for _, frame := range e.frames { 76 | frame.Format(p) 77 | } 78 | } 79 | return e.err 80 | } 81 | 82 | func Is(err, target error) bool { 83 | return err == target || xerrors.Is(err, target) || errors.Is(err, target) 84 | } 85 | 86 | func As(err error, target interface{}) bool { 87 | return xerrors.As(err, target) 88 | } 89 | 90 | func IsCode(err error, code Code) bool { 91 | for _, c := range GetErrorCodes(err) { 92 | if c == code.ErrorCode() { 93 | return true 94 | } 95 | } 96 | 97 | return false 98 | } 99 | -------------------------------------------------------------------------------- /failure/error_test.go: -------------------------------------------------------------------------------- 1 | package failure 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "reflect" 7 | "strings" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | const ( 13 | errApplication StringCode = "application" 14 | errTemporary StringCode = "temporary" 15 | errTest StringCode = "test" 16 | ) 17 | 18 | func TestError(t *testing.T) { 19 | berr := fmt.Errorf("Test") 20 | aerr := NewError(errApplication, berr) 21 | 22 | if m := fmt.Sprint(aerr); m != "application: Test" { 23 | t.Fatalf("missmatch: %s", m) 24 | } 25 | 26 | if m := fmt.Sprintf("%+v", aerr); strings.HasPrefix(m, "application: Test") { 27 | t.Fatalf("missmatch: %s", m) 28 | } 29 | 30 | if !Is(aerr, berr) { 31 | t.Fatalf("check invalid") 32 | } 33 | 34 | if GetErrorCode(aerr) != "application" { 35 | t.Fatalf("Error code is invalid: %s", GetErrorCode(aerr)) 36 | } 37 | 38 | terr := NewError(errTemporary, aerr) 39 | 40 | if m := fmt.Sprint(terr); m != "temporary: application: Test" { 41 | t.Fatalf("missmatch: %s", m) 42 | } 43 | 44 | if !Is(terr, berr) { 45 | t.Fatalf("check invalid") 46 | } 47 | 48 | if GetErrorCode(terr) != "temporary" { 49 | t.Fatalf("Error code is invalid: %s", GetErrorCode(terr)) 50 | } 51 | 52 | gotCodes := GetErrorCodes(terr) 53 | expectCodes := []string{"temporary", "application"} 54 | if !reflect.DeepEqual(gotCodes, expectCodes) { 55 | t.Fatalf("Error codes is invalid:\n %v\n %v", gotCodes, expectCodes) 56 | } 57 | } 58 | 59 | type fakeNetError struct { 60 | timeout bool 61 | temporary bool 62 | } 63 | 64 | func (f fakeNetError) Error() string { 65 | return "fake" 66 | } 67 | 68 | func (f fakeNetError) Timeout() bool { 69 | return f.timeout 70 | } 71 | 72 | func (f fakeNetError) Temporary() bool { 73 | return f.temporary 74 | } 75 | 76 | func TestErrorWrap(t *testing.T) { 77 | rctx := context.TODO() 78 | 79 | ctx, cancel := context.WithCancel(rctx) 80 | cancel() 81 | 82 | canceledError := NewError(errApplication, ctx.Err()) 83 | if GetErrorCode(canceledError) != "application" { 84 | t.Fatalf("%s", GetErrorCode(canceledError)) 85 | } 86 | codes := GetErrorCodes(canceledError) 87 | expectCodes := []string{"application", CanceledErrorCode.ErrorCode()} 88 | if !reflect.DeepEqual(codes, expectCodes) { 89 | t.Fatalf("Error codes is invalid:\n %v\n %v", codes, expectCodes) 90 | } 91 | 92 | ctx, cancel = context.WithTimeout(rctx, -1*time.Second) 93 | defer cancel() 94 | 95 | timeoutError := NewError(errApplication, ctx.Err()) 96 | if GetErrorCode(timeoutError) != "application" { 97 | t.Fatalf("%s", GetErrorCode(timeoutError)) 98 | } 99 | codes = GetErrorCodes(timeoutError) 100 | expectCodes = []string{"application", TimeoutErrorCode.ErrorCode()} 101 | if !reflect.DeepEqual(codes, expectCodes) { 102 | t.Fatalf("Error codes is invalid:\n %v\n %v", codes, expectCodes) 103 | } 104 | 105 | ferr := fakeNetError{timeout: false, temporary: true} 106 | temporaryError := NewError(errApplication, ferr) 107 | if GetErrorCode(temporaryError) != "application" { 108 | t.Fatalf("%s", GetErrorCode(temporaryError)) 109 | } 110 | codes = GetErrorCodes(temporaryError) 111 | expectCodes = []string{"application", TemporaryErrorCode.ErrorCode()} 112 | if !reflect.DeepEqual(codes, expectCodes) { 113 | t.Fatalf("Error codes is invalid:\n %v\n %v", codes, expectCodes) 114 | } 115 | 116 | err := NewError(errTest, fmt.Errorf("error")) 117 | nilError := NewError(errTest, err) 118 | codes = GetErrorCodes(nilError) 119 | expectCodes = []string{"test"} 120 | if !reflect.DeepEqual(codes, expectCodes) { 121 | t.Fatalf("Error codes is invalid:\n %v\n %v", codes, expectCodes) 122 | } 123 | } 124 | 125 | func TestErrorFrames(t *testing.T) { 126 | berr := fmt.Errorf("frames") 127 | 128 | var f func(int) error 129 | f = func(n int) error { 130 | if n > 0 { 131 | return f(n - 1) 132 | } else { 133 | return NewError(errApplication, berr) 134 | } 135 | } 136 | aerr := f(3) 137 | 138 | details := fmt.Sprintf("%+v", aerr) 139 | dLines := strings.Split(details, "\n") 140 | 141 | // callstack * 2 + 2 messages 142 | eLineCount := 2 + CaptureBacktraceSize*2 143 | if len(dLines) != eLineCount { 144 | t.Fatalf("expected %d but got %d", eLineCount, len(dLines)) 145 | } 146 | } 147 | 148 | func TestIsCode(t *testing.T) { 149 | err := NewError(errApplication, NewError(errTemporary, fmt.Errorf("foo"))) 150 | 151 | if !IsCode(err, errApplication) { 152 | t.Fatal(err) 153 | } 154 | 155 | if !IsCode(err, errTemporary) { 156 | t.Fatal(err) 157 | } 158 | 159 | if IsCode(err, UnknownErrorCode) { 160 | t.Fatal(err) 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /failure/errors.go: -------------------------------------------------------------------------------- 1 | package failure 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "sync/atomic" 7 | ) 8 | 9 | type ErrorsHook func(error) 10 | 11 | type Errors struct { 12 | mu sync.RWMutex 13 | cmu sync.RWMutex 14 | count int32 15 | closed uint32 16 | errors []error 17 | queue chan error 18 | hook ErrorsHook 19 | } 20 | 21 | func NewErrors(ctx context.Context) *Errors { 22 | set := &Errors{ 23 | mu: sync.RWMutex{}, 24 | cmu: sync.RWMutex{}, 25 | count: int32(0), 26 | closed: uint32(0), 27 | errors: make([]error, 0, 0), 28 | queue: make(chan error), 29 | } 30 | 31 | set.hook = func(err error) { 32 | atomic.AddInt32(&set.count, -1) 33 | } 34 | 35 | go set.collect(ctx) 36 | 37 | return set 38 | } 39 | 40 | func (s *Errors) collect(ctx context.Context) { 41 | go func() { 42 | <-ctx.Done() 43 | s.Close() 44 | }() 45 | 46 | for err := range s.queue { 47 | s.mu.Lock() 48 | s.errors = append(s.errors, err) 49 | s.mu.Unlock() 50 | 51 | go s.hook(err) 52 | } 53 | atomic.AddInt32(&s.count, -1) 54 | } 55 | 56 | func (s *Errors) Add(err error) { 57 | defer func() { recover() }() 58 | 59 | if atomic.CompareAndSwapUint32(&s.closed, 0, 0) { 60 | s.cmu.RLock() 61 | s.queue <- err 62 | s.cmu.RUnlock() 63 | atomic.AddInt32(&s.count, 1) 64 | } 65 | } 66 | 67 | func (s *Errors) Hook(hook ErrorsHook) { 68 | oldHook := s.hook 69 | s.hook = func(err error) { 70 | defer oldHook(err) 71 | 72 | hook(err) 73 | } 74 | } 75 | 76 | func (s *Errors) Wait() { 77 | for atomic.LoadInt32(&s.count) > 0 { 78 | } 79 | } 80 | 81 | func (s *Errors) Close() { 82 | if atomic.CompareAndSwapUint32(&s.closed, 0, 1) { 83 | atomic.AddInt32(&s.count, 1) 84 | s.cmu.Lock() 85 | close(s.queue) 86 | s.cmu.Unlock() 87 | } 88 | } 89 | 90 | func (s *Errors) Done() { 91 | s.Close() 92 | s.Wait() 93 | } 94 | 95 | func (s *Errors) Messages() map[string][]string { 96 | s.mu.RLock() 97 | defer s.mu.RUnlock() 98 | 99 | table := make(map[string][]string) 100 | for _, err := range s.errors { 101 | code := GetErrorCode(err) 102 | if _, ok := table[code]; ok { 103 | table[code] = append(table[code], err.Error()) 104 | } else { 105 | table[code] = []string{err.Error()} 106 | } 107 | } 108 | 109 | return table 110 | } 111 | 112 | func (s *Errors) Count() map[string]int64 { 113 | s.mu.RLock() 114 | defer s.mu.RUnlock() 115 | 116 | table := make(map[string]int64) 117 | for _, err := range s.errors { 118 | codes := GetErrorCodes(err) 119 | for _, code := range codes { 120 | if _, ok := table[code]; ok { 121 | table[code]++ 122 | } else { 123 | table[code] = 1 124 | } 125 | } 126 | } 127 | 128 | return table 129 | } 130 | 131 | func (s *Errors) All() []error { 132 | s.mu.RLock() 133 | defer s.mu.RUnlock() 134 | 135 | errors := make([]error, len(s.errors)) 136 | copy(errors, s.errors) 137 | 138 | return errors 139 | } 140 | 141 | func (s *Errors) Reset() { 142 | s.mu.RLock() 143 | defer s.mu.RUnlock() 144 | 145 | s.errors = []error{} 146 | } 147 | -------------------------------------------------------------------------------- /failure/errors_test.go: -------------------------------------------------------------------------------- 1 | package failure 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync/atomic" 7 | "testing" 8 | ) 9 | 10 | func TestErrors(t *testing.T) { 11 | ctx, cancel := context.WithCancel(context.Background()) 12 | set := NewErrors(ctx) 13 | defer cancel() 14 | 15 | for i := 0; i < 100; i++ { 16 | set.Add(fmt.Errorf("unknown error")) 17 | } 18 | 19 | set.Done() 20 | 21 | table := set.Count() 22 | if table["unknown"] != 100 { 23 | t.Errorf("missmatch unknown count: %d", table["unknown"]) 24 | } 25 | 26 | errors := set.All() 27 | if len(errors) != 100 { 28 | t.Errorf("missmatch errors count: %d", len(errors)) 29 | } 30 | 31 | set.Reset() 32 | moreErrors := set.All() 33 | if len(moreErrors) != 0 { 34 | t.Errorf("missmatch errors count: %d", len(moreErrors)) 35 | } 36 | } 37 | 38 | func TestErrorsClosed(t *testing.T) { 39 | ctx, cancel := context.WithCancel(context.Background()) 40 | defer cancel() 41 | set := NewErrors(ctx) 42 | 43 | set.Add(fmt.Errorf("test")) 44 | set.Add(fmt.Errorf("test")) 45 | set.Add(fmt.Errorf("test")) 46 | 47 | set.Done() 48 | 49 | set.Add(fmt.Errorf("test")) 50 | 51 | table := set.Count() 52 | if table["unknown"] != 3 { 53 | t.Fatalf("missmatch unknown count: %d", table["unknown"]) 54 | } 55 | 56 | messages := set.Messages() 57 | if len(messages["unknown"]) != 3 { 58 | t.Fatalf("missmatch unknown message count: %d", len(messages["unknown"])) 59 | } 60 | } 61 | 62 | func TestErrorsHook(t *testing.T) { 63 | ctx, cancel := context.WithCancel(context.Background()) 64 | set := NewErrors(ctx) 65 | defer cancel() 66 | 67 | n := int32(0) 68 | cnt := &n 69 | set.Hook(func(err error) { 70 | atomic.AddInt32(cnt, 1) 71 | }) 72 | 73 | for i := 0; i < 10; i++ { 74 | set.Add(fmt.Errorf("unknown error")) 75 | } 76 | 77 | set.Done() 78 | 79 | table := set.Count() 80 | if table["unknown"] != 10 { 81 | t.Errorf("missmatch unknown count: %d", table["unknown"]) 82 | } 83 | 84 | set.Wait() 85 | 86 | if atomic.LoadInt32(cnt) != 10 { 87 | t.Errorf("missmatch unknown hook count: %d", atomic.LoadInt32(cnt)) 88 | } 89 | } 90 | 91 | func BenchmarkErrorsAdd(b *testing.B) { 92 | err := fmt.Errorf("test") 93 | ctx, cancel := context.WithCancel(context.Background()) 94 | defer cancel() 95 | 96 | set := NewErrors(ctx) 97 | 98 | b.ResetTimer() 99 | for i := 0; i < b.N; i++ { 100 | set.Add(err) 101 | } 102 | set.Done() 103 | b.StopTimer() 104 | } 105 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/isucon/isucandar 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/andybalholm/brotli v1.0.0 7 | github.com/dsnet/compress v0.0.1 8 | github.com/julienschmidt/httprouter v1.3.0 9 | github.com/labstack/echo/v4 v4.7.2 10 | github.com/pquerna/cachecontrol v0.1.0 11 | github.com/stretchr/testify v1.7.0 12 | golang.org/x/net v0.0.0-20220225172249-27dd8689420f 13 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 14 | ) 15 | 16 | require ( 17 | github.com/davecgh/go-spew v1.1.1 // indirect 18 | github.com/golang-jwt/jwt v3.2.2+incompatible // indirect 19 | github.com/labstack/gommon v0.3.1 // indirect 20 | github.com/mattn/go-colorable v0.1.11 // indirect 21 | github.com/mattn/go-isatty v0.0.14 // indirect 22 | github.com/pmezard/go-difflib v1.0.0 // indirect 23 | github.com/valyala/bytebufferpool v1.0.0 // indirect 24 | github.com/valyala/fasttemplate v1.2.1 // indirect 25 | golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 // indirect 26 | golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e // indirect 27 | golang.org/x/text v0.3.7 // indirect 28 | golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 // indirect 29 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect 30 | ) 31 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/andybalholm/brotli v1.0.0 h1:7UCwP93aiSfvWpapti8g88vVVGp2qqtGyePsSuDafo4= 2 | github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y= 3 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/dsnet/compress v0.0.1 h1:PlZu0n3Tuv04TzpfPbrnI0HW/YwodEXDS+oPKahKF0Q= 7 | github.com/dsnet/compress v0.0.1/go.mod h1:Aw8dCMJ7RioblQeTqt88akK31OvO8Dhf5JflhBbQEHo= 8 | github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY= 9 | github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= 10 | github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= 11 | github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= 12 | github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= 13 | github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= 14 | github.com/klauspost/cpuid v1.2.0/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= 15 | github.com/labstack/echo/v4 v4.7.2 h1:Kv2/p8OaQ+M6Ex4eGimg9b9e6icoxA42JSlOR3msKtI= 16 | github.com/labstack/echo/v4 v4.7.2/go.mod h1:xkCDAdFCIf8jsFQ5NnbK7oqaF/yU1A1X20Ltm0OvSks= 17 | github.com/labstack/gommon v0.3.1 h1:OomWaJXm7xR6L1HmEtGyQf26TEn7V6X88mktX9kee9o= 18 | github.com/labstack/gommon v0.3.1/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= 19 | github.com/mattn/go-colorable v0.1.11 h1:nQ+aFkoE2TMGc0b68U2OKSexC+eq46+XwZzWXHRmPYs= 20 | github.com/mattn/go-colorable v0.1.11/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= 21 | github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= 22 | github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= 23 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 24 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 25 | github.com/pquerna/cachecontrol v0.1.0 h1:yJMy84ti9h/+OEWa752kBTKv4XC30OtVVHYv/8cTqKc= 26 | github.com/pquerna/cachecontrol v0.1.0/go.mod h1:NrUG3Z7Rdu85UNR3vm7SOsl1nFIeSiQnrHV5K9mBcUI= 27 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 28 | github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 29 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= 30 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 31 | github.com/ulikunitz/xz v0.5.6/go.mod h1:2bypXElzHzzJZwzH67Y6wb67pO62Rzfn7BSiF4ABRW8= 32 | github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= 33 | github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= 34 | github.com/valyala/fasttemplate v1.2.1 h1:TVEnxayobAdVkhQfrfes2IzOB6o+z4roRkPF52WA1u4= 35 | github.com/valyala/fasttemplate v1.2.1/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= 36 | golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 h1:HWj/xjIHfjYU5nVXpTM0s39J9CbLn7Cc5a7IC5rwsMQ= 37 | golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= 38 | golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc= 39 | golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= 40 | golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 41 | golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 42 | golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 43 | golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM= 44 | golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 45 | golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= 46 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= 47 | golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 h1:Hir2P/De0WpUhtrKGGjvSb2YxUgyZ7EFOSLIcSSpiwE= 48 | golang.org/x/time v0.0.0-20201208040808-7e3f01d25324/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= 49 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= 50 | golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 51 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 52 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 53 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 54 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= 55 | gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 56 | -------------------------------------------------------------------------------- /parallel/parallel.go: -------------------------------------------------------------------------------- 1 | package parallel 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "sync" 7 | "sync/atomic" 8 | ) 9 | 10 | var ( 11 | ErrLimiterClosed = errors.New("limiter closed") 12 | ErrNegativeCount = errors.New("negative count") 13 | ) 14 | 15 | const ( 16 | closedFalse uint32 = iota 17 | closedTrue 18 | ) 19 | 20 | type Parallel struct { 21 | mu sync.Mutex 22 | ctx context.Context 23 | limit int32 24 | count int32 25 | closed uint32 26 | closer chan struct{} 27 | doner chan struct{} 28 | } 29 | 30 | func NewParallel(ctx context.Context, limit int32) *Parallel { 31 | var doner chan struct{} = nil 32 | if limit > 0 { 33 | doner = make(chan struct{}, limit) 34 | } 35 | 36 | p := &Parallel{ 37 | mu: sync.Mutex{}, 38 | ctx: ctx, 39 | limit: limit, 40 | count: 0, 41 | closed: closedFalse, 42 | closer: make(chan struct{}), 43 | doner: doner, 44 | } 45 | 46 | return p 47 | } 48 | 49 | func (l *Parallel) CurrentLimit() int32 { 50 | return atomic.LoadInt32(&l.limit) 51 | } 52 | 53 | func (l *Parallel) Do(f func(context.Context)) error { 54 | atomic.AddInt32(&l.count, 1) 55 | 56 | err := l.start() 57 | if err != nil { 58 | atomic.AddInt32(&l.count, -1) 59 | return err 60 | } 61 | 62 | l.mu.Lock() 63 | doner := l.doner 64 | l.mu.Unlock() 65 | 66 | go func(doner chan struct{}) { 67 | defer l.done(doner) 68 | f(l.ctx) 69 | }(doner) 70 | 71 | return nil 72 | } 73 | 74 | func (l *Parallel) Wait() { 75 | if atomic.LoadUint32(&l.closed) != closedTrue { 76 | for { 77 | select { 78 | case <-l.ctx.Done(): 79 | l.Close() 80 | case <-l.closer: 81 | return 82 | } 83 | } 84 | } 85 | } 86 | 87 | func (l *Parallel) Close() { 88 | if atomic.CompareAndSwapUint32(&l.closed, closedFalse, closedTrue) { 89 | close(l.closer) 90 | } 91 | } 92 | 93 | func (l *Parallel) SetParallelism(limit int32) { 94 | l.mu.Lock() 95 | defer l.mu.Unlock() 96 | atomic.StoreInt32(&l.limit, limit) 97 | if l.doner != nil { 98 | close(l.doner) 99 | } 100 | 101 | if limit > 0 { 102 | l.doner = make(chan struct{}, limit) 103 | } else { 104 | l.doner = nil 105 | } 106 | } 107 | 108 | func (l *Parallel) AddParallelism(limit int32) { 109 | l.SetParallelism(atomic.LoadInt32(&l.limit) + limit) 110 | } 111 | 112 | func (l *Parallel) start() error { 113 | for l.isRunning() { 114 | if count, limit, kept := l.isLimitKept(); kept { 115 | if atomic.CompareAndSwapInt32(&l.count, count, count+1) { 116 | return nil 117 | } 118 | } else if limit > 0 { 119 | l.mu.Lock() 120 | l.doner <- struct{}{} 121 | l.mu.Unlock() 122 | } 123 | } 124 | 125 | return ErrLimiterClosed 126 | } 127 | 128 | func (l *Parallel) done(doner chan struct{}) { 129 | select { 130 | case <-doner: 131 | default: 132 | } 133 | 134 | count := atomic.AddInt32(&l.count, -2) 135 | if count < 0 { 136 | panic(ErrNegativeCount) 137 | } 138 | if count == 0 { 139 | l.Close() 140 | } 141 | } 142 | 143 | func (l *Parallel) isRunning() bool { 144 | return atomic.LoadUint32(&l.closed) == closedFalse && l.ctx.Err() == nil 145 | } 146 | 147 | func (l *Parallel) isLimitKept() (int32, int32, bool) { 148 | limit := atomic.LoadInt32(&l.limit) 149 | count := atomic.LoadInt32(&l.count) 150 | return count, limit, limit < 1 || count < (limit*2) 151 | } 152 | -------------------------------------------------------------------------------- /parallel/parallel_test.go: -------------------------------------------------------------------------------- 1 | package parallel 2 | 3 | import ( 4 | "context" 5 | "sync/atomic" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestParallel(t *testing.T) { 11 | ctx := context.TODO() 12 | 13 | parallel := NewParallel(ctx, 2) 14 | defer parallel.Close() 15 | 16 | pcount := int32(0) 17 | pmcount := int32(0) 18 | exited := uint32(0) 19 | f := func(_ context.Context) { 20 | atomic.AddInt32(&pcount, 1) 21 | defer atomic.AddInt32(&pcount, -1) 22 | time.Sleep(10 * time.Millisecond) 23 | } 24 | 25 | parallel.Do(f) 26 | go func() { 27 | parallel.Do(f) 28 | parallel.Do(f) 29 | parallel.Do(f) 30 | }() 31 | 32 | go func() { 33 | for atomic.LoadUint32(&exited) == 0 { 34 | m := atomic.LoadInt32(&pcount) 35 | if atomic.LoadInt32(&pmcount) < m { 36 | atomic.StoreInt32(&pmcount, m) 37 | } 38 | } 39 | }() 40 | 41 | parallel.Wait() 42 | atomic.StoreUint32(&exited, 1) 43 | 44 | maxCount := atomic.LoadInt32(&pmcount) 45 | if maxCount != 2 { 46 | t.Fatalf("Invalid parallel count: %d / %d", maxCount, 2) 47 | } 48 | } 49 | 50 | func TestParallelClosed(t *testing.T) { 51 | ctx := context.TODO() 52 | 53 | parallel := NewParallel(ctx, 2) 54 | parallel.Close() 55 | 56 | called := false 57 | err := parallel.Do(func(_ context.Context) { 58 | called = true 59 | }) 60 | 61 | parallel.Wait() 62 | 63 | if err == nil || err != ErrLimiterClosed { 64 | t.Fatalf("missmatch error: %+v", err) 65 | } 66 | 67 | if called { 68 | t.Fatalf("Do not process on closed") 69 | } 70 | } 71 | 72 | func TestParallelCanceled(t *testing.T) { 73 | ctx, cancel := context.WithCancel(context.Background()) 74 | cancel() 75 | 76 | parallel := NewParallel(ctx, 0) 77 | 78 | parallel.Do(func(_ context.Context) { 79 | t.Fatal("Do not call") 80 | }) 81 | 82 | parallel.Wait() 83 | } 84 | 85 | func TestParallelPanicOnNegative(t *testing.T) { 86 | ctx, cancel := context.WithCancel(context.Background()) 87 | defer cancel() 88 | 89 | parallel := NewParallel(ctx, 0) 90 | 91 | var err interface{} 92 | func() { 93 | defer func() { err = recover() }() 94 | parallel.done(nil) 95 | }() 96 | 97 | if err != ErrNegativeCount { 98 | t.Fatal(err) 99 | } 100 | } 101 | 102 | func TestParallelSetParallelism(t *testing.T) { 103 | check := func(paralellism int32) { 104 | ctx, cancel := context.WithCancel(context.Background()) 105 | defer cancel() 106 | 107 | parallel := NewParallel(ctx, -1) 108 | parallel.SetParallelism(paralellism) 109 | 110 | pcount := int32(0) 111 | pmcount := int32(0) 112 | exited := uint32(0) 113 | f := func(c context.Context) { 114 | atomic.AddInt32(&pcount, 1) 115 | defer atomic.AddInt32(&pcount, -1) 116 | 117 | time.Sleep(10 * time.Millisecond) 118 | } 119 | 120 | parallel.Do(f) 121 | go func() { 122 | parallel.Do(f) 123 | parallel.Do(f) 124 | parallel.Do(f) 125 | }() 126 | 127 | go func() { 128 | for atomic.LoadUint32(&exited) == 0 { 129 | m := atomic.LoadInt32(&pcount) 130 | if atomic.LoadInt32(&pmcount) < m { 131 | atomic.StoreInt32(&pmcount, m) 132 | } 133 | } 134 | }() 135 | parallel.Wait() 136 | atomic.StoreUint32(&exited, 1) 137 | 138 | maxCount := atomic.LoadInt32(&pmcount) 139 | if maxCount != parallel.CurrentLimit() && parallel.CurrentLimit() > 0 { 140 | t.Fatalf("Invalid parallel count: %d / %d", maxCount, parallel.CurrentLimit()) 141 | } 142 | 143 | parallel.Wait() 144 | } 145 | 146 | check(2) 147 | check(1) 148 | check(-1) 149 | } 150 | 151 | func TestParallelAddParallelism(t *testing.T) { 152 | ctx, cancel := context.WithCancel(context.TODO()) 153 | defer cancel() 154 | 155 | para := NewParallel(ctx, 1) 156 | para.AddParallelism(1) 157 | 158 | pcount := int32(0) 159 | pmcount := int32(0) 160 | exited := uint32(0) 161 | f := func(c context.Context) { 162 | atomic.AddInt32(&pcount, 1) 163 | defer atomic.AddInt32(&pcount, -1) 164 | 165 | time.Sleep(10 * time.Millisecond) 166 | } 167 | 168 | para.Do(f) 169 | go func() { 170 | para.Do(f) 171 | para.Do(f) 172 | para.Do(f) 173 | }() 174 | 175 | go func() { 176 | for atomic.LoadUint32(&exited) == 0 { 177 | m := atomic.LoadInt32(&pcount) 178 | if atomic.LoadInt32(&pmcount) < m { 179 | atomic.StoreInt32(&pmcount, m) 180 | } 181 | } 182 | }() 183 | para.Wait() 184 | atomic.StoreUint32(&exited, 1) 185 | 186 | maxCount := atomic.LoadInt32(&pmcount) 187 | if maxCount != 2 { 188 | t.Fatalf("Invalid parallel count: %d / %d", maxCount, para.CurrentLimit()) 189 | } 190 | } 191 | 192 | func BenchmarkParallel(b *testing.B) { 193 | ctx, cancel := context.WithCancel(context.TODO()) 194 | defer cancel() 195 | 196 | parallel := NewParallel(ctx, -1) 197 | nop := func(_ context.Context) {} 198 | 199 | b.ResetTimer() 200 | for i := 0; i < b.N; i++ { 201 | parallel.Do(nop) 202 | } 203 | parallel.Wait() 204 | b.StopTimer() 205 | } 206 | -------------------------------------------------------------------------------- /pubsub/pubsub.go: -------------------------------------------------------------------------------- 1 | package pubsub 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | ) 7 | 8 | type PubSub struct { 9 | Capacity int 10 | 11 | mu sync.RWMutex 12 | ch []chan interface{} 13 | } 14 | 15 | func NewPubSub() *PubSub { 16 | return &PubSub{ 17 | Capacity: 10, 18 | mu: sync.RWMutex{}, 19 | ch: []chan interface{}{}, 20 | } 21 | } 22 | 23 | func (p *PubSub) Publish(payload interface{}) { 24 | p.mu.RLock() 25 | defer p.mu.RUnlock() 26 | 27 | for _, ch := range p.ch { 28 | ch <- payload 29 | } 30 | } 31 | 32 | func (p *PubSub) Subscribe(ctx context.Context, f func(interface{})) <-chan bool { 33 | p.mu.Lock() 34 | defer p.mu.Unlock() 35 | 36 | ch := make(chan interface{}, p.Capacity) 37 | p.ch = append(p.ch, ch) 38 | 39 | sub := &Subscription{ 40 | pubsub: p, 41 | f: f, 42 | ch: ch, 43 | } 44 | 45 | waiter := make(chan bool) 46 | 47 | go func() { 48 | L: 49 | for ctx.Err() == nil { 50 | select { 51 | case payload, ok := <-sub.ch: 52 | if ok { 53 | sub.f(payload) 54 | } 55 | case <-ctx.Done(): 56 | sub.close() 57 | break L 58 | } 59 | } 60 | close(waiter) 61 | }() 62 | 63 | return waiter 64 | } 65 | 66 | type Subscription struct { 67 | pubsub *PubSub 68 | f func(interface{}) 69 | ch chan interface{} 70 | } 71 | 72 | func (s *Subscription) close() { 73 | s.pubsub.mu.Lock() 74 | defer s.pubsub.mu.Unlock() 75 | 76 | for idx, ch := range s.pubsub.ch { 77 | if ch == s.ch { 78 | deleted := append(s.pubsub.ch[:idx], s.pubsub.ch[idx+1:]...) 79 | channels := make([]chan interface{}, len(deleted)) 80 | copy(channels, deleted) 81 | s.pubsub.ch = channels 82 | } 83 | } 84 | 85 | close(s.ch) 86 | 87 | s.ch = nil 88 | } 89 | -------------------------------------------------------------------------------- /pubsub/pubsub_test.go: -------------------------------------------------------------------------------- 1 | package pubsub 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "sync/atomic" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestPubSub(t *testing.T) { 12 | ctx, cancel := context.WithCancel(context.Background()) 13 | defer cancel() 14 | 15 | pubsub := NewPubSub() 16 | 17 | wg := sync.WaitGroup{} 18 | 19 | result1 := int32(0) 20 | pubsub.Subscribe(ctx, func(payload interface{}) { 21 | atomic.AddInt32(&result1, int32(payload.(int))) 22 | wg.Done() 23 | }) 24 | result2 := int32(0) 25 | pubsub.Subscribe(ctx, func(payload interface{}) { 26 | atomic.AddInt32(&result2, int32(payload.(int))) 27 | wg.Done() 28 | }) 29 | 30 | wg.Add(4) 31 | pubsub.Publish(1) 32 | pubsub.Publish(2) 33 | 34 | wg.Wait() 35 | 36 | if result1 != 3 { 37 | t.Fatalf("invalid 1: %v", result1) 38 | } 39 | if result2 != 3 { 40 | t.Fatalf("invalid 2: %v", result2) 41 | } 42 | } 43 | 44 | func TestPubSubUnsubscribe(t *testing.T) { 45 | pubsub := NewPubSub() 46 | 47 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) 48 | defer cancel() 49 | <-pubsub.Subscribe(ctx, func(payload interface{}) {}) 50 | } 51 | -------------------------------------------------------------------------------- /random/useragent/browser.go: -------------------------------------------------------------------------------- 1 | package useragent 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | ) 7 | 8 | var ( 9 | chromeVersions = []string{ 10 | "60.0.3112.113", 11 | "63.0.3239.132", 12 | "67.0.3396.99", 13 | "69.0.3497.100", 14 | "72.0.3626.121", 15 | "74.0.3729.169", 16 | "79.0.3945.88", 17 | "80.0.3987.163", 18 | "81.0.4044.138", 19 | "83.0.4103.116", 20 | "84.0.4147.135", 21 | "85.0.4183.102", 22 | } 23 | ) 24 | 25 | func Chrome() string { 26 | return fmt.Sprintf("Mozilla/5.0 (%s) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/%s Safari/537.36", Platform(), chromeVersions[rand.Intn(len(chromeVersions))]) 27 | } 28 | 29 | var ( 30 | edgeVersions = []string{ 31 | "79.0.522.52", 32 | "80.0.522.52", 33 | "81.0.522.52", 34 | "82.0.522.52", 35 | "83.0.522.52", 36 | "85.0.564.44", 37 | } 38 | ) 39 | 40 | func Edge() string { 41 | return fmt.Sprintf("%s Edg/%s", Chrome(), edgeVersions[rand.Intn(len(edgeVersions))]) 42 | } 43 | 44 | func Firefox() string { 45 | return fmt.Sprintf("Mozilla/5.0 (%s) Gecko/20100101 Firefox/%d.0", Platform(), 70+rand.Intn(10)) 46 | } 47 | -------------------------------------------------------------------------------- /random/useragent/platform.go: -------------------------------------------------------------------------------- 1 | package useragent 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | ) 7 | 8 | func Platform() string { 9 | switch rand.Intn(3) { 10 | case 1: 11 | return MacOS() 12 | case 2: 13 | return Linux() 14 | default: 15 | return Windows() 16 | } 17 | } 18 | 19 | func Windows() string { 20 | return "Windows NT 10.0; Win64; x64" 21 | } 22 | 23 | func MacOS() string { 24 | return fmt.Sprintf("Macintosh; Intel Mac OS X 10.%d", 11+rand.Intn(3)) 25 | } 26 | 27 | var ( 28 | linuxDistributions = []string{ 29 | "Ubuntu", 30 | "U", 31 | "Arch Linux", 32 | } 33 | ) 34 | 35 | func Linux() string { 36 | return fmt.Sprintf("X11; %s; Linux x86_64", linuxDistributions[rand.Intn(len(linuxDistributions))]) 37 | } 38 | -------------------------------------------------------------------------------- /random/useragent/platform_test.go: -------------------------------------------------------------------------------- 1 | package useragent 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestPlatform(t *testing.T) { 8 | for i := 0; i < 100; i++ { 9 | platform := Platform() 10 | if platform == "" { 11 | t.Fatal("Empty platform") 12 | } 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /random/useragent/useragent.go: -------------------------------------------------------------------------------- 1 | package useragent 2 | 3 | import ( 4 | "math/rand" 5 | ) 6 | 7 | func UserAgent() string { 8 | switch rand.Intn(3) { 9 | case 1: 10 | return Chrome() 11 | case 2: 12 | return Edge() 13 | default: 14 | return Firefox() 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /random/useragent/useragent_test.go: -------------------------------------------------------------------------------- 1 | package useragent 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestUserAgent(t *testing.T) { 8 | for i := 0; i < 100; i++ { 9 | ua := UserAgent() 10 | if ua == "" { 11 | t.Fatal("Empty platform") 12 | } 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /score/score.go: -------------------------------------------------------------------------------- 1 | package score 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "sync/atomic" 7 | ) 8 | 9 | type ScoreTag string 10 | type ScoreTable map[ScoreTag]int64 11 | 12 | type sumTable map[ScoreTag]*int64 13 | 14 | type Score struct { 15 | Table ScoreTable 16 | DefaultScoreMagnification int64 17 | 18 | mu sync.RWMutex 19 | cmu sync.RWMutex 20 | total sumTable 21 | count int32 22 | queue chan ScoreTag 23 | closed uint32 24 | } 25 | 26 | func NewScore(ctx context.Context) *Score { 27 | score := &Score{ 28 | Table: make(ScoreTable), 29 | DefaultScoreMagnification: 0, 30 | mu: sync.RWMutex{}, 31 | cmu: sync.RWMutex{}, 32 | total: make(sumTable), 33 | count: 0, 34 | queue: make(chan ScoreTag), 35 | closed: 0, 36 | } 37 | 38 | go score.collect(ctx) 39 | 40 | return score 41 | } 42 | 43 | func (s *Score) add(tag ScoreTag) { 44 | s.mu.Lock() 45 | defer s.mu.Unlock() 46 | 47 | if ptr, ok := s.total[tag]; ok { 48 | atomic.AddInt64(ptr, 1) 49 | } else { 50 | n := int64(1) 51 | s.total[tag] = &n 52 | } 53 | } 54 | 55 | func (s *Score) collect(ctx context.Context) { 56 | go func() { 57 | <-ctx.Done() 58 | s.Close() 59 | }() 60 | 61 | for tag := range s.queue { 62 | s.add(tag) 63 | atomic.AddInt32(&s.count, -1) 64 | } 65 | atomic.AddInt32(&s.count, -1) 66 | } 67 | 68 | func (s *Score) Set(tag ScoreTag, mag int64) { 69 | s.Table[tag] = mag 70 | } 71 | 72 | func (s *Score) Add(tag ScoreTag) { 73 | defer func() { recover() }() 74 | 75 | if atomic.CompareAndSwapUint32(&s.closed, 0, 0) { 76 | s.cmu.RLock() 77 | s.queue <- tag 78 | s.cmu.RUnlock() 79 | atomic.AddInt32(&s.count, 1) 80 | } 81 | } 82 | 83 | func (s *Score) Close() { 84 | if atomic.CompareAndSwapUint32(&s.closed, 0, 1) { 85 | atomic.AddInt32(&s.count, 1) 86 | s.cmu.Lock() 87 | close(s.queue) 88 | s.cmu.Unlock() 89 | } 90 | } 91 | 92 | func (s *Score) Wait() { 93 | for atomic.LoadInt32(&s.count) > 0 { 94 | } 95 | } 96 | 97 | func (s *Score) Done() { 98 | s.Close() 99 | s.Wait() 100 | } 101 | 102 | func (s *Score) Breakdown() ScoreTable { 103 | s.mu.RLock() 104 | defer s.mu.RUnlock() 105 | 106 | table := make(ScoreTable) 107 | for tag, ptr := range s.total { 108 | table[tag] = atomic.LoadInt64(ptr) 109 | } 110 | return table 111 | } 112 | 113 | func (s *Score) Sum() int64 { 114 | s.mu.RLock() 115 | defer s.mu.RUnlock() 116 | 117 | sum := int64(0) 118 | for tag, ptr := range s.total { 119 | if mag, found := s.Table[tag]; found { 120 | sum += atomic.LoadInt64(ptr) * mag 121 | } else { 122 | sum += atomic.LoadInt64(ptr) * s.DefaultScoreMagnification 123 | } 124 | } 125 | return sum 126 | } 127 | 128 | func (s *Score) Total() int64 { 129 | s.Done() 130 | return s.Sum() 131 | } 132 | 133 | func (s *Score) Reset() { 134 | s.mu.Lock() 135 | defer s.mu.Unlock() 136 | 137 | s.total = make(sumTable) 138 | } 139 | -------------------------------------------------------------------------------- /score/score_test.go: -------------------------------------------------------------------------------- 1 | package score 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | ) 7 | 8 | func TestScoreWithDone(t *testing.T) { 9 | score := NewScore(context.Background()) 10 | score.Set("foo", 2) 11 | score.Set("bar", 1) 12 | 13 | for i := 0; i < 1000; i++ { 14 | score.Add("foo") 15 | score.Add("bar") 16 | score.Add("baz") 17 | } 18 | 19 | score.Done() 20 | 21 | if score.Total() != 3000 { 22 | t.Fatalf("Expected 3000 but got %d", score.Total()) 23 | } 24 | 25 | score.Reset() 26 | if score.Total() != 0 { 27 | t.Fatalf("Expected 0 but got %d", score.Total()) 28 | } 29 | } 30 | 31 | func TestScoreWithContext(t *testing.T) { 32 | ctx, cancel := context.WithCancel(context.Background()) 33 | score := NewScore(ctx) 34 | score.Set("foo", 2) 35 | 36 | for i := 0; i < 1000; i++ { 37 | score.Add("foo") 38 | score.Add("bar") 39 | } 40 | 41 | cancel() 42 | 43 | score.Done() 44 | 45 | score.Add("d") 46 | } 47 | 48 | func TestScoreBreakdown(t *testing.T) { 49 | score := NewScore(context.Background()) 50 | 51 | score.Add("a") 52 | score.Add("b") 53 | score.Add("c") 54 | 55 | score.Done() 56 | 57 | breakdown := score.Breakdown() 58 | if c, ok := breakdown["a"]; !ok || c != int64(1) { 59 | t.Fatalf("Add failed of a: %d", c) 60 | } 61 | if c, ok := breakdown["b"]; !ok || c != int64(1) { 62 | t.Fatalf("Add failed of b: %d", c) 63 | } 64 | if c, ok := breakdown["c"]; !ok || c != int64(1) { 65 | t.Fatalf("Add failed of c: %d", c) 66 | } 67 | } 68 | 69 | func BenchmarkScoreCollection(b *testing.B) { 70 | score := NewScore(context.TODO()) 71 | for i := 0; i < b.N; i++ { 72 | score.Add("test") 73 | score.Sum() 74 | } 75 | score.Done() 76 | } 77 | -------------------------------------------------------------------------------- /test/http.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | func IsSuccessfulResponse(r *http.Response) bool { 8 | return (r.StatusCode >= 200 && r.StatusCode <= 299) || r.StatusCode == 304 9 | } 10 | 11 | func HasExpectedHeader(r *http.Response, header http.Header) bool { 12 | for key, values := range header { 13 | actual := r.Header.Values(key) 14 | if len(actual) != len(values) { 15 | return false 16 | } 17 | 18 | for i, v := range values { 19 | if v != actual[i] { 20 | return false 21 | } 22 | } 23 | } 24 | 25 | return true 26 | } 27 | -------------------------------------------------------------------------------- /test/http_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "net/http" 5 | "testing" 6 | ) 7 | 8 | func TestIsSuccessfulResponse(t *testing.T) { 9 | res := &http.Response{} 10 | 11 | expects := map[int]bool{ 12 | 200: true, 13 | 201: true, 14 | 204: true, 15 | 299: true, 16 | 300: false, 17 | 303: false, 18 | 304: true, 19 | 305: false, 20 | 404: false, 21 | } 22 | 23 | for statusCode, ok := range expects { 24 | res.StatusCode = statusCode 25 | if IsSuccessfulResponse(res) != ok { 26 | t.Fatalf("%d: %v / %v", statusCode, IsSuccessfulResponse(res), ok) 27 | } 28 | } 29 | } 30 | 31 | func TestHasExpectedHeader(t *testing.T) { 32 | res := &http.Response{ 33 | Header: make(http.Header), 34 | } 35 | 36 | res.Header.Set("X-Drive", "1") 37 | res.Header.Add("X-Drive", "2") 38 | 39 | expected := http.Header{ 40 | "X-Drive": []string{"1", "2"}, 41 | } 42 | 43 | if !HasExpectedHeader(res, expected) { 44 | t.Fatal("header check failed") 45 | } 46 | 47 | notFound := http.Header{ 48 | "X-Not-Found": []string{"value"}, 49 | } 50 | if HasExpectedHeader(res, notFound) { 51 | t.Fatal("header check failed") 52 | } 53 | 54 | invalidLength := http.Header{ 55 | "X-Drive": []string{"1"}, 56 | } 57 | if HasExpectedHeader(res, invalidLength) { 58 | t.Fatal("header check failed") 59 | } 60 | 61 | invalidValue := http.Header{ 62 | "X-Drive": []string{"1", "3"}, 63 | } 64 | if HasExpectedHeader(res, invalidValue) { 65 | t.Fatal("header check failed") 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /worker/worker.go: -------------------------------------------------------------------------------- 1 | package worker 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "sync/atomic" 7 | 8 | "github.com/isucon/isucandar/parallel" 9 | ) 10 | 11 | var ( 12 | nopWorkFunc = func(_ context.Context, _ int) {} 13 | ) 14 | 15 | type WorkerFunc func(context.Context, int) 16 | type WorkerOption func(*Worker) error 17 | 18 | type Worker struct { 19 | mu sync.RWMutex 20 | workFunc WorkerFunc 21 | count int32 22 | parallelism int32 23 | parallel *parallel.Parallel 24 | } 25 | 26 | func NewWorker(f WorkerFunc, opts ...WorkerOption) (*Worker, error) { 27 | count := int32(-1) 28 | parallelism := int32(-1) 29 | 30 | if f == nil { 31 | f = nopWorkFunc 32 | } 33 | 34 | worker := &Worker{ 35 | mu: sync.RWMutex{}, 36 | workFunc: f, 37 | count: count, 38 | parallelism: parallelism, 39 | } 40 | 41 | for _, opt := range opts { 42 | err := opt(worker) 43 | if err != nil { 44 | return nil, err 45 | } 46 | } 47 | 48 | return worker, nil 49 | } 50 | 51 | func (w *Worker) Process(ctx context.Context) { 52 | count := atomic.LoadInt32(&w.count) 53 | if count < 1 { 54 | w.processInfinity(ctx) 55 | } else { 56 | w.processLimited(ctx, int(count)) 57 | } 58 | } 59 | 60 | func (w *Worker) processInfinity(ctx context.Context) { 61 | if ctx.Err() != nil { 62 | return 63 | } 64 | 65 | parallel := parallel.NewParallel(ctx, atomic.LoadInt32(&w.parallelism)) 66 | defer parallel.Close() 67 | w.mu.Lock() 68 | w.parallel = parallel 69 | w.mu.Unlock() 70 | 71 | work := func(ctx context.Context) { 72 | w.workFunc(ctx, -1) 73 | } 74 | 75 | L: 76 | for { 77 | select { 78 | case <-ctx.Done(): 79 | break L 80 | default: 81 | parallel.Do(work) 82 | } 83 | } 84 | 85 | w.Wait() 86 | } 87 | 88 | func (w *Worker) processLimited(ctx context.Context, limit int) { 89 | if ctx.Err() != nil { 90 | return 91 | } 92 | 93 | parallel := parallel.NewParallel(ctx, atomic.LoadInt32(&w.parallelism)) 94 | defer parallel.Close() 95 | w.mu.Lock() 96 | w.parallel = parallel 97 | w.mu.Unlock() 98 | 99 | work := func(i int) func(context.Context) { 100 | return func(ctx context.Context) { 101 | w.workFunc(ctx, i) 102 | } 103 | } 104 | 105 | L: 106 | for i := 0; i < limit; i++ { 107 | select { 108 | case <-ctx.Done(): 109 | break L 110 | default: 111 | parallel.Do(work(i)) 112 | } 113 | } 114 | 115 | w.Wait() 116 | } 117 | 118 | func (w *Worker) Wait() { 119 | w.mu.RLock() 120 | defer w.mu.RUnlock() 121 | 122 | if w.parallel != nil { 123 | w.parallel.Wait() 124 | } 125 | } 126 | 127 | func (w *Worker) SetLoopCount(count int32) { 128 | atomic.StoreInt32(&w.count, count) 129 | } 130 | 131 | func (w *Worker) SetParallelism(parallelism int32) { 132 | w.mu.RLock() 133 | defer w.mu.RUnlock() 134 | 135 | atomic.StoreInt32(&w.parallelism, parallelism) 136 | if w.parallel != nil { 137 | w.parallel.SetParallelism(parallelism) 138 | } 139 | } 140 | 141 | func (w *Worker) AddParallelism(parallelism int32) { 142 | w.SetParallelism(atomic.LoadInt32(&w.parallelism) + parallelism) 143 | } 144 | 145 | func WithLoopCount(count int32) WorkerOption { 146 | return func(w *Worker) error { 147 | w.SetLoopCount(count) 148 | return nil 149 | } 150 | } 151 | 152 | func WithInfinityLoop() WorkerOption { 153 | return func(w *Worker) error { 154 | w.SetLoopCount(-1) 155 | return nil 156 | } 157 | } 158 | 159 | func WithMaxParallelism(parallelism int32) WorkerOption { 160 | return func(w *Worker) error { 161 | w.SetParallelism(parallelism) 162 | return nil 163 | } 164 | } 165 | 166 | func WithUnlimitedParallelism() WorkerOption { 167 | return func(w *Worker) error { 168 | w.SetParallelism(-1) 169 | return nil 170 | } 171 | } 172 | -------------------------------------------------------------------------------- /worker/worker_test.go: -------------------------------------------------------------------------------- 1 | package worker 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "sync" 7 | "sync/atomic" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func TestWorker(t *testing.T) { 13 | errOpt := func(_ *Worker) error { 14 | return errors.New("invalid") 15 | } 16 | 17 | worker, err := NewWorker(nil, errOpt) 18 | if err == nil || worker != nil { 19 | t.Fatal("error not occured") 20 | } 21 | 22 | worker, err = NewWorker(nil, WithLoopCount(1)) 23 | if err != nil { 24 | t.Fatal(err) 25 | } 26 | 27 | worker.Process(context.Background()) 28 | } 29 | 30 | func TestWorkerLimited(t *testing.T) { 31 | pool := []int{} 32 | mu := sync.Mutex{} 33 | f := func(_ context.Context, i int) { 34 | mu.Lock() 35 | pool = append(pool, i) 36 | mu.Unlock() 37 | } 38 | 39 | worker, err := NewWorker(f, WithLoopCount(5), WithUnlimitedParallelism()) 40 | if err != nil { 41 | t.Fatal(err) 42 | } 43 | 44 | worker.Process(context.Background()) 45 | 46 | mu.Lock() 47 | defer mu.Unlock() 48 | if len(pool) != 5 { 49 | t.Fatalf("executed count is missmatch: %d", len(pool)) 50 | } 51 | } 52 | 53 | func TestWorkerLimitedCancel(t *testing.T) { 54 | f := func(_ context.Context, _ int) { 55 | <-time.After(100 * time.Millisecond) 56 | } 57 | 58 | worker, err := NewWorker(f, WithLoopCount(100), WithMaxParallelism(1)) 59 | if err != nil { 60 | t.Fatal(err) 61 | } 62 | 63 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) 64 | defer cancel() 65 | 66 | now := time.Now() 67 | worker.Process(ctx) 68 | diff := time.Now().Sub(now) 69 | 70 | if diff > 1*time.Second { 71 | t.Fatalf("Executed all with %s", diff) 72 | } 73 | } 74 | 75 | func TestWorkerLimitedCanceled(t *testing.T) { 76 | n := int32(0) 77 | count := &n 78 | f := func(_ context.Context, _ int) { 79 | atomic.AddInt32(count, 1) 80 | <-time.After(100 * time.Millisecond) 81 | } 82 | 83 | worker, err := NewWorker(f, WithLoopCount(100), WithMaxParallelism(1)) 84 | if err != nil { 85 | t.Fatal(err) 86 | } 87 | 88 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) 89 | cancel() 90 | 91 | worker.Process(ctx) 92 | 93 | if n := atomic.LoadInt32(count); n > 0 { 94 | t.Fatalf("Executed count: %d", n) 95 | } 96 | } 97 | 98 | func TestWorkerUnlimited(t *testing.T) { 99 | n := int32(0) 100 | count := &n 101 | f := func(_ context.Context, i int) { 102 | atomic.AddInt32(count, 1) 103 | } 104 | 105 | worker, err := NewWorker(f, WithInfinityLoop(), WithMaxParallelism(100)) 106 | if err != nil { 107 | t.Fatal(err) 108 | } 109 | 110 | ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) 111 | defer cancel() 112 | worker.Process(ctx) 113 | 114 | if atomic.LoadInt32(count) == 0 { 115 | t.Fatalf("worker not executed") 116 | } 117 | } 118 | 119 | func TestWorkerUnlimitedCanceled(t *testing.T) { 120 | n := int32(0) 121 | count := &n 122 | f := func(_ context.Context, i int) { 123 | atomic.AddInt32(count, 1) 124 | } 125 | 126 | worker, err := NewWorker(f, WithInfinityLoop(), WithMaxParallelism(100)) 127 | if err != nil { 128 | t.Fatal(err) 129 | } 130 | 131 | ctx, cancel := context.WithCancel(context.Background()) 132 | cancel() 133 | worker.Process(ctx) 134 | 135 | if n := atomic.LoadInt32(count); n > 0 { 136 | t.Fatalf("Executed count: %d", n) 137 | } 138 | } 139 | 140 | func TestWorkerSetLoopCount(t *testing.T) { 141 | var worker *Worker 142 | 143 | count := int32(0) 144 | f := func(_ context.Context, i int) { 145 | atomic.AddInt32(&count, 1) 146 | } 147 | 148 | worker, err := NewWorker(f) 149 | if err != nil { 150 | t.Fatal(err) 151 | } 152 | worker.SetLoopCount(10) 153 | 154 | worker.Process(context.Background()) 155 | 156 | if n := atomic.LoadInt32(&count); n != 10 { 157 | t.Fatalf("Executed count: %d", n) 158 | } 159 | } 160 | 161 | func TestWorkerSetParallelism(t *testing.T) { 162 | var worker *Worker 163 | 164 | count := int32(0) 165 | f := func(_ context.Context, i int) { 166 | atomic.AddInt32(&count, 1) 167 | time.Sleep(100 * time.Millisecond) 168 | } 169 | 170 | worker, err := NewWorker(f, WithMaxParallelism(1)) 171 | if err != nil { 172 | t.Fatal(err) 173 | } 174 | worker.SetLoopCount(10) 175 | 176 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) 177 | defer cancel() 178 | 179 | worker.Process(ctx) 180 | worker.Wait() 181 | 182 | if n := atomic.LoadInt32(&count); n > 1 { 183 | t.Fatalf("Executed count: %d", n) 184 | } 185 | 186 | atomic.StoreInt32(&count, 0) 187 | worker.SetParallelism(2) 188 | 189 | ctx2, cancel2 := context.WithTimeout(context.Background(), 50*time.Millisecond) 190 | defer cancel2() 191 | worker.Process(ctx2) 192 | worker.Wait() 193 | 194 | if n := atomic.LoadInt32(&count); n > 2 { 195 | t.Fatalf("Executed count: %d", n) 196 | } 197 | } 198 | 199 | func TestWorkerAddParallelism(t *testing.T) { 200 | var worker *Worker 201 | 202 | count := int32(0) 203 | f := func(_ context.Context, i int) { 204 | atomic.AddInt32(&count, 1) 205 | time.Sleep(100 * time.Millisecond) 206 | } 207 | 208 | worker, err := NewWorker(f, WithMaxParallelism(1)) 209 | if err != nil { 210 | t.Fatal(err) 211 | } 212 | worker.SetLoopCount(10) 213 | 214 | ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) 215 | defer cancel() 216 | 217 | worker.Process(ctx) 218 | worker.Wait() 219 | 220 | if n := atomic.LoadInt32(&count); n > 1 { 221 | t.Fatalf("Executed count: %d", n) 222 | } 223 | 224 | atomic.StoreInt32(&count, 0) 225 | worker.AddParallelism(1) 226 | 227 | ctx2, cancel2 := context.WithTimeout(context.Background(), 50*time.Millisecond) 228 | defer cancel2() 229 | worker.Process(ctx2) 230 | worker.Wait() 231 | 232 | if n := atomic.LoadInt32(&count); n > 2 { 233 | t.Fatalf("Executed count: %d", n) 234 | } 235 | } 236 | 237 | func BenchmarkWorker(b *testing.B) { 238 | ctx, cancel := context.WithCancel(context.TODO()) 239 | defer cancel() 240 | 241 | nop := func(_ context.Context, _ int) {} 242 | worker, err := NewWorker(nop, WithLoopCount(int32(b.N))) 243 | if err != nil { 244 | b.Fatal(err) 245 | } 246 | 247 | b.ResetTimer() 248 | worker.Process(ctx) 249 | worker.Wait() 250 | b.StopTimer() 251 | } 252 | --------------------------------------------------------------------------------