├── .gitattributes ├── .github └── workflows │ └── build_and_test.yml ├── .gitignore ├── LICENSE ├── README.md ├── callgraphutil ├── aliases.go ├── calls.go ├── cosmograph.go ├── cosmograph_test.go ├── csv.go ├── csv_test.go ├── doc.go ├── dot.go ├── dot_test.go ├── graph.go ├── graph_vulncheck.go ├── path.go └── ssa.go ├── check.go ├── cmd ├── logi │ └── main.go ├── sqli │ └── main.go ├── ssadump │ └── main.go ├── taint │ ├── Makefile │ ├── example │ │ └── main.go │ ├── main.go │ ├── main_test.go │ └── vhs │ │ ├── demo.gif │ │ └── demo.tape └── xss │ └── main.go ├── doc.go ├── go.mod ├── go.sum ├── log └── injection │ ├── injection.go │ ├── injection_test.go │ └── testdata │ └── src │ ├── a │ └── main.go │ ├── b │ └── main.go │ ├── c │ └── main.go │ ├── d │ └── main.go │ ├── e │ └── main.go │ ├── f │ └── main.go │ └── g │ └── main.go ├── sources_sinks.go ├── sql └── injection │ ├── injection.go │ ├── injection_test.go │ └── testdata │ └── src │ ├── a │ └── main.go │ ├── b │ └── main.go │ ├── c │ └── main.go │ ├── d │ └── main.go │ ├── e │ └── main.go │ ├── example │ └── main.go │ ├── f │ └── main.go │ ├── g │ └── main.go │ ├── github.com │ ├── jinzhu │ │ └── gorm │ │ │ └── mock.go │ └── lib │ │ └── pq │ │ └── main.go │ ├── h │ └── main.go │ └── i │ └── main.go ├── walk_ssa.go └── xss ├── testdata └── src │ ├── a │ └── main.go │ ├── b │ └── main.go │ ├── c │ └── main.go │ ├── d │ └── main.go │ ├── e │ └── main.go │ ├── f │ └── main.go │ └── g │ └── main.go ├── xss.go └── xss_test.go /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.github/workflows/build_and_test.yml: -------------------------------------------------------------------------------- 1 | # This workflow runs go build and test for each 2 | # push and pull request to the main branch. 3 | name: Build and Test 4 | 5 | on: 6 | push: 7 | branches: [ "main" ] 8 | pull_request: 9 | branches: [ "main" ] 10 | 11 | jobs: 12 | 13 | build: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v3 17 | 18 | - name: Setup Go 19 | uses: actions/setup-go@v4 20 | with: 21 | go-version: '1.21' 22 | 23 | - name: Build 24 | run: go build -v ./... 25 | 26 | - name: Test 27 | run: go test -v ./... 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # If you prefer the allow list template instead of the deny list, see community template: 2 | # https://github.com/github/gitignore/blob/main/community/Golang/Go.AllowList.gitignore 3 | # 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, built with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | # vendor/ 19 | 20 | # Go workspace file 21 | go.work 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Mozilla Public License Version 2.0 2 | ================================== 3 | 4 | 1. Definitions 5 | -------------- 6 | 7 | 1.1. "Contributor" 8 | means each individual or legal entity that creates, contributes to 9 | the creation of, or owns Covered Software. 10 | 11 | 1.2. "Contributor Version" 12 | means the combination of the Contributions of others (if any) used 13 | by a Contributor and that particular Contributor's Contribution. 14 | 15 | 1.3. "Contribution" 16 | means Covered Software of a particular Contributor. 17 | 18 | 1.4. "Covered Software" 19 | means Source Code Form to which the initial Contributor has attached 20 | the notice in Exhibit A, the Executable Form of such Source Code 21 | Form, and Modifications of such Source Code Form, in each case 22 | including portions thereof. 23 | 24 | 1.5. "Incompatible With Secondary Licenses" 25 | means 26 | 27 | (a) that the initial Contributor has attached the notice described 28 | in Exhibit B to the Covered Software; or 29 | 30 | (b) that the Covered Software was made available under the terms of 31 | version 1.1 or earlier of the License, but not also under the 32 | terms of a Secondary License. 33 | 34 | 1.6. "Executable Form" 35 | means any form of the work other than Source Code Form. 36 | 37 | 1.7. "Larger Work" 38 | means a work that combines Covered Software with other material, in 39 | a separate file or files, that is not Covered Software. 40 | 41 | 1.8. "License" 42 | means this document. 43 | 44 | 1.9. "Licensable" 45 | means having the right to grant, to the maximum extent possible, 46 | whether at the time of the initial grant or subsequently, any and 47 | all of the rights conveyed by this License. 48 | 49 | 1.10. "Modifications" 50 | means any of the following: 51 | 52 | (a) any file in Source Code Form that results from an addition to, 53 | deletion from, or modification of the contents of Covered 54 | Software; or 55 | 56 | (b) any new file in Source Code Form that contains any Covered 57 | Software. 58 | 59 | 1.11. "Patent Claims" of a Contributor 60 | means any patent claim(s), including without limitation, method, 61 | process, and apparatus claims, in any patent Licensable by such 62 | Contributor that would be infringed, but for the grant of the 63 | License, by the making, using, selling, offering for sale, having 64 | made, import, or transfer of either its Contributions or its 65 | Contributor Version. 66 | 67 | 1.12. "Secondary License" 68 | means either the GNU General Public License, Version 2.0, the GNU 69 | Lesser General Public License, Version 2.1, the GNU Affero General 70 | Public License, Version 3.0, or any later versions of those 71 | licenses. 72 | 73 | 1.13. "Source Code Form" 74 | means the form of the work preferred for making modifications. 75 | 76 | 1.14. "You" (or "Your") 77 | means an individual or a legal entity exercising rights under this 78 | License. For legal entities, "You" includes any entity that 79 | controls, is controlled by, or is under common control with You. For 80 | purposes of this definition, "control" means (a) the power, direct 81 | or indirect, to cause the direction or management of such entity, 82 | whether by contract or otherwise, or (b) ownership of more than 83 | fifty percent (50%) of the outstanding shares or beneficial 84 | ownership of such entity. 85 | 86 | 2. License Grants and Conditions 87 | -------------------------------- 88 | 89 | 2.1. Grants 90 | 91 | Each Contributor hereby grants You a world-wide, royalty-free, 92 | non-exclusive license: 93 | 94 | (a) under intellectual property rights (other than patent or trademark) 95 | Licensable by such Contributor to use, reproduce, make available, 96 | modify, display, perform, distribute, and otherwise exploit its 97 | Contributions, either on an unmodified basis, with Modifications, or 98 | as part of a Larger Work; and 99 | 100 | (b) under Patent Claims of such Contributor to make, use, sell, offer 101 | for sale, have made, import, and otherwise transfer either its 102 | Contributions or its Contributor Version. 103 | 104 | 2.2. Effective Date 105 | 106 | The licenses granted in Section 2.1 with respect to any Contribution 107 | become effective for each Contribution on the date the Contributor first 108 | distributes such Contribution. 109 | 110 | 2.3. Limitations on Grant Scope 111 | 112 | The licenses granted in this Section 2 are the only rights granted under 113 | this License. No additional rights or licenses will be implied from the 114 | distribution or licensing of Covered Software under this License. 115 | Notwithstanding Section 2.1(b) above, no patent license is granted by a 116 | Contributor: 117 | 118 | (a) for any code that a Contributor has removed from Covered Software; 119 | or 120 | 121 | (b) for infringements caused by: (i) Your and any other third party's 122 | modifications of Covered Software, or (ii) the combination of its 123 | Contributions with other software (except as part of its Contributor 124 | Version); or 125 | 126 | (c) under Patent Claims infringed by Covered Software in the absence of 127 | its Contributions. 128 | 129 | This License does not grant any rights in the trademarks, service marks, 130 | or logos of any Contributor (except as may be necessary to comply with 131 | the notice requirements in Section 3.4). 132 | 133 | 2.4. Subsequent Licenses 134 | 135 | No Contributor makes additional grants as a result of Your choice to 136 | distribute the Covered Software under a subsequent version of this 137 | License (see Section 10.2) or under the terms of a Secondary License (if 138 | permitted under the terms of Section 3.3). 139 | 140 | 2.5. Representation 141 | 142 | Each Contributor represents that the Contributor believes its 143 | Contributions are its original creation(s) or it has sufficient rights 144 | to grant the rights to its Contributions conveyed by this License. 145 | 146 | 2.6. Fair Use 147 | 148 | This License is not intended to limit any rights You have under 149 | applicable copyright doctrines of fair use, fair dealing, or other 150 | equivalents. 151 | 152 | 2.7. Conditions 153 | 154 | Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted 155 | in Section 2.1. 156 | 157 | 3. Responsibilities 158 | ------------------- 159 | 160 | 3.1. Distribution of Source Form 161 | 162 | All distribution of Covered Software in Source Code Form, including any 163 | Modifications that You create or to which You contribute, must be under 164 | the terms of this License. You must inform recipients that the Source 165 | Code Form of the Covered Software is governed by the terms of this 166 | License, and how they can obtain a copy of this License. You may not 167 | attempt to alter or restrict the recipients' rights in the Source Code 168 | Form. 169 | 170 | 3.2. Distribution of Executable Form 171 | 172 | If You distribute Covered Software in Executable Form then: 173 | 174 | (a) such Covered Software must also be made available in Source Code 175 | Form, as described in Section 3.1, and You must inform recipients of 176 | the Executable Form how they can obtain a copy of such Source Code 177 | Form by reasonable means in a timely manner, at a charge no more 178 | than the cost of distribution to the recipient; and 179 | 180 | (b) You may distribute such Executable Form under the terms of this 181 | License, or sublicense it under different terms, provided that the 182 | license for the Executable Form does not attempt to limit or alter 183 | the recipients' rights in the Source Code Form under this License. 184 | 185 | 3.3. Distribution of a Larger Work 186 | 187 | You may create and distribute a Larger Work under terms of Your choice, 188 | provided that You also comply with the requirements of this License for 189 | the Covered Software. If the Larger Work is a combination of Covered 190 | Software with a work governed by one or more Secondary Licenses, and the 191 | Covered Software is not Incompatible With Secondary Licenses, this 192 | License permits You to additionally distribute such Covered Software 193 | under the terms of such Secondary License(s), so that the recipient of 194 | the Larger Work may, at their option, further distribute the Covered 195 | Software under the terms of either this License or such Secondary 196 | License(s). 197 | 198 | 3.4. Notices 199 | 200 | You may not remove or alter the substance of any license notices 201 | (including copyright notices, patent notices, disclaimers of warranty, 202 | or limitations of liability) contained within the Source Code Form of 203 | the Covered Software, except that You may alter any license notices to 204 | the extent required to remedy known factual inaccuracies. 205 | 206 | 3.5. Application of Additional Terms 207 | 208 | You may choose to offer, and to charge a fee for, warranty, support, 209 | indemnity or liability obligations to one or more recipients of Covered 210 | Software. However, You may do so only on Your own behalf, and not on 211 | behalf of any Contributor. You must make it absolutely clear that any 212 | such warranty, support, indemnity, or liability obligation is offered by 213 | You alone, and You hereby agree to indemnify every Contributor for any 214 | liability incurred by such Contributor as a result of warranty, support, 215 | indemnity or liability terms You offer. You may include additional 216 | disclaimers of warranty and limitations of liability specific to any 217 | jurisdiction. 218 | 219 | 4. Inability to Comply Due to Statute or Regulation 220 | --------------------------------------------------- 221 | 222 | If it is impossible for You to comply with any of the terms of this 223 | License with respect to some or all of the Covered Software due to 224 | statute, judicial order, or regulation then You must: (a) comply with 225 | the terms of this License to the maximum extent possible; and (b) 226 | describe the limitations and the code they affect. Such description must 227 | be placed in a text file included with all distributions of the Covered 228 | Software under this License. Except to the extent prohibited by statute 229 | or regulation, such description must be sufficiently detailed for a 230 | recipient of ordinary skill to be able to understand it. 231 | 232 | 5. Termination 233 | -------------- 234 | 235 | 5.1. The rights granted under this License will terminate automatically 236 | if You fail to comply with any of its terms. However, if You become 237 | compliant, then the rights granted under this License from a particular 238 | Contributor are reinstated (a) provisionally, unless and until such 239 | Contributor explicitly and finally terminates Your grants, and (b) on an 240 | ongoing basis, if such Contributor fails to notify You of the 241 | non-compliance by some reasonable means prior to 60 days after You have 242 | come back into compliance. Moreover, Your grants from a particular 243 | Contributor are reinstated on an ongoing basis if such Contributor 244 | notifies You of the non-compliance by some reasonable means, this is the 245 | first time You have received notice of non-compliance with this License 246 | from such Contributor, and You become compliant prior to 30 days after 247 | Your receipt of the notice. 248 | 249 | 5.2. If You initiate litigation against any entity by asserting a patent 250 | infringement claim (excluding declaratory judgment actions, 251 | counter-claims, and cross-claims) alleging that a Contributor Version 252 | directly or indirectly infringes any patent, then the rights granted to 253 | You by any and all Contributors for the Covered Software under Section 254 | 2.1 of this License shall terminate. 255 | 256 | 5.3. In the event of termination under Sections 5.1 or 5.2 above, all 257 | end user license agreements (excluding distributors and resellers) which 258 | have been validly granted by You or Your distributors under this License 259 | prior to termination shall survive termination. 260 | 261 | ************************************************************************ 262 | * * 263 | * 6. Disclaimer of Warranty * 264 | * ------------------------- * 265 | * * 266 | * Covered Software is provided under this License on an "as is" * 267 | * basis, without warranty of any kind, either expressed, implied, or * 268 | * statutory, including, without limitation, warranties that the * 269 | * Covered Software is free of defects, merchantable, fit for a * 270 | * particular purpose or non-infringing. The entire risk as to the * 271 | * quality and performance of the Covered Software is with You. * 272 | * Should any Covered Software prove defective in any respect, You * 273 | * (not any Contributor) assume the cost of any necessary servicing, * 274 | * repair, or correction. This disclaimer of warranty constitutes an * 275 | * essential part of this License. No use of any Covered Software is * 276 | * authorized under this License except under this disclaimer. * 277 | * * 278 | ************************************************************************ 279 | 280 | ************************************************************************ 281 | * * 282 | * 7. Limitation of Liability * 283 | * -------------------------- * 284 | * * 285 | * Under no circumstances and under no legal theory, whether tort * 286 | * (including negligence), contract, or otherwise, shall any * 287 | * Contributor, or anyone who distributes Covered Software as * 288 | * permitted above, be liable to You for any direct, indirect, * 289 | * special, incidental, or consequential damages of any character * 290 | * including, without limitation, damages for lost profits, loss of * 291 | * goodwill, work stoppage, computer failure or malfunction, or any * 292 | * and all other commercial damages or losses, even if such party * 293 | * shall have been informed of the possibility of such damages. This * 294 | * limitation of liability shall not apply to liability for death or * 295 | * personal injury resulting from such party's negligence to the * 296 | * extent applicable law prohibits such limitation. Some * 297 | * jurisdictions do not allow the exclusion or limitation of * 298 | * incidental or consequential damages, so this exclusion and * 299 | * limitation may not apply to You. * 300 | * * 301 | ************************************************************************ 302 | 303 | 8. Litigation 304 | ------------- 305 | 306 | Any litigation relating to this License may be brought only in the 307 | courts of a jurisdiction where the defendant maintains its principal 308 | place of business and such litigation shall be governed by laws of that 309 | jurisdiction, without reference to its conflict-of-law provisions. 310 | Nothing in this Section shall prevent a party's ability to bring 311 | cross-claims or counter-claims. 312 | 313 | 9. Miscellaneous 314 | ---------------- 315 | 316 | This License represents the complete agreement concerning the subject 317 | matter hereof. If any provision of this License is held to be 318 | unenforceable, such provision shall be reformed only to the extent 319 | necessary to make it enforceable. Any law or regulation which provides 320 | that the language of a contract shall be construed against the drafter 321 | shall not be used to construe this License against a Contributor. 322 | 323 | 10. Versions of the License 324 | --------------------------- 325 | 326 | 10.1. New Versions 327 | 328 | Mozilla Foundation is the license steward. Except as provided in Section 329 | 10.3, no one other than the license steward has the right to modify or 330 | publish new versions of this License. Each version will be given a 331 | distinguishing version number. 332 | 333 | 10.2. Effect of New Versions 334 | 335 | You may distribute the Covered Software under the terms of the version 336 | of the License under which You originally received the Covered Software, 337 | or under the terms of any subsequent version published by the license 338 | steward. 339 | 340 | 10.3. Modified Versions 341 | 342 | If you create software not governed by this License, and you want to 343 | create a new license for such software, you may create and use a 344 | modified version of this License if you rename the license and remove 345 | any references to the name of the license steward (except to note that 346 | such modified license differs from this License). 347 | 348 | 10.4. Distributing Source Code Form that is Incompatible With Secondary 349 | Licenses 350 | 351 | If You choose to distribute Source Code Form that is Incompatible With 352 | Secondary Licenses under the terms of this version of the License, the 353 | notice described in Exhibit B of this License must be attached. 354 | 355 | Exhibit A - Source Code Form License Notice 356 | ------------------------------------------- 357 | 358 | This Source Code Form is subject to the terms of the Mozilla Public 359 | License, v. 2.0. If a copy of the MPL was not distributed with this 360 | file, You can obtain one at http://mozilla.org/MPL/2.0/. 361 | 362 | If it is not possible or desirable to put the notice in a particular 363 | file, then You may include the notice in a location (such as a LICENSE 364 | file in a relevant directory) where a recipient would be likely to look 365 | for such a notice. 366 | 367 | You may add additional accurate notices of copyright ownership. 368 | 369 | Exhibit B - "Incompatible With Secondary Licenses" Notice 370 | --------------------------------------------------------- 371 | 372 | This Source Code Form is "Incompatible With Secondary Licenses", as 373 | defined by the Mozilla Public License, v. 2.0. 374 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # taint 2 | 3 | Implements static [taint analysis](https://en.wikipedia.org/wiki/Taint_checking) for Go programs. 4 | 5 | Taint analysis is a technique for identifying the flow of sensitive data through a program. 6 | It can be used to identify potential security vulnerabilities, such as SQL injection or 7 | cross-site scripting (XSS) attacks, by understanding how this data is used and transformed 8 | as it flows through the code. 9 | 10 | A "**source**" is a point in the program where sensitive data originates, typically from user 11 | input, such as data entered into a form on a web page, or data loaded from an external source. 12 | A "**sink**" is a point in the program where sensitive data is used or transmitted to exploit 13 | the program. 14 | 15 | ## Example 16 | 17 | This code generates a function call graph rooted at a program's `main` function and 18 | then runs taint analysis on it. If the program uses `database/sql`, the taint analysis 19 | will determine if the program is vulnerable to SQL injection such that any of the given 20 | sources reach the given sinks. 21 | 22 | ```go 23 | cg, _ := callgraph.New(mainFn, buildSSA.SrcFuncs...) 24 | 25 | sources := taint.NewSources( 26 | "*net/http.Request", 27 | ) 28 | 29 | sinks := taint.NewSinks( 30 | "(*database/sql.DB).Query", 31 | "(*database/sql.DB).QueryContext", 32 | "(*database/sql.DB).QueryRow", 33 | "(*database/sql.DB).QueryRowContext", 34 | "(*database/sql.Tx).Query", 35 | "(*database/sql.Tx).QueryContext", 36 | "(*database/sql.Tx).QueryRow", 37 | "(*database/sql.Tx).QueryRowContext", 38 | ) 39 | 40 | results, _ := taint.Check(cg, sources, sinks) 41 | 42 | for _, result := range results { 43 | // We found a query edge that is tainted by user input, is it 44 | // doing this safely? We expect this to be safely done by 45 | // providing a prepared statement as a constant in the query 46 | // (first argument after context). 47 | queryEdge := result.Path[len(result.Path)-1] 48 | 49 | // Get the query arguments, skipping the first element, pointer to the DB. 50 | queryArgs := queryEdge.Site.Common().Args[1:] 51 | 52 | // Skip the context argument, if using a *Context query variant. 53 | if strings.HasPrefix(queryEdge.Site.Value().Call.Value.String(), "Context") { 54 | queryArgs = queryArgs[1:] 55 | } 56 | 57 | // Get the query function parameter. 58 | query := queryArgs[0] 59 | 60 | // Ensure it is a constant (prepared statement), otherwise report 61 | // potential SQL injection. 62 | if _, isConst := query.(*ssa.Const); !isConst { 63 | pass.Reportf(result.SinkValue.Pos(), "potential sql injection") 64 | } 65 | } 66 | ``` 67 | 68 | ### `taint` 69 | 70 | The `taint` CLI is a an interactive tool to find potential security vulnerabilities. Can be used 71 | to find potential SQL injections, log injections, and cross-site scripting (XSS) vulnerabilities, 72 | among other types of vulnerabilities. 73 | 74 | ```console 75 | $ go install github.com/picatz/taint/cmd/taint@latest 76 | ``` 77 | 78 | ![demo](./cmd/taint/vhs/demo.gif) 79 | 80 | ### `sqli` 81 | 82 | The `sqli` [analyzer](https://pkg.go.dev/golang.org/x/tools/go/analysis#Analyzer) finds potential SQL injections. 83 | 84 | ```console 85 | $ go install github.com/picatz/taint/cmd/sqli@latest 86 | ``` 87 | 88 | ```console 89 | $ cd sql/injection/testdata/src/example 90 | $ cat main.go 91 | package main 92 | 93 | import ( 94 | "database/sql" 95 | "net/http" 96 | ) 97 | 98 | func business(db *sql.DB, q string) { 99 | db.Query(q) // potential sql injection 100 | } 101 | 102 | func run() { 103 | db, _ := sql.Open("sqlite3", ":memory:") 104 | 105 | mux := http.NewServeMux() 106 | 107 | mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 108 | business(db, r.URL.Query().Get("sql-query")) 109 | }) 110 | 111 | http.ListenAndServe(":8080", mux) 112 | } 113 | 114 | func main() { 115 | run() 116 | } 117 | $ sqli main.go 118 | ./sql/injection/testdata/src/example/main.go:9:10: potential sql injection 119 | ``` 120 | 121 | ### `logi` 122 | 123 | The `logi` [analyzer](https://pkg.go.dev/golang.org/x/tools/go/analysis#Analyzer) finds potential log injections. 124 | 125 | ```console 126 | $ go install github.com/picatz/taint/cmd/logi@latest 127 | ``` 128 | 129 | ```console 130 | $ cd log/injection/testdata/src/a 131 | $ cat main.go 132 | package main 133 | 134 | import ( 135 | "log" 136 | "net/http" 137 | ) 138 | 139 | func main() { 140 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 141 | log.Println(r.URL.Query().Get("input")) 142 | }) 143 | 144 | http.ListenAndServe(":8080", nil) 145 | } 146 | $ logi main.go 147 | ./log/injection/testdata/src/example/main.go:10:14: potential log injection 148 | ``` 149 | 150 | ### `xss` 151 | 152 | The `xss` [analyzer](https://pkg.go.dev/golang.org/x/tools/go/analysis#Analyzer) finds potential cross-site scripting (XSS) vulnerabilities. 153 | 154 | ```console 155 | $ go install github.com/picatz/taint/cmd/xss@latest 156 | ``` 157 | 158 | ```console 159 | $ cd xss/testdata/src/a 160 | $ cat main.go 161 | package main 162 | 163 | import ( 164 | "net/http" 165 | ) 166 | 167 | func main() { 168 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 169 | w.Write([]byte(r.URL.Query().Get("input"))) // want "potential XSS" 170 | }) 171 | 172 | http.ListenAndServe(":8080", nil) 173 | } 174 | $ xss main.go 175 | ./xss/testdata/src/example/main.go:9:8: potential XSS 176 | ``` 177 | -------------------------------------------------------------------------------- /callgraphutil/aliases.go: -------------------------------------------------------------------------------- 1 | package callgraphutil 2 | 3 | import "golang.org/x/tools/go/callgraph" 4 | 5 | // Nodes is a handy alias for a slice of callgraph.Nodes. 6 | type Nodes = []*callgraph.Node 7 | 8 | // Edges is a handy alias for a slice of callgraph.Edges. 9 | type Edges = []*callgraph.Edge 10 | -------------------------------------------------------------------------------- /callgraphutil/calls.go: -------------------------------------------------------------------------------- 1 | package callgraphutil 2 | 3 | import "golang.org/x/tools/go/callgraph" 4 | 5 | // CalleesOf returns nodes that are called by the caller node. 6 | func CalleesOf(caller *callgraph.Node) Nodes { 7 | calleesMap := make(map[*callgraph.Node]bool) 8 | for _, e := range caller.Out { 9 | calleesMap[e.Callee] = true 10 | } 11 | 12 | // Convert map to slice. 13 | calleesSlice := make([]*callgraph.Node, 0, len(calleesMap)) 14 | for callee := range calleesMap { 15 | calleesSlice = append(calleesSlice, callee) 16 | } 17 | 18 | return calleesSlice 19 | } 20 | 21 | // CallersOf returns nodes that call the callee node. 22 | func CallersOf(callee *callgraph.Node) Nodes { 23 | uniqCallers := make(map[*callgraph.Node]bool) 24 | for _, e := range callee.In { 25 | uniqCallers[e.Caller] = true 26 | } 27 | 28 | // Convert map to slice. 29 | callersSlice := make(Nodes, 0, len(uniqCallers)) 30 | for caller := range uniqCallers { 31 | callersSlice = append(callersSlice, caller) 32 | } 33 | 34 | return callersSlice 35 | } 36 | -------------------------------------------------------------------------------- /callgraphutil/cosmograph.go: -------------------------------------------------------------------------------- 1 | package callgraphutil 2 | 3 | import ( 4 | "encoding/csv" 5 | "fmt" 6 | "io" 7 | 8 | "golang.org/x/tools/go/callgraph" 9 | ) 10 | 11 | // WriteComsmograph writes the given callgraph.Graph to the given io.Writer in CSV 12 | // format, which can be used to generate a visual representation of the call 13 | // graph using Comsmograph. 14 | // 15 | // https://cosmograph.app/run/ 16 | func WriteCosmograph(graph, metadata io.Writer, g *callgraph.Graph) error { 17 | graphWriter := csv.NewWriter(graph) 18 | graphWriter.Comma = ',' 19 | defer graphWriter.Flush() 20 | 21 | metadataWriter := csv.NewWriter(metadata) 22 | metadataWriter.Comma = ',' 23 | defer metadataWriter.Flush() 24 | 25 | // Write header. 26 | if err := graphWriter.Write([]string{"source", "target"}); err != nil { 27 | return fmt.Errorf("failed to write header: %w", err) 28 | } 29 | 30 | // Write metadata header. 31 | if err := metadataWriter.Write([]string{"id", "pkg", "func"}); err != nil { 32 | return fmt.Errorf("failed to write metadata header: %w", err) 33 | } 34 | 35 | // Write edges. 36 | for _, n := range g.Nodes { 37 | // TODO: fix this so there's not so many "shared" functions? 38 | // 39 | // It is a bit of a hack, but it works for now. 40 | var pkgPath string 41 | if n.Func.Pkg != nil { 42 | pkgPath = n.Func.Pkg.Pkg.Path() 43 | } else { 44 | pkgPath = "shared" 45 | } 46 | 47 | // Write metadata. 48 | if err := metadataWriter.Write([]string{ 49 | fmt.Sprintf("%d", n.ID), 50 | pkgPath, 51 | n.Func.String(), 52 | }); err != nil { 53 | return fmt.Errorf("failed to write metadata: %w", err) 54 | } 55 | 56 | for _, e := range n.Out { 57 | // Write edge. 58 | if err := graphWriter.Write([]string{ 59 | fmt.Sprintf("%d", n.ID), 60 | fmt.Sprintf("%d", e.Callee.ID), 61 | }); err != nil { 62 | return fmt.Errorf("failed to write edge: %w", err) 63 | } 64 | } 65 | } 66 | 67 | return nil 68 | } 69 | -------------------------------------------------------------------------------- /callgraphutil/cosmograph_test.go: -------------------------------------------------------------------------------- 1 | package callgraphutil_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "testing" 8 | 9 | "github.com/picatz/taint/callgraphutil" 10 | ) 11 | 12 | func TestWriteCosmograph(t *testing.T) { 13 | var ( 14 | ownerName = "picatz" 15 | repoName = "taint" 16 | ) 17 | 18 | repo, _, err := cloneGitHubRepository(context.Background(), ownerName, repoName) 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | 23 | pkgs, err := loadPackages(context.Background(), repo, "./...") 24 | if err != nil { 25 | t.Fatal(err) 26 | } 27 | 28 | mainFn, srcFns, err := loadSSA(context.Background(), pkgs) 29 | if err != nil { 30 | t.Fatal(err) 31 | } 32 | 33 | cg, err := loadCallGraph(context.Background(), mainFn, srcFns) 34 | if err != nil { 35 | t.Fatal(err) 36 | } 37 | 38 | graphOutput, err := os.Create(fmt.Sprintf("%s.csv", repoName)) 39 | if err != nil { 40 | t.Fatal(err) 41 | } 42 | defer graphOutput.Close() 43 | 44 | metadataOutput, err := os.Create(fmt.Sprintf("%s-metadata.csv", repoName)) 45 | if err != nil { 46 | t.Fatal(err) 47 | } 48 | defer metadataOutput.Close() 49 | 50 | err = callgraphutil.WriteCosmograph(graphOutput, metadataOutput, cg) 51 | if err != nil { 52 | t.Fatal(err) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /callgraphutil/csv.go: -------------------------------------------------------------------------------- 1 | package callgraphutil 2 | 3 | import ( 4 | "encoding/csv" 5 | "fmt" 6 | "io" 7 | "runtime" 8 | "strings" 9 | 10 | "golang.org/x/tools/go/callgraph" 11 | ) 12 | 13 | // WriteCSV writes the given callgraph.Graph to the given io.Writer in CSV 14 | // format. This format can be used to generate a visual representation of the 15 | // call graph using many different tools. 16 | func WriteCSV(w io.Writer, g *callgraph.Graph) error { 17 | cw := csv.NewWriter(w) 18 | cw.Comma = ',' 19 | defer cw.Flush() 20 | 21 | // Write header. 22 | if err := cw.Write([]string{ 23 | "source_pkg", 24 | "source_pkg_go_version", 25 | "source_pkg_origin", 26 | "source_func", 27 | "source_func_name", 28 | "source_func_signature", 29 | "target_pkg", 30 | "target_pkg_go_version", 31 | "target_pkg_origin", 32 | "target_func", 33 | "target_func_name", 34 | "target_func_signature", 35 | }); err != nil { 36 | return fmt.Errorf("failed to write header: %w", err) 37 | } 38 | 39 | // Write edges. 40 | for _, n := range g.Nodes { 41 | source, err := getNodeInfo(n) 42 | if err != nil { 43 | return fmt.Errorf("failed to get node info: %w", err) 44 | } 45 | 46 | for _, e := range n.Out { 47 | target, err := getNodeInfo(e.Callee) 48 | if err != nil { 49 | return fmt.Errorf("failed to get node info: %w", err) 50 | } 51 | 52 | record := []string{} 53 | record = append(record, source.CSV()...) 54 | record = append(record, target.CSV()...) 55 | 56 | // Write edge. 57 | if err := cw.Write(record); err != nil { 58 | return fmt.Errorf("failed to write edge: %w", err) 59 | } 60 | } 61 | } 62 | 63 | return nil 64 | } 65 | 66 | // nodeInfo is a struct that contains information about a callgraph.Node used 67 | // to generate CSV output. 68 | type nodeInfo struct { 69 | pkgPath string 70 | pkgGoVersion string 71 | pkgOrigin string 72 | pkgFunc string 73 | pkgFuncName string 74 | pkgFuncSignature string 75 | } 76 | 77 | // CSV returns single record for the node. 78 | func (n *nodeInfo) CSV() []string { 79 | return []string{ 80 | n.pkgPath, 81 | n.pkgGoVersion, 82 | n.pkgOrigin, 83 | n.pkgFunc, 84 | n.pkgFuncName, 85 | n.pkgFuncSignature, 86 | } 87 | } 88 | 89 | // getNodeInfo returns a nodeInfo struct for the given callgraph.Node. 90 | func getNodeInfo(n *callgraph.Node) (*nodeInfo, error) { 91 | info := &nodeInfo{ 92 | pkgPath: "unknown", 93 | pkgGoVersion: runtime.Version(), 94 | pkgOrigin: "unknown", 95 | pkgFunc: n.Func.String(), 96 | pkgFuncName: n.Func.Name(), 97 | pkgFuncSignature: n.Func.Signature.String(), 98 | } 99 | 100 | if n.Func.Pkg != nil { 101 | info.pkgPath = n.Func.Pkg.Pkg.Path() 102 | 103 | if goVersion := n.Func.Pkg.Pkg.GoVersion(); goVersion != "" { 104 | info.pkgGoVersion = goVersion 105 | } 106 | } 107 | 108 | if strings.Contains(info.pkgPath, ".") { 109 | info.pkgOrigin = strings.Split(info.pkgPath, "/")[0] 110 | } else { 111 | // If the package path doesn't contain a dot, then it's 112 | // probably a standard library package? This is a pattern 113 | // I've used and seen elsewhere. 114 | info.pkgOrigin = "stdlib" 115 | } 116 | 117 | return info, nil 118 | } 119 | -------------------------------------------------------------------------------- /callgraphutil/csv_test.go: -------------------------------------------------------------------------------- 1 | package callgraphutil_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "testing" 8 | 9 | "github.com/picatz/taint/callgraphutil" 10 | ) 11 | 12 | func TestWriteCSV(t *testing.T) { 13 | var ( 14 | ownerName = "picatz" 15 | repoName = "taint" 16 | ) 17 | 18 | repo, _, err := cloneGitHubRepository(context.Background(), ownerName, repoName) 19 | if err != nil { 20 | t.Fatal(err) 21 | } 22 | 23 | pkgs, err := loadPackages(context.Background(), repo, "./...") 24 | if err != nil { 25 | t.Fatal(err) 26 | } 27 | 28 | mainFn, srcFns, err := loadSSA(context.Background(), pkgs) 29 | if err != nil { 30 | t.Fatal(err) 31 | } 32 | 33 | cg, err := loadCallGraph(context.Background(), mainFn, srcFns) 34 | if err != nil { 35 | t.Fatal(err) 36 | } 37 | 38 | fh, err := os.Create(fmt.Sprintf("%s.csv", repoName)) 39 | if err != nil { 40 | t.Fatal(err) 41 | } 42 | defer fh.Close() 43 | 44 | err = callgraphutil.WriteCSV(fh, cg) 45 | if err != nil { 46 | t.Fatal(err) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /callgraphutil/doc.go: -------------------------------------------------------------------------------- 1 | // Package callgraphutil implements utilities for golang.org/x/tools/go/callgraph 2 | // including path searching, graph construction, printing, and more. 3 | package callgraphutil 4 | -------------------------------------------------------------------------------- /callgraphutil/dot.go: -------------------------------------------------------------------------------- 1 | package callgraphutil 2 | 3 | import ( 4 | "bufio" 5 | "fmt" 6 | "io" 7 | "strings" 8 | 9 | "golang.org/x/tools/go/callgraph" 10 | ) 11 | 12 | // WriteDOT writes the given callgraph.Graph to the given io.Writer in the 13 | // DOT format, which can be used to generate a visual representation of the 14 | // call graph using Graphviz. 15 | func WriteDOT(w io.Writer, g *callgraph.Graph) error { 16 | b := bufio.NewWriter(w) 17 | defer b.Flush() 18 | 19 | b.WriteString("digraph callgraph {\n") 20 | b.WriteString("\tgraph [fontname=\"Helvetica\", overlap=false normalize=true];\n") 21 | b.WriteString("\tnode [fontname=\"Helvetica\" shape=box];\n") 22 | b.WriteString("\tedge [fontname=\"Helvetica\"];\n") 23 | 24 | edges := []*callgraph.Edge{} 25 | 26 | nodesByPkg := map[string][]*callgraph.Node{} 27 | 28 | addPkgNode := func(n *callgraph.Node) { 29 | // TODO: fix this so there's not so many "shared" functions? 30 | // 31 | // It is a bit of a hack, but it works for now. 32 | var pkgPath string 33 | if n.Func.Pkg != nil { 34 | pkgPath = n.Func.Pkg.Pkg.Path() 35 | } else { 36 | pkgPath = "shared" 37 | } 38 | 39 | // Check if the package already exists. 40 | if _, ok := nodesByPkg[pkgPath]; !ok { 41 | // If not, create it. 42 | nodesByPkg[pkgPath] = []*callgraph.Node{} 43 | } 44 | nodesByPkg[pkgPath] = append(nodesByPkg[pkgPath], n) 45 | } 46 | 47 | // Check if root node exists, if so, write it. 48 | if g.Root != nil { 49 | b.WriteString(fmt.Sprintf("\troot = %d;\n", g.Root.ID)) 50 | } 51 | 52 | // Process nodes and edges. 53 | for _, n := range g.Nodes { 54 | // Add node to map of nodes by package. 55 | addPkgNode(n) 56 | 57 | // Add edges 58 | edges = append(edges, n.Out...) 59 | } 60 | 61 | // Write nodes by package. 62 | for pkg, nodes := range nodesByPkg { 63 | // Make the pkg name sugraph cluster friendly (remove dots, dashes, and slashes). 64 | clusterName := strings.Replace(pkg, ".", "_", -1) 65 | clusterName = strings.Replace(clusterName, "/", "_", -1) 66 | clusterName = strings.Replace(clusterName, "-", "_", -1) 67 | 68 | // NOTE: even if we're using a subgraph cluster, it may not be 69 | // respected by all Graphviz layout engines. For example, the 70 | // "dot" engine will respect the cluster, but the "sfdp" engine 71 | // will not. 72 | b.WriteString(fmt.Sprintf("\tsubgraph cluster_%s {\n", clusterName)) 73 | b.WriteString(fmt.Sprintf("\t\tlabel=%q;\n", pkg)) 74 | for _, n := range nodes { 75 | b.WriteString(fmt.Sprintf("\t\t%d [label=%q];\n", n.ID, n.Func)) 76 | } 77 | b.WriteString("\t}\n") 78 | } 79 | 80 | // Write edges. 81 | for _, e := range edges { 82 | b.WriteString(fmt.Sprintf("\t%d -> %d;\n", e.Caller.ID, e.Callee.ID)) 83 | } 84 | 85 | b.WriteString("}\n") 86 | 87 | return nil 88 | } 89 | -------------------------------------------------------------------------------- /callgraphutil/dot_test.go: -------------------------------------------------------------------------------- 1 | package callgraphutil_test 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "fmt" 7 | "go/ast" 8 | "go/parser" 9 | "go/token" 10 | "os" 11 | "testing" 12 | 13 | "github.com/go-git/go-git/v5" 14 | "github.com/picatz/taint/callgraphutil" 15 | "golang.org/x/tools/go/callgraph" 16 | "golang.org/x/tools/go/packages" 17 | "golang.org/x/tools/go/ssa" 18 | "golang.org/x/tools/go/ssa/ssautil" 19 | ) 20 | 21 | func cloneGitHubRepository(ctx context.Context, ownerName, repoName string) (string, string, error) { 22 | // Get the owner and repo part of the URL. 23 | ownerAndRepo := ownerName + "/" + repoName 24 | 25 | // Get the directory path. 26 | dir, err := os.MkdirTemp(os.TempDir(), fmt.Sprintf("callgraphutil_csv-%s-%s", ownerName, repoName)) 27 | if err != nil { 28 | return "", "", fmt.Errorf("failed to create temp dir: %w", err) 29 | } 30 | 31 | // Clone the repository. 32 | repo, err := git.PlainCloneContext(ctx, dir, false, &git.CloneOptions{ 33 | URL: fmt.Sprintf("https://github.com/%s", ownerAndRepo), 34 | Depth: 1, 35 | Tags: git.NoTags, 36 | SingleBranch: true, 37 | }) 38 | if err != nil { 39 | return dir, "", fmt.Errorf("%w", err) 40 | } 41 | 42 | // Get the repository's HEAD. 43 | head, err := repo.Head() 44 | if err != nil { 45 | return dir, "", fmt.Errorf("%w", err) 46 | } 47 | 48 | return dir, head.Hash().String(), nil 49 | } 50 | 51 | func loadPackages(ctx context.Context, dir, pattern string) ([]*packages.Package, error) { 52 | loadMode := 53 | packages.NeedName | 54 | packages.NeedDeps | 55 | packages.NeedFiles | 56 | packages.NeedModule | 57 | packages.NeedTypes | 58 | packages.NeedImports | 59 | packages.NeedSyntax | 60 | packages.NeedTypesInfo 61 | // packages.NeedTypesSizes | 62 | // packages.NeedCompiledGoFiles | 63 | // packages.NeedExportFile | 64 | // packages.NeedEmbedPatterns 65 | 66 | // parseMode := parser.ParseComments 67 | parseMode := parser.SkipObjectResolution 68 | 69 | // patterns := []string{dir} 70 | patterns := []string{pattern} 71 | // patterns := []string{"all"} 72 | 73 | pkgs, err := packages.Load(&packages.Config{ 74 | Mode: loadMode, 75 | Context: ctx, 76 | Env: os.Environ(), 77 | Dir: dir, 78 | Tests: false, 79 | ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { 80 | return parser.ParseFile(fset, filename, src, parseMode) 81 | }, 82 | }, patterns...) 83 | if err != nil { 84 | return nil, err 85 | } 86 | 87 | return pkgs, nil 88 | 89 | } 90 | 91 | func loadSSA(ctx context.Context, pkgs []*packages.Package) (mainFn *ssa.Function, srcFns []*ssa.Function, err error) { 92 | ssaBuildMode := ssa.InstantiateGenerics // ssa.SanityCheckFunctions | ssa.GlobalDebug 93 | 94 | // Analyze the package. 95 | ssaProg, ssaPkgs := ssautil.Packages(pkgs, ssaBuildMode) 96 | 97 | // It's possible that the ssaProg is nil? 98 | if ssaProg == nil { 99 | err = fmt.Errorf("failed to create new ssa program") 100 | return 101 | } 102 | 103 | ssaProg.Build() 104 | 105 | for _, pkg := range ssaPkgs { 106 | if pkg == nil { 107 | continue 108 | } 109 | pkg.Build() 110 | } 111 | 112 | // Remove nil ssaPkgs by iterating over the slice of packages 113 | // and for each nil package, we append the slice up to that 114 | // index and then append the slice from the next index to the 115 | // end of the slice. This effectively removes the nil package 116 | // from the slice without having to allocate a new slice. 117 | for i := 0; i < len(ssaPkgs); i++ { 118 | if ssaPkgs[i] == nil { 119 | ssaPkgs = append(ssaPkgs[:i], ssaPkgs[i+1:]...) 120 | i-- 121 | } 122 | } 123 | 124 | mainPkgs := ssautil.MainPackages(ssaPkgs) 125 | 126 | mainFn = mainPkgs[0].Members["main"].(*ssa.Function) 127 | 128 | for _, pkg := range ssaPkgs { 129 | for _, fn := range pkg.Members { 130 | if fn.Object() == nil { 131 | continue 132 | } 133 | 134 | if fn.Object().Name() == "_" { 135 | continue 136 | } 137 | 138 | pkgFn := pkg.Func(fn.Object().Name()) 139 | if pkgFn == nil { 140 | continue 141 | } 142 | 143 | var addAnons func(f *ssa.Function) 144 | addAnons = func(f *ssa.Function) { 145 | srcFns = append(srcFns, f) 146 | for _, anon := range f.AnonFuncs { 147 | addAnons(anon) 148 | } 149 | } 150 | addAnons(pkgFn) 151 | } 152 | } 153 | 154 | if mainFn == nil { 155 | err = fmt.Errorf("failed to find main function") 156 | return 157 | } 158 | 159 | return 160 | } 161 | 162 | func loadCallGraph(ctx context.Context, mainFn *ssa.Function, srcFns []*ssa.Function) (*callgraph.Graph, error) { 163 | cg, err := callgraphutil.NewGraph(mainFn, srcFns...) 164 | if err != nil { 165 | return nil, fmt.Errorf("failed to create new callgraph: %w", err) 166 | } 167 | 168 | return cg, nil 169 | } 170 | 171 | func TestWriteDOT(t *testing.T) { 172 | repo, _, err := cloneGitHubRepository(context.Background(), "picatz", "taint") 173 | if err != nil { 174 | t.Fatal(err) 175 | } 176 | 177 | pkgs, err := loadPackages(context.Background(), repo, "./...") 178 | if err != nil { 179 | t.Fatal(err) 180 | } 181 | 182 | mainFn, srcFns, err := loadSSA(context.Background(), pkgs) 183 | if err != nil { 184 | t.Fatal(err) 185 | } 186 | 187 | cg, err := loadCallGraph(context.Background(), mainFn, srcFns) 188 | if err != nil { 189 | t.Fatal(err) 190 | } 191 | 192 | output := &bytes.Buffer{} 193 | 194 | err = callgraphutil.WriteDOT(output, cg) 195 | if err != nil { 196 | t.Fatal(err) 197 | } 198 | 199 | fmt.Println(output.String()) 200 | } 201 | -------------------------------------------------------------------------------- /callgraphutil/graph.go: -------------------------------------------------------------------------------- 1 | package callgraphutil 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "go/types" 7 | 8 | "golang.org/x/tools/go/callgraph" 9 | "golang.org/x/tools/go/ssa" 10 | "golang.org/x/tools/go/ssa/ssautil" 11 | ) 12 | 13 | // GraphString returns a string representation of the call graph, 14 | // which is a sequence of nodes separated by newlines, with the 15 | // callees of each node indented by a tab. 16 | func GraphString(g *callgraph.Graph) string { 17 | var buf bytes.Buffer 18 | 19 | for _, n := range g.Nodes { 20 | fmt.Fprintf(&buf, "%s\n", n) 21 | for _, e := range n.Out { 22 | fmt.Fprintf(&buf, "\t→ %s\n", e.Callee) 23 | } 24 | fmt.Fprintf(&buf, "\n") 25 | } 26 | 27 | return buf.String() 28 | } 29 | 30 | // NewGraph returns a new Graph with the specified root node. 31 | // 32 | // Typically, the root node is the main function of the program, and the 33 | // srcFns are the source functions that are of interest to the caller. But, the root 34 | // node can be any function, and the srcFns can be any set of functions. 35 | // 36 | // This algorithm attempts to add all source functions reachable from the root node 37 | // by traversing the SSA IR and adding edges to the graph; it handles calls 38 | // to functions, methods, closures, and interfaces. It may miss some complex 39 | // edges today, such as stucts containing function fields accessed via slice or map 40 | // indexing. This is a known limitation, but something we hope to improve in the near future. 41 | // https://github.com/picatz/taint/issues/23 42 | func NewGraph(root *ssa.Function, srcFns ...*ssa.Function) (*callgraph.Graph, error) { 43 | g := &callgraph.Graph{ 44 | Nodes: make(map[*ssa.Function]*callgraph.Node), 45 | } 46 | 47 | g.Root = g.CreateNode(root) 48 | 49 | allFns := ssautil.AllFunctions(root.Prog) 50 | 51 | for _, srcFn := range srcFns { 52 | // debug("adding src function %d/%d: %v\n", i+1, len(srcFns), srcFn) 53 | 54 | err := AddFunction(g, srcFn, allFns) 55 | if err != nil { 56 | return g, fmt.Errorf("failed to add src function %v: %w", srcFn, err) 57 | } 58 | 59 | for _, block := range srcFn.DomPreorder() { 60 | for _, instr := range block.Instrs { 61 | checkBlockInstruction(root, allFns, g, srcFn, instr) 62 | } 63 | } 64 | } 65 | 66 | return g, nil 67 | } 68 | 69 | // checkBlockInstruction checks the given instruction for any function calls, adding 70 | // edges to the call graph as needed and recursively adding any new functions to the graph 71 | // that are discovered during the process (typically via interface methods). 72 | func checkBlockInstruction(root *ssa.Function, allFns map[*ssa.Function]bool, g *callgraph.Graph, fn *ssa.Function, instr ssa.Instruction) error { 73 | // debug("\tcheckBlockInstruction: %v\n", instr) 74 | switch instrt := instr.(type) { 75 | case *ssa.Call: 76 | var instrCall *ssa.Function 77 | 78 | switch callt := instrt.Call.Value.(type) { 79 | case *ssa.Function: 80 | instrCall = callt 81 | 82 | for _, instrtCallArg := range instrt.Call.Args { 83 | switch instrtCallArgt := instrtCallArg.(type) { 84 | case *ssa.ChangeInterface: 85 | // Track type casts through matching interface methods. 86 | // 87 | // # Example 88 | // 89 | // func buffer(r io.Reader) io.Reader { 90 | // return bufio.NewReader(r) 91 | // } 92 | // 93 | // func mirror(w http.ResponseWriter, r *http.Request) { 94 | // _, err := io.Copy(w, buffer(r.Body)) // w is an http.ResponseWriter, convert to io.Writer for io.Copy 95 | // if err != nil { 96 | // panic(err) 97 | // } 98 | // } 99 | // 100 | // io.Copy is called with an io.Writer, but the underlying type is a net/http.ResponseWriter. 101 | // 102 | // n11:net/http.HandleFunc → n1:c.mirror → n5:io.Copy → n6:(io.Writer).Write → n7:(net/http.ResponseWriter).Write 103 | // 104 | switch argtt := instrtCallArgt.Type().Underlying().(type) { 105 | case *types.Interface: 106 | numMethods := argtt.NumMethods() 107 | 108 | for i := 0; i < numMethods; i++ { 109 | method := argtt.Method(i) 110 | 111 | methodPkg := method.Pkg() 112 | if methodPkg == nil { 113 | // Universe scope method, such as "error.Error". 114 | continue 115 | } 116 | 117 | pkg := root.Prog.ImportedPackage(method.Pkg().Path()) 118 | if pkg == nil { 119 | // This is a method from a package that is not imported, so we skip it. 120 | continue 121 | } 122 | fn := pkg.Func(method.Name()) 123 | if fn == nil { 124 | fn = pkg.Prog.NewFunction(method.Name(), method.Type().(*types.Signature), "callgraph") 125 | } 126 | 127 | callgraph.AddEdge(g.CreateNode(instrCall), instrt, g.CreateNode(fn)) 128 | 129 | switch xType := instrtCallArgt.X.Type().(type) { 130 | case *types.Named: 131 | named := xType 132 | 133 | pkg2 := root.Prog.ImportedPackage(named.Obj().Pkg().Path()) 134 | 135 | methodSet := pkg2.Prog.MethodSets.MethodSet(named) 136 | methodSel := methodSet.Lookup(pkg2.Pkg, method.Name()) 137 | 138 | if methodSel == nil { 139 | continue 140 | } 141 | 142 | methodType := methodSel.Type().(*types.Signature) 143 | 144 | fn2 := pkg2.Func(method.Name()) 145 | if fn2 == nil { 146 | fn2 = pkg2.Prog.NewFunction(method.Name(), methodType, "callgraph") 147 | } 148 | 149 | callgraph.AddEdge(g.CreateNode(fn), instrt, g.CreateNode(fn2)) 150 | default: 151 | continue 152 | } 153 | } 154 | } 155 | } 156 | } 157 | case *ssa.MakeClosure: 158 | switch calltFn := callt.Fn.(type) { 159 | case *ssa.Function: 160 | instrCall = calltFn 161 | } 162 | case *ssa.Parameter: 163 | // This is likely a method call, so we need to 164 | // get the function from the method receiver which 165 | // is not available directly from the call instruction, 166 | // but rather from the package level function. 167 | 168 | // Skip this instruction if we could not determine 169 | // the function being called. 170 | if !instrt.Call.IsInvoke() || (instrt.Call.Method == nil) { 171 | return nil 172 | } 173 | 174 | // TODO: should we share the resulting function? 175 | instrtCallMethodPkg := instrt.Call.Method.Pkg() 176 | if instrtCallMethodPkg == nil { 177 | // This is an interface method call from the universe scope, such as "error.Error", 178 | // so we return nil to skip this instruction, which we will assume is safe. 179 | return nil 180 | } else { 181 | pkg := root.Prog.ImportedPackage(instrt.Call.Method.Pkg().Path()) 182 | 183 | fn := pkg.Func(instrt.Call.Method.Name()) 184 | if fn == nil { 185 | fn = pkg.Prog.NewFunction(instrt.Call.Method.Name(), instrt.Call.Signature(), "callgraph") 186 | } 187 | instrCall = fn 188 | } 189 | default: 190 | // case *ssa.TypeAssert: ?? 191 | // fmt.Printf("unknown call type: %v: %[1]T\n", callt) 192 | } 193 | 194 | // If we could not determine the function being 195 | // called, skip this instruction. 196 | if instrCall == nil { 197 | return nil 198 | } 199 | 200 | callgraph.AddEdge(g.CreateNode(fn), instrt, g.CreateNode(instrCall)) 201 | 202 | err := AddFunction(g, instrCall, allFns) 203 | if err != nil { 204 | return fmt.Errorf("failed to add function %v from block instr: %w", instrCall, err) 205 | } 206 | 207 | // attempt to link function arguments that are functions 208 | for a := 0; a < len(instrt.Call.Args); a++ { 209 | arg := instrt.Call.Args[a] 210 | switch argt := arg.(type) { 211 | case *ssa.Function: 212 | // TODO: check if edge already exists? 213 | callgraph.AddEdge(g.CreateNode(instrCall), instrt, g.CreateNode(argt)) 214 | case *ssa.MakeClosure: 215 | switch argtFn := argt.Fn.(type) { 216 | case *ssa.Function: 217 | callgraph.AddEdge(g.CreateNode(instrCall), instrt, g.CreateNode(argtFn)) 218 | } 219 | } 220 | } 221 | } 222 | 223 | // Delete duplicate edges that may have been added, which is a responsibility of the caller 224 | // when using the callgraph.AddEdge function directly. 225 | for _, n := range g.Nodes { 226 | // debug("checking node %v\n", n) 227 | for i := 0; i < len(n.Out); i++ { 228 | for j := i + 1; j < len(n.Out); j++ { 229 | if n.Out[i].Callee == n.Out[j].Callee { 230 | // debug("deleting duplicate edge %v\n", n.Out[j]) 231 | n.Out = append(n.Out[:j], n.Out[j+1:]...) 232 | j-- 233 | } 234 | } 235 | } 236 | } 237 | 238 | return nil 239 | } 240 | 241 | // AddFunction analyzes the given target SSA function, adding information to the call graph. 242 | // 243 | // Based on the implementation of golang.org/x/tools/cmd/guru/callers.go: 244 | // https://cs.opensource.google/go/x/tools/+/master:cmd/guru/callers.go;drc=3e0d083b858b3fdb7d095b5a3deb184aa0a5d35e;bpv=1;bpt=1;l=90 245 | func AddFunction(cg *callgraph.Graph, target *ssa.Function, allFns map[*ssa.Function]bool) error { 246 | // debug("\tAddFunction: %v (all funcs %d)\n", target, len(allFns)) 247 | 248 | // First check if we have already processed this function. 249 | if _, ok := cg.Nodes[target]; ok { 250 | return nil 251 | } 252 | 253 | targetNode := cg.CreateNode(target) 254 | 255 | // Find receiver type (for methods). 256 | var recvType types.Type 257 | if recv := target.Signature.Recv(); recv != nil { 258 | recvType = recv.Type() 259 | } 260 | 261 | if len(allFns) == 0 { 262 | allFns = ssautil.AllFunctions(target.Prog) 263 | } 264 | 265 | // Find all direct calls to function, 266 | // or a place where its address is taken. 267 | for progFn := range allFns { 268 | var space [32]*ssa.Value // preallocate 269 | 270 | for _, block := range progFn.DomPreorder() { 271 | for _, instr := range block.Instrs { 272 | // Is this a method (T).f of a concrete type T 273 | // whose runtime type descriptor is address-taken? 274 | // (To be fully sound, we would have to check that 275 | // the type doesn't make it to reflection as a 276 | // subelement of some other address-taken type.) 277 | if recvType != nil { 278 | if mi, ok := instr.(*ssa.MakeInterface); ok { 279 | if types.Identical(mi.X.Type(), recvType) { 280 | 281 | return nil // T is address-taken 282 | } 283 | if ptr, ok := mi.X.Type().(*types.Pointer); ok && 284 | types.Identical(ptr.Elem(), recvType) { 285 | return nil // *T is address-taken 286 | } 287 | } 288 | } 289 | 290 | // Direct call to target? 291 | rands := instr.Operands(space[:0]) 292 | if site, ok := instr.(ssa.CallInstruction); ok && site.Common().Value == target { 293 | callgraph.AddEdge(cg.CreateNode(progFn), site, targetNode) 294 | rands = rands[1:] // skip .Value (rands[0]) 295 | } 296 | 297 | // Address-taken? 298 | for _, rand := range rands { 299 | if rand != nil && *rand == target { 300 | return nil 301 | } 302 | } 303 | } 304 | } 305 | } 306 | 307 | return nil 308 | } 309 | -------------------------------------------------------------------------------- /callgraphutil/graph_vulncheck.go: -------------------------------------------------------------------------------- 1 | package callgraphutil 2 | 3 | import ( 4 | "context" 5 | 6 | "golang.org/x/tools/go/callgraph" 7 | "golang.org/x/tools/go/callgraph/cha" 8 | "golang.org/x/tools/go/callgraph/vta" 9 | "golang.org/x/tools/go/ssa" 10 | "golang.org/x/tools/go/ssa/ssautil" 11 | ) 12 | 13 | // NewVulncheckCallGraph builds a call graph of prog based on VTA analysis, 14 | // straight from the govulncheck project. This is used to demonstrate the 15 | // difference between the call graph built by this package's algorithm and 16 | // govulncheck's algorithm (based on CHA and VTA analysis). 17 | // 18 | // This method is based on the following: 19 | // https://github.com/golang/vuln/blob/7335627909c99e391cf911fcd214badcb8aa6d7d/internal/vulncheck/utils.go#L63 20 | func NewVulncheckCallGraph(ctx context.Context, prog *ssa.Program, entries []*ssa.Function) (*callgraph.Graph, error) { 21 | entrySlice := make(map[*ssa.Function]bool) 22 | for _, e := range entries { 23 | entrySlice[e] = true 24 | } 25 | 26 | if err := ctx.Err(); err != nil { // cancelled? 27 | return nil, err 28 | } 29 | initial := cha.CallGraph(prog) 30 | allFuncs := ssautil.AllFunctions(prog) 31 | 32 | fslice := forwardSlice(entrySlice, initial) 33 | // Keep only actually linked functions. 34 | pruneSet(fslice, allFuncs) 35 | 36 | if err := ctx.Err(); err != nil { // cancelled? 37 | return nil, err 38 | } 39 | vtaCg := vta.CallGraph(fslice, initial) 40 | 41 | // Repeat the process once more, this time using 42 | // the produced VTA call graph as the base graph. 43 | fslice = forwardSlice(entrySlice, vtaCg) 44 | pruneSet(fslice, allFuncs) 45 | 46 | if err := ctx.Err(); err != nil { // cancelled? 47 | return nil, err 48 | } 49 | cg := vta.CallGraph(fslice, vtaCg) 50 | cg.DeleteSyntheticNodes() 51 | 52 | return cg, nil 53 | } 54 | 55 | // forwardSlice computes the transitive closure of functions forward reachable 56 | // via calls in cg or referred to in an instruction starting from `sources`. 57 | // 58 | // https://github.com/golang/vuln/blob/7335627909c99e391cf911fcd214badcb8aa6d7d/internal/vulncheck/slicing.go#L14 59 | func forwardSlice(sources map[*ssa.Function]bool, cg *callgraph.Graph) map[*ssa.Function]bool { 60 | seen := make(map[*ssa.Function]bool) 61 | var visit func(f *ssa.Function) 62 | visit = func(f *ssa.Function) { 63 | if seen[f] { 64 | return 65 | } 66 | seen[f] = true 67 | 68 | if n := cg.Nodes[f]; n != nil { 69 | for _, e := range n.Out { 70 | if e.Site != nil { 71 | visit(e.Callee.Func) 72 | } 73 | } 74 | } 75 | 76 | var buf [10]*ssa.Value // avoid alloc in common case 77 | for _, b := range f.Blocks { 78 | for _, instr := range b.Instrs { 79 | for _, op := range instr.Operands(buf[:0]) { 80 | if fn, ok := (*op).(*ssa.Function); ok { 81 | visit(fn) 82 | } 83 | } 84 | } 85 | } 86 | } 87 | for source := range sources { 88 | visit(source) 89 | } 90 | return seen 91 | } 92 | 93 | // pruneSet removes functions in `set` that are in `toPrune`. 94 | // 95 | // https://github.com/golang/vuln/blob/7335627909c99e391cf911fcd214badcb8aa6d7d/internal/vulncheck/slicing.go#L49 96 | func pruneSet(set, toPrune map[*ssa.Function]bool) { 97 | for f := range set { 98 | if !toPrune[f] { 99 | delete(set, f) 100 | } 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /callgraphutil/path.go: -------------------------------------------------------------------------------- 1 | package callgraphutil 2 | 3 | import ( 4 | "bytes" 5 | 6 | "golang.org/x/tools/go/callgraph" 7 | ) 8 | 9 | // Path is a sequence of callgraph.Edges, where each edge 10 | // represents a call from a caller to a callee, making up 11 | // a "chain" of calls, e.g.: main → foo → bar → baz. 12 | type Path []*callgraph.Edge 13 | 14 | // Empty returns true if the path is empty, false otherwise. 15 | func (p Path) Empty() bool { 16 | return len(p) == 0 17 | } 18 | 19 | // First returns the first edge in the path, or nil if the path is empty. 20 | func (p Path) First() *callgraph.Edge { 21 | if len(p) == 0 { 22 | return nil 23 | } 24 | return p[0] 25 | } 26 | 27 | // Last returns the last edge in the path, or nil if the path is empty. 28 | func (p Path) Last() *callgraph.Edge { 29 | if len(p) == 0 { 30 | return nil 31 | } 32 | return p[len(p)-1] 33 | } 34 | 35 | // String returns a string representation of the path which 36 | // is a sequence of edges separated by " → ". 37 | // 38 | // Intended to be used while debugging. 39 | func (p Path) String() string { 40 | var buf bytes.Buffer 41 | for i, e := range p { 42 | if i == 0 { 43 | buf.WriteString(e.Caller.String()) 44 | } 45 | 46 | buf.WriteString(" → ") 47 | 48 | buf.WriteString(e.Callee.String()) 49 | } 50 | return buf.String() 51 | } 52 | 53 | // Paths is a collection of paths, which may be logically grouped 54 | // together, e.g.: all paths from main to foo, or all paths from 55 | // main to bar. 56 | type Paths []Path 57 | 58 | // Shortest returns the shortest path in the collection of paths. 59 | // 60 | // If there are no paths, this returns nil. If there are multiple 61 | // paths of the same length, this returns the first path found. 62 | func (p Paths) Shortest() Path { 63 | if len(p) == 0 { 64 | return nil 65 | } 66 | 67 | shortest := p[0] 68 | for _, path := range p { 69 | if len(path) < len(shortest) { 70 | shortest = path 71 | } 72 | } 73 | 74 | return shortest 75 | } 76 | 77 | // Longest returns the longest path in the collection of paths. 78 | // 79 | // If there are no paths, this returns nil. If there are multiple 80 | // paths of the same length, the first path found is returned. 81 | func (p Paths) Longest() Path { 82 | if len(p) == 0 { 83 | return nil 84 | } 85 | 86 | longest := p[0] 87 | for _, path := range p { 88 | if len(path) > len(longest) { 89 | longest = path 90 | } 91 | } 92 | 93 | return longest 94 | } 95 | 96 | // PathSearch returns the first path found from the start node 97 | // to a node that matches the isMatch function. This is a depth 98 | // first search, so it will return the first path found, which 99 | // may not be the shortest path. 100 | // 101 | // To find all paths, use PathsSearch, which returns a collection 102 | // of paths. 103 | func PathSearch(start *callgraph.Node, isMatch func(*callgraph.Node) bool) Path { 104 | var ( 105 | stack = make(Path, 0, 32) 106 | seen = make(map[*callgraph.Node]bool) 107 | 108 | search func(n *callgraph.Node) Path 109 | ) 110 | 111 | search = func(n *callgraph.Node) Path { 112 | if !seen[n] { 113 | // debug("searching: %v\n", n) 114 | seen[n] = true 115 | if isMatch(n) { 116 | return stack 117 | } 118 | for _, e := range n.Out { 119 | stack = append(stack, e) // push 120 | if found := search(e.Callee); found != nil { 121 | return found 122 | } 123 | stack = stack[:len(stack)-1] // pop 124 | } 125 | } 126 | return nil 127 | } 128 | return search(start) 129 | } 130 | 131 | // PathsSearch returns all paths found from the start node 132 | // to a node that matches the isMatch function. Under the hood, 133 | // this is a depth first search. 134 | // 135 | // To find the first path (which may not be the shortest), use PathSearch. 136 | func PathsSearch(start *callgraph.Node, isMatch func(*callgraph.Node) bool) Paths { 137 | var ( 138 | paths = Paths{} 139 | 140 | stack = make(Path, 0, 32) 141 | seen = make(map[*callgraph.Node]bool) 142 | 143 | search func(n *callgraph.Node) 144 | ) 145 | 146 | search = func(n *callgraph.Node) { 147 | if n == nil { 148 | return 149 | } 150 | 151 | // debug("searching: %v\n", n) 152 | if !seen[n] { 153 | seen[n] = true 154 | if isMatch(n) { 155 | paths = append(paths, stack) 156 | 157 | stack = make(Path, 0, 32) 158 | seen = make(map[*callgraph.Node]bool) 159 | return 160 | } 161 | for _, e := range n.Out { 162 | // debug("\tout: %v\n", e) 163 | stack = append(stack, e) // push 164 | search(e.Callee) 165 | if len(stack) == 0 { 166 | continue 167 | } 168 | stack = stack[:len(stack)-1] // pop 169 | } 170 | } 171 | } 172 | search(start) 173 | 174 | return paths 175 | } 176 | 177 | // PathSearchCallTo returns the first path found from the start node 178 | // to a node that matches the function name. 179 | func PathSearchCallTo(start *callgraph.Node, fn string) Path { 180 | return PathSearch(start, func(n *callgraph.Node) bool { 181 | fnStr := n.Func.String() 182 | return fnStr == fn 183 | }) 184 | } 185 | 186 | // PathsSearchCallTo returns the paths that call the given function name, 187 | // which uses SSA function name syntax, e.g.: "(*database/sql.DB).Query". 188 | func PathsSearchCallTo(start *callgraph.Node, fn string) Paths { 189 | return PathsSearch(start, func(n *callgraph.Node) bool { 190 | if n == nil || n.Func == nil { 191 | return false 192 | } 193 | fnStr := n.Func.String() 194 | return fnStr == fn 195 | }) 196 | } 197 | -------------------------------------------------------------------------------- /callgraphutil/ssa.go: -------------------------------------------------------------------------------- 1 | package callgraphutil 2 | 3 | import ( 4 | "golang.org/x/tools/go/callgraph" 5 | "golang.org/x/tools/go/ssa" 6 | ) 7 | 8 | // InstructionsFor returns the ssa.Instruction for the given ssa.Value using 9 | // the given node as the root of the call graph that is searched. 10 | func InstructionsFor(root *callgraph.Node, v ssa.Value) (si ssa.Instruction) { 11 | PathsSearch(root, func(n *callgraph.Node) bool { 12 | for _, b := range root.Func.Blocks { 13 | for _, instr := range b.Instrs { 14 | if instr.Pos() == v.Pos() { 15 | si = instr 16 | return true 17 | } 18 | } 19 | } 20 | return false 21 | }) 22 | return 23 | } 24 | -------------------------------------------------------------------------------- /check.go: -------------------------------------------------------------------------------- 1 | package taint 2 | 3 | import ( 4 | "golang.org/x/tools/go/callgraph" 5 | "golang.org/x/tools/go/ssa" 6 | 7 | "github.com/picatz/taint/callgraphutil" 8 | ) 9 | 10 | // Result is an individual finding from a taint check. 11 | // 12 | // It contains the path within the callgraph where the source 13 | // found its way into the sink, along with the source and sink 14 | // type information and SSA values. 15 | type Result struct { 16 | // Path is the specific path within a callgraph 17 | // where the source founds its way into a sink. 18 | Path callgraphutil.Path 19 | 20 | // Source type information. 21 | SourceType string 22 | // Source SSA value. 23 | SourceValue ssa.Value 24 | 25 | // Sink information. 26 | SinkType string 27 | // Sink SSA value. 28 | SinkValue ssa.Value 29 | } 30 | 31 | // Results is a collection of unique findings from a taint check. 32 | type Results []Result 33 | 34 | // Check is the primary function users of this package will use. 35 | // 36 | // It returns a list of results from the callgraph, where any of the given 37 | // sources found their way into any of the given sinks. 38 | // 39 | // Sources is a list of functions that return user-controlled values, 40 | // such as HTTP request parameters. Sinks is a list of potentially dangerous 41 | // functions that should not be called with user-controlled values. 42 | // 43 | // Diagram 44 | // ╭───────────────────────────────────────────────────────────────╮ 45 | // │ ╭────────┬──────────────╮ │ 46 | // │ ▼ │ │ │ 47 | // ╭───────╮ │ ╭───────────╮ ╭───────────────╮│ ╭──────────┴──────────╮ │ 48 | // │ Check ├──▶ │ │ checkPath ├──▶ │ checkSSAValue ├┴─▶ │ checkSSAInstruction │ │ 49 | // ╰───────╯ │ ╰───────────╯ ╰───────────────╯ ╰─────────────────────╯ │ 50 | // ╰──────────────────────────────┬────────────────────────────────╯ 51 | // │ 52 | // ▼ 53 | // ╭─────────╮ 54 | // │ Results │ 55 | // ╰─────────╯ 56 | // 57 | // This is a recursive algorithm that will traverse the callgraph to identify 58 | // if any of the given sources were used to obtain the initial SSA value (v). 59 | // We handle this value, depending on its type, where we "peel back" its 60 | // references and relevant SSA instructions to determine if any of the given 61 | // sinks were involved in the creation of the initial value. 62 | func Check(cg *callgraph.Graph, sources Sources, sinks Sinks) Results { 63 | // The results of the taint check. 64 | results := Results{} 65 | 66 | // For each sink given, identify the individual paths from 67 | // within the callgraph that those sinks can end up as 68 | // the final node path (the "sink path"). 69 | for sink := range sinks { 70 | sinkPaths := callgraphutil.PathsSearchCallTo(cg.Root, sink) 71 | 72 | // fmt.Println("sink paths:", len(sinkPaths)) 73 | 74 | for _, sinkPath := range sinkPaths { 75 | // fmt.Println("sink path:", sinkPath) 76 | // Ensure the path isn't empty (which can happen?!). 77 | // 78 | // TODO: ensure returned paths from within searched paths 79 | // are never empty. That's just silly. 80 | if sinkPath.Empty() { 81 | continue 82 | } 83 | 84 | // Check if the last edge (e.g. a SQL query) used any of the given 85 | // sources (e.g. user input in an HTTP request) to identify if it 86 | // was "tainted". 87 | tainted, src, tv := checkPath(sinkPath, sources) 88 | if tainted { 89 | // Extract the last edge from the last part of the path 90 | // to include the calle as the sink in the result. 91 | lastEdge := sinkPath.Last() 92 | 93 | // Add the result to the list of results. 94 | results = append(results, Result{ 95 | Path: sinkPath, 96 | SourceType: src, 97 | SourceValue: tv, 98 | SinkType: lastEdge.Callee.String(), 99 | SinkValue: lastEdge.Site.Value(), 100 | }) 101 | } 102 | } 103 | } 104 | 105 | // Return the results of the taint check. 106 | return results 107 | } 108 | 109 | // checkPath implements taint analysis that can be used to identify if the given 110 | // callgraph path contains information from taintable sources (typically user input). 111 | func checkPath(path callgraphutil.Path, sources Sources) (bool, string, ssa.Value) { 112 | // Ensure the path isn't empty (which can happen?!). 113 | if path.Empty() { 114 | return false, "", nil 115 | } 116 | 117 | // Value set used to keep track of values which were already visited 118 | // during the taint analysis. This prevents cyclic calls from crashing 119 | // the program. 120 | visited := valueSet{} 121 | 122 | // Start at last call from the path to see if any of the given sources were used 123 | // along with it to perform an action (e.g. SQL query). 124 | tainted, src, tv := checkSSAValue(path, sources, path.Last().Site.Value(), visited) 125 | if tainted { 126 | return true, src, tv 127 | } 128 | 129 | return false, "", nil 130 | } 131 | 132 | // checkSSAValue implements the core taint analysis algorithm. It identifies 133 | // if the given value "v" comes from any of the given sources (user input). 134 | // 135 | // It keeps track of nodes it has previously visted/checked, and recursively 136 | // calls itself (or checkSSAInstruction) as nessecary. 137 | // 138 | // It returns true if the given SSA value is tained by any of the given sources. 139 | func checkSSAValue(path callgraphutil.Path, sources Sources, v ssa.Value, visited valueSet) (bool, string, ssa.Value) { 140 | // First, check if this value has already been visited. 141 | // 142 | // If so, we can assume it is safe. 143 | if visited.includes(v) { 144 | return false, "", nil 145 | } 146 | 147 | // If it was not previously visited, we add it to the set 148 | // of visited values. This will prevent visting cyclic 149 | // calls from crashing the program. 150 | visited.add(v) 151 | 152 | // fmt.Printf("! check SSA value %s: %[1]T\n", v) 153 | 154 | // This is the core of the algorithm. 155 | // 156 | // It handles traversing the SSA value and callgraph to identify 157 | // if any of the given sources were used to obtain the initial 158 | // SSA value (v). We handle this value, depending on its type, 159 | // where we "peel back" its references and relevant SSA 160 | // instructions to determine if any of the given sinks were 161 | // involved in the process. 162 | switch value := v.(type) { 163 | // We assume that constants, functions, and globals are safe. 164 | // 165 | // To be clear: functions and globals may not always safe. 166 | // Just generally speaking. So in order to support additional 167 | // analysis in the future these values may need to be considered. 168 | // 169 | // It is probably safe to consider constants are always safe. 170 | // But what if you wanted to check if a constant made it into 171 | // a sink? 172 | case *ssa.Const, *ssa.Function, *ssa.Global: 173 | return false, "", nil 174 | // Function parameters can obscure the analysis of the value, 175 | // because we need to step backwards through the callgraph path 176 | // (just one step?) to identify what actual value the caller used. 177 | case *ssa.Parameter: 178 | // Check if the parameter's type is a source. 179 | paramTypeStr := value.Type().String() 180 | if src, ok := sources.includes(paramTypeStr); ok { 181 | return true, src, value 182 | } 183 | 184 | // Check the parameter's referrers. 185 | refs := value.Referrers() 186 | if refs != nil { 187 | for _, ref := range *refs { 188 | refVal, isVal := ref.(ssa.Value) 189 | if isVal { 190 | tainted, src, tv := checkSSAValue(path, sources, refVal, visited) 191 | if tainted { 192 | return true, src, tv 193 | } 194 | continue 195 | } 196 | 197 | tainted, src, tv := checkSSAInstruction(path, sources, ref, visited) 198 | if tainted { 199 | return true, src, tv 200 | } 201 | } 202 | } 203 | 204 | // TODO: consider if we can remove the range with a single 205 | // step backwards? 206 | for _, edge := range path { 207 | // Find the caller that used the function parameter's parent (the function). 208 | if edge.Callee.Func == v.Parent() { 209 | // Inspect the instructions of the caller's function to identify 210 | // the relevant call using the function parameter. 211 | for _, block := range edge.Caller.Func.DomPreorder() { 212 | for _, instr := range block.Instrs { 213 | callInstr, ok := instr.(*ssa.Call) 214 | if !ok { 215 | continue 216 | } 217 | if callInstr.Call.Value.Pos() == edge.Callee.Func.Pos() { 218 | tainted, src, tv := checkSSAInstruction(path, sources, instr, visited) 219 | if tainted { 220 | return true, src, tv 221 | } 222 | } 223 | } 224 | } 225 | } 226 | } 227 | // Function calls can be a little tricky. We need to check a few things. 228 | // 1. See if the call itself was a source. 229 | // 2. See if any of the arguments was a source. 230 | // 3. See if the call value calls a source (anonymous functions). 231 | case *ssa.Call: 232 | // 1. Handle the case where we finally called a source. 233 | callTypeStr := value.Call.Value.String() 234 | if src, ok := sources.includes(callTypeStr); ok { 235 | return true, src, value.Call.Value 236 | } 237 | // 2. Handle the arguments of the call. 238 | for _, arg := range value.Call.Args { 239 | tainted, src, tv := checkSSAValue(path, sources, arg, visited) 240 | if tainted { 241 | return true, src, tv 242 | } 243 | } 244 | // 3. Handle the case of a *ssa.Call from an anonymous function (*ssa.MakeClosure). 245 | tainted, src, tv := checkSSAValue(path, sources, value.Call.Value, visited) 246 | if tainted { 247 | return true, src, tv 248 | } 249 | // Memory allocations or addressing can be traversed using the value's 250 | // referrers. Each referrer is either an SSA value or instruction. 251 | case *ssa.Alloc: 252 | refs := value.Referrers() 253 | for _, ref := range *refs { 254 | refVal, isVal := ref.(ssa.Value) 255 | if isVal { 256 | tainted, src, tv := checkSSAValue(path, sources, refVal, visited) 257 | if tainted { 258 | return true, src, tv 259 | } 260 | continue 261 | } 262 | 263 | tainted, src, tv := checkSSAInstruction(path, sources, ref, visited) 264 | if tainted { 265 | return true, src, tv 266 | } 267 | } 268 | // Free variables can be traversed using the value's referrers, or the 269 | // value's parent's referrers. Each referrer is either an SSA value or 270 | // instruction. 271 | // 272 | // These can be tricky because they can be used in a few different ways, 273 | // preventing us from just checking the value's referrers in all cases. 274 | case *ssa.FreeVar: 275 | refs := value.Referrers() 276 | for _, ref := range *refs { 277 | refVal, isVal := ref.(ssa.Value) 278 | if isVal { 279 | tainted, src, tv := checkSSAValue(path, sources, refVal, visited) 280 | if tainted { 281 | return true, src, tv 282 | } 283 | continue 284 | } 285 | 286 | tainted, src, tv := checkSSAInstruction(path, sources, ref, visited) 287 | if tainted { 288 | return true, src, tv 289 | } 290 | } 291 | 292 | // Handle the case of an anonymous function being injected into the child scope. 293 | // 294 | // Example 295 | // 296 | // ╭──────────────────────────────╮ 297 | // ↓ │ 298 | // user := r.URL.Query()["query"] │ Parent scope 299 | // func() { │ ┄┄┄┄┄┄┄┄┄┄┄┄ 300 | // userValue := user[0] ←────────╯ Child scope 301 | // business(db, func() *string { 302 | // return &userValue 303 | // }()) 304 | // }() 305 | // 306 | // TODO: consider checking parentFn params and other places? 307 | parentFn := value.Parent().Parent() 308 | for _, block := range parentFn.DomPreorder() { 309 | for _, instr := range block.Instrs { 310 | // fmt.Printf("\t - check SSA value %s: %[1]T ~ %[2]v\n", instr, value.Name()) 311 | val, isval := instr.(ssa.Value) 312 | if !isval { 313 | continue 314 | } 315 | alloc, isalloc := val.(*ssa.Alloc) 316 | if isalloc { 317 | if alloc.Comment == value.Name() { 318 | tainted, src, tv := checkSSAValue(path, sources, val, visited) 319 | if tainted { 320 | return true, src, tv 321 | } 322 | } 323 | continue 324 | } 325 | 326 | // TODO: handle more cases like this... 327 | // 328 | // Example 329 | // 330 | // mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 331 | // var input map[string]any ←────────╮ ↑ 332 | // ╭─────────────────↓─────────╯ 333 | // json.NewDecoder(r.Body).Decode(&input) ←────╮ 334 | // │ 335 | // func() { ↓ 336 | // userValue := fmt.Sprintf("%s", input["query"]) ←────────╮ 337 | // business(db, func() *string { │ 338 | // return &userValue ←─────────────────────────────────╯ 339 | // }()) 340 | // }() 341 | // }) 342 | // 343 | tainted, src, tv := checkSSAValue(path, sources, val, valueSet{}) 344 | if tainted { 345 | return true, src, tv 346 | } 347 | } 348 | } 349 | case *ssa.IndexAddr: 350 | refs := value.Referrers() 351 | for _, ref := range *refs { 352 | refVal, isVal := ref.(ssa.Value) 353 | if isVal { 354 | tainted, src, tv := checkSSAValue(path, sources, refVal, visited) 355 | if tainted { 356 | return true, src, tv 357 | } 358 | continue 359 | } 360 | 361 | tainted, src, tv := checkSSAInstruction(path, sources, ref, visited) 362 | if tainted { 363 | return true, src, tv 364 | } 365 | } 366 | tainted, src, tv := checkSSAValue(path, sources, value.X, visited) 367 | if tainted { 368 | return true, src, tv 369 | } 370 | case *ssa.FieldAddr: 371 | /* 372 | value.String() 373 | => "&r.URL [#1]" 374 | value.Type().String() 375 | => "**net/url.URL" 376 | value.X.Type().String() 377 | =? "*net/http.Request" 378 | */ 379 | if src, ok := sources.includes(value.X.Type().String()); ok { 380 | return true, src, value 381 | } 382 | 383 | tainted, src, tv := checkSSAValue(path, sources, value.X, visited) 384 | if tainted { 385 | return true, src, tv 386 | } 387 | 388 | refs := value.Referrers() 389 | for _, ref := range *refs { 390 | refVal, isVal := ref.(ssa.Value) 391 | if isVal { 392 | tainted, src, tv := checkSSAValue(path, sources, refVal, visited) 393 | if tainted { 394 | return true, src, tv 395 | } 396 | continue 397 | } 398 | 399 | tainted, src, tv := checkSSAInstruction(path, sources, ref, visited) 400 | if tainted { 401 | return true, src, tv 402 | } 403 | } 404 | indexableValueRefs := value.X.Referrers() 405 | for _, ref := range *indexableValueRefs { 406 | refVal, isVal := ref.(ssa.Value) 407 | if isVal { 408 | tainted, src, tv := checkSSAValue(path, sources, refVal, visited) 409 | if tainted { 410 | return true, src, tv 411 | } 412 | continue 413 | } 414 | 415 | tainted, src, tv := checkSSAInstruction(path, sources, ref, visited) 416 | if tainted { 417 | return true, src, tv 418 | } 419 | } 420 | case *ssa.MakeClosure: 421 | tainted, src, tv := checkSSAValue(path, sources, value.Fn, visited) 422 | if tainted { 423 | return true, src, tv 424 | } 425 | for _, binding := range value.Bindings { 426 | tainted, src, tv := checkSSAValue(path, sources, binding, visited) 427 | if tainted { 428 | return true, src, tv 429 | } 430 | } 431 | case *ssa.BinOp: 432 | // Check the left hand side operands of the binary operations. 433 | tainted, src, tv := checkSSAValue(path, sources, value.X, visited) // left 434 | if tainted { 435 | return true, src, tv 436 | } 437 | tainted, src, tv = checkSSAValue(path, sources, value.Y, visited) // right 438 | if tainted { 439 | return true, src, tv 440 | } 441 | case *ssa.UnOp: 442 | // Check the single operand. 443 | tainted, src, tv := checkSSAValue(path, sources, value.X, visited) 444 | if tainted { 445 | return true, src, tv 446 | } 447 | case *ssa.Slice: 448 | // Check the sliced value. 449 | tainted, src, tv := checkSSAValue(path, sources, value.X, visited) 450 | if tainted { 451 | return true, src, tv 452 | } 453 | case *ssa.MakeInterface: 454 | // Check the value being made into an interface. 455 | tainted, src, tv := checkSSAValue(path, sources, value.X, visited) 456 | if tainted { 457 | return true, src, tv 458 | } 459 | case *ssa.ChangeInterface: 460 | // Check the value being changed into an interface. 461 | tainted, src, tv := checkSSAValue(path, sources, value.X, visited) 462 | if tainted { 463 | return true, src, tv 464 | } 465 | 466 | // Check the value's referrers. 467 | refs := value.X.Referrers() 468 | for _, ref := range *refs { 469 | refVal, isVal := ref.(ssa.Value) 470 | if isVal { 471 | tainted, src, tv := checkSSAValue(path, sources, refVal, visited) 472 | if tainted { 473 | return true, src, tv 474 | } 475 | continue 476 | } 477 | 478 | tainted, src, tv := checkSSAInstruction(path, sources, ref, visited) 479 | if tainted { 480 | return true, src, tv 481 | } 482 | } 483 | case *ssa.TypeAssert: 484 | // Check the value being type asserted. 485 | tainted, src, tv := checkSSAValue(path, sources, value.X, visited) 486 | if tainted { 487 | return true, src, tv 488 | } 489 | case *ssa.Convert: 490 | // Check the value being converted. 491 | tainted, src, tv := checkSSAValue(path, sources, value.X, visited) 492 | if tainted { 493 | return true, src, tv 494 | } 495 | case *ssa.Extract: 496 | // Check the value being extracted. 497 | tainted, src, tv := checkSSAValue(path, sources, value.Tuple, visited) 498 | if tainted { 499 | return true, src, tv 500 | } 501 | case *ssa.Lookup: 502 | // Check the string or map value being looked up. 503 | tainted, src, tv := checkSSAValue(path, sources, value.X, visited) 504 | if tainted { 505 | return true, src, tv 506 | } 507 | 508 | // Check the index value being looked up. 509 | refs := value.Index.Referrers() 510 | if refs != nil { 511 | for _, ref := range *refs { 512 | refVal, isVal := ref.(ssa.Value) 513 | if isVal { 514 | tainted, src, tv := checkSSAValue(path, sources, refVal, visited) 515 | if tainted { 516 | return true, src, tv 517 | } 518 | continue 519 | } 520 | 521 | tainted, src, tv := checkSSAInstruction(path, sources, ref, visited) 522 | if tainted { 523 | return true, src, tv 524 | } 525 | } 526 | } 527 | case *ssa.MakeMap: 528 | refs := value.Referrers() 529 | if refs != nil { 530 | for _, ref := range *refs { 531 | refVal, isVal := ref.(ssa.Value) 532 | if isVal { 533 | tainted, src, tv := checkSSAValue(path, sources, refVal, visited) 534 | if tainted { 535 | return true, src, tv 536 | } 537 | continue 538 | } 539 | 540 | tainted, src, tv := checkSSAInstruction(path, sources, ref, visited) 541 | if tainted { 542 | return true, src, tv 543 | } 544 | } 545 | } 546 | default: 547 | // fmt.Printf("? check SSA value %s: %[1]T\n", v) 548 | return false, "", nil 549 | } 550 | return false, "", nil 551 | } 552 | 553 | // checkSSAInstruction is used internally by checkSSAValue when it needs to traverse 554 | // SSA instructions, like the contents of a calling function. 555 | func checkSSAInstruction(path callgraphutil.Path, sources Sources, i ssa.Instruction, visited valueSet) (bool, string, ssa.Value) { 556 | // fmt.Printf("! check SSA instr %s: %[1]T\n", i) 557 | 558 | switch instr := i.(type) { 559 | case *ssa.Store: 560 | // Store instructions need to be checked for both the value being stored, 561 | // and the address being stored to. 562 | tainted, src, tv := checkSSAValue(path, sources, instr.Val, visited) 563 | if tainted { 564 | return true, src, tv 565 | } 566 | tainted, src, tv = checkSSAValue(path, sources, instr.Addr, visited) 567 | if tainted { 568 | return true, src, tv 569 | } 570 | case *ssa.Call: 571 | // Check the operands of the call instruction. 572 | for _, instrValue := range instr.Operands(nil) { 573 | if instrValue == nil { 574 | continue 575 | } 576 | iv := *instrValue 577 | tainted, src, tv := checkSSAValue(path, sources, iv, visited) 578 | if tainted { 579 | return true, src, tv 580 | } 581 | } 582 | case *ssa.MapUpdate: 583 | // Map update instructions need to be checked for both the map being updated, 584 | // and the key and value being updated. 585 | tainted, src, tv := checkSSAValue(path, sources, instr.Key, visited) 586 | if tainted { 587 | return true, src, tv 588 | } 589 | 590 | tainted, src, tv = checkSSAValue(path, sources, instr.Value, visited) 591 | if tainted { 592 | return true, src, tv 593 | } 594 | default: 595 | // fmt.Printf("? check SSA instr %s: %[1]T\n", i) 596 | return false, "", nil 597 | } 598 | return false, "", nil 599 | } 600 | -------------------------------------------------------------------------------- /cmd/logi/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/picatz/taint/log/injection" 5 | "golang.org/x/tools/go/analysis/singlechecker" 6 | ) 7 | 8 | func main() { 9 | singlechecker.Main(injection.Analyzer) 10 | } 11 | -------------------------------------------------------------------------------- /cmd/sqli/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/picatz/taint/sql/injection" 5 | "golang.org/x/tools/go/analysis/singlechecker" 6 | ) 7 | 8 | func main() { 9 | singlechecker.Main(injection.Analyzer) 10 | } 11 | -------------------------------------------------------------------------------- /cmd/ssadump/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "go/ast" 7 | "go/parser" 8 | "go/token" 9 | "os" 10 | "os/signal" 11 | 12 | "golang.org/x/tools/go/packages" 13 | "golang.org/x/tools/go/ssa" 14 | "golang.org/x/tools/go/ssa/ssautil" 15 | ) 16 | 17 | func main() { 18 | ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) 19 | defer cancel() 20 | 21 | patterns := os.Args[1:] 22 | 23 | if len(patterns) == 0 { 24 | fmt.Fprintf(os.Stderr, "usage: %s \n", os.Args[0]) 25 | os.Exit(1) 26 | } 27 | 28 | loadMode := packages.NeedName | 29 | packages.NeedFiles | 30 | packages.NeedCompiledGoFiles | 31 | packages.NeedImports | 32 | packages.NeedTypes | 33 | packages.NeedTypesSizes | 34 | packages.NeedSyntax | 35 | packages.NeedTypesInfo | 36 | packages.NeedDeps 37 | 38 | parseMode := parser.SkipObjectResolution 39 | 40 | dir, err := os.Getwd() 41 | if err != nil { 42 | fmt.Fprintf(os.Stderr, "failed to get current working directory: %v\n", err.Error()) 43 | os.Exit(1) 44 | } 45 | 46 | cfg := &packages.Config{ 47 | Mode: loadMode, 48 | Context: ctx, 49 | Dir: dir, 50 | Env: os.Environ(), 51 | ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { 52 | return parser.ParseFile(fset, filename, src, parseMode) 53 | }, 54 | } 55 | 56 | initial, err := packages.Load(cfg, patterns...) 57 | if err != nil { 58 | fmt.Fprintf(os.Stderr, "%v\n", err.Error()) 59 | os.Exit(1) 60 | } 61 | 62 | // bubble up all loaded package errors 63 | for _, pkg := range initial { 64 | if len(pkg.Errors) != 0 { 65 | for _, err := range pkg.Errors { 66 | fmt.Fprintf(os.Stderr, "%v\n", err.Error()) 67 | } 68 | } 69 | } 70 | 71 | _, pkgs := ssautil.Packages(initial, 0) 72 | 73 | for _, pkg := range pkgs { 74 | // malformed packages will be nil 75 | if pkg != nil { 76 | pkg.Build() 77 | for _, m := range pkg.Members { 78 | if fn, ok := m.(*ssa.Function); ok { 79 | fn.WriteTo(os.Stdout) 80 | } 81 | } 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /cmd/taint/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: install 2 | install: 3 | @go build -o $(shell go env GOPATH)/bin/taint . 4 | 5 | .PHONY: vhs 6 | vhs: 7 | @vhs ./vhs/demo.tape 8 | -------------------------------------------------------------------------------- /cmd/taint/example/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql" 5 | "net/http" 6 | ) 7 | 8 | func handle(db *sql.DB, q string) { 9 | db.Query(q) // want "potential sql injection" 10 | } 11 | 12 | func business(db *sql.DB, q *string) error { 13 | handle(db, *q) 14 | return nil 15 | } 16 | 17 | func main() { 18 | db, _ := sql.Open("sqlite3", ":memory:") 19 | 20 | mux := http.NewServeMux() 21 | 22 | mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 23 | user := r.URL.Query()["query"] 24 | userValue := user[0] 25 | business(db, &userValue) 26 | }) 27 | 28 | err := http.ListenAndServe(":8080", mux) 29 | if err != nil { 30 | panic(err) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /cmd/taint/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "flag" 7 | "fmt" 8 | "go/ast" 9 | "go/parser" 10 | "go/token" 11 | "io" 12 | "net/url" 13 | "os" 14 | "os/signal" 15 | "path/filepath" 16 | "sort" 17 | "strconv" 18 | "strings" 19 | 20 | "github.com/charmbracelet/lipgloss" 21 | "github.com/go-git/go-git/v5" 22 | "github.com/picatz/taint" 23 | "github.com/picatz/taint/callgraphutil" 24 | "golang.org/x/term" 25 | "golang.org/x/tools/go/callgraph" 26 | "golang.org/x/tools/go/packages" 27 | "golang.org/x/tools/go/ssa" 28 | "golang.org/x/tools/go/ssa/ssautil" 29 | ) 30 | 31 | var ( 32 | styleBold = lipgloss.NewStyle().Bold(true) 33 | styleFaint = lipgloss.NewStyle().Faint(true) 34 | styleNumber = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("69")) 35 | styleArgument = lipgloss.NewStyle().Foreground(lipgloss.Color("68")) 36 | styleFlag = lipgloss.NewStyle().Foreground(lipgloss.Color("66")) 37 | styleCommand = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("62")) 38 | ) 39 | 40 | var ( 41 | pkgs []*packages.Package 42 | ssaProg *ssa.Program 43 | ssaPkgs []*ssa.Package 44 | cg *callgraph.Graph 45 | ) 46 | 47 | // highlightNode returns a string with the node highlighted, such that 48 | // `n4:(net/http.ResponseWriter).Write` has the `n4` highlighted as a number, 49 | // and the rest of the string highlighted as a typical Go identifier. 50 | func highlightNode(node string) string { 51 | // Split the node string on the colon. 52 | parts := strings.Split(node, ":") 53 | 54 | // Get the node ID. 55 | nodeID := parts[0] 56 | 57 | // Highlight the node ID. 58 | nodeID = styleNumber.Render(nodeID) 59 | 60 | // Get the rest of the node string. 61 | nodeStr := strings.Join(parts[1:], ":") 62 | 63 | // Return the highlighted node. 64 | return nodeID + ":" + nodeStr 65 | } 66 | 67 | // makeRawTerminal returns a raw terminal and a function to restore the 68 | // terminal to its previous state, which should be called when the terminal 69 | // is no longer needed (typically in a defer). 70 | func makeRawTerminal() (*term.Terminal, func(), error) { 71 | // Set the terminal to raw mode. 72 | oldState, err := term.MakeRaw(0) 73 | if err != nil { 74 | return nil, nil, fmt.Errorf("%w", err) 75 | } 76 | 77 | termWidth, termHeight, err := term.GetSize(0) 78 | if err != nil { 79 | return nil, nil, fmt.Errorf("%w", err) 80 | } 81 | 82 | termReadWriter := struct { 83 | io.Reader 84 | io.Writer 85 | }{os.Stdin, os.Stdout} 86 | 87 | t := term.NewTerminal(termReadWriter, "") // Will set the prompt later. 88 | 89 | err = t.SetSize(termWidth, termHeight) 90 | if err != nil { 91 | return nil, nil, fmt.Errorf("%w", err) 92 | } 93 | 94 | return t, func() { term.Restore(0, oldState) }, nil 95 | } 96 | 97 | func clearScreen(bt *bufio.Writer) error { 98 | // Clear the screen. 99 | _, err := bt.Write([]byte("\033[2J")) 100 | if err != nil { 101 | return fmt.Errorf("%w", err) 102 | } 103 | 104 | // Move to the top left. 105 | _, err = bt.Write([]byte("\033[H")) 106 | if err != nil { 107 | return fmt.Errorf("%w", err) 108 | } 109 | 110 | // Flush the buffer to the terminal. 111 | err = bt.Flush() 112 | if err != nil { 113 | return fmt.Errorf("%w", err) 114 | } 115 | 116 | return nil 117 | } 118 | 119 | type commandArg struct { 120 | name string 121 | desc string 122 | optional bool 123 | } 124 | 125 | type commandFlag struct { 126 | name string 127 | desc string 128 | } 129 | 130 | type command struct { 131 | name string 132 | desc string 133 | args []*commandArg 134 | flags []*commandFlag 135 | fn commandFn 136 | } 137 | 138 | func (c *command) nRequiredArgs() int { 139 | var n int 140 | for _, arg := range c.args { 141 | if arg.optional { 142 | continue 143 | } 144 | n++ 145 | } 146 | return n 147 | } 148 | 149 | func (c *command) help() string { 150 | var help strings.Builder 151 | 152 | help.WriteString(styleCommand.Render(c.name) + " ") 153 | 154 | for _, arg := range c.args { 155 | if arg.optional { 156 | help.WriteString(styleArgument.Render("[") + styleFaint.Render(fmt.Sprintf("<%s>", arg.name)) + styleArgument.Render("] ")) 157 | continue 158 | } 159 | help.WriteString(styleArgument.Render(fmt.Sprintf("<%s> ", arg.name))) 160 | } 161 | 162 | for _, flag := range c.flags { 163 | help.WriteString(styleFlag.Render(fmt.Sprintf("--%s ", flag.name) + styleFaint.Render(flag.desc))) 164 | } 165 | 166 | help.WriteString(styleFaint.Render(c.desc) + "\n") 167 | 168 | return help.String() 169 | } 170 | 171 | type commandFn func( 172 | ctx context.Context, 173 | bt *bufio.Writer, 174 | args []string, 175 | flags map[string]string, 176 | ) error 177 | 178 | func errorCommandFn(err error) commandFn { 179 | return func( 180 | _ context.Context, 181 | _ *bufio.Writer, 182 | _ []string, 183 | _ map[string]string, 184 | ) error { 185 | return err 186 | } 187 | } 188 | 189 | func terminalWriteFn(fn func(bt *bufio.Writer) error) commandFn { 190 | return func( 191 | _ context.Context, 192 | bt *bufio.Writer, 193 | _ []string, 194 | _ map[string]string, 195 | ) error { 196 | return fn(bt) 197 | } 198 | } 199 | 200 | type commands []*command 201 | 202 | func (c commands) help() string { 203 | var help strings.Builder 204 | for _, cmd := range c { 205 | help.WriteString(styleFaint.Render("- ") + styleCommand.Render(cmd.name) + " ") 206 | 207 | for _, arg := range cmd.args { 208 | if arg.optional { 209 | help.WriteString(styleArgument.Render("[") + styleFaint.Render(arg.name) + styleArgument.Render("] ")) 210 | continue 211 | } 212 | help.WriteString(styleArgument.Render(fmt.Sprintf("<%s> ", arg.name))) 213 | } 214 | 215 | help.WriteString(styleFaint.Render(cmd.desc) + "\n") 216 | } 217 | 218 | help.WriteString("\n") 219 | 220 | return help.String() 221 | } 222 | 223 | func (c commands) eval(ctx context.Context, bt *bufio.Writer, input string) error { 224 | fields := strings.Fields(input) 225 | if len(fields) == 0 { 226 | return nil 227 | } 228 | 229 | cmdName := fields[0] 230 | 231 | argsAndFlags := fields[1:] 232 | 233 | // Parse flags with Go's flag package. 234 | flagSet := flag.NewFlagSet(cmdName, flag.ContinueOnError) 235 | 236 | flagSet.SetOutput(bt) 237 | 238 | flagSet.Usage = func() { 239 | // Print command help. 240 | bt.WriteString(c.help()) 241 | bt.Flush() 242 | } 243 | 244 | // Parse the flags. 245 | err := flagSet.Parse(argsAndFlags) 246 | if err != nil { 247 | return err 248 | } 249 | 250 | // Get the flags. 251 | flags := make(map[string]string) 252 | flagSet.Visit(func(f *flag.Flag) { 253 | flags[f.Name] = f.Value.String() 254 | }) 255 | 256 | for _, cmd := range c { 257 | if cmd.name == cmdName { 258 | // Check there are enough arguments. 259 | if len(flagSet.Args()) < cmd.nRequiredArgs() { 260 | bt.WriteString("not enough arguments, expected " + styleNumber.Render(fmt.Sprintf("%d", cmd.nRequiredArgs())) + " but got " + styleNumber.Render(fmt.Sprintf("%d", len(flagSet.Args()))) + "\n") 261 | bt.WriteString("usage: " + cmd.help()) 262 | bt.Flush() 263 | return nil 264 | } 265 | 266 | return cmd.fn(ctx, bt, flagSet.Args(), flags) 267 | } 268 | } 269 | 270 | bt.WriteString("unknown command: " + cmdName + "\n") 271 | bt.Flush() 272 | 273 | return nil 274 | } 275 | 276 | var builtinCommandExit = &command{ 277 | name: "exit", 278 | desc: "exit the shell", 279 | fn: errorCommandFn(io.EOF), 280 | } 281 | 282 | var builtinCommandClear = &command{ 283 | name: "clear", 284 | desc: "clear the screen", 285 | fn: terminalWriteFn(func(bt *bufio.Writer) error { 286 | return clearScreen(bt) 287 | }), 288 | } 289 | 290 | var builtinCommandLoad = &command{ 291 | name: "load", 292 | desc: "load a program", 293 | args: []*commandArg{ 294 | { 295 | name: "target", 296 | desc: "the target to load (directory or github repository)", 297 | }, 298 | { 299 | name: "pattern", 300 | desc: "the pattern to load (default: ./...)", 301 | optional: true, 302 | }, 303 | }, 304 | fn: func(ctx context.Context, bt *bufio.Writer, args []string, flags map[string]string) error { 305 | arg := args[0] 306 | 307 | var ( 308 | pattern string = "./..." 309 | 310 | dir string 311 | head string 312 | err error 313 | ) 314 | 315 | if len(args) > 1 { 316 | pattern = args[1] 317 | } 318 | 319 | // If the argument starts with https://github.com/, then we'll try to 320 | // clone the repository and load it. 321 | if strings.HasPrefix(arg, "https://github.com/") { 322 | // Clone the repository. 323 | dir, head, err = cloneRepository(ctx, arg) 324 | 325 | bt.WriteString("cloned " + styleNumber.Render(arg) + " to " + styleNumber.Render(dir) + " at " + styleNumber.Render(head) + "\n") 326 | bt.Flush() 327 | 328 | if err != nil { 329 | bt.WriteString(err.Error() + "\n") 330 | bt.Flush() 331 | return nil 332 | } 333 | } else { 334 | dir = arg 335 | } 336 | 337 | // Check if the directory exists. 338 | _, err = os.Stat(dir) 339 | if os.IsNotExist(err) { 340 | bt.WriteString(fmt.Sprintf("directory %q does not exist\n", dir)) 341 | bt.Flush() 342 | return nil 343 | } 344 | 345 | loadMode := 346 | packages.NeedName | 347 | packages.NeedDeps | 348 | packages.NeedFiles | 349 | packages.NeedModule | 350 | packages.NeedTypes | 351 | packages.NeedImports | 352 | packages.NeedSyntax | 353 | packages.NeedTypesInfo 354 | // packages.NeedTypesSizes | 355 | // packages.NeedCompiledGoFiles | 356 | // packages.NeedExportFile | 357 | // packages.NeedEmbedPatterns 358 | 359 | // parseMode := parser.ParseComments 360 | parseMode := parser.SkipObjectResolution 361 | 362 | // patterns := []string{dir} 363 | patterns := []string{pattern} 364 | // patterns := []string{"all"} 365 | 366 | pkgs, err = packages.Load(&packages.Config{ 367 | Mode: loadMode, 368 | Context: ctx, 369 | Env: os.Environ(), 370 | Dir: dir, 371 | Tests: false, 372 | ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { 373 | return parser.ParseFile(fset, filename, src, parseMode) 374 | }, 375 | }, patterns...) 376 | if err != nil { 377 | bt.WriteString(err.Error() + "\n") 378 | bt.Flush() 379 | return nil 380 | } 381 | 382 | ssaBuildMode := ssa.InstantiateGenerics // ssa.SanityCheckFunctions | ssa.GlobalDebug 383 | 384 | // Analyze the package. 385 | ssaProg, ssaPkgs = ssautil.Packages(pkgs, ssaBuildMode) 386 | 387 | ssaProg.Build() 388 | 389 | for _, pkg := range ssaPkgs { 390 | pkg.Build() 391 | } 392 | 393 | mainPkgs := ssautil.MainPackages(ssaPkgs) 394 | 395 | mainFn := mainPkgs[0].Members["main"].(*ssa.Function) 396 | 397 | var srcFns []*ssa.Function 398 | 399 | for _, pkg := range ssaPkgs { 400 | for _, fn := range pkg.Members { 401 | if fn.Object() == nil { 402 | continue 403 | } 404 | 405 | if fn.Object().Name() == "_" { 406 | continue 407 | } 408 | 409 | pkgFn := pkg.Func(fn.Object().Name()) 410 | if pkgFn == nil { 411 | continue 412 | } 413 | 414 | var addAnons func(f *ssa.Function) 415 | addAnons = func(f *ssa.Function) { 416 | srcFns = append(srcFns, f) 417 | for _, anon := range f.AnonFuncs { 418 | addAnons(anon) 419 | } 420 | } 421 | addAnons(pkgFn) 422 | } 423 | } 424 | 425 | if mainFn == nil { 426 | bt.WriteString("no main function found\n") 427 | bt.Flush() 428 | return nil 429 | } 430 | 431 | cg, err = callgraphutil.NewGraph(mainFn, srcFns...) 432 | if err != nil { 433 | bt.WriteString(err.Error() + "\n") 434 | bt.Flush() 435 | return nil 436 | } 437 | 438 | bt.WriteString("loaded " + styleNumber.Render(fmt.Sprintf("%d", len(pkgs))) + " packages\n") 439 | bt.Flush() 440 | return nil 441 | }, 442 | } 443 | 444 | var builtinCommandPkgs = &command{ 445 | name: "pkgs", 446 | desc: "list loaded packages", 447 | fn: func(ctx context.Context, bt *bufio.Writer, args []string, flags map[string]string) error { 448 | if len(pkgs) == 0 { 449 | bt.WriteString("no packages are loaded\n") 450 | bt.Flush() 451 | return nil 452 | } 453 | 454 | var pkgsStr strings.Builder 455 | 456 | for _, pkg := range pkgs { 457 | var ssaPkg *ssa.Package 458 | for _, p := range ssaPkgs { 459 | if p.Pkg.Path() == pkg.PkgPath { 460 | ssaPkg = p 461 | break 462 | } 463 | } 464 | 465 | if ssaPkg == nil { 466 | continue 467 | } 468 | 469 | pkgsStr.WriteString(pkg.PkgPath + " " + styleFaint.Render(fmt.Sprintf("%d imports", len(ssaPkg.Pkg.Imports()))) + "\n") 470 | } 471 | 472 | bt.WriteString(pkgsStr.String()) 473 | 474 | bt.Flush() 475 | return nil 476 | }, 477 | } 478 | 479 | var builtinCommandCG = &command{ 480 | name: "cg", 481 | desc: "print the callgraph", 482 | fn: func(ctx context.Context, bt *bufio.Writer, args []string, flags map[string]string) error { 483 | if cg == nil { 484 | bt.WriteString("no callgraph is loaded\n") 485 | bt.Flush() 486 | return nil 487 | } 488 | 489 | cgStr := strings.ReplaceAll(callgraphutil.GraphString(cg), "→", styleFaint.Render("→")) 490 | 491 | bt.WriteString(cgStr) 492 | bt.Flush() 493 | return nil 494 | }, 495 | } 496 | 497 | var builtinCommandRoot = &command{ 498 | name: "root", 499 | desc: "print the callgraph's root", 500 | fn: func(ctx context.Context, bt *bufio.Writer, args []string, flags map[string]string) error { 501 | if cg == nil { 502 | bt.WriteString("no callgraph is loaded\n") 503 | bt.Flush() 504 | return nil 505 | } 506 | 507 | bt.WriteString(cg.Root.String() + "\n") 508 | bt.Flush() 509 | return nil 510 | }, 511 | } 512 | 513 | var builtinCommandNodes = &command{ 514 | name: "nodes", 515 | desc: "print the callgraph nodes", 516 | fn: func(ctx context.Context, bt *bufio.Writer, args []string, flags map[string]string) error { 517 | if cg == nil { 518 | bt.WriteString("no callgraph is loaded\n") 519 | bt.Flush() 520 | return nil 521 | } 522 | 523 | var nodesStr strings.Builder 524 | 525 | nodesStrs := make([]string, 0, len(cg.Nodes)) 526 | 527 | for _, node := range cg.Nodes { 528 | nodesStrs = append(nodesStrs, node.String()) 529 | } 530 | 531 | sort.SliceStable(nodesStrs, func(i, j int) bool { 532 | // Parse node ID to int. 533 | iID := strings.Split(nodesStrs[i], ":")[0] 534 | jID := strings.Split(nodesStrs[j], ":")[0] 535 | 536 | // Trim the leading "n" prefix. 537 | iID = strings.TrimPrefix(iID, "n") 538 | jID = strings.TrimPrefix(jID, "n") 539 | 540 | // Parse node ID to int. 541 | iN, err := strconv.Atoi(iID) 542 | if err != nil { 543 | return false 544 | } 545 | 546 | jN, err := strconv.Atoi(jID) 547 | if err != nil { 548 | return false 549 | } 550 | 551 | // Compare node IDs. 552 | return iN < jN 553 | }) 554 | 555 | for _, nodeStr := range nodesStrs { 556 | nodesStr.WriteString(highlightNode(nodeStr) + "\n") 557 | } 558 | 559 | bt.WriteString(nodesStr.String()) 560 | bt.Flush() 561 | return nil 562 | }, 563 | } 564 | 565 | var builtinCommandsCallpath = &command{ 566 | name: "callpath", 567 | desc: "find callpaths to a function", 568 | args: []*commandArg{ 569 | { 570 | name: "function", 571 | desc: "the function to find callpaths to", 572 | }, 573 | }, 574 | fn: func(ctx context.Context, bt *bufio.Writer, args []string, flags map[string]string) error { 575 | if cg == nil { 576 | bt.WriteString("no callgraph is loaded\n") 577 | bt.Flush() 578 | return nil 579 | } 580 | 581 | if len(args) != 1 { 582 | bt.WriteString("usage: callpath \n") 583 | bt.Flush() 584 | return nil 585 | } 586 | 587 | fn := args[0] 588 | 589 | paths := callgraphutil.PathsSearchCallTo(cg.Root, fn) 590 | 591 | if len(paths) == 0 { 592 | bt.WriteString("no calls to " + fn + "\n") 593 | bt.Flush() 594 | return nil 595 | } 596 | 597 | for _, path := range paths { 598 | pathStr := path.String() 599 | 600 | // Split on " → " and highlight each node. 601 | parts := strings.Split(pathStr, " → ") 602 | 603 | for i, part := range parts { 604 | parts[i] = highlightNode(part) 605 | } 606 | 607 | pathStr = strings.Join(parts, styleFaint.Render(" → ")) 608 | 609 | bt.WriteString(pathStr + "\n") 610 | bt.Flush() 611 | } 612 | return nil 613 | }, 614 | } 615 | 616 | var builtinCommandCheck = &command{ 617 | name: "check", 618 | desc: "perform a taint analysis check", 619 | args: []*commandArg{ 620 | { 621 | name: "source", 622 | desc: "the source to check", 623 | }, 624 | { 625 | name: "sink", 626 | desc: "the sink to check", 627 | }, 628 | }, 629 | fn: func(ctx context.Context, bt *bufio.Writer, args []string, flags map[string]string) error { 630 | if cg == nil { 631 | bt.WriteString("no callgraph is loaded\n") 632 | bt.Flush() 633 | return nil 634 | } 635 | 636 | if len(args) != 2 { 637 | bt.WriteString("usage: check \n") 638 | bt.Flush() 639 | return nil 640 | } 641 | 642 | source := args[0] 643 | 644 | sink := args[1] 645 | 646 | results := taint.Check(cg, taint.NewSources(source), taint.NewSinks(sink)) 647 | 648 | var resultsStr strings.Builder 649 | 650 | for _, result := range results { 651 | resultPathStr := result.Path.String() 652 | 653 | parts := strings.Split(resultPathStr, " → ") 654 | 655 | for i, part := range parts { 656 | parts[i] = highlightNode(part) 657 | } 658 | 659 | resultPathStr = strings.Join(parts, styleFaint.Render(" → ")) 660 | 661 | resultsStr.WriteString(resultPathStr + "\n") 662 | } 663 | 664 | bt.WriteString(resultsStr.String()) 665 | bt.Flush() 666 | return nil 667 | }, 668 | } 669 | 670 | var builtinCommands = commands{ 671 | builtinCommandExit, 672 | builtinCommandClear, 673 | builtinCommandLoad, 674 | builtinCommandPkgs, 675 | builtinCommandCG, 676 | builtinCommandRoot, 677 | builtinCommandNodes, 678 | builtinCommandsCallpath, 679 | builtinCommandCheck, 680 | } 681 | 682 | func startShell(ctx context.Context) error { 683 | // Get a raw terminal. 684 | t, restore, err := makeRawTerminal() 685 | if err != nil { 686 | return err 687 | } 688 | 689 | // Restore the terminal on exit. 690 | defer restore() 691 | 692 | // Use buffered terminal writer. 693 | bt := bufio.NewWriter(t) 694 | 695 | // Autocomplete for commands. 696 | t.AutoCompleteCallback = func(line string, pos int, key rune) (newLine string, newPos int, ok bool) { 697 | // If the user presses tab, then autocomplete the command. 698 | if key == '\t' { 699 | for _, cmd := range builtinCommands { 700 | // If line is using the load command, then autocomplete the 701 | // directory name. 702 | if strings.HasPrefix(line, "load ") { 703 | // Get the directory name. 704 | dir := strings.TrimPrefix(line, "load ") 705 | 706 | // Check if the directory exists. 707 | _, err := os.Stat(dir) 708 | if os.IsNotExist(err) { 709 | // If the directory does not exist, check if there is a 710 | // directory with the same prefix. 711 | dirPrefix := strings.TrimSuffix(dir, "/") 712 | 713 | // Get the parent directory. 714 | parentDir := filepath.Dir(dirPrefix) 715 | 716 | // Get the directory name. 717 | dirName := filepath.Base(dirPrefix) 718 | 719 | // Open the parent directory. 720 | f, err := os.Open(parentDir) 721 | if err != nil { 722 | continue 723 | } 724 | 725 | // Get the directory entries. 726 | entries, err := f.Readdir(-1) 727 | if err != nil { 728 | // Close the parent directory. 729 | _ = f.Close() 730 | continue 731 | } 732 | 733 | // Close the parent directory. 734 | err = f.Close() 735 | if err != nil { 736 | continue 737 | } 738 | 739 | // Check if any of the directory entries match the 740 | // directory name prefix. 741 | for _, entry := range entries { 742 | if strings.HasPrefix(entry.Name(), dirName) { 743 | // If so, we'll autocomplete the directory name. 744 | loadCmd := "load " + filepath.Join(parentDir, entry.Name()) 745 | 746 | return loadCmd, len(loadCmd), true 747 | } 748 | } 749 | 750 | return line, pos, false 751 | } 752 | 753 | // Otherwise, we'll autocomplete the directory name. 754 | return "load " + dir, len("load " + dir), true 755 | } 756 | 757 | if strings.HasPrefix(cmd.name, line) { 758 | // Return the new line and position, which must come after the 759 | // command. 760 | return cmd.name, len(cmd.name), true 761 | } 762 | } 763 | } 764 | 765 | // Otherwise, we'll just return the line. 766 | return line, pos, false 767 | } 768 | 769 | // Print welcome message. 770 | bt.WriteString(styleBold.Render("Commands") + " " + styleFaint.Render("(tab complete)") + "\n\n") 771 | 772 | // Print the commands. 773 | bt.WriteString(builtinCommands.help()) 774 | 775 | // Flush the buffer to the terminal. 776 | bt.Flush() 777 | 778 | for { 779 | // Move to left edge. 780 | bt.WriteString("\033[0G") 781 | 782 | // Set the prompt. 783 | bt.WriteString(styleBold.Render("> ")) 784 | 785 | // Flush the buffer to the terminal. 786 | bt.Flush() 787 | 788 | // Read up to line from STDIN. 789 | input, err := t.ReadLine() 790 | if err != nil { 791 | return err 792 | } 793 | 794 | // Evaluate the input. 795 | err = builtinCommands.eval(ctx, bt, input) 796 | if err != nil { 797 | return err 798 | } 799 | 800 | // Flush the buffer to the terminal. 801 | bt.Flush() 802 | } 803 | } 804 | 805 | func main() { 806 | ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) 807 | defer cancel() 808 | 809 | if err := startShell(ctx); err != nil { 810 | if err == io.EOF { 811 | os.Exit(0) 812 | } 813 | 814 | fmt.Fprintf(os.Stderr, "error: %v\n", err) 815 | os.Exit(1) 816 | } 817 | } 818 | 819 | // cloneRepository clones a repository and returns the directory it was cloned 820 | // to using go-git under the hood, which is a pure Go implementation of Git. 821 | func cloneRepository(ctx context.Context, repoURL string) (string, string, error) { 822 | // Parse the repository URL (e.g. https://github.com/picatz/taint). 823 | u, err := url.Parse(repoURL) 824 | if err != nil { 825 | return "", "", fmt.Errorf("%w", err) 826 | } 827 | 828 | // Split the path into segments. 829 | pathSegments := strings.Split(u.Path, "/") 830 | 831 | // Ensure there are at least 2 segments for owner and repo. 832 | if len(pathSegments) < 3 { 833 | return "", "", fmt.Errorf("invalid GitHub URL: %s", repoURL) 834 | } 835 | 836 | // Get the owner and repo part of the URL. 837 | ownerAndRepo := pathSegments[1] + "/" + pathSegments[2] 838 | 839 | // Get the directory path. 840 | dir := filepath.Join(os.TempDir(), "taint", "github", ownerAndRepo) 841 | 842 | // Check if the directory exists. 843 | _, err = os.Stat(dir) 844 | if err == nil { 845 | // If the directory exists, we'll assume it's a valid repository, 846 | // and return the directory. Open the directory to 847 | repo, err := git.PlainOpen(dir) 848 | if err != nil { 849 | return dir, "", fmt.Errorf("%w", err) 850 | } 851 | 852 | // Get the repository's HEAD. 853 | head, err := repo.Head() 854 | if err != nil { 855 | return dir, "", fmt.Errorf("%w", err) 856 | } 857 | 858 | return dir, head.Hash().String(), nil 859 | } 860 | 861 | // Clone the repository. 862 | repo, err := git.PlainCloneContext(ctx, dir, false, &git.CloneOptions{ 863 | URL: repoURL, 864 | Depth: 1, 865 | Tags: git.NoTags, 866 | SingleBranch: true, 867 | }) 868 | if err != nil { 869 | return dir, "", fmt.Errorf("%w", err) 870 | } 871 | 872 | // Get the repository's HEAD. 873 | head, err := repo.Head() 874 | if err != nil { 875 | return dir, "", fmt.Errorf("%w", err) 876 | } 877 | 878 | return dir, head.Hash().String(), nil 879 | } 880 | -------------------------------------------------------------------------------- /cmd/taint/main_test.go: -------------------------------------------------------------------------------- 1 | package main_test 2 | 3 | import ( 4 | "context" 5 | "go/ast" 6 | "go/parser" 7 | "go/token" 8 | "os" 9 | "testing" 10 | 11 | "github.com/picatz/taint/callgraphutil" 12 | "golang.org/x/tools/go/packages" 13 | "golang.org/x/tools/go/ssa" 14 | "golang.org/x/tools/go/ssa/ssautil" 15 | ) 16 | 17 | func TestLoadAndSearch(t *testing.T) { 18 | loadMode := 19 | packages.NeedName | 20 | packages.NeedDeps | 21 | packages.NeedFiles | 22 | packages.NeedCompiledGoFiles | 23 | packages.NeedModule | 24 | packages.NeedTypes | 25 | packages.NeedImports | 26 | packages.NeedSyntax | 27 | packages.NeedTypesInfo 28 | // packages.NeedTypesSizes | 29 | // packages.NeedExportFile | 30 | // packages.NeedEmbedPatterns 31 | 32 | // parseMode := parser.ParseComments 33 | parseMode := parser.SkipObjectResolution 34 | 35 | // patterns := []string{dir} 36 | patterns := []string{"./..."} 37 | // patterns := []string{"all"} 38 | 39 | pkgs, err := packages.Load(&packages.Config{ 40 | Mode: loadMode, 41 | Context: context.Background(), 42 | Env: os.Environ(), 43 | Dir: "./example", 44 | Tests: false, 45 | ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { 46 | return parser.ParseFile(fset, filename, src, parseMode) 47 | }, 48 | }, patterns...) 49 | if err != nil { 50 | t.Fatal(err) 51 | } 52 | 53 | ssaBuildMode := ssa.InstantiateGenerics // ssa.SanityCheckFunctions | ssa.GlobalDebug 54 | 55 | // Analyze the package. 56 | ssaProg, ssaPkgs := ssautil.Packages(pkgs, ssaBuildMode) 57 | 58 | ssaProg.Build() 59 | 60 | for _, pkg := range ssaPkgs { 61 | pkg.Build() 62 | } 63 | 64 | mainPkgs := ssautil.MainPackages(ssaPkgs) 65 | 66 | mainFn := mainPkgs[0].Members["main"].(*ssa.Function) 67 | 68 | var srcFns []*ssa.Function 69 | 70 | for _, pkg := range ssaPkgs { 71 | for _, fn := range pkg.Members { 72 | if fn.Object() == nil { 73 | continue 74 | } 75 | 76 | if fn.Object().Name() == "_" { 77 | continue 78 | } 79 | 80 | pkgFn := pkg.Func(fn.Object().Name()) 81 | if pkgFn == nil { 82 | continue 83 | } 84 | 85 | var addAnons func(f *ssa.Function) 86 | addAnons = func(f *ssa.Function) { 87 | srcFns = append(srcFns, f) 88 | for _, anon := range f.AnonFuncs { 89 | addAnons(anon) 90 | } 91 | } 92 | addAnons(pkgFn) 93 | } 94 | } 95 | 96 | if mainFn == nil { 97 | t.Fatal("main function not found") 98 | } 99 | 100 | cg, err := callgraphutil.NewGraph(mainFn, srcFns...) 101 | if err != nil { 102 | t.Fatal(err) 103 | } 104 | 105 | t.Log(cg) 106 | 107 | // path := callgraph.PathSearchCallTo(cg.Root, "(*database/sql.DB).Query") 108 | 109 | // if path == nil { 110 | // t.Fatal("no path found") 111 | // } 112 | 113 | // t.Log(path) 114 | 115 | paths := callgraphutil.PathsSearchCallTo(cg.Root, "(*database/sql.DB).Query") 116 | 117 | if len(paths) == 0 { 118 | t.Fatal("no paths found") 119 | } 120 | 121 | for _, path := range paths { 122 | t.Log(path) 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /cmd/taint/vhs/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/picatz/taint/d6d87c73acf50d1d3110365698cbc98e3cb4cf0a/cmd/taint/vhs/demo.gif -------------------------------------------------------------------------------- /cmd/taint/vhs/demo.tape: -------------------------------------------------------------------------------- 1 | Output ./vhs/demo.gif 2 | 3 | Set Margin 20 4 | Set MarginFill "#ffffff" 5 | Set BorderRadius 10 6 | 7 | Set FontSize 20 8 | Set Width 1200 9 | Set Height 600 10 | 11 | Type "taint" 12 | 13 | Sleep 500ms 14 | 15 | Enter 16 | 17 | Sleep 3s 18 | 19 | Type "load ./example" 20 | 21 | Sleep 500ms 22 | 23 | Enter 24 | 25 | Sleep 2s 26 | 27 | Type "n" 28 | 29 | Tab@500ms 2 30 | 31 | Sleep 500ms 32 | 33 | Enter 34 | 35 | Sleep 2s 36 | 37 | Type "p" 38 | 39 | Tab@500ms 2 40 | 41 | Enter 42 | 43 | Sleep 5s 44 | 45 | Type "cg" 46 | 47 | Sleep 1s 48 | 49 | Enter 50 | 51 | Sleep 5s 52 | 53 | Type "callpath (*database/sql.DB).Query" 54 | 55 | Sleep 500ms 56 | 57 | Enter 58 | 59 | Sleep 3s 60 | 61 | Type "check *net/http.Request (*database/sql.DB).Query" 62 | 63 | Sleep 500ms 64 | 65 | Enter 66 | 67 | Sleep 5s 68 | 69 | Type "ex" 70 | 71 | Tab@500ms 2 72 | 73 | Enter 74 | 75 | Sleep 5s -------------------------------------------------------------------------------- /cmd/xss/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/picatz/taint/xss" 5 | "golang.org/x/tools/go/analysis/singlechecker" 6 | ) 7 | 8 | func main() { 9 | singlechecker.Main(xss.Analyzer) 10 | } 11 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Package taint enables "taint checking", a static analysis technique 2 | // for identifying attacker-controlled "sources" used in dangerous 3 | // contexts "sinks". 4 | // 5 | // A classic example of this is identifying SQL injections, 6 | // where user controlled inputs, typically from an HTTP request, 7 | // finds their way into a SQL query without using a prepared statement. 8 | package taint 9 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/picatz/taint 2 | 3 | go 1.21 4 | 5 | require ( 6 | github.com/charmbracelet/lipgloss v0.9.1 7 | github.com/go-git/go-git/v5 v5.11.0 8 | golang.org/x/term v0.18.0 9 | golang.org/x/tools v0.16.1 10 | ) 11 | 12 | require ( 13 | dario.cat/mergo v1.0.0 // indirect 14 | github.com/Microsoft/go-winio v0.6.1 // indirect 15 | github.com/ProtonMail/go-crypto v0.0.0-20230828082145-3c4c8a2d2371 // indirect 16 | github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect 17 | github.com/cloudflare/circl v1.3.7 // indirect 18 | github.com/cyphar/filepath-securejoin v0.2.4 // indirect 19 | github.com/emirpasic/gods v1.18.1 // indirect 20 | github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect 21 | github.com/go-git/go-billy/v5 v5.5.0 // indirect 22 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect 23 | github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect 24 | github.com/kevinburke/ssh_config v1.2.0 // indirect 25 | github.com/lucasb-eyer/go-colorful v1.2.0 // indirect 26 | github.com/mattn/go-isatty v0.0.20 // indirect 27 | github.com/mattn/go-runewidth v0.0.15 // indirect 28 | github.com/muesli/reflow v0.3.0 // indirect 29 | github.com/muesli/termenv v0.15.2 // indirect 30 | github.com/pjbgf/sha1cd v0.3.0 // indirect 31 | github.com/rivo/uniseg v0.4.4 // indirect 32 | github.com/sergi/go-diff v1.1.0 // indirect 33 | github.com/skeema/knownhosts v1.2.1 // indirect 34 | github.com/xanzy/ssh-agent v0.3.3 // indirect 35 | golang.org/x/crypto v0.21.0 // indirect 36 | golang.org/x/mod v0.14.0 // indirect 37 | golang.org/x/net v0.23.0 // indirect 38 | golang.org/x/sys v0.18.0 // indirect 39 | gopkg.in/warnings.v0 v0.1.2 // indirect 40 | ) 41 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= 2 | dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= 3 | github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= 4 | github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= 5 | github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= 6 | github.com/ProtonMail/go-crypto v0.0.0-20230828082145-3c4c8a2d2371 h1:kkhsdkhsCvIsutKu5zLMgWtgh9YxGCNAw8Ad8hjwfYg= 7 | github.com/ProtonMail/go-crypto v0.0.0-20230828082145-3c4c8a2d2371/go.mod h1:EjAoLdwvbIOoOQr3ihjnSoLZRtE8azugULFRteWMNc0= 8 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= 9 | github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= 10 | github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= 11 | github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= 12 | github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= 13 | github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= 14 | github.com/bwesterb/go-ristretto v1.2.3/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0= 15 | github.com/charmbracelet/lipgloss v0.9.1 h1:PNyd3jvaJbg4jRHKWXnCj1akQm4rh8dbEzN1p/u1KWg= 16 | github.com/charmbracelet/lipgloss v0.9.1/go.mod h1:1mPmG4cxScwUQALAAnacHaigiiHB9Pmr+v1VEawJl6I= 17 | github.com/cloudflare/circl v1.3.3/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA= 18 | github.com/cloudflare/circl v1.3.7 h1:qlCDlTPz2n9fu58M0Nh1J/JzcFpfgkFHHX3O35r5vcU= 19 | github.com/cloudflare/circl v1.3.7/go.mod h1:sRTcRWXGLrKw6yIGJ+l7amYJFfAXbZG0kBSc8r4zxgA= 20 | github.com/cyphar/filepath-securejoin v0.2.4 h1:Ugdm7cg7i6ZK6x3xDF1oEu1nfkyfH53EtKeQYTC3kyg= 21 | github.com/cyphar/filepath-securejoin v0.2.4/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxGGx79pTxQpKOJNYHHl4= 22 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 23 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 24 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 25 | github.com/elazarl/goproxy v0.0.0-20230808193330-2592e75ae04a h1:mATvB/9r/3gvcejNsXKSkQ6lcIaNec2nyfOdlTBR2lU= 26 | github.com/elazarl/goproxy v0.0.0-20230808193330-2592e75ae04a/go.mod h1:Ro8st/ElPeALwNFlcTpWmkr6IoMFfkjXAvTHpevnDsM= 27 | github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= 28 | github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= 29 | github.com/gliderlabs/ssh v0.3.5 h1:OcaySEmAQJgyYcArR+gGGTHCyE7nvhEMTlYY+Dp8CpY= 30 | github.com/gliderlabs/ssh v0.3.5/go.mod h1:8XB4KraRrX39qHhT6yxPsHedjA08I/uBVwj4xC+/+z4= 31 | github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI= 32 | github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic= 33 | github.com/go-git/go-billy/v5 v5.5.0 h1:yEY4yhzCDuMGSv83oGxiBotRzhwhNr8VZyphhiu+mTU= 34 | github.com/go-git/go-billy/v5 v5.5.0/go.mod h1:hmexnoNsr2SJU1Ju67OaNz5ASJY3+sHgFRpCtpDCKow= 35 | github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMje31YglSBqCdIqdhKBW8lokaMrL3uTkpGYlE2OOT4= 36 | github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII= 37 | github.com/go-git/go-git/v5 v5.11.0 h1:XIZc1p+8YzypNr34itUfSvYJcv+eYdTnTvOZ2vD3cA4= 38 | github.com/go-git/go-git/v5 v5.11.0/go.mod h1:6GFcX2P3NM7FPBfpePbpLd21XxsgdAt+lKqXmCUiUCY= 39 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= 40 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= 41 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 42 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 43 | github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= 44 | github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= 45 | github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4= 46 | github.com/kevinburke/ssh_config v1.2.0/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= 47 | github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= 48 | github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= 49 | github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= 50 | github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= 51 | github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 52 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 53 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 54 | github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= 55 | github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= 56 | github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= 57 | github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= 58 | github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= 59 | github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= 60 | github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= 61 | github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= 62 | github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= 63 | github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= 64 | github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= 65 | github.com/onsi/gomega v1.27.10 h1:naR28SdDFlqrG6kScpT8VWpu1xWY5nJRCF3XaYyBjhI= 66 | github.com/onsi/gomega v1.27.10/go.mod h1:RsS8tutOdbdgzbPtzzATp12yT7kM5I5aElG3evPbQ0M= 67 | github.com/pjbgf/sha1cd v0.3.0 h1:4D5XXmUUBUl/xQ6IjCkEAbqXskkq/4O7LmGn0AqMDs4= 68 | github.com/pjbgf/sha1cd v0.3.0/go.mod h1:nZ1rrWOcGJ5uZgEEVL1VUM9iRQiZvWdbZjkKyFzPPsI= 69 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 70 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 71 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 72 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 73 | github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= 74 | github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= 75 | github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= 76 | github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= 77 | github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= 78 | github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= 79 | github.com/sergi/go-diff v1.1.0 h1:we8PVUC3FE2uYfodKH/nBHMSetSfHDR6scGdBi+erh0= 80 | github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= 81 | github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= 82 | github.com/skeema/knownhosts v1.2.1 h1:SHWdIUa82uGZz+F+47k8SY4QhhI291cXCpopT1lK2AQ= 83 | github.com/skeema/knownhosts v1.2.1/go.mod h1:xYbVRSPxqBZFrdmDyMmsOs+uX1UZC3nTN3ThzgDxUwo= 84 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 85 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 86 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 87 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 88 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 89 | github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= 90 | github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw= 91 | github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= 92 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 93 | golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= 94 | golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= 95 | golang.org/x/crypto v0.3.1-0.20221117191849-2c476679df9a/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= 96 | golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= 97 | golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= 98 | golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= 99 | golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= 100 | golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= 101 | golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= 102 | golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= 103 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 104 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= 105 | golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= 106 | golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= 107 | golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= 108 | golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= 109 | golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= 110 | golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= 111 | golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= 112 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 113 | golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 114 | golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 115 | golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= 116 | golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 117 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 118 | golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 119 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 120 | golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 121 | golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 122 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 123 | golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 124 | golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 125 | golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 126 | golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 127 | golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 128 | golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 129 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 130 | golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= 131 | golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 132 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 133 | golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= 134 | golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= 135 | golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= 136 | golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= 137 | golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= 138 | golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= 139 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 140 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 141 | golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 142 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= 143 | golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= 144 | golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= 145 | golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= 146 | golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= 147 | golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= 148 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 149 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 150 | golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= 151 | golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= 152 | golang.org/x/tools v0.16.1 h1:TLyB3WofjdOEepBHAU20JdNC1Zbg87elYofWYAY5oZA= 153 | golang.org/x/tools v0.16.1/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0= 154 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 155 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 156 | gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 157 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 158 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 159 | gopkg.in/warnings.v0 v0.1.2 h1:wFXVbFY8DY5/xOe1ECiWdKCzZlxgshcYVNkBHstARME= 160 | gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= 161 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 162 | gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 163 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 164 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 165 | -------------------------------------------------------------------------------- /log/injection/injection.go: -------------------------------------------------------------------------------- 1 | package injection 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/picatz/taint" 8 | "github.com/picatz/taint/callgraphutil" 9 | 10 | "golang.org/x/tools/go/analysis" 11 | "golang.org/x/tools/go/analysis/passes/buildssa" 12 | ) 13 | 14 | var userControlledValues = taint.NewSources( 15 | "*net/http.Request", 16 | ) 17 | 18 | var injectableLogFunctions = taint.NewSinks( 19 | // Note: at this time, they *must* be a function or method. 20 | "log.Fatal", 21 | "log.Fatalf", 22 | "log.Fatalln", 23 | "log.Panic", 24 | "log.Panicf", 25 | "log.Panicln", 26 | "log.Print", 27 | "log.Printf", 28 | "log.Println", 29 | "log.Output", 30 | "log.SetOutput", 31 | "log.SetPrefix", 32 | "log.Writer", 33 | "(*log.Logger).Fatal", 34 | "(*log.Logger).Fatalf", 35 | "(*log.Logger).Fatalln", 36 | "(*log.Logger).Panic", 37 | "(*log.Logger).Panicf", 38 | "(*log.Logger).Panicln", 39 | "(*log.Logger).Print", 40 | "(*log.Logger).Printf", 41 | "(*log.Logger).Println", 42 | "(*log.Logger).Output", 43 | "(*log.Logger).SetOutput", 44 | "(*log.Logger).SetPrefix", 45 | "(*log.Logger).Writer", 46 | 47 | // log/slog (structured logging) 48 | // https://pkg.go.dev/log/slog 49 | "log/slog.Debug", 50 | "log/slog.DebugContext", 51 | "log/slog.Error", 52 | "log/slog.ErrorContext", 53 | "log/slog.Info", 54 | "log/slog.InfoContext", 55 | "log/slog.Warn", 56 | "log/slog.WarnContext", 57 | "log/slog.Log", 58 | "log/slog.LogAttrs", 59 | "(*log/slog.Logger).With", 60 | "(*log/slog.Logger).Debug", 61 | "(*log/slog.Logger).DebugContext", 62 | "(*log/slog.Logger).Error", 63 | "(*log/slog.Logger).ErrorContext", 64 | "(*log/slog.Logger).Info", 65 | "(*log/slog.Logger).InfoContext", 66 | "(*log/slog.Logger).Warn", 67 | "(*log/slog.Logger).WarnContext", 68 | "(*log/slog.Logger).Log", 69 | "(*log/slog.Logger).LogAttrs", 70 | "log/slog.NewRecord", 71 | "(*log/slog.Record).Add", 72 | "(*log/slog.Record).AddAttrs", 73 | 74 | // TODO: consider adding the following logger packages, 75 | // and the ability to configure this list generically. 76 | // 77 | // https://pkg.go.dev/golang.org/x/exp/slog 78 | // https://pkg.go.dev/github.com/golang/glog 79 | // https://pkg.go.dev/github.com/hashicorp/go-hclog 80 | // https://pkg.go.dev/github.com/sirupsen/logrus 81 | // https://pkg.go.dev/go.uber.org/zap 82 | // ... 83 | ) 84 | 85 | // Analyzer finds potential log injection issues to demonstrate 86 | // the github.com/picatz/taint package. 87 | var Analyzer = &analysis.Analyzer{ 88 | Name: "logi", 89 | Doc: "finds potential log injection issues", 90 | Run: run, 91 | Requires: []*analysis.Analyzer{buildssa.Analyzer}, 92 | } 93 | 94 | // imports returns true if the package imports any of the given packages. 95 | func imports(pass *analysis.Pass, pkgs ...string) bool { 96 | var imported bool 97 | for _, imp := range pass.Pkg.Imports() { 98 | for _, pkg := range pkgs { 99 | if strings.HasSuffix(imp.Path(), pkg) { 100 | imported = true 101 | break 102 | } 103 | } 104 | if imported { 105 | break 106 | } 107 | } 108 | return imported 109 | } 110 | 111 | func run(pass *analysis.Pass) (interface{}, error) { 112 | // Require the log package is imported in the 113 | // program being analyzed before running the analysis. 114 | // 115 | // This prevents wasting time analyzing programs that don't log. 116 | if !imports(pass, "log", "log/slog") { 117 | return nil, nil 118 | } 119 | 120 | // Get the built SSA IR. 121 | buildSSA := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA) 122 | 123 | // Identify the main function from the package's SSA IR. 124 | mainFn := buildSSA.Pkg.Func("main") 125 | if mainFn == nil { 126 | return nil, nil 127 | } 128 | 129 | // Construct a callgraph, using the main function as the root, 130 | // constructed of all other functions. This returns a callgraph 131 | // we can use to identify directed paths to logging functions. 132 | cg, err := callgraphutil.NewGraph(mainFn, buildSSA.SrcFuncs...) 133 | if err != nil { 134 | return nil, fmt.Errorf("failed to create new callgraph: %w", err) 135 | } 136 | 137 | // Run taint check for user controlled values (sources) ending 138 | // up in injectable log functions (sinks). 139 | results := taint.Check(cg, userControlledValues, injectableLogFunctions) 140 | 141 | // For each result, check if a prepared statement is providing 142 | // a mitigation for the user controlled value. 143 | // 144 | // TODO: ensure this makes sense for all the GORM usage? 145 | for _, result := range results { 146 | pass.Reportf(result.SinkValue.Pos(), "potential log injection") 147 | } 148 | 149 | return nil, nil 150 | } 151 | -------------------------------------------------------------------------------- /log/injection/injection_test.go: -------------------------------------------------------------------------------- 1 | package injection 2 | 3 | import ( 4 | "testing" 5 | 6 | "golang.org/x/tools/go/analysis/analysistest" 7 | ) 8 | 9 | var testdata = analysistest.TestData() 10 | 11 | func TestA(t *testing.T) { 12 | analysistest.Run(t, testdata, Analyzer, "a") 13 | } 14 | 15 | func TestB(t *testing.T) { 16 | analysistest.Run(t, testdata, Analyzer, "b") 17 | } 18 | 19 | func TestC(t *testing.T) { 20 | analysistest.Run(t, testdata, Analyzer, "c") 21 | } 22 | 23 | func TestD(t *testing.T) { 24 | analysistest.Run(t, testdata, Analyzer, "d") 25 | } 26 | 27 | func TestE(t *testing.T) { 28 | analysistest.Run(t, testdata, Analyzer, "e") 29 | } 30 | 31 | func TestF(t *testing.T) { 32 | analysistest.Run(t, testdata, Analyzer, "f") 33 | } 34 | 35 | func TestG(t *testing.T) { 36 | analysistest.Run(t, testdata, Analyzer, "g") 37 | } 38 | -------------------------------------------------------------------------------- /log/injection/testdata/src/a/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | ) 7 | 8 | func main() { 9 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 10 | log.Println(r.URL.Query().Get("input")) // want "potential log injection" 11 | }) 12 | 13 | http.ListenAndServe(":8080", nil) 14 | } 15 | -------------------------------------------------------------------------------- /log/injection/testdata/src/b/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | ) 7 | 8 | func l(input string) { 9 | l := log.New(nil, "", 0) 10 | l.Println(input) // want "potential log injection" 11 | } 12 | 13 | func buisness(input string) { 14 | l(input) 15 | } 16 | 17 | func main() { 18 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 19 | input := r.URL.Query().Get("input") 20 | 21 | buisness(input) 22 | }) 23 | 24 | http.ListenAndServe(":8080", nil) 25 | } 26 | -------------------------------------------------------------------------------- /log/injection/testdata/src/c/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log" 5 | "net/http" 6 | ) 7 | 8 | func l(input string) { 9 | log.Println(input) // want "potential log injection" 10 | } 11 | 12 | func buisness(input string) { 13 | l(input) 14 | } 15 | 16 | func main() { 17 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 18 | input := r.URL.Query().Get("input") 19 | 20 | f := func() { 21 | buisness(input) 22 | } 23 | 24 | f() 25 | }) 26 | 27 | http.ListenAndServe(":8080", nil) 28 | } 29 | -------------------------------------------------------------------------------- /log/injection/testdata/src/d/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log/slog" 5 | "net/http" 6 | ) 7 | 8 | func l(input string) { 9 | slog.Info(input) // want "potential log injection" 10 | } 11 | 12 | func buisness(input string) { 13 | l(input) 14 | } 15 | 16 | func main() { 17 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 18 | input := r.URL.Query().Get("input") 19 | 20 | f := func() { 21 | buisness(input) 22 | } 23 | 24 | f() 25 | }) 26 | 27 | http.ListenAndServe(":8080", nil) 28 | } 29 | -------------------------------------------------------------------------------- /log/injection/testdata/src/e/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | "net/http" 7 | "os" 8 | ) 9 | 10 | var logger = slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ 11 | Level: slog.LevelInfo, 12 | })) 13 | 14 | func l(input string) { 15 | logger.InfoContext(context.Background(), "l", "input", input) // want "potential log injection" 16 | } 17 | 18 | func buisness(input string) { 19 | l(input) 20 | } 21 | 22 | func main() { 23 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 24 | input := r.URL.Query().Get("input") 25 | 26 | f := func() { 27 | buisness(input) 28 | } 29 | 30 | f() 31 | }) 32 | 33 | http.ListenAndServe(":8080", nil) 34 | } 35 | -------------------------------------------------------------------------------- /log/injection/testdata/src/f/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | "net/http" 7 | "os" 8 | ) 9 | 10 | var logger = slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ 11 | Level: slog.LevelInfo, 12 | })) 13 | 14 | func l(input string) { 15 | logger.InfoContext(context.Background(), "l", "input", map[string]string{"value": input}) // want "potential log injection" 16 | } 17 | 18 | func buisness(input string) { 19 | l(input) 20 | } 21 | 22 | func main() { 23 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 24 | input := r.URL.Query().Get("input") 25 | 26 | f := func() { 27 | buisness(input) 28 | } 29 | 30 | f() 31 | }) 32 | 33 | http.ListenAndServe(":8080", nil) 34 | } 35 | -------------------------------------------------------------------------------- /log/injection/testdata/src/g/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | "net/http" 7 | "os" 8 | ) 9 | 10 | func l(logger *slog.Logger, input string) { 11 | logger2 := logger.With("input", input).WithGroup("l") // want "potential log injection" 12 | 13 | logger2.InfoContext(context.Background(), "l", "input", []string{input}) // want "potential log injection" 14 | } 15 | 16 | func buisness(logger *slog.Logger, input string) { 17 | l(logger, input) 18 | } 19 | 20 | func main() { 21 | logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ 22 | Level: slog.LevelInfo, 23 | })) 24 | 25 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 26 | input := r.URL.Query().Get("input") 27 | 28 | f := func() { 29 | buisness(logger, input) 30 | } 31 | 32 | f() 33 | }) 34 | 35 | http.ListenAndServe(":8080", nil) 36 | } 37 | -------------------------------------------------------------------------------- /sources_sinks.go: -------------------------------------------------------------------------------- 1 | package taint 2 | 3 | import ( 4 | "golang.org/x/tools/go/ssa" 5 | ) 6 | 7 | // valueSet is a set of ssa.Values that can be used to track 8 | // the values that have been visited during a traversal. This 9 | // is used to prevent infinite recursion, and to prevent 10 | // visiting the same value multiple times. 11 | type valueSet map[ssa.Value]struct{} 12 | 13 | // includes returns true if the value is in the set. 14 | func (v valueSet) includes(sv ssa.Value) bool { 15 | if v == nil { 16 | return false 17 | } 18 | _, ok := v[sv] 19 | return ok 20 | } 21 | 22 | // add adds the value to the set. 23 | func (v valueSet) add(sv ssa.Value) { 24 | if v == nil { 25 | v = valueSet{} 26 | } 27 | v[sv] = struct{}{} 28 | } 29 | 30 | // stringSet is a set of unique strings that express 31 | // the types of sources and sinks that are being 32 | // tracked. 33 | type stringSet map[string]struct{} 34 | 35 | // includes returns true if the string is in the set. 36 | func (t stringSet) includes(str string) (string, bool) { 37 | if t == nil { 38 | return "", false 39 | } 40 | _, ok := t[str] 41 | return str, ok 42 | } 43 | 44 | // Sources are the types that are considered "sources" of 45 | // tainted data in the program. 46 | type Sources = stringSet 47 | 48 | // NewSources returns a new Sources set with the given 49 | // source types. 50 | func NewSources(sourceTypes ...string) Sources { 51 | srcs := Sources{} 52 | 53 | for _, src := range sourceTypes { 54 | srcs[src] = struct{}{} 55 | } 56 | 57 | return srcs 58 | } 59 | 60 | // Sinks are the types that are considered "sinks" that 61 | // tainted data in the program may flow into. 62 | type Sinks = stringSet 63 | 64 | // NewSinks returns a new Sinks set with the given 65 | // sink types. 66 | func NewSinks(sinkTypes ...string) Sinks { 67 | snks := Sinks{} 68 | 69 | for _, snk := range sinkTypes { 70 | snks[snk] = struct{}{} 71 | } 72 | 73 | return snks 74 | } 75 | -------------------------------------------------------------------------------- /sql/injection/injection.go: -------------------------------------------------------------------------------- 1 | package injection 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/picatz/taint" 8 | "github.com/picatz/taint/callgraphutil" 9 | 10 | "golang.org/x/tools/go/analysis" 11 | "golang.org/x/tools/go/analysis/passes/buildssa" 12 | "golang.org/x/tools/go/ssa" 13 | ) 14 | 15 | // userControlledValues are the sources of user controlled values that 16 | // can be tained and end up in a SQL query. 17 | var userControlledValues = taint.NewSources( 18 | // Function (and method) calls 19 | // "(net/url.Values).Get", 20 | // "(*net/url.URL).Query", 21 | // "(*net/url.URL).Redacted", 22 | // "(*net/url.URL).EscapedFragment", 23 | // "(*net/url.Userinfo).Username", 24 | // "(*net/url.Userinfo).Passworde", 25 | // "(*net/url.Userinfo).String", 26 | // "(*net/http.Request).FormFile", 27 | // "(*net/http.Request).FormValue", 28 | // "(*net/http.Request).PostFormValue", 29 | // "(*net/http.Request).Referer", 30 | // "(*net/http.Request).UserAgent", 31 | // "(*net/http.Request).GetBody", 32 | // "(net/http.Header).Get", 33 | // "(net/http.Header).Values", 34 | // 35 | // Types (and fields) 36 | "*net/http.Request", 37 | // 38 | // "google.golang.org/grpc/metadata.MD", ? 39 | // 40 | // TODO: add more, consider pointer variants and specific fields on types 41 | // TODO: consider support for protobuf defined *Request types... 42 | // TODO: consider supprot for gRPC request metadata (HTTP2 headers) 43 | // TODO: consider support for msgpack-rpc? 44 | ) 45 | 46 | var injectableSQLMethods = taint.NewSinks( 47 | // Note: at this time, they *must* be a function or method. 48 | "(*database/sql.DB).Query", 49 | "(*database/sql.DB).QueryContext", 50 | "(*database/sql.DB).QueryRow", 51 | "(*database/sql.DB).QueryRowContext", 52 | "(*database/sql.Tx).Query", 53 | "(*database/sql.Tx).QueryContext", 54 | "(*database/sql.Tx).QueryRow", 55 | "(*database/sql.Tx).QueryRowContext", 56 | // GORM v1 57 | // https://gorm.io/docs/security.html 58 | // https://gorm.io/docs/security.html#SQL-injection-Methods 59 | "(*github.com/jinzhu/gorm.DB).Where", 60 | "(*github.com/jinzhu/gorm.DB).Or", 61 | "(*github.com/jinzhu/gorm.DB).Not", 62 | "(*github.com/jinzhu/gorm.DB).Group", 63 | "(*github.com/jinzhu/gorm.DB).Having", 64 | "(*github.com/jinzhu/gorm.DB).Joins", 65 | "(*github.com/jinzhu/gorm.DB).Select", 66 | "(*github.com/jinzhu/gorm.DB).Distinct", 67 | "(*github.com/jinzhu/gorm.DB).Pluck", 68 | "(*github.com/jinzhu/gorm.DB).Raw", 69 | "(*github.com/jinzhu/gorm.DB).Exec", 70 | "(*github.com/jinzhu/gorm.DB).Order", 71 | // 72 | // TODO: add more, consider (non-)pointer variants? 73 | ) 74 | 75 | // Analyzer finds potential SQL injection issues to demonstrate 76 | // the github.com/picatz/taint package. 77 | var Analyzer = &analysis.Analyzer{ 78 | Name: "sqli", 79 | Doc: "finds potential SQL injection issues", 80 | Run: run, 81 | Requires: []*analysis.Analyzer{buildssa.Analyzer}, 82 | } 83 | 84 | // imports returns true if the package imports any of the given packages. 85 | func imports(pass *analysis.Pass, pkgs ...string) bool { 86 | var imported bool 87 | for _, imp := range pass.Pkg.Imports() { 88 | for _, pkg := range pkgs { 89 | if strings.HasSuffix(imp.Path(), pkg) { 90 | imported = true 91 | break 92 | } 93 | } 94 | if imported { 95 | break 96 | } 97 | } 98 | return imported 99 | } 100 | 101 | func run(pass *analysis.Pass) (interface{}, error) { 102 | // Require the database/sql or GORM v1 packages are imported in the 103 | // program being analyzed before running the analysis. 104 | // 105 | // This prevents wasting time analyzing programs that don't use SQL. 106 | if !imports(pass, "database/sql", "github.com/jinzhu/gorm") { 107 | return nil, nil 108 | } 109 | 110 | // Get the built SSA IR. 111 | buildSSA := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA) 112 | 113 | // Identify the main function from the package's SSA IR. 114 | mainFn := buildSSA.Pkg.Func("main") 115 | if mainFn == nil { 116 | return nil, nil 117 | } 118 | 119 | // Construct a callgraph, using the main function as the root, 120 | // constructed of all other functions. This returns a callgraph 121 | // we can use to identify directed paths to SQL queries. 122 | cg, err := callgraphutil.NewGraph(mainFn, buildSSA.SrcFuncs...) 123 | if err != nil { 124 | return nil, fmt.Errorf("failed to create new callgraph: %w", err) 125 | } 126 | 127 | // If you'd like to compare the callgraph constructed by the 128 | // callgraphutil package to the one constructed by others 129 | // (e.g. pointer analysis, rta, cha, static, etc), uncomment the 130 | // following lines and compare the output. 131 | // 132 | // Today, I believe the callgraphutil package is the most 133 | // accurate, but I'd love to be proven wrong. 134 | 135 | // Note: this actually panis for testcase b 136 | // ptares, err := pointer.Analyze(&pointer.Config{ 137 | // Mains: []*ssa.Package{buildSSA.Pkg}, 138 | // BuildCallGraph: true, 139 | // }) 140 | // if err != nil { 141 | // return nil, fmt.Errorf("failed to create new callgraph using pointer analysis: %w", err) 142 | // } 143 | // cg := ptares.CallGraph 144 | 145 | // cg := rta.Analyze([]*ssa.Function{mainFn}, true).CallGraph 146 | // cg := cha.CallGraph(buildSSA.Pkg.Prog) 147 | // cg := static.CallGraph(buildSSA.Pkg.Prog) 148 | 149 | // https://github.com/golang/vuln/blob/7335627909c99e391cf911fcd214badcb8aa6d7d/internal/vulncheck/utils.go#L61 150 | // cg, err := callgraphutil.NewVulncheckCallGraph(context.Background(), buildSSA.Pkg.Prog, buildSSA.SrcFuncs) 151 | // if err != nil { 152 | // return nil, err 153 | // } 154 | // cg.Root = cg.CreateNode(mainFn) 155 | 156 | // fmt.Println(callgraphutil.CallGraphString(cg)) 157 | 158 | // Run taint check for user controlled values (sources) ending 159 | // up in injectable SQL methods (sinks). 160 | results := taint.Check(cg, userControlledValues, injectableSQLMethods) 161 | 162 | // For each result, check if a prepared statement is providing 163 | // a mitigation for the user controlled value. 164 | // 165 | // TODO: ensure this makes sense for all the GORM usage? 166 | for _, result := range results { 167 | // We found a query edge that is tainted by user input, is it 168 | // doing this safely? We expect this to be safely done by 169 | // providing a prepared statement as a constant in the query 170 | // (first argument after context). 171 | queryEdge := result.Path[len(result.Path)-1] 172 | 173 | // Get the query arguments, skipping the first element, pointer to the DB. 174 | queryArgs := queryEdge.Site.Common().Args[1:] 175 | 176 | // Skip the context argument, if using a *Context query variant. 177 | if strings.HasPrefix(queryEdge.Site.Value().Call.Value.String(), "Context") { 178 | queryArgs = queryArgs[1:] 179 | } 180 | 181 | // Get the query function parameter. 182 | query := queryArgs[0] 183 | 184 | // Ensure it is a constant (prepared statement), otherwise report 185 | // potential SQL injection. 186 | if _, isConst := query.(*ssa.Const); !isConst { 187 | pass.Reportf(result.SinkValue.Pos(), "potential sql injection") 188 | } 189 | } 190 | 191 | return nil, nil 192 | } 193 | -------------------------------------------------------------------------------- /sql/injection/injection_test.go: -------------------------------------------------------------------------------- 1 | package injection 2 | 3 | import ( 4 | "testing" 5 | 6 | "golang.org/x/tools/go/analysis/analysistest" 7 | ) 8 | 9 | var testdata = analysistest.TestData() 10 | 11 | func TestA(t *testing.T) { 12 | analysistest.Run(t, testdata, Analyzer, "a") 13 | } 14 | 15 | func TestB(t *testing.T) { 16 | analysistest.Run(t, testdata, Analyzer, "b") 17 | } 18 | 19 | func TestC(t *testing.T) { 20 | analysistest.Run(t, testdata, Analyzer, "c") 21 | } 22 | 23 | func TestD(t *testing.T) { 24 | analysistest.Run(t, testdata, Analyzer, "d") 25 | } 26 | 27 | func TestE(t *testing.T) { 28 | analysistest.Run(t, testdata, Analyzer, "e") 29 | } 30 | 31 | func TestF(t *testing.T) { 32 | analysistest.Run(t, testdata, Analyzer, "f") 33 | } 34 | 35 | func TestG(t *testing.T) { 36 | analysistest.Run(t, testdata, Analyzer, "g") 37 | } 38 | 39 | // TODO: this is not worked out yet 40 | func TestH(t *testing.T) { 41 | // t.Skip("skipping known failing test for now") 42 | analysistest.Run(t, testdata, Analyzer, "h") 43 | } 44 | 45 | func TestExample(t *testing.T) { 46 | analysistest.Run(t, testdata, Analyzer, "example") 47 | } 48 | 49 | func TestI(t *testing.T) { 50 | analysistest.Run(t, testdata, Analyzer, "i") 51 | } 52 | -------------------------------------------------------------------------------- /sql/injection/testdata/src/a/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql" 5 | "os" 6 | ) 7 | 8 | func main() { 9 | db, err := sql.Open("sqlite3", ":memory:") 10 | if err != nil { 11 | panic(err) 12 | } 13 | rows, err := db.Query("SELECT * FROM foo where name=?", os.Args[1]) 14 | if err != nil { 15 | panic(err) 16 | } 17 | defer rows.Close() 18 | } 19 | -------------------------------------------------------------------------------- /sql/injection/testdata/src/b/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "net/http" 7 | "net/url" 8 | ) 9 | 10 | func business(db *sql.DB, name string) error { 11 | q := fmt.Sprintf("SELECT * FROM voo where name='%s'", name) 12 | _, err := db.Query(q) // want "potential sql injection" 13 | if err != nil { 14 | return err 15 | } 16 | return nil 17 | } 18 | 19 | func business2(db *sql.DB, name string) error { 20 | q := "SELECT * FROM roo where name='" + name + "'" 21 | _, err := db.Query(q) // want "potential sql injection" 22 | if err != nil { 23 | return err 24 | } 25 | return nil 26 | } 27 | 28 | func business3(db *sql.DB, query string) error { 29 | _, err := db.Query(query) // want "potential sql injection" 30 | if err != nil { 31 | return err 32 | } 33 | return nil 34 | } 35 | 36 | func business4(db *sql.DB, query string) error { 37 | _, err := db.Query(query) 38 | if err != nil { 39 | return err 40 | } 41 | return nil 42 | } 43 | 44 | type logic struct { 45 | name string 46 | } 47 | 48 | func business5(db *sql.DB, l logic) error { 49 | _, err := db.Query(l.name) // want "potential sql injection" 50 | if err != nil { 51 | return err 52 | } 53 | return nil 54 | } 55 | 56 | func business6(db *sql.DB, u url.Values) error { 57 | _, err := db.Query(u.Get("query")) // want "potential sql injection" 58 | if err != nil { 59 | return err 60 | } 61 | return nil 62 | } 63 | 64 | func business7(db *sql.DB, q string) error { 65 | _, err := db.Query(q) // want "potential sql injection" 66 | if err != nil { 67 | return err 68 | } 69 | return nil 70 | } 71 | 72 | func realMain() { 73 | db, err := sql.Open("sqlite3", ":memory:") 74 | if err != nil { 75 | panic(err) 76 | } 77 | mux := http.NewServeMux() 78 | 79 | mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 80 | q := fmt.Sprintf("SELECT * FROM foo where nameo='%s'", r.URL.Query().Get("name")) 81 | rows, err := db.Query(q) // want "potential sql injection" 82 | if err != nil { 83 | http.Error(w, err.Error(), http.StatusInternalServerError) 84 | return 85 | } 86 | w.Write([]byte(fmt.Sprintf("%#+v", rows))) 87 | }) 88 | 89 | mux.HandleFunc("/bar", func(w http.ResponseWriter, r *http.Request) { 90 | rows, err := db.Query(fmt.Sprintf("SELECT * FROM bar where name='%s'", r.URL.Query().Get("name"))) // want "potential sql injection" 91 | if err != nil { 92 | http.Error(w, err.Error(), http.StatusInternalServerError) 93 | return 94 | } 95 | w.Write([]byte(fmt.Sprintf("%#+v", rows))) 96 | }) 97 | 98 | mux.HandleFunc("/baz", func(w http.ResponseWriter, r *http.Request) { 99 | name := r.URL.Query().Get("name") 100 | business(db, name) 101 | business2(db, r.URL.Query().Get("name2")) 102 | }) 103 | 104 | mux.HandleFunc("/boo", func(w http.ResponseWriter, r *http.Request) { 105 | name := r.URL.Query().Get("query") 106 | name2 := name 107 | business3(db, name2) 108 | 109 | if r.Form.Get("lol") != "" { 110 | business4(db, "SELECT * FROM lol where name='lol'") 111 | } else { 112 | _, err := db.Query("SELECT * FROM lol where name=?", name) 113 | if err != nil { 114 | panic(err) 115 | } 116 | } 117 | 118 | r.URL.User.Password() 119 | 120 | business5(db, logic{name: name2}) 121 | 122 | business6(db, r.URL.Query()) 123 | 124 | pass, _ := r.URL.User.Password() 125 | business7(db, pass) 126 | }) 127 | } 128 | 129 | func main() { 130 | realMain() 131 | } 132 | -------------------------------------------------------------------------------- /sql/injection/testdata/src/c/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql" 5 | "net/http" 6 | ) 7 | 8 | func business(db *sql.DB, q *string) error { 9 | _, err := db.Query(*q) // want "potential sql injection" 10 | if err != nil { 11 | return err 12 | } 13 | return nil 14 | } 15 | 16 | func realMain() { 17 | db, err := sql.Open("sqlite3", ":memory:") 18 | if err != nil { 19 | panic(err) 20 | } 21 | mux := http.NewServeMux() 22 | 23 | mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 24 | pass, _ := r.URL.User.Password() 25 | business(db, func() *string { 26 | return &pass 27 | }()) 28 | }) 29 | } 30 | 31 | func main() { 32 | realMain() 33 | } 34 | -------------------------------------------------------------------------------- /sql/injection/testdata/src/d/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql" 5 | "net/http" 6 | ) 7 | 8 | func handle(db *sql.DB, q string) { 9 | db.Query(q) // want "potential sql injection" 10 | } 11 | 12 | func business(db *sql.DB, q *string) error { 13 | handle(db, *q) 14 | return nil 15 | } 16 | 17 | func realMain() { 18 | db, _ := sql.Open("sqlite3", ":memory:") 19 | 20 | mux := http.NewServeMux() 21 | 22 | mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 23 | user := r.URL.Query()["query"] 24 | func() { 25 | userValue := user[0] 26 | business(db, func() *string { 27 | return &userValue 28 | }()) 29 | }() 30 | }) 31 | } 32 | 33 | func main() { 34 | realMain() 35 | } 36 | -------------------------------------------------------------------------------- /sql/injection/testdata/src/e/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql" 5 | "net/http" 6 | ) 7 | 8 | func handle(db *sql.DB, q string) { 9 | db.Query(q) // want "potential sql injection" 10 | } 11 | 12 | func business(db *sql.DB, q *string) error { 13 | handle(db, *q) 14 | return nil 15 | } 16 | 17 | func realMain() { 18 | db, _ := sql.Open("sqlite3", ":memory:") 19 | 20 | mux := http.NewServeMux() 21 | 22 | mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 23 | user := r.FormValue("query") 24 | func() { 25 | userValue := string(user) 26 | business(db, func() *string { 27 | return &userValue 28 | }()) 29 | }() 30 | }) 31 | } 32 | 33 | func main() { 34 | realMain() 35 | } 36 | -------------------------------------------------------------------------------- /sql/injection/testdata/src/example/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql" 5 | "net/http" 6 | ) 7 | 8 | func business(db *sql.DB, q string) { 9 | db.Query(q) // want "potential sql injection" 10 | } 11 | 12 | func run() { 13 | db, _ := sql.Open("sqlite3", ":memory:") 14 | 15 | mux := http.NewServeMux() 16 | 17 | mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 18 | business(db, r.URL.Query().Get("sql-query")) 19 | }) 20 | 21 | http.ListenAndServe(":8080", mux) 22 | } 23 | 24 | func main() { 25 | run() 26 | } 27 | -------------------------------------------------------------------------------- /sql/injection/testdata/src/f/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql" 5 | "fmt" 6 | "net/http" 7 | ) 8 | 9 | func handle(db *sql.DB, u string) { 10 | q := fmt.Sprintf("SELECT * FROM voo where url='%s'", u) 11 | db.Query(q) // want "potential sql injection" 12 | } 13 | 14 | func business(db *sql.DB, q *string) error { 15 | handle(db, *q) 16 | return nil 17 | } 18 | 19 | func realMain() { 20 | db, _ := sql.Open("sqlite3", ":memory:") 21 | 22 | mux := http.NewServeMux() 23 | 24 | mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 25 | u := r.URL 26 | func() { 27 | userValue := fmt.Sprintf("%s", u) 28 | business(db, func() *string { 29 | return &userValue 30 | }()) 31 | }() 32 | }) 33 | } 34 | 35 | func main() { 36 | realMain() 37 | } 38 | -------------------------------------------------------------------------------- /sql/injection/testdata/src/g/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | 7 | "github.com/jinzhu/gorm" 8 | ) 9 | 10 | type User struct { 11 | gorm.Model `json:"model"` 12 | Name string `json:"name"` 13 | Email string `json:"email"` 14 | } 15 | 16 | func handle(db *gorm.DB, u string) { 17 | q := fmt.Sprintf("url='%s'", u) 18 | 19 | var users []User 20 | db.Where(q).Find(&users) // want "potential sql injection" 21 | fmt.Println() 22 | } 23 | 24 | func business(db *gorm.DB, q *string) error { 25 | handle(db, *q) 26 | return nil 27 | } 28 | 29 | func realMain() { 30 | db, _ := gorm.Open("sqlite3", ":memory:") 31 | 32 | mux := http.NewServeMux() 33 | 34 | mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 35 | u := r.URL 36 | func() { 37 | userValue := fmt.Sprintf("%s", u) 38 | business(db, func() *string { 39 | return &userValue 40 | }()) 41 | }() 42 | }) 43 | } 44 | 45 | func main() { 46 | realMain() 47 | } 48 | -------------------------------------------------------------------------------- /sql/injection/testdata/src/github.com/jinzhu/gorm/mock.go: -------------------------------------------------------------------------------- 1 | package gorm 2 | 3 | // Model is mocked from https://github.com/jinzhu/gorm/blob/5c235b72a414e448d1f441aba24a47fd6eb976f4/model.go#L9 4 | type Model struct{} 5 | 6 | // DB is mocked from https://github.com/jinzhu/gorm/blob/5c235b72a414e448d1f441aba24a47fd6eb976f4/main.go#L15 7 | type DB struct{} 8 | 9 | // Open is mocked from https://github.com/jinzhu/gorm/blob/5c235b72a414e448d1f441aba24a47fd6eb976f4/main.go#L58 10 | func Open(dialect string, args ...interface{}) (db *DB, err error) { 11 | return nil, nil 12 | } 13 | 14 | // Where is mocked from https://github.com/jinzhu/gorm/blob/5c235b72a414e448d1f441aba24a47fd6eb976f4/main.go#L237 15 | func (s *DB) Where(query interface{}, args ...interface{}) *DB { 16 | return nil 17 | } 18 | 19 | // Or is mocked from https://github.com/jinzhu/gorm/blob/5c235b72a414e448d1f441aba24a47fd6eb976f4/main.go#L242 20 | func (s *DB) Or(query interface{}, args ...interface{}) *DB { 21 | return nil 22 | } 23 | 24 | // Not is mocked from https://github.com/jinzhu/gorm/blob/5c235b72a414e448d1f441aba24a47fd6eb976f4/main.go#L247 25 | func (s *DB) Not(query interface{}, args ...interface{}) *DB { 26 | return nil 27 | } 28 | 29 | // Find is mocked from https://github.com/jinzhu/gorm/blob/5c235b72a414e448d1f441aba24a47fd6eb976f4/main.go#L356 30 | func (s *DB) Find(out interface{}, where ...interface{}) *DB { 31 | return nil 32 | } 33 | 34 | // Take is mocked from https://github.com/jinzhu/gorm/blob/5c235b72a414e448d1f441aba24a47fd6eb976f4/main.go#L341 35 | func (s *DB) Take(out interface{}, where ...interface{}) *DB { 36 | return nil 37 | } 38 | 39 | // First is mocked from https://github.com/jinzhu/gorm/blob/5c235b72a414e448d1f441aba24a47fd6eb976f4/main.go#L332 40 | func (s *DB) First(out interface{}, where ...interface{}) *DB { 41 | return nil 42 | } 43 | 44 | // Take is mocked from https://github.com/jinzhu/gorm/blob/5c235b72a414e448d1f441aba24a47fd6eb976f4/main.go#L348 45 | func (s *DB) Last(out interface{}, where ...interface{}) *DB { 46 | return nil 47 | } 48 | 49 | // Delete is mocked from https://github.com/jinzhu/gorm/blob/5c235b72a414e448d1f441aba24a47fd6eb976f4/main.go#L491 50 | func (s *DB) Delete(value interface{}, where ...interface{}) *DB { 51 | return nil 52 | } 53 | -------------------------------------------------------------------------------- /sql/injection/testdata/src/github.com/lib/pq/main.go: -------------------------------------------------------------------------------- 1 | package pq 2 | 3 | // QuoteIdentifier is a fake implementation to make testing possible 4 | func QuoteIdentifier(s string) string { 5 | return "" 6 | } 7 | -------------------------------------------------------------------------------- /sql/injection/testdata/src/h/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "net/http" 7 | 8 | "github.com/jinzhu/gorm" 9 | ) 10 | 11 | type User struct { 12 | gorm.Model `json:"model"` 13 | Name string `json:"name"` 14 | Email string `json:"email"` 15 | } 16 | 17 | func handle(db *gorm.DB, u string) { 18 | q := fmt.Sprintf("url='%s'", u) 19 | 20 | var users []User 21 | db.Where(q).Find(&users) // want "potential sql injection" 22 | fmt.Println() 23 | } 24 | 25 | func business(db *gorm.DB, q *string) error { 26 | handle(db, *q) 27 | return nil 28 | } 29 | 30 | func realMain() { 31 | db, _ := gorm.Open("sqlite3", ":memory:") 32 | 33 | mux := http.NewServeMux() 34 | 35 | mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 36 | var input map[string]any 37 | json.NewDecoder(r.Body).Decode(&input) 38 | 39 | func() { 40 | userValue := fmt.Sprintf("%s", input["query"]) 41 | business(db, func() *string { 42 | return &userValue 43 | }()) 44 | }() 45 | }) 46 | } 47 | 48 | func main() { 49 | realMain() 50 | } 51 | -------------------------------------------------------------------------------- /sql/injection/testdata/src/i/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "database/sql" 5 | "net/http" 6 | ) 7 | 8 | func business7(db *sql.DB, q string) error { 9 | _, err := db.Query(q) // want "potential sql injection" 10 | if err != nil { 11 | return err 12 | } 13 | return nil 14 | } 15 | 16 | func realMain() { 17 | db, err := sql.Open("sqlite3", ":memory:") 18 | if err != nil { 19 | panic(err) 20 | } 21 | mux := http.NewServeMux() 22 | 23 | mux.HandleFunc("/boo", func(w http.ResponseWriter, r *http.Request) { 24 | pass, _ := r.URL.User.Password() 25 | business7(db, pass) 26 | }) 27 | } 28 | 29 | func main() { 30 | realMain() 31 | } 32 | -------------------------------------------------------------------------------- /walk_ssa.go: -------------------------------------------------------------------------------- 1 | package taint 2 | 3 | import ( 4 | "fmt" 5 | 6 | "golang.org/x/tools/go/ssa" 7 | ) 8 | 9 | var ErrStopWalk = fmt.Errorf("taint: stop walk") 10 | 11 | // WalkSSA walks the SSA IR recursively with a visitor function that 12 | // can be used to inspect each node in the graph. The visitor function 13 | // should return an error if it wants to stop the walk. 14 | func WalkSSA(v ssa.Value, visit func(v ssa.Value) error) error { 15 | visited := make(valueSet) 16 | 17 | return walkSSA(v, visit, visited) 18 | } 19 | 20 | func walkSSA(v ssa.Value, visit func(v ssa.Value) error, visited valueSet) error { 21 | if visited == nil { 22 | visited = make(valueSet) 23 | } 24 | 25 | if visited.includes(v) { 26 | return nil 27 | } 28 | 29 | visited.add(v) 30 | 31 | // fmt.Printf("walk SSA: %s: %[1]T\n", v) 32 | 33 | if err := visit(v); err != nil { 34 | return err 35 | } 36 | 37 | switch v := v.(type) { 38 | case *ssa.Call: 39 | // Check the operands of the call instruction. 40 | for _, opr := range v.Operands(nil) { 41 | if err := walkSSA(*opr, visit, visited); err != nil { 42 | return err 43 | } 44 | } 45 | 46 | // Check the arguments of the call instruction. 47 | for _, arg := range v.Common().Args { 48 | if err := walkSSA(arg, visit, visited); err != nil { 49 | return err 50 | } 51 | } 52 | 53 | // Check the function being called. 54 | if err := walkSSA(v.Call.Value, visit, visited); err != nil { 55 | return err 56 | } 57 | 58 | // Check the return value of the call instruction. 59 | if v.Common().IsInvoke() { 60 | if err := walkSSA(v.Common().Value, visit, visited); err != nil { 61 | return err 62 | } 63 | } 64 | 65 | // Check the return value of the call instruction. 66 | if err := walkSSA(v.Common().Value, visit, visited); err != nil { 67 | return err 68 | } 69 | case *ssa.ChangeInterface: 70 | if err := walkSSA(v.X, visit, visited); err != nil { 71 | return err 72 | } 73 | case *ssa.Convert: 74 | if err := walkSSA(v.X, visit, visited); err != nil { 75 | return err 76 | } 77 | case *ssa.MakeInterface: 78 | if err := walkSSA(v.X, visit, visited); err != nil { 79 | return err 80 | } 81 | case *ssa.Phi: 82 | for _, edge := range v.Edges { 83 | if err := walkSSA(edge, visit, visited); err != nil { 84 | return err 85 | } 86 | } 87 | case *ssa.UnOp: 88 | if err := walkSSA(v.X, visit, visited); err != nil { 89 | return err 90 | } 91 | case *ssa.Function: 92 | for _, block := range v.Blocks { 93 | for _, instr := range block.Instrs { 94 | for _, opr := range instr.Operands(nil) { 95 | if err := walkSSA(*opr, visit, visited); err != nil { 96 | return err 97 | } 98 | } 99 | } 100 | } 101 | default: 102 | // fmt.Printf("? walk SSA %s: %[1]T\n", v) 103 | } 104 | 105 | refs := v.Referrers() 106 | if refs == nil { 107 | return nil 108 | } 109 | 110 | for _, instr := range *refs { 111 | switch instr := instr.(type) { 112 | case *ssa.Store: 113 | // Store instructions need to be checked for both the value being stored, 114 | // and the address being stored to. 115 | if err := walkSSA(instr.Val, visit, visited); err != nil { 116 | return err 117 | } 118 | 119 | if err := walkSSA(instr.Addr, visit, visited); err != nil { 120 | return err 121 | } 122 | case *ssa.Call: 123 | // Check the operands of the call instruction. 124 | for _, opr := range instr.Operands(nil) { 125 | if err := walkSSA(*opr, visit, visited); err != nil { 126 | return err 127 | } 128 | } 129 | 130 | // Check the arguments of the call instruction. 131 | for _, arg := range instr.Common().Args { 132 | if err := walkSSA(arg, visit, visited); err != nil { 133 | return err 134 | } 135 | } 136 | 137 | // Check the function being called. 138 | if err := walkSSA(instr.Call.Value, visit, visited); err != nil { 139 | return err 140 | } 141 | 142 | // Check the return value of the call instruction. 143 | if instr.Common().IsInvoke() { 144 | if err := walkSSA(instr.Common().Value, visit, visited); err != nil { 145 | return err 146 | } 147 | } 148 | 149 | // Check the return value of the call instruction. 150 | if err := walkSSA(instr.Common().Value, visit, visited); err != nil { 151 | return err 152 | } 153 | default: 154 | // fmt.Printf("? check SSA instr %s: %[1]T\n", i) 155 | continue 156 | } 157 | } 158 | 159 | return nil 160 | } 161 | -------------------------------------------------------------------------------- /xss/testdata/src/a/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | func main() { 8 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 9 | w.Write([]byte(r.URL.Query().Get("input"))) // want "potential XSS" 10 | }) 11 | 12 | http.ListenAndServe(":8080", nil) 13 | } 14 | -------------------------------------------------------------------------------- /xss/testdata/src/b/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "net/http" 5 | ) 6 | 7 | func mirror(w http.ResponseWriter, r *http.Request) { 8 | input := r.URL.Query().Get("input") 9 | 10 | b := []byte(input) 11 | 12 | w.Write(b) // want "potential XSS" 13 | } 14 | 15 | func main() { 16 | http.HandleFunc("/", mirror) 17 | 18 | http.ListenAndServe(":8080", nil) 19 | } 20 | -------------------------------------------------------------------------------- /xss/testdata/src/c/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "io" 6 | "net/http" 7 | ) 8 | 9 | func buffer(r io.Reader) io.Reader { 10 | return bufio.NewReader(r) 11 | } 12 | 13 | func mirror(w http.ResponseWriter, r *http.Request) { 14 | _, err := io.Copy(w, buffer(r.Body)) // want "potential XSS" 15 | if err != nil { 16 | panic(err) 17 | } 18 | } 19 | 20 | func mirror2(w http.ResponseWriter, r *http.Request) { 21 | _, err := io.WriteString(w, r.URL.Query().Get("q")) // want "potential XSS" 22 | if err != nil { 23 | panic(err) 24 | } 25 | } 26 | 27 | func mirror3(w http.ResponseWriter, r *http.Request) { 28 | _, err := w.Write([]byte(r.URL.Query().Get("q"))) // want "potential XSS" 29 | if err != nil { 30 | panic(err) 31 | } 32 | } 33 | 34 | func mirror4(w http.ResponseWriter, r *http.Request) { 35 | b, err := io.ReadAll(r.Body) 36 | if err != nil { 37 | panic(err) 38 | } 39 | 40 | _, err = w.Write(b) // want "potential XSS" 41 | if err != nil { 42 | panic(err) 43 | } 44 | } 45 | 46 | func main() { 47 | http.HandleFunc("/1", mirror) 48 | http.HandleFunc("/2", mirror2) 49 | http.HandleFunc("/3", mirror3) 50 | http.HandleFunc("/4", mirror4) 51 | 52 | http.ListenAndServe(":8080", nil) 53 | } 54 | -------------------------------------------------------------------------------- /xss/testdata/src/d/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "html" 5 | "io" 6 | "net/http" 7 | ) 8 | 9 | func mirrorSafe(w http.ResponseWriter, r *http.Request) { 10 | b, err := io.ReadAll(r.Body) 11 | if err != nil { 12 | panic(err) 13 | } 14 | 15 | str := html.EscapeString(string(b)) 16 | 17 | _, err = w.Write([]byte(str)) // safe 18 | if err != nil { 19 | panic(err) 20 | } 21 | } 22 | 23 | func main() { 24 | http.HandleFunc("/mirror-safe", mirrorSafe) 25 | 26 | http.ListenAndServe(":8080", nil) 27 | } 28 | -------------------------------------------------------------------------------- /xss/testdata/src/e/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | ) 7 | 8 | // this will panic if run, because the given *http.Request is not an io.Reader 9 | // but it's fine for testing, because we don't actually run the code. 10 | func echo(w io.Writer, r any) { 11 | ior := r.(io.Reader) 12 | 13 | b, err := io.ReadAll(ior) 14 | if err != nil { 15 | panic(err) 16 | } 17 | 18 | w.Write(b) 19 | } 20 | 21 | func handler(w http.ResponseWriter, r *http.Request) { 22 | echo(w, r) // want "potential XSS" 23 | } 24 | 25 | func main() { 26 | http.HandleFunc("/mirror-safe", handler) 27 | 28 | http.ListenAndServe(":8080", nil) 29 | } 30 | -------------------------------------------------------------------------------- /xss/testdata/src/f/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "io" 5 | "net/http" 6 | ) 7 | 8 | func echo(w io.Writer, r any) { 9 | ior := r.(io.Reader) 10 | 11 | b, err := io.ReadAll(ior) 12 | if err != nil { 13 | panic(err) 14 | } 15 | 16 | w.Write(b) 17 | } 18 | 19 | func handler(w http.ResponseWriter, r *http.Request) { 20 | echo(w, r.Body) // want "potential XSS" 21 | } 22 | 23 | func main() { 24 | http.HandleFunc("/mirror-safe", handler) 25 | 26 | http.ListenAndServe(":8080", nil) 27 | } 28 | -------------------------------------------------------------------------------- /xss/testdata/src/g/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bufio" 5 | "html" 6 | "io" 7 | "net/http" 8 | ) 9 | 10 | func echoSafe(w io.Writer, r any) { 11 | ior := r.(io.Reader) 12 | 13 | b, err := io.ReadAll(ior) 14 | if err != nil { 15 | panic(err) 16 | } 17 | 18 | es := html.EscapeString(string(b)) 19 | 20 | w.Write([]byte(es)) 21 | } 22 | 23 | func echoUnsafe(w io.Writer, r any) { 24 | ior := r.(io.Reader) 25 | 26 | b, err := io.ReadAll(ior) 27 | if err != nil { 28 | panic(err) 29 | } 30 | 31 | w.Write(b) 32 | } 33 | 34 | func handler(w http.ResponseWriter, r *http.Request) { 35 | b := bufio.NewWriterSize(w, 4096) 36 | defer b.Flush() 37 | 38 | switch r.URL.Path { 39 | case "/mirror-safe": 40 | echoSafe(w, r.Body) 41 | case "/mirror-unsafe": 42 | echoUnsafe(w, r.Body) // want "potential XSS" 43 | } 44 | } 45 | 46 | func main() { 47 | http.HandleFunc("/", handler) 48 | 49 | http.ListenAndServe(":8080", nil) 50 | } 51 | -------------------------------------------------------------------------------- /xss/xss.go: -------------------------------------------------------------------------------- 1 | package xss 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | 7 | "github.com/picatz/taint" 8 | "github.com/picatz/taint/callgraphutil" 9 | 10 | "golang.org/x/tools/go/analysis" 11 | "golang.org/x/tools/go/analysis/passes/buildssa" 12 | "golang.org/x/tools/go/ssa" 13 | ) 14 | 15 | var userControlledValues = taint.NewSources( 16 | "*net/http.Request", 17 | ) 18 | 19 | var injectableFunctions = taint.NewSinks( 20 | // Note: at this time, they *must* be a function or method. 21 | "(net/http.ResponseWriter).Write", 22 | "(net/http.ResponseWriter).WriteHeader", 23 | ) 24 | 25 | // Analyzer finds potential XSS issues. 26 | var Analyzer = &analysis.Analyzer{ 27 | Name: "xss", 28 | Doc: "finds potential XSS issues", 29 | Run: run, 30 | Requires: []*analysis.Analyzer{buildssa.Analyzer}, 31 | } 32 | 33 | // imports returns true if the package imports any of the given packages. 34 | func imports(pass *analysis.Pass, pkgs ...string) bool { 35 | var imported bool 36 | for _, imp := range pass.Pkg.Imports() { 37 | for _, pkg := range pkgs { 38 | if strings.HasSuffix(imp.Path(), pkg) { 39 | imported = true 40 | break 41 | } 42 | } 43 | if imported { 44 | break 45 | } 46 | } 47 | return imported 48 | } 49 | 50 | func run(pass *analysis.Pass) (interface{}, error) { 51 | // Require the log package is imported in the 52 | // program being analyzed before running the analysis. 53 | // 54 | // This prevents wasting time analyzing programs that don't log. 55 | if !imports(pass, "net/http") { 56 | return nil, nil 57 | } 58 | 59 | // Get the built SSA IR. 60 | buildSSA := pass.ResultOf[buildssa.Analyzer].(*buildssa.SSA) 61 | 62 | // Identify the main function from the package's SSA IR. 63 | mainFn := buildSSA.Pkg.Func("main") 64 | if mainFn == nil { 65 | return nil, nil 66 | } 67 | 68 | // Construct a callgraph, using the main function as the root, 69 | // constructed of all other functions. This returns a callgraph 70 | // we can use to identify directed paths to logging functions. 71 | cg, err := callgraphutil.NewGraph(mainFn, buildSSA.SrcFuncs...) 72 | if err != nil { 73 | return nil, fmt.Errorf("failed to create new callgraph: %w", err) 74 | } 75 | 76 | // fmt.Println(cg) 77 | 78 | // Run taint check for user controlled values (sources) ending 79 | // up in injectable log functions (sinks). 80 | results := taint.Check(cg, userControlledValues, injectableFunctions) 81 | 82 | for _, result := range results { 83 | // Check if html.EscapeString was called on the source value 84 | // before it was passed to the sink. 85 | var escaped bool 86 | for _, edge := range result.Path { 87 | for _, arg := range edge.Site.Common().Args { 88 | taint.WalkSSA(arg, func(v ssa.Value) error { 89 | call, ok := v.(*ssa.Call) 90 | if !ok { 91 | return nil 92 | } 93 | if call.Call.Value.String() == "html.EscapeString" { 94 | escaped = true 95 | return taint.ErrStopWalk 96 | } 97 | return nil 98 | }) 99 | } 100 | if escaped { 101 | break 102 | } 103 | } 104 | 105 | if !escaped { 106 | pass.Reportf(result.SinkValue.Pos(), "potential XSS") 107 | } 108 | } 109 | 110 | return nil, nil 111 | } 112 | -------------------------------------------------------------------------------- /xss/xss_test.go: -------------------------------------------------------------------------------- 1 | package xss 2 | 3 | import ( 4 | "testing" 5 | 6 | "golang.org/x/tools/go/analysis/analysistest" 7 | ) 8 | 9 | var testdata = analysistest.TestData() 10 | 11 | func TestA(t *testing.T) { 12 | analysistest.Run(t, testdata, Analyzer, "a") 13 | } 14 | 15 | func TestB(t *testing.T) { 16 | analysistest.Run(t, testdata, Analyzer, "b") 17 | } 18 | 19 | func TestC(t *testing.T) { 20 | analysistest.Run(t, testdata, Analyzer, "c") 21 | } 22 | 23 | func TestD(t *testing.T) { 24 | analysistest.Run(t, testdata, Analyzer, "d") 25 | } 26 | 27 | func TestE(t *testing.T) { 28 | analysistest.Run(t, testdata, Analyzer, "e") 29 | } 30 | 31 | func TestF(t *testing.T) { 32 | analysistest.Run(t, testdata, Analyzer, "f") 33 | } 34 | 35 | func TestG(t *testing.T) { 36 | analysistest.Run(t, testdata, Analyzer, "g") 37 | } 38 | --------------------------------------------------------------------------------