├── .gitignore ├── LICENCE.txt ├── MANIFEST.in ├── README.md ├── deepstyle ├── __init__.py └── model.py ├── experiments ├── metrics │ ├── __init__.py │ ├── metrics.py │ └── simrank.py └── notebooks │ ├── authorship-attribution.ipynb │ ├── authorship-clustering.ipynb │ ├── dbert-ft-train.ipynb │ ├── sna-train.ipynb │ └── tfidf-focus.ipynb ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Dirs: 2 | *.egg-info/ 3 | build/ 4 | dist/ 5 | _trash/ 6 | trash/ 7 | 8 | # Misc files: 9 | *.sh 10 | !README.md 11 | TODO.md 12 | 13 | # Python 14 | *.pyc 15 | __pycache__/ 16 | 17 | # hj specific 18 | **/wm-dist/*.json 19 | 20 | # Eclipse 21 | .directory 22 | .project 23 | .metadata 24 | bin/ 25 | tmp/ 26 | *.tmp 27 | *.bak 28 | *.swp 29 | *~.nib 30 | local.properties 31 | .settings/ 32 | .loadpath 33 | .recommenders 34 | 35 | # External tool builders 36 | .externalToolBuilders/ 37 | 38 | # Locally stored "Eclipse launch configurations" 39 | *.launch 40 | 41 | # PyDev specific (Python IDE for Eclipse) 42 | *.pydevproject 43 | 44 | # CDT-specific (C/C++ Development Tooling) 45 | .cproject 46 | 47 | # Java annotation processor (APT) 48 | .factorypath 49 | 50 | # PDT-specific (PHP Development Tools) 51 | .buildpath 52 | 53 | # sbteclipse plugin 54 | .target 55 | 56 | # Tern plugin 57 | .tern-project 58 | 59 | # TeXlipse plugin 60 | .texlipse 61 | 62 | # STS (Spring Tool Suite) 63 | .springBeans 64 | 65 | # Code Recommenders 66 | .recommenders/ 67 | 68 | # Scala IDE specific (Scala & Java development for Eclipse) 69 | .cache-main 70 | .scala_dependencies 71 | .worksheet 72 | 73 | -------------------------------------------------------------------------------- /LICENCE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019 anonym2020 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.md 2 | include *.json 3 | include *.sh 4 | include LICENSE 5 | include LICENSE.txt 6 | include RELEASE 7 | include requirements.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepStyle 2 | 3 | **DeepStyle** provides pretrained models aiming to project text in a stylometric space. The base project consists in a new method of representation learning and a definition of writing style based on distributional properties. This repository contains datasets, pretrained models and other ressources that were used to train and test models. 4 | 5 | ## Datasets and pretrained model 6 | 7 | To get the datasets (i.e. the *R-set* and 22 main *U-sets*), please send me a private message or create a new issue in the repository. 8 | 9 | The DeepStyle model (pretrained *DBert-ft*) is available at . 10 | 11 | ## Installation 12 | 13 | ```bash 14 | git clone https://github.com/hayj/deepstyle 15 | cd deepstyle 16 | python setup.py install 17 | ``` 18 | 19 | Dependencies (tensorflow and transformers) will not be automatically installed since we leave the possibility for users to install newer versions. DeepStyle was tested on tensorflow-gpu `2.0` and transformers `2.4.1`. 20 | 21 | ## Usage of the *DBert-ft* model 22 | 23 | 1. Download the pretrained model available at (both the config and weights). 24 | 2. Use the `DeepStyle` class in order to embed documents: 25 | 26 | ```python 27 | from deepstyle.model import DeepStyle 28 | # Give the folder of the model: 29 | m = DeepStyle("/path/to/the/folder/containing/both/files") 30 | # Sample document: 31 | doc = "Welcome to day two of cold, nasty, wet weather. Ick. Rain is so bad by itself... But when you mix it with a hella cold temperature and nasty wind... Not so much fun anymore." 32 | # Embed a document: 33 | print(m.embed(doc)) # Return a np.ndarray [-0.6553829, 0.3634828, ..., 1.2970213, 0.1685428] 34 | # Get the pretrained model and use its methods (e.g. to get attentions): 35 | m.model # See https://huggingface.co/transformers/model_doc/distilbert.html#tfdistilbertforsequenceclassification 36 | ``` 37 | 38 | In case you have troubles executing this, create a new Python environement and install these package versions: 39 | 40 | pip uninstall -y tensorflow && pip install tensorflow==2.0 41 | pip uninstall -y transformers && pip install transformers==2.4.1 42 | pip uninstall -y h5py && pip install h5py==2.10.0 43 | 44 | ## Experiments 45 | 46 | The folder `experiments` contains main experiments of the project. Some parts of the code are notebooks and need to be adapted to your python environment. For long runs, notebooks are converted into python files (e.g. for the *DBert-ft* training). 47 | 48 | ## "No locks available" issue 49 | 50 | In case you get this error, you can set the `HDF5_USE_FILE_LOCKING` env var, for instance, at the beginning of your script: 51 | 52 | ```python 53 | import os 54 | os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' 55 | ``` 56 | 57 | ## Tested on 58 | 59 | tensorflow-gpu==2.0.0 60 | transformers==2.4.1 61 | h5py==2.10.0 62 | 63 | With python 3.6 and 3.7, CUDA 10.0 and CUDNN 7. 64 | 65 | ## Command line demo 66 | 67 | Here a command line demo of DeepStyle on Ubuntu 20: 68 | 69 | cd ~/tmp 70 | mkdir dbert-ft 71 | cd dbert-ft 72 | wget http://212.129.44.40/DeepStyle/dbert-ft/config.json 73 | wget http://212.129.44.40/DeepStyle/dbert-ft/tf_model.h5 74 | cd ../ 75 | conda create -n dbertft-env -y python=3.7 anaconda 76 | conda activate dbertft-env 77 | git clone https://github.com/hayj/deepstyle ; cd deepstyle ; pip uninstall deepstyle -y ; python setup.py install ; cd ../ ; rm -rf deepstyle 78 | pip install --ignore-installed --upgrade tensorflow==2.0.0 79 | pip install --ignore-installed --upgrade transformers==2.4.1 80 | pip install --ignore-installed --upgrade h5py==2.10.0 81 | ipython -c "from deepstyle.model import DeepStyle ; m = DeepStyle('dbert-ft') ; m.embed('Hello World')" 82 | 83 | ## Citation 84 | 85 | [Link to the publication](https://www.aclweb.org/anthology/2020.wnut-1.30.pdf) 86 | 87 | > Julien Hay, Bich-Liên Doan, Fabrice Popineau, et Ouassim Ait Elhara. Representation learning of writing style. In Proceedings of the 6th Workshop on Noisy User-generated Text (W-NUT 2020), November 2020. 88 | 89 | Bibtex format: 90 | 91 | @inproceedings{hay-2020-deepstyle, 92 | title = "Representation learning of writing style", 93 | author = "Hay, Julien and 94 | Doan, Bich-Li\^{e}n and 95 | Popineau, Fabrice and 96 | Ait Elhara, Ouassim", 97 | booktitle = "Proceedings of the 6th Workshop on Noisy User-generated Text (W-NUT 2020)", 98 | month = nov, 99 | year = "2020" 100 | } 101 | -------------------------------------------------------------------------------- /deepstyle/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" -------------------------------------------------------------------------------- /deepstyle/model.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import numpy as np 4 | import os 5 | from transformers import DistilBertConfig, DistilBertTokenizer, TFDistilBertForSequenceClassification 6 | import logging 7 | 8 | 9 | def chunks(*args, **kwargs): 10 | return list(chunksYielder(*args, **kwargs)) 11 | def chunksYielder(l, n): 12 | """Yield successive n-sized chunks from l.""" 13 | if l is None: 14 | return [] 15 | for i in range(0, len(l), n): 16 | yield l[i:i + n] 17 | 18 | def getDistilBertRepresentations\ 19 | ( 20 | model, 21 | inputs, 22 | layer='distilbert', # distilbert, pre_classifier, dropout, classifier 23 | 24 | ): 25 | """ 26 | Get only one input sample 27 | model is a TFDistilBertForSequenceClassification 28 | See https://huggingface.co/transformers/_modules/transformers/modeling_tf_distilbert.html#TFDistilBertModel 29 | """ 30 | distilbert_output = model.distilbert(inputs) 31 | hidden_state = distilbert_output[0] 32 | pooled_output = hidden_state[:, 0] 33 | if layer == 'distilbert': 34 | return pooled_output 35 | pooled_output = model.pre_classifier(pooled_output) 36 | if layer == 'pre_classifier': 37 | return pooled_output 38 | pooled_output = model.dropout(pooled_output, training=False) 39 | if layer == 'dropout': 40 | return pooled_output 41 | logits = model.classifier(pooled_output) 42 | if layer == 'classifier': 43 | return logits 44 | else: 45 | raise Exception("Please choose a layer in ['distilbert', 'pre_classifier', 'dropout', 'classifier']") 46 | 47 | 48 | distilBertTokenizerSingleton = None 49 | def distilBertEncode\ 50 | ( 51 | doc, 52 | maxLength=512, 53 | multiSamplage=False, 54 | multiSamplageMinMaxLengthRatio=0.3, 55 | bertStartIndex=101, 56 | bertEndIndex=102, 57 | preventTokenizerWarnings=False, 58 | loggerName="transformers.tokenization_utils", 59 | proxies=None, 60 | logger=None, 61 | verbose=True, 62 | ): 63 | """ 64 | Return an encoded doc for DistilBert. 65 | This function return a list of document parts if you set multiSamplage as True. 66 | """ 67 | # We set the logger level: 68 | if preventTokenizerWarnings: 69 | previousLoggerLevel = logging.getLogger(loggerName).level 70 | logging.getLogger(loggerName).setLevel(logging.ERROR) 71 | # We init the tokenizer: 72 | global distilBertTokenizerSingleton 73 | if distilBertTokenizerSingleton is None: 74 | distilBertTokenizerSingleton = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', proxies=proxies) 75 | tokenizer = distilBertTokenizerSingleton 76 | # We tokenize the doc: 77 | if isinstance(doc, list): 78 | doc = " ".join(doc) 79 | doc = tokenizer.encode(doc, add_special_tokens=False) 80 | # In case we want multiple parts: 81 | if multiSamplage: 82 | # We chunk the doc: 83 | parts = chunks(doc, (maxLength - 2)) 84 | # We add special tokens (only one [CLS] at the begining and one [SEP] at the end, 85 | # even for entire documents): 86 | parts = [[bertStartIndex] + part + [bertEndIndex] for part in parts] 87 | # We remove the last part: 88 | if len(parts) > 1 and len(parts[-1]) < int(maxLength * multiSamplageMinMaxLengthRatio): 89 | parts = parts[:-1] 90 | # We pad the last part: 91 | parts[-1] = parts[-1] + [0] * (maxLength - len(parts[-1])) 92 | # We check the length of each part: 93 | for part in parts: 94 | assert len(part) == maxLength 95 | # We reset the logger: 96 | if preventTokenizerWarnings: 97 | logging.getLogger(loggerName).setLevel(previousLoggerLevel) 98 | return parts 99 | # In case we have only one part: 100 | else: 101 | # We truncate the doc: 102 | doc = doc[:(maxLength - 2)] 103 | # We add special tokens: 104 | doc = [bertStartIndex] + doc + [bertEndIndex] 105 | # We pad the doc 106 | doc = doc + [0] * (maxLength - len(doc)) 107 | # We check the length: 108 | assert len(doc) == maxLength 109 | # We reset the logger: 110 | if preventTokenizerWarnings: 111 | logging.getLogger(loggerName).setLevel(previousLoggerLevel) 112 | return doc 113 | 114 | 115 | class DeepStyle: 116 | """ 117 | DeepStyle provides an interface of the DBert-ft model. It allows to embed document in a stylometric space. 118 | """ 119 | def __init__(self, path, batchSize=16, layer='distilbert'): 120 | """ 121 | path is the directory of the DBert-ft model 122 | """ 123 | self.batchSize = batchSize 124 | self.layer = layer 125 | self.path = path 126 | if not (os.path.isfile(self.path + "/config.json") and os.path.isfile(self.path + "/tf_model.h5")): 127 | raise Exception("You need to provide the pretrained model directory that contains tf_model.h5 and config.json") 128 | self.model = None 129 | self.__load() 130 | 131 | def __load(self): 132 | dbertConf = DistilBertConfig.from_pretrained(self.path + '/config.json') 133 | self.model = TFDistilBertForSequenceClassification.from_pretrained\ 134 | ( 135 | self.path + '/tf_model.h5', 136 | config=dbertConf, 137 | ) 138 | 139 | def embed(self, text): 140 | """ 141 | Give raw text and get the style vector. If the text is longer than 512 wordpieces, it will be split and the style vector will be the mean of all embeddings. 142 | """ 143 | encodedText = distilBertEncode\ 144 | ( 145 | text, 146 | maxLength=512, 147 | multiSamplage=True, 148 | preventTokenizerWarnings=True, 149 | ) 150 | encodedBatches = chunks(encodedText, self.batchSize) 151 | embeddings = [] 152 | for encodedBatch in encodedBatches: 153 | outputs = getDistilBertRepresentations(self.model, np.array(encodedBatch), layer=self.layer) 154 | for output in outputs: 155 | embeddings.append(np.array(output)) 156 | return np.mean(embeddings, axis=0) -------------------------------------------------------------------------------- /experiments/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hayj/DeepStyle/599cbf2fa1fa537070bd64d7ca92c72451e1404c/experiments/metrics/__init__.py -------------------------------------------------------------------------------- /experiments/metrics/metrics.py: -------------------------------------------------------------------------------- 1 | # From https://gist.github.com/bwhite/3726239 2 | 3 | """Information Retrieval metrics 4 | Useful Resources: 5 | http://www.cs.utexas.edu/~mooney/ir-course/slides/Evaluation.ppt 6 | http://www.nii.ac.jp/TechReports/05-014E.pdf 7 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf 8 | http://hal.archives-ouvertes.fr/docs/00/72/67/60/PDF/07-busa-fekete.pdf 9 | Learning to Rank for Information Retrieval (Tie-Yan Liu) 10 | """ 11 | import numpy as np 12 | 13 | 14 | def mean_reciprocal_rank(rs): 15 | """Score is reciprocal of the rank of the first relevant item 16 | First element is 'rank 1'. Relevance is binary (nonzero is relevant). 17 | Example from http://en.wikipedia.org/wiki/Mean_reciprocal_rank 18 | >>> rs = [[0, 0, 1], [0, 1, 0], [1, 0, 0]] 19 | >>> mean_reciprocal_rank(rs) 20 | 0.61111111111111105 21 | >>> rs = np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0]]) 22 | >>> mean_reciprocal_rank(rs) 23 | 0.5 24 | >>> rs = [[0, 0, 0, 1], [1, 0, 0], [1, 0, 0]] 25 | >>> mean_reciprocal_rank(rs) 26 | 0.75 27 | Args: 28 | rs: Iterator of relevance scores (list or numpy) in rank order 29 | (first element is the first item) 30 | Returns: 31 | Mean reciprocal rank 32 | """ 33 | rs = (np.asarray(r).nonzero()[0] for r in rs) 34 | return np.mean([1. / (r[0] + 1) if r.size else 0. for r in rs]) 35 | 36 | 37 | def r_precision(r): 38 | """Score is precision after all relevant documents have been retrieved 39 | Relevance is binary (nonzero is relevant). 40 | >>> r = [0, 0, 1] 41 | >>> r_precision(r) 42 | 0.33333333333333331 43 | >>> r = [0, 1, 0] 44 | >>> r_precision(r) 45 | 0.5 46 | >>> r = [1, 0, 0] 47 | >>> r_precision(r) 48 | 1.0 49 | Args: 50 | r: Relevance scores (list or numpy) in rank order 51 | (first element is the first item) 52 | Returns: 53 | R Precision 54 | """ 55 | r = np.asarray(r) != 0 56 | z = r.nonzero()[0] 57 | if not z.size: 58 | return 0. 59 | return np.mean(r[:z[-1] + 1]) 60 | 61 | 62 | def precision_at_k(r, k): 63 | """Score is precision @ k 64 | Relevance is binary (nonzero is relevant). 65 | >>> r = [0, 0, 1] 66 | >>> precision_at_k(r, 1) 67 | 0.0 68 | >>> precision_at_k(r, 2) 69 | 0.0 70 | >>> precision_at_k(r, 3) 71 | 0.33333333333333331 72 | >>> precision_at_k(r, 4) 73 | Traceback (most recent call last): 74 | File "", line 1, in ? 75 | ValueError: Relevance score length < k 76 | Args: 77 | r: Relevance scores (list or numpy) in rank order 78 | (first element is the first item) 79 | Returns: 80 | Precision @ k 81 | Raises: 82 | ValueError: len(r) must be >= k 83 | """ 84 | assert k >= 1 85 | r = np.asarray(r)[:k] != 0 86 | if r.size != k: 87 | raise ValueError('Relevance score length < k') 88 | return np.mean(r) 89 | 90 | 91 | def average_precision(r): 92 | """Score is average precision (area under PR curve) 93 | Relevance is binary (nonzero is relevant). 94 | >>> r = [1, 1, 0, 1, 0, 1, 0, 0, 0, 1] 95 | >>> delta_r = 1. / sum(r) 96 | >>> sum([sum(r[:x + 1]) / (x + 1.) * delta_r for x, y in enumerate(r) if y]) 97 | 0.7833333333333333 98 | >>> average_precision(r) 99 | 0.78333333333333333 100 | Args: 101 | r: Relevance scores (list or numpy) in rank order 102 | (first element is the first item) 103 | Returns: 104 | Average precision 105 | """ 106 | r = np.asarray(r) != 0 107 | out = [precision_at_k(r, k + 1) for k in range(r.size) if r[k]] 108 | if not out: 109 | return 0. 110 | return np.mean(out) 111 | 112 | 113 | def mean_average_precision(rs): 114 | """Score is mean average precision 115 | Relevance is binary (nonzero is relevant). 116 | >>> rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1]] 117 | >>> mean_average_precision(rs) 118 | 0.78333333333333333 119 | >>> rs = [[1, 1, 0, 1, 0, 1, 0, 0, 0, 1], [0]] 120 | >>> mean_average_precision(rs) 121 | 0.39166666666666666 122 | Args: 123 | rs: Iterator of relevance scores (list or numpy) in rank order 124 | (first element is the first item) 125 | Returns: 126 | Mean average precision 127 | """ 128 | return np.mean([average_precision(r) for r in rs]) 129 | 130 | 131 | def dcg_at_k(r, k, method=0): 132 | """Score is discounted cumulative gain (dcg) 133 | Relevance is positive real values. Can use binary 134 | as the previous methods. 135 | Example from 136 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf 137 | >>> r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0] 138 | >>> dcg_at_k(r, 1) 139 | 3.0 140 | >>> dcg_at_k(r, 1, method=1) 141 | 3.0 142 | >>> dcg_at_k(r, 2) 143 | 5.0 144 | >>> dcg_at_k(r, 2, method=1) 145 | 4.2618595071429155 146 | >>> dcg_at_k(r, 10) 147 | 9.6051177391888114 148 | >>> dcg_at_k(r, 11) 149 | 9.6051177391888114 150 | Args: 151 | r: Relevance scores (list or numpy) in rank order 152 | (first element is the first item) 153 | k: Number of results to consider 154 | method: If 0 then weights are [1.0, 1.0, 0.6309, 0.5, 0.4307, ...] 155 | If 1 then weights are [1.0, 0.6309, 0.5, 0.4307, ...] 156 | Returns: 157 | Discounted cumulative gain 158 | """ 159 | r = np.asfarray(r)[:k] 160 | if r.size: 161 | if method == 0: 162 | return r[0] + np.sum(r[1:] / np.log2(np.arange(2, r.size + 1))) 163 | elif method == 1: 164 | return np.sum(r / np.log2(np.arange(2, r.size + 2))) 165 | else: 166 | raise ValueError('method must be 0 or 1.') 167 | return 0. 168 | 169 | 170 | def ndcg_at_k(r, k, method=0): 171 | """Score is normalized discounted cumulative gain (ndcg) 172 | Relevance is positive real values. Can use binary 173 | as the previous methods. 174 | Example from 175 | http://www.stanford.edu/class/cs276/handouts/EvaluationNew-handout-6-per.pdf 176 | >>> r = [3, 2, 3, 0, 0, 1, 2, 2, 3, 0] 177 | >>> ndcg_at_k(r, 1) 178 | 1.0 179 | >>> r = [2, 1, 2, 0] 180 | >>> ndcg_at_k(r, 4) 181 | 0.9203032077642922 182 | >>> ndcg_at_k(r, 4, method=1) 183 | 0.96519546960144276 184 | >>> ndcg_at_k([0], 1) 185 | 0.0 186 | >>> ndcg_at_k([1], 2) 187 | 1.0 188 | Args: 189 | r: Relevance scores (list or numpy) in rank order 190 | (first element is the first item) 191 | k: Number of results to consider 192 | method: If 0 then weights are [1.0, 1.0, 0.6309, 0.5, 0.4307, ...] 193 | If 1 then weights are [1.0, 0.6309, 0.5, 0.4307, ...] 194 | Returns: 195 | Normalized discounted cumulative gain 196 | """ 197 | dcg_max = dcg_at_k(sorted(r, reverse=True), k, method) 198 | if not dcg_max: 199 | return 0. 200 | return dcg_at_k(r, k, method) / dcg_max 201 | 202 | 203 | if __name__ == "__main__": 204 | import doctest 205 | doctest.testmod() -------------------------------------------------------------------------------- /experiments/metrics/simrank.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from deepstyle.experiment.metrics import * 3 | import numpy as np 4 | 5 | def pairwiseCosineSimilarity(data): 6 | """ 7 | This function compute the pairwise cosine similary of a vector. 8 | n is the dimension of the input vector, n*n is the size of the output matrix. 9 | """ 10 | return 1 - pairwise_distances(data, metric='cosine', n_jobs=multiprocessing.cpu_count()) 11 | 12 | def similarityNDCG(vectors, labels, returnSimMatrix=False, logger=None, verbose=True): 13 | """ 14 | This fonction take vector representation of documents (so a matrix). 15 | vectors[0] is the first document and is a vector, for example [1.2, 5.3, -2.4, ..] 16 | labels are class identifiers of each dimension. 17 | Can be strings, for example the author of an article: 18 | ["author1", "author1", "author2", "author3", "author3", "author3", ...] 19 | This function return the averaged ndcg score over all column of the 20 | cosine similarity matrix of vectors: for a better understanding of this 21 | see functions `pairwiseCosineSimilarity` and `pairwiseSimNDCG`. 22 | Set returnSimMatrix as True to get both the generated matrix and the score in a tuple. 23 | """ 24 | mtx = pairwiseCosineSimilarity(vectors) 25 | score = pairwiseSimNDCG(mtx, labels, logger=logger, verbose=verbose) 26 | if returnSimMatrix: 27 | return (mtx, score) 28 | else: 29 | return score 30 | 31 | def pairwiseSimNDCG(simMatrix, labels, logger=None, verbose=True, useNNDCG=True): 32 | """ 33 | This function take a similary matrix n*n (which is a symmetric matrix) 34 | with 1.0 on the diagonal. 35 | It take labels which are class identifiers of each dimension. 36 | Can be strings, for example the author of an article: 37 | ["author1", "author1", "author2", "author3", "author3", "author3"] 38 | It returns the nDCG at k (with k = n) averaged over all columns. 39 | """ 40 | if useNNDCG: 41 | ndcgFunct = nndcg 42 | else: 43 | ndcgFunct = ndcg 44 | labels = pd.factorize(labels)[0] 45 | def rankLabels(col): 46 | # col = np.array([row[0] + row[1], row[2] + row[3]]) 47 | col = np.vstack([col, labels]) 48 | # https://stackoverflow.com/questions/2828059/sorting-arrays-in-numpy-by-column 49 | # argsort() to get indexes and sort, [::-1] for reverse 50 | col = col[:, col[0,:].argsort()[::-1]] 51 | col = col[1] 52 | return col 53 | simMatrix = np.apply_along_axis(rankLabels, 0, simMatrix) 54 | nDCGs = [] 55 | for x in range(simMatrix.shape[0]): 56 | col = simMatrix[:, x] 57 | label = labels[x] 58 | for y in range(len(col)): 59 | col[y] = col[y] == label 60 | nDCGs.append(ndcgFunct(col)) 61 | return np.average(nDCGs) 62 | 63 | 64 | 65 | def ndcg(r, method=0): 66 | """ 67 | This function return the nDCG at k with k = len(r) 68 | """ 69 | return ndcg_at_k(r, len(r), method=method) 70 | 71 | def nndcg(r, method=0): 72 | k = len(r) 73 | idcg = dcg_at_k(sorted(r, reverse=True), k, method) 74 | if not idcg: 75 | return 0. 76 | wdcg = dcg_at_k(sorted(r, reverse=False), k, method) 77 | if not wdcg: 78 | return 0. 79 | dcg = dcg_at_k(r, k, method) 80 | return (dcg - wdcg) / (idcg - wdcg) 81 | 82 | -------------------------------------------------------------------------------- /experiments/notebooks/authorship-attribution.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Commands" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# cd ~/asa2-aa-logs ; jupython --no-tail --venv st-venv -o nohup-asa2-aa-$HOSTNAME.out ~/notebooks/asa/eval/asa2-aa.ipynb\n", 17 | "# observe ~/asa2-aa-logs/nohup-asa2-aa-$HOSTNAME.out" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "# Init" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "isNotebook = '__file__' not in locals()" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "TEST = False # isNotebook, False, True" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "import os\n", 66 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "# Imports" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "from newssource.asattribution.utils import *\n", 97 | "from newssource.asattribution.asamin import *\n", 98 | "from newssource.asa.asapreproc import *\n", 99 | "from newssource.asa.models import *\n", 100 | "from newssource.metrics.ndcg import *" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "import matplotlib\n", 110 | "import numpy as np\n", 111 | "import matplotlib.pyplot as plt\n", 112 | "if not isNotebook:\n", 113 | " matplotlib.use('Agg')\n", 114 | "import random\n", 115 | "import time\n", 116 | "import pickle\n", 117 | "import copy\n", 118 | "from hashlib import md5" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": { 125 | "scrolled": false 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "from keras import backend as K\n", 130 | "K.tensorflow_backend._get_available_gpus()" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "from numpy import array\n", 140 | "from keras.preprocessing.text import one_hot\n", 141 | "from keras.preprocessing.sequence import pad_sequences\n", 142 | "from keras.models import Sequential\n", 143 | "from keras.layers import Dense, LSTM, Flatten\n", 144 | "from keras.layers.embeddings import Embedding\n", 145 | "from keras.models import load_model\n", 146 | "from keras.utils import multi_gpu_model" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "from sklearn.model_selection import train_test_split\n", 156 | "from sklearn.feature_extraction.text import TfidfVectorizer\n", 157 | "from sklearn.naive_bayes import MultinomialNB\n", 158 | "from sklearn import metrics\n", 159 | "from sklearn.model_selection import cross_val_score\n", 160 | "from sklearn.dummy import DummyClassifier\n", 161 | "from sklearn.model_selection import KFold, StratifiedKFold" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "from gensim.test.utils import common_texts\n", 171 | "from gensim.models.doc2vec import Doc2Vec, TaggedDocument" 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "from random import random\n", 181 | "from numpy import array\n", 182 | "from numpy import cumsum\n", 183 | "from keras.models import Sequential\n", 184 | "from keras.layers import LSTM, GRU\n", 185 | "from keras.layers import Dense\n", 186 | "from keras.layers import TimeDistributed\n", 187 | "from keras.utils import multi_gpu_model" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "import statistics" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "from machinelearning.baseline import *\n", 206 | "from machinelearning.encoder import *\n", 207 | "from machinelearning.kerasutils import *\n", 208 | "from machinelearning.kerasmodels import *" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "from machinelearning.baseline import *\n", 218 | "from machinelearning.encoder import *\n", 219 | "from machinelearning import kerasutils\n", 220 | "from machinelearning.iterator import *\n", 221 | "from machinelearning.metrics import *" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": null, 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "from keras.layers import concatenate, Input\n", 231 | "from keras.models import Model\n", 232 | "from keras.utils import plot_model" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "import scipy" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "from deepstyle.model import *" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "from sklearn.decomposition import TruncatedSVD\n", 260 | "from sklearn.random_projection import sparse_random_matrix" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "import matplotlib\n", 270 | "import numpy as np\n", 271 | "import matplotlib.pyplot as plt\n", 272 | "if not isNotebook:\n", 273 | " matplotlib.use('Agg')\n", 274 | "from sklearn.decomposition import TruncatedSVD\n", 275 | "from sklearn.random_projection import sparse_random_matrix\n", 276 | "from sklearn import svm\n", 277 | "from sklearn.model_selection import train_test_split\n", 278 | "from sklearn import linear_model" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "# pip install Cython ; git clone https://github.com/epfml/sent2vec.git ; cd ./sent2vec ; pip install ." 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": null, 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [ 296 | "from newssource.dbert.utils import *" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "from nlptools.topicmodeling import *" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": null, 311 | "metadata": {}, 312 | "outputs": [], 313 | "source": [] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "metadata": {}, 319 | "outputs": [], 320 | "source": [] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "metadata": {}, 325 | "source": [ 326 | "# Config" 327 | ] 328 | }, 329 | { 330 | "cell_type": "code", 331 | "execution_count": null, 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "logger = Logger(tmpDir(\"logs\") + \"/asa2-aa-\" + getHostname() + \".log\")" 336 | ] 337 | }, 338 | { 339 | "cell_type": "code", 340 | "execution_count": null, 341 | "metadata": {}, 342 | "outputs": [], 343 | "source": [ 344 | "tt = TicToc(logger=logger)\n", 345 | "tt.tic()" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": null, 351 | "metadata": {}, 352 | "outputs": [], 353 | "source": [ 354 | "files = \\\n", 355 | "[\n", 356 | " 'uset0-l50-dpl50-d18-bc10',\n", 357 | " 'uset0-l50-dpl50-blogger.com',\n", 358 | " 'uset0-l50-dpl50-breitbart.com',\n", 359 | " 'uset0-l50-dpl50-businessinsider.com',\n", 360 | " 'uset0-l50-dpl50-cnn.com',\n", 361 | " 'uset0-l50-dpl50-guardian.co.uk',\n", 362 | " 'uset0-l50-dpl50-livejournal.com',\n", 363 | " 'uset0-l50-dpl50-nytimes.com',\n", 364 | " 'uset0-l50-dpl50-theguardian.com',\n", 365 | " 'uset0-l50-dpl50-washingtonpost.com',\n", 366 | " 'uset1-l50-dpl50-blogger.com',\n", 367 | " 'uset1-l50-dpl50-d18-bc10',\n", 368 | " 'uset1-l50-dpl50-livejournal.com',\n", 369 | " 'uset2-l50-dpl50-blogger.com',\n", 370 | " 'uset2-l50-dpl50-d18-bc10',\n", 371 | " 'uset2-l50-dpl50-livejournal.com',\n", 372 | " 'uset3-l50-dpl50-blogger.com',\n", 373 | " 'uset3-l50-dpl50-d18-bc10',\n", 374 | " 'uset3-l50-dpl50-livejournal.com',\n", 375 | " 'uset4-l50-dpl50-blogger.com',\n", 376 | " 'uset4-l50-dpl50-d18-bc10',\n", 377 | " 'uset4-l50-dpl50-livejournal.com',\n", 378 | "]\n", 379 | "files = [nosaveDir() + \"/Data/Asa2/detok-usets/\" + e + \"/0.ndjson.bz2\" for e in files]\n", 380 | "bp(files, 5, logger)" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": null, 386 | "metadata": {}, 387 | "outputs": [], 388 | "source": [ 389 | "tipisNumbers = \"85 86 87 92 95 59 56 58 57 84 83 82 81 88 63 03 93 07 06 90 60 62 61 01 02 89 94\".split()\n", 390 | "tipis = ['tipi' + e for e in sorted(tipisNumbers)]\n", 391 | "bp(tipis, logger)" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": null, 397 | "metadata": {}, 398 | "outputs": [], 399 | "source": [ 400 | "# We get the uset for the current tipi:\n", 401 | "association = associate(tipis, files)\n", 402 | "bp(association, 5, logger)\n", 403 | "if not \"tipi\" in getHostname() or isDocker() or isHostname(\"titanv\"):\n", 404 | " file = association['tipi07']\n", 405 | "else:\n", 406 | " assert getHostname() in association\n", 407 | " file = association[getHostname()]\n", 408 | "uset = file.split('/')[-2]" 409 | ] 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": null, 414 | "metadata": {}, 415 | "outputs": [], 416 | "source": [ 417 | "# Parameters for results:\n", 418 | "dataCol = \"filtered_sentences\" # filtered_sentences, sentences" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": null, 424 | "metadata": {}, 425 | "outputs": [], 426 | "source": [ 427 | "# Other parameters:\n", 428 | "randomTestsAmount = 0\n", 429 | "docLength = 1200" 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": null, 435 | "metadata": {}, 436 | "outputs": [], 437 | "source": [ 438 | "# Logging:\n", 439 | "log(getHostname() + \" handles \" + uset, logger)\n", 440 | "log(\"dataCol: \" + str(dataCol), logger)" 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": null, 446 | "metadata": {}, 447 | "outputs": [], 448 | "source": [ 449 | "# Data:\n", 450 | "hashs = []\n", 451 | "docs = []\n", 452 | "labels = []\n", 453 | "flatDocs = []\n", 454 | "flatLowDocs = []\n", 455 | "flatTruncDocs = []\n", 456 | "flatTruncLowDocs = []\n", 457 | "detokDocs = []\n", 458 | "detokSentences = []\n", 459 | "for row in NDJson(file):\n", 460 | " detokSentences.append(row['filtered_detokenized_sentences'])\n", 461 | " detokDocs.append(row['filtered_detokenized'])\n", 462 | " sentences = row[dataCol]\n", 463 | " theHash = objectToHash(sentences)\n", 464 | " hashs.append(theHash)\n", 465 | " docs.append(sentences)\n", 466 | " labels.append(row[\"label\"])\n", 467 | " flattenedDoc = flattenLists(sentences)\n", 468 | " flatDocs.append(flattenedDoc)\n", 469 | " flatLowDocs.append([e.lower() for e in flattenedDoc])\n", 470 | " truncatedDoc = flattenedDoc[:docLength]\n", 471 | " flatTruncDocs.append(truncatedDoc)\n", 472 | " flatTruncLowDocs.append([e.lower() for e in truncatedDoc])\n", 473 | "bp(docs, logger)\n", 474 | "tt.tic(\"Got documents\")" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": null, 480 | "metadata": {}, 481 | "outputs": [], 482 | "source": [ 483 | "(classes, indexLabels) = encodeMulticlassLabels(labels, encoding='index')\n", 484 | "bp(indexLabels, logger)" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": null, 490 | "metadata": {}, 491 | "outputs": [], 492 | "source": [] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": null, 497 | "metadata": {}, 498 | "outputs": [], 499 | "source": [] 500 | }, 501 | { 502 | "cell_type": "markdown", 503 | "metadata": {}, 504 | "source": [ 505 | "# Results" 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": null, 511 | "metadata": {}, 512 | "outputs": [], 513 | "source": [ 514 | "(user, password, host) = getOctodsMongoAuth()" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": null, 520 | "metadata": {}, 521 | "outputs": [], 522 | "source": [] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": null, 527 | "metadata": {}, 528 | "outputs": [], 529 | "source": [] 530 | }, 531 | { 532 | "cell_type": "markdown", 533 | "metadata": {}, 534 | "source": [ 535 | "# Features:" 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": null, 541 | "metadata": {}, 542 | "outputs": [], 543 | "source": [ 544 | "featuresCache = dict()\n", 545 | "def getAsaFeatures2\\\n", 546 | "(\n", 547 | " docs,\n", 548 | " flatDocs,\n", 549 | " flatLowDocs,\n", 550 | " flatTruncDocs,\n", 551 | " flatTruncLowDocs,\n", 552 | " detokDocs,\n", 553 | " detokSentences,\n", 554 | " \n", 555 | " dataHash=None,\n", 556 | " \n", 557 | " useNMF=False,\n", 558 | " nmfKwargs=None,\n", 559 | " \n", 560 | " useLDA=False,\n", 561 | " ldaKwargs=None,\n", 562 | " \n", 563 | " useDbert=False,\n", 564 | " dbertKwargs=None, # operation, layer, modelName\n", 565 | " \n", 566 | " useDeepStyle=False,\n", 567 | " deepStyleKwargs=None,\n", 568 | " deepStyleRoot=nosaveDir() + \"/asa2-train\",\n", 569 | " \n", 570 | " useTFIDF=False,\n", 571 | " tfidfKwargs=None,\n", 572 | " defaultTFIDFNIter=30,\n", 573 | " \n", 574 | " useDoc2Vec=False,\n", 575 | " doc2VecKwargs=None,\n", 576 | " d2vPath=nosaveDir() + \"/d2v/d2vmodel-t-ds22.02g-s300-w3-n15-e15-lTrue-adFalse-7bb8a\",\n", 577 | " \n", 578 | " useStylo=False,\n", 579 | " styloKwargs=None,\n", 580 | " \n", 581 | " useUsent=False,\n", 582 | " usentKwargs=None,\n", 583 | " usentEmbeddingsPattern=nosaveDir() + \"/usent/usentEmbedding*.pickle\",\n", 584 | " \n", 585 | " useInferSent=False,\n", 586 | " inferSentKwargs=None,\n", 587 | " inferSentRoot=nosaveDir() + '/infersent',\n", 588 | " \n", 589 | " useBERT=False,\n", 590 | " bertKwargs=None,\n", 591 | " \n", 592 | " useSent2Vec=False,\n", 593 | " sent2VecKwargs=None,\n", 594 | " sent2VecRoot=nosaveDir() + '/sent2vec',\n", 595 | " defaultSent2VecModelName=\"wiki_unigrams.bin\",\n", 596 | "\n", 597 | " logger=None,\n", 598 | " verbose=True,\n", 599 | "):\n", 600 | " global featuresCache\n", 601 | " if dataHash is None:\n", 602 | " logWarning(\"Please provide a data hash to prevent its computation each call\", logger)\n", 603 | " dataHash = objectToHash([docs, flatDocs, flatLowDocs, flatTruncDocs, flatTruncLowDocs, detokDocs, detokSentences])\n", 604 | " features = []\n", 605 | " # NMF:\n", 606 | " if useNMF:\n", 607 | " h = objectToHash(['NMF', nmfKwargs, dataHash])\n", 608 | " if h in featuresCache:\n", 609 | " features.append(featuresCache[h])\n", 610 | " else:\n", 611 | " data = nmfFeatures(flatDocs, **nmfKwargs)\n", 612 | " features.append(data)\n", 613 | " featuresCache[h] = data\n", 614 | " # LDA:\n", 615 | " if useLDA:\n", 616 | " h = objectToHash(['LDA', ldaKwargs, dataHash])\n", 617 | " if h in featuresCache:\n", 618 | " features.append(featuresCache[h])\n", 619 | " else:\n", 620 | " data = ldaFeatures(flatDocs, **ldaKwargs)\n", 621 | " features.append(data)\n", 622 | " featuresCache[h] = data\n", 623 | " # DBert:\n", 624 | " if useDbert:\n", 625 | " h = objectToHash(['DBert', dbertKwargs, dataHash])\n", 626 | " if h in featuresCache:\n", 627 | " features.append(featuresCache[h])\n", 628 | " else:\n", 629 | " modelName = dbertKwargs['modelName']\n", 630 | " if modelName is None:\n", 631 | " modelPath = None\n", 632 | " else:\n", 633 | " if isDir(nosaveDir() + '/dbert-train/' + modelName):\n", 634 | " modelPath = sortedGlob(nosaveDir() + '/dbert-train/' + modelName + '/epochs/ep*')[-1]\n", 635 | " else:\n", 636 | " modelPath = sortedGlob(nosaveDir() + '/dbert-tmp/' + modelName + '/epochs/ep*')[-1]\n", 637 | " layer = dbertKwargs['layer']\n", 638 | " # if modelName is None:\n", 639 | " # (user, password, host) = getOctodsMongoAuth()\n", 640 | " #  dbertCache = SerializableDict(\"dbert-embeddings\",\n", 641 | " # user=user, host=host, password=password,\n", 642 | " # useMongodb=True, logger=logger)\n", 643 | " # else:\n", 644 | " dbertCache = SerializableDict(\"dbert-embeddings-\" + str(modelName),\n", 645 | " nosaveDir() + '/dbert-cache',\n", 646 | " useMongodb=False, logger=logger,\n", 647 | " loadRetry=30, loadSleepMin=0.5, loadSleepMax=30,\n", 648 | " readIsAnAction=False,)\n", 649 | " embeddings = []\n", 650 | " for doc in docs:\n", 651 | " if modelPath is None:\n", 652 | " layer = 'distilbert'\n", 653 | " (currentHash, cacheObject) = getDbertEmbeddingsHash(doc, layer, modelPath)\n", 654 | " if modelPath is None:\n", 655 | " layer = None\n", 656 | " assert currentHash in dbertCache\n", 657 | " embeddings.append(dbertCache[currentHash]['embeddings'])\n", 658 | " dbertData = []\n", 659 | " if dbertKwargs['operation'] == \"first\":\n", 660 | " for emb in embeddings:\n", 661 | " dbertData.append(emb[0])\n", 662 | " elif dbertKwargs['operation'] == \"mean\":\n", 663 | " for emb in embeddings:\n", 664 | " dbertData.append(np.mean(emb, axis=0))\n", 665 | " dbertData = np.array(dbertData)\n", 666 | " features.append(dbertData)\n", 667 | " featuresCache[h] = dbertData\n", 668 | " # DeepStyle:\n", 669 | " if useDeepStyle:\n", 670 | " h = objectToHash(['DeepStyle', deepStyleKwargs, dataHash])\n", 671 | " if h in featuresCache:\n", 672 | " features.append(featuresCache[h])\n", 673 | " else:\n", 674 | " m = DeepStyle(deepStyleRoot + \"/\" + deepStyleKwargs['modelPattern'])\n", 675 | " embeddings = np.array([m.embed(doc) for doc in docs])\n", 676 | " features.append(embeddings)\n", 677 | " featuresCache[h] = embeddings\n", 678 | " # TFIDF:\n", 679 | " if useTFIDF:\n", 680 | " if \"nIter\" not in tfidfKwargs:\n", 681 | " tfidfKwargs[\"nIter\"] = defaultTFIDFNIter\n", 682 | " h = objectToHash(['TFIDF', tfidfKwargs, dataHash])\n", 683 | " if h in featuresCache:\n", 684 | " features.append(featuresCache[h])\n", 685 | " else:\n", 686 | " data = flatTruncDocs if tfidfKwargs['truncate'] else docs\n", 687 | " tfidfInstance = TFIDF(data, doLower=tfidfKwargs['doLower'], logger=logger, verbose=False)\n", 688 | " tfidfData = tfidfInstance.getTFIDFMatrix()\n", 689 | " svd = TruncatedSVD(n_components=tfidfKwargs['nComponents'],\n", 690 | " n_iter=tfidfKwargs['nIter'],\n", 691 | " random_state=42)\n", 692 | " svdTFIDFData = svd.fit_transform(tfidfData)\n", 693 | " features.append(svdTFIDFData)\n", 694 | " featuresCache[h] = svdTFIDFData\n", 695 | " # Doc2Vec:\n", 696 | " if useDoc2Vec:\n", 697 | " h = objectToHash(['Doc2Vec', doc2VecKwargs, dataHash])\n", 698 | " if h in featuresCache:\n", 699 | " features.append(featuresCache[h])\n", 700 | " else:\n", 701 | " data = flatTruncLowDocs if doc2VecKwargs['truncate'] else flatLowDocs\n", 702 | " d2vModel = Doc2Vec.load(sortedGlob(d2vPath + \"/*model*.d2v\")[0])\n", 703 | " d2vData = d2vTokenssToEmbeddings(data, d2vModel, logger=logger, verbose=False)\n", 704 | " features.append(d2vData)\n", 705 | " featuresCache[h] = d2vData\n", 706 | " # Stylo:\n", 707 | " if useStylo:\n", 708 | " h = objectToHash(['Stylo', styloKwargs, dataHash])\n", 709 | " if h in featuresCache:\n", 710 | " features.append(featuresCache[h])\n", 711 | " else:\n", 712 | " styloVectors = []\n", 713 | " for text in pb(detokDocs, logger=logger, message=\"Getting stylo features\", verbose=verbose):\n", 714 | " styloVectors.append(stylo(text, asNpArray=True))\n", 715 | " styloVectors = np.array(styloVectors)\n", 716 | " features.append(styloVectors)\n", 717 | " featuresCache[h] = styloVectors\n", 718 | " # Usent:\n", 719 | " if useUsent:\n", 720 | " h = objectToHash(['Usent', usentKwargs, dataHash])\n", 721 | " if h in featuresCache:\n", 722 | " features.append(featuresCache[h])\n", 723 | " else:\n", 724 | " # We get all embeddgins from usent:\n", 725 | " allHashes = set()\n", 726 | " for doc in docs:\n", 727 | " docHash = objectToHash(doc)\n", 728 | " allHashes.add(docHash)\n", 729 | " for sentence in doc:\n", 730 | " theHash = objectToHash(sentence)\n", 731 | " allHashes.add(theHash)\n", 732 | " usentEmbeddings = dict()\n", 733 | " for usentEmbeddingsFile in pb(sortedGlob(usentEmbeddingsPattern),\n", 734 | " printRatio=0.1, logger=logger, message=\"Getting Usent embeddings from all files\"):\n", 735 | " current = deserialize(usentEmbeddingsFile)\n", 736 | " for theHash, value in current.items():\n", 737 | " if theHash in allHashes:\n", 738 | " usentEmbeddings[theHash] = value\n", 739 | " assert len(allHashes) == len(usentEmbeddings)\n", 740 | " if usentKwargs['operation'] == \"full\":\n", 741 | " data = []\n", 742 | " for doc in docs:\n", 743 | " theHash = objectToHash(doc)\n", 744 | " data.append(usentEmbeddings[theHash])\n", 745 | " data = np.array(data)\n", 746 | " elif usentKwargs['operation'] == \"mean\":\n", 747 | " data = []\n", 748 | " for doc in docs:\n", 749 | " docEmbeddings = []\n", 750 | " for sentence in doc:\n", 751 | " theHash = objectToHash(sentence)\n", 752 | " docEmbeddings.append(usentEmbeddings[theHash])\n", 753 | " docEmbedding = np.mean(docEmbeddings, axis=0)\n", 754 | " data.append(docEmbedding)\n", 755 | " data = np.array(data)\n", 756 | " features.append(data)\n", 757 | " featuresCache[h] = data\n", 758 | " # InferSent:\n", 759 | " if useInferSent:\n", 760 | " h = objectToHash(['InferSent', inferSentKwargs, dataHash])\n", 761 | " if h in featuresCache:\n", 762 | " features.append(featuresCache[h])\n", 763 | " else:\n", 764 | " V = inferSentKwargs['V']\n", 765 | " operation = inferSentKwargs['operation']\n", 766 | " MODEL_PATH = inferSentRoot + '/infersent%s.pkl' % V\n", 767 | " params_model = {'bsize': 64, 'word_emb_dim': 300, 'enc_lstm_dim': 2048,\n", 768 | " 'pool_type': 'max', 'dpout_model': 0.0, 'version': V}\n", 769 | " infersent = InferSent(params_model)\n", 770 | " infersent.load_state_dict(torch.load(MODEL_PATH))\n", 771 | " if V == 2:\n", 772 | " W2V_PATH = inferSentRoot + '/fastText/crawl-300d-2M.vec'\n", 773 | " else:\n", 774 | " W2V_PATH = inferSentRoot + '/GloVe/glove.840B.300d.txt'\n", 775 | " infersent.set_w2v_path(W2V_PATH)\n", 776 | " infersent.build_vocab(detokDocs, tokenize=True)\n", 777 | " if operation == \"full\":\n", 778 | " detokDocsForInferSent = []\n", 779 | " for current in detokDocs:\n", 780 | " detokDocsForInferSent.append(current[:10000])\n", 781 | " isData = infersent.encode(detokDocsForInferSent, tokenize=True)\n", 782 | " elif operation == \"mean\":\n", 783 | " isData = []\n", 784 | " for currentDetokSentences in detokSentences:\n", 785 | " embedding = np.mean(infersent.encode(currentDetokSentences, tokenize=True), axis=0)\n", 786 | " isData.append(embedding)\n", 787 | " isData = np.array(isData)\n", 788 | " features.append(isData)\n", 789 | " featuresCache[h] = isData\n", 790 | " # BERT:\n", 791 | " if useBERT:\n", 792 | " h = objectToHash(['BERT', bertKwargs, dataHash])\n", 793 | " if h in featuresCache:\n", 794 | " features.append(featuresCache[h])\n", 795 | " else:\n", 796 | " (user, password, host) = getOctodsMongoAuth()\n", 797 | " bertCache = SerializableDict(\"newsid-bretcache\",\n", 798 | " useMongodb=True,\n", 799 | " user=user, password=password, host=host,\n", 800 | " logger=logger)\n", 801 | " bertData = []\n", 802 | " for doc in docs:\n", 803 | " theHash = objectToHash(doc)\n", 804 | " current = bertCache[theHash]\n", 805 | " bertData.append(current)\n", 806 | " bertData = np.array(bertData)\n", 807 | " features.append(bertData)\n", 808 | " featuresCache[h] = bertData\n", 809 | " # Sent2Vec:\n", 810 | " if useSent2Vec:\n", 811 | " if sent2VecKwargs is None:\n", 812 | " sent2VecKwargs = dict()\n", 813 | " if 'modelName' not in sent2VecKwargs:\n", 814 | " sent2VecKwargs['modelName'] = defaultSent2VecModelName\n", 815 | " h = objectToHash(['Sent2Vec', sent2VecKwargs, dataHash])\n", 816 | " if h in featuresCache:\n", 817 | " features.append(featuresCache[h])\n", 818 | " else:\n", 819 | " modelName = sent2VecKwargs['modelName']\n", 820 | " s2vModel = sent2vec.Sent2vecModel()\n", 821 | " s2vModel.load_model(sent2VecRoot + '/' + modelName)\n", 822 | " operation = sent2VecKwargs['operation']\n", 823 | " if operation == \"full\":\n", 824 | " s2vData = s2vModel.embed_sentences(detokDocs)\n", 825 | " elif operation == \"mean\":\n", 826 | " s2vData = []\n", 827 | " for currentDetokSentences in detokSentences:\n", 828 | " embedding = np.mean(s2vModel.embed_sentences(currentDetokSentences), axis=0)\n", 829 | " s2vData.append(embedding)\n", 830 | " s2vData = np.array(s2vData)\n", 831 | " try:\n", 832 | " s2vModel.release_shared_mem(sent2VecRoot + '/' + modelName)\n", 833 | " s2vModel = None\n", 834 | " except Exception as e:\n", 835 | " logException(e, logger)\n", 836 | " features.append(s2vData)\n", 837 | " featuresCache[h] = s2vData\n", 838 | " # Concatenation:\n", 839 | " features = np.concatenate(features, axis=1)\n", 840 | " return features" 841 | ] 842 | }, 843 | { 844 | "cell_type": "code", 845 | "execution_count": null, 846 | "metadata": {}, 847 | "outputs": [], 848 | "source": [] 849 | }, 850 | { 851 | "cell_type": "code", 852 | "execution_count": null, 853 | "metadata": {}, 854 | "outputs": [], 855 | "source": [] 856 | }, 857 | { 858 | "cell_type": "markdown", 859 | "metadata": {}, 860 | "source": [ 861 | "# Execution" 862 | ] 863 | }, 864 | { 865 | "cell_type": "code", 866 | "execution_count": null, 867 | "metadata": {}, 868 | "outputs": [], 869 | "source": [ 870 | "def getDoubleCombinasons():\n", 871 | " for i in range(len(uniqueUses)):\n", 872 | " for u in range(i + 1, len(uniqueUses)):\n", 873 | " yield (uniqueUses[i], uniqueUses[u])" 874 | ] 875 | }, 876 | { 877 | "cell_type": "code", 878 | "execution_count": null, 879 | "metadata": {}, 880 | "outputs": [], 881 | "source": [ 882 | "def getScore(features, labels, validationRatio=0.3, doFigShow=isNotebook, logger=None, verbose=True):\n", 883 | " clf = linear_model.SGDClassifier()\n", 884 | " xTrain, xTest, yTrain, yTest = train_test_split(features, labels, test_size=validationRatio, random_state=42)\n", 885 | " scores = scikitLearnFit(clf, xTrain, yTrain, xTest, yTest, doFigShow=doFigShow, doFigSave=False, logger=logger)\n", 886 | " log(\"Features shape: \" + str(features.shape), logger=logger, verbose=verbose)\n", 887 | " return max(scores)" 888 | ] 889 | }, 890 | { 891 | "cell_type": "code", 892 | "execution_count": null, 893 | "metadata": {}, 894 | "outputs": [], 895 | "source": [ 896 | "uniqueUses = ['useNMF', 'useLDA', 'useDbert', 'useTFIDF', 'useStylo', 'useDeepStyle', 'useBERT', 'useDoc2Vec', 'useUsent', 'useInferSent', 'useSent2Vec']\n", 897 | "allKwargs = \\\n", 898 | "{\n", 899 | " 'nmfKwargs': {},\n", 900 | " 'ldaKwargs': {},\n", 901 | " # 'dbertKwargs': {'operation': 'mean', 'layer': 'distilbert', 'modelName': '94bef_ep32'},\n", 902 | " 'dbertKwargs': {'operation': 'mean', 'layer': None, 'modelName': None},\n", 903 | " 'deepStyleKwargs': {'modelPattern': '6ebdd3e05d4388c658ca2d5c53b0bc36'},\n", 904 | " 'tfidfKwargs': {'truncate': False, 'doLower': True, 'nComponents': 50},\n", 905 | " 'doc2VecKwargs': {'truncate': False},\n", 906 | " 'styloKwargs': None,\n", 907 | " 'usentKwargs': {'operation': 'mean'},\n", 908 | " 'inferSentKwargs': {'V': 1, 'operation': 'mean'},\n", 909 | " 'bertKwargs': None,\n", 910 | " 'sent2VecKwargs': {'operation': 'mean'},\n", 911 | "}\n", 912 | "useToKwargsMap = \\\n", 913 | "{\n", 914 | " 'useNMF': 'nmfKwargs',\n", 915 | " 'useLDA': 'ldaKwargs',\n", 916 | " 'useDbert': 'dbertKwargs',\n", 917 | " 'useDeepStyle': 'deepStyleKwargs',\n", 918 | " 'useTFIDF': 'tfidfKwargs',\n", 919 | " 'useDoc2Vec': 'doc2VecKwargs',\n", 920 | " 'useStylo': 'styloKwargs',\n", 921 | " 'useUsent': 'usentKwargs',\n", 922 | " 'useInferSent': 'inferSentKwargs',\n", 923 | " 'useBERT': 'bertKwargs',\n", 924 | " 'useSent2Vec': 'sent2VecKwargs',\n", 925 | "}" 926 | ] 927 | }, 928 | { 929 | "cell_type": "code", 930 | "execution_count": null, 931 | "metadata": {}, 932 | "outputs": [], 933 | "source": [ 934 | "combinasons = [(e,) for e in uniqueUses] + list(getDoubleCombinasons())\n", 935 | "bp(combinasons, 5, logger)\n", 936 | "log(\"Count of combinasons: \" + str(len(combinasons)), logger)\n", 937 | "log(\"Count of points: \" + str(len(combinasons) * len(files)), logger)" 938 | ] 939 | }, 940 | { 941 | "cell_type": "code", 942 | "execution_count": null, 943 | "metadata": {}, 944 | "outputs": [], 945 | "source": [ 946 | "if True:\n", 947 | " logWarning(\"Removing combs that do not have useDbert\", logger)\n", 948 | " newCombinasons = []\n", 949 | " for current in combinasons:\n", 950 | " # if 'useDbert' in current and ('useNMF' in current or 'useLDA' in current):\n", 951 | " # if 'useNMF' in current or 'useLDA' in current:\n", 952 | " if 'useDbert' in current:\n", 953 | " # if 'useDbert' in current and 'useNMF' in current:\n", 954 | " newCombinasons.append(current)\n", 955 | " combinasons = newCombinasons\n", 956 | " bp(combinasons, 3, logger)" 957 | ] 958 | }, 959 | { 960 | "cell_type": "code", 961 | "execution_count": null, 962 | "metadata": {}, 963 | "outputs": [], 964 | "source": [ 965 | "if isNotebook:\n", 966 | " combinasons = combinasons[:1]\n", 967 | " bp(combinasons, 3, logger)" 968 | ] 969 | }, 970 | { 971 | "cell_type": "code", 972 | "execution_count": null, 973 | "metadata": {}, 974 | "outputs": [], 975 | "source": [ 976 | "documentsArgs = \\\n", 977 | "(\n", 978 | " docs,\n", 979 | " flatDocs,\n", 980 | " flatLowDocs,\n", 981 | " flatTruncDocs,\n", 982 | " flatTruncLowDocs,\n", 983 | " detokDocs,\n", 984 | " detokSentences,\n", 985 | ")" 986 | ] 987 | }, 988 | { 989 | "cell_type": "code", 990 | "execution_count": null, 991 | "metadata": {}, 992 | "outputs": [], 993 | "source": [ 994 | "dataHash = objectToHash(documentsArgs)\n", 995 | "bp('dataHash: ' + dataHash, logger)" 996 | ] 997 | }, 998 | { 999 | "cell_type": "code", 1000 | "execution_count": null, 1001 | "metadata": {}, 1002 | "outputs": [], 1003 | "source": [ 1004 | "# AA:\n", 1005 | "if True:\n", 1006 | " results = SerializableDict('asa2-aa', useMongodb=True, user=user, password=password, host=host, logger=logger)\n", 1007 | " for comb in pb(combinasons, logger=logger):\n", 1008 | " currentKwargs = copy.deepcopy(allKwargs)\n", 1009 | " for unique in comb:\n", 1010 | " currentKwargs[unique] = True\n", 1011 | " result = copy.deepcopy(currentKwargs)\n", 1012 | " for key in useToKwargsMap:\n", 1013 | " if not (dictContains(result, key) and result[key]):\n", 1014 | " del result[useToKwargsMap[key]]\n", 1015 | " result['uset'] = uset\n", 1016 | " result['dataCol'] = dataCol\n", 1017 | " theHash = objectToHash(result)\n", 1018 | " if theHash not in results:\n", 1019 | " features = getAsaFeatures2\\\n", 1020 | " (\n", 1021 | " *documentsArgs,\n", 1022 | " logger=logger,\n", 1023 | " verbose=True,\n", 1024 | " dataHash=dataHash,\n", 1025 | " **currentKwargs,\n", 1026 | " )\n", 1027 | " score = getScore(features, indexLabels, logger=logger, doFigShow=False)\n", 1028 | " log(\"Score of \" + str(comb) + \": \" + str(truncateFloat(score, 3)), logger)\n", 1029 | " result['score'] = score\n", 1030 | " results[theHash] = result\n", 1031 | " else:\n", 1032 | " log(\"Found the score of \" + str(comb) + \": \" + str(truncateFloat(results[theHash]['score'], 3)), logger)" 1033 | ] 1034 | }, 1035 | { 1036 | "cell_type": "code", 1037 | "execution_count": null, 1038 | "metadata": {}, 1039 | "outputs": [], 1040 | "source": [ 1041 | "# Clustering:\n", 1042 | "if False:\n", 1043 | " results = SerializableDict('asa2-clustering-comb', useMongodb=True, user=user, password=password, host=host, logger=logger)\n", 1044 | " for comb in pb(combinasons, logger=logger):\n", 1045 | " currentKwargs = copy.deepcopy(allKwargs)\n", 1046 | " for unique in comb:\n", 1047 | " currentKwargs[unique] = True\n", 1048 | " result = copy.deepcopy(currentKwargs)\n", 1049 | " for key in useToKwargsMap:\n", 1050 | " if not (dictContains(result, key) and result[key]):\n", 1051 | " del result[useToKwargsMap[key]]\n", 1052 | " result['uset'] = uset\n", 1053 | " result['dataCol'] = dataCol\n", 1054 | " theHash = objectToHash(result)\n", 1055 | " if theHash not in results:\n", 1056 | " features = getAsaFeatures2\\\n", 1057 | " (\n", 1058 | " *documentsArgs,\n", 1059 | " logger=logger,\n", 1060 | " verbose=True,\n", 1061 | " dataHash=dataHash,\n", 1062 | " **currentKwargs,\n", 1063 | " )\n", 1064 | " # We compute and store score:\n", 1065 | " data = features\n", 1066 | " simMatrix = pairwiseCosineSimilarity(data)\n", 1067 | " score = pairwiseSimNDCG(simMatrix, indexLabels)\n", 1068 | " log(\"SimRank of \" + str(comb) + \": \" + str(truncateFloat(score, 3)), logger)\n", 1069 | " # calharScore = metrics.calinski_harabasz_score(data, indexLabels)\n", 1070 | " # log(\"CalHar: \" + str(calharScore), logger)\n", 1071 | " # davbScore = metrics.davies_bouldin_score(data, indexLabels)\n", 1072 | " # log(\"DavB: \" + str(davbScore), logger)\n", 1073 | " # Adding results:\n", 1074 | " result['score'] = score\n", 1075 | " results[theHash] = result\n", 1076 | " else:\n", 1077 | " log(\"Found the score of \" + str(comb) + \": \" + str(truncateFloat(results[theHash]['score'], 3)), logger)" 1078 | ] 1079 | }, 1080 | { 1081 | "cell_type": "code", 1082 | "execution_count": null, 1083 | "metadata": {}, 1084 | "outputs": [], 1085 | "source": [ 1086 | "if False:\n", 1087 | " allKwargs['dbertKwargs'] = {'operation': 'mean', 'layer': None, 'modelName': None}\n", 1088 | " currentKwargs = copy.deepcopy(allKwargs)\n", 1089 | " currentKwargs['useDbert'] = True\n", 1090 | " dbertBaseFeatures = getAsaFeatures2\\\n", 1091 | " (\n", 1092 | " *documentsArgs,\n", 1093 | " logger=logger,\n", 1094 | " verbose=True,\n", 1095 | " dataHash=dataHash,\n", 1096 | " **currentKwargs,\n", 1097 | " )\n", 1098 | " bp(dbertBaseFeatures, logger)\n", 1099 | " allKwargs['dbertKwargs'] = {'operation': 'mean', 'layer': 'distilbert', 'modelName': '94bef_ep32'}\n", 1100 | " currentKwargs = copy.deepcopy(allKwargs)\n", 1101 | " currentKwargs['useDbert'] = True\n", 1102 | " dbertEp32Features = getAsaFeatures2\\\n", 1103 | " (\n", 1104 | " *documentsArgs,\n", 1105 | " logger=logger,\n", 1106 | " verbose=True,\n", 1107 | " dataHash=dataHash,\n", 1108 | " **currentKwargs,\n", 1109 | " )\n", 1110 | " bp(dbertEp32Features, logger)" 1111 | ] 1112 | }, 1113 | { 1114 | "cell_type": "code", 1115 | "execution_count": null, 1116 | "metadata": {}, 1117 | "outputs": [], 1118 | "source": [ 1119 | "if False:\n", 1120 | " features = np.concatenate([dbertEp32Features, dbertBaseFeatures], axis=1)\n", 1121 | " bp(dbertEp32Features, logger)\n", 1122 | " log(features.shape, logger)\n", 1123 | " score = getScore(features, indexLabels, logger=logger, doFigShow=False)\n", 1124 | " log(\"Score of bdert base and 94bef_ep32 for \" + uset + \" --> \" + str(score), logger)" 1125 | ] 1126 | }, 1127 | { 1128 | "cell_type": "code", 1129 | "execution_count": null, 1130 | "metadata": {}, 1131 | "outputs": [], 1132 | "source": [ 1133 | "if False:\n", 1134 | " features = np.concatenate([dbertEp32Features, dbertBaseFeatures], axis=1)\n", 1135 | " data = features\n", 1136 | " simMatrix = pairwiseCosineSimilarity(data)\n", 1137 | " score = pairwiseSimNDCG(simMatrix, indexLabels)\n", 1138 | " log(\"simrank_uuu Score of bdert base and 94bef_ep32 for \" + uset + \" --> \" + str(score), logger)" 1139 | ] 1140 | }, 1141 | { 1142 | "cell_type": "code", 1143 | "execution_count": null, 1144 | "metadata": {}, 1145 | "outputs": [], 1146 | "source": [] 1147 | }, 1148 | { 1149 | "cell_type": "code", 1150 | "execution_count": null, 1151 | "metadata": {}, 1152 | "outputs": [], 1153 | "source": [] 1154 | }, 1155 | { 1156 | "cell_type": "markdown", 1157 | "metadata": {}, 1158 | "source": [ 1159 | "# End" 1160 | ] 1161 | }, 1162 | { 1163 | "cell_type": "code", 1164 | "execution_count": null, 1165 | "metadata": {}, 1166 | "outputs": [], 1167 | "source": [ 1168 | "tt.toc()" 1169 | ] 1170 | }, 1171 | { 1172 | "cell_type": "code", 1173 | "execution_count": null, 1174 | "metadata": {}, 1175 | "outputs": [], 1176 | "source": [ 1177 | "if not isNotebook:\n", 1178 | " exit()" 1179 | ] 1180 | }, 1181 | { 1182 | "cell_type": "code", 1183 | "execution_count": null, 1184 | "metadata": {}, 1185 | "outputs": [], 1186 | "source": [] 1187 | }, 1188 | { 1189 | "cell_type": "code", 1190 | "execution_count": null, 1191 | "metadata": {}, 1192 | "outputs": [], 1193 | "source": [] 1194 | }, 1195 | { 1196 | "cell_type": "markdown", 1197 | "metadata": {}, 1198 | "source": [ 1199 | "# Tables" 1200 | ] 1201 | }, 1202 | { 1203 | "cell_type": "markdown", 1204 | "metadata": {}, 1205 | "source": [ 1206 | "### Uniques" 1207 | ] 1208 | }, 1209 | { 1210 | "cell_type": "code", 1211 | "execution_count": null, 1212 | "metadata": {}, 1213 | "outputs": [], 1214 | "source": [ 1215 | "results = SerializableDict('asa2-aa', useMongodb=True, user=user, password=password, host=host, logger=logger)\n", 1216 | "# results = SerializableDict('asa2-clustering-comb', useMongodb=True, user=user, password=password, host=host, logger=logger)" 1217 | ] 1218 | }, 1219 | { 1220 | "cell_type": "code", 1221 | "execution_count": null, 1222 | "metadata": {}, 1223 | "outputs": [], 1224 | "source": [ 1225 | "def extractFeatures(o):\n", 1226 | " k = set(o.keys()) if isinstance(o, dict) else set(o)\n", 1227 | " return [k for k in o if re.search(\"^use[A-Z].*$\", k)]\n", 1228 | "def featuresCount(*args, **kwargs):\n", 1229 | " return len(extractFeatures(*args, **kwargs))" 1230 | ] 1231 | }, 1232 | { 1233 | "cell_type": "code", 1234 | "execution_count": null, 1235 | "metadata": {}, 1236 | "outputs": [], 1237 | "source": [ 1238 | "def getData(fCount=None, usetPattern=None):\n", 1239 | " if usetPattern is None:\n", 1240 | " usetPattern = '.*'\n", 1241 | " if \".*\" not in usetPattern:\n", 1242 | " usetPattern = \".*\" + usetPattern + \".*\"\n", 1243 | " if '^' not in usetPattern:\n", 1244 | " usetPattern = '^' + usetPattern\n", 1245 | " if '$' not in usetPattern:\n", 1246 | " usetPattern = usetPattern + '$'\n", 1247 | " r = []\n", 1248 | " for _, e in results.items():\n", 1249 | " if re.match(usetPattern, e['uset']) and (fCount is None or featuresCount(e) == fCount):\n", 1250 | " r.append(e)\n", 1251 | " return r" 1252 | ] 1253 | }, 1254 | { 1255 | "cell_type": "code", 1256 | "execution_count": null, 1257 | "metadata": {}, 1258 | "outputs": [], 1259 | "source": [ 1260 | "# r = getData(1, uset)\n", 1261 | "# r = getData(1, 'd18-bc10')\n", 1262 | "# r = getData(1, '.*')\n", 1263 | "# r = getData(1, 'blogger')\n", 1264 | "# r = getData(1, 'livejournal')\n", 1265 | "r = getData(1, 'washington')\n", 1266 | "# r = getData(1, 'breitbart')\n", 1267 | "# r = getData(1, 'business')\n", 1268 | "# r = getData(1, 'cnn')\n", 1269 | "# r = getData(1, 'guardian.co.uk')\n", 1270 | "# r = getData(1, 'theguardian.com')\n", 1271 | "# r = getData(1, 'nytimes')" 1272 | ] 1273 | }, 1274 | { 1275 | "cell_type": "code", 1276 | "execution_count": null, 1277 | "metadata": {}, 1278 | "outputs": [], 1279 | "source": [ 1280 | "def filterDbert(r, dbertKwargsFilter, logger=None, verbose=True):\n", 1281 | " logWarning(\"Filtering DBert with \" + str(dbertKwargsFilter), logger)\n", 1282 | " deletedCount = 0\n", 1283 | " acceptedCount = 0\n", 1284 | " newR = []\n", 1285 | " for current in r:\n", 1286 | " if dictContains(current, 'useDbert') and current['useDbert']:\n", 1287 | " ok = True\n", 1288 | " for key in dbertKwargsFilter:\n", 1289 | " if current['dbertKwargs'][key] != dbertKwargsFilter[key]:\n", 1290 | " ok = False\n", 1291 | " break\n", 1292 | " if ok:\n", 1293 | " newR.append(current)\n", 1294 | " acceptedCount += 1\n", 1295 | " else:\n", 1296 | " deletedCount += 1\n", 1297 | " else:\n", 1298 | " newR.append(current)\n", 1299 | " log(\"deletedCount: \" + str(deletedCount), logger)\n", 1300 | " log(\"acceptedCount: \" + str(acceptedCount), logger)\n", 1301 | " return newR" 1302 | ] 1303 | }, 1304 | { 1305 | "cell_type": "code", 1306 | "execution_count": null, 1307 | "metadata": {}, 1308 | "outputs": [], 1309 | "source": [ 1310 | "# r = filterDbert(r, {'operation': 'mean', 'layer': None, 'modelName': None}, logger=logger)\n", 1311 | "r = filterDbert(r, {'operation': 'mean', 'layer': 'distilbert', 'modelName': '94bef_ep32'}, logger=logger)" 1312 | ] 1313 | }, 1314 | { 1315 | "cell_type": "code", 1316 | "execution_count": null, 1317 | "metadata": { 1318 | "scrolled": true 1319 | }, 1320 | "outputs": [], 1321 | "source": [ 1322 | "bp(r, 5, logger)" 1323 | ] 1324 | }, 1325 | { 1326 | "cell_type": "code", 1327 | "execution_count": null, 1328 | "metadata": {}, 1329 | "outputs": [], 1330 | "source": [ 1331 | "models = dict()" 1332 | ] 1333 | }, 1334 | { 1335 | "cell_type": "code", 1336 | "execution_count": null, 1337 | "metadata": {}, 1338 | "outputs": [], 1339 | "source": [ 1340 | "log(\"Amount of usets: \" + str(len(set([e['uset'] for e in r]))))\n", 1341 | "log(\"Amount of models: \" + str(len(set([str(extractFeatures(e)) for e in r]))))" 1342 | ] 1343 | }, 1344 | { 1345 | "cell_type": "code", 1346 | "execution_count": null, 1347 | "metadata": {}, 1348 | "outputs": [], 1349 | "source": [ 1350 | "for current in r:\n", 1351 | " f = extractFeatures(current)[0]\n", 1352 | " if f not in models:\n", 1353 | " models[f] = []\n", 1354 | " models[f].append(current['score'])\n", 1355 | "bp(models, 5)" 1356 | ] 1357 | }, 1358 | { 1359 | "cell_type": "code", 1360 | "execution_count": null, 1361 | "metadata": {}, 1362 | "outputs": [], 1363 | "source": [ 1364 | "for key, values in models.items():\n", 1365 | " print(len(values))" 1366 | ] 1367 | }, 1368 | { 1369 | "cell_type": "code", 1370 | "execution_count": null, 1371 | "metadata": {}, 1372 | "outputs": [], 1373 | "source": [ 1374 | "for key in models.keys():\n", 1375 | " models[key] = float(np.mean(models[key]))\n", 1376 | "for m, s in sortBy(models.items(), 1): print(m, s)" 1377 | ] 1378 | }, 1379 | { 1380 | "cell_type": "code", 1381 | "execution_count": null, 1382 | "metadata": {}, 1383 | "outputs": [], 1384 | "source": [] 1385 | }, 1386 | { 1387 | "cell_type": "code", 1388 | "execution_count": null, 1389 | "metadata": {}, 1390 | "outputs": [], 1391 | "source": [] 1392 | }, 1393 | { 1394 | "cell_type": "markdown", 1395 | "metadata": {}, 1396 | "source": [ 1397 | "### Doubles" 1398 | ] 1399 | }, 1400 | { 1401 | "cell_type": "code", 1402 | "execution_count": null, 1403 | "metadata": {}, 1404 | "outputs": [], 1405 | "source": [ 1406 | "# results = SerializableDict('asa2-aa', useMongodb=True, user=user, password=password, host=host, logger=logger)\n", 1407 | "results = SerializableDict('asa2-clustering-comb', useMongodb=True, user=user, password=password, host=host, logger=logger)" 1408 | ] 1409 | }, 1410 | { 1411 | "cell_type": "code", 1412 | "execution_count": null, 1413 | "metadata": {}, 1414 | "outputs": [], 1415 | "source": [ 1416 | "r = getData(2, '.*')\n", 1417 | "# r = getData(2, 'd18-bc10')\n", 1418 | "# r = getData(2, 'blogger')\n", 1419 | "# r = getData(2, 'livejournal')\n", 1420 | "# r = getData(2, 'washington')\n", 1421 | "# r = getData(2, 'breitbart')\n", 1422 | "# r = getData(2, 'business')\n", 1423 | "# r = getData(2, 'cnn')\n", 1424 | "# r = getData(2, 'guardian.co.uk')\n", 1425 | "# r = getData(2, 'theguardian.com')\n", 1426 | "# r = getData(2, 'nytimes')" 1427 | ] 1428 | }, 1429 | { 1430 | "cell_type": "code", 1431 | "execution_count": null, 1432 | "metadata": {}, 1433 | "outputs": [], 1434 | "source": [ 1435 | "r = filterDbert(r, {'operation': 'mean', 'layer': None, 'modelName': None}, logger=logger)\n", 1436 | "# r = filterDbert(r, {'operation': 'mean', 'layer': 'distilbert', 'modelName': '94bef_ep32'}, logger=logger)" 1437 | ] 1438 | }, 1439 | { 1440 | "cell_type": "code", 1441 | "execution_count": null, 1442 | "metadata": {}, 1443 | "outputs": [], 1444 | "source": [ 1445 | "log(\"Amount of usets: \" + str(len(set([e['uset'] for e in r]))))\n", 1446 | "log(\"Amount of models: \" + str(len(set([str(extractFeatures(e)) for e in r]))))" 1447 | ] 1448 | }, 1449 | { 1450 | "cell_type": "code", 1451 | "execution_count": null, 1452 | "metadata": {}, 1453 | "outputs": [], 1454 | "source": [ 1455 | "models = dict()" 1456 | ] 1457 | }, 1458 | { 1459 | "cell_type": "code", 1460 | "execution_count": null, 1461 | "metadata": {}, 1462 | "outputs": [], 1463 | "source": [ 1464 | "for current in r:\n", 1465 | " f = str(extractFeatures(current))\n", 1466 | " if f not in models:\n", 1467 | " models[f] = []\n", 1468 | " models[f].append(current['score'])\n", 1469 | "bp(models)" 1470 | ] 1471 | }, 1472 | { 1473 | "cell_type": "code", 1474 | "execution_count": null, 1475 | "metadata": {}, 1476 | "outputs": [], 1477 | "source": [ 1478 | "for e in sorted(models.keys()):\n", 1479 | " print(e)" 1480 | ] 1481 | }, 1482 | { 1483 | "cell_type": "code", 1484 | "execution_count": null, 1485 | "metadata": {}, 1486 | "outputs": [], 1487 | "source": [ 1488 | "for key, values in models.items():\n", 1489 | " print(len(values))" 1490 | ] 1491 | }, 1492 | { 1493 | "cell_type": "code", 1494 | "execution_count": null, 1495 | "metadata": {}, 1496 | "outputs": [], 1497 | "source": [ 1498 | "for key in models.keys():\n", 1499 | " models[key] = float(np.mean(models[key]))" 1500 | ] 1501 | }, 1502 | { 1503 | "cell_type": "code", 1504 | "execution_count": null, 1505 | "metadata": {}, 1506 | "outputs": [], 1507 | "source": [ 1508 | "for m, s in sortBy(models.items(), 1):\n", 1509 | " print(m, s)" 1510 | ] 1511 | }, 1512 | { 1513 | "cell_type": "code", 1514 | "execution_count": null, 1515 | "metadata": {}, 1516 | "outputs": [], 1517 | "source": [] 1518 | }, 1519 | { 1520 | "cell_type": "code", 1521 | "execution_count": null, 1522 | "metadata": {}, 1523 | "outputs": [], 1524 | "source": [] 1525 | } 1526 | ], 1527 | "metadata": { 1528 | "kernelspec": { 1529 | "display_name": "Python 3", 1530 | "language": "python", 1531 | "name": "python3" 1532 | }, 1533 | "language_info": { 1534 | "codemirror_mode": { 1535 | "name": "ipython", 1536 | "version": 3 1537 | }, 1538 | "file_extension": ".py", 1539 | "mimetype": "text/x-python", 1540 | "name": "python", 1541 | "nbconvert_exporter": "python", 1542 | "pygments_lexer": "ipython3", 1543 | "version": "3.6.3" 1544 | } 1545 | }, 1546 | "nbformat": 4, 1547 | "nbformat_minor": 2 1548 | } 1549 | -------------------------------------------------------------------------------- /experiments/notebooks/dbert-ft-train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# DBert train\n", 8 | "\n", 9 | "From " 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Commands" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "# Titanv tf install:\n", 26 | "# !pip freeze | grep flow\n", 27 | "# !pip install --upgrade pip\n", 28 | "# !pip uninstall --y tensorboard tensorflow-estimator tensorflow tensorflow-gpu\n", 29 | "# !pip install --upgrade tensorflow==2.0.0\n", 30 | "# !pip install --upgrade tensorflow-gpu==2.0.0\n", 31 | "# !pip install --upgrade transformers==2.4.1\n", 32 | "# !pip freeze | grep flow" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "# titanv 1:\n", 42 | "# screen -S dbert-train1\n", 43 | "# source ~/.bash_profile ; source ~/.bash_aliases ; cd ~/dbert-train-logs\n", 44 | "# DOCKER_PORT=9961 nn -o nohup-dbert-train-$HOSTNAME-1.out ~/docker/keras/run-jupython.sh ~/notebooks/asa/train/dbert-train.ipynb titanv\n", 45 | "# observe ~/dbert-train-logs/nohup-dbert-train-$HOSTNAME-1.out" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "# titanv 2:\n", 55 | "# screen -S dbert-train2\n", 56 | "# source ~/.bash_profile ; source ~/.bash_aliases ; cd ~/dbert-train-logs\n", 57 | "# DOCKER_PORT=9962 nn -o nohup-dbert-train-$HOSTNAME-2.out ~/docker/keras/run-jupython.sh ~/notebooks/asa/train/dbert-train.ipynb titanv\n", 58 | "# observe ~/dbert-train-logs/nohup-dbert-train-$HOSTNAME-2.out" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "# cd ; archive-notebooks ; cd ~/logs ; ./mv-old-logs.sh # optionnel\n", 68 | "# sbatch ~/slurm/run-notebook.sh ~/tmp/archives/notebooks/asa/train/dbert-train.ipynb\n", 69 | "# observe ~/logs/*.out" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "# cd ; archive-notebooks ; cd ~/logs\n", 79 | "# sbatch ~/slurm/run-notebook.sh ~/tmp/archives/notebooks/asa/train/dbert-train.ipynb\n", 80 | "# observe ~/logs/*.out" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "## Imports" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "isNotebook = '__file__' not in locals()" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "import os\n", 120 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "import logging\n", 130 | "import math\n", 131 | "import tensorflow as tf\n", 132 | "from tensorflow.keras import callbacks\n", 133 | "from transformers import \\\n", 134 | "(\n", 135 | " DistilBertConfig,\n", 136 | " DistilBertTokenizer,\n", 137 | " TFDistilBertForSequenceClassification,\n", 138 | ")" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "## Functions" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "def ksetGen\\\n", 169 | "(\n", 170 | " train=True,\n", 171 | " ksetRoot=dataDir() + '/Asa2/detok-kset' if lri() else homeDir() + \"/asa/asa2-data/detok-kset\",\n", 172 | " maxFiles=None,\n", 173 | " **kwargs,\n", 174 | "):\n", 175 | " # We find files:\n", 176 | " if train:\n", 177 | " files = sortedGlob(ksetRoot + '/train/*.bz2')\n", 178 | " else:\n", 179 | " files = sortedGlob(ksetRoot + '/validation/*.bz2')\n", 180 | " if maxFiles is not None:\n", 181 | " files = files[:maxFiles]\n", 182 | " # we return the generator:\n", 183 | " return genFunct(files, ksetRoot=ksetRoot, **kwargs)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "def genFunct\\\n", 193 | "(\n", 194 | " files,\n", 195 | " \n", 196 | " ksetRoot=dataDir() + '/Asa2/detok-kset' if lri() else homeDir() + \"/asa/asa2-data/detok-kset\",\n", 197 | " dataCol=\"filtered_detokenized_sentences\",\n", 198 | " labelField='label',\n", 199 | " \n", 200 | " labelEncoding='index',\n", 201 | " labelEncoder=None,\n", 202 | " \n", 203 | " maxSamples=None,\n", 204 | " maxSentences=None,\n", 205 | " \n", 206 | " preventTokenizerWarnings=True,\n", 207 | " loggerName=\"transformers.tokenization_utils\",\n", 208 | " \n", 209 | " logger=None,\n", 210 | " verbose=True,\n", 211 | " \n", 212 | " showProgress=False,\n", 213 | " \n", 214 | " multiSamplage=False,\n", 215 | " **encodeKwargs,\n", 216 | "):\n", 217 | " # Handling unique file:\n", 218 | " if not isinstance(files, list):\n", 219 | " files = [files]\n", 220 | " # Misc init:\n", 221 | " samplesCount = 0\n", 222 | " # We set the logger level:\n", 223 | " if preventTokenizerWarnings:\n", 224 | " previousLoggerLevel = logging.getLogger(loggerName).level\n", 225 | " logging.getLogger(loggerName).setLevel(logging.ERROR)\n", 226 | " if showProgress:\n", 227 | " pbar = ProgressBar(len(files), logger=logger, verbose=verbose)\n", 228 | " # We get labels and encode labels:\n", 229 | " if labelEncoder is None:\n", 230 | " labels = sorted(list(deserialize(ksetRoot + '/validation/labels.pickle')))\n", 231 | " (classes, labels) = encodeMulticlassLabels(labels, encoding=labelEncoding)\n", 232 | " labelEncoder = dict()\n", 233 | " assert len(classes) == len(labels)\n", 234 | " for i in range(len(classes)):\n", 235 | " labelEncoder[classes[i]] = labels[i]\n", 236 | " # For each file:\n", 237 | " for file in files:\n", 238 | " for row in NDJson(file):\n", 239 | " # We get sentences:\n", 240 | " sentences = row[dataCol]\n", 241 | " if not (isinstance(sentences, list) and len(sentences) > 1 and isinstance(sentences[0], str)):\n", 242 | " raise Exception(\"All row[dataCol] must be a list of strings (sentences)\")\n", 243 | " if maxSentences is not None:\n", 244 | " sentences = sentences[:maxSentences]\n", 245 | " # We encode the document:\n", 246 | " parts = tf2utils.distilBertEncode\\\n", 247 | " (\n", 248 | " sentences,\n", 249 | " multiSamplage=multiSamplage,\n", 250 | " preventTokenizerWarnings=False,\n", 251 | " proxies=proxies,\n", 252 | " logger=logger, verbose=verbose,\n", 253 | " **encodeKwargs,\n", 254 | " )\n", 255 | " if not multiSamplage:\n", 256 | " parts = [parts]\n", 257 | " # We yield all parts:\n", 258 | " for part in parts:\n", 259 | " yield (np.array(part), labelEncoder[row[labelField]])\n", 260 | " # yield (np.array([np.array(part), np.array(part)]), np.array([labelEncoder[row[labelField]], labelEncoder[row[labelField]]]))\n", 261 | " samplesCount += 1\n", 262 | " if maxSamples is not None and samplesCount >= maxSamples:\n", 263 | " break\n", 264 | " if showProgress:\n", 265 | " pbar.tic(file)\n", 266 | " if maxSamples is not None and samplesCount >= maxSamples:\n", 267 | " break\n", 268 | " # We reset the logger:\n", 269 | " if preventTokenizerWarnings:\n", 270 | " logging.getLogger(loggerName).setLevel(previousLoggerLevel)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "def saveFunct(model, directory, **kwargs):\n", 280 | " model.save_pretrained(directory)" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": null, 286 | "metadata": {}, 287 | "outputs": [], 288 | "source": [ 289 | "def getSamplesCount(logger=None, verbose=True):\n", 290 | " samplesCountCache = None\n", 291 | " (user, password, host) = getOctodsMongoAuth()\n", 292 | " samplesCountCache = SerializableDict('samples-count', user=user, host=host, password=password, useMongodb=True)\n", 293 | " samplesCountParams = \\\n", 294 | " {\n", 295 | " 'maxFiles': config['maxFiles'],\n", 296 | " 'maxSamples': config['maxSamples'],\n", 297 | " 'multiSamplage': config['multiSamplage'],\n", 298 | " 'maxLength': config['maxLength'],\n", 299 | " 'dataCol': config['dataCol'],\n", 300 | " }\n", 301 | " trainSamplesCountParams = mergeDicts(samplesCountParams, {'train': True})\n", 302 | " trainSamplesCountHash = objectToHash(trainSamplesCountParams)\n", 303 | " validationSamplesCountParams = mergeDicts(samplesCountParams, {'train': False})\n", 304 | " validationSamplesCountHash = objectToHash(validationSamplesCountParams)\n", 305 | " if samplesCountCache is not None and trainSamplesCountHash in samplesCountCache:\n", 306 | " trainSamplesCount = samplesCountCache[trainSamplesCountHash]\n", 307 | " else:\n", 308 | " log(\"Starting to count batches in the train set...\", logger, verbose=verbose)\n", 309 | " trainSamplesCount = 0\n", 310 | " for row in ksetGen\\\n", 311 | " (\n", 312 | " train=True,\n", 313 | " **samplesCountParams,\n", 314 | " showProgress=True,\n", 315 | " logger=logger,\n", 316 | " verbose=True,\n", 317 | " ):\n", 318 | " trainSamplesCount += 1\n", 319 | " if samplesCountCache is not None:\n", 320 | " samplesCountCache[trainSamplesCountHash] = trainSamplesCount\n", 321 | " if samplesCountCache is not None and validationSamplesCountHash in samplesCountCache:\n", 322 | " validationSamplesCount = samplesCountCache[validationSamplesCountHash]\n", 323 | " else:\n", 324 | " log(\"Starting to count batches in the validation set...\", logger, verbose=verbose)\n", 325 | " validationSamplesCount = 0\n", 326 | " for row in ksetGen\\\n", 327 | " (\n", 328 | " train=False,\n", 329 | " **samplesCountParams,\n", 330 | " showProgress=True,\n", 331 | " logger=logger,\n", 332 | " verbose=True,\n", 333 | " ):\n", 334 | " validationSamplesCount += 1\n", 335 | " if samplesCountCache is not None:\n", 336 | " samplesCountCache[validationSamplesCountHash] = validationSamplesCount\n", 337 | " return (trainSamplesCount, validationSamplesCount)" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": null, 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [] 346 | }, 347 | { 348 | "cell_type": "code", 349 | "execution_count": null, 350 | "metadata": {}, 351 | "outputs": [], 352 | "source": [] 353 | }, 354 | { 355 | "cell_type": "markdown", 356 | "metadata": {}, 357 | "source": [ 358 | "## Config" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [ 367 | "config = \\\n", 368 | "{\n", 369 | " 'dataCol': 'filtered_detokenized_sentences',\n", 370 | " 'ksetRoot': dataDir() + '/Asa2/detok-kset' if lri() else homeDir() + \"/asa/asa2-data/detok-kset\",\n", 371 | " 'multiSamplage': True,\n", 372 | " 'maxFiles': 30 if isNotebook else None,\n", 373 | " 'maxSamples': 5000 if isNotebook else None,\n", 374 | " 'maxLength': 512,\n", 375 | " 'batchSize': 16,\n", 376 | " \n", 377 | " 'learningRate': 3e-5,\n", 378 | " 'epsilon': 1e-08,\n", 379 | " 'clipnorm': 1.0,\n", 380 | " \n", 381 | " 'trainStepDivider': 2 if isNotebook else 30,\n", 382 | " 'shuffle': 0 if isNotebook else 100,\n", 383 | " 'queueSize': 100,\n", 384 | " \n", 385 | " 'useMLIterator': True,\n", 386 | "}" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": null, 392 | "metadata": {}, 393 | "outputs": [], 394 | "source": [ 395 | "ksetRoot = config['ksetRoot']" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": null, 401 | "metadata": {}, 402 | "outputs": [], 403 | "source": [ 404 | "outputDirRoot = homeDir() + '/asa/dbert-train'\n", 405 | "outputDir = outputDirRoot + '/' + objectToHash(config)[:5]\n", 406 | "mkdir(outputDir)" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": null, 412 | "metadata": {}, 413 | "outputs": [], 414 | "source": [ 415 | "if False:\n", 416 | " assert config['maxFiles'] == 3\n", 417 | " assert isNotebook\n", 418 | " remove(outputDir)" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": null, 424 | "metadata": {}, 425 | "outputs": [], 426 | "source": [ 427 | "logger = Logger(outputDir + '/dbert-train.log')\n", 428 | "log(\"outputDir: \" + str(outputDir), logger)" 429 | ] 430 | }, 431 | { 432 | "cell_type": "code", 433 | "execution_count": null, 434 | "metadata": {}, 435 | "outputs": [], 436 | "source": [ 437 | "trainFiles = sortedGlob(ksetRoot + '/train/*.bz2')\n", 438 | "validationFiles = sortedGlob(ksetRoot + '/validation/*.bz2')\n", 439 | "if config['maxFiles'] is not None:\n", 440 | " log(\"Reducing amount of train files from \" + str(len(trainFiles)) + \" to \" + str(config['maxFiles']), logger)\n", 441 | " trainFiles = trainFiles[:config['maxFiles']]\n", 442 | " log(\"Reducing amount of validation files from \" + str(len(validationFiles)) + \" to \" + str(config['maxFiles']), logger)\n", 443 | " validationFiles = validationFiles[:config['maxFiles']]\n", 444 | "bp(trainFiles, logger)\n", 445 | "bp(validationFiles, logger)" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": null, 451 | "metadata": {}, 452 | "outputs": [], 453 | "source": [] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": null, 458 | "metadata": {}, 459 | "outputs": [], 460 | "source": [] 461 | }, 462 | { 463 | "cell_type": "markdown", 464 | "metadata": {}, 465 | "source": [ 466 | "## Model" 467 | ] 468 | }, 469 | { 470 | "cell_type": "code", 471 | "execution_count": null, 472 | "metadata": {}, 473 | "outputs": [], 474 | "source": [ 475 | "# In case we reume a previous train:\n", 476 | "batchesPassed = 0\n", 477 | "initialEpoch = 0\n", 478 | "lastEpochPath = None\n", 479 | "if len(sortedGlob(outputDir + \"/epochs/ep*\")) > 0:\n", 480 | " lastEpochPath = sortedGlob(outputDir + \"/epochs/ep*\")[-1]\n", 481 | " batchesPassedPath = lastEpochPath + \"/batchesPassed.txt\"\n", 482 | " assert isFile(batchesPassedPath)\n", 483 | " assert not isFile(outputDir + \"/finished\")\n", 484 | " assert not isFile(outputDir + \"/stop\")\n", 485 | " initialEpoch = getFirstNumber(decomposePath(lastEpochPath)[1]) + 1\n", 486 | " batchesPassed = int(fileToStr(batchesPassedPath))\n", 487 | " log(\"We found an epoch to resume: \" + lastEpochPath, logger)\n", 488 | " logWarning(\"We will skip \" + str(batchesPassed) + \" batches because we resume a previous train\", logger)" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": null, 494 | "metadata": {}, 495 | "outputs": [], 496 | "source": [ 497 | "if lastEpochPath is not None:\n", 498 | " log(\"Loading previous model...\", logger)\n", 499 | " dbertConfig = DistilBertConfig.from_pretrained(lastEpochPath + '/config.json')\n", 500 | " model = TFDistilBertForSequenceClassification.from_pretrained\\\n", 501 | " (\n", 502 | " lastEpochPath + '/tf_model.h5',\n", 503 | " config=dbertConfig,\n", 504 | " )\n", 505 | "else:\n", 506 | " log(\"Loading a new model from distilbert-base-uncased...\", logger)\n", 507 | " # Labels count:\n", 508 | " numLabels = len(deserialize(ksetRoot + '/validation/labels.pickle'))\n", 509 | " # Config:\n", 510 | " dbertConfig = DistilBertConfig.from_pretrained\\\n", 511 | " (\n", 512 | " \"distilbert-base-uncased\",\n", 513 | " num_labels=numLabels,\n", 514 | " max_length=config['maxLength'],\n", 515 | " proxies=proxies,\n", 516 | " )\n", 517 | " # Model:\n", 518 | " model = TFDistilBertForSequenceClassification.from_pretrained\\\n", 519 | " (\n", 520 | " \"distilbert-base-uncased\",\n", 521 | " config=dbertConfig,\n", 522 | " proxies=proxies,\n", 523 | " )\n", 524 | "log(\"Model loaded.\", logger)" 525 | ] 526 | }, 527 | { 528 | "cell_type": "code", 529 | "execution_count": null, 530 | "metadata": {}, 531 | "outputs": [], 532 | "source": [ 533 | "# Optimizer:\n", 534 | "optKwargs = dict()\n", 535 | "if dictContains(config, 'clipnorm'): optKwargs['clipnorm'] = config['clipnorm']\n", 536 | "if dictContains(config, 'learningRate'): optKwargs['learning_rate'] = config['learningRate']\n", 537 | "if dictContains(config, 'epsilon'): optKwargs['epsilon'] = config['epsilon']\n", 538 | "opt = tf.keras.optimizers.Adam(**optKwargs)\n", 539 | "# Loss:\n", 540 | "loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n", 541 | "# Metric:\n", 542 | "metric = tf.keras.metrics.SparseCategoricalAccuracy(\"accuracy\")\n", 543 | "# Compilation:\n", 544 | "model.compile(optimizer=opt, loss=loss, metrics=[metric])" 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "execution_count": null, 550 | "metadata": {}, 551 | "outputs": [], 552 | "source": [ 553 | "model.summary()" 554 | ] 555 | }, 556 | { 557 | "cell_type": "code", 558 | "execution_count": null, 559 | "metadata": {}, 560 | "outputs": [], 561 | "source": [] 562 | }, 563 | { 564 | "cell_type": "code", 565 | "execution_count": null, 566 | "metadata": {}, 567 | "outputs": [], 568 | "source": [] 569 | }, 570 | { 571 | "cell_type": "markdown", 572 | "metadata": {}, 573 | "source": [ 574 | "## Training" 575 | ] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "execution_count": null, 580 | "metadata": {}, 581 | "outputs": [], 582 | "source": [ 583 | "(trainSamplesCount, validationSamplesCount) = getSamplesCount(logger=logger)\n", 584 | "log('trainSamplesCount: ' + str(trainSamplesCount) + ', validationSamplesCount: ' + str(validationSamplesCount), logger)" 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": null, 590 | "metadata": {}, 591 | "outputs": [], 592 | "source": [ 593 | "trainBatchesAmount = math.ceil(trainSamplesCount / config['batchSize'])\n", 594 | "validationBatchesAmount = math.ceil(validationSamplesCount / config['batchSize'])\n", 595 | "trainSteps = math.ceil(trainBatchesAmount / config[\"trainStepDivider\"])\n", 596 | "validationSteps = validationBatchesAmount\n", 597 | "log('trainBatchesAmount: ' + str(trainBatchesAmount), logger)\n", 598 | "log('validationBatchesAmount: ' + str(validationBatchesAmount), logger)\n", 599 | "log('trainSteps: ' + str(trainSteps), logger)\n", 600 | "log('validationSteps: ' + str(validationSteps), logger)" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": null, 606 | "metadata": {}, 607 | "outputs": [], 608 | "source": [ 609 | "callback = tf2utils.KerasCallback\\\n", 610 | "(\n", 611 | " model,\n", 612 | " outputDir,\n", 613 | " saveFunct=saveFunct,\n", 614 | " showGraphs=isNotebook,\n", 615 | " earlyStopMonitor=\n", 616 | " {\n", 617 | " 'val_loss': {'patience': 10, 'mode': 'auto'},\n", 618 | " 'val_accuracy': {'patience': 10, 'mode': 'auto'},\n", 619 | " 'val_top_k_categorical_accuracy': {'patience': 10, 'mode': 'auto'},\n", 620 | " },\n", 621 | " initialEpoch=initialEpoch,\n", 622 | " batchesAmount=trainBatchesAmount,\n", 623 | " batchesPassed=batchesPassed,\n", 624 | " removeEpochs=True,\n", 625 | " logger=logger,\n", 626 | ")" 627 | ] 628 | }, 629 | { 630 | "cell_type": "code", 631 | "execution_count": null, 632 | "metadata": {}, 633 | "outputs": [], 634 | "source": [ 635 | "ksetGenKwargs = \\\n", 636 | "{\n", 637 | " 'ksetRoot': ksetRoot,\n", 638 | " 'dataCol': config['dataCol'],\n", 639 | " 'maxLength': config['maxLength'],\n", 640 | " 'multiSamplage': config['multiSamplage'],\n", 641 | " 'maxSamples': config['maxSamples'],\n", 642 | "}\n", 643 | "ksetGenTrainKwargs = mergeDicts(ksetGenKwargs, {'train': True})\n", 644 | "ksetGenValidationKwargs = mergeDicts(ksetGenKwargs, {'train': False})" 645 | ] 646 | }, 647 | { 648 | "cell_type": "code", 649 | "execution_count": null, 650 | "metadata": {}, 651 | "outputs": [], 652 | "source": [ 653 | " if config['useMLIterator']:\n", 654 | " train = IteratorToGenerator\\\n", 655 | " (\n", 656 | " InfiniteBatcher\\\n", 657 | " (\n", 658 | " AgainAndAgain\\\n", 659 | " (\n", 660 | " MLIterator,\n", 661 | " trainFiles,\n", 662 | " genFunct,\n", 663 | " genKwargs=ksetGenKwargs,\n", 664 | " queuesMaxSize=100,\n", 665 | " parallelProcesses=cpuCount(),\n", 666 | " useFlushTimer=False,\n", 667 | " flushTimeout=300,\n", 668 | " logger=logger,\n", 669 | " ),\n", 670 | " batchSize=config['batchSize'],\n", 671 | " shuffle=config['shuffle'],\n", 672 | " queueSize=config['queueSize'],\n", 673 | " skip=batchesPassed,\n", 674 | " logger=logger,\n", 675 | " )\n", 676 | " )\n", 677 | " validation = IteratorToGenerator\\\n", 678 | " (\n", 679 | " InfiniteBatcher\\\n", 680 | " (\n", 681 | " AgainAndAgain\\\n", 682 | " (\n", 683 | " MLIterator,\n", 684 | " validationFiles,\n", 685 | " genFunct,\n", 686 | " genKwargs=ksetGenKwargs,\n", 687 | " queuesMaxSize=100,\n", 688 | " parallelProcesses=cpuCount(),\n", 689 | " useFlushTimer=False,\n", 690 | " flushTimeout=300,\n", 691 | " logger=logger,\n", 692 | " ),\n", 693 | " batchSize=config['batchSize'],\n", 694 | " shuffle=config['shuffle'],\n", 695 | " queueSize=config['queueSize'],\n", 696 | " skip=batchesPassed,\n", 697 | " logger=logger,\n", 698 | " )\n", 699 | " )\n", 700 | "else:\n", 701 | " train = IteratorToGenerator(InfiniteBatcher\\\n", 702 | " (\n", 703 | " AgainAndAgain(ksetGen, **ksetGenTrainKwargs),\n", 704 | " batchSize=config['batchSize'],\n", 705 | " shuffle=config['shuffle'],\n", 706 | " queueSize=config['queueSize'],\n", 707 | " skip=batchesPassed,\n", 708 | " logger=logger,\n", 709 | " ))\n", 710 | " validation = IteratorToGenerator(InfiniteBatcher\\\n", 711 | " (\n", 712 | " AgainAndAgain(ksetGen, **ksetGenValidationKwargs),\n", 713 | " batchSize=config['batchSize'],\n", 714 | " shuffle=0,\n", 715 | " queueSize=100,\n", 716 | " skip=0,\n", 717 | " logger=logger,\n", 718 | " ))" 719 | ] 720 | }, 721 | { 722 | "cell_type": "code", 723 | "execution_count": null, 724 | "metadata": {}, 725 | "outputs": [], 726 | "source": [ 727 | "history = model.fit\\\n", 728 | "(\n", 729 | " x=train,\n", 730 | " epochs=100 * config[\"trainStepDivider\"],\n", 731 | " validation_data=validation,\n", 732 | " callbacks=[callback, callbacks.TerminateOnNaN()],\n", 733 | " initial_epoch=initialEpoch,\n", 734 | " steps_per_epoch=trainSteps,\n", 735 | " validation_steps=validationSteps,\n", 736 | ")" 737 | ] 738 | }, 739 | { 740 | "cell_type": "code", 741 | "execution_count": null, 742 | "metadata": {}, 743 | "outputs": [], 744 | "source": [] 745 | }, 746 | { 747 | "cell_type": "code", 748 | "execution_count": null, 749 | "metadata": {}, 750 | "outputs": [], 751 | "source": [] 752 | } 753 | ], 754 | "metadata": { 755 | "kernelspec": { 756 | "display_name": "Python 3", 757 | "language": "python", 758 | "name": "python3" 759 | }, 760 | "language_info": { 761 | "codemirror_mode": { 762 | "name": "ipython", 763 | "version": 3 764 | }, 765 | "file_extension": ".py", 766 | "mimetype": "text/x-python", 767 | "name": "python", 768 | "nbconvert_exporter": "python", 769 | "pygments_lexer": "ipython3", 770 | "version": "3.6.3" 771 | } 772 | }, 773 | "nbformat": 4, 774 | "nbformat_minor": 2 775 | } 776 | -------------------------------------------------------------------------------- /experiments/notebooks/sna-train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Commands" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# titanv 1:\n", 17 | "# screen -S asa2-train1\n", 18 | "# source ~/.bash_profile ; source ~/.bash_aliases ; cd ~/asa2-train-logs\n", 19 | "# DOCKER_PORT=9961 nn -o nohup-asa2-train-$HOSTNAME-1.out ~/docker/keras/run-jupython.sh ~/notebooks/asa/train/asa2-train.ipynb titanv" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "# titanv 2:\n", 29 | "# screen -S asa2-train2\n", 30 | "# source ~/.bash_profile ; source ~/.bash_aliases ; cd ~/asa2-train-logs\n", 31 | "# DOCKER_PORT=9962 nn -o nohup-asa2-train-$HOSTNAME-2.out ~/docker/keras/run-jupython.sh ~/notebooks/asa/train/asa2-train.ipynb titanv" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "# Inits" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "isNotebook = '__file__' not in locals()" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "TEST = False\n", 71 | "if isNotebook:\n", 72 | " TEST = False" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "# Force CUDA_VISIBLE_DEVICES" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "import os\n", 103 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "hasGPU = True" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": {}, 132 | "source": [ 133 | "# Imports" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "from newssource.asattribution.utils import *\n", 143 | "from newssource.asattribution.asamin import *\n", 144 | "from newssource.asa.asapreproc import *\n", 145 | "from newssource.metrics.ndcg import *" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "import matplotlib\n", 155 | "import numpy as np\n", 156 | "import matplotlib.pyplot as plt\n", 157 | "if not isNotebook:\n", 158 | " matplotlib.use('Agg')\n", 159 | "import random\n", 160 | "import time\n", 161 | "import pickle\n", 162 | "from hashlib import md5" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": { 169 | "scrolled": false 170 | }, 171 | "outputs": [], 172 | "source": [ 173 | "from keras import backend as K\n", 174 | "K.tensorflow_backend._get_available_gpus()" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "from sklearn.model_selection import train_test_split\n", 184 | "from sklearn.feature_extraction.text import TfidfVectorizer\n", 185 | "from sklearn.naive_bayes import MultinomialNB\n", 186 | "from sklearn import metrics\n", 187 | "from sklearn.model_selection import cross_val_score\n", 188 | "from sklearn.dummy import DummyClassifier" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "from gensim.test.utils import common_texts\n", 198 | "from gensim.models.doc2vec import Doc2Vec, TaggedDocument" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "from random import random\n", 208 | "from numpy import array\n", 209 | "from numpy import cumsum" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "import statistics" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "from machinelearning.baseline import *\n", 228 | "from machinelearning.encoder import *\n", 229 | "from machinelearning.kerasutils import *\n", 230 | "from machinelearning.kerasmodels import *\n", 231 | "from machinelearning.iterator import *\n", 232 | "from machinelearning.metrics import *\n", 233 | "from machinelearning.attmap.builder import *" 234 | ] 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": null, 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "from keras.layers import LSTM, GRU, Dense, CuDNNLSTM, CuDNNGRU, Bidirectional\n", 243 | "from keras.layers import BatchNormalization, Activation, SpatialDropout1D, InputSpec\n", 244 | "from keras.layers import MaxPooling1D, TimeDistributed, Flatten, concatenate, Conv1D\n", 245 | "from keras.utils import multi_gpu_model, plot_model\n", 246 | "from keras.layers import concatenate, Input, Dropout\n", 247 | "from keras.models import Model, load_model, Sequential\n", 248 | "from keras.preprocessing.text import one_hot\n", 249 | "from keras.preprocessing.sequence import pad_sequences\n", 250 | "from keras.layers.embeddings import Embedding\n", 251 | "from keras.callbacks import Callback, History, ModelCheckpoint, EarlyStopping\n", 252 | "from keras import optimizers\n", 253 | "from keras import callbacks\n", 254 | "from keras.engine.topology import Layer\n", 255 | "from keras import initializers as initializers, regularizers, constraints\n", 256 | "from keras import backend as K" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "# Config" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [ 286 | "config = \\\n", 287 | "{\n", 288 | " \"patience\": 3 if TEST else 20, # \n", 289 | " # 3gramsFiltered, 2gramsFiltered, 1gramsFiltered, textSentences # \n", 290 | " # filtered_sentences, sentences\n", 291 | " \"dataCol\": \"filtered_sentences\", # \n", 292 | " \"wordVectorsPattern\": \"test\" if TEST else \"glove-840B\", # \n", 293 | " \"embeddingsDimension\": 100 if TEST else 300,\n", 294 | " \"minVocDF\": 2, # 10 if isNotebook else 2,\n", 295 | " \"minVocLF\": 2,\n", 296 | " \"minTokensLength\": 3,\n", 297 | " \"docLength\": 1200, # The size of documents representation (median * 3 = 600)\n", 298 | " \"isTrainableEmbeddings\": False,\n", 299 | " \"doMultiGPU\": False, # not isNotebook\n", 300 | " \"batchSize\": 32 if TEST else 128, # 128, 256 # \n", 301 | " \"epochs\": 10 if TEST else 1000,\n", 302 | " \"maxQueueSize\": 10 if TEST else 10,\n", 303 | " \"saveFinalModel\": False,\n", 304 | " \"doNotif\": True, # not isNotebook,\n", 305 | " \"attention\": False, # \n", 306 | " \"bidirectional\": False, # \n", 307 | " \"isCuDNN\": True, # \n", 308 | " \"saveMetrics\":\n", 309 | " {\n", 310 | " \"val_loss\": \"min\",\n", 311 | " \"val_acc\": \"max\",\n", 312 | " \"val_top_k_categorical_accuracy\": \"max\",\n", 313 | " },\n", 314 | " \"metrics\": ['accuracy', 'top_k_categorical_accuracy'], # ['sparse_categorical_accuracy', 'sparse_top_k_categorical_accuracy'], ['accuracy', 'top_k_categorical_accuracy']\n", 315 | " \"loss\": 'categorical_crossentropy', # sparse_categorical_crossentropy, categorical_crossentropy\n", 316 | " \"trainStepDivider\": 80 if TEST else 0, # 5 à la base......\n", 317 | " \"infiniteBatcherShuffle\": 0 if TEST else 0,\n", 318 | " \n", 319 | " \"inputEncoding\": \"embedding\", # index, embedding\n", 320 | " \"labelEncoding\": 'onehot', # onehot, index\n", 321 | " \n", 322 | " \"denseUnits\": [500, 100], # [100], [500, 100]\n", 323 | " \"rnnUnits\": 500,\n", 324 | " \n", 325 | " \"persist\": [False, False] if TEST else [False, False],\n", 326 | "}\n", 327 | "if config[\"isCuDNN\"]:\n", 328 | " del config[\"isCuDNN\"]" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": null, 334 | "metadata": {}, 335 | "outputs": [], 336 | "source": [ 337 | "# config[\"hostname\"] = getHostname()\n", 338 | "print(\"hostname: \" + str(getHostname()))" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": null, 344 | "metadata": {}, 345 | "outputs": [], 346 | "source": [ 347 | "outputDirRoot = nosaveDir() + \"/asa2-train\"\n", 348 | "print(\"outputDirRoot: \" + outputDirRoot)" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": null, 354 | "metadata": {}, 355 | "outputs": [], 356 | "source": [ 357 | "config[\"trainPattern\"] = \"train\" # \n", 358 | "config[\"validationPattern\"] = \"validation\" # " 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "metadata": {}, 365 | "outputs": [], 366 | "source": [ 367 | "dataDirectory = dataDir() + \"/Asa2/kset\"" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": null, 373 | "metadata": {}, 374 | "outputs": [], 375 | "source": [ 376 | "config[\"outputDir\"] = outputDirRoot + \"/\" + objectToHash(config)\n", 377 | "mkdir(config[\"outputDir\"])" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": null, 383 | "metadata": {}, 384 | "outputs": [], 385 | "source": [ 386 | "logger = Logger(config[\"outputDir\"] + \"/asa2-train.log\")" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": null, 392 | "metadata": { 393 | "scrolled": true 394 | }, 395 | "outputs": [], 396 | "source": [ 397 | "tt = TicToc(logger=logger)\n", 398 | "tt.tic()" 399 | ] 400 | }, 401 | { 402 | "cell_type": "code", 403 | "execution_count": null, 404 | "metadata": {}, 405 | "outputs": [], 406 | "source": [ 407 | "toJsonFile(toMongoStorable(config), config[\"outputDir\"] + \"/config.json\")" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": null, 413 | "metadata": {}, 414 | "outputs": [], 415 | "source": [ 416 | "log(lts(config), logger)" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": null, 422 | "metadata": {}, 423 | "outputs": [], 424 | "source": [] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": null, 429 | "metadata": {}, 430 | "outputs": [], 431 | "source": [] 432 | }, 433 | { 434 | "cell_type": "markdown", 435 | "metadata": {}, 436 | "source": [ 437 | "# Loading of word embeddings" 438 | ] 439 | }, 440 | { 441 | "cell_type": "code", 442 | "execution_count": null, 443 | "metadata": {}, 444 | "outputs": [], 445 | "source": [ 446 | "emb = Embeddings(config[\"wordVectorsPattern\"], config[\"embeddingsDimension\"], verbose=True, logger=logger)" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": null, 452 | "metadata": {}, 453 | "outputs": [], 454 | "source": [ 455 | "wordEmbeddings = emb.getVectors()" 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": null, 461 | "metadata": {}, 462 | "outputs": [], 463 | "source": [ 464 | "# We get the embedings dimension:\n", 465 | "embeddingsDimension = len(wordEmbeddings[\"the\"])" 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": null, 471 | "metadata": {}, 472 | "outputs": [], 473 | "source": [ 474 | "tt.tic(\"We loaded word embeddings\")" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": null, 480 | "metadata": {}, 481 | "outputs": [], 482 | "source": [] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": null, 487 | "metadata": {}, 488 | "outputs": [], 489 | "source": [] 490 | }, 491 | { 492 | "cell_type": "markdown", 493 | "metadata": {}, 494 | "source": [ 495 | "# We get files" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": null, 501 | "metadata": {}, 502 | "outputs": [], 503 | "source": [ 504 | "trainFilesPath = sortedGlob(dataDirectory + \"/\" + config[\"trainPattern\"] + \"/*.bz2\")\n", 505 | "validationFilesPath = sortedGlob(dataDirectory + \"/\" + config[\"validationPattern\"] + \"/*.bz2\")\n", 506 | "assert len(trainFilesPath) > 0\n", 507 | "log(\"trainFilesPath:\\n\" + reducedLTS(trainFilesPath, 4), logger)\n", 508 | "log(\"validationFilesPath:\\n\" + reducedLTS(validationFilesPath, 4), logger)" 509 | ] 510 | }, 511 | { 512 | "cell_type": "code", 513 | "execution_count": null, 514 | "metadata": {}, 515 | "outputs": [], 516 | "source": [] 517 | }, 518 | { 519 | "cell_type": "code", 520 | "execution_count": null, 521 | "metadata": {}, 522 | "outputs": [], 523 | "source": [] 524 | }, 525 | { 526 | "cell_type": "markdown", 527 | "metadata": {}, 528 | "source": [ 529 | "# We search an amount of batches to skip" 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "execution_count": null, 535 | "metadata": {}, 536 | "outputs": [], 537 | "source": [ 538 | "trainInfiniteBatcherSkip = 0\n", 539 | "# In case we reume a previous train:\n", 540 | "if len(sortedGlob(config[\"outputDir\"] + \"/models/ep*\")) > 0:\n", 541 | " lastEpochPath = sortedGlob(config[\"outputDir\"] + \"/models/ep*\")[-1]\n", 542 | " log(\"We found an epoch to resume: \" + lastEpochPath, logger)\n", 543 | " batchesPassedFile = lastEpochPath + \"/batchesPassed.txt\"\n", 544 | " if isFile(batchesPassedFile):\n", 545 | " trainInfiniteBatcherSkip = int(fileToStr(batchesPassedFile))\n", 546 | " logWarning(\"We will skip \" + str(trainInfiniteBatcherSkip) + \" batches because we resume a previous train\", logger)" 547 | ] 548 | }, 549 | { 550 | "cell_type": "code", 551 | "execution_count": null, 552 | "metadata": {}, 553 | "outputs": [], 554 | "source": [] 555 | }, 556 | { 557 | "cell_type": "code", 558 | "execution_count": null, 559 | "metadata": {}, 560 | "outputs": [], 561 | "source": [] 562 | }, 563 | { 564 | "cell_type": "markdown", 565 | "metadata": {}, 566 | "source": [ 567 | "# We prepare data" 568 | ] 569 | }, 570 | { 571 | "cell_type": "code", 572 | "execution_count": null, 573 | "metadata": {}, 574 | "outputs": [], 575 | "source": [ 576 | "prebuilt = None\n", 577 | "prebuiltPath = config[\"outputDir\"] + \"/asap-prebuilt.pickle\"" 578 | ] 579 | }, 580 | { 581 | "cell_type": "code", 582 | "execution_count": null, 583 | "metadata": {}, 584 | "outputs": [], 585 | "source": [ 586 | "if isFile(prebuiltPath):\n", 587 | " prebuilt = prebuiltPath\n", 588 | " log(\"We found \" + prebuilt, logger)\n", 589 | "else:\n", 590 | " log(\"We didn't found any asap prebuilt pickle file...\", logger)" 591 | ] 592 | }, 593 | { 594 | "cell_type": "code", 595 | "execution_count": null, 596 | "metadata": {}, 597 | "outputs": [], 598 | "source": [ 599 | "# We init AsaPreproc:\n", 600 | "asap = buildASAP\\\n", 601 | "(\n", 602 | " trainFilesPath,\n", 603 | " validationFilesPath,\n", 604 | " config[\"dataCol\"],\n", 605 | " \n", 606 | " minTokensLength=config[\"minTokensLength\"],\n", 607 | "\n", 608 | " batchSize=config[\"batchSize\"],\n", 609 | " minVocDF=config[\"minVocDF\"],\n", 610 | " minVocLF=config[\"minVocLF\"],\n", 611 | "\n", 612 | " wordEmbeddings=wordEmbeddings,\n", 613 | " \n", 614 | " persist=config[\"persist\"],\n", 615 | " \n", 616 | " docLength=config[\"docLength\"],\n", 617 | " \n", 618 | " prebuilt=prebuilt,\n", 619 | " \n", 620 | " logger=logger,\n", 621 | " verbose=True,\n", 622 | " \n", 623 | " labelEncoding=config[\"labelEncoding\"],\n", 624 | " encoding=config[\"inputEncoding\"],\n", 625 | ")" 626 | ] 627 | }, 628 | { 629 | "cell_type": "code", 630 | "execution_count": null, 631 | "metadata": {}, 632 | "outputs": [], 633 | "source": [ 634 | "# We serialize all:\n", 635 | "asap.serializePrebuilt(prebuiltPath)" 636 | ] 637 | }, 638 | { 639 | "cell_type": "code", 640 | "execution_count": null, 641 | "metadata": {}, 642 | "outputs": [], 643 | "source": [ 644 | "tt.tic(\"All train and validation data are ready\")" 645 | ] 646 | }, 647 | { 648 | "cell_type": "code", 649 | "execution_count": null, 650 | "metadata": {}, 651 | "outputs": [], 652 | "source": [] 653 | }, 654 | { 655 | "cell_type": "code", 656 | "execution_count": null, 657 | "metadata": {}, 658 | "outputs": [], 659 | "source": [] 660 | }, 661 | { 662 | "cell_type": "markdown", 663 | "metadata": {}, 664 | "source": [ 665 | "# We define the model" 666 | ] 667 | }, 668 | { 669 | "cell_type": "code", 670 | "execution_count": null, 671 | "metadata": {}, 672 | "outputs": [], 673 | "source": [ 674 | "opt = optimizers.Adam(clipnorm=1.0)" 675 | ] 676 | }, 677 | { 678 | "cell_type": "code", 679 | "execution_count": null, 680 | "metadata": {}, 681 | "outputs": [], 682 | "source": [ 683 | "initialEpoch = 0\n", 684 | "if len(sortedGlob(config[\"outputDir\"] + \"/models/ep*\")) > 0:\n", 685 | " logWarning(\"#\" * 20 + \" We will resume a previous train \" + \"#\" * 20, logger)\n", 686 | " lastEpochPath = sortedGlob(config[\"outputDir\"] + \"/models/ep*\")[-1]\n", 687 | " initialEpoch = getFirstNumber(decomposePath(lastEpochPath)[1]) + 1\n", 688 | " assert not isFile(config[\"outputDir\"] + \"/finished\")\n", 689 | " # To load the model we first build it:\n", 690 | " modelKwargs = fromJsonFile(lastEpochPath + \"/kwargs.json\")\n", 691 | " assert modelKwargs[\"docLength\"] == asap.getDocLength()\n", 692 | " assert modelKwargs[\"vocSize\"] == len(asap.getVocIndex())\n", 693 | " assert modelKwargs[\"nbClasses\"] == len(asap.getLabelEncoder())\n", 694 | " if \"embeddingsDimension\" not in modelKwargs:\n", 695 | " modelKwargs[\"embeddingsDimension\"] = asap.getEmbeddingsDimension() if config[\"inputEncoding\"] == \"embedding\" else None\n", 696 | " if \"embeddingMatrix\" not in modelKwargs:\n", 697 | " modelKwargs[\"embeddingMatrix\"] = asap.getEmbeddingMatrix() if config[\"inputEncoding\"] == \"index\" else None\n", 698 | " (originalModel, modelKwargs, modelScript) = buildRNN\\\n", 699 | " (\n", 700 | " logger=logger,\n", 701 | " verbose=True,\n", 702 | " **modelKwargs,\n", 703 | " )\n", 704 | " # Then we load weights:\n", 705 | " if isFile(lastEpochPath + \"/model.h5\"):\n", 706 | " originalModel = load_model(lastEpochPath + \"/model.h5\")\n", 707 | " else:\n", 708 | " originalModel.load_weights(lastEpochPath + \"/weights.h5\") # WARNING we loose optimizer states\n", 709 | " # Finally we compile it:\n", 710 | " originalModel.compile(loss=config[\"loss\"], optimizer=opt, metrics=config[\"metrics\"])\n", 711 | " parallelModel = None\n", 712 | " model = originalModel\n", 713 | " tt.tic(\"We loaded the model to resume training. Initial epoch: \" + str(initialEpoch), logger)\n", 714 | "else:\n", 715 | " (originalModel, modelKwargs, modelScript) = buildRNN\\\n", 716 | " (\n", 717 | " docLength=asap.getDocLength(),\n", 718 | " vocSize=len(asap.getVocIndex()),\n", 719 | " nbClasses=len(asap.getLabelEncoder()),\n", 720 | " isEmbeddingsTrainable=config[\"isTrainableEmbeddings\"],\n", 721 | " \n", 722 | " denseUnits=config[\"denseUnits\"],\n", 723 | " rnnUnits=config[\"rnnUnits\"],\n", 724 | " \n", 725 | " embSpacialDropout=0.2,\n", 726 | " firstDropout=0.2,\n", 727 | " recurrentDropout=0.2,\n", 728 | " attentionDropout=0.2,\n", 729 | " denseDropout=0.2,\n", 730 | " useRNNDropout=True,\n", 731 | " \n", 732 | " isBidirectional=config[\"bidirectional\"],\n", 733 | " \n", 734 | " isCuDNN='isCuDNN' not in config or config[\"isCuDNN\"],\n", 735 | " \n", 736 | " rnnType='LSTM',\n", 737 | " addAttention=config[\"attention\"],\n", 738 | " \n", 739 | " bnAfterEmbedding=False,\n", 740 | " bnAfterRNN=False,\n", 741 | " bnAfterAttention=False,\n", 742 | " bnAfterDenses=False,\n", 743 | " bnAfterLast=False,\n", 744 | " bnBeforeActivation=True,\n", 745 | "\n", 746 | " embeddingMatrix=asap.getEmbeddingMatrix() if config[\"inputEncoding\"] == \"index\" else None,\n", 747 | " logger=logger,\n", 748 | " verbose=True,\n", 749 | " \n", 750 | " embeddingsDimension=asap.getEmbeddingsDimension() if config[\"inputEncoding\"] == \"embedding\" else None,\n", 751 | " )\n", 752 | " parallelModel = None\n", 753 | " model = originalModel\n", 754 | " model.compile(loss=config[\"loss\"], optimizer=opt, metrics=config[\"metrics\"])" 755 | ] 756 | }, 757 | { 758 | "cell_type": "code", 759 | "execution_count": null, 760 | "metadata": {}, 761 | "outputs": [], 762 | "source": [ 763 | "model.summary()" 764 | ] 765 | }, 766 | { 767 | "cell_type": "code", 768 | "execution_count": null, 769 | "metadata": {}, 770 | "outputs": [], 771 | "source": [ 772 | "strToFile(originalModel.to_json(), config[\"outputDir\"] + \"/model.json\")" 773 | ] 774 | }, 775 | { 776 | "cell_type": "code", 777 | "execution_count": null, 778 | "metadata": {}, 779 | "outputs": [], 780 | "source": [] 781 | }, 782 | { 783 | "cell_type": "code", 784 | "execution_count": null, 785 | "metadata": {}, 786 | "outputs": [], 787 | "source": [] 788 | }, 789 | { 790 | "cell_type": "markdown", 791 | "metadata": {}, 792 | "source": [ 793 | "# We define metrics callbacks" 794 | ] 795 | }, 796 | { 797 | "cell_type": "code", 798 | "execution_count": null, 799 | "metadata": { 800 | "scrolled": false 801 | }, 802 | "outputs": [], 803 | "source": [ 804 | "mainCallback = KerasCallback\\\n", 805 | "(\n", 806 | " originalModel,\n", 807 | " logger=logger,\n", 808 | " graphDir=config[\"outputDir\"] + \"/graphs\",\n", 809 | " modelsDir=config[\"outputDir\"] + \"/models\",\n", 810 | " doNotif=False,\n", 811 | " saveMetrics=config[\"saveMetrics\"],\n", 812 | " doPltShow=isNotebook,\n", 813 | " historyFile=config[\"outputDir\"] + \"/history.json\",\n", 814 | " initialEpoch=initialEpoch,\n", 815 | " stopFile=config[\"outputDir\"] + \"/stop\",\n", 816 | " earlyStopMonitor=\\\n", 817 | " {\n", 818 | " 'val_loss': {'patience': config[\"patience\"]},\n", 819 | " 'val_acc': {'patience': config[\"patience\"]},\n", 820 | " 'val_top_k_categorical_accuracy': {'patience': config[\"patience\"]},\n", 821 | " },\n", 822 | " batchesPassed=trainInfiniteBatcherSkip,\n", 823 | " batchesAmount=asap.getBatchesCount(0),\n", 824 | " \n", 825 | " saveFunct=saveModel,\n", 826 | " saveFunctKwargs=\\\n", 827 | " {\n", 828 | " \"makeSubDir\": False, \"kwargs\": modelKwargs, \"script\": modelScript,\n", 829 | " \"extraInfos\":\\\n", 830 | " {\n", 831 | " \"wordVectorsPattern\": config[\"wordVectorsPattern\"],\n", 832 | " \"embeddingsDimension\": config[\"embeddingsDimension\"],\n", 833 | " \"minVocDF\": config[\"minVocDF\"],\n", 834 | " \"minVocLF\": config[\"minVocLF\"],\n", 835 | " \"dataCol\": config[\"dataCol\"],\n", 836 | " \"labelEncoding\": config[\"labelEncoding\"],\n", 837 | " \"inputEncoding\": config[\"inputEncoding\"],\n", 838 | " \"trainStepDivider\": config[\"trainStepDivider\"],\n", 839 | " \"infiniteBatcherShuffle\": config[\"infiniteBatcherShuffle\"],\n", 840 | " },\n", 841 | " },\n", 842 | ")" 843 | ] 844 | }, 845 | { 846 | "cell_type": "code", 847 | "execution_count": null, 848 | "metadata": {}, 849 | "outputs": [], 850 | "source": [] 851 | }, 852 | { 853 | "cell_type": "code", 854 | "execution_count": null, 855 | "metadata": {}, 856 | "outputs": [], 857 | "source": [] 858 | }, 859 | { 860 | "cell_type": "markdown", 861 | "metadata": {}, 862 | "source": [ 863 | "# We train the model" 864 | ] 865 | }, 866 | { 867 | "cell_type": "code", 868 | "execution_count": null, 869 | "metadata": {}, 870 | "outputs": [], 871 | "source": [ 872 | "stepsPerEpoch = asap.getBatchesCount(0)\n", 873 | "if dictContains(config, \"trainStepDivider\") and config[\"trainStepDivider\"] > 1:\n", 874 | " stepsPerEpoch = math.ceil(stepsPerEpoch / config[\"trainStepDivider\"])\n", 875 | " log(\"The stepsPerEpoch was \" + str(asap.getBatchesCount(0)) + \" but now is \" + str(stepsPerEpoch), logger)" 876 | ] 877 | }, 878 | { 879 | "cell_type": "code", 880 | "execution_count": null, 881 | "metadata": { 882 | "scrolled": true 883 | }, 884 | "outputs": [], 885 | "source": [ 886 | "log(\"We launch fit_generator\", logger)\n", 887 | "asap.verbose = True # if TEST else False\n", 888 | "history = model.fit_generator\\\n", 889 | "(\n", 890 | " asap.getInfiniteBatcher(0, shuffle=config[\"infiniteBatcherShuffle\"], skip=trainInfiniteBatcherSkip),\n", 891 | " steps_per_epoch=stepsPerEpoch,\n", 892 | " validation_data=asap.getInfiniteBatcher(1),\n", 893 | " validation_steps=asap.getBatchesCount(1) / 40 if TEST else asap.getBatchesCount(1),\n", 894 | " epochs=config[\"epochs\"],\n", 895 | " verbose=1,\n", 896 | " max_queue_size=config[\"maxQueueSize\"],\n", 897 | " callbacks=[mainCallback, callbacks.TerminateOnNaN()],\n", 898 | " initial_epoch=initialEpoch,\n", 899 | ")" 900 | ] 901 | }, 902 | { 903 | "cell_type": "code", 904 | "execution_count": null, 905 | "metadata": {}, 906 | "outputs": [], 907 | "source": [ 908 | "mainCallback.logHistory()\n", 909 | "log(\"Best val_loss: \" + str(min(mainCallback.history[\"val_loss\"])), logger)\n", 910 | "log(\"Best val_acc: \" + str(max(mainCallback.history[\"val_acc\"])), logger)\n", 911 | "log(\"Best val_top_k_acc: \" + str(max(mainCallback.history[\"val_top_k_categorical_accuracy\"])), logger)\n", 912 | "log(\"Nb epochs: \" + str(len(mainCallback.history[\"val_loss\"])), logger)" 913 | ] 914 | }, 915 | { 916 | "cell_type": "code", 917 | "execution_count": null, 918 | "metadata": {}, 919 | "outputs": [], 920 | "source": [ 921 | "if not isNotebook:\n", 922 | " notif(\"Training done\", lts(config))" 923 | ] 924 | }, 925 | { 926 | "cell_type": "code", 927 | "execution_count": null, 928 | "metadata": {}, 929 | "outputs": [], 930 | "source": [] 931 | }, 932 | { 933 | "cell_type": "code", 934 | "execution_count": null, 935 | "metadata": {}, 936 | "outputs": [], 937 | "source": [] 938 | }, 939 | { 940 | "cell_type": "markdown", 941 | "metadata": {}, 942 | "source": [ 943 | "# We save the model and infos" 944 | ] 945 | }, 946 | { 947 | "cell_type": "code", 948 | "execution_count": null, 949 | "metadata": {}, 950 | "outputs": [], 951 | "source": [ 952 | "try:\n", 953 | " # toJsonFile(history.history, config[\"outputDir\"] + \"/history.json\")\n", 954 | " toJsonFile({\"epochs\": mainCallback.epochs, \"history\": mainCallback.history},\n", 955 | " config[\"outputDir\"] + \"/history.json\")\n", 956 | "except Exception as e:\n", 957 | " logException(e, logger)" 958 | ] 959 | }, 960 | { 961 | "cell_type": "code", 962 | "execution_count": null, 963 | "metadata": {}, 964 | "outputs": [], 965 | "source": [ 966 | "if config[\"saveFinalModel\"]:\n", 967 | " finalModelDirectory = config[\"outputDir\"] + \"/final-model\"\n", 968 | " mkdir(finalModelDirectory)\n", 969 | " originalModel.save(finalModelDirectory + '/model.h5')" 970 | ] 971 | }, 972 | { 973 | "cell_type": "code", 974 | "execution_count": null, 975 | "metadata": {}, 976 | "outputs": [], 977 | "source": [ 978 | "touch(config[\"outputDir\"] + \"/finished\")" 979 | ] 980 | }, 981 | { 982 | "cell_type": "code", 983 | "execution_count": null, 984 | "metadata": {}, 985 | "outputs": [], 986 | "source": [] 987 | }, 988 | { 989 | "cell_type": "code", 990 | "execution_count": null, 991 | "metadata": {}, 992 | "outputs": [], 993 | "source": [] 994 | }, 995 | { 996 | "cell_type": "markdown", 997 | "metadata": {}, 998 | "source": [ 999 | "# Attention" 1000 | ] 1001 | }, 1002 | { 1003 | "cell_type": "code", 1004 | "execution_count": null, 1005 | "metadata": {}, 1006 | "outputs": [], 1007 | "source": [ 1008 | "if isNotebook:\n", 1009 | " asap.setParallelProcesses(1) # For consistency" 1010 | ] 1011 | }, 1012 | { 1013 | "cell_type": "code", 1014 | "execution_count": null, 1015 | "metadata": {}, 1016 | "outputs": [], 1017 | "source": [ 1018 | "if isNotebook:\n", 1019 | " # We get val encoded tokens:\n", 1020 | " valEncodedTokens = np.array([tokens for tokens, encodedAd in asap.getPart(1)])\n", 1021 | " # We get documents non-lowered and non-flattened, with masks (pre-padding):\n", 1022 | " valPaddedDocs = [tokens for tokens, _ in asap.getRawPart(1, pad=True)]\n", 1023 | " # We get all encoded labels:\n", 1024 | " valEncodedLabels = np.array([encodedAd for tokens, encodedAd in asap.getPart(1)])\n", 1025 | " # We get all labels:\n", 1026 | " valLabels = [ad for tokens, ad in asap.getRawPart(1)]\n", 1027 | " # We get the dict label -> encodedLabel (we can also use asap.encodedLabelToLabel(encodedLabel)):\n", 1028 | " encodedAds = asap.getLabelEncoder()\n", 1029 | " # We display it:\n", 1030 | " bp(valPaddedDocs)" 1031 | ] 1032 | }, 1033 | { 1034 | "cell_type": "code", 1035 | "execution_count": null, 1036 | "metadata": {}, 1037 | "outputs": [], 1038 | "source": [ 1039 | "if isNotebook:\n", 1040 | " # We get predictions:\n", 1041 | " predictionsAsSoftmax = model.predict(valEncodedTokens)\n", 1042 | " # We convert it to encoded labels according to the max probability of softmax vectors:\n", 1043 | " predictionsAsEncodedLabel = []\n", 1044 | " for i in range(len(predictionsAsSoftmax)):\n", 1045 | " predSoftmax = predictionsAsSoftmax[i]\n", 1046 | " predEncodedLabel = np.zeros(len(predSoftmax))\n", 1047 | " predEncodedLabel[np.argmax(predSoftmax)] = 1\n", 1048 | " predictionsAsEncodedLabel.append(predEncodedLabel)\n", 1049 | " # And we convert all to labels:\n", 1050 | " predictionsAsLabel = [asap.decodeLabel(enc) for enc in predictionsAsEncodedLabel]\n", 1051 | " # We display it:\n", 1052 | " bp(predictionsAsLabel)" 1053 | ] 1054 | }, 1055 | { 1056 | "cell_type": "code", 1057 | "execution_count": null, 1058 | "metadata": {}, 1059 | "outputs": [], 1060 | "source": [ 1061 | "if isNotebook:\n", 1062 | " # We get attentions:\n", 1063 | " attentions = getAttentions(model, valEncodedTokens)" 1064 | ] 1065 | }, 1066 | { 1067 | "cell_type": "code", 1068 | "execution_count": null, 1069 | "metadata": {}, 1070 | "outputs": [], 1071 | "source": [ 1072 | "if isNotebook:\n", 1073 | " # Now we compute the accuracy:\n", 1074 | " wellClassifiedCount = 0\n", 1075 | " for i in range(len(predictionsAsLabel)):\n", 1076 | " if valLabels[i] == predictionsAsLabel[i]:\n", 1077 | " wellClassifiedCount += 1\n", 1078 | " print(\"Accuracy: \" + str(truncateFloat(wellClassifiedCount / len(predictionsAsLabel) * 100.0, 2)))" 1079 | ] 1080 | }, 1081 | { 1082 | "cell_type": "code", 1083 | "execution_count": null, 1084 | "metadata": {}, 1085 | "outputs": [], 1086 | "source": [ 1087 | "if isNotebook:\n", 1088 | " # And finally we print some attentions:\n", 1089 | " for i in range(10, 50):\n", 1090 | " doc = valPaddedDocs[i]\n", 1091 | " label = valLabels[i]\n", 1092 | " predLabel = predictionsAsLabel[i]\n", 1093 | " okToken = \"==> OK <==\" if label == predLabel else \"==> FAIL <==\"\n", 1094 | " print(okToken + \" Prediction: \" + label[:30] + \", Ground truth: \" + str(predLabel)[:30])\n", 1095 | " attention = attentions[i]\n", 1096 | " showAttentionMap(doc, attention)" 1097 | ] 1098 | }, 1099 | { 1100 | "cell_type": "code", 1101 | "execution_count": null, 1102 | "metadata": {}, 1103 | "outputs": [], 1104 | "source": [] 1105 | }, 1106 | { 1107 | "cell_type": "code", 1108 | "execution_count": null, 1109 | "metadata": {}, 1110 | "outputs": [], 1111 | "source": [] 1112 | }, 1113 | { 1114 | "cell_type": "markdown", 1115 | "metadata": {}, 1116 | "source": [ 1117 | "# End" 1118 | ] 1119 | }, 1120 | { 1121 | "cell_type": "code", 1122 | "execution_count": null, 1123 | "metadata": {}, 1124 | "outputs": [], 1125 | "source": [ 1126 | "if config[\"doNotif\"]:\n", 1127 | " notif(\"LARGE training done on \" + getHostname())" 1128 | ] 1129 | }, 1130 | { 1131 | "cell_type": "code", 1132 | "execution_count": null, 1133 | "metadata": {}, 1134 | "outputs": [], 1135 | "source": [ 1136 | "tt.toc()" 1137 | ] 1138 | }, 1139 | { 1140 | "cell_type": "code", 1141 | "execution_count": null, 1142 | "metadata": {}, 1143 | "outputs": [], 1144 | "source": [] 1145 | }, 1146 | { 1147 | "cell_type": "code", 1148 | "execution_count": null, 1149 | "metadata": {}, 1150 | "outputs": [], 1151 | "source": [] 1152 | } 1153 | ], 1154 | "metadata": { 1155 | "kernelspec": { 1156 | "display_name": "Python 3", 1157 | "language": "python", 1158 | "name": "python3" 1159 | }, 1160 | "language_info": { 1161 | "codemirror_mode": { 1162 | "name": "ipython", 1163 | "version": 3 1164 | }, 1165 | "file_extension": ".py", 1166 | "mimetype": "text/x-python", 1167 | "name": "python", 1168 | "nbconvert_exporter": "python", 1169 | "pygments_lexer": "ipython3", 1170 | "version": "3.6.3" 1171 | } 1172 | }, 1173 | "nbformat": 4, 1174 | "nbformat_minor": 2 1175 | } 1176 | -------------------------------------------------------------------------------- /experiments/notebooks/tfidf-focus.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Commands" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "# jupython --venv st-venv ~/notebooks/DeepLearning/asamin-dual.ipynb" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "# titanv 1:\n", 26 | "# screen -S asamin-dual1\n", 27 | "# source ~/.bash_profile ; source ~/.bash_aliases ; cd ~/misc-logs\n", 28 | "# DOCKER_PORT=9961 nn -o nohup-dual-$HOSTNAME-1.out ~/docker/keras/run-jupython.sh ~/notebooks/asa/eval/asamin-dual.ipynb titanv\n", 29 | "# observe ~/misc-logs/nohup-dual-$HOSTNAME-1.out" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "# titanv 2:\n", 39 | "# screen -S asamin-dual2\n", 40 | "# source ~/.bash_profile ; source ~/.bash_aliases ; cd ~/misc-logs\n", 41 | "# DOCKER_PORT=9962 nn -o nohup-dual-$HOSTNAME-2.out ~/docker/keras/run-jupython.sh ~/notebooks/asa/eval/asamin-dual.ipynb titanv\n", 42 | "# observe ~/misc-logs/nohup-dual-$HOSTNAME-2.out" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "# Init" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "from machinelearning.seeder import *\n", 73 | "seed()" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "isNotebook = '__file__' not in locals()" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "TEST = False # isNotebook, False, True" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "import os\n", 101 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "metadata": {}, 121 | "source": [ 122 | "# Imports" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "from newssource.asattribution.utils import *\n", 132 | "from newssource.asattribution.asamin import *\n", 133 | "from newssource.asa.asapreproc import *\n", 134 | "from newssource.asa.models import *\n", 135 | "from newssource.metrics.ndcg import *" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "import matplotlib\n", 145 | "import numpy as np\n", 146 | "import matplotlib.pyplot as plt\n", 147 | "if not isNotebook:\n", 148 | " matplotlib.use('Agg')\n", 149 | "import random\n", 150 | "import time\n", 151 | "import pickle\n", 152 | "import copy\n", 153 | "from hashlib import md5" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": { 160 | "scrolled": false 161 | }, 162 | "outputs": [], 163 | "source": [ 164 | "from keras import backend as K\n", 165 | "K.tensorflow_backend._get_available_gpus()" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "from numpy import array\n", 175 | "from keras.preprocessing.text import one_hot\n", 176 | "from keras.preprocessing.sequence import pad_sequences\n", 177 | "from keras.models import Sequential\n", 178 | "from keras.layers import Dense, LSTM, Flatten\n", 179 | "from keras.layers.embeddings import Embedding\n", 180 | "from keras.models import load_model\n", 181 | "from keras.utils import multi_gpu_model" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "metadata": {}, 188 | "outputs": [], 189 | "source": [ 190 | "from sklearn.model_selection import train_test_split\n", 191 | "from sklearn.feature_extraction.text import TfidfVectorizer\n", 192 | "from sklearn.naive_bayes import MultinomialNB\n", 193 | "from sklearn import metrics\n", 194 | "from sklearn.model_selection import cross_val_score\n", 195 | "from sklearn.dummy import DummyClassifier\n", 196 | "from sklearn.model_selection import KFold, StratifiedKFold" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "metadata": {}, 203 | "outputs": [], 204 | "source": [ 205 | "from gensim.test.utils import common_texts\n", 206 | "from gensim.models.doc2vec import Doc2Vec, TaggedDocument" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "from random import random\n", 216 | "from numpy import array\n", 217 | "from numpy import cumsum\n", 218 | "from keras.models import Sequential\n", 219 | "from keras.layers import LSTM, GRU\n", 220 | "from keras.layers import Dense\n", 221 | "from keras.layers import TimeDistributed\n", 222 | "from keras.utils import multi_gpu_model" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "import statistics" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [ 240 | "from machinelearning.baseline import *\n", 241 | "from machinelearning.encoder import *\n", 242 | "from machinelearning.kerasutils import *\n", 243 | "from machinelearning.kerasmodels import *" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": {}, 250 | "outputs": [], 251 | "source": [ 252 | "from machinelearning.baseline import *\n", 253 | "from machinelearning.encoder import *\n", 254 | "from machinelearning import kerasutils\n", 255 | "from machinelearning.iterator import *\n", 256 | "from machinelearning.metrics import *" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "from keras.layers import concatenate, Input\n", 266 | "from keras.models import Model\n", 267 | "from keras.utils import plot_model" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": null, 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "import scipy" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "from keras.layers import LSTM, GRU, Dense, CuDNNLSTM, CuDNNGRU, TimeDistributed, Flatten, concatenate\n", 286 | "from keras.utils import multi_gpu_model, plot_model\n", 287 | "from keras.layers import concatenate, Input\n", 288 | "from keras.models import Model, load_model, Sequential\n", 289 | "from keras.preprocessing.text import one_hot\n", 290 | "from keras.preprocessing.sequence import pad_sequences\n", 291 | "from keras.layers.embeddings import Embedding\n", 292 | "from keras.callbacks import Callback, History, ModelCheckpoint" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [] 301 | }, 302 | { 303 | "cell_type": "code", 304 | "execution_count": null, 305 | "metadata": {}, 306 | "outputs": [], 307 | "source": [] 308 | }, 309 | { 310 | "cell_type": "markdown", 311 | "metadata": {}, 312 | "source": [ 313 | "# Modules reloading" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": null, 319 | "metadata": {}, 320 | "outputs": [], 321 | "source": [ 322 | "from importlib import reload" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": null, 328 | "metadata": {}, 329 | "outputs": [], 330 | "source": [ 331 | "def reloadKerasutils():\n", 332 | " global kerasutils\n", 333 | " kerasutils = reload(kerasutils)" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "metadata": {}, 347 | "outputs": [], 348 | "source": [] 349 | }, 350 | { 351 | "cell_type": "markdown", 352 | "metadata": {}, 353 | "source": [ 354 | "# Misc vars" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": null, 360 | "metadata": {}, 361 | "outputs": [], 362 | "source": [ 363 | "logger = Logger(tmpDir() + \"/asamin-dual.log\" if hjlat() else \"asamin-dual.log\")" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": null, 369 | "metadata": {}, 370 | "outputs": [], 371 | "source": [ 372 | "tt = TicToc(logger=logger)\n", 373 | "tt.tic()" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": null, 379 | "metadata": {}, 380 | "outputs": [], 381 | "source": [] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": null, 386 | "metadata": {}, 387 | "outputs": [], 388 | "source": [] 389 | }, 390 | { 391 | "cell_type": "markdown", 392 | "metadata": {}, 393 | "source": [ 394 | "# Results" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": null, 400 | "metadata": {}, 401 | "outputs": [], 402 | "source": [ 403 | "results = dict()" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": null, 409 | "metadata": {}, 410 | "outputs": [], 411 | "source": [ 412 | "def addResult(obj, *args, doPrint=True):\n", 413 | " global config\n", 414 | " global results\n", 415 | " if isinstance(obj, float):\n", 416 | " obj = {\"score\": obj}\n", 417 | " if len(args) == 1:\n", 418 | " if isinstance(args[0], list):\n", 419 | " path = args[0]\n", 420 | " else:\n", 421 | " path = [args]\n", 422 | " else:\n", 423 | " path = args\n", 424 | " currentResult = results\n", 425 | " for key in path[:-1]:\n", 426 | " if key not in currentResult:\n", 427 | " currentResult[key] = dict()\n", 428 | " currentResult = currentResult[key]\n", 429 | " key = path[-1]\n", 430 | " if key not in currentResult:\n", 431 | " currentResult[key] = []\n", 432 | " currentResult = currentResult[key]\n", 433 | " localConf = copy.deepcopy(config)\n", 434 | " for k in [\"doNotif\", \"doMultiGPU\", \"doFlattenSentences\", \"filterNonWordOrPunct\", \"punct\", \"outputDir\"]:\n", 435 | " if k in localConf:\n", 436 | " del localConf[k]\n", 437 | " currentResult.append(mergeDicts(obj, {\"config\": localConf}))\n", 438 | " if doPrint:\n", 439 | " printResults()" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": null, 445 | "metadata": {}, 446 | "outputs": [], 447 | "source": [ 448 | "def printResults():\n", 449 | " global results\n", 450 | " global logger\n", 451 | " localRes = copy.deepcopy(results)\n", 452 | " def recDel(d):\n", 453 | " if isinstance(d, dict):\n", 454 | " if \"config\" in d:\n", 455 | " del d[\"config\"]\n", 456 | " for key in d.keys():\n", 457 | " d[key] = recDel(d[key])\n", 458 | " elif isinstance(d, list):\n", 459 | " for i in range(len(d)):\n", 460 | " d[i] = recDel(d[i])\n", 461 | " return d\n", 462 | " localRes = recDel(localRes)\n", 463 | " log(\"results without configs:\\n\" + lts(localRes), logger)" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": null, 469 | "metadata": {}, 470 | "outputs": [], 471 | "source": [ 472 | "def saveState():\n", 473 | " toJsonFile(mongoStorable(config), config[\"outputDir\"] + \"/config.json\")\n", 474 | " toJsonFile(results, config[\"outputDir\"] + \"/results.json\")" 475 | ] 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": null, 480 | "metadata": {}, 481 | "outputs": [], 482 | "source": [] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": null, 487 | "metadata": {}, 488 | "outputs": [], 489 | "source": [] 490 | }, 491 | { 492 | "cell_type": "markdown", 493 | "metadata": {}, 494 | "source": [ 495 | "# Config" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": null, 501 | "metadata": {}, 502 | "outputs": [], 503 | "source": [ 504 | "config = dict()" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": null, 510 | "metadata": {}, 511 | "outputs": [], 512 | "source": [ 513 | "config = mergeDicts(config, \\\n", 514 | "{\n", 515 | " # \n", 516 | " \"dataCol\": \"filtered_sentences\", # filtered_sentences, sentences, usetSentences, 3gramsFiltered, 2gramsFiltered, 1gramsFiltered, textSentences\n", 517 | " \"minTokensLength\": 3,\n", 518 | " \"minVocDF\": 2,\n", 519 | " \"minVocLF\": 2,\n", 520 | " \n", 521 | " # \n", 522 | " # \"docLength\": 1200,\n", 523 | " \n", 524 | " # \n", 525 | " # \"wordVectorsPattern\": \"test\", # glove-6B, fasttext, glove-840B.300d\n", 526 | " # \"embeddingsDimension\": 100,\n", 527 | " # \"doLower\": True,\n", 528 | " \n", 529 | " # \n", 530 | " # ASA2:\n", 531 | " \"transfer\": {\"dirName\": \"6ebdd3e05d4388c658ca2d5c53b0bc36\", \"epoch\": 35, \"aaPop\": 3, \"srLayerId\": -8}, # filtered_sentences, epochs 35 (???) 33 (???)\n", 532 | " # \"transfer\": {\"dirName\": \"7fa22619af09e90724d2e3a8cf5db796\", \"epoch\": 26, \"aaPop\": 3, \"srLayerId\": -8}, # textSentences\n", 533 | "\n", 534 | " # \n", 535 | " # \"scoring\": \"accuracy\",\n", 536 | " # \"inputEncoding\": \"index\", # index, embedding # \n", 537 | " # \"cv\": 2 if TEST else 10,\n", 538 | " # \"patience\": 0 if TEST else 30, # 30\n", 539 | " # \"batchSize\": 32 if TEST else 32, # 128 doesn't work here, don't know why...\n", 540 | " # \"doNotif\": not isNotebook,\n", 541 | " \n", 542 | " # uset0-l50-dpl50-d18-bc10, asa, bchack (others: blogcorpus, c50)\n", 543 | " # uset0-l50-dpl200-d18-bc10, uset1-l50-dpl200-d18-bc10, uset2-l50-dpl200-d18-bc10, uset3-l50-dpl200-d18-bc10, uset4-l50-dpl200-d18-bc10\n", 544 | " # uset0-l50-dpl160-blogger.com, uset1-l50-dpl160-blogger.com, uset2-l50-dpl160-blogger.com, uset3-l50-dpl160-blogger.com, uset4-l50-dpl160-blogger.com\n", 545 | " \"datasetName\": \"uset4-l50-dpl160-blogger.com\",\n", 546 | " \n", 547 | " \n", 548 | " \"usePretrained\": True,\n", 549 | " \"batchSize\": 128,\n", 550 | " \"patience\": 60,\n", 551 | " \"epochs\": 500,\n", 552 | " \"loss\": 'categorical_crossentropy', # sparse_categorical_crossentropy, categorical_crossentropy\n", 553 | " \"metrics\": ['accuracy', 'top_k_categorical_accuracy'],\n", 554 | " \"saveMetrics\":\n", 555 | " {\n", 556 | " \"val_loss\": \"min\",\n", 557 | " \"val_acc\": \"max\",\n", 558 | " \"val_top_k_categorical_accuracy\": \"max\",\n", 559 | " },\n", 560 | " \"dropout\": 0.2,\n", 561 | " \"denseUnits\": 100,\n", 562 | "})" 563 | ] 564 | }, 565 | { 566 | "cell_type": "code", 567 | "execution_count": null, 568 | "metadata": {}, 569 | "outputs": [], 570 | "source": [ 571 | "config[\"outputDir\"] = nosaveDir() + \"/asa2-focus/\" + objectToHash(config)" 572 | ] 573 | }, 574 | { 575 | "cell_type": "code", 576 | "execution_count": null, 577 | "metadata": {}, 578 | "outputs": [], 579 | "source": [ 580 | "if TEST:\n", 581 | " (dir, filename, ext, filenameExt) = decomposePath(config[\"outputDir\"])\n", 582 | " config[\"outputDir\"] = dir + \"/tests/\" + filenameExt\n", 583 | "mkdir(config[\"outputDir\"])\n", 584 | "logger = Logger(config[\"outputDir\"] + \"/asa2-focus.log\")" 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": null, 590 | "metadata": {}, 591 | "outputs": [], 592 | "source": [ 593 | "log(lts(config), logger)" 594 | ] 595 | }, 596 | { 597 | "cell_type": "code", 598 | "execution_count": null, 599 | "metadata": {}, 600 | "outputs": [], 601 | "source": [] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": null, 606 | "metadata": {}, 607 | "outputs": [], 608 | "source": [] 609 | }, 610 | { 611 | "cell_type": "markdown", 612 | "metadata": {}, 613 | "source": [ 614 | "# We prepare data" 615 | ] 616 | }, 617 | { 618 | "cell_type": "code", 619 | "execution_count": null, 620 | "metadata": {}, 621 | "outputs": [], 622 | "source": [ 623 | "dataRootDir = dataDir() + \"/Asa2/usets\"\n", 624 | "dataDirectory = dataRootDir + \"/\" + config[\"datasetName\"]\n", 625 | "asapFiles = sortedGlob(dataDirectory + \"/*.bz2\")\n", 626 | "bp(asapFiles)" 627 | ] 628 | }, 629 | { 630 | "cell_type": "code", 631 | "execution_count": null, 632 | "metadata": {}, 633 | "outputs": [], 634 | "source": [ 635 | "attTransferDir = nosaveDir() + \"/asa2-train/\" + config[\"transfer\"][\"dirName\"]\n", 636 | "attTransferEpochPath = attTransferDir + \"/models/epoch\" + digitalizeIntegers(str(config[\"transfer\"][\"epoch\"]), 4)\n", 637 | "attTransferWeightsPath = attTransferEpochPath + \"/weights.h5\"\n", 638 | "attTransferPrebuilt = deserialize(attTransferDir + \"/asap-prebuilt.pickle\")\n", 639 | "attTransferConfig = fromJsonFile(attTransferDir + \"/config.json\")\n", 640 | "attTransferDocLength = attTransferConfig[\"docLength\"]" 641 | ] 642 | }, 643 | { 644 | "cell_type": "code", 645 | "execution_count": null, 646 | "metadata": {}, 647 | "outputs": [], 648 | "source": [ 649 | "asapKwargs = \\\n", 650 | "{\n", 651 | " 'dataCol': config[\"dataCol\"],\n", 652 | " 'minTokensLength': config[\"minTokensLength\"],\n", 653 | " 'minVocDF': config[\"minVocDF\"],\n", 654 | " 'minVocLF': config[\"minVocLF\"],\n", 655 | " 'persist': [True],\n", 656 | " 'docLength': attTransferConfig[\"docLength\"],\n", 657 | " 'labelEncoding': \"onehot\",\n", 658 | " 'logger': logger,\n", 659 | "}\n", 660 | "asap = buildASAP(asapFiles, **asapKwargs)" 661 | ] 662 | }, 663 | { 664 | "cell_type": "code", 665 | "execution_count": null, 666 | "metadata": {}, 667 | "outputs": [], 668 | "source": [ 669 | "attTransferPrebuilt[\"labelEncoder\"] = asap.labelEncoder\n", 670 | "attTransferPrebuilt[\"samplesCounts\"] = asap.samplesCounts" 671 | ] 672 | }, 673 | { 674 | "cell_type": "code", 675 | "execution_count": null, 676 | "metadata": {}, 677 | "outputs": [], 678 | "source": [ 679 | "attTransferAsapKwargs = \\\n", 680 | "{\n", 681 | " 'dataCol': config[\"dataCol\"],\n", 682 | " 'minTokensLength': config[\"minTokensLength\"],\n", 683 | " # 'batchSize': config[\"batchSize\"],\n", 684 | " 'minVocDF': config[\"minVocDF\"],\n", 685 | " 'minVocLF': config[\"minVocLF\"],\n", 686 | " 'persist': [True],\n", 687 | " 'prebuilt': attTransferPrebuilt,\n", 688 | " 'logger': logger,\n", 689 | " 'verbose': True,\n", 690 | " \n", 691 | " 'docLength': attTransferDocLength,\n", 692 | "}\n", 693 | "attTransferAsapKwargs[\"encoding\"] = attTransferConfig[\"inputEncoding\"]\n", 694 | "attTransferAsap = buildASAP(asapFiles, **attTransferAsapKwargs)" 695 | ] 696 | }, 697 | { 698 | "cell_type": "code", 699 | "execution_count": null, 700 | "metadata": { 701 | "scrolled": true 702 | }, 703 | "outputs": [], 704 | "source": [ 705 | "attTransferModelKwargs = fromJsonFile(attTransferEpochPath + \"/kwargs.json\")\n", 706 | "assert attTransferModelKwargs[\"vocSize\"] == len(attTransferPrebuilt[\"vocIndex\"])\n", 707 | "assert attTransferModelKwargs[\"docLength\"] == attTransferAsap.getDocLength()\n", 708 | "assert attTransferModelKwargs[\"vocSize\"] == len(attTransferAsap.getVocIndex())" 709 | ] 710 | }, 711 | { 712 | "cell_type": "code", 713 | "execution_count": null, 714 | "metadata": {}, 715 | "outputs": [], 716 | "source": [ 717 | "docs = [row[\"filtered_sentences\"] for row in NDJson(asapFiles[0])] " 718 | ] 719 | }, 720 | { 721 | "cell_type": "code", 722 | "execution_count": null, 723 | "metadata": {}, 724 | "outputs": [], 725 | "source": [ 726 | "bp(docs)" 727 | ] 728 | }, 729 | { 730 | "cell_type": "code", 731 | "execution_count": null, 732 | "metadata": {}, 733 | "outputs": [], 734 | "source": [ 735 | "splitIndex = int(0.75 * len(docs))\n", 736 | "trainDocs = docs[:splitIndex]\n", 737 | "testDocs = docs[splitIndex:]\n", 738 | "log(str(len(trainDocs)) + \" docs in the training set and \" + str(len(testDocs)) + \" docs in the test set\", logger)" 739 | ] 740 | }, 741 | { 742 | "cell_type": "code", 743 | "execution_count": null, 744 | "metadata": {}, 745 | "outputs": [], 746 | "source": [] 747 | }, 748 | { 749 | "cell_type": "code", 750 | "execution_count": null, 751 | "metadata": {}, 752 | "outputs": [], 753 | "source": [] 754 | }, 755 | { 756 | "cell_type": "markdown", 757 | "metadata": {}, 758 | "source": [ 759 | "# Init of DeepStyle" 760 | ] 761 | }, 762 | { 763 | "cell_type": "code", 764 | "execution_count": null, 765 | "metadata": {}, 766 | "outputs": [], 767 | "source": [ 768 | "from deepstyle.model import *" 769 | ] 770 | }, 771 | { 772 | "cell_type": "code", 773 | "execution_count": null, 774 | "metadata": {}, 775 | "outputs": [], 776 | "source": [ 777 | "# modelName = \"6ebdd3e05d4388c658ca2d5c53b0bc36\"\n", 778 | "modelName = \"extrafocus-model\"" 779 | ] 780 | }, 781 | { 782 | "cell_type": "code", 783 | "execution_count": null, 784 | "metadata": {}, 785 | "outputs": [], 786 | "source": [ 787 | "model = DeepStyle(nosaveDir() + \"/asa2-train/\" + modelName, logger=logger)" 788 | ] 789 | }, 790 | { 791 | "cell_type": "code", 792 | "execution_count": null, 793 | "metadata": {}, 794 | "outputs": [], 795 | "source": [] 796 | }, 797 | { 798 | "cell_type": "code", 799 | "execution_count": null, 800 | "metadata": {}, 801 | "outputs": [], 802 | "source": [] 803 | }, 804 | { 805 | "cell_type": "markdown", 806 | "metadata": {}, 807 | "source": [ 808 | "# We get attentions" 809 | ] 810 | }, 811 | { 812 | "cell_type": "code", 813 | "execution_count": null, 814 | "metadata": {}, 815 | "outputs": [], 816 | "source": [ 817 | "attentions = model.attentions(testDocs, progressVerbose=True)\n", 818 | "# model.save()\n", 819 | "bp(attentions, logger)" 820 | ] 821 | }, 822 | { 823 | "cell_type": "code", 824 | "execution_count": null, 825 | "metadata": {}, 826 | "outputs": [], 827 | "source": [ 828 | "def truncateAttentions(attentions, docs):\n", 829 | " newAttentions = []\n", 830 | " for i in range(len(docs)):\n", 831 | " # Getting the attention of the current doc:\n", 832 | " doc = docs[i]\n", 833 | " attention = attentions[i]\n", 834 | " docLength = len(attention)\n", 835 | " # Unpadding the attention:\n", 836 | " if len(doc) < docLength:\n", 837 | " pad = docLength - len(doc)\n", 838 | " attention = attention[pad:]\n", 839 | " # We check the shape:\n", 840 | " if len(doc) < docLength:\n", 841 | " assert len(attention) == len(doc)\n", 842 | " else:\n", 843 | " assert len(attention) == docLength\n", 844 | " # Making it a proba distrib:\n", 845 | " attention = toProbDist(attention)\n", 846 | " # And finally add it to attentions:\n", 847 | " newAttentions.append(attention)\n", 848 | " return newAttentions" 849 | ] 850 | }, 851 | { 852 | "cell_type": "code", 853 | "execution_count": null, 854 | "metadata": {}, 855 | "outputs": [], 856 | "source": [ 857 | "attDocs = [x[0] for x in attTransferAsap.getRawPart(truncate=True, pad=False)]\n", 858 | "attTestDocs = attDocs[splitIndex:]" 859 | ] 860 | }, 861 | { 862 | "cell_type": "code", 863 | "execution_count": null, 864 | "metadata": {}, 865 | "outputs": [], 866 | "source": [ 867 | "assert len(attTestDocs) == len(attentions)\n", 868 | "attentions = truncateAttentions(attentions, attTestDocs)" 869 | ] 870 | }, 871 | { 872 | "cell_type": "code", 873 | "execution_count": null, 874 | "metadata": {}, 875 | "outputs": [], 876 | "source": [] 877 | }, 878 | { 879 | "cell_type": "code", 880 | "execution_count": null, 881 | "metadata": {}, 882 | "outputs": [], 883 | "source": [] 884 | }, 885 | { 886 | "cell_type": "markdown", 887 | "metadata": {}, 888 | "source": [ 889 | "# We get TFIDF values" 890 | ] 891 | }, 892 | { 893 | "cell_type": "code", 894 | "execution_count": null, 895 | "metadata": {}, 896 | "outputs": [], 897 | "source": [ 898 | "tfidfDocs = [flattenLists(doc)[:1200] for doc in docs]\n", 899 | "bp(tfidfDocs, 4, logger)" 900 | ] 901 | }, 902 | { 903 | "cell_type": "code", 904 | "execution_count": null, 905 | "metadata": {}, 906 | "outputs": [], 907 | "source": [ 908 | "tfidf = TFIDF(tfidfDocs, logger=logger, doLower=True, sublinearTF=True)" 909 | ] 910 | }, 911 | { 912 | "cell_type": "code", 913 | "execution_count": null, 914 | "metadata": {}, 915 | "outputs": [], 916 | "source": [ 917 | "tfidfValues = tfidf.getTFIDFVectors()" 918 | ] 919 | }, 920 | { 921 | "cell_type": "code", 922 | "execution_count": null, 923 | "metadata": {}, 924 | "outputs": [], 925 | "source": [ 926 | "tfidfValues = tfidfValues[splitIndex:]" 927 | ] 928 | }, 929 | { 930 | "cell_type": "code", 931 | "execution_count": null, 932 | "metadata": {}, 933 | "outputs": [], 934 | "source": [ 935 | "for i in range(len(tfidfValues)):\n", 936 | " tfidfValues[i] = toProbDist(tfidfValues[i])" 937 | ] 938 | }, 939 | { 940 | "cell_type": "code", 941 | "execution_count": null, 942 | "metadata": {}, 943 | "outputs": [], 944 | "source": [ 945 | "assert len(tfidfValues) == len(attentions)" 946 | ] 947 | }, 948 | { 949 | "cell_type": "code", 950 | "execution_count": null, 951 | "metadata": {}, 952 | "outputs": [], 953 | "source": [ 954 | "for i in range(len(tfidfValues)):\n", 955 | " assert len(tfidfValues[i]) == len(attentions[i])" 956 | ] 957 | }, 958 | { 959 | "cell_type": "code", 960 | "execution_count": null, 961 | "metadata": {}, 962 | "outputs": [], 963 | "source": [ 964 | "bp(tfidfValues, logger)" 965 | ] 966 | }, 967 | { 968 | "cell_type": "code", 969 | "execution_count": null, 970 | "metadata": {}, 971 | "outputs": [], 972 | "source": [] 973 | }, 974 | { 975 | "cell_type": "code", 976 | "execution_count": null, 977 | "metadata": {}, 978 | "outputs": [], 979 | "source": [] 980 | }, 981 | { 982 | "cell_type": "markdown", 983 | "metadata": {}, 984 | "source": [ 985 | "# We get the attention focus" 986 | ] 987 | }, 988 | { 989 | "cell_type": "code", 990 | "execution_count": null, 991 | "metadata": {}, 992 | "outputs": [], 993 | "source": [ 994 | "def getAttentionFocus(attentions, tfidfValues):\n", 995 | " attentionFocus = []\n", 996 | " for i in range(len(attentions)):\n", 997 | " attention = attentions[i]\n", 998 | " tfidf = tfidfValues[i]\n", 999 | " attentionFocus.append(np.dot(attention, tfidf))\n", 1000 | " return np.mean(attentionFocus)" 1001 | ] 1002 | }, 1003 | { 1004 | "cell_type": "code", 1005 | "execution_count": null, 1006 | "metadata": {}, 1007 | "outputs": [], 1008 | "source": [ 1009 | "attentionFocus = getAttentionFocus(attentions, tfidfValues)" 1010 | ] 1011 | }, 1012 | { 1013 | "cell_type": "code", 1014 | "execution_count": null, 1015 | "metadata": {}, 1016 | "outputs": [], 1017 | "source": [ 1018 | "attentionFocus = attentionFocus * 100" 1019 | ] 1020 | }, 1021 | { 1022 | "cell_type": "code", 1023 | "execution_count": null, 1024 | "metadata": { 1025 | "scrolled": true 1026 | }, 1027 | "outputs": [], 1028 | "source": [ 1029 | "log(\"attention focus: \" + str(attentionFocus), logger)" 1030 | ] 1031 | }, 1032 | { 1033 | "cell_type": "code", 1034 | "execution_count": null, 1035 | "metadata": {}, 1036 | "outputs": [], 1037 | "source": [] 1038 | }, 1039 | { 1040 | "cell_type": "code", 1041 | "execution_count": null, 1042 | "metadata": {}, 1043 | "outputs": [], 1044 | "source": [] 1045 | } 1046 | ], 1047 | "metadata": { 1048 | "kernelspec": { 1049 | "display_name": "Python 3", 1050 | "language": "python", 1051 | "name": "python3" 1052 | }, 1053 | "language_info": { 1054 | "codemirror_mode": { 1055 | "name": "ipython", 1056 | "version": 3 1057 | }, 1058 | "file_extension": ".py", 1059 | "mimetype": "text/x-python", 1060 | "name": "python", 1061 | "nbconvert_exporter": "python", 1062 | "pygments_lexer": "ipython3", 1063 | "version": "3.6.3" 1064 | } 1065 | }, 1066 | "nbformat": 4, 1067 | "nbformat_minor": 2 1068 | } 1069 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hayj/DeepStyle/599cbf2fa1fa537070bd64d7ca92c72451e1404c/requirements.txt -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os 4 | from setuptools import setup, find_packages 5 | import importlib 6 | import re 7 | 8 | # Vars to set: 9 | description = "" 10 | author = "anonym2020" 11 | author_email = "anonym2020@anonym2020.com" 12 | version = "0.0.1" # replaced by the version in the main init file if exists 13 | 14 | # Current dir: 15 | thelibFolder = os.path.dirname(os.path.realpath(__file__)) 16 | 17 | # We take all requirements from the file or you can set it here : 18 | requirementPath = thelibFolder + '/requirements.txt' 19 | install_requires = [] # Example : ["gunicorn", "docutils >= 0.3", "lxml==0.5a7"] 20 | dependency_links = [] 21 | if os.path.isfile(requirementPath): 22 | with open(requirementPath) as f: 23 | dependency_links = [] 24 | install_requires = [] 25 | required = f.read().splitlines() 26 | for current in required: 27 | if 'git' in current: 28 | if "https" not in current: 29 | current = current.replace("-e git", "https") 30 | current = current.replace(".git#egg", "/zipball/master#egg") 31 | dependency_links.append(current) 32 | else: 33 | install_requires.append(current) 34 | 35 | # dependency_links is deprecated, see https://serverfault.com/questions/608192/pip-install-seems-to-be-ignoring-dependency-links 36 | dependency_links = [] 37 | 38 | # We search a folder containing "__init__.py": 39 | def walklevel(some_dir, level=1): 40 | some_dir = some_dir.rstrip(os.path.sep) 41 | assert os.path.isdir(some_dir) 42 | num_sep = some_dir.count(os.path.sep) 43 | for root, dirs, files in os.walk(some_dir): 44 | yield root, dirs, files 45 | num_sep_this = root.count(os.path.sep) 46 | if num_sep + level <= num_sep_this: 47 | del dirs[:] 48 | mainPackageName = thelibFolder.lower().split('/')[-1] 49 | for dirname, dirnames, filenames in walklevel(thelibFolder): 50 | if "__init__.py" in filenames: 51 | mainPackageName = dirname.split("/")[-1] 52 | packagePath = thelibFolder + '/' + mainPackageName 53 | # Get the version of the lib in the __init__.py: 54 | initFilePath = packagePath + '/' + "__init__.py" 55 | if os.path.isdir(packagePath): 56 | with open(initFilePath, 'r') as f: 57 | text = f.read() 58 | result = re.search('^__version__\s*=\s*["\'](.*)["\']', text, flags=re.MULTILINE) 59 | if result is not None: 60 | version = result.group(1) 61 | 62 | # To import the lib, use: 63 | # thelib = importlib.import_module(mainPackageName) 64 | 65 | # Readme content: 66 | readme = None 67 | readmePath = thelibFolder + '/README.md' 68 | if os.path.isfile(readmePath): 69 | try: 70 | import pypandoc 71 | readme = pypandoc.convert(readmePath, 'rst') 72 | except(IOError, ImportError) as e: 73 | print(e) 74 | print("Cannot use pypandoc to convert the README...") 75 | readme = open(readmePath).read() 76 | 77 | packageName = thelibFolder.lower().split('/')[-1] 78 | if packageName.startswith("pip-"): 79 | packageName = mainPackageName.lower() 80 | 81 | # The whole setup: 82 | setup( 83 | 84 | # The name for PyPi: 85 | name=packageName, 86 | 87 | # The version of the code which is located in the main __init__.py: 88 | version=version, 89 | 90 | # All packages to add: 91 | packages=find_packages(), 92 | 93 | # About the author: 94 | author=author, 95 | author_email=author_email, 96 | 97 | # A short desc: 98 | description=description, 99 | 100 | # A long desc with the readme: 101 | long_description=readme, 102 | 103 | # Dependencies: 104 | install_requires=install_requires, 105 | dependency_links=dependency_links, 106 | 107 | # For handle the MANIFEST.in: 108 | include_package_data=True, 109 | 110 | # The url to the official repo: 111 | # url='https://', 112 | 113 | # You can choose what you want here : https://pypi.python.org/pypi?%3Aaction=list_classifiers 114 | classifiers=[ 115 | "Programming Language :: Python", 116 | "Development Status :: 1 - Planning", 117 | "License :: OSI Approved :: MIT License", 118 | "Natural Language :: English", 119 | "Operating System :: OS Independent", 120 | "Programming Language :: Python :: 2.7", 121 | "Topic :: Utilities", 122 | ], 123 | 124 | # If you want a command line like "do-something", on a specific funct of the package : 125 | # entry_points = { 126 | # 'console_scripts': [ 127 | # 'wm-setup = workspacemanager.setup:generateSetup', 128 | # 'wm-pew = workspacemanager.venv:generateVenv', 129 | # 'wm-deps = workspacemanager.deps:installDeps', 130 | # ], 131 | # }, 132 | ) 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | --------------------------------------------------------------------------------