├── .github └── workflows │ └── test.yml ├── .gitignore ├── LICENSE ├── README.md ├── automemlimit.go ├── examples ├── dynamic │ ├── go.mod │ ├── go.sum │ ├── limit.txt │ └── main.go ├── logger │ ├── go.mod │ ├── go.sum │ └── main.go └── system │ ├── go.mod │ ├── go.sum │ └── main.go ├── go.mod ├── go.sum └── memlimit ├── cgroups.go ├── cgroups_linux.go ├── cgroups_linux_test.go ├── cgroups_test.go ├── cgroups_unsupported.go ├── cgroups_unsupported_test.go ├── exp_system.go ├── experiment.go ├── experiment_test.go ├── logger.go ├── memlimit.go ├── memlimit_linux_test.go ├── memlimit_test.go ├── memlimit_unsupported_test.go └── provider.go /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: [ push, pull_request ] 4 | 5 | jobs: 6 | test-ubuntu-22_04: 7 | runs-on: ubuntu-22.04 8 | 9 | steps: 10 | - uses: actions/checkout@v3 11 | 12 | - name: Docker Info 13 | run: | 14 | docker info 15 | 16 | - name: Pull golang image 17 | run: | 18 | docker pull golang:1.22 19 | 20 | - name: Run tests in Go container (1000m) 21 | run: | 22 | docker run --rm -v=$(pwd):/app -w=/app -m=1000m golang:1.22 go test -v ./... -expected=1048576000 -cgroup-version 2 23 | 24 | - name: Run tests in Go container (4321m) 25 | run: | 26 | docker run --rm -v=$(pwd):/app -w=/app -m=4321m golang:1.22 go test -v ./... -expected=4530896896 -cgroup-version 2 27 | 28 | - name: Run tests in Go container (system memory limit) 29 | run: | 30 | docker run --rm -v=$(pwd):/app -w=/app golang:1.22 go test -v ./... -expected-system=$(($(awk '/MemTotal/ {print $2}' /proc/meminfo) * 1024)) -cgroup-version 2 31 | 32 | test-ubuntu-24_04: 33 | runs-on: ubuntu-24.04 34 | 35 | steps: 36 | - uses: actions/checkout@v3 37 | 38 | - name: Docker Info 39 | run: | 40 | docker info 41 | 42 | - name: Pull golang image 43 | run: | 44 | docker pull golang:1.22 45 | 46 | - name: Run tests in Go container (1000m) 47 | run: | 48 | docker run --rm -v=$(pwd):/app -w=/app -m=1000m golang:1.22 go test -v ./... -expected=1048576000 -cgroup-version 2 49 | 50 | - name: Run tests in Go container (4321m) 51 | run: | 52 | docker run --rm -v=$(pwd):/app -w=/app -m=4321m golang:1.22 go test -v ./... -expected=4530896896 -cgroup-version 2 53 | 54 | - name: Run tests in Go container (system memory limit) 55 | run: | 56 | docker run --rm -v=$(pwd):/app -w=/app golang:1.22 go test -v ./... -expected-system=$(($(awk '/MemTotal/ {print $2}' /proc/meminfo) * 1024)) -cgroup-version 2 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Geon Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # automemlimit 2 | 3 | [![Go Reference](https://pkg.go.dev/badge/github.com/KimMachineGun/automemlimit.svg)](https://pkg.go.dev/github.com/KimMachineGun/automemlimit) 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/KimMachineGun/automemlimit)](https://goreportcard.com/report/github.com/KimMachineGun/automemlimit) 5 | [![Test](https://github.com/KimMachineGun/automemlimit/actions/workflows/test.yml/badge.svg?branch=main)](https://github.com/KimMachineGun/automemlimit/actions/workflows/test.yml) 6 | 7 | Automatically set `GOMEMLIMIT` to match Linux [cgroups(7)](https://man7.org/linux/man-pages/man7/cgroups.7.html) memory limit. 8 | 9 | See more details about `GOMEMLIMIT` [here](https://tip.golang.org/doc/gc-guide#Memory_limit). 10 | 11 | ## Notice 12 | 13 | Version `v0.5.0` introduces a fallback to system memory limits as an experimental feature when cgroup limits are unavailable. Activate this by setting `AUTOMEMLIMIT_EXPERIMENT=system`. 14 | You can also use system memory limits via `memlimit.FromSystem` provider directly. 15 | 16 | This feature is under evaluation and might become a default or be removed based on user feedback. 17 | If you have any feedback about this feature, please open an issue. 18 | 19 | ## Installation 20 | 21 | ```shell 22 | go get github.com/KimMachineGun/automemlimit@latest 23 | ``` 24 | 25 | ## Usage 26 | 27 | ```go 28 | package main 29 | 30 | // By default, it sets `GOMEMLIMIT` to 90% of cgroup's memory limit. 31 | // This is equivalent to `memlimit.SetGoMemLimitWithOpts(memlimit.WithLogger(slog.Default()))` 32 | // To disable logging, use `memlimit.SetGoMemLimitWithOpts` directly. 33 | import _ "github.com/KimMachineGun/automemlimit" 34 | ``` 35 | 36 | or 37 | 38 | ```go 39 | package main 40 | 41 | import "github.com/KimMachineGun/automemlimit/memlimit" 42 | 43 | func init() { 44 | memlimit.SetGoMemLimitWithOpts( 45 | memlimit.WithRatio(0.9), 46 | memlimit.WithProvider(memlimit.FromCgroup), 47 | memlimit.WithLogger(slog.Default()), 48 | memlimit.WithRefreshInterval(1*time.Minute), 49 | ) 50 | memlimit.SetGoMemLimitWithOpts( 51 | memlimit.WithRatio(0.9), 52 | memlimit.WithProvider( 53 | memlimit.ApplyFallback( 54 | memlimit.FromCgroup, 55 | memlimit.FromSystem, 56 | ), 57 | ), 58 | memlimit.WithRefreshInterval(1*time.Minute), 59 | ) 60 | memlimit.SetGoMemLimit(0.9) 61 | memlimit.SetGoMemLimitWithProvider(memlimit.Limit(1024*1024), 0.9) 62 | memlimit.SetGoMemLimitWithProvider(memlimit.FromCgroup, 0.9) 63 | memlimit.SetGoMemLimitWithProvider(memlimit.FromCgroupV1, 0.9) 64 | memlimit.SetGoMemLimitWithProvider(memlimit.FromCgroupHybrid, 0.9) 65 | memlimit.SetGoMemLimitWithProvider(memlimit.FromCgroupV2, 0.9) 66 | } 67 | ``` 68 | -------------------------------------------------------------------------------- /automemlimit.go: -------------------------------------------------------------------------------- 1 | package automemlimit 2 | 3 | import ( 4 | "log/slog" 5 | 6 | "github.com/KimMachineGun/automemlimit/memlimit" 7 | ) 8 | 9 | func init() { 10 | memlimit.SetGoMemLimitWithOpts( 11 | memlimit.WithLogger(slog.Default()), 12 | ) 13 | } 14 | -------------------------------------------------------------------------------- /examples/dynamic/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/KimMachineGun/automemlimit/examples/dynamic 2 | 3 | go 1.22.0 4 | 5 | toolchain go1.23.3 6 | 7 | require github.com/KimMachineGun/automemlimit v0.0.0 8 | 9 | require github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect 10 | 11 | replace github.com/KimMachineGun/automemlimit => ../../ 12 | -------------------------------------------------------------------------------- /examples/dynamic/go.sum: -------------------------------------------------------------------------------- 1 | github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0= 2 | github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y= 3 | -------------------------------------------------------------------------------- /examples/dynamic/limit.txt: -------------------------------------------------------------------------------- 1 | 4294967296 2 | -------------------------------------------------------------------------------- /examples/dynamic/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "log/slog" 7 | "os" 8 | "os/signal" 9 | "strconv" 10 | "time" 11 | 12 | "github.com/KimMachineGun/automemlimit/memlimit" 13 | ) 14 | 15 | func init() { 16 | slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stderr, nil))) 17 | 18 | memlimit.SetGoMemLimitWithOpts( 19 | memlimit.WithProvider( 20 | FileProvider("limit.txt"), 21 | ), 22 | memlimit.WithRefreshInterval(5*time.Second), 23 | memlimit.WithLogger(slog.Default()), 24 | ) 25 | } 26 | 27 | func main() { 28 | c := make(chan os.Signal, 1) 29 | signal.Notify(c, os.Interrupt) 30 | 31 | s := <-c 32 | slog.Info("signal captured", slog.Any("signal", s)) 33 | } 34 | 35 | func FileProvider(path string) memlimit.Provider { 36 | return func() (uint64, error) { 37 | b, err := os.ReadFile(path) 38 | if err != nil { 39 | if errors.Is(err, os.ErrNotExist) { 40 | return memlimit.ApplyFallback(memlimit.FromCgroup, memlimit.FromSystem)() 41 | } 42 | return 0, err 43 | } 44 | 45 | b = bytes.TrimSpace(b) 46 | if len(b) == 0 { 47 | return 0, memlimit.ErrNoLimit 48 | } 49 | 50 | return strconv.ParseUint(string(b), 10, 64) 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /examples/logger/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/KimMachineGun/automemlimit/examples/logger 2 | 3 | go 1.22.0 4 | 5 | toolchain go1.23.3 6 | 7 | require github.com/KimMachineGun/automemlimit v0.0.0 8 | 9 | require github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect 10 | 11 | replace github.com/KimMachineGun/automemlimit => ../../ 12 | -------------------------------------------------------------------------------- /examples/logger/go.sum: -------------------------------------------------------------------------------- 1 | github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0= 2 | github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y= 3 | -------------------------------------------------------------------------------- /examples/logger/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "log/slog" 5 | "os" 6 | 7 | "github.com/KimMachineGun/automemlimit/memlimit" 8 | ) 9 | 10 | func init() { 11 | memlimit.SetGoMemLimitWithOpts( 12 | memlimit.WithProvider( 13 | memlimit.Limit(1024*1024*1024), 14 | ), 15 | memlimit.WithLogger(slog.New(slog.NewJSONHandler(os.Stderr, nil))), 16 | ) 17 | } 18 | 19 | func main() {} 20 | -------------------------------------------------------------------------------- /examples/system/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/KimMachineGun/automemlimit/examples/system 2 | 3 | go 1.22.0 4 | 5 | toolchain go1.23.3 6 | 7 | require github.com/KimMachineGun/automemlimit v0.0.0 8 | 9 | require github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect 10 | 11 | replace github.com/KimMachineGun/automemlimit => ../../ 12 | -------------------------------------------------------------------------------- /examples/system/go.sum: -------------------------------------------------------------------------------- 1 | github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0= 2 | github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y= 3 | -------------------------------------------------------------------------------- /examples/system/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "github.com/KimMachineGun/automemlimit/memlimit" 5 | ) 6 | 7 | func init() { 8 | memlimit.SetGoMemLimitWithOpts( 9 | memlimit.WithProvider( 10 | memlimit.ApplyFallback( 11 | memlimit.FromCgroup, 12 | memlimit.FromSystem, 13 | ), 14 | ), 15 | ) 16 | } 17 | 18 | func main() {} 19 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/KimMachineGun/automemlimit 2 | 3 | go 1.22.0 4 | 5 | toolchain go1.23.3 6 | 7 | require github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 8 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 h1:onHthvaw9LFnH4t2DcNVpwGmV9E1BkGknEliJkfwQj0= 2 | github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58/go.mod h1:DXv8WO4yhMYhSNPKjeNKa5WY9YCIEBRbNzFFPJbWO6Y= 3 | -------------------------------------------------------------------------------- /memlimit/cgroups.go: -------------------------------------------------------------------------------- 1 | package memlimit 2 | 3 | import ( 4 | "bufio" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "math" 9 | "os" 10 | "path/filepath" 11 | "slices" 12 | "strconv" 13 | "strings" 14 | ) 15 | 16 | var ( 17 | // ErrNoCgroup is returned when the process is not in cgroup. 18 | ErrNoCgroup = errors.New("process is not in cgroup") 19 | // ErrCgroupsNotSupported is returned when the system does not support cgroups. 20 | ErrCgroupsNotSupported = errors.New("cgroups is not supported on this system") 21 | ) 22 | 23 | // fromCgroup retrieves the memory limit from the cgroup. 24 | // The versionDetector function is used to detect the cgroup version from the mountinfo. 25 | func fromCgroup(versionDetector func(mis []mountInfo) (bool, bool)) (uint64, error) { 26 | mf, err := os.Open("/proc/self/mountinfo") 27 | if err != nil { 28 | return 0, fmt.Errorf("failed to open /proc/self/mountinfo: %w", err) 29 | } 30 | defer mf.Close() 31 | 32 | mis, err := parseMountInfo(mf) 33 | if err != nil { 34 | return 0, fmt.Errorf("failed to parse mountinfo: %w", err) 35 | } 36 | 37 | v1, v2 := versionDetector(mis) 38 | if !(v1 || v2) { 39 | return 0, ErrNoCgroup 40 | } 41 | 42 | cf, err := os.Open("/proc/self/cgroup") 43 | if err != nil { 44 | return 0, fmt.Errorf("failed to open /proc/self/cgroup: %w", err) 45 | } 46 | defer cf.Close() 47 | 48 | chs, err := parseCgroupFile(cf) 49 | if err != nil { 50 | return 0, fmt.Errorf("failed to parse cgroup file: %w", err) 51 | } 52 | 53 | if v2 { 54 | limit, err := getMemoryLimitV2(chs, mis) 55 | if err == nil { 56 | return limit, nil 57 | } else if !v1 { 58 | return 0, err 59 | } 60 | } 61 | 62 | return getMemoryLimitV1(chs, mis) 63 | } 64 | 65 | // detectCgroupVersion detects the cgroup version from the mountinfo. 66 | func detectCgroupVersion(mis []mountInfo) (bool, bool) { 67 | var v1, v2 bool 68 | for _, mi := range mis { 69 | switch mi.FilesystemType { 70 | case "cgroup": 71 | v1 = true 72 | case "cgroup2": 73 | v2 = true 74 | } 75 | } 76 | return v1, v2 77 | } 78 | 79 | // getMemoryLimitV2 retrieves the memory limit from the cgroup v2 controller. 80 | func getMemoryLimitV2(chs []cgroupHierarchy, mis []mountInfo) (uint64, error) { 81 | // find the cgroup v2 path for the memory controller. 82 | // in cgroup v2, the paths are unified and the controller list is empty. 83 | idx := slices.IndexFunc(chs, func(ch cgroupHierarchy) bool { 84 | return ch.HierarchyID == "0" && ch.ControllerList == "" 85 | }) 86 | if idx == -1 { 87 | return 0, errors.New("cgroup v2 path not found") 88 | } 89 | relPath := chs[idx].CgroupPath 90 | 91 | // find the mountpoint for the cgroup v2 controller. 92 | idx = slices.IndexFunc(mis, func(mi mountInfo) bool { 93 | return mi.FilesystemType == "cgroup2" 94 | }) 95 | if idx == -1 { 96 | return 0, errors.New("cgroup v2 mountpoint not found") 97 | } 98 | root, mountPoint := mis[idx].Root, mis[idx].MountPoint 99 | 100 | // resolve the actual cgroup path 101 | cgroupPath, err := resolveCgroupPath(mountPoint, root, relPath) 102 | if err != nil { 103 | return 0, err 104 | } 105 | 106 | // retrieve the memory limit from the memory.max file 107 | return readMemoryLimitV2FromPath(filepath.Join(cgroupPath, "memory.max")) 108 | } 109 | 110 | // readMemoryLimitV2FromPath reads the memory limit for cgroup v2 from the given path. 111 | // this function expects the path to be memory.max file. 112 | func readMemoryLimitV2FromPath(path string) (uint64, error) { 113 | b, err := os.ReadFile(path) 114 | if err != nil { 115 | if errors.Is(err, os.ErrNotExist) { 116 | return 0, ErrNoLimit 117 | } 118 | return 0, fmt.Errorf("failed to read memory.max: %w", err) 119 | } 120 | 121 | slimit := strings.TrimSpace(string(b)) 122 | if slimit == "max" { 123 | return 0, ErrNoLimit 124 | } 125 | 126 | limit, err := strconv.ParseUint(slimit, 10, 64) 127 | if err != nil { 128 | return 0, fmt.Errorf("failed to parse memory.max value: %w", err) 129 | } 130 | 131 | return limit, nil 132 | } 133 | 134 | // getMemoryLimitV1 retrieves the memory limit from the cgroup v1 controller. 135 | func getMemoryLimitV1(chs []cgroupHierarchy, mis []mountInfo) (uint64, error) { 136 | // find the cgroup v1 path for the memory controller. 137 | idx := slices.IndexFunc(chs, func(ch cgroupHierarchy) bool { 138 | return slices.Contains(strings.Split(ch.ControllerList, ","), "memory") 139 | }) 140 | if idx == -1 { 141 | return 0, errors.New("cgroup v1 path for memory controller not found") 142 | } 143 | relPath := chs[idx].CgroupPath 144 | 145 | // find the mountpoint for the cgroup v1 controller. 146 | idx = slices.IndexFunc(mis, func(mi mountInfo) bool { 147 | return mi.FilesystemType == "cgroup" && slices.Contains(strings.Split(mi.SuperOptions, ","), "memory") 148 | }) 149 | if idx == -1 { 150 | return 0, errors.New("cgroup v1 mountpoint for memory controller not found") 151 | } 152 | root, mountPoint := mis[idx].Root, mis[idx].MountPoint 153 | 154 | // resolve the actual cgroup path 155 | cgroupPath, err := resolveCgroupPath(mountPoint, root, relPath) 156 | if err != nil { 157 | return 0, err 158 | } 159 | 160 | // retrieve the memory limit from the memory.stats and memory.limit_in_bytes files. 161 | return readMemoryLimitV1FromPath(cgroupPath) 162 | } 163 | 164 | // getCgroupV1NoLimit returns the maximum value that is used to represent no limit in cgroup v1. 165 | // the max memory limit is max int64, but it should be multiple of the page size. 166 | func getCgroupV1NoLimit() uint64 { 167 | ps := uint64(os.Getpagesize()) 168 | return math.MaxInt64 / ps * ps 169 | } 170 | 171 | // readMemoryLimitV1FromPath reads the memory limit for cgroup v1 from the given path. 172 | // this function expects the path to be the cgroup directory. 173 | func readMemoryLimitV1FromPath(cgroupPath string) (uint64, error) { 174 | // read hierarchical_memory_limit and memory.limit_in_bytes files. 175 | // but if hierarchical_memory_limit is not available, then use the max value as a fallback. 176 | hml, err := readHierarchicalMemoryLimit(filepath.Join(cgroupPath, "memory.stats")) 177 | if err != nil && !errors.Is(err, os.ErrNotExist) { 178 | return 0, fmt.Errorf("failed to read hierarchical_memory_limit: %w", err) 179 | } else if hml == 0 { 180 | hml = math.MaxUint64 181 | } 182 | 183 | // read memory.limit_in_bytes file. 184 | b, err := os.ReadFile(filepath.Join(cgroupPath, "memory.limit_in_bytes")) 185 | if err != nil && !errors.Is(err, os.ErrNotExist) { 186 | return 0, fmt.Errorf("failed to read memory.limit_in_bytes: %w", err) 187 | } 188 | lib, err := strconv.ParseUint(strings.TrimSpace(string(b)), 10, 64) 189 | if err != nil { 190 | return 0, fmt.Errorf("failed to parse memory.limit_in_bytes value: %w", err) 191 | } else if lib == 0 { 192 | hml = math.MaxUint64 193 | } 194 | 195 | // use the minimum value between hierarchical_memory_limit and memory.limit_in_bytes. 196 | // if the limit is the maximum value, then it is considered as no limit. 197 | limit := min(hml, lib) 198 | if limit >= getCgroupV1NoLimit() { 199 | return 0, ErrNoLimit 200 | } 201 | 202 | return limit, nil 203 | } 204 | 205 | // readHierarchicalMemoryLimit extracts hierarchical_memory_limit from memory.stats. 206 | // this function expects the path to be memory.stats file. 207 | func readHierarchicalMemoryLimit(path string) (uint64, error) { 208 | file, err := os.Open(path) 209 | if err != nil { 210 | return 0, err 211 | } 212 | defer file.Close() 213 | 214 | scanner := bufio.NewScanner(file) 215 | for scanner.Scan() { 216 | line := scanner.Text() 217 | 218 | fields := strings.Split(line, " ") 219 | if len(fields) < 2 { 220 | return 0, fmt.Errorf("failed to parse memory.stats %q: not enough fields", line) 221 | } 222 | 223 | if fields[0] == "hierarchical_memory_limit" { 224 | if len(fields) > 2 { 225 | return 0, fmt.Errorf("failed to parse memory.stats %q: too many fields for hierarchical_memory_limit", line) 226 | } 227 | return strconv.ParseUint(fields[1], 10, 64) 228 | } 229 | } 230 | if err := scanner.Err(); err != nil { 231 | return 0, err 232 | } 233 | 234 | return 0, nil 235 | } 236 | 237 | // https://www.man7.org/linux/man-pages/man5/proc_pid_mountinfo.5.html 238 | // 731 771 0:59 /sysrq-trigger /proc/sysrq-trigger ro,nosuid,nodev,noexec,relatime - proc proc rw 239 | // 240 | // 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue 241 | // (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11) 242 | // 243 | // (1) mount ID: a unique ID for the mount (may be reused after umount(2)). 244 | // (2) parent ID: the ID of the parent mount (or of self for the root of this mount namespace's mount tree). 245 | // (3) major:minor: the value of st_dev for files on this filesystem (see stat(2)). 246 | // (4) root: the pathname of the directory in the filesystem which forms the root of this mount. 247 | // (5) mount point: the pathname of the mount point relative to the process's root directory. 248 | // (6) mount options: per-mount options (see mount(2)). 249 | // (7) optional fields: zero or more fields of the form "tag[:value]"; see below. 250 | // (8) separator: the end of the optional fields is marked by a single hyphen. 251 | // (9) filesystem type: the filesystem type in the form "type[.subtype]". 252 | // (10) mount source: filesystem-specific information or "none". 253 | // (11) super options: per-superblock options (see mount(2)). 254 | type mountInfo struct { 255 | Root string 256 | MountPoint string 257 | FilesystemType string 258 | SuperOptions string 259 | } 260 | 261 | // parseMountInfoLine parses a line from the mountinfo file. 262 | func parseMountInfoLine(line string) (mountInfo, error) { 263 | if line == "" { 264 | return mountInfo{}, errors.New("empty line") 265 | } 266 | 267 | fieldss := strings.SplitN(line, " - ", 2) 268 | if len(fieldss) != 2 { 269 | return mountInfo{}, fmt.Errorf("invalid separator") 270 | } 271 | 272 | fields1 := strings.SplitN(fieldss[0], " ", 7) 273 | if len(fields1) < 6 { 274 | return mountInfo{}, fmt.Errorf("not enough fields before separator: %v", fields1) 275 | } else if len(fields1) == 6 { 276 | fields1 = append(fields1, "") 277 | } 278 | 279 | fields2 := strings.SplitN(fieldss[1], " ", 3) 280 | if len(fields2) < 3 { 281 | return mountInfo{}, fmt.Errorf("not enough fields after separator: %v", fields2) 282 | } 283 | 284 | return mountInfo{ 285 | Root: fields1[3], 286 | MountPoint: fields1[4], 287 | FilesystemType: fields2[0], 288 | SuperOptions: fields2[2], 289 | }, nil 290 | } 291 | 292 | // parseMountInfo parses the mountinfo file. 293 | func parseMountInfo(r io.Reader) ([]mountInfo, error) { 294 | var ( 295 | s = bufio.NewScanner(r) 296 | mis []mountInfo 297 | ) 298 | for s.Scan() { 299 | line := s.Text() 300 | 301 | mi, err := parseMountInfoLine(line) 302 | if err != nil { 303 | return nil, fmt.Errorf("failed to parse mountinfo file %q: %w", line, err) 304 | } 305 | 306 | mis = append(mis, mi) 307 | } 308 | if err := s.Err(); err != nil { 309 | return nil, err 310 | } 311 | 312 | return mis, nil 313 | } 314 | 315 | // https://www.man7.org/linux/man-pages/man7/cgroups.7.html 316 | // 317 | // 5:cpuacct,cpu,cpuset:/daemons 318 | // (1) (2) (3) 319 | // 320 | // (1) hierarchy ID: 321 | // 322 | // cgroups version 1 hierarchies, this field 323 | // contains a unique hierarchy ID number that can be 324 | // matched to a hierarchy ID in /proc/cgroups. For the 325 | // cgroups version 2 hierarchy, this field contains the 326 | // value 0. 327 | // 328 | // (2) controller list: 329 | // 330 | // For cgroups version 1 hierarchies, this field 331 | // contains a comma-separated list of the controllers 332 | // bound to the hierarchy. For the cgroups version 2 333 | // hierarchy, this field is empty. 334 | // 335 | // (3) cgroup path: 336 | // 337 | // This field contains the pathname of the control group 338 | // in the hierarchy to which the process belongs. This 339 | // pathname is relative to the mount point of the 340 | // hierarchy. 341 | type cgroupHierarchy struct { 342 | HierarchyID string 343 | ControllerList string 344 | CgroupPath string 345 | } 346 | 347 | // parseCgroupHierarchyLine parses a line from the cgroup file. 348 | func parseCgroupHierarchyLine(line string) (cgroupHierarchy, error) { 349 | if line == "" { 350 | return cgroupHierarchy{}, errors.New("empty line") 351 | } 352 | 353 | fields := strings.Split(line, ":") 354 | if len(fields) < 3 { 355 | return cgroupHierarchy{}, fmt.Errorf("not enough fields: %v", fields) 356 | } else if len(fields) > 3 { 357 | return cgroupHierarchy{}, fmt.Errorf("too many fields: %v", fields) 358 | } 359 | 360 | return cgroupHierarchy{ 361 | HierarchyID: fields[0], 362 | ControllerList: fields[1], 363 | CgroupPath: fields[2], 364 | }, nil 365 | } 366 | 367 | // parseCgroupFile parses the cgroup file. 368 | func parseCgroupFile(r io.Reader) ([]cgroupHierarchy, error) { 369 | var ( 370 | s = bufio.NewScanner(r) 371 | chs []cgroupHierarchy 372 | ) 373 | for s.Scan() { 374 | line := s.Text() 375 | 376 | ch, err := parseCgroupHierarchyLine(line) 377 | if err != nil { 378 | return nil, fmt.Errorf("failed to parse cgroup file %q: %w", line, err) 379 | } 380 | 381 | chs = append(chs, ch) 382 | } 383 | if err := s.Err(); err != nil { 384 | return nil, err 385 | } 386 | 387 | return chs, nil 388 | } 389 | 390 | // resolveCgroupPath resolves the actual cgroup path from the mountpoint, root, and cgroupRelPath. 391 | func resolveCgroupPath(mountpoint, root, cgroupRelPath string) (string, error) { 392 | rel, err := filepath.Rel(root, cgroupRelPath) 393 | if err != nil { 394 | return "", err 395 | } 396 | 397 | // if the relative path is ".", then the cgroupRelPath is the root itself. 398 | if rel == "." { 399 | return mountpoint, nil 400 | } 401 | 402 | // if the relative path starts with "..", then it is outside the root. 403 | if strings.HasPrefix(rel, "..") { 404 | return "", fmt.Errorf("invalid cgroup path: %s is not under root %s", cgroupRelPath, root) 405 | } 406 | 407 | return filepath.Join(mountpoint, rel), nil 408 | } 409 | -------------------------------------------------------------------------------- /memlimit/cgroups_linux.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | // +build linux 3 | 4 | package memlimit 5 | 6 | // FromCgroup retrieves the memory limit from the cgroup. 7 | func FromCgroup() (uint64, error) { 8 | return fromCgroup(detectCgroupVersion) 9 | } 10 | 11 | // FromCgroupV1 retrieves the memory limit from the cgroup v1 controller. 12 | // After v1.0.0, this function could be removed and FromCgroup should be used instead. 13 | func FromCgroupV1() (uint64, error) { 14 | return fromCgroup(func(_ []mountInfo) (bool, bool) { 15 | return true, false 16 | }) 17 | } 18 | 19 | // FromCgroupHybrid retrieves the memory limit from the cgroup v2 and v1 controller sequentially, 20 | // basically, it is equivalent to FromCgroup. 21 | // After v1.0.0, this function could be removed and FromCgroup should be used instead. 22 | func FromCgroupHybrid() (uint64, error) { 23 | return FromCgroup() 24 | } 25 | 26 | // FromCgroupV2 retrieves the memory limit from the cgroup v2 controller. 27 | // After v1.0.0, this function could be removed and FromCgroup should be used instead. 28 | func FromCgroupV2() (uint64, error) { 29 | return fromCgroup(func(_ []mountInfo) (bool, bool) { 30 | return false, true 31 | }) 32 | } 33 | -------------------------------------------------------------------------------- /memlimit/cgroups_linux_test.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | // +build linux 3 | 4 | package memlimit 5 | 6 | import ( 7 | "testing" 8 | ) 9 | 10 | func TestFromCgroup(t *testing.T) { 11 | if expected == 0 { 12 | t.Skip() 13 | } 14 | 15 | limit, err := FromCgroup() 16 | if cgVersion == 0 && err != ErrNoCgroup { 17 | t.Fatalf("FromCgroup() error = %v, wantErr %v", err, ErrNoCgroup) 18 | } 19 | 20 | if err != nil { 21 | t.Fatalf("FromCgroup() error = %v, wantErr %v", err, nil) 22 | } 23 | if limit != expected { 24 | t.Fatalf("FromCgroup() got = %v, want %v", limit, expected) 25 | } 26 | } 27 | 28 | func TestFromCgroupHybrid(t *testing.T) { 29 | if expected == 0 { 30 | t.Skip() 31 | } 32 | 33 | limit, err := FromCgroupHybrid() 34 | if cgVersion == 0 && err != ErrNoCgroup { 35 | t.Fatalf("FromCgroupHybrid() error = %v, wantErr %v", err, ErrNoCgroup) 36 | } 37 | 38 | if err != nil { 39 | t.Fatalf("FromCgroupHybrid() error = %v, wantErr %v", err, nil) 40 | } 41 | if limit != expected { 42 | t.Fatalf("FromCgroupHybrid() got = %v, want %v", limit, expected) 43 | } 44 | } 45 | 46 | func TestFromCgroupV1(t *testing.T) { 47 | if expected == 0 || cgVersion != 1 { 48 | t.Skip() 49 | } 50 | limit, err := FromCgroupV1() 51 | if err != nil { 52 | t.Fatalf("FromCgroupV1() error = %v, wantErr %v", err, nil) 53 | } 54 | if limit != expected { 55 | t.Fatalf("FromCgroupV1() got = %v, want %v", limit, expected) 56 | } 57 | } 58 | 59 | func TestFromCgroupV2(t *testing.T) { 60 | if expected == 0 || cgVersion != 2 { 61 | t.Skip() 62 | } 63 | limit, err := FromCgroupV2() 64 | if err != nil { 65 | t.Fatalf("FromCgroupV2() error = %v, wantErr %v", err, nil) 66 | } 67 | if limit != expected { 68 | t.Fatalf("FromCgroupV2() got = %v, want %v", limit, expected) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /memlimit/cgroups_test.go: -------------------------------------------------------------------------------- 1 | package memlimit 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestParseMountInfoLine(t *testing.T) { 9 | tests := []struct { 10 | name string 11 | input string 12 | want mountInfo 13 | wantErr string 14 | }{ 15 | { 16 | name: "valid line with optional field", 17 | input: "36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue", 18 | want: mountInfo{ 19 | Root: "/mnt1", 20 | MountPoint: "/mnt2", 21 | FilesystemType: "ext3", 22 | SuperOptions: "rw,errors=continue", 23 | }, 24 | }, 25 | { 26 | name: "valid line without optional field", 27 | input: "731 771 0:59 /sysrq-trigger /proc/sysrq-trigger ro,nosuid,nodev,noexec,relatime - proc proc rw", 28 | want: mountInfo{ 29 | Root: "/sysrq-trigger", 30 | MountPoint: "/proc/sysrq-trigger", 31 | FilesystemType: "proc", 32 | SuperOptions: "rw", 33 | }, 34 | }, 35 | { 36 | name: "valid line with minimal fields (no optional fields)", 37 | input: "25 1 0:22 / /dev rw - devtmpfs udev rw", 38 | want: mountInfo{ 39 | Root: "/", 40 | MountPoint: "/dev", 41 | FilesystemType: "devtmpfs", 42 | SuperOptions: "rw", 43 | }, 44 | }, 45 | { 46 | name: "no separator", 47 | input: "36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 ext3 /dev/root rw,errors=continue", 48 | wantErr: `invalid separator`, 49 | }, 50 | { 51 | name: "not enough fields on left side", 52 | input: "36 35 98:0 /mnt1 /mnt2 - ext3 /dev/root rw,errors=continue", 53 | wantErr: `not enough fields before separator: [36 35 98:0 /mnt1 /mnt2]`, 54 | }, 55 | { 56 | name: "not enough fields on right side", 57 | input: "36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3", 58 | wantErr: `not enough fields after separator: [ext3]`, 59 | }, 60 | { 61 | name: "empty line", 62 | input: "", 63 | wantErr: `empty line`, 64 | }, 65 | { 66 | name: "6 fields on left side (no optional field), should add empty optional field", 67 | input: "100 1 8:2 / /data rw - ext4 /dev/sda2 rw,relatime", 68 | want: mountInfo{ 69 | Root: "/", 70 | MountPoint: "/data", 71 | FilesystemType: "ext4", 72 | SuperOptions: "rw,relatime", 73 | }, 74 | }, 75 | { 76 | name: "multiple optional fields on left side (issue #26)", 77 | input: "465 34 253:0 / / rw,relatime shared:409 master:1 - xfs /dev/mapper/fedora-root rw,seclabel,attr2,inode64,logbufs=8,logbsize=32k,noquota", 78 | want: mountInfo{ 79 | Root: "/", 80 | MountPoint: "/", 81 | FilesystemType: "xfs", 82 | SuperOptions: "rw,seclabel,attr2,inode64,logbufs=8,logbsize=32k,noquota", 83 | }, 84 | }, 85 | { 86 | name: "super options have spaces (issue #28)", 87 | input: `1391 1160 0:151 / /Docker/host rw,noatime - 9p C:\134Program\040Files\134Docker\134Docker\134resources rw,dirsync,aname=drvfs;path=C:\Program Files\Docker\Docker\resources;symlinkroot=/mnt/,mmap,access=client,msize=65536,trans=fd,rfd=3,wfd=3`, 88 | want: mountInfo{ 89 | Root: "/", 90 | MountPoint: "/Docker/host", 91 | FilesystemType: "9p", 92 | SuperOptions: `rw,dirsync,aname=drvfs;path=C:\Program Files\Docker\Docker\resources;symlinkroot=/mnt/,mmap,access=client,msize=65536,trans=fd,rfd=3,wfd=3`, 93 | }, 94 | }, 95 | } 96 | 97 | for _, tt := range tests { 98 | t.Run(tt.name, func(t *testing.T) { 99 | got, err := parseMountInfoLine(tt.input) 100 | if tt.wantErr != "" { 101 | if err == nil { 102 | t.Fatalf("expected an error containing %q, got nil", tt.wantErr) 103 | } 104 | if err.Error() != tt.wantErr { 105 | t.Fatalf("expected error containing %q, got %q", tt.wantErr, err.Error()) 106 | } 107 | return 108 | } 109 | 110 | if err != nil { 111 | t.Fatalf("unexpected error: %v", err) 112 | } 113 | 114 | if !reflect.DeepEqual(got, tt.want) { 115 | t.Fatalf("expected %+v, got %+v", tt.want, got) 116 | } 117 | }) 118 | } 119 | } 120 | 121 | func TestParseCgroupHierarchyLine(t *testing.T) { 122 | tests := []struct { 123 | name string 124 | input string 125 | want cgroupHierarchy 126 | wantErr string 127 | }{ 128 | { 129 | name: "valid line with multiple controllers", 130 | input: "5:cpuacct,cpu,cpuset:/daemons", 131 | want: cgroupHierarchy{ 132 | HierarchyID: "5", 133 | ControllerList: "cpuacct,cpu,cpuset", 134 | CgroupPath: "/daemons", 135 | }, 136 | }, 137 | { 138 | name: "valid line with no controllers (cgroup v2)", 139 | input: "0::/system.slice/docker.service", 140 | want: cgroupHierarchy{ 141 | HierarchyID: "0", 142 | ControllerList: "", 143 | CgroupPath: "/system.slice/docker.service", 144 | }, 145 | }, 146 | { 147 | name: "invalid line - only two fields", 148 | input: "5:cpuacct,cpu,cpuset", 149 | wantErr: "not enough fields: [5 cpuacct,cpu,cpuset]", 150 | }, 151 | { 152 | name: "invalid line - too many fields", 153 | input: "5:cpuacct,cpu:cpuset:/daemons:extra", 154 | wantErr: "too many fields: [5 cpuacct,cpu cpuset /daemons extra]", 155 | }, 156 | { 157 | name: "empty line", 158 | input: "", 159 | wantErr: "empty line", 160 | }, 161 | { 162 | name: "line with empty controller list but valid fields", 163 | input: "2::/my_cgroup", 164 | want: cgroupHierarchy{ 165 | HierarchyID: "2", 166 | ControllerList: "", 167 | CgroupPath: "/my_cgroup", 168 | }, 169 | }, 170 | } 171 | 172 | for _, tt := range tests { 173 | t.Run(tt.name, func(t *testing.T) { 174 | got, err := parseCgroupHierarchyLine(tt.input) 175 | if tt.wantErr != "" { 176 | if err == nil { 177 | t.Fatalf("expected an error containing %q, got nil", tt.wantErr) 178 | } 179 | if err.Error() != tt.wantErr { 180 | t.Fatalf("expected error containing %q, got %q", tt.wantErr, err.Error()) 181 | } 182 | return 183 | } 184 | if err != nil { 185 | t.Fatalf("unexpected error: %v", err) 186 | } 187 | 188 | if !reflect.DeepEqual(got, tt.want) { 189 | t.Fatalf("expected %+v, got %+v", tt.want, got) 190 | } 191 | }) 192 | } 193 | } 194 | 195 | func TestResolveCgroupPath(t *testing.T) { 196 | tests := []struct { 197 | name string 198 | mountpoint string 199 | root string 200 | cgroupRelPath string 201 | want string 202 | wantErr string 203 | }{ 204 | { 205 | name: "exact match with both root and cgroupRelPath as '/'", 206 | mountpoint: "/fake/mount", 207 | root: "/", 208 | cgroupRelPath: "/", 209 | want: "/fake/mount", 210 | }, 211 | { 212 | name: "exact match with a non-root path", 213 | mountpoint: "/fake/mount", 214 | root: "/container0", 215 | cgroupRelPath: "/container0", 216 | want: "/fake/mount", 217 | }, 218 | { 219 | name: "valid subpath under root", 220 | mountpoint: "/fake/mount", 221 | root: "/container0", 222 | cgroupRelPath: "/container0/group1", 223 | want: "/fake/mount/group1", 224 | }, 225 | { 226 | name: "invalid cgroup path outside root", 227 | mountpoint: "/fake/mount", 228 | root: "/container0", 229 | cgroupRelPath: "/other_container", 230 | wantErr: "invalid cgroup path: /other_container is not under root /container0", 231 | }, 232 | } 233 | 234 | for _, tt := range tests { 235 | t.Run(tt.name, func(t *testing.T) { 236 | got, err := resolveCgroupPath(tt.mountpoint, tt.root, tt.cgroupRelPath) 237 | if tt.wantErr != "" { 238 | if err == nil { 239 | t.Fatalf("expected an error containing %q, got nil", tt.wantErr) 240 | } 241 | if err.Error() != tt.wantErr { 242 | t.Fatalf("expected error containing %q, got %q", tt.wantErr, err.Error()) 243 | } 244 | return 245 | } 246 | if err != nil { 247 | t.Fatalf("unexpected error: %v", err) 248 | } 249 | 250 | if got != tt.want { 251 | t.Fatalf("expected path %q, got %q", tt.want, got) 252 | } 253 | }) 254 | } 255 | } 256 | -------------------------------------------------------------------------------- /memlimit/cgroups_unsupported.go: -------------------------------------------------------------------------------- 1 | //go:build !linux 2 | // +build !linux 3 | 4 | package memlimit 5 | 6 | func FromCgroup() (uint64, error) { 7 | return 0, ErrCgroupsNotSupported 8 | } 9 | 10 | func FromCgroupV1() (uint64, error) { 11 | return 0, ErrCgroupsNotSupported 12 | } 13 | 14 | func FromCgroupHybrid() (uint64, error) { 15 | return 0, ErrCgroupsNotSupported 16 | } 17 | 18 | func FromCgroupV2() (uint64, error) { 19 | return 0, ErrCgroupsNotSupported 20 | } 21 | -------------------------------------------------------------------------------- /memlimit/cgroups_unsupported_test.go: -------------------------------------------------------------------------------- 1 | //go:build !linux 2 | // +build !linux 3 | 4 | package memlimit 5 | 6 | import ( 7 | "testing" 8 | ) 9 | 10 | func TestFromCgroup(t *testing.T) { 11 | limit, err := FromCgroup() 12 | if err != ErrCgroupsNotSupported { 13 | t.Fatalf("FromCgroup() error = %v, wantErr %v", err, ErrCgroupsNotSupported) 14 | } 15 | if limit != 0 { 16 | t.Fatalf("FromCgroup() got = %v, want %v", limit, 0) 17 | } 18 | } 19 | 20 | func TestFromCgroupV1(t *testing.T) { 21 | limit, err := FromCgroupV1() 22 | if err != ErrCgroupsNotSupported { 23 | t.Fatalf("FromCgroupV1() error = %v, wantErr %v", err, ErrCgroupsNotSupported) 24 | } 25 | if limit != 0 { 26 | t.Fatalf("FromCgroupV1() got = %v, want %v", limit, 0) 27 | } 28 | } 29 | 30 | func TestFromCgroupHybrid(t *testing.T) { 31 | limit, err := FromCgroupHybrid() 32 | if err != ErrCgroupsNotSupported { 33 | t.Fatalf("FromCgroupHybrid() error = %v, wantErr %v", err, ErrCgroupsNotSupported) 34 | } 35 | if limit != 0 { 36 | t.Fatalf("FromCgroupHybrid() got = %v, want %v", limit, 0) 37 | } 38 | } 39 | 40 | func TestFromCgroupV2(t *testing.T) { 41 | limit, err := FromCgroupV2() 42 | if err != ErrCgroupsNotSupported { 43 | t.Fatalf("FromCgroupV2() error = %v, wantErr %v", err, ErrCgroupsNotSupported) 44 | } 45 | if limit != 0 { 46 | t.Fatalf("FromCgroupV2() got = %v, want %v", limit, 0) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /memlimit/exp_system.go: -------------------------------------------------------------------------------- 1 | package memlimit 2 | 3 | import ( 4 | "github.com/pbnjay/memory" 5 | ) 6 | 7 | // FromSystem returns the total memory of the system. 8 | func FromSystem() (uint64, error) { 9 | limit := memory.TotalMemory() 10 | if limit == 0 { 11 | return 0, ErrNoLimit 12 | } 13 | return limit, nil 14 | } 15 | -------------------------------------------------------------------------------- /memlimit/experiment.go: -------------------------------------------------------------------------------- 1 | package memlimit 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "reflect" 7 | "strings" 8 | ) 9 | 10 | const ( 11 | envAUTOMEMLIMIT_EXPERIMENT = "AUTOMEMLIMIT_EXPERIMENT" 12 | ) 13 | 14 | // Experiments is a set of experiment flags. 15 | // It is used to enable experimental features. 16 | // 17 | // You can set the flags by setting the environment variable AUTOMEMLIMIT_EXPERIMENT. 18 | // The value of the environment variable is a comma-separated list of experiment names. 19 | // 20 | // The following experiment names are known: 21 | // 22 | // - none: disable all experiments 23 | // - system: enable fallback to system memory limit 24 | type Experiments struct { 25 | // System enables fallback to system memory limit. 26 | System bool 27 | } 28 | 29 | func parseExperiments() (Experiments, error) { 30 | var exp Experiments 31 | 32 | // Create a map of known experiment names. 33 | names := make(map[string]func(bool)) 34 | rv := reflect.ValueOf(&exp).Elem() 35 | rt := rv.Type() 36 | for i := 0; i < rt.NumField(); i++ { 37 | field := rv.Field(i) 38 | names[strings.ToLower(rt.Field(i).Name)] = field.SetBool 39 | } 40 | 41 | // Parse names. 42 | for _, f := range strings.Split(os.Getenv(envAUTOMEMLIMIT_EXPERIMENT), ",") { 43 | if f == "" { 44 | continue 45 | } 46 | if f == "none" { 47 | exp = Experiments{} 48 | continue 49 | } 50 | val := true 51 | set, ok := names[f] 52 | if !ok { 53 | return Experiments{}, fmt.Errorf("unknown AUTOMEMLIMIT_EXPERIMENT %s", f) 54 | } 55 | set(val) 56 | } 57 | 58 | return exp, nil 59 | } 60 | -------------------------------------------------------------------------------- /memlimit/experiment_test.go: -------------------------------------------------------------------------------- 1 | package memlimit 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | func TestParseExperiments(t *testing.T) { 11 | tests := []struct { 12 | name string 13 | env string 14 | want Experiments 15 | wantErr error 16 | }{ 17 | { 18 | name: "empty", 19 | env: "", 20 | want: Experiments{}, 21 | }, 22 | { 23 | name: "unknown", 24 | env: "unknown", 25 | want: Experiments{}, 26 | wantErr: fmt.Errorf("unknown AUTOMEMLIMIT_EXPERIMENT unknown"), 27 | }, 28 | { 29 | name: "none", 30 | env: "none", 31 | want: Experiments{}, 32 | }, 33 | { 34 | name: "none - with other", 35 | env: "system,none", 36 | want: Experiments{}, 37 | }, 38 | { 39 | name: "system", 40 | env: "system", 41 | want: Experiments{ 42 | System: true, 43 | }, 44 | }, 45 | } 46 | for _, tt := range tests { 47 | t.Run(tt.name, func(t *testing.T) { 48 | exp, ok := os.LookupEnv(envAUTOMEMLIMIT_EXPERIMENT) 49 | t.Cleanup(func() { 50 | if ok { 51 | os.Setenv(envAUTOMEMLIMIT_EXPERIMENT, exp) 52 | } else { 53 | os.Unsetenv(envAUTOMEMLIMIT_EXPERIMENT) 54 | } 55 | }) 56 | 57 | os.Setenv("AUTOMEMLIMIT_EXPERIMENT", tt.env) 58 | exps, err := parseExperiments() 59 | if !reflect.DeepEqual(exps, tt.want) { 60 | t.Errorf("experiments= %#v, want %#v", exps, tt.want) 61 | } 62 | if !reflect.DeepEqual(err, tt.wantErr) { 63 | t.Errorf("err = %#v, want %#v", err, tt.wantErr) 64 | } 65 | }) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /memlimit/logger.go: -------------------------------------------------------------------------------- 1 | package memlimit 2 | 3 | import ( 4 | "context" 5 | "log/slog" 6 | ) 7 | 8 | type noopLogger struct{} 9 | 10 | func (noopLogger) Enabled(context.Context, slog.Level) bool { return false } 11 | func (noopLogger) Handle(context.Context, slog.Record) error { return nil } 12 | func (d noopLogger) WithAttrs([]slog.Attr) slog.Handler { return d } 13 | func (d noopLogger) WithGroup(string) slog.Handler { return d } 14 | -------------------------------------------------------------------------------- /memlimit/memlimit.go: -------------------------------------------------------------------------------- 1 | package memlimit 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "log/slog" 7 | "math" 8 | "os" 9 | "runtime/debug" 10 | "strconv" 11 | "time" 12 | ) 13 | 14 | const ( 15 | envGOMEMLIMIT = "GOMEMLIMIT" 16 | envAUTOMEMLIMIT = "AUTOMEMLIMIT" 17 | // Deprecated: use memlimit.WithLogger instead 18 | envAUTOMEMLIMIT_DEBUG = "AUTOMEMLIMIT_DEBUG" 19 | 20 | defaultAUTOMEMLIMIT = 0.9 21 | ) 22 | 23 | // ErrNoLimit is returned when the memory limit is not set. 24 | var ErrNoLimit = errors.New("memory is not limited") 25 | 26 | type config struct { 27 | logger *slog.Logger 28 | ratio float64 29 | provider Provider 30 | refresh time.Duration 31 | } 32 | 33 | // Option is a function that configures the behavior of SetGoMemLimitWithOptions. 34 | type Option func(cfg *config) 35 | 36 | // WithRatio configures the ratio of the memory limit to set as GOMEMLIMIT. 37 | // 38 | // Default: 0.9 39 | func WithRatio(ratio float64) Option { 40 | return func(cfg *config) { 41 | cfg.ratio = ratio 42 | } 43 | } 44 | 45 | // WithProvider configures the provider. 46 | // 47 | // Default: FromCgroup 48 | func WithProvider(provider Provider) Option { 49 | return func(cfg *config) { 50 | cfg.provider = provider 51 | } 52 | } 53 | 54 | // WithLogger configures the logger. 55 | // It automatically attaches the "package" attribute to the logs. 56 | // 57 | // Default: slog.New(noopLogger{}) 58 | func WithLogger(logger *slog.Logger) Option { 59 | return func(cfg *config) { 60 | cfg.logger = memlimitLogger(logger) 61 | } 62 | } 63 | 64 | // WithRefreshInterval configures the refresh interval for automemlimit. 65 | // If a refresh interval is greater than 0, automemlimit periodically fetches 66 | // the memory limit from the provider and reapplies it if it has changed. 67 | // If the provider returns an error, it logs the error and continues. 68 | // ErrNoLimit is treated as math.MaxInt64. 69 | // 70 | // Default: 0 (no refresh) 71 | func WithRefreshInterval(refresh time.Duration) Option { 72 | return func(cfg *config) { 73 | cfg.refresh = refresh 74 | } 75 | } 76 | 77 | // WithEnv configures whether to use environment variables. 78 | // 79 | // Default: false 80 | // 81 | // Deprecated: currently this does nothing. 82 | func WithEnv() Option { 83 | return func(cfg *config) {} 84 | } 85 | 86 | func memlimitLogger(logger *slog.Logger) *slog.Logger { 87 | if logger == nil { 88 | return slog.New(noopLogger{}) 89 | } 90 | return logger.With(slog.String("package", "github.com/KimMachineGun/automemlimit/memlimit")) 91 | } 92 | 93 | // SetGoMemLimitWithOpts sets GOMEMLIMIT with options and environment variables. 94 | // 95 | // You can configure how much memory of the cgroup's memory limit to set as GOMEMLIMIT 96 | // through AUTOMEMLIMIT environment variable in the half-open range (0.0,1.0]. 97 | // 98 | // If AUTOMEMLIMIT is not set, it defaults to 0.9. (10% is the headroom for memory sources the Go runtime is unaware of.) 99 | // If GOMEMLIMIT is already set or AUTOMEMLIMIT=off, this function does nothing. 100 | // 101 | // If AUTOMEMLIMIT_EXPERIMENT is set, it enables experimental features. 102 | // Please see the documentation of Experiments for more details. 103 | // 104 | // Options: 105 | // - WithRatio 106 | // - WithProvider 107 | // - WithLogger 108 | func SetGoMemLimitWithOpts(opts ...Option) (_ int64, _err error) { 109 | // init config 110 | cfg := &config{ 111 | logger: slog.New(noopLogger{}), 112 | ratio: defaultAUTOMEMLIMIT, 113 | provider: FromCgroup, 114 | } 115 | // TODO: remove this 116 | if debug, ok := os.LookupEnv(envAUTOMEMLIMIT_DEBUG); ok { 117 | defaultLogger := memlimitLogger(slog.Default()) 118 | defaultLogger.Warn("AUTOMEMLIMIT_DEBUG is deprecated, use memlimit.WithLogger instead") 119 | if debug == "true" { 120 | cfg.logger = defaultLogger 121 | } 122 | } 123 | for _, opt := range opts { 124 | opt(cfg) 125 | } 126 | 127 | // log error if any on return 128 | defer func() { 129 | if _err != nil { 130 | cfg.logger.Error("failed to set GOMEMLIMIT", slog.Any("error", _err)) 131 | } 132 | }() 133 | 134 | // parse experiments 135 | exps, err := parseExperiments() 136 | if err != nil { 137 | return 0, fmt.Errorf("failed to parse experiments: %w", err) 138 | } 139 | if exps.System { 140 | cfg.logger.Info("system experiment is enabled: using system memory limit as a fallback") 141 | cfg.provider = ApplyFallback(cfg.provider, FromSystem) 142 | } 143 | 144 | // rollback to previous memory limit on panic 145 | snapshot := debug.SetMemoryLimit(-1) 146 | defer rollbackOnPanic(cfg.logger, snapshot, &_err) 147 | 148 | // check if GOMEMLIMIT is already set 149 | if val, ok := os.LookupEnv(envGOMEMLIMIT); ok { 150 | cfg.logger.Info("GOMEMLIMIT is already set, skipping", slog.String(envGOMEMLIMIT, val)) 151 | return 0, nil 152 | } 153 | 154 | // parse AUTOMEMLIMIT 155 | ratio := cfg.ratio 156 | if val, ok := os.LookupEnv(envAUTOMEMLIMIT); ok { 157 | if val == "off" { 158 | cfg.logger.Info("AUTOMEMLIMIT is set to off, skipping") 159 | return 0, nil 160 | } 161 | ratio, err = strconv.ParseFloat(val, 64) 162 | if err != nil { 163 | return 0, fmt.Errorf("cannot parse AUTOMEMLIMIT: %s", val) 164 | } 165 | } 166 | 167 | // apply ratio to the provider 168 | provider := capProvider(ApplyRatio(cfg.provider, ratio)) 169 | 170 | // set the memory limit and start refresh 171 | limit, err := updateGoMemLimit(uint64(snapshot), provider, cfg.logger) 172 | go refresh(provider, cfg.logger, cfg.refresh) 173 | if err != nil { 174 | if errors.Is(err, ErrNoLimit) { 175 | cfg.logger.Info("memory is not limited, skipping") 176 | // TODO: consider returning the snapshot 177 | return 0, nil 178 | } 179 | return 0, fmt.Errorf("failed to set GOMEMLIMIT: %w", err) 180 | } 181 | 182 | return int64(limit), nil 183 | } 184 | 185 | // updateGoMemLimit updates the Go's memory limit, if it has changed. 186 | func updateGoMemLimit(currLimit uint64, provider Provider, logger *slog.Logger) (uint64, error) { 187 | newLimit, err := provider() 188 | if err != nil { 189 | return 0, err 190 | } 191 | 192 | if newLimit == currLimit { 193 | logger.Debug("GOMEMLIMIT is not changed, skipping", slog.Uint64(envGOMEMLIMIT, newLimit)) 194 | return newLimit, nil 195 | } 196 | 197 | debug.SetMemoryLimit(int64(newLimit)) 198 | logger.Info("GOMEMLIMIT is updated", slog.Uint64(envGOMEMLIMIT, newLimit), slog.Uint64("previous", currLimit)) 199 | 200 | return newLimit, nil 201 | } 202 | 203 | // refresh periodically fetches the memory limit from the provider and reapplies it if it has changed. 204 | // See more details in the documentation of WithRefreshInterval. 205 | func refresh(provider Provider, logger *slog.Logger, refresh time.Duration) { 206 | if refresh == 0 { 207 | return 208 | } 209 | 210 | provider = noErrNoLimitProvider(provider) 211 | 212 | t := time.NewTicker(refresh) 213 | for range t.C { 214 | err := func() (_err error) { 215 | snapshot := debug.SetMemoryLimit(-1) 216 | defer rollbackOnPanic(logger, snapshot, &_err) 217 | 218 | _, err := updateGoMemLimit(uint64(snapshot), provider, logger) 219 | if err != nil { 220 | return err 221 | } 222 | 223 | return nil 224 | }() 225 | if err != nil { 226 | logger.Error("failed to refresh GOMEMLIMIT", slog.Any("error", err)) 227 | } 228 | } 229 | } 230 | 231 | // rollbackOnPanic rollbacks to the snapshot on panic. 232 | // Since it uses recover, it should be called in a deferred function. 233 | func rollbackOnPanic(logger *slog.Logger, snapshot int64, err *error) { 234 | panicErr := recover() 235 | if panicErr != nil { 236 | if *err != nil { 237 | logger.Error("failed to set GOMEMLIMIT", slog.Any("error", *err)) 238 | } 239 | *err = fmt.Errorf("panic during setting the Go's memory limit, rolling back to previous limit %d: %v", 240 | snapshot, panicErr, 241 | ) 242 | debug.SetMemoryLimit(snapshot) 243 | } 244 | } 245 | 246 | // SetGoMemLimitWithEnv sets GOMEMLIMIT with the value from the environment variables. 247 | // Since WithEnv is deprecated, this function is equivalent to SetGoMemLimitWithOpts(). 248 | // Deprecated: use SetGoMemLimitWithOpts instead. 249 | func SetGoMemLimitWithEnv() { 250 | _, _ = SetGoMemLimitWithOpts() 251 | } 252 | 253 | // SetGoMemLimit sets GOMEMLIMIT with the value from the cgroup's memory limit and given ratio. 254 | func SetGoMemLimit(ratio float64) (int64, error) { 255 | return SetGoMemLimitWithOpts(WithRatio(ratio)) 256 | } 257 | 258 | // SetGoMemLimitWithProvider sets GOMEMLIMIT with the value from the given provider and ratio. 259 | func SetGoMemLimitWithProvider(provider Provider, ratio float64) (int64, error) { 260 | return SetGoMemLimitWithOpts(WithProvider(provider), WithRatio(ratio)) 261 | } 262 | 263 | func noErrNoLimitProvider(provider Provider) Provider { 264 | return func() (uint64, error) { 265 | limit, err := provider() 266 | if errors.Is(err, ErrNoLimit) { 267 | return math.MaxInt64, nil 268 | } 269 | return limit, err 270 | } 271 | } 272 | 273 | func capProvider(provider Provider) Provider { 274 | return func() (uint64, error) { 275 | limit, err := provider() 276 | if err != nil { 277 | return 0, err 278 | } else if limit > math.MaxInt64 { 279 | return math.MaxInt64, nil 280 | } 281 | return limit, nil 282 | } 283 | } 284 | -------------------------------------------------------------------------------- /memlimit/memlimit_linux_test.go: -------------------------------------------------------------------------------- 1 | //go:build linux 2 | // +build linux 3 | 4 | package memlimit 5 | 6 | import ( 7 | "flag" 8 | "math" 9 | "os" 10 | "runtime/debug" 11 | "testing" 12 | ) 13 | 14 | var ( 15 | cgVersion uint64 16 | expected uint64 17 | expectedSystem uint64 18 | ) 19 | 20 | func TestMain(m *testing.M) { 21 | flag.Uint64Var(&expected, "expected", 0, "Expected cgroup's memory limit") 22 | flag.Uint64Var(&expectedSystem, "expected-system", 0, "Expected system memory limit") 23 | flag.Uint64Var(&cgVersion, "cgroup-version", 0, "Cgroup version") 24 | flag.Parse() 25 | 26 | os.Exit(m.Run()) 27 | } 28 | 29 | func TestSetGoMemLimit(t *testing.T) { 30 | type args struct { 31 | ratio float64 32 | } 33 | tests := []struct { 34 | name string 35 | args args 36 | want int64 37 | wantErr error 38 | gomemlimit int64 39 | skip bool 40 | }{ 41 | { 42 | name: "0.5", 43 | args: args{ 44 | ratio: 0.5, 45 | }, 46 | want: int64(float64(expected) * 0.5), 47 | wantErr: nil, 48 | gomemlimit: int64(float64(expected) * 0.5), 49 | skip: expected == 0 || cgVersion == 0, 50 | }, 51 | { 52 | name: "0.9", 53 | args: args{ 54 | ratio: 0.9, 55 | }, 56 | want: int64(float64(expected) * 0.9), 57 | wantErr: nil, 58 | gomemlimit: int64(float64(expected) * 0.9), 59 | skip: expected == 0 || cgVersion == 0, 60 | }, 61 | { 62 | name: "Unavailable", 63 | args: args{ 64 | ratio: 0.9, 65 | }, 66 | want: 0, 67 | wantErr: ErrCgroupsNotSupported, 68 | gomemlimit: math.MaxInt64, 69 | skip: cgVersion != 0, 70 | }, 71 | } 72 | for _, tt := range tests { 73 | t.Run(tt.name, func(t *testing.T) { 74 | if tt.skip { 75 | t.Skip() 76 | } 77 | t.Cleanup(func() { 78 | debug.SetMemoryLimit(math.MaxInt64) 79 | }) 80 | got, err := SetGoMemLimit(tt.args.ratio) 81 | if err != tt.wantErr { 82 | t.Errorf("SetGoMemLimit() error = %v, wantErr %v", err, tt.wantErr) 83 | return 84 | } 85 | if got != tt.want { 86 | t.Errorf("SetGoMemLimit() got = %v, want %v", got, tt.want) 87 | } 88 | if debug.SetMemoryLimit(-1) != tt.gomemlimit { 89 | t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) 90 | } 91 | }) 92 | } 93 | } 94 | 95 | func TestSetGoMemLimitWithProvider_WithCgroupProvider(t *testing.T) { 96 | type args struct { 97 | provider Provider 98 | ratio float64 99 | } 100 | tests := []struct { 101 | name string 102 | args args 103 | want int64 104 | wantErr error 105 | gomemlimit int64 106 | skip bool 107 | }{ 108 | { 109 | name: "FromCgroup", 110 | args: args{ 111 | provider: FromCgroup, 112 | ratio: 0.9, 113 | }, 114 | want: int64(float64(expected) * 0.9), 115 | wantErr: nil, 116 | gomemlimit: int64(float64(expected) * 0.9), 117 | skip: expected == 0 || cgVersion == 0, 118 | }, 119 | { 120 | name: "FromCgroup_Unavailable", 121 | args: args{ 122 | provider: FromCgroup, 123 | ratio: 0.9, 124 | }, 125 | want: 0, 126 | wantErr: ErrNoCgroup, 127 | gomemlimit: math.MaxInt64, 128 | skip: expected == 0 || cgVersion != 0, 129 | }, 130 | { 131 | name: "FromCgroupV1", 132 | args: args{ 133 | provider: FromCgroupV1, 134 | ratio: 0.9, 135 | }, 136 | want: int64(float64(expected) * 0.9), 137 | wantErr: nil, 138 | gomemlimit: int64(float64(expected) * 0.9), 139 | skip: expected == 0 || cgVersion != 1, 140 | }, 141 | { 142 | name: "FromCgroupHybrid", 143 | args: args{ 144 | provider: FromCgroupHybrid, 145 | ratio: 0.9, 146 | }, 147 | want: int64(float64(expected) * 0.9), 148 | wantErr: nil, 149 | gomemlimit: int64(float64(expected) * 0.9), 150 | skip: expected == 0 || cgVersion != 1, 151 | }, 152 | { 153 | name: "FromCgroupV2", 154 | args: args{ 155 | provider: FromCgroupV2, 156 | ratio: 0.9, 157 | }, 158 | want: int64(float64(expected) * 0.9), 159 | wantErr: nil, 160 | gomemlimit: int64(float64(expected) * 0.9), 161 | skip: expected == 0 || cgVersion != 2, 162 | }, 163 | } 164 | for _, tt := range tests { 165 | t.Run(tt.name, func(t *testing.T) { 166 | if tt.skip { 167 | t.Skip() 168 | } 169 | t.Cleanup(func() { 170 | debug.SetMemoryLimit(math.MaxInt64) 171 | }) 172 | got, err := SetGoMemLimitWithProvider(tt.args.provider, tt.args.ratio) 173 | if err != tt.wantErr { 174 | t.Errorf("SetGoMemLimitWithProvider() error = %v, wantErr %v", err, tt.wantErr) 175 | return 176 | } 177 | if got != tt.want { 178 | t.Errorf("SetGoMemLimitWithProvider() got = %v, want %v", got, tt.want) 179 | } 180 | if debug.SetMemoryLimit(-1) != tt.gomemlimit { 181 | t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) 182 | } 183 | }) 184 | } 185 | } 186 | 187 | func TestSetGoMemLimitWithProvider_WithSystemProvider(t *testing.T) { 188 | type args struct { 189 | provider Provider 190 | ratio float64 191 | } 192 | tests := []struct { 193 | name string 194 | args args 195 | want int64 196 | wantErr error 197 | gomemlimit int64 198 | skip bool 199 | }{ 200 | { 201 | name: "FromSystem", 202 | args: args{ 203 | provider: FromSystem, 204 | ratio: 0.9, 205 | }, 206 | want: int64(float64(expectedSystem) * 0.9), 207 | wantErr: nil, 208 | gomemlimit: int64(float64(expectedSystem) * 0.9), 209 | skip: expectedSystem == 0, 210 | }, 211 | } 212 | for _, tt := range tests { 213 | t.Run(tt.name, func(t *testing.T) { 214 | if tt.skip { 215 | t.Skip() 216 | } 217 | t.Cleanup(func() { 218 | debug.SetMemoryLimit(math.MaxInt64) 219 | }) 220 | got, err := SetGoMemLimitWithProvider(tt.args.provider, tt.args.ratio) 221 | if err != tt.wantErr { 222 | t.Errorf("SetGoMemLimitWithProvider() error = %v, wantErr %v", err, tt.wantErr) 223 | return 224 | } 225 | if got != tt.want { 226 | t.Errorf("SetGoMemLimitWithProvider() got = %v, want %v", got, tt.want) 227 | } 228 | if debug.SetMemoryLimit(-1) != tt.gomemlimit { 229 | t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) 230 | } 231 | }) 232 | } 233 | } 234 | -------------------------------------------------------------------------------- /memlimit/memlimit_test.go: -------------------------------------------------------------------------------- 1 | package memlimit 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "runtime/debug" 7 | "sync/atomic" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func TestLimit(t *testing.T) { 13 | type args struct { 14 | limit uint64 15 | } 16 | tests := []struct { 17 | name string 18 | args args 19 | want uint64 20 | wantErr error 21 | }{ 22 | { 23 | name: "0bytes", 24 | args: args{ 25 | limit: 0, 26 | }, 27 | want: 0, 28 | wantErr: nil, 29 | }, 30 | { 31 | name: "1kib", 32 | args: args{ 33 | limit: 1024, 34 | }, 35 | want: 1024, 36 | wantErr: nil, 37 | }, 38 | { 39 | name: "1mib", 40 | args: args{ 41 | limit: 1024 * 1024, 42 | }, 43 | want: 1024 * 1024, 44 | wantErr: nil, 45 | }, 46 | { 47 | name: "1gib", 48 | args: args{ 49 | limit: 1024 * 1024 * 1024, 50 | }, 51 | want: 1024 * 1024 * 1024, 52 | wantErr: nil, 53 | }, 54 | } 55 | for _, tt := range tests { 56 | t.Run(tt.name, func(t *testing.T) { 57 | got, err := Limit(tt.args.limit)() 58 | if err != tt.wantErr { 59 | t.Errorf("Limit() error = %v, wantErr %v", err, tt.wantErr) 60 | return 61 | } 62 | if got != tt.want { 63 | t.Errorf("Limit() got = %v, want %v", got, tt.want) 64 | } 65 | }) 66 | } 67 | } 68 | 69 | func TestSetGoMemLimitWithProvider(t *testing.T) { 70 | type args struct { 71 | provider Provider 72 | ratio float64 73 | } 74 | tests := []struct { 75 | name string 76 | args args 77 | want int64 78 | wantErr error 79 | gomemlimit int64 80 | }{ 81 | { 82 | name: "Limit_0.5", 83 | args: args{ 84 | provider: Limit(1024 * 1024 * 1024), 85 | ratio: 0.5, 86 | }, 87 | want: 536870912, 88 | wantErr: nil, 89 | gomemlimit: 536870912, 90 | }, 91 | { 92 | name: "Limit_0.9", 93 | args: args{ 94 | provider: Limit(1024 * 1024 * 1024), 95 | ratio: 0.9, 96 | }, 97 | want: 966367641, 98 | wantErr: nil, 99 | gomemlimit: 966367641, 100 | }, 101 | { 102 | name: "Limit_0.9_math.MaxUint64", 103 | args: args{ 104 | provider: Limit(math.MaxUint64), 105 | ratio: 0.9, 106 | }, 107 | want: math.MaxInt64, 108 | wantErr: nil, 109 | gomemlimit: math.MaxInt64, 110 | }, 111 | { 112 | name: "Limit_0.9_math.MaxUint64", 113 | args: args{ 114 | provider: Limit(math.MaxUint64), 115 | ratio: 0.9, 116 | }, 117 | want: math.MaxInt64, 118 | wantErr: nil, 119 | gomemlimit: math.MaxInt64, 120 | }, 121 | { 122 | name: "Limit_0.45_math.MaxUint64", 123 | args: args{ 124 | provider: Limit(math.MaxUint64), 125 | ratio: 0.45, 126 | }, 127 | want: 8301034833169298432, 128 | wantErr: nil, 129 | gomemlimit: 8301034833169298432, 130 | }, 131 | } 132 | for _, tt := range tests { 133 | t.Run(tt.name, func(t *testing.T) { 134 | t.Cleanup(func() { 135 | debug.SetMemoryLimit(math.MaxInt64) 136 | }) 137 | got, err := SetGoMemLimitWithProvider(tt.args.provider, tt.args.ratio) 138 | if err != tt.wantErr { 139 | t.Errorf("SetGoMemLimitWithProvider() error = %v, wantErr %v", err, tt.wantErr) 140 | return 141 | } 142 | if got != tt.want { 143 | t.Errorf("SetGoMemLimitWithProvider() got = %v, want %v", got, tt.want) 144 | } 145 | if debug.SetMemoryLimit(-1) != tt.gomemlimit { 146 | t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) 147 | } 148 | }) 149 | } 150 | } 151 | 152 | func TestSetGoMemLimitWithOpts(t *testing.T) { 153 | tests := []struct { 154 | name string 155 | opts []Option 156 | want int64 157 | wantErr error 158 | gomemlimit int64 159 | }{ 160 | { 161 | name: "unknown error", 162 | opts: []Option{ 163 | WithProvider(func() (uint64, error) { 164 | return 0, fmt.Errorf("unknown error") 165 | }), 166 | }, 167 | want: 0, 168 | wantErr: fmt.Errorf("failed to set GOMEMLIMIT: unknown error"), 169 | gomemlimit: math.MaxInt64, 170 | }, 171 | { 172 | name: "ErrNoLimit", 173 | opts: []Option{ 174 | WithProvider(func() (uint64, error) { 175 | return 0, ErrNoLimit 176 | }), 177 | }, 178 | want: 0, 179 | wantErr: nil, 180 | gomemlimit: math.MaxInt64, 181 | }, 182 | { 183 | name: "wrapped ErrNoLimit", 184 | opts: []Option{ 185 | WithProvider(func() (uint64, error) { 186 | return 0, fmt.Errorf("wrapped: %w", ErrNoLimit) 187 | }), 188 | }, 189 | want: 0, 190 | wantErr: nil, 191 | gomemlimit: math.MaxInt64, 192 | }, 193 | } 194 | for _, tt := range tests { 195 | t.Run(tt.name, func(t *testing.T) { 196 | got, err := SetGoMemLimitWithOpts(tt.opts...) 197 | if tt.wantErr != nil && err.Error() != tt.wantErr.Error() { 198 | t.Errorf("SetGoMemLimitWithOpts() error = %v, wantErr %v", err, tt.wantErr) 199 | return 200 | } 201 | if got != tt.want { 202 | t.Errorf("SetGoMemLimitWithOpts() got = %v, want %v", got, tt.want) 203 | } 204 | if debug.SetMemoryLimit(-1) != tt.gomemlimit { 205 | t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) 206 | } 207 | }) 208 | } 209 | } 210 | 211 | func TestSetGoMemLimitWithOpts_rollbackOnPanic(t *testing.T) { 212 | t.Cleanup(func() { 213 | debug.SetMemoryLimit(math.MaxInt64) 214 | }) 215 | 216 | limit := int64(987654321) 217 | _ = debug.SetMemoryLimit(987654321) 218 | _, err := SetGoMemLimitWithOpts( 219 | WithProvider(func() (uint64, error) { 220 | debug.SetMemoryLimit(123456789) 221 | panic("panic") 222 | }), 223 | WithRatio(1), 224 | ) 225 | if err == nil { 226 | t.Error("SetGoMemLimitWithOpts() error = nil, want panic") 227 | } 228 | 229 | curr := debug.SetMemoryLimit(-1) 230 | if curr != limit { 231 | t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, limit) 232 | } 233 | } 234 | 235 | func TestSetGoMemLimitWithOpts_WithRefreshInterval(t *testing.T) { 236 | t.Cleanup(func() { 237 | debug.SetMemoryLimit(math.MaxInt64) 238 | }) 239 | 240 | var limit atomic.Int64 241 | output, err := SetGoMemLimitWithOpts( 242 | WithProvider(func() (uint64, error) { 243 | l := limit.Load() 244 | if l == 0 { 245 | return 0, ErrNoLimit 246 | } 247 | return uint64(l), nil 248 | }), 249 | WithRatio(1), 250 | WithRefreshInterval(10*time.Millisecond), 251 | ) 252 | if err != nil { 253 | t.Errorf("SetGoMemLimitWithOpts() error = %v", err) 254 | } else if output != limit.Load() { 255 | t.Errorf("SetGoMemLimitWithOpts() got = %v, want %v", output, limit.Load()) 256 | } 257 | 258 | // 1. no limit 259 | curr := debug.SetMemoryLimit(-1) 260 | if curr != math.MaxInt64 { 261 | t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, limit.Load()) 262 | } 263 | 264 | // 2. max limit 265 | limit.Add(math.MaxInt64) 266 | time.Sleep(100 * time.Millisecond) 267 | 268 | curr = debug.SetMemoryLimit(-1) 269 | if curr != math.MaxInt64 { 270 | t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, int64(math.MaxInt64)) 271 | } 272 | 273 | // 3. adjust limit 274 | limit.Add(-1024) 275 | time.Sleep(100 * time.Millisecond) 276 | 277 | curr = debug.SetMemoryLimit(-1) 278 | if curr != math.MaxInt64-1024 { 279 | t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, int64(math.MaxInt64)-1024) 280 | } 281 | 282 | // 4. no limit again 283 | limit.Store(0) 284 | time.Sleep(100 * time.Millisecond) 285 | 286 | curr = debug.SetMemoryLimit(-1) 287 | if curr != math.MaxInt64 { 288 | t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, int64(math.MaxInt64)) 289 | } 290 | 291 | // 5. new limit 292 | limit.Store(math.MaxInt32) 293 | time.Sleep(100 * time.Millisecond) 294 | 295 | curr = debug.SetMemoryLimit(-1) 296 | if curr != math.MaxInt32 { 297 | t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", curr, math.MaxInt32) 298 | } 299 | } 300 | -------------------------------------------------------------------------------- /memlimit/memlimit_unsupported_test.go: -------------------------------------------------------------------------------- 1 | //go:build !linux 2 | // +build !linux 3 | 4 | package memlimit 5 | 6 | import ( 7 | "errors" 8 | "flag" 9 | "math" 10 | "os" 11 | "runtime/debug" 12 | "testing" 13 | ) 14 | 15 | var expected uint64 16 | 17 | func TestMain(m *testing.M) { 18 | flag.Uint64Var(&expected, "expected", 0, "Expected memory limit") 19 | flag.Parse() 20 | 21 | os.Exit(m.Run()) 22 | } 23 | 24 | func TestSetGoMemLimit(t *testing.T) { 25 | type args struct { 26 | ratio float64 27 | } 28 | tests := []struct { 29 | name string 30 | args args 31 | want int64 32 | wantErr error 33 | gomemlimit int64 34 | }{ 35 | { 36 | name: "0.5", 37 | args: args{ 38 | ratio: 0.5, 39 | }, 40 | want: 0, 41 | wantErr: ErrCgroupsNotSupported, 42 | gomemlimit: math.MaxInt64, 43 | }, 44 | { 45 | name: "0.9", 46 | args: args{ 47 | ratio: 0.9, 48 | }, 49 | want: 0, 50 | wantErr: ErrCgroupsNotSupported, 51 | gomemlimit: math.MaxInt64, 52 | }, 53 | } 54 | for _, tt := range tests { 55 | t.Run(tt.name, func(t *testing.T) { 56 | t.Cleanup(func() { 57 | debug.SetMemoryLimit(math.MaxInt64) 58 | }) 59 | got, err := SetGoMemLimit(tt.args.ratio) 60 | if !errors.Is(err, tt.wantErr) { 61 | t.Errorf("SetGoMemLimit() error = %v, wantErr %v", err, tt.wantErr) 62 | return 63 | } 64 | if got != tt.want { 65 | t.Errorf("SetGoMemLimit() got = %v, want %v", got, tt.want) 66 | } 67 | if debug.SetMemoryLimit(-1) != tt.gomemlimit { 68 | t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) 69 | } 70 | }) 71 | } 72 | } 73 | 74 | func TestSetGoMemLimitWithProvider_WithCgroupProvider(t *testing.T) { 75 | type args struct { 76 | provider Provider 77 | ratio float64 78 | } 79 | tests := []struct { 80 | name string 81 | args args 82 | want int64 83 | wantErr error 84 | }{ 85 | { 86 | name: "FromCgroup", 87 | args: args{ 88 | provider: FromCgroup, 89 | ratio: 0.9, 90 | }, 91 | want: 0, 92 | wantErr: ErrCgroupsNotSupported, 93 | }, 94 | { 95 | name: "FromCgroupV1", 96 | args: args{ 97 | provider: FromCgroupV1, 98 | ratio: 0.9, 99 | }, 100 | want: 0, 101 | wantErr: ErrCgroupsNotSupported, 102 | }, 103 | { 104 | name: "FromCgroupHybrid", 105 | args: args{ 106 | provider: FromCgroupHybrid, 107 | ratio: 0.9, 108 | }, 109 | want: 0, 110 | wantErr: ErrCgroupsNotSupported, 111 | }, 112 | { 113 | name: "FromCgroupV2", 114 | args: args{ 115 | provider: FromCgroupV2, 116 | ratio: 0.9, 117 | }, 118 | want: 0, 119 | wantErr: ErrCgroupsNotSupported, 120 | }, 121 | } 122 | for _, tt := range tests { 123 | t.Run(tt.name, func(t *testing.T) { 124 | got, err := SetGoMemLimitWithProvider(tt.args.provider, tt.args.ratio) 125 | if !errors.Is(err, tt.wantErr) { 126 | t.Errorf("SetGoMemLimitWithProvider() error = %v, wantErr %v", err, tt.wantErr) 127 | return 128 | } 129 | if got != tt.want { 130 | t.Errorf("SetGoMemLimitWithProvider() got = %v, want %v", got, tt.want) 131 | } 132 | }) 133 | } 134 | } 135 | 136 | func TestSetGoMemLimitWithProvider_WithSystemProvider(t *testing.T) { 137 | type args struct { 138 | provider Provider 139 | ratio float64 140 | } 141 | tests := []struct { 142 | name string 143 | args args 144 | want int64 145 | wantErr error 146 | gomemlimit int64 147 | skip bool 148 | }{ 149 | { 150 | name: "FromSystem", 151 | args: args{ 152 | provider: FromSystem, 153 | ratio: 0.9, 154 | }, 155 | want: int64(float64(expected) * 0.9), 156 | wantErr: nil, 157 | gomemlimit: int64(float64(expected) * 0.9), 158 | skip: expected == 0, 159 | }, 160 | } 161 | for _, tt := range tests { 162 | t.Run(tt.name, func(t *testing.T) { 163 | if tt.skip { 164 | t.Skip() 165 | } 166 | t.Cleanup(func() { 167 | debug.SetMemoryLimit(math.MaxInt64) 168 | }) 169 | got, err := SetGoMemLimitWithProvider(tt.args.provider, tt.args.ratio) 170 | if !errors.Is(err, tt.wantErr) { 171 | t.Errorf("SetGoMemLimitWithProvider() error = %v, wantErr %v", err, tt.wantErr) 172 | return 173 | } 174 | if got != tt.want { 175 | t.Errorf("SetGoMemLimitWithProvider() got = %v, want %v", got, tt.want) 176 | } 177 | if debug.SetMemoryLimit(-1) != tt.gomemlimit { 178 | t.Errorf("debug.SetMemoryLimit(-1) got = %v, want %v", debug.SetMemoryLimit(-1), tt.gomemlimit) 179 | } 180 | }) 181 | } 182 | } 183 | -------------------------------------------------------------------------------- /memlimit/provider.go: -------------------------------------------------------------------------------- 1 | package memlimit 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // Provider is a function that returns the memory limit. 8 | type Provider func() (uint64, error) 9 | 10 | // Limit is a helper Provider function that returns the given limit. 11 | func Limit(limit uint64) func() (uint64, error) { 12 | return func() (uint64, error) { 13 | return limit, nil 14 | } 15 | } 16 | 17 | // ApplyRationA is a helper Provider function that applies the given ratio to the given provider. 18 | func ApplyRatio(provider Provider, ratio float64) Provider { 19 | if ratio == 1 { 20 | return provider 21 | } 22 | return func() (uint64, error) { 23 | if ratio <= 0 || ratio > 1 { 24 | return 0, fmt.Errorf("invalid ratio: %f, ratio should be in the range (0.0,1.0]", ratio) 25 | } 26 | limit, err := provider() 27 | if err != nil { 28 | return 0, err 29 | } 30 | return uint64(float64(limit) * ratio), nil 31 | } 32 | } 33 | 34 | // ApplyFallback is a helper Provider function that sets the fallback provider. 35 | func ApplyFallback(provider Provider, fallback Provider) Provider { 36 | return func() (uint64, error) { 37 | limit, err := provider() 38 | if err != nil { 39 | return fallback() 40 | } 41 | return limit, nil 42 | } 43 | } 44 | --------------------------------------------------------------------------------