├── mstream ├── demo.sh ├── anom.hpp ├── Makefile ├── categhash.hpp ├── numerichash.hpp ├── recordhash.hpp ├── numerichash.cpp ├── categhash.cpp ├── recordhash.cpp ├── anom.cpp ├── main.cpp ├── lshhash.h └── argparse.hpp ├── results.py ├── results_win.py ├── pca.py ├── ae.py ├── README.md ├── ib.py ├── run.sh └── LICENSE /mstream/demo.sh: -------------------------------------------------------------------------------- 1 | make 2 | ./mstream -n demonumerical.txt -c democategorical.txt -t demotime.txt -------------------------------------------------------------------------------- /mstream/anom.hpp: -------------------------------------------------------------------------------- 1 | #ifndef anom_hpp 2 | #define anom_hpp 3 | 4 | #include 5 | #include 6 | 7 | using namespace std; 8 | 9 | vector * 10 | mstream(vector > &numeric, vector > &categ, vector ×, int num_rows, 11 | int num_buckets, double factor, int dimension1, int dimension2); 12 | 13 | #endif /* anom_hpp */ 14 | -------------------------------------------------------------------------------- /mstream/Makefile: -------------------------------------------------------------------------------- 1 | #Source: https://gist.github.com/Wenchy/64db1636845a3da0c4c7 2 | CC := g++ 3 | CFLAGS := -Wall -g -O3 -std=c++17 4 | TARGET := mstream 5 | 6 | SRCS := $(wildcard *.cpp) 7 | OBJS := $(patsubst %.cpp,%.o,$(SRCS)) 8 | 9 | all: $(TARGET) 10 | $(TARGET): $(OBJS) 11 | $(CC) -o $@ $^ 12 | %.o: %.cpp 13 | $(CC) $(CFLAGS) -c $< 14 | clean: 15 | rm -rf $(TARGET) *.o 16 | rm -rf ../*.txt 17 | 18 | .PHONY: all clean -------------------------------------------------------------------------------- /mstream/categhash.hpp: -------------------------------------------------------------------------------- 1 | #ifndef categhash_hpp 2 | #define categhash_hpp 3 | 4 | #include 5 | #include 6 | 7 | class Categhash { 8 | public: 9 | Categhash(int r, int b); 10 | void insert(long cur_int, double weight); 11 | double get_count(long cur_int); 12 | void clear(); 13 | void lower(double factor); 14 | 15 | private: 16 | int num_rows; 17 | int num_buckets; 18 | std::vector hash_a, hash_b; 19 | std::vector > count; 20 | 21 | int hash(long cur_int, int i); 22 | }; 23 | 24 | #endif /* categhash_hpp */ 25 | -------------------------------------------------------------------------------- /mstream/numerichash.hpp: -------------------------------------------------------------------------------- 1 | #ifndef numerichash_hpp 2 | #define numerichash_hpp 3 | 4 | #include 5 | #include 6 | 7 | class Numerichash { 8 | public: 9 | Numerichash(int r, int b); 10 | 11 | void insert(double cur_node, double weight); 12 | 13 | double get_count(double cur_node); 14 | 15 | void clear(); 16 | 17 | void lower(double factor); 18 | 19 | private: 20 | int num_rows; 21 | int num_buckets; 22 | std::vector > count; 23 | 24 | int hash(double cur_node); 25 | }; 26 | 27 | #endif /* numerichash_hpp */ 28 | -------------------------------------------------------------------------------- /results.py: -------------------------------------------------------------------------------- 1 | from sklearn import metrics 2 | import pandas as pd 3 | import numpy as np 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser(description="Find AUC") 7 | parser.add_argument("--label", help="labels file", required=True) 8 | parser.add_argument("--scores", help="scores file", required=True) 9 | args = parser.parse_args() 10 | 11 | data = pd.read_csv(args.label, names=["label"]) 12 | is_anom = data.label 13 | scores = pd.read_csv(args.scores, header=None, squeeze=True) 14 | fpr, tpr, _ = metrics.roc_curve(is_anom, scores) 15 | auc = metrics.roc_auc_score(is_anom, scores) 16 | count = np.sum(is_anom) 17 | preds = np.zeros_like(is_anom) 18 | indices = np.argsort(scores, axis=0)[::-1] 19 | preds[indices[:count]] = 1 20 | print( 21 | "AUC: ", auc, 22 | ) 23 | -------------------------------------------------------------------------------- /results_win.py: -------------------------------------------------------------------------------- 1 | from sklearn import metrics 2 | import pandas as pd 3 | import numpy as np 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser(description="Find AUC") 7 | parser.add_argument("--label", help="labels file", required=True) 8 | parser.add_argument("--scores", help="scores file", required=True) 9 | parser.add_argument("--window", help="window size", type=int, required=True) 10 | args = parser.parse_args() 11 | 12 | data = pd.read_csv(args.label, names=["label"]) 13 | is_anom = data.label 14 | scores = pd.read_csv(args.scores, header=None, squeeze=True) 15 | window = args.window 16 | aucfile = open('aucwindow'+str(window)+'.txt', 'w') 17 | for low in range(0, data.shape[0], window): 18 | auc = metrics.roc_auc_score(is_anom[low:low+window], scores[low:low+window]) 19 | aucfile.write('%f\n'%(auc)) 20 | aucfile.close() 21 | -------------------------------------------------------------------------------- /mstream/recordhash.hpp: -------------------------------------------------------------------------------- 1 | #ifndef recordhash_hpp 2 | #define recordhash_hpp 3 | 4 | #include 5 | #include 6 | 7 | class Recordhash { 8 | public: 9 | Recordhash(int r, int b, int dim1, int dim2); 10 | 11 | void insert(std::vector &cur_numeric, std::vector &cur_categ, double weight); 12 | 13 | double get_count(std::vector &cur_numeric, std::vector &cur_categ); 14 | 15 | void clear(); 16 | 17 | void lower(double factor); 18 | 19 | private: 20 | int num_rows; 21 | int num_buckets; 22 | int dimension1; 23 | int dimension2; 24 | std::vector > > num_recordhash; 25 | std::vector > cat_recordhash; 26 | std::vector > count; 27 | 28 | int numerichash(const std::vector& cur_numeric, int i); 29 | 30 | int categhash(std::vector &cur_categ, int i); 31 | }; 32 | 33 | #endif /* recordhash_hpp */ 34 | -------------------------------------------------------------------------------- /pca.py: -------------------------------------------------------------------------------- 1 | from sklearn.decomposition import PCA 2 | import numpy as np 3 | import time 4 | import argparse 5 | 6 | np.random.seed(0) # For reproducibility 7 | np.seterr(divide="ignore", invalid="ignore") 8 | parser = argparse.ArgumentParser(description="Training for MSTREAM-PCA") 9 | parser.add_argument("--dim", type=int, help="number of dimensions", default=12) 10 | parser.add_argument("--input", help="input file", required=True) 11 | parser.add_argument("--output", help="output file", default="pca.txt") 12 | parser.add_argument( 13 | "--numRecords", type=int, help="number of records for training", default=256 14 | ) 15 | args = parser.parse_args() 16 | pca = PCA(n_components=args.dim) 17 | 18 | data = np.loadtxt(args.input, delimiter=",") 19 | mean, std = data.mean(0), data.std(0) 20 | new = (data - mean) / std 21 | new[:, std == 0] = 0 22 | t = time.time() 23 | pca.fit(new[: args.numRecords]) 24 | new = pca.transform(new) 25 | print("Time for Training PCA is: ", time.time() - t) 26 | np.savetxt(args.output, new, delimiter=",", fmt="%.2f") 27 | -------------------------------------------------------------------------------- /mstream/numerichash.cpp: -------------------------------------------------------------------------------- 1 | #define MIN(X, Y) (((X) < (Y)) ? (X) : (Y)) 2 | #define MAX(X, Y) (((X) > (Y)) ? (X) : (Y)) 3 | 4 | #include "numerichash.hpp" 5 | #include 6 | 7 | using namespace std; 8 | 9 | Numerichash::Numerichash(int r, int b) { 10 | num_rows = r; 11 | num_buckets = b; 12 | this->clear(); 13 | } 14 | 15 | int Numerichash::hash(double cur_node) { 16 | int bucket; 17 | cur_node = cur_node * (num_buckets - 1); 18 | bucket = floor(cur_node); 19 | if(bucket < 0) 20 | bucket = (bucket%num_buckets + num_buckets)%num_buckets; 21 | return bucket; 22 | } 23 | 24 | void Numerichash::insert(double cur_node, double weight) { 25 | int bucket; 26 | bucket = hash(cur_node); 27 | count[0][bucket] += weight; 28 | } 29 | 30 | double Numerichash::get_count(double cur_node) { 31 | int bucket; 32 | bucket = hash(cur_node); 33 | return count[0][bucket]; 34 | } 35 | 36 | void Numerichash::clear() { 37 | count = vector >(num_rows, vector(num_buckets, 0.0)); 38 | } 39 | 40 | void Numerichash::lower(double factor) { 41 | for (int i = 0; i < num_rows; i++) { 42 | for (int j = 0; j < num_buckets; j++) { 43 | count[i][j] = count[i][j] * factor; 44 | } 45 | } 46 | } -------------------------------------------------------------------------------- /mstream/categhash.cpp: -------------------------------------------------------------------------------- 1 | #define MIN(X, Y) (((X) < (Y)) ? (X) : (Y)) 2 | #define MAX(X, Y) (((X) > (Y)) ? (X) : (Y)) 3 | 4 | #include 5 | #include "categhash.hpp" 6 | #include 7 | 8 | using namespace std; 9 | 10 | Categhash::Categhash(int r, int b) { 11 | num_rows = r; 12 | num_buckets = b; 13 | hash_a.resize(num_rows); 14 | hash_b.resize(num_rows); 15 | for (int i = 0; i < num_rows; i++) { 16 | // a is in [1, p-1]; b is in [0, p-1] 17 | hash_a[i] = rand() % (num_buckets - 1) + 1; 18 | hash_b[i] = rand() % num_buckets; 19 | } 20 | this->clear(); 21 | } 22 | 23 | int Categhash::hash(long a, int i) { 24 | int resid = (a * hash_a[i] + hash_b[i]) % num_buckets; 25 | return resid + (resid < 0 ? num_buckets : 0); 26 | } 27 | 28 | void Categhash::insert(long cur_int, double weight) { 29 | int bucket; 30 | for (int i = 0; i < num_rows; i++) { 31 | bucket = hash(cur_int, i); 32 | count[i][bucket] += weight; 33 | } 34 | } 35 | 36 | double Categhash::get_count(long cur_int) { 37 | double min_count = numeric_limits::max(); 38 | int bucket; 39 | for (int i = 0; i < num_rows; i++) { 40 | bucket = hash(cur_int, i); 41 | min_count = MIN(min_count, count[i][bucket]); 42 | } 43 | return min_count; 44 | } 45 | 46 | void Categhash::clear() { 47 | count = vector >(num_rows, vector(num_buckets, 0.0)); 48 | } 49 | 50 | void Categhash::lower(double factor) { 51 | for (int i = 0; i < num_rows; i++) { 52 | for (int j = 0; j < num_buckets; j++) { 53 | count[i][j] = count[i][j] * factor; 54 | } 55 | } 56 | } -------------------------------------------------------------------------------- /ae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import time 5 | import argparse 6 | import os 7 | os.environ['KMP_DUPLICATE_LIB_OK']=True # For MAC MKL Optimization 8 | np.random.seed(0) 9 | torch.manual_seed(0) 10 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 11 | 12 | parser = argparse.ArgumentParser(description="Training for MSTREAM-AE") 13 | parser.add_argument( 14 | "--outputdim", type=int, help="number of output dimensions", default=12 15 | ) 16 | parser.add_argument( 17 | "--inputdim", type=int, help="number of input dimensions", required=True 18 | ) 19 | parser.add_argument("--input", help="input file", required=True) 20 | parser.add_argument("--output", help="output file", default="ae.txt") 21 | parser.add_argument( 22 | "--numRecords", type=int, help="number of records for training", default=256 23 | ) 24 | parser.add_argument("--lr", type=float, help="learning rate", required=True) 25 | parser.add_argument("--numEpochs", type=int, help="number of epochs", required=True) 26 | args = parser.parse_args() 27 | 28 | 29 | class AutoEncoder(nn.Module): 30 | def __init__(self): 31 | super(AutoEncoder, self).__init__() 32 | self.e1 = nn.Linear(args.inputdim, args.outputdim) 33 | self.output_layer = nn.Linear(args.outputdim, args.inputdim) 34 | 35 | def forward(self, x): 36 | x = F.relu(self.e1(x)) 37 | x = self.output_layer(x) 38 | return x 39 | 40 | 41 | ae = AutoEncoder().to(device) 42 | loss_func = nn.MSELoss() 43 | optimizer = torch.optim.Adam(ae.parameters(), lr=args.lr) 44 | data = torch.Tensor(np.loadtxt(args.input, delimiter=",")) 45 | t = time.time() 46 | mean, std = data.mean(0), data.std(0) 47 | new = (data - mean) / std 48 | new[:, std == 0] = 0 49 | for epoch in range(args.numEpochs): 50 | x = torch.autograd.Variable(new[: args.numRecords]).to(device) 51 | optimizer.zero_grad() 52 | pred = ae(x) 53 | loss = loss_func(pred, x) 54 | loss.backward() 55 | optimizer.step() 56 | recon = F.relu(ae.e1(torch.autograd.Variable(new).to(device))).detach().cpu() 57 | print("Time for Training AE is ", time.time() - t) 58 | np.savetxt(args.output, recon.numpy(), delimiter=",", fmt="%.2f") 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MSᴛʀᴇᴀᴍ 2 | 3 |

