├── License ├── README.md ├── go.mod ├── image.png ├── reader.go ├── reader_test.go ├── shared_test.go ├── writer.go └── writer_test.go /License: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Logan Spears 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RowBoat 2 | 3 | [![Go Reference](https://pkg.go.dev/badge/github.com/notnil/rowboat.svg)](https://pkg.go.dev/github.com/notnil/rowboat) 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/notnil/rowboat)](https://goreportcard.com/report/github.com/notnil/rowboat) 5 | [![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE) 6 | 7 | 8 | Build Status 9 | 10 | 11 | RowBoat is a Go package that provides a simple and efficient way to read from and write to CSV files using Go's generics. It leverages struct tags and reflection to map CSV headers to struct fields, making it easy to work with CSV data in a type-safe manner. 12 | 13 | ## Features 14 | 15 | - **Generic CSV Reader and Writer**: Read and write CSV data into custom structs using Go generics. 16 | - **Struct Field Mapping**: Automatically maps CSV headers to struct fields based on field names or `csv` tags. 17 | - **Custom Marshaling and Unmarshaling**: Support for custom types that implement `CSVMarshaler` and `CSVUnmarshaler` interfaces. 18 | - **Field Indexing**: Control the order of fields in CSV output using `index` in struct tags. 19 | - **Support for Basic Types**: Handles basic Go types including `string`, `int`, `float64`, `bool`, and `time.Time`. 20 | 21 | ## Installation 22 | 23 | ```bash 24 | go get github.com/notnil/rowboat 25 | ``` 26 | 27 | ## Usage 28 | 29 | ### Defining Your Struct 30 | 31 | Define a struct that represents the CSV data. Use struct tags to specify CSV headers and indexing if needed. 32 | 33 | ```go 34 | package main 35 | 36 | import ( 37 | "github.com/notnil/rowboat" 38 | "time" 39 | ) 40 | 41 | type Person struct { 42 | Name string `csv:"Name"` 43 | Email string `csv:"Email"` 44 | Age int `csv:"Age"` 45 | JoinedAt time.Time `csv:"JoinedAt"` 46 | } 47 | ``` 48 | 49 | ### Reading CSV Data 50 | 51 | Create a `Reader` instance and read CSV data into your struct. 52 | 53 | ```go 54 | // main.go 55 | 56 | package main 57 | 58 | import ( 59 | "fmt" 60 | "os" 61 | "github.com/notnil/rowboat" 62 | "slices" 63 | ) 64 | 65 | func main() { 66 | // Open your CSV file 67 | file, err := os.Open("people.csv") 68 | if err != nil { 69 | panic(err) 70 | } 71 | defer file.Close() 72 | 73 | // Create a new Reader instance 74 | rb, err := rowboat.NewReader[Person](file) 75 | if err != nil { 76 | panic(err) 77 | } 78 | 79 | // Collect all records 80 | people := slices.Collect(rb.All()) 81 | 82 | // Use the data 83 | for _, person := range people { 84 | fmt.Printf("%+v\n", person) 85 | } 86 | } 87 | ``` 88 | 89 | ### Writing CSV Data 90 | 91 | Create a `Writer` instance and write your struct data to a CSV file. 92 | 93 | ```go 94 | // main.go 95 | 96 | package main 97 | 98 | import ( 99 | "os" 100 | 101 | "github.com/notnil/rowboat" 102 | ) 103 | 104 | func main() { 105 | // Open a file for writing 106 | file, err := os.Create("people_output.csv") 107 | if err != nil { 108 | panic(err) 109 | } 110 | defer file.Close() 111 | 112 | // Create a new Writer instance 113 | writer, err := rowboat.NewWriter[Person](file) 114 | if err != nil { 115 | panic(err) 116 | } 117 | 118 | // Sample data 119 | people := []Person{ 120 | {Name: "Alice", Email: "alice@example.com", Age: 30}, 121 | {Name: "Bob", Email: "bob@example.com", Age: 25}, 122 | } 123 | 124 | if err := writer.WriteHeader(); err != nil { 125 | panic(err) 126 | } 127 | 128 | // Write all records 129 | for _, person := range people { 130 | if err := writer.Write(person); err != nil { 131 | panic(err) 132 | } 133 | } 134 | } 135 | ``` 136 | 137 | ## Advanced Features 138 | 139 | ### Custom Unmarshaling 140 | 141 | If you have custom types, implement the `CSVUnmarshaler` interface to define how to parse CSV strings. 142 | 143 | ```go 144 | // point.go 145 | 146 | package main 147 | 148 | import ( 149 | "fmt" 150 | "strconv" 151 | "strings" 152 | ) 153 | 154 | type Point struct { 155 | X, Y float64 156 | } 157 | 158 | func (p *Point) UnmarshalCSV(value string) error { 159 | parts := strings.Split(value, ";") 160 | if len(parts) != 2 { 161 | return fmt.Errorf("invalid point format") 162 | } 163 | x, err := strconv.ParseFloat(parts[0], 64) 164 | if err != nil { 165 | return err 166 | } 167 | y, err := strconv.ParseFloat(parts[1], 64) 168 | if err != nil { 169 | return err 170 | } 171 | p.X = x 172 | p.Y = y 173 | return nil 174 | } 175 | ``` 176 | 177 | ### Custom Marshaling 178 | 179 | For writing custom types, implement the `CSVMarshaler` interface. 180 | 181 | ```go 182 | // point.go 183 | 184 | func (p Point) MarshalCSV() (string, error) { 185 | return fmt.Sprintf("%.2f;%.2f", p.X, p.Y), nil 186 | } 187 | ``` 188 | 189 | ### Field Indexing 190 | 191 | Control the order of fields in the CSV output using the `index` tag. 192 | 193 | ```go 194 | type IndexedPerson struct { 195 | Age int `csv:"Age,index=0"` 196 | Name string `csv:"Name,index=1"` 197 | Email string `csv:"Email,index=2"` 198 | } 199 | ``` 200 | 201 | ## Examples 202 | 203 | ### Reading with Filters 204 | 205 | Use the `Filter` function to read only specific records. 206 | 207 | ```go 208 | // main.go 209 | 210 | package main 211 | 212 | import ( 213 | "fmt" 214 | "slices" 215 | "strings" 216 | 217 | "github.com/notnil/rowboat" 218 | ) 219 | 220 | func main() { 221 | csvData := `Name,Email,Age 222 | Alice,alice@example.com,30 223 | Bob,bob@example.com,25 224 | Charlie,charlie@example.com,35` 225 | 226 | reader := strings.NewReader(csvData) 227 | rb, err := rowboat.NewReader[Person](reader) 228 | if err != nil { 229 | panic(err) 230 | } 231 | 232 | adults := slices.Collect(rowboat.Filter(func(p Person) bool { 233 | return p.Age >= 30 234 | }, rb.All())) 235 | 236 | fmt.Println(adults) 237 | } 238 | ``` 239 | 240 | ### Writing All Records from an Iterator 241 | 242 | ```go 243 | // main.go 244 | 245 | package main 246 | 247 | import ( 248 | "iter" 249 | "os" 250 | "slices" 251 | 252 | "github.com/notnil/rowboat" 253 | ) 254 | 255 | func main() { 256 | file, err := os.Create("people_output.csv") 257 | if err != nil { 258 | panic(err) 259 | } 260 | defer file.Close() 261 | 262 | writer, err := rowboat.NewWriter[Person](file) 263 | if err != nil { 264 | panic(err) 265 | } 266 | 267 | people := []Person{ 268 | {Name: "Alice", Email: "alice@example.com", Age: 30}, 269 | {Name: "Bob", Email: "bob@example.com", Age: 25}, 270 | } 271 | 272 | if err := writer.WriteHeader(); err != nil { 273 | panic(err) 274 | } 275 | 276 | // Write all records 277 | if err := writer.WriteAll(slices.Values(people)); err != nil { 278 | panic(err) 279 | } 280 | } 281 | ``` 282 | 283 | ## Struct Tag Details 284 | 285 | - **`csv:"ColumnName"`**: Specifies the CSV header name for the field. 286 | - **`csv:"-"`**: Skips the field; it will not be read from or written to CSV. 287 | - **`index=N`**: Sets the index (order) of the field in the CSV. Lower indexes come first. 288 | 289 | ## Custom Types Interface Definitions 290 | 291 | ```go 292 | // reader.go 293 | 294 | type CSVUnmarshaler interface { 295 | UnmarshalCSV(string) error 296 | } 297 | ``` 298 | 299 | ```go 300 | // writer.go 301 | 302 | type CSVMarshaler interface { 303 | MarshalCSV() (string, error) 304 | } 305 | ``` 306 | 307 | ## License 308 | 309 | This project is licensed under the MIT License. 310 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/notnil/rowboat 2 | 3 | go 1.23.2 4 | -------------------------------------------------------------------------------- /image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notnil/rowboat/525daa4903c88f86713f57815425bf1b143453f7/image.png -------------------------------------------------------------------------------- /reader.go: -------------------------------------------------------------------------------- 1 | package rowboat 2 | 3 | import ( 4 | "encoding/csv" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "iter" 9 | "reflect" 10 | "sort" 11 | "strconv" 12 | "strings" 13 | "time" 14 | ) 15 | 16 | // CSVUnmarshaler is an interface for custom CSV unmarshaling 17 | type CSVUnmarshaler interface { 18 | UnmarshalCSV(string) error 19 | } 20 | 21 | // Reader struct holds the CSV reader and mapping information 22 | type Reader[T any] struct { 23 | reader *csv.Reader 24 | headers []string 25 | fieldMap map[int]reflect.StructField 26 | err error 27 | current T 28 | } 29 | 30 | // NewReader creates a new RowBoat reader instance 31 | func NewReader[T any](r io.Reader) (*Reader[T], error) { 32 | rb := &Reader[T]{} 33 | rb.reader = csv.NewReader(r) 34 | 35 | // Read headers 36 | headers, err := rb.reader.Read() 37 | if err != nil { 38 | return nil, err 39 | } 40 | rb.headers = headers 41 | 42 | // Map CSV headers to struct fields 43 | if err := rb.createFieldMap(); err != nil { 44 | return nil, err 45 | } 46 | 47 | return rb, nil 48 | } 49 | 50 | // createFieldMap maps CSV headers to struct fields using struct tags 51 | func (rb *Reader[T]) createFieldMap() error { 52 | rb.fieldMap = make(map[int]reflect.StructField) 53 | 54 | var t T 55 | tType := reflect.TypeOf(t) 56 | if tType.Kind() != reflect.Struct { 57 | return errors.New("generic type T must be a struct") 58 | } 59 | 60 | // First pass: collect fields and their indexes 61 | type fieldInfo struct { 62 | field reflect.StructField 63 | name string 64 | index int 65 | } 66 | fields := make([]fieldInfo, 0, tType.NumField()) 67 | maxIndex := -1 68 | 69 | for i := 0; i < tType.NumField(); i++ { 70 | field := tType.Field(i) 71 | csvTag := field.Tag.Get("csv") 72 | if csvTag == "-" { 73 | continue 74 | } 75 | 76 | name := field.Name 77 | index := i // default index is field order 78 | tagParts := strings.Split(csvTag, ",") 79 | if len(tagParts) > 0 && tagParts[0] != "" { 80 | name = tagParts[0] 81 | } 82 | 83 | for _, part := range tagParts[1:] { 84 | part = strings.TrimSpace(part) 85 | if strings.HasPrefix(part, "index=") { 86 | idxStr := strings.TrimPrefix(part, "index=") 87 | idx, err := strconv.Atoi(idxStr) 88 | if err != nil { 89 | return fmt.Errorf("invalid index value '%s' in field '%s': %v", idxStr, field.Name, err) 90 | } 91 | index = idx 92 | if index > maxIndex { 93 | maxIndex = index 94 | } 95 | } 96 | } 97 | 98 | fields = append(fields, fieldInfo{ 99 | field: field, 100 | name: name, 101 | index: index, 102 | }) 103 | } 104 | 105 | // Assign indexes to fields without explicit index 106 | nextIndex := maxIndex + 1 107 | for i := range fields { 108 | if fields[i].index == fields[i].field.Index[0] { 109 | fields[i].index = nextIndex 110 | nextIndex++ 111 | } 112 | } 113 | 114 | // Sort fields by index 115 | sort.Slice(fields, func(i, j int) bool { 116 | return fields[i].index < fields[j].index 117 | }) 118 | 119 | // Map headers to fields 120 | headerMap := make(map[string]int) 121 | for i, header := range rb.headers { 122 | headerMap[strings.TrimSpace(header)] = i 123 | } 124 | 125 | // Create final field mapping 126 | for _, fi := range fields { 127 | if idx, ok := headerMap[fi.name]; ok { 128 | rb.fieldMap[idx] = fi.field 129 | } 130 | } 131 | 132 | return nil 133 | } 134 | 135 | // nextRow advances the iterator and parses the next record 136 | func (rb *Reader[T]) nextRow() bool { 137 | record, err := rb.reader.Read() 138 | if err == io.EOF { 139 | return false 140 | } 141 | if err != nil { 142 | rb.err = err 143 | return false 144 | } 145 | 146 | var t T 147 | tValue := reflect.ValueOf(&t).Elem() 148 | 149 | for idx, value := range record { 150 | if field, ok := rb.fieldMap[idx]; ok { 151 | fieldValue := tValue.FieldByName(field.Name) 152 | if !fieldValue.CanSet() { 153 | continue 154 | } 155 | if err := setFieldValue(fieldValue, value); err != nil { 156 | rb.err = fmt.Errorf("error setting field %s: %w", field.Name, err) 157 | return false 158 | } 159 | } 160 | } 161 | rb.current = t 162 | return true 163 | } 164 | 165 | // setFieldValue sets the value of a struct field based on its type 166 | func setFieldValue(field reflect.Value, value string) error { 167 | csvUnmarshalerType := reflect.TypeOf((*CSVUnmarshaler)(nil)).Elem() 168 | 169 | // Check if the field implements CSVUnmarshaler 170 | if field.CanInterface() && field.Type().Implements(csvUnmarshalerType) { 171 | unmarshaler := field.Interface().(CSVUnmarshaler) 172 | return unmarshaler.UnmarshalCSV(value) 173 | } 174 | 175 | // Check if the pointer to the field implements CSVUnmarshaler 176 | if field.CanAddr() && field.Addr().CanInterface() && field.Addr().Type().Implements(csvUnmarshalerType) { 177 | unmarshaler := field.Addr().Interface().(CSVUnmarshaler) 178 | return unmarshaler.UnmarshalCSV(value) 179 | } 180 | 181 | // Handle specific types like time.Time 182 | if field.Type() == reflect.TypeOf(time.Time{}) { 183 | t, err := time.Parse(time.RFC3339, value) 184 | if err != nil { 185 | return err 186 | } 187 | field.Set(reflect.ValueOf(t)) 188 | return nil 189 | } 190 | 191 | // Handle basic kinds 192 | switch field.Kind() { 193 | case reflect.String: 194 | field.SetString(value) 195 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 196 | intValue, err := strconv.ParseInt(value, 10, 64) 197 | if err != nil { 198 | return err 199 | } 200 | field.SetInt(intValue) 201 | case reflect.Float32, reflect.Float64: 202 | floatValue, err := strconv.ParseFloat(value, 64) 203 | if err != nil { 204 | return err 205 | } 206 | field.SetFloat(floatValue) 207 | case reflect.Bool: 208 | boolValue, err := strconv.ParseBool(value) 209 | if err != nil { 210 | return err 211 | } 212 | field.SetBool(boolValue) 213 | default: 214 | return fmt.Errorf("unsupported field type: %s", field.Type()) 215 | } 216 | return nil 217 | } 218 | 219 | // All returns an iterator over all records in the CSV file. 220 | // Each iteration returns a parsed struct of type T. 221 | func (rb *Reader[T]) All() iter.Seq[T] { 222 | return func(yield func(T) bool) { 223 | // Keep reading records until we hit EOF or an error 224 | for rb.nextRow() { 225 | // Get the current record 226 | record := rb.current 227 | 228 | // Pass to yield function - if it returns false, stop iteration 229 | if !yield(record) { 230 | return 231 | } 232 | } 233 | 234 | // Check if we stopped due to an error 235 | if rb.err != nil && rb.err != io.EOF { 236 | // We can't return an error directly from the iterator, 237 | // but we can panic which will be caught by the range loop 238 | panic(rb.err) 239 | } 240 | } 241 | } 242 | 243 | // Filter returns a sequence that contains the elements 244 | // of s for which f returns true. 245 | func Filter[V any](f func(V) bool, s iter.Seq[V]) iter.Seq[V] { 246 | return func(yield func(V) bool) { 247 | for v := range s { 248 | if f(v) { 249 | if !yield(v) { 250 | return 251 | } 252 | } 253 | } 254 | } 255 | } 256 | -------------------------------------------------------------------------------- /reader_test.go: -------------------------------------------------------------------------------- 1 | package rowboat_test 2 | 3 | import ( 4 | "reflect" 5 | "slices" 6 | "strings" 7 | "testing" 8 | "time" 9 | 10 | "github.com/notnil/rowboat" 11 | ) 12 | 13 | func TestRowBoat(t *testing.T) { 14 | // CSV data as a string 15 | csvData := `Name,Email,Age 16 | Alice,alice@example.com,30 17 | Bob,bob@example.com,25 18 | Charlie,charlie@example.com,35` 19 | 20 | // Create a reader from the CSV data string 21 | reader := strings.NewReader(csvData) 22 | 23 | // Create a new RowBoat instance 24 | rb, err := rowboat.NewReader[Person](reader) 25 | if err != nil { 26 | t.Fatalf("Failed to create RowBoat: %v", err) 27 | } 28 | 29 | // Expected results 30 | expected := []Person{ 31 | {Name: "Alice", Email: "alice@example.com", Age: 30}, 32 | {Name: "Charlie", Email: "charlie@example.com", Age: 35}, 33 | } 34 | results := slices.Collect(rowboat.Filter(func(p Person) bool { 35 | return p.Age > 25 36 | }, rb.All())) 37 | 38 | // Compare results 39 | if !reflect.DeepEqual(results, expected) { 40 | t.Errorf("Parsed results do not match expected.\nExpected: %+v\nGot: %+v", expected, results) 41 | } 42 | } 43 | 44 | func TestExtraneousColumns(t *testing.T) { 45 | // CSV data with extra columns that aren't in the struct 46 | csvData := `Name,Email,Age,ExtraCol1,ExtraCol2 47 | Alice,alice@example.com,30,unused1,unused2 48 | Bob,bob@example.com,25,unused3,unused4` 49 | 50 | reader := strings.NewReader(csvData) 51 | 52 | // Create a new RowBoat instance 53 | rb, err := rowboat.NewReader[Person](reader) 54 | if err != nil { 55 | t.Fatalf("Failed to create RowBoat: %v", err) 56 | } 57 | 58 | // Expected results - should ignore the extra columns 59 | expected := []Person{ 60 | {Name: "Alice", Email: "alice@example.com", Age: 30}, 61 | {Name: "Bob", Email: "bob@example.com", Age: 25}, 62 | } 63 | 64 | results := slices.Collect(rb.All()) 65 | 66 | // Compare results 67 | if !reflect.DeepEqual(results, expected) { 68 | t.Errorf("Parsed results do not match expected.\nExpected: %+v\nGot: %+v", expected, results) 69 | } 70 | } 71 | 72 | func TestBlankRows(t *testing.T) { 73 | // CSV data with blank rows 74 | csvData := `Name,Email,Age 75 | 76 | Alice,alice@example.com,30 77 | 78 | Bob,bob@example.com,25 79 | ` 80 | reader := strings.NewReader(csvData) 81 | 82 | // Create a new RowBoat instance 83 | rb, err := rowboat.NewReader[Person](reader) 84 | if err != nil { 85 | t.Fatalf("Failed to create RowBoat: %v", err) 86 | } 87 | 88 | // Expected results - should skip blank rows 89 | expected := []Person{ 90 | {Name: "Alice", Email: "alice@example.com", Age: 30}, 91 | {Name: "Bob", Email: "bob@example.com", Age: 25}, 92 | } 93 | 94 | results := slices.Collect(rb.All()) 95 | 96 | // Compare results 97 | if !reflect.DeepEqual(results, expected) { 98 | t.Errorf("Parsed results do not match expected.\nExpected: %+v\nGot: %+v", expected, results) 99 | } 100 | } 101 | 102 | func TestComplexTypes(t *testing.T) { 103 | // CSV data with various types 104 | csvData := `name,created_at,active,score,count,rate,tags 105 | John,2023-01-02T15:04:05Z,true,98.6,42,3.14,test;debug 106 | Jane,2023-06-15T09:30:00Z,false,75.2,100,2.718,prod;live` 107 | 108 | reader := strings.NewReader(csvData) 109 | 110 | rb, err := rowboat.NewReader[ComplexRecord](reader) 111 | if err != nil { 112 | t.Fatalf("Failed to create RowBoat: %v", err) 113 | } 114 | 115 | // Parse expected time values 116 | t1, _ := time.Parse(time.RFC3339, "2023-01-02T15:04:05Z") 117 | t2, _ := time.Parse(time.RFC3339, "2023-06-15T09:30:00Z") 118 | 119 | expected := []ComplexRecord{ 120 | { 121 | Name: "John", 122 | CreatedAt: t1, 123 | Active: true, 124 | Score: 98.6, 125 | Count: 42, 126 | Rate: 3.14, 127 | Tags: "test;debug", 128 | }, 129 | { 130 | Name: "Jane", 131 | CreatedAt: t2, 132 | Active: false, 133 | Score: 75.2, 134 | Count: 100, 135 | Rate: 2.718, 136 | Tags: "prod;live", 137 | }, 138 | } 139 | 140 | results := slices.Collect(rb.All()) 141 | 142 | if !reflect.DeepEqual(results, expected) { 143 | t.Errorf("Parsed results do not match expected.\nExpected: %+v\nGot: %+v", expected, results) 144 | } 145 | } 146 | 147 | func TestCustomUnmarshaler(t *testing.T) { 148 | csvData := `point 149 | 1;2 150 | 3;4` 151 | reader := strings.NewReader(csvData) 152 | 153 | rb, err := rowboat.NewReader[Custom](reader) 154 | if err != nil { 155 | t.Fatalf("Failed to create RowBoat: %v", err) 156 | } 157 | 158 | results := slices.Collect(rb.All()) 159 | 160 | expected := []Custom{ 161 | {Point: Point{X: 1, Y: 2}}, 162 | {Point: Point{X: 3, Y: 4}}, 163 | } 164 | 165 | if !reflect.DeepEqual(results, expected) { 166 | t.Errorf("Parsed results do not match expected.\nExpected: %+v\nGot: %+v", expected, results) 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /shared_test.go: -------------------------------------------------------------------------------- 1 | package rowboat_test 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "strings" 7 | "time" 8 | ) 9 | 10 | type Person struct { 11 | Name string `csv:"Name"` 12 | Email string `csv:"Email"` 13 | Age int `csv:"Age"` 14 | } 15 | 16 | type ComplexRecord struct { 17 | Name string `csv:"name"` 18 | CreatedAt time.Time `csv:"created_at"` 19 | Active bool `csv:"active"` 20 | Score float64 `csv:"score"` 21 | Count int `csv:"count"` 22 | Rate float32 `csv:"rate"` 23 | Tags string `csv:"tags"` 24 | } 25 | 26 | type Point struct { 27 | X float64 `csv:"x"` 28 | Y float64 `csv:"y"` 29 | } 30 | 31 | func (p *Point) UnmarshalCSV(value string) error { 32 | parts := strings.Split(value, ";") 33 | if len(parts) != 2 { 34 | return fmt.Errorf("invalid point format: %s", value) 35 | } 36 | var err error 37 | p.X, err = strconv.ParseFloat(parts[0], 64) 38 | if err != nil { 39 | return err 40 | } 41 | p.Y, err = strconv.ParseFloat(parts[1], 64) 42 | return err 43 | } 44 | 45 | func (p Point) MarshalCSV() (string, error) { 46 | return fmt.Sprintf("%.2f;%.2f", p.X, p.Y), nil 47 | } 48 | 49 | type Custom struct { 50 | Point Point `csv:"point"` 51 | } 52 | -------------------------------------------------------------------------------- /writer.go: -------------------------------------------------------------------------------- 1 | package rowboat 2 | 3 | import ( 4 | "encoding/csv" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "iter" 9 | "reflect" 10 | "sort" 11 | "strconv" 12 | "strings" 13 | "time" 14 | ) 15 | 16 | // CSVMarshaler is an interface for custom CSV marshaling 17 | type CSVMarshaler interface { 18 | MarshalCSV() (string, error) 19 | } 20 | 21 | // fieldInfo contains information about a struct field for CSV writing 22 | type fieldInfo struct { 23 | Index int 24 | Name string 25 | Field reflect.StructField 26 | } 27 | 28 | // Writer struct holds the CSV writer and mapping information 29 | type Writer[T any] struct { 30 | writer *csv.Writer 31 | fields []fieldInfo 32 | } 33 | 34 | // NewWriter creates a new RowBoat writer instance 35 | func NewWriter[T any](w io.Writer) (*Writer[T], error) { 36 | rw := &Writer[T]{} 37 | rw.writer = csv.NewWriter(w) 38 | 39 | // Analyze the struct fields 40 | if err := rw.createFieldInfo(); err != nil { 41 | return nil, err 42 | } 43 | 44 | return rw, nil 45 | } 46 | 47 | func (rw *Writer[T]) WriteHeader() error { 48 | headers := make([]string, len(rw.fields)) 49 | for i, fi := range rw.fields { 50 | headers[i] = fi.Name 51 | } 52 | if err := rw.writer.Write(headers); err != nil { 53 | return err 54 | } 55 | rw.writer.Flush() 56 | return nil 57 | } 58 | 59 | // Write writes a single record to the CSV writer 60 | func (rw *Writer[T]) Write(record T) error { 61 | recordValues := make([]string, len(rw.fields)) 62 | v := reflect.ValueOf(record) 63 | for i, fi := range rw.fields { 64 | fieldValue := v.FieldByName(fi.Field.Name) 65 | strValue, err := getFieldStringValue(fieldValue) 66 | if err != nil { 67 | return fmt.Errorf("error marshaling field %s: %w", fi.Field.Name, err) 68 | } 69 | recordValues[i] = strValue 70 | } 71 | 72 | if err := rw.writer.Write(recordValues); err != nil { 73 | return err 74 | } 75 | rw.writer.Flush() 76 | return nil 77 | } 78 | 79 | // WriteAll writes multiple records from an iterator 80 | func (rw *Writer[T]) WriteAll(records iter.Seq[T]) error { 81 | var err error 82 | records(func(record T) bool { 83 | if err = rw.Write(record); err != nil { 84 | return false 85 | } 86 | return true 87 | }) 88 | return err 89 | } 90 | 91 | // createFieldInfo extracts information about struct fields, including indexes 92 | func (rw *Writer[T]) createFieldInfo() error { 93 | var t T 94 | tType := reflect.TypeOf(t) 95 | if tType.Kind() != reflect.Struct { 96 | return errors.New("generic type T must be a struct") 97 | } 98 | 99 | fields := make([]fieldInfo, 0, tType.NumField()) 100 | maxIndex := -1 101 | 102 | for i := 0; i < tType.NumField(); i++ { 103 | field := tType.Field(i) 104 | csvTag := field.Tag.Get("csv") 105 | if csvTag == "-" { 106 | continue // skip field 107 | } 108 | 109 | name := field.Name 110 | index := i // default index is the field order 111 | tagParts := strings.Split(csvTag, ",") 112 | if len(tagParts) > 0 && tagParts[0] != "" { 113 | name = tagParts[0] 114 | } 115 | 116 | for _, part := range tagParts[1:] { 117 | part = strings.TrimSpace(part) 118 | if strings.HasPrefix(part, "index=") { 119 | idxStr := strings.TrimPrefix(part, "index=") 120 | idx, err := strconv.Atoi(idxStr) 121 | if err != nil { 122 | return fmt.Errorf("invalid index value '%s' in field '%s': %v", idxStr, field.Name, err) 123 | } 124 | index = idx 125 | if index > maxIndex { 126 | maxIndex = index 127 | } 128 | } 129 | } 130 | 131 | fields = append(fields, fieldInfo{ 132 | Index: index, 133 | Name: name, 134 | Field: field, 135 | }) 136 | } 137 | 138 | // Assign indexes to fields without an explicit index, starting from maxIndex+1 139 | nextIndex := maxIndex + 1 140 | for i := range fields { 141 | if fields[i].Index == fields[i].Field.Index[0] { // Field's default index 142 | fields[i].Index = nextIndex 143 | nextIndex++ 144 | } 145 | } 146 | 147 | // Sort the fields based on the index 148 | sort.Slice(fields, func(i, j int) bool { 149 | return fields[i].Index < fields[j].Index 150 | }) 151 | 152 | rw.fields = fields 153 | return nil 154 | } 155 | 156 | // getFieldStringValue converts a struct field value to string for CSV 157 | func getFieldStringValue(field reflect.Value) (string, error) { 158 | csvMarshalerType := reflect.TypeOf((*CSVMarshaler)(nil)).Elem() 159 | 160 | // Check if the field implements CSVMarshaler 161 | if field.CanInterface() && field.Type().Implements(csvMarshalerType) { 162 | marshaler := field.Interface().(CSVMarshaler) 163 | return marshaler.MarshalCSV() 164 | } 165 | 166 | // Check if the pointer to the field implements CSVMarshaler 167 | if field.CanAddr() && field.Addr().CanInterface() && field.Addr().Type().Implements(csvMarshalerType) { 168 | marshaler := field.Addr().Interface().(CSVMarshaler) 169 | return marshaler.MarshalCSV() 170 | } 171 | 172 | // Handle specific types like time.Time 173 | if field.Type() == reflect.TypeOf(time.Time{}) { 174 | t := field.Interface().(time.Time) 175 | return t.Format(time.RFC3339), nil 176 | } 177 | 178 | // Handle basic kinds 179 | switch field.Kind() { 180 | case reflect.String: 181 | return field.String(), nil 182 | case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 183 | return strconv.FormatInt(field.Int(), 10), nil 184 | case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 185 | return strconv.FormatUint(field.Uint(), 10), nil 186 | case reflect.Float32, reflect.Float64: 187 | return strconv.FormatFloat(field.Float(), 'f', -1, 64), nil 188 | case reflect.Bool: 189 | return strconv.FormatBool(field.Bool()), nil 190 | default: 191 | return "", fmt.Errorf("unsupported field type: %s", field.Type()) 192 | } 193 | } 194 | -------------------------------------------------------------------------------- /writer_test.go: -------------------------------------------------------------------------------- 1 | package rowboat_test 2 | 3 | import ( 4 | "bytes" 5 | "iter" 6 | "reflect" 7 | "slices" 8 | "strings" 9 | "testing" 10 | "time" 11 | 12 | "github.com/notnil/rowboat" 13 | ) 14 | 15 | func TestWriter(t *testing.T) { 16 | // Create test data 17 | people := []Person{ 18 | {Name: "Alice", Email: "alice@example.com", Age: 30}, 19 | {Name: "Bob", Email: "bob@example.com", Age: 25}, 20 | {Name: "Charlie", Email: "charlie@example.com", Age: 35}, 21 | } 22 | 23 | // Create a buffer to write CSV data 24 | var buf bytes.Buffer 25 | 26 | // Create a new Writer instance 27 | writer, err := rowboat.NewWriter[Person](&buf) 28 | if err != nil { 29 | t.Fatalf("Failed to create Writer: %v", err) 30 | } 31 | 32 | if err := writer.WriteHeader(); err != nil { 33 | t.Fatalf("Failed to write header: %v", err) 34 | } 35 | 36 | // Write all records 37 | for _, person := range people { 38 | if err := writer.Write(person); err != nil { 39 | t.Fatalf("Failed to write record: %v", err) 40 | } 41 | } 42 | 43 | // Get the written CSV data 44 | csvData := buf.String() 45 | 46 | // Create a reader to verify the written data 47 | reader := strings.NewReader(csvData) 48 | rb, err := rowboat.NewReader[Person](reader) 49 | if err != nil { 50 | t.Fatalf("Failed to create Reader: %v", err) 51 | } 52 | 53 | // Read back the data and compare 54 | results := slices.Collect(rb.All()) 55 | 56 | if !reflect.DeepEqual(results, people) { 57 | t.Errorf("Written results do not match expected.\nExpected: %+v\nGot: %+v", people, results) 58 | } 59 | } 60 | 61 | func TestComplexWriter(t *testing.T) { 62 | // Create test data with complex types 63 | t1 := time.Date(2023, 1, 2, 15, 4, 5, 0, time.UTC) 64 | t2 := time.Date(2023, 6, 15, 9, 30, 0, 0, time.UTC) 65 | 66 | records := []ComplexRecord{ 67 | { 68 | Name: "John", 69 | CreatedAt: t1, 70 | Active: true, 71 | Score: 98.6, 72 | Count: 42, 73 | Rate: 3.14, 74 | Tags: "test;debug", 75 | }, 76 | { 77 | Name: "Jane", 78 | CreatedAt: t2, 79 | Active: false, 80 | Score: 75.2, 81 | Count: 100, 82 | Rate: 2.718, 83 | Tags: "prod;live", 84 | }, 85 | } 86 | 87 | var buf bytes.Buffer 88 | writer, err := rowboat.NewWriter[ComplexRecord](&buf) 89 | if err != nil { 90 | t.Fatalf("Failed to create Writer: %v", err) 91 | } 92 | 93 | if err := writer.WriteHeader(); err != nil { 94 | t.Fatalf("Failed to write header: %v", err) 95 | } 96 | 97 | for _, record := range records { 98 | if err := writer.Write(record); err != nil { 99 | t.Fatalf("Failed to write record: %v", err) 100 | } 101 | } 102 | 103 | // Read back and verify 104 | reader := strings.NewReader(buf.String()) 105 | rb, err := rowboat.NewReader[ComplexRecord](reader) 106 | if err != nil { 107 | t.Fatalf("Failed to create Reader: %v", err) 108 | } 109 | 110 | results := slices.Collect(rb.All()) 111 | 112 | if !reflect.DeepEqual(results, records) { 113 | t.Errorf("Written results do not match expected.\nExpected: %+v\nGot: %+v", records, results) 114 | } 115 | } 116 | 117 | func TestWriteAll(t *testing.T) { 118 | t1 := time.Date(2023, 1, 2, 15, 4, 5, 0, time.UTC) 119 | t2 := time.Date(2023, 6, 15, 9, 30, 0, 0, time.UTC) 120 | 121 | records := []ComplexRecord{ 122 | { 123 | Name: "John", 124 | CreatedAt: t1, 125 | Active: true, 126 | Score: 98.6, 127 | Count: 42, 128 | Rate: 3.14, 129 | Tags: "test;debug", 130 | }, 131 | { 132 | Name: "Jane", 133 | CreatedAt: t2, 134 | Active: false, 135 | Score: 75.2, 136 | Count: 100, 137 | Rate: 2.718, 138 | Tags: "prod;live", 139 | }, 140 | } 141 | 142 | var buf bytes.Buffer 143 | writer, err := rowboat.NewWriter[ComplexRecord](&buf) 144 | if err != nil { 145 | t.Fatalf("Failed to create Writer: %v", err) 146 | } 147 | 148 | // Convert slice to iterator using iter.Seq 149 | recordIter := iter.Seq[ComplexRecord](func(yield func(ComplexRecord) bool) { 150 | for _, record := range records { 151 | if !yield(record) { 152 | return 153 | } 154 | } 155 | }) 156 | 157 | if err := writer.WriteHeader(); err != nil { 158 | t.Fatalf("Failed to write header: %v", err) 159 | } 160 | 161 | if err := writer.WriteAll(recordIter); err != nil { 162 | t.Fatalf("Failed to write records: %v", err) 163 | } 164 | 165 | // Read back and verify 166 | reader := strings.NewReader(buf.String()) 167 | rb, err := rowboat.NewReader[ComplexRecord](reader) 168 | if err != nil { 169 | t.Fatalf("Failed to create Reader: %v", err) 170 | } 171 | 172 | results := slices.Collect(rb.All()) 173 | 174 | if !reflect.DeepEqual(results, records) { 175 | t.Errorf("Written results do not match expected.\nExpected: %+v\nGot: %+v", records, results) 176 | } 177 | } 178 | 179 | func TestCustomMarshaler(t *testing.T) { 180 | records := []Custom{ 181 | { 182 | Point: Point{X: 1.23, Y: 4.56}, 183 | }, 184 | { 185 | Point: Point{X: 7.89, Y: 0.12}, 186 | }, 187 | } 188 | 189 | var buf bytes.Buffer 190 | writer, err := rowboat.NewWriter[Custom](&buf) 191 | if err != nil { 192 | t.Fatalf("Failed to create Writer: %v", err) 193 | } 194 | 195 | // Convert slice to iterator using iter.Seq 196 | recordIter := iter.Seq[Custom](func(yield func(Custom) bool) { 197 | for _, record := range records { 198 | if !yield(record) { 199 | return 200 | } 201 | } 202 | }) 203 | 204 | if err := writer.WriteHeader(); err != nil { 205 | t.Fatalf("Failed to write header: %v", err) 206 | } 207 | 208 | if err := writer.WriteAll(recordIter); err != nil { 209 | t.Fatalf("Failed to write records: %v", err) 210 | } 211 | 212 | expected := "point\n1.23;4.56\n7.89;0.12\n" 213 | if buf.String() != expected { 214 | t.Errorf("Written CSV does not match expected.\nExpected:\n%s\nGot:\n%s", expected, buf.String()) 215 | } 216 | } 217 | 218 | func TestWriterWithIndexing(t *testing.T) { 219 | // Create test data 220 | people := []Person{ 221 | {Name: "Alice", Email: "alice@example.com", Age: 30}, 222 | {Name: "Bob", Email: "bob@example.com", Age: 25}, 223 | {Name: "Charlie", Email: "charlie@example.com", Age: 35}, 224 | } 225 | 226 | // Define Person with indexing 227 | type IndexedPerson struct { 228 | Name string `csv:"Name,index=1"` 229 | Email string `csv:"Email,index=2"` 230 | Age int `csv:"Age,index=0"` 231 | } 232 | 233 | // Convert people to IndexedPerson 234 | indexedPeople := make([]IndexedPerson, len(people)) 235 | for i, p := range people { 236 | indexedPeople[i] = IndexedPerson{ 237 | Name: p.Name, 238 | Email: p.Email, 239 | Age: p.Age, 240 | } 241 | } 242 | 243 | // Create a buffer to write CSV data 244 | var buf bytes.Buffer 245 | 246 | // Create a new Writer instance 247 | writer, err := rowboat.NewWriter[IndexedPerson](&buf) 248 | if err != nil { 249 | t.Fatalf("Failed to create Writer: %v", err) 250 | } 251 | 252 | if err := writer.WriteHeader(); err != nil { 253 | t.Fatalf("Failed to write header: %v", err) 254 | } 255 | 256 | // Write all records 257 | for _, person := range indexedPeople { 258 | if err := writer.Write(person); err != nil { 259 | t.Fatalf("Failed to write record: %v", err) 260 | } 261 | } 262 | 263 | // Expected CSV output 264 | expectedCSV := `Age,Name,Email 265 | 30,Alice,alice@example.com 266 | 25,Bob,bob@example.com 267 | 35,Charlie,charlie@example.com 268 | ` 269 | 270 | csvStr := buf.String() 271 | if csvStr != expectedCSV { 272 | t.Errorf("Written CSV does not match expected.\nExpected:\n%s\nGot:\n%s", expectedCSV, csvStr) 273 | } 274 | 275 | // Verify by reading back 276 | reader := strings.NewReader(csvStr) 277 | rb, err := rowboat.NewReader[IndexedPerson](reader) 278 | if err != nil { 279 | t.Fatalf("Failed to create Reader: %v", err) 280 | } 281 | 282 | // Read back the data and compare 283 | results := slices.Collect(rb.All()) 284 | 285 | // Convert back to original person for comparison 286 | readPeople := make([]Person, len(results)) 287 | for i, ip := range results { 288 | readPeople[i] = Person{ 289 | Name: ip.Name, 290 | Email: ip.Email, 291 | Age: ip.Age, 292 | } 293 | } 294 | 295 | if !reflect.DeepEqual(readPeople, people) { 296 | t.Errorf("Written and read results do not match expected.\nExpected: %+v\nGot: %+v", people, readPeople) 297 | } 298 | } 299 | --------------------------------------------------------------------------------