├── .github └── workflows │ ├── no_external_dependencies.sh │ └── test.yml ├── LICENSE ├── README.md ├── archive.go ├── archive_test.go ├── astikit.go ├── astikit_test.go ├── binary.go ├── binary_test.go ├── bit_flags.go ├── bit_flags_test.go ├── bool.go ├── bool_test.go ├── bytes.go ├── bytes_test.go ├── cache.go ├── cache_test.go ├── defer.go ├── defer_test.go ├── errors.go ├── errors_test.go ├── event.go ├── event_test.go ├── exec.go ├── flag.go ├── flag_test.go ├── float.go ├── float_test.go ├── go.mod ├── http.go ├── http_test.go ├── io.go ├── io_test.go ├── ipc ├── posix │ ├── posix.c │ ├── posix.go │ ├── posix.h │ └── posix_test.go └── systemv │ ├── systemv.c │ ├── systemv.go │ ├── systemv.h │ └── systemv_test.go ├── json.go ├── json_test.go ├── limiter.go ├── limiter_test.go ├── logger.go ├── logger_test.go ├── map.go ├── map_test.go ├── os.go ├── os_js.go ├── os_others.go ├── os_test.go ├── pcm.go ├── pcm_test.go ├── ptr.go ├── rand.go ├── sort.go ├── sort_test.go ├── ssh.go ├── ssh_test.go ├── stat.go ├── stat_test.go ├── sync.go ├── sync_test.go ├── template.go ├── template_test.go ├── testdata ├── archive │ ├── d │ │ └── f │ └── f ├── ipc │ └── f ├── os │ ├── d │ │ ├── d1 │ │ │ └── f11 │ │ ├── d2 │ │ │ ├── d21 │ │ │ │ └── f211 │ │ │ └── f21 │ │ └── f1 │ └── f ├── ssh │ └── f ├── template │ ├── layouts │ │ ├── dir │ │ │ └── layout2.html │ │ ├── dummy.css │ │ └── layout1.html │ └── templates │ │ ├── dir │ │ └── template2.html │ │ ├── dummy.css │ │ └── template1.html └── translator │ ├── d1 │ ├── d2 │ │ └── en.json │ └── en.json │ ├── en.json │ ├── fr.json │ └── invalid.csv ├── time.go ├── time_test.go ├── translator.go ├── translator_test.go ├── worker.go └── worker_test.go /.github/workflows/no_external_dependencies.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "$(go list -m all)" != "github.com/asticode/go-astikit" ]; then 4 | echo "This repo doesn't allow any external dependencies" 5 | exit 1 6 | else 7 | echo "cheers!" 8 | fi -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: [ "master", "dev" ] 6 | pull_request: 7 | branches: [ "master" ] 8 | 9 | jobs: 10 | 11 | build: 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | os: [ubuntu-latest, macos-latest, windows-latest] 16 | 17 | runs-on: ${{ matrix.os }} 18 | steps: 19 | - uses: actions/checkout@v3 20 | 21 | - name: Set up Go 22 | uses: actions/setup-go@v4 23 | with: 24 | go-version: '1.20' 25 | 26 | - name: Check external dependencies 27 | run: | 28 | bash .github/workflows/no_external_dependencies.sh 29 | 30 | - name: Install dependencies 31 | run: go mod download 32 | 33 | - name: Run tests 34 | run: go test -race -covermode atomic -coverprofile=covprofile ./... 35 | 36 | - if: github.event_name != 'pull_request' 37 | name: Send coverage 38 | env: 39 | COVERALLS_TOKEN: ${{ secrets.COVERALLS_TOKEN }} 40 | run: | 41 | go install github.com/mattn/goveralls@latest 42 | goveralls -coverprofile=covprofile -service=github 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Quentin Renard 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![GoReportCard](http://goreportcard.com/badge/github.com/asticode/go-astikit)](http://goreportcard.com/report/github.com/asticode/go-astikit) 2 | [![GoDoc](https://godoc.org/github.com/asticode/go-astikit?status.svg)](https://godoc.org/github.com/asticode/go-astikit) 3 | [![Test](https://github.com/asticode/go-astikit/actions/workflows/test.yml/badge.svg)](https://github.com/asticode/go-astikit/actions/workflows/test.yml) 4 | [![Coveralls](https://coveralls.io/repos/github/asticode/go-astikit/badge.svg?branch=master)](https://coveralls.io/github/asticode/go-astikit) 5 | 6 | `astikit` is a set of golang helpers that don't require any external dependencies. -------------------------------------------------------------------------------- /archive.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "archive/zip" 5 | "context" 6 | "fmt" 7 | "io" 8 | "os" 9 | "path/filepath" 10 | "strings" 11 | ) 12 | 13 | // internal shouldn't lead with a "/" 14 | func zipInternalPath(p string) (external, internal string) { 15 | if items := strings.Split(p, ".zip"); len(items) > 1 { 16 | external = items[0] + ".zip" 17 | internal = strings.TrimPrefix(strings.Join(items[1:], ".zip"), string(os.PathSeparator)) 18 | return 19 | } 20 | external = p 21 | return 22 | } 23 | 24 | // Zip zips a src into a dst 25 | // Possible dst formats are: 26 | // - /path/to/zip.zip 27 | // - /path/to/zip.zip/root/path 28 | func Zip(ctx context.Context, dst, src string) (err error) { 29 | // Get external/internal path 30 | externalPath, internalPath := zipInternalPath(dst) 31 | 32 | // Make sure the directory exists 33 | if err = os.MkdirAll(filepath.Dir(externalPath), DefaultDirMode); err != nil { 34 | return fmt.Errorf("astikit: mkdirall %s failed: %w", filepath.Dir(externalPath), err) 35 | } 36 | 37 | // Create destination file 38 | var dstFile *os.File 39 | if dstFile, err = os.Create(externalPath); err != nil { 40 | return fmt.Errorf("astikit: creating %s failed: %w", externalPath, err) 41 | } 42 | defer dstFile.Close() 43 | 44 | // Create zip writer 45 | var zw = zip.NewWriter(dstFile) 46 | defer zw.Close() 47 | 48 | // Make sure to clean dir path so that we get consistent path separator with filepath.Walk 49 | src = filepath.Clean(src) 50 | 51 | // Walk 52 | if err = filepath.Walk(src, func(path string, info os.FileInfo, e error) (err error) { 53 | // Process error 54 | if e != nil { 55 | err = e 56 | return 57 | } 58 | 59 | // Init header 60 | var h *zip.FileHeader 61 | if h, err = zip.FileInfoHeader(info); err != nil { 62 | return fmt.Errorf("astikit: initializing zip header failed: %w", err) 63 | } 64 | 65 | // Set header info 66 | h.Name = filepath.Join(internalPath, strings.TrimPrefix(path, src)) 67 | if info.IsDir() { 68 | h.Name += string(os.PathSeparator) 69 | } else { 70 | h.Method = zip.Deflate 71 | } 72 | 73 | // Create writer 74 | var w io.Writer 75 | if w, err = zw.CreateHeader(h); err != nil { 76 | return fmt.Errorf("astikit: creating zip header failed: %w", err) 77 | } 78 | 79 | // If path is dir, stop here 80 | if info.IsDir() { 81 | return 82 | } 83 | 84 | // Open path 85 | var walkFile *os.File 86 | if walkFile, err = os.Open(path); err != nil { 87 | return fmt.Errorf("astikit: opening %s failed: %w", path, err) 88 | } 89 | defer walkFile.Close() 90 | 91 | // Copy 92 | if _, err = Copy(ctx, w, walkFile); err != nil { 93 | return fmt.Errorf("astikit: copying failed: %w", err) 94 | } 95 | return 96 | }); err != nil { 97 | return fmt.Errorf("astikit: walking failed: %w", err) 98 | } 99 | return 100 | } 101 | 102 | // Unzip unzips a src into a dst 103 | // Possible src formats are: 104 | // - /path/to/zip.zip 105 | // - /path/to/zip.zip/root/path 106 | func Unzip(ctx context.Context, dst, src string) (err error) { 107 | // Get external/internal path 108 | externalPath, internalPath := zipInternalPath(src) 109 | 110 | // Make sure the destination exists 111 | if err = os.MkdirAll(dst, DefaultDirMode); err != nil { 112 | return fmt.Errorf("astikit: mkdirall %s failed: %w", dst, err) 113 | } 114 | 115 | // Open overall reader 116 | var r *zip.ReadCloser 117 | if r, err = zip.OpenReader(externalPath); err != nil { 118 | return fmt.Errorf("astikit: opening overall zip reader on %s failed: %w", externalPath, err) 119 | } 120 | defer r.Close() 121 | 122 | // Loop through files to determine their type 123 | var dirs, files, symlinks = make(map[string]*zip.File), make(map[string]*zip.File), make(map[string]*zip.File) 124 | for _, f := range r.File { 125 | // Validate internal path 126 | if internalPath != "" && !strings.HasPrefix(f.Name, internalPath) { 127 | continue 128 | } 129 | var p = filepath.Join(dst, strings.TrimPrefix(f.Name, internalPath)) 130 | 131 | // Check file type 132 | if f.FileInfo().Mode()&os.ModeSymlink != 0 { 133 | symlinks[p] = f 134 | } else if f.FileInfo().IsDir() { 135 | dirs[p] = f 136 | } else { 137 | files[p] = f 138 | } 139 | } 140 | 141 | // Invalid internal path 142 | if internalPath != "" && len(dirs) == 0 && len(files) == 0 && len(symlinks) == 0 { 143 | return fmt.Errorf("astikit: content in archive does not match specified internal path %s", internalPath) 144 | } 145 | 146 | // Create dirs 147 | for p, f := range dirs { 148 | if err = os.MkdirAll(p, f.FileInfo().Mode().Perm()); err != nil { 149 | return fmt.Errorf("astikit: mkdirall %s failed: %w", p, err) 150 | } 151 | } 152 | 153 | // Create files 154 | for p, f := range files { 155 | if err = createZipFile(ctx, f, p); err != nil { 156 | return fmt.Errorf("astikit: creating zip file into %s failed: %w", p, err) 157 | } 158 | } 159 | 160 | // Create symlinks 161 | for p, f := range symlinks { 162 | if err = createZipSymlink(f, p); err != nil { 163 | return fmt.Errorf("astikit: creating zip symlink into %s failed: %w", p, err) 164 | } 165 | } 166 | return 167 | } 168 | 169 | func createZipFile(ctx context.Context, f *zip.File, p string) (err error) { 170 | // Open file reader 171 | var fr io.ReadCloser 172 | if fr, err = f.Open(); err != nil { 173 | return fmt.Errorf("astikit: opening zip reader on file %s failed: %w", f.Name, err) 174 | } 175 | defer fr.Close() 176 | 177 | // Since dirs don't always come up we make sure the directory of the file exists with default 178 | // file mode 179 | if err = os.MkdirAll(filepath.Dir(p), DefaultDirMode); err != nil { 180 | return fmt.Errorf("astikit: mkdirall %s failed: %w", filepath.Dir(p), err) 181 | } 182 | 183 | // Open the file 184 | var fl *os.File 185 | if fl, err = os.OpenFile(p, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.FileInfo().Mode().Perm()); err != nil { 186 | return fmt.Errorf("astikit: opening file %s failed: %w", p, err) 187 | } 188 | defer fl.Close() 189 | 190 | // Copy 191 | if _, err = Copy(ctx, fl, fr); err != nil { 192 | return fmt.Errorf("astikit: copying %s into %s failed: %w", f.Name, p, err) 193 | } 194 | return 195 | } 196 | 197 | func createZipSymlink(f *zip.File, p string) (err error) { 198 | // Open file reader 199 | var fr io.ReadCloser 200 | if fr, err = f.Open(); err != nil { 201 | return fmt.Errorf("astikit: opening zip reader on file %s failed: %w", f.Name, err) 202 | } 203 | defer fr.Close() 204 | 205 | // If file is a symlink we retrieve the target path that is in the content of the file 206 | var b []byte 207 | if b, err = io.ReadAll(fr); err != nil { 208 | return fmt.Errorf("astikit: reading all %s failed: %w", f.Name, err) 209 | } 210 | 211 | // Create the symlink 212 | if err = os.Symlink(string(b), p); err != nil { 213 | return fmt.Errorf("astikit: creating symlink from %s to %s failed: %w", string(b), p, err) 214 | } 215 | return 216 | } 217 | -------------------------------------------------------------------------------- /archive_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "context" 5 | "path/filepath" 6 | "testing" 7 | ) 8 | 9 | func TestZip(t *testing.T) { 10 | // Get temp dir 11 | dir := t.TempDir() 12 | 13 | // With internal path 14 | i := "testdata/archive" 15 | f := filepath.Join(dir, "with-internal", "f.zip/root") 16 | err := Zip(context.Background(), f, i) 17 | if err != nil { 18 | t.Fatalf("expected no error, got %+v", err) 19 | } 20 | d := filepath.Join(dir, "with-internal", "d") 21 | err = Unzip(context.Background(), d, filepath.Join(dir, "with-internal", "f.zip/invalid")) 22 | if err == nil { 23 | t.Fatal("expected error, got nil") 24 | } 25 | err = Unzip(context.Background(), d, f) 26 | if err != nil { 27 | t.Fatalf("expected no error, got %+v", err) 28 | } 29 | compareDir(t, i, d) 30 | 31 | // Without internal path 32 | f = filepath.Join(dir, "without-internal", "f.zip") 33 | err = Zip(context.Background(), f, i) 34 | if err != nil { 35 | t.Fatalf("expected no error, got %+v", err) 36 | } 37 | d = filepath.Join(dir, "without-internal", "d") 38 | err = Unzip(context.Background(), d, f) 39 | if err != nil { 40 | t.Fatalf("expected no error, got %+v", err) 41 | } 42 | compareDir(t, i, d) 43 | } 44 | -------------------------------------------------------------------------------- /astikit.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "os" 5 | ) 6 | 7 | // Default modes 8 | var ( 9 | DefaultDirMode os.FileMode = 0755 10 | ) 11 | -------------------------------------------------------------------------------- /astikit_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "reflect" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | func fileContent(t *testing.T, path string) string { 12 | b, err := os.ReadFile(path) 13 | if err != nil { 14 | t.Fatalf("expected no error, got %+v", err) 15 | } 16 | return string(b) 17 | } 18 | 19 | func checkFile(t *testing.T, p string, e string) { 20 | if g := fileContent(t, p); e != g { 21 | t.Fatalf("expected %s, got %s", e, g) 22 | } 23 | } 24 | 25 | func compareFile(t *testing.T, expectedPath, gotPath string) { 26 | if e, g := fileContent(t, expectedPath), fileContent(t, gotPath); e != g { 27 | t.Fatalf("expected %s, got %s", e, g) 28 | } 29 | } 30 | 31 | func dirContent(t *testing.T, dir string) (o map[string]string) { 32 | // Make sure to clean dir path so that we get consistent path separator with filepath.Walk 33 | dir = filepath.Clean(dir) 34 | 35 | // Walk 36 | o = make(map[string]string) 37 | err := filepath.Walk(dir, func(path string, info os.FileInfo, e error) (err error) { 38 | // Check error 39 | if e != nil { 40 | return e 41 | } 42 | 43 | // Don't process dirs 44 | if info.IsDir() { 45 | return 46 | } 47 | 48 | // Read 49 | var b []byte 50 | if b, err = os.ReadFile(path); err != nil { 51 | return 52 | } 53 | 54 | // Add to map 55 | o[strings.TrimPrefix(path, dir)] = string(b) 56 | return 57 | }) 58 | if err != nil { 59 | t.Fatalf("expected no error, got %+v", err) 60 | } 61 | return 62 | } 63 | 64 | func checkDir(t *testing.T, p string, e map[string]string) { 65 | for k, v := range e { 66 | delete(e, k) 67 | e[filepath.Clean(k)] = v 68 | } 69 | if g := dirContent(t, p); !reflect.DeepEqual(e, g) { 70 | t.Fatalf("expected %s, got %s", e, g) 71 | } 72 | } 73 | 74 | func compareDir(t *testing.T, ePath, gPath string) { 75 | if e, g := dirContent(t, ePath), dirContent(t, gPath); !reflect.DeepEqual(e, g) { 76 | t.Fatalf("expected %+v, got %+v", e, g) 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /binary.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "encoding/binary" 5 | "errors" 6 | "io" 7 | ) 8 | 9 | // BitsWriter represents an object that can write individual bits into a writer 10 | // in a developer-friendly way. Check out the Write method for more information. 11 | // This is particularly helpful when you want to build a slice of bytes based 12 | // on individual bits for testing purposes. 13 | type BitsWriter struct { 14 | bo binary.ByteOrder 15 | cache byte 16 | cacheLen byte 17 | bsCache []byte 18 | w io.Writer 19 | writeCb BitsWriterWriteCallback 20 | } 21 | 22 | type BitsWriterWriteCallback func([]byte) 23 | 24 | // BitsWriterOptions represents BitsWriter options 25 | type BitsWriterOptions struct { 26 | ByteOrder binary.ByteOrder 27 | // WriteCallback is called every time when full byte is written 28 | WriteCallback BitsWriterWriteCallback 29 | Writer io.Writer 30 | } 31 | 32 | // NewBitsWriter creates a new BitsWriter 33 | func NewBitsWriter(o BitsWriterOptions) (w *BitsWriter) { 34 | w = &BitsWriter{ 35 | bo: o.ByteOrder, 36 | bsCache: make([]byte, 1), 37 | w: o.Writer, 38 | writeCb: o.WriteCallback, 39 | } 40 | if w.bo == nil { 41 | w.bo = binary.BigEndian 42 | } 43 | return 44 | } 45 | 46 | func (w *BitsWriter) SetWriteCallback(cb BitsWriterWriteCallback) { 47 | w.writeCb = cb 48 | } 49 | 50 | // Write writes bits into the writer. Bits are only written when there are 51 | // enough to create a byte. When using a string or a bool, bits are added 52 | // from left to right as if 53 | // Available types are: 54 | // - string("10010"): processed as n bits, n being the length of the input 55 | // - []byte: processed as n bytes, n being the length of the input 56 | // - bool: processed as one bit 57 | // - uint8/uint16/uint32/uint64: processed as n bits, if type is uintn 58 | func (w *BitsWriter) Write(i any) error { 59 | // Transform input into "10010" format 60 | 61 | switch a := i.(type) { 62 | case string: 63 | for _, r := range a { 64 | var err error 65 | if r == '1' { 66 | err = w.writeBit(1) 67 | } else { 68 | err = w.writeBit(0) 69 | } 70 | if err != nil { 71 | return err 72 | } 73 | } 74 | case []byte: 75 | return w.writeByteSlice(a) 76 | case bool: 77 | if a { 78 | return w.writeBit(1) 79 | } else { 80 | return w.writeBit(0) 81 | } 82 | case uint8: 83 | return w.writeFullByte(a) 84 | case uint16: 85 | return w.writeFullInt(uint64(a), 2) 86 | case uint32: 87 | return w.writeFullInt(uint64(a), 4) 88 | case uint64: 89 | return w.writeFullInt(a, 8) 90 | default: 91 | return errors.New("astikit: invalid type") 92 | } 93 | 94 | return nil 95 | } 96 | 97 | // Writes exactly n bytes from bs 98 | // Writes first n bytes of bs if len(bs) > n 99 | // Pads with padByte at the end if len(bs) < n 100 | func (w *BitsWriter) WriteBytesN(bs []byte, n int, padByte uint8) error { 101 | if n == 0 { 102 | return nil 103 | } 104 | 105 | if len(bs) >= n { 106 | return w.writeByteSlice(bs[:n]) 107 | } 108 | 109 | if err := w.writeByteSlice(bs); err != nil { 110 | return err 111 | } 112 | 113 | // no bytes.Repeat here to avoid allocation 114 | for i := 0; i < n-len(bs); i++ { 115 | if err := w.writeFullByte(padByte); err != nil { 116 | return err 117 | } 118 | } 119 | 120 | return nil 121 | } 122 | 123 | func (w *BitsWriter) writeByteSlice(in []byte) error { 124 | if len(in) == 0 { 125 | return nil 126 | } 127 | 128 | if w.cacheLen != 0 { 129 | for _, b := range in { 130 | if err := w.writeFullByte(b); err != nil { 131 | return err 132 | } 133 | } 134 | } else { 135 | return w.write(in) 136 | } 137 | 138 | return nil 139 | } 140 | 141 | func (w *BitsWriter) write(b []byte) error { 142 | if _, err := w.w.Write(b); err != nil { 143 | return err 144 | } 145 | if w.writeCb != nil { 146 | for i := range b { 147 | w.writeCb(b[i : i+1]) 148 | } 149 | } 150 | return nil 151 | } 152 | 153 | func (w *BitsWriter) writeFullInt(in uint64, len int) error { 154 | if w.bo == binary.BigEndian { 155 | return w.writeBitsN(in, len*8) 156 | } else { 157 | for i := 0; i < len; i++ { 158 | err := w.writeFullByte(byte(in >> (i * 8))) 159 | if err != nil { 160 | return err 161 | } 162 | } 163 | } 164 | 165 | return nil 166 | } 167 | 168 | func (w *BitsWriter) flushBsCache() error { 169 | return w.write(w.bsCache) 170 | } 171 | 172 | func (w *BitsWriter) writeFullByte(b byte) error { 173 | if w.cacheLen == 0 { 174 | w.bsCache[0] = b 175 | } else { 176 | w.bsCache[0] = w.cache | (b >> w.cacheLen) 177 | w.cache = b << (8 - w.cacheLen) 178 | } 179 | return w.flushBsCache() 180 | } 181 | 182 | func (w *BitsWriter) writeBit(bit byte) error { 183 | if bit != 0 { 184 | w.cache |= 1 << (7 - w.cacheLen) 185 | } 186 | w.cacheLen++ 187 | if w.cacheLen == 8 { 188 | w.bsCache[0] = w.cache 189 | if err := w.flushBsCache(); err != nil { 190 | return err 191 | } 192 | 193 | w.cacheLen = 0 194 | w.cache = 0 195 | } 196 | return nil 197 | } 198 | 199 | func (w *BitsWriter) writeBitsN(toWrite uint64, n int) (err error) { 200 | toWrite &= ^uint64(0) >> (64 - n) 201 | 202 | for n > 0 { 203 | if w.cacheLen == 0 { 204 | if n >= 8 { 205 | n -= 8 206 | w.bsCache[0] = byte(toWrite >> n) 207 | if err = w.flushBsCache(); err != nil { 208 | return 209 | } 210 | } else { 211 | w.cacheLen = uint8(n) 212 | w.cache = byte(toWrite << (8 - w.cacheLen)) 213 | n = 0 214 | } 215 | } else { 216 | free := int(8 - w.cacheLen) 217 | m := n 218 | if m >= free { 219 | m = free 220 | } 221 | 222 | if n <= free { 223 | w.cache |= byte(toWrite << (free - m)) 224 | } else { 225 | w.cache |= byte(toWrite >> (n - m)) 226 | } 227 | 228 | n -= m 229 | w.cacheLen += uint8(m) 230 | 231 | if w.cacheLen == 8 { 232 | w.bsCache[0] = w.cache 233 | if err = w.flushBsCache(); err != nil { 234 | return err 235 | } 236 | 237 | w.cacheLen = 0 238 | w.cache = 0 239 | } 240 | } 241 | } 242 | 243 | return 244 | } 245 | 246 | // WriteN writes the input into n bits 247 | func (w *BitsWriter) WriteN(i any, n int) error { 248 | var toWrite uint64 249 | switch a := i.(type) { 250 | case uint8: 251 | toWrite = uint64(a) 252 | case uint16: 253 | toWrite = uint64(a) 254 | case uint32: 255 | toWrite = uint64(a) 256 | case uint64: 257 | toWrite = a 258 | default: 259 | return errors.New("astikit: invalid type") 260 | } 261 | 262 | return w.writeBitsN(toWrite, n) 263 | } 264 | 265 | // BitsWriterBatch allows to chain multiple Write* calls and check for error only once 266 | // For more info see https://github.com/asticode/go-astikit/pull/6 267 | type BitsWriterBatch struct { 268 | err error 269 | w *BitsWriter 270 | } 271 | 272 | func NewBitsWriterBatch(w *BitsWriter) BitsWriterBatch { 273 | return BitsWriterBatch{ 274 | w: w, 275 | } 276 | } 277 | 278 | // Calls BitsWriter.Write if there was no write error before 279 | func (b *BitsWriterBatch) Write(i any) { 280 | if b.err == nil { 281 | b.err = b.w.Write(i) 282 | } 283 | } 284 | 285 | // Calls BitsWriter.WriteN if there was no write error before 286 | func (b *BitsWriterBatch) WriteN(i any, n int) { 287 | if b.err == nil { 288 | b.err = b.w.WriteN(i, n) 289 | } 290 | } 291 | 292 | // Calls BitsWriter.WriteBytesN if there was no write error before 293 | func (b *BitsWriterBatch) WriteBytesN(bs []byte, n int, padByte uint8) { 294 | if b.err == nil { 295 | b.err = b.w.WriteBytesN(bs, n, padByte) 296 | } 297 | } 298 | 299 | // Returns first write error 300 | func (b *BitsWriterBatch) Err() error { 301 | return b.err 302 | } 303 | 304 | var byteHamming84Tab = [256]uint8{ 305 | 0x01, 0xff, 0xff, 0x08, 0xff, 0x0c, 0x04, 0xff, 0xff, 0x08, 0x08, 0x08, 0x06, 0xff, 0xff, 0x08, 306 | 0xff, 0x0a, 0x02, 0xff, 0x06, 0xff, 0xff, 0x0f, 0x06, 0xff, 0xff, 0x08, 0x06, 0x06, 0x06, 0xff, 307 | 0xff, 0x0a, 0x04, 0xff, 0x04, 0xff, 0x04, 0x04, 0x00, 0xff, 0xff, 0x08, 0xff, 0x0d, 0x04, 0xff, 308 | 0x0a, 0x0a, 0xff, 0x0a, 0xff, 0x0a, 0x04, 0xff, 0xff, 0x0a, 0x03, 0xff, 0x06, 0xff, 0xff, 0x0e, 309 | 0x01, 0x01, 0x01, 0xff, 0x01, 0xff, 0xff, 0x0f, 0x01, 0xff, 0xff, 0x08, 0xff, 0x0d, 0x05, 0xff, 310 | 0x01, 0xff, 0xff, 0x0f, 0xff, 0x0f, 0x0f, 0x0f, 0xff, 0x0b, 0x03, 0xff, 0x06, 0xff, 0xff, 0x0f, 311 | 0x01, 0xff, 0xff, 0x09, 0xff, 0x0d, 0x04, 0xff, 0xff, 0x0d, 0x03, 0xff, 0x0d, 0x0d, 0xff, 0x0d, 312 | 0xff, 0x0a, 0x03, 0xff, 0x07, 0xff, 0xff, 0x0f, 0x03, 0xff, 0x03, 0x03, 0xff, 0x0d, 0x03, 0xff, 313 | 0xff, 0x0c, 0x02, 0xff, 0x0c, 0x0c, 0xff, 0x0c, 0x00, 0xff, 0xff, 0x08, 0xff, 0x0c, 0x05, 0xff, 314 | 0x02, 0xff, 0x02, 0x02, 0xff, 0x0c, 0x02, 0xff, 0xff, 0x0b, 0x02, 0xff, 0x06, 0xff, 0xff, 0x0e, 315 | 0x00, 0xff, 0xff, 0x09, 0xff, 0x0c, 0x04, 0xff, 0x00, 0x00, 0x00, 0xff, 0x00, 0xff, 0xff, 0x0e, 316 | 0xff, 0x0a, 0x02, 0xff, 0x07, 0xff, 0xff, 0x0e, 0x00, 0xff, 0xff, 0x0e, 0xff, 0x0e, 0x0e, 0x0e, 317 | 0x01, 0xff, 0xff, 0x09, 0xff, 0x0c, 0x05, 0xff, 0xff, 0x0b, 0x05, 0xff, 0x05, 0xff, 0x05, 0x05, 318 | 0xff, 0x0b, 0x02, 0xff, 0x07, 0xff, 0xff, 0x0f, 0x0b, 0x0b, 0xff, 0x0b, 0xff, 0x0b, 0x05, 0xff, 319 | 0xff, 0x09, 0x09, 0x09, 0x07, 0xff, 0xff, 0x09, 0x00, 0xff, 0xff, 0x09, 0xff, 0x0d, 0x05, 0xff, 320 | 0x07, 0xff, 0xff, 0x09, 0x07, 0x07, 0x07, 0xff, 0xff, 0x0b, 0x03, 0xff, 0x07, 0xff, 0xff, 0x0e, 321 | } 322 | 323 | // ByteHamming84Decode hamming 8/4 decodes 324 | func ByteHamming84Decode(i uint8) (o uint8, ok bool) { 325 | o = byteHamming84Tab[i] 326 | if o == 0xff { 327 | return 328 | } 329 | ok = true 330 | return 331 | } 332 | 333 | var byteParityTab = [256]uint8{ 334 | 0x00, 0x01, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x00, 335 | 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 336 | 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 337 | 0x00, 0x01, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x00, 338 | 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 339 | 0x00, 0x01, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x00, 340 | 0x00, 0x01, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x00, 341 | 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 342 | 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 343 | 0x00, 0x01, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x00, 344 | 0x00, 0x01, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x00, 345 | 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 346 | 0x00, 0x01, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x00, 347 | 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 348 | 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 349 | 0x00, 0x01, 0x01, 0x00, 0x01, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x00, 350 | } 351 | 352 | // ByteParity returns the byte parity 353 | func ByteParity(i uint8) (o uint8, ok bool) { 354 | ok = byteParityTab[i] == 1 355 | o = i & 0x7f 356 | return 357 | } 358 | -------------------------------------------------------------------------------- /binary_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "reflect" 8 | "testing" 9 | ) 10 | 11 | func TestBitsWriter(t *testing.T) { 12 | // TODO Need to test LittleEndian 13 | bw := &bytes.Buffer{} 14 | cbBuf := bytes.Buffer{} 15 | w := NewBitsWriter(BitsWriterOptions{ 16 | Writer: bw, 17 | WriteCallback: func(i []byte) { 18 | cbBuf.Write(i) 19 | }, 20 | }) 21 | 22 | err := w.Write("000000") 23 | if err != nil { 24 | t.Fatalf("expected no error, got %+v", err) 25 | } 26 | if e, g := 0, bw.Len(); e != g { 27 | t.Fatalf("expected %d, got %d", e, g) 28 | } 29 | err = w.Write(false) 30 | if err != nil { 31 | t.Fatalf("expected no error, got %+v", err) 32 | } 33 | err = w.Write(true) 34 | if err != nil { 35 | t.Fatalf("expected no error, got %+v", err) 36 | } 37 | if e, g := []byte{1}, bw.Bytes(); !reflect.DeepEqual(e, g) { 38 | t.Fatalf("expected %+v, got %+v", e, g) 39 | } 40 | err = w.Write([]byte{2, 3}) 41 | if err != nil { 42 | t.Fatalf("expected no error, got %+v", err) 43 | } 44 | if e, g := []byte{1, 2, 3}, bw.Bytes(); !reflect.DeepEqual(e, g) { 45 | t.Fatalf("expected %+v, got %+v", e, g) 46 | } 47 | err = w.Write(uint8(4)) 48 | if err != nil { 49 | t.Fatalf("expected no error, got %+v", err) 50 | } 51 | if e, g := []byte{1, 2, 3, 4}, bw.Bytes(); !reflect.DeepEqual(e, g) { 52 | t.Fatalf("expected %+v, got %+v", e, g) 53 | } 54 | err = w.Write(uint16(5)) 55 | if err != nil { 56 | t.Fatalf("expected no error, got %+v", err) 57 | } 58 | if e, g := []byte{1, 2, 3, 4, 0, 5}, bw.Bytes(); !reflect.DeepEqual(e, g) { 59 | t.Fatalf("expected %+v, got %+v", e, g) 60 | } 61 | err = w.Write(uint32(6)) 62 | if err != nil { 63 | t.Fatalf("expected no error, got %+v", err) 64 | } 65 | if e, g := []byte{1, 2, 3, 4, 0, 5, 0, 0, 0, 6}, bw.Bytes(); !reflect.DeepEqual(e, g) { 66 | t.Fatalf("expected %+v, got %+v", e, g) 67 | } 68 | err = w.Write(uint64(7)) 69 | if err != nil { 70 | t.Fatalf("expected no error, got %+v", err) 71 | } 72 | if e, g := []byte{1, 2, 3, 4, 0, 5, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 7}, bw.Bytes(); !reflect.DeepEqual(e, g) { 73 | t.Fatalf("expected %+v, got %+v", e, g) 74 | } 75 | if e, g := []byte{1, 2, 3, 4, 0, 5, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 7}, cbBuf.Bytes(); !reflect.DeepEqual(e, g) { 76 | t.Fatalf("callback buffer: expected %+v, got %+v", e, g) 77 | } 78 | err = w.Write(1) 79 | if err == nil { 80 | t.Fatal("expected error") 81 | } 82 | 83 | bw.Reset() 84 | err = w.WriteN(uint8(4), 3) 85 | if err != nil { 86 | t.Fatalf("expected no error, got %+v", err) 87 | } 88 | err = w.WriteN(uint16(4096), 13) 89 | if err != nil { 90 | t.Fatalf("expected no error, got %+v", err) 91 | } 92 | if e, g := []byte{144, 0}, bw.Bytes(); !reflect.DeepEqual(e, g) { 93 | t.Fatalf("expected %+v, got %+v", e, g) 94 | } 95 | } 96 | 97 | var bitsWriter_WriteBytesN_testCases = []struct { 98 | bs []byte 99 | n int 100 | expected []byte 101 | }{ 102 | {nil, 0, nil}, 103 | {[]byte{0x00}, 0, nil}, 104 | {nil, 3, []byte{0xff, 0xff, 0xff}}, 105 | {[]byte("en"), 3, []byte{'e', 'n', 0xff}}, 106 | {[]byte("eng"), 3, []byte{'e', 'n', 'g'}}, 107 | {[]byte("english"), 3, []byte{'e', 'n', 'g'}}, 108 | } 109 | 110 | func TestBitsWriter_WriteBytesN(t *testing.T) { 111 | padByte := uint8(0xff) 112 | for _, tt := range bitsWriter_WriteBytesN_testCases { 113 | t.Run(fmt.Sprintf("%v/%d", tt.bs, tt.n), func(t *testing.T) { 114 | buf := bytes.Buffer{} 115 | w := NewBitsWriter(BitsWriterOptions{Writer: &buf}) 116 | 117 | err := w.WriteBytesN(tt.bs, tt.n, padByte) 118 | if err != nil { 119 | t.Fatalf("expected no error, got %+v", err) 120 | } 121 | 122 | if !reflect.DeepEqual(tt.expected, buf.Bytes()) { 123 | t.Fatalf("expected %+v, got %+v", tt.expected, buf.Bytes()) 124 | } 125 | }) 126 | } 127 | } 128 | 129 | // testLimitedWriter is an implementation of io.Writer with max write size limit to test error handling 130 | type testLimitedWriter struct { 131 | BytesLimit int 132 | } 133 | 134 | func (t *testLimitedWriter) Write(p []byte) (n int, err error) { 135 | t.BytesLimit -= len(p) 136 | if t.BytesLimit >= 0 { 137 | return len(p), nil 138 | } 139 | return len(p) + t.BytesLimit, io.EOF 140 | } 141 | 142 | func TestNewBitsWriterBatch(t *testing.T) { 143 | wr := &testLimitedWriter{BytesLimit: 1} 144 | w := NewBitsWriter(BitsWriterOptions{Writer: wr}) 145 | b := NewBitsWriterBatch(w) 146 | 147 | b.Write(uint8(0)) 148 | if err := b.Err(); err != nil { 149 | t.Fatalf("expected no error, got %+v", err) 150 | } 151 | b.Write(uint8(1)) 152 | if err := b.Err(); err == nil { 153 | t.Fatalf("expected error, got %+v", err) 154 | } 155 | 156 | // let's check if the error is persisted 157 | b.Write(uint8(2)) 158 | if err := b.Err(); err == nil { 159 | t.Fatalf("expected error, got %+v", err) 160 | } 161 | } 162 | 163 | func BenchmarkBitsWriter_Write(b *testing.B) { 164 | benchmarks := []struct { 165 | input any 166 | }{ 167 | {"000000"}, 168 | {false}, 169 | {true}, 170 | {[]byte{2, 3}}, 171 | {uint8(4)}, 172 | {uint16(5)}, 173 | {uint32(6)}, 174 | {uint64(7)}, 175 | } 176 | 177 | bw := &bytes.Buffer{} 178 | bw.Grow(1024) 179 | w := NewBitsWriter(BitsWriterOptions{Writer: bw}) 180 | 181 | for _, bm := range benchmarks { 182 | b.Run(fmt.Sprintf("%#v", bm.input), func(b *testing.B) { 183 | b.ReportAllocs() 184 | for i := 0; i < b.N; i++ { 185 | bw.Reset() 186 | w.Write(bm.input) //nolint:errcheck 187 | } 188 | }) 189 | } 190 | } 191 | 192 | func BenchmarkBitsWriter_WriteN(b *testing.B) { 193 | type benchData struct { 194 | i any 195 | n int 196 | } 197 | benchmarks := make([]benchData, 0, 128) 198 | for i := 1; i <= 8; i++ { 199 | benchmarks = append(benchmarks, benchData{uint8(0xff), i}) 200 | } 201 | for i := 1; i <= 16; i++ { 202 | benchmarks = append(benchmarks, benchData{uint16(0xffff), i}) 203 | } 204 | for i := 1; i <= 32; i++ { 205 | benchmarks = append(benchmarks, benchData{uint32(0xffffffff), i}) 206 | } 207 | for i := 1; i <= 64; i++ { 208 | benchmarks = append(benchmarks, benchData{uint64(0xffffffffffffffff), i}) 209 | } 210 | 211 | bw := &bytes.Buffer{} 212 | bw.Grow(1024) 213 | w := NewBitsWriter(BitsWriterOptions{Writer: bw}) 214 | 215 | for _, bm := range benchmarks { 216 | b.Run(fmt.Sprintf("%#v/%d", bm.i, bm.n), func(b *testing.B) { 217 | b.ReportAllocs() 218 | for i := 0; i < b.N; i++ { 219 | bw.Reset() 220 | w.WriteN(bm.i, bm.n) //nolint:errcheck 221 | } 222 | }) 223 | } 224 | } 225 | 226 | func BenchmarkBitsWriter_WriteBytesN(b *testing.B) { 227 | bw := &bytes.Buffer{} 228 | bw.Grow(1024) 229 | w := NewBitsWriter(BitsWriterOptions{Writer: bw}) 230 | 231 | for _, bm := range bitsWriter_WriteBytesN_testCases { 232 | b.Run(fmt.Sprintf("%v/%d", bm.bs, bm.n), func(b *testing.B) { 233 | b.ReportAllocs() 234 | for i := 0; i < b.N; i++ { 235 | bw.Reset() 236 | w.WriteBytesN(bm.bs, bm.n, 0xff) //nolint:errcheck 237 | } 238 | }) 239 | } 240 | } 241 | 242 | func testByteHamming84Decode(i uint8) (o uint8, ok bool) { 243 | p1, d1, p2, d2, p3, d3, p4, d4 := i>>7&0x1, i>>6&0x1, i>>5&0x1, i>>4&0x1, i>>3&0x1, i>>2&0x1, i>>1&0x1, i&0x1 244 | testA := p1^d1^d3^d4 > 0 245 | testB := d1^p2^d2^d4 > 0 246 | testC := d1^d2^p3^d3 > 0 247 | testD := p1^d1^p2^d2^p3^d3^p4^d4 > 0 248 | if testA && testB && testC { 249 | // p4 may be incorrect 250 | } else if testD && (!testA || !testB || !testC) { 251 | return 252 | } else { 253 | if !testA && testB && testC { 254 | // p1 is incorrect 255 | } else if testA && !testB && testC { 256 | // p2 is incorrect 257 | } else if testA && testB && !testC { 258 | // p3 is incorrect 259 | } else if !testA && !testB && testC { 260 | // d4 is incorrect 261 | d4 ^= 1 262 | } else if testA && !testB && !testC { 263 | // d2 is incorrect 264 | d2 ^= 1 265 | } else if !testA && testB && !testC { 266 | // d3 is incorrect 267 | d3 ^= 1 268 | } else { 269 | // d1 is incorrect 270 | d1 ^= 1 271 | } 272 | } 273 | o = uint8(d4<<3 | d3<<2 | d2<<1 | d1) 274 | ok = true 275 | return 276 | } 277 | 278 | func TestByteHamming84Decode(t *testing.T) { 279 | for i := 0; i < 256; i++ { 280 | v, okV := ByteHamming84Decode(uint8(i)) 281 | e, okE := testByteHamming84Decode(uint8(i)) 282 | if !okE { 283 | if okV { 284 | t.Fatal("expected false, got true") 285 | } 286 | } else { 287 | if !okV { 288 | t.Fatal("expected true, got false") 289 | } 290 | if !reflect.DeepEqual(e, v) { 291 | t.Fatalf("expected %+v, got %+v", e, v) 292 | } 293 | } 294 | } 295 | } 296 | 297 | func testByteParity(i uint8) bool { 298 | return (i&0x1)^(i>>1&0x1)^(i>>2&0x1)^(i>>3&0x1)^(i>>4&0x1)^(i>>5&0x1)^(i>>6&0x1)^(i>>7&0x1) > 0 299 | } 300 | 301 | func TestByteParity(t *testing.T) { 302 | for i := 0; i < 256; i++ { 303 | v, okV := ByteParity(uint8(i)) 304 | okE := testByteParity(uint8(i)) 305 | if !okE { 306 | if okV { 307 | t.Fatal("expected false, got true") 308 | } 309 | } else { 310 | if !okV { 311 | t.Fatal("expected true, got false") 312 | } 313 | if e := uint8(i) & 0x7f; e != v { 314 | t.Fatalf("expected %+v, got %+v", e, v) 315 | } 316 | } 317 | } 318 | } 319 | -------------------------------------------------------------------------------- /bit_flags.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | type BitFlags uint64 4 | 5 | func (fs BitFlags) Add(f uint64) uint64 { return uint64(fs) | f } 6 | 7 | func (fs BitFlags) Del(f uint64) uint64 { return uint64(fs) &^ f } 8 | 9 | func (fs BitFlags) Has(f uint64) bool { return uint64(fs)&f > 0 } 10 | -------------------------------------------------------------------------------- /bit_flags_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestBitFlags(t *testing.T) { 8 | f := BitFlags(2 | 4) 9 | r := f.Add(1) 10 | if e, g := uint64(7), r; e != g { 11 | t.Fatalf("expected %d, got %d", e, g) 12 | } 13 | r = f.Del(2) 14 | if e, g := uint64(4), r; e != g { 15 | t.Fatalf("expected %d, got %d", e, g) 16 | } 17 | if e, g := false, f.Has(1); e != g { 18 | t.Fatalf("expected %v, got %v", e, g) 19 | } 20 | if e, g := true, f.Has(4); e != g { 21 | t.Fatalf("expected %v, got %v", e, g) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /bool.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | func BoolToUInt32(b bool) uint32 { 4 | if b { 5 | return 1 6 | } 7 | return 0 8 | } 9 | -------------------------------------------------------------------------------- /bool_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import "testing" 4 | 5 | func TestBoolToUInt32(t *testing.T) { 6 | if e, g := uint32(0), BoolToUInt32(false); e != g { 7 | t.Fatalf("expected %d, got %d", e, g) 8 | } 9 | if e, g := uint32(1), BoolToUInt32(true); e != g { 10 | t.Fatalf("expected %d, got %d", e, g) 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /bytes.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import "fmt" 4 | 5 | // BytesIterator represents an object capable of iterating sequentially and safely 6 | // through a slice of bytes. This is particularly useful when you need to iterate 7 | // through a slice of bytes and don't want to check for "index out of range" errors 8 | // manually. 9 | type BytesIterator struct { 10 | bs []byte 11 | offset int 12 | } 13 | 14 | // NewBytesIterator creates a new BytesIterator 15 | func NewBytesIterator(bs []byte) *BytesIterator { 16 | return &BytesIterator{bs: bs} 17 | } 18 | 19 | // NextByte returns the next byte 20 | func (i *BytesIterator) NextByte() (b byte, err error) { 21 | if len(i.bs) < i.offset+1 { 22 | err = fmt.Errorf("astikit: slice length is %d, offset %d is invalid", len(i.bs), i.offset) 23 | return 24 | } 25 | b = i.bs[i.offset] 26 | i.offset++ 27 | return 28 | } 29 | 30 | // NextBytes returns the n next bytes 31 | func (i *BytesIterator) NextBytes(n int) (bs []byte, err error) { 32 | if len(i.bs) < i.offset+n { 33 | err = fmt.Errorf("astikit: slice length is %d, offset %d is invalid", len(i.bs), i.offset+n) 34 | return 35 | } 36 | bs = make([]byte, n) 37 | copy(bs, i.bs[i.offset:i.offset+n]) 38 | i.offset += n 39 | return 40 | } 41 | 42 | // NextBytesNoCopy returns the n next bytes 43 | // Be careful with this function as it doesn't make a copy of returned data. 44 | // bs will point to internal BytesIterator buffer. 45 | // If you need to modify returned bytes or store it for some time, use NextBytes instead 46 | func (i *BytesIterator) NextBytesNoCopy(n int) (bs []byte, err error) { 47 | if len(i.bs) < i.offset+n { 48 | err = fmt.Errorf("astikit: slice length is %d, offset %d is invalid", len(i.bs), i.offset+n) 49 | return 50 | } 51 | bs = i.bs[i.offset : i.offset+n] 52 | i.offset += n 53 | return 54 | } 55 | 56 | // Seek seeks to the nth byte 57 | func (i *BytesIterator) Seek(n int) { 58 | i.offset = n 59 | } 60 | 61 | // Skip skips the n previous/next bytes 62 | func (i *BytesIterator) Skip(n int) { 63 | i.offset += n 64 | } 65 | 66 | // HasBytesLeft checks whether there are bytes left 67 | func (i *BytesIterator) HasBytesLeft() bool { 68 | return i.offset < len(i.bs) 69 | } 70 | 71 | // Offset returns the offset 72 | func (i *BytesIterator) Offset() int { 73 | return i.offset 74 | } 75 | 76 | // Dump dumps the rest of the slice 77 | func (i *BytesIterator) Dump() (bs []byte) { 78 | if !i.HasBytesLeft() { 79 | return 80 | } 81 | bs = make([]byte, len(i.bs)-i.offset) 82 | copy(bs, i.bs[i.offset:len(i.bs)]) 83 | i.offset = len(i.bs) 84 | return 85 | } 86 | 87 | // Len returns the slice length 88 | func (i *BytesIterator) Len() int { 89 | return len(i.bs) 90 | } 91 | 92 | const ( 93 | padRight = "right" 94 | padLeft = "left" 95 | ) 96 | 97 | type bytesPadder struct { 98 | cut bool 99 | direction string 100 | length int 101 | repeat byte 102 | } 103 | 104 | func newBytesPadder(repeat byte, length int) *bytesPadder { 105 | return &bytesPadder{ 106 | direction: padLeft, 107 | length: length, 108 | repeat: repeat, 109 | } 110 | } 111 | 112 | func (p *bytesPadder) pad(i []byte) []byte { 113 | if len(i) == p.length { 114 | return i 115 | } else if len(i) > p.length { 116 | if p.cut { 117 | return i[:p.length] 118 | } 119 | return i 120 | } else { 121 | o := make([]byte, len(i)) 122 | copy(o, i) 123 | for idx := 0; idx < p.length-len(i); idx++ { 124 | if p.direction == padRight { 125 | o = append(o, p.repeat) 126 | } else { 127 | o = append([]byte{p.repeat}, o...) 128 | } 129 | o = append(o, p.repeat) 130 | } 131 | o = o[:p.length] 132 | return o 133 | } 134 | } 135 | 136 | // PadOption represents a Pad option 137 | type PadOption func(p *bytesPadder) 138 | 139 | // PadCut is a PadOption 140 | // It indicates to the padder it must cut the input to the provided length 141 | // if its original length is bigger 142 | func PadCut(p *bytesPadder) { p.cut = true } 143 | 144 | // PadLeft is a PadOption 145 | // It indicates additionnal bytes have to be added to the left 146 | func PadLeft(p *bytesPadder) { p.direction = padLeft } 147 | 148 | // PadRight is a PadOption 149 | // It indicates additionnal bytes have to be added to the right 150 | func PadRight(p *bytesPadder) { p.direction = padRight } 151 | 152 | // BytesPad pads the slice of bytes with additionnal options 153 | func BytesPad(i []byte, repeat byte, length int, options ...PadOption) []byte { 154 | p := newBytesPadder(repeat, length) 155 | for _, o := range options { 156 | o(p) 157 | } 158 | return p.pad(i) 159 | } 160 | 161 | // StrPad pads the string with additionnal options 162 | func StrPad(i string, repeat rune, length int, options ...PadOption) string { 163 | return string(BytesPad([]byte(i), byte(repeat), length, options...)) 164 | } 165 | -------------------------------------------------------------------------------- /bytes_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | ) 7 | 8 | func TestBytesIterator(t *testing.T) { 9 | i := NewBytesIterator([]byte("12345678")) 10 | if e, g := 8, i.Len(); e != g { 11 | t.Fatalf("expected %v, got %v", e, g) 12 | } 13 | b, err := i.NextByte() 14 | if err != nil { 15 | t.Fatalf("expected no error, got %+v", err) 16 | } 17 | if e := byte('1'); e != b { 18 | t.Fatalf("expected %v, got %v", e, b) 19 | } 20 | bs, err := i.NextBytes(2) 21 | if err != nil { 22 | t.Fatalf("expected no error, got %+v", err) 23 | } 24 | if e := []byte("23"); !bytes.Equal(e, bs) { 25 | t.Fatalf("expected %+v, got %+v", e, bs) 26 | } 27 | i.Seek(1) 28 | bs, err = i.NextBytesNoCopy(2) 29 | if err != nil { 30 | t.Fatalf("expected no error, got %+v", err) 31 | } 32 | if e := []byte("23"); !bytes.Equal(e, bs) { 33 | t.Fatalf("expected %+v, got %+v", e, bs) 34 | } 35 | i.Seek(4) 36 | b, err = i.NextByte() 37 | if err != nil { 38 | t.Fatalf("expected no error, got %+v", err) 39 | } 40 | if e := byte('5'); e != b { 41 | t.Fatalf("expected %v, got %v", e, b) 42 | } 43 | i.Skip(1) 44 | b, err = i.NextByte() 45 | if err != nil { 46 | t.Fatalf("expected no error, got %+v", err) 47 | } 48 | if e := byte('7'); e != b { 49 | t.Fatalf("expected %v, got %v", e, b) 50 | } 51 | if e, g := 7, i.Offset(); e != g { 52 | t.Fatalf("expected %v, got %v", e, g) 53 | } 54 | if !i.HasBytesLeft() { 55 | t.Fatal("expected true, got false") 56 | } 57 | bs = i.Dump() 58 | if e := []byte("8"); !bytes.Equal(e, bs) { 59 | t.Fatalf("expected %+v, got %+v", e, bs) 60 | } 61 | if i.HasBytesLeft() { 62 | t.Fatal("expected false, got true") 63 | } 64 | _, err = i.NextByte() 65 | if err == nil { 66 | t.Fatal("expected error") 67 | } 68 | _, err = i.NextBytes(2) 69 | if err == nil { 70 | t.Fatal("expected error") 71 | } 72 | bs = i.Dump() 73 | if e, g := 0, len(bs); e != g { 74 | t.Fatalf("expected %+v, got %+v", e, g) 75 | } 76 | } 77 | func TestBytesPad(t *testing.T) { 78 | if e, g := []byte("test"), BytesPad([]byte("test"), ' ', 4); !bytes.Equal(e, g) { 79 | t.Fatalf("expected %+v, got %+v", e, g) 80 | } 81 | if e, g := []byte("testtest"), BytesPad([]byte("testtest"), ' ', 4); !bytes.Equal(e, g) { 82 | t.Fatalf("expected %+v, got %+v", e, g) 83 | } 84 | if e, g := []byte("test"), BytesPad([]byte("testtest"), ' ', 4, PadCut); !bytes.Equal(e, g) { 85 | t.Fatalf("expected %+v, got %+v", e, g) 86 | } 87 | if e, g := []byte(" test"), BytesPad([]byte("test"), ' ', 6); !bytes.Equal(e, g) { 88 | t.Fatalf("expected %+v, got %+v", e, g) 89 | } 90 | if e, g := []byte("test "), BytesPad([]byte("test"), ' ', 6, PadRight); !bytes.Equal(e, g) { 91 | t.Fatalf("expected %+v, got %+v", e, g) 92 | } 93 | if e, g := []byte(" "), BytesPad([]byte{}, ' ', 4); !bytes.Equal(e, g) { 94 | t.Fatalf("expected %+v, got %+v", e, g) 95 | } 96 | } 97 | 98 | func TestStrPad(t *testing.T) { 99 | if e, g := "test", StrPad("test", ' ', 4); e != g { 100 | t.Fatalf("expected %+v, got %+v", e, g) 101 | } 102 | if e, g := "testtest", StrPad("testtest", ' ', 4); e != g { 103 | t.Fatalf("expected %+v, got %+v", e, g) 104 | } 105 | if e, g := "test", StrPad("testtest", ' ', 4, PadCut); e != g { 106 | t.Fatalf("expected %+v, got %+v", e, g) 107 | } 108 | if e, g := " test", StrPad("test", ' ', 6); e != g { 109 | t.Fatalf("expected %+v, got %+v", e, g) 110 | } 111 | if e, g := "test ", StrPad("test", ' ', 6, PadRight); e != g { 112 | t.Fatalf("expected %+v, got %+v", e, g) 113 | } 114 | if e, g := " ", StrPad("", ' ', 4); e != g { 115 | t.Fatalf("expected %+v, got %+v", e, g) 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /cache.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | // Cache is an object capable of caching stuff while ensuring cumulated cached size never gets above 8 | // a provided threshold 9 | type Cache struct { 10 | items []CacheItem // We use a slice since we want to reorder items when one has been used 11 | m sync.Mutex // Locks items 12 | o CacheOptions 13 | size int 14 | } 15 | 16 | type CacheItem interface { 17 | Size() int 18 | } 19 | 20 | type CacheOptions struct { 21 | // - 0 disables cache 22 | // - < 0 disables max size 23 | MaxSize int 24 | } 25 | 26 | func NewCache(o CacheOptions) *Cache { 27 | return &Cache{o: o} 28 | } 29 | 30 | func (c *Cache) Get(found func(i CacheItem) bool) (CacheItem, bool) { 31 | // Lock 32 | c.m.Lock() 33 | defer c.m.Unlock() 34 | 35 | // Find item 36 | var idx int 37 | for idx = 0; idx < len(c.items); idx++ { 38 | if found(c.items[idx]) { 39 | break 40 | } 41 | } 42 | 43 | // Item was not found 44 | if idx >= len(c.items) { 45 | return nil, false 46 | } 47 | 48 | // Save item 49 | i := c.items[idx] 50 | 51 | // Move entry to the last position 52 | c.items = append(append(c.items[:idx], c.items[idx+1:]...), i) 53 | return i, true 54 | } 55 | 56 | func (c *Cache) Set(i CacheItem) { 57 | // Nothing to do 58 | if c.o.MaxSize == 0 { 59 | return 60 | } 61 | 62 | // Item is bigger than cache max size 63 | if c.o.MaxSize > 0 && i.Size() > c.o.MaxSize { 64 | return 65 | } 66 | 67 | // Lock 68 | c.m.Lock() 69 | defer c.m.Unlock() 70 | 71 | // Make room for item 72 | if c.o.MaxSize > 0 { 73 | for c.size+i.Size() > c.o.MaxSize { 74 | c.size -= c.items[0].Size() 75 | c.items = c.items[1:] 76 | } 77 | } 78 | 79 | // Store image 80 | c.size += i.Size() 81 | c.items = append(c.items, i) 82 | } 83 | 84 | func (c *Cache) Delete(remove func(i CacheItem) bool) { 85 | // Lock 86 | c.m.Lock() 87 | defer c.m.Unlock() 88 | 89 | // Loop through entries 90 | for idx := 0; idx < len(c.items); idx++ { 91 | // Remove 92 | if remove(c.items[idx]) { 93 | c.items = append(c.items[:idx], c.items[idx+1:]...) 94 | idx-- 95 | } 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /cache_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | type cacheItem int 8 | 9 | func (i cacheItem) Size() int { return int(i) } 10 | 11 | func cacheFunc(i int) func(i CacheItem) bool { 12 | return func(ci CacheItem) bool { return int(ci.(cacheItem)) == i } 13 | } 14 | 15 | func TestCache(t *testing.T) { 16 | // Cache can be disabled 17 | c := NewCache(CacheOptions{}) 18 | c.Set(cacheItem(1)) 19 | _, ok := c.Get(cacheFunc(1)) 20 | if ok { 21 | t.Fatal("expected false, got true") 22 | } 23 | 24 | // Cache can be limited 25 | c = NewCache(CacheOptions{MaxSize: 5}) 26 | if _, ok = c.Get(cacheFunc(1)); ok { 27 | t.Fatal("expected false, got true") 28 | } 29 | c.Set(cacheItem(1)) 30 | i, ok := c.Get(cacheFunc(1)) 31 | if !ok { 32 | t.Fatal("expected true, got false") 33 | } 34 | if e, g := 1, int(i.(cacheItem)); e != g { 35 | t.Fatalf("expected %d, got %d", e, g) 36 | } 37 | c.Set(cacheItem(2)) 38 | c.Set(cacheItem(3)) 39 | if _, ok = c.Get(cacheFunc(1)); ok { 40 | t.Fatal("expected false, got true") 41 | } 42 | if _, ok = c.Get(cacheFunc(3)); !ok { 43 | t.Fatal("expected true, got false") 44 | } 45 | // Getting an item makes it less likely to get purged 46 | if _, ok = c.Get(cacheFunc(2)); !ok { 47 | t.Fatal("expected true, got false") 48 | } 49 | c.Set(cacheItem(1)) 50 | if _, ok = c.Get(cacheFunc(3)); ok { 51 | t.Fatal("expected false, got true") 52 | } 53 | if _, ok = c.Get(cacheFunc(1)); !ok { 54 | t.Fatal("expected true, got false") 55 | } 56 | if _, ok = c.Get(cacheFunc(2)); !ok { 57 | t.Fatal("expected true, got false") 58 | } 59 | c.Set(cacheItem(6)) 60 | if _, ok = c.Get(cacheFunc(6)); ok { 61 | t.Fatal("expected false, got true") 62 | } 63 | 64 | // Cache can be unlimited 65 | c = NewCache(CacheOptions{MaxSize: -1}) 66 | c.Set(cacheItem(1)) 67 | c.Set(cacheItem(2)) 68 | c.Set(cacheItem(3)) 69 | if _, ok = c.Get(cacheFunc(1)); !ok { 70 | t.Fatal("expected true, got false") 71 | } 72 | if _, ok = c.Get(cacheFunc(2)); !ok { 73 | t.Fatal("expected true, got false") 74 | } 75 | if _, ok = c.Get(cacheFunc(3)); !ok { 76 | t.Fatal("expected true, got false") 77 | } 78 | c.Delete(cacheFunc(2)) 79 | if _, ok = c.Get(cacheFunc(1)); !ok { 80 | t.Fatal("expected true, got false") 81 | } 82 | if _, ok = c.Get(cacheFunc(2)); ok { 83 | t.Fatal("expected false, got true") 84 | } 85 | if _, ok = c.Get(cacheFunc(3)); !ok { 86 | t.Fatal("expected true, got false") 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /defer.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | type CloseFunc func() 8 | type CloseFuncWithError func() error 9 | 10 | // Closer is an object that can close several things 11 | type Closer struct { 12 | closed bool 13 | fs []CloseFuncWithError 14 | // We need to split into 2 mutexes to allow using .Add() in .Do() 15 | mc *sync.Mutex // Locks .Close() 16 | mf *sync.Mutex // Locks fs 17 | onClosed CloserOnClosed 18 | } 19 | 20 | type CloserOnClosed func(err error) 21 | 22 | // NewCloser creates a new closer 23 | func NewCloser() *Closer { 24 | return &Closer{ 25 | mc: &sync.Mutex{}, 26 | mf: &sync.Mutex{}, 27 | } 28 | } 29 | 30 | // Close implements the io.Closer interface 31 | func (c *Closer) Close() error { 32 | // Lock 33 | c.mc.Lock() 34 | defer c.mc.Unlock() 35 | 36 | // Get funcs 37 | c.mf.Lock() 38 | fs := c.fs 39 | c.mf.Unlock() 40 | 41 | // Loop through closers 42 | err := NewErrors() 43 | for _, f := range fs { 44 | err.Add(f()) 45 | } 46 | 47 | // Reset closers 48 | c.fs = []CloseFuncWithError{} 49 | 50 | // Update attribute 51 | c.closed = true 52 | 53 | // Callback 54 | if c.onClosed != nil { 55 | c.onClosed(err) 56 | } 57 | 58 | // Return 59 | if err.IsNil() { 60 | return nil 61 | } 62 | return err 63 | } 64 | 65 | func (c *Closer) Add(f CloseFunc) { 66 | c.AddWithError(func() error { 67 | f() 68 | return nil 69 | }) 70 | } 71 | 72 | func (c *Closer) AddWithError(f CloseFuncWithError) { 73 | // Lock 74 | c.mf.Lock() 75 | defer c.mf.Unlock() 76 | 77 | // Append 78 | c.fs = append([]CloseFuncWithError{f}, c.fs...) 79 | } 80 | 81 | func (c *Closer) Append(dst *Closer) { 82 | // Lock 83 | c.mf.Lock() 84 | dst.mf.Lock() 85 | defer c.mf.Unlock() 86 | defer dst.mf.Unlock() 87 | 88 | // Append 89 | c.fs = append(c.fs, dst.fs...) 90 | } 91 | 92 | // NewChild creates a new child closer 93 | func (c *Closer) NewChild() (child *Closer) { 94 | child = NewCloser() 95 | c.AddWithError(child.Close) 96 | return 97 | } 98 | 99 | // Do executes a callback while ensuring : 100 | // - closer hasn't been closed before 101 | // - closer can't be closed in between 102 | func (c *Closer) Do(fn func()) { 103 | // Lock 104 | c.mc.Lock() 105 | defer c.mc.Unlock() 106 | 107 | // Closer already closed 108 | if c.closed { 109 | return 110 | } 111 | 112 | // Callback 113 | fn() 114 | } 115 | 116 | func (c *Closer) OnClosed(fn CloserOnClosed) { 117 | c.mc.Lock() 118 | defer c.mc.Unlock() 119 | c.onClosed = fn 120 | } 121 | 122 | func (c *Closer) IsClosed() bool { 123 | c.mc.Lock() 124 | defer c.mc.Unlock() 125 | return c.closed 126 | } 127 | -------------------------------------------------------------------------------- /defer_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "errors" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestCloser(t *testing.T) { 10 | var c int 11 | var o []string 12 | c1 := NewCloser() 13 | c1.OnClosed(func(err error) { c++ }) 14 | c2 := c1.NewChild() 15 | c1.Add(func() { o = append(o, "1") }) 16 | c1.AddWithError(func() error { 17 | o = append(o, "2") 18 | return errors.New("1") 19 | }) 20 | c1.AddWithError(func() error { return errors.New("2") }) 21 | c2.AddWithError(func() error { 22 | o = append(o, "3") 23 | return errors.New("3") 24 | }) 25 | err := c1.Close() 26 | if e := []string{"2", "1", "3"}; !reflect.DeepEqual(o, e) { 27 | t.Fatalf("expected %+v, got %+v", e, o) 28 | } 29 | if e, g := "2 && 1 && 3", err.Error(); !reflect.DeepEqual(g, e) { 30 | t.Fatalf("expected %+v, got %+v", e, g) 31 | } 32 | c1.AddWithError(func() error { return nil }) 33 | if err = c1.Close(); err != nil { 34 | t.Fatalf("expected no error, got %+v", err) 35 | } 36 | if e, g := 2, c; e != g { 37 | t.Fatalf("expected %v, got %v", e, g) 38 | } 39 | if !c1.IsClosed() { 40 | t.Fatal("expected true, got false") 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /errors.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | "sync" 7 | ) 8 | 9 | // Errors is an error containing multiple errors 10 | type Errors struct { 11 | m *sync.Mutex // Locks p 12 | p []error 13 | } 14 | 15 | // NewErrors creates new errors 16 | func NewErrors(errs ...error) *Errors { 17 | return &Errors{ 18 | m: &sync.Mutex{}, 19 | p: errs, 20 | } 21 | } 22 | 23 | // Add adds a new error 24 | func (errs *Errors) Add(err error) { 25 | if err == nil { 26 | return 27 | } 28 | errs.m.Lock() 29 | defer errs.m.Unlock() 30 | errs.p = append(errs.p, err) 31 | } 32 | 33 | // IsNil checks whether the error is nil 34 | func (errs *Errors) IsNil() bool { 35 | errs.m.Lock() 36 | defer errs.m.Unlock() 37 | return len(errs.p) == 0 38 | } 39 | 40 | // Loop loops through the errors 41 | func (errs *Errors) Loop(fn func(idx int, err error) bool) { 42 | errs.m.Lock() 43 | defer errs.m.Unlock() 44 | for idx, err := range errs.p { 45 | if stop := fn(idx, err); stop { 46 | return 47 | } 48 | } 49 | } 50 | 51 | // Error implements the error interface 52 | func (errs *Errors) Error() string { 53 | errs.m.Lock() 54 | defer errs.m.Unlock() 55 | var ss []string 56 | for _, err := range errs.p { 57 | ss = append(ss, err.Error()) 58 | } 59 | return strings.Join(ss, " && ") 60 | } 61 | 62 | func (errs *Errors) Is(target error) bool { 63 | errs.m.Lock() 64 | defer errs.m.Unlock() 65 | for _, v := range errs.p { 66 | if errors.Is(v, target) { 67 | return true 68 | } 69 | } 70 | return false 71 | } 72 | 73 | // ErrorCause returns the cause of an error 74 | func ErrorCause(err error) error { 75 | for { 76 | if u := errors.Unwrap(err); u != nil { 77 | err = u 78 | continue 79 | } 80 | return err 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /errors_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "strconv" 7 | "testing" 8 | ) 9 | 10 | func TestErrors(t *testing.T) { 11 | errs := NewErrors() 12 | if !errs.IsNil() { 13 | t.Fatal("expected true, got false") 14 | } 15 | errs = NewErrors(errors.New("1")) 16 | if e, g := "1", errs.Error(); g != e { 17 | t.Fatalf("expected %+v, got %+v", e, g) 18 | } 19 | errs.Add(errors.New("2")) 20 | if e, g := "1 && 2", errs.Error(); g != e { 21 | t.Fatalf("expected %+v, got %+v", e, g) 22 | } 23 | errs.Loop(func(idx int, err error) bool { 24 | if e, g := strconv.Itoa(idx+1), err.Error(); g != e { 25 | t.Fatalf("expected %v, got %v", e, g) 26 | } 27 | return false 28 | }) 29 | err1 := errors.New("1") 30 | err2 := errors.New("2") 31 | err3 := errors.New("3") 32 | errs = NewErrors(err1, err3) 33 | for _, v := range []struct { 34 | err error 35 | expected bool 36 | }{ 37 | { 38 | err: err1, 39 | expected: true, 40 | }, 41 | { 42 | err: err2, 43 | expected: false, 44 | }, 45 | { 46 | err: err3, 47 | expected: true, 48 | }, 49 | } { 50 | if g := errors.Is(errs, v.err); g != v.expected { 51 | t.Fatalf("expected %v, got %v", v.expected, g) 52 | } 53 | } 54 | } 55 | 56 | func TestErrorCause(t *testing.T) { 57 | err1 := errors.New("test 1") 58 | err2 := fmt.Errorf("test 2 failed: %w", err1) 59 | if e, g := err1, ErrorCause(err2); !errors.Is(g, e) { 60 | t.Fatalf("expected %+v, got %+v", e, g) 61 | } 62 | err3 := fmt.Errorf("test 3 failed: %w", err2) 63 | if e, g := err1, ErrorCause(err3); !errors.Is(g, e) { 64 | t.Fatalf("expected %+v, got %+v", e, g) 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /event.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | type EventHandler func(payload any) (delete bool) 8 | 9 | type EventName string 10 | 11 | type EventManager struct { 12 | handlerCount uint64 13 | // We use a map[int]... so that deletion is as smooth as possible 14 | hs map[EventName]map[uint64]EventHandler 15 | m *sync.Mutex 16 | } 17 | 18 | func NewEventManager() *EventManager { 19 | return &EventManager{ 20 | hs: make(map[EventName]map[uint64]EventHandler), 21 | m: &sync.Mutex{}, 22 | } 23 | } 24 | 25 | func (m *EventManager) On(n EventName, h EventHandler) uint64 { 26 | // Lock 27 | m.m.Lock() 28 | defer m.m.Unlock() 29 | 30 | // Make sure event name exists 31 | if _, ok := m.hs[n]; !ok { 32 | m.hs[n] = make(map[uint64]EventHandler) 33 | } 34 | 35 | // Increment handler count 36 | m.handlerCount++ 37 | 38 | // Add handler 39 | m.hs[n][m.handlerCount] = h 40 | 41 | // Return id 42 | return m.handlerCount 43 | } 44 | 45 | func (m *EventManager) Off(id uint64) { 46 | // Lock 47 | m.m.Lock() 48 | defer m.m.Unlock() 49 | 50 | // Loop through handlers 51 | for _, ids := range m.hs { 52 | // Loop through ids 53 | for v := range ids { 54 | // Id matches 55 | if id == v { 56 | delete(ids, id) 57 | } 58 | } 59 | } 60 | } 61 | 62 | func (m *EventManager) Emit(n EventName, payload any) { 63 | // Loop through handlers 64 | for _, h := range m.handlers(n) { 65 | if h.h(payload) { 66 | m.Off(h.id) 67 | } 68 | } 69 | } 70 | 71 | type eventManagerHandler struct { 72 | h EventHandler 73 | id uint64 74 | } 75 | 76 | func (m *EventManager) handlers(n EventName) (hs []eventManagerHandler) { 77 | // Lock 78 | m.m.Lock() 79 | defer m.m.Unlock() 80 | 81 | // Index handlers 82 | hsm := make(map[uint64]eventManagerHandler) 83 | var ids []uint64 84 | if _, ok := m.hs[n]; ok { 85 | for id, h := range m.hs[n] { 86 | hsm[id] = eventManagerHandler{ 87 | h: h, 88 | id: id, 89 | } 90 | ids = append(ids, id) 91 | } 92 | } 93 | 94 | // Sort ids 95 | SortUint64(ids) 96 | 97 | // Append 98 | for _, id := range ids { 99 | hs = append(hs, hsm[id]) 100 | } 101 | return 102 | } 103 | -------------------------------------------------------------------------------- /event_test.go: -------------------------------------------------------------------------------- 1 | package astikit_test 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | 7 | "github.com/asticode/go-astikit" 8 | ) 9 | 10 | func TestEvent(t *testing.T) { 11 | const ( 12 | eventName1 astikit.EventName = "event-name-1" 13 | eventName2 astikit.EventName = "event-name-2" 14 | eventName3 astikit.EventName = "event-name-3" 15 | ) 16 | m := astikit.NewEventManager() 17 | ons := make(map[astikit.EventName][]any) 18 | m.On(eventName1, func(payload any) (delete bool) { 19 | ons[eventName1] = append(ons[eventName1], payload) 20 | return true 21 | }) 22 | id := m.On(eventName3, func(payload any) (delete bool) { 23 | ons[eventName3] = append(ons[eventName3], payload) 24 | return false 25 | }) 26 | 27 | m.Emit(eventName1, 1) 28 | m.Emit(eventName1, 2) 29 | m.Emit(eventName2, 1) 30 | m.Emit(eventName2, 2) 31 | m.Emit(eventName3, 1) 32 | m.Emit(eventName3, 2) 33 | 34 | m.Off(id) 35 | m.Emit(eventName3, 3) 36 | 37 | if e, g := map[astikit.EventName][]any{ 38 | eventName1: {1}, 39 | eventName3: {1, 2}, 40 | }, ons; !reflect.DeepEqual(e, g) { 41 | t.Fatalf("expected %+v, got %+v", e, g) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /exec.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os/exec" 7 | "strings" 8 | "sync" 9 | ) 10 | 11 | // Statuses 12 | const ( 13 | ExecStatusCrashed = "crashed" 14 | ExecStatusRunning = "running" 15 | ExecStatusStopped = "stopped" 16 | ) 17 | 18 | // ExecHandler represents an object capable of handling the execution of a cmd 19 | type ExecHandler struct { 20 | cancel context.CancelFunc 21 | ctx context.Context 22 | err error 23 | o sync.Once 24 | stopped bool 25 | } 26 | 27 | // Status returns the cmd status 28 | func (h *ExecHandler) Status() string { 29 | if h.ctx.Err() != nil { 30 | if h.stopped || h.err == nil { 31 | return ExecStatusStopped 32 | } 33 | return ExecStatusCrashed 34 | } 35 | return ExecStatusRunning 36 | } 37 | 38 | // Stop stops the cmd 39 | func (h *ExecHandler) Stop() { 40 | h.o.Do(func() { 41 | h.cancel() 42 | h.stopped = true 43 | }) 44 | } 45 | 46 | // ExecCmdOptions represents exec options 47 | type ExecCmdOptions struct { 48 | Args []string 49 | CmdAdapter func(cmd *exec.Cmd, h *ExecHandler) error 50 | Name string 51 | StopFunc func(cmd *exec.Cmd) error 52 | } 53 | 54 | // ExecCmd executes a cmd 55 | // The process will be stopped when the worker stops 56 | func ExecCmd(w *Worker, o ExecCmdOptions) (h *ExecHandler, err error) { 57 | // Create handler 58 | h = &ExecHandler{} 59 | h.ctx, h.cancel = context.WithCancel(w.Context()) 60 | 61 | // Create command 62 | cmd := exec.Command(o.Name, o.Args...) 63 | 64 | // Adapt command 65 | if o.CmdAdapter != nil { 66 | if err = o.CmdAdapter(cmd, h); err != nil { 67 | err = fmt.Errorf("astikit: adapting cmd failed: %w", err) 68 | return 69 | } 70 | } 71 | 72 | // Start 73 | w.Logger().Infof("astikit: starting %s", strings.Join(cmd.Args, " ")) 74 | if err = cmd.Start(); err != nil { 75 | err = fmt.Errorf("astikit: executing %s: %w", strings.Join(cmd.Args, " "), err) 76 | return 77 | } 78 | 79 | // Handle context 80 | go func() { 81 | // Wait for context to be done 82 | <-h.ctx.Done() 83 | 84 | // Get stop func 85 | f := func() error { return cmd.Process.Kill() } 86 | if o.StopFunc != nil { 87 | f = func() error { return o.StopFunc(cmd) } 88 | } 89 | 90 | // Stop 91 | if err := f(); err != nil { 92 | w.Logger().Error(fmt.Errorf("astikit: stopping cmd failed: %w", err)) 93 | return 94 | } 95 | }() 96 | 97 | // Execute in a task 98 | w.NewTask().Do(func() { 99 | h.err = cmd.Wait() 100 | h.cancel() 101 | w.Logger().Infof("astikit: status is now %s for %s", h.Status(), strings.Join(cmd.Args, " ")) 102 | }) 103 | return 104 | } 105 | -------------------------------------------------------------------------------- /flag.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "os" 5 | "strings" 6 | ) 7 | 8 | // FlagCmd retrieves the command from the input Args 9 | func FlagCmd() (o string) { 10 | if len(os.Args) >= 2 && os.Args[1][0] != '-' { 11 | o = os.Args[1] 12 | os.Args = append([]string{os.Args[0]}, os.Args[2:]...) 13 | } 14 | return 15 | } 16 | 17 | // FlagStrings represents a flag that can be set several times and 18 | // stores unique string values 19 | type FlagStrings struct { 20 | Map map[string]bool 21 | Slice *[]string 22 | } 23 | 24 | // NewFlagStrings creates a new FlagStrings 25 | func NewFlagStrings() FlagStrings { 26 | return FlagStrings{ 27 | Map: make(map[string]bool), 28 | Slice: &[]string{}, 29 | } 30 | } 31 | 32 | // String implements the flag.Value interface 33 | func (f FlagStrings) String() string { 34 | if f.Slice == nil { 35 | return "" 36 | } 37 | return strings.Join(*f.Slice, ",") 38 | } 39 | 40 | // Set implements the flag.Value interface 41 | func (f FlagStrings) Set(i string) error { 42 | if _, ok := f.Map[i]; ok { 43 | return nil 44 | } 45 | f.Map[i] = true 46 | *f.Slice = append(*f.Slice, i) 47 | return nil 48 | } 49 | -------------------------------------------------------------------------------- /flag_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "flag" 5 | "os" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | func TestFlagCmd(t *testing.T) { 11 | os.Args = []string{"name"} 12 | if e, g := "", FlagCmd(); g != e { 13 | t.Fatalf("expected %v, got %v", e, g) 14 | } 15 | os.Args = []string{"name", "-flag"} 16 | if e, g := "", FlagCmd(); g != e { 17 | t.Fatalf("expected %v, got %v", e, g) 18 | } 19 | os.Args = []string{"name", "cmd"} 20 | if e, g := "cmd", FlagCmd(); g != e { 21 | t.Fatalf("expected %v, got %v", e, g) 22 | } 23 | } 24 | 25 | func TestFlagStrings(t *testing.T) { 26 | f := NewFlagStrings() 27 | flag.Var(f, "t", "") 28 | flag.CommandLine.Parse([]string{"-t", "1", "-t", "2", "-t", "1"}) //nolint:errcheck 29 | if e := (FlagStrings{ 30 | Map: map[string]bool{ 31 | "1": true, 32 | "2": true, 33 | }, 34 | Slice: &[]string{"1", "2"}, 35 | }); !reflect.DeepEqual(e, f) { 36 | t.Fatalf("expected %+v, got %+v", e, f) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /float.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "strconv" 7 | ) 8 | 9 | // Rational represents a rational 10 | type Rational struct{ den, num int } 11 | 12 | // NewRational creates a new rational 13 | func NewRational(num, den int) *Rational { 14 | return &Rational{ 15 | den: den, 16 | num: num, 17 | } 18 | } 19 | 20 | // Num returns the rational num 21 | func (r *Rational) Num() int { 22 | return r.num 23 | } 24 | 25 | // Den returns the rational den 26 | func (r *Rational) Den() int { 27 | return r.den 28 | } 29 | 30 | // ToFloat64 returns the rational as a float64 31 | func (r *Rational) ToFloat64() float64 { 32 | return float64(r.num) / float64(r.den) 33 | } 34 | 35 | // MarshalText implements the TextMarshaler interface 36 | func (r *Rational) MarshalText() (b []byte, err error) { 37 | b = []byte(fmt.Sprintf("%d/%d", r.num, r.den)) 38 | return 39 | } 40 | 41 | // UnmarshalText implements the TextUnmarshaler interface 42 | func (r *Rational) UnmarshalText(b []byte) (err error) { 43 | r.num = 0 44 | r.den = 1 45 | if len(b) == 0 { 46 | return 47 | } 48 | items := bytes.Split(b, []byte("/")) 49 | if r.num, err = strconv.Atoi(string(items[0])); err != nil { 50 | err = fmt.Errorf("astikit: atoi of %s failed: %w", string(items[0]), err) 51 | return 52 | } 53 | if len(items) > 1 { 54 | if r.den, err = strconv.Atoi(string(items[1])); err != nil { 55 | err = fmt.Errorf("astifloat: atoi of %s failed: %w", string(items[1]), err) 56 | return 57 | } 58 | } 59 | return 60 | } 61 | -------------------------------------------------------------------------------- /float_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestRational(t *testing.T) { 8 | r := &Rational{} 9 | err := r.UnmarshalText([]byte("")) 10 | if err != nil { 11 | t.Fatalf("expected no error, got %+v", err) 12 | } 13 | if e, g := 0.0, r.ToFloat64(); e != g { 14 | t.Fatalf("expected %+v, got %+v", e, g) 15 | } 16 | err = r.UnmarshalText([]byte("test")) 17 | if err == nil { 18 | t.Fatal("expected error, got nil") 19 | } 20 | err = r.UnmarshalText([]byte("1/test")) 21 | if err == nil { 22 | t.Fatal("expected error, got nil") 23 | } 24 | err = r.UnmarshalText([]byte("0")) 25 | if err != nil { 26 | t.Fatalf("expected no error, got %+v", err) 27 | } 28 | if e, g := 0, r.Num(); e != g { 29 | t.Fatalf("expected %+v, got %+v", e, g) 30 | } 31 | if e, g := 1, r.Den(); e != g { 32 | t.Fatalf("expected %+v, got %+v", e, g) 33 | } 34 | err = r.UnmarshalText([]byte("1/2")) 35 | if err != nil { 36 | t.Fatalf("expected no error, got %+v", err) 37 | } 38 | if e, g := 1, r.Num(); e != g { 39 | t.Fatalf("expected %+v, got %+v", e, g) 40 | } 41 | if e, g := 2, r.Den(); e != g { 42 | t.Fatalf("expected %+v, got %+v", e, g) 43 | } 44 | if e, g := 0.5, r.ToFloat64(); e != g { 45 | t.Fatalf("expected %+v, got %+v", e, g) 46 | } 47 | r = NewRational(1, 2) 48 | b, err := r.MarshalText() 49 | if err != nil { 50 | t.Fatalf("expected no error, got %+v", err) 51 | } 52 | if e, g := "1/2", string(b); e != g { 53 | t.Fatalf("expected %s, got %s", e, g) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/asticode/go-astikit 2 | 3 | go 1.18 4 | -------------------------------------------------------------------------------- /http_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "errors" 8 | "io" 9 | "net" 10 | "net/http" 11 | "path/filepath" 12 | "reflect" 13 | "strings" 14 | "testing" 15 | "time" 16 | ) 17 | 18 | func TestServeHTTP(t *testing.T) { 19 | w := NewWorker(WorkerOptions{}) 20 | ln, err := net.Listen("tcp", "127.0.0.1:") 21 | if err != nil { 22 | t.Fatalf("expected no error, got %+v", err) 23 | } 24 | ln.Close() 25 | var i int 26 | ServeHTTP(w, ServeHTTPOptions{ 27 | Addr: ln.Addr().String(), 28 | Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 29 | i++ 30 | w.Stop() 31 | }), 32 | }) 33 | s := time.Now() 34 | for { 35 | if time.Since(s) > time.Second { 36 | t.Fatal("timed out") 37 | } 38 | 39 | _, err := http.DefaultClient.Get("http://" + ln.Addr().String()) 40 | if err != nil { 41 | time.Sleep(10 * time.Millisecond) 42 | continue 43 | } 44 | break 45 | } 46 | w.Wait() 47 | if e := 1; i != e { 48 | t.Fatalf("expected %+v, got %+v", e, i) 49 | } 50 | } 51 | 52 | type mockedHTTPClient func(req *http.Request) (*http.Response, error) 53 | 54 | func (c mockedHTTPClient) Do(req *http.Request) (*http.Response, error) { return c(req) } 55 | 56 | type mockedNetError struct{ timeout bool } 57 | 58 | func (err mockedNetError) Error() string { return "" } 59 | func (err mockedNetError) Timeout() bool { return err.timeout } 60 | func (err mockedNetError) Temporary() bool { return false } 61 | 62 | type mockedHTTPBody struct { 63 | closed bool 64 | } 65 | 66 | func (b *mockedHTTPBody) Read([]byte) (int, error) { 67 | return 0, nil 68 | } 69 | 70 | func (b *mockedHTTPBody) Close() error { 71 | b.closed = true 72 | return nil 73 | } 74 | 75 | func TestHTTPSender(t *testing.T) { 76 | // All errors 77 | var c int 78 | var bs []*mockedHTTPBody 79 | s1 := NewHTTPSender(HTTPSenderOptions{ 80 | Client: mockedHTTPClient(func(req *http.Request) (resp *http.Response, err error) { 81 | c++ 82 | b := &mockedHTTPBody{} 83 | bs = append(bs, b) 84 | resp = &http.Response{ 85 | Body: b, 86 | StatusCode: http.StatusInternalServerError, 87 | } 88 | return 89 | }), 90 | RetryMax: 3, 91 | }) 92 | if _, err := s1.Send(&http.Request{}); err != nil { 93 | t.Fatalf("expected no error, got %+v", err) 94 | } 95 | if e := 4; c != e { 96 | t.Fatalf("expected %v, got %v", e, c) 97 | } 98 | if e, g := 4, len(bs); e != g { 99 | t.Fatalf("expected %v, got %v", e, g) 100 | } 101 | for i := 0; i < len(bs)-1; i++ { 102 | if !bs[i].closed { 103 | t.Fatalf("body #%d is not closed", i+1) 104 | } 105 | } 106 | 107 | // Successful after retries 108 | bs = []*mockedHTTPBody{} 109 | c = 0 110 | s2 := NewHTTPSender(HTTPSenderOptions{ 111 | Client: mockedHTTPClient(func(req *http.Request) (resp *http.Response, err error) { 112 | c++ 113 | switch c { 114 | case 1: 115 | b := &mockedHTTPBody{} 116 | bs = append(bs, b) 117 | resp = &http.Response{ 118 | Body: b, 119 | StatusCode: http.StatusInternalServerError, 120 | } 121 | case 2: 122 | err = mockedNetError{timeout: true} 123 | default: 124 | // No retrying 125 | b := &mockedHTTPBody{} 126 | bs = append(bs, b) 127 | resp = &http.Response{ 128 | Body: b, 129 | StatusCode: http.StatusBadRequest, 130 | } 131 | } 132 | return 133 | }), 134 | RetryMax: 3, 135 | }) 136 | if _, err := s2.Send(&http.Request{}); err != nil { 137 | t.Fatalf("expected no error, got %+v", err) 138 | } 139 | if e := 3; c != e { 140 | t.Fatalf("expected %v, got %v", e, c) 141 | } 142 | if e, g := 2, len(bs); e != g { 143 | t.Fatalf("expected %v, got %v", e, g) 144 | } 145 | for i := 0; i < len(bs)-1; i++ { 146 | if !bs[i].closed { 147 | t.Fatalf("body #%d is not closed", i+1) 148 | } 149 | } 150 | 151 | // JSON 152 | var ( 153 | ebe = "error" 154 | ebi = "body-in" 155 | ebo = "body-out" 156 | ehi = map[string]string{ 157 | "K1": "v1", 158 | "K2": "v2", 159 | } 160 | eho = http.Header{"k1": []string{"v1"}} 161 | eu = "https://domain.com/url" 162 | ) 163 | var gu, gbi string 164 | ghi := make(map[string]string) 165 | s3 := NewHTTPSender(HTTPSenderOptions{ 166 | Client: mockedHTTPClient(func(req *http.Request) (resp *http.Response, err error) { 167 | switch req.Method { 168 | case http.MethodHead: 169 | for k, v := range req.Header { 170 | ghi[k] = strings.Join(v, ",") 171 | } 172 | gu = req.URL.String() 173 | resp = &http.Response{ 174 | Body: io.NopCloser(&bytes.Buffer{}), 175 | Header: eho, 176 | StatusCode: http.StatusBadRequest, 177 | } 178 | case http.MethodPost: 179 | json.NewDecoder(req.Body).Decode(&gbi) //nolint:errcheck 180 | resp = &http.Response{Body: io.NopCloser(bytes.NewBuffer([]byte("\"" + ebe + "\""))), StatusCode: http.StatusBadRequest} 181 | case http.MethodGet: 182 | resp = &http.Response{Body: io.NopCloser(bytes.NewBuffer([]byte("\"" + ebo + "\""))), StatusCode: http.StatusOK} 183 | } 184 | return 185 | }), 186 | }) 187 | var gho http.Header 188 | errTest := errors.New("test") 189 | var isce HTTPSenderInvalidStatusCodeError 190 | if err := s3.SendJSON(HTTPSendJSONOptions{ 191 | HeadersIn: ehi, 192 | HeadersOut: func(h http.Header) { gho = h }, 193 | Method: http.MethodHead, 194 | StatusCodeFunc: func(code int) error { 195 | if code == http.StatusBadRequest { 196 | return errTest 197 | } 198 | return nil 199 | }, 200 | URL: eu, 201 | }); err == nil { 202 | t.Fatal("expected error, got nil") 203 | } else if !errors.Is(err, errTest) { 204 | t.Fatal("expected true, got false") 205 | } else if !errors.As(err, &isce) { 206 | t.Fatal("expected true, got false") 207 | } 208 | if e, g := (HTTPSenderInvalidStatusCodeError{ 209 | Err: errTest, 210 | StatusCode: http.StatusBadRequest, 211 | }), isce; !reflect.DeepEqual(e, g) { 212 | t.Fatalf("expected %+v, got %+v", e, g) 213 | } 214 | if !reflect.DeepEqual(ehi, ghi) { 215 | t.Fatalf("expected %+v, got %+v", ehi, ghi) 216 | } 217 | if !reflect.DeepEqual(eho, gho) { 218 | t.Fatalf("expected %+v, got %+v", eho, gho) 219 | } 220 | if gu != eu { 221 | t.Fatalf("expected %s, got %s", eu, gu) 222 | } 223 | var gbe string 224 | if err := s3.SendJSON(HTTPSendJSONOptions{ 225 | BodyError: &gbe, 226 | BodyIn: ebi, 227 | Method: http.MethodPost, 228 | }); !errors.Is(err, ErrHTTPSenderUnmarshaledError) { 229 | t.Fatalf("expected ErrHTTPSenderUnmarshaledError, got %s", err) 230 | } 231 | if gbe != ebe { 232 | t.Fatalf("expected %s, got %s", ebe, gbe) 233 | } 234 | if gbi != ebi { 235 | t.Fatalf("expected %s, got %s", ebi, gbi) 236 | } 237 | var gbo string 238 | if err := s3.SendJSON(HTTPSendJSONOptions{ 239 | BodyOut: &gbo, 240 | Method: http.MethodGet, 241 | }); err != nil { 242 | t.Fatalf("expected no error, got %s", err) 243 | } 244 | if gbo != ebo { 245 | t.Fatalf("expected %s, got %s", ebo, gbo) 246 | } 247 | 248 | // Timeout 249 | bs = []*mockedHTTPBody{} 250 | timeoutMockedHTTPClient := mockedHTTPClient(func(req *http.Request) (resp *http.Response, err error) { 251 | ctx, cancel := context.WithCancel(req.Context()) 252 | defer cancel() 253 | <-ctx.Done() 254 | b := &mockedHTTPBody{} 255 | bs = append(bs, b) 256 | resp = &http.Response{Body: b} 257 | return 258 | }) 259 | s4 := NewHTTPSender(HTTPSenderOptions{Client: timeoutMockedHTTPClient}) 260 | if _, err := s4.SendWithTimeout(&http.Request{}, time.Millisecond); err == nil { 261 | t.Fatal("expected error, got nil") 262 | } 263 | if err := s4.SendJSON(HTTPSendJSONOptions{Timeout: time.Millisecond}); err == nil { 264 | t.Fatal("expected error, got nil") 265 | } 266 | if e, g := 2, len(bs); e != g { 267 | t.Fatalf("expected %v, got %v", e, g) 268 | } 269 | for i, b := range bs { 270 | if !b.closed { 271 | t.Fatalf("body #%d is not closed", i+1) 272 | } 273 | } 274 | // Make sure reading response body doesn't fail if timeout is not reached 275 | if err := s3.SendJSON(HTTPSendJSONOptions{ 276 | BodyOut: &gbo, 277 | Method: http.MethodGet, 278 | Timeout: time.Hour, 279 | }); err != nil { 280 | t.Fatalf("expected no error, got %s", err) 281 | } 282 | 283 | // Context 284 | ctx, cancel := context.WithCancel(context.Background()) 285 | cancel() 286 | ctxCheckerMockedHTTPClient := mockedHTTPClient(func(req *http.Request) (resp *http.Response, err error) { 287 | return &http.Response{}, req.Context().Err() 288 | }) 289 | s5 := NewHTTPSender(HTTPSenderOptions{Client: ctxCheckerMockedHTTPClient}) 290 | if err := s5.SendJSON(HTTPSendJSONOptions{Context: ctx}); !errors.Is(err, context.Canceled) { 291 | t.Fatalf("expected context cancelled error, got %s", err) 292 | } 293 | } 294 | 295 | func TestHTTPDownloader(t *testing.T) { 296 | // Get temp dir 297 | dir := t.TempDir() 298 | 299 | // Create downloader 300 | d := NewHTTPDownloader(HTTPDownloaderOptions{ 301 | Limiter: GoroutineLimiterOptions{Max: 2}, 302 | Sender: HTTPSenderOptions{ 303 | Client: mockedHTTPClient(func(req *http.Request) (resp *http.Response, err error) { 304 | // In case of DownloadInWriter we want to check if the order is kept event 305 | // if downloaded order is messed up 306 | if req.URL.EscapedPath() == "/path/to/2" { 307 | time.Sleep(time.Millisecond) 308 | } 309 | resp = &http.Response{ 310 | Body: io.NopCloser(bytes.NewBufferString(req.URL.EscapedPath())), 311 | StatusCode: http.StatusOK, 312 | } 313 | return 314 | }), 315 | }, 316 | }) 317 | defer d.Close() 318 | 319 | // Download in directory 320 | err := d.DownloadInDirectory(context.Background(), dir, 321 | HTTPDownloaderSrc{URL: "/path/to/1"}, 322 | HTTPDownloaderSrc{URL: "/path/to/2"}, 323 | HTTPDownloaderSrc{URL: "/path/to/3"}, 324 | ) 325 | if err != nil { 326 | t.Fatalf("expected no error, got %+v", err) 327 | } 328 | checkDir(t, dir, map[string]string{ 329 | "/1": "/path/to/1", 330 | "/2": "/path/to/2", 331 | "/3": "/path/to/3", 332 | }) 333 | 334 | // Download in writer 335 | w := &bytes.Buffer{} 336 | err = d.DownloadInWriter(context.Background(), w, 337 | HTTPDownloaderSrc{URL: "/path/to/1"}, 338 | HTTPDownloaderSrc{URL: "/path/to/2"}, 339 | HTTPDownloaderSrc{URL: "/path/to/3"}, 340 | ) 341 | if err != nil { 342 | t.Fatalf("expected no error, got %+v", err) 343 | } 344 | if e, g := "/path/to/1/path/to/2/path/to/3", w.String(); e != g { 345 | t.Fatalf("expected %s, got %s", e, g) 346 | } 347 | 348 | // Download in file 349 | p := filepath.Join(dir, "f") 350 | err = d.DownloadInFile(context.Background(), p, 351 | HTTPDownloaderSrc{URL: "/path/to/1"}, 352 | HTTPDownloaderSrc{URL: "/path/to/2"}, 353 | HTTPDownloaderSrc{URL: "/path/to/3"}, 354 | ) 355 | if err != nil { 356 | t.Fatalf("expected no error, got %+v", err) 357 | } 358 | checkFile(t, p, "/path/to/1/path/to/2/path/to/3") 359 | } 360 | -------------------------------------------------------------------------------- /io.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "io" 8 | "sync" 9 | "time" 10 | ) 11 | 12 | // Copy is a copy with a context 13 | func Copy(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) { 14 | return io.Copy(dst, NewCtxReader(ctx, src)) 15 | } 16 | 17 | type nopCloser struct { 18 | io.Writer 19 | } 20 | 21 | func (nopCloser) Close() error { return nil } 22 | 23 | // NopCloser returns a WriteCloser with a no-op Close method wrapping 24 | // the provided Writer w. 25 | func NopCloser(w io.Writer) io.WriteCloser { 26 | return nopCloser{w} 27 | } 28 | 29 | // CtxReader represents a reader with a context 30 | type CtxReader struct { 31 | ctx context.Context 32 | reader io.Reader 33 | } 34 | 35 | // NewCtxReader creates a reader with a context 36 | func NewCtxReader(ctx context.Context, r io.Reader) *CtxReader { 37 | return &CtxReader{ 38 | ctx: ctx, 39 | reader: r, 40 | } 41 | } 42 | 43 | // Read implements the io.Reader interface 44 | func (r *CtxReader) Read(p []byte) (n int, err error) { 45 | // Check context 46 | if err = r.ctx.Err(); err != nil { 47 | return 48 | } 49 | 50 | // Read 51 | return r.reader.Read(p) 52 | } 53 | 54 | // WriterAdapter represents an object that can adapt a Writer 55 | type WriterAdapter struct { 56 | buffer *bytes.Buffer 57 | o WriterAdapterOptions 58 | } 59 | 60 | // WriterAdapterOptions represents WriterAdapter options 61 | type WriterAdapterOptions struct { 62 | Callback func(i []byte) 63 | Split []byte 64 | } 65 | 66 | // NewWriterAdapter creates a new WriterAdapter 67 | func NewWriterAdapter(o WriterAdapterOptions) *WriterAdapter { 68 | return &WriterAdapter{ 69 | buffer: &bytes.Buffer{}, 70 | o: o, 71 | } 72 | } 73 | 74 | // Close closes the adapter properly 75 | func (w *WriterAdapter) Close() error { 76 | if w.buffer.Len() > 0 { 77 | w.write(w.buffer.Bytes()) 78 | } 79 | return nil 80 | } 81 | 82 | // Write implements the io.Writer interface 83 | func (w *WriterAdapter) Write(i []byte) (n int, err error) { 84 | // Update n to avoid broken pipe error 85 | defer func() { 86 | n = len(i) 87 | }() 88 | 89 | // Split 90 | if len(w.o.Split) > 0 { 91 | // Split bytes are not present, write in buffer 92 | if !bytes.Contains(i, w.o.Split) { 93 | w.buffer.Write(i) 94 | return 95 | } 96 | 97 | // Loop in split items 98 | items := bytes.Split(i, w.o.Split) 99 | for i := 0; i < len(items)-1; i++ { 100 | // If this is the first item, prepend the buffer 101 | if i == 0 { 102 | items[i] = append(w.buffer.Bytes(), items[i]...) 103 | w.buffer.Reset() 104 | } 105 | 106 | // Write 107 | w.write(items[i]) 108 | } 109 | 110 | // Add remaining to buffer 111 | w.buffer.Write(items[len(items)-1]) 112 | return 113 | } 114 | 115 | // By default, forward the bytes 116 | w.write(i) 117 | return 118 | } 119 | 120 | func (w *WriterAdapter) write(i []byte) { 121 | if w.o.Callback != nil { 122 | w.o.Callback(i) 123 | } 124 | } 125 | 126 | // Piper doesn't block on writes. It will block on reads unless you provide a ReadTimeout in which case 127 | // it will return an optional error, after the provided timeout, if no read is available. When closing the 128 | // piper, it will interrupt any ongoing read/future writes and return io.EOF. 129 | // Piper doesn't handle multiple readers at the same time. 130 | type Piper struct { 131 | buf [][]byte 132 | c *sync.Cond 133 | closed bool 134 | o PiperOptions 135 | m sync.Mutex 136 | } 137 | 138 | type PiperOptions struct { 139 | ReadTimeout time.Duration 140 | ReadTimeoutError error 141 | } 142 | 143 | func NewPiper(o PiperOptions) *Piper { 144 | return &Piper{ 145 | c: sync.NewCond(&sync.Mutex{}), 146 | o: o, 147 | } 148 | } 149 | 150 | func (p *Piper) Close() error { 151 | // Update closed 152 | p.m.Lock() 153 | if p.closed { 154 | p.m.Unlock() 155 | return nil 156 | } 157 | p.closed = true 158 | p.m.Unlock() 159 | 160 | // Signal 161 | p.c.L.Lock() 162 | p.c.Signal() 163 | p.c.L.Unlock() 164 | return nil 165 | } 166 | 167 | func (p *Piper) Read(i []byte) (n int, err error) { 168 | // Handle read timeout 169 | var ctx context.Context 170 | if p.o.ReadTimeout > 0 { 171 | // Create context 172 | var cancel context.CancelFunc 173 | ctx, cancel = context.WithTimeout(context.Background(), p.o.ReadTimeout) 174 | defer cancel() 175 | 176 | // Watch the context in a goroutine 177 | go func() { 178 | // Wait for context to be done 179 | <-ctx.Done() 180 | 181 | // Context has timed out 182 | if errors.Is(ctx.Err(), context.DeadlineExceeded) { 183 | // Signal 184 | p.c.L.Lock() 185 | p.c.Signal() 186 | p.c.L.Unlock() 187 | } 188 | }() 189 | } 190 | 191 | // Loop 192 | for { 193 | // Check context 194 | if ctx != nil && ctx.Err() != nil { 195 | return 0, p.o.ReadTimeoutError 196 | } 197 | 198 | // Lock 199 | p.c.L.Lock() 200 | p.m.Lock() 201 | 202 | // Closed 203 | if p.closed { 204 | p.m.Unlock() 205 | p.c.L.Unlock() 206 | return 0, io.EOF 207 | } 208 | 209 | // Get buffer length 210 | l := len(p.buf) 211 | p.m.Unlock() 212 | 213 | // Nothing in the buffer, we need to wait 214 | if l == 0 { 215 | p.c.Wait() 216 | p.c.L.Unlock() 217 | continue 218 | } 219 | p.c.L.Unlock() 220 | 221 | // Copy 222 | p.m.Lock() 223 | n = len(p.buf[0]) 224 | copy(i, p.buf[0]) 225 | p.buf = p.buf[1:] 226 | p.m.Unlock() 227 | return 228 | } 229 | } 230 | 231 | func (p *Piper) Write(i []byte) (n int, err error) { 232 | // Closed 233 | p.m.Lock() 234 | if p.closed { 235 | p.m.Unlock() 236 | return 0, io.EOF 237 | } 238 | p.m.Unlock() 239 | 240 | // Copy 241 | b := make([]byte, len(i)) 242 | copy(b, i) 243 | 244 | // Append 245 | p.m.Lock() 246 | p.buf = append(p.buf, b) 247 | p.m.Unlock() 248 | 249 | // Signal 250 | p.c.L.Lock() 251 | p.c.Signal() 252 | p.c.L.Unlock() 253 | return len(b), nil 254 | } 255 | -------------------------------------------------------------------------------- /io_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "io" 8 | "reflect" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | func TestCopy(t *testing.T) { 14 | // Context canceled 15 | ctx, cancel := context.WithCancel(context.Background()) 16 | cancel() 17 | r, w := bytes.NewBuffer([]byte("bla bla bla")), &bytes.Buffer{} 18 | n, err := Copy(ctx, w, r) 19 | if e := int64(0); n != e { 20 | t.Fatalf("expected %v, got %v", e, n) 21 | } 22 | if e := context.Canceled; !errors.Is(err, e) { 23 | t.Fatalf("error should be %+v, got %+v", e, err) 24 | } 25 | 26 | // Default 27 | n, err = Copy(context.Background(), w, r) 28 | if e := int64(11); n != e { 29 | t.Fatalf("expected %v, got %v", e, n) 30 | } 31 | if err != nil { 32 | t.Fatalf("expected no error, got %+v", err) 33 | } 34 | } 35 | 36 | func TestWriterAdapter(t *testing.T) { 37 | // Init 38 | var o []string 39 | var w = NewWriterAdapter(WriterAdapterOptions{ 40 | Callback: func(i []byte) { 41 | o = append(o, string(i)) 42 | }, 43 | Split: []byte("\n"), 44 | }) 45 | 46 | // No Split 47 | w.Write([]byte("bla bla ")) //nolint:errcheck 48 | if len(o) != 0 { 49 | t.Fatalf("expected %v, got %v", 0, len(o)) 50 | } 51 | 52 | // Multi Split 53 | w.Write([]byte("bla \nbla bla\nbla")) //nolint:errcheck 54 | if e := []string{"bla bla bla ", "bla bla"}; !reflect.DeepEqual(o, e) { 55 | t.Fatalf("expected %+v, got %+v", e, o) 56 | } 57 | 58 | // Close 59 | w.Close() 60 | if e := []string{"bla bla bla ", "bla bla", "bla"}; !reflect.DeepEqual(o, e) { 61 | t.Fatalf("expected %+v, got %+v", e, o) 62 | } 63 | } 64 | 65 | func TestPiper(t *testing.T) { 66 | p1 := NewPiper(PiperOptions{}) 67 | defer p1.Close() 68 | 69 | // Piper shouldn't block on write 70 | w := []byte("test") 71 | ctx1, cancel1 := context.WithTimeout(context.Background(), time.Second) 72 | defer cancel1() 73 | var n int 74 | var err error 75 | go func() { 76 | defer cancel1() 77 | n, err = p1.Write(w) 78 | }() 79 | <-ctx1.Done() 80 | if errCtx := ctx1.Err(); errors.Is(errCtx, context.DeadlineExceeded) { 81 | t.Fatalf("expected no deadline exceeded error, got %+v", errCtx) 82 | } 83 | if err != nil { 84 | t.Fatalf("expected no error, got %+v", err) 85 | } 86 | if e, g := 4, n; e != g { 87 | t.Fatalf("expected %d, got %d", e, g) 88 | } 89 | r := make([]byte, 10) 90 | n, err = p1.Read(r) 91 | if err != nil { 92 | t.Fatalf("expected no error, got %+v", err) 93 | } 94 | if e, g := 4, n; e != g { 95 | t.Fatalf("expected %d, got %d", e, g) 96 | } 97 | if e, g := w, r[:n]; !bytes.Equal(e, g) { 98 | t.Fatalf("expected %s, got %s", e, g) 99 | } 100 | 101 | // Piper should block on read unless write or piper is closed 102 | ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond) 103 | defer cancel2() 104 | ctx3, cancel3 := context.WithTimeout(context.Background(), time.Second) 105 | defer cancel3() 106 | r = make([]byte, 10) 107 | go func() { 108 | defer cancel2() 109 | defer cancel3() 110 | n, err = p1.Read(r) 111 | }() 112 | <-ctx2.Done() 113 | if errCtx := ctx2.Err(); !errors.Is(errCtx, context.DeadlineExceeded) { 114 | t.Fatalf("expected deadline exceeded error, got %+v", errCtx) 115 | } 116 | _, errWrite := p1.Write(w) 117 | if errWrite != nil { 118 | t.Fatalf("expected no error, got %+v", errWrite) 119 | } 120 | <-ctx3.Done() 121 | if errCtx := ctx3.Err(); errors.Is(errCtx, context.DeadlineExceeded) { 122 | t.Fatalf("expected no deadline exceeded error, got %+v", errCtx) 123 | } 124 | if err != nil { 125 | t.Fatalf("expected no error, got %+v", err) 126 | } 127 | if e, g := 4, n; e != g { 128 | t.Fatalf("expected %d, got %d", e, g) 129 | } 130 | if e, g := w, r[:n]; !bytes.Equal(e, g) { 131 | t.Fatalf("expected %s, got %s", e, g) 132 | } 133 | ctx4, cancel4 := context.WithTimeout(context.Background(), 100*time.Millisecond) 134 | defer cancel4() 135 | ctx5, cancel5 := context.WithTimeout(context.Background(), time.Second) 136 | defer cancel5() 137 | go func() { 138 | defer cancel4() 139 | defer cancel5() 140 | _, err = p1.Read(r) 141 | }() 142 | <-ctx4.Done() 143 | if errCtx := ctx4.Err(); !errors.Is(errCtx, context.DeadlineExceeded) { 144 | t.Fatalf("expected deadline exceeded error, got %+v", errCtx) 145 | } 146 | p1.Close() 147 | <-ctx5.Done() 148 | if errCtx := ctx5.Err(); errors.Is(errCtx, context.DeadlineExceeded) { 149 | t.Fatalf("expected no deadline exceeded error, got %+v", errCtx) 150 | } 151 | if !errors.Is(err, io.EOF) { 152 | t.Fatalf("expected io.EOF error, got %+v", err) 153 | } 154 | _, err = p1.Write(w) 155 | if !errors.Is(err, io.EOF) { 156 | t.Fatalf("expected io.EOF error, got %+v", err) 157 | } 158 | 159 | // Piper should timeout on read if a read timeout is provided 160 | e1 := errors.New("1") 161 | p2 := NewPiper(PiperOptions{ 162 | ReadTimeout: time.Millisecond, 163 | ReadTimeoutError: e1, 164 | }) 165 | defer p2.Close() 166 | ctx6, cancel6 := context.WithTimeout(context.Background(), time.Second) 167 | defer cancel6() 168 | go func() { 169 | defer cancel6() 170 | _, err = p2.Read(r) 171 | }() 172 | <-ctx6.Done() 173 | if errCtx := ctx6.Err(); errors.Is(errCtx, context.DeadlineExceeded) { 174 | t.Fatalf("expected no deadline exceeded error, got %+v", errCtx) 175 | } 176 | if !errors.Is(err, e1) { 177 | t.Fatalf("expected %s, got %s", e1, err) 178 | } 179 | } 180 | -------------------------------------------------------------------------------- /ipc/posix/posix.c: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | int astikit_close(int fd, int *errno_ptr) 9 | { 10 | int ret = close(fd); 11 | if (ret < 0) 12 | { 13 | *errno_ptr = errno; 14 | } 15 | return ret; 16 | } 17 | 18 | int astikit_fstat(int fd, struct stat *s, int *errno_ptr) 19 | { 20 | int ret = fstat(fd, s); 21 | if (ret < 0) 22 | { 23 | *errno_ptr = errno; 24 | } 25 | return ret; 26 | } 27 | 28 | int astikit_ftruncate(int fd, off_t length, int *errno_ptr) 29 | { 30 | int ret = ftruncate(fd, length); 31 | if (ret < 0) 32 | { 33 | *errno_ptr = errno; 34 | } 35 | return ret; 36 | } 37 | 38 | void *astikit_mmap(size_t length, int fd, int *errno_ptr) 39 | { 40 | void *addr = mmap(NULL, length, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); 41 | if (addr == MAP_FAILED) 42 | { 43 | *errno_ptr = errno; 44 | return NULL; 45 | } 46 | return addr; 47 | } 48 | 49 | int astikit_munmap(void *addr, size_t length, int *errno_ptr) 50 | { 51 | int ret = munmap(addr, length); 52 | if (ret < 0) 53 | { 54 | *errno_ptr = errno; 55 | } 56 | return ret; 57 | } 58 | 59 | int astikit_shm_open(char *name, int flags, mode_t mode, int *errno_ptr) 60 | { 61 | int fd = shm_open(name, flags, mode); 62 | if (fd < 0) 63 | { 64 | *errno_ptr = errno; 65 | } 66 | return fd; 67 | } 68 | 69 | int astikit_shm_unlink(char *name, int *errno_ptr) 70 | { 71 | int ret = shm_unlink(name); 72 | if (ret < 0) 73 | { 74 | *errno_ptr = errno; 75 | } 76 | return ret; 77 | } -------------------------------------------------------------------------------- /ipc/posix/posix.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | 3 | package astiposix 4 | 5 | //#include 6 | //#include 7 | //#include "posix.h" 8 | import "C" 9 | import ( 10 | "errors" 11 | "fmt" 12 | "os" 13 | "strconv" 14 | "strings" 15 | "sync" 16 | "syscall" 17 | "unsafe" 18 | ) 19 | 20 | type SharedMemory struct { 21 | addr unsafe.Pointer 22 | cname string 23 | fd *C.int 24 | name string 25 | size int 26 | unlink bool 27 | } 28 | 29 | func newSharedMemory(name string, flags, mode int, cb func(shm *SharedMemory) error) (shm *SharedMemory, err error) { 30 | // Create shared memory 31 | shm = &SharedMemory{name: name} 32 | 33 | // To have a similar behavior with python, we need to handle the leading slash the same way: 34 | // - make sure the "public" name has no leading "/" 35 | // - make sure the "internal" name has a leading "/" 36 | shm.name = strings.TrimPrefix(shm.name, "/") 37 | shm.cname = "/" + shm.name 38 | 39 | // Get c name 40 | cname := C.CString(shm.cname) 41 | defer C.free(unsafe.Pointer(cname)) 42 | 43 | // Get file descriptor 44 | var errno C.int 45 | fd := C.astikit_shm_open(cname, C.int(flags), C.mode_t(mode), &errno) 46 | if fd < 0 { 47 | err = fmt.Errorf("astikit: shm_open failed: %w", syscall.Errno(errno)) 48 | return 49 | } 50 | shm.fd = &fd 51 | 52 | // Make sure to close shared memory in case of error 53 | defer func() { 54 | if err != nil { 55 | shm.Close() 56 | } 57 | }() 58 | 59 | // Callback 60 | if cb != nil { 61 | if err = cb(shm); err != nil { 62 | err = fmt.Errorf("astikit: callback failed: %w", err) 63 | return 64 | } 65 | } 66 | 67 | // Get size 68 | var stat C.struct_stat 69 | if ret := C.astikit_fstat(*shm.fd, &stat, &errno); ret < 0 { 70 | err = fmt.Errorf("astikit: fstat failed: %w", syscall.Errno(errno)) 71 | return 72 | } 73 | shm.size = int(stat.st_size) 74 | 75 | // Map memory 76 | addr := C.astikit_mmap(C.size_t(shm.size), *shm.fd, &errno) 77 | if addr == nil { 78 | err = fmt.Errorf("astikit: mmap failed: %w", syscall.Errno(errno)) 79 | return 80 | } 81 | 82 | // Update addr 83 | shm.addr = addr 84 | return 85 | } 86 | 87 | func CreateSharedMemory(name string, size int) (*SharedMemory, error) { 88 | return newSharedMemory(name, os.O_CREATE|os.O_EXCL|os.O_RDWR, 0600, func(shm *SharedMemory) (err error) { 89 | // Shared memory needs to be unlink on close 90 | shm.unlink = true 91 | 92 | // Truncate 93 | var errno C.int 94 | if ret := C.astikit_ftruncate(*shm.fd, C.off_t(size), &errno); ret < 0 { 95 | err = fmt.Errorf("astikit: ftruncate failed: %w", syscall.Errno(errno)) 96 | return 97 | } 98 | return 99 | }) 100 | } 101 | 102 | func OpenSharedMemory(name string) (*SharedMemory, error) { 103 | return newSharedMemory(name, os.O_RDWR, 0600, nil) 104 | } 105 | 106 | func (shm *SharedMemory) Close() error { 107 | // Unlink 108 | if shm.unlink { 109 | // Get c name 110 | cname := C.CString(shm.cname) 111 | defer C.free(unsafe.Pointer(cname)) 112 | 113 | // Unlink 114 | var errno C.int 115 | if ret := C.astikit_shm_unlink(cname, &errno); ret < 0 { 116 | return fmt.Errorf("astikit: unlink failed: %w", syscall.Errno(errno)) 117 | } 118 | shm.unlink = false 119 | } 120 | 121 | // Unmap memory 122 | if shm.addr != nil { 123 | var errno C.int 124 | if ret := C.astikit_munmap(shm.addr, C.size_t(shm.size), &errno); ret < 0 { 125 | return fmt.Errorf("astikit: munmap failed: %w", syscall.Errno(errno)) 126 | } 127 | shm.addr = nil 128 | } 129 | 130 | // Close file descriptor 131 | if shm.fd != nil { 132 | var errno C.int 133 | if ret := C.astikit_close(*shm.fd, &errno); ret < 0 { 134 | return fmt.Errorf("astikit: close failed: %w", syscall.Errno(errno)) 135 | } 136 | shm.fd = nil 137 | } 138 | return nil 139 | } 140 | 141 | func (shm *SharedMemory) Write(src unsafe.Pointer, size int) error { 142 | // Unmapped 143 | if shm.addr == nil { 144 | return errors.New("astikit: shared memory is unmapped") 145 | } 146 | 147 | // Copy 148 | C.memcpy(shm.addr, src, C.size_t(size)) 149 | return nil 150 | } 151 | 152 | func (shm *SharedMemory) WriteBytes(b []byte) error { 153 | // Get c bytes 154 | cb := C.CBytes(b) 155 | defer C.free(cb) 156 | 157 | // Write 158 | return shm.Write(cb, len(b)) 159 | } 160 | 161 | func (shm *SharedMemory) ReadBytes(size int) ([]byte, error) { 162 | // Unmapped 163 | if shm.addr == nil { 164 | return nil, errors.New("astikit: shared memory is unmapped") 165 | } 166 | 167 | // Get bytes 168 | return C.GoBytes(shm.addr, C.int(size)), nil 169 | } 170 | 171 | func (shm *SharedMemory) Name() string { 172 | return shm.name 173 | } 174 | 175 | func (shm *SharedMemory) Size() int { 176 | return shm.size 177 | } 178 | 179 | func (shm *SharedMemory) Addr() unsafe.Pointer { 180 | return shm.addr 181 | } 182 | 183 | type VariableSizeSharedMemoryWriter struct { 184 | m sync.Mutex // Locks write operations 185 | prefix string 186 | shm *SharedMemory 187 | } 188 | 189 | func NewVariableSizeSharedMemoryWriter(prefix string) *VariableSizeSharedMemoryWriter { 190 | return &VariableSizeSharedMemoryWriter{prefix: prefix} 191 | } 192 | 193 | func (w *VariableSizeSharedMemoryWriter) closeSharedMemory() { 194 | if w.shm != nil { 195 | w.shm.Close() 196 | } 197 | } 198 | 199 | func (w *VariableSizeSharedMemoryWriter) Close() { 200 | w.closeSharedMemory() 201 | } 202 | 203 | func (w *VariableSizeSharedMemoryWriter) Write(src unsafe.Pointer, size int) (ro VariableSizeSharedMemoryReadOptions, err error) { 204 | // Lock 205 | w.m.Lock() 206 | defer w.m.Unlock() 207 | 208 | // Shared memory has not yet been created or previous shared memory segment is too small 209 | if w.shm == nil || size > w.shm.Size() { 210 | // Close previous shared memory 211 | w.closeSharedMemory() 212 | 213 | // Create shared memory 214 | var shm *SharedMemory 215 | if shm, err = CreateSharedMemory(w.prefix+"-"+strconv.Itoa(size), size); err != nil { 216 | err = fmt.Errorf("astikit: creating shared memory failed: %w", err) 217 | return 218 | } 219 | 220 | // Store shared memory 221 | w.shm = shm 222 | } 223 | 224 | // Write 225 | if err = w.shm.Write(src, size); err != nil { 226 | err = fmt.Errorf("astikit: writing to shared memory failed: %w", err) 227 | return 228 | } 229 | 230 | // Create read options 231 | ro = VariableSizeSharedMemoryReadOptions{ 232 | Name: w.shm.Name(), 233 | Size: size, 234 | } 235 | return 236 | } 237 | 238 | func (w *VariableSizeSharedMemoryWriter) WriteBytes(b []byte) (VariableSizeSharedMemoryReadOptions, error) { 239 | // Get c bytes 240 | cb := C.CBytes(b) 241 | defer C.free(cb) 242 | 243 | // Write 244 | return w.Write(cb, len(b)) 245 | } 246 | 247 | type VariableSizeSharedMemoryReader struct { 248 | m sync.Mutex // Locks read operations 249 | shm *SharedMemory 250 | } 251 | 252 | func NewVariableSizeSharedMemoryReader() *VariableSizeSharedMemoryReader { 253 | return &VariableSizeSharedMemoryReader{} 254 | } 255 | 256 | func (r *VariableSizeSharedMemoryReader) closeSharedMemory() { 257 | if r.shm != nil { 258 | r.shm.Close() 259 | } 260 | } 261 | 262 | func (r *VariableSizeSharedMemoryReader) Close() { 263 | r.closeSharedMemory() 264 | } 265 | 266 | type VariableSizeSharedMemoryReadOptions struct { 267 | Name string `json:"name"` 268 | Size int `json:"size"` 269 | } 270 | 271 | func (r *VariableSizeSharedMemoryReader) ReadBytes(o VariableSizeSharedMemoryReadOptions) (b []byte, err error) { 272 | // Lock 273 | r.m.Lock() 274 | defer r.m.Unlock() 275 | 276 | // Shared memory has not yet been opened or shared memory's name has changed 277 | if r.shm == nil || r.shm.Name() != o.Name { 278 | // Close previous shared memory 279 | r.closeSharedMemory() 280 | 281 | // Open shared memory 282 | var shm *SharedMemory 283 | if shm, err = OpenSharedMemory(o.Name); err != nil { 284 | err = fmt.Errorf("astikit: opening shared memory failed: %w", err) 285 | return 286 | } 287 | 288 | // Store attributes 289 | r.shm = shm 290 | } 291 | 292 | // Copy 293 | b = make([]byte, o.Size) 294 | C.memcpy(unsafe.Pointer(&b[0]), r.shm.Addr(), C.size_t(o.Size)) 295 | return 296 | } 297 | -------------------------------------------------------------------------------- /ipc/posix/posix.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | int astikit_close(int fd, int *errno_ptr); 5 | int astikit_fstat(int fd, struct stat *s, int *errno_ptr); 6 | int astikit_ftruncate(int fd, off_t length, int *errno_ptr); 7 | void *astikit_mmap(size_t length, int fd, int *errno_ptr); 8 | int astikit_munmap(void *addr, size_t length, int *errno_ptr); 9 | int astikit_shm_open(char *name, int flags, mode_t mode, int *errno_ptr); 10 | int astikit_shm_unlink(char *name, int *errno_ptr); -------------------------------------------------------------------------------- /ipc/posix/posix_test.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | 3 | package astiposix 4 | 5 | import ( 6 | "bytes" 7 | "testing" 8 | ) 9 | 10 | func TestSharedMemory(t *testing.T) { 11 | sm1, err := CreateSharedMemory("/test", 8) 12 | if err != nil { 13 | t.Fatalf("expected no error, got %s", err) 14 | } 15 | defer sm1.Close() 16 | if sm1.Addr() == nil { 17 | t.Fatal("expected not nil, got nil") 18 | } 19 | if g := sm1.Size(); g <= 0 { 20 | t.Fatalf("expected > 0, got %d", g) 21 | } 22 | if e, g := "test", sm1.Name(); e != g { 23 | t.Fatalf("expected %v, got %v", e, g) 24 | } 25 | if _, err = CreateSharedMemory("/test", 8); err == nil { 26 | t.Fatal("expected error, got nil") 27 | } 28 | 29 | b1 := []byte("test") 30 | if err := sm1.WriteBytes(b1); err != nil { 31 | t.Fatalf("expected no error, got %s", err) 32 | } 33 | 34 | sm2, err := OpenSharedMemory("test") 35 | if err != nil { 36 | t.Fatalf("expected no error, got %s", err) 37 | } 38 | defer sm2.Close() 39 | b2, err := sm2.ReadBytes(len(b1)) 40 | if err != nil { 41 | t.Fatalf("expected no error, got %s", err) 42 | } 43 | if e, g := b1, b2; !bytes.Equal(b1, b2) { 44 | t.Fatalf("expected %s, got %s", e, g) 45 | } 46 | 47 | if err = sm1.Close(); err != nil { 48 | t.Fatalf("expected no error, got %s", err) 49 | } 50 | if err = sm1.WriteBytes(b1); err == nil { 51 | t.Fatal("expected error, got nil") 52 | } 53 | if err = sm1.Close(); err != nil { 54 | t.Fatalf("expected no error, got %s", err) 55 | } 56 | 57 | if err = sm2.Close(); err != nil { 58 | t.Fatalf("expected no error, got %s", err) 59 | } 60 | if _, err = sm2.ReadBytes(len(b1)); err == nil { 61 | t.Fatal("expected error, got nil") 62 | } 63 | if err = sm2.Close(); err != nil { 64 | t.Fatalf("expected no error, got %s", err) 65 | } 66 | } 67 | 68 | func TestVariableSizeSharedMemory(t *testing.T) { 69 | w := NewVariableSizeSharedMemoryWriter("test-1") 70 | defer w.Close() 71 | r := NewVariableSizeSharedMemoryReader() 72 | defer r.Close() 73 | 74 | b1 := []byte("test") 75 | ro1, err := w.WriteBytes(b1) 76 | if err != nil { 77 | t.Fatalf("expected no error, got %s", err) 78 | } 79 | if e, g := w.shm.Name(), ro1.Name; e != g { 80 | t.Fatalf("expected %s, got %s", e, g) 81 | } 82 | if e, g := len(b1), ro1.Size; e != g { 83 | t.Fatalf("expected %d, got %d", e, g) 84 | } 85 | b2, err := r.ReadBytes(ro1) 86 | if err != nil { 87 | t.Fatalf("expected no error, got %s", err) 88 | } 89 | if !bytes.Equal(b1, b2) { 90 | t.Fatalf("expected %s, got %s", b1, b2) 91 | } 92 | 93 | b3 := make([]byte, w.shm.Size()+1) 94 | b3[0] = 'a' 95 | b3[len(b3)-1] = 'b' 96 | ro2, err := w.WriteBytes(b3) 97 | if err != nil { 98 | t.Fatalf("expected no error, got %s", err) 99 | } 100 | if ro1.Name == ro2.Name { 101 | t.Fatal("expected different, got equalt") 102 | } 103 | b4, err := r.ReadBytes(ro2) 104 | if err != nil { 105 | t.Fatalf("expected no error, got %s", err) 106 | } 107 | if !bytes.Equal(b3, b4) { 108 | t.Fatalf("expected %s, got %s", b3, b4) 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /ipc/systemv/systemv.c: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | int astikit_ftok(char *path, int project_id, int *errno_ptr) 9 | { 10 | int key = ftok(path, project_id); 11 | if (key < 0) 12 | { 13 | *errno_ptr = errno; 14 | } 15 | return key; 16 | } 17 | 18 | int astikit_sem_get(key_t key, int flags, int *errno_ptr) 19 | { 20 | int id = semget(key, 1, flags); 21 | if (id < 0) 22 | { 23 | *errno_ptr = errno; 24 | } 25 | return id; 26 | } 27 | 28 | int astikit_sem_close(int id, int *errno_ptr) 29 | { 30 | int ret = semctl(id, 0, IPC_RMID); 31 | if (ret < 0) 32 | { 33 | *errno_ptr = errno; 34 | } 35 | return ret; 36 | } 37 | 38 | // "0" means the resource is free 39 | // "1" means the resource is being used 40 | 41 | int astikit_sem_lock(int id, int *errno_ptr) 42 | { 43 | struct sembuf operations[2]; 44 | 45 | // Wait for the value to be 0 46 | operations[0].sem_num = 0; 47 | operations[0].sem_op = 0; 48 | operations[0].sem_flg = 0; 49 | 50 | // Increment the value 51 | operations[1].sem_num = 0; 52 | operations[1].sem_op = 1; 53 | operations[1].sem_flg = 0; 54 | 55 | int ret = semop(id, operations, 2); 56 | if (ret < 0) 57 | { 58 | *errno_ptr = errno; 59 | } 60 | return ret; 61 | } 62 | 63 | int astikit_sem_unlock(int id, int *errno_ptr) 64 | { 65 | struct sembuf operations[1]; 66 | 67 | // Decrement the value 68 | operations[0].sem_num = 0; 69 | operations[0].sem_op = -1; 70 | operations[0].sem_flg = 0; 71 | 72 | int ret = semop(id, operations, 1); 73 | if (ret < 0) 74 | { 75 | *errno_ptr = errno; 76 | } 77 | return ret; 78 | } 79 | 80 | int astikit_shm_get(key_t key, int size, int flags, int *errno_ptr) 81 | { 82 | int id = shmget(key, size, flags); 83 | if (id < 0) 84 | { 85 | *errno_ptr = errno; 86 | } 87 | return id; 88 | } 89 | 90 | void *astikit_shm_at(int id, int *errno_ptr) 91 | { 92 | void *addr = shmat(id, NULL, 0); 93 | if (addr == (void *)-1) 94 | { 95 | *errno_ptr = errno; 96 | return NULL; 97 | } 98 | return addr; 99 | } 100 | 101 | int astikit_shm_close(int id, const void *addr, int *errno_ptr) 102 | { 103 | int ret; 104 | if (addr != NULL) 105 | { 106 | ret = shmdt(addr); 107 | if (ret < 0) 108 | { 109 | *errno_ptr = errno; 110 | return ret; 111 | } 112 | } 113 | ret = shmctl(id, IPC_RMID, NULL); 114 | if (ret < 0) 115 | { 116 | *errno_ptr = errno; 117 | } 118 | return ret; 119 | } -------------------------------------------------------------------------------- /ipc/systemv/systemv.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | 3 | package astisystemv 4 | 5 | //#include 6 | //#include 7 | //#include 8 | //#include "systemv.h" 9 | import "C" 10 | import ( 11 | "errors" 12 | "fmt" 13 | "sync" 14 | "syscall" 15 | "time" 16 | "unsafe" 17 | 18 | "github.com/asticode/go-astikit" 19 | ) 20 | 21 | func NewKey(projectID int, path string) (int, error) { 22 | // Get c path 23 | cpath := C.CString(path) 24 | defer C.free(unsafe.Pointer(cpath)) 25 | 26 | // Get key 27 | var errno C.int 28 | key := C.astikit_ftok(cpath, C.int(projectID), &errno) 29 | if key < 0 { 30 | return 0, fmt.Errorf("astikit: ftok failed: %s", syscall.Errno(errno)) 31 | } 32 | return int(key), nil 33 | } 34 | 35 | const ( 36 | IpcCreate = C.IPC_CREAT 37 | IpcExclusive = C.IPC_EXCL 38 | ) 39 | 40 | type Semaphore struct { 41 | id C.int 42 | key int 43 | } 44 | 45 | func newSemaphore(key int, flags int) (*Semaphore, error) { 46 | // Get id 47 | var errno C.int 48 | id := C.astikit_sem_get(C.int(key), C.int(flags), &errno) 49 | if id < 0 { 50 | return nil, fmt.Errorf("astikit: sem_get failed: %w", syscall.Errno(errno)) 51 | } 52 | return &Semaphore{ 53 | id: id, 54 | key: key, 55 | }, nil 56 | } 57 | 58 | func CreateSemaphore(key, flags int) (*Semaphore, error) { 59 | return newSemaphore(key, flags) 60 | } 61 | 62 | func OpenSemaphore(key int) (*Semaphore, error) { 63 | return newSemaphore(key, 0) 64 | } 65 | 66 | func (s *Semaphore) Close() error { 67 | // Already closed 68 | if s.id == -1 { 69 | return nil 70 | } 71 | 72 | // Close 73 | var errno C.int 74 | if ret := C.astikit_sem_close(s.id, &errno); ret < 0 { 75 | return fmt.Errorf("astikit: sem_close failed: %w", syscall.Errno(errno)) 76 | } 77 | 78 | // Update 79 | s.id = -1 80 | s.key = -1 81 | return nil 82 | } 83 | 84 | func (s *Semaphore) Lock() error { 85 | // Closed 86 | if s.id == -1 { 87 | return errors.New("astikit: semaphore is closed") 88 | } 89 | 90 | // Lock 91 | var errno C.int 92 | ret := C.astikit_sem_lock(s.id, &errno) 93 | if ret < 0 { 94 | return fmt.Errorf("astikit: sem_lock failed: %w", syscall.Errno(errno)) 95 | } 96 | return nil 97 | } 98 | 99 | func (s *Semaphore) Unlock() error { 100 | // Closed 101 | if s.id == -1 { 102 | return errors.New("astikit: semaphore is closed") 103 | } 104 | 105 | // Unlock 106 | var errno C.int 107 | ret := C.astikit_sem_unlock(s.id, &errno) 108 | if ret < 0 { 109 | return fmt.Errorf("astikit: sem_unlock failed: %w", syscall.Errno(errno)) 110 | } 111 | return nil 112 | } 113 | 114 | func (s *Semaphore) Key() int { 115 | return s.key 116 | } 117 | 118 | type SharedMemory struct { 119 | addr unsafe.Pointer 120 | id C.int 121 | key int 122 | } 123 | 124 | func newSharedMemory(key, size int, flags int) (shm *SharedMemory, err error) { 125 | // Get id 126 | var errno C.int 127 | id := C.astikit_shm_get(C.int(key), C.int(size), C.int(flags), &errno) 128 | if id < 0 { 129 | err = fmt.Errorf("astikit: shm_get failed: %w", syscall.Errno(errno)) 130 | return 131 | } 132 | 133 | // Create shared memory 134 | shm = &SharedMemory{ 135 | id: id, 136 | key: key, 137 | } 138 | 139 | // Make sure to close shared memory in case of error 140 | defer func() { 141 | if err != nil { 142 | shm.Close() 143 | } 144 | }() 145 | 146 | // Attach 147 | addr := C.astikit_shm_at(C.int(id), &errno) 148 | if addr == nil { 149 | err = fmt.Errorf("astikit: shm_at failed: %w", syscall.Errno(errno)) 150 | return 151 | } 152 | 153 | // Update addr 154 | shm.addr = addr 155 | return 156 | } 157 | 158 | func CreateSharedMemory(key, size, flags int) (*SharedMemory, error) { 159 | return newSharedMemory(key, size, flags) 160 | } 161 | 162 | func OpenSharedMemory(key int) (*SharedMemory, error) { 163 | return newSharedMemory(key, 0, 0) 164 | } 165 | 166 | func (shm *SharedMemory) Close() error { 167 | // Already closed 168 | if shm.id == -1 { 169 | return nil 170 | } 171 | 172 | // Close 173 | var errno C.int 174 | if ret := C.astikit_shm_close(shm.id, shm.addr, &errno); ret < 0 { 175 | return fmt.Errorf("astikit: shm_close failed: %w", syscall.Errno(errno)) 176 | } 177 | 178 | // Update 179 | shm.addr = nil 180 | shm.id = -1 181 | shm.key = -1 182 | return nil 183 | } 184 | 185 | func (shm *SharedMemory) Write(src unsafe.Pointer, size int) error { 186 | // Closed 187 | if shm.id == -1 { 188 | return errors.New("astikit: shared memory is closed") 189 | } 190 | 191 | // Copy 192 | C.memcpy(shm.addr, src, C.size_t(size)) 193 | return nil 194 | } 195 | 196 | func (shm *SharedMemory) WriteBytes(b []byte) error { 197 | // Get c bytes 198 | cb := C.CBytes(b) 199 | defer C.free(cb) 200 | 201 | // Write 202 | return shm.Write(cb, len(b)) 203 | } 204 | 205 | func (shm *SharedMemory) Addr() unsafe.Pointer { 206 | return shm.addr 207 | } 208 | 209 | func (shm *SharedMemory) Key() int { 210 | return shm.key 211 | } 212 | 213 | func (shm *SharedMemory) ReadBytes(size int) ([]byte, error) { 214 | // Closed 215 | if shm.id == -1 { 216 | return nil, errors.New("astikit: shared memory is closed") 217 | } 218 | 219 | // Get bytes 220 | return C.GoBytes(shm.addr, C.int(size)), nil 221 | } 222 | 223 | type SemaphoredSharedMemoryWriter struct { 224 | m sync.Mutex // Locks write operations 225 | sem *Semaphore 226 | shm *SharedMemory 227 | shmAt int64 228 | shmSize int 229 | } 230 | 231 | func NewSemaphoredSharedMemoryWriter() *SemaphoredSharedMemoryWriter { 232 | return &SemaphoredSharedMemoryWriter{} 233 | } 234 | 235 | func (w *SemaphoredSharedMemoryWriter) closeSemaphore() { 236 | if w.sem != nil { 237 | w.sem.Close() 238 | } 239 | } 240 | 241 | func (w *SemaphoredSharedMemoryWriter) closeSharedMemory() { 242 | if w.shm != nil { 243 | w.shm.Close() 244 | } 245 | } 246 | 247 | func (w *SemaphoredSharedMemoryWriter) Close() { 248 | w.closeSemaphore() 249 | w.closeSharedMemory() 250 | } 251 | 252 | func (w *SemaphoredSharedMemoryWriter) generateRandomKey(f func(key int) error) error { 253 | try := 0 254 | for { 255 | key := int(int32(astikit.RandSource.Int63())) 256 | if key == int(C.IPC_PRIVATE) { 257 | continue 258 | } 259 | err := f(key) 260 | if errors.Is(err, syscall.EEXIST) { 261 | if try++; try < 10000 { 262 | continue 263 | } 264 | return errors.New("astikit: max tries reached") 265 | } 266 | return err 267 | } 268 | } 269 | 270 | func (w *SemaphoredSharedMemoryWriter) Write(src unsafe.Pointer, size int) (ro *SemaphoredSharedMemoryReadOptions, err error) { 271 | // Lock 272 | w.m.Lock() 273 | defer w.m.Unlock() 274 | 275 | // Shared memory has not been created or previous shared memory segment is too small, 276 | // we need to allocate a new shared memory segment 277 | if w.shm == nil || size > w.shmSize { 278 | // Close previous shared memory 279 | w.closeSharedMemory() 280 | 281 | // Generate random key 282 | if err = w.generateRandomKey(func(key int) (err error) { 283 | // Create shared memory 284 | var shm *SharedMemory 285 | if shm, err = CreateSharedMemory(key, size, IpcCreate|IpcExclusive|0666); err != nil { 286 | err = fmt.Errorf("astikit: creating shared memory failed: %w", err) 287 | return 288 | } 289 | 290 | // Store attributes 291 | w.shm = shm 292 | w.shmAt = time.Now().UnixNano() 293 | w.shmSize = size 294 | return 295 | }); err != nil { 296 | err = fmt.Errorf("astikit: generating random key failed: %w", err) 297 | return 298 | } 299 | } 300 | 301 | // Semaphore has not been created 302 | if w.sem == nil { 303 | // Generate random key 304 | if err = w.generateRandomKey(func(key int) (err error) { 305 | // Create semaphore 306 | var sem *Semaphore 307 | if sem, err = CreateSemaphore(key, IpcCreate|IpcExclusive|0666); err != nil { 308 | err = fmt.Errorf("astikit: creating semaphore failed: %w", err) 309 | return 310 | } 311 | 312 | // Store attributes 313 | w.sem = sem 314 | return 315 | }); err != nil { 316 | err = fmt.Errorf("astikit: generating random key failed: %w", err) 317 | return 318 | } 319 | } 320 | 321 | // Lock 322 | if err = w.sem.Lock(); err != nil { 323 | err = fmt.Errorf("astikit: locking semaphore failed: %w", err) 324 | return 325 | } 326 | 327 | // Write 328 | if err = w.shm.Write(src, size); err != nil { 329 | err = fmt.Errorf("astikit: writing to shared memory failed: %w", err) 330 | return 331 | } 332 | 333 | // Unlock 334 | if err = w.sem.Unlock(); err != nil { 335 | err = fmt.Errorf("astikit: unlocking semaphore failed: %w", err) 336 | return 337 | } 338 | 339 | // Create read options 340 | ro = &SemaphoredSharedMemoryReadOptions{ 341 | SemaphoreKey: w.sem.Key(), 342 | SharedMemoryAt: w.shmAt, 343 | SharedMemoryKey: w.shm.Key(), 344 | Size: size, 345 | } 346 | return 347 | } 348 | 349 | func (w *SemaphoredSharedMemoryWriter) WriteBytes(b []byte) (*SemaphoredSharedMemoryReadOptions, error) { 350 | // Get c bytes 351 | cb := C.CBytes(b) 352 | defer C.free(cb) 353 | 354 | // Write 355 | return w.Write(cb, len(b)) 356 | } 357 | 358 | type SemaphoredSharedMemoryReader struct { 359 | m sync.Mutex // Locks read operations 360 | sem *Semaphore 361 | shm *SharedMemory 362 | shmAt int64 363 | } 364 | 365 | func NewSemaphoredSharedMemoryReader() *SemaphoredSharedMemoryReader { 366 | return &SemaphoredSharedMemoryReader{} 367 | } 368 | 369 | func (r *SemaphoredSharedMemoryReader) closeSemaphore() { 370 | if r.sem != nil { 371 | r.sem.Close() 372 | } 373 | } 374 | 375 | func (r *SemaphoredSharedMemoryReader) closeSharedMemory() { 376 | if r.shm != nil { 377 | r.shm.Close() 378 | } 379 | } 380 | 381 | func (r *SemaphoredSharedMemoryReader) Close() { 382 | r.closeSemaphore() 383 | r.closeSharedMemory() 384 | } 385 | 386 | type SemaphoredSharedMemoryReadOptions struct { 387 | SemaphoreKey int 388 | SharedMemoryAt int64 389 | SharedMemoryKey int 390 | Size int 391 | } 392 | 393 | func (r *SemaphoredSharedMemoryReader) ReadBytes(o *SemaphoredSharedMemoryReadOptions) (b []byte, err error) { 394 | // Lock 395 | r.m.Lock() 396 | defer r.m.Unlock() 397 | 398 | // Shared memory is not opened or shared memory has changed 399 | if r.shm == nil || r.shm.Key() != o.SharedMemoryKey || r.shmAt != o.SharedMemoryAt { 400 | // Close previous shared memory 401 | r.closeSharedMemory() 402 | 403 | // Open shared memory 404 | var shm *SharedMemory 405 | if shm, err = OpenSharedMemory(o.SharedMemoryKey); err != nil { 406 | err = fmt.Errorf("astikit: opening shared memory failed: %w", err) 407 | return 408 | } 409 | 410 | // Store attributes 411 | r.shm = shm 412 | r.shmAt = o.SharedMemoryAt 413 | } 414 | 415 | // Semaphore is not opened 416 | if r.sem == nil { 417 | // Close previous semaphore 418 | r.closeSemaphore() 419 | 420 | // Open semaphore 421 | var sem *Semaphore 422 | if sem, err = OpenSemaphore(o.SemaphoreKey); err != nil { 423 | err = fmt.Errorf("astikit: opening semaphore failed: %w", err) 424 | return 425 | } 426 | 427 | // Store attributes 428 | r.sem = sem 429 | } 430 | 431 | // Lock 432 | if err = r.sem.Lock(); err != nil { 433 | err = fmt.Errorf("astikit: locking semaphore failed: %w", err) 434 | return 435 | } 436 | 437 | // Copy 438 | b = make([]byte, o.Size) 439 | C.memcpy(unsafe.Pointer(&b[0]), r.shm.Addr(), C.size_t(o.Size)) 440 | 441 | // Unlock 442 | if err = r.sem.Unlock(); err != nil { 443 | err = fmt.Errorf("astikit: unlocking semaphore failed: %w", err) 444 | return 445 | } 446 | return 447 | } 448 | -------------------------------------------------------------------------------- /ipc/systemv/systemv.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | int astikit_ftok(char *path, int project_id, int *errno_ptr); 4 | int astikit_sem_get(key_t key, int flags, int *errno_ptr); 5 | int astikit_sem_close(int id, int *errno_ptr); 6 | int astikit_sem_lock(int id, int *errno_ptr); 7 | int astikit_sem_unlock(int id, int *errno_ptr); 8 | void *astikit_shm_at(int id, int *errno_ptr); 9 | int astikit_shm_get(key_t key, int size, int flags, int *errno_ptr); 10 | int astikit_shm_close(int id, const void *addr, int *errno_ptr); -------------------------------------------------------------------------------- /ipc/systemv/systemv_test.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | 3 | package astisystemv 4 | 5 | import ( 6 | "bytes" 7 | "testing" 8 | ) 9 | 10 | func TestNewKey(t *testing.T) { 11 | _, err := NewKey(1, "../../testdata/ipc/invalid") 12 | if err == nil { 13 | t.Fatal("expected an error, got none") 14 | } 15 | if _, err = NewKey(1, "../../testdata/ipc/f"); err != nil { 16 | t.Fatalf("expected no error, got %s", err) 17 | } 18 | } 19 | 20 | func TestSemaphore(t *testing.T) { 21 | const key = 1 22 | s1, err := CreateSemaphore(key, IpcCreate|IpcExclusive|0666) 23 | if err != nil { 24 | t.Fatalf("expected no error, got %s", err) 25 | } 26 | defer s1.Close() 27 | if e, g := key, s1.Key(); e != g { 28 | t.Fatalf("expected %v, got %v", e, g) 29 | } 30 | if err = s1.Lock(); err != nil { 31 | t.Fatalf("expected no error, got %s", err) 32 | } 33 | if err = s1.Unlock(); err != nil { 34 | t.Fatalf("expected no error, got %s", err) 35 | } 36 | s2, err := OpenSemaphore(key) 37 | if err != nil { 38 | t.Fatalf("expected no error, got %s", err) 39 | } 40 | defer s2.Close() 41 | if e, g := key, s2.Key(); e != g { 42 | t.Fatalf("expected %v, got %v", e, g) 43 | } 44 | if err = s2.Lock(); err != nil { 45 | t.Fatalf("expected no error, got %s", err) 46 | } 47 | if err = s2.Unlock(); err != nil { 48 | t.Fatalf("expected no error, got %s", err) 49 | } 50 | if err = s1.Close(); err != nil { 51 | t.Fatalf("expected no error, got %s", err) 52 | } 53 | if err = s1.Lock(); err == nil { 54 | t.Fatal("expected error, got nil") 55 | } 56 | if err = s1.Unlock(); err == nil { 57 | t.Fatal("expected error, got nil") 58 | } 59 | if err = s1.Close(); err != nil { 60 | t.Fatalf("expected no error, got %s", err) 61 | } 62 | if err = s2.Close(); err == nil { 63 | t.Fatal("expected error, got nil") 64 | } 65 | if err = s2.Lock(); err == nil { 66 | t.Fatal("expected error, got nil") 67 | } 68 | if err = s2.Unlock(); err == nil { 69 | t.Fatal("expected error, got nil") 70 | } 71 | } 72 | 73 | func TestSharedMemory(t *testing.T) { 74 | const key = 1 75 | sm1, err := CreateSharedMemory(key, 10, IpcCreate|IpcExclusive|0666) 76 | if err != nil { 77 | t.Fatalf("expected no error, got %s", err) 78 | } 79 | defer sm1.Close() 80 | if sm1.Addr() == nil { 81 | t.Fatal("expected not nil, got nil") 82 | } 83 | if e, g := key, sm1.Key(); e != g { 84 | t.Fatalf("expected %v, got %v", e, g) 85 | } 86 | b1 := []byte("test") 87 | if err := sm1.WriteBytes(b1); err != nil { 88 | t.Fatalf("expected no error, got %s", err) 89 | } 90 | sm2, err := OpenSharedMemory(key) 91 | if err != nil { 92 | t.Fatalf("expected no error, got %s", err) 93 | } 94 | defer sm2.Close() 95 | b2, err := sm2.ReadBytes(len(b1)) 96 | if err != nil { 97 | t.Fatalf("expected no error, got %s", err) 98 | } 99 | if e, g := b1, b2; !bytes.Equal(b1, b2) { 100 | t.Fatalf("expected %s, got %s", e, g) 101 | } 102 | if err = sm1.Close(); err != nil { 103 | t.Fatalf("expected no error, got %s", err) 104 | } 105 | if err = sm1.WriteBytes(b1); err == nil { 106 | t.Fatal("expected error, got nil") 107 | } 108 | if err = sm1.Close(); err != nil { 109 | t.Fatalf("expected no error, got %s", err) 110 | } 111 | } 112 | 113 | func TestSemaphoredSharedMemory(t *testing.T) { 114 | w1 := NewSemaphoredSharedMemoryWriter() 115 | defer w1.Close() 116 | w2 := NewSemaphoredSharedMemoryWriter() 117 | defer w2.Close() 118 | r1 := NewSemaphoredSharedMemoryReader() 119 | defer r1.Close() 120 | r2 := NewSemaphoredSharedMemoryReader() 121 | defer r2.Close() 122 | 123 | b1 := []byte("test") 124 | semKeys := make(map[int]bool) 125 | shmAts := make(map[*SemaphoredSharedMemoryWriter]int64) 126 | shmKeys := make(map[int]bool) 127 | for _, v := range []struct { 128 | r *SemaphoredSharedMemoryReader 129 | w *SemaphoredSharedMemoryWriter 130 | }{ 131 | { 132 | r: r1, 133 | w: w1, 134 | }, 135 | { 136 | r: r2, 137 | w: w2, 138 | }, 139 | } { 140 | ro, err := v.w.WriteBytes(b1) 141 | if err != nil { 142 | t.Fatalf("expected no error, got %s", err) 143 | } 144 | if e, g := len(b1), ro.Size; e != g { 145 | t.Fatalf("expected %d, got %d", e, g) 146 | } 147 | if e, g := v.w.sem.Key(), ro.SemaphoreKey; e != g { 148 | t.Fatalf("expected %d, got %d", e, g) 149 | } 150 | if _, ok := semKeys[ro.SemaphoreKey]; ok { 151 | t.Fatal("expected false, got true") 152 | } 153 | semKeys[ro.SemaphoreKey] = true 154 | if g := ro.SharedMemoryAt; g <= 0 { 155 | t.Fatalf("expected > 0, got %d", g) 156 | } 157 | shmAts[v.w] = ro.SharedMemoryAt 158 | if e, g := v.w.shm.Key(), ro.SharedMemoryKey; e != g { 159 | t.Fatalf("expected %d, got %d", e, g) 160 | } 161 | if _, ok := shmKeys[ro.SharedMemoryKey]; ok { 162 | t.Fatal("expected false, got true") 163 | } 164 | shmKeys[ro.SharedMemoryKey] = true 165 | 166 | b, err := v.r.ReadBytes(ro) 167 | if err != nil { 168 | t.Fatalf("expected no error, got %s", err) 169 | } 170 | if !bytes.Equal(b1, b) { 171 | t.Fatalf("expected %s, got %s", b1, b) 172 | } 173 | } 174 | 175 | b3 := append(b1, []byte("1")...) 176 | ro, err := w1.WriteBytes(b3) 177 | if err != nil { 178 | t.Fatalf("expected no error, got %s", err) 179 | } 180 | at, ok := shmAts[w1] 181 | if !ok { 182 | t.Fatal("expected false, got true") 183 | } 184 | if ne, g := at, ro.SharedMemoryAt; ne == g { 185 | t.Fatalf("didn't expect %d, got %d", ne, g) 186 | } 187 | 188 | b4, err := r1.ReadBytes(ro) 189 | if err != nil { 190 | t.Fatalf("expected no error, got %s", err) 191 | } 192 | if !bytes.Equal(b3, b4) { 193 | t.Fatalf("expected %s, got %s", b3, b4) 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /json.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | ) 8 | 9 | func JSONEqual(a, b any) bool { 10 | ba, err := json.Marshal(a) 11 | if err != nil { 12 | return false 13 | } 14 | bb, err := json.Marshal(b) 15 | if err != nil { 16 | return false 17 | } 18 | return bytes.Equal(ba, bb) 19 | } 20 | 21 | func JSONClone(src, dst any) (err error) { 22 | // Marshal 23 | var b []byte 24 | if b, err = json.Marshal(src); err != nil { 25 | err = fmt.Errorf("main: marshaling failed: %w", err) 26 | return 27 | } 28 | 29 | // Unmarshal 30 | if err = json.Unmarshal(b, dst); err != nil { 31 | err = fmt.Errorf("main: unmarshaling failed: %w", err) 32 | return 33 | } 34 | return 35 | } 36 | -------------------------------------------------------------------------------- /json_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import "testing" 4 | 5 | type jsonA struct { 6 | A string `json:"a"` 7 | } 8 | 9 | type jsonB struct { 10 | B string `json:"a"` 11 | } 12 | 13 | func TestJSONClone(t *testing.T) { 14 | a := jsonA{A: "a"} 15 | b := &jsonB{} 16 | err := JSONClone(a, b) 17 | if err != nil { 18 | t.Fatalf("expected no error, got %s", err) 19 | } 20 | if !JSONEqual(a, b) { 21 | t.Fatal("expected true, got false") 22 | } 23 | } 24 | 25 | func TestJSONEqual(t *testing.T) { 26 | if JSONEqual(jsonA{A: "a"}, jsonB{B: "b"}) { 27 | t.Fatal("expected false, got true") 28 | } 29 | if !JSONEqual(jsonA{A: "a"}, jsonB{B: "a"}) { 30 | t.Fatal("expected true, got false") 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /limiter.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | // Limiter represents a limiter 10 | type Limiter struct { 11 | buckets map[string]*LimiterBucket 12 | m *sync.Mutex // Locks buckets 13 | } 14 | 15 | // NewLimiter creates a new limiter 16 | func NewLimiter() *Limiter { 17 | return &Limiter{ 18 | buckets: make(map[string]*LimiterBucket), 19 | m: &sync.Mutex{}, 20 | } 21 | } 22 | 23 | // Add adds a new bucket 24 | func (l *Limiter) Add(name string, cap int, period time.Duration) *LimiterBucket { 25 | l.m.Lock() 26 | defer l.m.Unlock() 27 | if _, ok := l.buckets[name]; !ok { 28 | l.buckets[name] = newLimiterBucket(cap, period) 29 | } 30 | return l.buckets[name] 31 | } 32 | 33 | // Bucket retrieves a bucket from the limiter 34 | func (l *Limiter) Bucket(name string) (b *LimiterBucket, ok bool) { 35 | l.m.Lock() 36 | defer l.m.Unlock() 37 | b, ok = l.buckets[name] 38 | return 39 | } 40 | 41 | // Close closes the limiter properly 42 | func (l *Limiter) Close() { 43 | l.m.Lock() 44 | defer l.m.Unlock() 45 | for _, b := range l.buckets { 46 | b.Close() 47 | } 48 | } 49 | 50 | // LimiterBucket represents a limiter bucket 51 | type LimiterBucket struct { 52 | cancel context.CancelFunc 53 | cap int 54 | ctx context.Context 55 | count int 56 | m sync.Mutex // Locks count 57 | period time.Duration 58 | o *sync.Once 59 | } 60 | 61 | // newLimiterBucket creates a new bucket 62 | func newLimiterBucket(cap int, period time.Duration) (b *LimiterBucket) { 63 | b = &LimiterBucket{ 64 | cap: cap, 65 | count: 0, 66 | period: period, 67 | o: &sync.Once{}, 68 | } 69 | b.ctx, b.cancel = context.WithCancel(context.Background()) 70 | go b.tick() 71 | return 72 | } 73 | 74 | // Inc increments the bucket count 75 | func (b *LimiterBucket) Inc() bool { 76 | b.m.Lock() 77 | defer b.m.Unlock() 78 | if b.count >= b.cap { 79 | return false 80 | } 81 | b.count++ 82 | return true 83 | } 84 | 85 | // tick runs a ticker to purge the bucket 86 | func (b *LimiterBucket) tick() { 87 | var t = time.NewTicker(b.period) 88 | defer t.Stop() 89 | for { 90 | select { 91 | case <-t.C: 92 | b.m.Lock() 93 | b.count = 0 94 | b.m.Unlock() 95 | case <-b.ctx.Done(): 96 | return 97 | } 98 | } 99 | } 100 | 101 | // close closes the bucket properly 102 | func (b *LimiterBucket) Close() { 103 | b.o.Do(func() { 104 | b.cancel() 105 | }) 106 | } 107 | -------------------------------------------------------------------------------- /limiter_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | ) 7 | 8 | func TestLimiter(t *testing.T) { 9 | var l = NewLimiter() 10 | defer l.Close() 11 | l.Add("test", 2, time.Second) 12 | b, ok := l.Bucket("test") 13 | if !ok { 14 | t.Fatal("no bucket found") 15 | } 16 | defer b.Close() 17 | if !b.Inc() { 18 | t.Fatalf("got false, expected true") 19 | } 20 | if !b.Inc() { 21 | t.Fatalf("got false, expected true") 22 | } 23 | if b.Inc() { 24 | t.Fatalf("got true, expected false") 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /logger.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "context" 5 | ) 6 | 7 | // LoggerLevel represents a logger level 8 | type LoggerLevel int 9 | 10 | // Logger levels 11 | const ( 12 | LoggerLevelDebug LoggerLevel = iota 13 | LoggerLevelInfo 14 | LoggerLevelWarn 15 | LoggerLevelError 16 | LoggerLevelFatal 17 | ) 18 | 19 | // LoggerLevelFromString creates a logger level from string 20 | func LoggerLevelFromString(s string) LoggerLevel { 21 | switch s { 22 | case "debug": 23 | return LoggerLevelDebug 24 | case "error": 25 | return LoggerLevelError 26 | case "fatal": 27 | return LoggerLevelFatal 28 | case "warn": 29 | return LoggerLevelWarn 30 | default: 31 | return LoggerLevelInfo 32 | } 33 | } 34 | 35 | func (l LoggerLevel) String() string { 36 | switch l { 37 | case LoggerLevelDebug: 38 | return "debug" 39 | case LoggerLevelError: 40 | return "error" 41 | case LoggerLevelFatal: 42 | return "fatal" 43 | case LoggerLevelWarn: 44 | return "warn" 45 | default: 46 | return "info" 47 | } 48 | } 49 | 50 | func (l *LoggerLevel) UnmarshalText(b []byte) error { 51 | *l = LoggerLevelFromString(string(b)) 52 | return nil 53 | } 54 | 55 | func (l LoggerLevel) MarshalText() ([]byte, error) { 56 | b := []byte(l.String()) 57 | return b, nil 58 | } 59 | 60 | // CompleteLogger represents a complete logger 61 | type CompleteLogger interface { 62 | SeverityCtxLogger 63 | SeverityLogger 64 | SeverityWriteLogger 65 | SeverityWriteCtxLogger 66 | StdLogger 67 | } 68 | 69 | // StdLogger represents a standard logger 70 | type StdLogger interface { 71 | Fatal(v ...any) 72 | Fatalf(format string, v ...any) 73 | Print(v ...any) 74 | Printf(format string, v ...any) 75 | } 76 | 77 | // SeverityLogger represents a severity logger 78 | type SeverityLogger interface { 79 | Debug(v ...any) 80 | Debugf(format string, v ...any) 81 | Error(v ...any) 82 | Errorf(format string, v ...any) 83 | Info(v ...any) 84 | Infof(format string, v ...any) 85 | Warn(v ...any) 86 | Warnf(format string, v ...any) 87 | } 88 | 89 | type TestLogger interface { 90 | Error(v ...any) 91 | Errorf(format string, v ...any) 92 | Fatal(v ...any) 93 | Fatalf(format string, v ...any) 94 | Log(v ...any) 95 | Logf(format string, v ...any) 96 | } 97 | 98 | // SeverityCtxLogger represents a severity with context logger 99 | type SeverityCtxLogger interface { 100 | DebugC(ctx context.Context, v ...any) 101 | DebugCf(ctx context.Context, format string, v ...any) 102 | ErrorC(ctx context.Context, v ...any) 103 | ErrorCf(ctx context.Context, format string, v ...any) 104 | FatalC(ctx context.Context, v ...any) 105 | FatalCf(ctx context.Context, format string, v ...any) 106 | InfoC(ctx context.Context, v ...any) 107 | InfoCf(ctx context.Context, format string, v ...any) 108 | WarnC(ctx context.Context, v ...any) 109 | WarnCf(ctx context.Context, format string, v ...any) 110 | } 111 | 112 | type SeverityWriteLogger interface { 113 | Write(l LoggerLevel, v ...any) 114 | Writef(l LoggerLevel, format string, v ...any) 115 | } 116 | 117 | type SeverityWriteCtxLogger interface { 118 | WriteC(ctx context.Context, l LoggerLevel, v ...any) 119 | WriteCf(ctx context.Context, l LoggerLevel, format string, v ...any) 120 | } 121 | 122 | type completeLogger struct { 123 | print, debug, error, fatal, info, warn func(v ...any) 124 | printf, debugf, errorf, fatalf, infof, warnf func(format string, v ...any) 125 | debugC, errorC, fatalC, infoC, warnC func(ctx context.Context, v ...any) 126 | debugCf, errorCf, fatalCf, infoCf, warnCf func(ctx context.Context, format string, v ...any) 127 | write func(l LoggerLevel, v ...any) 128 | writeC func(ctx context.Context, l LoggerLevel, v ...any) 129 | writeCf func(ctx context.Context, l LoggerLevel, format string, v ...any) 130 | writef func(l LoggerLevel, format string, v ...any) 131 | } 132 | 133 | func newCompleteLogger() *completeLogger { 134 | l := &completeLogger{} 135 | l.debug = func(v ...any) { l.print(v...) } 136 | l.debugf = func(format string, v ...any) { l.printf(format, v...) } 137 | l.debugC = func(ctx context.Context, v ...any) { l.debug(v...) } 138 | l.debugCf = func(ctx context.Context, format string, v ...any) { l.debugf(format, v...) } 139 | l.error = func(v ...any) { l.print(v...) } 140 | l.errorf = func(format string, v ...any) { l.printf(format, v...) } 141 | l.errorC = func(ctx context.Context, v ...any) { l.error(v...) } 142 | l.errorCf = func(ctx context.Context, format string, v ...any) { l.errorf(format, v...) } 143 | l.fatal = func(v ...any) { l.print(v...) } 144 | l.fatalf = func(format string, v ...any) { l.printf(format, v...) } 145 | l.fatalC = func(ctx context.Context, v ...any) { l.fatal(v...) } 146 | l.fatalCf = func(ctx context.Context, format string, v ...any) { l.fatalf(format, v...) } 147 | l.info = func(v ...any) { l.print(v...) } 148 | l.infof = func(format string, v ...any) { l.printf(format, v...) } 149 | l.infoC = func(ctx context.Context, v ...any) { l.info(v...) } 150 | l.infoCf = func(ctx context.Context, format string, v ...any) { l.infof(format, v...) } 151 | l.print = func(v ...any) {} 152 | l.printf = func(format string, v ...any) {} 153 | l.warn = func(v ...any) { l.print(v...) } 154 | l.warnf = func(format string, v ...any) { l.printf(format, v...) } 155 | l.warnC = func(ctx context.Context, v ...any) { l.warn(v...) } 156 | l.warnCf = func(ctx context.Context, format string, v ...any) { l.warnf(format, v...) } 157 | l.write = func(lv LoggerLevel, v ...any) { 158 | switch lv { 159 | case LoggerLevelDebug: 160 | l.debug(v...) 161 | case LoggerLevelError: 162 | l.error(v...) 163 | case LoggerLevelFatal: 164 | l.fatal(v...) 165 | case LoggerLevelWarn: 166 | l.warn(v...) 167 | default: 168 | l.info(v...) 169 | } 170 | } 171 | l.writeC = func(ctx context.Context, lv LoggerLevel, v ...any) { 172 | switch lv { 173 | case LoggerLevelDebug: 174 | l.debugC(ctx, v...) 175 | case LoggerLevelError: 176 | l.errorC(ctx, v...) 177 | case LoggerLevelFatal: 178 | l.fatalC(ctx, v...) 179 | case LoggerLevelWarn: 180 | l.warnC(ctx, v...) 181 | default: 182 | l.infoC(ctx, v...) 183 | } 184 | } 185 | l.writeCf = func(ctx context.Context, lv LoggerLevel, format string, v ...any) { 186 | switch lv { 187 | case LoggerLevelDebug: 188 | l.debugCf(ctx, format, v...) 189 | case LoggerLevelError: 190 | l.errorCf(ctx, format, v...) 191 | case LoggerLevelFatal: 192 | l.fatalCf(ctx, format, v...) 193 | case LoggerLevelWarn: 194 | l.warnCf(ctx, format, v...) 195 | default: 196 | l.infoCf(ctx, format, v...) 197 | } 198 | } 199 | l.writef = func(lv LoggerLevel, format string, v ...any) { 200 | switch lv { 201 | case LoggerLevelDebug: 202 | l.debugf(format, v...) 203 | case LoggerLevelError: 204 | l.errorf(format, v...) 205 | case LoggerLevelFatal: 206 | l.fatalf(format, v...) 207 | case LoggerLevelWarn: 208 | l.warnf(format, v...) 209 | default: 210 | l.infof(format, v...) 211 | } 212 | } 213 | return l 214 | } 215 | 216 | func (l *completeLogger) Debug(v ...any) { l.debug(v...) } 217 | func (l *completeLogger) Debugf(format string, v ...any) { l.debugf(format, v...) } 218 | func (l *completeLogger) DebugC(ctx context.Context, v ...any) { l.debugC(ctx, v...) } 219 | func (l *completeLogger) DebugCf(ctx context.Context, format string, v ...any) { 220 | l.debugCf(ctx, format, v...) 221 | } 222 | func (l *completeLogger) Error(v ...any) { l.error(v...) } 223 | func (l *completeLogger) Errorf(format string, v ...any) { l.errorf(format, v...) } 224 | func (l *completeLogger) ErrorC(ctx context.Context, v ...any) { l.errorC(ctx, v...) } 225 | func (l *completeLogger) ErrorCf(ctx context.Context, format string, v ...any) { 226 | l.errorCf(ctx, format, v...) 227 | } 228 | func (l *completeLogger) Fatal(v ...any) { l.fatal(v...) } 229 | func (l *completeLogger) Fatalf(format string, v ...any) { l.fatalf(format, v...) } 230 | func (l *completeLogger) FatalC(ctx context.Context, v ...any) { l.fatalC(ctx, v...) } 231 | func (l *completeLogger) FatalCf(ctx context.Context, format string, v ...any) { 232 | l.fatalCf(ctx, format, v...) 233 | } 234 | func (l *completeLogger) Info(v ...any) { l.info(v...) } 235 | func (l *completeLogger) Infof(format string, v ...any) { l.infof(format, v...) } 236 | func (l *completeLogger) InfoC(ctx context.Context, v ...any) { l.infoC(ctx, v...) } 237 | func (l *completeLogger) InfoCf(ctx context.Context, format string, v ...any) { 238 | l.infoCf(ctx, format, v...) 239 | } 240 | func (l *completeLogger) Print(v ...any) { l.print(v...) } 241 | func (l *completeLogger) Printf(format string, v ...any) { l.printf(format, v...) } 242 | func (l *completeLogger) Warn(v ...any) { l.warn(v...) } 243 | func (l *completeLogger) Warnf(format string, v ...any) { l.warnf(format, v...) } 244 | func (l *completeLogger) WarnC(ctx context.Context, v ...any) { l.warnC(ctx, v...) } 245 | func (l *completeLogger) WarnCf(ctx context.Context, format string, v ...any) { 246 | l.warnCf(ctx, format, v...) 247 | } 248 | func (l *completeLogger) Write(lv LoggerLevel, v ...any) { l.write(lv, v...) } 249 | func (l *completeLogger) Writef(lv LoggerLevel, format string, v ...any) { 250 | l.writef(lv, format, v...) 251 | } 252 | func (l *completeLogger) WriteC(ctx context.Context, lv LoggerLevel, v ...any) { 253 | l.writeC(ctx, lv, v...) 254 | } 255 | func (l *completeLogger) WriteCf(ctx context.Context, lv LoggerLevel, format string, v ...any) { 256 | l.writeCf(ctx, lv, format, v...) 257 | } 258 | 259 | // AdaptStdLogger transforms an StdLogger into a CompleteLogger if needed 260 | func AdaptStdLogger(i StdLogger) CompleteLogger { 261 | if v, ok := i.(CompleteLogger); ok { 262 | return v 263 | } 264 | l := newCompleteLogger() 265 | if i == nil { 266 | return l 267 | } 268 | l.fatal = i.Fatal 269 | l.fatalf = i.Fatalf 270 | l.print = i.Print 271 | l.printf = i.Printf 272 | if v, ok := i.(SeverityLogger); ok { 273 | l.debug = v.Debug 274 | l.debugf = v.Debugf 275 | l.error = v.Error 276 | l.errorf = v.Errorf 277 | l.info = v.Info 278 | l.infof = v.Infof 279 | l.warn = v.Warn 280 | l.warnf = v.Warnf 281 | } 282 | if v, ok := i.(SeverityCtxLogger); ok { 283 | l.debugC = v.DebugC 284 | l.debugCf = v.DebugCf 285 | l.errorC = v.ErrorC 286 | l.errorCf = v.ErrorCf 287 | l.fatalC = v.FatalC 288 | l.fatalCf = v.FatalCf 289 | l.infoC = v.InfoC 290 | l.infoCf = v.InfoCf 291 | l.warnC = v.WarnC 292 | l.warnCf = v.WarnCf 293 | } 294 | if v, ok := i.(SeverityWriteLogger); ok { 295 | l.write = v.Write 296 | l.writef = v.Writef 297 | } 298 | if v, ok := i.(SeverityWriteCtxLogger); ok { 299 | l.writeC = v.WriteC 300 | l.writeCf = v.WriteCf 301 | } 302 | return l 303 | } 304 | 305 | // AdaptTestLogger transforms a TestLogger into a CompleteLogger if needed 306 | func AdaptTestLogger(i TestLogger) CompleteLogger { 307 | if v, ok := i.(CompleteLogger); ok { 308 | return v 309 | } 310 | l := newCompleteLogger() 311 | if i == nil { 312 | return l 313 | } 314 | l.error = i.Error 315 | l.errorf = i.Errorf 316 | l.fatal = i.Fatal 317 | l.fatalf = i.Fatalf 318 | l.print = i.Log 319 | l.printf = i.Logf 320 | l.debug = l.print 321 | l.debugf = l.printf 322 | l.info = l.print 323 | l.infof = l.printf 324 | l.warn = l.print 325 | l.warnf = l.printf 326 | return l 327 | } 328 | -------------------------------------------------------------------------------- /logger_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestLoggerLevel(t *testing.T) { 8 | var l LoggerLevel 9 | for _, v := range []struct { 10 | l LoggerLevel 11 | s string 12 | }{ 13 | { 14 | l: LoggerLevelDebug, 15 | s: "debug", 16 | }, 17 | { 18 | l: LoggerLevelError, 19 | s: "error", 20 | }, 21 | { 22 | l: LoggerLevelFatal, 23 | s: "fatal", 24 | }, 25 | { 26 | l: LoggerLevelInfo, 27 | s: "info", 28 | }, 29 | { 30 | l: LoggerLevelWarn, 31 | s: "warn", 32 | }, 33 | } { 34 | if e, g := v.s, v.l.String(); e != g { 35 | t.Fatalf("expected %s, got %s", e, g) 36 | } 37 | b, err := v.l.MarshalText() 38 | if err != nil { 39 | t.Fatalf("expected no error, got %s", err) 40 | } 41 | if e, g := v.s, string(b); e != g { 42 | t.Fatalf("expected %s, got %s", e, g) 43 | } 44 | if e, g := v.l, LoggerLevelFromString(v.s); e != g { 45 | t.Fatalf("expected %s, got %s", e, g) 46 | } 47 | err = l.UnmarshalText([]byte(v.s)) 48 | if err != nil { 49 | t.Fatalf("expected no error, got %s", err) 50 | } 51 | if e, g := v.l, l; e != g { 52 | t.Fatalf("expected %s, got %s", e, g) 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /map.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "fmt" 5 | "sync" 6 | ) 7 | 8 | // BiMap represents a bidirectional map 9 | type BiMap struct { 10 | forward map[any]any 11 | inverse map[any]any 12 | m *sync.Mutex 13 | } 14 | 15 | // NewBiMap creates a new BiMap 16 | func NewBiMap() *BiMap { 17 | return &BiMap{ 18 | forward: make(map[any]any), 19 | inverse: make(map[any]any), 20 | m: &sync.Mutex{}, 21 | } 22 | } 23 | 24 | func (m *BiMap) get(k any, i map[any]any) (v any, ok bool) { 25 | m.m.Lock() 26 | defer m.m.Unlock() 27 | v, ok = i[k] 28 | return 29 | } 30 | 31 | // Get gets the value in the forward map based on the provided key 32 | func (m *BiMap) Get(k any) (any, bool) { return m.get(k, m.forward) } 33 | 34 | // GetInverse gets the value in the inverse map based on the provided key 35 | func (m *BiMap) GetInverse(k any) (any, bool) { return m.get(k, m.inverse) } 36 | 37 | // MustGet gets the value in the forward map based on the provided key and panics if key is not found 38 | func (m *BiMap) MustGet(k any) any { 39 | v, ok := m.get(k, m.forward) 40 | if !ok { 41 | panic(fmt.Sprintf("astikit: key %+v not found in foward map", k)) 42 | } 43 | return v 44 | } 45 | 46 | // MustGetInverse gets the value in the inverse map based on the provided key and panics if key is not found 47 | func (m *BiMap) MustGetInverse(k any) any { 48 | v, ok := m.get(k, m.inverse) 49 | if !ok { 50 | panic(fmt.Sprintf("astikit: key %+v not found in inverse map", k)) 51 | } 52 | return v 53 | } 54 | 55 | func (m *BiMap) set(k, v any, f, i map[any]any) *BiMap { 56 | m.m.Lock() 57 | defer m.m.Unlock() 58 | f[k] = v 59 | i[v] = k 60 | return m 61 | } 62 | 63 | // Set sets the value in the forward and inverse map for the provided forward key 64 | func (m *BiMap) Set(k, v any) *BiMap { return m.set(k, v, m.forward, m.inverse) } 65 | 66 | // SetInverse sets the value in the forward and inverse map for the provided inverse key 67 | func (m *BiMap) SetInverse(k, v any) *BiMap { return m.set(k, v, m.inverse, m.forward) } 68 | -------------------------------------------------------------------------------- /map_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import "testing" 4 | 5 | func TestBiMap(t *testing.T) { 6 | m := NewBiMap() 7 | m.Set(0, 1) 8 | v, ok := m.Get(0) 9 | if !ok { 10 | t.Fatal("expected true, got false") 11 | } 12 | if e, g := 1, v.(int); e != g { 13 | t.Fatalf("expected %d, got %d", e, g) 14 | } 15 | _, ok = m.GetInverse(0) 16 | if ok { 17 | t.Fatal("expected false, got true") 18 | } 19 | v, ok = m.GetInverse(1) 20 | if !ok { 21 | t.Fatal("expected true, got false") 22 | } 23 | if e, g := 0, v.(int); e != g { 24 | t.Fatalf("expected %d, got %d", e, g) 25 | } 26 | m.SetInverse(0, 1) 27 | v, ok = m.GetInverse(0) 28 | if !ok { 29 | t.Fatal("expected true, got false") 30 | } 31 | if e, g := 1, v.(int); e != g { 32 | t.Fatalf("expected %d, got %d", e, g) 33 | } 34 | testPanic(t, false, func() { m.MustGet(0) }) 35 | testPanic(t, true, func() { m.MustGet(2) }) 36 | testPanic(t, false, func() { m.MustGetInverse(0) }) 37 | testPanic(t, true, func() { m.MustGetInverse(2) }) 38 | } 39 | 40 | func testPanic(t *testing.T, shouldPanic bool, fn func()) { 41 | defer func() { 42 | err := recover() 43 | if shouldPanic && err == nil { 44 | t.Fatal("should have panicked") 45 | } else if !shouldPanic && err != nil { 46 | t.Fatal("should not have panicked") 47 | } 48 | }() 49 | fn() 50 | } 51 | -------------------------------------------------------------------------------- /os.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "path/filepath" 8 | "strings" 9 | ) 10 | 11 | // MoveFile is a cancellable move of a local file to a local or remote location 12 | func MoveFile(ctx context.Context, dst, src string, f CopyFileFunc) (err error) { 13 | // Copy 14 | if err = CopyFile(ctx, dst, src, f); err != nil { 15 | err = fmt.Errorf("astikit: copying file %s to %s failed: %w", src, dst, err) 16 | return 17 | } 18 | 19 | // Delete 20 | if err = os.Remove(src); err != nil { 21 | err = fmt.Errorf("astikit: removing %s failed: %w", src, err) 22 | return 23 | } 24 | return 25 | } 26 | 27 | // CopyFileFunc represents a CopyFile func 28 | type CopyFileFunc func(ctx context.Context, dst string, srcStat os.FileInfo, srcFile *os.File) error 29 | 30 | // CopyFile is a cancellable copy of a local file to a local or remote location 31 | func CopyFile(ctx context.Context, dst, src string, f CopyFileFunc) (err error) { 32 | // Check context 33 | if err = ctx.Err(); err != nil { 34 | return 35 | } 36 | 37 | // Stat src 38 | var srcStat os.FileInfo 39 | if srcStat, err = os.Stat(src); err != nil { 40 | err = fmt.Errorf("astikit: stating %s failed: %w", src, err) 41 | return 42 | } 43 | 44 | // Src is a dir 45 | if srcStat.IsDir() { 46 | // Make sure to clean dir path so that we get consistent path separator with filepath.Walk 47 | src = filepath.Clean(src) 48 | 49 | // Walk through the dir 50 | if err = filepath.Walk(src, func(path string, info os.FileInfo, errWalk error) (err error) { 51 | // Check error 52 | if errWalk != nil { 53 | err = errWalk 54 | return 55 | } 56 | 57 | // Do not process root 58 | if src == path { 59 | return 60 | } 61 | 62 | // Copy 63 | p := filepath.Join(dst, strings.TrimPrefix(path, src)) 64 | if err = CopyFile(ctx, p, path, f); err != nil { 65 | err = fmt.Errorf("astikit: copying %s to %s failed: %w", path, p, err) 66 | return 67 | } 68 | return nil 69 | }); err != nil { 70 | err = fmt.Errorf("astikit: walking through %s failed: %w", src, err) 71 | return 72 | } 73 | return 74 | } 75 | 76 | // Open src 77 | var srcFile *os.File 78 | if srcFile, err = os.Open(src); err != nil { 79 | err = fmt.Errorf("astikit: opening %s failed: %w", src, err) 80 | return 81 | } 82 | defer srcFile.Close() 83 | 84 | // Custom 85 | if err = f(ctx, dst, srcStat, srcFile); err != nil { 86 | err = fmt.Errorf("astikit: custom failed: %w", err) 87 | return 88 | } 89 | return 90 | } 91 | 92 | // LocalCopyFileFunc is the local CopyFileFunc that allows doing cross partition copies 93 | func LocalCopyFileFunc(ctx context.Context, dst string, srcStat os.FileInfo, srcFile *os.File) (err error) { 94 | // Check context 95 | if err = ctx.Err(); err != nil { 96 | return 97 | } 98 | 99 | // Create the destination folder 100 | if err = os.MkdirAll(filepath.Dir(dst), DefaultDirMode); err != nil { 101 | err = fmt.Errorf("astikit: mkdirall %s failed: %w", filepath.Dir(dst), err) 102 | return 103 | } 104 | 105 | // Create the destination file 106 | var dstFile *os.File 107 | if dstFile, err = os.Create(dst); err != nil { 108 | err = fmt.Errorf("astikit: creating %s failed: %w", dst, err) 109 | return 110 | } 111 | defer dstFile.Close() 112 | 113 | // Chmod using os.chmod instead of file.Chmod 114 | if err = os.Chmod(dst, srcStat.Mode()); err != nil { 115 | err = fmt.Errorf("astikit: chmod %s %s failed, %w", dst, srcStat.Mode(), err) 116 | return 117 | } 118 | 119 | // Copy the content 120 | if _, err = Copy(ctx, dstFile, srcFile); err != nil { 121 | err = fmt.Errorf("astikit: copying content of %s to %s failed: %w", srcFile.Name(), dstFile.Name(), err) 122 | return 123 | } 124 | return 125 | } 126 | 127 | // SignalHandler represents a func that can handle a signal 128 | type SignalHandler func(s os.Signal) 129 | 130 | // TermSignalHandler returns a SignalHandler that is executed only on a term signal 131 | func TermSignalHandler(f func()) SignalHandler { 132 | return func(s os.Signal) { 133 | if isTermSignal(s) { 134 | f() 135 | } 136 | } 137 | } 138 | 139 | // LoggerSignalHandler returns a SignalHandler that logs the signal 140 | func LoggerSignalHandler(l SeverityLogger, ignoredSignals ...os.Signal) SignalHandler { 141 | ss := make(map[os.Signal]bool) 142 | for _, s := range ignoredSignals { 143 | ss[s] = true 144 | } 145 | return func(s os.Signal) { 146 | if _, ok := ss[s]; ok { 147 | return 148 | } 149 | l.Debugf("astikit: received signal %s", s) 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /os_js.go: -------------------------------------------------------------------------------- 1 | // +build js,wasm 2 | 3 | package astikit 4 | 5 | import ( 6 | "os" 7 | "syscall" 8 | ) 9 | 10 | func isTermSignal(s os.Signal) bool { 11 | return s == syscall.SIGKILL || s == syscall.SIGINT || s == syscall.SIGQUIT || s == syscall.SIGTERM 12 | } 13 | -------------------------------------------------------------------------------- /os_others.go: -------------------------------------------------------------------------------- 1 | // +build !js !wasm 2 | 3 | package astikit 4 | 5 | import ( 6 | "os" 7 | "syscall" 8 | ) 9 | 10 | func isTermSignal(s os.Signal) bool { 11 | return s == syscall.SIGABRT || s == syscall.SIGKILL || s == syscall.SIGINT || s == syscall.SIGQUIT || s == syscall.SIGTERM 12 | } 13 | -------------------------------------------------------------------------------- /os_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "path/filepath" 7 | "testing" 8 | ) 9 | 10 | func TestCopyFile(t *testing.T) { 11 | // Get temp dir 12 | p := t.TempDir() 13 | 14 | // Copy file 15 | e := "testdata/os/f" 16 | g := filepath.Join(p, "f") 17 | err := CopyFile(context.Background(), g, e, LocalCopyFileFunc) 18 | if err != nil { 19 | t.Fatalf("expected no error, got %+v", err) 20 | } 21 | compareFile(t, e, g) 22 | 23 | // Move file 24 | e = g 25 | g = filepath.Join(p, "m") 26 | err = MoveFile(context.Background(), g, e, LocalCopyFileFunc) 27 | if err != nil { 28 | t.Fatalf("expected no error, got %+v", err) 29 | } 30 | checkFile(t, g, "0") 31 | _, err = os.Stat(e) 32 | if !os.IsNotExist(err) { 33 | t.Fatal("expected true, got false") 34 | } 35 | 36 | // Copy dir 37 | e = "testdata/os/d" 38 | g = filepath.Join(p, "d") 39 | err = CopyFile(context.Background(), g, e, LocalCopyFileFunc) 40 | if err != nil { 41 | t.Fatalf("expected no error, got %+v", err) 42 | } 43 | compareDir(t, e, g) 44 | } 45 | -------------------------------------------------------------------------------- /pcm_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestPCMLevel(t *testing.T) { 10 | if e, g := 2.160246899469287, PCMLevel([]int{1, 2, 3}); g != e { 11 | t.Fatalf("got %v, expected %v", g, e) 12 | } 13 | } 14 | 15 | func TestPCMNormalize(t *testing.T) { 16 | // Nothing to do 17 | i := []int{10000, maxPCMSample(16), -10000} 18 | if g := PCMNormalize(i, 16); !reflect.DeepEqual(i, g) { 19 | t.Fatalf("got %+v, expected %+v", g, i) 20 | } 21 | 22 | // Normalize 23 | i = []int{10000, 0, -10000} 24 | if e, g := []int{32767, 0, -32767}, PCMNormalize(i, 16); !reflect.DeepEqual(e, g) { 25 | t.Fatalf("got %+v, expected %+v", g, e) 26 | } 27 | } 28 | 29 | func TestConvertPCMBitDepth(t *testing.T) { 30 | // Nothing to do 31 | s, err := ConvertPCMBitDepth(1>>8, 16, 16) 32 | if err != nil { 33 | t.Fatalf("expected no error, got %+v", err) 34 | } 35 | if e := 1 >> 8; !reflect.DeepEqual(s, e) { 36 | t.Fatalf("got %+v, expected %+v", s, e) 37 | } 38 | 39 | // Src bit depth > Dst bit depth 40 | s, err = ConvertPCMBitDepth(1>>24, 32, 16) 41 | if err != nil { 42 | t.Fatalf("expected no error, got %+v", err) 43 | } 44 | if e := 1 >> 8; !reflect.DeepEqual(s, e) { 45 | t.Fatalf("got %+v, expected %+v", s, e) 46 | } 47 | 48 | // Src bit depth < Dst bit depth 49 | s, err = ConvertPCMBitDepth(1>>8, 16, 32) 50 | if err != nil { 51 | t.Fatalf("expected no error, got %+v", err) 52 | } 53 | if e := 1 >> 24; !reflect.DeepEqual(s, e) { 54 | t.Fatalf("got %+v, expected %+v", s, e) 55 | } 56 | } 57 | 58 | func TestPCMSampleRateConverter(t *testing.T) { 59 | // Create input 60 | var i []int 61 | for idx := 0; idx < 20; idx++ { 62 | i = append(i, idx+1) 63 | } 64 | 65 | // Create sample func 66 | var o []int 67 | var sampleFunc = func(s int) (err error) { 68 | o = append(o, s) 69 | return 70 | } 71 | 72 | // Nothing to do 73 | c := NewPCMSampleRateConverter(1, 1, 1, sampleFunc) 74 | for _, s := range i { 75 | c.Add(s) //nolint:errcheck 76 | } 77 | if !reflect.DeepEqual(o, i) { 78 | t.Fatalf("got %+v, expected %+v", i, o) 79 | } 80 | 81 | // Simple src sample rate > dst sample rate 82 | o = []int{} 83 | c = NewPCMSampleRateConverter(5, 3, 1, sampleFunc) 84 | for _, s := range i { 85 | c.Add(s) //nolint:errcheck 86 | } 87 | if e := []int{1, 2, 4, 6, 7, 9, 11, 12, 14, 16, 17, 19}; !reflect.DeepEqual(e, o) { 88 | t.Fatalf("got %+v, expected %+v", o, e) 89 | } 90 | 91 | // Multi channels 92 | o = []int{} 93 | c = NewPCMSampleRateConverter(4, 2, 2, sampleFunc) 94 | for _, s := range i { 95 | c.Add(s) //nolint:errcheck 96 | } 97 | if e := []int{1, 2, 4, 5, 8, 9, 12, 13, 16, 17}; !reflect.DeepEqual(e, o) { 98 | t.Fatalf("got %+v, expected %+v", o, e) 99 | } 100 | 101 | // Realistic src sample rate > dst sample rate 102 | i = []int{} 103 | for idx := 0; idx < 4*44100; idx++ { 104 | i = append(i, idx+1) 105 | } 106 | o = []int{} 107 | c = NewPCMSampleRateConverter(44100, 16000, 2, sampleFunc) 108 | for _, s := range i { 109 | c.Add(s) //nolint:errcheck 110 | } 111 | if e, g := 4*16000, len(o); g != e { 112 | t.Fatalf("invalid len, got %v, expected %v", g, e) 113 | } 114 | 115 | // Create input 116 | i = []int{} 117 | for idx := 0; idx < 10; idx++ { 118 | i = append(i, idx+1) 119 | } 120 | 121 | // Simple src sample rate < dst sample rate 122 | o = []int{} 123 | c = NewPCMSampleRateConverter(3, 5, 1, sampleFunc) 124 | for _, s := range i { 125 | c.Add(s) //nolint:errcheck 126 | } 127 | if e := []int{1, 1, 2, 2, 3, 4, 4, 5, 5, 6, 7, 7, 8, 8, 9, 10, 10}; !reflect.DeepEqual(e, o) { 128 | t.Fatalf("got %+v, expected %+v", o, e) 129 | } 130 | 131 | // Multi channels 132 | o = []int{} 133 | c = NewPCMSampleRateConverter(3, 5, 2, sampleFunc) 134 | for _, s := range i { 135 | c.Add(s) //nolint:errcheck 136 | } 137 | if e := []int{1, 2, 1, 2, 3, 4, 3, 4, 5, 6, 7, 8, 7, 8, 9, 10, 9, 10}; !reflect.DeepEqual(e, o) { 138 | t.Fatalf("got %+v, expected %+v", o, e) 139 | } 140 | } 141 | 142 | func TestPCMChannelsConverter(t *testing.T) { 143 | // Create input 144 | var i []int 145 | for idx := 0; idx < 20; idx++ { 146 | i = append(i, idx+1) 147 | } 148 | 149 | // Create sample func 150 | var o []int 151 | var sampleFunc = func(s int) (err error) { 152 | o = append(o, s) 153 | return 154 | } 155 | 156 | // Nothing to do 157 | c := NewPCMChannelsConverter(3, 3, sampleFunc) 158 | for _, s := range i { 159 | c.Add(s) //nolint:errcheck 160 | } 161 | if !reflect.DeepEqual(i, o) { 162 | t.Fatalf("got %+v, expected %+v", o, i) 163 | } 164 | 165 | // Throw away data 166 | o = []int{} 167 | c = NewPCMChannelsConverter(3, 1, sampleFunc) 168 | for _, s := range i { 169 | c.Add(s) //nolint:errcheck 170 | } 171 | if e := []int{1, 4, 7, 10, 13, 16, 19}; !reflect.DeepEqual(e, o) { 172 | t.Fatalf("got %+v, expected %+v", o, e) 173 | } 174 | 175 | // Repeat data 176 | o = []int{} 177 | c = NewPCMChannelsConverter(1, 2, sampleFunc) 178 | for _, s := range i { 179 | c.Add(s) //nolint:errcheck 180 | } 181 | if e := []int{1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20}; !reflect.DeepEqual(o, e) { 182 | t.Fatalf("got %+v, expected %+v", o, e) 183 | } 184 | } 185 | 186 | func TestPCMSilenceDetector(t *testing.T) { 187 | // Create silence detector 188 | sd := NewPCMSilenceDetector(PCMSilenceDetectorOptions{ 189 | MaxSilenceLevel: 2, 190 | MinSilenceDuration: 400 * time.Millisecond, // 2 samples 191 | SampleRate: 5, 192 | StepDuration: 200 * time.Millisecond, // 1 sample 193 | }) 194 | 195 | // Leading non silences + invalid leading silence + trailing silence is leftover 196 | vs := sd.Add([]int{3, 1, 3, 1}) 197 | if e := [][]int(nil); !reflect.DeepEqual(vs, e) { 198 | t.Fatalf("got %+v, expected %+v", vs, e) 199 | } 200 | if e, g := 1, len(sd.analyses); e != g { 201 | t.Fatalf("got %v, expected %v", g, e) 202 | } 203 | 204 | // Valid leading silence but trailing silence is insufficient for now 205 | vs = sd.Add([]int{1, 3, 3, 1}) 206 | if e := [][]int(nil); !reflect.DeepEqual(vs, e) { 207 | t.Fatalf("got %+v, expected %+v", vs, e) 208 | } 209 | if e, g := 5, len(sd.analyses); e != g { 210 | t.Fatalf("got %v, expected %v", g, e) 211 | } 212 | 213 | // Valid samples 214 | vs = sd.Add([]int{1}) 215 | if e := [][]int{{1, 1, 3, 3, 1, 1}}; !reflect.DeepEqual(vs, e) { 216 | t.Fatalf("got %+v, expected %+v", vs, e) 217 | } 218 | if e, g := 2, len(sd.analyses); e != g { 219 | t.Fatalf("got %v, expected %v", g, e) 220 | } 221 | 222 | // Multiple valid samples + truncate leading and trailing silences 223 | vs = sd.Add([]int{1, 1, 1, 1, 3, 3, 1, 1, 1, 1, 3, 3, 1, 1, 1, 1}) 224 | if e := [][]int{{1, 1, 3, 3, 1, 1}, {1, 1, 3, 3, 1, 1}}; !reflect.DeepEqual(vs, e) { 225 | t.Fatalf("got %+v, expected %+v", vs, e) 226 | } 227 | if e, g := 2, len(sd.analyses); e != g { 228 | t.Fatalf("got %v, expected %v", g, e) 229 | } 230 | 231 | // Invalid in-between silences that should be kept 232 | vs = sd.Add([]int{1, 1, 1, 3, 3, 1, 3, 3, 1, 3, 3, 1, 1, 1}) 233 | if e := [][]int{{1, 1, 3, 3, 1, 3, 3, 1, 3, 3, 1, 1}}; !reflect.DeepEqual(vs, e) { 234 | t.Fatalf("got %+v, expected %+v", vs, e) 235 | } 236 | if e, g := 2, len(sd.analyses); e != g { 237 | t.Fatalf("got %v, expected %v", g, e) 238 | } 239 | } 240 | -------------------------------------------------------------------------------- /ptr.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import "time" 4 | 5 | // BoolPtr transforms a bool into a *bool 6 | func BoolPtr(i bool) *bool { 7 | return &i 8 | } 9 | 10 | // BytePtr transforms a byte into a *byte 11 | func BytePtr(i byte) *byte { 12 | return &i 13 | } 14 | 15 | // DurationPtr transforms a time.Duration into a *time.Duration 16 | func DurationPtr(i time.Duration) *time.Duration { 17 | return &i 18 | } 19 | 20 | // Float64Ptr transforms a float64 into a *float64 21 | func Float64Ptr(i float64) *float64 { 22 | return &i 23 | } 24 | 25 | // IntPtr transforms an int into an *int 26 | func IntPtr(i int) *int { 27 | return &i 28 | } 29 | 30 | // Int64Ptr transforms an int64 into an *int64 31 | func Int64Ptr(i int64) *int64 { 32 | return &i 33 | } 34 | 35 | // StrSlicePtr transforms a []string into a *[]string 36 | func StrSlicePtr(i []string) *[]string { 37 | return &i 38 | } 39 | 40 | // StrPtr transforms a string into a *string 41 | func StrPtr(i string) *string { 42 | return &i 43 | } 44 | 45 | // TimePtr transforms a time.Time into a *time.Time 46 | func TimePtr(i time.Time) *time.Time { 47 | return &i 48 | } 49 | 50 | // UInt8Ptr transforms a uint8 into a *uint8 51 | func UInt8Ptr(i uint8) *uint8 { 52 | return &i 53 | } 54 | 55 | // UInt16Ptr transforms a uint16 into a *uint16 56 | func UInt16Ptr(i uint16) *uint16 { 57 | return &i 58 | } 59 | 60 | // UInt32Ptr transforms a uint32 into a *uint32 61 | func UInt32Ptr(i uint32) *uint32 { 62 | return &i 63 | } 64 | 65 | // UInt64Ptr transforms a uint64 into a *uint64 66 | func UInt64Ptr(i uint64) *uint64 { 67 | return &i 68 | } 69 | -------------------------------------------------------------------------------- /rand.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "math/rand" 5 | "strings" 6 | "time" 7 | ) 8 | 9 | const ( 10 | randLetterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890" 11 | randLetterIdxBits = 6 // 6 bits to represent a letter index 12 | randLetterIdxMask = 1<= 0; { 25 | if remain == 0 { 26 | cache, remain = RandSource.Int63(), randLetterIdxMax 27 | } 28 | if idx := int(cache & randLetterIdxMask); idx < len(randLetterBytes) { 29 | sb.WriteByte(randLetterBytes[idx]) 30 | i-- 31 | } 32 | cache >>= randLetterIdxBits 33 | remain-- 34 | } 35 | return sb.String() 36 | } 37 | -------------------------------------------------------------------------------- /sort.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import "sort" 4 | 5 | // SortInt64 sorts a slice of int64s in increasing order. 6 | func SortInt64(a []int64) { sort.Sort(SortInt64Slice(a)) } 7 | 8 | // SortInt64Slice attaches the methods of Interface to []int64, sorting in increasing order. 9 | type SortInt64Slice []int64 10 | 11 | func (p SortInt64Slice) Len() int { return len(p) } 12 | func (p SortInt64Slice) Less(i, j int) bool { return p[i] < p[j] } 13 | func (p SortInt64Slice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } 14 | 15 | // SortUint64 sorts a slice of uint64s in increasing order. 16 | func SortUint64(a []uint64) { sort.Sort(SortUint64Slice(a)) } 17 | 18 | // SortUint64Slice attaches the methods of Interface to []uint64, sorting in increasing order. 19 | type SortUint64Slice []uint64 20 | 21 | func (p SortUint64Slice) Len() int { return len(p) } 22 | func (p SortUint64Slice) Less(i, j int) bool { return p[i] < p[j] } 23 | func (p SortUint64Slice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } 24 | -------------------------------------------------------------------------------- /sort_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestSort(t *testing.T) { 9 | i := []int64{3, 2, 4, 1} 10 | SortInt64(i) 11 | if e := []int64{1, 2, 3, 4}; !reflect.DeepEqual(e, i) { 12 | t.Fatalf("expected %+v, got %+v", e, i) 13 | } 14 | 15 | ui := []uint64{3, 2, 4, 1} 16 | SortUint64(ui) 17 | if e := []uint64{1, 2, 3, 4}; !reflect.DeepEqual(e, ui) { 18 | t.Fatalf("expected %+v, got %+v", e, ui) 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /ssh.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "io" 7 | "os" 8 | "path/filepath" 9 | "strings" 10 | ) 11 | 12 | // SSHSession represents an SSH Session 13 | type SSHSession interface { 14 | Run(string) error 15 | Start(string) error 16 | StdinPipe() (io.WriteCloser, error) 17 | Wait() error 18 | } 19 | 20 | // SSHSessionFunc represents a func that can return an SSHSession 21 | type SSHSessionFunc func() (s SSHSession, c *Closer, err error) 22 | 23 | // SSHCopyFileFunc is the SSH CopyFileFunc that allows doing SSH copies 24 | func SSHCopyFileFunc(fn SSHSessionFunc) CopyFileFunc { 25 | return func(ctx context.Context, dst string, srcStat os.FileInfo, srcFile *os.File) (err error) { 26 | // Check context 27 | if err = ctx.Err(); err != nil { 28 | return 29 | } 30 | 31 | // Escape dir path 32 | d := strings.ReplaceAll(filepath.Dir(dst), " ", "\\ ") 33 | 34 | // Using local closure allows better readibility for the defer c.Close() since it 35 | // isolates the use of the ssh session 36 | if err = func() (err error) { 37 | // Create ssh session 38 | var s SSHSession 39 | var c *Closer 40 | if s, c, err = fn(); err != nil { 41 | err = fmt.Errorf("astikit: creating ssh session failed: %w", err) 42 | return 43 | } 44 | defer c.Close() 45 | 46 | // Create the destination folder 47 | if err = s.Run("mkdir -p " + d); err != nil { 48 | err = fmt.Errorf("astikit: creating %s failed: %w", filepath.Dir(dst), err) 49 | return 50 | } 51 | return 52 | }(); err != nil { 53 | return 54 | } 55 | 56 | // Using local closure allows better readibility for the defer c.Close() since it 57 | // isolates the use of the ssh session 58 | if err = func() (err error) { 59 | // Create ssh session 60 | var s SSHSession 61 | var c *Closer 62 | if s, c, err = fn(); err != nil { 63 | err = fmt.Errorf("astikit: creating ssh session failed: %w", err) 64 | return 65 | } 66 | defer c.Close() 67 | 68 | // Create stdin pipe 69 | var stdin io.WriteCloser 70 | if stdin, err = s.StdinPipe(); err != nil { 71 | err = fmt.Errorf("astikit: creating stdin pipe failed: %w", err) 72 | return 73 | } 74 | defer stdin.Close() 75 | 76 | // Use "scp" command 77 | if err = s.Start("scp -qt " + d); err != nil { 78 | err = fmt.Errorf("astikit: scp to %s failed: %w", dst, err) 79 | return 80 | } 81 | 82 | // Send metadata 83 | if _, err = fmt.Fprintln(stdin, fmt.Sprintf("C%04o", srcStat.Mode().Perm()), srcStat.Size(), filepath.Base(dst)); err != nil { 84 | err = fmt.Errorf("astikit: sending metadata failed: %w", err) 85 | return 86 | } 87 | 88 | // Copy 89 | if _, err = Copy(ctx, stdin, srcFile); err != nil { 90 | err = fmt.Errorf("astikit: copying failed: %w", err) 91 | return 92 | } 93 | 94 | // Send close 95 | if _, err = fmt.Fprint(stdin, "\x00"); err != nil { 96 | err = fmt.Errorf("astikit: sending close failed: %w", err) 97 | return 98 | } 99 | 100 | // Close stdin 101 | if err = stdin.Close(); err != nil { 102 | err = fmt.Errorf("astikit: closing failed: %w", err) 103 | return 104 | } 105 | 106 | // Wait 107 | if err = s.Wait(); err != nil { 108 | err = fmt.Errorf("astikit: waiting failed: %w", err) 109 | return 110 | } 111 | return 112 | }(); err != nil { 113 | return 114 | } 115 | return 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /ssh_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "io" 7 | "path/filepath" 8 | "reflect" 9 | "testing" 10 | ) 11 | 12 | type mockedSSHSession struct { 13 | buf *bytes.Buffer 14 | cmds []string 15 | } 16 | 17 | func newMockedSSHSession() *mockedSSHSession { 18 | return &mockedSSHSession{buf: &bytes.Buffer{}} 19 | } 20 | 21 | func (s *mockedSSHSession) Run(cmd string) error { 22 | s.cmds = append(s.cmds, cmd) 23 | return nil 24 | } 25 | 26 | func (s *mockedSSHSession) Start(cmd string) error { 27 | s.cmds = append(s.cmds, cmd) 28 | return nil 29 | } 30 | 31 | func (s *mockedSSHSession) StdinPipe() (io.WriteCloser, error) { 32 | return NopCloser(s.buf), nil 33 | } 34 | 35 | func (s *mockedSSHSession) Wait() error { return nil } 36 | 37 | func TestSSHCopyFunc(t *testing.T) { 38 | var c int 39 | s := newMockedSSHSession() 40 | err := CopyFile(context.Background(), "/path/to with space/dst", "testdata/ssh/f", SSHCopyFileFunc(func() (SSHSession, *Closer, error) { 41 | c++ 42 | return s, NewCloser(), nil 43 | })) 44 | if err != nil { 45 | t.Fatalf("expected no error, got %+v", err) 46 | } 47 | if e := 2; c != e { 48 | t.Fatalf("expected %v, got %v", e, c) 49 | } 50 | if e := []string{"mkdir -p " + filepath.Clean("/path/to\\ with\\ space"), "scp -qt " + filepath.Clean("/path/to\\ with\\ space")}; !reflect.DeepEqual(e, s.cmds) { 51 | t.Fatalf("expected %+v, got %+v", e, s.cmds) 52 | } 53 | if e1, e2, e3, g := "C0775 1 dst\n0\x00", "C0755 1 dst\n0\x00", "C0666 1 dst\n0\x00", s.buf.String(); g != e1 && g != e2 && g != e3 { 54 | t.Fatalf("expected %s or %s or %s, got %s", e1, e2, e3, g) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /stat.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "sync/atomic" 7 | "time" 8 | ) 9 | 10 | // Stater is an object that can compute and handle stats 11 | type Stater struct { 12 | cancel context.CancelFunc 13 | ctx context.Context 14 | h StatsHandleFunc 15 | m *sync.Mutex // Locks ss 16 | period time.Duration 17 | running uint32 18 | ss map[*StatMetadata]StatOptions 19 | } 20 | 21 | // StatOptions represents stat options 22 | type StatOptions struct { 23 | Metadata *StatMetadata 24 | // Either a StatValuer or StatValuerOverTime 25 | Valuer any 26 | } 27 | 28 | // StatsHandleFunc is a method that can handle stat values 29 | type StatsHandleFunc func(stats []StatValue) 30 | 31 | // StatMetadata represents a stat metadata 32 | type StatMetadata struct { 33 | Description string 34 | Label string 35 | Name string 36 | Unit string 37 | } 38 | 39 | // StatValuer represents a stat valuer 40 | type StatValuer interface { 41 | Value(delta time.Duration) any 42 | } 43 | 44 | type StatValuerFunc func(d time.Duration) any 45 | 46 | func (f StatValuerFunc) Value(d time.Duration) any { 47 | return f(d) 48 | } 49 | 50 | // StatValue represents a stat value 51 | type StatValue struct { 52 | *StatMetadata 53 | Value any 54 | } 55 | 56 | // StaterOptions represents stater options 57 | type StaterOptions struct { 58 | HandleFunc StatsHandleFunc 59 | Period time.Duration 60 | } 61 | 62 | // NewStater creates a new stater 63 | func NewStater(o StaterOptions) *Stater { 64 | return &Stater{ 65 | h: o.HandleFunc, 66 | m: &sync.Mutex{}, 67 | period: o.Period, 68 | ss: make(map[*StatMetadata]StatOptions), 69 | } 70 | } 71 | 72 | // Start starts the stater 73 | func (s *Stater) Start(ctx context.Context) { 74 | // Check context 75 | if ctx.Err() != nil { 76 | return 77 | } 78 | 79 | // Make sure to start only once 80 | if atomic.CompareAndSwapUint32(&s.running, 0, 1) { 81 | // Update status 82 | defer atomic.StoreUint32(&s.running, 0) 83 | 84 | // Reset context 85 | s.ctx, s.cancel = context.WithCancel(ctx) 86 | 87 | // Create ticker 88 | t := time.NewTicker(s.period) 89 | defer t.Stop() 90 | 91 | // Loop 92 | lastStatAt := now() 93 | for { 94 | select { 95 | case <-t.C: 96 | // Get delta 97 | n := now() 98 | delta := n.Sub(lastStatAt) 99 | lastStatAt = n 100 | 101 | // Loop through stats 102 | var stats []StatValue 103 | s.m.Lock() 104 | for _, o := range s.ss { 105 | // Get value 106 | var v any 107 | if h, ok := o.Valuer.(StatValuer); ok { 108 | v = h.Value(delta) 109 | } else { 110 | continue 111 | } 112 | 113 | // Append 114 | stats = append(stats, StatValue{ 115 | StatMetadata: o.Metadata, 116 | Value: v, 117 | }) 118 | } 119 | s.m.Unlock() 120 | 121 | // Handle stats 122 | go s.h(stats) 123 | case <-s.ctx.Done(): 124 | return 125 | } 126 | } 127 | } 128 | } 129 | 130 | // Stop stops the stater 131 | func (s *Stater) Stop() { 132 | if s.cancel != nil { 133 | s.cancel() 134 | } 135 | } 136 | 137 | // AddStats adds stats 138 | func (s *Stater) AddStats(os ...StatOptions) { 139 | s.m.Lock() 140 | defer s.m.Unlock() 141 | for _, o := range os { 142 | s.ss[o.Metadata] = o 143 | } 144 | } 145 | 146 | // DelStats deletes stats 147 | func (s *Stater) DelStats(os ...StatOptions) { 148 | s.m.Lock() 149 | defer s.m.Unlock() 150 | for _, o := range os { 151 | delete(s.ss, o.Metadata) 152 | } 153 | } 154 | 155 | type AtomicUint64RateStat struct { 156 | last *uint64 157 | v *uint64 158 | } 159 | 160 | func NewAtomicUint64RateStat(v *uint64) *AtomicUint64RateStat { 161 | return &AtomicUint64RateStat{v: v} 162 | } 163 | 164 | func (s *AtomicUint64RateStat) Value(d time.Duration) any { 165 | current := atomic.LoadUint64(s.v) 166 | defer func() { s.last = ¤t }() 167 | if d <= 0 { 168 | return 0.0 169 | } 170 | var last uint64 171 | if s.last != nil { 172 | last = *s.last 173 | } 174 | return float64(current-last) / d.Seconds() 175 | } 176 | 177 | type AtomicDurationPercentageStat struct { 178 | d *AtomicDuration 179 | last *time.Duration 180 | } 181 | 182 | func NewAtomicDurationPercentageStat(d *AtomicDuration) *AtomicDurationPercentageStat { 183 | return &AtomicDurationPercentageStat{d: d} 184 | } 185 | 186 | func (s *AtomicDurationPercentageStat) Value(d time.Duration) any { 187 | current := s.d.Duration() 188 | defer func() { s.last = ¤t }() 189 | if d <= 0 { 190 | return 0.0 191 | } 192 | var last time.Duration 193 | if s.last != nil { 194 | last = *s.last 195 | } 196 | return float64(current-last) / float64(d) * 100 197 | } 198 | 199 | type AtomicDurationAvgStat struct { 200 | count *uint64 201 | d *AtomicDuration 202 | last *time.Duration 203 | lastCount *uint64 204 | } 205 | 206 | func NewAtomicDurationAvgStat(d *AtomicDuration, count *uint64) *AtomicDurationAvgStat { 207 | return &AtomicDurationAvgStat{ 208 | count: count, 209 | d: d, 210 | } 211 | } 212 | 213 | func (s *AtomicDurationAvgStat) Value(_ time.Duration) any { 214 | current := s.d.Duration() 215 | currentCount := atomic.LoadUint64(s.count) 216 | defer func() { 217 | s.last = ¤t 218 | s.lastCount = ¤tCount 219 | }() 220 | var last time.Duration 221 | var lastCount uint64 222 | if s.last != nil { 223 | last = *s.last 224 | } 225 | if s.lastCount != nil { 226 | lastCount = *s.lastCount 227 | } 228 | if currentCount-lastCount <= 0 { 229 | return time.Duration(0) 230 | } 231 | return time.Duration(float64(current-last) / float64(currentCount-lastCount)) 232 | } 233 | -------------------------------------------------------------------------------- /stat_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "context" 5 | "reflect" 6 | "sync" 7 | "sync/atomic" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func TestStater(t *testing.T) { 13 | // Update the now function so that it increments by 5s every time stats are computed 14 | var c int64 15 | mc := &sync.Mutex{} // Locks c 16 | nowPrevious := now 17 | defer func() { now = nowPrevious }() 18 | mn := &sync.Mutex{} // Locks nowV 19 | nowV := time.Unix(c*5, 0) 20 | now = func() time.Time { 21 | mn.Lock() 22 | defer mn.Unlock() 23 | return nowV 24 | } 25 | 26 | // Add stats 27 | var u1 uint64 28 | v1 := NewAtomicUint64RateStat(&u1) 29 | m1 := &StatMetadata{Description: "1"} 30 | o1 := StatOptions{Metadata: m1, Valuer: v1} 31 | d2 := NewAtomicDuration(0) 32 | v2 := NewAtomicDurationPercentageStat(d2) 33 | m2 := &StatMetadata{Description: "2"} 34 | o2 := StatOptions{Metadata: m2, Valuer: v2} 35 | d3 := NewAtomicDuration(0) 36 | v3 := NewAtomicDurationAvgStat(d3, &u1) 37 | m3 := &StatMetadata{Description: "3"} 38 | o3 := StatOptions{Metadata: m3, Valuer: v3} 39 | v4 := StatValuerFunc(func(d time.Duration) any { return 42 }) 40 | m4 := &StatMetadata{Description: "4"} 41 | o4 := StatOptions{Metadata: m4, Valuer: v4} 42 | 43 | // First time stats are computed, it actually acts as if stats were being updated 44 | // Second time stats are computed, results are stored and context is cancelled 45 | var ss []StatValue 46 | ctx, cancel := context.WithCancel(context.Background()) 47 | s := NewStater(StaterOptions{ 48 | HandleFunc: func(stats []StatValue) { 49 | mc.Lock() 50 | defer mc.Unlock() 51 | c++ 52 | switch c { 53 | case 1: 54 | atomic.AddUint64(&u1, 10) 55 | d2.Add(4 * time.Second) 56 | d3.Add(10 * time.Second) 57 | mn.Lock() 58 | nowV = time.Unix(5, 0) 59 | mn.Unlock() 60 | case 2: 61 | ss = stats 62 | cancel() 63 | } 64 | }, 65 | Period: time.Millisecond, 66 | }) 67 | s.AddStats(o1, o2, o3, o4) 68 | s.Start(ctx) 69 | defer s.Stop() 70 | for _, e := range []StatValue{ 71 | {StatMetadata: m1, Value: 2.0}, 72 | {StatMetadata: m2, Value: 80.0}, 73 | {StatMetadata: m3, Value: time.Second}, 74 | {StatMetadata: m4, Value: 42}, 75 | } { 76 | found := false 77 | for _, s := range ss { 78 | if reflect.DeepEqual(s, e) { 79 | found = true 80 | break 81 | } 82 | } 83 | if !found { 84 | t.Fatalf("expected %+v, not found", e) 85 | } 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /sync_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "reflect" 8 | "strings" 9 | "sync" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | func TestChan(t *testing.T) { 15 | // Do not process all 16 | c := NewChan(ChanOptions{}) 17 | var o []int 18 | c.Add(func() { 19 | o = append(o, 1) 20 | c.Stop() 21 | }) 22 | c.Add(func() { 23 | o = append(o, 2) 24 | }) 25 | c.Start(context.Background()) 26 | if e, g := 1, len(o); e != g { 27 | t.Fatalf("expected %+v, got %+v", e, g) 28 | } 29 | 30 | // Process all 31 | c = NewChan(ChanOptions{ProcessAll: true}) 32 | o = []int{} 33 | c.Add(func() { 34 | o = append(o, 1) 35 | c.Stop() 36 | }) 37 | c.Add(func() { 38 | o = append(o, 2) 39 | }) 40 | c.Start(context.Background()) 41 | if e, g := 2, len(o); e != g { 42 | t.Fatalf("expected %+v, got %+v", e, g) 43 | } 44 | 45 | // Default order 46 | c = NewChan(ChanOptions{ProcessAll: true}) 47 | o = []int{} 48 | c.Add(func() { 49 | o = append(o, 1) 50 | }) 51 | c.Add(func() { 52 | o = append(o, 2) 53 | c.Stop() 54 | }) 55 | c.Start(context.Background()) 56 | if e := []int{1, 2}; !reflect.DeepEqual(o, e) { 57 | t.Fatalf("expected %+v, got %+v", e, o) 58 | } 59 | 60 | // FILO order 61 | c = NewChan(ChanOptions{ 62 | Order: ChanOrderFILO, 63 | ProcessAll: true, 64 | }) 65 | o = []int{} 66 | c.Add(func() { 67 | o = append(o, 1) 68 | }) 69 | c.Add(func() { 70 | o = append(o, 2) 71 | c.Stop() 72 | }) 73 | c.Start(context.Background()) 74 | if e := []int{2, 1}; !reflect.DeepEqual(o, e) { 75 | t.Fatalf("expected %+v, got %+v", e, o) 76 | } 77 | 78 | // Block when started 79 | c = NewChan(ChanOptions{AddStrategy: ChanAddStrategyBlockWhenStarted}) 80 | o = []int{} 81 | go func() { 82 | c.Add(func() { 83 | o = append(o, 1) 84 | }) 85 | o = append(o, 2) 86 | c.Add(func() { 87 | o = append(o, 3) 88 | }) 89 | o = append(o, 4) 90 | c.Stop() 91 | }() 92 | c.Start(context.Background()) 93 | if e := []int{1, 2, 3, 4}; !reflect.DeepEqual(o, e) { 94 | t.Fatalf("expected %+v, got %+v", e, o) 95 | } 96 | } 97 | 98 | func TestGoroutineLimiter(t *testing.T) { 99 | l := NewGoroutineLimiter(GoroutineLimiterOptions{Max: 2}) 100 | defer l.Close() 101 | m := &sync.Mutex{} 102 | var c, max int 103 | const n = 4 104 | wg := &sync.WaitGroup{} 105 | wg.Add(n) 106 | fn := func() { 107 | defer wg.Done() 108 | defer func() { 109 | m.Lock() 110 | c-- 111 | m.Unlock() 112 | }() 113 | m.Lock() 114 | c++ 115 | if c > max { 116 | max = c 117 | } 118 | m.Unlock() 119 | time.Sleep(time.Millisecond) 120 | } 121 | for idx := 0; idx < n; idx++ { 122 | l.Do(fn) //nolint:errcheck 123 | } 124 | wg.Wait() 125 | if e := 2; e != max { 126 | t.Fatalf("expected %+v, got %+v", e, max) 127 | } 128 | } 129 | 130 | func TestEventer(t *testing.T) { 131 | e := NewEventer(EventerOptions{Chan: ChanOptions{ProcessAll: true}}) 132 | var o []string 133 | e.On("1", func(payload any) { o = append(o, payload.(string)) }) 134 | e.On("2", func(payload any) { o = append(o, payload.(string)) }) 135 | go func() { 136 | time.Sleep(10 * time.Millisecond) 137 | e.Dispatch("1", "1.1") 138 | e.Dispatch("2", "2") 139 | e.Dispatch("1", "1.2") 140 | e.Stop() 141 | }() 142 | e.Start(context.Background()) 143 | if e := []string{"1.1", "2", "1.2"}; !reflect.DeepEqual(e, o) { 144 | t.Fatalf("expected %+v, got %+v", e, o) 145 | } 146 | } 147 | 148 | type mockedStdLogger struct { 149 | m sync.Mutex 150 | ss []string 151 | } 152 | 153 | func (l *mockedStdLogger) Fatal(v ...any) { 154 | l.m.Lock() 155 | defer l.m.Unlock() 156 | l.ss = append(l.ss, "fatal: "+fmt.Sprint(v...)) 157 | } 158 | func (l *mockedStdLogger) Fatalf(format string, v ...any) { 159 | l.m.Lock() 160 | defer l.m.Unlock() 161 | l.ss = append(l.ss, "fatal: "+fmt.Sprintf(format, v...)) 162 | } 163 | func (l *mockedStdLogger) Print(v ...any) { 164 | l.m.Lock() 165 | defer l.m.Unlock() 166 | l.ss = append(l.ss, "print: "+fmt.Sprint(v...)) 167 | } 168 | func (l *mockedStdLogger) Printf(format string, v ...any) { 169 | l.m.Lock() 170 | defer l.m.Unlock() 171 | l.ss = append(l.ss, "print: "+fmt.Sprintf(format, v...)) 172 | } 173 | 174 | func TestDebugMutex(t *testing.T) { 175 | l := &mockedStdLogger{} 176 | m := NewDebugMutex("test", l, DebugMutexWithDeadlockDetection(time.Millisecond)) 177 | m.Lock() 178 | go func() { 179 | time.Sleep(100 * time.Millisecond) 180 | m.Unlock() 181 | }() 182 | m.Lock() 183 | l.m.Lock() 184 | ss := l.ss 185 | l.m.Unlock() 186 | if e, g := 1, len(ss); e != g { 187 | t.Fatalf("expected %d, got %d", e, g) 188 | } 189 | if s, g := "sync_test.go:177", ss[0]; !strings.Contains(g, s) { 190 | t.Fatalf("%s doesn't contain %s", g, s) 191 | } 192 | if s, g := "sync_test.go:182", ss[0]; !strings.Contains(g, s) { 193 | t.Fatalf("%s doesn't contain %s", g, s) 194 | } 195 | } 196 | 197 | func TestFIFOMutex(t *testing.T) { 198 | m := FIFOMutex{} 199 | var r []int 200 | m.Lock() 201 | wg := sync.WaitGroup{} 202 | testFIFOMutex(1, &m, &r, &wg) 203 | m.Unlock() 204 | wg.Wait() 205 | if e, g := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, r; !reflect.DeepEqual(e, g) { 206 | t.Fatalf("expected %v, got %v", e, g) 207 | } 208 | } 209 | 210 | func testFIFOMutex(i int, m *FIFOMutex, r *[]int, wg *sync.WaitGroup) { 211 | wg.Add(1) 212 | go func() { 213 | defer wg.Done() 214 | if i < 10 { 215 | testFIFOMutex(i+1, m, r, wg) 216 | } 217 | m.Lock() 218 | *r = append(*r, i) 219 | m.Unlock() 220 | }() 221 | } 222 | 223 | func TestBufferedBatcher(t *testing.T) { 224 | var count int 225 | var batches []map[any]int 226 | var bb1 *BufferedBatcher 227 | ctx1, cancel1 := context.WithCancel(context.Background()) 228 | defer cancel1() 229 | bb1 = NewBufferedBatcher(BufferedBatcherOptions{OnBatch: func(ctx context.Context, batch []any) { 230 | count++ 231 | if len(batch) > 0 { 232 | m := make(map[any]int) 233 | for _, i := range batch { 234 | m[i]++ 235 | } 236 | batches = append(batches, m) 237 | } 238 | switch count { 239 | case 1: 240 | bb1.Add(1) 241 | bb1.Add(1) 242 | bb1.Add(2) 243 | case 2: 244 | bb1.Add(2) 245 | bb1.Add(2) 246 | bb1.Add(3) 247 | case 3: 248 | bb1.Add(1) 249 | bb1.Add(1) 250 | bb1.Add(2) 251 | bb1.Add(2) 252 | bb1.Add(3) 253 | bb1.Add(3) 254 | case 4: 255 | go func() { 256 | time.Sleep(100 * time.Millisecond) 257 | bb1.Add(1) 258 | }() 259 | case 5: 260 | cancel1() 261 | } 262 | }}) 263 | bb1.Add(1) 264 | ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second) 265 | defer cancel2() 266 | go func() { 267 | defer cancel2() 268 | bb1.Start(ctx1) 269 | }() 270 | <-ctx2.Done() 271 | if errors.Is(ctx2.Err(), context.DeadlineExceeded) { 272 | t.Fatal("expected nothing, got timeout") 273 | } 274 | if e, g := []map[any]int{ 275 | {1: 1}, 276 | {1: 1, 2: 1}, 277 | {2: 1, 3: 1}, 278 | {1: 1, 2: 1, 3: 1}, 279 | {1: 1}, 280 | }, batches; !reflect.DeepEqual(e, g) { 281 | t.Fatalf("expected %+v, got %+v", e, g) 282 | } 283 | 284 | var bb2 *BufferedBatcher 285 | bb2 = NewBufferedBatcher(BufferedBatcherOptions{OnBatch: func(ctx context.Context, batch []any) { 286 | bb2.Start(context.Background()) 287 | bb2.Stop() 288 | bb2.Stop() 289 | }}) 290 | bb2.Add(1) 291 | ctx3, cancel3 := context.WithTimeout(context.Background(), time.Second) 292 | defer cancel3() 293 | go func() { 294 | defer cancel3() 295 | bb2.Start(context.Background()) 296 | }() 297 | <-ctx3.Done() 298 | if errors.Is(ctx3.Err(), context.DeadlineExceeded) { 299 | t.Fatal("expected nothing, got timeout") 300 | } 301 | } 302 | -------------------------------------------------------------------------------- /template.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | "strings" 8 | "sync" 9 | "text/template" 10 | ) 11 | 12 | // Templater represents an object capable of storing and parsing templates 13 | type Templater struct { 14 | layouts []string 15 | m sync.Mutex 16 | templates map[string]*template.Template 17 | } 18 | 19 | // NewTemplater creates a new templater 20 | func NewTemplater() *Templater { 21 | return &Templater{templates: make(map[string]*template.Template)} 22 | } 23 | 24 | // AddLayoutsFromDir walks through a dir and add files as layouts 25 | func (t *Templater) AddLayoutsFromDir(dirPath, ext string) (err error) { 26 | // Make sure to clean dir path so that we get consistent path separator with filepath.Walk 27 | dirPath = filepath.Clean(dirPath) 28 | 29 | // Get layouts 30 | if err = filepath.Walk(dirPath, func(path string, info os.FileInfo, e error) (err error) { 31 | // Check input error 32 | if e != nil { 33 | err = fmt.Errorf("astikit: walking layouts has an input error for path %s: %w", path, e) 34 | return 35 | } 36 | 37 | // Only process files 38 | if info.IsDir() { 39 | return 40 | } 41 | 42 | // Check extension 43 | if ext != "" && filepath.Ext(path) != ext { 44 | return 45 | } 46 | 47 | // Read layout 48 | var b []byte 49 | if b, err = os.ReadFile(path); err != nil { 50 | err = fmt.Errorf("astikit: reading %s failed: %w", path, err) 51 | return 52 | } 53 | 54 | // Add layout 55 | t.AddLayout(string(b)) 56 | return 57 | }); err != nil { 58 | err = fmt.Errorf("astikit: walking layouts in %s failed: %w", dirPath, err) 59 | return 60 | } 61 | return 62 | } 63 | 64 | // AddTemplatesFromDir walks through a dir and add files as templates 65 | func (t *Templater) AddTemplatesFromDir(dirPath, ext string) (err error) { 66 | // Make sure to clean dir path so that we get consistent path separator with filepath.Walk 67 | dirPath = filepath.Clean(dirPath) 68 | 69 | // Loop through templates 70 | if err = filepath.Walk(dirPath, func(path string, info os.FileInfo, e error) (err error) { 71 | // Check input error 72 | if e != nil { 73 | err = fmt.Errorf("astikit: walking templates has an input error for path %s: %w", path, e) 74 | return 75 | } 76 | 77 | // Only process files 78 | if info.IsDir() { 79 | return 80 | } 81 | 82 | // Check extension 83 | if ext != "" && filepath.Ext(path) != ext { 84 | return 85 | } 86 | 87 | // Read file 88 | var b []byte 89 | if b, err = os.ReadFile(path); err != nil { 90 | err = fmt.Errorf("astikit: reading template content of %s failed: %w", path, err) 91 | return 92 | } 93 | 94 | // Add template 95 | // We use ToSlash to homogenize Windows path 96 | if err = t.AddTemplate(filepath.ToSlash(strings.TrimPrefix(path, dirPath)), string(b)); err != nil { 97 | err = fmt.Errorf("astikit: adding template failed: %w", err) 98 | return 99 | } 100 | return 101 | }); err != nil { 102 | err = fmt.Errorf("astikit: walking templates in %s failed: %w", dirPath, err) 103 | return 104 | } 105 | return 106 | } 107 | 108 | // AddLayout adds a new layout 109 | func (t *Templater) AddLayout(c string) { 110 | t.layouts = append(t.layouts, c) 111 | } 112 | 113 | // AddTemplate adds a new template 114 | func (t *Templater) AddTemplate(path, content string) (err error) { 115 | // Parse 116 | var tpl *template.Template 117 | if tpl, err = t.Parse(content); err != nil { 118 | err = fmt.Errorf("astikit: parsing template for path %s failed: %w", path, err) 119 | return 120 | } 121 | 122 | // Add template 123 | t.m.Lock() 124 | t.templates[path] = tpl 125 | t.m.Unlock() 126 | return 127 | } 128 | 129 | // DelTemplate deletes a template 130 | func (t *Templater) DelTemplate(path string) { 131 | t.m.Lock() 132 | defer t.m.Unlock() 133 | delete(t.templates, path) 134 | } 135 | 136 | // Template retrieves a templates 137 | func (t *Templater) Template(path string) (tpl *template.Template, ok bool) { 138 | t.m.Lock() 139 | defer t.m.Unlock() 140 | tpl, ok = t.templates[path] 141 | return 142 | } 143 | 144 | // Parse parses the content of a template 145 | func (t *Templater) Parse(content string) (o *template.Template, err error) { 146 | // Parse content 147 | o = template.New("root") 148 | if o, err = o.Parse(content); err != nil { 149 | err = fmt.Errorf("astikit: parsing template content failed: %w", err) 150 | return 151 | } 152 | 153 | // Parse layouts 154 | for idx, l := range t.layouts { 155 | if o, err = o.Parse(l); err != nil { 156 | err = fmt.Errorf("astikit: parsing layout #%d failed: %w", idx+1, err) 157 | return 158 | } 159 | } 160 | return 161 | } 162 | -------------------------------------------------------------------------------- /template_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | ) 7 | 8 | func TestTemplater(t *testing.T) { 9 | tp := NewTemplater() 10 | if err := tp.AddLayoutsFromDir("testdata/template/layouts", ".html"); err != nil { 11 | t.Fatalf("expected no error, got %+v", err) 12 | } 13 | if e, g := 2, len(tp.layouts); e != g { 14 | t.Fatalf("expected %v, got %v", e, g) 15 | } 16 | if err := tp.AddTemplatesFromDir("testdata/template/templates", ".html"); err != nil { 17 | t.Fatalf("expected no error, got %+v", err) 18 | } 19 | if e, g := 2, len(tp.templates); e != g { 20 | t.Fatalf("expected %v, got %v", e, g) 21 | } 22 | tp.DelTemplate("/dir/template2.html") 23 | if e, g := 1, len(tp.templates); e != g { 24 | t.Fatalf("expected %v, got %v", e, g) 25 | } 26 | v, ok := tp.Template("/template1.html") 27 | if !ok { 28 | t.Fatal("no template found") 29 | } 30 | w := &bytes.Buffer{} 31 | if err := v.Execute(w, nil); err != nil { 32 | t.Fatalf("expected no error, got %+v", err) 33 | } 34 | if e, g := "Layout - Template", w.String(); g != e { 35 | t.Fatalf("expected %s, got %s", e, g) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /testdata/archive/d/f: -------------------------------------------------------------------------------- 1 | 1 -------------------------------------------------------------------------------- /testdata/archive/f: -------------------------------------------------------------------------------- 1 | 0 -------------------------------------------------------------------------------- /testdata/ipc/f: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asticode/go-astikit/3eedad89c3d8a32a45fc28263a522fcc4fc5deff/testdata/ipc/f -------------------------------------------------------------------------------- /testdata/os/d/d1/f11: -------------------------------------------------------------------------------- 1 | 2 -------------------------------------------------------------------------------- /testdata/os/d/d2/d21/f211: -------------------------------------------------------------------------------- 1 | 4 -------------------------------------------------------------------------------- /testdata/os/d/d2/f21: -------------------------------------------------------------------------------- 1 | 3 -------------------------------------------------------------------------------- /testdata/os/d/f1: -------------------------------------------------------------------------------- 1 | 1 -------------------------------------------------------------------------------- /testdata/os/f: -------------------------------------------------------------------------------- 1 | 0 -------------------------------------------------------------------------------- /testdata/ssh/f: -------------------------------------------------------------------------------- 1 | 0 -------------------------------------------------------------------------------- /testdata/template/layouts/dir/layout2.html: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asticode/go-astikit/3eedad89c3d8a32a45fc28263a522fcc4fc5deff/testdata/template/layouts/dir/layout2.html -------------------------------------------------------------------------------- /testdata/template/layouts/dummy.css: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asticode/go-astikit/3eedad89c3d8a32a45fc28263a522fcc4fc5deff/testdata/template/layouts/dummy.css -------------------------------------------------------------------------------- /testdata/template/layouts/layout1.html: -------------------------------------------------------------------------------- 1 | {{ define "layout" }}Layout - {{ template "template" . }}{{ end }} -------------------------------------------------------------------------------- /testdata/template/templates/dir/template2.html: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asticode/go-astikit/3eedad89c3d8a32a45fc28263a522fcc4fc5deff/testdata/template/templates/dir/template2.html -------------------------------------------------------------------------------- /testdata/template/templates/dummy.css: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asticode/go-astikit/3eedad89c3d8a32a45fc28263a522fcc4fc5deff/testdata/template/templates/dummy.css -------------------------------------------------------------------------------- /testdata/template/templates/template1.html: -------------------------------------------------------------------------------- 1 | {{ define "template" }}Template{{ end }}{{ template "layout" . }} -------------------------------------------------------------------------------- /testdata/translator/d1/d2/en.json: -------------------------------------------------------------------------------- 1 | {"6":"6"} -------------------------------------------------------------------------------- /testdata/translator/d1/en.json: -------------------------------------------------------------------------------- 1 | {"5":"5"} -------------------------------------------------------------------------------- /testdata/translator/en.json: -------------------------------------------------------------------------------- 1 | {"1":"1","2":{"3":"3"},"f":"f%sf"} -------------------------------------------------------------------------------- /testdata/translator/fr.json: -------------------------------------------------------------------------------- 1 | {"4":"4"} -------------------------------------------------------------------------------- /testdata/translator/invalid.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asticode/go-astikit/3eedad89c3d8a32a45fc28263a522fcc4fc5deff/testdata/translator/invalid.csv -------------------------------------------------------------------------------- /time.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "context" 5 | "encoding" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "strconv" 10 | "strings" 11 | "time" 12 | ) 13 | 14 | // Sleep is a cancellable sleep 15 | func Sleep(ctx context.Context, d time.Duration) (err error) { 16 | for { 17 | select { 18 | case <-time.After(d): 19 | return 20 | case <-ctx.Done(): 21 | err = ctx.Err() 22 | return 23 | } 24 | } 25 | } 26 | 27 | var now = time.Now 28 | 29 | func Now() time.Time { 30 | return now() 31 | } 32 | 33 | type mockedNow struct { 34 | previous func() time.Time 35 | } 36 | 37 | func newMockedNow() *mockedNow { 38 | return &mockedNow{previous: now} 39 | } 40 | 41 | func (m *mockedNow) Close() error { 42 | now = m.previous 43 | return nil 44 | } 45 | 46 | func MockNow(fn func() time.Time) io.Closer { 47 | m := newMockedNow() 48 | now = fn 49 | return m 50 | } 51 | 52 | var ( 53 | _ encoding.TextMarshaler = (*Timestamp)(nil) 54 | _ encoding.TextUnmarshaler = (*Timestamp)(nil) 55 | _ json.Marshaler = (*Timestamp)(nil) 56 | _ json.Unmarshaler = (*Timestamp)(nil) 57 | ) 58 | 59 | type Timestamp struct { 60 | time.Time 61 | } 62 | 63 | func NewTimestamp(t time.Time) *Timestamp { 64 | return &Timestamp{Time: t} 65 | } 66 | 67 | func (t *Timestamp) UnmarshalJSON(text []byte) error { 68 | return t.UnmarshalText(text) 69 | } 70 | 71 | func (t *Timestamp) UnmarshalText(text []byte) (err error) { 72 | var i int 73 | if i, err = strconv.Atoi(string(text)); err != nil { 74 | return 75 | } 76 | t.Time = time.Unix(int64(i), 0) 77 | return 78 | } 79 | 80 | func (t Timestamp) MarshalJSON() ([]byte, error) { 81 | return t.MarshalText() 82 | } 83 | 84 | func (t Timestamp) MarshalText() (text []byte, err error) { 85 | text = []byte(strconv.Itoa(int(t.UTC().Unix()))) 86 | return 87 | } 88 | 89 | var ( 90 | _ encoding.TextMarshaler = (*TimestampNano)(nil) 91 | _ encoding.TextUnmarshaler = (*TimestampNano)(nil) 92 | _ json.Marshaler = (*TimestampNano)(nil) 93 | _ json.Unmarshaler = (*TimestampNano)(nil) 94 | ) 95 | 96 | type TimestampNano struct { 97 | time.Time 98 | } 99 | 100 | func NewTimestampNano(t time.Time) *TimestampNano { 101 | return &TimestampNano{Time: t} 102 | } 103 | 104 | func (t *TimestampNano) UnmarshalJSON(text []byte) error { 105 | return t.UnmarshalText(text) 106 | } 107 | 108 | func (t *TimestampNano) UnmarshalText(text []byte) (err error) { 109 | var i int 110 | if i, err = strconv.Atoi(string(text)); err != nil { 111 | return 112 | } 113 | t.Time = time.Unix(0, int64(i)) 114 | return 115 | } 116 | 117 | func (t TimestampNano) MarshalJSON() ([]byte, error) { 118 | return t.MarshalText() 119 | } 120 | 121 | func (t TimestampNano) MarshalText() (text []byte, err error) { 122 | text = []byte(strconv.Itoa(int(t.UTC().UnixNano()))) 123 | return 124 | } 125 | 126 | var ( 127 | _ json.Marshaler = (*Stopwatch)(nil) 128 | _ json.Unmarshaler = (*Stopwatch)(nil) 129 | ) 130 | 131 | type Stopwatch struct { 132 | children []*Stopwatch 133 | createdAt time.Time 134 | doneAt time.Time 135 | id string 136 | } 137 | 138 | func NewStopwatch() *Stopwatch { 139 | return newStopwatch("") 140 | } 141 | 142 | func newStopwatch(id string) *Stopwatch { 143 | return &Stopwatch{ 144 | createdAt: Now(), 145 | id: id, 146 | } 147 | } 148 | 149 | func (s *Stopwatch) NewChild(id string) *Stopwatch { 150 | // Create stopwatch 151 | dst := newStopwatch(id) 152 | 153 | // Make sure to propagate done to children 154 | s.propagateDone(dst.createdAt) 155 | 156 | // Append 157 | s.children = append(s.children, dst) 158 | return dst 159 | } 160 | 161 | func (s *Stopwatch) propagateDone(doneAt time.Time) { 162 | // No children 163 | if len(s.children) == 0 { 164 | return 165 | } 166 | 167 | // Get child 168 | c := s.children[len(s.children)-1] 169 | 170 | // Update done at 171 | if c.doneAt.IsZero() { 172 | c.doneAt = doneAt 173 | } 174 | 175 | // Make sure to propagate done to children 176 | c.propagateDone(doneAt) 177 | } 178 | 179 | func (s *Stopwatch) Done() { 180 | // Update done at 181 | if s.doneAt.IsZero() { 182 | s.doneAt = Now() 183 | } 184 | 185 | // Make sure to propagate done to children 186 | s.propagateDone(s.doneAt) 187 | } 188 | 189 | func (s *Stopwatch) FindChild(id string, nextIDs ...string) (*Stopwatch, bool) { 190 | return s.child(append([]string{id}, nextIDs...)...) 191 | } 192 | 193 | func (s *Stopwatch) child(ids ...string) (*Stopwatch, bool) { 194 | // Loop through ids 195 | for idx, id := range ids { 196 | // Loop through children 197 | for _, c := range s.children { 198 | // Child doesn't match 199 | if c.id != id { 200 | continue 201 | } 202 | 203 | // Last id 204 | if idx == len(ids)-1 { 205 | return c, true 206 | } 207 | return c.child(ids[idx:]...) 208 | } 209 | } 210 | return nil, false 211 | } 212 | 213 | func (s *Stopwatch) Duration() time.Duration { 214 | if !s.doneAt.IsZero() { 215 | return s.doneAt.Sub(s.createdAt) 216 | } 217 | return Now().Sub(s.createdAt) 218 | } 219 | 220 | func (s *Stopwatch) Merge(i *Stopwatch) { 221 | // No children 222 | if len(i.children) == 0 { 223 | return 224 | } 225 | 226 | // Make sure to propagate done to children 227 | s.propagateDone(i.children[0].createdAt) 228 | 229 | // Append 230 | s.children = append(s.children, i.children...) 231 | } 232 | 233 | func (s *Stopwatch) Dump() string { 234 | return s.dump("", s.createdAt) 235 | } 236 | 237 | func (s *Stopwatch) dump(ident string, rootCreatedAt time.Time) string { 238 | // Dump stopwatch 239 | var ss []string 240 | if ident == "" { 241 | ss = append(ss, DurationMinimalistFormat(s.doneAt.Sub(s.createdAt))) 242 | } else { 243 | ss = append(ss, fmt.Sprintf("%s[%s]%s: %s", ident, DurationMinimalistFormat(s.createdAt.Sub(rootCreatedAt)), s.id, DurationMinimalistFormat(s.doneAt.Sub(s.createdAt)))) 244 | } 245 | 246 | // Loop through children 247 | ident += " " 248 | for _, c := range s.children { 249 | // Dump child 250 | ss = append(ss, c.dump(ident, rootCreatedAt)) 251 | } 252 | return strings.Join(ss, "\n") 253 | } 254 | 255 | type stopwatchJSON struct { 256 | Children []stopwatchJSON `json:"children"` 257 | CreatedAt TimestampNano `json:"created_at"` 258 | DoneAt TimestampNano `json:"done_at"` 259 | ID string `json:"id"` 260 | } 261 | 262 | func (sj stopwatchJSON) toStopwatch(s *Stopwatch) { 263 | s.createdAt = sj.CreatedAt.Time 264 | s.doneAt = sj.DoneAt.Time 265 | s.id = sj.ID 266 | for _, cj := range sj.Children { 267 | c := &Stopwatch{} 268 | cj.toStopwatch(c) 269 | s.children = append(s.children, c) 270 | } 271 | } 272 | 273 | func (s *Stopwatch) toStopwatchJSON() (sj stopwatchJSON) { 274 | sj.Children = []stopwatchJSON{} 275 | sj.CreatedAt = *NewTimestampNano(s.createdAt) 276 | sj.DoneAt = *NewTimestampNano(s.doneAt) 277 | sj.ID = s.id 278 | for _, c := range s.children { 279 | sj.Children = append(sj.Children, c.toStopwatchJSON()) 280 | } 281 | return 282 | } 283 | 284 | func (s *Stopwatch) UnmarshalJSON(text []byte) error { 285 | var j stopwatchJSON 286 | if err := json.Unmarshal(text, &j); err != nil { 287 | return err 288 | } 289 | j.toStopwatch(s) 290 | return nil 291 | } 292 | 293 | func (s Stopwatch) MarshalJSON() ([]byte, error) { 294 | return json.Marshal(s.toStopwatchJSON()) 295 | } 296 | 297 | func DurationMinimalistFormat(d time.Duration) string { 298 | if d < time.Microsecond { 299 | return strconv.Itoa(int(d)) + "ns" 300 | } else if d < time.Millisecond { 301 | return strconv.Itoa(int(d/1e3)) + "µs" 302 | } else if d < time.Second { 303 | return strconv.Itoa(int(d/1e6)) + "ms" 304 | } 305 | return strconv.Itoa(int(d/1e9)) + "s" 306 | } 307 | -------------------------------------------------------------------------------- /time_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "errors" 8 | "reflect" 9 | "sync" 10 | "testing" 11 | "time" 12 | ) 13 | 14 | func TestSleep(t *testing.T) { 15 | var ctx, cancel = context.WithCancel(context.Background()) 16 | var err error 17 | var wg = &sync.WaitGroup{} 18 | wg.Add(1) 19 | go func() { 20 | defer wg.Done() 21 | err = Sleep(ctx, time.Minute) 22 | }() 23 | cancel() 24 | wg.Wait() 25 | if e, g := context.Canceled, err; !errors.Is(g, e) { 26 | t.Fatalf("err should be %s, got %s", e, g) 27 | } 28 | } 29 | 30 | func TestTimestamp(t *testing.T) { 31 | const j = `{"value":1495290215}` 32 | v := struct { 33 | Value Timestamp `json:"value"` 34 | }{} 35 | err := json.Unmarshal([]byte(j), &v) 36 | if err != nil { 37 | t.Fatalf("err should be nil, got %s", err) 38 | } 39 | if e, g := int64(1495290215), v.Value.Unix(); g != e { 40 | t.Fatalf("timestamp should be %v, got %v", e, g) 41 | } 42 | b, err := json.Marshal(v) 43 | if err != nil { 44 | t.Fatalf("err should be nil, got %s", err) 45 | } 46 | if string(b) != j { 47 | t.Fatalf("json should be %s, got %s", j, b) 48 | } 49 | } 50 | 51 | func isAbsoluteTime(t time.Time) bool { 52 | return t.Year() >= time.Now().Year()-1 53 | } 54 | 55 | func TestNow(t *testing.T) { 56 | if g := Now(); !isAbsoluteTime(Now()) { 57 | t.Fatalf("expected %s to be an absolute time", g) 58 | } 59 | var count int64 60 | m := MockNow(func() time.Time { 61 | count++ 62 | return time.Unix(count, 0) 63 | }) 64 | if e, g := time.Unix(1, 0), Now(); !reflect.DeepEqual(e, g) { 65 | t.Fatalf("expected %s, got %s", e, g) 66 | } 67 | m.Close() 68 | if g := Now(); !isAbsoluteTime(Now()) { 69 | t.Fatalf("expected %s to be an absolute time", g) 70 | } 71 | } 72 | 73 | func TestTimestampNano(t *testing.T) { 74 | const j = `{"value":1732636645443709000}` 75 | v := struct { 76 | Value TimestampNano `json:"value"` 77 | }{} 78 | err := json.Unmarshal([]byte(j), &v) 79 | if err != nil { 80 | t.Fatalf("err should be nil, got %s", err) 81 | } 82 | if e, g := int64(1732636645443709000), v.Value.UnixNano(); g != e { 83 | t.Fatalf("timestamp should be %v, got %v", e, g) 84 | } 85 | b, err := json.Marshal(v) 86 | if err != nil { 87 | t.Fatalf("err should be nil, got %s", err) 88 | } 89 | if string(b) != j { 90 | t.Fatalf("json should be %s, got %s", j, b) 91 | } 92 | } 93 | 94 | func TestStopwatch(t *testing.T) { 95 | var count int64 96 | defer MockNow(func() time.Time { 97 | count++ 98 | return time.Unix(count, 0) 99 | }).Close() 100 | 101 | s1 := NewStopwatch() 102 | s2 := s1.NewChild("1") 103 | if e, g := 2*time.Second, s1.Duration(); e != g { 104 | t.Fatalf("expected %s, got %s", e, g) 105 | } 106 | s2.Done() 107 | s1.NewChild("2") 108 | s3 := s1.NewChild("3") 109 | s3.NewChild("3-1") 110 | s4 := s3.NewChild("3-2") 111 | s4.NewChild("3-2-1") 112 | s5 := NewStopwatch() 113 | s5.NewChild("3-2-2") 114 | s6 := s5.NewChild("3-2-3") 115 | s7 := s6.NewChild("3-2-3-1") 116 | s5.Done() 117 | s4.Merge(s5) 118 | s3.NewChild("3-3") 119 | s1.NewChild("4") 120 | s1.Done() 121 | if e, g := `16s 122 | [1s]1: 2s 123 | [4s]2: 1s 124 | [5s]3: 10s 125 | [6s]3-1: 1s 126 | [7s]3-2: 7s 127 | [8s]3-2-1: 2s 128 | [10s]3-2-2: 1s 129 | [11s]3-2-3: 2s 130 | [12s]3-2-3-1: 1s 131 | [14s]3-3: 1s 132 | [15s]4: 1s`, s1.Dump(); e != g { 133 | t.Fatalf("expected %s, got %s", e, g) 134 | } 135 | if e, g := 16*time.Second, s1.Duration(); e != g { 136 | t.Fatalf("expected %s, got %s", e, g) 137 | } 138 | b, err := s5.MarshalJSON() 139 | if err != nil { 140 | t.Fatalf("expected no error, got %s", err) 141 | } 142 | if e, g := []byte(`{"children":[{"children":[],"created_at":11000000000,"done_at":12000000000,"id":"3-2-2"},{"children":[{"children":[],"created_at":13000000000,"done_at":14000000000,"id":"3-2-3-1"}],"created_at":12000000000,"done_at":14000000000,"id":"3-2-3"}],"created_at":10000000000,"done_at":14000000000,"id":""}`), b; !bytes.Equal(e, g) { 143 | t.Fatalf("expected %s, got %s", e, g) 144 | } 145 | var s8 Stopwatch 146 | if err = s8.UnmarshalJSON(b); err != nil { 147 | t.Fatalf("expected no error, got %s", err) 148 | } 149 | if e, g := *s5, s8; !reflect.DeepEqual(e, g) { 150 | t.Fatalf("expected %+v, got %+v", e, g) 151 | } 152 | s9, ok := s1.FindChild("3") 153 | if !ok { 154 | t.Fatal("expected true, got false") 155 | } 156 | if e, g := s3, s9; e != g { 157 | t.Fatalf("expected %+v, got %+v", e, g) 158 | } 159 | if s9, ok = s1.FindChild("3", "3-2", "3-2-3", "3-2-3-1"); !ok { 160 | t.Fatal("expected true, got false") 161 | } 162 | if e, g := s7, s9; e != g { 163 | t.Fatalf("expected %+v, got %+v", e, g) 164 | } 165 | } 166 | 167 | func TestDurationMinimalistFormat(t *testing.T) { 168 | for _, v := range []struct { 169 | d time.Duration 170 | e string 171 | }{ 172 | { 173 | d: 123 * time.Nanosecond, 174 | e: "123ns", 175 | }, 176 | { 177 | d: 123456 * time.Nanosecond, 178 | e: "123µs", 179 | }, 180 | { 181 | d: 123456789 * time.Nanosecond, 182 | e: "123ms", 183 | }, 184 | { 185 | d: 123456789123 * time.Nanosecond, 186 | e: "123s", 187 | }, 188 | } { 189 | if g := DurationMinimalistFormat(v.d); v.e != g { 190 | t.Fatalf("expected %s, got %s", v.e, g) 191 | } 192 | } 193 | } 194 | -------------------------------------------------------------------------------- /translator.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "net/http" 8 | "os" 9 | "path/filepath" 10 | "sort" 11 | "strconv" 12 | "strings" 13 | "sync" 14 | ) 15 | 16 | // Translator represents an object capable of translating stuff 17 | type Translator struct { 18 | defaultLanguage string 19 | m *sync.RWMutex // Lock p 20 | p map[string]string 21 | validLanguages map[string]bool 22 | } 23 | 24 | // TranslatorOptions represents Translator options 25 | type TranslatorOptions struct { 26 | DefaultLanguage string 27 | ValidLanguages []string 28 | } 29 | 30 | // NewTranslator creates a new Translator 31 | func NewTranslator(o TranslatorOptions) (t *Translator) { 32 | t = &Translator{ 33 | defaultLanguage: o.DefaultLanguage, 34 | m: &sync.RWMutex{}, 35 | p: make(map[string]string), 36 | validLanguages: make(map[string]bool), 37 | } 38 | for _, l := range o.ValidLanguages { 39 | t.validLanguages[l] = true 40 | } 41 | return 42 | } 43 | 44 | // ParseDir adds translations located in ".json" files in the specified dir 45 | // If ".json" files are located in child dirs, keys will be prefixed with their paths 46 | func (t *Translator) ParseDir(dirPath string) (err error) { 47 | // Default dir path 48 | if dirPath == "" { 49 | if dirPath, err = os.Getwd(); err != nil { 50 | err = fmt.Errorf("astikit: getwd failed: %w", err) 51 | return 52 | } 53 | } 54 | 55 | // Make sure to clean dir path so that we get consistent path separator with filepath.Walk 56 | dirPath = filepath.Clean(dirPath) 57 | 58 | // Walk through dir 59 | if err = filepath.Walk(dirPath, func(path string, info os.FileInfo, e error) (err error) { 60 | // Check input error 61 | if e != nil { 62 | err = fmt.Errorf("astikit: walking %s has an input error for path %s: %w", dirPath, path, e) 63 | return 64 | } 65 | 66 | // Only process files 67 | if info.IsDir() { 68 | return 69 | } 70 | 71 | // Only process ".json" files 72 | if filepath.Ext(path) != ".json" { 73 | return 74 | } 75 | 76 | // Parse file 77 | if err = t.ParseFile(dirPath, path); err != nil { 78 | err = fmt.Errorf("astikit: parsing %s failed: %w", path, err) 79 | return 80 | } 81 | return 82 | }); err != nil { 83 | err = fmt.Errorf("astikit: walking %s failed: %w", dirPath, err) 84 | return 85 | } 86 | return 87 | } 88 | 89 | // ParseFile adds translation located in the provided path 90 | func (t *Translator) ParseFile(dirPath, path string) (err error) { 91 | // Lock 92 | t.m.Lock() 93 | defer t.m.Unlock() 94 | 95 | // Open file 96 | var f *os.File 97 | if f, err = os.Open(path); err != nil { 98 | err = fmt.Errorf("astikit: opening %s failed: %w", path, err) 99 | return 100 | } 101 | defer f.Close() 102 | 103 | // Unmarshal 104 | var p map[string]any 105 | if err = json.NewDecoder(f).Decode(&p); err != nil { 106 | err = fmt.Errorf("astikit: unmarshaling %s failed: %w", path, err) 107 | return 108 | } 109 | 110 | // Get language 111 | language := strings.TrimSuffix(filepath.Base(path), filepath.Ext(path)) 112 | 113 | // Update valid languages 114 | t.validLanguages[language] = true 115 | 116 | // Get prefix 117 | prefix := language 118 | if dp := filepath.Dir(path); dp != dirPath { 119 | var fs []string 120 | for _, v := range strings.Split(strings.TrimPrefix(dp, dirPath), string(os.PathSeparator)) { 121 | if v != "" { 122 | fs = append(fs, v) 123 | } 124 | } 125 | prefix += "." + strings.Join(fs, ".") 126 | } 127 | 128 | // Parse 129 | t.parse(p, prefix) 130 | return 131 | } 132 | 133 | func (t *Translator) key(prefix, key string) string { 134 | return prefix + "." + key 135 | } 136 | 137 | func (t *Translator) parse(i map[string]any, prefix string) { 138 | for k, v := range i { 139 | p := t.key(prefix, k) 140 | switch a := v.(type) { 141 | case string: 142 | t.p[p] = a 143 | case map[string]any: 144 | t.parse(a, p) 145 | } 146 | } 147 | } 148 | 149 | // HTTPMiddleware is the Translator HTTP middleware 150 | func (t *Translator) HTTPMiddleware(h http.Handler) http.Handler { 151 | return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 152 | // Store language in context 153 | if l := r.Header.Get("Accept-Language"); l != "" { 154 | *r = *r.WithContext(contextWithTranslatorLanguage(r.Context(), t.parseAcceptLanguage(l))) 155 | } 156 | 157 | // Next handler 158 | h.ServeHTTP(rw, r) 159 | }) 160 | } 161 | 162 | func (t *Translator) parseAcceptLanguage(h string) string { 163 | // Split on comma 164 | var qs []float64 165 | ls := make(map[float64][]string) 166 | for _, c := range strings.Split(strings.TrimSpace(h), ",") { 167 | // Empty 168 | c = strings.TrimSpace(c) 169 | if c == "" { 170 | continue 171 | } 172 | 173 | // Split on semi colon 174 | ss := strings.Split(c, ";") 175 | 176 | // Parse coefficient 177 | q := float64(1) 178 | if len(ss) > 1 { 179 | s := strings.TrimSpace(ss[1]) 180 | if strings.HasPrefix(s, "q=") { 181 | var err error 182 | if q, err = strconv.ParseFloat(strings.TrimPrefix(s, "q="), 64); err != nil { 183 | q = 1 184 | } 185 | } 186 | } 187 | 188 | // Add 189 | if _, ok := ls[q]; !ok { 190 | qs = append(qs, q) 191 | } 192 | ls[q] = append(ls[q], strings.TrimSpace(ss[0])) 193 | } 194 | 195 | // Order coefficients 196 | sort.Float64s(qs) 197 | 198 | // Loop through coefficients in reverse order 199 | for idx := len(qs) - 1; idx >= 0; idx-- { 200 | for _, l := range ls[qs[idx]] { 201 | if _, ok := t.validLanguages[l]; ok { 202 | return l 203 | } 204 | } 205 | } 206 | return "" 207 | } 208 | 209 | const contextKeyTranslatorLanguage = contextKey("astikit.translator.language") 210 | 211 | type contextKey string 212 | 213 | func contextWithTranslatorLanguage(ctx context.Context, language string) context.Context { 214 | return context.WithValue(ctx, contextKeyTranslatorLanguage, language) 215 | } 216 | 217 | func translatorLanguageFromContext(ctx context.Context) string { 218 | v, ok := ctx.Value(contextKeyTranslatorLanguage).(string) 219 | if !ok { 220 | return "" 221 | } 222 | return v 223 | } 224 | 225 | func (t *Translator) language(language string) string { 226 | if language == "" { 227 | return t.defaultLanguage 228 | } 229 | return language 230 | } 231 | 232 | // LanguageCtx returns the translator language from the context, or the default language if not in the context 233 | func (t *Translator) LanguageCtx(ctx context.Context) string { 234 | return t.language(translatorLanguageFromContext(ctx)) 235 | } 236 | 237 | // Translate translates a key into a specific language 238 | func (t *Translator) Translate(language, key string) string { 239 | // Lock 240 | t.m.RLock() 241 | defer t.m.RUnlock() 242 | 243 | // Get translation 244 | k1 := t.key(t.language(language), key) 245 | v, ok := t.p[k1] 246 | if ok { 247 | return v 248 | } 249 | 250 | // Default translation 251 | k2 := t.key(t.defaultLanguage, key) 252 | if v, ok = t.p[k2]; ok { 253 | return v 254 | } 255 | return k1 256 | } 257 | 258 | // Translatef translates a key into a specific language with optional formatting args 259 | func (t *Translator) Translatef(language, key string, args ...any) string { 260 | return fmt.Sprintf(t.Translate(language, key), args...) 261 | } 262 | 263 | // TranslateCtx is an alias for TranslateC 264 | func (t *Translator) TranslateCtx(ctx context.Context, key string) string { 265 | return t.TranslateC(ctx, key) 266 | } 267 | 268 | // TranslateC translates a key using the language specified in the context 269 | func (t *Translator) TranslateC(ctx context.Context, key string) string { 270 | return t.Translate(translatorLanguageFromContext(ctx), key) 271 | } 272 | 273 | func (t *Translator) TranslateCf(ctx context.Context, key string, args ...any) string { 274 | return t.Translatef(translatorLanguageFromContext(ctx), key, args...) 275 | } 276 | -------------------------------------------------------------------------------- /translator_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "net/http" 5 | "net/http/httptest" 6 | "reflect" 7 | "strings" 8 | "testing" 9 | ) 10 | 11 | func TestTranslator(t *testing.T) { 12 | // Setup 13 | tl := NewTranslator(TranslatorOptions{DefaultLanguage: "fr"}) 14 | 15 | // Parse dir 16 | err := tl.ParseDir("testdata/translator") 17 | if err != nil { 18 | t.Fatalf("expected no error, got %v", err) 19 | } 20 | if e := map[string]string{ 21 | "en.1": "1", 22 | "en.2.3": "3", 23 | "en.d1.5": "5", 24 | "en.d1.d2.6": "6", 25 | "en.f": "f%sf", 26 | "fr.4": "4", 27 | }; !reflect.DeepEqual(e, tl.p) { 28 | t.Fatalf("expected %+v, got %+v", e, tl.p) 29 | } 30 | 31 | // Middleware 32 | var o string 33 | s := httptest.NewServer(ChainHTTPMiddlewares(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 34 | var args []any 35 | if v := r.Header.Get("args"); v != "" { 36 | for _, s := range strings.Split(v, ",") { 37 | args = append(args, s) 38 | } 39 | } 40 | if len(args) > 0 { 41 | o = tl.TranslateCf(r.Context(), r.Header.Get("key"), args...) 42 | } else { 43 | o = tl.TranslateC(r.Context(), r.Header.Get("key")) 44 | } 45 | }), tl.HTTPMiddleware)) 46 | defer s.Close() 47 | 48 | // Translate 49 | for _, v := range []struct { 50 | args []string 51 | expected string 52 | key string 53 | language string 54 | }{ 55 | { 56 | expected: "4", 57 | key: "4", 58 | }, 59 | { 60 | expected: "fr.1", 61 | key: "1", 62 | }, 63 | { 64 | expected: "3", 65 | key: "2.3", 66 | language: "en-US,en;q=0.8", 67 | }, 68 | { 69 | expected: "4", 70 | key: "4", 71 | language: "en", 72 | }, 73 | { 74 | expected: "en.5", 75 | key: "5", 76 | language: "en", 77 | }, 78 | { 79 | expected: "6", 80 | key: "d1.d2.6", 81 | language: "en", 82 | }, 83 | { 84 | expected: "4", 85 | key: "4", 86 | language: "it", 87 | }, 88 | { 89 | args: []string{"arg"}, 90 | expected: "fargf", 91 | key: "f", 92 | language: "en", 93 | }, 94 | } { 95 | r, err := http.NewRequest(http.MethodGet, s.URL, nil) 96 | if err != nil { 97 | t.Fatalf("expected no error, got %+v", err) 98 | } 99 | if len(v.args) > 0 { 100 | r.Header.Set("args", strings.Join(v.args, ",")) 101 | } 102 | r.Header.Set("key", v.key) 103 | if v.language != "" { 104 | r.Header.Set("Accept-Language", v.language) 105 | } 106 | _, err = http.DefaultClient.Do(r) 107 | if err != nil { 108 | t.Fatalf("expected no error, got %+v", err) 109 | } 110 | if !reflect.DeepEqual(v.expected, o) { 111 | t.Fatalf("expected %+v, got %+v", v.expected, o) 112 | } 113 | } 114 | } 115 | 116 | func TestTranslator_ParseAcceptLanguage(t *testing.T) { 117 | tl := NewTranslator(TranslatorOptions{ValidLanguages: []string{"en", "fr"}}) 118 | if e, g := "", tl.parseAcceptLanguage(""); !reflect.DeepEqual(e, g) { 119 | t.Fatalf("expected %+v, got %+v", e, g) 120 | } 121 | if e, g := "fr", tl.parseAcceptLanguage(" fr-FR, fr ; q=0.9 ,en;q=0.7,en-US;q=0.8 "); !reflect.DeepEqual(e, g) { 122 | t.Fatalf("expected %+v, got %+v", e, g) 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /worker.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "context" 5 | "os" 6 | "os/signal" 7 | "sync" 8 | ) 9 | 10 | // Worker represents an object capable of blocking, handling signals and stopping 11 | type Worker struct { 12 | cancel context.CancelFunc 13 | ctx context.Context 14 | l SeverityLogger 15 | os, ow sync.Once 16 | wg *sync.WaitGroup 17 | } 18 | 19 | // WorkerOptions represents worker options 20 | type WorkerOptions struct { 21 | Logger StdLogger 22 | } 23 | 24 | // NewWorker builds a new worker 25 | func NewWorker(o WorkerOptions) (w *Worker) { 26 | w = &Worker{ 27 | l: AdaptStdLogger(o.Logger), 28 | wg: &sync.WaitGroup{}, 29 | } 30 | w.ctx, w.cancel = context.WithCancel(context.Background()) 31 | w.wg.Add(1) 32 | w.l.Info("astikit: starting worker...") 33 | return 34 | } 35 | 36 | // HandleSignals handles signals 37 | func (w *Worker) HandleSignals(hs ...SignalHandler) { 38 | // Prepend mandatory handler 39 | hs = append([]SignalHandler{TermSignalHandler(w.Stop)}, hs...) 40 | 41 | // Notify 42 | ch := make(chan os.Signal, 1) 43 | signal.Notify(ch) 44 | 45 | // Execute in a task 46 | w.NewTask().Do(func() { 47 | for { 48 | select { 49 | case s := <-ch: 50 | // Loop through handlers 51 | for _, h := range hs { 52 | h(s) 53 | } 54 | 55 | // Return 56 | if isTermSignal(s) { 57 | return 58 | } 59 | case <-w.Context().Done(): 60 | return 61 | } 62 | } 63 | }) 64 | } 65 | 66 | // Stop stops the Worker 67 | func (w *Worker) Stop() { 68 | w.os.Do(func() { 69 | w.l.Info("astikit: stopping worker...") 70 | w.cancel() 71 | w.wg.Done() 72 | }) 73 | } 74 | 75 | // Wait is a blocking pattern 76 | func (w *Worker) Wait() { 77 | w.ow.Do(func() { 78 | w.l.Info("astikit: worker is now waiting...") 79 | w.wg.Wait() 80 | }) 81 | } 82 | 83 | // NewTask creates a new task 84 | func (w *Worker) NewTask() *Task { 85 | return newTask(w.wg) 86 | } 87 | 88 | // Context returns the worker's context 89 | func (w *Worker) Context() context.Context { 90 | return w.ctx 91 | } 92 | 93 | // Logger returns the worker's logger 94 | func (w *Worker) Logger() SeverityLogger { 95 | return w.l 96 | } 97 | 98 | // TaskFunc represents a function that can create a new task 99 | type TaskFunc func() *Task 100 | 101 | // Task represents a task 102 | type Task struct { 103 | od, ow sync.Once 104 | wg, pwg *sync.WaitGroup 105 | } 106 | 107 | func newTask(parentWg *sync.WaitGroup) (t *Task) { 108 | t = &Task{ 109 | wg: &sync.WaitGroup{}, 110 | pwg: parentWg, 111 | } 112 | t.pwg.Add(1) 113 | return 114 | } 115 | 116 | // NewSubTask creates a new sub task 117 | func (t *Task) NewSubTask() *Task { 118 | return newTask(t.wg) 119 | } 120 | 121 | // Do executes the task 122 | func (t *Task) Do(f func()) { 123 | go func() { 124 | // Make sure to mark the task as done 125 | defer t.Done() 126 | 127 | // Custom 128 | f() 129 | 130 | // Wait for first level subtasks to be done 131 | // Wait() can also be called in f() if something needs to be executed just after Wait() 132 | t.Wait() 133 | }() 134 | } 135 | 136 | // Done indicates the task is done 137 | func (t *Task) Done() { 138 | t.od.Do(func() { 139 | t.pwg.Done() 140 | }) 141 | } 142 | 143 | // Wait waits for first level subtasks to be finished 144 | func (t *Task) Wait() { 145 | t.ow.Do(func() { 146 | t.wg.Wait() 147 | }) 148 | } 149 | -------------------------------------------------------------------------------- /worker_test.go: -------------------------------------------------------------------------------- 1 | package astikit 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestWorker(t *testing.T) { 9 | w := NewWorker(WorkerOptions{}) 10 | ts := w.NewTask() 11 | var o []int 12 | ts.Do(func() { 13 | w.Stop() 14 | o = append(o, 1) 15 | }) 16 | w.Wait() 17 | o = append(o, 2) 18 | if e := []int{1, 2}; !reflect.DeepEqual(o, e) { 19 | t.Fatalf("expected %+v, got %+v", e, o) 20 | } 21 | } 22 | --------------------------------------------------------------------------------