4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 |

23 | 24 | Implementation of 25 | 26 | - [MSᴛʀᴇᴀᴍ: Fast Anomaly Detection in Multi-Aspect Streams](https://dl.acm.org/doi/pdf/10.1145/3442381.3450023). *Siddharth Bhatia, Arjit Jain, Pan Li, Ritesh Kumar, Bryan Hooi.* The Web Conference (formerly WWW), 2021. 27 | 28 | ![](https://www.comp.nus.edu.sg/~sbhatia/assets/img/mstream.png) 29 | MSᴛʀᴇᴀᴍ detects group anomalies from a multi-aspect data stream in constant time and memory. We output an anomaly score for each record. MSᴛʀᴇᴀᴍ builds on top of [MIDAS](https://github.com/Stream-AD/MIDAS) to work in a multi-aspect setting such as event-log data, multi-attributed graphs etc. 30 | 31 | ## Demo 32 | 33 | 1. Run `bash run.sh KDD` to compile the code and run it on the KDD dataset. 34 | 2. Run `bash run.sh DOS` to compile the code and run it on the DOS dataset. 35 | 3. Run `bash run.sh UNSW` to compile the code and run it on the UNSW dataset. 36 | 37 | 38 | ## MSᴛʀᴇᴀᴍ 39 | 1. Change Directory to MSᴛʀᴇᴀᴍ folder `cd mstream` 40 | 2. Run `make` to compile code and create the binary 41 | 2. Run `./mstream -n numericalfile -c categoricalfile -t timefile ` 42 | 3. Run `make clean` to clean binaries 43 | 44 | ## Command line options 45 | * `-h --help`: produce help message 46 | * `-n --numerical`: Numerical file name 47 | * `-c --categorical`: Categorical file name 48 | * `-c --time`: Timestamps file name 49 | * `-o --output`: Output file name (default: scores.txt)   50 | * `-r --rows`: Number of Hash Functions (default: 2)   51 | * `-b --buckets`: Number of Buckets (default: 1024) 52 | * `-a --alpha`: Temporal Decay Factor (default: 0.6) 53 | 54 | 55 | ## Input file format for MSᴛʀᴇᴀᴍ 56 | MSᴛʀᴇᴀᴍ expects the input multi-aspect record stream to be stored in three files: 57 | 1. `Numerical file`: contains `,` separated Numerical Features. 58 | 2. `Categorical file`: contains `,` separated Categorical Features. 59 | 3. `Time File`: contains Timestamps. 60 | 61 | Both Numerical and Categorical files contain corresponding features of the multi-aspect record. Records should be sorted in non-decreasing order of their time stamps and the column delimiter should be `,` 62 | 63 | 64 | ## Datasets 65 | 1. [KDDCUP99](http://kdd.ics.uci.edu/databases/kddcup99/kddcup99.html) 66 | 2. [CICIDS-DoS](https://www.unb.ca/cic/datasets/ids-2018.html) 67 | 2. [UNSW-NB 15](https://www.unsw.adfa.edu.au/unsw-canberra-cyber/cybersecurity/ADFA-NB15-Datasets/) 68 | 3. [CICIDS-DDoS](https://www.unb.ca/cic/datasets/ids-2018.html) 69 | 70 | 71 | ## Citation 72 | 73 | If you use this code for your research, please consider citing our WWW paper. 74 | 75 | ```bibtex 76 | @inproceedings{bhatia2021mstream, 77 | title={Fast Anomaly Detection in Multi-Aspect Streams}, 78 | author={Siddharth Bhatia and Arjit Jain and Pan Li and Ritesh Kumar and Bryan Hooi}, 79 | booktitle={The Web Conference (WWW)}, 80 | year={2021} 81 | } 82 | 83 | ``` 84 | -------------------------------------------------------------------------------- /mstream/recordhash.cpp: -------------------------------------------------------------------------------- 1 | #define MIN(X, Y) (((X) < (Y)) ? (X) : (Y)) 2 | #define MAX(X, Y) (((X) > (Y)) ? (X) : (Y)) 3 | 4 | #include 5 | #include "recordhash.hpp" 6 | #include 7 | #include 8 | #include "lshhash.h" 9 | #include 10 | 11 | using namespace std; 12 | 13 | Recordhash::Recordhash(int r, int b, int dim1, int dim2) { 14 | num_rows = r; 15 | num_buckets = b; 16 | dimension1 = dim1; 17 | dimension2 = dim2; 18 | MTRand mtr; 19 | int log_bucket; 20 | 21 | num_recordhash.resize(num_rows); 22 | for (int i = 0; i < num_rows; i++) { 23 | log_bucket = ceil(log2(num_buckets)); 24 | num_recordhash[i].resize(log_bucket); 25 | for (int j = 0; j < log_bucket; j++) { 26 | num_recordhash[i][j].resize(dimension1); 27 | for (int k = 0; k < dimension1; k++) { 28 | num_recordhash[i][j][k] = mtr.randNorm(); 29 | } 30 | } 31 | } 32 | 33 | cat_recordhash.resize(num_rows); 34 | for (int i = 0; i < num_rows; i++) { 35 | cat_recordhash[i].resize(dimension2); 36 | for (int k = 0; k < dimension2 - 1; k++) { 37 | cat_recordhash[i][k] = (rand() % (num_buckets - 1) + 1); 38 | } 39 | if (dimension2) 40 | cat_recordhash[i][dimension2 - 1] = (rand() % num_buckets); 41 | } 42 | 43 | this->clear(); 44 | } 45 | 46 | int Recordhash::numerichash(const vector &cur_numeric, int i) { 47 | 48 | double sum = 0.0; 49 | int bitcounter = 0; 50 | int log_bucket = ceil(log2(num_buckets)); 51 | bitset<30> b; 52 | 53 | for (int iter = 0; iter < log_bucket; iter++) { 54 | sum = 0; 55 | for (int k = 0; k < dimension1; k++) { 56 | sum = sum + num_recordhash[i][iter][k] * cur_numeric[k]; 57 | } 58 | 59 | if (sum < 0) 60 | b.set(bitcounter, 0); 61 | else 62 | b.set(bitcounter, 1); 63 | bitcounter++; 64 | } 65 | 66 | return b.to_ulong(); 67 | } 68 | 69 | int Recordhash::categhash(vector &cur_categ, int i) { 70 | 71 | int counter = 0; 72 | int resid = 0; 73 | 74 | for (int k = 0; k < dimension2; k++) { 75 | resid = (resid + cat_recordhash[i][counter] * cur_categ[counter]) % num_buckets; 76 | counter++; 77 | } 78 | return resid + (resid < 0 ? num_buckets : 0); 79 | } 80 | 81 | void Recordhash::insert(vector &cur_numeric, vector &cur_categ, double weight) { 82 | int bucket1, bucket2, bucket; 83 | 84 | for (int i = 0; i < num_rows; i++) { 85 | bucket1 = numerichash(cur_numeric, i); 86 | bucket2 = categhash(cur_categ, i); 87 | bucket = (bucket1 + bucket2) % num_buckets; 88 | count[i][bucket] += weight; 89 | } 90 | } 91 | 92 | double Recordhash::get_count(vector &cur_numeric, vector &cur_categ) { 93 | double min_count = numeric_limits::max(); 94 | int bucket1, bucket2, bucket; 95 | for (int i = 0; i < num_rows; i++) { 96 | bucket1 = numerichash(cur_numeric, i); 97 | bucket2 = categhash(cur_categ, i); 98 | bucket = (bucket1 + bucket2) % num_buckets; 99 | min_count = MIN(min_count, count[i][bucket]); 100 | } 101 | return min_count; 102 | } 103 | 104 | void Recordhash::clear() { 105 | count = vector >(num_rows, vector(num_buckets, 0.0)); 106 | } 107 | 108 | void Recordhash::lower(double factor) { 109 | for (int i = 0; i < num_rows; i++) { 110 | for (int j = 0; j < num_buckets; j++) { 111 | count[i][j] = count[i][j] * factor; 112 | } 113 | } 114 | } -------------------------------------------------------------------------------- /ib.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import time 6 | from torch.autograd import Variable 7 | import pandas as pd 8 | import argparse 9 | import os 10 | os.environ['KMP_DUPLICATE_LIB_OK']=True # For MAC MKL Optimization 11 | np.random.seed(0) 12 | torch.manual_seed(0) 13 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 14 | """Modified from https://github.com/burklight/nonlinear-IB-PyTorch""" 15 | 16 | 17 | def compute_distances(x): 18 | x_norm = (x ** 2).sum(1).view(-1, 1) 19 | x_t = torch.transpose(x, 0, 1) 20 | x_t_norm = x_norm.view(1, -1) 21 | dist = x_norm + x_t_norm - 2.0 * torch.mm(x, x_t) 22 | dist = torch.clamp(dist, 0, np.inf) 23 | 24 | return dist 25 | 26 | 27 | def KDE_IXT_estimation(logvar_t, mean_t): 28 | n_batch, d = mean_t.shape 29 | var = torch.exp(logvar_t) + 1e-10 # to avoid 0's in the log 30 | normalization_constant = math.log(n_batch) 31 | dist = compute_distances(mean_t) 32 | distance_contribution = -torch.mean(torch.logsumexp(input=-0.5 * dist / var, dim=1)) 33 | I_XT = normalization_constant + distance_contribution 34 | 35 | return I_XT 36 | 37 | 38 | def get_IXT(mean_t, logvar_t): 39 | IXT = KDE_IXT_estimation(logvar_t, mean_t) # in natts 40 | IXT = IXT / np.log(2) # in bits 41 | return IXT 42 | 43 | 44 | def get_ITY(logits_y, y): 45 | HY_given_T = ce(logits_y, y) 46 | ITY = (np.log(2) - HY_given_T) / np.log(2) # in bits 47 | return ITY 48 | 49 | 50 | def get_loss(IXT_upper, ITY_lower): 51 | loss = -1.0 * (ITY_lower - beta * IXT_upper) 52 | return loss 53 | 54 | 55 | parser = argparse.ArgumentParser(description="Training for MSTREAM-IB") 56 | parser.add_argument( 57 | "--outputdim", type=int, help="number of output dimensions", default=12 58 | ) 59 | parser.add_argument( 60 | "--inputdim", type=int, help="number of input dimensions", required=True 61 | ) 62 | parser.add_argument("--input", help="input file", required=True) 63 | parser.add_argument("--label", help="labels file", required=True) 64 | parser.add_argument("--output", help="output file", default="ib.txt") 65 | parser.add_argument( 66 | "--numRecords", type=int, help="number of records for training", default=256 67 | ) 68 | parser.add_argument("--beta", type=float, help="beta value of IB", default=0.5) 69 | parser.add_argument("--lr", type=float, help="learning rate", required=True) 70 | parser.add_argument("--numEpochs", type=int, help="number of epochs", required=True) 71 | args = parser.parse_args() 72 | beta = args.beta 73 | 74 | 75 | class AutoEncoder(nn.Module): 76 | def __init__(self): 77 | super(AutoEncoder, self).__init__() 78 | self.e1 = nn.Linear(args.inputdim, args.outputdim) 79 | self.output_layer = nn.Linear(args.outputdim, 1) 80 | 81 | def forward(self, x): 82 | mu = self.e1(x) 83 | intermed = mu + torch.randn_like(mu) * 1 84 | x = self.output_layer(intermed) 85 | return x, mu 86 | 87 | 88 | ce = torch.nn.BCEWithLogitsLoss() 89 | data = torch.Tensor(np.loadtxt(args.input, delimiter=",")) 90 | label = pd.read_csv(args.label, names=["label"])[: args.numRecords] 91 | t = time.time() 92 | mean, std = data.mean(0), data.std(0) 93 | new = (data - mean) / std 94 | new[:, std == 0] = 0 95 | label = torch.Tensor(np.array(label.label).reshape(-1, 1)) 96 | 97 | ae = AutoEncoder().to(device) 98 | optimizer = torch.optim.Adam(ae.parameters(), lr=args.lr) 99 | 100 | for epoch in range(args.numEpochs): 101 | train_x = Variable(new[: args.numRecords]).to(device) 102 | train_y = Variable(label).to(device) 103 | optimizer.zero_grad() 104 | train_logits_y, train_mean_t = ae(train_x) 105 | train_ITY = get_ITY(train_logits_y, train_y) 106 | logvar_t = torch.Tensor([0]).to(device) 107 | train_IXT = get_IXT(train_mean_t, logvar_t) 108 | loss = get_loss(train_IXT, train_ITY) 109 | loss.backward() 110 | optimizer.step() 111 | 112 | recon = ae.e1(torch.autograd.Variable(new).to(device)).detach().cpu() 113 | print("Time for Training IB is ", time.time() - t) 114 | np.savetxt(args.output, recon.numpy(), delimiter=",", fmt="%.2f") 115 | -------------------------------------------------------------------------------- /mstream/anom.cpp: -------------------------------------------------------------------------------- 1 | #define MIN(X, Y) (((X) < (Y)) ? (X) : (Y)) 2 | #define MAX(X, Y) (((X) > (Y)) ? (X) : (Y)) 3 | 4 | #include 5 | #include 6 | #include 7 | #include "anom.hpp" 8 | #include "numerichash.hpp" 9 | #include "recordhash.hpp" 10 | #include "categhash.hpp" 11 | 12 | double counts_to_anom(double tot, double cur, int cur_t) { 13 | double cur_mean = tot / cur_t; 14 | double sqerr = pow(MAX(0, cur - cur_mean), 2); 15 | return sqerr / cur_mean + sqerr / (cur_mean * MAX(1, cur_t - 1)); 16 | } 17 | 18 | vector *mstream(vector > &numeric, vector > &categ, vector ×, int num_rows, 19 | int num_buckets, double factor, int dimension1, int dimension2) { 20 | 21 | int length = times.size(), cur_t = 1; 22 | 23 | Recordhash cur_count(num_rows, num_buckets, dimension1, dimension2); 24 | Recordhash total_count(num_rows, num_buckets, dimension1, dimension2); 25 | 26 | auto *anom_score = new vector(length); 27 | 28 | vector numeric_score(dimension1, Numerichash(num_rows, num_buckets)); 29 | vector numeric_total(dimension1, Numerichash(num_rows, num_buckets)); 30 | vector categ_score(dimension2, Categhash(num_rows, num_buckets)); 31 | vector categ_total(dimension2, Categhash(num_rows, num_buckets)); 32 | 33 | vector cur_numeric(0); 34 | vector max_numeric(0); 35 | vector min_numeric(0); 36 | if (dimension1) { 37 | max_numeric.resize(dimension1, numeric_limits::min()); 38 | min_numeric.resize(dimension1, numeric_limits::max()); 39 | } 40 | vector cur_categ(0); 41 | 42 | for (int i = 0; i < length; i++) { 43 | if (i == 0 || times[i] > cur_t) { 44 | cur_count.lower(factor); 45 | for (int j = 0; j < dimension1; j++) { 46 | numeric_score[j].lower(factor); 47 | } 48 | for (int j = 0; j < dimension2; j++) { 49 | categ_score[j].lower(factor); 50 | } 51 | cur_t = times[i]; 52 | } 53 | 54 | if (dimension1) 55 | cur_numeric.swap(numeric[i]); 56 | if (dimension2) 57 | cur_categ.swap(categ[i]); 58 | 59 | double sum = 0.0, t, cur_score; 60 | for (int node_iter = 0; node_iter < dimension1; node_iter++) { 61 | cur_numeric[node_iter] = log10(1 + cur_numeric[node_iter]); 62 | if (!i) { 63 | max_numeric[node_iter] = cur_numeric[node_iter]; 64 | min_numeric[node_iter] = cur_numeric[node_iter]; 65 | cur_numeric[node_iter] = 0; 66 | } else { 67 | min_numeric[node_iter] = MIN(min_numeric[node_iter], cur_numeric[node_iter]); 68 | max_numeric[node_iter] = MAX(max_numeric[node_iter], cur_numeric[node_iter]); 69 | if (max_numeric[node_iter] == min_numeric[node_iter]) cur_numeric[node_iter] = 0; 70 | else cur_numeric[node_iter] = (cur_numeric[node_iter] - min_numeric[node_iter]) / 71 | (max_numeric[node_iter] - min_numeric[node_iter]); 72 | } 73 | numeric_score[node_iter].insert(cur_numeric[node_iter], 1); 74 | numeric_total[node_iter].insert(cur_numeric[node_iter], 1); 75 | t = counts_to_anom(numeric_total[node_iter].get_count(cur_numeric[node_iter]), 76 | numeric_score[node_iter].get_count(cur_numeric[node_iter]), cur_t); 77 | sum = sum+t; 78 | } 79 | cur_count.insert(cur_numeric, cur_categ, 1); 80 | total_count.insert(cur_numeric, cur_categ, 1); 81 | 82 | for (int node_iter = 0; node_iter < dimension2; node_iter++) { 83 | categ_score[node_iter].insert(cur_categ[node_iter], 1); 84 | categ_total[node_iter].insert(cur_categ[node_iter], 1); 85 | t = counts_to_anom(categ_total[node_iter].get_count(cur_categ[node_iter]), 86 | categ_score[node_iter].get_count(cur_categ[node_iter]), cur_t); 87 | sum = sum+t; 88 | } 89 | 90 | cur_score = counts_to_anom(total_count.get_count(cur_numeric, cur_categ), 91 | cur_count.get_count(cur_numeric, cur_categ), cur_t); 92 | sum = sum + cur_score; 93 | (*anom_score)[i] = log(1 + sum); 94 | 95 | } 96 | return anom_score; 97 | } 98 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | cd mstream/ 3 | make clean 4 | make 5 | cd ../ 6 | 7 | if [ $1 == "KDD" ]; then 8 | echo "KDD" 9 | echo "Vanilla MSTREAM" 10 | mstream/mstream -t 'data/kddtime.txt' -n 'data/kddnumeric.txt' -c 'data/kddcateg.txt' -o 'score.txt' -a 0.8 11 | python3 results.py --label 'data/kdd_label.txt' --scores 'score.txt' 12 | 13 | echo "MSTREAM-PCA" 14 | python3 pca.py --input 'data/kddnumeric.txt' 15 | mstream/mstream -t 'data/kddtime.txt' -n 'pca.txt' -c 'data/kddcateg.txt' -o 'pcascore.txt' -a 0.8 16 | python3 results.py --label 'data/kdd_label.txt' --scores 'pcascore.txt' 17 | 18 | echo "MSTREAM-IB" 19 | python3 ib.py --input 'data/kddnumeric.txt' --inputdim 34 --label 'data/kdd_label.txt' --lr 0.01 --numEpochs 100 20 | mstream/mstream -t 'data/kddtime.txt' -n 'ib.txt' -c 'data/kddcateg.txt' -o 'ibscore.txt' -a 0.8 21 | python3 results.py --label 'data/kdd_label.txt' --scores 'ibscore.txt' 22 | 23 | echo "MSTREAM-AE" 24 | python3 ae.py --input 'data/kddnumeric.txt' --inputdim 34 --lr 0.01 --numEpochs 100 25 | mstream/mstream -t 'data/kddtime.txt' -n 'ae.txt' -c 'data/kddcateg.txt' -o 'aescore.txt' -a 0.8 26 | python3 results.py --label 'data/kdd_label.txt' --scores 'aescore.txt' 27 | fi 28 | 29 | if [ $1 == "UNSW" ]; then 30 | echo "UNSW" 31 | echo "Vanilla MSTREAM" 32 | mstream/mstream -t 'data/unswtime.txt' -n 'data/unswnumeric.txt' -c 'data/unswcateg.txt' -o 'score.txt' -a 0.4 33 | python3 results.py --label 'data/unsw_label.txt' --scores 'score.txt' 34 | 35 | echo "MSTREAM-PCA" 36 | python3 pca.py --input 'data/unswnumeric.txt' 37 | mstream/mstream -t 'data/unswtime.txt' -n 'pca.txt' -c 'data/unswcateg.txt' -o 'pcascore.txt' -a 0.4 38 | python3 results.py --label 'data/unsw_label.txt' --scores 'pcascore.txt' 39 | 40 | echo "MSTREAM-IB" 41 | python3 ib.py --input 'data/unswnumeric.txt' --inputdim 39 --label 'data/unsw_label.txt' --lr 0.01 --numEpochs 100 42 | mstream/mstream -t 'data/unswtime.txt' -n 'ib.txt' -c 'data/unswcateg.txt' -o 'ibscore.txt' -a 0.4 43 | python3 results.py --label 'data/unsw_label.txt' --scores 'ibscore.txt' 44 | 45 | echo "MSTREAM-AE" 46 | python3 ae.py --input 'data/unswnumeric.txt' --inputdim 39 --lr 0.01 --numEpochs 100 47 | mstream/mstream -t 'data/unswtime.txt' -n 'ae.txt' -c 'data/unswcateg.txt' -o 'aescore.txt' -a 0.4 48 | python3 results.py --label 'data/unsw_label.txt' --scores 'aescore.txt' 49 | fi 50 | 51 | if [ $1 == "DOS" ]; then 52 | echo "DOS" 53 | echo "Vanilla MSTREAM" 54 | mstream/mstream -t 'data/dostime.txt' -n 'data/dosnumeric.txt' -c 'data/doscateg.txt' -o 'score.txt' -a 0.95 55 | python3 results.py --label 'data/dos_label.txt' --scores 'score.txt' 56 | 57 | echo "MSTREAM-PCA" 58 | python3 pca.py --input 'data/dosnumeric.txt' 59 | mstream/mstream -t 'data/dostime.txt' -n 'pca.txt' -c 'data/doscateg.txt' -o 'pcascore.txt' -a 0.95 60 | python3 results.py --label 'data/dos_label.txt' --scores 'pcascore.txt' 61 | 62 | echo "MSTREAM-IB" 63 | python3 ib.py --input 'data/dosnumeric.txt' --inputdim 76 --label 'data/kdd_label.txt' --lr 0.01 --numEpochs 200 64 | mstream/mstream -t 'data/dostime.txt' -n 'ib.txt' -c 'data/doscateg.txt' -o 'ibscore.txt' -a 0.95 65 | python3 results.py --label 'data/dos_label.txt' --scores 'ibscore.txt' 66 | 67 | echo "MSTREAM-AE" 68 | python3 ae.py --input 'data/dosnumeric.txt' --inputdim 76 --lr 0.0001 --numEpochs 1000 69 | mstream/mstream -t 'data/dostime.txt' -n 'ae.txt' -c 'data/doscateg.txt' -o 'aescore.txt' -a 0.95 70 | python3 results.py --label 'data/dos_label.txt' --scores 'aescore.txt' 71 | fi 72 | 73 | if [ $1 == "DDOS" ]; then 74 | echo "DDOS" 75 | echo "Vanilla MSTREAM" 76 | mstream/mstream -t 'data/ddostime.txt' -n 'data/ddosnumeric.txt' -c 'data/ddoscateg.txt' -o 'score.txt' -a 0.95 77 | python3 results.py --label 'data/ddos_label.txt' --scores 'score.txt' 78 | 79 | echo "MSTREAM-PCA" 80 | python3 pca.py --input 'data/ddosnumeric.txt' 81 | mstream/mstream -t 'data/ddostime.txt' -n 'pca.txt' -c 'data/ddoscateg.txt' -o 'pcascore.txt' -a 0.95 82 | python3 results.py --label 'data/ddos_label.txt' --scores 'pcascore.txt' 83 | 84 | echo "MSTREAM-IB" 85 | python3 ib.py --input 'data/ddosnumeric.txt' --inputdim 76 --label 'data/ddos_label.txt' --lr 0.001 --numEpochs 200 86 | mstream/mstream -t 'data/ddostime.txt' -n 'ib.txt' -c 'data/ddoscateg.txt' -o 'ibscore.txt' -a 0.95 87 | python3 results.py --label 'data/ddos_label.txt' --scores 'ibscore.txt' 88 | 89 | echo "MSTREAM-AE" 90 | python3 ae.py --input 'data/ddosnumeric.txt' --inputdim 76 --lr 0.001 --numEpochs 100 91 | mstream/mstream -t 'data/ddostime.txt' -n 'ae.txt' -c 'data/ddoscateg.txt' -o 'aescore.txt' -a 0.95 92 | python3 results.py --label 'data/ddos_label.txt' --scores 'aescore.txt' 93 | fi 94 | -------------------------------------------------------------------------------- /mstream/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "anom.hpp" 6 | #include "argparse.hpp" 7 | 8 | using namespace std; 9 | 10 | void load_data(vector > &numeric, vector > &categorical, vector ×, 11 | const string &numeric_filename, const string &categ_filename, const string &time_filename) { 12 | int l = 0; 13 | string s, line; 14 | if (!numeric_filename.empty()) { 15 | ifstream numericFile(numeric_filename); 16 | while (numericFile) { 17 | l++; 18 | if (!getline(numericFile, s)) 19 | break; 20 | if (s[0] != '#') { 21 | istringstream ss(s); 22 | vector record; 23 | while (ss) { 24 | 25 | if (!getline(ss, line, ',')) { 26 | break; 27 | } 28 | try { 29 | record.push_back(stod(line)); 30 | } 31 | catch (const std::invalid_argument &e) { 32 | cout << "NaN found in file " << numeric_filename << " line " << l 33 | << endl; 34 | e.what(); 35 | } 36 | } 37 | numeric.push_back(record); 38 | } 39 | } 40 | if (!numericFile.eof()) { 41 | cerr << "Could not read file " << numeric_filename << "\n"; 42 | __throw_invalid_argument("File not found."); 43 | } 44 | } 45 | if (!categ_filename.empty()) { 46 | ifstream categFile(categ_filename); 47 | l = 0; 48 | while (categFile) { 49 | l++; 50 | if (!getline(categFile, s)) 51 | break; 52 | if (s[0] != '#') { 53 | istringstream ss(s); 54 | vector record; 55 | while (ss) { 56 | if (!getline(ss, line, ',')) 57 | break; 58 | try { 59 | record.push_back(stol(line)); 60 | } 61 | catch (const std::invalid_argument &e) { 62 | cout << "NaN found in file " << categ_filename << " line " << l 63 | << endl; 64 | e.what(); 65 | } 66 | } 67 | categorical.push_back(record); 68 | } 69 | } 70 | if (!categFile.eof()) { 71 | cerr << "Could not read file " << categ_filename << "\n"; 72 | __throw_invalid_argument("File not found."); 73 | } 74 | } 75 | ifstream timeFile(time_filename); 76 | l = 0; 77 | while (timeFile) { 78 | l++; 79 | if (!getline(timeFile, s)) 80 | break; 81 | if (s[0] != '#') { 82 | istringstream ss(s); 83 | while (ss) { 84 | if (!getline(ss, line, ',')) 85 | break; 86 | try { 87 | times.push_back(stoi(line)); 88 | } 89 | catch (const std::invalid_argument &e) { 90 | cout << "NaN found in file " << time_filename << " line " << l 91 | << endl; 92 | e.what(); 93 | } 94 | } 95 | } 96 | } 97 | if (!timeFile.eof()) { 98 | cerr << "Could not read file " << time_filename << "\n"; 99 | __throw_invalid_argument("File not found."); 100 | } 101 | } 102 | 103 | int main(int argc, const char *argv[]) { 104 | 105 | argparse::ArgumentParser program("mstream"); 106 | program.add_argument("-n", "--numerical") 107 | .default_value(string("")) 108 | .help("Numerical Data File"); 109 | program.add_argument("-c", "--categorical") 110 | .default_value(string("")) 111 | .help("Categorical Data File"); 112 | program.add_argument("-t", "--times") 113 | .required() 114 | .help("Timestamp Data File"); 115 | program.add_argument("-r", "--rows") 116 | .default_value(2) 117 | .action([](const std::string &value) { return std::stoi(value); }) 118 | .help("Number of rows. Default is 2"); 119 | program.add_argument("-b", "--buckets") 120 | .default_value(1024) 121 | .action([](const std::string &value) { return std::stoi(value); }) 122 | .help("Number of buckets. Default is 1024"); 123 | program.add_argument("-a", "--alpha") 124 | .default_value(0.8) 125 | .action([](const std::string &value) { return std::stod(value); }) 126 | .help("Alpha: Temporal Decay Factor. Default is 0.8"); 127 | program.add_argument("-o", "--output").default_value(string("scores.txt")).help( 128 | "Output File. Default is scores.txt"); 129 | try { 130 | program.parse_args(argc, argv); 131 | } 132 | catch (const std::runtime_error &err) { 133 | std::cout << err.what() << std::endl; 134 | program.print_help(); 135 | exit(1); 136 | } 137 | 138 | string numeric_filename = program.get("-n"); 139 | string categ_filename = program.get("-c"); 140 | string times_filename = program.get("-t"); 141 | string output_filename = program.get("-o"); 142 | int rows = program.get("-r"); 143 | int buckets = program.get("-b"); 144 | auto alpha = program.get("-a"); 145 | 146 | if (rows < 1) { 147 | cerr << "Number of numerichash functions should be positive.\n"; 148 | exit(1); 149 | } 150 | 151 | if (buckets < 2) { 152 | cerr << "Number of buckets should be at least 2\n"; 153 | exit(1); 154 | } 155 | 156 | if (alpha <= 0 || alpha >= 1) { 157 | cerr << "Alpha: Temporal Decay Factor must be between 0 and 1.\n"; 158 | exit(1); 159 | } 160 | 161 | if (numeric_filename.empty() && categ_filename.empty()) { 162 | cerr << "Please give at least one of numeric or categorical data file\n"; 163 | exit(1); 164 | } 165 | 166 | vector > numeric; 167 | vector > categ; 168 | vector times; 169 | int dimension1 = 0, dimension2 = 0; 170 | load_data(numeric, categ, times, numeric_filename, categ_filename, times_filename); 171 | if (!numeric.empty()) 172 | dimension1 = numeric[0].size(); 173 | if (!categ.empty()) 174 | dimension2 = categ[0].size(); 175 | if ((dimension1 && times.size() != numeric.size()) || (dimension2 && times.size() != categ.size())) { 176 | cerr << "Number of records in the files do not match.\n"; 177 | exit(1); 178 | } 179 | 180 | cout << "Finished loading" << endl; 181 | 182 | clock_t start_time2 = clock(); 183 | vector *scores2 = mstream(numeric, categ, times, rows, buckets, alpha, dimension1, dimension2); 184 | cout << "@ " << ((double) (clock() - start_time2)) / CLOCKS_PER_SEC << endl; 185 | 186 | FILE *output_file = fopen(output_filename.c_str(), "w"); 187 | for (double i : *scores2) { 188 | fprintf(output_file, "%f\n", i); 189 | } 190 | return 0; 191 | } 192 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /mstream/lshhash.h: -------------------------------------------------------------------------------- 1 | // MersenneTwister.h 2 | // Mersenne Twister random number generator -- a C++ class MTRand 3 | // Based on code by Makoto Matsumoto, Takuji Nishimura, and Shawn Cokus 4 | // Richard J. Wagner v1.1 28 September 2009 wagnerr@umich.edu 5 | 6 | // The Mersenne Twister is an algorithm for generating random numbers. It 7 | // was designed with consideration of the flaws in various other generators. 8 | // The period, 2^19937-1, and the order of equidistribution, 623 dimensions, 9 | // are far greater. The generator is also fast; it avoids multiplication and 10 | // division, and it benefits from caches and pipelines. For more information 11 | // see the inventors' web page at 12 | // http://www.math.sci.hiroshima-u.ac.jp/~m-mat/MT/emt.html 13 | 14 | // Reference 15 | // M. Matsumoto and T. Nishimura, "Mersenne Twister: A 623-Dimensionally 16 | // Equidistributed Uniform Pseudo-Random Number Generator", ACM Transactions on 17 | // Modeling and Computer Simulation, Vol. 8, No. 1, January 1998, pp 3-30. 18 | 19 | // Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, 20 | // Copyright (C) 2000 - 2009, Richard J. Wagner 21 | // All rights reserved. 22 | // 23 | // Redistribution and use in source and binary forms, with or without 24 | // modification, are permitted provided that the following conditions 25 | // are met: 26 | // 27 | // 1. Redistributions of source code must retain the above copyright 28 | // notice, this list of conditions and the following disclaimer. 29 | // 30 | // 2. Redistributions in binary form must reproduce the above copyright 31 | // notice, this list of conditions and the following disclaimer in the 32 | // documentation and/or other materials provided with the distribution. 33 | // 34 | // 3. The names of its contributors may not be used to endorse or promote 35 | // products derived from this software without specific prior written 36 | // permission. 37 | // 38 | // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 39 | // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 40 | // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 41 | // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 42 | // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 43 | // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 44 | // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 45 | // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 46 | // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 47 | // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 48 | // POSSIBILITY OF SUCH DAMAGE. 49 | 50 | // The original code included the following notice: 51 | // 52 | // When you use this, send an email to: m-mat@math.sci.hiroshima-u.ac.jp 53 | // with an appropriate reference to your work. 54 | // 55 | // It would be nice to CC: wagnerr@umich.edu and Cokus@math.washington.edu 56 | // when you write. 57 | 58 | #ifndef lshhash_hpp 59 | #define lshhash_hpp 60 | 61 | // Not thread safe (unless auto-initialization is avoided and each thread has 62 | // its own MTRand object) 63 | 64 | #include 65 | #include 66 | #include 67 | #include 68 | #include 69 | 70 | class MTRand { 71 | // Data 72 | public: 73 | typedef unsigned long uint32; // unsigned integer type, at least 32 bits 74 | 75 | enum { N = 624 }; // length of state vector 76 | enum { SAVE = N + 1 }; // length of array for save() 77 | 78 | protected: 79 | enum { M = 397 }; // period parameter 80 | 81 | uint32 state[N]; // internal state 82 | uint32 *pNext; // next value to get from state 83 | int left; // number of values left before reload needed 84 | 85 | // Methods 86 | public: 87 | MTRand( const uint32 oneSeed ); // initialize with a simple uint32 88 | MTRand( uint32 *const bigSeed, uint32 const seedLength = N ); // or array 89 | MTRand(); // auto-initialize with /dev/urandom or time() and clock() 90 | MTRand( const MTRand& o ); // copy 91 | 92 | // Do NOT use for CRYPTOGRAPHY without securely hashing several returned 93 | // values together, otherwise the generator state can be learned after 94 | // reading 624 consecutive values. 95 | 96 | // Access to 32-bit random numbers 97 | uint32 randInt(); // integer in [0,2^32-1] 98 | uint32 randInt( const uint32 n ); // integer in [0,n] for n < 2^32 99 | double rand(); // real number in [0,1] 100 | double rand( const double n ); // real number in [0,n] 101 | double randExc(); // real number in [0,1) 102 | double randExc( const double n ); // real number in [0,n) 103 | double randDblExc(); // real number in (0,1) 104 | double randDblExc( const double n ); // real number in (0,n) 105 | double operator()(); // same as rand() 106 | 107 | // Access to 53-bit random numbers (capacity of IEEE double precision) 108 | double rand53(); // real number in [0,1) 109 | 110 | // Access to nonuniform random number distributions 111 | double randNorm( const double mean = 0.0, const double stddev = 1.0 ); 112 | 113 | // Re-seeding functions with same behavior as initializers 114 | void seed( const uint32 oneSeed ); 115 | void seed( uint32 *const bigSeed, const uint32 seedLength = N ); 116 | void seed(); 117 | 118 | // Saving and loading generator state 119 | void save( uint32* saveArray ) const; // to array of size SAVE 120 | void load( uint32 *const loadArray ); // from such array 121 | friend std::ostream& operator<<( std::ostream& os, const MTRand& mtrand ); 122 | friend std::istream& operator>>( std::istream& is, MTRand& mtrand ); 123 | MTRand& operator=( const MTRand& o ); 124 | 125 | protected: 126 | void initialize( const uint32 oneSeed ); 127 | void reload(); 128 | uint32 hiBit( const uint32 u ) const { return u & 0x80000000UL; } 129 | uint32 loBit( const uint32 u ) const { return u & 0x00000001UL; } 130 | uint32 loBits( const uint32 u ) const { return u & 0x7fffffffUL; } 131 | uint32 mixBits( const uint32 u, const uint32 v ) const 132 | { return hiBit(u) | loBits(v); } 133 | uint32 magic( const uint32 u ) const 134 | { return loBit(u) ? 0x9908b0dfUL : 0x0UL; } 135 | uint32 twist( const uint32 m, const uint32 s0, const uint32 s1 ) const 136 | { return m ^ (mixBits(s0,s1)>>1) ^ magic(s1); } 137 | static uint32 hash( time_t t, clock_t c ); 138 | }; 139 | 140 | // Functions are defined in order of usage to assist inlining 141 | 142 | inline MTRand::uint32 MTRand::hash( time_t t, clock_t c ) 143 | { 144 | // Get a uint32 from t and c 145 | // Better than uint32(x) in case x is floating point in [0,1] 146 | // Based on code by Lawrence Kirby (fred@genesis.demon.co.uk) 147 | 148 | static uint32 differ = 0; // guarantee time-based seeds will change 149 | 150 | uint32 h1 = 0; 151 | unsigned char *p = (unsigned char *) &t; 152 | for( size_t i = 0; i < sizeof(t); ++i ) 153 | { 154 | h1 *= UCHAR_MAX + 2U; 155 | h1 += p[i]; 156 | } 157 | uint32 h2 = 0; 158 | p = (unsigned char *) &c; 159 | for( size_t j = 0; j < sizeof(c); ++j ) 160 | { 161 | h2 *= UCHAR_MAX + 2U; 162 | h2 += p[j]; 163 | } 164 | return ( h1 + differ++ ) ^ h2; 165 | } 166 | 167 | inline void MTRand::initialize( const uint32 seed ) 168 | { 169 | // Initialize generator state with seed 170 | // See Knuth TAOCP Vol 2, 3rd Ed, p.106 for multiplier. 171 | // In previous versions, most significant bits (MSBs) of the seed affect 172 | // only MSBs of the state array. Modified 9 Jan 2002 by Makoto Matsumoto. 173 | uint32 *s = state; 174 | uint32 *r = state; 175 | int i = 1; 176 | *s++ = seed & 0xffffffffUL; 177 | for( ; i < N; ++i ) 178 | { 179 | *s++ = ( 1812433253UL * ( *r ^ (*r >> 30) ) + i ) & 0xffffffffUL; 180 | r++; 181 | } 182 | } 183 | 184 | inline void MTRand::reload() 185 | { 186 | // Generate N new values in state 187 | // Made clearer and faster by Matthew Bellew (matthew.bellew@home.com) 188 | static const int MmN = int(M) - int(N); // in case enums are unsigned 189 | uint32 *p = state; 190 | int i; 191 | for( i = N - M; i--; ++p ) 192 | *p = twist( p[M], p[0], p[1] ); 193 | for( i = M; --i; ++p ) 194 | *p = twist( p[MmN], p[0], p[1] ); 195 | *p = twist( p[MmN], p[0], state[0] ); 196 | 197 | left = N, pNext = state; 198 | } 199 | 200 | inline void MTRand::seed( const uint32 oneSeed ) 201 | { 202 | // Seed the generator with a simple uint32 203 | initialize(oneSeed); 204 | reload(); 205 | } 206 | 207 | inline void MTRand::seed( uint32 *const bigSeed, const uint32 seedLength ) 208 | { 209 | // Seed the generator with an array of uint32's 210 | // There are 2^19937-1 possible initial states. This function allows 211 | // all of those to be accessed by providing at least 19937 bits (with a 212 | // default seed length of N = 624 uint32's). Any bits above the lower 32 213 | // in each element are discarded. 214 | // Just call seed() if you want to get array from /dev/urandom 215 | initialize(19650218UL); 216 | int i = 1; 217 | uint32 j = 0; 218 | int k = ( N > seedLength ? N : seedLength ); 219 | for( ; k; --k ) 220 | { 221 | state[i] = 222 | state[i] ^ ( (state[i-1] ^ (state[i-1] >> 30)) * 1664525UL ); 223 | state[i] += ( bigSeed[j] & 0xffffffffUL ) + j; 224 | state[i] &= 0xffffffffUL; 225 | ++i; ++j; 226 | if( i >= N ) { state[0] = state[N-1]; i = 1; } 227 | if( j >= seedLength ) j = 0; 228 | } 229 | for( k = N - 1; k; --k ) 230 | { 231 | state[i] = 232 | state[i] ^ ( (state[i-1] ^ (state[i-1] >> 30)) * 1566083941UL ); 233 | state[i] -= i; 234 | state[i] &= 0xffffffffUL; 235 | ++i; 236 | if( i >= N ) { state[0] = state[N-1]; i = 1; } 237 | } 238 | state[0] = 0x80000000UL; // MSB is 1, assuring non-zero initial array 239 | reload(); 240 | } 241 | 242 | inline void MTRand::seed() 243 | { 244 | // Seed the generator with an array from /dev/urandom if available 245 | // Otherwise use a numerichash of time() and clock() values 246 | 247 | // First try getting an array from /dev/urandom 248 | FILE* urandom = fopen( "/dev/urandom", "rb" ); 249 | if( urandom ) 250 | { 251 | uint32 bigSeed[N]; 252 | uint32 *s = bigSeed; 253 | int i = N; 254 | bool success = true; 255 | while( success && i-- ) 256 | success = fread( s++, sizeof(uint32), 1, urandom ); 257 | fclose(urandom); 258 | if( success ) { seed( bigSeed, N ); return; } 259 | } 260 | 261 | // Was not successful, so use time() and clock() instead 262 | seed( hash( time(NULL), clock() ) ); 263 | } 264 | 265 | inline MTRand::MTRand( const uint32 oneSeed ) 266 | { seed(oneSeed); } 267 | 268 | inline MTRand::MTRand( uint32 *const bigSeed, const uint32 seedLength ) 269 | { seed(bigSeed,seedLength); } 270 | 271 | inline MTRand::MTRand() 272 | { seed(); } 273 | 274 | inline MTRand::MTRand( const MTRand& o ) 275 | { 276 | const uint32 *t = o.state; 277 | uint32 *s = state; 278 | int i = N; 279 | for( ; i--; *s++ = *t++ ) {} 280 | left = o.left; 281 | pNext = &state[N-left]; 282 | } 283 | 284 | inline MTRand::uint32 MTRand::randInt() 285 | { 286 | // Pull a 32-bit integer from the generator state 287 | // Every other access function simply transforms the numbers extracted here 288 | 289 | if( left == 0 ) reload(); 290 | --left; 291 | 292 | uint32 s1; 293 | s1 = *pNext++; 294 | s1 ^= (s1 >> 11); 295 | s1 ^= (s1 << 7) & 0x9d2c5680UL; 296 | s1 ^= (s1 << 15) & 0xefc60000UL; 297 | return ( s1 ^ (s1 >> 18) ); 298 | } 299 | 300 | inline MTRand::uint32 MTRand::randInt( const uint32 n ) 301 | { 302 | // Find which bits are used in n 303 | // Optimized by Magnus Jonsson (magnus@smartelectronix.com) 304 | uint32 used = n; 305 | used |= used >> 1; 306 | used |= used >> 2; 307 | used |= used >> 4; 308 | used |= used >> 8; 309 | used |= used >> 16; 310 | 311 | // Draw numbers until one is found in [0,n] 312 | uint32 i; 313 | do 314 | i = randInt() & used; // toss unused bits to shorten search 315 | while( i > n ); 316 | return i; 317 | } 318 | 319 | inline double MTRand::rand() 320 | { return double(randInt()) * (1.0/4294967295.0); } 321 | 322 | inline double MTRand::rand( const double n ) 323 | { return rand() * n; } 324 | 325 | inline double MTRand::randExc() 326 | { return double(randInt()) * (1.0/4294967296.0); } 327 | 328 | inline double MTRand::randExc( const double n ) 329 | { return randExc() * n; } 330 | 331 | inline double MTRand::randDblExc() 332 | { return ( double(randInt()) + 0.5 ) * (1.0/4294967296.0); } 333 | 334 | inline double MTRand::randDblExc( const double n ) 335 | { return randDblExc() * n; } 336 | 337 | inline double MTRand::rand53() 338 | { 339 | uint32 a = randInt() >> 5, b = randInt() >> 6; 340 | return ( a * 67108864.0 + b ) * (1.0/9007199254740992.0); // by Isaku Wada 341 | } 342 | 343 | inline double MTRand::randNorm( const double mean, const double stddev ) 344 | { 345 | // Return a real number from a normal (Gaussian) distribution with given 346 | // mean and standard deviation by polar form of Box-Muller transformation 347 | double x, y, r; 348 | do 349 | { 350 | x = 2.0 * rand() - 1.0; 351 | y = 2.0 * rand() - 1.0; 352 | r = x * x + y * y; 353 | } 354 | while ( r >= 1.0 || r == 0.0 ); 355 | double s = sqrt( -2.0 * log(r) / r ); 356 | return mean + x * s * stddev; 357 | } 358 | 359 | inline double MTRand::operator()() 360 | { 361 | return rand(); 362 | } 363 | 364 | inline void MTRand::save( uint32* saveArray ) const 365 | { 366 | const uint32 *s = state; 367 | uint32 *sa = saveArray; 368 | int i = N; 369 | for( ; i--; *sa++ = *s++ ) {} 370 | *sa = left; 371 | } 372 | 373 | inline void MTRand::load( uint32 *const loadArray ) 374 | { 375 | uint32 *s = state; 376 | uint32 *la = loadArray; 377 | int i = N; 378 | for( ; i--; *s++ = *la++ ) {} 379 | left = *la; 380 | pNext = &state[N-left]; 381 | } 382 | 383 | inline std::ostream& operator<<( std::ostream& os, const MTRand& mtrand ) 384 | { 385 | const MTRand::uint32 *s = mtrand.state; 386 | int i = mtrand.N; 387 | for( ; i--; os << *s++ << "\t" ) {} 388 | return os << mtrand.left; 389 | } 390 | 391 | inline std::istream& operator>>( std::istream& is, MTRand& mtrand ) 392 | { 393 | MTRand::uint32 *s = mtrand.state; 394 | int i = mtrand.N; 395 | for( ; i--; is >> *s++ ) {} 396 | is >> mtrand.left; 397 | mtrand.pNext = &mtrand.state[mtrand.N-mtrand.left]; 398 | return is; 399 | } 400 | 401 | inline MTRand& MTRand::operator=( const MTRand& o ) 402 | { 403 | if( this == &o ) return (*this); 404 | const uint32 *t = o.state; 405 | uint32 *s = state; 406 | int i = N; 407 | for( ; i--; *s++ = *t++ ) {} 408 | left = o.left; 409 | pNext = &state[N-left]; 410 | return (*this); 411 | } 412 | 413 | #endif // MERSENNETWISTER_H 414 | 415 | // Change log: 416 | // 417 | // v0.1 - First release on 15 May 2000 418 | // - Based on code by Makoto Matsumoto, Takuji Nishimura, and Shawn Cokus 419 | // - Translated from C to C++ 420 | // - Made completely ANSI compliant 421 | // - Designed convenient interface for initialization, seeding, and 422 | // obtaining numbers in default or user-defined ranges 423 | // - Added automatic seeding from /dev/urandom or time() and clock() 424 | // - Provided functions for saving and loading generator state 425 | // 426 | // v0.2 - Fixed bug which reloaded generator one step too late 427 | // 428 | // v0.3 - Switched to clearer, faster reload() code from Matthew Bellew 429 | // 430 | // v0.4 - Removed trailing newline in saved generator format to be consistent 431 | // with output format of built-in types 432 | // 433 | // v0.5 - Improved portability by replacing static const int's with enum's and 434 | // clarifying return values in seed(); suggested by Eric Heimburg 435 | // - Removed MAXINT constant; use 0xffffffffUL instead 436 | // 437 | // v0.6 - Eliminated seed overflow when uint32 is larger than 32 bits 438 | // - Changed integer [0,n] generator to give better uniformity 439 | // 440 | // v0.7 - Fixed operator precedence ambiguity in reload() 441 | // - Added access for real numbers in (0,1) and (0,n) 442 | // 443 | // v0.8 - Included time.h header to properly support time_t and clock_t 444 | // 445 | // v1.0 - Revised seeding to match 26 Jan 2002 update of Nishimura and Matsumoto 446 | // - Allowed for seeding with arrays of any length 447 | // - Added access for real numbers in [0,1) with 53-bit resolution 448 | // - Added access for real numbers from normal (Gaussian) distributions 449 | // - Increased overall speed by optimizing twist() 450 | // - Doubled speed of integer [0,n] generation 451 | // - Fixed out-of-range number generation on 64-bit machines 452 | // - Improved portability by substituting literal constants for long enum's 453 | // - Changed license from GNU LGPL to BSD 454 | // 455 | // v1.1 - Corrected parameter label in randNorm from "variance" to "stddev" 456 | // - Changed randNorm algorithm from basic to polar form for efficiency 457 | // - Updated includes from deprecated to standard forms 458 | // - Cleaned declarations and definitions to please Intel compiler 459 | // - Revised twist() operator to work on ones'-complement machines 460 | // - Fixed reload() function to work when N and M are unsigned 461 | // - Added copy constructor and copy operator from Salvador Espana -------------------------------------------------------------------------------- /mstream/argparse.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | __ _ _ __ __ _ _ __ __ _ _ __ ___ ___ 3 | / _` | '__/ _` | '_ \ / _` | '__/ __|/ _ \ Argument Parser for Modern C++ 4 | | (_| | | | (_| | |_) | (_| | | \__ \ __/ http://github.com/p-ranav/argparse 5 | \__,_|_| \__, | .__/ \__,_|_| |___/\___| 6 | |___/|_| 7 | Licensed under the MIT License . 8 | SPDX-License-Identifier: MIT 9 | Copyright (c) 2019 Pranav Srinivas Kumar . 10 | Permission is hereby granted, free of charge, to any person obtaining a copy 11 | of this software and associated documentation files (the "Software"), to deal 12 | in the Software without restriction, including without limitation the rights 13 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | copies of the Software, and to permit persons to whom the Software is 15 | furnished to do so, subject to the following conditions: 16 | The above copyright notice and this permission notice shall be included in all 17 | copies or substantial portions of the Software. 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | */ 26 | #pragma once 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include 37 | #include 38 | #include 39 | #include 40 | #include 41 | #include 42 | 43 | namespace argparse { 44 | 45 | namespace details { // namespace for helper methods 46 | 47 | template struct is_container_helper {}; 48 | 49 | template 50 | struct is_container : std::false_type {}; 51 | 52 | template <> struct is_container : std::false_type {}; 53 | 54 | template 55 | struct is_container< 56 | T, 57 | std::conditional_t().begin()), 60 | decltype(std::declval().end()), 61 | decltype(std::declval().size())>, 62 | void>> : public std::true_type {}; 63 | 64 | template 65 | static constexpr bool is_container_v = is_container::value; 66 | 67 | template 68 | using enable_if_container = std::enable_if_t, T>; 69 | 70 | template 71 | using enable_if_not_container = std::enable_if_t, T>; 72 | } // namespace 73 | 74 | class Argument { 75 | friend class ArgumentParser; 76 | 77 | public: 78 | Argument() = default; 79 | 80 | template 81 | explicit Argument(Args... args) 82 | : mNames({std::move(args)...}), mIsOptional((is_optional(args) || ...)) { 83 | std::sort( 84 | mNames.begin(), mNames.end(), [](const auto &lhs, const auto &rhs) { 85 | return lhs.size() == rhs.size() ? lhs < rhs : lhs.size() < rhs.size(); 86 | }); 87 | } 88 | 89 | Argument &help(std::string aHelp) { 90 | mHelp = std::move(aHelp); 91 | return *this; 92 | } 93 | 94 | Argument &default_value(std::any aDefaultValue) { 95 | mDefaultValue = std::move(aDefaultValue); 96 | return *this; 97 | } 98 | 99 | Argument &required() { 100 | mIsRequired = true; 101 | return *this; 102 | } 103 | 104 | Argument &implicit_value(std::any aImplicitValue) { 105 | mImplicitValue = std::move(aImplicitValue); 106 | mNumArgs = 0; 107 | return *this; 108 | } 109 | 110 | Argument &action(std::function aAction) { 111 | mAction = std::move(aAction); 112 | return *this; 113 | } 114 | 115 | Argument &nargs(size_t aNumArgs) { 116 | mNumArgs = aNumArgs; 117 | return *this; 118 | } 119 | 120 | template 121 | Iterator consume(Iterator start, Iterator end, std::string usedName = {}) { 122 | if (mIsUsed) { 123 | throw std::runtime_error("Duplicate argument"); 124 | } 125 | mIsUsed = true; 126 | mUsedName = std::move(usedName); 127 | if (mNumArgs == 0) { 128 | mValues.emplace_back(mImplicitValue); 129 | return start; 130 | } else if (mNumArgs <= static_cast(std::distance(start, end))) { 131 | end = std::next(start, mNumArgs); 132 | if (std::any_of(start, end, Argument::is_optional)) { 133 | throw std::runtime_error("optional argument in parameter sequence"); 134 | } 135 | std::transform(start, end, std::back_inserter(mValues), mAction); 136 | return end; 137 | } else if (mDefaultValue.has_value()) { 138 | return start; 139 | } else { 140 | throw std::runtime_error("Too few arguments"); 141 | } 142 | } 143 | 144 | /* 145 | * @throws std::runtime_error if argument values are not valid 146 | */ 147 | void validate() const { 148 | if (mIsOptional) { 149 | if (mIsUsed && mValues.size() != mNumArgs && !mDefaultValue.has_value()) { 150 | std::stringstream stream; 151 | stream << mUsedName << ": expected " << mNumArgs 152 | << " argument(s). " << mValues.size() << " provided."; 153 | throw std::runtime_error(stream.str()); 154 | } else { 155 | // TODO: check if an implicit value was programmed for this argument 156 | if (!mIsUsed && !mDefaultValue.has_value() && mIsRequired) { 157 | std::stringstream stream; 158 | stream << mNames[0] << ": required."; 159 | throw std::runtime_error(stream.str()); 160 | } 161 | if (mIsUsed && mIsRequired && mValues.size() == 0) { 162 | std::stringstream stream; 163 | stream << mUsedName << ": no value provided."; 164 | throw std::runtime_error(stream.str()); 165 | } 166 | } 167 | } else { 168 | if (mValues.size() != mNumArgs && !mDefaultValue.has_value()) { 169 | std::stringstream stream; 170 | stream << mUsedName << ": expected " << mNumArgs 171 | << " argument(s). " << mValues.size() << " provided."; 172 | throw std::runtime_error(stream.str()); 173 | } 174 | } 175 | } 176 | 177 | size_t get_arguments_length() const { 178 | return std::accumulate(std::begin(mNames), std::end(mNames), size_t(0), 179 | [](const auto &sum, const auto &s) { 180 | return sum + s.size() + 181 | 1; // +1 for space between names 182 | }); 183 | } 184 | 185 | friend std::ostream &operator<<(std::ostream &stream, 186 | const Argument &argument) { 187 | std::stringstream nameStream; 188 | std::copy(std::begin(argument.mNames), std::end(argument.mNames), 189 | std::ostream_iterator(nameStream, " ")); 190 | stream << nameStream.str() << "\t" << argument.mHelp; 191 | if (argument.mIsRequired) 192 | stream << "[Required]"; 193 | stream << "\n"; 194 | return stream; 195 | } 196 | 197 | template bool operator!=(const T &aRhs) const { 198 | return !(*this == aRhs); 199 | } 200 | 201 | /* 202 | * Entry point for template non-container types 203 | * @throws std::logic_error in case of incompatible types 204 | */ 205 | template 206 | std::enable_if_t, bool> operator==(const T &aRhs) const { 207 | return get() == aRhs; 208 | } 209 | 210 | /* 211 | * Template specialization for containers 212 | * @throws std::logic_error in case of incompatible types 213 | */ 214 | template 215 | std::enable_if_t, bool> operator==(const T &aRhs) const { 216 | using ValueType = typename T::value_type; 217 | auto tLhs = get(); 218 | if (tLhs.size() != aRhs.size()) 219 | return false; 220 | else { 221 | return std::equal(std::begin(tLhs), std::end(tLhs), std::begin(aRhs), 222 | [](const auto &lhs, const auto &rhs) { 223 | return std::any_cast(lhs) == rhs; 224 | }); 225 | } 226 | } 227 | 228 | private: 229 | static bool is_integer(const std::string &aValue) { 230 | if (aValue.empty() || 231 | ((!isdigit(aValue[0])) && (aValue[0] != '-') && (aValue[0] != '+'))) 232 | return false; 233 | char *tPtr; 234 | strtol(aValue.c_str(), &tPtr, 10); 235 | return (*tPtr == 0); 236 | } 237 | 238 | static bool is_float(const std::string &aValue) { 239 | std::istringstream tStream(aValue); 240 | float tFloat; 241 | // noskipws considers leading whitespace invalid 242 | tStream >> std::noskipws >> tFloat; 243 | // Check the entire string was consumed 244 | // and if either failbit or badbit is set 245 | return tStream.eof() && !tStream.fail(); 246 | } 247 | 248 | // If an argument starts with "-" or "--", then it's optional 249 | static bool is_optional(const std::string &aName) { 250 | return (!aName.empty() && aName[0] == '-' && !is_integer(aName) && 251 | !is_float(aName)); 252 | } 253 | 254 | static bool is_positional(const std::string &aName) { 255 | return !is_optional(aName); 256 | } 257 | 258 | /* 259 | * Getter for template non-container types 260 | * @throws std::logic_error in case of incompatible types 261 | */ 262 | template details::enable_if_not_container get() const { 263 | if (!mValues.empty()) { 264 | return std::any_cast(mValues.front()); 265 | } 266 | if (mDefaultValue.has_value()) { 267 | return std::any_cast(mDefaultValue); 268 | } 269 | throw std::logic_error("No value provided"); 270 | } 271 | 272 | /* 273 | * Getter for container types 274 | * @throws std::logic_error in case of incompatible types 275 | */ 276 | template details::enable_if_container get() const { 277 | using ValueType = typename CONTAINER::value_type; 278 | CONTAINER tResult; 279 | if (!mValues.empty()) { 280 | std::transform( 281 | std::begin(mValues), std::end(mValues), std::back_inserter(tResult), 282 | [](const auto &value) { return std::any_cast(value); }); 283 | return tResult; 284 | } 285 | if (mDefaultValue.has_value()) { 286 | const auto &tDefaultValues = 287 | std::any_cast(mDefaultValue); 288 | std::transform(std::begin(tDefaultValues), std::end(tDefaultValues), 289 | std::back_inserter(tResult), [](const auto &value) { 290 | return std::any_cast(value); 291 | }); 292 | return tResult; 293 | } 294 | throw std::logic_error("No value provided"); 295 | } 296 | 297 | std::vector mNames; 298 | std::string mUsedName; 299 | std::string mHelp; 300 | std::any mDefaultValue; 301 | std::any mImplicitValue; 302 | std::function mAction = 303 | [](const std::string &aValue) { return aValue; }; 304 | std::vector mValues; 305 | std::vector mRawValues; 306 | size_t mNumArgs = 1; 307 | bool mIsOptional = false; 308 | bool mIsRequired = false; 309 | bool mIsUsed = false; // relevant for optional arguments. True if used by user 310 | 311 | public: 312 | static constexpr auto mHelpOption = "-h"; 313 | static constexpr auto mHelpOptionLong = "--help"; 314 | }; 315 | 316 | class ArgumentParser { 317 | public: 318 | explicit ArgumentParser(std::string aProgramName = {}) 319 | : mProgramName(std::move(aProgramName)) { 320 | add_argument(Argument::mHelpOption, Argument::mHelpOptionLong) 321 | .help("show this help message and exit") 322 | .nargs(0) 323 | .default_value(false) 324 | .implicit_value(true); 325 | } 326 | 327 | // Parameter packing 328 | // Call add_argument with variadic number of string arguments 329 | template Argument &add_argument(Targs... Fargs) { 330 | std::shared_ptr tArgument = 331 | std::make_shared(std::move(Fargs)...); 332 | 333 | if (tArgument->mIsOptional) 334 | mOptionalArguments.emplace_back(tArgument); 335 | else 336 | mPositionalArguments.emplace_back(tArgument); 337 | 338 | for (const auto &mName : tArgument->mNames) { 339 | mArgumentMap.insert_or_assign(mName, tArgument); 340 | } 341 | return *tArgument; 342 | } 343 | 344 | // Parameter packed add_parents method 345 | // Accepts a variadic number of ArgumentParser objects 346 | template void add_parents(Targs... Fargs) { 347 | const auto tNewParentParsers = {Fargs...}; 348 | for (const auto &tParentParser : tNewParentParsers) { 349 | const auto &tPositionalArguments = tParentParser.mPositionalArguments; 350 | std::copy(std::begin(tPositionalArguments), 351 | std::end(tPositionalArguments), 352 | std::back_inserter(mPositionalArguments)); 353 | 354 | const auto &tOptionalArguments = tParentParser.mOptionalArguments; 355 | std::copy(std::begin(tOptionalArguments), std::end(tOptionalArguments), 356 | std::back_inserter(mOptionalArguments)); 357 | 358 | const auto &tArgumentMap = tParentParser.mArgumentMap; 359 | for (const auto &[tKey, tValue] : tArgumentMap) { 360 | mArgumentMap.insert_or_assign(tKey, tValue); 361 | } 362 | } 363 | std::move(std::begin(tNewParentParsers), std::end(tNewParentParsers), 364 | std::back_inserter(mParentParsers)); 365 | } 366 | 367 | /* Call parse_args_internal - which does all the work 368 | * Then, validate the parsed arguments 369 | * This variant is used mainly for testing 370 | * @throws std::runtime_error in case of any invalid argument 371 | */ 372 | void parse_args(const std::vector &aArguments) { 373 | parse_args_internal(aArguments); 374 | parse_args_validate(); 375 | } 376 | 377 | /* Main entry point for parsing command-line arguments using this 378 | * ArgumentParser 379 | * @throws std::runtime_error in case of any invalid argument 380 | */ 381 | void parse_args(int argc, const char *const argv[]) { 382 | std::vector arguments; 383 | std::copy(argv, argv + argc, std::back_inserter(arguments)); 384 | parse_args(arguments); 385 | } 386 | 387 | /* Getter enabled for all template types other than std::vector and std::list 388 | * @throws std::logic_error in case of an invalid argument name 389 | * @throws std::logic_error in case of incompatible types 390 | */ 391 | template T get(const std::string &aArgumentName) { 392 | auto tIterator = mArgumentMap.find(aArgumentName); 393 | if (tIterator != mArgumentMap.end()) { 394 | return tIterator->second->get(); 395 | } 396 | throw std::logic_error("No such argument"); 397 | } 398 | 399 | /* Indexing operator. Return a reference to an Argument object 400 | * Used in conjuction with Argument.operator== e.g., parser["foo"] == true 401 | * @throws std::logic_error in case of an invalid argument name 402 | */ 403 | Argument &operator[](const std::string &aArgumentName) { 404 | auto tIterator = mArgumentMap.find(aArgumentName); 405 | if (tIterator != mArgumentMap.end()) { 406 | return *(tIterator->second); 407 | } 408 | throw std::logic_error("No such argument"); 409 | } 410 | 411 | // Printing the one and only help message 412 | // I've stuck with a simple message format, nothing fancy. 413 | // TODO: support user-defined help and usage messages for the ArgumentParser 414 | std::string print_help() { 415 | std::stringstream stream; 416 | stream << std::left; 417 | stream << "Usage: ./" << mProgramName << " [options] "; 418 | size_t tLongestArgumentLength = get_length_of_longest_argument(); 419 | 420 | for (const auto &argument : mPositionalArguments) { 421 | stream << argument->mNames.front() << " "; 422 | } 423 | stream << "\n\n"; 424 | 425 | if (!mPositionalArguments.empty()) 426 | stream << "Positional arguments:\n"; 427 | 428 | for (const auto &mPositionalArgument : mPositionalArguments) { 429 | stream.width(tLongestArgumentLength); 430 | stream << *mPositionalArgument; 431 | } 432 | 433 | if (!mOptionalArguments.empty()) 434 | stream << (mPositionalArguments.empty() ? "" : "\n") 435 | << "Options:\n"; 436 | 437 | for (const auto &mOptionalArgument : mOptionalArguments) { 438 | stream.width(tLongestArgumentLength); 439 | stream << *mOptionalArgument; 440 | } 441 | 442 | std::cout << stream.str(); 443 | return stream.str(); 444 | } 445 | 446 | private: 447 | /* 448 | * @throws std::runtime_error in case of any invalid argument 449 | */ 450 | void parse_args_internal(const std::vector &aArguments) { 451 | if (mProgramName.empty() && !aArguments.empty()) { 452 | mProgramName = aArguments.front(); 453 | } 454 | auto end = std::end(aArguments); 455 | auto positionalArgumentIt = std::begin(mPositionalArguments); 456 | for (auto it = std::next(std::begin(aArguments)); it != end;) { 457 | const auto &tCurrentArgument = *it; 458 | if (tCurrentArgument == Argument::mHelpOption || 459 | tCurrentArgument == Argument::mHelpOptionLong) { 460 | throw std::runtime_error("help called"); 461 | } 462 | if (Argument::is_positional(tCurrentArgument)) { 463 | if (positionalArgumentIt == std::end(mPositionalArguments)) { 464 | throw std::runtime_error( 465 | "Maximum number of positional arguments exceeded"); 466 | } 467 | auto tArgument = *(positionalArgumentIt++); 468 | it = tArgument->consume(it, end); 469 | } else if (auto tIterator = mArgumentMap.find(tCurrentArgument); 470 | tIterator != mArgumentMap.end()) { 471 | auto tArgument = tIterator->second; 472 | it = tArgument->consume(std::next(it), end, tCurrentArgument); 473 | } else if (const auto &tCompoundArgument = tCurrentArgument; 474 | tCompoundArgument.size() > 1 && tCompoundArgument[0] == '-' && 475 | tCompoundArgument[1] != '-') { 476 | ++it; 477 | for (size_t j = 1; j < tCompoundArgument.size(); j++) { 478 | auto tCurrentArgument = std::string{'-', tCompoundArgument[j]}; 479 | if (auto tIterator = mArgumentMap.find(tCurrentArgument); 480 | tIterator != mArgumentMap.end()) { 481 | auto tArgument = tIterator->second; 482 | it = tArgument->consume(it, end, tCurrentArgument); 483 | } else { 484 | throw std::runtime_error("Unknown argument"); 485 | } 486 | } 487 | } else { 488 | throw std::runtime_error("Unknown argument"); 489 | } 490 | } 491 | } 492 | 493 | /* 494 | * @throws std::runtime_error in case of any invalid argument 495 | */ 496 | void parse_args_validate() { 497 | // Check if all arguments are parsed 498 | std::for_each(std::begin(mArgumentMap), std::end(mArgumentMap), 499 | [](const auto &argPair) { 500 | const auto &tArgument = argPair.second; 501 | tArgument->validate(); 502 | }); 503 | } 504 | 505 | // Used by print_help. 506 | size_t get_length_of_longest_argument() { 507 | if (mArgumentMap.empty()) 508 | return 0; 509 | std::vector argumentLengths(mArgumentMap.size()); 510 | std::transform(std::begin(mArgumentMap), std::end(mArgumentMap), 511 | std::begin(argumentLengths), [](const auto &argPair) { 512 | const auto &tArgument = argPair.second; 513 | return tArgument->get_arguments_length(); 514 | }); 515 | return *std::max_element(std::begin(argumentLengths), 516 | std::end(argumentLengths)); 517 | } 518 | 519 | std::string mProgramName; 520 | std::vector mParentParsers; 521 | std::vector> mPositionalArguments; 522 | std::vector> mOptionalArguments; 523 | std::map> mArgumentMap; 524 | }; 525 | 526 | #define PARSE_ARGS(parser, argc, argv) \ 527 | try { \ 528 | parser.parse_args(argc, argv); \ 529 | } catch (const std::runtime_error &err) { \ 530 | std::cout << err.what() << std::endl; \ 531 | parser.print_help(); \ 532 | exit(0); \ 533 | } 534 | 535 | } // namespace argparse --------------------------------------------------------------------------------