├── README.md ├── CMakeLists.txt ├── main.cpp ├── values_4_fill.hpp ├── matrix.hpp ├── NeuralNetwork.hpp └── rapidcsv.h /README.md: -------------------------------------------------------------------------------- 1 | # HLS_Transformer 2 | c++ version of ViT (https://openreview.net/pdf?id=YicbFdNTTy) 3 | 4 | For FPGA HLS 5 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.20) 2 | project(NN_3) 3 | 4 | set(CMAKE_CXX_STANDARD 14) 5 | 6 | add_executable(NN_3 main.cpp values_4_fill.hpp params.h) 7 | add_compile_options(-bigobj) -------------------------------------------------------------------------------- /main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "NeuralNetwork.hpp" 3 | #include 4 | #include 5 | 6 | int main() 7 | { 8 | std::vector topology = {64,64,256,64}; 9 | 10 | std::vector Blocklist = {}; 11 | sp::Block block_0(topology); 12 | sp::Block block_1(topology); 13 | sp::Block block_2(topology); 14 | sp::Block block_3(topology); 15 | sp::Block block_4(topology); 16 | sp::Block block_5(topology); 17 | sp::Block block_6(topology); 18 | sp::Block block_7(topology); 19 | sp::Block block_8(topology); 20 | sp::Block block_9(topology); 21 | sp::Block block_10(topology); 22 | sp::Block block_11(topology); 23 | 24 | Blocklist.push_back(block_0); 25 | Blocklist.push_back(block_1); 26 | Blocklist.push_back(block_2); 27 | Blocklist.push_back(block_3); 28 | Blocklist.push_back(block_4); 29 | Blocklist.push_back(block_5); 30 | Blocklist.push_back(block_6); 31 | Blocklist.push_back(block_7); 32 | Blocklist.push_back(block_8); 33 | Blocklist.push_back(block_9); 34 | Blocklist.push_back(block_10); 35 | Blocklist.push_back(block_11); 36 | 37 | 38 | int num_blks = Blocklist.size(); 39 | 40 | 41 | sp::SimpleNN nn(Blocklist, 1.0f); 42 | 43 | 44 | sp::Matrix2D input (4,300); 45 | std::fill(input._vals.begin(), input._vals.end(), 0.5); 46 | 47 | std::cout << "training start\n"; 48 | nn.feedForward(input, num_blks); 49 | 50 | // test 51 | std::vector preds = nn.getPredictions(); 52 | std::cout << "training complete\n"; 53 | std::cout << preds[0] <<','< 5 | #include 6 | #include "rapidcsv.h" 7 | 8 | //pos_embedding, (301, 64) 9 | std::vector posit {}; 10 | //cls_token, (1, 64) 11 | std::vector cls {}; 12 | //init weight, (64, 4) 13 | std::vector ini_w {}; 14 | //init bias, (64) 15 | std::vector ini_b {}; 16 | 17 | //BLOCK 18 | //block_0 19 | std::vector blk0_norm_w0 {}; //(64) 20 | std::vector blk0_norm_b0 {}; //(64) 21 | std::vector blk0_norm_w1 {}; //(64) 22 | std::vector blk0_norm_b1 {}; //(64) 23 | 24 | std::vector blk0_w0 {}; // to_qkv, No bias (64,64) 25 | std::vector blk0_w1 {}; //(64,256) 26 | std::vector blk0_b1 {}; //(64,1) 27 | std::vector blk0_w2 {}; //(256,64) 28 | std::vector blk0_b2 {}; //(256,1) 29 | 30 | //block_1 31 | std::vector blk1_norm_w0 {}; 32 | std::vector blk1_norm_b0 {}; 33 | std::vector blk1_norm_w1 {}; 34 | std::vector blk1_norm_b1 {}; 35 | 36 | std::vector blk1_w0 {}; // to_qkv, No bias 37 | std::vector blk1_w1 {}; 38 | std::vector blk1_b1 {}; 39 | std::vector blk1_w2 {}; 40 | std::vector blk1_b2 {}; 41 | 42 | //block_2 43 | std::vector blk2_norm_w0 {}; 44 | std::vector blk2_norm_b0 {}; 45 | std::vector blk2_norm_w1 {}; 46 | std::vector blk2_norm_b1 {}; 47 | 48 | std::vector blk2_w0 {}; // to_qkv, No bias 49 | std::vector blk2_w1 {}; 50 | std::vector blk2_b1 {}; 51 | std::vector blk2_w2 {}; 52 | std::vector blk2_b2 {}; 53 | 54 | //block_3 55 | std::vector blk3_norm_w0 {}; 56 | std::vector blk3_norm_b0 {}; 57 | std::vector blk3_norm_w1 {}; 58 | std::vector blk3_norm_b1 {}; 59 | 60 | std::vector blk3_w0 {}; // to_qkv, No bias 61 | std::vector blk3_w1 {}; 62 | std::vector blk3_b1 {}; 63 | std::vector blk3_w2 {}; 64 | std::vector blk3_b2 {}; 65 | 66 | //block_4 67 | std::vector blk4_norm_w0 {}; 68 | std::vector blk4_norm_b0 {}; 69 | std::vector blk4_norm_w1 {}; 70 | std::vector blk4_norm_b1 {}; 71 | 72 | std::vector blk4_w0 {}; // to_qkv, No bias 73 | std::vector blk4_w1 {}; 74 | std::vector blk4_b1 {}; 75 | std::vector blk4_w2 {}; 76 | std::vector blk4_b2 {}; 77 | 78 | //block_5 79 | std::vector blk5_norm_w0 {}; 80 | std::vector blk5_norm_b0 {}; 81 | std::vector blk5_norm_w1 {}; 82 | std::vector blk5_norm_b1 {}; 83 | 84 | std::vector blk5_w0 {}; // to_qkv, No bias 85 | std::vector blk5_w1 {}; 86 | std::vector blk5_b1 {}; 87 | std::vector blk5_w2 {}; 88 | std::vector blk5_b2 {}; 89 | 90 | //block_6 91 | std::vector blk6_norm_w0 {}; 92 | std::vector blk6_norm_b0 {}; 93 | std::vector blk6_norm_w1 {}; 94 | std::vector blk6_norm_b1 {}; 95 | 96 | std::vector blk6_w0 {}; // to_qkv, No bias 97 | std::vector blk6_w1 {}; 98 | std::vector blk6_b1 {}; 99 | std::vector blk6_w2 {}; 100 | std::vector blk6_b2 {}; 101 | 102 | //block_7 103 | std::vector blk7_norm_w0 {}; 104 | std::vector blk7_norm_b0 {}; 105 | std::vector blk7_norm_w1 {}; 106 | std::vector blk7_norm_b1 {}; 107 | 108 | std::vector blk7_w0 {}; // to_qkv, No bias 109 | std::vector blk7_w1 {}; 110 | std::vector blk7_b1 {}; 111 | std::vector blk7_w2 {}; 112 | std::vector blk7_b2 {}; 113 | 114 | //block_8 115 | std::vector blk8_norm_w0 {}; 116 | std::vector blk8_norm_b0 {}; 117 | std::vector blk8_norm_w1 {}; 118 | std::vector blk8_norm_b1 {}; 119 | 120 | std::vector blk8_w0 {}; // to_qkv, No bias 121 | std::vector blk8_w1 {}; 122 | std::vector blk8_b1 {}; 123 | std::vector blk8_w2 {}; 124 | std::vector blk8_b2 {}; 125 | 126 | //block_9 127 | std::vector blk9_norm_w0 {}; 128 | std::vector blk9_norm_b0 {}; 129 | std::vector blk9_norm_w1 {}; 130 | std::vector blk9_norm_b1 {}; 131 | 132 | std::vector blk9_w0 {}; // to_qkv, No bias 133 | std::vector blk9_w1 {}; 134 | std::vector blk9_b1 {}; 135 | std::vector blk9_w2 {}; 136 | std::vector blk9_b2 {}; 137 | 138 | //block_10 139 | std::vector blk10_norm_w0 {}; 140 | std::vector blk10_norm_b0 {}; 141 | std::vector blk10_norm_w1 {}; 142 | std::vector blk10_norm_b1 {}; 143 | 144 | std::vector blk10_w0 {}; // to_qkv, No bias 145 | std::vector blk10_w1 {}; 146 | std::vector blk10_b1 {}; 147 | std::vector blk10_w2 {}; 148 | std::vector blk10_b2 {}; 149 | 150 | //block_11 151 | std::vector blk11_norm_w0 {}; 152 | std::vector blk11_norm_b0 {}; 153 | std::vector blk11_norm_w1 {}; 154 | std::vector blk11_norm_b1 {}; 155 | 156 | std::vector blk11_w0 {}; // to_qkv, No bias 157 | std::vector blk11_w1 {}; 158 | std::vector blk11_b1 {}; 159 | std::vector blk11_w2 {}; 160 | std::vector blk11_b2 {}; 161 | -------------------------------------------------------------------------------- /matrix.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | 9 | namespace sp 10 | { 11 | template 12 | class Matrix2D 13 | { 14 | public: 15 | uint32_t _cols; 16 | uint32_t _rows; 17 | std::vector _vals; 18 | 19 | 20 | public: 21 | Matrix2D(uint32_t cols, uint32_t rows) 22 | : _cols(cols), 23 | _rows(rows), 24 | _vals({}) 25 | { 26 | _vals.resize(rows * cols, T()); 27 | } 28 | 29 | Matrix2D() 30 | : _cols(0), 31 | _rows(0), 32 | _vals({}) 33 | { 34 | } 35 | 36 | T& at(uint32_t col, uint32_t row) 37 | { 38 | return _vals[row * _cols + col]; 39 | } 40 | 41 | bool isSquare() 42 | { 43 | return _rows == _cols; 44 | } 45 | 46 | 47 | Matrix2D negetive() 48 | { 49 | Matrix2D output(_cols, _rows); 50 | for (uint32_t y = 0; y < output._rows; y++) 51 | for (uint32_t x = 0; x < output._cols; x++) 52 | { 53 | output.at(x, y) = -at(x, y); 54 | } 55 | return output; 56 | } 57 | 58 | Matrix2D fill_value(std::vector value_vec) 59 | { 60 | Matrix2D output(_cols, _rows); 61 | for (int i=0; i func) 122 | { 123 | Matrix2D output(_cols, _rows); 124 | for (uint32_t y = 0; y < output._rows; y++) 125 | for (uint32_t x = 0; x < output._cols; x++) 126 | { 127 | output.at(x, y) = func(at(x, y)); 128 | } 129 | return output; 130 | } 131 | 132 | Matrix2D multiplyScaler(float s) 133 | { 134 | Matrix2D output(_cols, _rows); 135 | for (uint32_t y = 0; y < output._rows; y++) 136 | for (uint32_t x = 0; x < output._cols; x++) 137 | { 138 | output.at(x, y) = at(x, y) * s; 139 | } 140 | return output; 141 | 142 | } 143 | 144 | Matrix2D addScaler(float s) 145 | { 146 | Matrix2D output(_cols, _rows); 147 | for (uint32_t y = 0; y < output._rows; y++) 148 | for (uint32_t x = 0; x < output._cols; x++) 149 | { 150 | output.at(x, y) = at(x, y) + s; 151 | } 152 | return output; 153 | 154 | } 155 | Matrix2D transpose() 156 | { 157 | Matrix2D output(_rows, _cols); 158 | for (uint32_t y = 0; y < _rows; y++) 159 | for (uint32_t x = 0; x < _cols; x++) 160 | { 161 | output.at(y, x) = at(x, y); 162 | } 163 | return output; 164 | } 165 | 166 | }; // class Matrix2D 167 | 168 | template 169 | void LogMatrix2D(Matrix2D& mat) 170 | { 171 | for (uint32_t y = 0; y < mat._rows; y++) 172 | { 173 | for (uint32_t x = 0; x < mat._cols; x++) 174 | std::cout << std::setw(10) << mat.at(x, y) << " "; 175 | std::cout << std::endl; 176 | } 177 | } 178 | 179 | } 180 | -------------------------------------------------------------------------------- /NeuralNetwork.hpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 12192 on 2021/9/25. 3 | // 4 | #pragma once 5 | #include "matrix.hpp" 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | //#include "params.h" 17 | //#include "rapidcsv.h" 18 | 19 | 20 | namespace sp { 21 | 22 | //simple activation function 23 | inline float Sigmoid(float x) { 24 | return 1.0f / (1 + exp(-x)); 25 | } 26 | 27 | //derivative of activation function 28 | // x = sigmoid(input); 29 | inline float DSigmoid(float x) { 30 | return (x * (1 - x)); 31 | } 32 | 33 | // multi-label activation 34 | inline Matrix2D Softmax(sp::Matrix2D input){ 35 | sp::Matrix2D exp_input(input._cols, input._rows); 36 | for(int i=0; i result(input._cols, input._rows); 41 | for(int i=0; i D_Softmax(sp::Matrix2D input){ //input:(1row, 3col) 48 | Matrix2D tensor_1 (input._cols, input._cols);//3*3 49 | for(int k=0; k tensor_1_T = tensor_1.transpose(); 56 | tensor_1 = tensor_1.multiplyElements(tensor_1_T); 57 | 58 | Matrix2D tensor_2 (input._cols, input._cols); 59 | std::fill_n(tensor_2._vals.begin(), input._cols * input._rows, 0); 60 | for (int i=0; i tensor_1_nega = tensor_1.negetive(); 64 | Matrix2D result = tensor_2.add(tensor_1_nega); // tsr2-tsr1 65 | return result; //3*3 66 | 67 | } 68 | 69 | inline sp::Matrix2DAttention(sp::Matrix2D input) 70 | { 71 | sp::Matrix2D q(input); //copy input to q 72 | sp::Matrix2D k(input); //copy input to k 73 | sp::Matrix2D v(input); //copy input to v 74 | 75 | sp::Matrix2D atten = q.transpose().multiply(k); // atten = k * qT 76 | atten = atten.multiplyScaler(0.125); // atten * 1/8 77 | std::cout<< "q*k is ok..."< atten_sig = atten.applyFunction(Sigmoid); // square: attention score 79 | sp::Matrix2D atten_sig = Softmax(atten); // square: attention score, (301,301) 80 | std::cout<< "softmax(q*k) is ok..."< attention = v.multiply(atten_sig); // same shape as input, (301,64) 82 | std::cout<< "softmax(q*k)*v is ok..."< Layer_Norm (sp::Matrix2D input, sp::Matrix2D weight, sp::Matrix2D bias) 88 | { 89 | assert(input._cols == weight._rows && input._cols == bias._rows); 90 | sp::Matrix2D result (input._cols, input._rows); // norm in each cols 91 | //weight: (64,1), bias:(64,1) 92 | float epsln = 1e-05; 93 | 94 | for (int col_id=0; col_id celoss(sp::Matrix2D pred, sp::Matrix2D target) 117 | { 118 | sp::Matrix2D log_pred(pred._cols, 1); 119 | for(int i=0; i error = log_pred.multiplyElements(target); 123 | error = error.negetive(); 124 | return error; 125 | } 126 | 127 | inline sp::Matrix2D Cat (sp::Matrix2D input_1, sp::Matrix2D input_2){ 128 | sp::Matrix2D result(input_1._cols, input_1._rows+input_2._rows); 129 | 130 | for(int row_id=0; row_id ReLU (sp::Matrix2D input){ 144 | sp::Matrix2D result (input._cols, input._rows); 145 | 146 | for(int i=0; i0){ 148 | result._vals[i] = input._vals[i]; 149 | } 150 | else{ 151 | result._vals[i] = 0; 152 | } 153 | } 154 | return result; 155 | } 156 | 157 | inline std::vector txt_2_vec (std::string filemane){ 158 | std::string x; 159 | std::ifstream inFile(filemane); 160 | std::vector result; 161 | if (!inFile){ 162 | std::cout << "Unable to open file! "; 163 | } 164 | 165 | while(std::getline(inFile, x, ',')) 166 | { 167 | result.push_back(atof(x.c_str())); 168 | } 169 | return result; 170 | inFile.close(); 171 | 172 | } 173 | 174 | // ================================================================================ 175 | 176 | 177 | // ================================================================================ 178 | class Block { 179 | public: 180 | std::vector> _weightMatrices; 181 | std::vector> _valueMatrices; 182 | std::vector> _biasMatrices; 183 | std::vector> _norm_weight; 184 | std::vector> _norm_bias; 185 | std::vector> _class_out; 186 | 187 | std::vector _topology; 188 | 189 | public: 190 | Block(std::vector topolo) : 191 | _topology(topolo), 192 | _weightMatrices({}), 193 | _valueMatrices({}), 194 | _biasMatrices({}), 195 | _norm_weight({}), 196 | _norm_bias({}), 197 | _class_out({}) 198 | 199 | { 200 | for (int i = 0; i < topolo.size() - 1; i++) { 201 | Matrix2D weightMatrix(topolo[i + 1], topolo[i]); //3 202 | _weightMatrices.push_back(weightMatrix); 203 | 204 | Matrix2D biasMatrix(1, topolo[i + 1]); //3 205 | _biasMatrices.push_back(biasMatrix); 206 | } 207 | for (int i = 0; i < 2; i++){ 208 | Matrix2D norm(1,64); 209 | _norm_weight.push_back(norm); 210 | _norm_bias.push_back(norm); 211 | } 212 | _valueMatrices.resize(topolo.size()); //4, num of element, not M shape. 213 | _class_out.resize(1); 214 | } 215 | 216 | }; 217 | 218 | 219 | class SimpleNN { 220 | public: 221 | std::vector _Blocklist; 222 | float _learningRate; 223 | 224 | public: 225 | SimpleNN(std::vector Blocklist, float learningRate = 0.1f) : 226 | _Blocklist(Blocklist), 227 | _learningRate(learningRate) { 228 | 229 | } 230 | 231 | bool feedForward(sp::Matrix2D values, int num_blk) { 232 | sp::Matrix2D class_tok (64, 1); 233 | sp::Matrix2D cat_class_tok (64, 301); 234 | sp::Matrix2D posit_enc (64, 301); 235 | sp::Matrix2D add_posit_enc (64, 301); 236 | sp::Matrix2D ini_weight (64, 4); 237 | sp::Matrix2D ini_bias (1, 64); 238 | sp::Matrix2D atten_input (64, 301); 239 | sp::Matrix2D atten_out (64, 301); 240 | sp::Matrix2D normed (64, 301); 241 | 242 | //=========================================== 243 | // = 244 | // Fill values Out of BLOCKs = 245 | // = 246 | //=========================================== 247 | std::string file; 248 | file = "..\\param\\cls_token.txt"; 249 | std::vector cls_token = txt_2_vec(file); // 64 250 | 251 | file = "..\\param\\pos_embedding.txt"; 252 | std::vector pos_embedding = txt_2_vec(file); // 19264 253 | 254 | file = "..\\param\\init_weight.txt"; 255 | std::vector init_w = txt_2_vec(file); 256 | 257 | file = "..\\param\\init_bias.txt"; 258 | std::vector init_b = txt_2_vec(file); 259 | 260 | class_tok = class_tok.fill_value(cls_token); 261 | posit_enc = posit_enc.fill_value(pos_embedding); 262 | ini_weight = ini_weight.fill_value(init_w); 263 | ini_bias = ini_bias.fill_value(init_b); 264 | 265 | 266 | // ini weight, bias input(300, 4) * w(4, 64) -> (300, 64) 267 | values = values.multiply(ini_weight); 268 | values = values.addBias(ini_bias); 269 | std::cout<<" input success... "< blk0_norm_w0 = txt_2_vec(file); 297 | 298 | file = "..\\param\\BLK " + std::to_string(blk) + "\\blk" 299 | + std::to_string(blk) + "_norm_b0.txt"; 300 | std::vector blk0_norm_b0 = txt_2_vec(file); 301 | 302 | file = "..\\param\\BLK " + std::to_string(blk) + "\\blk" 303 | + std::to_string(blk) + "_norm_w1.txt"; 304 | std::vector blk0_norm_w1 = txt_2_vec(file); 305 | 306 | file = "..\\param\\BLK " + std::to_string(blk) + "\\blk" 307 | + std::to_string(blk) + "_norm_b1.txt"; 308 | std::vector blk0_norm_b1 = txt_2_vec(file); 309 | 310 | file = "..\\param\\BLK " + std::to_string(blk) + "\\blk" 311 | + std::to_string(blk) + "_w0.txt"; 312 | std::vector blk0_w0 = txt_2_vec(file); 313 | std::cout< blk0_w1 = txt_2_vec(file); 317 | std::cout< blk0_b1 = txt_2_vec(file); 322 | std::cout< blk0_w2 = txt_2_vec(file); 327 | std::cout< blk0_b2 = txt_2_vec(file); 332 | std::cout< block_param_list 334 | _Blocklist[blk]._norm_weight[0] = _Blocklist[blk]._norm_weight[0].fill_value(blk0_norm_w0); 335 | _Blocklist[blk]._norm_bias[0] = _Blocklist[blk]._norm_bias[0].fill_value(blk0_norm_b0); 336 | _Blocklist[blk]._norm_weight[1] = _Blocklist[blk]._norm_weight[1].fill_value(blk0_norm_w1); 337 | _Blocklist[blk]._norm_bias[1] = _Blocklist[blk]._norm_bias[1].fill_value(blk0_norm_b1); 338 | 339 | _Blocklist[blk]._weightMatrices[0] = _Blocklist[blk]._weightMatrices[0].fill_value(blk0_w0); 340 | _Blocklist[blk]._weightMatrices[1] = _Blocklist[blk]._weightMatrices[1].fill_value(blk0_w1); 341 | 342 | _Blocklist[blk]._biasMatrices[1] = _Blocklist[blk]._biasMatrices[1].fill_value(blk0_b1); 343 | _Blocklist[blk]._weightMatrices[2] = _Blocklist[blk]._weightMatrices[2].fill_value(blk0_w2); 344 | _Blocklist[blk]._biasMatrices[2] = _Blocklist[blk]._biasMatrices[2].fill_value(blk0_b2); 345 | 346 | 347 | //================================================= 348 | //================================================= 349 | 350 | _Blocklist[blk]._valueMatrices[0] = values; //(301, 64) 351 | // for (auto value: values._vals) 352 | // std::cout <<"attention input: "<(301, 64) 360 | _Blocklist[blk]._valueMatrices[1] = values; //(301, 64) 361 | atten_out = Attention(values); //(301, 64) 362 | values = atten_out; 363 | std::cout<<" Attention ok..."<(301, 256) 372 | values = values.addBias(_Blocklist[blk]._biasMatrices[1]); //bias(256), is shape of(out_features) 373 | _Blocklist[blk]._valueMatrices[2] = values; 374 | values = ReLU(values); 375 | values = values.multiply(_Blocklist[blk]._weightMatrices[2]); //(301, 256)*(256, 64)->(301, 64) 376 | values = values.addBias(_Blocklist[blk]._biasMatrices[2]); //bias(64) 377 | _Blocklist[blk]._valueMatrices[3] = values; 378 | std::cout<<" MLP Layers finished..."< class_token (64, 1); 387 | sp::Matrix2D final_w (5, 64); 388 | sp::Matrix2D final_b (5,1); 389 | sp::Matrix2D class_out (5,1); 390 | 391 | file = "..\\param\\final_weight.txt"; 392 | std::vector final_weight = txt_2_vec(file); 393 | file = "..\\param\\final_bias.txt"; 394 | std::vector final_bias = txt_2_vec(file); 395 | final_w = final_w.fill_value(final_weight); 396 | final_b = final_b.fill_value(final_bias); 397 | 398 | for (int col_id=0; col_id getPredictions() { 425 | // return _Blocklist.back()._valueMatrices.back()._vals; 426 | return _Blocklist.back()._class_out[0]._vals; 427 | } 428 | 429 | }; 430 | } -------------------------------------------------------------------------------- /rapidcsv.h: -------------------------------------------------------------------------------- 1 | /* 2 | * rapidcsv.h 3 | * 4 | * URL: https://github.com/d99kris/rapidcsv 5 | * Version: 8.53 6 | * 7 | * Copyright (C) 2017-2021 Kristofer Berggren 8 | * All rights reserved. 9 | * 10 | * rapidcsv is distributed under the BSD 3-Clause license, see LICENSE for details. 11 | * 12 | */ 13 | 14 | #pragma once 15 | 16 | #include 17 | #include 18 | #include 19 | #ifdef HAS_CODECVT 20 | #include 21 | #include 22 | #endif 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include 30 | #include 31 | 32 | #if defined(_MSC_VER) 33 | #include 34 | typedef SSIZE_T ssize_t; 35 | #endif 36 | 37 | namespace rapidcsv 38 | { 39 | #if defined(_MSC_VER) 40 | static const bool sPlatformHasCR = true; 41 | #else 42 | static const bool sPlatformHasCR = false; 43 | #endif 44 | 45 | /** 46 | * @brief Datastructure holding parameters controlling how invalid numbers (including 47 | * empty strings) should be handled. 48 | */ 49 | struct ConverterParams 50 | { 51 | /** 52 | * @brief Constructor 53 | * @param pHasDefaultConverter specifies if conversion of non-numerical strings shall be 54 | * converted to a default numerical value, instead of causing 55 | * an exception to be thrown (default). 56 | * @param pDefaultFloat floating-point default value to represent invalid numbers. 57 | * @param pDefaultInteger integer default value to represent invalid numbers. 58 | */ 59 | explicit ConverterParams(const bool pHasDefaultConverter = false, 60 | const long double pDefaultFloat = std::numeric_limits::signaling_NaN(), 61 | const long long pDefaultInteger = 0) 62 | : mHasDefaultConverter(pHasDefaultConverter) 63 | , mDefaultFloat(pDefaultFloat) 64 | , mDefaultInteger(pDefaultInteger) 65 | { 66 | } 67 | 68 | /** 69 | * @brief specifies if conversion of non-numerical strings shall be converted to a default 70 | * numerical value, instead of causing an exception to be thrown (default). 71 | */ 72 | bool mHasDefaultConverter; 73 | 74 | /** 75 | * @brief floating-point default value to represent invalid numbers. 76 | */ 77 | long double mDefaultFloat; 78 | 79 | /** 80 | * @brief integer default value to represent invalid numbers. 81 | */ 82 | long long mDefaultInteger; 83 | }; 84 | 85 | /** 86 | * @brief Exception thrown when attempting to access Document data in a datatype which 87 | * is not supported by the Converter class. 88 | */ 89 | class no_converter : public std::exception 90 | { 91 | /** 92 | * @brief Provides details about the exception 93 | * @returns an explanatory string 94 | */ 95 | virtual const char* what() const throw() 96 | { 97 | return "unsupported conversion datatype"; 98 | } 99 | }; 100 | 101 | /** 102 | * @brief Class providing conversion to/from numerical datatypes and strings. Only 103 | * intended for rapidcsv internal usage, but exposed externally to allow 104 | * specialization for custom datatype conversions. 105 | */ 106 | template 107 | class Converter 108 | { 109 | public: 110 | /** 111 | * @brief Constructor 112 | * @param pConverterParams specifies how conversion of non-numerical values to 113 | * numerical datatype shall be handled. 114 | */ 115 | Converter(const ConverterParams& pConverterParams) 116 | : mConverterParams(pConverterParams) 117 | { 118 | } 119 | 120 | /** 121 | * @brief Converts numerical value to string representation. 122 | * @param pVal numerical value 123 | * @param pStr output string 124 | */ 125 | void ToStr(const T& pVal, std::string& pStr) const 126 | { 127 | if (typeid(T) == typeid(int) || 128 | typeid(T) == typeid(long) || 129 | typeid(T) == typeid(long long) || 130 | typeid(T) == typeid(unsigned) || 131 | typeid(T) == typeid(unsigned long) || 132 | typeid(T) == typeid(unsigned long long) || 133 | typeid(T) == typeid(float) || 134 | typeid(T) == typeid(double) || 135 | typeid(T) == typeid(long double) || 136 | typeid(T) == typeid(char)) 137 | { 138 | std::ostringstream out; 139 | out << pVal; 140 | pStr = out.str(); 141 | } 142 | else 143 | { 144 | throw no_converter(); 145 | } 146 | } 147 | 148 | /** 149 | * @brief Converts string holding a numerical value to numerical datatype representation. 150 | * @param pVal numerical value 151 | * @param pStr output string 152 | */ 153 | void ToVal(const std::string& pStr, T& pVal) const 154 | { 155 | try 156 | { 157 | if (typeid(T) == typeid(int)) 158 | { 159 | pVal = static_cast(std::stoi(pStr)); 160 | return; 161 | } 162 | else if (typeid(T) == typeid(long)) 163 | { 164 | pVal = static_cast(std::stol(pStr)); 165 | return; 166 | } 167 | else if (typeid(T) == typeid(long long)) 168 | { 169 | pVal = static_cast(std::stoll(pStr)); 170 | return; 171 | } 172 | else if (typeid(T) == typeid(unsigned)) 173 | { 174 | pVal = static_cast(std::stoul(pStr)); 175 | return; 176 | } 177 | else if (typeid(T) == typeid(unsigned long)) 178 | { 179 | pVal = static_cast(std::stoul(pStr)); 180 | return; 181 | } 182 | else if (typeid(T) == typeid(unsigned long long)) 183 | { 184 | pVal = static_cast(std::stoull(pStr)); 185 | return; 186 | } 187 | } 188 | catch (...) 189 | { 190 | if (!mConverterParams.mHasDefaultConverter) 191 | { 192 | throw; 193 | } 194 | else 195 | { 196 | pVal = static_cast(mConverterParams.mDefaultInteger); 197 | return; 198 | } 199 | } 200 | 201 | try 202 | { 203 | if (typeid(T) == typeid(float)) 204 | { 205 | pVal = static_cast(std::stof(pStr)); 206 | return; 207 | } 208 | else if (typeid(T) == typeid(double)) 209 | { 210 | pVal = static_cast(std::stod(pStr)); 211 | return; 212 | } 213 | else if (typeid(T) == typeid(long double)) 214 | { 215 | pVal = static_cast(std::stold(pStr)); 216 | return; 217 | } 218 | } 219 | catch (...) 220 | { 221 | if (!mConverterParams.mHasDefaultConverter) 222 | { 223 | throw; 224 | } 225 | else 226 | { 227 | pVal = static_cast(mConverterParams.mDefaultFloat); 228 | return; 229 | } 230 | } 231 | 232 | if (typeid(T) == typeid(char)) 233 | { 234 | pVal = static_cast(pStr[0]); 235 | return; 236 | } 237 | else 238 | { 239 | throw no_converter(); 240 | } 241 | } 242 | 243 | private: 244 | const ConverterParams& mConverterParams; 245 | }; 246 | 247 | /** 248 | * @brief Specialized implementation handling string to string conversion. 249 | * @param pVal string 250 | * @param pStr string 251 | */ 252 | template<> 253 | inline void Converter::ToStr(const std::string& pVal, std::string& pStr) const 254 | { 255 | pStr = pVal; 256 | } 257 | 258 | /** 259 | * @brief Specialized implementation handling string to string conversion. 260 | * @param pVal string 261 | * @param pStr string 262 | */ 263 | template<> 264 | inline void Converter::ToVal(const std::string& pStr, std::string& pVal) const 265 | { 266 | pVal = pStr; 267 | } 268 | 269 | template 270 | using ConvFunc = std::function; 271 | 272 | /** 273 | * @brief Datastructure holding parameters controlling which row and column should be 274 | * treated as labels. 275 | */ 276 | struct LabelParams 277 | { 278 | /** 279 | * @brief Constructor 280 | * @param pColumnNameIdx specifies the zero-based row index of the column labels, setting 281 | * it to -1 prevents column lookup by label name, and gives access 282 | * to all rows as document data. Default: 0 283 | * @param pRowNameIdx specifies the zero-based column index of the row labels, setting 284 | * it to -1 prevents row lookup by label name, and gives access 285 | * to all columns as document data. Default: -1 286 | */ 287 | explicit LabelParams(const int pColumnNameIdx = 0, const int pRowNameIdx = -1) 288 | : mColumnNameIdx(pColumnNameIdx) 289 | , mRowNameIdx(pRowNameIdx) 290 | { 291 | } 292 | 293 | /** 294 | * @brief specifies the zero-based row index of the column labels. 295 | */ 296 | int mColumnNameIdx; 297 | 298 | /** 299 | * @brief specifies the zero-based column index of the row labels. 300 | */ 301 | int mRowNameIdx; 302 | }; 303 | 304 | /** 305 | * @brief Datastructure holding parameters controlling how the CSV data fields are separated. 306 | */ 307 | struct SeparatorParams 308 | { 309 | /** 310 | * @brief Constructor 311 | * @param pSeparator specifies the column separator (default ','). 312 | * @param pTrim specifies whether to trim leading and trailing spaces from 313 | * cells read (default false). 314 | * @param pHasCR specifies whether a new document (i.e. not an existing document read) 315 | * should use CR/LF instead of only LF (default is to use standard 316 | * behavior of underlying platforms - CR/LF for Win, and LF for others). 317 | * @param pQuotedLinebreaks specifies whether to allow line breaks in quoted text (default false) 318 | * @param pAutoQuote specifies whether to automatically dequote data during read, and add 319 | * quotes during write (default true). 320 | */ 321 | explicit SeparatorParams(const char pSeparator = ',', const bool pTrim = false, 322 | const bool pHasCR = sPlatformHasCR, const bool pQuotedLinebreaks = false, 323 | const bool pAutoQuote = true) 324 | : mSeparator(pSeparator) 325 | , mTrim(pTrim) 326 | , mHasCR(pHasCR) 327 | , mQuotedLinebreaks(pQuotedLinebreaks) 328 | , mAutoQuote(pAutoQuote) 329 | { 330 | } 331 | 332 | /** 333 | * @brief specifies the column separator. 334 | */ 335 | char mSeparator; 336 | 337 | /** 338 | * @brief specifies whether to trim leading and trailing spaces from cells read. 339 | */ 340 | bool mTrim; 341 | 342 | /** 343 | * @brief specifies whether new documents should use CR/LF instead of LF. 344 | */ 345 | bool mHasCR; 346 | 347 | /** 348 | * @brief specifies whether to allow line breaks in quoted text. 349 | */ 350 | bool mQuotedLinebreaks; 351 | 352 | /** 353 | * @brief specifies whether to automatically dequote cell data. 354 | */ 355 | bool mAutoQuote; 356 | }; 357 | 358 | /** 359 | * @brief Datastructure holding parameters controlling how special line formats should be 360 | * treated. 361 | */ 362 | struct LineReaderParams 363 | { 364 | /** 365 | * @brief Constructor 366 | * @param pSkipCommentLines specifies whether to skip lines prefixed with 367 | * mCommentPrefix. Default: false 368 | * @param pCommentPrefix specifies which prefix character to indicate a comment 369 | * line. Default: # 370 | * @param pSkipEmptyLines specifies whether to skip empty lines. Default: false 371 | */ 372 | explicit LineReaderParams(const bool pSkipCommentLines = false, 373 | const char pCommentPrefix = '#', 374 | const bool pSkipEmptyLines = false) 375 | : mSkipCommentLines(pSkipCommentLines) 376 | , mCommentPrefix(pCommentPrefix) 377 | , mSkipEmptyLines(pSkipEmptyLines) 378 | { 379 | } 380 | 381 | /** 382 | * @brief specifies whether to skip lines prefixed with mCommentPrefix. 383 | */ 384 | bool mSkipCommentLines; 385 | 386 | /** 387 | * @brief specifies which prefix character to indicate a comment line. 388 | */ 389 | char mCommentPrefix; 390 | 391 | /** 392 | * @brief specifies whether to skip empty lines. 393 | */ 394 | bool mSkipEmptyLines; 395 | }; 396 | 397 | /** 398 | * @brief Class representing a CSV document. 399 | */ 400 | class Document 401 | { 402 | public: 403 | /** 404 | * @brief Constructor 405 | * @param pPath specifies the path of an existing CSV-file to populate the Document 406 | * data with. 407 | * @param pLabelParams specifies which row and column should be treated as labels. 408 | * @param pSeparatorParams specifies which field and row separators should be used. 409 | * @param pConverterParams specifies how invalid numbers (including empty strings) should be 410 | * handled. 411 | * @param pLineReaderParams specifies how special line formats should be treated. 412 | */ 413 | explicit Document(const std::string& pPath = std::string(), 414 | const LabelParams& pLabelParams = LabelParams(), 415 | const SeparatorParams& pSeparatorParams = SeparatorParams(), 416 | const ConverterParams& pConverterParams = ConverterParams(), 417 | const LineReaderParams& pLineReaderParams = LineReaderParams()) 418 | : mPath(pPath) 419 | , mLabelParams(pLabelParams) 420 | , mSeparatorParams(pSeparatorParams) 421 | , mConverterParams(pConverterParams) 422 | , mLineReaderParams(pLineReaderParams) 423 | { 424 | if (!mPath.empty()) 425 | { 426 | ReadCsv(); 427 | } 428 | } 429 | 430 | /** 431 | * @brief Constructor 432 | * @param pStream specifies an input stream to read CSV data from. 433 | * @param pLabelParams specifies which row and column should be treated as labels. 434 | * @param pSeparatorParams specifies which field and row separators should be used. 435 | * @param pConverterParams specifies how invalid numbers (including empty strings) should be 436 | * handled. 437 | * @param pLineReaderParams specifies how special line formats should be treated. 438 | */ 439 | explicit Document(std::istream& pStream, 440 | const LabelParams& pLabelParams = LabelParams(), 441 | const SeparatorParams& pSeparatorParams = SeparatorParams(), 442 | const ConverterParams& pConverterParams = ConverterParams(), 443 | const LineReaderParams& pLineReaderParams = LineReaderParams()) 444 | : mPath() 445 | , mLabelParams(pLabelParams) 446 | , mSeparatorParams(pSeparatorParams) 447 | , mConverterParams(pConverterParams) 448 | , mLineReaderParams(pLineReaderParams) 449 | { 450 | ReadCsv(pStream); 451 | } 452 | 453 | /** 454 | * @brief Read Document data from file. 455 | * @param pPath specifies the path of an existing CSV-file to populate the Document 456 | * data with. 457 | * @param pLabelParams specifies which row and column should be treated as labels. 458 | * @param pSeparatorParams specifies which field and row separators should be used. 459 | * @param pConverterParams specifies how invalid numbers (including empty strings) should be 460 | * handled. 461 | * @param pLineReaderParams specifies how special line formats should be treated. 462 | */ 463 | void Load(const std::string& pPath, 464 | const LabelParams& pLabelParams = LabelParams(), 465 | const SeparatorParams& pSeparatorParams = SeparatorParams(), 466 | const ConverterParams& pConverterParams = ConverterParams(), 467 | const LineReaderParams& pLineReaderParams = LineReaderParams()) 468 | { 469 | mPath = pPath; 470 | mLabelParams = pLabelParams; 471 | mSeparatorParams = pSeparatorParams; 472 | mConverterParams = pConverterParams; 473 | mLineReaderParams = pLineReaderParams; 474 | ReadCsv(); 475 | } 476 | 477 | /** 478 | * @brief Read Document data from stream. 479 | * @param pStream specifies an input stream to read CSV data from. 480 | * @param pLabelParams specifies which row and column should be treated as labels. 481 | * @param pSeparatorParams specifies which field and row separators should be used. 482 | * @param pConverterParams specifies how invalid numbers (including empty strings) should be 483 | * handled. 484 | * @param pLineReaderParams specifies how special line formats should be treated. 485 | */ 486 | void Load(std::istream& pStream, 487 | const LabelParams& pLabelParams = LabelParams(), 488 | const SeparatorParams& pSeparatorParams = SeparatorParams(), 489 | const ConverterParams& pConverterParams = ConverterParams(), 490 | const LineReaderParams& pLineReaderParams = LineReaderParams()) 491 | { 492 | mPath = ""; 493 | mLabelParams = pLabelParams; 494 | mSeparatorParams = pSeparatorParams; 495 | mConverterParams = pConverterParams; 496 | mLineReaderParams = pLineReaderParams; 497 | ReadCsv(pStream); 498 | } 499 | 500 | /** 501 | * @brief Write Document data to file. 502 | * @param pPath optionally specifies the path where the CSV-file will be created 503 | * (if not specified, the original path provided when creating or 504 | * loading the Document data will be used). 505 | */ 506 | void Save(const std::string& pPath = std::string()) 507 | { 508 | if (!pPath.empty()) 509 | { 510 | mPath = pPath; 511 | } 512 | WriteCsv(); 513 | } 514 | 515 | /** 516 | * @brief Write Document data to stream. 517 | * @param pStream specifies an output stream to write the data to. 518 | */ 519 | void Save(std::ostream& pStream) 520 | { 521 | WriteCsv(pStream); 522 | } 523 | 524 | /** 525 | * @brief Clears loaded Document data. 526 | * 527 | */ 528 | void Clear() 529 | { 530 | mData.clear(); 531 | mColumnNames.clear(); 532 | mRowNames.clear(); 533 | #ifdef HAS_CODECVT 534 | mIsUtf16 = false; 535 | mIsLE = false; 536 | #endif 537 | } 538 | 539 | /** 540 | * @brief Get column index by name. 541 | * @param pColumnName column label name. 542 | * @returns zero-based column index. 543 | */ 544 | ssize_t GetColumnIdx(const std::string& pColumnName) const 545 | { 546 | if (mLabelParams.mColumnNameIdx >= 0) 547 | { 548 | if (mColumnNames.find(pColumnName) != mColumnNames.end()) 549 | { 550 | return mColumnNames.at(pColumnName) - (mLabelParams.mRowNameIdx + 1); 551 | } 552 | } 553 | return -1; 554 | } 555 | 556 | /** 557 | * @brief Get column by index. 558 | * @param pColumnIdx zero-based column index. 559 | * @returns vector of column data. 560 | */ 561 | template 562 | std::vector GetColumn(const size_t pColumnIdx) const 563 | { 564 | const ssize_t columnIdx = pColumnIdx + (mLabelParams.mRowNameIdx + 1); 565 | std::vector column; 566 | Converter converter(mConverterParams); 567 | for (auto itRow = mData.begin(); itRow != mData.end(); ++itRow) 568 | { 569 | if (std::distance(mData.begin(), itRow) > mLabelParams.mColumnNameIdx) 570 | { 571 | if (columnIdx < static_cast(itRow->size())) 572 | { 573 | T val; 574 | converter.ToVal(itRow->at(columnIdx), val); 575 | column.push_back(val); 576 | } 577 | else 578 | { 579 | const std::string errStr = "requested column index " + 580 | std::to_string(columnIdx - (mLabelParams.mRowNameIdx + 1)) + " >= " + 581 | std::to_string(itRow->size() - (mLabelParams.mRowNameIdx + 1)) + 582 | " (number of columns on row index " + 583 | std::to_string(std::distance(mData.begin(), itRow) - 584 | (mLabelParams.mColumnNameIdx + 1)) + ")"; 585 | throw std::out_of_range(errStr); 586 | } 587 | } 588 | } 589 | return column; 590 | } 591 | 592 | /** 593 | * @brief Get column by index. 594 | * @param pColumnIdx zero-based column index. 595 | * @param pToVal conversion function. 596 | * @returns vector of column data. 597 | */ 598 | template 599 | std::vector GetColumn(const size_t pColumnIdx, ConvFunc pToVal) const 600 | { 601 | const ssize_t columnIdx = pColumnIdx + (mLabelParams.mRowNameIdx + 1); 602 | std::vector column; 603 | for (auto itRow = mData.begin(); itRow != mData.end(); ++itRow) 604 | { 605 | if (std::distance(mData.begin(), itRow) > mLabelParams.mColumnNameIdx) 606 | { 607 | T val; 608 | pToVal(itRow->at(columnIdx), val); 609 | column.push_back(val); 610 | } 611 | } 612 | return column; 613 | } 614 | 615 | /** 616 | * @brief Get column by name. 617 | * @param pColumnName column label name. 618 | * @returns vector of column data. 619 | */ 620 | template 621 | std::vector GetColumn(const std::string& pColumnName) const 622 | { 623 | const ssize_t columnIdx = GetColumnIdx(pColumnName); 624 | if (columnIdx < 0) 625 | { 626 | throw std::out_of_range("column not found: " + pColumnName); 627 | } 628 | return GetColumn(columnIdx); 629 | } 630 | 631 | /** 632 | * @brief Get column by name. 633 | * @param pColumnName column label name. 634 | * @param pToVal conversion function. 635 | * @returns vector of column data. 636 | */ 637 | template 638 | std::vector GetColumn(const std::string& pColumnName, ConvFunc pToVal) const 639 | { 640 | const ssize_t columnIdx = GetColumnIdx(pColumnName); 641 | if (columnIdx < 0) 642 | { 643 | throw std::out_of_range("column not found: " + pColumnName); 644 | } 645 | return GetColumn(columnIdx, pToVal); 646 | } 647 | 648 | /** 649 | * @brief Set column by index. 650 | * @param pColumnIdx zero-based column index. 651 | * @param pColumn vector of column data. 652 | */ 653 | template 654 | void SetColumn(const size_t pColumnIdx, const std::vector& pColumn) 655 | { 656 | const size_t columnIdx = pColumnIdx + (mLabelParams.mRowNameIdx + 1); 657 | 658 | while (pColumn.size() + (mLabelParams.mColumnNameIdx + 1) > GetDataRowCount()) 659 | { 660 | std::vector row; 661 | row.resize(GetDataColumnCount()); 662 | mData.push_back(row); 663 | } 664 | 665 | if ((columnIdx + 1) > GetDataColumnCount()) 666 | { 667 | for (auto itRow = mData.begin(); itRow != mData.end(); ++itRow) 668 | { 669 | itRow->resize(columnIdx + 1 + (mLabelParams.mRowNameIdx + 1)); 670 | } 671 | } 672 | 673 | Converter converter(mConverterParams); 674 | for (auto itRow = pColumn.begin(); itRow != pColumn.end(); ++itRow) 675 | { 676 | std::string str; 677 | converter.ToStr(*itRow, str); 678 | mData.at(std::distance(pColumn.begin(), itRow) + (mLabelParams.mColumnNameIdx + 1)).at(columnIdx) = str; 679 | } 680 | } 681 | 682 | /** 683 | * @brief Set column by name. 684 | * @param pColumnName column label name. 685 | * @param pColumn vector of column data. 686 | */ 687 | template 688 | void SetColumn(const std::string& pColumnName, const std::vector& pColumn) 689 | { 690 | const ssize_t columnIdx = GetColumnIdx(pColumnName); 691 | if (columnIdx < 0) 692 | { 693 | throw std::out_of_range("column not found: " + pColumnName); 694 | } 695 | SetColumn(columnIdx, pColumn); 696 | } 697 | 698 | /** 699 | * @brief Remove column by index. 700 | * @param pColumnIdx zero-based column index. 701 | */ 702 | void RemoveColumn(const size_t pColumnIdx) 703 | { 704 | const ssize_t columnIdx = pColumnIdx + (mLabelParams.mRowNameIdx + 1); 705 | for (auto itRow = mData.begin(); itRow != mData.end(); ++itRow) 706 | { 707 | itRow->erase(itRow->begin() + columnIdx); 708 | } 709 | } 710 | 711 | /** 712 | * @brief Remove column by name. 713 | * @param pColumnName column label name. 714 | */ 715 | void RemoveColumn(const std::string& pColumnName) 716 | { 717 | ssize_t columnIdx = GetColumnIdx(pColumnName); 718 | if (columnIdx < 0) 719 | { 720 | throw std::out_of_range("column not found: " + pColumnName); 721 | } 722 | 723 | RemoveColumn(columnIdx); 724 | } 725 | 726 | /** 727 | * @brief Insert column at specified index. 728 | * @param pColumnIdx zero-based column index. 729 | * @param pColumn vector of column data (optional argument). 730 | * @param pColumnName column label name (optional argument). 731 | */ 732 | template 733 | void InsertColumn(const size_t pColumnIdx, const std::vector& pColumn = std::vector(), 734 | const std::string& pColumnName = std::string()) 735 | { 736 | const size_t columnIdx = pColumnIdx + (mLabelParams.mRowNameIdx + 1); 737 | 738 | std::vector column; 739 | if (pColumn.empty()) 740 | { 741 | column.resize(GetDataRowCount()); 742 | } 743 | else 744 | { 745 | column.resize(pColumn.size() + (mLabelParams.mColumnNameIdx + 1)); 746 | Converter converter(mConverterParams); 747 | for (auto itRow = pColumn.begin(); itRow != pColumn.end(); ++itRow) 748 | { 749 | std::string str; 750 | converter.ToStr(*itRow, str); 751 | const size_t rowIdx = std::distance(pColumn.begin(), itRow) + (mLabelParams.mColumnNameIdx + 1); 752 | column.at(rowIdx) = str; 753 | } 754 | } 755 | 756 | while (column.size() > GetDataRowCount()) 757 | { 758 | std::vector row; 759 | const size_t columnCount = std::max(static_cast(mLabelParams.mColumnNameIdx + 1), 760 | GetDataColumnCount()); 761 | row.resize(columnCount); 762 | mData.push_back(row); 763 | } 764 | 765 | for (auto itRow = mData.begin(); itRow != mData.end(); ++itRow) 766 | { 767 | const size_t rowIdx = std::distance(mData.begin(), itRow); 768 | itRow->insert(itRow->begin() + columnIdx, column.at(rowIdx)); 769 | } 770 | 771 | if (!pColumnName.empty()) 772 | { 773 | SetColumnName(pColumnIdx, pColumnName); 774 | } 775 | } 776 | 777 | /** 778 | * @brief Get number of data columns (excluding label columns). 779 | * @returns column count. 780 | */ 781 | size_t GetColumnCount() const 782 | { 783 | const ssize_t count = static_cast((mData.size() > 0) ? mData.at(0).size() : 0) - 784 | (mLabelParams.mRowNameIdx + 1); 785 | return (count >= 0) ? count : 0; 786 | } 787 | 788 | /** 789 | * @brief Get row index by name. 790 | * @param pRowName row label name. 791 | * @returns zero-based row index. 792 | */ 793 | ssize_t GetRowIdx(const std::string& pRowName) const 794 | { 795 | if (mLabelParams.mRowNameIdx >= 0) 796 | { 797 | if (mRowNames.find(pRowName) != mRowNames.end()) 798 | { 799 | return mRowNames.at(pRowName) - (mLabelParams.mColumnNameIdx + 1); 800 | } 801 | } 802 | return -1; 803 | } 804 | 805 | /** 806 | * @brief Get row by index. 807 | * @param pRowIdx zero-based row index. 808 | * @returns vector of row data. 809 | */ 810 | template 811 | std::vector GetRow(const size_t pRowIdx) const 812 | { 813 | const ssize_t rowIdx = pRowIdx + (mLabelParams.mColumnNameIdx + 1); 814 | std::vector row; 815 | Converter converter(mConverterParams); 816 | for (auto itCol = mData.at(rowIdx).begin(); itCol != mData.at(rowIdx).end(); ++itCol) 817 | { 818 | if (std::distance(mData.at(rowIdx).begin(), itCol) > mLabelParams.mRowNameIdx) 819 | { 820 | T val; 821 | converter.ToVal(*itCol, val); 822 | row.push_back(val); 823 | } 824 | } 825 | return row; 826 | } 827 | 828 | /** 829 | * @brief Get row by index. 830 | * @param pRowIdx zero-based row index. 831 | * @param pToVal conversion function. 832 | * @returns vector of row data. 833 | */ 834 | template 835 | std::vector GetRow(const size_t pRowIdx, ConvFunc pToVal) const 836 | { 837 | const ssize_t rowIdx = pRowIdx + (mLabelParams.mColumnNameIdx + 1); 838 | std::vector row; 839 | Converter converter(mConverterParams); 840 | for (auto itCol = mData.at(rowIdx).begin(); itCol != mData.at(rowIdx).end(); ++itCol) 841 | { 842 | if (std::distance(mData.at(rowIdx).begin(), itCol) > mLabelParams.mRowNameIdx) 843 | { 844 | T val; 845 | pToVal(*itCol, val); 846 | row.push_back(val); 847 | } 848 | } 849 | return row; 850 | } 851 | 852 | /** 853 | * @brief Get row by name. 854 | * @param pRowName row label name. 855 | * @returns vector of row data. 856 | */ 857 | template 858 | std::vector GetRow(const std::string& pRowName) const 859 | { 860 | ssize_t rowIdx = GetRowIdx(pRowName); 861 | if (rowIdx < 0) 862 | { 863 | throw std::out_of_range("row not found: " + pRowName); 864 | } 865 | return GetRow(rowIdx); 866 | } 867 | 868 | /** 869 | * @brief Get row by name. 870 | * @param pRowName row label name. 871 | * @param pToVal conversion function. 872 | * @returns vector of row data. 873 | */ 874 | template 875 | std::vector GetRow(const std::string& pRowName, ConvFunc pToVal) const 876 | { 877 | ssize_t rowIdx = GetRowIdx(pRowName); 878 | if (rowIdx < 0) 879 | { 880 | throw std::out_of_range("row not found: " + pRowName); 881 | } 882 | return GetRow(rowIdx, pToVal); 883 | } 884 | 885 | /** 886 | * @brief Set row by index. 887 | * @param pRowIdx zero-based row index. 888 | * @param pRow vector of row data. 889 | */ 890 | template 891 | void SetRow(const size_t pRowIdx, const std::vector& pRow) 892 | { 893 | const size_t rowIdx = pRowIdx + (mLabelParams.mColumnNameIdx + 1); 894 | 895 | while ((rowIdx + 1) > GetDataRowCount()) 896 | { 897 | std::vector row; 898 | row.resize(GetDataColumnCount()); 899 | mData.push_back(row); 900 | } 901 | 902 | if (pRow.size() > GetDataColumnCount()) 903 | { 904 | for (auto itRow = mData.begin(); itRow != mData.end(); ++itRow) 905 | { 906 | itRow->resize(pRow.size() + (mLabelParams.mRowNameIdx + 1)); 907 | } 908 | } 909 | 910 | Converter converter(mConverterParams); 911 | for (auto itCol = pRow.begin(); itCol != pRow.end(); ++itCol) 912 | { 913 | std::string str; 914 | converter.ToStr(*itCol, str); 915 | mData.at(rowIdx).at(std::distance(pRow.begin(), itCol) + (mLabelParams.mRowNameIdx + 1)) = str; 916 | } 917 | } 918 | 919 | /** 920 | * @brief Set row by name. 921 | * @param pRowName row label name. 922 | * @param pRow vector of row data. 923 | */ 924 | template 925 | void SetRow(const std::string& pRowName, const std::vector& pRow) 926 | { 927 | ssize_t rowIdx = GetRowIdx(pRowName); 928 | if (rowIdx < 0) 929 | { 930 | throw std::out_of_range("row not found: " + pRowName); 931 | } 932 | return SetRow(rowIdx, pRow); 933 | } 934 | 935 | /** 936 | * @brief Remove row by index. 937 | * @param pRowIdx zero-based row index. 938 | */ 939 | void RemoveRow(const size_t pRowIdx) 940 | { 941 | const ssize_t rowIdx = pRowIdx + (mLabelParams.mColumnNameIdx + 1); 942 | mData.erase(mData.begin() + rowIdx); 943 | } 944 | 945 | /** 946 | * @brief Remove row by name. 947 | * @param pRowName row label name. 948 | */ 949 | void RemoveRow(const std::string& pRowName) 950 | { 951 | ssize_t rowIdx = GetRowIdx(pRowName); 952 | if (rowIdx < 0) 953 | { 954 | throw std::out_of_range("row not found: " + pRowName); 955 | } 956 | 957 | RemoveRow(rowIdx); 958 | } 959 | 960 | /** 961 | * @brief Insert row at specified index. 962 | * @param pRowIdx zero-based row index. 963 | * @param pRow vector of row data (optional argument). 964 | * @param pRowName row label name (optional argument). 965 | */ 966 | template 967 | void InsertRow(const size_t pRowIdx, const std::vector& pRow = std::vector(), 968 | const std::string& pRowName = std::string()) 969 | { 970 | const size_t rowIdx = pRowIdx + (mLabelParams.mColumnNameIdx + 1); 971 | 972 | std::vector row; 973 | if (pRow.empty()) 974 | { 975 | row.resize(GetDataColumnCount()); 976 | } 977 | else 978 | { 979 | row.resize(pRow.size() + (mLabelParams.mRowNameIdx + 1)); 980 | Converter converter(mConverterParams); 981 | for (auto itCol = pRow.begin(); itCol != pRow.end(); ++itCol) 982 | { 983 | std::string str; 984 | converter.ToStr(*itCol, str); 985 | row.at(std::distance(pRow.begin(), itCol) + (mLabelParams.mRowNameIdx + 1)) = str; 986 | } 987 | } 988 | 989 | while (rowIdx > GetDataRowCount()) 990 | { 991 | std::vector tempRow; 992 | tempRow.resize(GetDataColumnCount()); 993 | mData.push_back(tempRow); 994 | } 995 | 996 | mData.insert(mData.begin() + rowIdx, row); 997 | 998 | if (!pRowName.empty()) 999 | { 1000 | SetRowName(pRowIdx, pRowName); 1001 | } 1002 | } 1003 | 1004 | /** 1005 | * @brief Get number of data rows (excluding label rows). 1006 | * @returns row count. 1007 | */ 1008 | size_t GetRowCount() const 1009 | { 1010 | const ssize_t count = static_cast(mData.size()) - (mLabelParams.mColumnNameIdx + 1); 1011 | return (count >= 0) ? count : 0; 1012 | } 1013 | 1014 | /** 1015 | * @brief Get cell by index. 1016 | * @param pColumnIdx zero-based column index. 1017 | * @param pRowIdx zero-based row index. 1018 | * @returns cell data. 1019 | */ 1020 | template 1021 | T GetCell(const size_t pColumnIdx, const size_t pRowIdx) const 1022 | { 1023 | const ssize_t columnIdx = pColumnIdx + (mLabelParams.mRowNameIdx + 1); 1024 | const ssize_t rowIdx = pRowIdx + (mLabelParams.mColumnNameIdx + 1); 1025 | 1026 | T val; 1027 | Converter converter(mConverterParams); 1028 | converter.ToVal(mData.at(rowIdx).at(columnIdx), val); 1029 | return val; 1030 | } 1031 | 1032 | /** 1033 | * @brief Get cell by index. 1034 | * @param pColumnIdx zero-based column index. 1035 | * @param pRowIdx zero-based row index. 1036 | * @param pToVal conversion function. 1037 | * @returns cell data. 1038 | */ 1039 | template 1040 | T GetCell(const size_t pColumnIdx, const size_t pRowIdx, ConvFunc pToVal) const 1041 | { 1042 | const ssize_t columnIdx = pColumnIdx + (mLabelParams.mRowNameIdx + 1); 1043 | const ssize_t rowIdx = pRowIdx + (mLabelParams.mColumnNameIdx + 1); 1044 | 1045 | T val; 1046 | pToVal(mData.at(rowIdx).at(columnIdx), val); 1047 | return val; 1048 | } 1049 | 1050 | /** 1051 | * @brief Get cell by name. 1052 | * @param pColumnName column label name. 1053 | * @param pRowName row label name. 1054 | * @returns cell data. 1055 | */ 1056 | template 1057 | T GetCell(const std::string& pColumnName, const std::string& pRowName) const 1058 | { 1059 | const ssize_t columnIdx = GetColumnIdx(pColumnName); 1060 | if (columnIdx < 0) 1061 | { 1062 | throw std::out_of_range("column not found: " + pColumnName); 1063 | } 1064 | 1065 | const ssize_t rowIdx = GetRowIdx(pRowName); 1066 | if (rowIdx < 0) 1067 | { 1068 | throw std::out_of_range("row not found: " + pRowName); 1069 | } 1070 | 1071 | return GetCell(columnIdx, rowIdx); 1072 | } 1073 | 1074 | /** 1075 | * @brief Get cell by name. 1076 | * @param pColumnName column label name. 1077 | * @param pRowName row label name. 1078 | * @param pToVal conversion function. 1079 | * @returns cell data. 1080 | */ 1081 | template 1082 | T GetCell(const std::string& pColumnName, const std::string& pRowName, ConvFunc pToVal) const 1083 | { 1084 | const ssize_t columnIdx = GetColumnIdx(pColumnName); 1085 | if (columnIdx < 0) 1086 | { 1087 | throw std::out_of_range("column not found: " + pColumnName); 1088 | } 1089 | 1090 | const ssize_t rowIdx = GetRowIdx(pRowName); 1091 | if (rowIdx < 0) 1092 | { 1093 | throw std::out_of_range("row not found: " + pRowName); 1094 | } 1095 | 1096 | return GetCell(columnIdx, rowIdx, pToVal); 1097 | } 1098 | 1099 | /** 1100 | * @brief Get cell by column name and row index. 1101 | * @param pColumnName column label name. 1102 | * @param pRowIdx zero-based row index. 1103 | * @returns cell data. 1104 | */ 1105 | template 1106 | T GetCell(const std::string& pColumnName, const size_t pRowIdx) const 1107 | { 1108 | const ssize_t columnIdx = GetColumnIdx(pColumnName); 1109 | if (columnIdx < 0) 1110 | { 1111 | throw std::out_of_range("column not found: " + pColumnName); 1112 | } 1113 | 1114 | return GetCell(columnIdx, pRowIdx); 1115 | } 1116 | 1117 | /** 1118 | * @brief Get cell by column name and row index. 1119 | * @param pColumnName column label name. 1120 | * @param pRowIdx zero-based row index. 1121 | * @param pToVal conversion function. 1122 | * @returns cell data. 1123 | */ 1124 | template 1125 | T GetCell(const std::string& pColumnName, const size_t pRowIdx, ConvFunc pToVal) const 1126 | { 1127 | const ssize_t columnIdx = GetColumnIdx(pColumnName); 1128 | if (columnIdx < 0) 1129 | { 1130 | throw std::out_of_range("column not found: " + pColumnName); 1131 | } 1132 | 1133 | return GetCell(columnIdx, pRowIdx, pToVal); 1134 | } 1135 | 1136 | /** 1137 | * @brief Get cell by column index and row name. 1138 | * @param pColumnIdx zero-based column index. 1139 | * @param pRowName row label name. 1140 | * @returns cell data. 1141 | */ 1142 | template 1143 | T GetCell(const size_t pColumnIdx, const std::string& pRowName) const 1144 | { 1145 | const ssize_t rowIdx = GetRowIdx(pRowName); 1146 | if (rowIdx < 0) 1147 | { 1148 | throw std::out_of_range("row not found: " + pRowName); 1149 | } 1150 | 1151 | return GetCell(pColumnIdx, rowIdx); 1152 | } 1153 | 1154 | /** 1155 | * @brief Get cell by column index and row name. 1156 | * @param pColumnIdx zero-based column index. 1157 | * @param pRowName row label name. 1158 | * @param pToVal conversion function. 1159 | * @returns cell data. 1160 | */ 1161 | template 1162 | T GetCell(const size_t pColumnIdx, const std::string& pRowName, ConvFunc pToVal) const 1163 | { 1164 | const ssize_t rowIdx = GetRowIdx(pRowName); 1165 | if (rowIdx < 0) 1166 | { 1167 | throw std::out_of_range("row not found: " + pRowName); 1168 | } 1169 | 1170 | return GetCell(pColumnIdx, rowIdx, pToVal); 1171 | } 1172 | 1173 | /** 1174 | * @brief Set cell by index. 1175 | * @param pRowIdx zero-based row index. 1176 | * @param pColumnIdx zero-based column index. 1177 | * @param pCell cell data. 1178 | */ 1179 | template 1180 | void SetCell(const size_t pColumnIdx, const size_t pRowIdx, const T& pCell) 1181 | { 1182 | const size_t columnIdx = pColumnIdx + (mLabelParams.mRowNameIdx + 1); 1183 | const size_t rowIdx = pRowIdx + (mLabelParams.mColumnNameIdx + 1); 1184 | 1185 | while ((rowIdx + 1) > GetDataRowCount()) 1186 | { 1187 | std::vector row; 1188 | row.resize(GetDataColumnCount()); 1189 | mData.push_back(row); 1190 | } 1191 | 1192 | if ((columnIdx + 1) > GetDataColumnCount()) 1193 | { 1194 | for (auto itRow = mData.begin(); itRow != mData.end(); ++itRow) 1195 | { 1196 | itRow->resize(columnIdx + 1); 1197 | } 1198 | } 1199 | 1200 | std::string str; 1201 | Converter converter(mConverterParams); 1202 | converter.ToStr(pCell, str); 1203 | mData.at(rowIdx).at(columnIdx) = str; 1204 | } 1205 | 1206 | /** 1207 | * @brief Set cell by name. 1208 | * @param pColumnName column label name. 1209 | * @param pRowName row label name. 1210 | * @param pCell cell data. 1211 | */ 1212 | template 1213 | void SetCell(const std::string& pColumnName, const std::string& pRowName, const T& pCell) 1214 | { 1215 | const ssize_t columnIdx = GetColumnIdx(pColumnName); 1216 | if (columnIdx < 0) 1217 | { 1218 | throw std::out_of_range("column not found: " + pColumnName); 1219 | } 1220 | 1221 | const ssize_t rowIdx = GetRowIdx(pRowName); 1222 | if (rowIdx < 0) 1223 | { 1224 | throw std::out_of_range("row not found: " + pRowName); 1225 | } 1226 | 1227 | SetCell(columnIdx, rowIdx, pCell); 1228 | } 1229 | 1230 | /** 1231 | * @brief Get column name 1232 | * @param pColumnIdx zero-based column index. 1233 | * @returns column name. 1234 | */ 1235 | std::string GetColumnName(const ssize_t pColumnIdx) 1236 | { 1237 | const ssize_t columnIdx = pColumnIdx + (mLabelParams.mRowNameIdx + 1); 1238 | if (mLabelParams.mColumnNameIdx < 0) 1239 | { 1240 | throw std::out_of_range("column name row index < 0: " + std::to_string(mLabelParams.mColumnNameIdx)); 1241 | } 1242 | 1243 | return mData.at(mLabelParams.mColumnNameIdx).at(columnIdx); 1244 | } 1245 | 1246 | /** 1247 | * @brief Set column name 1248 | * @param pColumnIdx zero-based column index. 1249 | * @param pColumnName column name. 1250 | */ 1251 | void SetColumnName(size_t pColumnIdx, const std::string& pColumnName) 1252 | { 1253 | const ssize_t columnIdx = pColumnIdx + (mLabelParams.mRowNameIdx + 1); 1254 | mColumnNames[pColumnName] = columnIdx; 1255 | if (mLabelParams.mColumnNameIdx < 0) 1256 | { 1257 | throw std::out_of_range("column name row index < 0: " + std::to_string(mLabelParams.mColumnNameIdx)); 1258 | } 1259 | 1260 | // increase table size if necessary: 1261 | const int rowIdx = mLabelParams.mColumnNameIdx; 1262 | if (rowIdx >= static_cast(mData.size())) 1263 | { 1264 | mData.resize(rowIdx + 1); 1265 | } 1266 | auto& row = mData[rowIdx]; 1267 | if (columnIdx >= static_cast(row.size())) 1268 | { 1269 | row.resize(columnIdx + 1); 1270 | } 1271 | 1272 | mData.at(mLabelParams.mColumnNameIdx).at(columnIdx) = pColumnName; 1273 | } 1274 | 1275 | /** 1276 | * @brief Get column names 1277 | * @returns vector of column names. 1278 | */ 1279 | std::vector GetColumnNames() 1280 | { 1281 | if (mLabelParams.mColumnNameIdx >= 0) 1282 | { 1283 | return std::vector(mData.at(mLabelParams.mColumnNameIdx).begin() + 1284 | (mLabelParams.mRowNameIdx + 1), 1285 | mData.at(mLabelParams.mColumnNameIdx).end()); 1286 | } 1287 | 1288 | return std::vector(); 1289 | } 1290 | 1291 | /** 1292 | * @brief Get row name 1293 | * @param pRowIdx zero-based column index. 1294 | * @returns row name. 1295 | */ 1296 | std::string GetRowName(const ssize_t pRowIdx) 1297 | { 1298 | const ssize_t rowIdx = pRowIdx + (mLabelParams.mColumnNameIdx + 1); 1299 | if (mLabelParams.mRowNameIdx < 0) 1300 | { 1301 | throw std::out_of_range("row name column index < 0: " + std::to_string(mLabelParams.mRowNameIdx)); 1302 | } 1303 | 1304 | return mData.at(rowIdx).at(mLabelParams.mRowNameIdx); 1305 | } 1306 | 1307 | /** 1308 | * @brief Set row name 1309 | * @param pRowIdx zero-based row index. 1310 | * @param pRowName row name. 1311 | */ 1312 | void SetRowName(size_t pRowIdx, const std::string& pRowName) 1313 | { 1314 | const ssize_t rowIdx = pRowIdx + (mLabelParams.mColumnNameIdx + 1); 1315 | mRowNames[pRowName] = rowIdx; 1316 | if (mLabelParams.mRowNameIdx < 0) 1317 | { 1318 | throw std::out_of_range("row name column index < 0: " + std::to_string(mLabelParams.mRowNameIdx)); 1319 | } 1320 | 1321 | // increase table size if necessary: 1322 | if (rowIdx >= static_cast(mData.size())) 1323 | { 1324 | mData.resize(rowIdx + 1); 1325 | } 1326 | auto& row = mData[rowIdx]; 1327 | if (mLabelParams.mRowNameIdx >= static_cast(row.size())) 1328 | { 1329 | row.resize(mLabelParams.mRowNameIdx + 1); 1330 | } 1331 | 1332 | mData.at(rowIdx).at(mLabelParams.mRowNameIdx) = pRowName; 1333 | } 1334 | 1335 | /** 1336 | * @brief Get row names 1337 | * @returns vector of row names. 1338 | */ 1339 | std::vector GetRowNames() 1340 | { 1341 | std::vector rownames; 1342 | if (mLabelParams.mRowNameIdx >= 0) 1343 | { 1344 | for (auto itRow = mData.begin(); itRow != mData.end(); ++itRow) 1345 | { 1346 | if (std::distance(mData.begin(), itRow) > mLabelParams.mColumnNameIdx) 1347 | { 1348 | rownames.push_back(itRow->at(mLabelParams.mRowNameIdx)); 1349 | } 1350 | } 1351 | } 1352 | return rownames; 1353 | } 1354 | 1355 | private: 1356 | void ReadCsv() 1357 | { 1358 | std::ifstream stream; 1359 | stream.exceptions(std::ifstream::failbit | std::ifstream::badbit); 1360 | stream.open(mPath, std::ios::binary); 1361 | ReadCsv(stream); 1362 | } 1363 | 1364 | void ReadCsv(std::istream& pStream) 1365 | { 1366 | Clear(); 1367 | pStream.seekg(0, std::ios::end); 1368 | std::streamsize length = pStream.tellg(); 1369 | pStream.seekg(0, std::ios::beg); 1370 | 1371 | #ifdef HAS_CODECVT 1372 | std::vector bom2b(2, '\0'); 1373 | if (length >= 2) 1374 | { 1375 | pStream.read(bom2b.data(), 2); 1376 | pStream.seekg(0, std::ios::beg); 1377 | } 1378 | 1379 | static const std::vector bomU16le = { '\xff', '\xfe' }; 1380 | static const std::vector bomU16be = { '\xfe', '\xff' }; 1381 | if ((bom2b == bomU16le) || (bom2b == bomU16be)) 1382 | { 1383 | mIsUtf16 = true; 1384 | mIsLE = (bom2b == bomU16le); 1385 | 1386 | std::wifstream wstream; 1387 | wstream.exceptions(std::wifstream::failbit | std::wifstream::badbit); 1388 | wstream.open(mPath, std::ios::binary); 1389 | if (mIsLE) 1390 | { 1391 | wstream.imbue(std::locale(wstream.getloc(), 1392 | new std::codecvt_utf16(std::consume_header | 1394 | std::little_endian)>)); 1395 | } 1396 | else 1397 | { 1398 | wstream.imbue(std::locale(wstream.getloc(), 1399 | new std::codecvt_utf16)); 1401 | } 1402 | std::wstringstream wss; 1403 | wss << wstream.rdbuf(); 1404 | std::string utf8 = ToString(wss.str()); 1405 | std::stringstream ss(utf8); 1406 | ParseCsv(ss, utf8.size()); 1407 | } 1408 | else 1409 | #endif 1410 | { 1411 | // check for UTF-8 Byte order mark and skip it when found 1412 | if (length >= 3) 1413 | { 1414 | std::vector bom3b(3, '\0'); 1415 | pStream.read(bom3b.data(), 3); 1416 | static const std::vector bomU8 = { '\xef', '\xbb', '\xbf' }; 1417 | if (bom3b != bomU8) 1418 | { 1419 | // file does not start with a UTF-8 Byte order mark 1420 | pStream.seekg(0, std::ios::beg); 1421 | } 1422 | else 1423 | { 1424 | // file did start with a UTF-8 Byte order mark, simply skip it 1425 | length -= 3; 1426 | } 1427 | } 1428 | 1429 | ParseCsv(pStream, length); 1430 | } 1431 | } 1432 | 1433 | void ParseCsv(std::istream& pStream, std::streamsize p_FileLength) 1434 | { 1435 | const std::streamsize bufLength = 64 * 1024; 1436 | std::vector buffer(bufLength); 1437 | std::vector row; 1438 | std::string cell; 1439 | bool quoted = false; 1440 | int cr = 0; 1441 | int lf = 0; 1442 | 1443 | while (p_FileLength > 0) 1444 | { 1445 | std::streamsize readLength = std::min(p_FileLength, bufLength); 1446 | pStream.read(buffer.data(), readLength); 1447 | for (int i = 0; i < readLength; ++i) 1448 | { 1449 | if (buffer[i] == '"') 1450 | { 1451 | if (cell.empty() || cell[0] == '"') 1452 | { 1453 | quoted = !quoted; 1454 | } 1455 | cell += buffer[i]; 1456 | } 1457 | else if (buffer[i] == mSeparatorParams.mSeparator) 1458 | { 1459 | if (!quoted) 1460 | { 1461 | row.push_back(Unquote(Trim(cell))); 1462 | cell.clear(); 1463 | } 1464 | else 1465 | { 1466 | cell += buffer[i]; 1467 | } 1468 | } 1469 | else if (buffer[i] == '\r') 1470 | { 1471 | if (mSeparatorParams.mQuotedLinebreaks && quoted) 1472 | { 1473 | cell += buffer[i]; 1474 | } 1475 | else 1476 | { 1477 | ++cr; 1478 | } 1479 | } 1480 | else if (buffer[i] == '\n') 1481 | { 1482 | if (mSeparatorParams.mQuotedLinebreaks && quoted) 1483 | { 1484 | cell += buffer[i]; 1485 | } 1486 | else 1487 | { 1488 | ++lf; 1489 | if (mLineReaderParams.mSkipEmptyLines && row.empty() && cell.empty()) 1490 | { 1491 | // skip empty line 1492 | } 1493 | else 1494 | { 1495 | row.push_back(Unquote(Trim(cell))); 1496 | 1497 | if (mLineReaderParams.mSkipCommentLines && !row.at(0).empty() && 1498 | (row.at(0)[0] == mLineReaderParams.mCommentPrefix)) 1499 | { 1500 | // skip comment line 1501 | } 1502 | else 1503 | { 1504 | mData.push_back(row); 1505 | } 1506 | 1507 | cell.clear(); 1508 | row.clear(); 1509 | quoted = false; 1510 | } 1511 | } 1512 | } 1513 | else 1514 | { 1515 | cell += buffer[i]; 1516 | } 1517 | } 1518 | p_FileLength -= readLength; 1519 | } 1520 | 1521 | // Handle last line without linebreak 1522 | if (!cell.empty() || !row.empty()) 1523 | { 1524 | row.push_back(Unquote(Trim(cell))); 1525 | cell.clear(); 1526 | mData.push_back(row); 1527 | row.clear(); 1528 | } 1529 | 1530 | // Assume CR/LF if at least half the linebreaks have CR 1531 | mSeparatorParams.mHasCR = (cr > (lf / 2)); 1532 | 1533 | // Set up column labels 1534 | if ((mLabelParams.mColumnNameIdx >= 0) && 1535 | (static_cast(mData.size()) > mLabelParams.mColumnNameIdx)) 1536 | { 1537 | int i = 0; 1538 | for (auto& columnName : mData[mLabelParams.mColumnNameIdx]) 1539 | { 1540 | mColumnNames[columnName] = i++; 1541 | } 1542 | } 1543 | 1544 | // Set up row labels 1545 | if ((mLabelParams.mRowNameIdx >= 0) && 1546 | (static_cast(mData.size()) > 1547 | (mLabelParams.mColumnNameIdx + 1))) 1548 | { 1549 | int i = 0; 1550 | for (auto& dataRow : mData) 1551 | { 1552 | if (static_cast(dataRow.size()) > mLabelParams.mRowNameIdx) 1553 | { 1554 | mRowNames[dataRow[mLabelParams.mRowNameIdx]] = i++; 1555 | } 1556 | } 1557 | } 1558 | } 1559 | 1560 | void WriteCsv() const 1561 | { 1562 | #ifdef HAS_CODECVT 1563 | if (mIsUtf16) 1564 | { 1565 | std::stringstream ss; 1566 | WriteCsv(ss); 1567 | std::string utf8 = ss.str(); 1568 | std::wstring wstr = ToWString(utf8); 1569 | 1570 | std::wofstream wstream; 1571 | wstream.exceptions(std::wofstream::failbit | std::wofstream::badbit); 1572 | wstream.open(mPath, std::ios::binary | std::ios::trunc); 1573 | 1574 | if (mIsLE) 1575 | { 1576 | wstream.imbue(std::locale(wstream.getloc(), 1577 | new std::codecvt_utf16(std::little_endian)>)); 1579 | } 1580 | else 1581 | { 1582 | wstream.imbue(std::locale(wstream.getloc(), 1583 | new std::codecvt_utf16)); 1584 | } 1585 | 1586 | wstream << static_cast(0xfeff); 1587 | wstream << wstr; 1588 | } 1589 | else 1590 | #endif 1591 | { 1592 | std::ofstream stream; 1593 | stream.exceptions(std::ofstream::failbit | std::ofstream::badbit); 1594 | stream.open(mPath, std::ios::binary | std::ios::trunc); 1595 | WriteCsv(stream); 1596 | } 1597 | } 1598 | 1599 | void WriteCsv(std::ostream& pStream) const 1600 | { 1601 | for (auto itr = mData.begin(); itr != mData.end(); ++itr) 1602 | { 1603 | for (auto itc = itr->begin(); itc != itr->end(); ++itc) 1604 | { 1605 | if (mSeparatorParams.mAutoQuote && 1606 | ((itc->find(mSeparatorParams.mSeparator) != std::string::npos) || 1607 | (itc->find(' ') != std::string::npos))) 1608 | { 1609 | // escape quotes in string 1610 | std::string str = *itc; 1611 | ReplaceString(str, "\"", "\"\""); 1612 | 1613 | pStream << "\"" << str << "\""; 1614 | } 1615 | else 1616 | { 1617 | pStream << *itc; 1618 | } 1619 | 1620 | if (std::distance(itc, itr->end()) > 1) 1621 | { 1622 | pStream << mSeparatorParams.mSeparator; 1623 | } 1624 | } 1625 | pStream << (mSeparatorParams.mHasCR ? "\r\n" : "\n"); 1626 | } 1627 | } 1628 | 1629 | size_t GetDataRowCount() const 1630 | { 1631 | return mData.size(); 1632 | } 1633 | 1634 | size_t GetDataColumnCount() const 1635 | { 1636 | return (mData.size() > 0) ? mData.at(0).size() : 0; 1637 | } 1638 | 1639 | std::string Trim(const std::string& pStr) 1640 | { 1641 | if (mSeparatorParams.mTrim) 1642 | { 1643 | std::string str = pStr; 1644 | 1645 | // ltrim 1646 | str.erase(str.begin(), std::find_if(str.begin(), str.end(), [](int ch) { return !isspace(ch); })); 1647 | 1648 | // rtrim 1649 | str.erase(std::find_if(str.rbegin(), str.rend(), [](int ch) { return !isspace(ch); }).base(), str.end()); 1650 | 1651 | return str; 1652 | } 1653 | else 1654 | { 1655 | return pStr; 1656 | } 1657 | } 1658 | 1659 | std::string Unquote(const std::string& pStr) 1660 | { 1661 | if (mSeparatorParams.mAutoQuote && (pStr.size() >= 2) && (pStr.front() == '"') && (pStr.back() == '"')) 1662 | { 1663 | // remove start/end quotes 1664 | std::string str = pStr.substr(1, pStr.size() - 2); 1665 | 1666 | // unescape quotes in string 1667 | ReplaceString(str, "\"\"", "\""); 1668 | 1669 | return str; 1670 | } 1671 | else 1672 | { 1673 | return pStr; 1674 | } 1675 | } 1676 | 1677 | #ifdef HAS_CODECVT 1678 | #if defined(_MSC_VER) 1679 | #pragma warning (disable: 4996) 1680 | #endif 1681 | static std::string ToString(const std::wstring& pWStr) 1682 | { 1683 | return std::wstring_convert, wchar_t>{ }.to_bytes(pWStr); 1684 | } 1685 | 1686 | static std::wstring ToWString(const std::string& pStr) 1687 | { 1688 | return std::wstring_convert, wchar_t>{ }.from_bytes(pStr); 1689 | } 1690 | #if defined(_MSC_VER) 1691 | #pragma warning (default: 4996) 1692 | #endif 1693 | #endif 1694 | 1695 | static void ReplaceString(std::string& pStr, const std::string& pSearch, const std::string& pReplace) 1696 | { 1697 | size_t pos = 0; 1698 | 1699 | while ((pos = pStr.find(pSearch, pos)) != std::string::npos) 1700 | { 1701 | pStr.replace(pos, pSearch.size(), pReplace); 1702 | pos += pReplace.size(); 1703 | } 1704 | } 1705 | 1706 | private: 1707 | std::string mPath; 1708 | LabelParams mLabelParams; 1709 | SeparatorParams mSeparatorParams; 1710 | ConverterParams mConverterParams; 1711 | LineReaderParams mLineReaderParams; 1712 | std::vector> mData; 1713 | std::map mColumnNames; 1714 | std::map mRowNames; 1715 | #ifdef HAS_CODECVT 1716 | bool mIsUtf16 = false; 1717 | bool mIsLE = false; 1718 | #endif 1719 | }; 1720 | } 1721 | --------------------------------------------------------------------------------