├── ChemTopicModel ├── __init__.py ├── chemTopicModel.py ├── drawFPBits.py ├── drawTopicModel.py ├── topicModelGUI.py ├── utilsChemTM.py ├── utilsDrawing.py ├── utilsEvaluation.py └── utilsFP.py ├── LICENSE ├── README.md ├── notebooks_data ├── README.md ├── data │ ├── TMChEMBL23_combine_w_text_100Ksample.csv │ ├── TMChEMBL23_nofeatureMols.csv │ └── TMChEMBL23_subtopic61.csv ├── notebooks │ ├── ChemTM-CHEMBL23-combine-with-text.ipynb │ ├── ChemTM_ChEMBL23-HierachicalTopics-noFeatureMolecules.ipynb │ ├── ChemTM_ChEMBL23-HierachicalTopics-subTopic61.ipynb │ └── ChemTM_Exploring-100-topic-ChEMBL23-model-using-different-metrics.ipynb ├── runTM_example.sh └── runTopicModel.py └── setup.py /ChemTopicModel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdkit/CheTo/8147adf9ec094c54213879e4dd0ea73ea800dded/ChemTopicModel/__init__.py -------------------------------------------------------------------------------- /ChemTopicModel/chemTopicModel.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2016, Novartis Institutes for BioMedical Research Inc. 3 | # All rights reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are 7 | # met: 8 | # 9 | # * Redistributions of source code must retain the above copyright 10 | # notice, this list of conditions and the following disclaimer. 11 | # * Redistributions in binary form must reproduce the above 12 | # copyright notice, this list of conditions and the following 13 | # disclaimer in the documentation and/or other materials provided 14 | # with the distribution. 15 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 16 | # nor the names of its contributors may be used to endorse or promote 17 | # products derived from this software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | # 31 | # Created by Nadine Schneider, June 2016 32 | 33 | 34 | import random 35 | import pandas as pd 36 | import numpy as np 37 | from collections import defaultdict, Counter 38 | import re 39 | 40 | from sklearn.feature_extraction import DictVectorizer 41 | from sklearn.decomposition import LatentDirichletAllocation 42 | from joblib import Parallel, delayed 43 | 44 | from rdkit import Chem 45 | from rdkit.Chem import AllChem, rdqueries, BRICS 46 | 47 | from ChemTopicModel import utilsFP 48 | 49 | #### FRAGMENT GENERATION #################### 50 | 51 | def _prepBRICSSmiles(m): 52 | smi = Chem.MolToSmiles(m,isomericSmiles=True, allHsExplicit=True, allBondsExplicit=True) 53 | # delete the connection ids 54 | smi = re.sub(r"\[\d+\*\]", "[*]", smi) 55 | order = eval(m.GetProp("_smilesAtomOutputOrder")) 56 | # make the smiles more descriptive, add properties 57 | return utilsFP.writePropsToSmiles(m,smi,order) 58 | 59 | def _generateFPs(mol,fragmentMethod='Morgan'): 60 | aBits={} 61 | fp=None 62 | # circular Morgan fingerprint fragmentation, we use a simple invariant than ususal here 63 | if fragmentMethod=='Morgan': 64 | tmp={} 65 | fp = AllChem.GetMorganFingerprint(mol,radius=2,invariants=utilsFP.generateAtomInvariant(mol),bitInfo=tmp) 66 | aBits = utilsFP.getMorganEnvironment(mol, tmp, fp=fp, minRad=2) 67 | fp = fp.GetNonzeroElements() 68 | # path-based RDKit fingerprint fragmentation 69 | elif fragmentMethod=='RDK': 70 | fp = AllChem.UnfoldedRDKFingerprintCountBased(mol,maxPath=5,minPath=3,bitInfo=aBits) 71 | fp = fp.GetNonzeroElements() 72 | # get the final BRICS fragmentation (= smallest possible BRICS fragments of a molecule) 73 | elif fragmentMethod=='Brics': 74 | fragMol=BRICS.BreakBRICSBonds(mol) 75 | propSmi = _prepBRICSSmiles(fragMol) 76 | fp=Counter(propSmi.split('.')) 77 | else: 78 | print("Unknown fragment method") 79 | return fp, aBits 80 | 81 | # this function is not part of the class due to parallelisation 82 | # generate the fragments of a molecule, return a map with moleculeID and fragment dict 83 | def _generateMolFrags(datachunk, vocabulary, fragmentMethod, fragIdx=None): 84 | if fragIdx is None and fragmentMethod == 'Brics': 85 | return 86 | result={} 87 | for idx, smi in datachunk: 88 | mol = Chem.MolFromSmiles(str(smi)) 89 | if mol == None: 90 | continue 91 | fp,_=_generateFPs(mol,fragmentMethod=fragmentMethod) 92 | if fp is None: 93 | continue 94 | tmp={} 95 | for k,v in fp.items(): 96 | if k not in vocabulary: 97 | continue 98 | # save memory: for BRICS use index instead of long complicated SMILES 99 | if fragmentMethod == 'Brics': 100 | tmp[fragIdx[k]]=v 101 | else: 102 | tmp[k]=v 103 | result[idx]=tmp 104 | return result 105 | 106 | ########### chemical topic modeling class ################### 107 | class ChemTopicModel: 108 | 109 | # initialisation chemical topic model 110 | def __init__(self, fragmentMethod = 'Morgan', randomState=42, sizeSampleDataSet=0.1, rareThres=0.001, 111 | commonThres=0.1, verbose=0, n_jobs=1, chunksize=1000, learningMethod='batch'): 112 | self.fragmentMethod = fragmentMethod 113 | self.seed = randomState 114 | self.sizeSampleDataSet = sizeSampleDataSet 115 | self.rareThres = rareThres 116 | self.commonThres = commonThres 117 | self.verbose = verbose 118 | self.n_jobs = n_jobs 119 | self.chunksize = chunksize 120 | self.learningMethod = learningMethod 121 | 122 | # generate the fragments used for the model, exclude rare and common fragments depending on a threshold 123 | def _generateFragmentVocabulary(self,molSample): 124 | fps=defaultdict(int) 125 | # collect fragments from a sample of the dataset 126 | for smi in molSample: 127 | mol = Chem.MolFromSmiles(str(smi)) 128 | if mol is None: 129 | continue 130 | fp,_=_generateFPs(mol,fragmentMethod=self.fragmentMethod) 131 | if fp is None: 132 | continue 133 | for bit in fp.keys(): 134 | fps[bit]+=1 135 | # filter rare and common fragments 136 | fragOcc = np.array(list(fps.values())) 137 | normFragOcc = fragOcc/float(len(molSample)) 138 | ids = normFragOcc > self.commonThres 139 | normFragOcc[ids] = 0 140 | ids = normFragOcc < self.rareThres 141 | normFragOcc[ids] = 0 142 | keys = list(fps.keys()) 143 | self.vocabulary = sorted(n for n,i in zip(keys,normFragOcc) if i != 0) 144 | self.fragIdx=dict((i,j) for j,i in enumerate(self.vocabulary)) 145 | if self.verbose: 146 | print('Created alphabet, size: {0}, used sample size: {1}'.format(len(self.alphabet),len(molSample))) 147 | 148 | # generate the fragment templates important for the visualisation of the topics later 149 | def _generateFragmentTemplates(self,molSample): 150 | fragTemplateDict=defaultdict(list) 151 | voc=set(self.vocabulary) 152 | if not len(self.vocabulary): 153 | print('Please generate your vocabulary first') 154 | return 155 | sizeVocabulary=len(self.vocabulary) 156 | for n,smi in enumerate(molSample): 157 | mol = Chem.MolFromSmiles(str(smi)) 158 | if mol == None: 159 | continue 160 | fp,aBits=_generateFPs(mol,fragmentMethod=self.fragmentMethod) 161 | if fp is None: 162 | continue 163 | for k,v in fp.items(): 164 | if k not in voc or k in fragTemplateDict: 165 | continue 166 | # save memory: for brics use index instead of long complicated smarts 167 | if self.fragmentMethod in ['Brics','BricsAll']: 168 | fragTemplateDict[self.fragIdx[k]]=['', []] 169 | else: 170 | fragTemplateDict[k]=[smi, aBits[k][0]] 171 | if len(fragTemplateDict) == sizeVocabulary: 172 | break 173 | tmp = [[k,v[0],v[1]] for k,v in fragTemplateDict.items()] 174 | self.fragmentTemplates = pd.DataFrame(tmp,columns=['bitIdx','templateMol','bitPathTemplateMol']) 175 | if self.verbose: 176 | print('Created fragment templates', self.fragmentTemplates.shape) 177 | 178 | # generate fragments for the whole dataset 179 | def _generateFragments(self): 180 | voc=set(self.vocabulary) 181 | fpsdict = dict([(idx,{}) for idx in self.moldata.index]) 182 | nrows = self.moldata.shape[0] 183 | counter = 0 184 | with Parallel(n_jobs=self.n_jobs,verbose=self.verbose) as parallel: 185 | while counter < nrows: 186 | nextChunk = min(counter+(self.n_jobs*self.chunksize),nrows) 187 | result = parallel(delayed(_generateMolFrags)(mollist, voc, 188 | self.fragmentMethod, 189 | self.fragIdx) 190 | for mollist in self._produceDataChunks(counter,nextChunk,self.chunksize)) 191 | for r in result: 192 | counter+=len(r) 193 | fpsdict.update(r) 194 | self.moldata['fps'] = np.array(sorted(fpsdict.items()))[:,1] 195 | 196 | # construct the molecule-fragment matrix as input for the LDA algorithm 197 | def _generateFragmentMatrix(self): 198 | fragM=[] 199 | vsize=len(self.vocabulary) 200 | for n,fps in enumerate(self.moldata['fps']): 201 | # we only use 8 bit integers for the counts to save memory 202 | t=np.zeros((vsize,),dtype=np.uint8) 203 | for k,v in fps.items(): 204 | idx = k 205 | if self.fragmentMethod in ['Morgan', 'RDK']: 206 | idx = self.fragIdx[k] 207 | if v > 255: 208 | print("WARNING: too many fragments of type {0} in molecule {1}".format(k,self.moldata['smiles'][len(fragM)])) 209 | t[idx]=255 210 | else: 211 | t[idx]=v 212 | fragM.append(t) 213 | self.fragM = np.array(fragM) 214 | 215 | # helper functions for parallelisation 216 | def _produceDataChunks(self,start,end,chunksize): 217 | for start in range(start,end,chunksize): 218 | end=min(self.moldata.shape[0],start+chunksize) 219 | yield list(zip(self.moldata[start:end].index, self.moldata[start:end]['smiles'])) 220 | 221 | def _generateMatrixChunks(self, start,end,chunksize=10000): 222 | for start in range(start,end,chunksize): 223 | end=min(self.fragM.shape[0],start+chunksize) 224 | yield self.fragM[start:end,:], start 225 | 226 | ############# main functions #####################################S 227 | 228 | # load the data (molecule table in SMILES format (required) and optionally some lables for the molecules) 229 | def loadData(self, inputDataFrame): 230 | self.moldata = inputDataFrame 231 | oriLabelNames = list(self.moldata.columns) 232 | self.oriLabelNames = oriLabelNames[1:] 233 | self.moldata.rename(columns=dict(zip(oriLabelNames, ['smiles']+['label_'+str(i) for i in range(len(oriLabelNames)-1)])), 234 | inplace=True) 235 | 236 | def generateFragments(self): 237 | # set a fixed seed due to order dependence of the LDA method --> the same data should get the same results 238 | sample = self.moldata.sample(frac=self.sizeSampleDataSet,random_state=np.random.RandomState(42)) 239 | self._generateFragmentVocabulary(sample['smiles']) 240 | self._generateFragmentTemplates(sample['smiles']) 241 | self._generateFragments() 242 | self._generateFragmentMatrix() 243 | 244 | # it is better use these functions instead of buildTopicModel if the dataset is larger 245 | def fitTopicModel(self, numTopics, max_iter=100, nJobs=1, sizeFittingDataset=1.0, **kwargs): 246 | 247 | self.lda = LatentDirichletAllocation(n_components=numTopics,learning_method=self.learningMethod,random_state=self.seed, 248 | n_jobs=nJobs, max_iter=max_iter, batch_size=self.chunksize, **kwargs) 249 | 250 | inputMatrix=self.fragM 251 | if sizeFittingDataset < 1.0: 252 | 253 | np.random.seed(self.seed) 254 | upperIndex = self.fragM.shape[0]-1 255 | size = int(self.fragM.shape[0]*sizeFittingDataset) 256 | ids = np.random.randint(0,upperIndex, size=size) 257 | inputMatrix = self.fragM[sorted(ids)] 258 | 259 | if inputMatrix.shape[0] > self.chunksize: 260 | # fit the model in chunks 261 | self.lda.learning_method = 'online' 262 | self.lda.fit(inputMatrix) 263 | 264 | def transformDataToTopicModel(self,lowerPrecision=False): 265 | 266 | try: 267 | self.lda 268 | except: 269 | raise ValueError('No topic model is available') 270 | 271 | if lowerPrecision: 272 | print('WARNING: using lower precision mode') 273 | 274 | if self.fragM.shape[0] > self.chunksize: 275 | # after fitting transform the data to our model 276 | for chunk in self._generateMatrixChunks(0,self.fragM.shape[0],chunksize=self.chunksize): 277 | resultLDA = self.lda.transform(chunk[0]) 278 | # here using a 32bit float instead of the 64bit float would save memory and might be enough precision. Test that later!! 279 | if chunk[1] > 0: 280 | if lowerPrecision: 281 | self.documentTopicProbabilities = np.concatenate((self.documentTopicProbabilities, 282 | (resultLDA/resultLDA.sum(axis=1,keepdims=1)).astype(np.float32)), axis=0) 283 | else: 284 | self.documentTopicProbabilities = np.concatenate((self.documentTopicProbabilities, 285 | resultLDA/resultLDA.sum(axis=1,keepdims=1)), axis=0) 286 | else: 287 | self.documentTopicProbabilities = resultLDA/resultLDA.sum(axis=1,keepdims=1) 288 | if lowerPrecision: 289 | self.documentTopicProbabilities = self.documentTopicProbabilities.astype(np.float32) 290 | else: 291 | resultLDA = self.lda.transform(self.fragM) 292 | # next line is not need anymore since it is normalized in sklearn already since version 0.18 293 | # self.documentTopicProbabilities = resultLDA/resultLDA.sum(axis=1,keepdims=1) 294 | self.documentTopicProbabilities = resultLDA 295 | if lowerPrecision: 296 | self.documentTopicProbabilities = self.documentTopicProbabilities.astype(np.float32) 297 | 298 | 299 | # use this if the dataset is small- to medium-sized 300 | def buildTopicModel(self, numTopics, max_iter=100, nJobs=1, lowerPrecision=False, sizeFittingDataset=0.1, **kwargs): 301 | 302 | self.fitTopicModel(numTopics, max_iter=max_iter, nJobs=nJobs, sizeFittingDataset=sizeFittingDataset, **kwargs) 303 | self.transformDataToTopicModel(lowerPrecision=lowerPrecision) 304 | 305 | 306 | def getTopicFragmentProbabilities(self): 307 | 308 | try: 309 | self.lda 310 | except: 311 | raise ValueError('No topic model is available') 312 | return self.lda.components_/self.lda.components_.sum(axis=1,keepdims=1) 313 | 314 | -------------------------------------------------------------------------------- /ChemTopicModel/drawFPBits.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2016, Novartis Institutes for BioMedical Research Inc. 3 | # All rights reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are 7 | # met: 8 | # 9 | # * Redistributions of source code must retain the above copyright 10 | # notice, this list of conditions and the following disclaimer. 11 | # * Redistributions in binary form must reproduce the above 12 | # copyright notice, this list of conditions and the following 13 | # disclaimer in the documentation and/or other materials provided 14 | # with the distribution. 15 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 16 | # nor the names of its contributors may be used to endorse or promote 17 | # products derived from this software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | # 31 | # Created by Nadine Schneider, June 2016 32 | 33 | 34 | from rdkit import Chem 35 | from rdkit.Chem.Draw import rdMolDraw2D 36 | from rdkit.Chem import rdDepictor 37 | 38 | import numpy as np 39 | import re 40 | 41 | def _drawFPBit(smi,bitPath,molSize=(150,150),kekulize=True,baseRad=0.05,svg=True, fontSize=0.9,**kwargs): 42 | mol = Chem.MolFromSmiles(smi) 43 | rdDepictor.Compute2DCoords(mol) 44 | 45 | # get the atoms for highlighting 46 | atomsToUse=[] 47 | for b in bitPath: 48 | atomsToUse.append(mol.GetBondWithIdx(b).GetBeginAtomIdx()) 49 | atomsToUse.append(mol.GetBondWithIdx(b).GetEndAtomIdx()) 50 | atomsToUse = list(set(atomsToUse)) 51 | 52 | # enlarge the environment by one further bond 53 | enlargedEnv=[] 54 | for atom in atomsToUse: 55 | a = mol.GetAtomWithIdx(atom) 56 | for b in a.GetBonds(): 57 | bidx=b.GetIdx() 58 | if bidx not in bitPath: 59 | enlargedEnv.append(bidx) 60 | enlargedEnv = list(set(enlargedEnv)) 61 | enlargedEnv+=bitPath 62 | 63 | # set the coordinates of the submol based on the coordinates of the original molecule 64 | amap={} 65 | submol = Chem.PathToSubmol(mol,enlargedEnv,atomMap=amap) 66 | rdDepictor.Compute2DCoords(submol) 67 | conf = submol.GetConformer(0) 68 | confOri = mol.GetConformer(0) 69 | for i1,i2 in amap.items(): 70 | conf.SetAtomPosition(i2,confOri.GetAtomPosition(i1)) 71 | 72 | envSubmol=[] 73 | for i1,i2 in amap.items(): 74 | for b in bitPath: 75 | beginAtom=amap[mol.GetBondWithIdx(b).GetBeginAtomIdx()] 76 | endAtom=amap[mol.GetBondWithIdx(b).GetEndAtomIdx()] 77 | envSubmol.append(submol.GetBondBetweenAtoms(beginAtom,endAtom).GetIdx()) 78 | 79 | # Drawing 80 | if svg: 81 | drawer = rdMolDraw2D.MolDraw2DSVG(molSize[0],molSize[1]) 82 | else: 83 | drawer = rdMolDraw2D.MolDraw2DCairo(molSize[0],molSize[1]) 84 | 85 | drawer.SetFontSize(fontSize) 86 | drawopt=drawer.drawOptions() 87 | drawopt.continuousHighlight=False 88 | 89 | # color all atoms of the submol in gray which are not part of the bit 90 | # highlight atoms which are in rings 91 | color = (.9,.9,.9) 92 | atomcolors,bondcolors={},{} 93 | highlightAtoms,highlightBonds=[],[] 94 | 95 | for aidx in amap.keys(): 96 | if aidx in atomsToUse: 97 | if mol.GetAtomWithIdx(aidx).GetIsAromatic(): 98 | atomcolors[amap[aidx]]=(0.9,0.9,0.2) 99 | highlightAtoms.append(amap[aidx]) 100 | elif mol.GetAtomWithIdx(aidx).IsInRing(): 101 | atomcolors[amap[aidx]]=(0.8,0.8,0.8) 102 | highlightAtoms.append(amap[aidx]) 103 | else: 104 | drawopt.atomLabels[amap[aidx]]='*' 105 | submol.GetAtomWithIdx(amap[aidx]).SetAtomicNum(1) 106 | for bid in submol.GetBonds(): 107 | bidx=bid.GetIdx() 108 | if bidx not in envSubmol: 109 | bondcolors[bidx]=color 110 | highlightBonds.append(bidx) 111 | 112 | drawer.DrawMolecule(submol,highlightAtoms=highlightAtoms,highlightAtomColors=atomcolors, 113 | highlightBonds=highlightBonds,highlightBondColors=bondcolors, 114 | **kwargs) 115 | drawer.FinishDrawing() 116 | return drawer.GetDrawingText() 117 | 118 | def drawFPBitPNG(smi,bitPath,molSize=(150,150),kekulize=True,baseRad=0.05,**kwargs): 119 | return _drawFPBit(smi,bitPath,molSize=molSize,kekulize=kekulize,baseRad=baseRad, svg=False,**kwargs) 120 | 121 | def drawFPBit(smi,bitPath,molSize=(150,150),kekulize=True,baseRad=0.05,fontSize=0.9,**kwargs): 122 | svg = _drawFPBit(smi,bitPath,molSize=molSize,kekulize=kekulize,baseRad=baseRad,**kwargs) 123 | return svg.replace('svg:','') 124 | 125 | def _drawBricsFrag(smi,molSize=(150,150),kekulize=True,baseRad=0.05,svg=True,**kwargs): 126 | 127 | # delete smarts specific syntax from the pattern 128 | smi = re.sub(r"\;R\d?\;D\d+", "", smi) 129 | mol = Chem.MolFromSmiles(smi, sanitize=True) 130 | mc = rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=kekulize) 131 | 132 | # Drawing 133 | drawer = rdMolDraw2D.MolDraw2DSVG(molSize[0],molSize[1]) 134 | if not svg: 135 | drawer = rdMolDraw2D.MolDraw2DCairo(molSize[0],molSize[1]) 136 | drawer.DrawMolecule(mc,**kwargs) 137 | drawer.FinishDrawing() 138 | return drawer.GetDrawingText() 139 | 140 | def drawBricsFragPNG(smi,molSize=(150,150),kekulize=True,baseRad=0.05,**kwargs): 141 | return _drawBricsFrag(smi,molSize=molSize,kekulize=kekulize,baseRad=baseRad,svg=False,**kwargs) 142 | 143 | def drawBricsFrag(smi,molSize=(150,150),kekulize=True,baseRad=0.05,**kwargs): 144 | svg = _drawBricsFrag(smi,molSize=molSize,kekulize=kekulize,baseRad=baseRad,**kwargs) 145 | return svg.replace('svg:','') 146 | -------------------------------------------------------------------------------- /ChemTopicModel/drawTopicModel.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2016, Novartis Institutes for BioMedical Research Inc. 3 | # All rights reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are 7 | # met: 8 | # 9 | # * Redistributions of source code must retain the above copyright 10 | # notice, this list of conditions and the following disclaimer. 11 | # * Redistributions in binary form must reproduce the above 12 | # copyright notice, this list of conditions and the following 13 | # disclaimer in the documentation and/or other materials provided 14 | # with the distribution. 15 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 16 | # nor the names of its contributors may be used to endorse or promote 17 | # products derived from this software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | # 31 | # Created by Nadine Schneider, June 2016 32 | 33 | 34 | from rdkit import Chem 35 | from rdkit.Chem import rdqueries 36 | from rdkit.Chem.Draw import rdMolDraw2D 37 | from rdkit.Chem import rdDepictor 38 | 39 | from IPython.display import display,HTML,SVG 40 | 41 | from collections import defaultdict 42 | import operator 43 | import seaborn as sns 44 | import numpy as np 45 | 46 | from ChemTopicModel import utilsDrawing, drawFPBits, chemTopicModel 47 | 48 | # get the topic fragment probabilites per atom for highlighting 49 | def _getAtomWeights(mol, molID, topicID, topicModel): 50 | 51 | weights=[0]*mol.GetNumAtoms() 52 | # ignore "wildcard atoms" in BRICS fragments 53 | q = rdqueries.AtomNumEqualsQueryAtom(0) 54 | # get all fragments of a certain molecule 55 | _,aBits=chemTopicModel._generateFPs(mol, topicModel.fragmentMethod) 56 | fp=topicModel.moldata.loc[molID,'fps'] 57 | probs = topicModel.getTopicFragmentProbabilities() 58 | nTopics, nFrags = probs.shape 59 | # use the max probability of a fragment associated with a certain topic 60 | # to normalize the fragment weights 61 | maxWeightTopic = max(probs[topicID]) 62 | r = 0.0 63 | # calculate the weight of an atom concerning a certain topic 64 | for bit in fp.keys(): 65 | try: 66 | idxBit = bit 67 | if topicModel.fragmentMethod in ['Morgan', 'RDK']: 68 | idxBit = topicModel.fragIdx[bit] 69 | except: 70 | continue 71 | try: 72 | r = probs[topicID,idxBit] 73 | except: 74 | continue 75 | if r <= 1./nFrags: 76 | continue 77 | # Morgan/RDK fingerprints 78 | if topicModel.fragmentMethod in ['Morgan', 'RDK'] and bit in aBits: 79 | paths = aBits[bit] 80 | for p in paths: 81 | for b in p: 82 | bond = mol.GetBondWithIdx(b) 83 | # for overlapping fragments take the highest weight for the atom 84 | weights[bond.GetBeginAtomIdx()]=max(r,weights[bond.GetBeginAtomIdx()]) 85 | weights[bond.GetEndAtomIdx()]=max(r,weights[bond.GetEndAtomIdx()]) 86 | elif topicModel.fragmentMethod.startswith('Brics'): 87 | # BRICS fragments 88 | submol = Chem.MolFromSmarts(topicModel.vocabulary[idxBit]) 89 | ignoreWildcards = [i.GetIdx() for i in list(submol.GetAtomsMatchingQuery(q))] 90 | matches = mol.GetSubstructMatches(submol) 91 | for m in matches: 92 | for n,atomidx in enumerate(m): 93 | if n in ignoreWildcards: 94 | continue 95 | # for overlapping fragments take the highest weight for the atom, this not happen for BRICS though :) 96 | weights[atomidx]=max(r,weights[atomidx]) 97 | atomWeights = np.array(weights) 98 | return atomWeights,maxWeightTopic 99 | 100 | # hightlight a topic in a molecule 101 | def drawTopicWeightsMolecule(mol, molID, topicID, topicModel, molSize=(450,200), kekulize=True,\ 102 | baseRad=0.1, color=(.9,.9,.9), fontSize=0.9): 103 | 104 | # get the atom weights 105 | atomWeights,maxWeightTopic=_getAtomWeights(mol, molID, topicID, topicModel) 106 | atRads={} 107 | atColors={} 108 | 109 | # color the atoms and set their highlight radius according to their weight 110 | if np.sum(atomWeights) > 0: 111 | for at,score in enumerate(atomWeights): 112 | atColors[at]=color 113 | atRads[at]=max(atomWeights[at]/maxWeightTopic, 0.0) * baseRad 114 | 115 | if atRads[at] > 0 and atRads[at] < 0.2: 116 | atRads[at] = 0.2 117 | 118 | try: 119 | mc = rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=kekulize) 120 | except ValueError: # <- can happen on a kekulization failure 121 | mc = rdMolDraw2D.PrepareMolForDrawing(mol, kekulize=False) 122 | 123 | drawer = rdMolDraw2D.MolDraw2DSVG(molSize[0],molSize[1]) 124 | drawer.SetFontSize(fontSize) 125 | drawer.DrawMolecule(mc,highlightAtoms=atColors.keys(), 126 | highlightAtomColors=atColors,highlightAtomRadii=atRads, 127 | highlightBonds=[]) 128 | drawer.FinishDrawing() 129 | svg = drawer.GetDrawingText() 130 | return svg.replace('svg:','') 131 | 132 | # generates all svgs of molecules belonging to a certain topic and highlights this topic within the molecule 133 | def generateMoleculeSVGsbyTopicIdx(topicModel, topicIdx, idsLabelToShow=[0], topicProbThreshold = 0.5, baseRad=0.5,\ 134 | molSize=(250,150),color=(.0,.0, 1.),maxMols=100, fontSize=0.9, maxTopicProb=0.5): 135 | svgs=[] 136 | namesSVGs=[] 137 | numDocs, numTopics = topicModel.documentTopicProbabilities.shape 138 | 139 | if topicIdx >= numTopics: 140 | return "Topic not found" 141 | tmp=topicModel.documentTopicProbabilities[:,topicIdx] 142 | ids=np.where(tmp >= topicProbThreshold) 143 | molset = sorted(list(zip(tmp[ids].tolist(),ids[0].tolist())), reverse=True)[:maxMols] 144 | if maxTopicProb > topicProbThreshold: 145 | ids=np.where((tmp >= topicProbThreshold) & (tmp < maxTopicProb)) 146 | molset = sorted(list(zip(tmp[ids].tolist(),ids[0].tolist())), reverse=False)[:maxMols] 147 | 148 | for prob,doc in molset: 149 | data = topicModel.moldata.iloc[doc] 150 | smi = data['smiles'] 151 | name = '' 152 | for idx in idsLabelToShow: 153 | name += str(data['label_'+str(idx)]) 154 | name += ' | ' 155 | mol = Chem.MolFromSmiles(smi) 156 | topicProb = prob #topicModel.documentTopicProbabilities[doc,topicIdx] 157 | svg = drawTopicWeightsMolecule(mol, doc, topicIdx, topicModel, molSize=molSize, baseRad=baseRad, color=color, fontSize=fontSize) 158 | svgs.append(svg) 159 | maxTopicID= np.argmax(topicModel.documentTopicProbabilities[doc]) 160 | maxProb = np.max(topicModel.documentTopicProbabilities[doc]) 161 | namesSVGs.append('{0}(p={1:.2f}) | (pmax({2})={3:.2f})'.format(name,topicProb,maxTopicID,maxProb)) 162 | if not len(svgs): 163 | #print('No molecules can be drawn') 164 | return [],[] 165 | 166 | return svgs, namesSVGs 167 | 168 | # generates all svgs of molecules having a certain label attached and highlights most probable topic within the molecule 169 | def generateMoleculeSVGsbyLabel(topicModel, label, idLabelToMatch=0, baseRad=0.5, molSize=(250,150),maxMols=100): 170 | 171 | data = topicModel.moldata.loc[topicModel.moldata['label_'+str(idLabelToMatch)] == label] 172 | 173 | if not len(data): 174 | return "Label not found" 175 | 176 | svgs=[] 177 | namesSVGs=[] 178 | numDocs, numTopics = topicModel.documentTopicProbabilities.shape 179 | colors = sns.husl_palette(numTopics, s=.6) 180 | 181 | topicIdx = np.argmax(topicModel.documentTopicProbabilities[data.index,:],axis=1) 182 | topicProb = np.amax(topicModel.documentTopicProbabilities[data.index,:],axis=1) 183 | topicdata = list(zip(data.index, topicIdx, topicProb)) 184 | topicdata_sorted = sorted(topicdata, key=operator.itemgetter(2), reverse=True) 185 | 186 | for idx,tIdx,tProb in topicdata_sorted[:maxMols]: 187 | mol = Chem.MolFromSmiles(data['smiles'][idx]) 188 | color = tuple(colors[tIdx]) 189 | svg = drawTopicWeightsMolecule(mol, idx, tIdx, topicModel, molSize=molSize, baseRad=baseRad, color=color) 190 | svgs.append(svg) 191 | namesSVGs.append(str("Topic "+str(tIdx)+" | (p="+str(round(tProb,2))+")")) 192 | 193 | return svgs, namesSVGs 194 | 195 | ### draw mols by label, highlight different topics 196 | 197 | # draws molecules of a certain label in a html table and highlights the most probable topic 198 | def drawMolsByLabel(topicModel, label, idLabelToMatch=0, baseRad=0.5, molSize=(250,150),\ 199 | numRowsShown=3, tableHeader='', maxMols=100): 200 | 201 | result = generateMoleculeSVGsbyLabel(topicModel, label, idLabelToMatch=idLabelToMatch,baseRad=baseRad,\ 202 | molSize=molSize, maxMols=maxMols) 203 | if len(result) == 1: 204 | print(result) 205 | return 206 | 207 | svgs, namesSVGs = result 208 | finalsvgs = [] 209 | for svg in svgs: 210 | # make the svg scalable 211 | finalsvgs.append(svg.replace(' prior, sorted(scores[topicIdx,:], reverse=True)[:n_top_frags]))) 285 | 286 | svgGrid = utilsDrawing.SvgsToGrid(svgs, namesSVGs, svgsPerRow=svgsPerRow, molSize=molSize) 287 | 288 | return svgGrid 289 | 290 | # generates svgs of the fragments related to a certain topic 291 | def generateTopicRelatedFragmentSVGs(topicModel, topicIdx, n_top_frags=10, molSize=(100,100),\ 292 | svg=True, prior=-1.0, fontSize=0.9): 293 | svgs=[] 294 | probs = topicModel.getTopicFragmentProbabilities() 295 | numTopics, numFragments = probs.shape 296 | if prior < 0: 297 | prior = 1./numFragments 298 | # only consider the top n fragments 299 | for i in probs[topicIdx,:].argsort()[::-1][:n_top_frags]: 300 | if probs[topicIdx,i] > prior: 301 | bit = topicModel.vocabulary[i] 302 | 303 | # allows including words 304 | if type(bit) != int: 305 | svgs.append(bit) 306 | continue 307 | 308 | # draw the bits using the templates 309 | if topicModel.fragmentMethod in ['Morgan', 'RDK']: 310 | templMol = topicModel.fragmentTemplates.loc[topicModel.fragmentTemplates['bitIdx'] == bit]['templateMol'].item() 311 | pathTemplMol = topicModel.fragmentTemplates.loc[topicModel.fragmentTemplates['bitIdx'] == bit]['bitPathTemplateMol'].item() 312 | if svg: 313 | svgs.append(drawFPBits.drawFPBit(templMol,pathTemplMol,molSize=molSize, fontSize=fontSize)) 314 | else: 315 | svgs.append(drawFPBits.drawFPBitPNG(templMol,pathTemplMol,molSize=molSize)) 316 | else: 317 | if svg: 318 | svgs.append(drawFPBits.drawBricsFrag(bit,molSize=molSize)) 319 | else: 320 | svgs.append(drawFPBits.drawBricsFragPNG(bit,molSize=molSize)) 321 | return svgs 322 | 323 | # draw the svgs of the fragments related to a certain topic in a html table 324 | def drawFragmentsbyTopic(topicModel, topicIdx, n_top_frags=10, numRowsShown=4, cssTableName='fragTab', \ 325 | prior=-1.0, numColumns=4, tableHeader='',fontSize=0.9): 326 | 327 | scores = topicModel.getTopicFragmentProbabilities() 328 | numTopics, numFragments = scores.shape 329 | if prior < 0: 330 | prior = 1./numFragments 331 | svgs=generateTopicRelatedFragmentSVGs(topicModel, topicIdx, n_top_frags=n_top_frags, prior=prior,fontSize=fontSize) 332 | namesSVGs = list(map(lambda x: "p(k={0})={1:.2f}".format(topicIdx,x), \ 333 | filter(lambda y: y > prior, sorted(scores[topicIdx,:], reverse=True)[:n_top_frags]))) 334 | if tableHeader == '': 335 | tableHeader = "Topic "+str(topicIdx) 336 | return display(HTML(utilsDrawing.drawSVGsToHTMLGrid(svgs,tableHeader=tableHeader,cssTableName=cssTableName,\ 337 | namesSVGs=namesSVGs,size=(120,100),numRowsShown=numRowsShown,\ 338 | numColumns=numColumns))) 339 | 340 | -------------------------------------------------------------------------------- /ChemTopicModel/topicModelGUI.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2016, Novartis Institutes for BioMedical Research Inc. 3 | # All rights reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are 7 | # met: 8 | # 9 | # * Redistributions of source code must retain the above copyright 10 | # notice, this list of conditions and the following disclaimer. 11 | # * Redistributions in binary form must reproduce the above 12 | # copyright notice, this list of conditions and the following 13 | # disclaimer in the documentation and/or other materials provided 14 | # with the distribution. 15 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 16 | # nor the names of its contributors may be used to endorse or promote 17 | # products derived from this software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | # 31 | # Created by Nadine Schneider, December 2016 32 | 33 | 34 | from ipywidgets import * 35 | from IPython.display import display, clear_output 36 | from ChemTopicModel import drawTopicModel, chemTopicModel 37 | 38 | # allows choosing of topic colors 39 | from matplotlib.colors import hex2color, rgb2hex, cnames 40 | import seaborn as sns 41 | import matplotlib.pyplot as plt 42 | import os 43 | import time 44 | import pandas as pd 45 | import numpy as np 46 | from collections import defaultdict 47 | 48 | # some nice interactive bokeh plots 49 | from bokeh.models import HoverTool, ColumnDataSource 50 | from bokeh.plotting import show, figure 51 | from bokeh.io import output_notebook 52 | 53 | # outputy bokeh plots within the notebook 54 | output_notebook() 55 | # use seaborn style 56 | sns.set() 57 | # for encoding the png images 58 | import base64 59 | 60 | def to_base64(png): 61 | return "data:image/png;base64," + base64.b64encode(png).decode("utf-8") 62 | 63 | # main GUI 64 | def TopicModel(): 65 | 66 | def buildModel(sender): 67 | clear_output() 68 | showTopicButton.disabled = True 69 | showMoleculesButton.disabled = True 70 | saveAsButton.disabled = True 71 | saveAsButton2.disabled = True 72 | statsButton.disabled = True 73 | statsButton2.disabled = True 74 | progressBar = widgets.FloatProgress(min=0, max=100, width='300px', margin='10px 5px 10px 10px') 75 | labelProgressBar.value = 'Loading data' 76 | display(progressBar) 77 | 78 | filename = dataSetSelector.value 79 | if filename == '': 80 | print('No data set specified, please check your input.') 81 | progressBar.close() 82 | return 83 | try: 84 | data = pd.read_csv(filename) 85 | labelProgressBar.value = 'Generating fragments (may take several minutes for larger data sets)' 86 | progressBar.value +=33 87 | except: 88 | progressBar.value +=100 89 | labelProgressBar.value = 'Reading data failed' 90 | print('Invalid data file, please check your file.') 91 | progressBar.close() 92 | return 93 | try: 94 | starttime = time.time() 95 | topicModel=chemTopicModel.ChemTopicModel(fragmentMethod=fragmentmethodSelector.value, rareThres=rareFilterSelector.value, commonThres=commonFilterSelector.value) 96 | topicModel.loadData(data) 97 | topicModel.generateFragments() 98 | labelProgressBar.value = 'Building the model (may take several minutes for larger data sets and many topics)' 99 | progressBar.value +=33 100 | topicModel.buildTopicModel(numTopicSelector.value) 101 | finaltime = time.time() - starttime 102 | progressBar.value +=34 103 | labelProgressBar.value = 'Finished model successfully in %.3f sec'%finaltime 104 | 105 | # Update parameters,dropdown options etc. 106 | labelSelector.options = topicModel.oriLabelNames 107 | labelSelector2.options = topicModel.oriLabelNames 108 | labelSelector2a.options = topicModel.oriLabelNames 109 | numDocs, numTopics = topicModel.documentTopicProbabilities.shape 110 | labelID = topicModel.oriLabelNames.index(labelSelector2.value) 111 | labelSelector3.options = sorted(list(set(topicModel.moldata['label_'+str(labelID)]))) 112 | labelSelector3a.options = sorted(list(set(topicModel.moldata['label_'+str(labelID)]))) 113 | topicSelector.max=numTopics 114 | params['labels'] = topicModel.oriLabelNames 115 | params['topicModel'] = topicModel 116 | params['colors'] = sns.husl_palette(numTopics, s=.6) 117 | params['numTopics'] = numTopics 118 | showTopicButton.disabled = False 119 | showMoleculesButton.disabled = False 120 | saveAsButton.disabled = False 121 | saveAsButton2.disabled = False 122 | statsButton.disabled = False 123 | statsButton2.disabled = False 124 | progressBar.close() 125 | except: 126 | progressBar.value +=100 127 | labelProgressBar.value = 'Model building failed' 128 | print('Topic model could not be built.') 129 | return 130 | 131 | _tooltipstr=""" 132 |
133 |
134 | Topic $index
135 | Top 3 fragments 136 |
137 |
138 |
139 | 144 |
[@desc4]
145 |
146 |
147 | 152 |
[@desc5]
153 |
154 |
155 | 160 |
[@desc6]
161 |
162 |
163 |
164 | """ 165 | 166 | def _getToolTipImages(topicModel, numTopics, nTopFrags): 167 | 168 | tmp=[] 169 | name=[] 170 | scores=topicModel.getTopicFragmentProbabilities() 171 | for i in range(0,numTopics): 172 | try: 173 | imgs = drawTopicModel.generateTopicRelatedFragmentSVGs(topicModel, i, n_top_frags=nTopFrags,molSize=(100,80),svg=False) 174 | t = [to_base64(i) for i in imgs] 175 | if len(t) < nTopFrags: 176 | for j in range(len(t),nTopFrags): 177 | t.append('') 178 | tmp.append(t) 179 | except: 180 | pass 181 | names = list(map(lambda x: "Score %.2f" % x, filter(lambda y: y > 0.0, sorted(scores[i,:], reverse=True)[:nTopFrags]))) 182 | name.append(names) 183 | name = np.array(name) 184 | 185 | edges = np.arange(numTopics+1) 186 | if len(tmp) == 0: 187 | tmp=[['','','']]*numTopics 188 | tmp = np.array(tmp) 189 | return name,tmp,edges 190 | 191 | def calcOverallStatistics(sender): 192 | clear_output() 193 | labelProgressBar.value='' 194 | topicModel = params['topicModel'] 195 | numDocs, numTopics = topicModel.documentTopicProbabilities.shape 196 | topicDocStats=[0]*numTopics 197 | for doc in range(0,numDocs): 198 | topicDocStats[np.argmax(topicModel.documentTopicProbabilities[doc,:])]+=1 199 | topicDocStatsNorm=np.array(topicDocStats).astype(float)/numDocs 200 | 201 | name,tmp,edges = _getToolTipImages(topicModel, numTopics, 3) 202 | 203 | source = ColumnDataSource( data = dict( y = topicDocStatsNorm, l = edges[ :-1 ], r = edges[ 1: ], desc1 = tmp[:,0], \ 204 | desc2 = tmp[:,1], desc3 = tmp[:,2], desc4 = name[:,0], \ 205 | desc5 = name[:,1], desc6 = name[:,2])) 206 | 207 | hover=HoverTool() 208 | hover.tooltips= _tooltipstr 209 | 210 | p = figure(width=800, height=400, tools=[hover], toolbar_location=None, title="Overall topic distribution") 211 | 212 | p.quad( top = 'y', bottom = 0, left = 'l', right = 'r', 213 | fill_color = "#036564", line_color = "#033649", source = source ) 214 | 215 | p.xaxis.axis_label = "Topics" 216 | p.yaxis.axis_label = "% molecules per topic" 217 | p.xaxis.minor_tick_line_color = None 218 | show(p) 219 | 220 | def calcSubsetStatistics(sender): 221 | clear_output() 222 | labelProgressBar.value='' 223 | topicModel = params['topicModel'] 224 | label = labelSelector3a.value 225 | labelID = params['labels'].index(labelSelector2a.value) 226 | numDocs, numTopics = topicModel.documentTopicProbabilities.shape 227 | 228 | data = topicModel.moldata.loc[topicModel.moldata['label_'+str(labelID)] == label] 229 | topicProfile = np.zeros((numTopics,), dtype=np.int) 230 | 231 | for idx in data.index: 232 | topicProfile = np.sum([topicProfile, topicModel.documentTopicProbabilities[idx]], axis=0) 233 | topicProfileNorm=np.array(topicProfile).astype(float)/data.shape[0] 234 | 235 | name,tmp,edges = _getToolTipImages(topicModel, numTopics, 3) 236 | 237 | source = ColumnDataSource( data = dict( y = topicProfileNorm, l = edges[ :-1 ], r = edges[ 1: ], desc1 = tmp[:,0], \ 238 | desc2 = tmp[:,1], desc3 = tmp[:,2], desc4 = name[:,0], \ 239 | desc5 = name[:,1], desc6 = name[:,2])) 240 | 241 | hover=HoverTool() 242 | hover.tooltips= _tooltipstr 243 | 244 | p = figure(width=800, height=400, tools=[hover],toolbar_location=None, title="Topic profile for "+str(label)) 245 | 246 | p.quad( top = 'y', bottom = 0, left = 'l', right = 'r', 247 | fill_color = "#036564", line_color = "#033649", source = source ) 248 | 249 | p.xaxis.axis_label = "Topics" 250 | p.yaxis.axis_label = "Mean probability of topics" 251 | p.xaxis.minor_tick_line_color = None 252 | show(p) 253 | 254 | 255 | def showTopic(sender): 256 | topicModel = params['topicModel'] 257 | clear_output() 258 | labelProgressBar.value='' 259 | topicID = topicSelector.value 260 | labelID = params['labels'].index(labelSelector.value) 261 | if chooseColor.value: 262 | c = colorSelector.value 263 | if not c.startswith('#'): 264 | c = cnames[c] 265 | hex_color = c 266 | rgb_color = hex2color(hex_color) 267 | else: 268 | rgb_color = tuple(params['colors'][topicSelector.value]) 269 | colorSelector.value = rgb2hex(rgb_color) 270 | 271 | temp=None 272 | if topicID == '' or labelID == '': 273 | print("Please check your input") 274 | else: 275 | drawTopicModel.drawFragmentsbyTopic(topicModel, topicID, n_top_frags=20, numRowsShown=1.2,\ 276 | numColumns=8, tableHeader='Top fragments of topic '+str(topicID)) 277 | 278 | drawTopicModel.drawMolsByTopic(topicModel, topicID, idsLabelToShow=[labelID], topicProbThreshold = 0.1, baseRad=0.9,\ 279 | numRowsShown=3, color=rgb_color) 280 | 281 | def showMolecules(sender): 282 | topicModel = params['topicModel'] 283 | clear_output() 284 | labelProgressBar.value='' 285 | label = labelSelector3.value 286 | labelID = params['labels'].index(labelSelector2.value) 287 | 288 | if label == '' or labelID == '': 289 | print("Please check your input") 290 | else: 291 | drawTopicModel.drawMolsByLabel(topicModel, label, idLabelToMatch=labelID, baseRad=0.9, \ 292 | molSize=(250,150), numRowsShown=3) 293 | 294 | 295 | def saveTopicAs(sender): 296 | topicModel = params['topicModel'] 297 | topicID = topicSelector.value 298 | labelID = params['labels'].index(labelSelector.value) 299 | path = filePath.value 300 | 301 | if chooseColor.value: 302 | c = colorSelector.value 303 | if not c.startswith('#'): 304 | c = cnames[c] 305 | hex_color = c 306 | rgb_color = hex2color(hex_color) 307 | else: 308 | rgb_color = tuple(params['colors'][topicSelector.value]) 309 | colorSelector.value = rgb2hex(rgb_color) 310 | 311 | temp=None 312 | if topicID == '' or labelID == '': 313 | print("Please check your input") 314 | else: 315 | svgGrid = drawTopicModel.generateSVGGridMolsbyTopic(topicModel, 0, idLabelToShow=labelID, topicProbThreshold = 0.1, 316 | baseRad=0.9, color=rgb_color) 317 | with open(path+'.svg','w') as out: 318 | out.write(svgGrid) 319 | print("Saved topic image to: "+os.getcwd()+'/'+path+'.svg') 320 | 321 | 322 | def saveMolSetAs(sender): 323 | topicModel = params['topicModel'] 324 | if topicModel == None: 325 | print('No topic model available, please build a valid model first.') 326 | return 327 | 328 | path = filePath2.value 329 | label = labelSelector3.value 330 | labelID = params['labels'].index(labelSelector2.value) 331 | 332 | if label == '' or labelID == '': 333 | print("Please check your input") 334 | else: 335 | svgGrid = drawTopicModel.generateSVGGridMolsByLabel(topicModel, label, idLabelToMatch=labelID, baseRad=0.9) 336 | 337 | with open(path+'.svg','w') as out: 338 | out.write(svgGrid) 339 | print("Saved molecule set image to: "+os.getcwd()+'/'+path+'.svg') 340 | 341 | 342 | def getMolLabels(labelName): 343 | topicModel = params['topicModel'] 344 | try: 345 | labelID = params['labels'].index(labelName) 346 | return list(set(topicModel.moldata['label_'+str(labelID)])) 347 | except: 348 | return [] 349 | 350 | def selectMolSet(sender): 351 | labelSelector3.options = sorted(getMolLabels(labelSelector2.value)) 352 | 353 | def selectMolSeta(sender): 354 | labelSelector3a.options = sorted(getMolLabels(labelSelector2a.value)) 355 | 356 | def topicColor(sender): 357 | rgb_color = tuple(params['colors'][topicSelector.value]) 358 | colorSelector.value = rgb2hex(rgb_color) 359 | 360 | 361 | # init values 362 | params=dict([('labels',[]),('numTopics',50),('colors',sns.husl_palette(20, s=.6)),('topicModel',None),('rareThres',0.001),('commonThres',0.1)]) 363 | 364 | labelProgressBar = widgets.Label(value='') 365 | 366 | 367 | ########### Model building widgets 368 | # widgets 369 | dataSetSelector = widgets.Text(description='Data set:',value='data/datasetA.csv', width='450px', margin='10px 5px 10px 10px') 370 | numTopicSelector = widgets.IntText(description='Number of topics', width='200px', value=params['numTopics'],\ 371 | margin='10px 5px 10px 10px') 372 | rareFilterSelector = widgets.BoundedFloatText(min=0,max=1.0,description='Threshold rare fragments', width='200px', value=params['rareThres'], margin='10px 5px 10px 10px') 373 | commonFilterSelector = widgets.BoundedFloatText(min=0,max=1.0,description='Threshold common fragments', width='200px', value=params['commonThres'], margin='10px 5px 10px 10px') 374 | fragmentmethodSelector = widgets.Dropdown(options=['Morgan', 'RDK', 'Brics'], description='Fragment method:',\ 375 | width='200px',margin='10px 5px 10px 10px') 376 | doItButton = widgets.Button(description="Build model", button_style='danger', width='300px', margin='10px 5px 10px 10px') 377 | 378 | # actions 379 | labels = doItButton.on_click(buildModel) 380 | 381 | # layout widgets 382 | set1 = widgets.HBox() 383 | set1.children = [dataSetSelector] 384 | set2 = widgets.HBox() 385 | set2.children = [numTopicSelector, fragmentmethodSelector] 386 | set2a = widgets.HBox() 387 | set2a.children = [rareFilterSelector, commonFilterSelector] 388 | set3 = widgets.HBox() 389 | set3.children = [doItButton] 390 | finalLayout = widgets.VBox() 391 | finalLayout.children = [set1, set2, set2a, set3] 392 | 393 | ########### Model statistics widget 394 | statsButton = widgets.Button(description="Show overall topic distribution", disabled=True, button_style='danger',\ 395 | width='300px', margin='10px 5px 10px 10px') 396 | labelSelector2a = widgets.Dropdown(options=params['labels'], description='Label:', width='300px', margin='10px 5px 10px 10px') 397 | init = labelSelector2a.value 398 | labelSelector3a = widgets.Dropdown(options=getMolLabels(init), description='Molecule set:', width='300px', margin='10px 5px 10px 10px') 399 | statsButton2 = widgets.Button(description="Show topic profile by label", disabled=True, button_style='danger',\ 400 | width='300px', margin='10px 5px 10px 10px') 401 | 402 | # actions 403 | statsButton.on_click(calcOverallStatistics) 404 | statsButton2.on_click(calcSubsetStatistics) 405 | labelSelector2a.observe(selectMolSeta) 406 | 407 | # layout 408 | statsLayout = widgets.HBox() 409 | statsLayout.children = [statsButton] 410 | statsLayout2 = widgets.HBox() 411 | statsLayout2.children = [labelSelector2a, labelSelector3a, statsButton2] 412 | finalLayoutStats= widgets.VBox() 413 | finalLayoutStats.children = [statsLayout, statsLayout2] 414 | 415 | ########### Model exploration widgets 416 | # choose topic tab 417 | labelSelector = widgets.Dropdown(options=params['labels'], description='Label to show:', width='300px', margin='10px 5px 10px 10px') 418 | topicSelector = widgets.BoundedIntText(description="Topic to show", min=0, max=params['numTopics']-1, width='200px',\ 419 | margin='10px 5px 10px 10px') 420 | lableChooseColor = widgets.Label(value='Define topic color',margin='10px 5px 10px 10px') 421 | chooseColor = widgets.Checkbox(value=False,margin='10px 5px 10px 10px') 422 | showTopicButton = widgets.Button(description="Show the topic", button_style='danger',disabled=True,\ 423 | width='200px', margin='10px 5px 10px 10px') 424 | # choose molecules tab 425 | labelSelector2 = widgets.Dropdown(options=params['labels'], description='Label:', width='300px', margin='10px 5px 10px 10px') 426 | init = labelSelector2.value 427 | labelSelector3 = widgets.Dropdown(options=getMolLabels(init), description='Molecule set:', width='300px', margin='10px 5px 10px 10px') 428 | showMoleculesButton = widgets.Button(description="Show the molecules", button_style='danger',disabled=True, width='200px',\ 429 | margin='10px 5px 10px 10px') 430 | # choose color tab 431 | colorSelector = widgets.ColorPicker(concise=False, description='Topic highlight color', value='#e0e3e4',width='200px', \ 432 | margin='10px 5px 10px 10px') 433 | # save as tab 434 | filePath = widgets.Text(description="Save file as:", width='450px', margin='10px 5px 10px 10px') 435 | saveAsButton = widgets.Button(description="Save topic image", button_style='info',disabled=True,width='200px',\ 436 | margin='10px 5px 10px 10px') 437 | filePath2 = widgets.Text(description="Save file as:", width='450px', margin='10px 5px 10px 10px') 438 | saveAsButton2 = widgets.Button(description="Save molecule set image", button_style='info',disabled=True,width='200px', \ 439 | margin='10px 5px 10px 10px') 440 | 441 | # actions 442 | showTopicButton.on_click(showTopic) 443 | saveAsButton.on_click(saveTopicAs) 444 | saveAsButton2.on_click(saveMolSetAs) 445 | showMoleculesButton.on_click(showMolecules) 446 | labelSelector2.observe(selectMolSet) 447 | 448 | # layout widgets 449 | tab1 = widgets.HBox() 450 | tab1.children = [topicSelector, labelSelector, lableChooseColor, chooseColor, showTopicButton] 451 | tab2 = widgets.HBox() 452 | tab2.children = [labelSelector2, labelSelector3, showMoleculesButton] 453 | tab3 = widgets.HBox() 454 | tab3.children = [colorSelector] 455 | tab4a = widgets.HBox() 456 | tab4a.children = [filePath, saveAsButton] 457 | tab4b = widgets.HBox() 458 | tab4b.children = [filePath2, saveAsButton2] 459 | tab4 = widgets.VBox() 460 | tab4.children = [tab4a, tab4b] 461 | 462 | children = [tab1, tab2, tab3, tab4] 463 | tabs = widgets.Tab(children=children) 464 | tabs.set_title(0,'Topic to explore') 465 | tabs.set_title(1,'Molecule set to explore') 466 | tabs.set_title(2,'Choose color') 467 | tabs.set_title(3,'Save images as') 468 | 469 | accordion = widgets.Accordion(children=[finalLayout, finalLayoutStats, tabs]) 470 | accordion.set_title(0, 'Build Topic model') 471 | accordion.set_title(1, 'Statistics Topic model') 472 | accordion.set_title(2, 'Explore Topic model') 473 | display(accordion) 474 | display(labelProgressBar) -------------------------------------------------------------------------------- /ChemTopicModel/utilsChemTM.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2016, Novartis Institutes for BioMedical Research Inc. 3 | # All rights reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are 7 | # met: 8 | # 9 | # * Redistributions of source code must retain the above copyright 10 | # notice, this list of conditions and the following disclaimer. 11 | # * Redistributions in binary form must reproduce the above 12 | # copyright notice, this list of conditions and the following 13 | # disclaimer in the documentation and/or other materials provided 14 | # with the distribution. 15 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 16 | # nor the names of its contributors may be used to endorse or promote 17 | # products derived from this software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | # 31 | # Created by Nadine Schneider, April 2019 32 | 33 | 34 | from collections import defaultdict, Counter 35 | import numpy as np 36 | import pandas as pd 37 | 38 | from ChemTopicModel import chemTopicModel 39 | 40 | 41 | def rankInterestingTopics(TM, minMaxProb=0.6, topXfrags=10): 42 | #ratio of high prob mols 43 | numMols, numTopics = TM.documentTopicProbabilities.shape 44 | fracHighProbMols=[] 45 | relTopicSize=[] 46 | absTopicSize=[] 47 | for i in range(numTopics): 48 | #find all molecules that have maximum topic probability for topic i 49 | subM = TM.documentTopicProbabilities[np.where(TM.documentTopicProbabilities.argmax(axis=1) == i)] 50 | numMolsMaxProb,_ = subM.shape 51 | if numMolsMaxProb > 0: 52 | relTopicSize.append(numMolsMaxProb/numMols) 53 | #get the fraction of molecules with a probability of at least minMaxProb 54 | numHighProbMols = np.where(subM.max(axis=1) >= minMaxProb)[0].shape[0] 55 | fracHighProbMols.append(numHighProbMols/numMolsMaxProb) 56 | absTopicSize.append(numMolsMaxProb) 57 | else: 58 | relTopicSize.append(0.0) 59 | absTopicSize.append(0.0) 60 | fracHighProbMols.append(0.0) 61 | 62 | fragsprob=TM.getTopicFragmentProbabilities() 63 | probTopXFrags = [sum(sorted(fragsprob[k], reverse=True)[:topXfrags]) for k in range(numTopics)] 64 | numFrags=[len(Counter(i.astype('float16')))-1 for i in fragsprob] 65 | minfrag = min(numFrags) 66 | maxfrag = max(numFrags) 67 | normNumFrags=[(x-minfrag)/(maxfrag-minfrag) for x in numFrags] 68 | 69 | tmpResult = list(zip(list(range(numTopics)),fracHighProbMols,absTopicSize,relTopicSize,normNumFrags,probTopXFrags)) 70 | result = pd.DataFrame(tmpResult, columns=['Topic Idx', 'fraction high prob. mols', 71 | 'abs. topic size', 'rel. topic size', 72 | 'rel. num. relevant frags', 'sum prob. top 5 frags']) 73 | return result -------------------------------------------------------------------------------- /ChemTopicModel/utilsDrawing.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2016, Novartis Institutes for BioMedical Research Inc. 3 | # All rights reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are 7 | # met: 8 | # 9 | # * Redistributions of source code must retain the above copyright 10 | # notice, this list of conditions and the following disclaimer. 11 | # * Redistributions in binary form must reproduce the above 12 | # copyright notice, this list of conditions and the following 13 | # disclaimer in the documentation and/or other materials provided 14 | # with the distribution. 15 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 16 | # nor the names of its contributors may be used to endorse or promote 17 | # products derived from this software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | # 31 | # Created by Nadine Schneider, June 2016 32 | 33 | 34 | import numpy as np 35 | import pandas as pd 36 | import copy 37 | import re 38 | from rdkit.Chem import PandasTools 39 | from IPython.display import SVG 40 | 41 | # generate an HTML table of the svg images to visulize them nicely in the Jupyter notebook 42 | PandasTools.RenderImagesInAllDataFrames(images=True) 43 | def drawSVGsToHTMLGrid(svgs, cssTableName='default', tableHeader='', namesSVGs=[], size=(150,150), numColumns=4, numRowsShown=2, noHeader=False): 44 | rows=[] 45 | names=copy.deepcopy(namesSVGs) 46 | rows = [SVG(i).data if i.startswith(' 0: 50 | rows+=['']*(numColumns-x) 51 | d+=1 52 | if len(names)>0: 53 | names+=['']*(numColumns-x) 54 | rows=np.array(rows).reshape(d,numColumns) 55 | finalRows=[] 56 | if len(names)>0: 57 | names = np.array(names).reshape(d,numColumns) 58 | for r,n in zip(rows,names): 59 | finalRows.append(r) 60 | finalRows.append(n) 61 | d*=2 62 | else: 63 | finalRows=rows 64 | 65 | headerRemove = int(max(numColumns,d)) 66 | df=pd.DataFrame(finalRows) 67 | 68 | style = '\n' 79 | if not noHeader: 80 | style += '
'+str(tableHeader)+'
\n' 81 | style += '
\n' 82 | dfhtml=style+df.to_html()+'\n
\n' 83 | dfhtml=dfhtml.replace('class="dataframe"','class="'+cssTableName+'"') 84 | dfhtml=dfhtml.replace('','') 85 | for i in range(0,headerRemove): 86 | dfhtml=dfhtml.replace(''+str(i)+'','') 87 | return dfhtml 88 | 89 | # build an svg grid image to print 90 | def SvgsToGrid(svgs, labels, svgsPerRow=4,molSize=(250,150),fontSize=12): 91 | 92 | matcher = re.compile(r'^(<.*>\n)(\n)(.*)',re.DOTALL) 93 | hdr='' 94 | ftr='' 95 | rect='' 96 | nRows = len(svgs)//svgsPerRow 97 | if len(svgs)%svgsPerRow : nRows+=1 98 | blocks = ['']*(nRows*svgsPerRow) 99 | labelSizeDist = fontSize*5 100 | fullSize=(svgsPerRow*(molSize[0]+molSize[0]/10.0),nRows*(molSize[1]+labelSizeDist)) 101 | print(fullSize) 102 | 103 | count=0 104 | for svg,name in zip(svgs,labels): 105 | h,r,b = matcher.match(svg).groups() 106 | if not hdr: 107 | hdr = h.replace("width='"+str(molSize[0])+"px'","width='%dpx'"%fullSize[0]) 108 | hdr = hdr.replace("height='"+str(molSize[1])+"px'","height='%dpx'"%fullSize[1]) 109 | if not rect: 110 | rect = r 111 | legend = '\n' 112 | legend += ''+name.split('|')[0]+'\n' 113 | if len(name.split('|')) > 1: 114 | legend += ''+name.split('|')[1]+'\n' 115 | legend += '\n' 116 | blocks[count] = b + legend 117 | count+=1 118 | 119 | for i,elem in enumerate(blocks): 120 | row = i//svgsPerRow 121 | col = i%svgsPerRow 122 | elem = rect+elem 123 | blocks[i] = '%s'%(col*(molSize[0]+molSize[0]/10.0),row*(molSize[1]+labelSizeDist),elem) 124 | res = hdr + '\n'.join(blocks)+ftr 125 | return res 126 | -------------------------------------------------------------------------------- /ChemTopicModel/utilsEvaluation.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2016, Novartis Institutes for BioMedical Research Inc. 3 | # All rights reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are 7 | # met: 8 | # 9 | # * Redistributions of source code must retain the above copyright 10 | # notice, this list of conditions and the following disclaimer. 11 | # * Redistributions in binary form must reproduce the above 12 | # copyright notice, this list of conditions and the following 13 | # disclaimer in the documentation and/or other materials provided 14 | # with the distribution. 15 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 16 | # nor the names of its contributors may be used to endorse or promote 17 | # products derived from this software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | # 31 | # Created by Nadine Schneider, September 2016 32 | 33 | 34 | from collections import defaultdict, Counter 35 | import numpy as np 36 | import pandas as pd 37 | import operator 38 | 39 | # continous Tanimoto coefficient 40 | def _calcContTanimoto(x,y): 41 | m=np.array([x,y]) 42 | minsum = m.min(0).sum() 43 | maxsum = m.max(0).sum() 44 | if maxsum > 0: 45 | return minsum/maxsum 46 | else: 47 | return 0 48 | 49 | # continous Tanimoto coefficient matrix for a count-based FP matrix 50 | def calcContTanimotoDistMatrix(fragMatrix): 51 | from scipy.spatial.distance import squareform,pdist 52 | dists = pdist(fragMatrix, lambda u, v: _calcContTanimoto(u,v)) 53 | return squareform(dists) 54 | 55 | # calculate recall and precision for a topic model based on a given label 56 | def generateStatistics(topicModel, idLabelToUse=1): 57 | 58 | label_topics=defaultdict(list) 59 | topics_label=defaultdict(list) 60 | numMolsPerLabel=Counter(topicModel.moldata['label_'+str(idLabelToUse)].tolist()) 61 | numDocs, numTopics = topicModel.documentTopicProbabilities.shape 62 | for i in range(0, numDocs): 63 | label = topicModel.moldata['label_'+str(idLabelToUse)][i] 64 | maxTopic=np.argmax(topicModel.documentTopicProbabilities[i]) 65 | label_topics[label].append(maxTopic) 66 | topics_label[maxTopic].append(label) 67 | label_topics2=defaultdict(dict) 68 | for tid,topics in label_topics.items(): 69 | label_topics2[tid]=Counter(topics) 70 | topics_label2=defaultdict(dict) 71 | for topicid,tid in topics_label.items(): 72 | topics_label2[topicid]=Counter(tid) 73 | 74 | data=[] 75 | for label,topics in label_topics2.items(): 76 | tsorted = sorted(topics.items(), key=operator.itemgetter(1),reverse=True) 77 | maxTopic = tsorted[0][0] 78 | numMolsMaxTopic = tsorted[0][1] 79 | numMols = numMolsPerLabel[label] 80 | precisionMT = numMolsMaxTopic/float(sum(topics_label2[maxTopic].values())) 81 | recallMT = numMolsMaxTopic/float(numMols) 82 | F1 = 2 * (precisionMT * recallMT) / (precisionMT + recallMT) 83 | data.append([label, numMols, len(topics.keys()), int(maxTopic), recallMT, precisionMT, F1]) 84 | data=pd.DataFrame(data, columns=['label','# mols','# topics','main topic ID','recall in main topic',\ 85 | 'precision in main topic', 'F1']) 86 | data = data.sort_values(['main topic ID']) 87 | overall=['Median'] 88 | overall.extend(data[['# mols','# topics']].median().values.tolist()) 89 | overall.append('-') 90 | overall.extend(data[['recall in main topic','precision in main topic', 'F1']].median().values.tolist()) 91 | data.loc[len(data)] = overall 92 | return data 93 | -------------------------------------------------------------------------------- /ChemTopicModel/utilsFP.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2016, Novartis Institutes for BioMedical Research Inc. 3 | # All rights reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are 7 | # met: 8 | # 9 | # * Redistributions of source code must retain the above copyright 10 | # notice, this list of conditions and the following disclaimer. 11 | # * Redistributions in binary form must reproduce the above 12 | # copyright notice, this list of conditions and the following 13 | # disclaimer in the documentation and/or other materials provided 14 | # with the distribution. 15 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 16 | # nor the names of its contributors may be used to endorse or promote 17 | # products derived from this software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | # 31 | # Created by Nadine Schneider, June 2016 32 | 33 | 34 | from rdkit import Chem 35 | from rdkit.Chem import AllChem 36 | 37 | import copy 38 | from collections import defaultdict 39 | import hashlib 40 | 41 | def getMorganEnvironment(mol, bitInfo, fp=None, minRad=0): 42 | """ 43 | 44 | >>> m = Chem.MolFromSmiles('CC(O)C') 45 | >>> bi = {} 46 | >>> fp = AllChem.GetMorganFingerprintAsBitVect(m,2,2048,bitInfo=bi) 47 | >>> getMorganEnvironment(m,bi) 48 | defaultdict(, {1: [[]], 227: [[1]], 283: [[0], [2]], 709: [[0, 1, 2]], 807: [[]], 1057: [[], []]}) 49 | >>> getMorganEnvironment(m,bi,minRad=1) 50 | defaultdict(, {227: [[1]], 283: [[0], [2]], 709: [[0, 1, 2]]}) 51 | >>> list(fp.GetOnBits()) 52 | [1, 227, 283, 709, 807, 1057] 53 | >>> getMorganEnvironment(m,bi,minRad=1,fp=fp) 54 | defaultdict(, {227: [[1]], 283: [[0], [2]], 709: [[0, 1, 2]]}) 55 | >>> list(fp.GetOnBits()) 56 | [227, 283, 709] 57 | 58 | """ 59 | bitPaths=defaultdict(list) 60 | for bit,info in bitInfo.items(): 61 | for atomID,radius in info: 62 | if radius < minRad: 63 | if fp != None: 64 | fp[bit]=0 65 | continue 66 | env = Chem.FindAtomEnvironmentOfRadiusN(mol,radius,atomID) 67 | bitPaths[bit].append(list(env)) 68 | return bitPaths 69 | 70 | def _includeRingMembership(s, n, noRingAtom=False): 71 | r=';R]' 72 | if noRingAtom: 73 | r=';R0]' 74 | d="]" 75 | return r.join([d.join(s.split(d)[:n]),d.join(s.split(d)[n:])]) 76 | 77 | def _includeDegree(s, n, d): 78 | r=';D'+str(d)+']' 79 | d="]" 80 | return r.join([d.join(s.split(d)[:n]),d.join(s.split(d)[n:])]) 81 | 82 | def writePropsToSmiles(mol,smi,order): 83 | """ 84 | 85 | >>> writePropsToSmiles(Chem.MolFromSmiles('Cc1ncccc1'),'[cH]:[n]:[c]-[CH3]',(3,2,1,0)) 86 | '[cH;R;D2]:[n;R;D2]:[c;R;D3]-[CH3;R0;D1]' 87 | 88 | """ 89 | finalsmi = copy.deepcopy(smi) 90 | for i,a in enumerate(order,1): 91 | atom = mol.GetAtomWithIdx(a) 92 | if not atom.GetAtomicNum(): 93 | continue 94 | finalsmi = _includeRingMembership(finalsmi, i, noRingAtom = not atom.IsInRing()) 95 | finalsmi = _includeDegree(finalsmi, i, atom.GetDegree()) 96 | return finalsmi 97 | 98 | def getSubstructSmi(mol,env,propsToSmiles=True): 99 | """ 100 | 101 | >>> getSubstructSmi(Chem.MolFromSmiles('Cc1ncccc1'),((0,1,2))) 102 | '[cH;R;D2]:[n;R;D2]:[c;R;D3]-[CH3;R0;D1]' 103 | 104 | """ 105 | atomsToUse=set() 106 | if not len(env): 107 | return '' 108 | for b in env: 109 | atomsToUse.add(mol.GetBondWithIdx(b).GetBeginAtomIdx()) 110 | atomsToUse.add(mol.GetBondWithIdx(b).GetEndAtomIdx()) 111 | # no isomeric smiles since we don't include that in the fingerprints 112 | smi = Chem.MolFragmentToSmiles(mol,atomsToUse,isomericSmiles=False, 113 | bondsToUse=env,allHsExplicit=True, allBondsExplicit=True) 114 | if propsToSmiles: 115 | order = eval(mol.GetProp("_smilesAtomOutputOrder")) 116 | smi = writePropsToSmiles(mol,smi,order) 117 | return smi 118 | 119 | def generateAtomInvariant(mol): 120 | """ 121 | 122 | >>> generateAtomInvariant(Chem.MolFromSmiles("Cc1ncccc1")) 123 | [346999948, 3963180082, 3525326240, 2490398925, 2490398925, 2490398925, 2490398925] 124 | 125 | """ 126 | num_atoms = mol.GetNumAtoms() 127 | invariants = [0]*num_atoms 128 | for i,a in enumerate(mol.GetAtoms()): 129 | descriptors=[] 130 | descriptors.append(a.GetAtomicNum()) 131 | descriptors.append(a.GetTotalDegree()) 132 | descriptors.append(a.GetTotalNumHs()) 133 | descriptors.append(a.IsInRing()) 134 | descriptors.append(a.GetIsAromatic()) 135 | invariants[i]=int(hashlib.sha256(str(descriptors).encode('utf-8')).hexdigest(),16)& 0xffffffff 136 | return invariants 137 | 138 | 139 | #------------------------------------ 140 | # 141 | # doctest boilerplate 142 | # 143 | def _test(): 144 | import doctest, sys 145 | return doctest.testmod(sys.modules["__main__"]) 146 | 147 | 148 | if __name__ == '__main__': 149 | import sys 150 | failed, tried = _test() 151 | sys.exit(failed) 152 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, NadineSchneider 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | CheTo - RC(=O)R 2 | -------- 3 | 4 | 5 | CheTo (ChemicalTopic) allows to apply topic modeling, a method developed in the text-mining field, to chemical data. Please see our recent publication for detailed information: 6 | 7 | Schneider, N.; Fechner, N.; Landrum, G. A.; Stiefl, N. *Chemical Topic Modeling: Exploring Molecular Data Sets Using a Common Text-Mining Approach*. J. Chem. Inf. Model. 2017, [http://pubs.acs.org/doi/10.1021/acs.jcim.7b00249](http://pubs.acs.org/doi/10.1021/acs.jcim.7b00249) 8 | 9 | The [supplementary](http://pubs.acs.org/doi/suppl/10.1021/acs.jcim.7b00249) of the paper contains exemplary data sets extracted from the [ChEMBL database](https://www.ebi.ac.uk/chembl/) and Jupyter notebooks to run the experiments described in the paper. 10 | 11 | An interactive web page showing an exemplary topic model of data set A from our paper can be found here [http://www.t5informatics.com/Papers/InteractiveTopicModelDatasetA.html](http://www.t5informatics.com/Papers/InteractiveTopicModelDatasetA.html) 12 | 13 | **Installation** 14 | 15 | To install CheTo using Conda, simply run: 16 | 17 | `conda install -c rdkit cheto` 18 | 19 | **Further reading** 20 | 21 | Using CheTo in KNIME: [http://rdkit.blogspot.ch/2017/08/chemical-topic-modeling-with-rdkit-and.html](http://rdkit.blogspot.ch/2017/08/chemical-topic-modeling-with-rdkit-and.html) 22 | 23 | After publication of our article we were made aware that applying topic modeling to chemical data was also suggested by Rajarshi Guha in 2012 in his blog ([http://blog.rguha.net/?p=997](http://blog.rguha.net/?p=997)). 24 | -------------------------------------------------------------------------------- /notebooks_data/README.md: -------------------------------------------------------------------------------- 1 | ## New experiments with CheTo 2 | 3 | In a recently published book chapter we are describing a set of different experiments using chemical topic modeling to dive deeper into the data in ChEMBL23. In this folder you can find the notebooks to create the results/images in the book chapter cited below. They should inspire you how CheTo can be used to explore chemical data sets. 4 | 5 | #### Book chapter: 6 | 7 | Schneider, N., Fechner, N., Stiefl, N., & Landrum, G. A. (2020). Chemical Topic Modeling–An Unsupervised Approach Originating from Text-mining to Organize Chemical Data. Artificial Intelligence in Drug Discovery, 75, 17. 8 | 9 | https://pubs.rsc.org/en/content/chapter/bk9781788015479-00015/978-1-78801-547-9 10 | 11 | 12 | #### Software dependencies 13 | 14 | - python 3.6.8 15 | - numpy 1.16.2 16 | - pandas 0.24.1 17 | - rdkit 2018.09.2.0 18 | - scikit-learn 0.20.2 -------------------------------------------------------------------------------- /notebooks_data/runTM_example.sh: -------------------------------------------------------------------------------- 1 | # default learning_rate and default learning_offset 2 | python runTopicModel.py --data chembl23_mols.csv.shuffled --rareThres 0.0005 --njobsFrag 10 --numTopics 100 --sizeSampleDataSet 0.1 --outfilePrefix tm_chembl23_100 --maxIterOpt 50 --chunksize 85000 --lowPrec 1 > chembl23_100.log & 3 | -------------------------------------------------------------------------------- /notebooks_data/runTopicModel.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2019, Novartis Institutes for BioMedical Research Inc. 3 | # All rights reserved. 4 | # 5 | # Redistribution and use in source and binary forms, with or without 6 | # modification, are permitted provided that the following conditions are 7 | # met: 8 | # 9 | # * Redistributions of source code must retain the above copyright 10 | # notice, this list of conditions and the following disclaimer. 11 | # * Redistributions in binary form must reproduce the above 12 | # copyright notice, this list of conditions and the following 13 | # disclaimer in the documentation and/or other materials provided 14 | # with the distribution. 15 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 16 | # nor the names of its contributors may be used to endorse or promote 17 | # products derived from this software without specific prior written permission. 18 | # 19 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 20 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 21 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 22 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 23 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 24 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 25 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 26 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 27 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 28 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | # 31 | # Created by Nadine Schneider, May 2019 32 | 33 | 34 | import pandas as pd 35 | import numpy as np 36 | 37 | import ChemTopicModel 38 | from ChemTopicModel import chemTopicModel, utilsEvaluation, drawTopicModel 39 | 40 | print('\n----------------------------------------------------------') 41 | print('----------------------------------------------------------') 42 | print('------------------- CHETO ----------------------------') 43 | print('----------- Chemical topic modeling ------------------') 44 | print('----------------------------------------------------------') 45 | print('----------------------------------------------------------\n\n') 46 | 47 | import time 48 | print(time.asctime()) 49 | 50 | import sklearn 51 | from rdkit import rdBase 52 | print('RDKit version: ',rdBase.rdkitVersion) 53 | print('Pandas version:', pd.__version__) 54 | print('Scikit-Learn version:', sklearn.__version__) 55 | print('Numpy version:', np.__version__) 56 | print(ChemTopicModel.__file__) 57 | 58 | print('\n----------------------------------------------------------\n') 59 | 60 | import pickle 61 | import argparse 62 | 63 | def main(): 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('--data', type=str, required=True, help='Please specify the path to your data file. Required format: csv; first column needs to contain a SMILES of your molecule, the following columns can be different labels for the data.') 66 | parser.add_argument('--numTopics', type=int, required=True, help='Please specify the number of topics for your model.') 67 | parser.add_argument('--fragMethod', type=str, default='Morgan', choices=['Morgan', 'RDK', 'Brics'], help='Please select your fragmentation method. Default: Morgan') 68 | parser.add_argument('--sizeSampleDataSet', type=float, default=1.0, 69 | help='Please choose a ratio between 0.0 and 1.0. Default: 1.0; for large datasets a value of 0.1 is recommended.') 70 | parser.add_argument('--rareThres', type=float, default=0.001, 71 | help='Please choose a threshold between 0.0 and 1.0. Default: 0.001; for small datasets a value of 0.01 is recommended.') 72 | parser.add_argument('--commonThres', type=float, default=0.1, help='Please choose a threshold between 0.0 and 1.0. Default: 0.1') 73 | parser.add_argument('--njobsFrag', type=int, default=1, help='Please specify the number of jobs used to fragment the molecules. Default: 1') 74 | parser.add_argument('--njobsLDA', type=int, default=1, help='Please specify the number of jobs used to fragment the molecules. Default: 1') 75 | parser.add_argument('--maxIterOpt', type=int, default=10, help='Please specify the number of iterations for the LDA optimization. Default: 10') 76 | parser.add_argument('--outfilePrefix', type=str, default='tm_', help='Please specify a filename to store the model.') 77 | parser.add_argument('--chunksize', type=int, default=1000, help='Please specify the chunksize for online training. Default: 1000') 78 | parser.add_argument('--lowPrec', type=bool, default=0, help='Choose a lower precision if you expect your model to be huge. Default: False') 79 | parser.add_argument('--ratioCmpdsMB', type=float, default=1.0, help='Choose the number of cmpds the model will be build on. Default: 1.0') 80 | args = parser.parse_args() 81 | 82 | print('---> Reading data') 83 | datafile = args.data 84 | data = pd.read_csv(datafile) 85 | 86 | data = data.sample(frac=1.0,random_state=np.random.RandomState(42)) 87 | data.reset_index(drop=True,inplace=True) 88 | data.to_csv(datafile+'.shuffled',index=False) 89 | 90 | seed=57 91 | tm=chemTopicModel.ChemTopicModel(sizeSampleDataSet=args.sizeSampleDataSet, fragmentMethod=args.fragMethod, 92 | rareThres=args.rareThres, commonThres=args.commonThres, randomState=seed, 93 | n_jobs=args.njobsFrag, chunksize=args.chunksize) 94 | 95 | tm.loadData(data) 96 | 97 | print("---> Generating fragments") 98 | stime = time.time() 99 | tm.generateFragments() 100 | print("Time:", time.time()-stime) 101 | print("Size fragment matrix ", tm.fragM.shape) 102 | 103 | print("---> Fitting topic model") 104 | stime = time.time() 105 | tm.fitTopicModel(args.numTopics, max_iter=args.maxIterOpt, nJobs=args.njobsLDA, sizeFittingDataset=args.ratioCmpdsMB) 106 | print("Time:", time.time()-stime) 107 | print("---> Transforming topic model") 108 | stime = time.time() 109 | tm.transformDataToTopicModel(lowerPrecision=args.lowPrec) 110 | print("Time:", time.time()-stime) 111 | 112 | print("---> Saving topic model") 113 | # you need protocol 4 to save large files (> 4GB), this is only possible with pyhton version > 3.4 114 | with open(args.outfilePrefix+'.pkl', 'wb') as fp: 115 | pickle.dump(tm, fp, protocol=4) 116 | 117 | print('---> DONE. Enjoy your model!') 118 | 119 | main() -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='ChemTopicModel', 4 | version='0.2', 5 | description='Applying topic modeling for chemistry. https://dx.doi.org/10.1021/acs.jcim.7b00249', 6 | url='http://rdkit.org', 7 | author='Nadine Schneider', 8 | author_email='nadine.schneider.shb@gmail.com', 9 | license='BSD', 10 | packages=['ChemTopicModel'], 11 | install_requires=['scikit-learn'], 12 | zip_safe=False) 13 | --------------------------------------------------------------------------------