├── .github ├── dependabot.yml └── workflows │ └── ci.yml ├── .gitignore ├── .golangci.yml ├── CHANGELOG.md ├── LICENSE ├── Makefile ├── README.md ├── arena.go ├── assets ├── README.md ├── go.mod └── logo.png ├── benchext ├── errtrace_test.go ├── go.mod ├── go.sum ├── pkg_errors_test.go ├── pkgerrors_example_http_test.go └── utils_for_test.go ├── cmd └── errtrace │ ├── .gitignore │ ├── main.go │ ├── main_test.go │ ├── testdata │ ├── golden │ │ ├── already_imported.go │ │ ├── already_imported.go.golden │ │ ├── closure.go │ │ ├── closure.go.golden │ │ ├── error_wrap.go │ │ ├── error_wrap.go.golden │ │ ├── imported_blank.go │ │ ├── imported_blank.go.golden │ │ ├── imported_with_alias.go │ │ ├── imported_with_alias.go.golden │ │ ├── name_already_taken.go │ │ ├── name_already_taken.go.golden │ │ ├── named_returns.go │ │ ├── named_returns.go.golden │ │ ├── nested.go │ │ ├── nested.go.golden │ │ ├── no-wrapn.go │ │ ├── no-wrapn.go.golden │ │ ├── no_import.go │ │ ├── no_import.go.golden │ │ ├── noop.go │ │ ├── noop.go.golden │ │ ├── optout.go │ │ ├── optout.go.golden │ │ ├── simple.go │ │ ├── simple.go.golden │ │ ├── tuple_rhs.go │ │ ├── tuple_rhs.go.golden │ │ ├── wrapn.go │ │ └── wrapn.go.golden │ ├── main │ │ ├── foo │ │ │ └── foo.go │ │ ├── go.mod │ │ ├── go.sum │ │ └── main.go │ └── toolexec-test │ │ ├── main.go │ │ ├── main_test.go │ │ ├── p1 │ │ └── p1.go │ │ ├── p2 │ │ └── p2.go │ │ └── p3 │ │ ├── errtrace.go │ │ └── p3.go │ ├── toolexec.go │ └── toolexec_test.go ├── codecov.yml ├── errors.go ├── errtrace.go ├── errtrace_line_test.go ├── errtrace_test.go ├── example_errhelper_test.go ├── example_http_test.go ├── example_trace_test.go ├── example_tree_test.go ├── go.mod ├── internal ├── diff │ └── diff.go ├── pc │ ├── pc_amd64.s │ ├── pc_arm64.s │ ├── pc_asm.go │ ├── pc_safe.go │ └── pc_test.go └── tracetest │ ├── clean.go │ ├── clean_2_test.go │ └── clean_test.go ├── tree.go ├── tree_test.go ├── unwrap.go ├── unwrap_test.go ├── wrap.go ├── wrap_caller.go ├── wrap_caller_safe_test.go └── wrap_caller_test.go /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: ['**'] 8 | types: 9 | # On by default if types not specified: 10 | - "opened" 11 | - "reopened" 12 | - "synchronize" 13 | 14 | # For `skip changelog` handling: 15 | - "labeled" 16 | - "unlabeled" 17 | 18 | permissions: 19 | contents: read 20 | 21 | env: 22 | # Use the Go toolchain installed by setup-go 23 | # https://github.com/actions/setup-go/issues/457 24 | GOTOOLCHAIN: local 25 | 26 | jobs: 27 | 28 | test: 29 | name: Test / Go ${{ matrix.go-version }} / ${{ matrix.os }}/${{ matrix.arch }} 30 | runs-on: ${{ matrix.os }} 31 | strategy: 32 | matrix: 33 | go-version: ['1.23.x', '1.22.x', '1.21.x'] 34 | arch: ['amd64', '386', 'arm64'] 35 | os: ['ubuntu-latest'] 36 | include: 37 | - os: 'macos-latest' 38 | arch: 'amd64' 39 | go-version: '1.23.x' 40 | - os: 'windows-latest' 41 | arch: 'amd64' 42 | go-version: '1.23.x' 43 | 44 | steps: 45 | - name: Checkout code 46 | uses: actions/checkout@v4 47 | 48 | - name: Setup Go 49 | uses: actions/setup-go@v5 50 | with: 51 | go-version: ${{ matrix.go-version }} 52 | 53 | # GH runners use amd64 which also support 386. 54 | # For other architectures, use qemu. 55 | - name: Install QEMU 56 | if: matrix.arch != 'amd64' && matrix.arch != '386' 57 | uses: docker/setup-qemu-action@v3 58 | 59 | - name: Enable race detection 60 | shell: bash 61 | run: | 62 | # Only amd64 support data-race detection in CI. 63 | # qemu doesn't give us cgo, which is needed for arm64. 64 | if [[ "$GOARCH" == amd64 ]]; then 65 | echo "Enabling data-race detection." 66 | else 67 | echo "NO_RACE=1" >> "$GITHUB_ENV" 68 | fi 69 | env: 70 | GOARCH: ${{ matrix.arch }} 71 | 72 | - name: Test ${{ matrix.arch }} 73 | run: make cover 74 | shell: bash 75 | env: 76 | GOARCH: ${{ matrix.arch }} 77 | 78 | - name: Coverage 79 | uses: codecov/codecov-action@v5 80 | with: 81 | files: ./cover.unsafe.out,./cover.safe.out 82 | 83 | lint: 84 | name: Lint 85 | runs-on: ubuntu-latest 86 | 87 | steps: 88 | - uses: actions/checkout@v4 89 | name: Check out repository 90 | 91 | - uses: actions/setup-go@v5 92 | name: Set up Go 93 | with: 94 | # Use the Go language version in go.mod for linting. 95 | go-version-file: go.mod 96 | cache: false # managed by golangci-lint 97 | 98 | - uses: golangci/golangci-lint-action@v6 99 | name: Install golangci-lint 100 | with: 101 | version: latest 102 | args: --help # make lint will run the linter 103 | 104 | - name: Lint 105 | run: make lint GOLANGCI_LINT_ARGS=--out-format=github-actions 106 | # Write in a GitHub Actions-friendly format 107 | # to annotate lines in the PR. 108 | 109 | changelog: 110 | runs-on: ubuntu-latest 111 | steps: 112 | - name: "Check CHANGELOG is updated or PR is marked skip changelog" 113 | uses: brettcannon/check-for-changed-files@v1.2.1 114 | # Run only if PR body doesn't contain '[skip changelog]'. 115 | if: ${{ !contains(github.event.pull_request.body, '[skip changelog]') }} 116 | with: 117 | file-pattern: CHANGELOG.md 118 | skip-label: "skip changelog" 119 | token: ${{ secrets.GITHUB_TOKEN }} 120 | failure-message: >- 121 | Missing a changelog update ${file-pattern}; please update or 122 | if a changelog entry is not needed, use label ${skip-label} 123 | or add [skip changelog] to the PR description. 124 | 125 | # ci-ok is a dummy job that runs after test and lint. 126 | # It provides a job for us to attach a Required Status Check to. 127 | ci-ok: 128 | name: OK 129 | runs-on: ubuntu-latest 130 | needs: [test, lint, changelog] 131 | steps: 132 | - name: Success 133 | run: echo "All checks passed." 134 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /bin 2 | /cover*.out 3 | /cover.html 4 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | output: 2 | # Make output more digestible with quickfix in vim/emacs/etc. 3 | sort-results: true 4 | print-issued-lines: false 5 | 6 | linters: 7 | # We'll track the golangci-lint default linters manually 8 | # instead of letting them change without our control. 9 | disable-all: true 10 | enable: 11 | # golangci-lint defaults: 12 | - errcheck 13 | - gosimple 14 | - govet 15 | - ineffassign 16 | - staticcheck 17 | - unused 18 | 19 | # Our own extras: 20 | - gofumpt 21 | - nolintlint # lints nolint directives 22 | - revive 23 | 24 | linters-settings: 25 | govet: 26 | # These govet checks are disabled by default, but they're useful. 27 | enable: 28 | - niliness 29 | - reflectvaluecompare 30 | - sortslice 31 | - unusedwrite 32 | 33 | errcheck: 34 | exclude-functions: 35 | # Writing a plain string to a fmt.State cannot fail. 36 | - io.WriteString(fmt.State) 37 | - fmt.Fprintf(fmt.State) 38 | 39 | issues: 40 | # Print all issues reported by all linters. 41 | max-issues-per-linter: 0 42 | max-same-issues: 0 43 | 44 | # Don't ignore some of the issues that golangci-lint considers okay. 45 | # This includes documenting all exported entities. 46 | exclude-use-default: false 47 | 48 | exclude-rules: 49 | # Don't warn on unused parameters. 50 | # Parameter names are useful; replacing them with '_' is undesirable. 51 | - linters: [revive] 52 | text: 'unused-parameter: parameter \S+ seems to be unused, consider removing or renaming it as _' 53 | 54 | # staticcheck already has smarter checks for empty blocks. 55 | # revive's empty-block linter has false positives. 56 | # For example, as of writing this, the following is not allowed. 57 | # for foo() { } 58 | - linters: [revive] 59 | text: 'empty-block: this block is empty, you can remove it' 60 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## Unreleased 9 | ### Added 10 | - Add `UnwrapFrame` function to extract a single frame from an error. 11 | You can use this to implement your own trace formatting logic. 12 | - Support extracting trace frames from custom errors. 13 | Any error value that implements `TracePC() uintptr` will now 14 | contribute to the trace. 15 | - Add `GetCaller` function for error helpers to annotate wrapped errors with 16 | their caller information instead of the helper. Example: 17 | 18 | ```go 19 | //go:noinline 20 | func Wrapf(err error, msg string, args ...any) { 21 | caller := errtrace.GetCaller() 22 | err := ... 23 | return caller.Wrap(err) 24 | } 25 | ``` 26 | 27 | - cmd/errtrace: 28 | Add `-no-wrapn` option to disable wrapping with generic `WrapN` functions. 29 | This is only useful for toolexec mode due to tooling limitations. 30 | - cmd/errtrace: 31 | Experimental support for instrumenting code with errtrace automatically 32 | as part of the Go build process. 33 | Try this out with `go build -toolexec=errtrace pkg/to/build`. 34 | Automatic instrumentation only rewrites packages that import errtrace. 35 | The flag `-required-packages` can be used to specify which packages 36 | are expected to import errtrace if they require rewrites. 37 | Example: `go build -toolexec="errtrace -required-packages pkg/..." pkg/to/build` 38 | 39 | ### Changed 40 | - Update `go` directive in go.mod to 1.21, and drop compatibility with Go 1.20 and earlier. 41 | 42 | ### Fixed 43 | - cmd/errtrace: Don't exit with a non-zero status when `-h` is used. 44 | - cmd/errtrace: Don't panic on imbalanced assignments inside defer blocks. 45 | 46 | ## v0.3.0 - 2023-12-22 47 | 48 | This release adds support to the CLI for using Go package patterns like `./...` 49 | to match and transform files. 50 | You can now use `errtrace -w ./...` to instrument all files in a Go module, 51 | or `errtrace -l ./...` to list all files that would be changed. 52 | 53 | ### Added 54 | - cmd/errtrace: Support Go package patterns in addition to file paths. 55 | Use `errtrace -w ./...` to transform all files under the current package 56 | and its descendants. 57 | 58 | ### Changed 59 | - cmd/errtrace: 60 | Print a message when reading from stdin because no arguments were given. 61 | Use '-' as the file name to read from stdin without a warning. 62 | 63 | ## v0.2.0 - 2023-11-30 64 | 65 | This release contains minor improvements to the errtrace code transformer 66 | allowing it to fit more use cases. 67 | 68 | ### Added 69 | - cmd/errtrace: 70 | Add -l flag to print files that would be changed without changing them. 71 | You can use this to build a check to verify that your code is instrumented. 72 | - cmd/errtrace: Support opt-out on lines with a `//errtrace:skip` comment. 73 | Optionally, a reason may be specified alongside the comment. 74 | The command will print a warning for any unused `//errtrace:skip` comments. 75 | 76 | ```go 77 | if err != nil { 78 | return io.EOF //errtrace:skip(io.Reader expects io.EOF) 79 | } 80 | ``` 81 | 82 | ## v0.1.1 - 2023-11-28 83 | ### Changed 84 | - Lower `go` directive in go.mod to 1.20 85 | to allow use with older versions. 86 | 87 | ### Fixed 88 | - Add a README.md to render alongside the 89 | [API reference](https://pkg.go.dev/braces.dev/errtrace). 90 | 91 | ## v0.1.0 - 2023-11-28 92 | 93 | Introducing errtrace, an experimental library 94 | that provides better stack traces for your errors. 95 | 96 | Install the library with: 97 | 98 | ```bash 99 | go get braces.dev/errtrace@v0.1.0 100 | ``` 101 | 102 | We've also included a tool 103 | that will automatically instrument your code with errtrace. 104 | In your project, run: 105 | 106 | ```bash 107 | go install braces.dev/errtrace/cmd/errtrace@v0.1.0 108 | git ls-files -- '*.go' | xargs errtrace -w 109 | ``` 110 | 111 | See [README](https://github.com/bracesdev/errtrace#readme) 112 | for more information. 113 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, The braces.dev maintainers 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL := /bin/bash 2 | PROJECT_ROOT = $(dir $(abspath $(lastword $(MAKEFILE_LIST)))) 3 | 4 | # 'go install' into the project's bin directory 5 | # and add it to the PATH. 6 | export GOBIN ?= $(PROJECT_ROOT)/bin 7 | export PATH := $(GOBIN):$(PATH) 8 | 9 | ERRTRACE = $(GOBIN)/errtrace 10 | 11 | # Packages to instrument with errtrace relative to the project root. 12 | ERRTRACE_PKGS = ./cmd/errtrace/... 13 | 14 | # only use -race if NO_RACE is unset. 15 | RACE=$(if $(NO_RACE),,-race) 16 | 17 | GOLANGCI_LINT_ARGS ?= 18 | 19 | .PHONY: test 20 | test: 21 | go test $(RACE) ./... 22 | go test $(RACE) -tags safe ./... 23 | go test -gcflags='-l -N' ./... # disable optimizations/inlining 24 | 25 | .PHONY: cover 26 | cover: 27 | go test -coverprofile cover.unsafe.out -coverpkg ./... $(RACE) ./... 28 | go test -coverprofile cover.safe.out -coverpkg ./... $(RACE) -tags safe ./... 29 | go test ./... -gcflags='-l -N' ./... # disable optimizations/inlining 30 | 31 | .PHONY: bench 32 | bench: 33 | go test -run NONE -bench . -cpu 1 34 | 35 | .PHONY: bench-parallel 36 | bench-parallel: 37 | go test -run NONE -bench . -cpu 1,2,4,8 38 | 39 | .PHONY: lint 40 | lint: golangci-lint errtrace-lint 41 | 42 | .PHONY: golangci-lint 43 | golangci-lint: 44 | @if ! command -v golangci-lint >/dev/null; then \ 45 | echo "golangci-lint not found. Installing..."; \ 46 | mkdir -p $(GOBIN); \ 47 | curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(GOBIN); \ 48 | fi; \ 49 | echo "Running golangci-lint"; \ 50 | golangci-lint run $(GOLANGCI_LINT_ARGS) ./... 51 | 52 | .PHONY: errtrace 53 | errtrace: $(ERRTRACE) 54 | $(ERRTRACE) -w $(ERRTRACE_PKGS) 55 | 56 | .PHONY: errtrace-lint 57 | errtrace-lint: $(ERRTRACE) 58 | @echo "Running errtrace"; \ 59 | changed=$$($(ERRTRACE) -l $(ERRTRACE_PKGS)); \ 60 | if [[ -n "$$changed" ]]; then \ 61 | echo "Found uninstrumented files. Please run 'make errtrace'"; \ 62 | echo "$$changed"; \ 63 | exit 1; \ 64 | fi 65 | 66 | $(ERRTRACE): 67 | go install braces.dev/errtrace/cmd/errtrace 68 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # errtrace 2 | 3 | *What if every function added its location to returned errors?* 4 | 5 | ![errtrace logo](assets/logo.png) 6 | 7 | [![CI](https://github.com/bracesdev/errtrace/actions/workflows/ci.yml/badge.svg)](https://github.com/bracesdev/errtrace/actions/workflows/ci.yml) 8 | [![Go Reference](https://pkg.go.dev/badge/braces.dev/errtrace.svg)](https://pkg.go.dev/braces.dev/errtrace) 9 | [![codecov](https://codecov.io/gh/bracesdev/errtrace/graph/badge.svg?token=KDY04XEEJ9)](https://codecov.io/gh/bracesdev/errtrace) 10 | 11 | - [Introduction](#introduction) 12 | - [Features](#features) 13 | - [Comparison with stack traces](#comparison-with-stack-traces) 14 | - [Try it out](#try-it-out) 15 | - [Why is this useful](#why-is-this-useful) 16 | - [Installation](#installation) 17 | - [Usage](#usage) 18 | - [Manual instrumentation](#manual-instrumentation) 19 | - [Automatic instrumentation](#automatic-instrumentation) 20 | - [Performance](#performance) 21 | - [Caveats](#caveats) 22 | - [Error wrapping](#error-wrapping) 23 | - [Safety](#safety) 24 | - [Contributing](#contributing) 25 | - [Acknowledgements](#acknowledgements) 26 | - [License](#license) 27 | 28 | ## Introduction 29 | 30 | errtrace is an **experimental** package to trace an error's return path — 31 | the return trace — through a Go program. 32 | 33 | Where a stack trace tracks the code path that led to an error, 34 | a return trace tracks the code path that the error took to get to the user. 35 | Often these are the same path, but in Go they can diverge, 36 | since errors are values that can be transported across goroutines 37 | (e.g. with channels). 38 | When that happens, a return trace can be more useful than a stack trace. 39 | 40 | This library is inspired by 41 | [Zig's error return traces](https://ziglang.org/documentation/0.11.0/#Error-Return-Traces). 42 | 43 | ### Features 44 | 45 | * **Lightweight**\ 46 | errtrace brings no other runtime dependencies with it. 47 | * **[Simple](#manual-instrumentation)**\ 48 | The library API is simple, straightforward, and idiomatic. 49 | * **[Easy](#automatic-instrumentation)**\ 50 | The errtrace CLI will automatically instrument your code. 51 | * **[Fast](#performance)**\ 52 | On popular 64-bit systems, 53 | errtrace is much faster than capturing a stack trace. 54 | 55 | ### Comparison with stack traces 56 | 57 | With stack traces, caller information for the goroutine is 58 | captured once when the error is created. 59 | 60 | In constrast, errtrace records the caller information incrementally, 61 | following the return path the error takes to get to the user. 62 | This approach works even if the error isn't propagated directly 63 | through function returns, and across goroutines. 64 | 65 | Both approaches look similar when the error flows 66 | through function calls within the same goroutine, 67 | but can differ significantly when errors are passed 68 | outside of functions and across goroutines (e.g., channels). 69 | 70 | Here's a real-world example that shows the benefits 71 | of errtrace tracing the return path 72 | by comparing a custom dial error returned for a HTTP request, 73 | which the net/http library uses a background goroutine for. 74 | 75 |
76 | errtrace compared to a stack trace 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 114 | 115 | 116 | 117 | 118 | 119 |
errtracestack trace
85 | 86 | ``` 87 | Error: connect rate limited 88 | 89 | braces.dev/errtrace_test.rateLimitDialer 90 | /path/to/errtrace/example_http_test.go:72 91 | braces.dev/errtrace_test.(*PackageStore).updateIndex 92 | /path/to/errtrace/example_http_test.go:59 93 | braces.dev/errtrace_test.(*PackageStore).Get 94 | /path/to/errtrace/example_http_test.go:49 95 | ``` 96 | 97 | 98 | 99 | ``` 100 | Error: connect rate limited 101 | braces.dev/errtrace_test.rateLimitDialer 102 | /errtrace/example_stack_test.go:81 103 | net/http.(*Transport).dial 104 | /goroot/src/net/http/transport.go:1190 105 | net/http.(*Transport).dialConn 106 | /goroot/src/net/http/transport.go:1625 107 | net/http.(*Transport).dialConnFor 108 | /goroot/src/net/http/transport.go:1467 109 | runtime.goexit 110 | /goroot/src/runtime/asm_arm64.s:1197 111 | ``` 112 | 113 |
errtrace reports the method that triggered the HTTP requeststack trace shows details of how the HTTP client creates a connection
120 | 121 |
122 | 123 | errtrace also reduces the performance impact 124 | of capturing caller information for errors 125 | that are handled rather than returned to the user, 126 | as the information is captured incrementally. 127 | Stack traces pay a fixed cost to capture caller information 128 | even if the error is handled immediately by the caller 129 | close to where the error is created. 130 | 131 | ### Try it out 132 | 133 | Try out errtrace with your own code: 134 | 135 | 1. Install the CLI. 136 | 137 | ```bash 138 | go install braces.dev/errtrace/cmd/errtrace@latest 139 | ``` 140 | 2. Switch to your Git repository and instrument your code. 141 | 142 | ```bash 143 | errtrace -w ./... 144 | ``` 145 | 3. Let `go mod tidy` install the errtrace Go module for you. 146 | 147 | ```bash 148 | go mod tidy 149 | ``` 150 | 4. Run your tests to ensure everything still works. 151 | You may see failures 152 | if you're comparing errors with `==` on critical paths 153 | or if you're type-casting errors directly. 154 | See [Error wrapping](#error-wrapping) for more details. 155 | 156 | ```bash 157 | go test ./... 158 | ``` 159 | 5. Print return traces for errors in your code. 160 | To do this, you can use the `errtrace.FormatString` function 161 | or format the error with `%+v` in `fmt.Printf`-style functions. 162 | 163 | ```go 164 | if err != nil { 165 | fmt.Fprintf(os.Stderr, "%+v", err) 166 | } 167 | ``` 168 | 169 | Return traces printed by errtrace 170 | will include the error message 171 | and the path the error took until it was printed. 172 | The output will look roughly like this: 173 | 174 | ``` 175 | error message 176 | 177 | example.com/myproject.MyFunc 178 | /home/user/myproject/myfile.go:123 179 | example.com/myproject.CallerOfMyFunc 180 | /home/user/myproject/another_file.go:456 181 | [...] 182 | ``` 183 | 184 | Here's a real-world example of errtrace in action: 185 | 186 |
187 | Example 188 | 189 | ``` 190 | doc2go: parse file: /path/to/project/example/foo.go:3:1: expected declaration, found invalid 191 | 192 | go.abhg.dev/doc2go/internal/gosrc.parseFiles 193 | /path/to/project/internal/gosrc/parser.go:85 194 | go.abhg.dev/doc2go/internal/gosrc.(*Parser).ParsePackage 195 | /path/to/project/internal/gosrc/parser.go:44 196 | main.(*Generator).renderPackage 197 | /path/to/project/generate.go:193 198 | main.(*Generator).renderTree 199 | /path/to/project/generate.go:141 200 | main.(*Generator).renderTrees 201 | /path/to/project/generate.go:118 202 | main.(*Generator).renderPackageIndex 203 | /path/to/project/generate.go:149 204 | main.(*Generator).renderTree 205 | /path/to/project/generate.go:137 206 | main.(*Generator).renderTrees 207 | /path/to/project/generate.go:118 208 | main.(*Generator).renderPackageIndex 209 | /path/to/project/generate.go:149 210 | main.(*Generator).renderTree 211 | /path/to/project/generate.go:137 212 | main.(*Generator).renderTrees 213 | /path/to/project/generate.go:118 214 | main.(*Generator).Generate 215 | /path/to/project/generate.go:110 216 | main.(*mainCmd).run 217 | /path/to/project/main.go:199 218 | ``` 219 | 220 | Note the some functions repeat in this trace 221 | because the functions are mutually recursive. 222 |
223 | 224 | 225 | ### Why is this useful? 226 | 227 | In Go, [errors are values](https://go.dev/blog/errors-are-values). 228 | This means that an error can be passed around like any other value. 229 | You can store it in a struct, pass it through a channel, etc. 230 | This level of flexibility is great, 231 | but it can also make it difficult to track down the source of an error. 232 | A stack trace stored in an error — recorded at the error site — 233 | becomes less useful as the error moves through the program. 234 | When it's eventually surfaced to the user, 235 | we've lost a lot of context about its origin. 236 | 237 | With errtrace, 238 | we instead record the path the program took from the error site 239 | to get to the user — the **return trace**. 240 | Not only can this be more useful than a stack trace, 241 | it tends to be much faster and more lightweight as well. 242 | 243 | ## Installation 244 | 245 | Install errtrace with Go modules: 246 | 247 | ```bash 248 | go get braces.dev/errtrace@latest 249 | ``` 250 | 251 | If you want to use the CLI, use `go install`. 252 | 253 | ```bash 254 | go install braces.dev/errtrace/cmd/errtrace@latest 255 | ``` 256 | 257 | ## Usage 258 | 259 | errtrace offers the following modes of usage: 260 | 261 | * [Manual instrumentation](#manual-instrumentation) 262 | * [Automatic instrumentation](#automatic-instrumentation) 263 | 264 | ### Manual instrumentation 265 | 266 | ```go 267 | import "braces.dev/errtrace" 268 | ``` 269 | 270 | Under manual instrumentation, 271 | you're expected to import errtrace, 272 | and wrap errors at all return sites like so: 273 | 274 | ```go 275 | // ... 276 | if err != nil { 277 | return errtrace.Wrap(err) 278 | } 279 | ``` 280 | 281 |
282 | Example 283 | 284 | Given a function like the following: 285 | 286 | ```go 287 | func writeToFile(path string, src io.Reader) error { 288 | dst, err := os.Create(path) 289 | if err != nil { 290 | return err 291 | } 292 | defer dst.Close() 293 | 294 | if _, err := io.Copy(dst, src); err != nil { 295 | return err 296 | } 297 | 298 | return nil 299 | } 300 | ``` 301 | 302 | With errtrace, you'd change it to: 303 | 304 | ```go 305 | func writeToFile(path string, src io.Reader) error { 306 | dst, err := os.Create(path) 307 | if err != nil { 308 | return errtrace.Wrap(err) 309 | } 310 | defer dst.Close() 311 | 312 | if _, err := io.Copy(dst, src); err != nil { 313 | return errtrace.Wrap(err) 314 | } 315 | 316 | return nil 317 | } 318 | ``` 319 | 320 | It's important that the `errtrace.Wrap` function is called 321 | inside the same function that's actually returning the error. 322 | A helper function will not suffice. 323 |
324 | 325 | ### Automatic instrumentation 326 | 327 | If manual instrumentation is too much work (we agree), 328 | we've included a tool that will automatically instrument 329 | all your code with errtrace. 330 | 331 | First, [install the tool](#installation). 332 | Then, run it on your code: 333 | 334 | ```bash 335 | errtrace -w path/to/file.go path/to/another/file.go 336 | ``` 337 | 338 | Instead of specifying individual files, 339 | you can also specify a Go package pattern. 340 | For example: 341 | 342 | ```bash 343 | errtrace -w example.com/path/to/package 344 | errtrace -w ./... 345 | ``` 346 | 347 | errtrace can be set be setup as a custom formatter in your editor, 348 | similar to gofmt or goimports. 349 | 350 | #### Opting-out during automatic instrumentation 351 | 352 | If you're relying on automatic instrumentation 353 | and want to ignore specific lines from being instrumented, 354 | you can add a comment in one of the following forms 355 | on relevant lines: 356 | 357 | ```go 358 | //errtrace:skip 359 | //errtrace:skip(explanation) 360 | //errtrace:skip // explanation 361 | ``` 362 | 363 | This can be especially useful if the returned error 364 | has to match another error exactly because the caller still uses `==`. 365 | 366 | For example, if you're implementing `io.Reader`, 367 | you need to return `io.EOF` when you reach the end of the input. 368 | Wrapping it will cause functions like `io.ReadAll` to misbehave. 369 | 370 | ```go 371 | type myReader struct{/* ... */} 372 | 373 | func (*myReader) Read(bs []byte) (int, error) { 374 | // ... 375 | return 0, io.EOF //errtrace:skip(io.Reader expects io.EOF) 376 | } 377 | ``` 378 | 379 | ## Performance 380 | 381 | errtrace is designed to have very low overhead 382 | on [supported systems](#supported-systems). 383 | 384 | Benchmark results for linux/amd64 on an Intel Core i5-13600 (best of 10): 385 | 386 | ``` 387 | BenchmarkFmtErrorf 11574928 103.5 ns/op 40 B/op 2 allocs/op 388 | # default build, uses Go assembly. 389 | BenchmarkWrap 78173496 14.70 ns/op 24 B/op 0 allocs/op 390 | # build with -tags safe to avoid assembly. 391 | BenchmarkWrap 5958579 198.5 ns/op 24 B/op 0 allocs/op 392 | 393 | # benchext compares capturing stacks using pkg/errors vs errtrace 394 | # both tests capture ~10 frames, 395 | BenchmarkErrtrace 6388651 188.4 ns/op 280 B/op 1 allocs/op 396 | BenchmarkPkgErrors 1673145 716.8 ns/op 304 B/op 3 allocs/op 397 | ``` 398 | 399 | Stack traces have a large initial cost, 400 | while errtrace scales with each frame that an error is returned through. 401 | 402 | ## Caveats 403 | 404 | ### Error wrapping 405 | 406 | errtrace operates by wrapping your errors to add caller information. 407 | As a result of this, 408 | error comparisons and type-casting may not work as expected. 409 | You can no longer use `==` to compare errors, or type-cast them directly. 410 | You must use the standard library's 411 | [errors.Is](https://pkg.go.dev/errors#Is) and 412 | [errors.As](https://pkg.go.dev/errors#As) functions. 413 | 414 | For example, if you have a function `readFile` 415 | that wraps an `io.EOF` error with errtrace: 416 | 417 | **Matching errors** 418 | 419 | ```go 420 | err := readFile() // returns errtrace.Wrap(io.EOF) 421 | 422 | // This will not work. 423 | fmt.Println(err == io.EOF) // false 424 | 425 | // Use errors.Is instead. 426 | fmt.Println(errors.Is(err, io.EOF)) // true 427 | ``` 428 | 429 | Similarly, if you have a function `runCmd` 430 | that wraps an `exec.ExitError` error with errtrace: 431 | 432 | **Type-casting errors** 433 | 434 | ```go 435 | err := runCmd() // returns errtrace.Wrap(&exec.ExitError{...}) 436 | 437 | // This will not work. 438 | exitErr, ok := err.(*exec.ExitError) // ok = false 439 | 440 | // Use errors.As instead. 441 | var exitErr *exec.ExitError 442 | ok := errors.As(err, &exitErr) // ok = true 443 | ``` 444 | 445 | #### Linting 446 | 447 | You can use [go-errorlint](https://github.com/polyfloyd/go-errorlint) 448 | to find places in your code 449 | where you're comparing errors with `==` instead of using `errors.Is` 450 | or type-casting them directly instead of using `errors.As`. 451 | 452 | ### Safety 453 | 454 | To achieve the performance above on [supported systems](#supported-systems), 455 | errtrace makes use of unsafe operations using Go assembly 456 | to read the caller information directly from the stack. 457 | This is part of the reason why we have the disclaimer on top. 458 | 459 | errtrace includes an opt-in safe mode 460 | that drops these unsafe operations in exchange for poorer performance. 461 | To opt into safe mode, 462 | use the `safe` build tag when compiling code that uses errtrace. 463 | 464 | ```bash 465 | go build -tags safe 466 | ``` 467 | 468 | #### Supported systems 469 | 470 | errtrace's unsafe operations are currently implemented 471 | for `GOARCH=amd64` and `GOARCH=arm64` only. 472 | Other systems are supported but they will use safe mode, which is slower. 473 | 474 | Contributions to support unsafe mode for other architectures are welcome. 475 | 476 | ## Contributing 477 | 478 | Contributions are welcome. 479 | However, we ask that before contributing new features, 480 | you [open an issue](https://github.com/bracesdev/errtrace/issues) 481 | to discuss the feature with us. 482 | 483 | ## Acknowledgements 484 | 485 | The idea of tracing return paths instead of stack traces 486 | comes from [Zig's error return traces](https://ziglang.org/documentation/0.11.0/#Error-Return-Traces). 487 | 488 | ## License 489 | 490 | This software is made available under the BSD3 license. 491 | See LICENSE file for details. 492 | -------------------------------------------------------------------------------- /arena.go: -------------------------------------------------------------------------------- 1 | package errtrace 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | // arena is a lock-free allocator for a fixed-size type. 8 | // It is intended to be used for allocating errTrace objects in batches. 9 | type arena[T any] struct { 10 | slabSize int 11 | pool sync.Pool 12 | } 13 | 14 | func newArena[T any](slabSize int) *arena[T] { 15 | return &arena[T]{ 16 | slabSize: slabSize, 17 | } 18 | } 19 | 20 | // Take returns a pointer to a new object from the arena. 21 | func (a *arena[T]) Take() *T { 22 | for { 23 | slab, ok := a.pool.Get().(*arenaSlab[T]) 24 | if !ok { 25 | slab = newArenaSlab[T](a.slabSize) 26 | } 27 | 28 | if e, ok := slab.take(); ok { 29 | a.pool.Put(slab) 30 | return e 31 | } 32 | } 33 | } 34 | 35 | // arenaSlab is a slab of objects in an arena. 36 | // 37 | // Each slab has a fixed number of objects in it. 38 | // Pointers are taken from the slab in order. 39 | type arenaSlab[T any] struct { 40 | // Full list of objects in the slab. 41 | buf []T 42 | 43 | // Index of the next object to be taken. 44 | idx int 45 | } 46 | 47 | func newArenaSlab[T any](sz int) *arenaSlab[T] { 48 | return &arenaSlab[T]{buf: make([]T, sz)} 49 | } 50 | 51 | func (a *arenaSlab[T]) take() (*T, bool) { 52 | if a.idx >= len(a.buf) { 53 | return nil, false 54 | } 55 | ptr := &a.buf[a.idx] 56 | a.idx++ 57 | return ptr, true 58 | } 59 | -------------------------------------------------------------------------------- /assets/README.md: -------------------------------------------------------------------------------- 1 | The logo for errtrace is made available under the 2 | [Creative Commons 4.0 Attribution License](https://creativecommons.org/licenses/by/4.0/). 3 | 4 | It is based on the Go Gopher mascot originally created by Renee French, 5 | which is also licensed under the Creative Commons 4.0 Attribution License. 6 | -------------------------------------------------------------------------------- /assets/go.mod: -------------------------------------------------------------------------------- 1 | module braces.dev/errtrace/assets 2 | 3 | go 1.21 4 | 5 | // This go.mod exists to avoid having the assets directory 6 | // be shipped as part of the Go module. 7 | -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bracesdev/errtrace/5dc94991cda96f6c3ede5f9be8a00ed55933d719/assets/logo.png -------------------------------------------------------------------------------- /benchext/errtrace_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | 7 | "braces.dev/errtrace" 8 | ) 9 | 10 | func recurseErrtrace(n int) error { 11 | if n == 0 { 12 | return errtrace.New("f5 failed") 13 | } 14 | return errtrace.Wrap(recurseErrtrace(n - 1)) 15 | } 16 | 17 | func BenchmarkErrtrace(b *testing.B) { 18 | var err error 19 | for i := 0; i < b.N; i++ { 20 | err = recurseErrtrace(10) 21 | } 22 | 23 | if wantMin, got := 10, strings.Count(errtrace.FormatString(err), "errtrace_test.go"); got < wantMin { 24 | b.Fatalf("missing expected stack frames, expected >%v, got %v", wantMin, got) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /benchext/go.mod: -------------------------------------------------------------------------------- 1 | module braces.dev/errtrace/benchext 2 | 3 | go 1.21 4 | 5 | replace braces.dev/errtrace => ../ 6 | 7 | require ( 8 | braces.dev/errtrace v0.0.0-00010101000000-000000000000 9 | github.com/pkg/errors v0.9.1 10 | ) 11 | -------------------------------------------------------------------------------- /benchext/go.sum: -------------------------------------------------------------------------------- 1 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 2 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 3 | -------------------------------------------------------------------------------- /benchext/pkg_errors_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | func recurseErrPkgErrors(n int) error { 12 | if n == 0 { 13 | return errors.New("error") 14 | } 15 | 16 | return recurseErrPkgErrors(n - 1) 17 | } 18 | 19 | func BenchmarkPkgErrors(b *testing.B) { 20 | var err error 21 | for i := 0; i < b.N; i++ { 22 | err = recurseErrPkgErrors(10) 23 | } 24 | 25 | if wantMin, got := 10, strings.Count(fmt.Sprintf("%+v", err), "pkg_errors_test.go"); got < wantMin { 26 | b.Fatalf("missing expected stack frames, expected >%v, got %v", wantMin, got) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /benchext/pkgerrors_example_http_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net" 7 | "net/http" 8 | "strings" 9 | 10 | "braces.dev/errtrace/internal/tracetest" 11 | "github.com/pkg/errors" 12 | ) 13 | 14 | func Example_http() { 15 | tp := &http.Transport{Dial: rateLimitDialer} 16 | client := &http.Client{Transport: tp} 17 | ps := &PackageStore{ 18 | client: client, 19 | } 20 | 21 | _, err := ps.Get() 22 | 23 | // Unwrap the HTTP-wrapped error, so we can print a proper stacktrace. 24 | var stErr interface { 25 | error 26 | StackTrace() errors.StackTrace 27 | } 28 | if errors.As(err, &stErr) { 29 | err = stErr 30 | } 31 | 32 | fmt.Printf("Error fetching packages: %s\n", cleanGoRoot(tracetest.MustClean(fmt.Sprintf("%+v", err)))) 33 | // Output: 34 | // Error fetching packages: connect rate limited 35 | // braces.dev/errtrace/benchext.rateLimitDialer 36 | // /path/to/errtrace/benchext/pkgerrors_example_http_test.go:1 37 | // net/http.(*Transport).dial 38 | // /goroot/src/net/http/transport.go:0 39 | // net/http.(*Transport).dialConn 40 | // /goroot/src/net/http/transport.go:0 41 | // net/http.(*Transport).dialConnFor 42 | // /goroot/src/net/http/transport.go:0 43 | // runtime.goexit 44 | // /goroot/src/runtime/asm_amd64.s:0 45 | } 46 | 47 | type PackageStore struct { 48 | client *http.Client 49 | packagesCached []string 50 | } 51 | 52 | func (ps *PackageStore) Get() ([]string, error) { 53 | if ps.packagesCached != nil { 54 | return ps.packagesCached, nil 55 | } 56 | 57 | packages, err := ps.updateIndex() 58 | if err != nil { 59 | return nil, err 60 | } 61 | 62 | ps.packagesCached = packages 63 | return packages, nil 64 | } 65 | 66 | func (ps *PackageStore) updateIndex() ([]string, error) { 67 | resp, err := ps.client.Get("http://example.com/packages.index") 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | contents, err := io.ReadAll(resp.Body) 73 | if err != nil { 74 | return nil, err 75 | } 76 | 77 | return strings.Split(string(contents), ","), nil 78 | } 79 | 80 | func rateLimitDialer(network, addr string) (net.Conn, error) { 81 | // for testing, always return an error. 82 | return nil, errors.Errorf("connect rate limited") 83 | } 84 | -------------------------------------------------------------------------------- /benchext/utils_for_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "regexp" 5 | "strings" 6 | ) 7 | 8 | // cleanGoRoot is similar to tracetest, but deals with GOROOT paths. 9 | // It replaces paths+line numbers for fixed values so we can use the 10 | // output in an example test. 11 | func cleanGoRoot(s string) string { 12 | gorootPath := regexp.MustCompile("/.*/src/") 13 | s = gorootPath.ReplaceAllString(s, "/goroot/src/") 14 | 15 | fileLine := regexp.MustCompile(`/goroot/.*:[0-9]+`) 16 | return fileLine.ReplaceAllStringFunc(s, func(path string) string { 17 | file, _, ok := strings.Cut(path, ":") 18 | if !ok { 19 | return path 20 | } 21 | 22 | return file + ":0" 23 | }) 24 | } 25 | -------------------------------------------------------------------------------- /cmd/errtrace/.gitignore: -------------------------------------------------------------------------------- 1 | /errtrace 2 | -------------------------------------------------------------------------------- /cmd/errtrace/main.go: -------------------------------------------------------------------------------- 1 | // errtrace instruments Go code with error return tracing. 2 | // 3 | // # Installation 4 | // 5 | // Install errtrace with: 6 | // 7 | // go install braces.dev/errtrace/cmd/errtrace@latest 8 | // 9 | // # Usage 10 | // 11 | // errtrace [options] 12 | // 13 | // This will transform source files and write them to the standard output. 14 | // 15 | // If instead of source files, Go package patterns are given, 16 | // errtrace will transform all the files that match those patterns. 17 | // For example, 'errtrace ./...' will transform all files in the current 18 | // package and all subpackages. 19 | // 20 | // Use the following flags to control the output: 21 | // 22 | // -format 23 | // whether to format ouput; one of: [auto, always, never]. 24 | // auto is the default and will format if the output is being written to a file. 25 | // -w write result to the given source files instead of stdout. 26 | // -l list files that would be modified without making any changes. 27 | package main 28 | 29 | // TODO 30 | // - -toolexec: run as a tool executor, fit for use with 'go build -toolexec' 31 | 32 | import ( 33 | "bytes" 34 | "encoding/json" 35 | "errors" 36 | "flag" 37 | "fmt" 38 | "go/ast" 39 | gofmt "go/format" 40 | "go/parser" 41 | "go/token" 42 | "io" 43 | "log" 44 | "os" 45 | "os/exec" 46 | "path/filepath" 47 | "regexp" 48 | "sort" 49 | "strconv" 50 | "strings" 51 | 52 | "braces.dev/errtrace" 53 | ) 54 | 55 | func main() { 56 | cmd := &mainCmd{ 57 | Stdin: os.Stdin, 58 | Stderr: os.Stderr, 59 | Stdout: os.Stdout, 60 | Getenv: os.Getenv, 61 | } 62 | 63 | os.Exit(cmd.Run(os.Args[1:])) 64 | } 65 | 66 | type mainParams struct { 67 | Write bool // -w 68 | List bool // -l 69 | Format format // -format 70 | NoWrapN bool // -no-wrapn 71 | Patterns []string // list of files to process 72 | 73 | ImplicitStdin bool // whether stdin was picked because there were no args 74 | } 75 | 76 | func (p *mainParams) shouldFormat() bool { 77 | switch p.Format { 78 | case formatAuto: 79 | return p.Write 80 | case formatAlways: 81 | return true 82 | case formatNever: 83 | return false 84 | default: 85 | panic(fmt.Sprintf("unknown format %q", p.Format)) 86 | } 87 | } 88 | 89 | func (p *mainParams) Parse(w io.Writer, args []string) error { 90 | flag := flag.NewFlagSet("errtrace", flag.ContinueOnError) 91 | flag.SetOutput(w) 92 | flag.Usage = func() { 93 | logln(w, "usage: errtrace [options] ") 94 | flag.PrintDefaults() 95 | } 96 | 97 | flag.Var(&p.Format, "format", "whether to format ouput; one of: [auto, always, never].\n"+ 98 | "auto is the default and will format if the output is being written to a file.") 99 | flag.BoolVar(&p.Write, "w", false, 100 | "write result to the given source files instead of stdout.") 101 | flag.BoolVar(&p.List, "l", false, 102 | "list files that would be modified without making any changes.") 103 | flag.BoolVar(&p.NoWrapN, "no-wrapn", false, 104 | "wrap multiple return values without using errtrace.WrapN", 105 | ) 106 | 107 | if err := flag.Parse(args); err != nil { 108 | return errtrace.Wrap(err) 109 | } 110 | 111 | p.Patterns = flag.Args() 112 | if len(p.Patterns) == 0 { 113 | // Read file from stdin when there's no args, similar to gofmt. 114 | p.Patterns = []string{"-"} 115 | p.ImplicitStdin = true 116 | } 117 | 118 | return nil 119 | } 120 | 121 | // format specifies whether the output should be gofmt'd. 122 | type format int 123 | 124 | var _ flag.Getter = (*format)(nil) 125 | 126 | const ( 127 | // formatAuto formats the output 128 | // if it's being written to a file 129 | // but not if it's being written to stdout. 130 | // 131 | // This is the default. 132 | formatAuto format = iota 133 | 134 | // formatAlways always formats the output. 135 | formatAlways 136 | 137 | // formatNever never formats the output. 138 | formatNever 139 | ) 140 | 141 | func (f *format) Get() interface{} { 142 | return *f 143 | } 144 | 145 | // IsBoolFlag tells the flag package that plain "-format" is a valid flag. 146 | // When "-format" is used without a value, 147 | // the flag package will call Set("true") on the flag. 148 | func (f *format) IsBoolFlag() bool { 149 | return true 150 | } 151 | 152 | func (f *format) Set(s string) error { 153 | switch s { 154 | case "auto": 155 | *f = formatAuto 156 | case "always", "true": // "true" comes from "-format" without a value 157 | *f = formatAlways 158 | case "never": 159 | *f = formatNever 160 | default: 161 | return errtrace.Wrap(fmt.Errorf("invalid format %q is not one of [auto, always, never]", s)) 162 | } 163 | return nil 164 | } 165 | 166 | func (f *format) String() string { 167 | switch *f { 168 | case formatAuto: 169 | return "auto" 170 | case formatAlways: 171 | return "always" 172 | case formatNever: 173 | return "never" 174 | default: 175 | return fmt.Sprintf("format(%d)", *f) 176 | } 177 | } 178 | 179 | type mainCmd struct { 180 | Stdin io.Reader 181 | Stdout io.Writer 182 | Stderr io.Writer 183 | Getenv func(string) string 184 | 185 | log *log.Logger 186 | } 187 | 188 | func (cmd *mainCmd) Run(args []string) (exitCode int) { 189 | cmd.log = log.New(cmd.Stderr, "", 0) 190 | 191 | if exitCode, ok := cmd.handleToolExec(args); ok { 192 | return exitCode 193 | } 194 | 195 | var p mainParams 196 | if err := p.Parse(cmd.Stderr, args); err != nil { 197 | if errors.Is(err, flag.ErrHelp) { 198 | return 0 199 | } 200 | cmd.log.Printf("errtrace: %+v", err) 201 | return 1 202 | } 203 | 204 | files, err := expandPatterns(p.Patterns) 205 | if err != nil { 206 | cmd.log.Printf("errtrace: %+v", err) 207 | return 1 208 | } 209 | 210 | // Paths will be printed relative to CWD. 211 | // Paths outside it will be printed as-is. 212 | var workDir string 213 | if wd, err := os.Getwd(); err == nil { 214 | workDir = wd + string(filepath.Separator) 215 | } 216 | 217 | for _, file := range files { 218 | display := file 219 | if workDir != "" { 220 | // Not using filepath.Rel 221 | // because we don't want any ".."s in the path. 222 | display = strings.TrimPrefix(file, workDir) 223 | } 224 | if display == "-" { 225 | display = "stdin" 226 | } 227 | 228 | req := fileRequest{ 229 | Format: p.shouldFormat(), 230 | Write: p.Write, 231 | List: p.List, 232 | Filename: display, 233 | Filepath: file, 234 | ImplicitStdin: p.ImplicitStdin, 235 | RewriteOpts: rewriteOpts{ 236 | NoWrapN: p.NoWrapN, 237 | }, 238 | } 239 | if err := cmd.processFile(req); err != nil { 240 | cmd.log.Printf("%s:%+v", display, err) 241 | exitCode = 1 242 | } 243 | } 244 | 245 | return exitCode 246 | } 247 | 248 | // expandPatterns turns the given list of patterns and files 249 | // into a list of paths to files. 250 | // 251 | // Arguments that are already files are returned as-is. 252 | // Arguments that are patterns are expanded using 'go list'. 253 | // As a special case for stdin, "-" is returned as-is. 254 | func expandPatterns(args []string) ([]string, error) { 255 | var files, patterns []string 256 | for _, arg := range args { 257 | if arg == "-" { 258 | files = append(files, arg) 259 | continue 260 | } 261 | 262 | if info, err := os.Stat(arg); err == nil && !info.IsDir() { 263 | files = append(files, arg) 264 | continue 265 | } 266 | 267 | patterns = append(patterns, arg) 268 | } 269 | 270 | if len(patterns) > 0 { 271 | pkgFiles, err := goListFiles(patterns) 272 | if err != nil { 273 | return nil, errtrace.Wrap(fmt.Errorf("go list: %w", err)) 274 | } 275 | 276 | files = append(files, pkgFiles...) 277 | } 278 | 279 | return files, nil 280 | } 281 | 282 | var _execCommand = exec.Command 283 | 284 | func goListFiles(patterns []string) (files []string, err error) { 285 | // The -e flag makes 'go list' include erroneous packages. 286 | // This will even include packages that have all files excluded 287 | // by build constraints if explicitly requested. 288 | // (with "path/to/pkg" instead of "./...") 289 | args := []string{"list", "-find", "-e", "-json"} 290 | args = append(args, patterns...) 291 | 292 | var stderr bytes.Buffer 293 | cmd := _execCommand("go", args...) 294 | cmd.Stderr = &stderr 295 | 296 | stdout, err := cmd.StdoutPipe() 297 | if err != nil { 298 | return nil, errtrace.Wrap(fmt.Errorf("create stdout pipe: %w", err)) 299 | } 300 | 301 | if err := cmd.Start(); err != nil { 302 | return nil, errtrace.Wrap(fmt.Errorf("start command: %w", err)) 303 | } 304 | 305 | type packageInfo struct { 306 | Dir string 307 | GoFiles []string 308 | CgoFiles []string 309 | TestGoFiles []string 310 | XTestGoFiles []string 311 | IgnoredGoFiles []string 312 | } 313 | 314 | decoder := json.NewDecoder(stdout) 315 | for decoder.More() { 316 | var pkg packageInfo 317 | if err := decoder.Decode(&pkg); err != nil { 318 | return nil, errtrace.Wrap(fmt.Errorf("output malformed: %w", err)) 319 | } 320 | 321 | for _, pkgFiles := range [][]string{ 322 | pkg.GoFiles, 323 | pkg.CgoFiles, 324 | pkg.TestGoFiles, 325 | pkg.XTestGoFiles, 326 | pkg.IgnoredGoFiles, 327 | } { 328 | for _, f := range pkgFiles { 329 | files = append(files, filepath.Join(pkg.Dir, f)) 330 | } 331 | } 332 | } 333 | 334 | if err := cmd.Wait(); err != nil { 335 | return nil, errtrace.Wrap(fmt.Errorf("%w\n%s", err, stderr.String())) 336 | } 337 | 338 | return files, nil 339 | } 340 | 341 | type fileRequest struct { 342 | Format bool 343 | Write bool 344 | List bool 345 | RewriteOpts rewriteOpts 346 | 347 | Filename string // name displayed to the user 348 | Filepath string // actual location on disk, or "-" for stdin 349 | 350 | ImplicitStdin bool 351 | } 352 | 353 | type rewriteOpts struct { 354 | NoWrapN bool 355 | } 356 | 357 | // processFile processes a single file. 358 | // This operates in two phases: 359 | // 360 | // First, it walks the AST to find all the places that need to be modified, 361 | // extracting other information as needed. 362 | // 363 | // The collected information is used to pick a package name, 364 | // whether we need an import, etc. and *then* the edits are applied. 365 | func (cmd *mainCmd) processFile(r fileRequest) error { 366 | src, err := cmd.readFile(r) 367 | if err != nil { 368 | return errtrace.Wrap(err) 369 | } 370 | 371 | parsed, err := cmd.parseFile(r.Filename, src, r.RewriteOpts) 372 | if err != nil { 373 | return errtrace.Wrap(err) 374 | } 375 | 376 | for _, line := range parsed.unusedOptouts { 377 | cmd.log.Printf("%s:%d:unused errtrace:skip", r.Filename, line) 378 | } 379 | if r.List { 380 | if len(parsed.inserts) > 0 { 381 | _, err = fmt.Fprintf(cmd.Stdout, "%s\n", r.Filename) 382 | } 383 | return errtrace.Wrap(err) 384 | } 385 | 386 | var out bytes.Buffer 387 | if err := cmd.rewriteFile(parsed, &out); err != nil { 388 | return errtrace.Wrap(err) 389 | } 390 | 391 | outSrc := out.Bytes() 392 | if r.Format { 393 | outSrc, err = gofmt.Source(outSrc) 394 | if err != nil { 395 | return errtrace.Wrap(fmt.Errorf("format: %w", err)) 396 | } 397 | } 398 | 399 | if r.Write { 400 | err = os.WriteFile(r.Filename, outSrc, 0o644) 401 | } else { 402 | _, err = cmd.Stdout.Write(outSrc) 403 | } 404 | return errtrace.Wrap(err) 405 | } 406 | 407 | type parsedFile struct { 408 | src []byte 409 | fset *token.FileSet 410 | file *ast.File 411 | 412 | errtracePkg string 413 | importsErrtrace bool // includes blank imports 414 | inserts []insert 415 | unusedOptouts []int // list of line numbers 416 | } 417 | 418 | func (cmd *mainCmd) parseFile(filename string, src []byte, opts rewriteOpts) (parsedFile, error) { 419 | fset := token.NewFileSet() 420 | f, err := parser.ParseFile(fset, filename, src, parser.ParseComments) 421 | if err != nil { 422 | return parsedFile{}, errtrace.Wrap(err) 423 | } 424 | 425 | errtracePkg := "errtrace" // name to use for errtrace package 426 | var importsErrtrace bool // whether there's any errtrace import, including blank imports 427 | needErrtraceImport := true // whether to add a new import. 428 | for _, imp := range f.Imports { 429 | if imp.Path.Value == `"braces.dev/errtrace"` { 430 | importsErrtrace = true 431 | if imp.Name != nil { 432 | if imp.Name.Name == "_" { 433 | // Can't use a blank import, keep processing imports. 434 | continue 435 | } 436 | // If the file already imports errtrace, 437 | // we'll want to use the name it's imported under. 438 | errtracePkg = imp.Name.Name 439 | } 440 | needErrtraceImport = false 441 | break 442 | } 443 | } 444 | 445 | if needErrtraceImport { 446 | // If the file doesn't import errtrace already, 447 | // do a quick check to find an unused identifier name. 448 | idents := make(map[string]struct{}) 449 | ast.Inspect(f, func(n ast.Node) bool { 450 | if ident, ok := n.(*ast.Ident); ok { 451 | idents[ident.Name] = struct{}{} 452 | } 453 | return true 454 | }) 455 | 456 | // Pick a name that isn't already used. 457 | // Prefer "errtrace" if it's available. 458 | for i := 1; ; i++ { 459 | candidate := errtracePkg 460 | if i > 1 { 461 | candidate += strconv.Itoa(i) 462 | } 463 | 464 | if _, ok := idents[candidate]; !ok { 465 | errtracePkg = candidate 466 | break 467 | } 468 | } 469 | } 470 | 471 | var inserts []insert 472 | w := walker{ 473 | fset: fset, 474 | optouts: optoutLines(fset, f.Comments), 475 | errtracePkg: errtracePkg, 476 | logger: cmd.log, 477 | inserts: &inserts, 478 | opts: opts, 479 | } 480 | ast.Walk(&w, f) 481 | 482 | // Look for unused optouts and warn about them. 483 | var unusedOptouts []int 484 | if len(w.optouts) > 0 { 485 | unusedOptouts = make([]int, 0, len(w.optouts)) 486 | for line, used := range w.optouts { 487 | if used == 0 { 488 | unusedOptouts = append(unusedOptouts, line) 489 | } 490 | } 491 | sort.Ints(unusedOptouts) 492 | } 493 | 494 | // If errtrace isn't imported, but at least one insert was made, 495 | // we'll need to import errtrace. 496 | // Add an import declaration to the file. 497 | if needErrtraceImport && len(inserts) > 0 { 498 | // We want to insert the import after the last existing import. 499 | // If the last import is part of a group, we'll make it part of the group. 500 | // 501 | // import ( 502 | // "foo" 503 | // ) 504 | // // becomes 505 | // import ( 506 | // "foo"; "brace.dev/errtrace" 507 | // ) 508 | // 509 | // Otherwise, we'll add a new import statement group. 510 | // 511 | // import "foo" 512 | // // becomes 513 | // import "foo"; import "brace.dev/errtrace" 514 | var ( 515 | lastImportSpec *ast.ImportSpec 516 | lastImportDecl *ast.GenDecl 517 | ) 518 | for _, imp := range f.Decls { 519 | decl, ok := imp.(*ast.GenDecl) 520 | if !ok || decl.Tok != token.IMPORT { 521 | break 522 | } 523 | lastImportDecl = decl 524 | if decl.Lparen.IsValid() && len(decl.Specs) > 0 { 525 | // There's an import group. 526 | lastImportSpec, _ = decl.Specs[len(decl.Specs)-1].(*ast.ImportSpec) 527 | } 528 | } 529 | 530 | var i insertImportErrtrace 531 | switch { 532 | case lastImportSpec != nil: 533 | // import ("foo") 534 | i.At = lastImportSpec.End() 535 | case lastImportDecl != nil: 536 | // import "foo" 537 | i.At = lastImportDecl.End() 538 | i.AddKeyword = true 539 | default: 540 | // package foo 541 | i.At = f.Name.End() 542 | i.AddKeyword = true 543 | } 544 | inserts = append(inserts, &i) 545 | } 546 | 547 | sort.Slice(inserts, func(i, j int) bool { 548 | return inserts[i].Pos() < inserts[j].Pos() 549 | }) 550 | 551 | return parsedFile{ 552 | src: src, 553 | fset: fset, 554 | file: f, 555 | errtracePkg: errtracePkg, 556 | importsErrtrace: importsErrtrace, 557 | inserts: inserts, 558 | unusedOptouts: unusedOptouts, 559 | }, nil 560 | } 561 | 562 | func (cmd *mainCmd) rewriteFile(f parsedFile, out *bytes.Buffer) error { 563 | var lastOffset int 564 | filePos := f.fset.File(f.file.Pos()) // position information for this file 565 | for _, it := range f.inserts { 566 | offset := filePos.Offset(it.Pos()) 567 | _, _ = out.Write(f.src[lastOffset:offset]) 568 | lastOffset = offset 569 | 570 | switch it := it.(type) { 571 | case *insertImportErrtrace: 572 | _, _ = io.WriteString(out, "; ") 573 | if it.AddKeyword { 574 | _, _ = io.WriteString(out, "import ") 575 | } 576 | 577 | if f.errtracePkg == "errtrace" { 578 | // Don't use named imports if we're using the default name. 579 | fmt.Fprintf(out, "%q", "braces.dev/errtrace") 580 | } else { 581 | fmt.Fprintf(out, "%s %q", f.errtracePkg, "braces.dev/errtrace") 582 | } 583 | 584 | case *insertWrapOpen: 585 | fmt.Fprintf(out, "%s.Wrap", f.errtracePkg) 586 | if it.N > 1 { 587 | fmt.Fprintf(out, "%d", it.N) 588 | } 589 | _, _ = out.WriteString("(") 590 | 591 | case *insertWrapClose: 592 | _, _ = out.WriteString(")") 593 | 594 | case *insertReturnNBlockStart: 595 | vars := nVars("r", it.N) 596 | fmt.Fprintf(out, "{ %s := ", strings.Join(vars, ", ")) 597 | 598 | // Update last offset, so we skip writing the "return", as it's 599 | // followed by the expression we want to assign to. 600 | // The "return" is added in insertReturnNBlockClose. 601 | lastOffset = filePos.Offset(it.SkipReturn) 602 | 603 | case *insertReturnNBlockClose: 604 | vars := nVars("r", it.N) // must match insertReturnNBlockStart 605 | 606 | // Last return is an error, wrap it. 607 | last := &vars[len(vars)-1] 608 | *last = fmt.Sprintf("%s.Wrap(%v)", f.errtracePkg, *last) 609 | 610 | fmt.Fprintf(out, "; return %s }", strings.Join(vars, ", ")) 611 | 612 | case *insertWrapAssign: 613 | // Turns this: 614 | // return 615 | // Into this: 616 | // x, y = errtrace.Wrap(x), errtrace.Wrap(y); return 617 | for i, name := range it.Names { 618 | if i > 0 { 619 | _, _ = out.WriteString(", ") 620 | } 621 | fmt.Fprintf(out, "%s", name) 622 | } 623 | _, _ = out.WriteString(" = ") 624 | for i, name := range it.Names { 625 | if i > 0 { 626 | _, _ = out.WriteString(", ") 627 | } 628 | fmt.Fprintf(out, "%s.Wrap(%s)", f.errtracePkg, name) 629 | } 630 | _, _ = out.WriteString("; ") 631 | 632 | default: 633 | cmd.log.Panicf("unhandled insertion type %T", it) 634 | } 635 | } 636 | _, _ = out.Write(f.src[lastOffset:]) // flush remaining 637 | return nil 638 | } 639 | 640 | func (cmd *mainCmd) readFile(r fileRequest) ([]byte, error) { 641 | if r.Filepath != "-" { 642 | return errtrace.Wrap2(os.ReadFile(r.Filename)) 643 | } 644 | 645 | if r.Write { 646 | return nil, errtrace.Wrap(fmt.Errorf("can't use -w with stdin")) 647 | } 648 | 649 | if r.ImplicitStdin { 650 | // Running with no args reads from stdin, but this is not obvious 651 | // so print a usage hint to stderr, if we think stdin is a TTY. 652 | // Best-effort check for a TTY by looking for a character device. 653 | type statter interface { 654 | Stat() (os.FileInfo, error) 655 | } 656 | if st, ok := cmd.Stdin.(statter); ok { 657 | if fi, err := st.Stat(); err == nil && 658 | fi.Mode()&os.ModeCharDevice == os.ModeCharDevice { 659 | cmd.log.Println("reading from stdin; use '-h' for help") 660 | } 661 | } 662 | } 663 | 664 | return errtrace.Wrap2(io.ReadAll(cmd.Stdin)) 665 | } 666 | 667 | type walker struct { 668 | // Inputs 669 | 670 | fset *token.FileSet // file set for positional information 671 | errtracePkg string // name of the errtrace package 672 | logger *log.Logger 673 | opts rewriteOpts 674 | 675 | optouts map[int]int // map from line to number of uses 676 | 677 | // Outputs 678 | 679 | // inserts is the list of inserts to make. 680 | inserts *[]insert 681 | 682 | // State 683 | 684 | // Function information: 685 | 686 | numReturns int // number of return values 687 | errorIdents []*ast.Ident // identifiers for error return values (only if unnamed returns) 688 | errorObjs map[*ast.Object]struct{} // objects for error return values (only if named returns) 689 | errorIndices []int // indices of error return values 690 | 691 | // Block information: 692 | 693 | // Errors that are wrapped in this block. 694 | alreadyWrapped map[*ast.Object]struct{} 695 | // The logic to detect re-wraps is pretty simplistic 696 | // since it doesn't do any control flow analysis. 697 | // If this becomes a necessity, we can add it later. 698 | } 699 | 700 | var _ ast.Visitor = (*walker)(nil) 701 | 702 | func (t *walker) logf(pos token.Pos, format string, args ...interface{}) { 703 | msg := fmt.Sprintf(format, args...) 704 | t.logger.Printf("%s:%s", t.fset.Position(pos), msg) 705 | } 706 | 707 | func (t *walker) Visit(n ast.Node) ast.Visitor { 708 | switch n := n.(type) { 709 | case *ast.FuncDecl: 710 | return t.funcType(n, n.Type) 711 | 712 | case *ast.BlockStmt: 713 | newT := *t 714 | newT.alreadyWrapped = make(map[*ast.Object]struct{}) 715 | return &newT 716 | 717 | case *ast.AssignStmt: 718 | t.assignStmt(n) 719 | 720 | case *ast.DeferStmt: 721 | // This is a bit inefficient; 722 | // we'll recurse into the DeferStmt's function literal (if any) twice. 723 | t.deferStmt(n) 724 | 725 | case *ast.FuncLit: 726 | return t.funcType(n, n.Type) 727 | 728 | case *ast.ReturnStmt: 729 | return t.returnStmt(n) 730 | } 731 | 732 | return t 733 | } 734 | 735 | func (t *walker) funcType(parent ast.Node, ft *ast.FuncType) ast.Visitor { 736 | // Clear state in case we're recursing into a function literal 737 | // inside a function that returns an error. 738 | newT := *t 739 | newT.errorObjs = nil 740 | newT.errorIdents = nil 741 | newT.errorIndices = nil 742 | newT.numReturns = 0 743 | t = &newT 744 | 745 | // If the function does not return anything, 746 | // we still need to recurse into any function literals. 747 | // Just return this visitor to continue recursing. 748 | if ft.Results == nil { 749 | return t 750 | } 751 | 752 | // If the function has return values, 753 | // we need to consider the following cases: 754 | // 755 | // - no error return value 756 | // - unnamed error return 757 | // - named error return 758 | var ( 759 | objs []*ast.Object // objects of error return values 760 | idents []*ast.Ident // identifiers of named error return values 761 | indices []int // indices of error return values 762 | count int // total number of return values 763 | // Invariants: 764 | // len(indices) <= count 765 | // len(names) == 0 || len(names) == len(indices) 766 | ) 767 | for _, field := range ft.Results.List { 768 | isError := isIdent(field.Type, "error") 769 | 770 | // field.Names is nil for unnamed return values. 771 | // Either all returns are named or none are. 772 | if len(field.Names) > 0 { 773 | for _, name := range field.Names { 774 | if isError { 775 | objs = append(objs, name.Obj) 776 | idents = append(idents, name) 777 | indices = append(indices, count) 778 | } 779 | count++ 780 | } 781 | } else { 782 | if isError { 783 | indices = append(indices, count) 784 | } 785 | count++ 786 | } 787 | } 788 | 789 | // If there are no error return values, 790 | // recurse to look for function literals. 791 | if len(indices) == 0 { 792 | return t 793 | } 794 | 795 | // If there's a single error return, 796 | // and this function is a method named "Unwrap", 797 | // don't wrap it so it plays nice with errors.Unwrap. 798 | if len(indices) == 1 { 799 | if decl, ok := parent.(*ast.FuncDecl); ok { 800 | if decl.Recv != nil && isIdent(decl.Name, "Unwrap") { 801 | return t 802 | } 803 | } 804 | } 805 | 806 | newT.errorObjs = setOf(objs) 807 | newT.errorIdents = idents 808 | newT.errorIndices = indices 809 | newT.numReturns = count 810 | return &newT 811 | } 812 | 813 | func (t *walker) returnStmt(n *ast.ReturnStmt) ast.Visitor { 814 | // Doesn't return errors. Continue recursing. 815 | if len(t.errorIndices) == 0 { 816 | return t 817 | } 818 | 819 | // Naked return. 820 | // We want to add assignments to the named return values. 821 | if n.Results == nil { 822 | if t.optout(n.Pos()) { 823 | return nil 824 | } 825 | 826 | // Ignore errors that have already been wrapped. 827 | names := make([]string, 0, len(t.errorIndices)) 828 | for _, ident := range t.errorIdents { 829 | if _, ok := t.alreadyWrapped[ident.Obj]; ok { 830 | continue 831 | } 832 | names = append(names, ident.Name) 833 | } 834 | 835 | if len(names) > 0 { 836 | *t.inserts = append(*t.inserts, &insertWrapAssign{ 837 | Names: names, 838 | Before: n.Pos(), 839 | }) 840 | } 841 | 842 | return nil 843 | } 844 | 845 | // Return with multiple return values being automatically expanded 846 | // E.g., 847 | // func foo() (int, error) { 848 | // return bar() 849 | // } 850 | // This needs to become: 851 | // func foo() (int, error) { 852 | // return Wrap2(bar()) 853 | // } 854 | // This is only supported if numReturns <= 6 and only the last return value is an error. 855 | if len(n.Results) == 1 && t.numReturns > 1 { 856 | if _, ok := n.Results[0].(*ast.CallExpr); !ok { 857 | t.logf(n.Pos(), "skipping function with incorrect number of return values: got %d, want %d", len(n.Results), t.numReturns) 858 | return t 859 | } 860 | 861 | t.wrapReturnCall(t.numReturns, n) 862 | return t 863 | } 864 | 865 | for _, idx := range t.errorIndices { 866 | t.wrapExpr(1, n.Results[idx]) 867 | } 868 | 869 | return t 870 | } 871 | 872 | func (t *walker) assignStmt(n *ast.AssignStmt) { 873 | // Record assignments to named error return values. 874 | // We'll use this to detect re-wraps. 875 | for i, lhs := range n.Lhs { 876 | ident, ok := lhs.(*ast.Ident) 877 | if !ok { 878 | continue // not an identifier 879 | } 880 | 881 | _, ok = t.errorObjs[ident.Obj] 882 | if !ok { 883 | continue // not an error assignment 884 | } 885 | 886 | if i < len(n.Rhs) && t.isErrtraceWrap(n.Rhs[i]) { 887 | // Assigning to a named error return value. 888 | t.alreadyWrapped[ident.Obj] = struct{}{} 889 | } 890 | } 891 | } 892 | 893 | func (t *walker) deferStmt(n *ast.DeferStmt) { 894 | // If there's a defer statement with a function literal, 895 | // *and* this function has named return values, 896 | // we'll want to watch for assignments to those return values. 897 | 898 | if len(t.errorIdents) == 0 { 899 | return // no named returns 900 | } 901 | 902 | funcLit, ok := n.Call.Fun.(*ast.FuncLit) 903 | if !ok { 904 | return // not a function literal 905 | } 906 | 907 | ast.Inspect(funcLit.Body, func(n ast.Node) bool { 908 | assign, ok := n.(*ast.AssignStmt) 909 | if !ok { 910 | return true 911 | } 912 | for i, lhs := range assign.Lhs { 913 | ident, ok := lhs.(*ast.Ident) 914 | if !ok { 915 | continue // not an identifier 916 | } 917 | 918 | if _, ok := t.errorObjs[ident.Obj]; !ok { 919 | continue // not an error assignment 920 | } 921 | 922 | // Assignment to an error return value. 923 | // This will take one of the following forms: 924 | // 925 | // (1) x, y, err = f1(), f2(), f3() 926 | // (2) x, y, err = f() // returns multiple values 927 | // (3) x, err, z = f() // returns multiple values 928 | // 929 | // For (1), we can wrap just the function 930 | // that returns the error. (f3 in this case) 931 | // 932 | // For (2), we can use a WrapN function 933 | // to wrap the entire function call. 934 | // 935 | // For (3), we could use an inline function call, 936 | // but that's not implemented yet. 937 | 938 | if i < len(assign.Rhs) && len(assign.Lhs) == len(assign.Rhs) { 939 | // Case (1): 940 | // Wrap the function that returns the error. 941 | t.wrapExpr(1, assign.Rhs[i]) 942 | } else if i == len(assign.Lhs)-1 && len(assign.Rhs) == 1 { 943 | // Case (2): 944 | // Wrap the entire function call. 945 | t.wrapExpr(len(assign.Lhs), assign.Rhs[0]) 946 | } else { 947 | t.logf(assign.Pos(), "skipping assignment: error is not the last return value") 948 | } 949 | } 950 | 951 | return true 952 | }) 953 | } 954 | 955 | func (t *walker) wrapReturnCall(n int, ret *ast.ReturnStmt) { 956 | // Common validation 957 | switch { 958 | case len(t.errorIndices) != 1: 959 | t.logf(ret.Pos(), "skipping function with multiple error returns") 960 | return 961 | case t.errorIndices[0] != t.numReturns-1: 962 | t.logf(ret.Pos(), "skipping function with non-final error return") 963 | return 964 | case t.isErrtraceWrap(ret.Results[0]): 965 | return 966 | case t.optout(ret.Pos()): 967 | return 968 | } 969 | 970 | if t.opts.NoWrapN { 971 | *t.inserts = append(*t.inserts, 972 | &insertReturnNBlockStart{N: n, Before: ret.Pos(), SkipReturn: ret.Results[0].Pos()}, 973 | &insertReturnNBlockClose{N: n, After: ret.End()}, 974 | ) 975 | return 976 | } 977 | 978 | if n > 6 { 979 | t.logf(ret.Pos(), "skipping function with too many return values") 980 | return 981 | } 982 | 983 | t.wrapExpr(n, ret.Results[0]) 984 | } 985 | 986 | func (t *walker) wrapExpr(n int, expr ast.Expr) { 987 | switch { 988 | case t.isErrtraceWrap(expr): 989 | return // already wrapped 990 | 991 | case isIdent(expr, "nil"): 992 | // Optimization: ignore if it's "nil". 993 | return 994 | } 995 | 996 | if t.optout(expr.Pos()) { 997 | return 998 | } 999 | 1000 | *t.inserts = append(*t.inserts, 1001 | &insertWrapOpen{N: n, Before: expr.Pos()}, 1002 | &insertWrapClose{After: expr.End()}, 1003 | ) 1004 | } 1005 | 1006 | // Detects if an expression is in the form errtrace.Wrap(e) or errtrace.Wrap{n}(e). 1007 | func (t *walker) isErrtraceWrap(expr ast.Expr) bool { 1008 | call, ok := expr.(*ast.CallExpr) 1009 | if !ok { 1010 | return false 1011 | } 1012 | 1013 | // Ignore if it's already errtrace.Wrap(...). 1014 | sel, ok := call.Fun.(*ast.SelectorExpr) 1015 | if !ok { 1016 | return false 1017 | } 1018 | 1019 | if !isIdent(sel.X, t.errtracePkg) { 1020 | return false 1021 | } 1022 | 1023 | return strings.HasPrefix(sel.Sel.Name, "Wrap") || 1024 | sel.Sel.Name == "New" || 1025 | sel.Sel.Name == "Errorf" 1026 | } 1027 | 1028 | // optout reports whether the line at the given position 1029 | // is opted out of tracing, incrementing uses if so. 1030 | func (t *walker) optout(pos token.Pos) bool { 1031 | line := t.fset.Position(pos).Line 1032 | _, ok := t.optouts[line] 1033 | if ok { 1034 | t.optouts[line]++ 1035 | } 1036 | return ok 1037 | } 1038 | 1039 | // insert is a request to add something to the source code. 1040 | type insert interface { 1041 | Pos() token.Pos // position to insert at 1042 | String() string // description for debugging 1043 | } 1044 | 1045 | // insertImportErrtrace adds an import declaration to the file 1046 | // right after the given node. 1047 | type insertImportErrtrace struct { 1048 | AddKeyword bool // whether the "import" keyword should be added 1049 | At token.Pos // position to insert at 1050 | } 1051 | 1052 | func (e *insertImportErrtrace) Pos() token.Pos { 1053 | return e.At 1054 | } 1055 | 1056 | func (e *insertImportErrtrace) String() string { 1057 | if e.AddKeyword { 1058 | return "add import statement" 1059 | } 1060 | return "add import" 1061 | } 1062 | 1063 | // insertWrapOpen adds a errtrace.Wrap call before an expression. 1064 | // 1065 | // foo() -> errtrace.Wrap(foo() 1066 | // 1067 | // This needs a corresponding insertWrapClose to close the call. 1068 | type insertWrapOpen struct { 1069 | // N specifies the number of parameters the Wrap function takes. 1070 | // Defaults to 1. 1071 | N int 1072 | 1073 | Before token.Pos // position to insert before 1074 | } 1075 | 1076 | func (e *insertWrapOpen) Pos() token.Pos { 1077 | return e.Before 1078 | } 1079 | 1080 | func (e *insertWrapOpen) String() string { 1081 | return "" 1082 | } 1083 | 1084 | // insertWrapClose closes a errtrace.Wrap call. 1085 | // 1086 | // foo() -> foo()) 1087 | // 1088 | // This needs a corresponding insertWrapOpen to open the call. 1089 | type insertWrapClose struct { 1090 | After token.Pos // position to insert after 1091 | } 1092 | 1093 | func (e *insertWrapClose) Pos() token.Pos { 1094 | return e.After 1095 | } 1096 | 1097 | func (e *insertWrapClose) String() string { 1098 | return "" 1099 | } 1100 | 1101 | type insertReturnNBlockStart struct { 1102 | N int // number of returns 1103 | Before token.Pos // position to insert before 1104 | SkipReturn token.Pos // skipped content, used to drop "return" 1105 | } 1106 | 1107 | func (i *insertReturnNBlockStart) Pos() token.Pos { 1108 | return i.Before 1109 | } 1110 | 1111 | func (i *insertReturnNBlockStart) String() string { 1112 | return "" 1113 | } 1114 | 1115 | type insertReturnNBlockClose struct { 1116 | N int // number of returns 1117 | After token.Pos // position to insert after 1118 | } 1119 | 1120 | func (i *insertReturnNBlockClose) Pos() token.Pos { 1121 | return i.After 1122 | } 1123 | 1124 | func (i *insertReturnNBlockClose) String() string { 1125 | return "" 1126 | } 1127 | 1128 | // insertWrapAssign wraps a variable in-place with an errtrace.Wrap call. 1129 | // This is used for naked returns in functions with named return values 1130 | // 1131 | // For example, it will turn this: 1132 | // 1133 | // func foo() (err error) { 1134 | // // ... 1135 | // return 1136 | // } 1137 | // 1138 | // Into this: 1139 | // 1140 | // func foo() (err error) { 1141 | // // ... 1142 | // err = errtrace.Wrap(err); return 1143 | // } 1144 | type insertWrapAssign struct { 1145 | Names []string // names of variables to wrap 1146 | Before token.Pos // position to insert before 1147 | } 1148 | 1149 | func (e *insertWrapAssign) Pos() token.Pos { 1150 | return e.Before 1151 | } 1152 | 1153 | func (e *insertWrapAssign) String() string { 1154 | return fmt.Sprintf("assign errors before %v", e.Names) 1155 | } 1156 | 1157 | func isIdent(expr ast.Expr, name string) bool { 1158 | ident, ok := expr.(*ast.Ident) 1159 | return ok && ident.Name == name 1160 | } 1161 | 1162 | func setOf[T comparable](xs []T) map[T]struct{} { 1163 | if len(xs) == 0 { 1164 | return nil 1165 | } 1166 | 1167 | set := make(map[T]struct{}) 1168 | for _, x := range xs { 1169 | set[x] = struct{}{} 1170 | } 1171 | return set 1172 | } 1173 | 1174 | var _errtraceSkip = regexp.MustCompile(`(^| )//errtrace:skip($|[ \(])`) 1175 | 1176 | // optoutLines returns the line numbers 1177 | // that have a comment in the form: 1178 | // 1179 | // //errtrace:skip 1180 | // 1181 | // It may be followed by other text, e.g., 1182 | // 1183 | // //errtrace:skip // for reasons 1184 | func optoutLines( 1185 | fset *token.FileSet, 1186 | comments []*ast.CommentGroup, 1187 | ) map[int]int { 1188 | lines := make(map[int]int) 1189 | for _, cg := range comments { 1190 | if len(cg.List) > 1 { 1191 | // skip multiline comments which are full line comments, not tied to a return. 1192 | continue 1193 | } 1194 | 1195 | c := cg.List[0] 1196 | if _errtraceSkip.MatchString(c.Text) { 1197 | lineNo := fset.Position(c.Pos()).Line 1198 | lines[lineNo] = 0 1199 | } 1200 | } 1201 | return lines 1202 | } 1203 | 1204 | func nVars(prefix string, n int) []string { 1205 | vars := make([]string, n) 1206 | for i := 0; i < n; i++ { 1207 | vars[i] = fmt.Sprintf("%s%d", prefix, i+1) 1208 | } 1209 | return vars 1210 | } 1211 | 1212 | func logln(w io.Writer, s string) { 1213 | // logging writes are best-effort 1214 | _, _ = fmt.Fprintln(w, s) 1215 | } 1216 | 1217 | func logf(w io.Writer, format string, a ...any) { 1218 | // logging writes are best-effort 1219 | _, _ = fmt.Fprintf(w, format, a...) 1220 | } 1221 | -------------------------------------------------------------------------------- /cmd/errtrace/main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "flag" 6 | "fmt" 7 | "go/parser" 8 | "go/token" 9 | "io" 10 | "os" 11 | "os/exec" 12 | "path" 13 | "path/filepath" 14 | "reflect" 15 | "sort" 16 | "strconv" 17 | "strings" 18 | "sync" 19 | "testing" 20 | 21 | "braces.dev/errtrace" 22 | "braces.dev/errtrace/internal/diff" 23 | ) 24 | 25 | func TestErrHelp(t *testing.T) { 26 | exitCode := (&mainCmd{ 27 | Stdin: strings.NewReader(""), 28 | Stdout: testWriter{t}, 29 | Stderr: testWriter{t}, 30 | }).Run([]string{"-h"}) 31 | if want := 0; exitCode != want { 32 | t.Errorf("exit code = %d, want %d", exitCode, want) 33 | } 34 | } 35 | 36 | // TestGolden runs errtrace on all .go files inside testdata/golden, 37 | // and compares the output to the corresponding .golden file. 38 | // Files must match exactly. 39 | // 40 | // If log messages are expected associated with specific lines, 41 | // they can be included in the source and the .golden file 42 | // in the format: 43 | // 44 | // foo() // want:"log message" 45 | // 46 | // The log message will be matched against the output of errtrace on stderr. 47 | // The string must be a valid Go string literal. 48 | func TestGolden(t *testing.T) { 49 | files, err := filepath.Glob("testdata/golden/*.go") 50 | if err != nil { 51 | t.Fatal(err) 52 | } 53 | 54 | for _, file := range files { 55 | name := strings.TrimSuffix(filepath.Base(file), ".go") 56 | t.Run(name, func(t *testing.T) { 57 | testGoldenFile(t, file) 58 | }) 59 | } 60 | } 61 | 62 | func testGoldenFile(t *testing.T, file string) { 63 | giveSrc, err := os.ReadFile(file) 64 | if err != nil { 65 | t.Fatal(err) 66 | } 67 | 68 | wantSrc, err := os.ReadFile(file + ".golden") 69 | if err != nil { 70 | t.Fatal("Bad test: missing .golden file:", err) 71 | } 72 | 73 | type runTests struct { 74 | noOptions bool 75 | optNoWrapN bool 76 | } 77 | run := runTests{true, true} // by default, run all tests. 78 | if strings.Contains(string(giveSrc), "@runIf options=") { 79 | run = runTests{noOptions: true} 80 | } 81 | if strings.Contains(string(giveSrc), "@runIf options=no-wrapn") { 82 | run = runTests{optNoWrapN: true} 83 | } 84 | 85 | if run.noOptions { 86 | t.Run("no options", func(t *testing.T) { 87 | testGoldenContents(t, nil /* additionalFlags */, file, giveSrc, wantSrc) 88 | }) 89 | } 90 | 91 | if run.optNoWrapN { 92 | t.Run("option no-wrapn", func(t *testing.T) { 93 | testGoldenContents(t, []string{"-no-wrapn"}, file, giveSrc, wantSrc) 94 | }) 95 | } 96 | } 97 | 98 | func testGoldenContents(t *testing.T, additionalFlags []string, file string, giveSrc, wantSrc []byte) { 99 | wantLogs, err := extractLogs(giveSrc) 100 | if err != nil { 101 | t.Fatal(err) 102 | } 103 | 104 | // Copy into a temporary directory so that we can run with -w. 105 | srcPath := filepath.Join(t.TempDir(), "src.go") 106 | if err := os.WriteFile(srcPath, []byte(giveSrc), 0o600); err != nil { 107 | t.Fatal(err) 108 | } 109 | 110 | // If the source is expected to change, 111 | // also verify that running with -l lists the file. 112 | // Otherwise, verify that running with -l does not list the file. 113 | t.Run("list", func(t *testing.T) { 114 | var out bytes.Buffer 115 | exitCode := (&mainCmd{ 116 | Stdout: &out, 117 | Stderr: testWriter{t}, 118 | }).Run(append(additionalFlags, "-l", srcPath)) 119 | if want := 0; exitCode != want { 120 | t.Errorf("exit code = %d, want %d", exitCode, want) 121 | } 122 | 123 | if bytes.Equal(giveSrc, wantSrc) { 124 | if want, got := "", out.String(); got != want { 125 | t.Errorf("expected no output, got:\n%s", indent(got)) 126 | } 127 | } else { 128 | if want, got := srcPath+"\n", out.String(); got != want { 129 | t.Errorf("got:\n%s\nwant:\n%s\ndiff:\n%s", indent(got), indent(want), indent(diff.Lines(want, got))) 130 | } 131 | } 132 | }) 133 | 134 | var stdout, stderr bytes.Buffer 135 | defer func() { 136 | if t.Failed() { 137 | t.Logf("stdout:\n%s", indent(stdout.String())) 138 | t.Logf("stderr:\n%s", indent(stderr.String())) 139 | } 140 | }() 141 | 142 | exitCode := (&mainCmd{ 143 | Stdout: &stdout, // We don't care about stdout. 144 | Stderr: &stderr, 145 | }).Run(append(additionalFlags, "-format=never", "-w", srcPath)) 146 | 147 | if want := 0; exitCode != want { 148 | t.Errorf("exit code = %d, want %d", exitCode, want) 149 | } 150 | 151 | gotSrc, err := os.ReadFile(srcPath) 152 | if err != nil { 153 | t.Fatal(err) 154 | } 155 | 156 | if want, got := string(wantSrc), string(gotSrc); got != want { 157 | t.Errorf("want output:\n%s\ngot:\n%s\ndiff:\n%s", indent(want), indent(got), indent(diff.Lines(want, got))) 158 | } 159 | 160 | // Check that the log messages match. 161 | gotLogs, err := parseLogOutput(srcPath, stderr.String()) 162 | if err != nil { 163 | t.Fatal(err) 164 | } 165 | 166 | if diff := diff.Diff(wantLogs, gotLogs); diff != "" { 167 | t.Errorf("log messages differ:\n%s", indent(diff)) 168 | } 169 | 170 | // Re-run on the output of the first run. 171 | // This should be a no-op. 172 | t.Run("idempotent", func(t *testing.T) { 173 | var got bytes.Buffer 174 | exitCode := (&mainCmd{ 175 | Stderr: testWriter{t}, 176 | Stdout: &got, 177 | }).Run([]string{srcPath}) 178 | 179 | if want := 0; exitCode != want { 180 | t.Errorf("exit code = %d, want %d", exitCode, want) 181 | } 182 | 183 | gotSrc := got.String() 184 | if want, got := string(wantSrc), gotSrc; got != want { 185 | t.Errorf("want output:\n%s\ngot:\n%s\ndiff:\n%s", indent(want), indent(got), indent(diff.Lines(want, got))) 186 | } 187 | }) 188 | 189 | // Create a Go package with the source file, 190 | // and run errtrace on the package. 191 | t.Run("package", func(t *testing.T) { 192 | dir := t.TempDir() 193 | 194 | file := filepath.Join(dir, filepath.Base(file)) 195 | if err := os.WriteFile(file, giveSrc, 0o600); err != nil { 196 | t.Fatal(err) 197 | } 198 | 199 | gomod := filepath.Join(dir, "go.mod") 200 | pkgdir := strings.TrimSuffix(filepath.Base(file), ".go") 201 | importPath := path.Join("example.com/test", pkgdir) 202 | if err := os.WriteFile(gomod, []byte(fmt.Sprintf("module %s\ngo 1.21\n", importPath)), 0o600); err != nil { 203 | t.Fatal(err) 204 | } 205 | 206 | restore := chdir(t, dir) 207 | var stderr bytes.Buffer 208 | exitCode := (&mainCmd{ 209 | Stderr: &stderr, 210 | Stdout: testWriter{t}, 211 | }).Run(append(additionalFlags, "-format=never", "-w", ".")) 212 | if want := 0; exitCode != want { 213 | t.Errorf("exit code = %d, want %d", exitCode, want) 214 | } 215 | restore() 216 | 217 | gotSrc, err := os.ReadFile(file) 218 | if err != nil { 219 | t.Fatal(err) 220 | } 221 | 222 | if want, got := string(wantSrc), string(gotSrc); got != want { 223 | t.Errorf("want output:\n%s\ngot:\n%s\ndiff:\n%s", indent(want), indent(got), indent(diff.Lines(want, got))) 224 | } 225 | 226 | // Check that the log messages match. 227 | gotLogs, err := parseLogOutput(file, stderr.String()) 228 | if err != nil { 229 | t.Fatal(err) 230 | } 231 | 232 | if diff := diff.Diff(wantLogs, gotLogs); diff != "" { 233 | t.Errorf("log messages differ:\n%s", indent(diff)) 234 | } 235 | }) 236 | } 237 | 238 | func TestParseMainParams(t *testing.T) { 239 | tests := []struct { 240 | name string 241 | give []string 242 | want mainParams 243 | wantErr []string // non-empty if we expect an error 244 | }{ 245 | { 246 | name: "stdin", 247 | want: mainParams{ 248 | Patterns: []string{"-"}, 249 | ImplicitStdin: true, 250 | }, 251 | }, 252 | } 253 | 254 | for _, tt := range tests { 255 | t.Run(tt.name, func(t *testing.T) { 256 | var got mainParams 257 | err := got.Parse(testWriter{t}, tt.give) 258 | 259 | if len(tt.wantErr) > 0 { 260 | if err == nil { 261 | t.Fatalf("expected error, got nil") 262 | } 263 | 264 | for _, want := range tt.wantErr { 265 | if got := err.Error(); !strings.Contains(got, want) { 266 | t.Errorf("error %q does not contain %q", got, want) 267 | } 268 | } 269 | 270 | return 271 | } 272 | 273 | if want, got := tt.want, got; !reflect.DeepEqual(want, got) { 274 | t.Errorf("got %v, want %v", got, want) 275 | } 276 | }) 277 | } 278 | } 279 | 280 | func TestParseFormatFlag(t *testing.T) { 281 | tests := []struct { 282 | name string 283 | give []string 284 | want format 285 | }{ 286 | { 287 | name: "default", 288 | want: formatAuto, 289 | }, 290 | { 291 | name: "auto explicit", 292 | give: []string{"-format=auto"}, 293 | want: formatAuto, 294 | }, 295 | { 296 | name: "always", 297 | give: []string{"-format=always"}, 298 | want: formatAlways, 299 | }, 300 | { 301 | name: "always explicit", 302 | give: []string{"-format"}, 303 | want: formatAlways, 304 | }, 305 | { 306 | name: "never", 307 | give: []string{"-format=never"}, 308 | want: formatNever, 309 | }, 310 | } 311 | 312 | for _, tt := range tests { 313 | t.Run(tt.name, func(t *testing.T) { 314 | flag := flag.NewFlagSet(t.Name(), flag.ContinueOnError) 315 | flag.SetOutput(testWriter{t}) 316 | 317 | var got format 318 | flag.Var(&got, "format", "") 319 | if err := flag.Parse(tt.give); err != nil { 320 | t.Fatal(err) 321 | } 322 | 323 | if want, got := tt.want, got; got != want { 324 | t.Errorf("got %v, want %v", got, want) 325 | } 326 | }) 327 | } 328 | } 329 | 330 | func TestFormatFlagError(t *testing.T) { 331 | var stderr bytes.Buffer 332 | exitCode := (&mainCmd{ 333 | Stderr: &stderr, 334 | Stdout: testWriter{t}, 335 | }).Run([]string{"-format=unknown"}) 336 | if want := 1; exitCode != want { 337 | t.Errorf("exit code = %d, want %d", exitCode, want) 338 | } 339 | 340 | if want, got := `invalid format "unknown"`, stderr.String(); !strings.Contains(got, want) { 341 | t.Errorf("stderr = %q, want %q", got, want) 342 | } 343 | } 344 | 345 | func TestFormatFlagString(t *testing.T) { 346 | tests := []struct { 347 | give format 348 | want string 349 | }{ 350 | {formatAuto, "auto"}, 351 | {formatAlways, "always"}, 352 | {formatNever, "never"}, 353 | {format(999), "format(999)"}, 354 | } 355 | 356 | for _, tt := range tests { 357 | t.Run(fmt.Sprintf("%d", tt.give), func(t *testing.T) { 358 | if want, got := tt.want, tt.give.String(); got != want { 359 | t.Errorf("got %q, want %q", got, want) 360 | } 361 | }) 362 | } 363 | } 364 | 365 | func TestShouldFormat(t *testing.T) { 366 | tests := []struct { 367 | name string 368 | give mainParams 369 | want bool 370 | }{ 371 | {"auto/no write", mainParams{Format: formatAuto}, false}, 372 | {"auto/write", mainParams{Format: formatAuto, Write: true}, true}, 373 | {"always", mainParams{Format: formatAlways}, true}, 374 | {"never", mainParams{Format: formatNever}, false}, 375 | } 376 | 377 | for _, tt := range tests { 378 | t.Run(tt.name, func(t *testing.T) { 379 | if want, got := tt.want, tt.give.shouldFormat(); got != want { 380 | t.Errorf("got %v, want %v", got, want) 381 | } 382 | }) 383 | } 384 | 385 | t.Run("unknown", func(t *testing.T) { 386 | defer func() { 387 | if err := recover(); err == nil { 388 | t.Fatal("no panic") 389 | } 390 | }() 391 | 392 | (&mainParams{Format: format(999)}).shouldFormat() 393 | }) 394 | } 395 | 396 | // -format=auto should format the file if used with -w, 397 | // and not format the file if used without -w. 398 | func TestFormatAuto(t *testing.T) { 399 | give := strings.Join([]string{ 400 | "package foo", 401 | `import "errors"`, 402 | "func foo() error {", 403 | ` return errors.New("foo")`, 404 | "}", 405 | }, "\n") 406 | 407 | wantUnformatted := strings.Join([]string{ 408 | "package foo", 409 | `import "errors"; import "braces.dev/errtrace"`, 410 | "func foo() error {", 411 | ` return errtrace.Wrap(errors.New("foo"))`, 412 | "}", 413 | }, "\n") 414 | 415 | wantFormatted := strings.Join([]string{ 416 | "package foo", 417 | "", 418 | `import "errors"`, 419 | `import "braces.dev/errtrace"`, 420 | "", 421 | "func foo() error {", 422 | ` return errtrace.Wrap(errors.New("foo"))`, 423 | "}", 424 | "", 425 | }, "\n") 426 | 427 | t.Run("write", func(t *testing.T) { 428 | srcPath := filepath.Join(t.TempDir(), "src.go") 429 | if err := os.WriteFile(srcPath, []byte(give), 0o600); err != nil { 430 | t.Fatal(err) 431 | } 432 | 433 | exitCode := (&mainCmd{ 434 | Stdout: testWriter{t}, 435 | Stderr: testWriter{t}, 436 | }).Run([]string{"-w", srcPath}) 437 | if want := 0; exitCode != want { 438 | t.Errorf("exit code = %d, want %d", exitCode, want) 439 | } 440 | 441 | bs, err := os.ReadFile(srcPath) 442 | if err != nil { 443 | t.Fatal(err) 444 | } 445 | 446 | if want, got := wantFormatted, string(bs); got != want { 447 | t.Errorf("got:\n%s\nwant:\n%s\ndiff:\n%s", indent(got), indent(want), indent(diff.Lines(want, got))) 448 | } 449 | }) 450 | 451 | t.Run("stdout", func(t *testing.T) { 452 | srcPath := filepath.Join(t.TempDir(), "src.go") 453 | if err := os.WriteFile(srcPath, []byte(give), 0o600); err != nil { 454 | t.Fatal(err) 455 | } 456 | 457 | var out bytes.Buffer 458 | exitCode := (&mainCmd{ 459 | Stdout: &out, 460 | Stderr: testWriter{t}, 461 | }).Run([]string{srcPath}) 462 | if want := 0; exitCode != want { 463 | t.Errorf("exit code = %d, want %d", exitCode, want) 464 | } 465 | 466 | if want, got := wantUnformatted, out.String(); got != want { 467 | t.Errorf("got:\n%s\nwant:\n%s\ndiff:\n%s", indent(got), indent(want), indent(diff.Lines(want, got))) 468 | } 469 | }) 470 | 471 | t.Run("stdin", func(t *testing.T) { 472 | var out bytes.Buffer 473 | exitCode := (&mainCmd{ 474 | Stdin: strings.NewReader(give), 475 | Stdout: &out, 476 | Stderr: testWriter{t}, 477 | }).Run(nil /* args */) // empty args implies stdin 478 | if want := 0; exitCode != want { 479 | t.Errorf("exit code = %d, want %d", exitCode, want) 480 | } 481 | if want, got := wantUnformatted, out.String(); want != got { 482 | t.Errorf("got:\n%s\nwant:\n%s\ndiff:\n%s", indent(got), indent(want), indent(diff.Lines(want, got))) 483 | } 484 | }) 485 | 486 | t.Run("stdin incompatible with write", func(t *testing.T) { 487 | var err, out bytes.Buffer 488 | exitCode := (&mainCmd{ 489 | Stdin: strings.NewReader("unused"), 490 | Stdout: &out, 491 | Stderr: &err, 492 | }).Run([]string{"-w", "-"}) 493 | if want := 1; exitCode != want { 494 | t.Errorf("exit code = %d, want %d", exitCode, want) 495 | } 496 | if want, got := "", out.String(); want != got { 497 | t.Errorf("stdout = %q, want %q", got, want) 498 | } 499 | if want, got := "stdin:can't use -w with stdin\n", err.String(); !strings.Contains(got, want) { 500 | t.Errorf("stderr = %q, does not contain %q", got, want) 501 | } 502 | if want, got := "(*mainCmd).readFile", err.String(); !strings.Contains(got, want) { 503 | t.Errorf("stderr = %q, does not contain %q", got, want) 504 | } 505 | }) 506 | } 507 | 508 | func TestListFlag(t *testing.T) { 509 | uninstrumentedSource := strings.Join([]string{ 510 | "package foo", 511 | `import "errors"`, 512 | "func foo() error {", 513 | ` return errors.New("foo")`, 514 | "}", 515 | }, "\n") 516 | 517 | instrumentedSource := strings.Join([]string{ 518 | "package foo", 519 | `import "errors"; import "braces.dev/errtrace"`, 520 | "func foo() error {", 521 | ` return errtrace.Wrap(errors.New("foo"))`, 522 | "}", 523 | }, "\n") 524 | 525 | dir := t.TempDir() 526 | 527 | instrumented := filepath.Join(dir, "instrumented.go") 528 | if err := os.WriteFile(instrumented, []byte(instrumentedSource), 0o600); err != nil { 529 | t.Fatal(err) 530 | } 531 | 532 | uninstrumented := filepath.Join(dir, "uninstrumented.go") 533 | if err := os.WriteFile(uninstrumented, []byte(uninstrumentedSource), 0o600); err != nil { 534 | t.Fatal(err) 535 | } 536 | 537 | var out bytes.Buffer 538 | exitCode := (&mainCmd{ 539 | Stdout: &out, 540 | Stderr: testWriter{t}, 541 | }).Run([]string{"-l", uninstrumented, instrumented}) 542 | if want := 0; exitCode != want { 543 | t.Errorf("exit code = %d, want %d", exitCode, want) 544 | } 545 | 546 | // Only the uninstrumented file should be listed. 547 | if want, got := uninstrumented+"\n", out.String(); got != want { 548 | t.Errorf("got:\n%s\nwant:\n%s\ndiff:\n%s", indent(got), indent(want), indent(diff.Lines(want, got))) 549 | } 550 | } 551 | 552 | func TestOptoutLines(t *testing.T) { 553 | fset := token.NewFileSet() 554 | f, err := parser.ParseFile(fset, "", `package foo 555 | func _() { 556 | _ = "line 3" //errtrace:skip 557 | _ = "this line not counted" // errtrace:skip 558 | _ = "line 5" //errtrace:skip // has a reason 559 | _ = "line 6" //nolint:somelinter //errtrace:skip // stuff 560 | }`, parser.ParseComments) 561 | if err != nil { 562 | t.Fatal(err) 563 | } 564 | 565 | var got []int 566 | for line := range optoutLines(fset, f.Comments) { 567 | got = append(got, line) 568 | } 569 | sort.Ints(got) 570 | 571 | if want := []int{3, 5, 6}; !reflect.DeepEqual(want, got) { 572 | t.Errorf("got: %v\nwant: %v\ndiff:\n%s", got, want, diff.Diff(want, got)) 573 | } 574 | } 575 | 576 | func TestExpandPatterns(t *testing.T) { 577 | dir := t.TempDir() 578 | 579 | // Temporary directories on macOS are symlinked to /private/var/folders/... 580 | dir, err := filepath.EvalSymlinks(dir) 581 | if err != nil { 582 | t.Fatal(err) 583 | } 584 | 585 | files := map[string]string{ 586 | "go.mod": "module example.com/foo\n", 587 | "top.go": "package foo\n", 588 | "top_test.go": "package foo\n", 589 | "sub/sub.go": "package sub\n", 590 | "sub/sub_test.go": "package sub\n", 591 | "sub/sub_ext_test.go": "package sub_test\n", 592 | "testdata/ignored_by_default.go": "package testdata\n", 593 | "tagged.go": "//go:build mytag\npackage foo\n", 594 | "tagged_test.go": "//go:build mytag\npackage foo\n", 595 | } 596 | 597 | for name, src := range files { 598 | dst := filepath.Join(dir, name) 599 | if err := os.MkdirAll(filepath.Dir(dst), 0o700); err != nil { 600 | t.Fatal(err) 601 | } 602 | 603 | if err := os.WriteFile(dst, []byte(src), 0o600); err != nil { 604 | t.Fatal(err) 605 | } 606 | } 607 | 608 | tests := []struct { 609 | name string 610 | args []string 611 | want []string 612 | }{ 613 | { 614 | name: "stdin", 615 | args: []string{"-"}, 616 | want: []string{"-"}, 617 | }, 618 | { 619 | name: "all", 620 | args: []string{"./..."}, 621 | want: []string{ 622 | "top.go", 623 | "top_test.go", 624 | "sub/sub.go", 625 | "sub/sub_test.go", 626 | "sub/sub_ext_test.go", 627 | "tagged.go", 628 | "tagged_test.go", 629 | }, 630 | }, 631 | { 632 | name: "relative subpackage", 633 | args: []string{"./sub"}, 634 | want: []string{ 635 | "sub/sub.go", 636 | "sub/sub_test.go", 637 | "sub/sub_ext_test.go", 638 | }, 639 | }, 640 | { 641 | name: "absolute subpackage", 642 | args: []string{"example.com/foo/sub/..."}, 643 | want: []string{ 644 | "sub/sub.go", 645 | "sub/sub_test.go", 646 | "sub/sub_ext_test.go", 647 | }, 648 | }, 649 | { 650 | name: "relative file", 651 | args: []string{"./sub/sub.go"}, 652 | want: []string{ 653 | "sub/sub.go", 654 | }, 655 | }, 656 | { 657 | name: "file and pattern", 658 | args: []string{ 659 | "testdata/ignored_by_default.go", 660 | "./sub/...", 661 | }, 662 | want: []string{ 663 | "sub/sub.go", 664 | "sub/sub_test.go", 665 | "sub/sub_ext_test.go", 666 | "testdata/ignored_by_default.go", 667 | }, 668 | }, 669 | { 670 | name: "file and pattern with tags", 671 | args: []string{"./...", "testdata/ignored_by_default.go"}, 672 | want: []string{ 673 | "top.go", 674 | "top_test.go", 675 | "sub/sub.go", 676 | "sub/sub_test.go", 677 | "sub/sub_ext_test.go", 678 | "tagged.go", 679 | "tagged_test.go", 680 | "testdata/ignored_by_default.go", 681 | }, 682 | }, 683 | } 684 | 685 | for _, tt := range tests { 686 | t.Run(tt.name, func(t *testing.T) { 687 | chdir(t, dir) 688 | 689 | got, err := expandPatterns(tt.args) 690 | if err != nil { 691 | t.Fatal(err) 692 | } 693 | 694 | for i, p := range got { 695 | if filepath.IsAbs(p) { 696 | p, err = filepath.Rel(dir, p) 697 | if err != nil { 698 | t.Fatal(err) 699 | } 700 | } 701 | 702 | // Normalize slashes for cross-platform tests. 703 | got[i] = path.Clean(filepath.ToSlash(p)) 704 | } 705 | 706 | sort.Strings(got) 707 | sort.Strings(tt.want) 708 | 709 | if !reflect.DeepEqual(tt.want, got) { 710 | t.Errorf("got: %v\nwant: %v\ndiff:\n%s", got, tt.want, diff.Diff(tt.want, got)) 711 | } 712 | }) 713 | } 714 | } 715 | 716 | func TestGoListFilesCommandError(t *testing.T) { 717 | defer func(oldExecCommand func(string, ...string) *exec.Cmd) { 718 | _execCommand = oldExecCommand 719 | }(_execCommand) 720 | _execCommand = func(string, ...string) *exec.Cmd { 721 | return exec.Command("false") 722 | } 723 | 724 | var stderr bytes.Buffer 725 | exitCode := (&mainCmd{ 726 | Stderr: &stderr, 727 | Stdout: testWriter{t}, 728 | }).Run([]string{"./..."}) 729 | if want := 1; exitCode != want { 730 | t.Errorf("exit code = %d, want %d", exitCode, want) 731 | } 732 | 733 | if want, got := "go list: exit status 1", stderr.String(); !strings.Contains(got, want) { 734 | t.Errorf("stderr = %q, want %q", got, want) 735 | } 736 | } 737 | 738 | func TestGoListFilesBadJSON(t *testing.T) { 739 | defer func(oldExecCommand func(string, ...string) *exec.Cmd) { 740 | _execCommand = oldExecCommand 741 | }(_execCommand) 742 | _execCommand = func(string, ...string) *exec.Cmd { 743 | return exec.Command("echo", "bad json") 744 | } 745 | 746 | var stderr bytes.Buffer 747 | exitCode := (&mainCmd{ 748 | Stderr: &stderr, 749 | Stdout: testWriter{t}, 750 | }).Run([]string{"./..."}) 751 | if want := 1; exitCode != want { 752 | t.Errorf("exit code = %d, want %d", exitCode, want) 753 | } 754 | 755 | if want, got := "go list: output malformed: invalid character 'b'", stderr.String(); !strings.Contains(got, want) { 756 | t.Errorf("stderr = %q, want %q", got, want) 757 | } 758 | } 759 | 760 | func TestStdinNoInputMessage(t *testing.T) { 761 | tests := []struct { 762 | name string 763 | stdin func(testing.TB) io.Reader 764 | args []string 765 | wantStderr string 766 | }{ 767 | { 768 | name: "stdin is a file", 769 | stdin: func(t testing.TB) io.Reader { 770 | f, err := os.Open("testdata/golden/noop.go") 771 | if err != nil { 772 | t.Fatal(err) 773 | } 774 | return f 775 | }, 776 | }, 777 | { 778 | name: "stdin is a pipe", 779 | stdin: func(t testing.TB) io.Reader { 780 | r, w, err := os.Pipe() 781 | if err != nil { 782 | t.Fatal(err) 783 | } 784 | 785 | go func() { 786 | if _, err := w.WriteString("package foo"); err != nil { 787 | t.Errorf("failed to write to stdin as pipe: %v", err) 788 | } 789 | if err := w.Close(); err != nil { 790 | t.Errorf("failed to close stdin pipe: %v", err) 791 | } 792 | }() 793 | 794 | return r 795 | }, 796 | }, 797 | { 798 | name: "implicit stdin with a char device", 799 | stdin: func(t testing.TB) io.Reader { 800 | return fakeTerminal{strings.NewReader("package foo")} 801 | }, 802 | wantStderr: "reading from stdin; use '-h' for help\n", 803 | }, 804 | { 805 | name: "explicit stdin with a char device", 806 | stdin: func(t testing.TB) io.Reader { 807 | return fakeTerminal{strings.NewReader("package foo")} 808 | }, 809 | args: []string{"-"}, 810 | }, 811 | } 812 | 813 | for _, tt := range tests { 814 | t.Run(tt.name, func(t *testing.T) { 815 | var stderr bytes.Buffer 816 | exitCode := (&mainCmd{ 817 | Stdin: tt.stdin(t), 818 | Stdout: io.Discard, 819 | Stderr: &stderr, 820 | }).Run(tt.args) 821 | 822 | if want := 0; exitCode != want { 823 | t.Errorf("exit code = %d, want %d", exitCode, want) 824 | } 825 | 826 | if want, got := tt.wantStderr, stderr.String(); got != want { 827 | t.Errorf("stderr = %q, want %q", got, want) 828 | } 829 | }) 830 | } 831 | } 832 | 833 | type fakeTerminal struct { 834 | io.Reader 835 | } 836 | 837 | func (ft fakeTerminal) Stat() (os.FileInfo, error) { 838 | return charDeviceFileInfo{}, nil 839 | } 840 | 841 | type charDeviceFileInfo struct { 842 | // embed so we implement the interface. 843 | // unimplemented methods will panic. 844 | os.FileInfo 845 | } 846 | 847 | func (fi charDeviceFileInfo) Mode() os.FileMode { 848 | return os.ModeDevice | os.ModeCharDevice 849 | } 850 | 851 | func indent(s string) string { 852 | return "\t" + strings.ReplaceAll(s, "\n", "\n\t") 853 | } 854 | 855 | type logLine struct { 856 | Line int 857 | Msg string 858 | } 859 | 860 | // extractLogs parses the "// want" comments in src 861 | // into a slice of logLine structs. 862 | func extractLogs(src []byte) ([]logLine, error) { 863 | fset := token.NewFileSet() 864 | f, err := parser.ParseFile(fset, "", src, parser.ParseComments) 865 | if err != nil { 866 | return nil, errtrace.Wrap(fmt.Errorf("parse: %w", err)) 867 | } 868 | 869 | var logs []logLine 870 | for _, c := range f.Comments { 871 | for _, l := range c.List { 872 | _, lit, ok := strings.Cut(l.Text, "// want:") 873 | if !ok { 874 | continue 875 | } 876 | 877 | pos := fset.Position(l.Pos()) 878 | s, err := strconv.Unquote(lit) 879 | if err != nil { 880 | return nil, errtrace.Wrap(fmt.Errorf("%s:bad string literal: %s", pos, lit)) 881 | } 882 | 883 | logs = append(logs, logLine{Line: pos.Line, Msg: s}) 884 | } 885 | } 886 | 887 | sort.Slice(logs, func(i, j int) bool { 888 | return logs[i].Line < logs[j].Line 889 | }) 890 | 891 | return logs, nil 892 | } 893 | 894 | func parseLogOutput(file, s string) ([]logLine, error) { 895 | var logs []logLine 896 | for _, line := range strings.Split(s, "\n") { 897 | if line == "" { 898 | continue 899 | } 900 | 901 | // Drop the path so we can determinstically split on ":" (which is a valid character in Windows paths). 902 | line = strings.TrimPrefix(line, file) 903 | parts := strings.SplitN(line, ":", 4) 904 | if len(parts) != 4 { 905 | return nil, errtrace.Wrap(fmt.Errorf("bad log line: %q", line)) 906 | } 907 | 908 | var msg string 909 | if len(parts) == 4 { 910 | if _, err := strconv.Atoi(parts[2]); err == nil { 911 | // file:line:column:msg 912 | msg = parts[3] 913 | } 914 | } 915 | if msg == "" && len(parts) >= 2 { 916 | // file:line:msg 917 | msg = strings.Join(parts[2:], ":") 918 | } 919 | if msg == "" { 920 | return nil, errtrace.Wrap(fmt.Errorf("bad log line: %q", line)) 921 | } 922 | 923 | lineNum, err := strconv.Atoi(parts[1]) 924 | if err != nil { 925 | return nil, errtrace.Wrap(fmt.Errorf("bad log line: %q", line)) 926 | } 927 | 928 | logs = append(logs, logLine{ 929 | Line: lineNum, 930 | Msg: msg, 931 | }) 932 | } 933 | 934 | return logs, nil 935 | } 936 | 937 | func chdir(t testing.TB, dir string) (restore func()) { 938 | t.Helper() 939 | 940 | cwd, err := os.Getwd() 941 | if err != nil { 942 | t.Fatal(err) 943 | } 944 | 945 | var once sync.Once 946 | restore = func() { 947 | once.Do(func() { 948 | if err := os.Chdir(cwd); err != nil { 949 | t.Fatal(err) 950 | } 951 | }) 952 | } 953 | 954 | t.Cleanup(restore) 955 | if err := os.Chdir(dir); err != nil { 956 | t.Fatal(err) 957 | } 958 | return restore 959 | } 960 | 961 | type testWriter struct{ T testing.TB } 962 | 963 | func (w testWriter) Write(p []byte) (int, error) { 964 | for _, line := range bytes.Split(p, []byte{'\n'}) { 965 | w.T.Logf("%s", line) 966 | } 967 | return len(p), nil 968 | } 969 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/already_imported.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | import ( 6 | "strconv" 7 | 8 | "braces.dev/errtrace" 9 | ) 10 | 11 | func Unwrapped(s string) (int, error) { 12 | i, err := strconv.Atoi(s) 13 | if err != nil { 14 | return 0, err 15 | } 16 | return i + 42, nil 17 | } 18 | 19 | func AlreadyWrapped(s string) (int, error) { 20 | i, err := strconv.Atoi(s) 21 | if err != nil { 22 | return 0, errtrace.Wrap(err) 23 | } 24 | return i + 42, nil 25 | } 26 | 27 | func SkipNew() error { 28 | return errtrace.New("test") 29 | } 30 | 31 | func SkipErrorf() error { 32 | return errtrace.Errorf("foo: %v", SkipNew()) 33 | } 34 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/already_imported.go.golden: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | import ( 6 | "strconv" 7 | 8 | "braces.dev/errtrace" 9 | ) 10 | 11 | func Unwrapped(s string) (int, error) { 12 | i, err := strconv.Atoi(s) 13 | if err != nil { 14 | return 0, errtrace.Wrap(err) 15 | } 16 | return i + 42, nil 17 | } 18 | 19 | func AlreadyWrapped(s string) (int, error) { 20 | i, err := strconv.Atoi(s) 21 | if err != nil { 22 | return 0, errtrace.Wrap(err) 23 | } 24 | return i + 42, nil 25 | } 26 | 27 | func SkipNew() error { 28 | return errtrace.New("test") 29 | } 30 | 31 | func SkipErrorf() error { 32 | return errtrace.Errorf("foo: %v", SkipNew()) 33 | } 34 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/closure.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | import ( 6 | "errors" 7 | "fmt" 8 | ) 9 | 10 | func ClosureReturnsError() error { 11 | return func() error { 12 | return errors.New("great sadness") 13 | }() 14 | } 15 | 16 | func ClosureDoesNotReturnError() error { 17 | x := func() int { 18 | return 42 19 | }() 20 | return nil 21 | } 22 | 23 | func DeferedClosureReturnsError() error { 24 | defer func() error { 25 | // The error is ignored, 26 | // but it should still be wrapped. 27 | return errors.New("great sadness") 28 | }() 29 | 30 | return nil 31 | } 32 | 33 | func DeferedClosureDoesNotReturnError() error { 34 | defer func() int { 35 | return 42 36 | }() 37 | 38 | return nil 39 | } 40 | 41 | func ClosureReturningErrorHasDifferentNumberOfReturns() (int, error) { 42 | x := func() error { 43 | return errors.New("great sadness") 44 | } 45 | 46 | return 42, x() 47 | } 48 | 49 | func ClosureNotReturningErrorHasDifferentNumberOfReturns() (int, error) { 50 | x := func() int { 51 | return 42 52 | } 53 | 54 | return 42, nil 55 | } 56 | 57 | func ClosureInsideAnotherClosure() error { 58 | return func() error { 59 | return func() error { 60 | return errors.New("great sadness") 61 | }() 62 | }() 63 | } 64 | 65 | func ClosureNotReturningErrorInsideAnotherClosure() (int, error) { 66 | var x int 67 | err := func() error { 68 | x := func() int { 69 | return 42 70 | }() 71 | return errors.New("great sadness") 72 | }() 73 | 74 | return x, err 75 | } 76 | 77 | func ClosureReturningAnErrorInsideADefer() error { 78 | defer func() { 79 | err := func() error { 80 | return errors.New("great sadness") 81 | }() 82 | 83 | fmt.Println(err) 84 | }() 85 | 86 | return nil 87 | } 88 | 89 | func ClosureNotReturningAnErrorInsideADefer() error { 90 | defer func() error { 91 | x := func() int { 92 | return 42 93 | }() 94 | 95 | return fmt.Errorf("great sadness: %d", x) 96 | }() 97 | 98 | return nil 99 | } 100 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/closure.go.golden: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | import ( 6 | "errors" 7 | "fmt"; "braces.dev/errtrace" 8 | ) 9 | 10 | func ClosureReturnsError() error { 11 | return errtrace.Wrap(func() error { 12 | return errtrace.Wrap(errors.New("great sadness")) 13 | }()) 14 | } 15 | 16 | func ClosureDoesNotReturnError() error { 17 | x := func() int { 18 | return 42 19 | }() 20 | return nil 21 | } 22 | 23 | func DeferedClosureReturnsError() error { 24 | defer func() error { 25 | // The error is ignored, 26 | // but it should still be wrapped. 27 | return errtrace.Wrap(errors.New("great sadness")) 28 | }() 29 | 30 | return nil 31 | } 32 | 33 | func DeferedClosureDoesNotReturnError() error { 34 | defer func() int { 35 | return 42 36 | }() 37 | 38 | return nil 39 | } 40 | 41 | func ClosureReturningErrorHasDifferentNumberOfReturns() (int, error) { 42 | x := func() error { 43 | return errtrace.Wrap(errors.New("great sadness")) 44 | } 45 | 46 | return 42, errtrace.Wrap(x()) 47 | } 48 | 49 | func ClosureNotReturningErrorHasDifferentNumberOfReturns() (int, error) { 50 | x := func() int { 51 | return 42 52 | } 53 | 54 | return 42, nil 55 | } 56 | 57 | func ClosureInsideAnotherClosure() error { 58 | return errtrace.Wrap(func() error { 59 | return errtrace.Wrap(func() error { 60 | return errtrace.Wrap(errors.New("great sadness")) 61 | }()) 62 | }()) 63 | } 64 | 65 | func ClosureNotReturningErrorInsideAnotherClosure() (int, error) { 66 | var x int 67 | err := func() error { 68 | x := func() int { 69 | return 42 70 | }() 71 | return errtrace.Wrap(errors.New("great sadness")) 72 | }() 73 | 74 | return x, errtrace.Wrap(err) 75 | } 76 | 77 | func ClosureReturningAnErrorInsideADefer() error { 78 | defer func() { 79 | err := func() error { 80 | return errtrace.Wrap(errors.New("great sadness")) 81 | }() 82 | 83 | fmt.Println(err) 84 | }() 85 | 86 | return nil 87 | } 88 | 89 | func ClosureNotReturningAnErrorInsideADefer() error { 90 | defer func() error { 91 | x := func() int { 92 | return 42 93 | }() 94 | 95 | return errtrace.Wrap(fmt.Errorf("great sadness: %d", x)) 96 | }() 97 | 98 | return nil 99 | } 100 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/error_wrap.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | type innerError struct{} 6 | 7 | func (*innerError) Error() string { 8 | return "sadness" 9 | } 10 | 11 | type errorWrapper struct { 12 | Err error 13 | } 14 | 15 | func (e *errorWrapper) Error() string { 16 | return e.Err.Error() 17 | } 18 | 19 | // Unwrap should not be wrapped by errtrace. 20 | func (e *errorWrapper) Unwrap() error { 21 | return e.Err 22 | } 23 | 24 | func Try() error { 25 | return &errorWrapper{Err: &innerError{}} 26 | } 27 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/error_wrap.go.golden: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo; import "braces.dev/errtrace" 4 | 5 | type innerError struct{} 6 | 7 | func (*innerError) Error() string { 8 | return "sadness" 9 | } 10 | 11 | type errorWrapper struct { 12 | Err error 13 | } 14 | 15 | func (e *errorWrapper) Error() string { 16 | return e.Err.Error() 17 | } 18 | 19 | // Unwrap should not be wrapped by errtrace. 20 | func (e *errorWrapper) Unwrap() error { 21 | return e.Err 22 | } 23 | 24 | func Try() error { 25 | return errtrace.Wrap(&errorWrapper{Err: &innerError{}}) 26 | } 27 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/imported_blank.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | import ( 6 | "strconv" 7 | 8 | _ "braces.dev/errtrace" 9 | ) 10 | 11 | func Unwrapped(s string) (int, error) { 12 | i, err := strconv.Atoi(s) 13 | if err != nil { 14 | return 0, err 15 | } 16 | return i + 42, nil 17 | } 18 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/imported_blank.go.golden: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | import ( 6 | "strconv" 7 | 8 | _ "braces.dev/errtrace"; "braces.dev/errtrace" 9 | ) 10 | 11 | func Unwrapped(s string) (int, error) { 12 | i, err := strconv.Atoi(s) 13 | if err != nil { 14 | return 0, errtrace.Wrap(err) 15 | } 16 | return i + 42, nil 17 | } 18 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/imported_with_alias.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | import ( 6 | "strconv" 7 | 8 | errtrace2 "braces.dev/errtrace" 9 | ) 10 | 11 | var _ = errtrace2.Wrap // keep import 12 | 13 | func Unwrapped(s string) (int, error) { 14 | i, err := strconv.Atoi(s) 15 | if err != nil { 16 | return 0, err 17 | } 18 | return i + 42, nil 19 | } 20 | 21 | func AlreadyWrapped(s string) (int, error) { 22 | i, err := strconv.Atoi(s) 23 | if err != nil { 24 | return 0, errtrace2.Wrap(err) 25 | } 26 | return i + 42, nil 27 | } 28 | 29 | func NakedNamedReturn(s string) (i int, err error) { 30 | i, err = strconv.Atoi(s) 31 | return 32 | } 33 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/imported_with_alias.go.golden: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | import ( 6 | "strconv" 7 | 8 | errtrace2 "braces.dev/errtrace" 9 | ) 10 | 11 | var _ = errtrace2.Wrap // keep import 12 | 13 | func Unwrapped(s string) (int, error) { 14 | i, err := strconv.Atoi(s) 15 | if err != nil { 16 | return 0, errtrace2.Wrap(err) 17 | } 18 | return i + 42, nil 19 | } 20 | 21 | func AlreadyWrapped(s string) (int, error) { 22 | i, err := strconv.Atoi(s) 23 | if err != nil { 24 | return 0, errtrace2.Wrap(err) 25 | } 26 | return i + 42, nil 27 | } 28 | 29 | func NakedNamedReturn(s string) (i int, err error) { 30 | i, err = strconv.Atoi(s) 31 | err = errtrace2.Wrap(err); return 32 | } 33 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/name_already_taken.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | import "strconv" 6 | 7 | func Unwrapped(errtrace string) (int, error) { 8 | // For some reason, the string is named errtrace. 9 | // Don't think about it too hard. 10 | i, err := strconv.Atoi(errtrace) 11 | return i + 42, err 12 | } 13 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/name_already_taken.go.golden: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | import "strconv"; import errtrace2 "braces.dev/errtrace" 6 | 7 | func Unwrapped(errtrace string) (int, error) { 8 | // For some reason, the string is named errtrace. 9 | // Don't think about it too hard. 10 | i, err := strconv.Atoi(errtrace) 11 | return i + 42, errtrace2.Wrap(err) 12 | } 13 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/named_returns.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | import ( 6 | "errors" 7 | "fmt" 8 | "os" 9 | 10 | "go.uber.org/multierr" 11 | ) 12 | 13 | func NoError() (x int) { 14 | return 0 15 | } 16 | 17 | func NoErrorNaked() (x int) { 18 | return 19 | } 20 | 21 | func NakedReturn(s string) (err error) { 22 | err = errors.New("sadness: " + s) 23 | fmt.Println("Reporting sadness") 24 | return 25 | } 26 | 27 | func NamedReturn(s string) (err error) { 28 | err = errors.New("sadness: " + s) 29 | fmt.Println("Reporting sadness") 30 | return err 31 | } 32 | 33 | func MultipleErrors() (err1, err2 error, ok bool, err3, err4 error) { 34 | err1 = errors.New("a") 35 | err2 = errors.New("b") 36 | ok = false 37 | err3 = errors.New("c") 38 | err4 = errors.New("d") 39 | 40 | if !ok { 41 | // Naked 42 | return 43 | } 44 | 45 | // Named 46 | return err1, err2, ok, err3, err4 47 | } 48 | 49 | func UnderscoreNamed() (_ error) { 50 | return NamedReturn("foo") 51 | } 52 | 53 | func UnderscoreNamedMultiple() (_ bool, err error) { 54 | return false, NamedReturn("foo") 55 | } 56 | 57 | func DeferWrapNamed() (err error) { 58 | defer func() { 59 | err = fmt.Errorf("wrapped: %w", err) 60 | }() 61 | 62 | return NamedReturn("foo") 63 | } 64 | 65 | func DeferWrapNamedWithItsOwnError() (_ int, err error) { 66 | // Both, the error returned by the deferred function 67 | // and the named error wrapped by it should be wrapped. 68 | defer func() error { 69 | err = fmt.Errorf("wrapped: %w", err) 70 | 71 | return errors.New("ignored") 72 | }() 73 | 74 | return 0, UnderscoreNamed() 75 | } 76 | 77 | func DeferToAnotherFunction() (err error) { 78 | f, err := os.Open("foo.txt") 79 | if err != nil { 80 | return err 81 | } 82 | defer multierr.AppendInto(&err, multierr.Close(f)) 83 | return nil 84 | } 85 | 86 | func NamedReturnShadowed() (err error) { 87 | defer func() { 88 | // Should not get wrapped by errtrace. 89 | err := cleanup() 90 | if err != nil { 91 | fmt.Println("cleanup failed:", err) 92 | } 93 | }() 94 | 95 | return f() 96 | } 97 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/named_returns.go.golden: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | import ( 6 | "errors" 7 | "fmt" 8 | "os" 9 | 10 | "go.uber.org/multierr"; "braces.dev/errtrace" 11 | ) 12 | 13 | func NoError() (x int) { 14 | return 0 15 | } 16 | 17 | func NoErrorNaked() (x int) { 18 | return 19 | } 20 | 21 | func NakedReturn(s string) (err error) { 22 | err = errors.New("sadness: " + s) 23 | fmt.Println("Reporting sadness") 24 | err = errtrace.Wrap(err); return 25 | } 26 | 27 | func NamedReturn(s string) (err error) { 28 | err = errors.New("sadness: " + s) 29 | fmt.Println("Reporting sadness") 30 | return errtrace.Wrap(err) 31 | } 32 | 33 | func MultipleErrors() (err1, err2 error, ok bool, err3, err4 error) { 34 | err1 = errors.New("a") 35 | err2 = errors.New("b") 36 | ok = false 37 | err3 = errors.New("c") 38 | err4 = errors.New("d") 39 | 40 | if !ok { 41 | // Naked 42 | err1, err2, err3, err4 = errtrace.Wrap(err1), errtrace.Wrap(err2), errtrace.Wrap(err3), errtrace.Wrap(err4); return 43 | } 44 | 45 | // Named 46 | return errtrace.Wrap(err1), errtrace.Wrap(err2), ok, errtrace.Wrap(err3), errtrace.Wrap(err4) 47 | } 48 | 49 | func UnderscoreNamed() (_ error) { 50 | return errtrace.Wrap(NamedReturn("foo")) 51 | } 52 | 53 | func UnderscoreNamedMultiple() (_ bool, err error) { 54 | return false, errtrace.Wrap(NamedReturn("foo")) 55 | } 56 | 57 | func DeferWrapNamed() (err error) { 58 | defer func() { 59 | err = errtrace.Wrap(fmt.Errorf("wrapped: %w", err)) 60 | }() 61 | 62 | return errtrace.Wrap(NamedReturn("foo")) 63 | } 64 | 65 | func DeferWrapNamedWithItsOwnError() (_ int, err error) { 66 | // Both, the error returned by the deferred function 67 | // and the named error wrapped by it should be wrapped. 68 | defer func() error { 69 | err = errtrace.Wrap(fmt.Errorf("wrapped: %w", err)) 70 | 71 | return errtrace.Wrap(errors.New("ignored")) 72 | }() 73 | 74 | return 0, errtrace.Wrap(UnderscoreNamed()) 75 | } 76 | 77 | func DeferToAnotherFunction() (err error) { 78 | f, err := os.Open("foo.txt") 79 | if err != nil { 80 | return errtrace.Wrap(err) 81 | } 82 | defer multierr.AppendInto(&err, multierr.Close(f)) 83 | return nil 84 | } 85 | 86 | func NamedReturnShadowed() (err error) { 87 | defer func() { 88 | // Should not get wrapped by errtrace. 89 | err := cleanup() 90 | if err != nil { 91 | fmt.Println("cleanup failed:", err) 92 | } 93 | }() 94 | 95 | return errtrace.Wrap(f()) 96 | } 97 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/nested.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | import ( 6 | "errors" 7 | "fmt" 8 | ) 9 | 10 | func HasFunctionLiteral() { 11 | err := func() error { 12 | return errors.New("sadness") 13 | }() 14 | 15 | fmt.Println(err) 16 | } 17 | 18 | func ImmediatelyInvokedFunctionExpression() error { 19 | return func() error { 20 | return errors.New("sadness") 21 | }() 22 | } 23 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/nested.go.golden: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | import ( 6 | "errors" 7 | "fmt"; "braces.dev/errtrace" 8 | ) 9 | 10 | func HasFunctionLiteral() { 11 | err := func() error { 12 | return errtrace.Wrap(errors.New("sadness")) 13 | }() 14 | 15 | fmt.Println(err) 16 | } 17 | 18 | func ImmediatelyInvokedFunctionExpression() error { 19 | return errtrace.Wrap(func() error { 20 | return errtrace.Wrap(errors.New("sadness")) 21 | }()) 22 | } 23 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/no-wrapn.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | // @runIf options=no-wrapn 4 | package foo 5 | 6 | import "example.com/bar" 7 | 8 | func hasTwo() (int, error) { 9 | // Same names as used by rewriting, with different types to verify scoping. 10 | r1 := true 11 | r2 := false 12 | return bar.Two() 13 | } 14 | 15 | func hasThree() (string, int, error) { 16 | return bar.Three() 17 | } 18 | 19 | func hasFour() (string, int, bool, error) { 20 | return bar.Four() 21 | } 22 | 23 | func hasFive() (a int, b bool, c string, d int, e error) { 24 | return bar.Five() 25 | } 26 | 27 | func hasSix() (a int, b bool, c string, d int, e bool, f error) { 28 | return bar.Six() 29 | } 30 | 31 | func hasSeven() (a int, b bool, c string, d int, e bool, f string, g error) { 32 | return bar.Seven() 33 | } 34 | 35 | func nonFinalError() (error, bool) { 36 | return bar.NonFinalError() // want:"skipping function with non-final error return" 37 | } 38 | 39 | func multipleErrors() (x int, err1, err2 error) { 40 | return bar.MultipleErrors() // want:"skipping function with multiple error returns" 41 | } 42 | 43 | func invalid() (x int, err error) { 44 | return 42 // want:"skipping function with incorrect number of return values: got 1, want 2" 45 | } 46 | 47 | func nestedExpressions() (int, error) { 48 | return func() (int, error) { 49 | r1 := true 50 | r2 := false 51 | return bar.Two() 52 | }() 53 | } 54 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/no-wrapn.go.golden: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | // @runIf options=no-wrapn 4 | package foo 5 | 6 | import "example.com/bar"; import "braces.dev/errtrace" 7 | 8 | func hasTwo() (int, error) { 9 | // Same names as used by rewriting, with different types to verify scoping. 10 | r1 := true 11 | r2 := false 12 | { r1, r2 := bar.Two(); return r1, errtrace.Wrap(r2) } 13 | } 14 | 15 | func hasThree() (string, int, error) { 16 | { r1, r2, r3 := bar.Three(); return r1, r2, errtrace.Wrap(r3) } 17 | } 18 | 19 | func hasFour() (string, int, bool, error) { 20 | { r1, r2, r3, r4 := bar.Four(); return r1, r2, r3, errtrace.Wrap(r4) } 21 | } 22 | 23 | func hasFive() (a int, b bool, c string, d int, e error) { 24 | { r1, r2, r3, r4, r5 := bar.Five(); return r1, r2, r3, r4, errtrace.Wrap(r5) } 25 | } 26 | 27 | func hasSix() (a int, b bool, c string, d int, e bool, f error) { 28 | { r1, r2, r3, r4, r5, r6 := bar.Six(); return r1, r2, r3, r4, r5, errtrace.Wrap(r6) } 29 | } 30 | 31 | func hasSeven() (a int, b bool, c string, d int, e bool, f string, g error) { 32 | { r1, r2, r3, r4, r5, r6, r7 := bar.Seven(); return r1, r2, r3, r4, r5, r6, errtrace.Wrap(r7) } 33 | } 34 | 35 | func nonFinalError() (error, bool) { 36 | return bar.NonFinalError() // want:"skipping function with non-final error return" 37 | } 38 | 39 | func multipleErrors() (x int, err1, err2 error) { 40 | return bar.MultipleErrors() // want:"skipping function with multiple error returns" 41 | } 42 | 43 | func invalid() (x int, err error) { 44 | return 42 // want:"skipping function with incorrect number of return values: got 1, want 2" 45 | } 46 | 47 | func nestedExpressions() (int, error) { 48 | { r1, r2 := func() (int, error) { 49 | r1 := true 50 | r2 := false 51 | { r1, r2 := bar.Two(); return r1, errtrace.Wrap(r2) } 52 | }(); return r1, errtrace.Wrap(r2) } 53 | } 54 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/no_import.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | type myError struct{} 6 | 7 | func (*myError) Error() string { 8 | return "sadness" 9 | } 10 | 11 | func Try() error { 12 | return &myError{} 13 | } 14 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/no_import.go.golden: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo; import "braces.dev/errtrace" 4 | 5 | type myError struct{} 6 | 7 | func (*myError) Error() string { 8 | return "sadness" 9 | } 10 | 11 | func Try() error { 12 | return errtrace.Wrap(&myError{}) 13 | } 14 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/noop.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | import "errors" 6 | 7 | // This file should not be changed. 8 | 9 | func success() error { 10 | return nil 11 | } 12 | 13 | func failure() error { 14 | return errors.New("failure") //errtrace:skip 15 | } 16 | 17 | func deferred() (err error) { 18 | defer func() { 19 | err = errors.New("failure") //errtrace:skip 20 | }() 21 | return nil 22 | } 23 | 24 | func namedReturn() (err error) { 25 | err = errors.New("failure") 26 | return //errtrace:skip 27 | } 28 | 29 | func immediatelyInvoked() error { 30 | return func() error { //errtrace:skip 31 | return errors.New("failure") //errtrace:skip 32 | }() 33 | } 34 | 35 | func multipleReturns() (error, error) { 36 | return errors.New("a"), errors.New("b") //errtrace:skip 37 | } 38 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/noop.go.golden: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | import "errors" 6 | 7 | // This file should not be changed. 8 | 9 | func success() error { 10 | return nil 11 | } 12 | 13 | func failure() error { 14 | return errors.New("failure") //errtrace:skip 15 | } 16 | 17 | func deferred() (err error) { 18 | defer func() { 19 | err = errors.New("failure") //errtrace:skip 20 | }() 21 | return nil 22 | } 23 | 24 | func namedReturn() (err error) { 25 | err = errors.New("failure") 26 | return //errtrace:skip 27 | } 28 | 29 | func immediatelyInvoked() error { 30 | return func() error { //errtrace:skip 31 | return errors.New("failure") //errtrace:skip 32 | }() 33 | } 34 | 35 | func multipleReturns() (error, error) { 36 | return errors.New("a"), errors.New("b") //errtrace:skip 37 | } 38 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/optout.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | import ( 6 | "errors" 7 | "io" 8 | 9 | "example.com/bar" 10 | ) 11 | 12 | func Try(problem bool) (int, error) { 13 | err := bar.Do(func() error { 14 | if problem { 15 | return errors.New("great sadness") 16 | } 17 | 18 | return io.EOF //nolint:errwrap //errtrace:skip(expects io.EOF) 19 | }) 20 | if err != nil { 21 | return 0, err 22 | } 23 | 24 | return bar.Baz() //errtrace:skip // caller wants unwrapped error 25 | } 26 | 27 | func unused() error { 28 | return nil //errtrace:skip // want:"unused errtrace:skip" 29 | } 30 | 31 | func multipleReturns() (a, b error) { 32 | return errors.New("a"), 33 | errors.New("b") //errtrace:skip 34 | } 35 | 36 | func multipleReturnsSkipped() (a, b error) { 37 | return errors.New("a"), //errtrace:skip 38 | errors.New("b") //errtrace:skip 39 | } 40 | 41 | // Explanation of why this function 42 | // is not using //errtrace:skip should not 43 | // trip up the warning. 44 | func notUsingSkip() error { 45 | return nil 46 | } 47 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/optout.go.golden: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | import ( 6 | "errors" 7 | "io" 8 | 9 | "example.com/bar"; "braces.dev/errtrace" 10 | ) 11 | 12 | func Try(problem bool) (int, error) { 13 | err := bar.Do(func() error { 14 | if problem { 15 | return errtrace.Wrap(errors.New("great sadness")) 16 | } 17 | 18 | return io.EOF //nolint:errwrap //errtrace:skip(expects io.EOF) 19 | }) 20 | if err != nil { 21 | return 0, errtrace.Wrap(err) 22 | } 23 | 24 | return bar.Baz() //errtrace:skip // caller wants unwrapped error 25 | } 26 | 27 | func unused() error { 28 | return nil //errtrace:skip // want:"unused errtrace:skip" 29 | } 30 | 31 | func multipleReturns() (a, b error) { 32 | return errtrace.Wrap(errors.New("a")), 33 | errors.New("b") //errtrace:skip 34 | } 35 | 36 | func multipleReturnsSkipped() (a, b error) { 37 | return errors.New("a"), //errtrace:skip 38 | errors.New("b") //errtrace:skip 39 | } 40 | 41 | // Explanation of why this function 42 | // is not using //errtrace:skip should not 43 | // trip up the warning. 44 | func notUsingSkip() error { 45 | return nil 46 | } 47 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/simple.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | import ( 6 | "io" 7 | "os" 8 | "strconv" 9 | ) 10 | 11 | func Unwrapped(s string) (int, error) { 12 | i, err := strconv.Atoi(s) 13 | if err != nil { 14 | return 0, err 15 | } 16 | return i + 42, nil 17 | } 18 | 19 | func DeferWithoutNamedReturns(s string) error { 20 | f, err := os.Open(s) 21 | if err != nil { 22 | return err 23 | } 24 | defer f.Close() 25 | 26 | bs, err := io.ReadAll(f) 27 | if err != nil { 28 | return err 29 | } 30 | 31 | return nil 32 | } 33 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/simple.go.golden: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | package foo 4 | 5 | import ( 6 | "io" 7 | "os" 8 | "strconv"; "braces.dev/errtrace" 9 | ) 10 | 11 | func Unwrapped(s string) (int, error) { 12 | i, err := strconv.Atoi(s) 13 | if err != nil { 14 | return 0, errtrace.Wrap(err) 15 | } 16 | return i + 42, nil 17 | } 18 | 19 | func DeferWithoutNamedReturns(s string) error { 20 | f, err := os.Open(s) 21 | if err != nil { 22 | return errtrace.Wrap(err) 23 | } 24 | defer f.Close() 25 | 26 | bs, err := io.ReadAll(f) 27 | if err != nil { 28 | return errtrace.Wrap(err) 29 | } 30 | 31 | return nil 32 | } 33 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/tuple_rhs.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "fmt" 4 | 5 | func multipleValueErrAssignment() (err error) { 6 | defer func() { 7 | _, err = fmt.Println("Hello, World!") 8 | 9 | // Handles too few lhs variables 10 | err = fmt.Println("Hello, World!") 11 | 12 | // Handles too many lhs variables 13 | _, err, _ = fmt.Println("Hello, World!") // want:"skipping assignment: error is not the last return value" 14 | 15 | // Handles misplaced err 16 | err, _ = fmt.Println("Hello, World!") // want:"skipping assignment: error is not the last return value" 17 | }() 18 | 19 | return nil 20 | } 21 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/tuple_rhs.go.golden: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "fmt"; import "braces.dev/errtrace" 4 | 5 | func multipleValueErrAssignment() (err error) { 6 | defer func() { 7 | _, err = errtrace.Wrap2(fmt.Println("Hello, World!")) 8 | 9 | // Handles too few lhs variables 10 | err = errtrace.Wrap(fmt.Println("Hello, World!")) 11 | 12 | // Handles too many lhs variables 13 | _, err, _ = fmt.Println("Hello, World!") // want:"skipping assignment: error is not the last return value" 14 | 15 | // Handles misplaced err 16 | err, _ = fmt.Println("Hello, World!") // want:"skipping assignment: error is not the last return value" 17 | }() 18 | 19 | return nil 20 | } 21 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/wrapn.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | // @runIf options= 4 | package foo 5 | 6 | import "example.com/bar" 7 | 8 | func hasTwo() (int, error) { 9 | return bar.Two() 10 | } 11 | 12 | func hasThree() (string, int, error) { 13 | return bar.Three() 14 | } 15 | 16 | func hasFour() (string, int, bool, error) { 17 | return bar.Four() 18 | } 19 | 20 | func hasFive() (a int, b bool, c string, d int, e error) { 21 | return bar.Five() 22 | } 23 | 24 | func hasSix() (a int, b bool, c string, d int, e bool, f error) { 25 | return bar.Six() 26 | } 27 | 28 | func hasSeven() (a int, b bool, c string, d int, e bool, f string, g error) { 29 | return bar.Seven() // want:"skipping function with too many return values" 30 | } 31 | 32 | func nonFinalError() (error, bool) { 33 | return bar.NonFinalError() // want:"skipping function with non-final error return" 34 | } 35 | 36 | func multipleErrors() (x int, err1, err2 error) { 37 | return bar.MultipleErrors() // want:"skipping function with multiple error returns" 38 | } 39 | 40 | func invalid() (x int, err error) { 41 | return 42 // want:"skipping function with incorrect number of return values: got 1, want 2" 42 | } 43 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/golden/wrapn.go.golden: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | 3 | // @runIf options= 4 | package foo 5 | 6 | import "example.com/bar"; import "braces.dev/errtrace" 7 | 8 | func hasTwo() (int, error) { 9 | return errtrace.Wrap2(bar.Two()) 10 | } 11 | 12 | func hasThree() (string, int, error) { 13 | return errtrace.Wrap3(bar.Three()) 14 | } 15 | 16 | func hasFour() (string, int, bool, error) { 17 | return errtrace.Wrap4(bar.Four()) 18 | } 19 | 20 | func hasFive() (a int, b bool, c string, d int, e error) { 21 | return errtrace.Wrap5(bar.Five()) 22 | } 23 | 24 | func hasSix() (a int, b bool, c string, d int, e bool, f error) { 25 | return errtrace.Wrap6(bar.Six()) 26 | } 27 | 28 | func hasSeven() (a int, b bool, c string, d int, e bool, f string, g error) { 29 | return bar.Seven() // want:"skipping function with too many return values" 30 | } 31 | 32 | func nonFinalError() (error, bool) { 33 | return bar.NonFinalError() // want:"skipping function with non-final error return" 34 | } 35 | 36 | func multipleErrors() (x int, err1, err2 error) { 37 | return bar.MultipleErrors() // want:"skipping function with multiple error returns" 38 | } 39 | 40 | func invalid() (x int, err error) { 41 | return 42 // want:"skipping function with incorrect number of return values: got 1, want 2" 42 | } 43 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/main/foo/foo.go: -------------------------------------------------------------------------------- 1 | package foo 2 | 3 | import "errors" 4 | 5 | func Foo() error { 6 | return errors.New("test") 7 | } 8 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/main/go.mod: -------------------------------------------------------------------------------- 1 | module braces.dev/errtrace/cmd/errtrace/testdata/main 2 | 3 | go 1.21.4 4 | 5 | require braces.dev/errtrace v0.3.0 6 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/main/go.sum: -------------------------------------------------------------------------------- 1 | braces.dev/errtrace v0.3.0 h1:pzfd6LcWgfWtXLaNFWRnxV/7NP+FSOlIjRLwDuHfPxs= 2 | braces.dev/errtrace v0.3.0/go.mod h1:YQpXdo+u5iimgQdZzFoic8AjedEDncXGpp6/2SfazzI= 3 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/main/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | "braces.dev/errtrace/cmd/errtrace/testdata/main/foo" 7 | ) 8 | 9 | func main() { 10 | if err := foo.Foo(); err != nil { 11 | fmt.Printf("%+v\n", err) 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/toolexec-test/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | 6 | _ "braces.dev/errtrace" // Opt-in to errtrace wrapping with toolexec. 7 | "braces.dev/errtrace/cmd/errtrace/testdata/toolexec-test/p1" 8 | ) 9 | 10 | func main() { 11 | if err := callP1(); err != nil { 12 | fmt.Printf("%+v\n", err) 13 | } 14 | } 15 | 16 | func callP1() error { 17 | return p1.WrapP2() // @trace 18 | } 19 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/toolexec-test/main_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "testing" 4 | 5 | func TestFoo(t *testing.T) { 6 | t.Errorf("fail") 7 | } 8 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/toolexec-test/p1/p1.go: -------------------------------------------------------------------------------- 1 | package p1 2 | 3 | import ( 4 | "fmt" 5 | 6 | "braces.dev/errtrace/cmd/errtrace/testdata/toolexec-test/p2" 7 | ) 8 | 9 | // WrapP2 wraps an error return from p2. 10 | func WrapP2() error { 11 | return fmt.Errorf("test2: %w", p2.CallP3()) 12 | } 13 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/toolexec-test/p2/p2.go: -------------------------------------------------------------------------------- 1 | package p2 2 | 3 | import ( 4 | "braces.dev/errtrace" 5 | 6 | "braces.dev/errtrace/cmd/errtrace/testdata/toolexec-test/p3" 7 | ) 8 | 9 | // CallP3 calls p3, and wraps the error. 10 | func CallP3() error { 11 | return errtrace.Wrap(p3.ReturnErr()) // @trace 12 | } 13 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/toolexec-test/p3/errtrace.go: -------------------------------------------------------------------------------- 1 | package p3 2 | 3 | // Opt-in to errtrace wrapping with toolexec. 4 | import _ "braces.dev/errtrace" 5 | -------------------------------------------------------------------------------- /cmd/errtrace/testdata/toolexec-test/p3/p3.go: -------------------------------------------------------------------------------- 1 | package p3 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | // ReturnErr returns an error. 8 | func ReturnErr() error { 9 | return errors.New("test") // @trace 10 | } 11 | -------------------------------------------------------------------------------- /cmd/errtrace/toolexec.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "crypto/md5" 6 | "encoding/hex" 7 | "flag" 8 | "fmt" 9 | "io" 10 | "os" 11 | "os/exec" 12 | "path/filepath" 13 | "runtime" 14 | "runtime/debug" 15 | "slices" 16 | "strings" 17 | 18 | "braces.dev/errtrace" 19 | ) 20 | 21 | func (cmd *mainCmd) handleToolExec(args []string) (exitCode int, handled bool) { 22 | // In toolexec mode, we're passed the original command + arguments. 23 | if len(args) == 0 { 24 | return -1, false 25 | } 26 | 27 | if cmd.Getenv == nil { 28 | cmd.Getenv = os.Getenv 29 | } 30 | 31 | // compile is run first with "-V=full" to get a version number 32 | // for caching build IDs. 33 | // No TOOLEXEC_IMPORTPATH is set in this case. 34 | version := slices.Contains(args, "-V=full") 35 | pkg := cmd.Getenv("TOOLEXEC_IMPORTPATH") 36 | if !version && pkg == "" { 37 | return -1, false 38 | } 39 | 40 | var p toolExecParams 41 | if err := p.Parse(os.Stdout, args); err != nil { 42 | cmd.log.Print(err) 43 | return 1, true 44 | } 45 | 46 | if version { 47 | return cmd.toolExecVersion(p), true 48 | } 49 | return cmd.toolExecRewrite(pkg, p), true 50 | } 51 | 52 | type toolExecParams struct { 53 | RequiredPkgSelectors []string 54 | 55 | Tool string 56 | ToolArgs []string 57 | 58 | flags *flag.FlagSet 59 | } 60 | 61 | func (p *toolExecParams) Parse(w io.Writer, args []string) error { 62 | p.flags = flag.NewFlagSet("errtrace (toolexec)", flag.ContinueOnError) 63 | flag.Usage = func() { 64 | logln(w, `usage with go build/run/test: -toolexec="errtrace [options]"`) 65 | flag.PrintDefaults() 66 | } 67 | var requiredPkgs string 68 | p.flags.StringVar(&requiredPkgs, "required-packages", "", "comma-separated list of package selectors "+ 69 | "that are expected to be import errtrace if they return errors.") 70 | 71 | // Flag parsing stops at the first non-flag argument (no "-"). 72 | if err := p.flags.Parse(args); err != nil { 73 | return errtrace.Wrap(err) 74 | } 75 | 76 | remArgs := p.flags.Args() 77 | if len(remArgs) == 0 { 78 | return errtrace.New("toolexec expected tool arguments") 79 | } 80 | 81 | p.Tool = remArgs[0] 82 | p.ToolArgs = remArgs[1:] 83 | p.RequiredPkgSelectors = strings.Split(requiredPkgs, ",") 84 | return nil 85 | } 86 | 87 | // Options affect the generated code, so use a hash 88 | // of any options for the toolexec version. 89 | func (p *toolExecParams) versionCacheKey() string { 90 | withoutTool := *p 91 | withoutTool.flags = nil 92 | withoutTool.Tool = "" 93 | withoutTool.ToolArgs = nil 94 | 95 | optStr := fmt.Sprintf("%v", withoutTool) 96 | optHash := md5.Sum([]byte(optStr)) 97 | return hex.EncodeToString(optHash[:]) 98 | } 99 | 100 | func (p *toolExecParams) requiredPackage(pkg string) bool { 101 | for _, selector := range p.RequiredPkgSelectors { 102 | if packageSelectorMatch(selector, pkg) { 103 | return true 104 | } 105 | } 106 | return false 107 | } 108 | 109 | func (cmd *mainCmd) toolExecVersion(p toolExecParams) int { 110 | version, err := binaryVersion() 111 | if err != nil { 112 | logf(cmd.Stderr, "errtrace version failed: %v", err) 113 | } 114 | 115 | tool := exec.Command(p.Tool, p.ToolArgs...) 116 | var stdout bytes.Buffer 117 | tool.Stdout = &stdout 118 | tool.Stderr = cmd.Stderr 119 | if err := tool.Run(); err != nil { 120 | if exitErr, ok := err.(*exec.ExitError); ok { 121 | return exitErr.ExitCode() 122 | } 123 | 124 | logf(cmd.Stderr, "tool %v failed: %v", p.Tool, err) 125 | return 1 126 | } 127 | 128 | if _, err := fmt.Fprintf( 129 | cmd.Stdout, 130 | "%s-errtrace-%s%s\n", 131 | strings.TrimSpace(stdout.String()), 132 | version, 133 | p.versionCacheKey(), 134 | ); err != nil { 135 | logf(cmd.Stderr, "failed to write version to stdout: %v", err) 136 | return 1 137 | } 138 | 139 | return 0 140 | } 141 | 142 | func (cmd *mainCmd) toolExecRewrite(pkg string, p toolExecParams) (exitCode int) { 143 | // We only need to modify the arguments for "compile" calls which work with .go files. 144 | if !isCompile(p.Tool) { 145 | return cmd.runOriginal(p) 146 | } 147 | 148 | // We only modify files that import errtrace, so stdlib is never eliglble. 149 | if isStdLib(p.ToolArgs) { 150 | return cmd.runOriginal(p) 151 | } 152 | 153 | exitCode, err := cmd.rewriteCompile(pkg, p) 154 | if err != nil { 155 | cmd.log.Print(err) 156 | return 1 157 | } 158 | 159 | return exitCode 160 | } 161 | 162 | func (cmd *mainCmd) rewriteCompile(pkg string, p toolExecParams) (exitCode int, _ error) { 163 | var canRewrite, needRewrite bool 164 | parsed := make(map[string]parsedFile) 165 | for _, arg := range p.ToolArgs { 166 | if !isGoFile(arg) { 167 | continue 168 | } 169 | 170 | contents, err := os.ReadFile(arg) 171 | if err != nil { 172 | return -1, errtrace.Wrap(err) 173 | } 174 | 175 | f, err := cmd.parseFile(arg, contents, rewriteOpts{}) 176 | if err != nil { 177 | return -1, errtrace.Wrap(err) 178 | } 179 | parsed[arg] = f 180 | 181 | // TODO: Support an "unsafe" mode to rewrite packages without errtrace imports. 182 | if f.importsErrtrace { 183 | canRewrite = true 184 | } 185 | if len(f.inserts) > 0 { 186 | needRewrite = true 187 | } 188 | } 189 | 190 | if !needRewrite { 191 | return cmd.runOriginal(p), nil 192 | } 193 | 194 | if !canRewrite { 195 | if p.requiredPackage(pkg) { 196 | logf(cmd.Stderr, "errtrace required package %v missing errtrace import, needs rewrite", pkg) 197 | return 1, nil 198 | } 199 | return cmd.runOriginal(p), nil 200 | } 201 | 202 | // Use a temporary directory per-package that is rewritten. 203 | tempDir, err := os.MkdirTemp("", filepath.Base(pkg)) 204 | if err != nil { 205 | return -1, errtrace.Wrap(err) 206 | } 207 | defer os.RemoveAll(tempDir) //nolint:errcheck // best-effort removal of temp files. 208 | 209 | newArgs := make([]string, 0, len(p.ToolArgs)) 210 | for _, arg := range p.ToolArgs { 211 | f, ok := parsed[arg] 212 | if !ok || len(f.inserts) == 0 { 213 | newArgs = append(newArgs, arg) 214 | continue 215 | } 216 | 217 | // Add a //line directive so the original filepath is used in errors and panics. 218 | var out bytes.Buffer 219 | _, _ = fmt.Fprintf(&out, "//line %v:1\n", arg) 220 | 221 | if err := cmd.rewriteFile(f, &out); err != nil { 222 | return -1, errtrace.Wrap(err) 223 | } 224 | 225 | // TODO: Handle clashes with the same base name in different directories (E.g., with bazel). 226 | newFile := filepath.Join(tempDir, filepath.Base(arg)) 227 | if err := os.WriteFile(newFile, out.Bytes(), 0o666); err != nil { 228 | return -1, errtrace.Wrap(err) 229 | } 230 | 231 | newArgs = append(newArgs, newFile) 232 | } 233 | 234 | p.ToolArgs = newArgs 235 | return cmd.runOriginal(p), nil 236 | } 237 | 238 | func isCompile(arg string) bool { 239 | if runtime.GOOS == "windows" { 240 | arg = strings.TrimSuffix(arg, ".exe") 241 | } 242 | return strings.HasSuffix(arg, "compile") 243 | } 244 | 245 | func isGoFile(arg string) bool { 246 | return strings.HasSuffix(arg, ".go") 247 | } 248 | 249 | func (cmd *mainCmd) runOriginal(p toolExecParams) (exitCode int) { 250 | tool := exec.Command(p.Tool, p.ToolArgs...) 251 | tool.Stdin = cmd.Stdin 252 | tool.Stdout = cmd.Stdout 253 | tool.Stderr = cmd.Stderr 254 | 255 | if err := tool.Run(); err != nil { 256 | if exitErr, ok := err.(*exec.ExitError); ok { 257 | return exitErr.ExitCode() 258 | } 259 | logf(cmd.Stderr, "tool %v failed: %v", p.Tool, err) 260 | return 1 261 | } 262 | 263 | return 0 264 | } 265 | 266 | // binaryVersion returns a string that uniquely identifies the binary. 267 | // We prefer to use the VCS info embedded in the build if possible 268 | // falling back to the MD5 of the binary. 269 | func binaryVersion() (string, error) { 270 | sha, ok := readBuildSHA() 271 | if ok { 272 | return sha, nil 273 | } 274 | 275 | exe, err := os.Executable() 276 | if err != nil { 277 | return "", errtrace.Wrap(err) 278 | } 279 | 280 | contents, err := os.ReadFile(exe) 281 | if err != nil { 282 | return "", errtrace.Wrap(err) 283 | } 284 | 285 | binaryHash := md5.Sum(contents) 286 | return hex.EncodeToString(binaryHash[:]), nil 287 | } 288 | 289 | // readBuildSHA returns the VCS SHA, if it's from an unmodified VCS state. 290 | func readBuildSHA() (_ string, ok bool) { 291 | buildInfo, ok := debug.ReadBuildInfo() 292 | if !ok { 293 | return "", false 294 | } 295 | 296 | var sha string 297 | for _, s := range buildInfo.Settings { 298 | switch s.Key { 299 | case "vcs.revision": 300 | sha = s.Value 301 | case "vcs.modified": 302 | if s.Value != "false" { 303 | return "", false 304 | } 305 | } 306 | } 307 | return sha, sha != "" 308 | } 309 | 310 | // isStdLib checks if the current execution is for stdlib. 311 | func isStdLib(args []string) bool { 312 | return slices.Contains(args, "-std") 313 | } 314 | 315 | func packageSelectorMatch(selector, importPath string) bool { 316 | if pkgPrefix, ok := strings.CutSuffix(selector, "..."); ok { 317 | // foo/... should match foo, but not foobar so we want 318 | // the pkgPrefix to contain the /. 319 | if strings.TrimSuffix(pkgPrefix, "/") == importPath { 320 | return true 321 | } 322 | return strings.HasPrefix(importPath, pkgPrefix) 323 | } 324 | 325 | return selector == importPath 326 | } 327 | -------------------------------------------------------------------------------- /cmd/errtrace/toolexec_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "bytes" 6 | "fmt" 7 | "io/fs" 8 | "os" 9 | "os/exec" 10 | "path/filepath" 11 | "regexp" 12 | "runtime" 13 | "slices" 14 | "sort" 15 | "strings" 16 | "testing" 17 | 18 | "braces.dev/errtrace" 19 | "braces.dev/errtrace/internal/diff" 20 | ) 21 | 22 | func TestToolExec(t *testing.T) { 23 | const testProgDir = "./testdata/toolexec-test" 24 | const testProgPkg = "braces.dev/errtrace/cmd/errtrace/testdata/toolexec-test/" 25 | 26 | errTraceCmd := filepath.Join(t.TempDir(), "errtrace") 27 | if runtime.GOOS == "windows" { 28 | errTraceCmd += ".exe" // can't run binaries on Windows otherwise. 29 | } 30 | _, stderr, err := runGo(t, ".", "build", "-o", errTraceCmd, ".") 31 | if err != nil { 32 | t.Fatalf("compile errtrace failed: %v\nstderr: %s", err, stderr) 33 | } 34 | 35 | var wantTraces []string 36 | err = filepath.Walk(testProgDir, func(path string, info fs.FileInfo, err error) error { 37 | if err != nil { 38 | return errtrace.Wrap(err) 39 | } 40 | if info.IsDir() { 41 | return nil 42 | } 43 | 44 | for _, line := range findTraceLines(t, path) { 45 | absPath, err := filepath.Abs(path) 46 | if err != nil { 47 | t.Fatalf("abspath: %v", err) 48 | } 49 | if runtime.GOOS == "windows" { 50 | // On Windows, absPath uses windows path separators, e.g., "c:\foo" 51 | // but the paths reported in traces contain '/'. 52 | absPath = filepath.ToSlash(absPath) 53 | } 54 | 55 | wantTraces = append(wantTraces, fmt.Sprintf("%v:%v", absPath, line)) 56 | } 57 | return nil 58 | }) 59 | if err != nil { 60 | t.Fatal("Walk failed", err) 61 | } 62 | sort.Strings(wantTraces) 63 | 64 | tests := []struct { 65 | name string 66 | goArgs func(t testing.TB) []string 67 | wantTraces []string 68 | wantErr string 69 | }{ 70 | { 71 | name: "no toolexec", 72 | goArgs: func(t testing.TB) []string { 73 | return []string{"."} 74 | }, 75 | wantTraces: nil, 76 | }, 77 | { 78 | name: "toolexec with pkg", 79 | goArgs: func(t testing.TB) []string { 80 | return []string{"-toolexec", errTraceCmd, "."} 81 | }, 82 | wantTraces: wantTraces, 83 | }, 84 | { 85 | name: "toolexec with files", 86 | goArgs: func(t testing.TB) []string { 87 | files, err := goListFiles([]string{testProgDir}) 88 | if err != nil { 89 | t.Fatalf("list go files in %v: %v", testProgDir, err) 90 | } 91 | 92 | nonTest := slices.DeleteFunc(files, func(file string) bool { 93 | return strings.HasSuffix(file, "_test.go") 94 | }) 95 | 96 | args := []string{"-toolexec", errTraceCmd} 97 | args = append(args, nonTest...) 98 | return args 99 | }, 100 | wantTraces: wantTraces, 101 | }, 102 | { 103 | name: "toolexec with required-packages ...", 104 | goArgs: func(t testing.TB) []string { 105 | return []string{"-toolexec", errTraceCmd + " -required-packages " + testProgPkg + "...", "."} 106 | }, 107 | wantErr: "p1 missing errtrace import", 108 | }, 109 | { 110 | name: "toolexec with required-packages list", 111 | goArgs: func(t testing.TB) []string { 112 | requiredPackages := strings.Join([]string{ 113 | testProgPkg + "p2", 114 | testProgPkg + "p3", 115 | }, ",") 116 | return []string{"-toolexec", errTraceCmd + " -required-packages " + requiredPackages, "."} 117 | }, 118 | wantTraces: wantTraces, 119 | }, 120 | } 121 | 122 | for _, tt := range tests { 123 | t.Run(tt.name, func(t *testing.T) { 124 | args := tt.goArgs(t) 125 | 126 | verifyCompile := func(t testing.TB, _, stderr string, err error) { 127 | if tt.wantErr != "" { 128 | if err == nil { 129 | t.Fatalf("run expected error %v, but got no error", tt.wantErr) 130 | return 131 | } 132 | if !strings.Contains(stderr, tt.wantErr) { 133 | t.Fatalf("run unexpected error %q to contain %q", stderr, tt.wantErr) 134 | } 135 | return 136 | } 137 | 138 | if err != nil { 139 | t.Fatalf("run failed: %v\n%s", err, stderr) 140 | } 141 | } 142 | 143 | verifyTraces := func(t testing.TB, stdout string) { 144 | gotLines := fileLines(stdout) 145 | sort.Strings(gotLines) 146 | 147 | if d := diff.Diff(tt.wantTraces, gotLines); d != "" { 148 | t.Errorf("diff in traces:\n%s", d) 149 | t.Errorf("go run output:\n%s", stdout) 150 | } 151 | } 152 | 153 | t.Run("go run", func(t *testing.T) { 154 | runArgs := append([]string{"run"}, args...) 155 | stdout, stderr, err := runGo(t, testProgDir, runArgs...) 156 | verifyCompile(t, stdout, stderr, err) 157 | verifyTraces(t, stdout) 158 | }) 159 | 160 | t.Run("go build", func(t *testing.T) { 161 | outExe := filepath.Join(t.TempDir(), "main") 162 | if runtime.GOOS == "windows" { 163 | outExe += ".exe" 164 | } 165 | 166 | runArgs := append([]string{"build", "-o", outExe}, args...) 167 | stdout, stderr, err := runGo(t, testProgDir, runArgs...) 168 | verifyCompile(t, stdout, stderr, err) 169 | if err != nil { 170 | return 171 | } 172 | 173 | cmd := exec.Command(outExe) 174 | output, err := cmd.Output() 175 | if err != nil { 176 | t.Fatalf("run built binary: %v", err) 177 | } 178 | verifyTraces(t, string(output)) 179 | }) 180 | }) 181 | } 182 | } 183 | 184 | func findTraceLines(t testing.TB, file string) []int { 185 | f, err := os.Open(file) 186 | if err != nil { 187 | t.Fatal(err) 188 | } 189 | defer f.Close() //nolint:errcheck 190 | 191 | var traces []int 192 | scanner := bufio.NewScanner(f) 193 | var lineNum int 194 | for scanner.Scan() { 195 | lineNum++ 196 | line := scanner.Text() 197 | if strings.Contains(line, "// @trace") { 198 | traces = append(traces, lineNum) 199 | } 200 | } 201 | if err := scanner.Err(); err != nil { 202 | t.Fatal(err) 203 | } 204 | 205 | return traces 206 | } 207 | 208 | var fileLineRegex = regexp.MustCompile(`^\s*(.*:[0-9]+)$`) 209 | 210 | func fileLines(out string) []string { 211 | var fileLines []string 212 | for _, line := range strings.Split(out, "\n") { 213 | if fileLineRegex.MatchString(line) { 214 | fileLines = append(fileLines, strings.TrimSpace(line)) 215 | } 216 | } 217 | return fileLines 218 | } 219 | 220 | func runGo(t testing.TB, dir string, args ...string) (stdout, stderr string, _ error) { 221 | var stdoutBuf, stderrBuf bytes.Buffer 222 | cmd := exec.Command("go", args...) 223 | cmd.Dir = dir 224 | cmd.Stdin = nil 225 | cmd.Stdout = &stdoutBuf 226 | cmd.Stderr = &stderrBuf 227 | err := cmd.Run() 228 | return stdoutBuf.String(), stderrBuf.String(), errtrace.Wrap(err) 229 | } 230 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | # Don't post comments. 2 | comment: false 3 | 4 | ignore: 5 | - 'internal/diff/*.go' # we hit this only if tests fail 6 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package errtrace 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "braces.dev/errtrace/internal/pc" 8 | ) 9 | 10 | // New returns an error with the supplied text. 11 | // 12 | // It's equivalent to [errors.New] followed by [Wrap] to add caller information. 13 | // 14 | //go:noinline due to GetCaller (see [Wrap] for details). 15 | func New(text string) error { 16 | return wrap(errors.New(text), pc.GetCaller()) 17 | } 18 | 19 | // Errorf creates an error message 20 | // according to a format specifier 21 | // and returns the string as a value that satisfies error. 22 | // 23 | // It's equivalent to [fmt.Errorf] followed by [Wrap] to add caller information. 24 | // 25 | //go:noinline due to GetCaller (see [Wrap] for details). 26 | func Errorf(format string, args ...any) error { 27 | return wrap(fmt.Errorf(format, args...), pc.GetCaller()) 28 | } 29 | -------------------------------------------------------------------------------- /errtrace.go: -------------------------------------------------------------------------------- 1 | // Package errtrace provides the ability to track a return trace for errors. 2 | // This differs from a stack trace in that 3 | // it is not a snapshot of the call stack at the time of the error, 4 | // but rather a trace of the path taken by the error as it was returned 5 | // until it was finally handled. 6 | // 7 | // # Wrapping errors 8 | // 9 | // Use the [Wrap] function at a return site 10 | // to annotate it with the position of the return. 11 | // 12 | // // Before 13 | // if err != nil { 14 | // return err 15 | // } 16 | // 17 | // // After 18 | // if err != nil { 19 | // return errtrace.Wrap(err) 20 | // } 21 | // 22 | // # Formatting return traces 23 | // 24 | // errtrace provides the [Format] and [FormatString] functions 25 | // to obtain the return trace of an error. 26 | // 27 | // errtrace.Format(os.Stderr, err) 28 | // 29 | // See [Format] for details of the output format. 30 | // 31 | // Additionally, errors returned by errtrace will also print a trace 32 | // if formatted with the %+v verb when used with a Printf-style function. 33 | // 34 | // log.Printf("error: %+v", err) 35 | // 36 | // # Unwrapping errors 37 | // 38 | // Use the [UnwrapFrame] function to unwrap a single frame from an error. 39 | // 40 | // for err != nil { 41 | // frame, inner, ok := errtrace.UnwrapFrame(err) 42 | // if !ok { 43 | // break // end of trace 44 | // } 45 | // printFrame(frame) 46 | // err = inner 47 | // } 48 | // 49 | // See the [UnwrapFrame] example test for a more complete example. 50 | // 51 | // # See also 52 | // 53 | // https://github.com/bracesdev/errtrace. 54 | package errtrace 55 | 56 | import ( 57 | "fmt" 58 | "io" 59 | "strings" 60 | ) 61 | 62 | var _arena = newArena[errTrace](1024) 63 | 64 | func wrap(err error, callerPC uintptr) error { 65 | et := _arena.Take() 66 | et.err = err 67 | et.pc = callerPC 68 | return et 69 | } 70 | 71 | // Format writes the return trace for given error to the writer. 72 | // The output takes a format similar to the following: 73 | // 74 | // 75 | // 76 | // 77 | // : 78 | // 79 | // : 80 | // [...] 81 | // 82 | // Any error that has a method `TracePC() uintptr` will 83 | // contribute to the trace. 84 | // If the error doesn't have a return trace attached to it, 85 | // only the error message is reported. 86 | // If the error is comprised of multiple errors (e.g. with [errors.Join]), 87 | // the return trace of each error is reported as a tree. 88 | // 89 | // Returns an error if the writer fails. 90 | func Format(w io.Writer, target error) (err error) { 91 | return writeTree(w, buildTraceTree(target)) 92 | } 93 | 94 | // FormatString writes the return trace for err to a string. 95 | // Any error that has a method `TracePC() uintptr` will 96 | // contribute to the trace. 97 | // See [Format] for details of the output format. 98 | func FormatString(target error) string { 99 | var s strings.Builder 100 | _ = Format(&s, target) 101 | return s.String() 102 | } 103 | 104 | type errTrace struct { 105 | err error 106 | pc uintptr 107 | } 108 | 109 | func (e *errTrace) Error() string { 110 | return e.err.Error() 111 | } 112 | 113 | func (e *errTrace) Unwrap() error { 114 | return e.err 115 | } 116 | 117 | func (e *errTrace) Format(s fmt.State, verb rune) { 118 | if verb == 'v' && s.Flag('+') { 119 | _ = Format(s, e) 120 | return 121 | } 122 | 123 | fmt.Fprintf(s, fmt.FormatString(s, verb), e.err) 124 | } 125 | 126 | // TracePC returns the program counter for the location 127 | // in the frame that the error originated with. 128 | // 129 | // The returned PC is intended to be used with 130 | // runtime.CallersFrames or runtime.FuncForPC 131 | // to aid in generating the error return trace 132 | func (e *errTrace) TracePC() uintptr { 133 | return e.pc 134 | } 135 | 136 | // compile time tracePCprovider interface check 137 | var _ interface{ TracePC() uintptr } = &errTrace{} 138 | -------------------------------------------------------------------------------- /errtrace_line_test.go: -------------------------------------------------------------------------------- 1 | package errtrace_test 2 | 3 | import ( 4 | _ "embed" 5 | "errors" 6 | "fmt" 7 | "go/scanner" 8 | "go/token" 9 | "strconv" 10 | "strings" 11 | "testing" 12 | 13 | "braces.dev/errtrace" 14 | ) 15 | 16 | //go:embed errtrace_line_test.go 17 | var errtraceLineTestFile string 18 | 19 | // Note: The following tests verify the line, and assume that the 20 | // test names are unique, and that they are the only tests in this file. 21 | func TestWrap_Line(t *testing.T) { 22 | failed := errors.New("failed") 23 | 24 | tests := []struct { 25 | name string 26 | f func() error 27 | }{ 28 | { 29 | name: "return Wrap", // @group 30 | f: func() error { 31 | return errtrace.Wrap(failed) // @trace 32 | }, 33 | }, 34 | { 35 | name: "Wrap to intermediate and return", // @group 36 | f: func() (retErr error) { 37 | wrapped := errtrace.Wrap(failed) // @trace 38 | return wrapped 39 | }, 40 | }, 41 | { 42 | name: "Decorate error after Wrap", // @group 43 | f: func() (retErr error) { 44 | wrapped := errtrace.Wrap(failed) // @trace 45 | return fmt.Errorf("got err: %w", wrapped) 46 | }, 47 | }, 48 | { 49 | name: "defer updates errTrace", // @group 50 | f: func() (retErr error) { 51 | defer func() { 52 | retErr = errtrace.Wrap(retErr) // @trace 53 | }() 54 | 55 | return failed 56 | }, 57 | }, 58 | 59 | // Test error creation helpers. 60 | { 61 | name: "New", // @group 62 | f: func() (retErr error) { 63 | return errtrace.New("test") // @trace 64 | }, 65 | }, 66 | { 67 | name: "Errorf with no error args", // @group 68 | f: func() (retErr error) { 69 | return errtrace.Errorf("test %v", 1) // @trace 70 | }, 71 | }, 72 | { 73 | name: "Errorf with wrapped error arg", // @group 74 | f: func() (retErr error) { 75 | err := errtrace.New("test1") // @trace 76 | return errtrace.Errorf("test2: %w", err) // @trace 77 | }, 78 | }, 79 | 80 | // Sanity testing for WrapN functions. 81 | { 82 | name: "Test Wrap2", // @group 83 | f: func() (retErr error) { 84 | nested := func() (int, error) { 85 | return errtrace.Wrap2(returnErr2()) // @trace 86 | } 87 | 88 | _, err := nested() 89 | return err 90 | }, 91 | }, 92 | { 93 | name: "Test Wrap3", // @group 94 | f: func() (retErr error) { 95 | nested := func() (int, int, error) { 96 | return errtrace.Wrap3(returnErr3()) // @trace 97 | } 98 | 99 | _, _, err := nested() 100 | return err 101 | }, 102 | }, 103 | { 104 | name: "Test Wrap4", // @group 105 | f: func() (retErr error) { 106 | nested := func() (int, int, int, error) { 107 | return errtrace.Wrap4(returnErr4()) // @trace 108 | } 109 | 110 | _, _, _, err := nested() 111 | return err 112 | }, 113 | }, 114 | { 115 | name: "Test Wrap5", // @group 116 | f: func() (retErr error) { 117 | nested := func() (int, int, int, int, error) { 118 | return errtrace.Wrap5(returnErr5()) // @trace 119 | } 120 | 121 | _, _, _, _, err := nested() 122 | return err 123 | }, 124 | }, 125 | { 126 | name: "Test Wrap6", // @group 127 | f: func() (retErr error) { 128 | nested := func() (int, int, int, int, int, error) { 129 | return errtrace.Wrap6(returnErr6()) // @trace 130 | } 131 | 132 | _, _, _, _, _, err := nested() 133 | return err 134 | }, 135 | }, 136 | 137 | // Multi-errors. 138 | { 139 | name: "MutiError", // @group 140 | f: func() error { 141 | return errors.Join( 142 | errtrace.New("foo"), // @trace 143 | errtrace.New("bar"), // @trace 144 | ) 145 | }, 146 | }, 147 | { 148 | name: "MultiError/Wrapped", // @group 149 | f: func() error { 150 | err := errors.Join( 151 | errtrace.New("foo"), // @trace 152 | errtrace.New("bar"), // @trace 153 | ) 154 | return errtrace.Wrap(err) // @trace 155 | }, 156 | }, 157 | } 158 | 159 | testMarkers, err := parseMarkers(errtraceLineTestFile) 160 | if err != nil { 161 | t.Fatal(err) 162 | } 163 | t.Logf("parsed markers: %v", testMarkers) 164 | 165 | for _, tt := range tests { 166 | t.Run(tt.name, func(t *testing.T) { 167 | markers := testMarkers[tt.name] 168 | if len(markers) == 0 { 169 | t.Fatalf("didn't find any markers for test") 170 | } 171 | 172 | gotErr := tt.f() 173 | got := errtrace.FormatString(gotErr) 174 | 175 | for _, wantLine := range markers { 176 | wantFileLine := fmt.Sprintf("errtrace_line_test.go:%v", wantLine) 177 | if !strings.Contains(got, wantFileLine) { 178 | t.Errorf("formatted output is missing file:line %q in:\n%s", wantFileLine, got) 179 | } 180 | } 181 | }) 182 | } 183 | } 184 | 185 | func returnErr2() (int, error) { return 1, errors.New("test") } 186 | func returnErr3() (int, int, error) { return 1, 2, errors.New("test") } 187 | func returnErr4() (int, int, int, error) { return 1, 2, 3, errors.New("test") } 188 | func returnErr5() (int, int, int, int, error) { return 1, 2, 3, 4, errors.New("test") } 189 | func returnErr6() (int, int, int, int, int, error) { return 1, 2, 3, 4, 5, errors.New("test") } 190 | 191 | // parseMarkers parses the source file and returns a map 192 | // from marker group name to line numbers in that group. 193 | // 194 | // Marker groups are identified by a '@group' comment 195 | // immediately following a string literal -- ignoring operators. 196 | // For example: 197 | // 198 | // { 199 | // name: "foo", // @group 200 | // // Note that the "," is ignored as it's an operator. 201 | // } 202 | // 203 | // Markers in the group are identified by a '@trace' comment. 204 | // For example: 205 | // 206 | // { 207 | // name: "foo", // @group 208 | // f: func() error { 209 | // return errtrace.Wrap(failed) // @trace 210 | // }, 211 | // } 212 | // 213 | // A group ends when a new group starts or the end of the file is reached. 214 | func parseMarkers(src string) (map[string][]int, error) { 215 | // We don't need to parse the Go AST. 216 | // Just lexical analysis is enough. 217 | fset := token.NewFileSet() 218 | file := fset.AddFile("errtrace_line_test.go", fset.Base(), len(src)) 219 | 220 | var ( 221 | errs []error 222 | scan scanner.Scanner 223 | ) 224 | scan.Init( 225 | file, 226 | []byte(src), 227 | func(pos token.Position, msg string) { 228 | // This function is called for each error encountered 229 | // while scanning. 230 | errs = append(errs, fmt.Errorf("%v:%v", pos, msg)) 231 | }, 232 | scanner.ScanComments, 233 | ) 234 | 235 | errf := func(pos token.Pos, format string, args ...any) { 236 | msg := fmt.Sprintf(format, args...) 237 | errs = append(errs, fmt.Errorf("%v:%v", file.Position(pos), msg)) 238 | } 239 | 240 | markers := make(map[string][]int) 241 | var ( 242 | currentMarker string 243 | lastStringLiteral string 244 | ) 245 | for { 246 | pos, tok, lit := scan.Scan() 247 | 248 | switch tok { 249 | case token.EOF: 250 | return markers, errors.Join(errs...) 251 | 252 | case token.STRING: 253 | s, err := strconv.Unquote(lit) 254 | if err != nil { 255 | errf(pos, "bad string literal: %v", err) 256 | continue 257 | } 258 | lastStringLiteral = s 259 | 260 | case token.COMMENT: 261 | switch lit { 262 | case "// @group": 263 | if lastStringLiteral == "" { 264 | errf(pos, "expected string literal before @group") 265 | continue 266 | } 267 | 268 | currentMarker = lastStringLiteral 269 | 270 | case "// @trace": 271 | if currentMarker == "" { 272 | errf(pos, "expected @group before @trace") 273 | continue 274 | } 275 | 276 | markers[currentMarker] = append(markers[currentMarker], file.Line(pos)) 277 | } 278 | 279 | default: 280 | // For all other non-operator tokens, reset the last string literal. 281 | if !tok.IsOperator() { 282 | lastStringLiteral = "" 283 | } 284 | } 285 | } 286 | } 287 | -------------------------------------------------------------------------------- /errtrace_test.go: -------------------------------------------------------------------------------- 1 | package errtrace_test 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "reflect" 7 | "runtime" 8 | "strings" 9 | "testing" 10 | 11 | "braces.dev/errtrace" 12 | ) 13 | 14 | func TestWrapNil(t *testing.T) { 15 | if err := errtrace.Wrap(nil); err != nil { 16 | t.Errorf("Wrap(): want nil, got %v", err) 17 | } 18 | } 19 | 20 | func TestWrappedError(t *testing.T) { 21 | orig := errors.New("foo") 22 | err := errtrace.Wrap(orig) 23 | 24 | if want, got := "foo", err.Error(); want != got { 25 | t.Errorf("Error(): want %q, got %q", want, got) 26 | } 27 | } 28 | 29 | func TestWrappedErrorIs(t *testing.T) { 30 | orig := errors.New("foo") 31 | err := errtrace.Wrap(orig) 32 | 33 | if !errors.Is(err, orig) { 34 | t.Errorf("Is(): want true, got false") 35 | } 36 | } 37 | 38 | type myError struct{ x int } 39 | 40 | func (m *myError) Error() string { 41 | return "great sadness" 42 | } 43 | 44 | func TestWrappedErrorAs(t *testing.T) { 45 | err := errtrace.Wrap(&myError{x: 42}) 46 | var m *myError 47 | if !errors.As(err, &m) { 48 | t.Errorf("As(): want true, got false") 49 | } 50 | 51 | if want, got := 42, m.x; want != got { 52 | t.Errorf("As(): want %d, got %d", want, got) 53 | } 54 | } 55 | 56 | func TestFormatTrace(t *testing.T) { 57 | orig := errors.New("foo") 58 | 59 | f := func() error { 60 | return errtrace.Wrap(orig) 61 | } 62 | g := func() error { 63 | return errtrace.Wrap(f()) 64 | } 65 | 66 | var h func(int) error 67 | h = func(n int) error { 68 | for n > 0 { 69 | return errtrace.Wrap(h(n - 1)) 70 | } 71 | return errtrace.Wrap(g()) 72 | } 73 | 74 | err := h(3) 75 | 76 | trace := errtrace.FormatString(err) 77 | 78 | // Line numbers change, 79 | // so verify function names and that the file name is correct. 80 | if want := "errtrace_test.go:"; !strings.Contains(trace, want) { 81 | t.Errorf("FormatString(): want trace to contain %q, got:\n%s", want, trace) 82 | } 83 | 84 | funcName := func(fn interface{}) string { 85 | return runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() 86 | } 87 | 88 | if fName := funcName(f) + "\n"; !strings.Contains(trace, fName) { 89 | t.Errorf("FormatString(): want trace to contain %q, got:\n%s", fName, trace) 90 | } 91 | 92 | if gName := funcName(g) + "\n"; !strings.Contains(trace, gName) { 93 | t.Errorf("FormatString(): want trace to contain %q, got:\n%s", gName, trace) 94 | } 95 | 96 | hName := funcName(h) + "\n" 97 | if want, got := 4, strings.Count(trace, hName); want != got { 98 | t.Errorf("FormatString(): want trace to contain %d instances of %q, got %d\n%s", want, hName, got, trace) 99 | } 100 | } 101 | 102 | func TestFormatVerbs(t *testing.T) { 103 | err := errors.New("error") 104 | wrapped := errtrace.Wrap(err) 105 | 106 | tests := []struct { 107 | name string 108 | fmt string 109 | want string 110 | }{ 111 | { 112 | name: "verb s", 113 | fmt: "%s", 114 | want: "error", 115 | }, 116 | { 117 | name: "verb v", 118 | fmt: "%v", 119 | want: "error", 120 | }, 121 | { 122 | name: "verb q", 123 | fmt: "%q", 124 | want: `"error"`, 125 | }, 126 | { 127 | name: "padded string", 128 | fmt: "%10s", 129 | want: " error", 130 | }, 131 | } 132 | 133 | for _, tt := range tests { 134 | t.Run(tt.name, func(t *testing.T) { 135 | if want, got := tt.want, fmt.Sprintf(tt.fmt, err); want != got { 136 | t.Errorf("unwrapped: want %q, got %q", want, got) 137 | } 138 | 139 | if want, got := tt.want, fmt.Sprintf(tt.fmt, wrapped); want != got { 140 | t.Errorf("wrapped: want %q, got %q", want, got) 141 | } 142 | }) 143 | } 144 | } 145 | 146 | func TestFormat(t *testing.T) { 147 | e1 := errtrace.New("e1") 148 | e2 := errtrace.Errorf("e2: %w", e1) 149 | e3 := errtrace.Wrap(e2) 150 | 151 | tests := []struct { 152 | name string 153 | err error 154 | want string 155 | wantTraces int 156 | }{ 157 | { 158 | name: "new error", 159 | err: e1, 160 | want: "e1", 161 | wantTraces: 1, 162 | }, 163 | { 164 | name: "wrapped with Errorf", 165 | err: e2, 166 | want: "e2: e1", 167 | wantTraces: 2, 168 | }, 169 | { 170 | name: "wrap after Errorf", 171 | err: e3, 172 | want: "e2: e1", 173 | wantTraces: 3, 174 | }, 175 | } 176 | 177 | for _, tt := range tests { 178 | t.Run(tt.name, func(t *testing.T) { 179 | if want, got := tt.want, fmt.Sprintf("%s", tt.err); want != got { 180 | t.Errorf("message: want %q, got: %q", want, got) 181 | } 182 | 183 | withTrace := fmt.Sprintf("%+v", tt.err) 184 | if !strings.HasPrefix(withTrace, tt.want) { 185 | t.Errorf("expected error message %q in trace:\n%s", tt.want, withTrace) 186 | } 187 | if want, got := tt.wantTraces, strings.Count(withTrace, "errtrace_test.TestFormat"); want != got { 188 | t.Errorf("expected traces %v, got %v in:\n%s", want, got, withTrace) 189 | } 190 | }) 191 | } 192 | } 193 | 194 | func BenchmarkWrap(b *testing.B) { 195 | err := errors.New("foo") 196 | b.RunParallel(func(pb *testing.PB) { 197 | for pb.Next() { 198 | _ = errtrace.Wrap(err) 199 | } 200 | }) 201 | } 202 | 203 | func BenchmarkFmtErrorf(b *testing.B) { 204 | err := errors.New("foo") 205 | b.RunParallel(func(pb *testing.PB) { 206 | for pb.Next() { 207 | _ = fmt.Errorf("bar: %w", err) 208 | } 209 | }) 210 | } 211 | -------------------------------------------------------------------------------- /example_errhelper_test.go: -------------------------------------------------------------------------------- 1 | package errtrace_test 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | 7 | "braces.dev/errtrace" 8 | "braces.dev/errtrace/internal/tracetest" 9 | ) 10 | 11 | func f1Wrap() error { 12 | return wrap(f2Wrap(), "u=1") 13 | } 14 | 15 | func f2Wrap() error { 16 | return wrap(f3Wrap(), "method=order") 17 | } 18 | 19 | func f3Wrap() error { 20 | return wrap(errors.New("failed"), "foo") 21 | } 22 | 23 | func wrap(err error, fields ...string) error { 24 | return errtrace.GetCaller(). 25 | Wrap(fmt.Errorf("%w %v", err, fields)) 26 | } 27 | 28 | func Example_getCaller() { 29 | got := errtrace.FormatString(f1Wrap()) 30 | 31 | // make trace agnostic to environment-specific location 32 | // and less sensitive to line number changes. 33 | fmt.Println(tracetest.MustClean(got)) 34 | 35 | // Output: 36 | //failed [foo] [method=order] [u=1] 37 | // 38 | //braces.dev/errtrace_test.f3Wrap 39 | // /path/to/errtrace/example_errhelper_test.go:3 40 | //braces.dev/errtrace_test.f2Wrap 41 | // /path/to/errtrace/example_errhelper_test.go:2 42 | //braces.dev/errtrace_test.f1Wrap 43 | // /path/to/errtrace/example_errhelper_test.go:1 44 | } 45 | -------------------------------------------------------------------------------- /example_http_test.go: -------------------------------------------------------------------------------- 1 | package errtrace_test 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "net" 7 | "net/http" 8 | "strings" 9 | 10 | "braces.dev/errtrace" 11 | "braces.dev/errtrace/internal/tracetest" 12 | ) 13 | 14 | func Example_http() { 15 | tp := &http.Transport{Dial: rateLimitDialer} 16 | client := &http.Client{Transport: tp} 17 | ps := &PackageStore{ 18 | client: client, 19 | } 20 | 21 | _, err := ps.Get() 22 | fmt.Printf("Error fetching packages: %+v\n", tracetest.MustClean(errtrace.FormatString(err))) 23 | // Output: 24 | // Error fetching packages: Get "http://example.com/packages.index": connect rate limited 25 | // 26 | // braces.dev/errtrace_test.rateLimitDialer 27 | // /path/to/errtrace/example_http_test.go:3 28 | // braces.dev/errtrace_test.(*PackageStore).updateIndex 29 | // /path/to/errtrace/example_http_test.go:2 30 | // braces.dev/errtrace_test.(*PackageStore).Get 31 | // /path/to/errtrace/example_http_test.go:1 32 | } 33 | 34 | type PackageStore struct { 35 | client *http.Client 36 | packagesCached []string 37 | } 38 | 39 | func (ps *PackageStore) Get() ([]string, error) { 40 | if ps.packagesCached != nil { 41 | return ps.packagesCached, nil 42 | } 43 | 44 | packages, err := ps.updateIndex() 45 | if err != nil { 46 | return nil, errtrace.Wrap(err) 47 | } 48 | 49 | ps.packagesCached = packages 50 | return packages, nil 51 | } 52 | 53 | func (ps *PackageStore) updateIndex() ([]string, error) { 54 | resp, err := ps.client.Get("http://example.com/packages.index") 55 | if err != nil { 56 | return nil, errtrace.Wrap(err) 57 | } 58 | 59 | contents, err := io.ReadAll(resp.Body) 60 | if err != nil { 61 | return nil, errtrace.Wrap(err) 62 | } 63 | 64 | return strings.Split(string(contents), ","), nil 65 | } 66 | 67 | func rateLimitDialer(network, addr string) (net.Conn, error) { 68 | // for testing, always return an error. 69 | return nil, errtrace.New("connect rate limited") 70 | } 71 | -------------------------------------------------------------------------------- /example_trace_test.go: -------------------------------------------------------------------------------- 1 | package errtrace_test 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "runtime" 7 | "strings" 8 | 9 | "braces.dev/errtrace" 10 | "braces.dev/errtrace/internal/tracetest" 11 | ) 12 | 13 | func f1() error { 14 | return errtrace.Wrap(f2()) 15 | } 16 | 17 | func f2() error { 18 | return errtrace.Wrap(f3()) 19 | } 20 | 21 | func f3() error { 22 | return errtrace.New("failed") 23 | } 24 | 25 | func Example_trace() { 26 | got := errtrace.FormatString(f1()) 27 | 28 | // make trace agnostic to environment-specific location 29 | // and less sensitive to line number changes. 30 | fmt.Println(tracetest.MustClean(got)) 31 | 32 | // Output: 33 | //failed 34 | // 35 | //braces.dev/errtrace_test.f3 36 | // /path/to/errtrace/example_trace_test.go:3 37 | //braces.dev/errtrace_test.f2 38 | // /path/to/errtrace/example_trace_test.go:2 39 | //braces.dev/errtrace_test.f1 40 | // /path/to/errtrace/example_trace_test.go:1 41 | } 42 | 43 | func f4() error { 44 | return errtrace.Wrap(fmt.Errorf("wrapped: %w", f1())) 45 | } 46 | 47 | func ExampleUnwrapFrame() { 48 | var frames []runtime.Frame 49 | current := f4() 50 | for current != nil { 51 | frame, inner, ok := errtrace.UnwrapFrame(current) 52 | if !ok { 53 | // If the error is not wrapped with errtrace, 54 | // unwrap it directly with errors.Unwrap. 55 | current = errors.Unwrap(current) 56 | continue 57 | // Note that this example does not handle multi-errors, 58 | // for example those returned by errors.Join. 59 | // To handle those, this loop would need to also check 60 | // for the 'Unwrap() []error' method on the error. 61 | } 62 | frames = append(frames, frame) 63 | current = inner 64 | } 65 | 66 | var trace strings.Builder 67 | for _, frame := range frames { 68 | fmt.Fprintf(&trace, "%s\n\t%s:%d\n", frame.Function, frame.File, frame.Line) 69 | } 70 | fmt.Println(tracetest.MustClean(trace.String())) 71 | 72 | // Output: 73 | // 74 | //braces.dev/errtrace_test.f4 75 | // /path/to/errtrace/example_trace_test.go:4 76 | //braces.dev/errtrace_test.f1 77 | // /path/to/errtrace/example_trace_test.go:1 78 | //braces.dev/errtrace_test.f2 79 | // /path/to/errtrace/example_trace_test.go:2 80 | //braces.dev/errtrace_test.f3 81 | // /path/to/errtrace/example_trace_test.go:3 82 | } 83 | -------------------------------------------------------------------------------- /example_tree_test.go: -------------------------------------------------------------------------------- 1 | package errtrace_test 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "strings" 7 | 8 | "braces.dev/errtrace" 9 | "braces.dev/errtrace/internal/tracetest" 10 | ) 11 | 12 | func normalErr(i int) error { 13 | return fmt.Errorf("std err %v", i) 14 | } 15 | 16 | func wrapNormalErr(i int) error { 17 | return errtrace.Wrap(normalErr(i)) 18 | } 19 | 20 | func nestedErrorList(i int) error { 21 | return errors.Join( 22 | normalErr(i), 23 | wrapNormalErr(i+1), 24 | ) 25 | } 26 | 27 | func Example_tree() { 28 | errs := errtrace.Wrap(errors.Join( 29 | normalErr(1), 30 | wrapNormalErr(2), 31 | nestedErrorList(3), 32 | )) 33 | got := errtrace.FormatString(errs) 34 | 35 | // make trace agnostic to environment-specific location 36 | // and less sensitive to line number changes. 37 | fmt.Println(trimTrailingSpaces(tracetest.MustClean(got))) 38 | 39 | // Output: 40 | // +- std err 1 41 | // | 42 | // +- std err 2 43 | // | 44 | // | braces.dev/errtrace_test.wrapNormalErr 45 | // | /path/to/errtrace/example_tree_test.go:1 46 | // | 47 | // | +- std err 3 48 | // | | 49 | // | +- std err 4 50 | // | | 51 | // | | braces.dev/errtrace_test.wrapNormalErr 52 | // | | /path/to/errtrace/example_tree_test.go:1 53 | // | | 54 | // +- std err 3 55 | // | std err 4 56 | // | 57 | // std err 1 58 | // std err 2 59 | // std err 3 60 | // std err 4 61 | // 62 | // braces.dev/errtrace_test.Example_tree 63 | // /path/to/errtrace/example_tree_test.go:2 64 | } 65 | 66 | func trimTrailingSpaces(s string) string { 67 | lines := strings.Split(s, "\n") 68 | for i := range lines { 69 | lines[i] = strings.TrimRight(lines[i], " \t") 70 | } 71 | return strings.Join(lines, "\n") 72 | } 73 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module braces.dev/errtrace 2 | 3 | go 1.21 4 | -------------------------------------------------------------------------------- /internal/diff/diff.go: -------------------------------------------------------------------------------- 1 | // Package diff provides utilities for comparing strings and slices 2 | // to produce a readable diff output for tests. 3 | package diff 4 | 5 | import ( 6 | "fmt" 7 | "strconv" 8 | "strings" 9 | ) 10 | 11 | // Lines returns a diff of two strings, line-by-line. 12 | func Lines(want, got string) string { 13 | return Diff(strings.Split(want, "\n"), strings.Split(got, "\n")) 14 | } 15 | 16 | // Diff is a silly diff implementation 17 | // that compares the provided slices and returns a diff of them. 18 | func Diff[T comparable](want, got []T) string { 19 | // We want to pad diff output with line number in the format: 20 | // 21 | // - 1 | line 1 22 | // + 2 | line 2 23 | // 24 | // To do that, we need to know the longest line number. 25 | longest := max(len(want), len(got)) 26 | lineFormat := fmt.Sprintf("%%s %%-%dd | %%v\n", len(strconv.Itoa(longest))) // e.g. "%-2d | %s%v\n" 27 | const ( 28 | minus = "-" 29 | plus = "+" 30 | equal = " " 31 | ) 32 | 33 | var buf strings.Builder 34 | writeLine := func(idx int, kind string, v T) { 35 | fmt.Fprintf(&buf, lineFormat, kind, idx+1, v) 36 | } 37 | 38 | var lastEqs []T 39 | for i := 0; i < len(want) || i < len(got); i++ { 40 | if i < len(want) && i < len(got) && want[i] == got[i] { 41 | lastEqs = append(lastEqs, want[i]) 42 | continue 43 | } 44 | 45 | // If there are any equal lines before this, show up to 3 of them. 46 | if len(lastEqs) > 0 { 47 | start := max(len(lastEqs)-3, 0) 48 | for j, eq := range lastEqs[start:] { 49 | writeLine(i-3+j, equal, eq) 50 | } 51 | } 52 | 53 | if i < len(want) { 54 | writeLine(i, minus, want[i]) 55 | } 56 | if i < len(got) { 57 | writeLine(i, plus, got[i]) 58 | } 59 | 60 | lastEqs = nil 61 | } 62 | 63 | return buf.String() 64 | } 65 | -------------------------------------------------------------------------------- /internal/pc/pc_amd64.s: -------------------------------------------------------------------------------- 1 | //go:build !safe && amd64 2 | 3 | #include "textflag.h" 4 | 5 | // func GetCaller() uintptr 6 | TEXT ·GetCaller(SB),NOSPLIT|NOFRAME,$0-8 7 | // BP is the hardware register frame pointer, as used in: 8 | // https://github.com/golang/go/blob/go1.21.4/src/runtime/asm_amd64.s#L2091-L2093 9 | // The return address sits one word above, hence we evaluate `*(BP+8)`. 10 | MOVQ 8(BP), AX 11 | MOVQ AX, ret+0(FP) 12 | RET 13 | 14 | // func GetCallerSkip1() uintptr 15 | TEXT ·GetCallerSkip1(SB),NOSPLIT|NOFRAME,$0-8 16 | // BP contains the frame pointer, dereference it to skip a frame. 17 | MOVQ (BP), AX 18 | MOVQ 8(AX), AX 19 | MOVQ AX, ret+0(FP) 20 | RET 21 | -------------------------------------------------------------------------------- /internal/pc/pc_arm64.s: -------------------------------------------------------------------------------- 1 | //go:build !safe && arm64 2 | 3 | #include "textflag.h" 4 | 5 | // func GetCaller() uintptr 6 | TEXT ·GetCaller(SB),NOSPLIT|NOFRAME,$0-8 7 | // R29 is the frame pointer, documented in https://pkg.go.dev/cmd/internal/obj/arm64 8 | // and used in https://github.com/golang/go/blob/go1.21.4/src/runtime/asm_arm64.s#L1571 9 | // The return address sits one word above, hence we evaluate `*(R29+8)`. 10 | MOVD 8(R29), R20 11 | MOVD R20, ret+0(FP) 12 | RET 13 | 14 | 15 | // func GetCallerSkip1() uintptr 16 | TEXT ·GetCallerSkip1(SB),NOSPLIT|NOFRAME,$0-8 17 | // R29 contains the frame pointer, dereference it to skip a frame. 18 | MOVD (R29), R20 19 | MOVD 8(R20), R20 20 | MOVD R20, ret+0(FP) 21 | RET 22 | -------------------------------------------------------------------------------- /internal/pc/pc_asm.go: -------------------------------------------------------------------------------- 1 | //go:build !safe && (arm64 || amd64) 2 | 3 | // Package pc provides access to the program counter 4 | // to determine the caller of a function. 5 | package pc 6 | 7 | // GetCaller returns the program counter of the caller's caller. 8 | func GetCaller() uintptr 9 | 10 | // GetCallerSkip1 is similar to GetCaller, but skips an additional caller. 11 | func GetCallerSkip1() uintptr 12 | -------------------------------------------------------------------------------- /internal/pc/pc_safe.go: -------------------------------------------------------------------------------- 1 | //go:build safe || !(amd64 || arm64) 2 | 3 | package pc 4 | 5 | import "runtime" 6 | 7 | // GetCaller returns the program counter of the caller's caller. 8 | func GetCaller() uintptr { 9 | return getCaller(0) 10 | } 11 | 12 | // GetCallerSkip1 is similar to GetCaller, but skips an additional caller. 13 | func GetCallerSkip1() uintptr { 14 | return getCaller(1) 15 | } 16 | 17 | func getCaller(skip int) uintptr { 18 | const baseSkip = 1 + // runtime.Callers 19 | 1 + // getCaller 20 | 1 + // GetCaller or GetCallerSkip1 21 | 1 // errtrace.Wrap, or errtrace.GetCaller 22 | 23 | var callers [1]uintptr 24 | n := runtime.Callers(baseSkip+skip, callers[:]) // skip getcallerpc + caller 25 | if n == 0 { 26 | return 0 27 | } 28 | return callers[0] 29 | } 30 | -------------------------------------------------------------------------------- /internal/pc/pc_test.go: -------------------------------------------------------------------------------- 1 | package pc 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | ) 7 | 8 | //go:noinline 9 | func wrap(err error) uintptr { 10 | return GetCaller() 11 | } 12 | 13 | func BenchmarkGetCaller(b *testing.B) { 14 | err := errors.New("test") 15 | 16 | var last uintptr 17 | for i := 0; i < b.N; i++ { 18 | cur := wrap(err) 19 | if cur == 0 { 20 | panic("invalid PC") 21 | } 22 | if last != 0 && cur != last { 23 | panic("inconsistent results") 24 | } 25 | last = cur 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /internal/tracetest/clean.go: -------------------------------------------------------------------------------- 1 | // Package tracetest provides utilities for errtrace 2 | // to test error trace output conveniently. 3 | package tracetest 4 | 5 | import ( 6 | "fmt" 7 | "path/filepath" 8 | "regexp" 9 | "runtime" 10 | "slices" 11 | "strconv" 12 | "strings" 13 | ) 14 | 15 | const _fixedDir = "/path/to/errtrace" 16 | 17 | // _fileLineMatcher matches file:line where file starts with the fixedDir. 18 | // Capture groups: 19 | // 20 | // 1. file path 21 | // 2. line number 22 | var _fileLineMatcher = regexp.MustCompile("(" + regexp.QuoteMeta(_fixedDir) + `[^:]+):(\d+)`) 23 | 24 | // MustClean makes traces more deterministic for tests by: 25 | // 26 | // - replacing the environment-specific path to errtrace 27 | // with the fixed path /path/to/errtrace 28 | // - replacing line numbers with the lowest values 29 | // that maintain relative ordering within the file 30 | // 31 | // Note that lines numbers are replaced with increasing values starting at 1, 32 | // with earlier positions in the file getting lower numbers. 33 | // The relative ordering of lines within a file is maintained. 34 | func MustClean(trace string) string { 35 | // Get deterministic file paths first. 36 | trace = strings.ReplaceAll(trace, getErrtraceDir(), _fixedDir) 37 | 38 | replacer := make(fileLineReplacer) 39 | for _, m := range _fileLineMatcher.FindAllStringSubmatch(trace, -1) { 40 | file := m[1] 41 | lineStr := m[2] 42 | line, err := strconv.Atoi(lineStr) 43 | if err != nil { 44 | panic(fmt.Sprintf("matched bad line number in %q: %v", m[0], err)) 45 | } 46 | replacer.Add(file, line) 47 | } 48 | 49 | return strings.NewReplacer(replacer.Replacements()...).Replace(trace) 50 | } 51 | 52 | func getErrtraceDir() string { 53 | _, file, _, _ := runtime.Caller(0) 54 | // Note: Assumes specific location of this file in errtrace, strip internal/tracetest/ 55 | dir := filepath.Dir(filepath.Dir(filepath.Dir(file))) 56 | 57 | // On Windows, filepath.Dir cleans the path, which modifies the separator. 58 | // To get back the original separator, truncate the original string. 59 | return file[:len(dir)] 60 | } 61 | 62 | // fileLineReplacer maintains a mapping from 63 | // file name to line numbers in that file that are referenced. 64 | // This is used to generate the replacements to be applied to the trace. 65 | type fileLineReplacer map[string][]int 66 | 67 | // Add adds a file:line pair to the replacer. 68 | func (r fileLineReplacer) Add(file string, line int) { 69 | r[file] = append(r[file], line) 70 | } 71 | 72 | // Replacements generates a slice of pairs of Replacements 73 | // to be applied to the trace. 74 | // 75 | // The first element in each pair is the original file:line 76 | // and the second element is the replacement file:line. 77 | // This returned slice can be fed into strings.NewReplacer. 78 | func (r fileLineReplacer) Replacements() []string { 79 | var allReplacements []string 80 | for file, fileLines := range r { 81 | // Sort the lines in the file, and remove duplicates. 82 | // The result will be a slice of unique line numbers. 83 | // The index of each line in this slice + 1 will be its new line number. 84 | slices.Sort(fileLines) 85 | fileLines = slices.Compact(fileLines) 86 | 87 | for idx, origLine := range fileLines { 88 | replaceLine := idx + 1 89 | allReplacements = append(allReplacements, 90 | fmt.Sprintf("%v:%v", file, origLine), 91 | fmt.Sprintf("%v:%v", file, replaceLine)) 92 | } 93 | } 94 | return allReplacements 95 | } 96 | -------------------------------------------------------------------------------- /internal/tracetest/clean_2_test.go: -------------------------------------------------------------------------------- 1 | package tracetest 2 | 3 | import "braces.dev/errtrace" 4 | 5 | // Separate file to verify how Clean handles separate files. 6 | 7 | func f1() error { 8 | return errtrace.Wrap(f2()) 9 | } 10 | 11 | func f2() error { 12 | if err := f3(); err != nil { 13 | return errtrace.Errorf("f3: %w", err) 14 | } 15 | 16 | return nil 17 | } 18 | 19 | func f3() error { 20 | return errtrace.New("err") 21 | } 22 | -------------------------------------------------------------------------------- /internal/tracetest/clean_test.go: -------------------------------------------------------------------------------- 1 | package tracetest 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | 7 | "braces.dev/errtrace" 8 | ) 9 | 10 | func TestClean_RealTrace(t *testing.T) { 11 | e1 := errtrace.Wrap(f1()) 12 | // dummy comments to offset line numbers by > 1 13 | // 14 | // dummy line to make line numbers offset by > 1 15 | e2 := errtrace.Wrap(e1) 16 | // 17 | e3 := errtrace.Wrap(e2) 18 | 19 | want := strings.Join([]string{ 20 | "f3: err", 21 | "", 22 | "braces.dev/errtrace/internal/tracetest.f3", 23 | " /path/to/errtrace/internal/tracetest/clean_2_test.go:3", 24 | "braces.dev/errtrace/internal/tracetest.f2", 25 | " /path/to/errtrace/internal/tracetest/clean_2_test.go:2", 26 | "braces.dev/errtrace/internal/tracetest.f1", 27 | " /path/to/errtrace/internal/tracetest/clean_2_test.go:1", 28 | "braces.dev/errtrace/internal/tracetest.TestClean_RealTrace", 29 | " /path/to/errtrace/internal/tracetest/clean_test.go:1", 30 | "braces.dev/errtrace/internal/tracetest.TestClean_RealTrace", 31 | " /path/to/errtrace/internal/tracetest/clean_test.go:2", 32 | "braces.dev/errtrace/internal/tracetest.TestClean_RealTrace", 33 | " /path/to/errtrace/internal/tracetest/clean_test.go:3", 34 | "", 35 | }, "\n") 36 | 37 | if got := MustClean(errtrace.FormatString(e3)); want != got { 38 | t.Errorf("want:\n%v\ngot:\n%v\n", want, got) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /tree.go: -------------------------------------------------------------------------------- 1 | package errtrace 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "runtime" 8 | "slices" 9 | "strings" 10 | ) 11 | 12 | // traceTree represents an error and its traces 13 | // as a tree structure. 14 | // 15 | // The root of the tree is the trace for the error itself. 16 | // Children, if any, are the traces for each of the errors 17 | // inside the multi-error (if the error was a multi-error). 18 | type traceTree struct { 19 | // Err is the error at the root of this tree. 20 | Err error 21 | 22 | // Trace is the trace for the error down until 23 | // the first multi-error was encountered. 24 | // 25 | // The trace is in the reverse order of the call stack. 26 | // The first element is the deepest call in the stack, 27 | // and the last element is the shallowest call in the stack. 28 | Trace []runtime.Frame 29 | 30 | // Children are the traces for each of the errors 31 | // inside the multi-error. 32 | Children []traceTree 33 | } 34 | 35 | // buildTraceTree builds a trace tree from an error. 36 | // 37 | // All errors connected to the given error 38 | // are considered part of its trace except: 39 | // if a multi-error is found, 40 | // a separate trace is built from each of its errors 41 | // and they're all considered children of this error. 42 | func buildTraceTree(err error) traceTree { 43 | current := traceTree{Err: err} 44 | loop: 45 | for { 46 | if frame, inner, ok := UnwrapFrame(err); ok { 47 | current.Trace = append(current.Trace, frame) 48 | err = inner 49 | continue 50 | } 51 | 52 | // We unwrap errors manually instead of using errors.As 53 | // because we don't want to accidentally skip over multi-errors 54 | // or interpret them as part of a single error chain. 55 | switch x := err.(type) { 56 | case interface{ Unwrap() error }: 57 | err = x.Unwrap() 58 | 59 | case interface{ Unwrap() []error }: 60 | // Encountered a multi-error. 61 | // Everything else is a child of current. 62 | errs := x.Unwrap() 63 | current.Children = make([]traceTree, 0, len(errs)) 64 | for _, err := range errs { 65 | current.Children = append(current.Children, buildTraceTree(err)) 66 | } 67 | 68 | break loop 69 | 70 | default: 71 | // Reached a terminal error. 72 | break loop 73 | } 74 | } 75 | 76 | slices.Reverse(current.Trace) 77 | return current 78 | } 79 | 80 | func writeTree(w io.Writer, tree traceTree) error { 81 | return (&treeWriter{W: w}).WriteTree(tree) 82 | } 83 | 84 | type treeWriter struct { 85 | W io.Writer 86 | e error 87 | } 88 | 89 | func (p *treeWriter) WriteTree(t traceTree) error { 90 | p.writeTree(t, nil /* path */) 91 | return p.e 92 | } 93 | 94 | // Records the error if non-nil. 95 | // Will be returned from WriteTree, ultimately. 96 | func (p *treeWriter) err(err error) { 97 | p.e = errors.Join(p.e, err) 98 | } 99 | 100 | // writeTree writes the tree to the writer. 101 | // 102 | // path is a slice of indexes leading to the current node 103 | // in the tree. 104 | func (p *treeWriter) writeTree(t traceTree, path []int) { 105 | for i, child := range t.Children { 106 | p.writeTree(child, append(path, i)) 107 | } 108 | 109 | p.writeTrace(t.Err, t.Trace, path) 110 | } 111 | 112 | func (p *treeWriter) writeTrace(err error, trace []runtime.Frame, path []int) { 113 | // A trace for a single error takes 114 | // the same form as a stack trace: 115 | // 116 | // error message 117 | // 118 | // func1 119 | // path/to/file.go:12 120 | // func2 121 | // path/to/file.go:34 122 | // 123 | // However, when path isn't empty, we're part of a tree, 124 | // so we need to add prefixes containers around the trace 125 | // to indicate the tree structure. 126 | // 127 | // We print in depth-first order, so we get: 128 | // 129 | // +- error message 1 130 | // | 131 | // | func5 132 | // | path/to/file.go:90 133 | // | func6 134 | // | path/to/file.go:12 135 | // | 136 | // +- error message 2 137 | // | 138 | // | func7 139 | // | path/to/file.go:34 140 | // | func8 141 | // | path/to/file.go:56 142 | // | 143 | // +- error message 3 144 | // | 145 | // | func3 146 | // | path/to/file.go:57 147 | // | func4 148 | // | path/to/file.go:78 149 | // | 150 | // error message 4 151 | // 152 | // func1 153 | // path/to/file.go:12 154 | // func2 155 | // path/to/file.go:34 156 | 157 | // +- error message 158 | // | 159 | // 160 | // The message may have newlines in it, 161 | // so we need to print each line separately. 162 | for i, line := range strings.Split(err.Error(), "\n") { 163 | if i == 0 { 164 | p.pipes(path, "+- ") 165 | } else { 166 | p.pipes(path, "| ") 167 | } 168 | p.writeString(line) 169 | p.writeString("\n") 170 | } 171 | 172 | if len(trace) > 0 { 173 | // Empty line between the message and the trace. 174 | p.pipes(path, "| ") 175 | p.writeString("\n") 176 | 177 | for _, frame := range trace { 178 | p.pipes(path, "| ") 179 | p.writeString(frame.Function) 180 | p.writeString("\n") 181 | 182 | p.pipes(path, "| ") 183 | p.printf("\t%s:%d\n", frame.File, frame.Line) 184 | } 185 | } 186 | 187 | // Connecting "|" lines when ending a trace 188 | // This is the "empty" line between traces. 189 | if len(path) > 0 { 190 | p.pipes(path, "| ") 191 | p.writeString("\n") 192 | } 193 | } 194 | 195 | // pipes draws the "| | |" pipes prefix. 196 | // 197 | // path is a slice of indexes leading to the current node. 198 | // For example, the path [1, 3, 2] says that the current node is 199 | // the 2nd child of the 3rd child of the 1st child of the root. 200 | // 201 | // last is the last "|" component in this grouping; 202 | // it'll normally be "| " or "+- ". 203 | // 204 | // In combination, path and last tell us how to draw the pipes. 205 | // More often than not, we just draw: 206 | // 207 | // | | | 208 | // 209 | // However, for the first line of a message, 210 | // we need to connect to the following line so we use "+- " 211 | // which gives us: 212 | // 213 | // | | +- msg 214 | // | | | 215 | // 216 | // Lastly, when drawing the tree, 217 | // if any of the intermediate positions in the path are 0, 218 | // (i.e. the first child of a parent), 219 | // we don't draw a pipe because it won't have 220 | // anything above it to connect to. 221 | // For example: 222 | // 223 | // 0 1 2 For some x > 0 224 | // ------- 225 | // | +- msg path = [x, 0, 0] 226 | // | | 227 | // | +- msg path = [x, 0, 1] 228 | // | | 229 | // | +- msg path = [x, 0] 230 | // | | 231 | // | +- msg path = [x, 1] 232 | // 233 | // Note that for cases where path[1] == 0, 234 | // we don't draw a pipe if len(path) > 2. 235 | func (p *treeWriter) pipes(path []int, last string) { 236 | for depth, idx := range path { 237 | if depth == len(path)-1 { 238 | p.writeString(last) 239 | } else if idx == 0 { 240 | // First child of the parent at this layer. 241 | // Nothing to connect to above us. 242 | p.writeString(" ") 243 | } else { 244 | p.writeString("| ") 245 | } 246 | } 247 | } 248 | 249 | func (p *treeWriter) writeString(s string) { 250 | _, err := io.WriteString(p.W, s) 251 | p.err(err) 252 | } 253 | 254 | func (p *treeWriter) printf(format string, args ...interface{}) { 255 | _, err := fmt.Fprintf(p.W, format, args...) 256 | p.err(err) 257 | } 258 | -------------------------------------------------------------------------------- /tree_test.go: -------------------------------------------------------------------------------- 1 | package errtrace 2 | 3 | import ( 4 | "errors" 5 | "runtime" 6 | "strings" 7 | "testing" 8 | 9 | "braces.dev/errtrace/internal/diff" 10 | ) 11 | 12 | func errorCaller() error { 13 | return Wrap(errorCallee()) 14 | } 15 | 16 | func errorCallee() error { 17 | return New("test error") 18 | } 19 | 20 | func errorMultiCaller() error { 21 | return errors.Join( 22 | errorCaller(), 23 | errorCaller(), 24 | ) 25 | } 26 | 27 | func TestBuildTreeSingle(t *testing.T) { 28 | tree := buildTraceTree(errorCaller()) 29 | trace := tree.Trace 30 | 31 | if want, got := 2, len(trace); want != got { 32 | t.Fatalf("trace length mismatch, want %d, got %d", want, got) 33 | } 34 | 35 | if want, got := "braces.dev/errtrace.errorCallee", trace[0].Function; want != got { 36 | t.Errorf("innermost function should be first, want %q, got %q", want, got) 37 | } 38 | 39 | if want, got := "braces.dev/errtrace.errorCaller", trace[1].Function; want != got { 40 | t.Errorf("outermost function should be last, want %q, got %q", want, got) 41 | } 42 | } 43 | 44 | func TestBuildTreeMulti(t *testing.T) { 45 | tree := buildTraceTree(errorMultiCaller()) 46 | 47 | if want, got := 0, len(tree.Trace); want != got { 48 | t.Fatalf("unexpected trace: %v", tree.Trace) 49 | } 50 | 51 | if want, got := 2, len(tree.Children); want != got { 52 | t.Fatalf("children length mismatch, want %d, got %d", want, got) 53 | } 54 | 55 | for _, child := range tree.Children { 56 | if want, got := 2, len(child.Trace); want != got { 57 | t.Fatalf("trace length mismatch, want %d, got %d", want, got) 58 | } 59 | 60 | if want, got := "braces.dev/errtrace.errorCallee", child.Trace[0].Function; want != got { 61 | t.Errorf("innermost function should be first, want %q, got %q", want, got) 62 | } 63 | 64 | if want, got := "braces.dev/errtrace.errorCaller", child.Trace[1].Function; want != got { 65 | t.Errorf("outermost function should be last, want %q, got %q", want, got) 66 | } 67 | } 68 | } 69 | 70 | func TestWriteTree(t *testing.T) { 71 | type testFrame struct { 72 | Function string 73 | File string 74 | Line int 75 | } 76 | 77 | // Helpers to make tests more readable. 78 | type frames = []testFrame 79 | tree := func(err error, trace frames, children ...traceTree) traceTree { 80 | runtimeFrames := make([]runtime.Frame, len(trace)) 81 | for i, f := range trace { 82 | runtimeFrames[i] = runtime.Frame{ 83 | Function: f.Function, 84 | File: f.File, 85 | Line: f.Line, 86 | } 87 | } 88 | 89 | return traceTree{ 90 | Err: err, 91 | Trace: runtimeFrames, 92 | Children: children, 93 | } 94 | } 95 | 96 | tests := []struct { 97 | name string 98 | give traceTree 99 | want []string // lines minus trailing newline 100 | }{ 101 | { 102 | name: "top level single error", 103 | give: tree( 104 | errors.New("test error"), 105 | frames{ 106 | {"foo", "foo.go", 42}, 107 | {"bar", "bar.go", 24}, 108 | }, 109 | ), 110 | want: []string{ 111 | "test error", 112 | "", 113 | "foo", 114 | " foo.go:42", 115 | "bar", 116 | " bar.go:24", 117 | }, 118 | }, 119 | { 120 | name: "multi error without trace", 121 | give: tree( 122 | errors.Join( 123 | errors.New("err a"), 124 | errors.New("err b"), 125 | ), 126 | frames{}, 127 | tree(errors.New("err a"), frames{ 128 | {"foo", "foo.go", 42}, 129 | {"bar", "bar.go", 24}, 130 | }), 131 | tree(errors.New("err b"), frames{ 132 | {"baz", "baz.go", 24}, 133 | {"qux", "qux.go", 48}, 134 | }), 135 | ), 136 | want: []string{ 137 | "+- err a", 138 | "| ", 139 | "| foo", 140 | "| foo.go:42", 141 | "| bar", 142 | "| bar.go:24", 143 | "| ", 144 | "+- err b", 145 | "| ", 146 | "| baz", 147 | "| baz.go:24", 148 | "| qux", 149 | "| qux.go:48", 150 | "| ", 151 | "err a", 152 | "err b", 153 | }, 154 | }, 155 | { 156 | name: "multi error with trace", 157 | give: tree( 158 | errors.Join( 159 | errors.New("err a"), 160 | errors.New("err b"), 161 | ), 162 | frames{ 163 | {"foo", "foo.go", 42}, 164 | {"bar", "bar.go", 24}, 165 | }, 166 | tree( 167 | errors.New("err a"), 168 | frames{ 169 | {"baz", "baz.go", 24}, 170 | {"qux", "qux.go", 48}, 171 | }, 172 | ), 173 | tree( 174 | errors.New("err b"), 175 | frames{ 176 | {"corge", "corge.go", 24}, 177 | {"grault", "grault.go", 48}, 178 | }, 179 | ), 180 | ), 181 | want: []string{ 182 | "+- err a", 183 | "| ", 184 | "| baz", 185 | "| baz.go:24", 186 | "| qux", 187 | "| qux.go:48", 188 | "| ", 189 | "+- err b", 190 | "| ", 191 | "| corge", 192 | "| corge.go:24", 193 | "| grault", 194 | "| grault.go:48", 195 | "| ", 196 | "err a", 197 | "err b", 198 | "", 199 | "foo", 200 | " foo.go:42", 201 | "bar", 202 | " bar.go:24", 203 | }, 204 | }, 205 | { 206 | name: "wrapped multi error with siblings", 207 | give: tree( 208 | errors.Join( 209 | errors.Join( 210 | errors.New("err a"), 211 | errors.New("err b"), 212 | ), 213 | errors.New("err c"), 214 | ), 215 | frames{ 216 | {"foo", "foo.go", 42}, 217 | {"bar", "bar.go", 24}, 218 | }, 219 | tree( 220 | errors.Join( 221 | errors.New("err a"), 222 | errors.New("err b"), 223 | ), 224 | frames{ 225 | {"baz", "baz.go", 24}, 226 | {"qux", "qux.go", 48}, 227 | }, 228 | tree( 229 | errors.New("err a"), 230 | frames{ 231 | {"quux", "quux.go", 24}, 232 | {"quuz", "quuz.go", 48}, 233 | }, 234 | ), 235 | tree( 236 | errors.New("err b"), 237 | frames{ 238 | {"abc", "abc.go", 24}, 239 | {"def", "def.go", 48}, 240 | }, 241 | ), 242 | ), 243 | tree( 244 | errors.New("err c"), 245 | frames{ 246 | {"corge", "corge.go", 24}, 247 | {"grault", "grault.go", 48}, 248 | }, 249 | ), 250 | ), 251 | want: []string{ 252 | " +- err a", 253 | " | ", 254 | " | quux", 255 | " | quux.go:24", 256 | " | quuz", 257 | " | quuz.go:48", 258 | " | ", 259 | " +- err b", 260 | " | ", 261 | " | abc", 262 | " | abc.go:24", 263 | " | def", 264 | " | def.go:48", 265 | " | ", 266 | "+- err a", 267 | "| err b", 268 | "| ", 269 | "| baz", 270 | "| baz.go:24", 271 | "| qux", 272 | "| qux.go:48", 273 | "| ", 274 | "+- err c", 275 | "| ", 276 | "| corge", 277 | "| corge.go:24", 278 | "| grault", 279 | "| grault.go:48", 280 | "| ", 281 | "err a", 282 | "err b", 283 | "err c", 284 | "", 285 | "foo", 286 | " foo.go:42", 287 | "bar", 288 | " bar.go:24", 289 | }, 290 | }, 291 | { 292 | name: "multi error with one non-traced error", 293 | give: tree( 294 | errors.Join( 295 | errors.New("err a"), 296 | errors.New("err b"), 297 | errors.New("err c"), 298 | ), 299 | frames{}, 300 | tree( 301 | errors.New("err a"), 302 | frames{ 303 | {"foo", "foo.go", 42}, 304 | {"bar", "bar.go", 24}, 305 | }, 306 | ), 307 | tree( 308 | errors.New("err b"), 309 | frames{}, 310 | ), 311 | tree( 312 | errors.New("err c"), 313 | frames{ 314 | {"baz", "baz.go", 24}, 315 | {"qux", "qux.go", 48}, 316 | }, 317 | ), 318 | ), 319 | want: []string{ 320 | "+- err a", 321 | "| ", 322 | "| foo", 323 | "| foo.go:42", 324 | "| bar", 325 | "| bar.go:24", 326 | "| ", 327 | "+- err b", 328 | "| ", 329 | "+- err c", 330 | "| ", 331 | "| baz", 332 | "| baz.go:24", 333 | "| qux", 334 | "| qux.go:48", 335 | "| ", 336 | "err a", 337 | "err b", 338 | "err c", 339 | }, 340 | }, 341 | } 342 | 343 | for _, tt := range tests { 344 | t.Run(tt.name, func(t *testing.T) { 345 | var s strings.Builder 346 | if err := writeTree(&s, tt.give); err != nil { 347 | t.Fatal(err) 348 | } 349 | 350 | if want, got := strings.Join(tt.want, "\n")+"\n", s.String(); want != got { 351 | t.Errorf("output mismatch:\n"+ 352 | "want:\n%s\n"+ 353 | "got:\n%s\n"+ 354 | "diff:\n%s", want, got, diff.Lines(want, got)) 355 | } 356 | }) 357 | } 358 | } 359 | -------------------------------------------------------------------------------- /unwrap.go: -------------------------------------------------------------------------------- 1 | package errtrace 2 | 3 | import ( 4 | "errors" 5 | "runtime" 6 | ) 7 | 8 | // UnwrapFrame unwraps the outermost frame from the given error, 9 | // returning it and the inner error. 10 | // ok is true if the frame was successfully extracted, 11 | // and false otherwise, or if the error is not an errtrace error. 12 | // 13 | // You can use this for structured access to trace information. 14 | // 15 | // Any error that has a method `TracePC() uintptr` will 16 | // contribute a frame to the trace. 17 | func UnwrapFrame(err error) (frame runtime.Frame, inner error, ok bool) { //nolint:revive // error is intentionally middle return 18 | e, ok := err.(interface{ TracePC() uintptr }) 19 | if !ok { 20 | return runtime.Frame{}, err, false 21 | } 22 | 23 | inner = errors.Unwrap(err) 24 | frames := runtime.CallersFrames([]uintptr{e.TracePC()}) 25 | f, _ := frames.Next() 26 | if f == (runtime.Frame{}) { 27 | // Unlikely, but if PC didn't yield a frame, 28 | // just return the inner error. 29 | return runtime.Frame{}, inner, false 30 | } 31 | 32 | return f, inner, true 33 | } 34 | -------------------------------------------------------------------------------- /unwrap_test.go: -------------------------------------------------------------------------------- 1 | package errtrace 2 | 3 | import ( 4 | "errors" 5 | "path/filepath" 6 | "reflect" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | func TestUnwrapFrame(t *testing.T) { 12 | giveErr := errors.New("great sadness") 13 | 14 | t.Run("not wrapped", func(t *testing.T) { 15 | _, inner, ok := UnwrapFrame(giveErr) 16 | if got, want := ok, false; got != want { 17 | t.Errorf("ok: got %v, want %v", got, want) 18 | } 19 | 20 | if got, want := inner, giveErr; got != want { 21 | t.Errorf("inner: got %v, want %v", inner, giveErr) 22 | } 23 | }) 24 | 25 | t.Run("wrapped", func(t *testing.T) { 26 | wrapped := Wrap(giveErr) 27 | frame, inner, ok := UnwrapFrame(wrapped) 28 | if got, want := ok, true; got != want { 29 | t.Errorf("ok: got %v, want %v", got, want) 30 | } 31 | 32 | if got, want := inner, giveErr; got != want { 33 | t.Errorf("inner: got %v, want %v", inner, giveErr) 34 | } 35 | 36 | if got, want := frame.Function, ".TestUnwrapFrame.func2"; !strings.HasSuffix(got, want) { 37 | t.Errorf("frame.Func: got %q, does not contain %q", got, want) 38 | } 39 | 40 | if got, want := filepath.Base(frame.File), "unwrap_test.go"; got != want { 41 | t.Errorf("frame.File: got %v, want %v", got, want) 42 | } 43 | }) 44 | 45 | t.Run("custom error", func(t *testing.T) { 46 | wrapped := wrapCustomTrace(giveErr) 47 | frame, inner, ok := UnwrapFrame(wrapped) 48 | if got, want := ok, true; got != want { 49 | t.Errorf("ok: got %v, want %v", got, want) 50 | } 51 | 52 | if got, want := inner, giveErr; got != want { 53 | t.Errorf("inner: got %v, want %v", inner, giveErr) 54 | } 55 | 56 | if got, want := frame.Function, ".wrapCustomTrace"; !strings.HasSuffix(got, want) { 57 | t.Errorf("frame.Func: got %q, does not contain %q", got, want) 58 | } 59 | 60 | if got, want := filepath.Base(frame.File), "unwrap_test.go"; got != want { 61 | t.Errorf("frame.File: got %v, want %v", got, want) 62 | } 63 | }) 64 | } 65 | 66 | func TestUnwrapFrame_badPC(t *testing.T) { 67 | giveErr := errors.New("great sadness") 68 | _, inner, ok := UnwrapFrame(wrap(giveErr, 0)) 69 | if got, want := ok, false; got != want { 70 | t.Errorf("ok: got %v, want %v", got, want) 71 | } 72 | 73 | if got, want := inner, giveErr; got != want { 74 | t.Errorf("inner: got %v, want %v", inner, giveErr) 75 | } 76 | } 77 | 78 | type customTraceError struct { 79 | err error 80 | pc uintptr 81 | } 82 | 83 | func wrapCustomTrace(err error) error { 84 | return &customTraceError{ 85 | err: err, 86 | pc: reflect.ValueOf(wrapCustomTrace).Pointer(), 87 | } 88 | } 89 | 90 | func (e *customTraceError) Error() string { 91 | return e.err.Error() 92 | } 93 | 94 | func (e *customTraceError) TracePC() uintptr { 95 | return e.pc 96 | } 97 | 98 | func (e *customTraceError) Unwrap() error { 99 | return e.err 100 | } 101 | -------------------------------------------------------------------------------- /wrap.go: -------------------------------------------------------------------------------- 1 | package errtrace 2 | 3 | import ( 4 | "braces.dev/errtrace/internal/pc" 5 | ) 6 | 7 | // Wrap adds information about the program counter of the caller to the error. 8 | // This is intended to be used at all return points in a function. 9 | // If err is nil, Wrap returns nil. 10 | // 11 | //go:noinline so caller's PC is saved in the stack frame for asm GetCaller. 12 | func Wrap(err error) error { 13 | if err == nil { 14 | return nil 15 | } 16 | 17 | return wrap(err, pc.GetCaller()) 18 | } 19 | 20 | // Wrap2 is used to [Wrap] the last error return when returning 2 values. 21 | // This is useful when returning multiple returns from a function call directly: 22 | // 23 | // return Wrap2(fn()) 24 | // 25 | // Wrap2 is used by the CLI to avoid line number changes. 26 | // 27 | //go:noinline due to GetCaller (see [Wrap] for details). 28 | func Wrap2[T any](t T, err error) (T, error) { 29 | if err == nil { 30 | return t, nil 31 | } 32 | 33 | return t, wrap(err, pc.GetCaller()) 34 | } 35 | 36 | // Wrap3 is used to [Wrap] the last error return when returning 3 values. 37 | // This is useful when returning multiple returns from a function call directly: 38 | // 39 | // return Wrap3(fn()) 40 | // 41 | // Wrap3 is used by the CLI to avoid line number changes. 42 | // 43 | //go:noinline due to GetCaller (see [Wrap] for details). 44 | func Wrap3[T1, T2 any](t1 T1, t2 T2, err error) (T1, T2, error) { 45 | if err == nil { 46 | return t1, t2, nil 47 | } 48 | 49 | return t1, t2, wrap(err, pc.GetCaller()) 50 | } 51 | 52 | // Wrap4 is used to [Wrap] the last error return when returning 4 values. 53 | // This is useful when returning multiple returns from a function call directly: 54 | // 55 | // return Wrap4(fn()) 56 | // 57 | // Wrap4 is used by the CLI to avoid line number changes. 58 | // 59 | //go:noinline due to GetCaller (see [Wrap] for details). 60 | func Wrap4[T1, T2, T3 any](t1 T1, t2 T2, t3 T3, err error) (T1, T2, T3, error) { 61 | if err == nil { 62 | return t1, t2, t3, nil 63 | } 64 | 65 | return t1, t2, t3, wrap(err, pc.GetCaller()) 66 | } 67 | 68 | // Wrap5 is used to [Wrap] the last error return when returning 5 values. 69 | // This is useful when returning multiple returns from a function call directly: 70 | // 71 | // return Wrap5(fn()) 72 | // 73 | // Wrap5 is used by the CLI to avoid line number changes. 74 | // 75 | //go:noinline due to GetCaller (see [Wrap] for details). 76 | func Wrap5[T1, T2, T3, T4 any](t1 T1, t2 T2, t3 T3, t4 T4, err error) (T1, T2, T3, T4, error) { 77 | if err == nil { 78 | return t1, t2, t3, t4, nil 79 | } 80 | 81 | return t1, t2, t3, t4, wrap(err, pc.GetCaller()) 82 | } 83 | 84 | // Wrap6 is used to [Wrap] the last error return when returning 6 values. 85 | // This is useful when returning multiple returns from a function call directly: 86 | // 87 | // return Wrap6(fn()) 88 | // 89 | // Wrap6 is used by the CLI to avoid line number changes. 90 | // 91 | //go:noinline due to GetCaller (see [Wrap] for details). 92 | func Wrap6[T1, T2, T3, T4, T5 any](t1 T1, t2 T2, t3 T3, t4 T4, t5 T5, err error) (T1, T2, T3, T4, T5, error) { 93 | if err == nil { 94 | return t1, t2, t3, t4, t5, nil 95 | } 96 | 97 | return t1, t2, t3, t4, t5, wrap(err, pc.GetCaller()) 98 | } 99 | -------------------------------------------------------------------------------- /wrap_caller.go: -------------------------------------------------------------------------------- 1 | package errtrace 2 | 3 | import "braces.dev/errtrace/internal/pc" 4 | 5 | // Caller represents a single caller frame, and is intended for error helpers 6 | // to capture caller information for wrapping. See [GetCaller] for details. 7 | type Caller struct { 8 | callerPC uintptr 9 | } 10 | 11 | // GetCaller captures the program counter of a caller, primarily intended for 12 | // error helpers so caller information captures the helper's caller. 13 | // 14 | // Callers of this function should be marked '//go:noinline' to avoid inlining, 15 | // as GetCaller expects to skip the caller's stack frame. 16 | // 17 | // //go:noinline 18 | // func Wrapf(err error, msg string, args ...any) { 19 | // caller := errtrace.GetCaller() 20 | // err := ... 21 | // return caller.Wrap(err) 22 | // } 23 | // 24 | //go:noinline 25 | func GetCaller() Caller { 26 | return Caller{pc.GetCallerSkip1()} 27 | } 28 | 29 | // Wrap adds the program counter captured in Caller to the error, 30 | // similar to [Wrap], but relying on previously captured caller inforamtion. 31 | func (c Caller) Wrap(err error) error { 32 | return wrap(err, c.callerPC) 33 | } 34 | -------------------------------------------------------------------------------- /wrap_caller_safe_test.go: -------------------------------------------------------------------------------- 1 | //go:build safe || !(amd64 || arm64) 2 | 3 | // 4 | // Build tag must match pc_safe.go 5 | 6 | package errtrace_test 7 | 8 | func init() { 9 | safe = true 10 | } 11 | -------------------------------------------------------------------------------- /wrap_caller_test.go: -------------------------------------------------------------------------------- 1 | package errtrace_test 2 | 3 | import ( 4 | "errors" 5 | "runtime" 6 | "strings" 7 | "testing" 8 | 9 | "braces.dev/errtrace" 10 | ) 11 | 12 | var safe = false 13 | 14 | // Note: Though test tables could DRY up the below tests, they're intentionally 15 | // kept as functions calling other simple functions to test how inlining impacts 16 | // use of `GetCaller`. 17 | 18 | func TestGetCallerWrap_ErrorsNew(t *testing.T) { 19 | err := callErrorsNew() 20 | wantErr(t, err, "callErrorsNew") 21 | } 22 | 23 | func callErrorsNew() error { 24 | return errorsNew("wrap errors.New") 25 | } 26 | 27 | func errorsNew(msg string) error { 28 | caller := errtrace.GetCaller() 29 | return caller.Wrap(errors.New(msg)) 30 | } 31 | 32 | func TestGetCallerWrap_WrapExisting(t *testing.T) { 33 | err := callWrapExisting() 34 | wantErr(t, err, "callWrapExisting") 35 | } 36 | 37 | func callWrapExisting() error { 38 | return wrapExisting() 39 | } 40 | 41 | var errFoo = errors.New("foo") 42 | 43 | func wrapExisting() error { 44 | return errtrace.GetCaller().Wrap(errFoo) 45 | } 46 | 47 | func TestGetCallerWrap_PassCaller(t *testing.T) { 48 | err := callPassCaller() 49 | wantErr(t, err, "callPassCaller") 50 | } 51 | 52 | func callPassCaller() error { 53 | return passCaller() 54 | } 55 | 56 | func passCaller() error { 57 | return passCallerInner(errtrace.GetCaller()) 58 | } 59 | 60 | func passCallerInner(caller errtrace.Caller) error { 61 | return caller.Wrap(errFoo) 62 | } 63 | 64 | func TestGetCallerWrap_RetCaller(t *testing.T) { 65 | err := callRetCaller() 66 | 67 | wantFn := "callRetCaller" 68 | if !safe { 69 | // If the function calling pc.GetCaller is inlined, there's no stack frame 70 | // so the assembly implementation can skip the correct caller. 71 | // Callers of GetCaller using `go:noinline` avoid this (as recommended in the docs). 72 | // Inlining is not consistent, hence we check the frame in !safe mode. 73 | f, _, _ := errtrace.UnwrapFrame(err) 74 | if !strings.HasSuffix(f.Function, wantFn) { 75 | wantFn = "TestGetCallerWrap_RetCaller" 76 | } 77 | } 78 | wantErr(t, err, wantFn) 79 | } 80 | 81 | func callRetCaller() error { 82 | return retCaller().Wrap(errFoo) 83 | } 84 | 85 | func retCaller() errtrace.Caller { 86 | return errtrace.GetCaller() 87 | } 88 | 89 | func TestGetCallerWrap_RetCallerNoInline(t *testing.T) { 90 | err := callRetCallerNoInline() 91 | wantErr(t, err, "callRetCallerNoInline") 92 | } 93 | 94 | func callRetCallerNoInline() error { 95 | return retCallerNoInline().Wrap(errFoo) 96 | } 97 | 98 | //go:noinline 99 | func retCallerNoInline() errtrace.Caller { 100 | return errtrace.GetCaller() 101 | } 102 | 103 | func wantErr(t testing.TB, err error, fn string) runtime.Frame { 104 | if err == nil { 105 | t.Fatalf("expected err") 106 | } 107 | 108 | f, _, _ := errtrace.UnwrapFrame(err) 109 | if !strings.HasSuffix(f.Function, "."+fn) { 110 | t.Errorf("expected caller to be %v, got %v", fn, f.Function) 111 | } 112 | return f 113 | } 114 | --------------------------------------------------------------------------------