├── .gitignore ├── CPT.py ├── LICENSE ├── PredictionTree.py ├── README.md ├── data ├── sample_submission.csv ├── test.csv └── train.csv ├── example.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /CPT.py: -------------------------------------------------------------------------------- 1 | from PredictionTree import * 2 | import pandas as pd 3 | from tqdm import tqdm 4 | 5 | 6 | class CPT(): 7 | 8 | alphabet = None # A set of all unique items in the entire data file 9 | root = None # Root node of the Prediction Tree 10 | II = None #Inverted Index dictionary, where key : unique item, value : set of sequences containing this item 11 | LT = None # A Lookup table dictionary, where key : id of a sequence(row), value: leaf node of a Prediction Tree 12 | 13 | def __init__(self): 14 | self.alphabet = set() 15 | self.root = PredictionTree() 16 | self.II = {} 17 | self.LT = {} 18 | 19 | def load_files(self,train_file,test_file = None): 20 | 21 | """ 22 | This function reads in the wide csv file of sequences separated by commas and returns a list of list of those 23 | sequences. The sequences are defined as below. 24 | 25 | seq1 = A,B,C,D 26 | seq2 B,C,E 27 | 28 | Returns: [[A,B,C,D],[B,C,E]] 29 | 30 | 31 | """ 32 | 33 | data = [] # List of list containing the entire sequence data using which the model will be trained. 34 | target = [] # List of list containing the test sequences whose next n items are to be predicted 35 | 36 | if train_file is None: 37 | return train_file 38 | 39 | train = pd.read_csv(train_file) 40 | 41 | for index, row in train.iterrows(): 42 | data.append(row.values) 43 | 44 | if test_file is not None: 45 | 46 | test = pd.read_csv(test_file) 47 | 48 | for index, row in test.iterrows(): 49 | data.append(row.values) 50 | target.append(list(row.values)) 51 | 52 | return data, target 53 | 54 | return data 55 | 56 | 57 | 58 | # In[3]: 59 | 60 | 61 | def train(self, data): 62 | 63 | """ 64 | This functions populates the Prediction Tree, Inverted Index and LookUp Table for the algorithm. 65 | 66 | Input: The list of list training data 67 | Output : Boolean True 68 | 69 | """ 70 | 71 | cursornode = self.root 72 | 73 | 74 | for seqid,row in enumerate(data): 75 | for element in row: 76 | 77 | # adding to the Prediction Tree 78 | 79 | if cursornode.hasChild(element)== False: 80 | cursornode.addChild(element) 81 | cursornode = cursornode.getChild(element) 82 | 83 | else: 84 | cursornode = cursornode.getChild(element) 85 | 86 | # Adding to the Inverted Index 87 | 88 | if self.II.get(element) is None: 89 | self.II[element] = set() 90 | 91 | self.II[element].add(seqid) 92 | 93 | self.alphabet.add(element) 94 | 95 | self.LT[seqid] = cursornode 96 | 97 | cursornode = self.root 98 | 99 | return True 100 | 101 | 102 | def score(self, counttable,key, length, target_size, number_of_similar_sequences, number_items_counttable): 103 | 104 | 105 | """ 106 | This function is the main workhorse and calculates the score to be populated against an item. Items are predicted 107 | using this score. 108 | 109 | Output: Returns a counttable dictionary which stores the score against items. This counttable is specific for a 110 | particular row or a sequence and therefore re-calculated at each prediction. 111 | 112 | 113 | """ 114 | 115 | 116 | 117 | weight_level = 1/number_of_similar_sequences 118 | weight_distance = 1/number_items_counttable 119 | score = 1 + weight_level + weight_distance* 0.001 120 | 121 | if counttable.get(key) is None: 122 | counttable[key] = score 123 | else: 124 | counttable[key] = score * counttable.get(key) 125 | 126 | return counttable 127 | 128 | 129 | 130 | def predict(self,data,target,k, n=1): 131 | """ 132 | Here target is the test dataset in the form of list of list, 133 | k is the number of last elements that will be used to find similar sequences and, 134 | n is the number of predictions required. 135 | 136 | Input: training list of list, target list of list, k,n 137 | 138 | Output: max n predictions for each sequence 139 | """ 140 | 141 | predictions = [] 142 | 143 | for each_target in tqdm(target): 144 | each_target = each_target[-k:] 145 | 146 | intersection = set(range(0,len(data))) 147 | 148 | for element in each_target: 149 | if self.II.get(element) is None: 150 | continue 151 | intersection = intersection & self.II.get(element) 152 | 153 | similar_sequences = [] 154 | 155 | for element in intersection: 156 | currentnode = self.LT.get(element) 157 | tmp = [] 158 | while currentnode.Item is not None: 159 | tmp.append(currentnode.Item) 160 | currentnode = currentnode.Parent 161 | similar_sequences.append(tmp) 162 | 163 | for sequence in similar_sequences: 164 | sequence.reverse() 165 | 166 | counttable = {} 167 | 168 | for sequence in similar_sequences: 169 | try: 170 | index = next(i for i,v in zip(range(len(sequence)-1, 0, -1), reversed(sequence)) if v == each_target[-1]) 171 | except: 172 | index = None 173 | if index is not None: 174 | count = 1 175 | for element in sequence[index+1:]: 176 | if element in each_target: 177 | continue 178 | 179 | counttable = self.score(counttable,element,len(each_target),len(each_target),len(similar_sequences),count) 180 | count+=1 181 | 182 | 183 | pred = self.get_n_largest(counttable,n) 184 | predictions.append(pred) 185 | 186 | return predictions 187 | 188 | 189 | 190 | def get_n_largest(self,dictionary,n): 191 | 192 | 193 | """ 194 | A small utility to obtain top n keys of a Dictionary based on their values. 195 | 196 | """ 197 | largest = sorted(dictionary.items(), key = lambda t: t[1], reverse=True)[:n] 198 | return [key for key,_ in largest] 199 | 200 | 201 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Neeraj Singh Sarwan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PredictionTree.py: -------------------------------------------------------------------------------- 1 | class PredictionTree(): 2 | Item = None 3 | Parent = None 4 | Children = None 5 | 6 | def __init__(self,itemValue=None): 7 | self.Item = itemValue 8 | self.Children = [] 9 | self.Parent = None 10 | 11 | def addChild(self, child): 12 | newchild = PredictionTree(child) 13 | newchild.Parent = self 14 | self.Children.append(newchild) 15 | 16 | def getChild(self,target): 17 | for chld in self.Children: 18 | if chld.Item == target: 19 | return chld 20 | return None 21 | 22 | def getChildren(self): 23 | return self.Children 24 | 25 | def hasChild(self,target): 26 | found = self.getChild(target) 27 | if found is not None: 28 | return True 29 | else: 30 | return False 31 | 32 | def removeChild(self,child): 33 | for chld in self.Children: 34 | if chld.Item==child: 35 | self.Children.remove(chld) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CPT (Compact Prediction Tree) 2 | 3 | This is the Python Implementation of CPT algorithm for Sequence Prediction. The library has been written from scratch in Python and as far as I believe is the first Python implementation of the algorithm. 4 | 5 | The repository is also an exercise on my part to code a research paper. The library is not perfect. I have intentionally left out some optimisations such as CFS(compression of frequenct sequences) etc. These features will be later added to the library as an ongoing effort. 6 | 7 | The library is created using the below two research papers. 8 | 9 | 1. [Compact Prediction Tree: A Losless Model for Accurate Sequence Prediction](http://www.philippe-fournier-viger.com/spmf/ADMA2013_Compact_Prediction_tree) 10 | 11 | 2. [CPT+: Decreasing the time/space complexity of the Compact Prediction Tree](https://pdfs.semanticscholar.org/bd00/0fe7e222b8095c6591291cd7bef18f970ab7.pdf) 12 | 13 | 14 | - How to use the library? 15 | 16 | There is no requirement of compiling anything but make sure you have Pandas and tqdm installed in your environment specific versions of which are mentioned in the file requirements.txt. 17 | 18 | - Sample code for training and getting predictions. 19 | 20 | ~~~ 21 | 22 | model = CPT() 23 | 24 | data, target = model.load_files("train.csv","test.csv") 25 | 26 | model.train(data) 27 | 28 | predictions = model.predict(data,target, k, n) 29 | 30 | ~~~ 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | from CPT import * 2 | model = CPT() 3 | data,target = model.load_files("./data/train.csv","./data/test.csv") 4 | model.train(data) 5 | predictions = model.predict(data,target,5,3) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.21.0 2 | pandas==0.20.3 --------------------------------------------------------------------------------