├── go.mod ├── LICENSE ├── README.md ├── csvplus.go └── csvplus_test.go /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/maxim2266/csvplus 2 | 3 | go 1.23 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016,2017,2018,2019 Maxim Konakov 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, 8 | this list of conditions and the following disclaimer. 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 3. Neither the name of the copyright holder nor the names of its contributors 13 | may be used to endorse or promote products derived from this software without 14 | specific prior written permission. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 17 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 18 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 19 | IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 20 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 21 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 22 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, 25 | EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # csvplus 2 | 3 | [![GoDoc](https://godoc.org/github.com/maxim2266/csvplus?status.svg)](https://pkg.go.dev/github.com/maxim2266/csvplus) 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/maxim2266/csvplus)](https://goreportcard.com/report/github.com/maxim2266/csvplus) 5 | [![License: BSD 3-Clause](https://img.shields.io/badge/License-BSD_3--Clause-yellow.svg)](https://opensource.org/licenses/BSD-3-Clause) 6 | 7 | Package `csvplus` extends the standard Go [encoding/csv](https://golang.org/pkg/encoding/csv/) 8 | package with fluent interface, lazy stream processing operations, indices and joins. 9 | 10 | The library is primarily designed for [ETL](https://en.wikipedia.org/wiki/Extract,_transform,_load)-like processes. 11 | It is mostly useful in places where the more advanced searching/joining capabilities of a fully-featured SQL 12 | database are not required, but the same time the data transformations needed still include SQL-like operations. 13 | 14 | ##### License: BSD 15 | 16 | ### Examples 17 | 18 | Simple sequential processing: 19 | ```Go 20 | // create data source from "people.csv" file with selected columns 21 | people := csvplus.FromFile("people.csv").SelectColumns("name", "surname", "id") 22 | 23 | err := csvplus.Take(people). 24 | Filter(csvplus.Like(csvplus.Row{"name": "Amelia"})). // choose only this name 25 | Map(func(row csvplus.Row) csvplus.Row { row["name"] = "Julia"; return row }). // replace the name 26 | ToCsvFile("out.csv", "name", "surname") // write select columns to .csv file 27 | 28 | if err != nil { 29 | return err 30 | } 31 | ``` 32 | 33 | More involved example: 34 | ```Go 35 | // create data source from "people.csv" with selected columns 36 | customers := csvplus.FromFile("people.csv").SelectColumns("id", "name", "surname") 37 | // build unique index on "id" column 38 | custIndex, err := csvplus.Take(customers).UniqueIndexOn("id") 39 | 40 | if err != nil { 41 | return err 42 | } 43 | 44 | // create another data source from "stock.csv" with selected columns 45 | products := csvplus.FromFile("stock.csv").SelectColumns("prod_id", "product", "price") 46 | // build unique index on "prod_id" column 47 | prodIndex, err := csvplus.Take(products).UniqueIndexOn("prod_id") 48 | 49 | if err != nil { 50 | return err 51 | } 52 | 53 | // create one more data source from "orders.csv" with selected columns 54 | orders := csvplus.FromFile("orders.csv").SelectColumns("cust_id", "prod_id", "qty", "ts") 55 | // create iterator on all the above sources, joined 56 | iter := csvplus.Take(orders).Join(custIndex, "cust_id").Join(prodIndex) 57 | 58 | // iterate the result to produce output 59 | return iter(func(row csvplus.Row) error { 60 | // prints lines like: 61 | // John Doe bought 38 oranges for £0.03 each on 2016-09-14T08:48:22+01:00 62 | _, e := fmt.Printf("%s %s bought %s %ss for £%s each on %s\n", 63 | row["name"], row["surname"], row["qty"], row["product"], row["price"], row["ts"]) 64 | return e 65 | }) 66 | ``` 67 | 68 | ### Design principles 69 | 70 | The package functionality is based on the operations on the following entities: 71 | - type `Row` 72 | - type `DataSource` 73 | - type `Index` 74 | 75 | #### Type `Row` 76 | `Row` represents one row from a `DataSource`. It is a map from column names 77 | to the string values under those columns on the current row. The package expects a unique name 78 | assigned to every column at source. Compared to using integer indices this provides more 79 | convenience when complex transformations get applied to each row during processing. 80 | 81 | #### type `DataSource` 82 | Type `DataSource` represents any source of zero or more rows, like `.csv` file. This is a function 83 | that when invoked feeds the given callback with the data from its source, one `Row` at a time. 84 | The type also has a number of operations defined on it that provide for easy composition of the 85 | operations on the `DataSource`, forming so called [fluent interface](https://en.wikipedia.org/wiki/Fluent_interface). 86 | All these operations are 'lazy', i.e. they are not performed immediately, but instead each of them 87 | returns a new `DataSource`. 88 | 89 | There is also a number of convenience operations that actually invoke 90 | the `DataSource` function to produce a specific type of output: 91 | - `IndexOn` to build an index on the specified column(s); 92 | - `UniqueIndexOn` to build a unique index on the specified column(s); 93 | - `ToCsv` to serialise the `DataSource` to the given `io.Writer` in `.csv` format; 94 | - `ToCsvFile` to store the `DataSource` in the specified file in `.csv` format; 95 | - `ToJSON` to serialise the `DataSource` to the given `io.Writer` in JSON format; 96 | - `ToJSONFile` to store the `DataSource` in the specified file in JSON format; 97 | - `ToRows` to convert the `DataSource` to a slice of `Row`s. 98 | 99 | #### Type `Index` 100 | `Index` is a sorted collection of rows. The sorting is performed on the columns specified when the index 101 | is created. Iteration over an index yields a sorted sequence of rows. An `Index` can be joined with 102 | a `DataSource`. The type has operations for finding rows and creating sub-indices in O(log(n)) time. 103 | Another useful operation is resolving duplicates. Building an index takes O(n*log(n)) time. It should 104 | be noted that the `Index` building operation requires the entire dataset to be read into 105 | the memory, so certain care should be taken when indexing huge datasets. An index can also be 106 | stored to, or loaded from a disk file. 107 | 108 | For more details see the [documentation](https://godoc.org/github.com/maxim2266/csvplus). 109 | 110 | ### Project status 111 | Tested on Linux Mint 22.1 using Go version 1.24.3. 112 | -------------------------------------------------------------------------------- /csvplus.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) 2016,2017,2018,2019 Maxim Konakov 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, 6 | are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, 9 | this list of conditions and the following disclaimer. 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 3. Neither the name of the copyright holder nor the names of its contributors 14 | may be used to endorse or promote products derived from this software without 15 | specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 20 | IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 21 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 22 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 23 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 24 | OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, 26 | EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | */ 28 | 29 | // Package csvplus extends the standard Go encoding/csv package with fluent 30 | // interface, lazy stream processing operations, indices and joins. 31 | package csvplus 32 | 33 | import ( 34 | "bytes" 35 | "encoding/csv" 36 | "encoding/gob" 37 | "encoding/json" 38 | "errors" 39 | "fmt" 40 | "io" 41 | "os" 42 | "sort" 43 | "strconv" 44 | "strings" 45 | "unsafe" 46 | ) 47 | 48 | /* 49 | Row represents one line from a data source like a .csv file. 50 | 51 | Each Row is a map from column names to the string values under that columns on the current line. 52 | It is assumed that each column has a unique name. 53 | In a .csv file, the column names may either come from the first line of the file ("expected header"), 54 | or they can be set-up via configuration of the reader object ("assumed header"). 55 | 56 | Using meaningful column names instead of indices is usually more convenient when the columns get rearranged 57 | during the execution of the processing pipeline. 58 | */ 59 | type Row map[string]string 60 | 61 | // HasColumn is a predicate returning 'true' when the specified column is present. 62 | func (row Row) HasColumn(col string) (found bool) { 63 | _, found = row[col] 64 | return 65 | } 66 | 67 | // SafeGetValue returns the value under the specified column, if present, otherwise it returns the 68 | // substitution value. 69 | func (row Row) SafeGetValue(col, subst string) string { 70 | if value, found := row[col]; found { 71 | return value 72 | } 73 | 74 | return subst 75 | } 76 | 77 | // Header returns a slice of all column names, sorted via sort.Strings. 78 | func (row Row) Header() []string { 79 | r := make([]string, 0, len(row)) 80 | 81 | for col := range row { 82 | r = append(r, col) 83 | } 84 | 85 | sort.Strings(r) 86 | return r 87 | } 88 | 89 | // String returns a string representation of the Row. 90 | func (row Row) String() string { 91 | if len(row) == 0 { 92 | return "{}" 93 | } 94 | 95 | header := row.Header() // make order predictable 96 | buff := append(append(append(append([]byte(`{ "`), header[0]...), `" : "`...), row[header[0]]...), '"') 97 | 98 | for _, col := range header[1:] { 99 | buff = append(append(append(append(append(buff, `, "`...), col...), `" : "`...), row[col]...), '"') 100 | } 101 | 102 | buff = append(buff, " }"...) 103 | return *(*string)(unsafe.Pointer(&buff)) 104 | } 105 | 106 | // SelectExisting takes a list of column names and returns a new Row 107 | // containing only those columns from the list that are present in the current Row. 108 | func (row Row) SelectExisting(cols ...string) Row { 109 | r := make(map[string]string, len(cols)) 110 | 111 | for _, name := range cols { 112 | if val, found := row[name]; found { 113 | r[name] = val 114 | } 115 | } 116 | 117 | return r 118 | } 119 | 120 | // Select takes a list of column names and returns a new Row 121 | // containing only the specified columns, or an error if any column is not present. 122 | func (row Row) Select(cols ...string) (Row, error) { 123 | r := make(map[string]string, len(cols)) 124 | 125 | for _, name := range cols { 126 | var found bool 127 | 128 | if r[name], found = row[name]; !found { 129 | return nil, fmt.Errorf(`missing column %q`, name) 130 | } 131 | } 132 | 133 | return r, nil 134 | } 135 | 136 | // SelectValues takes a list of column names and returns a slice of their 137 | // corresponding values, or an error if any column is not present. 138 | func (row Row) SelectValues(cols ...string) ([]string, error) { 139 | r := make([]string, len(cols)) 140 | 141 | for i, name := range cols { 142 | var found bool 143 | 144 | if r[i], found = row[name]; !found { 145 | return nil, fmt.Errorf(`missing column %q`, name) 146 | } 147 | } 148 | 149 | return r, nil 150 | } 151 | 152 | // Clone returns a copy of the current Row. 153 | func (row Row) Clone() Row { 154 | r := make(map[string]string, len(row)) 155 | 156 | for k, v := range row { 157 | r[k] = v 158 | } 159 | 160 | return r 161 | } 162 | 163 | // ValueAsInt returns the value of the given column converted to integer type, or an error. 164 | // The column must be present on the row. 165 | func (row Row) ValueAsInt(column string) (res int, err error) { 166 | var val string 167 | var found bool 168 | 169 | if val, found = row[column]; !found { 170 | err = fmt.Errorf(`missing column %q`, column) 171 | return 172 | } 173 | 174 | if res, err = strconv.Atoi(val); err != nil { 175 | if e, ok := err.(*strconv.NumError); ok { 176 | err = fmt.Errorf(`column %q: cannot convert %q to integer: %s`, column, val, e.Err) 177 | } else { 178 | err = fmt.Errorf(`column %q: %s`, column, err) 179 | } 180 | } 181 | 182 | return 183 | } 184 | 185 | // ValueAsFloat64 returns the value of the given column converted to floating point type, or an error. 186 | // The column must be present on the row. 187 | func (row Row) ValueAsFloat64(column string) (res float64, err error) { 188 | var val string 189 | var found bool 190 | 191 | if val, found = row[column]; !found { 192 | err = fmt.Errorf(`missing column %q`, column) 193 | return 194 | } 195 | 196 | if res, err = strconv.ParseFloat(val, 64); err != nil { 197 | if e, ok := err.(*strconv.NumError); ok { 198 | err = fmt.Errorf(`column %q: cannot convert %q to float: %s`, column, val, e.Err) 199 | } else { 200 | err = fmt.Errorf(`column %q: %s`, column, err.Error()) 201 | } 202 | } 203 | 204 | return 205 | } 206 | 207 | // RowFunc is the function type used when iterating Rows. 208 | type RowFunc func(Row) error 209 | 210 | // DataSource is the iterator type used throughout this library. The iterator 211 | // calls the given RowFunc once per each row. The iteration continues until 212 | // either the data source is exhausted or the supplied RowFunc returns a non-nil error, in 213 | // which case the error is returned back to the caller of the iterator. A special case of io.EOF simply 214 | // stops the iteration and the iterator function returns nil error. 215 | type DataSource func(RowFunc) error 216 | 217 | // TakeRows converts a slice of Rows to a DataSource. 218 | func TakeRows(rows []Row) DataSource { 219 | return func(fn RowFunc) error { 220 | return iterate(rows, fn) 221 | } 222 | } 223 | 224 | // the core iteration 225 | func iterate(rows []Row, fn RowFunc) (err error) { 226 | var row Row 227 | var i int 228 | 229 | for i, row = range rows { 230 | if err = fn(row.Clone()); err != nil { 231 | break 232 | } 233 | } 234 | 235 | switch err { 236 | case nil: 237 | // nothing to do 238 | case io.EOF: 239 | err = nil // end of iteration 240 | default: 241 | // wrap error 242 | err = &DataSourceError{ 243 | Line: uint64(i), 244 | Err: err, 245 | } 246 | } 247 | 248 | return 249 | } 250 | 251 | // Take converts anything with Iterate() method to a DataSource. 252 | func Take(src interface { 253 | Iterate(fn RowFunc) error 254 | }) DataSource { 255 | return src.Iterate 256 | } 257 | 258 | // Transform is the most generic operation on a Row. It takes a function that 259 | // maps a Row to another Row and an error. Any error returned from that function 260 | // stops the iteration, otherwise the returned Row, if not empty, gets passed 261 | // down to the next stage of the processing pipeline. 262 | func (src DataSource) Transform(trans func(Row) (Row, error)) DataSource { 263 | return func(fn RowFunc) error { 264 | return src(func(row Row) (err error) { 265 | if row, err = trans(row); err == nil && len(row) > 0 { 266 | err = fn(row) 267 | } 268 | 269 | return 270 | }) 271 | } 272 | } 273 | 274 | // Filter takes a predicate which, when applied to a Row, decides if that Row 275 | // should be passed down for further processing. The predicate should return 'true' to pass the Row. 276 | func (src DataSource) Filter(pred func(Row) bool) DataSource { 277 | return func(fn RowFunc) error { 278 | return src(func(row Row) (err error) { 279 | if pred(row) { 280 | err = fn(row) 281 | } 282 | 283 | return 284 | }) 285 | } 286 | } 287 | 288 | // Map takes a function which gets applied to each Row when the source is iterated over. The function 289 | // may return a modified input Row, or an entirely new Row. 290 | func (src DataSource) Map(mf func(Row) Row) DataSource { 291 | return func(fn RowFunc) error { 292 | return src(func(row Row) error { 293 | return fn(mf(row)) 294 | }) 295 | } 296 | } 297 | 298 | // Validate takes a function which checks every Row upon iteration and returns an error 299 | // if the validation fails. The iteration stops at the first error encountered. 300 | func (src DataSource) Validate(vf func(Row) error) DataSource { 301 | return func(fn RowFunc) error { 302 | return src(func(row Row) (err error) { 303 | if err = vf(row); err == nil { 304 | err = fn(row) 305 | } 306 | 307 | return 308 | }) 309 | } 310 | } 311 | 312 | // Top specifies the number of Rows to pass down the pipeline before stopping the iteration. 313 | func (src DataSource) Top(n uint64) DataSource { 314 | return func(fn RowFunc) error { 315 | counter := n 316 | 317 | return src(func(row Row) error { 318 | if counter == 0 { 319 | return io.EOF 320 | } 321 | 322 | counter-- 323 | return fn(row) 324 | }) 325 | } 326 | } 327 | 328 | // Drop specifies the number of Rows to ignore before passing the remaining rows down the pipeline. 329 | func (src DataSource) Drop(n uint64) DataSource { 330 | return func(fn RowFunc) error { 331 | counter := n 332 | 333 | return src(func(row Row) error { 334 | if counter == 0 { 335 | return fn(row) 336 | } 337 | 338 | counter-- 339 | return nil 340 | }) 341 | } 342 | } 343 | 344 | // TakeWhile takes a predicate which gets applied to each Row upon iteration. 345 | // The iteration stops when the predicate returns 'false'. 346 | func (src DataSource) TakeWhile(pred func(Row) bool) DataSource { 347 | return func(fn RowFunc) error { 348 | var done bool 349 | 350 | return src(func(row Row) error { 351 | if done = (done || !pred(row)); done { 352 | return io.EOF 353 | } 354 | 355 | return fn(row) 356 | }) 357 | } 358 | } 359 | 360 | // DropWhile ignores all the Rows for as long as the specified predicate is true; 361 | // afterwards all the remaining Rows are passed down the pipeline. 362 | func (src DataSource) DropWhile(pred func(Row) bool) DataSource { 363 | return func(fn RowFunc) error { 364 | var yield bool 365 | 366 | return src(func(row Row) (err error) { 367 | if yield = (yield || !pred(row)); yield { 368 | err = fn(row) 369 | } 370 | 371 | return 372 | }) 373 | } 374 | } 375 | 376 | // ToCsv iterates the data source and writes the selected columns in .csv format to the given io.Writer. 377 | // The data are written in the "canonical" form with the header on the first line and with all the lines 378 | // having the same number of fields, using default settings for the underlying csv.Writer. 379 | func (src DataSource) ToCsv(out io.Writer, columns ...string) (err error) { 380 | if len(columns) == 0 { 381 | panic("empty column list in ToCsv() function") 382 | } 383 | 384 | w := csv.NewWriter(out) 385 | 386 | // header 387 | if err = w.Write(columns); err == nil { 388 | // rows 389 | err = src(func(row Row) (e error) { 390 | var values []string 391 | 392 | if values, e = row.SelectValues(columns...); e == nil { 393 | e = w.Write(values) 394 | } 395 | 396 | return 397 | }) 398 | } 399 | 400 | if err == nil { 401 | w.Flush() 402 | err = w.Error() 403 | } 404 | 405 | return 406 | } 407 | 408 | // ToCsvFile iterates the data source and writes the selected columns in .csv format to the given file. 409 | // The data are written in the "canonical" form with the header on the first line and with all the lines 410 | // having the same number of fields, using default settings for the underlying csv.Writer. 411 | func (src DataSource) ToCsvFile(name string, columns ...string) error { 412 | return writeFile(name, func(file io.Writer) error { 413 | return src.ToCsv(file, columns...) 414 | }) 415 | } 416 | 417 | // call the given function with the file stream open for writing 418 | func writeFile(name string, fn func(io.Writer) error) (err error) { 419 | var file *os.File 420 | 421 | if file, err = os.Create(name); err != nil { 422 | return 423 | } 424 | 425 | defer func() { 426 | if p := recover(); p != nil { 427 | file.Close() 428 | os.Remove(name) 429 | panic(p) 430 | } 431 | 432 | if e := file.Close(); e != nil && err == nil { 433 | err = e 434 | } 435 | 436 | if err != nil { 437 | os.Remove(name) 438 | } 439 | }() 440 | 441 | err = fn(file) 442 | return 443 | } 444 | 445 | // ToJSON iterates over the data source and writes all Rows to the given io.Writer in JSON format. 446 | func (src DataSource) ToJSON(out io.Writer) (err error) { 447 | var buff bytes.Buffer 448 | 449 | buff.WriteByte('[') 450 | 451 | count := uint64(0) 452 | enc := json.NewEncoder(&buff) 453 | 454 | enc.SetIndent("", "") 455 | enc.SetEscapeHTML(false) 456 | 457 | err = src(func(row Row) (e error) { 458 | if count++; count != 1 { 459 | buff.WriteByte(',') 460 | } 461 | 462 | if e = enc.Encode(row); e == nil && buff.Len() > 10000 { 463 | _, e = buff.WriteTo(out) 464 | } 465 | 466 | return 467 | }) 468 | 469 | if err == nil { 470 | buff.WriteByte(']') 471 | _, err = buff.WriteTo(out) 472 | } 473 | 474 | return 475 | } 476 | 477 | // ToJSONFile iterates over the data source and writes all Rows to the given file in JSON format. 478 | func (src DataSource) ToJSONFile(name string) error { 479 | return writeFile(name, src.ToJSON) 480 | } 481 | 482 | // ToRows iterates the DataSource storing the result in a slice of Rows. 483 | func (src DataSource) ToRows() (rows []Row, err error) { 484 | err = src(func(row Row) error { 485 | rows = append(rows, row) 486 | return nil 487 | }) 488 | 489 | return 490 | } 491 | 492 | // DropColumns removes the specifies columns from each row. 493 | func (src DataSource) DropColumns(columns ...string) DataSource { 494 | if len(columns) == 0 { 495 | panic("no columns specified in DropColumns()") 496 | } 497 | 498 | return func(fn RowFunc) error { 499 | return src(func(row Row) error { 500 | for _, col := range columns { 501 | delete(row, col) 502 | } 503 | 504 | return fn(row) 505 | }) 506 | } 507 | } 508 | 509 | // SelectColumns leaves only the specified columns on each row. It is an error 510 | // if any such column does not exist. 511 | func (src DataSource) SelectColumns(columns ...string) DataSource { 512 | if len(columns) == 0 { 513 | panic("no columns specified in SelectColumns()") 514 | } 515 | 516 | return func(fn RowFunc) error { 517 | return src(func(row Row) (err error) { 518 | if row, err = row.Select(columns...); err == nil { 519 | err = fn(row) 520 | } 521 | 522 | return 523 | }) 524 | } 525 | } 526 | 527 | // IndexOn iterates the input source, building index on the specified columns. 528 | // Columns are taken from the specified list from left to the right. 529 | func (src DataSource) IndexOn(columns ...string) (*Index, error) { 530 | return createIndex(src, columns) 531 | } 532 | 533 | // UniqueIndexOn iterates the input source, building unique index on the specified columns. 534 | // Columns are taken from the specified list from left to the right. 535 | func (src DataSource) UniqueIndexOn(columns ...string) (*Index, error) { 536 | return createUniqueIndex(src, columns) 537 | } 538 | 539 | // Join returns a DataSource which is a join between the current DataSource and the specified 540 | // Index. The specified columns are matched against those from the index, in the order of specification. 541 | // Empty 'columns' list yields a join on the columns from the Index (aka "natural join") which all must 542 | // exist in the current DataSource. 543 | // Each row in the resulting table contains all the columns from both the current table and the index. 544 | // This is a lazy operation, the actual join is performed only when the resulting table is iterated over. 545 | func (src DataSource) Join(index *Index, columns ...string) DataSource { 546 | if len(columns) == 0 { 547 | columns = index.impl.columns 548 | } else if len(columns) > len(index.impl.columns) { 549 | panic("too many source columns in Join()") 550 | } 551 | 552 | return func(fn RowFunc) error { 553 | return src(func(row Row) (err error) { 554 | var values []string 555 | 556 | if values, err = row.SelectValues(columns...); err == nil { 557 | n := len(index.impl.rows) 558 | 559 | for i := index.impl.first(values); i < n && !index.impl.cmp(i, values, false); i++ { 560 | if err = fn(mergeRows(index.impl.rows[i], row)); err != nil { 561 | break 562 | } 563 | } 564 | } 565 | 566 | return 567 | }) 568 | } 569 | } 570 | 571 | func mergeRows(left, right Row) Row { 572 | r := make(map[string]string, len(left)+len(right)) 573 | 574 | for k, v := range left { 575 | r[k] = v 576 | } 577 | 578 | for k, v := range right { 579 | r[k] = v 580 | } 581 | 582 | return r 583 | } 584 | 585 | // Except returns a table containing all the rows not in the specified Index, unchanged. The specified 586 | // columns are matched against those from the index, in the order of specification. If no columns 587 | // are specified then the columns list is taken from the index. 588 | func (src DataSource) Except(index *Index, columns ...string) DataSource { 589 | if len(columns) == 0 { 590 | columns = index.impl.columns 591 | } else if len(columns) > len(index.impl.columns) { 592 | panic("too many source columns in Except()") 593 | } 594 | 595 | return func(fn RowFunc) error { 596 | return src(func(row Row) (err error) { 597 | var values []string 598 | 599 | if values, err = row.SelectValues(columns...); err == nil { 600 | if !index.impl.has(values) { 601 | err = fn(row) 602 | } 603 | } 604 | 605 | return 606 | }) 607 | } 608 | } 609 | 610 | // Index is a sorted collection of Rows with O(log(n)) complexity of search 611 | // on the indexed columns. Iteration over the Index yields a sequence of Rows sorted on the index. 612 | type Index struct { 613 | impl indexImpl 614 | } 615 | 616 | // Iterate iterates over all rows of the index. The rows are sorted by the values of the columns 617 | // specified when the Index was created. 618 | func (index *Index) Iterate(fn RowFunc) error { 619 | return iterate(index.impl.rows, fn) 620 | } 621 | 622 | // Find returns a DataSource of all Rows from the Index that match the specified values 623 | // in the indexed columns, left to the right. The number of specified values may be less than 624 | // the number of the indexed columns. 625 | func (index *Index) Find(values ...string) DataSource { 626 | return TakeRows(index.impl.find(values)) 627 | } 628 | 629 | // SubIndex returns an Index containing only the rows where the values of the 630 | // indexed columns match the supplied values, left to the right. The number of specified values 631 | // must be less than the number of indexed columns. 632 | func (index *Index) SubIndex(values ...string) *Index { 633 | if len(values) >= len(index.impl.columns) { 634 | panic("too many values in SubIndex()") 635 | } 636 | 637 | return &Index{indexImpl{ 638 | rows: index.impl.find(values), 639 | columns: index.impl.columns[len(values):], 640 | }} 641 | } 642 | 643 | // ResolveDuplicates calls the specified function once per each pack of duplicates with the same key. 644 | // The specified function must not modify its parameter and is expected to do one of the following: 645 | // 646 | // - Select and return one row from the input list. The row will be used as the only row with its key; 647 | // 648 | // - Return an empty row. The entire set of rows will be ignored; 649 | // 650 | // - Return an error which will be passed back to the caller of ResolveDuplicates(). 651 | func (index *Index) ResolveDuplicates(resolve func(rows []Row) (Row, error)) error { 652 | return index.impl.dedup(resolve) 653 | } 654 | 655 | // WriteTo writes the index to the specified file. 656 | func (index *Index) WriteTo(fileName string) (err error) { 657 | var file *os.File 658 | 659 | if file, err = os.Create(fileName); err != nil { 660 | return 661 | } 662 | 663 | defer func() { 664 | if e := file.Close(); e != nil || err != nil { 665 | os.Remove(fileName) 666 | 667 | if err == nil { 668 | err = e 669 | } 670 | } 671 | }() 672 | 673 | enc := gob.NewEncoder(file) 674 | 675 | if err = enc.Encode(index.impl.columns); err == nil { 676 | err = enc.Encode(index.impl.rows) 677 | } 678 | 679 | return 680 | } 681 | 682 | // LoadIndex reads the index from the specified file. 683 | func LoadIndex(fileName string) (*Index, error) { 684 | var file *os.File 685 | var err error 686 | 687 | if file, err = os.Open(fileName); err != nil { 688 | return nil, err 689 | } 690 | 691 | defer file.Close() 692 | 693 | index := new(Index) 694 | dec := gob.NewDecoder(file) 695 | 696 | if err = dec.Decode(&index.impl.columns); err != nil { 697 | return nil, err 698 | } 699 | 700 | if err = dec.Decode(&index.impl.rows); err != nil { 701 | return nil, err 702 | } 703 | 704 | return index, nil 705 | } 706 | 707 | func createIndex(src DataSource, columns []string) (*Index, error) { 708 | switch len(columns) { 709 | case 0: 710 | panic("empty column list in CreateIndex()") 711 | case 1: 712 | // do nothing 713 | default: 714 | if !allColumnsUnique(columns) { 715 | panic("duplicate column name(s) in CreateIndex()") 716 | } 717 | } 718 | 719 | index := &Index{indexImpl{columns: columns}} 720 | 721 | // copy Rows with validation 722 | if err := src(func(row Row) error { 723 | for _, col := range columns { 724 | if !row.HasColumn(col) { 725 | return fmt.Errorf(`missing column %q while creating an index`, col) 726 | } 727 | } 728 | 729 | index.impl.rows = append(index.impl.rows, row) 730 | return nil 731 | }); err != nil { 732 | return nil, err 733 | } 734 | 735 | // sort 736 | sort.Sort(&index.impl) 737 | return index, nil 738 | } 739 | 740 | func createUniqueIndex(src DataSource, columns []string) (index *Index, err error) { 741 | // create index 742 | if index, err = createIndex(src, columns); err != nil || len(index.impl.rows) < 2 { 743 | return 744 | } 745 | 746 | // check for duplicates by linear search; not the best idea. 747 | rows := index.impl.rows 748 | 749 | for i := 1; i < len(rows); i++ { 750 | if equalRows(columns, rows[i-1], rows[i]) { 751 | return nil, errors.New("duplicate value while creating unique index: " + rows[i].SelectExisting(columns...).String()) 752 | } 753 | } 754 | 755 | return 756 | } 757 | 758 | // compare the specified columns from the two rows 759 | func equalRows(columns []string, r1, r2 Row) bool { 760 | for _, col := range columns { 761 | if r1[col] != r2[col] { 762 | return false 763 | } 764 | } 765 | 766 | return true 767 | } 768 | 769 | // check if all the column names from the specified list are unique 770 | func allColumnsUnique(columns []string) bool { 771 | set := make(map[string]struct{}, len(columns)) 772 | 773 | for _, col := range columns { 774 | if _, found := set[col]; found { 775 | return false 776 | } 777 | 778 | set[col] = struct{}{} 779 | } 780 | 781 | return true 782 | } 783 | 784 | // index implementation 785 | type indexImpl struct { 786 | rows []Row 787 | columns []string 788 | } 789 | 790 | // functions required by sort.Sort() 791 | func (index *indexImpl) Len() int { return len(index.rows) } 792 | func (index *indexImpl) Swap(i, j int) { index.rows[i], index.rows[j] = index.rows[j], index.rows[i] } 793 | 794 | func (index *indexImpl) Less(i, j int) bool { 795 | left, right := index.rows[i], index.rows[j] 796 | 797 | for _, col := range index.columns { 798 | switch strings.Compare(left[col], right[col]) { 799 | case -1: 800 | return true 801 | case 1: 802 | return false 803 | } 804 | } 805 | 806 | return false 807 | } 808 | 809 | // deduplication 810 | func (index *indexImpl) dedup(resolve func(rows []Row) (Row, error)) (err error) { 811 | // find first duplicate 812 | var lower int 813 | 814 | for lower = 1; lower < len(index.rows); lower++ { 815 | if equalRows(index.columns, index.rows[lower-1], index.rows[lower]) { 816 | break 817 | } 818 | } 819 | 820 | if lower == len(index.rows) { 821 | return 822 | } 823 | 824 | dest := lower - 1 825 | 826 | // loop: find index of the first row with another key, resolve, then find next duplicate 827 | for lower < len(index.rows) { 828 | // the duplicate is in [lower-1, upper[ range 829 | values, _ := index.rows[lower].SelectValues(index.columns...) 830 | 831 | upper := lower + sort.Search(len(index.rows)-lower, func(i int) bool { 832 | return index.cmp(lower+i, values, false) 833 | }) 834 | 835 | // resolve 836 | var row Row 837 | 838 | if row, err = resolve(index.rows[lower-1 : upper]); err != nil { 839 | return 840 | } 841 | 842 | lower = upper + 1 843 | 844 | // store the chosen row if not 'empty' 845 | if len(row) >= len(index.columns) { 846 | index.rows[dest] = row 847 | dest++ 848 | } 849 | 850 | // find next duplicate 851 | for lower < len(index.rows) { 852 | if equalRows(index.columns, index.rows[lower-1], index.rows[lower]) { 853 | break 854 | } 855 | 856 | index.rows[dest] = index.rows[lower-1] 857 | lower++ 858 | dest++ 859 | } 860 | } 861 | 862 | if err == nil { 863 | index.rows = index.rows[:dest] 864 | } 865 | 866 | return 867 | } 868 | 869 | // search on the index 870 | func (index *indexImpl) find(values []string) []Row { 871 | // check boundaries 872 | if len(values) == 0 { 873 | return index.rows 874 | } 875 | 876 | if len(values) > len(index.columns) { 877 | panic("too many columns in indexImpl.find()") 878 | } 879 | 880 | // get bounds 881 | upper := sort.Search(len(index.rows), func(i int) bool { 882 | return index.cmp(i, values, false) 883 | }) 884 | 885 | lower := sort.Search(upper, func(i int) bool { 886 | return index.cmp(i, values, true) 887 | }) 888 | 889 | // done 890 | return index.rows[lower:upper] 891 | } 892 | 893 | func (index *indexImpl) first(values []string) int { 894 | return sort.Search(len(index.rows), func(i int) bool { 895 | return index.cmp(i, values, true) 896 | }) 897 | } 898 | 899 | func (index *indexImpl) has(values []string) bool { 900 | // find the lowest index 901 | i := index.first(values) 902 | 903 | // check if the row at the lowest index matches the values 904 | return i < len(index.rows) && !index.cmp(i, values, false) 905 | } 906 | 907 | func (index *indexImpl) cmp(i int, values []string, eq bool) bool { 908 | row := index.rows[i] 909 | 910 | for j, val := range values { 911 | switch strings.Compare(row[index.columns[j]], val) { 912 | case 1: 913 | return true 914 | case -1: 915 | return false 916 | } 917 | } 918 | 919 | return eq 920 | } 921 | 922 | // Reader is iterable csv reader. The iteration is performed in its Iterate() method, which 923 | // may only be invoked once per each instance of the Reader. 924 | type Reader struct { 925 | source maker 926 | delimiter, comment rune 927 | numFields int 928 | lazyQuotes, trimLeadingSpace bool 929 | header map[string]int 930 | headerFromFirstRow bool 931 | } 932 | 933 | type maker = func() (io.Reader, func(), error) 934 | 935 | // FromReader constructs a new csv reader from the given io.Reader, with default settings. 936 | func FromReader(input io.Reader) *Reader { 937 | return makeReader(func() (io.Reader, func(), error) { 938 | return input, func() {}, nil 939 | }) 940 | } 941 | 942 | // FromReadCloser constructs a new csv reader from the given io.ReadCloser, with default settings. 943 | func FromReadCloser(input io.ReadCloser) *Reader { 944 | return makeReader(func() (io.Reader, func(), error) { 945 | return input, func() { input.Close() }, nil 946 | }) 947 | } 948 | 949 | // FromFile constructs a new csv reader bound to the specified file, with default settings. 950 | func FromFile(name string) *Reader { 951 | return makeReader(func() (io.Reader, func(), error) { 952 | file, err := os.Open(name) 953 | 954 | if err != nil { 955 | return nil, nil, err 956 | } 957 | 958 | return file, func() { file.Close() }, nil 959 | }) 960 | } 961 | 962 | func makeReader(fn maker) *Reader { 963 | return &Reader{ 964 | source: fn, 965 | delimiter: ',', 966 | headerFromFirstRow: true, 967 | } 968 | } 969 | 970 | // Delimiter sets the symbol to be used as a field delimiter. 971 | func (r *Reader) Delimiter(c rune) *Reader { 972 | r.delimiter = c 973 | return r 974 | } 975 | 976 | // CommentChar sets the symbol that starts a comment. 977 | func (r *Reader) CommentChar(c rune) *Reader { 978 | r.comment = c 979 | return r 980 | } 981 | 982 | // LazyQuotes specifies that a quote may appear in an unquoted field and a 983 | // non-doubled quote may appear in a quoted field of the input. 984 | func (r *Reader) LazyQuotes() *Reader { 985 | r.lazyQuotes = true 986 | return r 987 | } 988 | 989 | // TrimLeadingSpace specifies that the leading white space in a field should be ignored. 990 | func (r *Reader) TrimLeadingSpace() *Reader { 991 | r.trimLeadingSpace = true 992 | return r 993 | } 994 | 995 | // AssumeHeader sets the header for those input sources that do not have their column 996 | // names specified on the first row. The header specification is a map 997 | // from the assigned column names to their corresponding column indices. 998 | func (r *Reader) AssumeHeader(spec map[string]int) *Reader { 999 | if len(spec) == 0 { 1000 | panic("Empty header spec") 1001 | } 1002 | 1003 | for name, col := range spec { 1004 | if col < 0 { 1005 | panic("header spec: negative index for column " + name) 1006 | } 1007 | } 1008 | 1009 | r.header = spec 1010 | r.headerFromFirstRow = false 1011 | return r 1012 | } 1013 | 1014 | // ExpectHeader sets the header for input sources that have their column 1015 | // names specified on the first row. The row gets verified 1016 | // against this specification when the reading starts. 1017 | // The header specification is a map from the expected column names to their corresponding 1018 | // column indices. A negative value for an index means that the real value of the index 1019 | // will be found by searching the first row for the specified column name. 1020 | func (r *Reader) ExpectHeader(spec map[string]int) *Reader { 1021 | if len(spec) == 0 { 1022 | panic("empty header spec") 1023 | } 1024 | 1025 | r.header = make(map[string]int, len(spec)) 1026 | 1027 | for name, col := range spec { 1028 | r.header[name] = col 1029 | } 1030 | 1031 | r.headerFromFirstRow = true 1032 | return r 1033 | } 1034 | 1035 | // SelectColumns specifies the names of the columns to read from the input source. 1036 | // The header specification is built by searching the first row of the input 1037 | // for the names specified and recording the indices of those columns. It is an error 1038 | // if any column name is not found. 1039 | func (r *Reader) SelectColumns(names ...string) *Reader { 1040 | if len(names) == 0 { 1041 | panic("empty header spec") 1042 | } 1043 | 1044 | r.header = make(map[string]int, len(names)) 1045 | 1046 | for _, name := range names { 1047 | if _, found := r.header[name]; found { 1048 | panic("header spec: duplicate column name: " + name) 1049 | } 1050 | 1051 | r.header[name] = -1 1052 | } 1053 | 1054 | r.headerFromFirstRow = true 1055 | return r 1056 | } 1057 | 1058 | // NumFields sets the expected number of fields on each row of the input. 1059 | // It is an error if any row does not have this exact number of fields. 1060 | func (r *Reader) NumFields(n int) *Reader { 1061 | r.numFields = n 1062 | return r 1063 | } 1064 | 1065 | // NumFieldsAuto specifies that the number of fields on each row must match that of 1066 | // the first row of the input. 1067 | func (r *Reader) NumFieldsAuto() *Reader { 1068 | return r.NumFields(0) 1069 | } 1070 | 1071 | // NumFieldsAny specifies that each row of the input may have different number 1072 | // of fields. Rows shorter than the maximum column index in the header specification will be padded 1073 | // with empty fields. 1074 | func (r *Reader) NumFieldsAny() *Reader { 1075 | return r.NumFields(-1) 1076 | } 1077 | 1078 | // Iterate reads the input row by row, converts each line to the Row type, and calls 1079 | // the supplied RowFunc. 1080 | func (r *Reader) Iterate(fn RowFunc) error { 1081 | // source 1082 | input, close, err := r.source() 1083 | 1084 | if err != nil { 1085 | return err 1086 | } 1087 | 1088 | defer close() 1089 | 1090 | // csv.Reader 1091 | reader := csv.NewReader(input) 1092 | 1093 | reader.Comma = r.delimiter 1094 | reader.Comment = r.comment 1095 | reader.LazyQuotes = r.lazyQuotes 1096 | reader.TrimLeadingSpace = r.trimLeadingSpace 1097 | reader.FieldsPerRecord = r.numFields 1098 | 1099 | // header 1100 | var header map[string]int 1101 | 1102 | lineNo := uint64(1) 1103 | 1104 | if r.headerFromFirstRow { 1105 | if header, err = r.makeHeader(reader); err != nil { 1106 | return mapError(err, lineNo) 1107 | } 1108 | 1109 | lineNo++ 1110 | } else { 1111 | header = r.header 1112 | } 1113 | 1114 | // iteration 1115 | var line []string 1116 | 1117 | for line, err = reader.Read(); err == nil; line, err = reader.Read() { 1118 | row := make(map[string]string, len(header)) 1119 | 1120 | for name, index := range header { 1121 | if index < len(line) { 1122 | row[name] = line[index] 1123 | } else if r.numFields < 0 { // padding allowed 1124 | row[name] = "" 1125 | } else { 1126 | return &DataSourceError{ 1127 | Line: lineNo, 1128 | Err: fmt.Errorf("column not found: %q (%d)", name, index), 1129 | } 1130 | } 1131 | } 1132 | 1133 | if err = fn(row); err != nil { 1134 | break 1135 | } 1136 | 1137 | lineNo++ 1138 | } 1139 | 1140 | // map error 1141 | if err != io.EOF { 1142 | return mapError(err, lineNo) 1143 | } 1144 | 1145 | return nil 1146 | } 1147 | 1148 | // build header spec from the first row of the input file 1149 | func (r *Reader) makeHeader(reader *csv.Reader) (map[string]int, error) { 1150 | line, err := reader.Read() 1151 | 1152 | if err != nil { 1153 | return nil, err 1154 | } 1155 | 1156 | if len(line) == 0 { 1157 | return nil, errors.New("empty header") 1158 | } 1159 | 1160 | if len(r.header) == 0 { // the header is not specified - get it from the first line 1161 | header := make(map[string]int, len(line)) 1162 | 1163 | for i, name := range line { 1164 | header[name] = i 1165 | } 1166 | 1167 | return header, nil 1168 | } 1169 | 1170 | // check and update the specified header 1171 | header := make(map[string]int, len(r.header)) 1172 | 1173 | // fix column indices 1174 | for i, name := range line { 1175 | if index, found := r.header[name]; found { 1176 | if index == -1 || index == i { 1177 | header[name] = i 1178 | } else { 1179 | return nil, fmt.Errorf(`misplaced column %q: expected at pos. %d, but found at pos. %d`, 1180 | name, index, i) 1181 | } 1182 | } 1183 | } 1184 | 1185 | // check if all columns are found 1186 | if len(header) < len(r.header) { 1187 | // compose the list of the missing columns 1188 | var list []string 1189 | 1190 | for name := range r.header { 1191 | if _, found := header[name]; !found { 1192 | list = append(list, name) 1193 | } 1194 | } 1195 | 1196 | // return error message 1197 | if len(list) > 1 { 1198 | return nil, errors.New("columns not found: " + strings.Join(list, ", ")) 1199 | } 1200 | 1201 | return nil, errors.New("column not found: " + list[0]) 1202 | } 1203 | 1204 | // all done 1205 | return header, nil 1206 | } 1207 | 1208 | // annotate error with line number 1209 | func mapError(err error, lineNo uint64) error { 1210 | switch e := err.(type) { 1211 | case *csv.ParseError: 1212 | return &DataSourceError{ 1213 | Line: lineNo, 1214 | Err: e.Err, 1215 | } 1216 | case *os.PathError: 1217 | return &DataSourceError{ 1218 | Line: lineNo, 1219 | Err: errors.New(e.Op + ": " + e.Err.Error()), 1220 | } 1221 | default: 1222 | return &DataSourceError{ 1223 | Line: lineNo, 1224 | Err: err, 1225 | } 1226 | } 1227 | } 1228 | 1229 | // DataSourceError is the type of the error returned from Reader.Iterate method. 1230 | type DataSourceError struct { 1231 | Line uint64 // counting from 1 1232 | Err error 1233 | } 1234 | 1235 | // Error returns a human-readable error message string. 1236 | func (e *DataSourceError) Error() string { 1237 | return fmt.Sprintf(`row %d: %s`, e.Line, e.Err) 1238 | } 1239 | 1240 | // All is a predicate combinator that takes any number of other predicates and 1241 | // produces a new predicate which returns 'true' only if all the specified predicates 1242 | // return 'true' for the same input Row. 1243 | func All(funcs ...func(Row) bool) func(Row) bool { 1244 | return func(row Row) bool { 1245 | for _, pred := range funcs { 1246 | if !pred(row) { 1247 | return false 1248 | } 1249 | } 1250 | 1251 | return true 1252 | } 1253 | } 1254 | 1255 | // Any is a predicate combinator that takes any number of other predicates and 1256 | // produces a new predicate which returns 'true' if any the specified predicates 1257 | // returns 'true' for the same input Row. 1258 | func Any(funcs ...func(Row) bool) func(Row) bool { 1259 | return func(row Row) bool { 1260 | for _, pred := range funcs { 1261 | if pred(row) { 1262 | return true 1263 | } 1264 | } 1265 | 1266 | return false 1267 | } 1268 | } 1269 | 1270 | // Not produces a new predicate that reverts the return value from the given predicate. 1271 | func Not(pred func(Row) bool) func(Row) bool { 1272 | return func(row Row) bool { 1273 | return !pred(row) 1274 | } 1275 | } 1276 | 1277 | // Like produces a predicate that returns 'true' if its input Row matches all the corresponding 1278 | // values from the specified 'match' Row. 1279 | func Like(match Row) func(Row) bool { 1280 | if len(match) == 0 { 1281 | panic("empty match row in Like() predicate") 1282 | } 1283 | 1284 | return func(row Row) bool { 1285 | for key, val := range match { 1286 | if v, found := row[key]; !found || v != val { 1287 | return false 1288 | } 1289 | } 1290 | 1291 | return true 1292 | } 1293 | } 1294 | -------------------------------------------------------------------------------- /csvplus_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) 2016,2017,2018,2019 Maxim Konakov 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, 6 | are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, 9 | this list of conditions and the following disclaimer. 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 3. Neither the name of the copyright holder nor the names of its contributors 14 | may be used to endorse or promote products derived from this software without 15 | specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 20 | IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, 21 | INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 22 | BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 23 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 24 | OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 | NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, 26 | EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | */ 28 | 29 | package csvplus 30 | 31 | import ( 32 | "bytes" 33 | "encoding/csv" 34 | "encoding/json" 35 | "errors" 36 | "flag" 37 | "fmt" 38 | "io" 39 | "math" 40 | "math/rand" 41 | "os" 42 | "sort" 43 | "strconv" 44 | "strings" 45 | "testing" 46 | "time" 47 | ) 48 | 49 | func TestRow(t *testing.T) { 50 | row := Row{ 51 | "id": "12345", 52 | "Name": "John", 53 | "Surname": "Doe", 54 | } 55 | 56 | if !row.HasColumn("Name") || !row.HasColumn("Surname") || !row.HasColumn("id") { 57 | t.Error("Failed HasColumn() test") 58 | return 59 | } 60 | 61 | hdr := row.Header() 62 | 63 | if len(hdr) != len(row) { 64 | t.Error("Invalid header length:", len(hdr)) 65 | return 66 | } 67 | 68 | for _, name := range hdr { 69 | if !row.HasColumn(name) { 70 | t.Error("Column not found:", name) 71 | return 72 | } 73 | } 74 | 75 | if row.SafeGetValue("Name", "") != "John" || row.SafeGetValue("xxx", "@") != "@" { 76 | t.Error("SafeGetValue() test failed") 77 | return 78 | } 79 | 80 | if s := row.SelectExisting("Name", "xxx").String(); s != `{ "Name" : "John" }` { 81 | t.Error("SafeSelect() test failed:", s) 82 | return 83 | } 84 | 85 | if _, e := row.Select("xxx", "zzz"); e == nil || e.Error() != `missing column "xxx"` { 86 | t.Error("Select() test failed:", e) 87 | return 88 | } 89 | 90 | if _, e := row.Select("id", "zzz"); e == nil || e.Error() != `missing column "zzz"` { 91 | t.Error("Select() test failed:", e) 92 | return 93 | } 94 | 95 | if r, e := row.Select("id"); e != nil || r.String() != `{ "id" : "12345" }` { 96 | t.Error("Select() test failed:", e, r.String()) 97 | return 98 | } 99 | 100 | if _, e := row.SelectValues("id", "xxx"); e == nil || e.Error() != `missing column "xxx"` { 101 | t.Error("SelectValues() test failed:", e) 102 | return 103 | } 104 | 105 | l, e := row.SelectValues("id", "Name") 106 | 107 | if e != nil || len(l) != 2 || l[0] != "12345" || l[1] != "John" { 108 | t.Error("SelectValues() test failed:", l, e) 109 | return 110 | } 111 | 112 | if s := row.String(); s != `{ "Name" : "John", "Surname" : "Doe", "id" : "12345" }` { 113 | t.Error("String() test failed:", s) 114 | return 115 | } 116 | } 117 | 118 | func TestSimpleDataSource(t *testing.T) { 119 | var n int 120 | 121 | hdr := sortedCopy(peopleHeader) 122 | src := Take(FromFile(tempFiles["people"]).SelectColumns(hdr...)). 123 | Filter(Any(Like(Row{"name": "Jack"}), Like(Row{"name": "Amelia"}))) 124 | 125 | err := src(func(row Row) error { 126 | if name := row.SafeGetValue("name", ""); name != "Jack" && name != "Amelia" { 127 | return errors.New("Unexpected name: " + name) 128 | } 129 | 130 | if len(row) != 4 { 131 | return fmt.Errorf("Unexpected number of columns: %d", len(row)) 132 | } 133 | 134 | for i, name := range row.Header() { 135 | if hdr[i] != name { 136 | return errors.New("Unexpected column name: " + name) 137 | } 138 | } 139 | 140 | n++ 141 | return nil 142 | }) 143 | 144 | if err != nil { 145 | t.Error(err) 146 | return 147 | } else if n != len(peopleSurnames)*2 { 148 | t.Error("Invalid number of rows:", n) 149 | return 150 | } 151 | } 152 | 153 | func TestFilterMap(t *testing.T) { 154 | src := Take(FromFile(tempFiles["people"]).SelectColumns("name", "surname", "id")). 155 | Filter(Like(Row{"name": "Amelia"})). 156 | Map(func(row Row) Row { row["name"] = "Julia"; return row }) 157 | 158 | err := src(func(row Row) error { 159 | if row["name"] != "Julia" { 160 | return fmt.Errorf("Unexpected name: %s instead of Julia", row["name"]) 161 | } 162 | 163 | return nil 164 | }) 165 | 166 | if err != nil { 167 | t.Error(err) 168 | return 169 | } 170 | } 171 | 172 | func TestWriteFile(t *testing.T) { 173 | var tmpFileName string 174 | var file1, file2 []byte 175 | 176 | defer os.Remove(tmpFileName) 177 | 178 | src := Take(FromFile(tempFiles["people"]).SelectColumns(peopleHeader...)) 179 | 180 | err := anyFrom( 181 | func() (e error) { tmpFileName, e = createTempFile(""); return }, 182 | func() (e error) { e = src.ToCsvFile(tmpFileName, peopleHeader...); return }, 183 | func() (e error) { file1, e = os.ReadFile(tmpFileName); return }, 184 | func() (e error) { file2, e = os.ReadFile(tempFiles["people"]); return }, 185 | ) 186 | 187 | if err == nil { 188 | if !bytes.Equal(bytes.TrimSpace(file1), bytes.TrimSpace(file2)) { 189 | t.Error("Files do not match") 190 | return 191 | } 192 | } else { 193 | t.Error(err) 194 | return 195 | } 196 | } 197 | 198 | func TestIndexImpl(t *testing.T) { 199 | index := indexImpl{ 200 | columns: []string{"x", "y", "z"}, 201 | rows: []Row{ 202 | {"x": "1", "y": "2", "z": "3", "junk": "zzz"}, 203 | {"x": "5", "y": "6", "z": "8", "junk": "nnn"}, 204 | {"x": "0", "y": "5", "z": "3", "junk": "xxx"}, 205 | {"x": "8", "y": "9", "z": "1", "junk": "aaa"}, 206 | {"x": "7", "y": "4", "z": "0", "junk": "bbb"}, 207 | {"x": "5", "y": "6", "z": "9", "junk": "iii"}, 208 | {"x": "2", "y": "6", "z": "7", "junk": "mmm"}, 209 | }, 210 | } 211 | 212 | sort.Sort(&index) 213 | 214 | rows := index.find([]string{"1", "2", "3"}) 215 | 216 | if len(rows) != 1 || 217 | rows[0].SafeGetValue("x", "") != "1" || 218 | rows[0].SafeGetValue("y", "") != "2" || 219 | rows[0].SafeGetValue("z", "") != "3" || 220 | rows[0].SafeGetValue("junk", "") != "zzz" { 221 | t.Errorf("Bad rows: %v", rows) 222 | return 223 | } 224 | 225 | rows = index.find([]string{"5", "6", "8"}) 226 | 227 | if len(rows) != 1 || 228 | rows[0].SafeGetValue("x", "") != "5" || 229 | rows[0].SafeGetValue("y", "") != "6" || 230 | rows[0].SafeGetValue("z", "") != "8" || 231 | rows[0].SafeGetValue("junk", "") != "nnn" { 232 | t.Errorf("Bad rows: %v", rows) 233 | return 234 | } 235 | 236 | rows = index.find([]string{"5", "6"}) 237 | 238 | if len(rows) != 2 || 239 | rows[0].SafeGetValue("x", "") != "5" || 240 | rows[0].SafeGetValue("y", "") != "6" || 241 | rows[1].SafeGetValue("x", "") != "5" || 242 | rows[1].SafeGetValue("y", "") != "6" { 243 | t.Errorf("Bad rows: %v", rows) 244 | return 245 | } 246 | } 247 | 248 | func TestLongChain(t *testing.T) { 249 | var err error 250 | var products, orders *Index 251 | 252 | orders, err = Take(FromFile(tempFiles["orders"]).SelectColumns("order_id", "cust_id", "prod_id", "qty", "ts")). 253 | IndexOn("cust_id") 254 | 255 | if err != nil { 256 | t.Error(err) 257 | return 258 | } 259 | 260 | products, err = Take(FromFile(tempFiles["stock"]).SelectColumns("prod_id", "product", "price")). 261 | UniqueIndexOn("prod_id") 262 | 263 | if err != nil { 264 | t.Error(err) 265 | return 266 | } 267 | 268 | people := Take(FromFile(tempFiles["people"]).SelectColumns("id", "name", "surname", "born")) 269 | 270 | var n int 271 | 272 | err = people.Filter(func(row Row) bool { 273 | year, e := row.ValueAsInt("born") 274 | 275 | if e != nil { 276 | t.Error(e) 277 | return false 278 | } 279 | 280 | return year > 1970 281 | }). 282 | SelectColumns("id", "name", "surname"). 283 | Join(orders, "id"). 284 | DropColumns("ts", "order_id", "cust_id"). 285 | Join(products). 286 | DropColumns("prod_id"). 287 | Map(func(row Row) Row { 288 | if row["name"] == "Amelia" { 289 | row["name"] = "Julia" 290 | } 291 | 292 | return row 293 | }). 294 | Filter(Like(Row{"surname": "Smith"})). 295 | Top(10). 296 | DropColumns("id")(func(row Row) error { 297 | if n++; n > 10 { 298 | return errors.New("Too many rows") 299 | } 300 | 301 | if row["surname"] != "Smith" { 302 | return errors.New(`Surname "Smith" not found`) 303 | } 304 | 305 | if row["name"] == "Amelia" { 306 | return errors.New(`Name "Amelia" found`) 307 | } 308 | 309 | if vals := row.SelectExisting("born", "ts", "order_id", "prod_id", "cust_id"); len(vals) != 0 { 310 | return errors.New("Some deleted fields are still there") 311 | } 312 | 313 | if len(row) != 5 { // name, surname, qty, product, price 314 | return fmt.Errorf("Unexpected number of columns: %d instead of 5", len(row)) 315 | } 316 | 317 | return nil 318 | }) 319 | 320 | if err != nil { 321 | t.Error(err) 322 | return 323 | } 324 | 325 | // check the original orders 326 | n = 0 327 | 328 | if err = Take(orders)(func(row Row) error { 329 | n++ 330 | 331 | if _, e := row.Select("order_id", "cust_id", "prod_id", "qty", "ts"); e != nil { 332 | return e 333 | } 334 | 335 | return nil 336 | }); err != nil { 337 | t.Error(err) 338 | return 339 | } 340 | 341 | if n != len(ordersData) { 342 | t.Errorf("Unexpected number of orders: %d instead of %d", n, len(ordersData)) 343 | return 344 | } 345 | 346 | // check the original products 347 | n = 0 348 | 349 | if err = Take(products)(func(row Row) error { 350 | n++ 351 | 352 | if _, e := row.Select("prod_id", "product", "price"); e != nil { 353 | return e 354 | } 355 | 356 | return nil 357 | }); err != nil { 358 | t.Error(err) 359 | return 360 | } 361 | 362 | if n != len(stockItems) { 363 | t.Errorf("Unexpected number of products: %d instead of %d", n, len(stockItems)) 364 | return 365 | } 366 | } 367 | 368 | func TestSimpleUniqueJoin(t *testing.T) { 369 | people := Take(FromFile(tempFiles["people"]).SelectColumns("id", "name", "surname")) 370 | orders := Take(FromFile(tempFiles["orders"]).SelectColumns("order_id", "cust_id", "qty")) 371 | 372 | idIndex, err := people.UniqueIndexOn("id") 373 | 374 | if err != nil { 375 | t.Errorf("Cannot create index: %s", err) 376 | return 377 | } 378 | 379 | qtyMap := make([]int, len(peopleData)) 380 | 381 | err = orders.Join(idIndex, "cust_id")(func(row Row) (e error) { 382 | var id, orderID, custID, qty int 383 | 384 | if id, e = row.ValueAsInt("id"); e != nil { 385 | return 386 | } 387 | 388 | if orderID, e = row.ValueAsInt("order_id"); e != nil { 389 | return 390 | } 391 | 392 | if custID, e = row.ValueAsInt("cust_id"); e != nil { 393 | return 394 | } 395 | 396 | if qty, e = row.ValueAsInt("qty"); e != nil { 397 | return 398 | } 399 | 400 | if id >= len(peopleData) { 401 | return fmt.Errorf("Invalid id: %d", id) 402 | } 403 | 404 | if peopleData[id].Name != row.SafeGetValue("name", "") || 405 | peopleData[id].Surname != row.SafeGetValue("surname", "") { 406 | return fmt.Errorf("Invalid parameters associated with id %d", id) 407 | } 408 | 409 | if id != custID { 410 | return fmt.Errorf("id = %d, cust_id = %d", id, custID) 411 | } 412 | 413 | if orderID >= numOrders { 414 | return fmt.Errorf("Invalid order_id: %d", orderID) 415 | } 416 | 417 | if ordersData[orderID].custID != custID { 418 | return fmt.Errorf("cust_id: got %d instead of %d", custID, ordersData[orderID].custID) 419 | } 420 | 421 | if ordersData[orderID].qty != qty { 422 | return fmt.Errorf("qty: got %d instead of %d", qty, ordersData[orderID].qty) 423 | } 424 | 425 | if len(row) != 6 { 426 | return fmt.Errorf("Invalid number of columns: %d", len(row)) 427 | } 428 | 429 | qtyMap[id] += qty 430 | 431 | return 432 | }) 433 | 434 | if err != nil { 435 | t.Errorf("Join failed: %s", err) 436 | return 437 | } 438 | 439 | // check qty map 440 | origMap := make([]int, len(peopleData)) 441 | 442 | for _, data := range ordersData { 443 | origMap[data.custID] += data.qty 444 | } 445 | 446 | for i, qty := range qtyMap { 447 | if qty != origMap[i] { 448 | t.Errorf("qty for id %d: %d instead of %d", i, qty, origMap[i]) 449 | return 450 | } 451 | } 452 | } 453 | 454 | func TestSorted(t *testing.T) { 455 | people := Take(FromFile(tempFiles["people"]).ExpectHeader(map[string]int{ 456 | "name": 1, 457 | "surname": 2, 458 | })) 459 | 460 | // by name, surname 461 | index, err := people.UniqueIndexOn("name", "surname") 462 | 463 | if err != nil { 464 | t.Error(err) 465 | return 466 | } 467 | 468 | if err = Take(index).Top(uint64(len(peopleSurnames)))(func(row Row) error { 469 | if name := row.SafeGetValue("name", "???"); name != "Amelia" { 470 | return errors.New("Unexpected name: " + name) 471 | } 472 | 473 | return nil 474 | }); err != nil { 475 | t.Error(err) 476 | return 477 | } 478 | 479 | // second name, DropWhile() 480 | if err = Take(index). 481 | DropWhile(Like(Row{"name": "Amelia"})). 482 | Top(uint64(len(peopleSurnames)))(func(row Row) error { 483 | if name := row.SafeGetValue("name", "???"); name != "Ava" { 484 | return errors.New("Unexpected name: " + name) 485 | } 486 | 487 | return nil 488 | }); err != nil { 489 | t.Error(err) 490 | return 491 | } 492 | 493 | // by surname, name 494 | index, err = people.UniqueIndexOn("surname", "name") 495 | 496 | if err != nil { 497 | t.Error(err) 498 | return 499 | } 500 | 501 | // take second surname 502 | if err = Take(index). 503 | Drop(uint64(len(peopleNames))). 504 | Top(uint64(len(peopleNames)))(func(row Row) error { 505 | if surname := row.SafeGetValue("surname", "???"); surname != "Davies" { 506 | return errors.New("Unexpected surname: " + surname) 507 | } 508 | 509 | return nil 510 | }); err != nil { 511 | t.Error(err) 512 | return 513 | } 514 | } 515 | 516 | func TestSimpleTotals(t *testing.T) { 517 | orders := Take(FromFile(tempFiles["orders"]).SelectColumns("cust_id", "prod_id", "qty")) 518 | products := Take(FromFile(tempFiles["stock"]).SelectColumns("prod_id", "price")) 519 | 520 | prodIndex, err := products.UniqueIndexOn("prod_id") 521 | 522 | if err != nil { 523 | t.Error(err) 524 | return 525 | } 526 | 527 | totals := make([]float64, len(peopleData)) 528 | 529 | if err = orders.Join(prodIndex)(func(row Row) error { 530 | var id, qty int 531 | var e error 532 | 533 | if id, e = row.ValueAsInt("cust_id"); e != nil { 534 | return fmt.Errorf("cust_id: %s", e) 535 | } 536 | 537 | if id >= len(peopleData) { 538 | return fmt.Errorf("Invalid id: %d", id) 539 | } 540 | 541 | if qty, e = row.ValueAsInt("qty"); e != nil { 542 | return fmt.Errorf("qty: %s", e) 543 | } 544 | 545 | var price float64 546 | 547 | if price, e = row.ValueAsFloat64("price"); e != nil { 548 | return fmt.Errorf("price: %s", e) 549 | } 550 | 551 | totals[id] = price * float64(qty) 552 | return nil 553 | 554 | }); err != nil { 555 | t.Error(err) 556 | return 557 | } 558 | 559 | origTotals := make([]float64, len(peopleData)) 560 | 561 | for _, order := range ordersData { 562 | origTotals[order.custID] = stockItems[order.prodID].price * float64(order.qty) 563 | } 564 | 565 | for id, total := range totals { 566 | if math.Abs((total-origTotals[id])/total) > 1e-6 { 567 | t.Errorf("total for id %d: %f instead of %f", id, total, origTotals[id]) 568 | return 569 | } 570 | } 571 | } 572 | 573 | func TestMultiIndex(t *testing.T) { 574 | source := Take(FromFile(tempFiles["people"]).SelectColumns("id", "name", "surname")) 575 | index, err := source.UniqueIndexOn("name", "surname") 576 | 577 | if err != nil { 578 | t.Error(err) 579 | return 580 | } 581 | 582 | // non-existing name 583 | if err = index.Find("xxx")(neverCalled); err != nil { 584 | t.Error(err) 585 | return 586 | } 587 | 588 | // test sub-index 589 | if err = index.Find("Amelia")(func(row Row) (e error) { 590 | if name := row.SafeGetValue("name", "???"); name != "Amelia" { 591 | e = fmt.Errorf("name: %s instead of Amelia", name) 592 | } 593 | 594 | return 595 | }); err != nil { 596 | t.Error(err) 597 | return 598 | } 599 | 600 | // self-join on existing names 601 | for _, name := range peopleNames { 602 | surnames := map[string]int{} 603 | s := source.Join(index.SubIndex(name)) 604 | 605 | if err = s(func(row Row) error { 606 | surnames[row.SafeGetValue("surname", "???")]++ 607 | return nil 608 | }); err != nil { 609 | t.Error(err) 610 | return 611 | } 612 | 613 | if len(surnames) != len(peopleSurnames) { 614 | t.Errorf(`Name "%s": Invalid number of surnames: %d instead of %d`, name, len(surnames), len(peopleSurnames)) 615 | return 616 | } 617 | 618 | for _, sname := range peopleSurnames { 619 | if count, found := surnames[sname]; !found || count != len(peopleNames) { 620 | t.Errorf(`Name "%s": Surname "%s" found %d times`, name, sname, count) 621 | return 622 | } 623 | } 624 | } 625 | 626 | // find all existing names and surnames 627 | for _, person := range peopleData { 628 | var count int 629 | 630 | if err = index.Find(person.Name, person.Surname)(func(Row) error { 631 | count++ 632 | return nil 633 | }); err != nil { 634 | t.Error(err) 635 | return 636 | } 637 | 638 | if count != 1 { 639 | t.Errorf("%s %s found %d times", person.Name, person.Surname, count) 640 | return 641 | } 642 | } 643 | 644 | // try non-existent name and surname 645 | if err = index.Find("Jack", "xxx")(neverCalled); err != nil { 646 | t.Error(err) 647 | return 648 | } 649 | } 650 | 651 | func TestExcept(t *testing.T) { 652 | const name = "Emily" 653 | 654 | people, err := Take(FromFile(tempFiles["people"]).SelectColumns("id", "name", "surname")). 655 | Filter(Like(Row{"name": name})). 656 | IndexOn("id") 657 | 658 | if err != nil { 659 | t.Error(err) 660 | return 661 | } 662 | 663 | n := 0 664 | 665 | err = Take(FromFile(tempFiles["orders"]).SelectColumns("cust_id", "prod_id", "qty")). 666 | Except(people, "cust_id")(func(row Row) error { 667 | if id, _ := strconv.Atoi(row["cust_id"]); peopleData[id].Name == name { 668 | return fmt.Errorf("Cust. id %d somehow got through", id) 669 | } 670 | 671 | n++ 672 | return nil 673 | }) 674 | 675 | if err != nil { 676 | t.Error(err) 677 | return 678 | } 679 | 680 | // calculate the right number of orders 681 | m := 0 682 | 683 | for _, order := range ordersData { 684 | if peopleData[order.custID].Name != name { 685 | m++ 686 | } 687 | } 688 | 689 | if n != m { 690 | t.Errorf("Unexpected number of orders: %d instead of %d", n, m) 691 | return 692 | } 693 | } 694 | 695 | func TestResolver(t *testing.T) { 696 | source, err := Take(FromFile(tempFiles["people"]).SelectColumns("id", "name", "surname")).ToRows() 697 | 698 | if err != nil { 699 | t.Error(err) 700 | return 701 | } 702 | 703 | for i := 0; i < 1000; i++ { 704 | // copy source 705 | src := make([]Row, len(source)) 706 | 707 | copy(src, source) 708 | 709 | // add random number of duplicates 710 | dup := src[rand.Intn(len(src))] 711 | id, name, surname := dup["id"], dup["name"], dup["surname"] 712 | n := rand.Intn(100) + 1 713 | 714 | for j := 0; j < n; j++ { 715 | k := rand.Intn(len(src)) 716 | src = append(src, dup) 717 | src[k], src[len(src)-1] = src[len(src)-1], src[k] 718 | } 719 | 720 | // index 721 | index, err := TakeRows(src).IndexOn("name", "surname") 722 | 723 | if err != nil { 724 | t.Error(err) 725 | return 726 | } 727 | 728 | // resolve 729 | nc := 0 730 | 731 | if err := index.ResolveDuplicates(func(rows []Row) (Row, error) { 732 | if nc++; nc != 1 { 733 | return nil, errors.New("Unexpected second call to the resolution function") 734 | } 735 | 736 | if len(rows) != n+1 { 737 | return nil, fmt.Errorf("Unexpected number of duplicates: %d instead of %d", len(rows), n+1) 738 | } 739 | 740 | for _, r := range rows { 741 | if r["id"] != id || r["name"] != name || r["surname"] != surname { 742 | return nil, errors.New("Unexpected duplicate: " + r.String()) 743 | } 744 | } 745 | 746 | return rows[0], nil 747 | }); err != nil { 748 | t.Error(err) 749 | return 750 | } 751 | } 752 | } 753 | 754 | func TestTransformedSource(t *testing.T) { 755 | // cust_id, prod_id, amount 756 | amounts, err := createAmountsTable() 757 | 758 | if err != nil { 759 | t.Error(err) 760 | return 761 | } 762 | 763 | // aggregate by cust_id 764 | custAmounts := make([]float64, len(peopleData)) 765 | 766 | if err = amounts(func(row Row) (e error) { 767 | var values []string 768 | 769 | if values, e = row.SelectValues("cust_id", "prod_id", "amount"); e != nil { 770 | return 771 | } 772 | 773 | var amount float64 774 | 775 | if amount, e = strconv.ParseFloat(values[2], 64); e != nil { 776 | return 777 | } 778 | 779 | var cid int 780 | 781 | if cid, e = strconv.Atoi(values[0]); e != nil { 782 | return 783 | } 784 | 785 | custAmounts[cid] += amount 786 | return 787 | }); err != nil { 788 | t.Error(err) 789 | return 790 | } 791 | 792 | // check custAmounts 793 | origCustAmounts := make([]float64, len(peopleData)) 794 | 795 | for _, order := range ordersData { 796 | origCustAmounts[order.custID] += float64(order.qty) * stockItems[order.prodID].price 797 | } 798 | 799 | for i, amount := range origCustAmounts { 800 | if math.Abs((custAmounts[i]-amount)/amount) > 1e-6 { 801 | t.Errorf("Amount mismatch for %s %s (%d): %f instead of %f", 802 | peopleData[i].Name, peopleData[i].Surname, i, custAmounts[i], amount) 803 | return 804 | } 805 | } 806 | } 807 | 808 | func TestErrors(t *testing.T) { 809 | // invalid column name 810 | err := Take(FromFile(tempFiles["people"]).SelectColumns("id", "name", "xxx"))(neverCalled) 811 | 812 | if err == nil || !strings.HasSuffix(err.Error(), "row 1: column not found: xxx") { 813 | t.Error("Unexpected error:", err) 814 | return 815 | } 816 | 817 | // duplicate column name 818 | if err = shouldPanic(func() { 819 | Take(FromFile(tempFiles["people"]).SelectColumns("id", "name", "id")) 820 | }); err != nil { 821 | t.Error(err) 822 | return 823 | } 824 | 825 | // missing column on index 826 | source := Take(FromFile(tempFiles["people"]).SelectColumns("id", "name", "surname")) 827 | 828 | _, err = source.IndexOn("name", "xxx") 829 | 830 | if err == nil || !strings.HasSuffix(err.Error(), `missing column "xxx" while creating an index`) { 831 | t.Error("Unexpected error:", err) 832 | return 833 | } 834 | 835 | // unique index with duplicate keys 836 | _, err = source.UniqueIndexOn("name") 837 | 838 | if err == nil || !strings.Contains(err.Error(), "duplicate value while creating unique index:") { 839 | t.Error(err) 840 | return 841 | } 842 | 843 | var index *Index 844 | 845 | if index, err = source.IndexOn("name"); err != nil { 846 | t.Error(err) 847 | return 848 | } 849 | 850 | if err = index.ResolveDuplicates(func(rows []Row) (Row, error) { 851 | if len(rows) != len(peopleSurnames) { 852 | return nil, fmt.Errorf("Unexpected number of duplicate rows: %d instead of %d", len(rows), len(peopleSurnames)) 853 | } 854 | 855 | return rows[0], nil 856 | }); err != nil { 857 | t.Error(err) 858 | return 859 | } 860 | 861 | if len(index.impl.rows) != len(peopleNames) { 862 | t.Errorf("Unexpected number of rows: %d instead of %d", len(index.impl.rows), len(peopleNames)) 863 | } 864 | 865 | // panics on index 866 | if err = shouldPanic(func() { 867 | source.IndexOn() // empty list of columns 868 | }); err != nil { 869 | t.Error(err) 870 | return 871 | } 872 | 873 | if index, err = source.IndexOn("id"); err != nil { 874 | t.Error(err) 875 | return 876 | } 877 | 878 | if err = shouldPanic(func() { 879 | index.SubIndex("aaa", "bbb") // too many values 880 | }); err != nil { 881 | t.Error(err) 882 | return 883 | } 884 | 885 | // invalid header 886 | people := Take(FromFile(tempFiles["people"]).ExpectHeader(map[string]int{ 887 | "name": 1, 888 | "surname": 3, // wrong column 889 | })) 890 | 891 | err = people(neverCalled) 892 | 893 | if err == nil || !strings.HasSuffix(err.Error(), `row 1: misplaced column "surname": expected at pos. 3, but found at pos. 2`) { 894 | t.Error("Unexpected error:", err) 895 | return 896 | } 897 | 898 | people = Take(FromFile(tempFiles["people"]).ExpectHeader(map[string]int{ 899 | "name": 1, 900 | "surname": 25, // non-existent column 901 | })) 902 | 903 | err = people(neverCalled) 904 | 905 | if err == nil || !strings.HasSuffix(err.Error(), `row 1: misplaced column "surname": expected at pos. 25, but found at pos. 2`) { 906 | t.Error("Unexpected error:", err) 907 | return 908 | } 909 | } 910 | 911 | func TestNumericalConversions(t *testing.T) { 912 | row := Row{"int": "12345", "float": "3.1415926", "string": "xyz"} 913 | 914 | var intVal int 915 | var err error 916 | 917 | if intVal, err = row.ValueAsInt("int"); err != nil { 918 | t.Error("Unexpected error:", err) 919 | return 920 | } 921 | 922 | if intVal != 12345 { 923 | t.Errorf("Unexpected value in integer conversion: %d instead of %s", intVal, row["int"]) 924 | return 925 | } 926 | 927 | if _, err = row.ValueAsInt("string"); err == nil { 928 | t.Error("Missed error in integer conversion") 929 | return 930 | } 931 | 932 | if err.Error() != `column "string": cannot convert "xyz" to integer: invalid syntax` { 933 | t.Error("Unexpected error message in integer conversion:", err) 934 | return 935 | } 936 | 937 | var floatVal float64 938 | 939 | if floatVal, err = row.ValueAsFloat64("float"); err != nil { 940 | t.Error("Unexpected error:", err) 941 | return 942 | } 943 | 944 | if math.Abs(floatVal-3.1415926)/floatVal > 1e-6 { 945 | t.Errorf("Unexpected value in float conversion: %f instead of %s", floatVal, row["float"]) 946 | return 947 | } 948 | 949 | if _, err = row.ValueAsFloat64("string"); err == nil { 950 | t.Error("Missed error in float conversion") 951 | return 952 | } 953 | 954 | if err.Error() != `column "string": cannot convert "xyz" to float: invalid syntax` { 955 | t.Error("Unexpected error message in float conversion:", err) 956 | return 957 | } 958 | } 959 | 960 | func TestIndexStore(t *testing.T) { 961 | const namePrefix = "index" 962 | 963 | // read data and build index 964 | index, err := Take(FromFile(tempFiles["people"]).SelectColumns("id", "name", "surname")).IndexOn("id") 965 | 966 | if err != nil { 967 | t.Error(err) 968 | return 969 | } 970 | 971 | // write index 972 | if tempFiles[namePrefix], err = createTempFile(namePrefix); err != nil { 973 | t.Error(err) 974 | return 975 | } 976 | 977 | if err = index.WriteTo(tempFiles[namePrefix]); err != nil { 978 | t.Error(err) 979 | return 980 | } 981 | 982 | // read index 983 | var index2 *Index 984 | 985 | if index2, err = LoadIndex(tempFiles[namePrefix]); err != nil { 986 | t.Error(err) 987 | return 988 | } 989 | 990 | // compare column names 991 | if len(index.impl.columns) != len(index2.impl.columns) { 992 | t.Errorf("Column number mismatch: %d instead of %d", len(index2.impl.columns), len(index.impl.columns)) 993 | return 994 | } 995 | 996 | for i, c := range index.impl.columns { 997 | if c != index2.impl.columns[i] { 998 | t.Errorf(`Unexpected column name: "%s" instead of "%s"`, index2.impl.columns[i], c) 999 | return 1000 | } 1001 | } 1002 | 1003 | // compare rows 1004 | if len(index.impl.rows) != len(index2.impl.rows) { 1005 | t.Errorf("Rows number mismatch^ %d instead of %d", len(index2.impl.rows), len(index.impl.rows)) 1006 | return 1007 | } 1008 | 1009 | for i, row := range index.impl.rows { 1010 | if row.String() != index2.impl.rows[i].String() { 1011 | t.Errorf(`Mismatching rows at %d:\n\t%s\n\t%s`, i, row.String(), index2.impl.rows[i].String()) 1012 | } 1013 | } 1014 | } 1015 | 1016 | func TestJSONStruct(t *testing.T) { 1017 | var buff bytes.Buffer 1018 | 1019 | // read input .csv and convert to JSON 1020 | err := Take(FromFile(tempFiles["people"]).SelectColumns("name", "surname", "born")).ToJSON(&buff) 1021 | 1022 | if err != nil { 1023 | t.Error(err) 1024 | return 1025 | } 1026 | 1027 | // de-serialise back from JSON to struct slice 1028 | data, err := peopleFromJSON(&buff) 1029 | 1030 | if err != nil { 1031 | t.Error(err) 1032 | return 1033 | } 1034 | 1035 | // validate 1036 | if len(data) != len(peopleData) { 1037 | t.Errorf("Invalid number of records: %d instead of %d", len(data), len(peopleData)) 1038 | return 1039 | } 1040 | 1041 | for i := 0; i < len(data); i++ { 1042 | if data[i].Name != peopleData[i].Name || 1043 | data[i].Surname != peopleData[i].Surname || 1044 | data[i].Born != peopleData[i].Born { 1045 | t.Errorf("Data mismatch: %v instead of %v", data[i], peopleData[i]) 1046 | return 1047 | } 1048 | } 1049 | } 1050 | 1051 | // benchmarks ------------------------------------------------------------------------------------- 1052 | func BenchmarkCreateSmallSingleIndex(b *testing.B) { 1053 | source, err := Take(FromFile(tempFiles["people"]).SelectColumns("id", "name", "surname")).ToRows() 1054 | 1055 | if err != nil { 1056 | b.Error(err) 1057 | return 1058 | } 1059 | 1060 | b.ResetTimer() 1061 | 1062 | for i := 0; i < b.N; i++ { 1063 | var index *Index 1064 | 1065 | if index, err = TakeRows(source).UniqueIndexOn("id"); err != nil { 1066 | b.Error(err) 1067 | return 1068 | } 1069 | 1070 | // just to do something with the index 1071 | if len(index.impl.columns) != 1 || index.impl.columns[0] != "id" || len(index.impl.rows) != len(peopleData) { 1072 | b.Error("Wrong index") 1073 | return 1074 | } 1075 | } 1076 | } 1077 | 1078 | func BenchmarkCreateBiggerMultiIndex(b *testing.B) { 1079 | source, err := Take(FromFile(tempFiles["orders"]).SelectColumns("cust_id", "prod_id", "qty")).ToRows() 1080 | 1081 | if err != nil { 1082 | b.Error(err) 1083 | return 1084 | } 1085 | 1086 | b.ResetTimer() 1087 | 1088 | for i := 0; i < b.N; i++ { 1089 | var index *Index 1090 | 1091 | if index, err = TakeRows(source).IndexOn("cust_id", "prod_id"); err != nil { 1092 | b.Error(err) 1093 | return 1094 | } 1095 | 1096 | // just to do something with the index 1097 | if len(index.impl.columns) != 2 || len(index.impl.rows) != len(ordersData) { 1098 | b.Error("Wrong index") 1099 | return 1100 | } 1101 | } 1102 | } 1103 | 1104 | func BenchmarkSearchSmallSingleIndex(b *testing.B) { 1105 | index, err := Take(FromFile(tempFiles["people"]).SelectColumns("id", "name", "surname")).UniqueIndexOn("id") 1106 | 1107 | if err != nil { 1108 | b.Error(err) 1109 | return 1110 | } 1111 | 1112 | b.ResetTimer() 1113 | 1114 | for i := 0; i < b.N; i++ { 1115 | index.Find("0") 1116 | } 1117 | } 1118 | 1119 | func BenchmarkSearchBiggerMultiIndex(b *testing.B) { 1120 | index, err := Take(FromFile(tempFiles["orders"]).SelectColumns("cust_id", "prod_id", "qty")).IndexOn("cust_id", "prod_id") 1121 | 1122 | if err != nil { 1123 | b.Error(err) 1124 | return 1125 | } 1126 | 1127 | b.ResetTimer() 1128 | 1129 | for i := 0; i < b.N; i++ { 1130 | index.Find("0", "0") 1131 | } 1132 | } 1133 | 1134 | func BenchmarkJoinOnSmallSingleIndex(b *testing.B) { 1135 | source, err := Take(FromFile(tempFiles["orders"]).SelectColumns("cust_id", "prod_id", "qty")).ToRows() 1136 | 1137 | if err != nil { 1138 | b.Error(err) 1139 | return 1140 | } 1141 | 1142 | var index *Index 1143 | 1144 | index, err = Take(FromFile(tempFiles["people"]).SelectColumns("id", "name", "surname")).UniqueIndexOn("id") 1145 | 1146 | if err != nil { 1147 | b.Error(err) 1148 | return 1149 | } 1150 | 1151 | b.ResetTimer() 1152 | 1153 | for i := 0; i < b.N; i++ { 1154 | if err := TakeRows(source).Join(index, "cust_id")(nop); err != nil { 1155 | b.Error(err) 1156 | return 1157 | } 1158 | } 1159 | } 1160 | 1161 | func BenchmarkJoinOnBiggerMultiIndex(b *testing.B) { 1162 | source, err := Take(FromFile(tempFiles["people"]).SelectColumns("id", "name", "surname")).ToRows() 1163 | 1164 | if err != nil { 1165 | b.Error(err) 1166 | return 1167 | } 1168 | 1169 | var index *Index 1170 | 1171 | index, err = Take(FromFile(tempFiles["orders"]).SelectColumns("cust_id", "prod_id", "qty")).IndexOn("cust_id", "prod_id") 1172 | 1173 | if err != nil { 1174 | b.Error(err) 1175 | return 1176 | } 1177 | 1178 | b.ResetTimer() 1179 | 1180 | for i := 0; i < b.N; i++ { 1181 | if err := TakeRows(source).Join(index, "id")(nop); err != nil { 1182 | b.Error(err) 1183 | return 1184 | } 1185 | } 1186 | } 1187 | 1188 | // generated test data ---------------------------------------------------------------------------- 1189 | type personData struct { 1190 | Name, Surname string 1191 | Born int `json:",string"` 1192 | } 1193 | 1194 | var peopleData = make([]personData, len(peopleNames)*len(peopleSurnames)) 1195 | 1196 | type orderData struct { 1197 | custID, prodID, qty int 1198 | ts time.Time 1199 | } 1200 | 1201 | const numOrders = 10000 1202 | 1203 | var ordersData [numOrders]orderData 1204 | 1205 | // people.csv ------------------------------------------------------------------------------------- 1206 | // http://www.ukbabynames.com/ 1207 | var peopleNames = [...]string{ 1208 | "Amelia", "Olivia", "Emily", "Ava", "Isla", 1209 | "Oliver", "Jack", "Harry", "Jacob", "Charlie", 1210 | } 1211 | 1212 | // http://surname.sofeminine.co.uk/w/surnames/most-common-surnames-in-great-britain.html 1213 | var peopleSurnames = [...]string{ 1214 | "Smith", "Jones", "Taylor", "Williams", "Brown", "Davies", 1215 | "Evans", "Wilson", "Thomas", "Roberts", "Johnson", "Lewis", 1216 | } 1217 | 1218 | var peopleHeader = []string{"id", "name", "surname", "born"} 1219 | 1220 | func makePersonsCsvFile() error { 1221 | return withTempFileWriter("people", func(out *csv.Writer) error { 1222 | // header 1223 | if err := out.Write(peopleHeader); err != nil { 1224 | return err 1225 | } 1226 | 1227 | // body 1228 | for i, name := range peopleNames { 1229 | for j, surname := range peopleSurnames { 1230 | id := i*len(peopleSurnames) + j 1231 | 1232 | peopleData[id] = personData{ 1233 | Name: name, 1234 | Surname: surname, 1235 | Born: 1916 + rand.Intn(90), // at least 10 years old 1236 | } 1237 | 1238 | person := &peopleData[id] 1239 | 1240 | if err := out.Write([]string{ 1241 | strconv.Itoa(id), 1242 | person.Name, 1243 | person.Surname, 1244 | strconv.Itoa(person.Born), 1245 | }); err != nil { 1246 | return err 1247 | } 1248 | } 1249 | } 1250 | 1251 | return nil 1252 | }) 1253 | } 1254 | 1255 | func peopleFromJSON(in io.Reader) (pd []personData, err error) { 1256 | err = json.NewDecoder(in).Decode(&pd) 1257 | return 1258 | } 1259 | 1260 | // stock.csv ------------------------------------------------------------------------------------ 1261 | var stockItems = [...]struct { 1262 | name string 1263 | price float64 1264 | }{ 1265 | {"banana", 0.01}, 1266 | {"apple", 0.02}, 1267 | {"orange", 0.03}, 1268 | {"pea", 0.04}, 1269 | {"tomato", 0.05}, 1270 | {"potato", 0.06}, 1271 | {"cucumber", 0.07}, 1272 | {"iPhone", 0.08}, 1273 | } 1274 | 1275 | var stockItemsHeader = []string{"prod_id", "product", "price"} 1276 | 1277 | func makeStockCsvFile() error { 1278 | return withTempFileWriter("stock", func(out *csv.Writer) error { 1279 | // header 1280 | if err := out.Write(stockItemsHeader); err != nil { 1281 | return err 1282 | } 1283 | 1284 | // body 1285 | for i, item := range stockItems { 1286 | price := strconv.FormatFloat(item.price, 'f', 2, 64) 1287 | 1288 | if err := out.Write([]string{strconv.Itoa(i), item.name, price}); err != nil { 1289 | return err 1290 | } 1291 | } 1292 | 1293 | return nil 1294 | }) 1295 | } 1296 | 1297 | // orders.csv ----------------------------------------------------------------------------------- 1298 | var orderHeader = []string{"order_id", "cust_id", "prod_id", "qty", "ts"} 1299 | 1300 | func makeOrderCsvFile() error { 1301 | return withTempFileWriter("orders", func(out *csv.Writer) error { 1302 | // header 1303 | if err := out.Write(orderHeader); err != nil { 1304 | return err 1305 | } 1306 | 1307 | // body 1308 | now := time.Now() 1309 | 1310 | for i := 0; i < numOrders; i++ { 1311 | ordersData[i] = orderData{ 1312 | custID: rand.Intn(len(peopleNames) * len(peopleSurnames)), 1313 | prodID: rand.Intn(len(stockItems)), 1314 | qty: rand.Intn(100) + 1, 1315 | ts: now.Add(-time.Second * time.Duration(rand.Intn(100000)+1)), 1316 | } 1317 | 1318 | order := &ordersData[i] 1319 | 1320 | if err := out.Write([]string{ 1321 | strconv.Itoa(i), 1322 | strconv.Itoa(order.custID), 1323 | strconv.Itoa(order.prodID), 1324 | strconv.Itoa(order.qty), 1325 | order.ts.Format(time.RFC3339), 1326 | }); err != nil { 1327 | return err 1328 | } 1329 | } 1330 | 1331 | return nil 1332 | }) 1333 | } 1334 | 1335 | // tests set-up ----------------------------------------------------------------------------------- 1336 | func TestMain(m *testing.M) { 1337 | os.Exit(runTests(m)) 1338 | } 1339 | 1340 | func runTests(m *testing.M) int { 1341 | defer deleteTemps() 1342 | 1343 | if err := anyFrom(makePersonsCsvFile, makeOrderCsvFile, makeStockCsvFile); err != nil { 1344 | panic(err) 1345 | } 1346 | 1347 | saveTemps := flag.Bool("save-temps", false, "Save all generated temporary files") 1348 | flag.Parse() 1349 | 1350 | if *saveTemps { 1351 | if err := saveTempFiles(); err != nil { 1352 | panic(err) 1353 | } 1354 | } 1355 | 1356 | return m.Run() 1357 | } 1358 | 1359 | // helpers ---------------------------------------------------------------------------------------- 1360 | func anyFrom(funcs ...func() error) error { 1361 | for _, fn := range funcs { 1362 | if err := fn(); err != nil { 1363 | return err 1364 | } 1365 | } 1366 | 1367 | return nil 1368 | } 1369 | 1370 | func sortedCopy(list []string) (r []string) { 1371 | r = make([]string, len(list)) 1372 | 1373 | copy(r, list) 1374 | sort.Strings(r) 1375 | return 1376 | } 1377 | 1378 | func createTempFile(prefix string) (name string, err error) { 1379 | var file *os.File 1380 | 1381 | if file, err = os.CreateTemp("", prefix); err != nil { 1382 | return 1383 | } 1384 | 1385 | name = file.Name() 1386 | file.Close() 1387 | return 1388 | } 1389 | 1390 | // cust_id, prod_id, amount 1391 | func createAmountsTable() (amounts DataSource, err error) { 1392 | var prodIndex *Index 1393 | 1394 | prodIndex, err = Take(FromFile(tempFiles["stock"]).SelectColumns("prod_id", "price")).UniqueIndexOn("prod_id") 1395 | 1396 | if err == nil { 1397 | amounts = Take(FromFile(tempFiles["orders"]).SelectColumns("cust_id", "prod_id", "qty")). 1398 | Join(prodIndex). 1399 | Transform(func(row Row) (Row, error) { 1400 | var qty int 1401 | var price float64 1402 | var e error 1403 | 1404 | if qty, e = row.ValueAsInt("qty"); e != nil { 1405 | return nil, e 1406 | } 1407 | 1408 | if price, e = row.ValueAsFloat64("price"); e != nil { 1409 | return nil, e 1410 | } 1411 | 1412 | row["amount"] = strconv.FormatFloat(price*float64(qty), 'f', 2, 64) 1413 | 1414 | delete(row, "price") 1415 | delete(row, "qty") 1416 | return row, nil 1417 | }) 1418 | } 1419 | 1420 | return 1421 | } 1422 | 1423 | func neverCalled(Row) error { return errors.New("This must never be called") } 1424 | 1425 | func nop(Row) error { return nil } 1426 | 1427 | func shouldPanic(fn func()) error { 1428 | defer func() { 1429 | _ = recover() 1430 | }() 1431 | 1432 | fn() 1433 | return errors.New("Panic did not happen") 1434 | } 1435 | 1436 | // temporary files -------------------------------------------------------------------------------- 1437 | var tempFiles = make(map[string]string, 10) 1438 | 1439 | func deleteTemps() { 1440 | for _, name := range tempFiles { 1441 | os.Remove(name) 1442 | } 1443 | } 1444 | 1445 | func withTempFileWriter(name string, fn func(*csv.Writer) error) (err error) { 1446 | var file *os.File 1447 | 1448 | if file, err = os.CreateTemp("", name); err != nil { 1449 | return err 1450 | } 1451 | 1452 | defer func() { 1453 | if e := file.Close(); e != nil && err == nil { 1454 | err = e 1455 | } 1456 | }() 1457 | 1458 | tempFiles[name] = file.Name() 1459 | 1460 | out := csv.NewWriter(file) 1461 | 1462 | defer func() { 1463 | if err == nil { 1464 | out.Flush() 1465 | err = out.Error() 1466 | } 1467 | }() 1468 | 1469 | err = fn(out) 1470 | return 1471 | } 1472 | 1473 | func saveTempFiles() error { 1474 | for name, fileName := range tempFiles { 1475 | dest := name + ".csv" 1476 | 1477 | os.Remove(dest) // otherwise os.Link fails 1478 | 1479 | if err := os.Link(fileName, dest); err != nil { 1480 | return err 1481 | } 1482 | } 1483 | 1484 | return nil 1485 | } 1486 | --------------------------------------------------------------------------------