├── .gitignore ├── go.mod ├── CHANGELOG.md ├── Makefile ├── line_count.go ├── parse_bias_test.go ├── split_files.go ├── go.sum ├── junit.go ├── circleci.go ├── junit_update.go ├── README.md └── main.go /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | vendor 3 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/leonid-shevtsov/split_tests 2 | 3 | go 1.14 4 | 5 | require ( 6 | github.com/bmatcuk/doublestar v1.3.0 7 | github.com/stretchr/testify v1.10.0 8 | ) 9 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # 0.3.0 2 | 3 | - add `exclude-glob` flag to remove certain test files 4 | 5 | # 0.2.0 6 | 7 | - fix inconsistent splitting for files with the same or missing time, or same line count 8 | - allow globbing for JUnit report files 9 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | release: release-linux release-osx 2 | 3 | release-linux: 4 | $(call release-base,linux) 5 | 6 | release-osx: 7 | $(call release-base,darwin) 8 | 9 | 10 | release-base = \ 11 | mkdir -p build; \ 12 | GOOS=$(1) GOARCH=amd64 go build -ldflags="-s -w" -o build/split_tests; \ 13 | gzip -S .$(1).gz build/split_tests 14 | 15 | clean: 16 | rm -rf build 17 | -------------------------------------------------------------------------------- /line_count.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "os" 7 | ) 8 | 9 | func estimateFileTimesByLineCount(currentFileSet map[string]bool, fileTimes map[string]float64) { 10 | for fileName := range currentFileSet { 11 | file, err := os.Open(fileName) 12 | if err != nil { 13 | printMsg("failed to count lines in file %s: %v\n", fileName, err) 14 | continue 15 | } 16 | defer file.Close() 17 | lineCount, err := lineCounter(file) 18 | if err != nil { 19 | printMsg("failed to count lines in file %s: %v\n", fileName, err) 20 | continue 21 | } 22 | fileTimes[fileName] = float64(lineCount) 23 | } 24 | } 25 | 26 | // Credit to http://stackoverflow.com/a/24563853/6678 27 | func lineCounter(r io.Reader) (int, error) { 28 | buf := make([]byte, 32*1024) 29 | count := 0 30 | lineSep := []byte{'\n'} 31 | 32 | for { 33 | c, err := r.Read(buf) 34 | count += bytes.Count(buf[:c], lineSep) 35 | 36 | switch { 37 | case err == io.EOF: 38 | return count, nil 39 | 40 | case err != nil: 41 | return count, err 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /parse_bias_test.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "github.com/stretchr/testify/require" 8 | ) 9 | 10 | func TestParseBias(t *testing.T) { 11 | examples := []struct { 12 | name string 13 | input string 14 | output []float64 15 | err string 16 | }{ 17 | { 18 | "happy case", 19 | "0=1,1=2.5", 20 | []float64{1, 2.5}, 21 | "", 22 | }, 23 | { 24 | "bad format", 25 | "bad", 26 | nil, 27 | "not a valid bias declaration: bad", 28 | }, 29 | { 30 | "bad pair", 31 | "0=1, bad", 32 | nil, 33 | "not a valid bias declaration: bad", 34 | }, 35 | { 36 | "bad index", 37 | "bad=0.5", 38 | nil, 39 | "failed to parse bias index: strconv.Atoi: parsing \"bad\": invalid syntax", 40 | }, 41 | { 42 | "bad time", 43 | "0=bad", 44 | nil, 45 | "failed to parse bias time: strconv.ParseFloat: parsing \"bad\": invalid syntax", 46 | }, 47 | { 48 | "index out of range", 49 | "3=0.5", 50 | nil, 51 | "bias index is not within the split number: 3", 52 | }, 53 | } 54 | 55 | for _, example := range examples { 56 | t.Run(example.name, func(t *testing.T) { 57 | output, err := parseBias(example.input, 2) 58 | if example.err != "" { 59 | require.EqualError(t, err, example.err) 60 | } else { 61 | require.NoError(t, err) 62 | assert.Equal(t, example.output, output) 63 | } 64 | }) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /split_files.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "sort" 4 | 5 | func splitFiles(biases []float64, fileTimesMap map[string]float64, splitTotal int) ([][]string, []float64) { 6 | buckets := make([][]string, splitTotal) 7 | bucketTimes := make([]float64, splitTotal) 8 | 9 | // Build a sorted list of files 10 | fileTimesList := make(fileTimesList, len(fileTimesMap)) 11 | for file, time := range fileTimesMap { 12 | fileTimesList = append(fileTimesList, fileTimesListItem{file, time}) 13 | } 14 | sort.Sort(fileTimesList) 15 | 16 | for _, file := range fileTimesList { 17 | // find bucket with min weight 18 | minBucket := 0 19 | for bucket := 1; bucket < splitTotal; bucket++ { 20 | if bucketTimes[bucket]+biases[bucket] < bucketTimes[minBucket]+biases[minBucket] { 21 | minBucket = bucket 22 | } 23 | } 24 | // add file to bucket 25 | buckets[minBucket] = append(buckets[minBucket], file.name) 26 | bucketTimes[minBucket] += file.time 27 | } 28 | 29 | return buckets, bucketTimes 30 | } 31 | 32 | type fileTimesListItem struct { 33 | name string 34 | time float64 35 | } 36 | 37 | type fileTimesList []fileTimesListItem 38 | 39 | func (l fileTimesList) Len() int { return len(l) } 40 | 41 | // Sorts by time descending, then by name ascending 42 | // Sort by name is required for deterministic order across machines 43 | func (l fileTimesList) Less(i, j int) bool { 44 | return l[i].time > l[j].time || 45 | (l[i].time == l[j].time && l[i].name < l[j].name) 46 | } 47 | 48 | func (l fileTimesList) Swap(i, j int) { 49 | temp := l[i] 50 | l[i] = l[j] 51 | l[j] = temp 52 | } 53 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/bmatcuk/doublestar v1.3.0 h1:1jLE2y0VpSrOn/QR9G4f2RmrCtkM3AuATcWradjHUvM= 2 | github.com/bmatcuk/doublestar v1.3.0/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE= 3 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 7 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 8 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 9 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 10 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 11 | github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= 12 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 13 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 14 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 15 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 16 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 17 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 18 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 19 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 20 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 21 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 22 | -------------------------------------------------------------------------------- /junit.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/xml" 5 | "io" 6 | "os" 7 | "path" 8 | 9 | "github.com/bmatcuk/doublestar" 10 | ) 11 | 12 | type junitXML struct { 13 | TestCases []struct { 14 | File string `xml:"file,attr"` 15 | Time float64 `xml:"time,attr"` 16 | } `xml:"testcase"` 17 | } 18 | 19 | func loadJUnitXML(reader io.Reader) *junitXML { 20 | var junitXML junitXML 21 | 22 | decoder := xml.NewDecoder(reader) 23 | err := decoder.Decode(&junitXML) 24 | if err != nil { 25 | fatalMsg("failed to parse junit xml: %v\n", err) 26 | } 27 | 28 | return &junitXML 29 | } 30 | 31 | func addFileTimesFromIOReader(fileTimes map[string]float64, reader io.Reader) { 32 | junitXML := loadJUnitXML(reader) 33 | for _, testCase := range junitXML.TestCases { 34 | filePath := path.Clean(testCase.File) 35 | fileTimes[filePath] += testCase.Time 36 | } 37 | } 38 | 39 | // loadJUnitTimingsFromGlob loads test timings from JUnit XML files matching a glob pattern 40 | func loadJUnitTimingsFromGlob(globPattern string) map[string]float64 { 41 | fileTimes := make(map[string]float64) 42 | 43 | if globPattern == "" { 44 | return fileTimes 45 | } 46 | 47 | filenames, err := doublestar.Glob(globPattern) 48 | if err != nil { 49 | fatalMsg("failed to match jUnit filename pattern: %v", err) 50 | } 51 | 52 | if len(filenames) == 0 { 53 | printMsg("warning: no files matched pattern %s\n", globPattern) 54 | return fileTimes 55 | } 56 | 57 | for _, junitFilename := range filenames { 58 | file, err := os.Open(junitFilename) 59 | if err != nil { 60 | fatalMsg("failed to open junit xml: %v\n", err) 61 | } 62 | printMsg("loaded test times from %s\n", junitFilename) 63 | addFileTimesFromIOReader(fileTimes, file) 64 | file.Close() 65 | } 66 | 67 | return fileTimes 68 | } 69 | 70 | func getFileTimesFromJUnitXML(fileTimes map[string]float64) { 71 | if junitXMLPath != "" { 72 | loadedTimes := loadJUnitTimingsFromGlob(junitXMLPath) 73 | for file, time := range loadedTimes { 74 | fileTimes[file] += time 75 | } 76 | } else { 77 | printMsg("using test times from JUnit report at stdin\n") 78 | addFileTimesFromIOReader(fileTimes, os.Stdin) 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /circleci.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "net/http" 7 | "path" 8 | ) 9 | 10 | func getCircleAPIJSON(url string, destination interface{}) { 11 | client := &http.Client{} 12 | req, err := http.NewRequest("GET", url, nil) 13 | req.Header.Add("Accept", "application/json") 14 | resp, err := client.Do(req) 15 | if err != nil { 16 | fatalMsg("error calling CircleCI API at %v: %v", url, err) 17 | } 18 | defer resp.Body.Close() 19 | decoder := json.NewDecoder(resp.Body) 20 | err = decoder.Decode(destination) 21 | if err != nil { 22 | fatalMsg("error parsing CircleCI JSON at %v: %v", url, err) 23 | } 24 | } 25 | 26 | type circleCIBranchList []struct { 27 | BuildNum int `json:"build_num"` 28 | } 29 | 30 | type circleCITestResults struct { 31 | Tests []struct { 32 | File string `json:"file"` 33 | RunTime float64 `json:"run_time"` 34 | } `json:"tests"` 35 | } 36 | 37 | func circleCIAPIURL() string { 38 | return fmt.Sprintf("https://circleci.com/api/v1.1/project/%s", circleCIProjectPrefix) 39 | } 40 | 41 | func getCircleCIBranchBuilds(branchName string) circleCIBranchList { 42 | buildsURL := fmt.Sprintf("%s/tree/%s?filter=successful&circle-token=%s", circleCIAPIURL(), branchName, circleCIAPIKey) 43 | var branchList circleCIBranchList 44 | getCircleAPIJSON(buildsURL, &branchList) 45 | return branchList 46 | } 47 | 48 | func getCircleCITestResults(buildNum int) circleCITestResults { 49 | testResultsURL := fmt.Sprintf("%s/%d/tests?circle-token=%s", circleCIAPIURL(), buildNum, circleCIAPIKey) 50 | var testResults circleCITestResults 51 | getCircleAPIJSON(testResultsURL, &testResults) 52 | return testResults 53 | } 54 | 55 | func getFileTimesFromCircleCI(fileTimes map[string]float64) { 56 | builds := getCircleCIBranchBuilds(circleCIBranchName) 57 | if len(builds) == 0 { 58 | builds = getCircleCIBranchBuilds("master") 59 | } 60 | if len(builds) > 0 { 61 | buildNum := builds[0].BuildNum 62 | printMsg("using test timings from CircleCI build %d\n", buildNum) 63 | 64 | testResults := getCircleCITestResults(buildNum) 65 | 66 | for _, test := range testResults.Tests { 67 | filePath := path.Clean(test.File) 68 | fileTimes[filePath] += test.RunTime 69 | } 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /junit_update.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/xml" 5 | "os" 6 | "strconv" 7 | ) 8 | 9 | var junitUpdateOldGlob string 10 | var junitUpdateNewGlob string 11 | var junitUpdateOutPath string 12 | 13 | const slidingWindowOldWeight = 0.9 14 | 15 | // applySlidingWindow applies an exponential moving average to smooth out timing fluctuations 16 | // Uses exponential moving average: oldWeight * old + (1 - oldWeight) * new 17 | // This gives more weight to historical data while still incorporating recent changes 18 | func applySlidingWindow(oldTime, newTime float64) float64 { 19 | return slidingWindowOldWeight*oldTime + (1-slidingWindowOldWeight)*newTime 20 | } 21 | 22 | 23 | // updateJUnitTimings merges old and new JUnit timings using a sliding window algorithm 24 | func updateJUnitTimings() { 25 | if junitUpdateOldGlob == "" || junitUpdateNewGlob == "" || junitUpdateOutPath == "" { 26 | fatalMsg("junit-update requires -junit-update, -junit-new, and -junit-out flags\n") 27 | } 28 | 29 | // Load old timings 30 | oldTimings := loadJUnitTimingsFromGlob(junitUpdateOldGlob) 31 | printMsg("loaded %d test files from old timings\n", len(oldTimings)) 32 | 33 | // Load new timings 34 | newTimings := loadJUnitTimingsFromGlob(junitUpdateNewGlob) 35 | printMsg("loaded %d test files from new timings\n", len(newTimings)) 36 | 37 | // Merge timings using sliding window algorithm 38 | mergedTimings := make(map[string]float64) 39 | 40 | // Process all tests from new timings 41 | for file, newTime := range newTimings { 42 | oldTime, exists := oldTimings[file] 43 | if exists { 44 | // Test exists in both: apply sliding window 45 | mergedTimings[file] = applySlidingWindow(oldTime, newTime) 46 | } else { 47 | // Test not in old: use new timing 48 | mergedTimings[file] = newTime 49 | } 50 | } 51 | 52 | // Tests not in new are automatically excluded (not added to mergedTimings) 53 | 54 | printMsg("merged %d test files (removed tests not in new, used sliding window for existing tests)\n", len(mergedTimings)) 55 | 56 | // Write output JUnit XML 57 | writeJUnitXML(mergedTimings, junitUpdateOutPath) 58 | printMsg("wrote updated timings to %s\n", junitUpdateOutPath) 59 | } 60 | 61 | // writeJUnitXML writes test timings to a JUnit XML file 62 | func writeJUnitXML(timings map[string]float64, outputPath string) { 63 | // JUnit XML structure for writing (with testsuite root element) 64 | type testCase struct { 65 | File string `xml:"file,attr"` 66 | Time string `xml:"time,attr"` 67 | } 68 | 69 | type testSuite struct { 70 | XMLName xml.Name `xml:"testsuite"` 71 | Name string `xml:"name,attr"` 72 | Tests int `xml:"tests,attr"` 73 | TestCases []testCase `xml:"testcase"` 74 | } 75 | 76 | // Convert map to slice for consistent output 77 | testCases := make([]testCase, 0, len(timings)) 78 | for file, time := range timings { 79 | // Format as decimal without scientific notation 80 | timeStr := strconv.FormatFloat(time, 'f', -1, 64) 81 | testCases = append(testCases, testCase{ 82 | File: file, 83 | Time: timeStr, 84 | }) 85 | } 86 | 87 | suite := testSuite{ 88 | Name: "rspec", 89 | Tests: len(testCases), 90 | TestCases: testCases, 91 | } 92 | 93 | // Create output file 94 | file, err := os.Create(outputPath) 95 | if err != nil { 96 | fatalMsg("failed to create output file: %v\n", err) 97 | } 98 | defer file.Close() 99 | 100 | // Write XML header 101 | file.WriteString(xml.Header) 102 | 103 | // Create encoder 104 | encoder := xml.NewEncoder(file) 105 | encoder.Indent("", " ") 106 | 107 | // Encode XML 108 | err = encoder.Encode(suite) 109 | if err != nil { 110 | fatalMsg("failed to encode JUnit XML: %v\n", err) 111 | } 112 | 113 | file.WriteString("\n") 114 | } 115 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # split_tests 2 | 3 | Splits a test suite into groups of equal time, based on previous tests timings. 4 | 5 | This is necessary for running the tests in parallel. As the execution time of test files might vary drastically, you will not get the best split by simply dividing them into even groups. 6 | 7 | ## Compatibility 8 | 9 | This tool was written for Ruby and CircleCI, but it can be used with any file-based test suite on any CI. 10 | Since then, CircleCI has introduced built-in test splitting. Also since then, the tool has been applied on 11 | GitHub Actions, that does not provide native test splitting. 12 | 13 | There is a [split-tests GitHub Action](https://github.com/marketplace/actions/split-tests) using this tool available on the Actions Marketplace. 14 | 15 | It is written in Golang, released as a binary, and has no external dependencies. 16 | 17 | ## Usage 18 | 19 | Download and extract the latest build from the releases page. 20 | 21 | ### Using the CircleCI API 22 | 23 | Get an API key and set `CIRCLECI_API_KEY` in the project config. 24 | 25 | ``` 26 | rspec $(split_tests -circle-project github.com/leonid-shevtsov/split_tests) 27 | ``` 28 | 29 | (The tool returns the set of files for the current split, joined by spaces.) 30 | 31 | ### Using a JUnit report 32 | 33 | ``` 34 | rspec $(split_tests -junit -junit-path=report.xml -split-index=$CI_NODE_INDEX -split-total=$CI_NODE_TOTAL) 35 | ``` 36 | 37 | Or, if it's easier to pipe the report file: 38 | 39 | ``` 40 | rspec $(curl http://my.junit.url | split_tests -junit -split-index=$CI_NODE_INDEX -split-total=$CI_NODE_TOTAL) 41 | ``` 42 | 43 | #### Stabilizing test timings 44 | 45 | Sometimes test timings fluctuate, so always relying on the latest timings might not make the best split. For such cases, `split_tests` has an update mode that will apply a sliding window average between prior timings and current ones: 46 | 47 | ``` 48 | split_tests -junit-update=old_glob -junit-new=new_glob -junit-out=out.xml 49 | ``` 50 | 51 | Then you take out.xml and use it for the next test run. 52 | 53 | Note that updating also cleans up the files and only preserves one "test case" per file, because that is enough for this tool's purpose. 54 | 55 | ### Naive split by line count 56 | 57 | If you don't have test times, it might be reasonable for your project to assume runtime proportional to test length. 58 | 59 | ``` 60 | rspec $(split_tests -line-count) 61 | ``` 62 | 63 | ### Apply bias 64 | 65 | Often a specific split will not just run the test suite, but also a linter or some other quicker checks. In this case you can use the `bias` argument to balance the split better: 66 | 67 | ``` 68 | # account for 20-second linter run in split 0 69 | split_tests -bias 0=20 -junit ... 70 | ``` 71 | 72 | The effect is that the split algorithm will assume an external delay of 20 seconds for the 0th split, and will reduce its assigned load by 20 seconds (as best it can.) Bias can be negative, too. 73 | 74 | Don't forget to specify the same bias configuration on all runners, not just the ones that have bias. 75 | 76 | This works best when you have real test timings (JUnit or CircleCI mode.) For splits by line count, you can still find the right bias empirically - although splits by line count are never perfectly balanced anyway. 77 | 78 | ### Naive split by file count 79 | 80 | In the absence of prior test times, `split_tests` can still split files into even groups by count. 81 | 82 | ``` 83 | rspec $(split_tests) 84 | ``` 85 | 86 | ## Arguments 87 | 88 | ````plain 89 | $./split_tests -help 90 | 91 | -bias string 92 | Set bias for specific splits (if one split is doing extra work like running a linter). 93 | Format: [split_index]=[bias_in_seconds],[another_index]=[another_bias],... 94 | -circleci-branch string 95 | Current branch for CircleCI (or set CIRCLE_BRANCH) - required to use CircleCI 96 | -circleci-key string 97 | CircleCI API key (or set CIRCLECI_API_KEY environment variable) - required to use CircleCI 98 | -circleci-project string 99 | CircleCI project name (e.g. github/leonid-shevtsov/split_tests) - required to use CircleCI 100 | -exclude-glob string 101 | Glob pattern to exclude test files. Make sure to single-quote. 102 | -glob string 103 | Glob pattern to find test files. Make sure to single-quote to avoid shell expansion. (default "spec/**/*_spec.rb") 104 | -help 105 | Show this help text 106 | -junit 107 | Use a JUnit XML report for test times 108 | -junit-path string 109 | Path to a JUnit XML report (leave empty to read from stdin; use glob pattern to load multiple files) 110 | -junit-new string 111 | Glob pattern for new JUnit XML files (for updating timings with sliding window) 112 | -junit-out string 113 | Output path for updated JUnit XML file (for updating timings with sliding window) 114 | -junit-update string 115 | Glob pattern for old JUnit XML files (for updating timings with sliding window) 116 | -line-count 117 | Use line count to estimate test times 118 | -split-index int 119 | This test container's index (or set CIRCLE_NODE_INDEX) (default -1) 120 | -split-total int 121 | Total number of containers (or set CIRCLE_NODE_TOTAL) (default -1) 122 | ``` 123 | 124 | ## Compilation 125 | 126 | This tool is written in Go and uses Go modules. 127 | 128 | - Install Go 129 | - Checkout the code 130 | - `make` 131 | 132 | --- 133 | 134 | (c) [Leonid Shevtsov](https://leonid.shevtsov.me) 2017-2020 135 | ```` 136 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "os" 7 | "strconv" 8 | "strings" 9 | 10 | "github.com/bmatcuk/doublestar" 11 | ) 12 | 13 | var useCircleCI bool 14 | var useJUnitXML bool 15 | var useLineCount bool 16 | var junitXMLPath string 17 | var testFilePattern = "" 18 | var excludeFilePattern = "" 19 | var circleCIProjectPrefix = "" 20 | var circleCIBranchName string 21 | var splitIndex int 22 | var splitTotal int 23 | var circleCIAPIKey string 24 | var bias string 25 | 26 | func printMsg(msg string, args ...interface{}) { 27 | if len(args) == 0 { 28 | fmt.Fprint(os.Stderr, msg) 29 | } else { 30 | fmt.Fprintf(os.Stderr, msg, args...) 31 | } 32 | } 33 | 34 | func fatalMsg(msg string, args ...interface{}) { 35 | printMsg(msg, args...) 36 | os.Exit(1) 37 | } 38 | 39 | func removeDeletedFiles(fileTimes map[string]float64, currentFileSet map[string]bool) { 40 | for file := range fileTimes { 41 | if !currentFileSet[file] { 42 | delete(fileTimes, file) 43 | } 44 | } 45 | } 46 | 47 | func addNewFiles(fileTimes map[string]float64, currentFileSet map[string]bool) { 48 | averageFileTime := 0.0 49 | if len(fileTimes) > 0 { 50 | for _, time := range fileTimes { 51 | averageFileTime += time 52 | } 53 | averageFileTime /= float64(len(fileTimes)) 54 | } else { 55 | averageFileTime = 1.0 56 | } 57 | 58 | for file := range currentFileSet { 59 | if _, isSet := fileTimes[file]; isSet { 60 | continue 61 | } 62 | if useCircleCI || useJUnitXML { 63 | printMsg("missing file time for %s\n", file) 64 | } 65 | fileTimes[file] = averageFileTime 66 | } 67 | } 68 | 69 | func parseFlags() { 70 | flag.StringVar(&testFilePattern, "glob", "spec/**/*_spec.rb", "Glob pattern to find test files. Make sure to single-quote to avoid shell expansion.") 71 | flag.StringVar(&excludeFilePattern, "exclude-glob", "", "Glob pattern to exclude test files. Make sure to single-quote.") 72 | 73 | flag.IntVar(&splitIndex, "split-index", -1, "This test container's index (or set CIRCLE_NODE_INDEX)") 74 | flag.IntVar(&splitTotal, "split-total", -1, "Total number of containers (or set CIRCLE_NODE_TOTAL)") 75 | 76 | flag.StringVar(&circleCIAPIKey, "circleci-key", "", "CircleCI API key (or set CIRCLECI_API_KEY environment variable) - required to use CircleCI") 77 | flag.StringVar(&circleCIProjectPrefix, "circleci-project", "", "CircleCI project name (e.g. github/leonid-shevtsov/split_tests) - required to use CircleCI") 78 | flag.StringVar(&circleCIBranchName, "circleci-branch", "", "Current branch for CircleCI (or set CIRCLE_BRANCH) - required to use CircleCI") 79 | 80 | flag.BoolVar(&useJUnitXML, "junit", false, "Use a JUnit XML report for test times") 81 | flag.StringVar(&junitXMLPath, "junit-path", "", "Path to a JUnit XML report (leave empty to read from stdin; use glob pattern to load multiple files)") 82 | 83 | flag.BoolVar(&useLineCount, "line-count", false, "Use line count to estimate test times") 84 | 85 | flag.StringVar(&junitUpdateOldGlob, "junit-update", "", "Glob pattern for old JUnit XML files (for updating timings with sliding window)") 86 | flag.StringVar(&junitUpdateNewGlob, "junit-new", "", "Glob pattern for new JUnit XML files (for updating timings with sliding window)") 87 | flag.StringVar(&junitUpdateOutPath, "junit-out", "", "Output path for updated JUnit XML file (for updating timings with sliding window)") 88 | 89 | var showHelp bool 90 | flag.BoolVar(&showHelp, "help", false, "Show this help text") 91 | 92 | flag.StringVar(&bias, "bias", "", "Set bias for specific splits (if one split is doing extra work like running a linter).\nFormat: [split_index]=[bias_in_seconds],[another_index]=[another_bias],...") 93 | 94 | flag.Parse() 95 | 96 | var err error 97 | if circleCIAPIKey == "" { 98 | circleCIAPIKey = os.Getenv("CIRCLECI_API_KEY") 99 | } 100 | if circleCIBranchName == "" { 101 | circleCIBranchName = os.Getenv("CIRCLE_BRANCH") 102 | } 103 | if splitTotal == -1 { 104 | splitTotal, err = strconv.Atoi(os.Getenv("CIRCLE_NODE_TOTAL")) 105 | if err != nil { 106 | splitIndex = -1 107 | } 108 | } 109 | if splitIndex == -1 { 110 | splitIndex, err = strconv.Atoi(os.Getenv("CIRCLE_NODE_INDEX")) 111 | if err != nil { 112 | splitIndex = -1 113 | } 114 | } 115 | 116 | useCircleCI = circleCIAPIKey != "" 117 | 118 | if showHelp { 119 | printMsg("Splits test files into containers of even duration\n\n") 120 | flag.PrintDefaults() 121 | os.Exit(1) 122 | } 123 | 124 | if useCircleCI && (circleCIProjectPrefix == "" || circleCIBranchName == "") { 125 | fatalMsg("Incomplete CircleCI configuration (set -circleci-key, -circleci-project, and -circleci-branch\n") 126 | } 127 | } 128 | 129 | func main() { 130 | parseFlags() 131 | 132 | // If JUnit update mode is enabled, handle it separately and exit 133 | if junitUpdateOldGlob != "" || junitUpdateNewGlob != "" || junitUpdateOutPath != "" { 134 | updateJUnitTimings() 135 | return 136 | } 137 | 138 | // Validate split parameters (not needed in update mode) 139 | if splitTotal == 0 || splitIndex < 0 || splitIndex > splitTotal { 140 | fatalMsg("-split-index and -split-total (and environment variables) are missing or invalid\n") 141 | } 142 | 143 | // We are not using filepath.Glob, 144 | // because it doesn't support '**' (to match all files in all nested directories) 145 | currentFiles, err := doublestar.Glob(testFilePattern) 146 | if err != nil { 147 | fatalMsg("failed to enumerate current file set: %v", err) 148 | } 149 | currentFileSet := make(map[string]bool) 150 | for _, file := range currentFiles { 151 | currentFileSet[file] = true 152 | } 153 | 154 | if excludeFilePattern != "" { 155 | excludedFiles, err := doublestar.Glob(excludeFilePattern) 156 | if err != nil { 157 | fatalMsg("failed to enumerate excluded file set: %v", err) 158 | } 159 | for _, file := range excludedFiles { 160 | delete(currentFileSet, file) 161 | } 162 | } 163 | 164 | fileTimes := make(map[string]float64) 165 | if useLineCount { 166 | estimateFileTimesByLineCount(currentFileSet, fileTimes) 167 | } else if useJUnitXML { 168 | getFileTimesFromJUnitXML(fileTimes) 169 | } else if useCircleCI { 170 | getFileTimesFromCircleCI(fileTimes) 171 | } 172 | 173 | removeDeletedFiles(fileTimes, currentFileSet) 174 | addNewFiles(fileTimes, currentFileSet) 175 | 176 | var biases []float64 177 | if bias != "" { 178 | biases, err = parseBias(bias, splitTotal) 179 | if err != nil { 180 | fatalMsg("failed to parse bias: %v", err) 181 | } 182 | } else { 183 | biases = make([]float64, splitTotal) 184 | } 185 | 186 | buckets, bucketTimes := splitFiles(biases, fileTimes, splitTotal) 187 | if useCircleCI || useJUnitXML { 188 | printMsg("expected test time: %0.1fs\n", bucketTimes[splitIndex]) 189 | } 190 | 191 | fmt.Println(strings.Join(buckets[splitIndex], " ")) 192 | } 193 | 194 | func parseBias(bias string, splitTotal int) ([]float64, error) { 195 | declarations := strings.Split(bias, ",") 196 | biases := make([]float64, splitTotal) 197 | for _, declaration := range declarations { 198 | parts := strings.Split(declaration, "=") 199 | if len(parts) != 2 { 200 | return nil, fmt.Errorf("not a valid bias declaration: %s", declaration) 201 | } 202 | index, err := strconv.Atoi(parts[0]) 203 | if err != nil { 204 | return nil, fmt.Errorf("failed to parse bias index: %w", err) 205 | } 206 | if index < 0 || index >= splitTotal { 207 | return nil, fmt.Errorf("bias index is not within the split number: %d", index) 208 | } 209 | biasSeconds, err := strconv.ParseFloat(parts[1], 64) 210 | if err != nil { 211 | return nil, fmt.Errorf("failed to parse bias time: %w", err) 212 | } 213 | biases[index] = biasSeconds 214 | } 215 | return biases, nil 216 | } 217 | --------------------------------------------------------------------------------