├── LICENSE ├── data └── ng20.RData ├── .Rbuildignore ├── tools ├── loss-evolution.png ├── example-visualisation.png └── example-visualisation-basic.png ├── inst ├── example │ ├── example_dtm.rds │ └── example_etm.ckpt ├── tinytest │ ├── data │ │ ├── init-mu_q_theta.bias.txt │ │ ├── end-mu_q_theta.bias.txt │ │ ├── init-logsigma_q_theta.bias.txt │ │ ├── end-logsigma_q_theta.bias.txt │ │ ├── init-2.bias.txt │ │ ├── end-0.bias.txt │ │ ├── end-2.bias.txt │ │ ├── init-0.bias.txt │ │ ├── end-alphas.txt │ │ ├── init-alphas.txt │ │ ├── init-logsigma_q_theta.weight.txt │ │ ├── init-mu_q_theta.weight.txt │ │ ├── end-logsigma_q_theta.weight.txt │ │ ├── end-mu_q_theta.weight.txt │ │ ├── init-2.weight.txt │ │ └── end-2.weight.txt │ └── test_end_to_end.R.R └── orig │ └── ETM │ ├── LICENSE │ ├── skipgram.py │ ├── data.py │ ├── README.md │ ├── utils.py │ ├── etm.py │ ├── scripts │ ├── stops.txt │ ├── data_nyt.py │ └── data_20ng.py │ └── main.py ├── .gitignore ├── tests └── tinytest.R ├── R ├── pkg.R ├── data.R ├── utils.R └── ETM.R ├── NAMESPACE ├── NEWS.md ├── man ├── ng20.Rd ├── as.matrix.ETM.Rd ├── summary.ETM.Rd ├── predict.ETM.Rd ├── plot.ETM.Rd └── ETM.Rd ├── ETM.Rproj ├── .github └── workflows │ ├── R-CMD-check.yml │ └── rhub.yaml ├── DESCRIPTION ├── LICENSE.note └── README.md /LICENSE: -------------------------------------------------------------------------------- 1 | YEAR: 2021 2 | COPYRIGHT HOLDER: Jan Wijffels, BNOSAC -------------------------------------------------------------------------------- /data/ng20.RData: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bnosac/ETM/HEAD/data/ng20.RData -------------------------------------------------------------------------------- /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^.*\.Rproj$ 2 | ^\.Rproj\.user$ 3 | dev 4 | .github 5 | inst/tinytest 6 | -------------------------------------------------------------------------------- /tools/loss-evolution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bnosac/ETM/HEAD/tools/loss-evolution.png -------------------------------------------------------------------------------- /inst/example/example_dtm.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bnosac/ETM/HEAD/inst/example/example_dtm.rds -------------------------------------------------------------------------------- /inst/example/example_etm.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bnosac/ETM/HEAD/inst/example/example_etm.ckpt -------------------------------------------------------------------------------- /tools/example-visualisation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bnosac/ETM/HEAD/tools/example-visualisation.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .Rproj.user 2 | .Rhistory 3 | .RData 4 | .Ruserdata 5 | src/*.o 6 | src/*.so 7 | src/*.dll 8 | dev 9 | -------------------------------------------------------------------------------- /tools/example-visualisation-basic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bnosac/ETM/HEAD/tools/example-visualisation-basic.png -------------------------------------------------------------------------------- /inst/tinytest/data/init-mu_q_theta.bias.txt: -------------------------------------------------------------------------------- 1 | 0.2853255867958069 2 | 0.0652959942817688 3 | 0.19156426191329956 4 | -0.24854962527751923 -------------------------------------------------------------------------------- /inst/tinytest/data/end-mu_q_theta.bias.txt: -------------------------------------------------------------------------------- 1 | 0.16553065180778503 2 | 0.08828012645244598 3 | 0.08991000056266785 4 | -0.17445848882198334 -------------------------------------------------------------------------------- /tests/tinytest.R: -------------------------------------------------------------------------------- 1 | if(requireNamespace("tinytest", quietly = TRUE) && FALSE){ 2 | tinytest::test_package("topicmodels.etm") 3 | } 4 | -------------------------------------------------------------------------------- /inst/tinytest/data/init-logsigma_q_theta.bias.txt: -------------------------------------------------------------------------------- 1 | 0.10274165868759155 2 | -0.41334447264671326 3 | -0.149125337600708 4 | -0.03497129678726196 -------------------------------------------------------------------------------- /inst/tinytest/data/end-logsigma_q_theta.bias.txt: -------------------------------------------------------------------------------- 1 | -0.001679401146247983 2 | -0.2912718951702118 3 | -0.043210797011852264 4 | -0.09374536573886871 -------------------------------------------------------------------------------- /inst/tinytest/data/init-2.bias.txt: -------------------------------------------------------------------------------- 1 | -0.1343691349029541 2 | -0.22685791552066803 3 | -0.03069823980331421 4 | 0.11382496356964111 5 | -0.35526373982429504 -------------------------------------------------------------------------------- /inst/tinytest/data/end-0.bias.txt: -------------------------------------------------------------------------------- 1 | -0.003992969170212746 2 | -0.022207029163837433 3 | -0.021055657416582108 4 | 0.08955368399620056 5 | 0.011268123984336853 -------------------------------------------------------------------------------- /inst/tinytest/data/end-2.bias.txt: -------------------------------------------------------------------------------- 1 | -0.03325639292597771 2 | -0.11705461889505386 3 | 0.07348982244729996 4 | -0.0009831655770540237 5 | -0.24144168198108673 -------------------------------------------------------------------------------- /inst/tinytest/data/init-0.bias.txt: -------------------------------------------------------------------------------- 1 | 0.01540415734052658 2 | -0.014282756485044956 3 | 0.013270918279886246 4 | -0.013338010758161545 5 | -0.01334417425096035 -------------------------------------------------------------------------------- /R/pkg.R: -------------------------------------------------------------------------------- 1 | #' @import torch 2 | #' @importFrom Matrix sparseMatrix rowSums 3 | #' @importFrom graphics par 4 | #' @importFrom stats predict 5 | NULL 6 | 7 | 8 | -------------------------------------------------------------------------------- /inst/tinytest/data/end-alphas.txt: -------------------------------------------------------------------------------- 1 | -0.24042658507823944 2 | 0.1252402365207672 3 | 0.2614149749279022 4 | -0.49488964676856995 5 | 0.27740034461021423 6 | 0.3321397304534912 7 | -0.23344556987285614 8 | 0.5449241995811462 9 | 0.021241048350930214 10 | -0.03262738510966301 11 | 0.11676688492298126 12 | 0.06686047464609146 -------------------------------------------------------------------------------- /inst/tinytest/data/init-alphas.txt: -------------------------------------------------------------------------------- 1 | -0.24081605672836304 2 | 0.1010899543762207 3 | 0.14850664138793945 4 | -0.5043508410453796 5 | 0.2728678584098816 6 | 0.210035502910614 7 | -0.2348487675189972 8 | 0.5583230257034302 9 | -0.10650685429573059 10 | -0.013019680976867676 11 | 0.10165554285049438 12 | -0.06088513135910034 -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | S3method(as.matrix,ETM) 4 | S3method(plot,ETM) 5 | S3method(predict,ETM) 6 | S3method(summary,ETM) 7 | export(ETM) 8 | import(torch) 9 | importFrom(Matrix,rowSums) 10 | importFrom(Matrix,sparseMatrix) 11 | importFrom(graphics,par) 12 | importFrom(stats,predict) 13 | -------------------------------------------------------------------------------- /R/data.R: -------------------------------------------------------------------------------- 1 | #' @title Bag of words sample of the 20 newsgroups dataset 2 | #' @description Data available at \url{https://github.com/bnosac-dev/ETM/tree/master/data/20ng} 3 | #' @name ng20 4 | #' @docType data 5 | #' @examples 6 | #' data(ng20) 7 | #' str(ng20$vocab) 8 | #' str(ng20$bow_tr$tokens) 9 | #' str(ng20$bow_tr$counts) 10 | NULL -------------------------------------------------------------------------------- /NEWS.md: -------------------------------------------------------------------------------- 1 | ### CHANGES IN ETM VERSION 0.1.1 2 | 3 | - Fix R CMD check notes in the documentation of ETM regarding the use of braces 4 | - fix note about arXiv DOI in DESCRIPTION 5 | - Added extra check to make sure data is passed to the fit method 6 | - Added url of the github repository of the package 7 | - drop ggalt from Suggests 8 | 9 | ### CHANGES IN ETM VERSION 0.1.0 10 | 11 | - Initial package 12 | -------------------------------------------------------------------------------- /man/ng20.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/data.R 3 | \docType{data} 4 | \name{ng20} 5 | \alias{ng20} 6 | \title{Bag of words sample of the 20 newsgroups dataset} 7 | \description{ 8 | Data available at \url{https://github.com/bnosac-dev/ETM/tree/master/data/20ng} 9 | } 10 | \examples{ 11 | data(ng20) 12 | str(ng20$vocab) 13 | str(ng20$bow_tr$tokens) 14 | str(ng20$bow_tr$counts) 15 | } 16 | -------------------------------------------------------------------------------- /ETM.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: Default 4 | SaveWorkspace: Default 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 4 10 | Encoding: UTF-8 11 | 12 | RnwWeave: knitr 13 | LaTeX: pdfLaTeX 14 | 15 | BuildType: Package 16 | PackageUseDevtools: Yes 17 | PackageInstallArgs: --no-multiarch --with-keep.source 18 | PackageCheckArgs: --no-multiarch 19 | PackageRoxygenize: rd,collate,namespace 20 | -------------------------------------------------------------------------------- /inst/tinytest/data/init-logsigma_q_theta.weight.txt: -------------------------------------------------------------------------------- 1 | -0.09458985924720764 2 | 0.3045787811279297 3 | -0.07669892907142639 4 | 0.39356333017349243 5 | 0.2148970365524292 6 | 0.3843591809272766 7 | 0.43945032358169556 8 | -0.4305282235145569 9 | 0.1182326078414917 10 | 0.1998259425163269 11 | -0.39963650703430176 12 | -0.3835235834121704 13 | 0.10413974523544312 14 | 0.024068444967269897 15 | 0.2728472948074341 16 | 0.29736095666885376 17 | 0.26914018392562866 18 | -0.2911526560783386 19 | 0.2897706627845764 20 | -0.09377476572990417 -------------------------------------------------------------------------------- /inst/tinytest/data/init-mu_q_theta.weight.txt: -------------------------------------------------------------------------------- 1 | -0.2224932163953781 2 | 0.22359275817871094 3 | -0.2332828789949417 4 | 0.14893251657485962 5 | -0.09833827614784241 6 | -0.01769554615020752 7 | -0.20539727807044983 8 | 0.1620616316795349 9 | -0.2507886290550232 10 | -0.4405246376991272 11 | -0.10334399342536926 12 | 0.20999419689178467 13 | -0.23090942203998566 14 | 0.29861021041870117 15 | 0.2051529884338379 16 | 0.1138695478439331 17 | -0.11817577481269836 18 | -0.3219031095504761 19 | 0.19015341997146606 20 | -0.37837737798690796 -------------------------------------------------------------------------------- /inst/tinytest/data/end-logsigma_q_theta.weight.txt: -------------------------------------------------------------------------------- 1 | -0.003587504616007209 2 | 0.19189821183681488 3 | -0.1122555211186409 4 | 0.3069078028202057 5 | 0.10574685782194138 6 | 0.27002888917922974 7 | 0.32435518503189087 8 | -0.31847354769706726 9 | 0.21401840448379517 10 | 0.09162580966949463 11 | -0.2850712239742279 12 | -0.26920658349990845 13 | 0.13660918176174164 14 | 0.11214754730463028 15 | 0.16112181544303894 16 | 0.18487784266471863 17 | 0.15754267573356628 18 | -0.3195997178554535 19 | 0.26248252391815186 20 | -0.0030806842260062695 -------------------------------------------------------------------------------- /inst/tinytest/data/end-mu_q_theta.weight.txt: -------------------------------------------------------------------------------- 1 | -0.11291909962892532 2 | 0.11395995318889618 3 | -0.26816174387931824 4 | 0.053577978163957596 5 | -0.005989446770399809 6 | -0.0013280515559017658 7 | -0.09682752192020416 8 | 0.19537977874279022 9 | -0.26883187890052795 10 | -0.32541659474372864 11 | -0.009363260120153427 12 | 0.10113615542650223 13 | -0.2651570439338684 14 | 0.21574807167053223 15 | 0.09659895300865173 16 | 0.016974007710814476 17 | -0.020255297422409058 18 | -0.28112468123435974 19 | 0.2644701302051544 20 | -0.26414456963539124 -------------------------------------------------------------------------------- /inst/tinytest/data/init-2.weight.txt: -------------------------------------------------------------------------------- 1 | 0.2505561113357544 2 | -0.2742682099342346 3 | -0.1302242875099182 4 | -0.07196003198623657 5 | 0.3319660425186157 6 | -0.35172685980796814 7 | 0.033337295055389404 8 | -0.3178785443305969 9 | 0.4024980068206787 10 | 0.08835029602050781 11 | -0.2920699119567871 12 | 0.11585462093353271 13 | 0.17507719993591309 14 | 0.2591361403465271 15 | -0.0692034661769867 16 | 0.37163084745407104 17 | -0.36633655428886414 18 | 0.03915277123451233 19 | -0.24276334047317505 20 | 0.41848963499069214 21 | -0.4231681525707245 22 | -0.06124269962310791 23 | -0.3147905468940735 24 | 0.2641980051994324 25 | 0.3260175585746765 -------------------------------------------------------------------------------- /inst/tinytest/data/end-2.weight.txt: -------------------------------------------------------------------------------- 1 | 0.13966384530067444 2 | -0.1624947488307953 3 | -0.02984231896698475 4 | 0.007785496301949024 5 | 0.21862125396728516 6 | -0.23797279596328735 7 | -0.00547742610797286 8 | -0.20485974848270416 9 | 0.2878909111022949 10 | -0.00013675331138074398 11 | -0.2635123133659363 12 | 0.018475860357284546 13 | 0.07832721620798111 14 | 0.2934812307357788 15 | -0.04077600687742233 16 | 0.35882818698883057 17 | -0.26639243960380554 18 | 0.03586164861917496 19 | -0.294453889131546 20 | 0.3988807797431946 21 | -0.3082766830921173 22 | 0.010483958758413792 23 | -0.20184746384620667 24 | 0.15277725458145142 25 | 0.21280665695667267 -------------------------------------------------------------------------------- /inst/orig/ETM/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Adji B. Dieng, Francisco J. R. Ruiz, David M. Blei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /.github/workflows/R-CMD-check.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - master 5 | pull_request: 6 | branches: 7 | - master 8 | 9 | name: R-CMD-check 10 | 11 | jobs: 12 | R-CMD-check: 13 | runs-on: ${{ matrix.config.os }} 14 | 15 | name: ${{ matrix.config.os }} (${{ matrix.config.r }}) 16 | 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | config: 21 | - {os: macos-latest, r: 'release'} 22 | - {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'} 23 | - {os: ubuntu-latest, r: 'release'} 24 | - {os: ubuntu-latest, r: 'oldrel'} 25 | - {os: ubuntu-latest, r: 'oldrel-1'} 26 | - {os: ubuntu-latest, r: 'oldrel-2'} 27 | - {os: ubuntu-latest, r: 'oldrel-3'} 28 | 29 | env: 30 | R_REMOTES_NO_ERRORS_FROM_WARNINGS: true 31 | RSPM: ${{ matrix.config.rspm }} 32 | GITHUB_PAT: ${{ secrets.PAT }} 33 | steps: 34 | - uses: actions/checkout@v3 35 | 36 | - uses: r-lib/actions/setup-pandoc@v2 37 | 38 | - uses: r-lib/actions/setup-r@v2 39 | with: 40 | r-version: ${{ matrix.config.r }} 41 | http-user-agent: ${{ matrix.config.http-user-agent }} 42 | use-public-rspm: true 43 | 44 | - uses: r-lib/actions/setup-r-dependencies@v2 45 | with: 46 | extra-packages: any::rcmdcheck 47 | needs: check 48 | 49 | - uses: r-lib/actions/check-r-package@v2 50 | with: 51 | upload-snapshots: true -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: topicmodels.etm 2 | Type: Package 3 | Title: Topic Modelling in Embedding Spaces 4 | Version: 0.1.1 5 | Maintainer: Jan Wijffels 6 | Authors@R: c( 7 | person('Jan', 'Wijffels', role = c('aut', 'cre', 'cph'), email = 'jwijffels@bnosac.be', comment = "R implementation"), 8 | person('BNOSAC', role = 'cph', comment = "R implementation"), 9 | person('Adji B. Dieng', role = c('ctb', 'cph'), comment = "original Python implementation in inst/orig"), 10 | person('Francisco J. R. Ruiz', role = c('ctb', 'cph'), comment = "original Python implementation in inst/orig"), 11 | person('David M. Blei', role = c('ctb', 'cph'), comment = "original Python implementation in inst/orig")) 12 | Description: Find topics in texts which are semantically embedded using techniques like word2vec or Glove. 13 | This topic modelling technique models each word with a categorical distribution whose natural parameter is the inner product between a word embedding and an embedding of its assigned topic. 14 | The techniques are explained in detail in the paper 'Topic Modeling in Embedding Spaces' by Adji B. Dieng, Francisco J. R. Ruiz, David M. Blei (2019), available at . 15 | URL: https://github.com/bnosac/ETM 16 | License: MIT + file LICENSE 17 | Encoding: UTF-8 18 | SystemRequirements: LibTorch (https://pytorch.org/) 19 | Depends: R (>= 2.10) 20 | Imports: 21 | graphics, 22 | stats, 23 | Matrix, 24 | torch (>= 0.5.0) 25 | Suggests: 26 | udpipe (>= 0.8.4), 27 | word2vec, 28 | uwot, 29 | tinytest, 30 | textplot (>= 0.2.0), 31 | ggrepel 32 | RoxygenNote: 7.3.2 33 | -------------------------------------------------------------------------------- /LICENSE.note: -------------------------------------------------------------------------------- 1 | This R package is distributed under the MIT License (http://opensource.org/licenses/MIT) 2 | 3 | The package also includes 3rd party open source software components and data. The following is a list of these components. The full copies of the license agreements of these components are included below. 4 | 5 | ------------------------------------------------------------------------------------------------------ 6 | - data/ng20.RData and original Python code kept as a reference in inst/orig/ETM 7 | ------------------------------------------------------------------------------------------------------ 8 | 9 | MIT License 10 | 11 | Copyright (c) 2019 Adji B. Dieng, Francisco J. R. Ruiz, David M. Blei 12 | 13 | Permission is hereby granted, free of charge, to any person obtaining a copy 14 | of this software and associated documentation files (the "Software"), to deal 15 | in the Software without restriction, including without limitation the rights 16 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 17 | copies of the Software, and to permit persons to whom the Software is 18 | furnished to do so, subject to the following conditions: 19 | 20 | The above copyright notice and this permission notice shall be included in all 21 | copies or substantial portions of the Software. 22 | 23 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 24 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 25 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 26 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 27 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 28 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 29 | SOFTWARE. 30 | -------------------------------------------------------------------------------- /man/as.matrix.ETM.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/ETM.R 3 | \name{as.matrix.ETM} 4 | \alias{as.matrix.ETM} 5 | \title{Get matrices out of an ETM object} 6 | \usage{ 7 | \method{as.matrix}{ETM}(x, type = c("embedding", "beta"), which = c("topics", "words"), ...) 8 | } 9 | \arguments{ 10 | \item{x}{an object of class \code{ETM}} 11 | 12 | \item{type}{character string with the type of information to extract: either 'beta' (words emttied by each topic) or 'embedding' (embeddings of words or topic centers). Defaults to 'embedding'.} 13 | 14 | \item{which}{a character string with either 'words' or 'topics' to get either the embeddings of the words used in the model or the embedding of the topic centers. Defaults to 'topics'. Only used if type = 'embedding'.} 15 | 16 | \item{...}{not used} 17 | } 18 | \value{ 19 | a numeric matrix containing, depending on the value supplied in \code{type} 20 | either the embeddings of the topic centers, the embeddings of the words or the words emitted by each topic 21 | } 22 | \description{ 23 | Convenience function to extract 24 | \itemize{ 25 | \item{embeddings of the topic centers} 26 | \item{embeddings of the words used in the model} 27 | \item{words emmitted by each topic (beta), which is the softmax-transformed inner product of word embedding and topic embeddings} 28 | } 29 | } 30 | \examples{ 31 | \dontshow{if(require(torch) && torch::torch_is_installed()) 32 | \{ 33 | } 34 | library(torch) 35 | library(topicmodels.etm) 36 | path <- system.file(package = "topicmodels.etm", "example", "example_etm.ckpt") 37 | model <- torch_load(path) 38 | 39 | topic.centers <- as.matrix(model, type = "embedding", which = "topics") 40 | word.embeddings <- as.matrix(model, type = "embedding", which = "words") 41 | topic.terminology <- as.matrix(model, type = "beta") 42 | \dontshow{ 43 | \} 44 | # End of main if statement running only if the torch is properly installed 45 | } 46 | } 47 | \seealso{ 48 | \code{\link{ETM}} 49 | } 50 | -------------------------------------------------------------------------------- /R/utils.R: -------------------------------------------------------------------------------- 1 | # #' @title Utility functions to map sparse matrices to token count lists 2 | # #' @description Convert a sparse matrix to counts of tokens 3 | # #' @param x a sparse Matrix 4 | # #' @return 5 | # #' a list with elements tokens, counts and vocab with the content of \code{x} 6 | # #' @export 7 | # #' @examples 8 | # #' library(udpipe) 9 | # #' library(Matrix) 10 | # #' data(brussels_reviews_anno, package = "udpipe") 11 | # #' x <- subset(brussels_reviews_anno, upos %in% "NOUN") 12 | # #' x <- document_term_frequencies(x, document = "doc_id", term = "lemma") 13 | # #' x <- document_term_matrix(x) 14 | # #' tk <- as_tokencounts(x) 15 | # #' str(tk) 16 | # #' 17 | # #' ## Test to do the other way around: tokencounts to sparse matrix 18 | # #' as_dtm <- function(tokens, counts, vocab){ 19 | # #' nm <- seq_len(length(tokens)) 20 | # #' mat <- sparseMatrix(i = unlist(Map(nm, tokens, 21 | # #' f = function(nm, key) rep(nm, length(key)))), 22 | # #' j = unlist(tokens, use.names = FALSE), 23 | # #' x = unlist(counts, use.names = FALSE)) 24 | # #' colnames(mat) <- vocab 25 | # #' mat 26 | # #' } 27 | # #' x_back <- as_dtm(tokens = tk$tokens, counts = tk$counts, vocab = tk$vocab) 28 | # #' rownames(x) <- NULL 29 | # #' all.equal(x, x_back) 30 | as_tokencounts <- function(x){ 31 | stopifnot(inherits(x, "dgCMatrix")) 32 | m <- Matrix::summary(x) 33 | tokens <- split(m$j, m$i) 34 | counts <- split(m$x, m$i) 35 | names(tokens) <- NULL 36 | names(counts) <- NULL 37 | list(tokens = tokens, counts = counts, vocab = colnames(x)) 38 | } 39 | 40 | 41 | 42 | as_dtm <- function(tokens, counts, vocab){ 43 | nm <- seq_len(length(tokens)) 44 | mat <- sparseMatrix(i = unlist(Map(nm, tokens, 45 | f = function(nm, key) rep(nm, length(key)))), 46 | j = unlist(tokens, use.names = FALSE), 47 | x = unlist(counts, use.names = FALSE)) 48 | colnames(mat) <- vocab 49 | mat 50 | } 51 | -------------------------------------------------------------------------------- /inst/orig/ETM/skipgram.py: -------------------------------------------------------------------------------- 1 | import gensim 2 | import pickle 3 | import os 4 | import numpy as np 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser(description='The Embedded Topic Model') 8 | 9 | ### data and file related arguments 10 | parser.add_argument('--data_file', type=str, default='', help='a .txt file containing the corpus') 11 | parser.add_argument('--emb_file', type=str, default='embeddings.txt', help='file to save the word embeddings') 12 | parser.add_argument('--dim_rho', type=int, default=300, help='dimensionality of the word embeddings') 13 | parser.add_argument('--min_count', type=int, default=2, help='minimum term frequency (to define the vocabulary)') 14 | parser.add_argument('--sg', type=int, default=1, help='whether to use skip-gram') 15 | parser.add_argument('--workers', type=int, default=25, help='number of CPU cores') 16 | parser.add_argument('--negative_samples', type=int, default=10, help='number of negative samples') 17 | parser.add_argument('--window_size', type=int, default=4, help='window size to determine context') 18 | parser.add_argument('--iters', type=int, default=50, help='number of iterationst') 19 | 20 | args = parser.parse_args() 21 | 22 | # Class for a memory-friendly iterator over the dataset 23 | class MySentences(object): 24 | def __init__(self, filename): 25 | self.filename = filename 26 | 27 | def __iter__(self): 28 | for line in open(self.filename): 29 | yield line.split() 30 | 31 | # Gensim code to obtain the embeddings 32 | sentences = MySentences(args.data_file) # a memory-friendly iterator 33 | model = gensim.models.Word2Vec(sentences, min_count=args.min_count, sg=args.sg, size=args.dim_rho, 34 | iter=args.iters, workers=args.workers, negative=args.negative_samples, window=args.window_size) 35 | 36 | # Write the embeddings to a file 37 | with open(args.emb_file, 'w') as f: 38 | for v in list(model.wv.vocab): 39 | vec = list(model.wv.__getitem__(v)) 40 | f.write(v + ' ') 41 | vec_str = ['%.9f' % val for val in vec] 42 | vec_str = " ".join(vec_str) 43 | f.write(vec_str + '\n') 44 | -------------------------------------------------------------------------------- /man/summary.ETM.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/ETM.R 3 | \name{summary.ETM} 4 | \alias{summary.ETM} 5 | \title{Project ETM embeddings using UMAP} 6 | \usage{ 7 | \method{summary}{ETM}(object, type = c("umap"), n_components = 2, top_n = 20, ...) 8 | } 9 | \arguments{ 10 | \item{object}{object of class \code{ETM}} 11 | 12 | \item{type}{character string with the type of summary to extract. Defaults to 'umap', no other summary information currently implemented.} 13 | 14 | \item{n_components}{the dimension of the space to embed into. Passed on to \code{\link[uwot]{umap}}. Defaults to 2.} 15 | 16 | \item{top_n}{passed on to \code{\link{predict.ETM}} to get the \code{top_n} most relevant words for each topic in the 2-dimensional space} 17 | 18 | \item{...}{further arguments passed onto \code{\link[uwot]{umap}}} 19 | } 20 | \value{ 21 | a list with elements 22 | \itemize{ 23 | \item{center: a matrix with the embeddings of the topic centers} 24 | \item{words: a matrix with the embeddings of the words} 25 | \item{embed_2d: a data.frame which contains a lower dimensional presentation in 2D of the topics and the top_n words associated with 26 | the topic, containing columns type, term, cluster (the topic number), rank, beta, x, y, weight; where type is either 'words' or 'centers', x/y contain the lower dimensional 27 | positions in 2D of the word and weight is the emitted beta scaled to the highest beta within a topic where the topic center always gets weight 0.8} 28 | } 29 | } 30 | \description{ 31 | Uses the uwot package to map the word embeddings and the center of the topic embeddings to a 2-dimensional space 32 | } 33 | \examples{ 34 | \dontshow{if(require(torch) && torch::torch_is_installed() && require(uwot)) 35 | \{ 36 | } 37 | library(torch) 38 | library(topicmodels.etm) 39 | library(uwot) 40 | path <- system.file(package = "topicmodels.etm", "example", "example_etm.ckpt") 41 | model <- torch_load(path) 42 | overview <- summary(model, 43 | metric = "cosine", n_neighbors = 15, 44 | fast_sgd = FALSE, n_threads = 1, verbose = TRUE) 45 | overview$center 46 | overview$embed_2d 47 | \dontshow{ 48 | \} 49 | # End of main if statement running only if the torch is properly installed 50 | } 51 | } 52 | \seealso{ 53 | \code{\link[uwot]{umap}}, \code{\link{ETM}} 54 | } 55 | -------------------------------------------------------------------------------- /inst/orig/ETM/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import pickle 4 | import numpy as np 5 | import torch 6 | import scipy.io 7 | 8 | def _fetch(path, name): 9 | if name == 'train': 10 | token_file = os.path.join(path, 'bow_tr_tokens.mat') 11 | count_file = os.path.join(path, 'bow_tr_counts.mat') 12 | elif name == 'valid': 13 | token_file = os.path.join(path, 'bow_va_tokens.mat') 14 | count_file = os.path.join(path, 'bow_va_counts.mat') 15 | else: 16 | token_file = os.path.join(path, 'bow_ts_tokens.mat') 17 | count_file = os.path.join(path, 'bow_ts_counts.mat') 18 | tokens = scipy.io.loadmat(token_file)['tokens'].squeeze() 19 | counts = scipy.io.loadmat(count_file)['counts'].squeeze() 20 | if name == 'test': 21 | token_1_file = os.path.join(path, 'bow_ts_h1_tokens.mat') 22 | count_1_file = os.path.join(path, 'bow_ts_h1_counts.mat') 23 | token_2_file = os.path.join(path, 'bow_ts_h2_tokens.mat') 24 | count_2_file = os.path.join(path, 'bow_ts_h2_counts.mat') 25 | tokens_1 = scipy.io.loadmat(token_1_file)['tokens'].squeeze() 26 | counts_1 = scipy.io.loadmat(count_1_file)['counts'].squeeze() 27 | tokens_2 = scipy.io.loadmat(token_2_file)['tokens'].squeeze() 28 | counts_2 = scipy.io.loadmat(count_2_file)['counts'].squeeze() 29 | return {'tokens': tokens, 'counts': counts, 30 | 'tokens_1': tokens_1, 'counts_1': counts_1, 31 | 'tokens_2': tokens_2, 'counts_2': counts_2} 32 | return {'tokens': tokens, 'counts': counts} 33 | 34 | def get_data(path): 35 | with open(os.path.join(path, 'vocab.pkl'), 'rb') as f: 36 | vocab = pickle.load(f) 37 | 38 | train = _fetch(path, 'train') 39 | valid = _fetch(path, 'valid') 40 | test = _fetch(path, 'test') 41 | 42 | return vocab, train, valid, test 43 | 44 | def get_batch(tokens, counts, ind, vocab_size, device, emsize=300): 45 | """fetch input data by batch.""" 46 | batch_size = len(ind) 47 | data_batch = np.zeros((batch_size, vocab_size)) 48 | 49 | for i, doc_id in enumerate(ind): 50 | doc = tokens[doc_id] 51 | count = counts[doc_id] 52 | L = count.shape[1] 53 | if len(doc) == 1: 54 | doc = [doc.squeeze()] 55 | count = [count.squeeze()] 56 | else: 57 | doc = doc.squeeze() 58 | count = count.squeeze() 59 | if doc_id != -1: 60 | for j, word in enumerate(doc): 61 | data_batch[i, word] = count[j] 62 | data_batch = torch.from_numpy(data_batch).float().to(device) 63 | return data_batch 64 | -------------------------------------------------------------------------------- /man/predict.ETM.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/ETM.R 3 | \name{predict.ETM} 4 | \alias{predict.ETM} 5 | \title{Predict functionality for an ETM object.} 6 | \usage{ 7 | \method{predict}{ETM}( 8 | object, 9 | newdata, 10 | type = c("topics", "terms"), 11 | batch_size = nrow(newdata), 12 | normalize = TRUE, 13 | top_n = 10, 14 | ... 15 | ) 16 | } 17 | \arguments{ 18 | \item{object}{an object of class \code{ETM}} 19 | 20 | \item{newdata}{bag of words document term matrix in \code{dgCMatrix} format. Only used in case type = 'topics'.} 21 | 22 | \item{type}{a character string with either 'topics' or 'terms' indicating to either predict to which 23 | topic a document encoded as a set of bag of words belongs to or to extract the most emitted terms for each topic} 24 | 25 | \item{batch_size}{integer with the size of the batch in order to do chunkwise predictions in chunks of \code{batch_size} rows. Defaults to the whole dataset provided in \code{newdata}. 26 | Only used in case type = 'topics'.} 27 | 28 | \item{normalize}{logical indicating to normalize the bag of words data. Defaults to \code{TRUE} similar as the default when building the \code{ETM} model. 29 | Only used in case type = 'topics'.} 30 | 31 | \item{top_n}{integer with the number of most relevant words for each topic to extract. Only used in case type = 'terms'.} 32 | 33 | \item{...}{not used} 34 | } 35 | \value{ 36 | Returns for 37 | \itemize{ 38 | \item{type 'topics': a matrix with topic probabilities of dimension nrow(newdata) x the number of topics} 39 | \item{type 'terms': a list of data.frame's where each data.frame has columns term, beta and rank indicating the 40 | top_n most emitted terms for that topic. List element 1 corresponds to the top terms emitted by topic 1, element 2 to topic 2 ...} 41 | } 42 | } 43 | \description{ 44 | Predict to which ETM topic a text belongs or extract which words are emitted for each topic. 45 | } 46 | \examples{ 47 | \dontshow{if(require(torch) && torch::torch_is_installed()) 48 | \{ 49 | } 50 | library(torch) 51 | library(topicmodels.etm) 52 | path <- system.file(package = "topicmodels.etm", "example", "example_etm.ckpt") 53 | model <- torch_load(path) 54 | 55 | # Get most emitted words for each topic 56 | terminology <- predict(model, type = "terms", top_n = 5) 57 | terminology 58 | 59 | # Get topics probabilities for each document 60 | path <- system.file(package = "topicmodels.etm", "example", "example_dtm.rds") 61 | dtm <- readRDS(path) 62 | dtm <- head(dtm, n = 5) 63 | scores <- predict(model, newdata = dtm, type = "topics") 64 | scores 65 | \dontshow{ 66 | \} 67 | # End of main if statement running only if the torch is properly installed 68 | } 69 | } 70 | \seealso{ 71 | \code{\link{ETM}} 72 | } 73 | -------------------------------------------------------------------------------- /inst/orig/ETM/README.md: -------------------------------------------------------------------------------- 1 | # ETM 2 | 3 | This is code that accompanies the paper titled "Topic Modeling in Embedding Spaces" by Adji B. Dieng, Francisco J. R. Ruiz, and David M. Blei. (Arxiv link: https://arxiv.org/abs/1907.04907) 4 | 5 | ETM defines words and topics in the same embedding space. The likelihood of a word under ETM is a Categorical whose natural parameter is given by the dot product between the word embedding and its assigned topic's embedding. ETM is a document model that learns interpretable topics and word embeddings and is robust to large vocabularies that include rare words and stop words. 6 | 7 | ## Dependencies 8 | 9 | + python 3.6.7 10 | + pytorch 1.1.0 11 | 12 | ## Datasets 13 | 14 | All the datasets are pre-processed and can be found below: 15 | 16 | + https://bitbucket.org/franrruiz/data_nyt_largev_4/src/master/ 17 | + https://bitbucket.org/franrruiz/data_nyt_largev_5/src/master/ 18 | + https://bitbucket.org/franrruiz/data_nyt_largev_6/src/master/ 19 | + https://bitbucket.org/franrruiz/data_nyt_largev_7/src/master/ 20 | + https://bitbucket.org/franrruiz/data_stopwords_largev_2/src/master/ (this one contains stop words and was used to showcase robustness of ETM to stop words.) 21 | + https://bitbucket.org/franrruiz/data_20ng_largev/src/master/ 22 | 23 | All the scripts to pre-process a given dataset for ETM can be found in the folder 'scripts'. The script for 20NewsGroup is self-contained as it uses scikit-learn. If you want to run ETM on your own dataset, follow the script for New York Times (given as example) called data_nyt.py 24 | 25 | ## To Run 26 | 27 | To learn interpretable embeddings and topics using ETM on the 20NewsGroup dataset, run 28 | ``` 29 | python main.py --mode train --dataset 20ng --data_path data/20ng --num_topics 50 --train_embeddings 1 --epochs 1000 30 | ``` 31 | 32 | To evaluate perplexity on document completion, topic coherence, topic diversity, and visualize the topics/embeddings run 33 | ``` 34 | python main.py --mode eval --dataset 20ng --data_path data/20ng --num_topics 50 --train_embeddings 1 --tc 1 --td 1 --load_from CKPT_PATH 35 | ``` 36 | 37 | To learn interpretable topics using ETM with pre-fitted word embeddings (called Labelled-ETM in the paper) on the 20NewsGroup dataset: 38 | 39 | + first fit the word embeddings. For example to use simple skipgram you can run 40 | ``` 41 | python skipgram.py --data_file PATH_TO_DATA --emb_file PATH_TO_EMBEDDINGS --dim_rho 300 --iters 50 --window_size 4 42 | ``` 43 | 44 | + then run the following 45 | ``` 46 | python main.py --mode train --dataset 20ng --data_path data/20ng --emb_path PATH_TO_EMBEDDINGS --num_topics 50 --train_embeddings 0 --epochs 1000 47 | ``` 48 | 49 | ## Citation 50 | 51 | ``` 52 | @article{dieng2019topic, 53 | title={Topic modeling in embedding spaces}, 54 | author={Dieng, Adji B and Ruiz, Francisco J R and Blei, David M}, 55 | journal={arXiv preprint arXiv:1907.04907}, 56 | year={2019} 57 | } 58 | ``` 59 | 60 | -------------------------------------------------------------------------------- /inst/orig/ETM/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def get_topic_diversity(beta, topk): 5 | num_topics = beta.shape[0] 6 | list_w = np.zeros((num_topics, topk)) 7 | for k in range(num_topics): 8 | idx = beta[k,:].argsort()[-topk:][::-1] 9 | list_w[k,:] = idx 10 | n_unique = len(np.unique(list_w)) 11 | TD = n_unique / (topk * num_topics) 12 | print('Topic diveristy is: {}'.format(TD)) 13 | 14 | def get_document_frequency(data, wi, wj=None): 15 | if wj is None: 16 | D_wi = 0 17 | for l in range(len(data)): 18 | doc = data[l].squeeze(0) 19 | if len(doc) == 1: 20 | continue 21 | else: 22 | doc = doc.squeeze() 23 | if wi in doc: 24 | D_wi += 1 25 | return D_wi 26 | D_wj = 0 27 | D_wi_wj = 0 28 | for l in range(len(data)): 29 | doc = data[l].squeeze(0) 30 | if len(doc) == 1: 31 | doc = [doc.squeeze()] 32 | else: 33 | doc = doc.squeeze() 34 | if wj in doc: 35 | D_wj += 1 36 | if wi in doc: 37 | D_wi_wj += 1 38 | return D_wj, D_wi_wj 39 | 40 | def get_topic_coherence(beta, data, vocab): 41 | D = len(data) ## number of docs...data is list of documents 42 | print('D: ', D) 43 | TC = [] 44 | num_topics = len(beta) 45 | for k in range(num_topics): 46 | print('k: {}/{}'.format(k, num_topics)) 47 | top_10 = list(beta[k].argsort()[-11:][::-1]) 48 | top_words = [vocab[a] for a in top_10] 49 | TC_k = 0 50 | counter = 0 51 | for i, word in enumerate(top_10): 52 | # get D(w_i) 53 | D_wi = get_document_frequency(data, word) 54 | j = i + 1 55 | tmp = 0 56 | while j < len(top_10) and j > i: 57 | # get D(w_j) and D(w_i, w_j) 58 | D_wj, D_wi_wj = get_document_frequency(data, word, top_10[j]) 59 | # get f(w_i, w_j) 60 | if D_wi_wj == 0: 61 | f_wi_wj = -1 62 | else: 63 | f_wi_wj = -1 + ( np.log(D_wi) + np.log(D_wj) - 2.0 * np.log(D) ) / ( np.log(D_wi_wj) - np.log(D) ) 64 | # update tmp: 65 | tmp += f_wi_wj 66 | j += 1 67 | counter += 1 68 | # update TC_k 69 | TC_k += tmp 70 | TC.append(TC_k) 71 | print('counter: ', counter) 72 | print('num topics: ', len(TC)) 73 | TC = np.mean(TC) / counter 74 | print('Topic coherence is: {}'.format(TC)) 75 | 76 | def nearest_neighbors(word, embeddings, vocab): 77 | vectors = embeddings.data.cpu().numpy() 78 | index = vocab.index(word) 79 | print('vectors: ', vectors.shape) 80 | query = vectors[index] 81 | print('query: ', query.shape) 82 | ranks = vectors.dot(query).squeeze() 83 | denom = query.T.dot(query).squeeze() 84 | denom = denom * np.sum(vectors**2, 1) 85 | denom = np.sqrt(denom) 86 | ranks = ranks / denom 87 | mostSimilar = [] 88 | [mostSimilar.append(idx) for idx in ranks.argsort()[::-1]] 89 | nearest_neighbors = mostSimilar[:20] 90 | nearest_neighbors = [vocab[comp] for comp in nearest_neighbors] 91 | return nearest_neighbors 92 | -------------------------------------------------------------------------------- /man/plot.ETM.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/ETM.R 3 | \name{plot.ETM} 4 | \alias{plot.ETM} 5 | \title{Plot functionality for an ETM object} 6 | \usage{ 7 | \method{plot}{ETM}( 8 | x, 9 | type = c("loss", "topics"), 10 | which, 11 | top_n = 4, 12 | title = "ETM topics", 13 | subtitle = "", 14 | encircle = FALSE, 15 | points = FALSE, 16 | ... 17 | ) 18 | } 19 | \arguments{ 20 | \item{x}{an object of class \code{ETM}} 21 | 22 | \item{type}{character string with the type of plot to generate: either 'loss' or 'topics'} 23 | 24 | \item{which}{an integer vector of topics to plot, used in case type = 'topics'. Defaults to all topics. See the example below.} 25 | 26 | \item{top_n}{passed on to \code{summary.ETM} in order to visualise the top_n most relevant words for each topic. Defaults to 4.} 27 | 28 | \item{title}{passed on to textplot_embedding_2d, used in case type = 'topics'} 29 | 30 | \item{subtitle}{passed on to textplot_embedding_2d, used in case type = 'topics'} 31 | 32 | \item{encircle}{passed on to textplot_embedding_2d, used in case type = 'topics'} 33 | 34 | \item{points}{passed on to textplot_embedding_2d, used in case type = 'topics'} 35 | 36 | \item{...}{arguments passed on to \code{\link{summary.ETM}}} 37 | } 38 | \value{ 39 | In case \code{type} is set to 'topics', maps the topic centers and most emitted words for each topic 40 | to 2D using \code{\link{summary.ETM}} and returns a ggplot object by calling \code{\link[textplot]{textplot_embedding_2d}}. \cr 41 | For type 'loss', makes a base graphics plot and returns invisibly nothing. 42 | } 43 | \description{ 44 | Convenience function allowing to plot 45 | \itemize{ 46 | \item{the evolution of the loss on the training / test set in order to inspect training convergence} 47 | \item{the \code{ETM} model in 2D dimensional space using a umap projection. 48 | This plot uses function \code{\link[textplot]{textplot_embedding_2d}} from the textplot R package and 49 | plots the top_n most emitted words of each topic and the topic centers in 2 dimensions} 50 | } 51 | } 52 | \examples{ 53 | \dontshow{if(require(torch) && torch::torch_is_installed()) 54 | \{ 55 | } 56 | library(torch) 57 | library(topicmodels.etm) 58 | path <- system.file(package = "topicmodels.etm", "example", "example_etm.ckpt") 59 | model <- torch_load(path) 60 | plot(model, type = "loss") 61 | \dontshow{ 62 | \} 63 | # End of main if statement running only if the torch is properly installed 64 | } 65 | 66 | \dontshow{if(require(torch) && torch::torch_is_installed() && 67 | require(textplot) && require(uwot) && require(ggrepel)) 68 | \{ 69 | } 70 | library(torch) 71 | library(topicmodels.etm) 72 | library(textplot) 73 | library(uwot) 74 | library(ggrepel) 75 | path <- system.file(package = "topicmodels.etm", "example", "example_etm.ckpt") 76 | model <- torch_load(path) 77 | plt <- plot(model, type = "topics", top_n = 7, which = c(1, 2, 14, 16, 18, 19), 78 | metric = "cosine", n_neighbors = 15, 79 | fast_sgd = FALSE, n_threads = 2, verbose = TRUE, 80 | title = "ETM Topics example") 81 | plt 82 | \dontshow{ 83 | \} 84 | # End of main if statement running only if the torch is properly installed 85 | } 86 | } 87 | \seealso{ 88 | \code{\link{ETM}}, \code{\link{summary.ETM}}, \code{\link[textplot]{textplot_embedding_2d}} 89 | } 90 | -------------------------------------------------------------------------------- /.github/workflows/rhub.yaml: -------------------------------------------------------------------------------- 1 | # R-hub's generic GitHub Actions workflow file. It's canonical location is at 2 | # https://github.com/r-hub/actions/blob/v1/workflows/rhub.yaml 3 | # You can update this file to a newer version using the rhub2 package: 4 | # 5 | # rhub::rhub_setup() 6 | # 7 | # It is unlikely that you need to modify this file manually. 8 | 9 | name: R-hub 10 | run-name: "${{ github.event.inputs.id }}: ${{ github.event.inputs.name || format('Manually run by {0}', github.triggering_actor) }}" 11 | 12 | on: 13 | workflow_dispatch: 14 | inputs: 15 | config: 16 | description: 'A comma separated list of R-hub platforms to use.' 17 | type: string 18 | default: 'linux,windows,macos,clang-asan,clang-ubsan,gcc-asan,nold,rchk,ubuntu-clang,valgrind' 19 | name: 20 | description: 'Run name. You can leave this empty now.' 21 | type: string 22 | id: 23 | description: 'Unique ID. You can leave this empty now.' 24 | type: string 25 | 26 | jobs: 27 | 28 | setup: 29 | runs-on: ubuntu-latest 30 | outputs: 31 | containers: ${{ steps.rhub-setup.outputs.containers }} 32 | platforms: ${{ steps.rhub-setup.outputs.platforms }} 33 | 34 | steps: 35 | # NO NEED TO CHECKOUT HERE 36 | - uses: r-hub/actions/setup@v1 37 | with: 38 | config: ${{ github.event.inputs.config }} 39 | id: rhub-setup 40 | 41 | linux-containers: 42 | needs: setup 43 | if: ${{ needs.setup.outputs.containers != '[]' }} 44 | runs-on: ubuntu-latest 45 | name: ${{ matrix.config.label }} 46 | strategy: 47 | fail-fast: false 48 | matrix: 49 | config: ${{ fromJson(needs.setup.outputs.containers) }} 50 | container: 51 | image: ${{ matrix.config.container }} 52 | 53 | steps: 54 | - uses: r-hub/actions/checkout@v1 55 | - uses: r-hub/actions/platform-info@v1 56 | with: 57 | token: ${{ secrets.RHUB_TOKEN }} 58 | job-config: ${{ matrix.config.job-config }} 59 | - uses: r-hub/actions/setup-deps@v1 60 | with: 61 | token: ${{ secrets.RHUB_TOKEN }} 62 | job-config: ${{ matrix.config.job-config }} 63 | - uses: r-hub/actions/run-check@v1 64 | with: 65 | token: ${{ secrets.RHUB_TOKEN }} 66 | job-config: ${{ matrix.config.job-config }} 67 | 68 | other-platforms: 69 | needs: setup 70 | if: ${{ needs.setup.outputs.platforms != '[]' }} 71 | runs-on: ${{ matrix.config.os }} 72 | name: ${{ matrix.config.label }} 73 | strategy: 74 | fail-fast: false 75 | matrix: 76 | config: ${{ fromJson(needs.setup.outputs.platforms) }} 77 | 78 | steps: 79 | - uses: r-hub/actions/checkout@v1 80 | - uses: r-hub/actions/setup-r@v1 81 | with: 82 | job-config: ${{ matrix.config.job-config }} 83 | token: ${{ secrets.RHUB_TOKEN }} 84 | - uses: r-hub/actions/platform-info@v1 85 | with: 86 | token: ${{ secrets.RHUB_TOKEN }} 87 | job-config: ${{ matrix.config.job-config }} 88 | - uses: r-hub/actions/setup-deps@v1 89 | with: 90 | job-config: ${{ matrix.config.job-config }} 91 | token: ${{ secrets.RHUB_TOKEN }} 92 | - uses: r-hub/actions/run-check@v1 93 | with: 94 | job-config: ${{ matrix.config.job-config }} 95 | token: ${{ secrets.RHUB_TOKEN }} 96 | -------------------------------------------------------------------------------- /inst/orig/ETM/etm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import math 5 | 6 | from torch import nn 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | class ETM(nn.Module): 11 | def __init__(self, num_topics, vocab_size, t_hidden_size, rho_size, emsize, 12 | theta_act, embeddings=None, train_embeddings=True, enc_drop=0.5): 13 | super(ETM, self).__init__() 14 | 15 | ## define hyperparameters 16 | self.num_topics = num_topics 17 | self.vocab_size = vocab_size 18 | self.t_hidden_size = t_hidden_size 19 | self.rho_size = rho_size 20 | self.enc_drop = enc_drop 21 | self.emsize = emsize 22 | self.t_drop = nn.Dropout(enc_drop) 23 | 24 | self.theta_act = self.get_activation(theta_act) 25 | 26 | ## define the word embedding matrix \rho 27 | if train_embeddings: 28 | self.rho = nn.Linear(rho_size, vocab_size, bias=False) 29 | else: 30 | num_embeddings, emsize = embeddings.size() 31 | rho = nn.Embedding(num_embeddings, emsize) 32 | self.rho = embeddings.clone().float().to(device) 33 | 34 | ## define the matrix containing the topic embeddings 35 | self.alphas = nn.Linear(rho_size, num_topics, bias=False)#nn.Parameter(torch.randn(rho_size, num_topics)) 36 | 37 | ## define variational distribution for \theta_{1:D} via amortizartion 38 | self.q_theta = nn.Sequential( 39 | nn.Linear(vocab_size, t_hidden_size), 40 | self.theta_act, 41 | nn.Linear(t_hidden_size, t_hidden_size), 42 | self.theta_act, 43 | ) 44 | self.mu_q_theta = nn.Linear(t_hidden_size, num_topics, bias=True) 45 | self.logsigma_q_theta = nn.Linear(t_hidden_size, num_topics, bias=True) 46 | 47 | def get_activation(self, act): 48 | if act == 'tanh': 49 | act = nn.Tanh() 50 | elif act == 'relu': 51 | act = nn.ReLU() 52 | elif act == 'softplus': 53 | act = nn.Softplus() 54 | elif act == 'rrelu': 55 | act = nn.RReLU() 56 | elif act == 'leakyrelu': 57 | act = nn.LeakyReLU() 58 | elif act == 'elu': 59 | act = nn.ELU() 60 | elif act == 'selu': 61 | act = nn.SELU() 62 | elif act == 'glu': 63 | act = nn.GLU() 64 | else: 65 | print('Defaulting to tanh activations...') 66 | act = nn.Tanh() 67 | return act 68 | 69 | def reparameterize(self, mu, logvar): 70 | """Returns a sample from a Gaussian distribution via reparameterization. 71 | """ 72 | if self.training: 73 | std = torch.exp(0.5 * logvar) 74 | eps = torch.randn_like(std) 75 | return eps.mul_(std).add_(mu) 76 | else: 77 | return mu 78 | 79 | def encode(self, bows): 80 | """Returns paramters of the variational distribution for \theta. 81 | 82 | input: bows 83 | batch of bag-of-words...tensor of shape bsz x V 84 | output: mu_theta, log_sigma_theta 85 | """ 86 | q_theta = self.q_theta(bows) 87 | if self.enc_drop > 0: 88 | q_theta = self.t_drop(q_theta) 89 | mu_theta = self.mu_q_theta(q_theta) 90 | logsigma_theta = self.logsigma_q_theta(q_theta) 91 | kl_theta = -0.5 * torch.sum(1 + logsigma_theta - mu_theta.pow(2) - logsigma_theta.exp(), dim=-1).mean() 92 | return mu_theta, logsigma_theta, kl_theta 93 | 94 | def get_beta(self): 95 | try: 96 | logit = self.alphas(self.rho.weight) # torch.mm(self.rho, self.alphas) 97 | except: 98 | logit = self.alphas(self.rho) 99 | beta = F.softmax(logit, dim=0).transpose(1, 0) ## softmax over vocab dimension 100 | return beta 101 | 102 | def get_theta(self, normalized_bows): 103 | mu_theta, logsigma_theta, kld_theta = self.encode(normalized_bows) 104 | z = self.reparameterize(mu_theta, logsigma_theta) 105 | theta = F.softmax(z, dim=-1) 106 | return theta, kld_theta 107 | 108 | def decode(self, theta, beta): 109 | res = torch.mm(theta, beta) 110 | preds = torch.log(res+1e-6) 111 | return preds 112 | 113 | def forward(self, bows, normalized_bows, theta=None, aggregate=True): 114 | ## get \theta 115 | if theta is None: 116 | theta, kld_theta = self.get_theta(normalized_bows) 117 | else: 118 | kld_theta = None 119 | 120 | ## get \beta 121 | beta = self.get_beta() 122 | 123 | ## get prediction loss 124 | preds = self.decode(theta, beta) 125 | recon_loss = -(preds * bows).sum(1) 126 | if aggregate: 127 | recon_loss = recon_loss.mean() 128 | return recon_loss, kld_theta 129 | 130 | -------------------------------------------------------------------------------- /inst/orig/ETM/scripts/stops.txt: -------------------------------------------------------------------------------- 1 | a 2 | able 3 | about 4 | above 5 | according 6 | accordingly 7 | across 8 | actually 9 | after 10 | afterwards 11 | again 12 | against 13 | all 14 | allow 15 | allows 16 | almost 17 | alone 18 | along 19 | already 20 | also 21 | although 22 | always 23 | am 24 | among 25 | amongst 26 | an 27 | and 28 | another 29 | any 30 | anybody 31 | anyhow 32 | anyone 33 | anything 34 | anyway 35 | anyways 36 | anywhere 37 | apart 38 | appear 39 | appreciate 40 | appropriate 41 | are 42 | around 43 | as 44 | aside 45 | ask 46 | asking 47 | associated 48 | at 49 | available 50 | away 51 | awfully 52 | b 53 | be 54 | became 55 | because 56 | become 57 | becomes 58 | becoming 59 | been 60 | before 61 | beforehand 62 | behind 63 | being 64 | believe 65 | below 66 | beside 67 | besides 68 | best 69 | better 70 | between 71 | beyond 72 | both 73 | brief 74 | but 75 | by 76 | c 77 | came 78 | can 79 | cannot 80 | cant 81 | cause 82 | causes 83 | certain 84 | certainly 85 | changes 86 | clearly 87 | co 88 | com 89 | come 90 | comes 91 | concerning 92 | consequently 93 | consider 94 | considering 95 | contain 96 | containing 97 | contains 98 | corresponding 99 | could 100 | course 101 | currently 102 | d 103 | definitely 104 | described 105 | despite 106 | did 107 | different 108 | do 109 | does 110 | doing 111 | done 112 | down 113 | downwards 114 | during 115 | e 116 | each 117 | edu 118 | eg 119 | eight 120 | either 121 | else 122 | elsewhere 123 | enough 124 | entirely 125 | especially 126 | et 127 | etc 128 | even 129 | ever 130 | every 131 | everybody 132 | everyone 133 | everything 134 | everywhere 135 | ex 136 | exactly 137 | example 138 | except 139 | f 140 | far 141 | few 142 | fifth 143 | first 144 | five 145 | followed 146 | following 147 | follows 148 | for 149 | former 150 | formerly 151 | forth 152 | four 153 | from 154 | further 155 | furthermore 156 | g 157 | get 158 | gets 159 | getting 160 | given 161 | gives 162 | go 163 | goes 164 | going 165 | gone 166 | got 167 | gotten 168 | greetings 169 | h 170 | had 171 | happens 172 | hardly 173 | has 174 | have 175 | having 176 | he 177 | hello 178 | help 179 | hence 180 | her 181 | here 182 | hereafter 183 | hereby 184 | herein 185 | hereupon 186 | hers 187 | herself 188 | hi 189 | him 190 | himself 191 | his 192 | hither 193 | hopefully 194 | how 195 | howbeit 196 | however 197 | i 198 | ie 199 | if 200 | ignored 201 | immediate 202 | in 203 | inasmuch 204 | inc 205 | indeed 206 | indicate 207 | indicated 208 | indicates 209 | inner 210 | insofar 211 | instead 212 | into 213 | inward 214 | is 215 | it 216 | its 217 | itself 218 | j 219 | just 220 | k 221 | keep 222 | keeps 223 | kept 224 | know 225 | knows 226 | known 227 | l 228 | last 229 | lately 230 | later 231 | latter 232 | latterly 233 | least 234 | less 235 | lest 236 | let 237 | like 238 | liked 239 | likely 240 | little 241 | look 242 | looking 243 | looks 244 | ltd 245 | m 246 | mainly 247 | many 248 | may 249 | maybe 250 | me 251 | mean 252 | meanwhile 253 | merely 254 | might 255 | more 256 | moreover 257 | most 258 | mostly 259 | much 260 | must 261 | my 262 | myself 263 | n 264 | name 265 | namely 266 | nd 267 | near 268 | nearly 269 | necessary 270 | need 271 | needs 272 | neither 273 | never 274 | nevertheless 275 | new 276 | next 277 | nine 278 | no 279 | nobody 280 | non 281 | none 282 | noone 283 | nor 284 | normally 285 | not 286 | nothing 287 | novel 288 | now 289 | nowhere 290 | o 291 | obviously 292 | of 293 | off 294 | often 295 | oh 296 | ok 297 | okay 298 | old 299 | on 300 | once 301 | one 302 | ones 303 | only 304 | onto 305 | or 306 | other 307 | others 308 | otherwise 309 | ought 310 | our 311 | ours 312 | ourselves 313 | out 314 | outside 315 | over 316 | overall 317 | own 318 | p 319 | particular 320 | particularly 321 | per 322 | perhaps 323 | placed 324 | please 325 | plus 326 | possible 327 | presumably 328 | probably 329 | provides 330 | q 331 | que 332 | quite 333 | qv 334 | r 335 | rather 336 | rd 337 | re 338 | really 339 | reasonably 340 | regarding 341 | regardless 342 | regards 343 | relatively 344 | respectively 345 | right 346 | s 347 | said 348 | same 349 | saw 350 | say 351 | saying 352 | says 353 | second 354 | secondly 355 | see 356 | seeing 357 | seem 358 | seemed 359 | seeming 360 | seems 361 | seen 362 | self 363 | selves 364 | sensible 365 | sent 366 | serious 367 | seriously 368 | seven 369 | several 370 | shall 371 | she 372 | should 373 | since 374 | six 375 | so 376 | some 377 | somebody 378 | somehow 379 | someone 380 | something 381 | sometime 382 | sometimes 383 | somewhat 384 | somewhere 385 | soon 386 | sorry 387 | specified 388 | specify 389 | specifying 390 | still 391 | sub 392 | such 393 | sup 394 | sure 395 | t 396 | take 397 | taken 398 | tell 399 | tends 400 | th 401 | than 402 | thank 403 | thanks 404 | thanx 405 | that 406 | thats 407 | the 408 | their 409 | theirs 410 | them 411 | themselves 412 | then 413 | thence 414 | there 415 | thereafter 416 | thereby 417 | therefore 418 | therein 419 | theres 420 | thereupon 421 | these 422 | they 423 | think 424 | third 425 | this 426 | thorough 427 | thoroughly 428 | those 429 | though 430 | three 431 | through 432 | throughout 433 | thru 434 | thus 435 | to 436 | together 437 | too 438 | took 439 | toward 440 | towards 441 | tried 442 | tries 443 | truly 444 | try 445 | trying 446 | twice 447 | two 448 | u 449 | un 450 | under 451 | unfortunately 452 | unless 453 | unlikely 454 | until 455 | unto 456 | up 457 | upon 458 | us 459 | use 460 | used 461 | useful 462 | uses 463 | using 464 | usually 465 | uucp 466 | v 467 | value 468 | various 469 | very 470 | via 471 | viz 472 | vs 473 | w 474 | want 475 | wants 476 | was 477 | way 478 | we 479 | welcome 480 | well 481 | went 482 | were 483 | what 484 | whatever 485 | when 486 | whence 487 | whenever 488 | where 489 | whereafter 490 | whereas 491 | whereby 492 | wherein 493 | whereupon 494 | wherever 495 | whether 496 | which 497 | while 498 | whither 499 | who 500 | whoever 501 | whole 502 | whom 503 | whose 504 | why 505 | will 506 | willing 507 | wish 508 | with 509 | within 510 | without 511 | wonder 512 | would 513 | would 514 | x 515 | y 516 | yes 517 | yet 518 | you 519 | your 520 | yours 521 | yourself 522 | yourselves 523 | z 524 | zero 525 | -------------------------------------------------------------------------------- /man/ETM.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/ETM.R 3 | \name{ETM} 4 | \alias{ETM} 5 | \title{Topic Modelling in Semantic Embedding Spaces} 6 | \usage{ 7 | ETM( 8 | k = 20, 9 | embeddings, 10 | dim = 800, 11 | activation = c("relu", "tanh", "softplus", "rrelu", "leakyrelu", "elu", "selu", "glu"), 12 | dropout = 0.5, 13 | vocab = rownames(embeddings) 14 | ) 15 | } 16 | \arguments{ 17 | \item{k}{the number of topics to extract} 18 | 19 | \item{embeddings}{either a matrix with pretrained word embeddings or an integer with the dimension of the word embeddings. Defaults to 50 if not provided.} 20 | 21 | \item{dim}{dimension of the variational inference hyperparameter theta (passed on to \code{\link[torch]{nn_linear}}). Defaults to 800.} 22 | 23 | \item{activation}{character string with the activation function of theta. Either one of 'relu', 'tanh', 'softplus', 'rrelu', 'leakyrelu', 'elu', 'selu', 'glu'. Defaults to 'relu'.} 24 | 25 | \item{dropout}{dropout percentage on the variational distribution for theta (passed on to \code{\link[torch]{nn_dropout}}). Defaults to 0.5.} 26 | 27 | \item{vocab}{a character vector with the words from the vocabulary. Defaults to the rownames of the \code{embeddings} argument.} 28 | } 29 | \value{ 30 | an object of class ETM which is a torch \code{nn_module} containing o.a. 31 | \itemize{ 32 | \item num_topics: the number of topics 33 | \item vocab: character vector with the terminology used in the model 34 | \item vocab_size: the number of words in \code{vocab} 35 | \item rho: The word embeddings 36 | \item alphas: The topic embeddings 37 | } 38 | } 39 | \description{ 40 | ETM is a generative topic model combining traditional topic models (LDA) with word embeddings (word2vec). \cr 41 | \itemize{ 42 | \item{It models each word with a categorical distribution whose natural parameter is the inner product between 43 | a word embedding and an embedding of its assigned topic.} 44 | \item{The model is fitted using an amortized variational inference algorithm on top of libtorch.} 45 | } 46 | } 47 | \section{Methods}{ 48 | 49 | \describe{ 50 | \item{\code{fit(data, optimizer, epoch, batch_size, normalize = TRUE, clip = 0, lr_anneal_factor = 4, lr_anneal_nonmono = 10)}}{Fit the model on a document term matrix by splitting the data in 70/30 training/test set and updating the model weights.} 51 | } 52 | } 53 | 54 | \section{Arguments}{ 55 | 56 | \describe{ 57 | \item{data}{bag of words document term matrix in \code{dgCMatrix} format} 58 | \item{optimizer}{object of class \code{torch_Optimizer}} 59 | \item{epoch}{integer with the number of iterations to train} 60 | \item{batch_size}{integer with the size of the batch} 61 | \item{normalize}{logical indicating to normalize the bag of words data} 62 | \item{clip}{number between 0 and 1 indicating to do gradient clipping - passed on to \code{\link[torch]{nn_utils_clip_grad_norm_}}} 63 | \item{lr_anneal_factor}{divide the learning rate by this factor when the loss on the test set is monotonic for at least \code{lr_anneal_nonmono} training iterations} 64 | \item{lr_anneal_nonmono}{number of iterations after which learning rate annealing is executed if the loss does not decreases} 65 | } 66 | } 67 | 68 | \examples{ 69 | library(torch) 70 | library(topicmodels.etm) 71 | library(word2vec) 72 | library(udpipe) 73 | data(brussels_reviews_anno, package = "udpipe") 74 | ## 75 | ## Toy example with pretrained embeddings 76 | ## 77 | 78 | ## a. build word2vec model 79 | x <- subset(brussels_reviews_anno, language \%in\% "nl") 80 | x <- paste.data.frame(x, term = "lemma", group = "doc_id") 81 | set.seed(4321) 82 | w2v <- word2vec(x = x$lemma, dim = 15, iter = 20, type = "cbow", min_count = 5) 83 | embeddings <- as.matrix(w2v) 84 | 85 | ## b. build document term matrix on nouns + adjectives, align with the embedding terms 86 | dtm <- subset(brussels_reviews_anno, language \%in\% "nl" & upos \%in\% c("NOUN", "ADJ")) 87 | dtm <- document_term_frequencies(dtm, document = "doc_id", term = "lemma") 88 | dtm <- document_term_matrix(dtm) 89 | dtm <- dtm_conform(dtm, columns = rownames(embeddings)) 90 | dtm <- dtm[dtm_rowsums(dtm) > 0, ] 91 | 92 | ## create and fit an embedding topic model - 8 topics, theta 100-dimensional 93 | if (torch::torch_is_installed()) { 94 | 95 | set.seed(4321) 96 | torch_manual_seed(4321) 97 | model <- ETM(k = 8, dim = 100, embeddings = embeddings, dropout = 0.5) 98 | optimizer <- optim_adam(params = model$parameters, lr = 0.005, weight_decay = 0.0000012) 99 | overview <- model$fit(data = dtm, optimizer = optimizer, epoch = 40, batch_size = 1000) 100 | scores <- predict(model, dtm, type = "topics") 101 | 102 | lastbatch <- subset(overview$loss, overview$loss$batch_is_last == TRUE) 103 | plot(lastbatch$epoch, lastbatch$loss) 104 | plot(overview$loss_test) 105 | 106 | ## show top words in each topic 107 | terminology <- predict(model, type = "terms", top_n = 7) 108 | terminology 109 | 110 | ## 111 | ## Toy example without pretrained word embeddings 112 | ## 113 | set.seed(4321) 114 | torch_manual_seed(4321) 115 | model <- ETM(k = 8, dim = 100, embeddings = 15, dropout = 0.5, vocab = colnames(dtm)) 116 | optimizer <- optim_adam(params = model$parameters, lr = 0.005, weight_decay = 0.0000012) 117 | overview <- model$fit(data = dtm, optimizer = optimizer, epoch = 40, batch_size = 1000) 118 | terminology <- predict(model, type = "terms", top_n = 7) 119 | terminology 120 | 121 | 122 | 123 | \dontshow{ 124 | ## 125 | ## Another example using fit_original 126 | ## 127 | data(ng20, package = "topicmodels.etm") 128 | vocab <- ng20$vocab 129 | tokens <- ng20$bow_tr$tokens 130 | counts <- ng20$bow_tr$counts 131 | 132 | torch_manual_seed(123456789) 133 | model <- ETM(k = 4, vocab = vocab, dim = 5, embeddings = 25) 134 | model 135 | optimizer <- optim_adam(params = model$parameters, lr = 0.005, weight_decay = 0.0000012) 136 | 137 | traindata <- list(tokens = tokens, counts = counts, vocab = vocab) 138 | test1 <- list(tokens = ng20$bow_ts_h1$tokens, counts = ng20$bow_ts_h1$counts, vocab = vocab) 139 | test2 <- list(tokens = ng20$bow_ts_h2$tokens, counts = ng20$bow_ts_h2$counts, vocab = vocab) 140 | 141 | out <- model$fit_original(data = traindata, test1 = test1, test2 = test2, epoch = 4, 142 | optimizer = optimizer, batch_size = 1000, 143 | lr_anneal_factor = 4, lr_anneal_nonmono = 10) 144 | test <- subset(out$loss, out$loss$batch_is_last == TRUE) 145 | plot(test$epoch, test$loss) 146 | 147 | topic.centers <- as.matrix(model, type = "embedding", which = "topics") 148 | word.embeddings <- as.matrix(model, type = "embedding", which = "words") 149 | topic.terminology <- as.matrix(model, type = "beta") 150 | 151 | terminology <- predict(model, type = "terms", top_n = 4) 152 | terminology 153 | } 154 | 155 | } 156 | } 157 | \references{ 158 | \url{https://arxiv.org/pdf/1907.04907.pdf} 159 | } 160 | -------------------------------------------------------------------------------- /inst/orig/ETM/scripts/data_nyt.py: -------------------------------------------------------------------------------- 1 | from sklearn.feature_extraction.text import CountVectorizer 2 | import numpy as np 3 | import pickle 4 | import random 5 | from scipy import sparse 6 | import itertools 7 | from scipy.io import savemat, loadmat 8 | 9 | # Maximum / minimum document frequency 10 | max_df = 0.7 11 | min_df = 100 # choose desired value for min_df 12 | 13 | # Read stopwords 14 | with open('stops.txt', 'r') as f: 15 | stops = f.read().split('\n') 16 | 17 | # Read data 18 | print('reading text file...') 19 | data_file = 'raw/new_york_times_text/nyt_docs.txt' 20 | with open(data_file, 'r') as f: 21 | docs = f.readlines() 22 | 23 | # Create count vectorizer 24 | print('counting document frequency of words...') 25 | cvectorizer = CountVectorizer(min_df=min_df, max_df=max_df, stop_words=None) 26 | cvz = cvectorizer.fit_transform(docs).sign() 27 | 28 | # Get vocabulary 29 | print('building the vocabulary...') 30 | sum_counts = cvz.sum(axis=0) 31 | v_size = sum_counts.shape[1] 32 | sum_counts_np = np.zeros(v_size, dtype=int) 33 | for v in range(v_size): 34 | sum_counts_np[v] = sum_counts[0,v] 35 | word2id = dict([(w, cvectorizer.vocabulary_.get(w)) for w in cvectorizer.vocabulary_]) 36 | id2word = dict([(cvectorizer.vocabulary_.get(w), w) for w in cvectorizer.vocabulary_]) 37 | del cvectorizer 38 | print(' initial vocabulary size: {}'.format(v_size)) 39 | 40 | # Sort elements in vocabulary 41 | idx_sort = np.argsort(sum_counts_np) 42 | vocab_aux = [id2word[idx_sort[cc]] for cc in range(v_size)] 43 | 44 | # Filter out stopwords (if any) 45 | vocab_aux = [w for w in vocab_aux if w not in stops] 46 | print(' vocabulary size after removing stopwords from list: {}'.format(len(vocab_aux))) 47 | print(' vocabulary after removing stopwords: {}'.format(len(vocab_aux))) 48 | 49 | # Create dictionary and inverse dictionary 50 | vocab = vocab_aux 51 | del vocab_aux 52 | word2id = dict([(w, j) for j, w in enumerate(vocab)]) 53 | id2word = dict([(j, w) for j, w in enumerate(vocab)]) 54 | 55 | # Split in train/test/valid 56 | print('tokenizing documents and splitting into train/test/valid...') 57 | num_docs = cvz.shape[0] 58 | trSize = int(np.floor(0.85*num_docs)) 59 | tsSize = int(np.floor(0.10*num_docs)) 60 | vaSize = int(num_docs - trSize - tsSize) 61 | del cvz 62 | idx_permute = np.random.permutation(num_docs).astype(int) 63 | 64 | # Remove words not in train_data 65 | vocab = list(set([w for idx_d in range(trSize) for w in docs[idx_permute[idx_d]].split() if w in word2id])) 66 | word2id = dict([(w, j) for j, w in enumerate(vocab)]) 67 | id2word = dict([(j, w) for j, w in enumerate(vocab)]) 68 | print(' vocabulary after removing words not in train: {}'.format(len(vocab))) 69 | 70 | docs_tr = [[word2id[w] for w in docs[idx_permute[idx_d]].split() if w in word2id] for idx_d in range(trSize)] 71 | docs_ts = [[word2id[w] for w in docs[idx_permute[idx_d+trSize]].split() if w in word2id] for idx_d in range(tsSize)] 72 | docs_va = [[word2id[w] for w in docs[idx_permute[idx_d+trSize+tsSize]].split() if w in word2id] for idx_d in range(vaSize)] 73 | del docs 74 | 75 | print(' number of documents (train): {} [this should be equal to {}]'.format(len(docs_tr), trSize)) 76 | print(' number of documents (test): {} [this should be equal to {}]'.format(len(docs_ts), tsSize)) 77 | print(' number of documents (valid): {} [this should be equal to {}]'.format(len(docs_va), vaSize)) 78 | 79 | # Remove empty documents 80 | print('removing empty documents...') 81 | 82 | def remove_empty(in_docs): 83 | return [doc for doc in in_docs if doc!=[]] 84 | 85 | docs_tr = remove_empty(docs_tr) 86 | docs_ts = remove_empty(docs_ts) 87 | docs_va = remove_empty(docs_va) 88 | 89 | # Remove test documents with length=1 90 | docs_ts = [doc for doc in docs_ts if len(doc)>1] 91 | 92 | # Split test set in 2 halves 93 | print('splitting test documents in 2 halves...') 94 | docs_ts_h1 = [[w for i,w in enumerate(doc) if i<=len(doc)/2.0-1] for doc in docs_ts] 95 | docs_ts_h2 = [[w for i,w in enumerate(doc) if i>len(doc)/2.0-1] for doc in docs_ts] 96 | 97 | # Getting lists of words and doc_indices 98 | print('creating lists of words...') 99 | 100 | def create_list_words(in_docs): 101 | return [x for y in in_docs for x in y] 102 | 103 | words_tr = create_list_words(docs_tr) 104 | words_ts = create_list_words(docs_ts) 105 | words_ts_h1 = create_list_words(docs_ts_h1) 106 | words_ts_h2 = create_list_words(docs_ts_h2) 107 | words_va = create_list_words(docs_va) 108 | 109 | print(' len(words_tr): ', len(words_tr)) 110 | print(' len(words_ts): ', len(words_ts)) 111 | print(' len(words_ts_h1): ', len(words_ts_h1)) 112 | print(' len(words_ts_h2): ', len(words_ts_h2)) 113 | print(' len(words_va): ', len(words_va)) 114 | 115 | # Get doc indices 116 | print('getting doc indices...') 117 | 118 | def create_doc_indices(in_docs): 119 | aux = [[j for i in range(len(doc))] for j, doc in enumerate(in_docs)] 120 | return [int(x) for y in aux for x in y] 121 | 122 | doc_indices_tr = create_doc_indices(docs_tr) 123 | doc_indices_ts = create_doc_indices(docs_ts) 124 | doc_indices_ts_h1 = create_doc_indices(docs_ts_h1) 125 | doc_indices_ts_h2 = create_doc_indices(docs_ts_h2) 126 | doc_indices_va = create_doc_indices(docs_va) 127 | 128 | print(' len(np.unique(doc_indices_tr)): {} [this should be {}]'.format(len(np.unique(doc_indices_tr)), len(docs_tr))) 129 | print(' len(np.unique(doc_indices_ts)): {} [this should be {}]'.format(len(np.unique(doc_indices_ts)), len(docs_ts))) 130 | print(' len(np.unique(doc_indices_ts_h1)): {} [this should be {}]'.format(len(np.unique(doc_indices_ts_h1)), len(docs_ts_h1))) 131 | print(' len(np.unique(doc_indices_ts_h2)): {} [this should be {}]'.format(len(np.unique(doc_indices_ts_h2)), len(docs_ts_h2))) 132 | print(' len(np.unique(doc_indices_va)): {} [this should be {}]'.format(len(np.unique(doc_indices_va)), len(docs_va))) 133 | 134 | # Number of documents in each set 135 | n_docs_tr = len(docs_tr) 136 | n_docs_ts = len(docs_ts) 137 | n_docs_ts_h1 = len(docs_ts_h1) 138 | n_docs_ts_h2 = len(docs_ts_h2) 139 | n_docs_va = len(docs_va) 140 | 141 | # Remove unused variables 142 | del docs_tr 143 | del docs_ts 144 | del docs_ts_h1 145 | del docs_ts_h2 146 | del docs_va 147 | 148 | # Create bow representation 149 | print('creating bow representation...') 150 | 151 | def create_bow(doc_indices, words, n_docs, vocab_size): 152 | return sparse.coo_matrix(([1]*len(doc_indices),(doc_indices, words)), shape=(n_docs, vocab_size)).tocsr() 153 | 154 | bow_tr = create_bow(doc_indices_tr, words_tr, n_docs_tr, len(vocab)) 155 | bow_ts = create_bow(doc_indices_ts, words_ts, n_docs_ts, len(vocab)) 156 | bow_ts_h1 = create_bow(doc_indices_ts_h1, words_ts_h1, n_docs_ts_h1, len(vocab)) 157 | bow_ts_h2 = create_bow(doc_indices_ts_h2, words_ts_h2, n_docs_ts_h2, len(vocab)) 158 | bow_va = create_bow(doc_indices_va, words_va, n_docs_va, len(vocab)) 159 | 160 | del words_tr 161 | del words_ts 162 | del words_ts_h1 163 | del words_ts_h2 164 | del words_va 165 | del doc_indices_tr 166 | del doc_indices_ts 167 | del doc_indices_ts_h1 168 | del doc_indices_ts_h2 169 | del doc_indices_va 170 | 171 | # Save vocabulary to file 172 | path_save = './min_df_' + str(min_df) + '/' 173 | if not os.path.isdir(path_save): 174 | os.system('mkdir -p ' + path_save) 175 | 176 | with open(path_save + 'vocab.pkl', 'wb') as f: 177 | pickle.dump(vocab, f) 178 | del vocab 179 | 180 | # Split bow intro token/value pairs 181 | print('splitting bow intro token/value pairs and saving to disk...') 182 | 183 | def split_bow(bow_in, n_docs): 184 | indices = [[w for w in bow_in[doc,:].indices] for doc in range(n_docs)] 185 | counts = [[c for c in bow_in[doc,:].data] for doc in range(n_docs)] 186 | return indices, counts 187 | 188 | bow_tr_tokens, bow_tr_counts = split_bow(bow_tr, n_docs_tr) 189 | savemat(path_save + 'bow_tr_tokens', {'tokens': bow_tr_tokens}, do_compression=True) 190 | savemat(path_save + 'bow_tr_counts', {'counts': bow_tr_counts}, do_compression=True) 191 | del bow_tr 192 | del bow_tr_tokens 193 | del bow_tr_counts 194 | 195 | bow_ts_tokens, bow_ts_counts = split_bow(bow_ts, n_docs_ts) 196 | savemat(path_save + 'bow_ts_tokens', {'tokens': bow_ts_tokens}, do_compression=True) 197 | savemat(path_save + 'bow_ts_counts', {'counts': bow_ts_counts}, do_compression=True) 198 | del bow_ts 199 | del bow_ts_tokens 200 | del bow_ts_counts 201 | 202 | bow_ts_h1_tokens, bow_ts_h1_counts = split_bow(bow_ts_h1, n_docs_ts_h1) 203 | savemat(path_save + 'bow_ts_h1_tokens', {'tokens': bow_ts_h1_tokens}, do_compression=True) 204 | savemat(path_save + 'bow_ts_h1_counts', {'counts': bow_ts_h1_counts}, do_compression=True) 205 | del bow_ts_h1 206 | del bow_ts_h1_tokens 207 | del bow_ts_h1_counts 208 | 209 | bow_ts_h2_tokens, bow_ts_h2_counts = split_bow(bow_ts_h2, n_docs_ts_h2) 210 | savemat(path_save + 'bow_ts_h2_tokens', {'tokens': bow_ts_h2_tokens}, do_compression=True) 211 | savemat(path_save + 'bow_ts_h2_counts', {'counts': bow_ts_h2_counts}, do_compression=True) 212 | del bow_ts_h2 213 | del bow_ts_h2_tokens 214 | del bow_ts_h2_counts 215 | 216 | bow_va_tokens, bow_va_counts = split_bow(bow_va, n_docs_va) 217 | savemat(path_save + 'bow_va_tokens', {'tokens': bow_va_tokens}, do_compression=True) 218 | savemat(path_save + 'bow_va_counts', {'counts': bow_va_counts}, do_compression=True) 219 | del bow_va 220 | del bow_va_tokens 221 | del bow_va_counts 222 | 223 | print('Data ready !!') 224 | print('*************') 225 | 226 | -------------------------------------------------------------------------------- /inst/orig/ETM/scripts/data_20ng.py: -------------------------------------------------------------------------------- 1 | from sklearn.feature_extraction.text import CountVectorizer 2 | from sklearn.datasets import fetch_20newsgroups 3 | import numpy as np 4 | import pickle 5 | import random 6 | from scipy import sparse 7 | import itertools 8 | from scipy.io import savemat, loadmat 9 | import re 10 | import string 11 | 12 | # Maximum / minimum document frequency 13 | max_df = 0.7 14 | min_df = 10 # choose desired value for min_df 15 | 16 | # Read stopwords 17 | with open('stops.txt', 'r') as f: 18 | stops = f.read().split('\n') 19 | 20 | # Read data 21 | print('reading data...') 22 | train_data = fetch_20newsgroups(subset='train') 23 | test_data = fetch_20newsgroups(subset='test') 24 | 25 | init_docs_tr = [re.findall(r'''[\w']+|[.,!?;-~{}`´_<=>:/@*()&'$%#"]''', train_data.data[doc]) for doc in range(len(train_data.data))] 26 | init_docs_ts = [re.findall(r'''[\w']+|[.,!?;-~{}`´_<=>:/@*()&'$%#"]''', test_data.data[doc]) for doc in range(len(test_data.data))] 27 | 28 | def contains_punctuation(w): 29 | return any(char in string.punctuation for char in w) 30 | 31 | def contains_numeric(w): 32 | return any(char.isdigit() for char in w) 33 | 34 | init_docs = init_docs_tr + init_docs_ts 35 | init_docs = [[w.lower() for w in init_docs[doc] if not contains_punctuation(w)] for doc in range(len(init_docs))] 36 | init_docs = [[w for w in init_docs[doc] if not contains_numeric(w)] for doc in range(len(init_docs))] 37 | init_docs = [[w for w in init_docs[doc] if len(w)>1] for doc in range(len(init_docs))] 38 | init_docs = [" ".join(init_docs[doc]) for doc in range(len(init_docs))] 39 | 40 | # Create count vectorizer 41 | print('counting document frequency of words...') 42 | cvectorizer = CountVectorizer(min_df=min_df, max_df=max_df, stop_words=None) 43 | cvz = cvectorizer.fit_transform(init_docs).sign() 44 | 45 | # Get vocabulary 46 | print('building the vocabulary...') 47 | sum_counts = cvz.sum(axis=0) 48 | v_size = sum_counts.shape[1] 49 | sum_counts_np = np.zeros(v_size, dtype=int) 50 | for v in range(v_size): 51 | sum_counts_np[v] = sum_counts[0,v] 52 | word2id = dict([(w, cvectorizer.vocabulary_.get(w)) for w in cvectorizer.vocabulary_]) 53 | id2word = dict([(cvectorizer.vocabulary_.get(w), w) for w in cvectorizer.vocabulary_]) 54 | del cvectorizer 55 | print(' initial vocabulary size: {}'.format(v_size)) 56 | 57 | # Sort elements in vocabulary 58 | idx_sort = np.argsort(sum_counts_np) 59 | vocab_aux = [id2word[idx_sort[cc]] for cc in range(v_size)] 60 | 61 | # Filter out stopwords (if any) 62 | vocab_aux = [w for w in vocab_aux if w not in stops] 63 | print(' vocabulary size after removing stopwords from list: {}'.format(len(vocab_aux))) 64 | 65 | # Create dictionary and inverse dictionary 66 | vocab = vocab_aux 67 | del vocab_aux 68 | word2id = dict([(w, j) for j, w in enumerate(vocab)]) 69 | id2word = dict([(j, w) for j, w in enumerate(vocab)]) 70 | 71 | # Split in train/test/valid 72 | print('tokenizing documents and splitting into train/test/valid...') 73 | num_docs_tr = len(init_docs_tr) 74 | trSize = num_docs_tr-100 75 | tsSize = len(init_docs_ts) 76 | vaSize = 100 77 | idx_permute = np.random.permutation(num_docs_tr).astype(int) 78 | 79 | # Remove words not in train_data 80 | vocab = list(set([w for idx_d in range(trSize) for w in init_docs[idx_permute[idx_d]].split() if w in word2id])) 81 | word2id = dict([(w, j) for j, w in enumerate(vocab)]) 82 | id2word = dict([(j, w) for j, w in enumerate(vocab)]) 83 | print(' vocabulary after removing words not in train: {}'.format(len(vocab))) 84 | 85 | # Split in train/test/valid 86 | docs_tr = [[word2id[w] for w in init_docs[idx_permute[idx_d]].split() if w in word2id] for idx_d in range(trSize)] 87 | docs_va = [[word2id[w] for w in init_docs[idx_permute[idx_d+trSize]].split() if w in word2id] for idx_d in range(vaSize)] 88 | docs_ts = [[word2id[w] for w in init_docs[idx_d+num_docs_tr].split() if w in word2id] for idx_d in range(tsSize)] 89 | 90 | print(' number of documents (train): {} [this should be equal to {}]'.format(len(docs_tr), trSize)) 91 | print(' number of documents (test): {} [this should be equal to {}]'.format(len(docs_ts), tsSize)) 92 | print(' number of documents (valid): {} [this should be equal to {}]'.format(len(docs_va), vaSize)) 93 | 94 | # Remove empty documents 95 | print('removing empty documents...') 96 | 97 | def remove_empty(in_docs): 98 | return [doc for doc in in_docs if doc!=[]] 99 | 100 | docs_tr = remove_empty(docs_tr) 101 | docs_ts = remove_empty(docs_ts) 102 | docs_va = remove_empty(docs_va) 103 | 104 | # Remove test documents with length=1 105 | docs_ts = [doc for doc in docs_ts if len(doc)>1] 106 | 107 | # Split test set in 2 halves 108 | print('splitting test documents in 2 halves...') 109 | docs_ts_h1 = [[w for i,w in enumerate(doc) if i<=len(doc)/2.0-1] for doc in docs_ts] 110 | docs_ts_h2 = [[w for i,w in enumerate(doc) if i>len(doc)/2.0-1] for doc in docs_ts] 111 | 112 | # Getting lists of words and doc_indices 113 | print('creating lists of words...') 114 | 115 | def create_list_words(in_docs): 116 | return [x for y in in_docs for x in y] 117 | 118 | words_tr = create_list_words(docs_tr) 119 | words_ts = create_list_words(docs_ts) 120 | words_ts_h1 = create_list_words(docs_ts_h1) 121 | words_ts_h2 = create_list_words(docs_ts_h2) 122 | words_va = create_list_words(docs_va) 123 | 124 | print(' len(words_tr): ', len(words_tr)) 125 | print(' len(words_ts): ', len(words_ts)) 126 | print(' len(words_ts_h1): ', len(words_ts_h1)) 127 | print(' len(words_ts_h2): ', len(words_ts_h2)) 128 | print(' len(words_va): ', len(words_va)) 129 | 130 | # Get doc indices 131 | print('getting doc indices...') 132 | 133 | def create_doc_indices(in_docs): 134 | aux = [[j for i in range(len(doc))] for j, doc in enumerate(in_docs)] 135 | return [int(x) for y in aux for x in y] 136 | 137 | doc_indices_tr = create_doc_indices(docs_tr) 138 | doc_indices_ts = create_doc_indices(docs_ts) 139 | doc_indices_ts_h1 = create_doc_indices(docs_ts_h1) 140 | doc_indices_ts_h2 = create_doc_indices(docs_ts_h2) 141 | doc_indices_va = create_doc_indices(docs_va) 142 | 143 | print(' len(np.unique(doc_indices_tr)): {} [this should be {}]'.format(len(np.unique(doc_indices_tr)), len(docs_tr))) 144 | print(' len(np.unique(doc_indices_ts)): {} [this should be {}]'.format(len(np.unique(doc_indices_ts)), len(docs_ts))) 145 | print(' len(np.unique(doc_indices_ts_h1)): {} [this should be {}]'.format(len(np.unique(doc_indices_ts_h1)), len(docs_ts_h1))) 146 | print(' len(np.unique(doc_indices_ts_h2)): {} [this should be {}]'.format(len(np.unique(doc_indices_ts_h2)), len(docs_ts_h2))) 147 | print(' len(np.unique(doc_indices_va)): {} [this should be {}]'.format(len(np.unique(doc_indices_va)), len(docs_va))) 148 | 149 | # Number of documents in each set 150 | n_docs_tr = len(docs_tr) 151 | n_docs_ts = len(docs_ts) 152 | n_docs_ts_h1 = len(docs_ts_h1) 153 | n_docs_ts_h2 = len(docs_ts_h2) 154 | n_docs_va = len(docs_va) 155 | 156 | # Remove unused variables 157 | del docs_tr 158 | del docs_ts 159 | del docs_ts_h1 160 | del docs_ts_h2 161 | del docs_va 162 | 163 | # Create bow representation 164 | print('creating bow representation...') 165 | 166 | def create_bow(doc_indices, words, n_docs, vocab_size): 167 | return sparse.coo_matrix(([1]*len(doc_indices),(doc_indices, words)), shape=(n_docs, vocab_size)).tocsr() 168 | 169 | bow_tr = create_bow(doc_indices_tr, words_tr, n_docs_tr, len(vocab)) 170 | bow_ts = create_bow(doc_indices_ts, words_ts, n_docs_ts, len(vocab)) 171 | bow_ts_h1 = create_bow(doc_indices_ts_h1, words_ts_h1, n_docs_ts_h1, len(vocab)) 172 | bow_ts_h2 = create_bow(doc_indices_ts_h2, words_ts_h2, n_docs_ts_h2, len(vocab)) 173 | bow_va = create_bow(doc_indices_va, words_va, n_docs_va, len(vocab)) 174 | 175 | del words_tr 176 | del words_ts 177 | del words_ts_h1 178 | del words_ts_h2 179 | del words_va 180 | del doc_indices_tr 181 | del doc_indices_ts 182 | del doc_indices_ts_h1 183 | del doc_indices_ts_h2 184 | del doc_indices_va 185 | 186 | # Write the vocabulary to a file 187 | path_save = './min_df_' + str(min_df) + '/' 188 | if not os.path.isdir(path_save): 189 | os.system('mkdir -p ' + path_save) 190 | 191 | with open(path_save + 'vocab.pkl', 'wb') as f: 192 | pickle.dump(vocab, f) 193 | del vocab 194 | 195 | # Split bow intro token/value pairs 196 | print('splitting bow intro token/value pairs and saving to disk...') 197 | 198 | def split_bow(bow_in, n_docs): 199 | indices = [[w for w in bow_in[doc,:].indices] for doc in range(n_docs)] 200 | counts = [[c for c in bow_in[doc,:].data] for doc in range(n_docs)] 201 | return indices, counts 202 | 203 | bow_tr_tokens, bow_tr_counts = split_bow(bow_tr, n_docs_tr) 204 | savemat(path_save + 'bow_tr_tokens', {'tokens': bow_tr_tokens}, do_compression=True) 205 | savemat(path_save + 'bow_tr_counts', {'counts': bow_tr_counts}, do_compression=True) 206 | del bow_tr 207 | del bow_tr_tokens 208 | del bow_tr_counts 209 | 210 | bow_ts_tokens, bow_ts_counts = split_bow(bow_ts, n_docs_ts) 211 | savemat(path_save + 'bow_ts_tokens', {'tokens': bow_ts_tokens}, do_compression=True) 212 | savemat(path_save + 'bow_ts_counts', {'counts': bow_ts_counts}, do_compression=True) 213 | del bow_ts 214 | del bow_ts_tokens 215 | del bow_ts_counts 216 | 217 | bow_ts_h1_tokens, bow_ts_h1_counts = split_bow(bow_ts_h1, n_docs_ts_h1) 218 | savemat(path_save + 'bow_ts_h1_tokens', {'tokens': bow_ts_h1_tokens}, do_compression=True) 219 | savemat(path_save + 'bow_ts_h1_counts', {'counts': bow_ts_h1_counts}, do_compression=True) 220 | del bow_ts_h1 221 | del bow_ts_h1_tokens 222 | del bow_ts_h1_counts 223 | 224 | bow_ts_h2_tokens, bow_ts_h2_counts = split_bow(bow_ts_h2, n_docs_ts_h2) 225 | savemat(path_save + 'bow_ts_h2_tokens', {'tokens': bow_ts_h2_tokens}, do_compression=True) 226 | savemat(path_save + 'bow_ts_h2_counts', {'counts': bow_ts_h2_counts}, do_compression=True) 227 | del bow_ts_h2 228 | del bow_ts_h2_tokens 229 | del bow_ts_h2_counts 230 | 231 | bow_va_tokens, bow_va_counts = split_bow(bow_va, n_docs_va) 232 | savemat(path_save + 'bow_va_tokens', {'tokens': bow_va_tokens}, do_compression=True) 233 | savemat(path_save + 'bow_va_counts', {'counts': bow_va_counts}, do_compression=True) 234 | del bow_va 235 | del bow_va_tokens 236 | del bow_va_counts 237 | 238 | print('Data ready !!') 239 | print('*************') 240 | 241 | -------------------------------------------------------------------------------- /inst/tinytest/test_end_to_end.R.R: -------------------------------------------------------------------------------- 1 | ## 2 | ## This test uses the ng20 dataset and compares the output obtained by running an ETM model fit using R 3 | ## to the same ETM model fit using the original Python implementation (https://github.com/bnosac-dev/ETM) 4 | ## - on the same data 5 | ## - using the same seed '2019' and without shuffling of the training data 6 | ## - with the same hyperparameters 7 | ## - using libtorch 1.9.0 on CPU 8 | ## 9 | ## Note that this uses no pretrained embeddings and used the following model parameters 10 | ## args.seed = 2019 11 | ## args.data_path = "dev/ETM/data/20ng" 12 | ## args.emb_size = 3 13 | ## args.train_embeddings = True 14 | ## args.epochs = 2 >>> epoch = 2 15 | ## args.t_hidden_size = 5 >>> dim = 5 16 | ## args.rho_size = 3 >>> embeddings = 3 17 | ## args.num_topics = 4 >>> k = 4 18 | ## args.optimizer == 'adam' >>> optim_adam 19 | ## args.lr=0.005 >>> lr = 0.005 20 | ## args.wdeay=1.2e-06 >>> weight_decay = 0.0000012 21 | ## args.theta_act='relu' >>> activation='relu' 22 | ## args.enc_drop=0 >>> dropout=0 23 | 24 | ## rho: word embeddings, alpha: topic emittance 25 | 26 | if (torch::torch_is_installed()) { 27 | library(torch) 28 | library(topicmodels.etm) 29 | path_tinytest_data <- system.file(package = "topicmodels.etm", "tinytest", "data") 30 | #path_tinytest_data <- "inst/tinytest/data" 31 | 32 | data(ng20, package = "topicmodels.etm") 33 | vocab <- ng20$vocab 34 | tokens <- ng20$bow_tr$tokens 35 | counts <- ng20$bow_tr$counts 36 | 37 | set.seed(2019) 38 | torch_manual_seed(2019) 39 | model <- ETM(k = 4, vocab = vocab, dim = 5, embeddings = 3, activation = 'relu', dropout = 0) 40 | 41 | ######################################################################################################## 42 | ## check initialisation/randomisation is the same (works since torch R package version 0.5) 43 | ## 44 | ## in R: q_theta.0.weight == in Python 0.weight 45 | ## in R: q_theta.0.bias == in Python 0.bias 46 | ## in R: q_theta.2.weight == in Python 2.weight 47 | ## in R: q_theta.2.bias == in Python 2.bias 48 | ## in R: q_theta.2.bias == in Python 2.bias 49 | #sapply(model$parameters, FUN = function(x) x$numel()) 50 | #model <- self 51 | params_r <- model$named_parameters() 52 | params_r$beta <- as.matrix(model$get_beta()) 53 | params_r$rho <- as.matrix(model$parameters$rho.weight) 54 | params_r$alphas <- as.matrix(model$parameters$alphas.weight) 55 | params_r$mu_q_theta.weight <- as.matrix(model$parameters$mu_q_theta.weight) 56 | params_r$mu_q_theta.bias <- as.numeric(model$parameters$mu_q_theta.bias) 57 | params_r$logsigma_q_theta.weight <- as.matrix(model$parameters$logsigma_q_theta.weight) 58 | params_r$logsigma_q_theta.bias <- as.numeric(model$parameters$logsigma_q_theta.bias) 59 | params_python <- list() 60 | params_python[["0.weight"]] <- as.numeric(readLines(file.path(path_tinytest_data, "init-0.weight.txt"), warn = FALSE)) 61 | params_python[["0.bias"]] <- as.numeric(readLines(file.path(path_tinytest_data, "init-0.bias.txt"), warn = FALSE)) 62 | params_python[["2.weight"]] <- as.numeric(readLines(file.path(path_tinytest_data, "init-2.weight.txt"), warn = FALSE)) 63 | params_python[["2.bias"]] <- as.numeric(readLines(file.path(path_tinytest_data, "init-2.bias.txt"), warn = FALSE)) 64 | params_python[["beta"]] <- as.numeric(readLines(file.path(path_tinytest_data, "init-beta.txt"), warn = FALSE)) 65 | params_python[["rho"]] <- as.numeric(readLines(file.path(path_tinytest_data, "init-rho.txt"), warn = FALSE)) 66 | params_python[["alphas"]] <- as.numeric(readLines(file.path(path_tinytest_data, "init-alphas.txt"), warn = FALSE)) 67 | params_python[["mu_q_theta.weight"]] <- as.numeric(readLines(file.path(path_tinytest_data, "init-mu_q_theta.weight.txt"), warn = FALSE)) 68 | params_python[["mu_q_theta.bias"]] <- as.numeric(readLines(file.path(path_tinytest_data, "init-mu_q_theta.bias.txt"), warn = FALSE)) 69 | params_python[["logsigma_q_theta.weight"]] <- as.numeric(readLines(file.path(path_tinytest_data, "init-logsigma_q_theta.weight.txt"), warn = FALSE)) 70 | params_python[["logsigma_q_theta.bias"]] <- as.numeric(readLines(file.path(path_tinytest_data, "init-logsigma_q_theta.bias.txt"), warn = FALSE)) 71 | params_python[["2.weight"]] <- matrix(params_python[["2.weight"]], nrow = nrow(params_r$q_theta.2.weight), byrow = TRUE) 72 | params_python[["0.weight"]] <- matrix(params_python[["0.weight"]], nrow = nrow(params_r$q_theta.0.weight), byrow = TRUE) 73 | params_python[["beta"]] <- matrix(params_python[["beta"]], nrow = nrow(params_r$beta), byrow = TRUE) 74 | params_python[["rho"]] <- matrix(params_python[["rho"]], nrow = nrow(params_r$rho), byrow = TRUE) 75 | params_python[["alphas"]] <- matrix(params_python[["alphas"]], nrow = nrow(params_r$alphas), byrow = TRUE) 76 | params_python[["mu_q_theta.weight"]] <- matrix(params_python[["mu_q_theta.weight"]], nrow = nrow(params_r$mu_q_theta.weight), byrow = TRUE) 77 | params_python[["logsigma_q_theta.weight"]] <- matrix(params_python[["logsigma_q_theta.weight"]], nrow = nrow(params_r$logsigma_q_theta.weight), byrow = TRUE) 78 | 79 | expect_equal(as.numeric(params_r$q_theta.0.bias), params_python[["0.bias"]]) 80 | expect_equal(as.numeric(params_r$q_theta.2.bias), params_python[["2.bias"]]) 81 | expect_equal(as.matrix(params_r$q_theta.2.weight), params_python[["2.weight"]]) 82 | expect_equal(as.matrix(params_r$q_theta.0.weight), params_python[["0.weight"]]) 83 | expect_equal(as.matrix(params_r$beta), params_python[["beta"]]) 84 | expect_equal(as.matrix(params_r$rho), params_python[["rho"]]) 85 | expect_equal(as.matrix(params_r$alphas), params_python[["alphas"]]) 86 | expect_equal(as.numeric(params_r$mu_q_theta.bias), params_python[["mu_q_theta.bias"]]) 87 | expect_equal(as.numeric(params_r$logsigma_q_theta.bias), params_python[["logsigma_q_theta.bias"]]) 88 | expect_equal(as.matrix(params_r$mu_q_theta.weight), params_python[["mu_q_theta.weight"]]) 89 | expect_equal(as.matrix(params_r$logsigma_q_theta.weight), params_python[["logsigma_q_theta.weight"]]) 90 | ## 91 | ## train the model - note that for this we have made sure permuting is not done on the training data in the python script 92 | ## 93 | #optimizer <- optim_sgd(params = model$parameters) 94 | optimizer <- optim_adam(params = model$parameters, lr = 0.005, weight_decay = 0.0000012) 95 | 96 | traindata <- list(tokens = tokens, counts = counts, vocab = vocab) 97 | test1 <- list(tokens = ng20$bow_ts_h1$tokens, counts = ng20$bow_ts_h1$counts, vocab = vocab) 98 | test2 <- list(tokens = ng20$bow_ts_h2$tokens, counts = ng20$bow_ts_h2$counts, vocab = vocab) 99 | 100 | set.seed(2019) 101 | torch_manual_seed(2019) 102 | #debugonce(model$train_epoch) 103 | out <- model$fit_original(data = traindata, test1 = test1, test2 = test2, epoch = 2, 104 | optimizer = optimizer, batch_size = 1000, 105 | lr_anneal_factor = 4, lr_anneal_nonmono = 10, 106 | permute = FALSE) ## do not permute training data in order to get same run as in python 107 | 108 | if(FALSE){ 109 | ## 110 | ## We did various checks on correct implementation of the different subfunctions of the estimation procedure giving exact match with the python implementation 111 | ## as well as the progress of the loss, kl_theta, nelbo which are all the same 112 | ## Unfortunately this doesn's make a test 100% bullet proof on libtorch apparently 113 | ## we noticed all parameters are the same when comparing R / Python implementation even including the backward step but when doing optimizer.step, the results diverge 114 | ## namely at https://github.com/bnosac/ETM/blob/master/R/ETM.R#L323 115 | ## probably this is something we can't make an end-to-end unit test for and we should make an issue at the torch repository 116 | ## 117 | #v <- model$get_beta() 118 | #v <- as.matrix(v) 119 | #pv <- as.numeric(readLines("dev/ETM/test.txt", warn = FALSE)) 120 | #pv <- matrix(pv, nrow = nrow(v), byrow = TRUE) 121 | #all.equal(v, pv) 122 | 123 | ######################################################################################################## 124 | ## check params after model run is the same (works since torch R package version 0.5) 125 | ## 126 | ## in R: q_theta.0.weight == in Python 0.weight 127 | ## in R: q_theta.0.bias == in Python 0.bias 128 | ## in R: q_theta.2.weight == in Python 2.weight 129 | ## in R: q_theta.2.bias == in Python 2.bias 130 | params_r <- model$named_parameters() 131 | params_r$beta <- as.matrix(model$get_beta()) 132 | params_r$rho <- as.matrix(model$parameters$rho.weight) 133 | params_r$alphas <- as.matrix(model$parameters$alphas.weight) 134 | params_python <- list() 135 | params_python[["0.weight"]] <- as.numeric(readLines(file.path(path_tinytest_data, "end-0.weight.txt"), warn = FALSE)) 136 | params_python[["0.bias"]] <- as.numeric(readLines(file.path(path_tinytest_data, "end-0.bias.txt"), warn = FALSE)) 137 | params_python[["2.weight"]] <- as.numeric(readLines(file.path(path_tinytest_data, "end-2.weight.txt"), warn = FALSE)) 138 | params_python[["2.bias"]] <- as.numeric(readLines(file.path(path_tinytest_data, "end-2.bias.txt"), warn = FALSE)) 139 | params_python[["beta"]] <- as.numeric(readLines(file.path(path_tinytest_data, "end-beta.txt"), warn = FALSE)) 140 | params_python[["rho"]] <- as.numeric(readLines(file.path(path_tinytest_data, "end-rho.txt"), warn = FALSE)) 141 | params_python[["alphas"]] <- as.numeric(readLines(file.path(path_tinytest_data, "end-alphas.txt"), warn = FALSE)) 142 | params_python[["2.weight"]] <- matrix(params_python[["2.weight"]], nrow = nrow(params_r$q_theta.2.weight), byrow = TRUE) 143 | params_python[["0.weight"]] <- matrix(params_python[["0.weight"]], nrow = nrow(params_r$q_theta.0.weight), byrow = TRUE) 144 | params_python[["beta"]] <- matrix(params_python[["beta"]], nrow = nrow(params_r$beta), byrow = TRUE) 145 | params_python[["rho"]] <- matrix(params_python[["rho"]], nrow = nrow(params_r$rho), byrow = TRUE) 146 | params_python[["alphas"]] <- matrix(params_python[["alphas"]], nrow = nrow(params_r$alphas), byrow = TRUE) 147 | 148 | expect_equal(as.numeric(params_r$q_theta.0.bias), params_python[["0.bias"]]) 149 | expect_equal(as.numeric(params_r$q_theta.2.bias), params_python[["2.bias"]]) 150 | expect_equal(as.matrix(params_r$q_theta.2.weight), params_python[["2.weight"]]) 151 | expect_equal(as.matrix(params_r$q_theta.0.weight), params_python[["0.weight"]]) 152 | expect_equal(as.matrix(params_r$beta), params_python[["beta"]]) 153 | expect_equal(as.matrix(params_r$beta), params_python[["beta"]], tolerance = 0.00001) 154 | expect_equal(as.matrix(params_r$rho), params_python[["rho"]]) 155 | expect_equal(as.matrix(params_r$alphas), params_python[["alphas"]]) 156 | 157 | # test <- subset(out$loss, out$loss$batch_is_last == TRUE) 158 | # plot(test$epoch, test$loss) 159 | # 160 | # topic.centers <- as.matrix(model, type = "embedding", which = "topics") 161 | # word.embeddings <- as.matrix(model, type = "embedding", which = "words") 162 | # topic.terminology <- as.matrix(model, type = "beta") 163 | # 164 | # terminology <- predict(model, type = "terms", top_n = 4) 165 | # terminology 166 | } 167 | } 168 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ETM - R package for Topic Modelling in Embedding Spaces 2 | 3 | This repository contains an R package called `topicmodels.etm` which is an implementation of ETM 4 | 5 | - ETM is a generative topic model combining traditional topic models (LDA) with word embeddings (word2vec) 6 | - It models each word with a categorical distribution whose natural parameter is the inner product between a word embedding and an embedding of its assigned topic 7 | - The model is fitted using an amortized variational inference algorithm on top of libtorch (https://torch.mlverse.org) 8 | - The techniques are explained in detail in the paper: "Topic Modelling in Embedding Spaces" by Adji B. Dieng, Francisco J. R. Ruiz and David M. Blei, available at https://arxiv.org/pdf/1907.04907.pdf 9 | 10 | ![](tools/example-visualisation.png) 11 | 12 | ### Installation 13 | 14 | - For installing the package from CRAN: 15 | 16 | ``` 17 | pkgs <- c("torch", "topicmodels.etm", "word2vec", "doc2vec", "udpipe", "uwot") 18 | install.packages(pkgs) 19 | library(torch) 20 | library(topicmodels.etm) 21 | ``` 22 | 23 | - For installing the development version of this package: you can perform the following installations in R: 24 | 25 | ``` 26 | install.packages("torch") 27 | install.packages("word2vec") 28 | install.packages("doc2vec") 29 | install.packages("udpipe") 30 | install.packages("remotes") 31 | library(torch) 32 | remotes::install_github('bnosac/ETM', INSTALL_opts = '--no-multiarch') 33 | ``` 34 | 35 | - For allowing to plot the models: 36 | 37 | ``` 38 | install.packages("textplot") 39 | install.packages("ggrepel") 40 | install.packages("ggalt") 41 | ``` 42 | 43 | ### Example 44 | 45 | Build a topic model on questions answered in Belgian parliament in 2020 in Dutch. 46 | 47 | #### a. Get data 48 | 49 | - Example text of +/- 6000 questions asked in the Belgian parliament (available in R package doc2vec). 50 | - Standardise the text a bit 51 | 52 | ``` 53 | library(torch) 54 | library(topicmodels.etm) 55 | library(doc2vec) 56 | library(word2vec) 57 | data(be_parliament_2020, package = "doc2vec") 58 | x <- data.frame(doc_id = be_parliament_2020$doc_id, 59 | text = be_parliament_2020$text_nl, 60 | stringsAsFactors = FALSE) 61 | x$text <- txt_clean_word2vec(x$text) 62 | ``` 63 | 64 | #### b. Build a word2vec model to get word embeddings and inspect it a bit 65 | 66 | ``` 67 | w2v <- word2vec(x = x$text, dim = 25, type = "skip-gram", iter = 10, min_count = 5, threads = 2) 68 | embeddings <- as.matrix(w2v) 69 | predict(w2v, newdata = c("migranten", "belastingen"), type = "nearest", top_n = 4) 70 | $migranten 71 | term1 term2 similarity rank 72 | 1 migranten lesbos 0.9434163 1 73 | 2 migranten chios 0.9334459 2 74 | 3 migranten vluchtelingenkampen 0.9269973 3 75 | 4 migranten kamp 0.9175452 4 76 | 77 | $belastingen 78 | term1 term2 similarity rank 79 | 1 belastingen belasting 0.9458982 1 80 | 2 belastingen ontvangsten 0.9091899 2 81 | 3 belastingen geheven 0.9071115 3 82 | 4 belastingen ontduiken 0.9029559 4 83 | ``` 84 | 85 | #### c. Build the embedding topic model 86 | 87 | - Create a bag of words document term matrix (using the udpipe package but other R packages provide similar functionalities) 88 | - Keep only the top 50% terms with the highest TFIDF 89 | - Make sure document/term/matrix and the embedding matrix have the same vocabulary 90 | 91 | ``` 92 | library(udpipe) 93 | dtm <- strsplit.data.frame(x, group = "doc_id", term = "text", split = " ") 94 | dtm <- document_term_frequencies(dtm) 95 | dtm <- document_term_matrix(dtm) 96 | dtm <- dtm_remove_tfidf(dtm, prob = 0.50) 97 | 98 | vocab <- intersect(rownames(embeddings), colnames(dtm)) 99 | embeddings <- dtm_conform(embeddings, rows = vocab) 100 | dtm <- dtm_conform(dtm, columns = vocab) 101 | dim(dtm) 102 | dim(embeddings) 103 | ``` 104 | 105 | - Learn 20 topics with a 100-dimensional hyperparameter for the variational inference 106 | 107 | ``` 108 | set.seed(1234) 109 | torch_manual_seed(4321) 110 | model <- ETM(k = 20, dim = 100, embeddings = embeddings) 111 | optimizer <- optim_adam(params = model$parameters, lr = 0.005, weight_decay = 0.0000012) 112 | loss <- model$fit(data = dtm, optimizer = optimizer, epoch = 20, batch_size = 1000) 113 | plot(model, type = "loss") 114 | ``` 115 | 116 | ![](tools/loss-evolution.png) 117 | 118 | 119 | #### d. Inspect the model 120 | 121 | ``` 122 | terminology <- predict(model, type = "terms", top_n = 5) 123 | terminology 124 | [[1]] 125 | term beta 126 | 3891 zelfstandigen 0.05245856 127 | 2543 opdeling 0.02827548 128 | 5469 werkloosheid 0.02366866 129 | 3611 ocmw 0.01772762 130 | 4957 zelfstandige 0.01139760 131 | 132 | [[2]] 133 | term beta 134 | 3891 zelfstandigen 0.032309771 135 | 5469 werkloosheid 0.021119611 136 | 4957 zelfstandige 0.010217560 137 | 3611 ocmw 0.009712025 138 | 2543 opdeling 0.008961252 139 | 140 | [[3]] 141 | term beta 142 | 2537 gedetineerden 0.02914266 143 | 3827 nationaliteit 0.02540042 144 | 3079 gevangenis 0.02136421 145 | 5311 gevangenissen 0.01215335 146 | 3515 asielzoekers 0.01204639 147 | 148 | [[4]] 149 | term beta 150 | 3435 btw 0.02814350 151 | 5536 kostprijs 0.02012880 152 | 3508 pod 0.01218093 153 | 2762 vzw 0.01088356 154 | 2996 vennootschap 0.01015108 155 | 156 | [[5]] 157 | term beta 158 | 3372 verbaal 0.011172118 159 | 3264 politiezone 0.008422602 160 | 3546 arrondissement 0.007855867 161 | 3052 inbreuken 0.007204257 162 | 2543 opdeling 0.007149355 163 | 164 | [[6]] 165 | term beta 166 | 3296 instelling 0.04442037 167 | 3540 wetenschappelijke 0.03434755 168 | 2652 china 0.02702594 169 | 3043 volksrepubliek 0.01844959 170 | 3893 hongkong 0.01792639 171 | 172 | [[7]] 173 | term beta 174 | 2133 databank 0.003111386 175 | 3079 gevangenis 0.002650804 176 | 3255 dvz 0.002098217 177 | 3614 centra 0.001884672 178 | 2142 geneesmiddelen 0.001791468 179 | 180 | [[8]] 181 | term beta 182 | 2547 defensie 0.03706463 183 | 3785 kabinet 0.01323747 184 | 4054 griekse 0.01317877 185 | 3750 turkse 0.01238277 186 | 3076 leger 0.00964661 187 | 188 | [[9]] 189 | term beta 190 | 3649 nmbs 0.005472604 191 | 3704 beslag 0.004442090 192 | 2457 nucleaire 0.003911803 193 | 2461 mondmaskers 0.003712016 194 | 3533 materiaal 0.003513884 195 | 196 | [[10]] 197 | term beta 198 | 4586 politiezones 0.017413139 199 | 2248 voertuigen 0.012508971 200 | 3649 nmbs 0.008157282 201 | 2769 politieagenten 0.007591151 202 | 3863 beelden 0.006747020 203 | 204 | [[11]] 205 | term beta 206 | 3827 nationaliteit 0.009992087 207 | 4912 duitse 0.008966853 208 | 3484 turkije 0.008940011 209 | 2652 china 0.008723009 210 | 4008 overeenkomst 0.007879931 211 | 212 | [[12]] 213 | term beta 214 | 3651 opsplitsen 0.008752496 215 | 4247 kinderen 0.006497230 216 | 2606 sciensano 0.006430181 217 | 3170 tests 0.006420473 218 | 3587 studenten 0.006165542 219 | 220 | [[13]] 221 | term beta 222 | 3052 inbreuken 0.007657704 223 | 2447 drugs 0.006734609 224 | 2195 meldingen 0.005259825 225 | 3372 verbaal 0.005117311 226 | 3625 cyberaanvallen 0.004269334 227 | 228 | [[14]] 229 | term beta 230 | 2234 gebouwen 0.06128503 231 | 3531 digitale 0.03030998 232 | 3895 bpost 0.02974019 233 | 4105 regie 0.02608073 234 | 3224 infrabel 0.01758554 235 | 236 | [[15]] 237 | term beta 238 | 3649 nmbs 0.08117295 239 | 3826 station 0.03944306 240 | 3911 trein 0.03548101 241 | 4965 treinen 0.02843846 242 | 3117 stations 0.02732874 243 | 244 | [[16]] 245 | term beta 246 | 3649 nmbs 0.06778506 247 | 3240 personeelsleden 0.03363639 248 | 2972 telewerk 0.01857295 249 | 4965 treinen 0.01807373 250 | 3785 kabinet 0.01702784 251 | 252 | [[17]] 253 | term beta 254 | 2371 app 0.009092372 255 | 3265 stoffen 0.006641808 256 | 2461 mondmaskers 0.006462210 257 | 3025 persoonsgegevens 0.005374488 258 | 2319 websites 0.005372964 259 | 260 | [[18]] 261 | term beta 262 | 5296 aangifte 0.01940070 263 | 3435 btw 0.01360575 264 | 2762 vzw 0.01307520 265 | 2756 facturen 0.01233578 266 | 2658 rekenhof 0.01196285 267 | 268 | [[19]] 269 | term beta 270 | 3631 beperking 0.017481016 271 | 3069 handicap 0.010403863 272 | 3905 tewerkstelling 0.009714387 273 | 3785 kabinet 0.006984415 274 | 2600 ombudsman 0.006074827 275 | 276 | [[20]] 277 | term beta 278 | 3228 geweld 0.05881281 279 | 4178 vrouwen 0.05113553 280 | 4247 kinderen 0.04818219 281 | 2814 jongeren 0.01803746 282 | 2195 meldingen 0.01548613 283 | ``` 284 | 285 | #### e. Predict alongside the model 286 | 287 | ``` 288 | newdata <- head(dtm, n = 5) 289 | scores <- predict(model, newdata, type = "topics") 290 | scores 291 | ``` 292 | 293 | #### f. Save / Load model 294 | 295 | ``` 296 | torch_save(model, "example_etm.ckpt") 297 | model <- torch_load("example_etm.ckpt") 298 | ``` 299 | 300 | #### g. Optionally - visualise the model in 2D 301 | 302 | Example plot shown above was created using the following code 303 | 304 | - This uses R package [textplot](https://github.com/bnosac/textplot) >= 0.2.0 which was updated on CRAN on 2021-08-18 305 | - The summary function maps the learned embeddings of the words and topic centers in 2D using [UMAP](https://github.com/jlmelville/uwot) and textplot_embedding_2d plots the selected topics of interest in 2D 306 | 307 | ``` 308 | library(textplot) 309 | library(uwot) 310 | library(ggrepel) 311 | library(ggalt) 312 | manifolded <- summary(model, type = "umap", n_components = 2, metric = "cosine", n_neighbors = 15, 313 | fast_sgd = FALSE, n_threads = 2, verbose = TRUE) 314 | space <- subset(manifolded$embed_2d, type %in% "centers") 315 | textplot_embedding_2d(space) 316 | space <- subset(manifolded$embed_2d, cluster %in% c(12, 14, 9, 7) & rank <= 7) 317 | textplot_embedding_2d(space, title = "ETM topics", subtitle = "embedded in 2D using UMAP", 318 | encircle = FALSE, points = TRUE) 319 | ``` 320 | 321 | ![](tools/example-visualisation-basic.png) 322 | 323 | #### z. Or you can brew up your own code to plot things 324 | 325 | - Put embeddings of words and topic centers in 2D using UMAP 326 | 327 | ``` 328 | library(uwot) 329 | centers <- as.matrix(model, type = "embedding", which = "topics") 330 | embeddings <- as.matrix(model, type = "embedding", which = "words") 331 | manifold <- umap(embeddings, 332 | n_components = 2, metric = "cosine", n_neighbors = 15, fast_sgd = TRUE, 333 | n_threads = 2, ret_model = TRUE, verbose = TRUE) 334 | centers <- umap_transform(X = centers, model = manifold) 335 | words <- manifold$embedding 336 | ``` 337 | 338 | - Plot words in 2D, color by topic and add topic centers in 2D 339 | - This uses R package textplot >= 0.2.0 (https://github.com/bnosac/textplot) which was put on CRAN on 2021-08-18 340 | 341 | ``` 342 | library(data.table) 343 | terminology <- predict(model, type = "terms", top_n = 7) 344 | terminology <- rbindlist(terminology, idcol = "cluster") 345 | df <- list(words = merge(x = terminology, 346 | y = data.frame(x = words[, 1], y = words[, 2], term = rownames(embeddings)), 347 | by = "term"), 348 | centers = data.frame(x = centers[, 1], y = centers[, 2], 349 | term = paste("Topic-", seq_len(nrow(centers)), sep = ""), 350 | cluster = seq_len(nrow(centers)))) 351 | df <- rbindlist(df, use.names = TRUE, fill = TRUE, idcol = "type") 352 | df <- df[, weight := ifelse(is.na(beta), 0.8, beta / max(beta, na.rm = TRUE)), by = list(cluster)] 353 | 354 | library(textplot) 355 | library(ggrepel) 356 | library(ggalt) 357 | x <- subset(df, type %in% c("words", "centers") & cluster %in% c(1, 3, 4, 8)) 358 | textplot_embedding_2d(x, title = "ETM topics", subtitle = "embedded in 2D using UMAP", encircle = FALSE, points = FALSE) 359 | textplot_embedding_2d(x, title = "ETM topics", subtitle = "embedded in 2D using UMAP", encircle = TRUE, points = TRUE) 360 | ``` 361 | 362 | - Or if you like writing down the full ggplot2 code 363 | 364 | ``` 365 | library(ggplot2) 366 | library(ggrepel) 367 | x$topic <- factor(x$cluster) 368 | plt <- ggplot(x, 369 | aes(x = x, y = y, label = term, color = topic, cex = weight, pch = factor(type, levels = c("centers", "words")))) + 370 | geom_text_repel(show.legend = FALSE) + 371 | theme_void() + 372 | labs(title = "ETM topics", subtitle = "embedded in 2D using UMAP") 373 | plt + geom_point(show.legend = FALSE) 374 | 375 | ## encircle if topics are non-overlapping can provide nice visualisations 376 | library(ggalt) 377 | plt + geom_encircle(aes(group = topic, fill = topic), alpha = 0.4, show.legend = FALSE) + geom_point(show.legend = FALSE) 378 | ``` 379 | 380 | > More examples are provided in the help of the ETM function see `?ETM` 381 | > Don't forget to set seeds to have reproducible behaviour 382 | 383 | ## Support in text mining 384 | 385 | Need support in text mining? 386 | Contact BNOSAC: http://www.bnosac.be 387 | 388 | -------------------------------------------------------------------------------- /inst/orig/ETM/main.py: -------------------------------------------------------------------------------- 1 | #/usr/bin/python 2 | 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import torch 7 | import pickle 8 | import numpy as np 9 | import os 10 | import math 11 | import random 12 | import sys 13 | import matplotlib.pyplot as plt 14 | import data 15 | import scipy.io 16 | 17 | from torch import nn, optim 18 | from torch.nn import functional as F 19 | 20 | from etm import ETM 21 | from utils import nearest_neighbors, get_topic_coherence, get_topic_diversity 22 | 23 | parser = argparse.ArgumentParser(description='The Embedded Topic Model') 24 | 25 | ### data and file related arguments 26 | parser.add_argument('--dataset', type=str, default='20ng', help='name of corpus') 27 | parser.add_argument('--data_path', type=str, default='data/20ng', help='directory containing data') 28 | parser.add_argument('--emb_path', type=str, default='data/20ng_embeddings.txt', help='directory containing word embeddings') 29 | parser.add_argument('--save_path', type=str, default='./results', help='path to save results') 30 | parser.add_argument('--batch_size', type=int, default=1000, help='input batch size for training') 31 | 32 | ### model-related arguments 33 | parser.add_argument('--num_topics', type=int, default=50, help='number of topics') 34 | parser.add_argument('--rho_size', type=int, default=300, help='dimension of rho') 35 | parser.add_argument('--emb_size', type=int, default=300, help='dimension of embeddings') 36 | parser.add_argument('--t_hidden_size', type=int, default=800, help='dimension of hidden space of q(theta)') 37 | parser.add_argument('--theta_act', type=str, default='relu', help='tanh, softplus, relu, rrelu, leakyrelu, elu, selu, glu)') 38 | parser.add_argument('--train_embeddings', type=int, default=0, help='whether to fix rho or train it') 39 | 40 | ### optimization-related arguments 41 | parser.add_argument('--lr', type=float, default=0.005, help='learning rate') 42 | parser.add_argument('--lr_factor', type=float, default=4.0, help='divide learning rate by this...') 43 | parser.add_argument('--epochs', type=int, default=20, help='number of epochs to train...150 for 20ng 100 for others') 44 | parser.add_argument('--mode', type=str, default='train', help='train or eval model') 45 | parser.add_argument('--optimizer', type=str, default='adam', help='choice of optimizer') 46 | parser.add_argument('--seed', type=int, default=2019, help='random seed (default: 1)') 47 | parser.add_argument('--enc_drop', type=float, default=0.0, help='dropout rate on encoder') 48 | parser.add_argument('--clip', type=float, default=0.0, help='gradient clipping') 49 | parser.add_argument('--nonmono', type=int, default=10, help='number of bad hits allowed') 50 | parser.add_argument('--wdecay', type=float, default=1.2e-6, help='some l2 regularization') 51 | parser.add_argument('--anneal_lr', type=int, default=0, help='whether to anneal the learning rate or not') 52 | parser.add_argument('--bow_norm', type=int, default=1, help='normalize the bows or not') 53 | 54 | ### evaluation, visualization, and logging-related arguments 55 | parser.add_argument('--num_words', type=int, default=10, help='number of words for topic viz') 56 | parser.add_argument('--log_interval', type=int, default=2, help='when to log training') 57 | parser.add_argument('--visualize_every', type=int, default=10, help='when to visualize results') 58 | parser.add_argument('--eval_batch_size', type=int, default=1000, help='input batch size for evaluation') 59 | parser.add_argument('--load_from', type=str, default='', help='the name of the ckpt to eval from') 60 | parser.add_argument('--tc', type=int, default=0, help='whether to compute topic coherence or not') 61 | parser.add_argument('--td', type=int, default=0, help='whether to compute topic diversity or not') 62 | 63 | args = parser.parse_args() 64 | 65 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 66 | 67 | print('\n') 68 | np.random.seed(args.seed) 69 | torch.manual_seed(args.seed) 70 | if torch.cuda.is_available(): 71 | torch.cuda.manual_seed(args.seed) 72 | 73 | ## get data 74 | # 1. vocabulary 75 | vocab, train, valid, test = data.get_data(os.path.join(args.data_path)) 76 | vocab_size = len(vocab) 77 | args.vocab_size = vocab_size 78 | 79 | # 1. training data 80 | train_tokens = train['tokens'] 81 | train_counts = train['counts'] 82 | args.num_docs_train = len(train_tokens) 83 | 84 | # 2. dev set 85 | valid_tokens = valid['tokens'] 86 | valid_counts = valid['counts'] 87 | args.num_docs_valid = len(valid_tokens) 88 | 89 | # 3. test data 90 | test_tokens = test['tokens'] 91 | test_counts = test['counts'] 92 | args.num_docs_test = len(test_tokens) 93 | test_1_tokens = test['tokens_1'] 94 | test_1_counts = test['counts_1'] 95 | args.num_docs_test_1 = len(test_1_tokens) 96 | test_2_tokens = test['tokens_2'] 97 | test_2_counts = test['counts_2'] 98 | args.num_docs_test_2 = len(test_2_tokens) 99 | 100 | embeddings = None 101 | if not args.train_embeddings: 102 | emb_path = args.emb_path 103 | vect_path = os.path.join(args.data_path.split('/')[0], 'embeddings.pkl') 104 | vectors = {} 105 | with open(emb_path, 'rb') as f: 106 | for l in f: 107 | line = l.decode().split() 108 | word = line[0] 109 | if word in vocab: 110 | vect = np.array(line[1:]).astype(np.float) 111 | vectors[word] = vect 112 | embeddings = np.zeros((vocab_size, args.emb_size)) 113 | words_found = 0 114 | for i, word in enumerate(vocab): 115 | try: 116 | embeddings[i] = vectors[word] 117 | words_found += 1 118 | except KeyError: 119 | embeddings[i] = np.random.normal(scale=0.6, size=(args.emb_size, )) 120 | embeddings = torch.from_numpy(embeddings).to(device) 121 | args.embeddings_dim = embeddings.size() 122 | 123 | print('=*'*100) 124 | print('Training an Embedded Topic Model on {} with the following settings: {}'.format(args.dataset.upper(), args)) 125 | print('=*'*100) 126 | 127 | ## define checkpoint 128 | if not os.path.exists(args.save_path): 129 | os.makedirs(args.save_path) 130 | 131 | if args.mode == 'eval': 132 | ckpt = args.load_from 133 | else: 134 | ckpt = os.path.join(args.save_path, 135 | 'etm_{}_K_{}_Htheta_{}_Optim_{}_Clip_{}_ThetaAct_{}_Lr_{}_Bsz_{}_RhoSize_{}_trainEmbeddings_{}'.format( 136 | args.dataset, args.num_topics, args.t_hidden_size, args.optimizer, args.clip, args.theta_act, 137 | args.lr, args.batch_size, args.rho_size, args.train_embeddings)) 138 | 139 | ## define model and optimizer 140 | model = ETM(args.num_topics, vocab_size, args.t_hidden_size, args.rho_size, args.emb_size, 141 | args.theta_act, embeddings, args.train_embeddings, args.enc_drop).to(device) 142 | 143 | print('model: {}'.format(model)) 144 | 145 | if args.optimizer == 'adam': 146 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wdecay) 147 | elif args.optimizer == 'adagrad': 148 | optimizer = optim.Adagrad(model.parameters(), lr=args.lr, weight_decay=args.wdecay) 149 | elif args.optimizer == 'adadelta': 150 | optimizer = optim.Adadelta(model.parameters(), lr=args.lr, weight_decay=args.wdecay) 151 | elif args.optimizer == 'rmsprop': 152 | optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.wdecay) 153 | elif args.optimizer == 'asgd': 154 | optimizer = optim.ASGD(model.parameters(), lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay) 155 | else: 156 | print('Defaulting to vanilla SGD') 157 | optimizer = optim.SGD(model.parameters(), lr=args.lr) 158 | 159 | def train(epoch): 160 | model.train() 161 | acc_loss = 0 162 | acc_kl_theta_loss = 0 163 | cnt = 0 164 | indices = torch.randperm(args.num_docs_train) 165 | indices = torch.split(indices, args.batch_size) 166 | for idx, ind in enumerate(indices): 167 | optimizer.zero_grad() 168 | model.zero_grad() 169 | data_batch = data.get_batch(train_tokens, train_counts, ind, args.vocab_size, device) 170 | sums = data_batch.sum(1).unsqueeze(1) 171 | if args.bow_norm: 172 | normalized_data_batch = data_batch / sums 173 | else: 174 | normalized_data_batch = data_batch 175 | recon_loss, kld_theta = model(data_batch, normalized_data_batch) 176 | total_loss = recon_loss + kld_theta 177 | total_loss.backward() 178 | 179 | if args.clip > 0: 180 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) 181 | optimizer.step() 182 | 183 | acc_loss += torch.sum(recon_loss).item() 184 | acc_kl_theta_loss += torch.sum(kld_theta).item() 185 | cnt += 1 186 | 187 | if idx % args.log_interval == 0 and idx > 0: 188 | cur_loss = round(acc_loss / cnt, 2) 189 | cur_kl_theta = round(acc_kl_theta_loss / cnt, 2) 190 | cur_real_loss = round(cur_loss + cur_kl_theta, 2) 191 | 192 | print('Epoch: {} .. batch: {}/{} .. LR: {} .. KL_theta: {} .. Rec_loss: {} .. NELBO: {}'.format( 193 | epoch, idx, len(indices), optimizer.param_groups[0]['lr'], cur_kl_theta, cur_loss, cur_real_loss)) 194 | 195 | cur_loss = round(acc_loss / cnt, 2) 196 | cur_kl_theta = round(acc_kl_theta_loss / cnt, 2) 197 | cur_real_loss = round(cur_loss + cur_kl_theta, 2) 198 | print('*'*100) 199 | print('Epoch----->{} .. LR: {} .. KL_theta: {} .. Rec_loss: {} .. NELBO: {}'.format( 200 | epoch, optimizer.param_groups[0]['lr'], cur_kl_theta, cur_loss, cur_real_loss)) 201 | print('*'*100) 202 | 203 | def visualize(m, show_emb=True): 204 | if not os.path.exists('./results'): 205 | os.makedirs('./results') 206 | 207 | m.eval() 208 | 209 | queries = ['andrew', 'computer', 'sports', 'religion', 'man', 'love', 210 | 'intelligence', 'money', 'politics', 'health', 'people', 'family'] 211 | 212 | ## visualize topics using monte carlo 213 | with torch.no_grad(): 214 | print('#'*100) 215 | print('Visualize topics...') 216 | topics_words = [] 217 | gammas = m.get_beta() 218 | for k in range(args.num_topics): 219 | gamma = gammas[k] 220 | top_words = list(gamma.cpu().numpy().argsort()[-args.num_words+1:][::-1]) 221 | topic_words = [vocab[a] for a in top_words] 222 | topics_words.append(' '.join(topic_words)) 223 | print('Topic {}: {}'.format(k, topic_words)) 224 | 225 | if show_emb: 226 | ## visualize word embeddings by using V to get nearest neighbors 227 | print('#'*100) 228 | print('Visualize word embeddings by using output embedding matrix') 229 | try: 230 | embeddings = m.rho.weight # Vocab_size x E 231 | except: 232 | embeddings = m.rho # Vocab_size x E 233 | neighbors = [] 234 | for word in queries: 235 | print('word: {} .. neighbors: {}'.format( 236 | word, nearest_neighbors(word, embeddings, vocab))) 237 | print('#'*100) 238 | 239 | def evaluate(m, source, tc=False, td=False): 240 | """Compute perplexity on document completion. 241 | """ 242 | m.eval() 243 | with torch.no_grad(): 244 | if source == 'val': 245 | indices = torch.split(torch.tensor(range(args.num_docs_valid)), args.eval_batch_size) 246 | tokens = valid_tokens 247 | counts = valid_counts 248 | else: 249 | indices = torch.split(torch.tensor(range(args.num_docs_test)), args.eval_batch_size) 250 | tokens = test_tokens 251 | counts = test_counts 252 | 253 | ## get \beta here 254 | beta = m.get_beta() 255 | 256 | ### do dc and tc here 257 | acc_loss = 0 258 | cnt = 0 259 | indices_1 = torch.split(torch.tensor(range(args.num_docs_test_1)), args.eval_batch_size) 260 | for idx, ind in enumerate(indices_1): 261 | ## get theta from first half of docs 262 | data_batch_1 = data.get_batch(test_1_tokens, test_1_counts, ind, args.vocab_size, device) 263 | sums_1 = data_batch_1.sum(1).unsqueeze(1) 264 | if args.bow_norm: 265 | normalized_data_batch_1 = data_batch_1 / sums_1 266 | else: 267 | normalized_data_batch_1 = data_batch_1 268 | theta, _ = m.get_theta(normalized_data_batch_1) 269 | 270 | ## get prediction loss using second half 271 | data_batch_2 = data.get_batch(test_2_tokens, test_2_counts, ind, args.vocab_size, device) 272 | sums_2 = data_batch_2.sum(1).unsqueeze(1) 273 | res = torch.mm(theta, beta) 274 | preds = torch.log(res) 275 | recon_loss = -(preds * data_batch_2).sum(1) 276 | 277 | loss = recon_loss / sums_2.squeeze() 278 | loss = loss.mean().item() 279 | acc_loss += loss 280 | cnt += 1 281 | cur_loss = acc_loss / cnt 282 | ppl_dc = round(math.exp(cur_loss), 1) 283 | print('*'*100) 284 | print('{} Doc Completion PPL: {}'.format(source.upper(), ppl_dc)) 285 | print('*'*100) 286 | if tc or td: 287 | beta = beta.data.cpu().numpy() 288 | if tc: 289 | print('Computing topic coherence...') 290 | get_topic_coherence(beta, train_tokens, vocab) 291 | if td: 292 | print('Computing topic diversity...') 293 | get_topic_diversity(beta, 25) 294 | return ppl_dc 295 | 296 | if args.mode == 'train': 297 | ## train model on data 298 | best_epoch = 0 299 | best_val_ppl = 1e9 300 | all_val_ppls = [] 301 | print('\n') 302 | print('Visualizing model quality before training...') 303 | visualize(model) 304 | print('\n') 305 | for epoch in range(1, args.epochs): 306 | train(epoch) 307 | val_ppl = evaluate(model, 'val') 308 | if val_ppl < best_val_ppl: 309 | with open(ckpt, 'wb') as f: 310 | torch.save(model, f) 311 | best_epoch = epoch 312 | best_val_ppl = val_ppl 313 | else: 314 | ## check whether to anneal lr 315 | lr = optimizer.param_groups[0]['lr'] 316 | if args.anneal_lr and (len(all_val_ppls) > args.nonmono and val_ppl > min(all_val_ppls[:-args.nonmono]) and lr > 1e-5): 317 | optimizer.param_groups[0]['lr'] /= args.lr_factor 318 | if epoch % args.visualize_every == 0: 319 | visualize(model) 320 | all_val_ppls.append(val_ppl) 321 | with open(ckpt, 'rb') as f: 322 | model = torch.load(f) 323 | model = model.to(device) 324 | val_ppl = evaluate(model, 'val') 325 | else: 326 | with open(ckpt, 'rb') as f: 327 | model = torch.load(f) 328 | model = model.to(device) 329 | model.eval() 330 | 331 | with torch.no_grad(): 332 | ## get document completion perplexities 333 | test_ppl = evaluate(model, 'test', tc=args.tc, td=args.td) 334 | 335 | ## get most used topics 336 | indices = torch.tensor(range(args.num_docs_train)) 337 | indices = torch.split(indices, args.batch_size) 338 | thetaAvg = torch.zeros(1, args.num_topics).to(device) 339 | thetaWeightedAvg = torch.zeros(1, args.num_topics).to(device) 340 | cnt = 0 341 | for idx, ind in enumerate(indices): 342 | data_batch = data.get_batch(train_tokens, train_counts, ind, args.vocab_size, device) 343 | sums = data_batch.sum(1).unsqueeze(1) 344 | cnt += sums.sum(0).squeeze().cpu().numpy() 345 | if args.bow_norm: 346 | normalized_data_batch = data_batch / sums 347 | else: 348 | normalized_data_batch = data_batch 349 | theta, _ = model.get_theta(normalized_data_batch) 350 | thetaAvg += theta.sum(0).unsqueeze(0) / args.num_docs_train 351 | weighed_theta = sums * theta 352 | thetaWeightedAvg += weighed_theta.sum(0).unsqueeze(0) 353 | if idx % 100 == 0 and idx > 0: 354 | print('batch: {}/{}'.format(idx, len(indices))) 355 | thetaWeightedAvg = thetaWeightedAvg.squeeze().cpu().numpy() / cnt 356 | print('\nThe 10 most used topics are {}'.format(thetaWeightedAvg.argsort()[::-1][:10])) 357 | 358 | ## show topics 359 | beta = model.get_beta() 360 | topic_indices = list(np.random.choice(args.num_topics, 10)) # 10 random topics 361 | print('\n') 362 | for k in range(args.num_topics):#topic_indices: 363 | gamma = beta[k] 364 | top_words = list(gamma.cpu().numpy().argsort()[-args.num_words+1:][::-1]) 365 | topic_words = [vocab[a] for a in top_words] 366 | print('Topic {}: {}'.format(k, topic_words)) 367 | 368 | if args.train_embeddings: 369 | ## show etm embeddings 370 | try: 371 | rho_etm = model.rho.weight.cpu() 372 | except: 373 | rho_etm = model.rho.cpu() 374 | queries = ['andrew', 'woman', 'computer', 'sports', 'religion', 'man', 'love', 375 | 'intelligence', 'money', 'politics', 'health', 'people', 'family'] 376 | print('\n') 377 | print('ETM embeddings...') 378 | for word in queries: 379 | print('word: {} .. etm neighbors: {}'.format(word, nearest_neighbors(word, rho_etm, vocab))) 380 | print('\n') 381 | -------------------------------------------------------------------------------- /R/ETM.R: -------------------------------------------------------------------------------- 1 | 2 | #' @title Topic Modelling in Semantic Embedding Spaces 3 | #' @description ETM is a generative topic model combining traditional topic models (LDA) with word embeddings (word2vec). \cr 4 | #' \itemize{ 5 | #' \item{It models each word with a categorical distribution whose natural parameter is the inner product between 6 | #' a word embedding and an embedding of its assigned topic.} 7 | #' \item{The model is fitted using an amortized variational inference algorithm on top of libtorch.} 8 | #' } 9 | #' @param k the number of topics to extract 10 | #' @param embeddings either a matrix with pretrained word embeddings or an integer with the dimension of the word embeddings. Defaults to 50 if not provided. 11 | #' @param dim dimension of the variational inference hyperparameter theta (passed on to \code{\link[torch]{nn_linear}}). Defaults to 800. 12 | #' @param activation character string with the activation function of theta. Either one of 'relu', 'tanh', 'softplus', 'rrelu', 'leakyrelu', 'elu', 'selu', 'glu'. Defaults to 'relu'. 13 | #' @param dropout dropout percentage on the variational distribution for theta (passed on to \code{\link[torch]{nn_dropout}}). Defaults to 0.5. 14 | #' @param vocab a character vector with the words from the vocabulary. Defaults to the rownames of the \code{embeddings} argument. 15 | #' @references \url{https://arxiv.org/pdf/1907.04907.pdf} 16 | #' @return an object of class ETM which is a torch \code{nn_module} containing o.a. 17 | #' \itemize{ 18 | #' \item num_topics: the number of topics 19 | #' \item vocab: character vector with the terminology used in the model 20 | #' \item vocab_size: the number of words in \code{vocab} 21 | #' \item rho: The word embeddings 22 | #' \item alphas: The topic embeddings 23 | #' } 24 | #' @section Methods: 25 | #' \describe{ 26 | #' \item{\code{fit(data, optimizer, epoch, batch_size, normalize = TRUE, clip = 0, lr_anneal_factor = 4, lr_anneal_nonmono = 10)}}{Fit the model on a document term matrix by splitting the data in 70/30 training/test set and updating the model weights.} 27 | #' } 28 | #' @section Arguments: 29 | #' \describe{ 30 | #' \item{data}{bag of words document term matrix in \code{dgCMatrix} format} 31 | #' \item{optimizer}{object of class \code{torch_Optimizer}} 32 | #' \item{epoch}{integer with the number of iterations to train} 33 | #' \item{batch_size}{integer with the size of the batch} 34 | #' \item{normalize}{logical indicating to normalize the bag of words data} 35 | #' \item{clip}{number between 0 and 1 indicating to do gradient clipping - passed on to \code{\link[torch]{nn_utils_clip_grad_norm_}}} 36 | #' \item{lr_anneal_factor}{divide the learning rate by this factor when the loss on the test set is monotonic for at least \code{lr_anneal_nonmono} training iterations} 37 | #' \item{lr_anneal_nonmono}{number of iterations after which learning rate annealing is executed if the loss does not decreases} 38 | #' } 39 | #' @export 40 | #' @examples 41 | #' library(torch) 42 | #' library(topicmodels.etm) 43 | #' library(word2vec) 44 | #' library(udpipe) 45 | #' data(brussels_reviews_anno, package = "udpipe") 46 | #' ## 47 | #' ## Toy example with pretrained embeddings 48 | #' ## 49 | #' 50 | #' ## a. build word2vec model 51 | #' x <- subset(brussels_reviews_anno, language %in% "nl") 52 | #' x <- paste.data.frame(x, term = "lemma", group = "doc_id") 53 | #' set.seed(4321) 54 | #' w2v <- word2vec(x = x$lemma, dim = 15, iter = 20, type = "cbow", min_count = 5) 55 | #' embeddings <- as.matrix(w2v) 56 | #' 57 | #' ## b. build document term matrix on nouns + adjectives, align with the embedding terms 58 | #' dtm <- subset(brussels_reviews_anno, language %in% "nl" & upos %in% c("NOUN", "ADJ")) 59 | #' dtm <- document_term_frequencies(dtm, document = "doc_id", term = "lemma") 60 | #' dtm <- document_term_matrix(dtm) 61 | #' dtm <- dtm_conform(dtm, columns = rownames(embeddings)) 62 | #' dtm <- dtm[dtm_rowsums(dtm) > 0, ] 63 | #' 64 | #' ## create and fit an embedding topic model - 8 topics, theta 100-dimensional 65 | #' if (torch::torch_is_installed()) { 66 | #' 67 | #' set.seed(4321) 68 | #' torch_manual_seed(4321) 69 | #' model <- ETM(k = 8, dim = 100, embeddings = embeddings, dropout = 0.5) 70 | #' optimizer <- optim_adam(params = model$parameters, lr = 0.005, weight_decay = 0.0000012) 71 | #' overview <- model$fit(data = dtm, optimizer = optimizer, epoch = 40, batch_size = 1000) 72 | #' scores <- predict(model, dtm, type = "topics") 73 | #' 74 | #' lastbatch <- subset(overview$loss, overview$loss$batch_is_last == TRUE) 75 | #' plot(lastbatch$epoch, lastbatch$loss) 76 | #' plot(overview$loss_test) 77 | #' 78 | #' ## show top words in each topic 79 | #' terminology <- predict(model, type = "terms", top_n = 7) 80 | #' terminology 81 | #' 82 | #' ## 83 | #' ## Toy example without pretrained word embeddings 84 | #' ## 85 | #' set.seed(4321) 86 | #' torch_manual_seed(4321) 87 | #' model <- ETM(k = 8, dim = 100, embeddings = 15, dropout = 0.5, vocab = colnames(dtm)) 88 | #' optimizer <- optim_adam(params = model$parameters, lr = 0.005, weight_decay = 0.0000012) 89 | #' overview <- model$fit(data = dtm, optimizer = optimizer, epoch = 40, batch_size = 1000) 90 | #' terminology <- predict(model, type = "terms", top_n = 7) 91 | #' terminology 92 | #' 93 | #' 94 | #' 95 | #' \dontshow{ 96 | #' ## 97 | #' ## Another example using fit_original 98 | #' ## 99 | #' data(ng20, package = "topicmodels.etm") 100 | #' vocab <- ng20$vocab 101 | #' tokens <- ng20$bow_tr$tokens 102 | #' counts <- ng20$bow_tr$counts 103 | #' 104 | #' torch_manual_seed(123456789) 105 | #' model <- ETM(k = 4, vocab = vocab, dim = 5, embeddings = 25) 106 | #' model 107 | #' optimizer <- optim_adam(params = model$parameters, lr = 0.005, weight_decay = 0.0000012) 108 | #' 109 | #' traindata <- list(tokens = tokens, counts = counts, vocab = vocab) 110 | #' test1 <- list(tokens = ng20$bow_ts_h1$tokens, counts = ng20$bow_ts_h1$counts, vocab = vocab) 111 | #' test2 <- list(tokens = ng20$bow_ts_h2$tokens, counts = ng20$bow_ts_h2$counts, vocab = vocab) 112 | #' 113 | #' out <- model$fit_original(data = traindata, test1 = test1, test2 = test2, epoch = 4, 114 | #' optimizer = optimizer, batch_size = 1000, 115 | #' lr_anneal_factor = 4, lr_anneal_nonmono = 10) 116 | #' test <- subset(out$loss, out$loss$batch_is_last == TRUE) 117 | #' plot(test$epoch, test$loss) 118 | #' 119 | #' topic.centers <- as.matrix(model, type = "embedding", which = "topics") 120 | #' word.embeddings <- as.matrix(model, type = "embedding", which = "words") 121 | #' topic.terminology <- as.matrix(model, type = "beta") 122 | #' 123 | #' terminology <- predict(model, type = "terms", top_n = 4) 124 | #' terminology 125 | #' } 126 | #' 127 | #' } 128 | ETM <- nn_module( 129 | classname = "ETM", 130 | initialize = function(k = 20, 131 | embeddings, 132 | dim = 800, 133 | activation = c("relu", "tanh", "softplus", "rrelu", "leakyrelu", "elu", "selu", "glu"), 134 | dropout = 0.5, 135 | vocab = rownames(embeddings)) { 136 | if(missing(embeddings)){ 137 | rho <- 50 138 | }else{ 139 | rho <- embeddings 140 | } 141 | num_topics <- k 142 | t_hidden_size <- dim 143 | activation <- match.arg(activation) 144 | if(is.matrix(rho)){ 145 | stopifnot(length(vocab) == nrow(rho)) 146 | stopifnot(all(vocab == rownames(rho))) 147 | train_embeddings <- FALSE 148 | rho_size <- ncol(rho) 149 | }else{ 150 | if(!is.character(vocab)){ 151 | stop("provide in vocab a character vector") 152 | } 153 | train_embeddings <- TRUE 154 | rho_size <- rho 155 | } 156 | enc_drop <- dropout 157 | 158 | vocab_size <- length(vocab) 159 | self$loss_fit <- NULL 160 | self$vocab <- vocab 161 | self$num_topics <- num_topics 162 | self$vocab_size <- vocab_size 163 | self$t_hidden_size <- t_hidden_size 164 | self$rho_size <- rho_size 165 | self$enc_drop <- enc_drop 166 | self$t_drop <- nn_dropout(p = enc_drop) 167 | 168 | self$activation <- activation 169 | self$theta_act <- get_activation(activation) 170 | 171 | 172 | ## define the word embedding matrix \rho 173 | if(train_embeddings){ 174 | self$rho <- nn_linear(rho_size, vocab_size, bias = FALSE) 175 | }else{ 176 | #rho = nn.Embedding(num_embeddings, emsize) 177 | #self.rho = embeddings.clone().float().to(device) 178 | self$rho <- nn_embedding(num_embeddings = vocab_size, embedding_dim = rho_size, .weight = torch_tensor(rho)) 179 | #self$rho <- torch_tensor(rho) 180 | } 181 | 182 | ## define the matrix containing the topic embeddings 183 | self$alphas <- nn_linear(rho_size, self$num_topics, bias = FALSE)#nn.Parameter(torch.randn(rho_size, num_topics)) 184 | 185 | ## define variational distribution for \theta_{1:D} via amortizartion 186 | self$q_theta <- nn_sequential( 187 | nn_linear(vocab_size, t_hidden_size), 188 | self$theta_act, 189 | nn_linear(t_hidden_size, t_hidden_size), 190 | self$theta_act 191 | ) 192 | self$mu_q_theta <- nn_linear(t_hidden_size, self$num_topics, bias = TRUE) 193 | self$logsigma_q_theta <- nn_linear(t_hidden_size, self$num_topics, bias = TRUE) 194 | }, 195 | print = function(...){ 196 | cat("Embedding Topic Model", sep = "\n") 197 | cat(sprintf(" - topics: %s", self$num_topics), sep = "\n") 198 | cat(sprintf(" - vocabulary size: %s", self$vocab_size), sep = "\n") 199 | cat(sprintf(" - embedding dimension: %s", self$rho_size), sep = "\n") 200 | cat(sprintf(" - variational distribution dimension: %s", self$t_hidden_size), sep = "\n") 201 | cat(sprintf(" - variational distribution activation function: %s", self$activation), sep = "\n") 202 | }, 203 | encode = function(bows){ 204 | # """Returns paramters of the variational distribution for \theta. 205 | # 206 | # input: bows 207 | # batch of bag-of-words...tensor of shape bsz x V 208 | # output: mu_theta, log_sigma_theta 209 | # """ 210 | q_theta <- self$q_theta(bows) 211 | if(self$enc_drop > 0){ 212 | q_theta <- self$t_drop(q_theta) 213 | } 214 | mu_theta <- self$mu_q_theta(q_theta) 215 | logsigma_theta <- self$logsigma_q_theta(q_theta) 216 | kl_theta <- -0.5 * torch_sum(1 + logsigma_theta - mu_theta$pow(2) - logsigma_theta$exp(), dim = -1)$mean() 217 | list(mu_theta = mu_theta, logsigma_theta = logsigma_theta, kl_theta = kl_theta) 218 | }, 219 | decode = function(theta, beta){ 220 | res <- torch_mm(theta, beta) 221 | preds <- torch_log(res + 1e-6) 222 | preds 223 | }, 224 | get_beta = function(){ 225 | logit <- try(self$alphas(self$rho$weight)) # torch.mm(self.rho, self.alphas) 226 | if(inherits(logit, "try-error")){ 227 | logit <- self$alphas(self$rho) 228 | } 229 | #beta <- nnf_softmax(logit, dim=0)$transpose(1, 0) ## softmax over vocab dimension 230 | beta <- nnf_softmax(logit, dim = 1)$transpose(2, 1) ## softmax over vocab dimension 231 | beta 232 | }, 233 | get_theta = function(normalized_bows){ 234 | reparameterize = function(self, mu, logvar){ 235 | if(self$training){ 236 | std <- torch_exp(0.5 * logvar) 237 | eps <- torch_randn_like(std) 238 | eps$mul_(std)$add_(mu) 239 | }else{ 240 | mu 241 | } 242 | } 243 | msg <- self$encode(normalized_bows) 244 | mu_theta <- msg$mu_theta 245 | logsigma_theta <- msg$logsigma_theta 246 | kld_theta <- msg$kl_theta 247 | z <- reparameterize(self, mu_theta, logsigma_theta) 248 | theta <- nnf_softmax(z, dim=-1) 249 | list(theta = theta, kld_theta = kld_theta) 250 | }, 251 | forward = function(bows, normalized_bows, theta = NULL, aggregate = TRUE) { 252 | ## get \theta 253 | if(is.null(theta)){ 254 | msg <- self$get_theta(normalized_bows) 255 | theta <- msg$theta 256 | kld_theta <- msg$kld_theta 257 | }else{ 258 | kld_theta <- NULL 259 | } 260 | ## get \beta 261 | beta <- self$get_beta() 262 | ## get prediction loss 263 | preds <- self$decode(theta, beta) 264 | recon_loss <- -(preds * bows)$sum(2) 265 | #print(dim(recon_loss)) 266 | if(aggregate){ 267 | recon_loss <- recon_loss$mean() 268 | } 269 | list(recon_loss = recon_loss, kld_theta = kld_theta) 270 | }, 271 | topwords = function(top_n = 10){ 272 | self$eval() 273 | out <- list() 274 | with_no_grad({ 275 | gammas <- self$get_beta() 276 | for(k in seq_len(self$num_topics)){ 277 | gamma <- gammas[k, ] 278 | gamma <- as.numeric(gamma) 279 | gamma <- data.frame(term = self$vocab, beta = gamma, stringsAsFactors = FALSE) 280 | gamma <- gamma[order(gamma$beta, decreasing = TRUE), ] 281 | gamma$rank <- seq_len(nrow(gamma)) 282 | out[[k]] <- head(gamma, n = top_n) 283 | } 284 | }) 285 | out 286 | }, 287 | train_epoch = function(tokencounts, optimizer, epoch, batch_size, normalize = TRUE, clip = 0, permute = TRUE){ 288 | self$train() 289 | train_tokens <- tokencounts$tokens 290 | train_counts <- tokencounts$counts 291 | vocab_size <- length(tokencounts$vocab) 292 | num_docs_train <- length(train_tokens) 293 | acc_loss <- 0 294 | acc_kl_theta_loss <- 0 295 | cnt <- 0 296 | if(permute){ 297 | indices <- torch_randperm(num_docs_train) + 1 298 | }else{ 299 | ## For comparing end-to-end run and unit testing 300 | indices <- torch_tensor(seq_len(num_docs_train)) 301 | } 302 | indices <- torch_split(indices, batch_size) 303 | losses <- list() 304 | for(i in seq_along(indices)){ 305 | ind <- indices[[i]] 306 | optimizer$zero_grad() 307 | self$zero_grad() 308 | data_batch <- get_batch(train_tokens, train_counts, ind, vocab_size) 309 | sums <- data_batch$sum(2)$unsqueeze(2) 310 | if(normalize){ 311 | normalized_data_batch <- data_batch / sums 312 | }else{ 313 | normalized_data_batch <- data_batch 314 | } 315 | #as.matrix(self$q_theta(data_batch[1:10, , drop = FALSE])) 316 | out <- self$forward(data_batch, normalized_data_batch) 317 | total_loss <- out$recon_loss + out$kld_theta 318 | total_loss$backward() 319 | 320 | if(clip > 0){ 321 | nn_utils_clip_grad_norm_(self$parameters, max_norm = clip) 322 | } 323 | optimizer$step() 324 | 325 | acc_loss <- acc_loss + torch_sum(out$recon_loss)$item() 326 | acc_kl_theta_loss <- acc_kl_theta_loss + torch_sum(out$kld_theta)$item() 327 | cnt <- cnt + 1 328 | 329 | cur_loss <- round(acc_loss / cnt, 2) 330 | cur_kl_theta <- round(acc_kl_theta_loss / cnt, 2) 331 | cur_real_loss <- round(cur_loss + cur_kl_theta, 2) 332 | 333 | losses[[i]] <- data.frame(epoch = epoch, 334 | batch = i, 335 | batch_is_last = i == length(indices), 336 | lr = optimizer$param_groups[[1]][['lr']], 337 | loss = cur_loss, 338 | kl_theta = cur_kl_theta, 339 | nelbo = cur_real_loss, 340 | batch_loss = acc_loss, 341 | batch_kl_theta = acc_kl_theta_loss, 342 | batch_nelbo = acc_loss + acc_kl_theta_loss) 343 | #cat( 344 | # sprintf('Epoch: %s .. batch: %s/%s .. LR: %s .. KL_theta: %s .. Rec_loss: %s .. NELBO: %s', 345 | # epoch, i, length(indices), optimizer$param_groups[[1]][['lr']], cur_kl_theta, cur_loss, cur_real_loss), sep = "\n") 346 | } 347 | losses <- do.call(rbind, losses) 348 | losses 349 | }, 350 | evaluate = function(data1, data2, batch_size, normalize = TRUE){ 351 | self$eval() 352 | vocab_size <- length(data1$vocab) 353 | tokens1 <- data1$tokens 354 | counts1 <- data1$counts 355 | tokens2 <- data2$tokens 356 | counts2 <- data2$counts 357 | 358 | indices <- torch_split(torch_tensor(seq_along(tokens1)), batch_size) 359 | ppl_dc <- 0 360 | with_no_grad({ 361 | beta <- self$get_beta() 362 | acc_loss <- 0 363 | cnt <- 0 364 | for(i in seq_along(indices)){ 365 | ## get theta from first half of docs 366 | ind <- indices[[i]] 367 | data_batch_1 <- get_batch(tokens1, counts1, ind, vocab_size) 368 | sums <- data_batch_1$sum(2)$unsqueeze(2) 369 | if(normalize){ 370 | normalized_data_batch <- data_batch_1 / sums 371 | }else{ 372 | normalized_data_batch <- data_batch_1 373 | } 374 | msg <- self$get_theta(normalized_data_batch) 375 | theta <- msg$theta 376 | 377 | ## get prediction loss using second half 378 | data_batch_2 <- get_batch(tokens2, counts2, ind, vocab_size) 379 | sums <- data_batch_2$sum(2)$unsqueeze(2) 380 | res <- torch_mm(theta, beta) 381 | preds <- torch_log(res) 382 | recon_loss <- -(preds * data_batch_2)$sum(2) 383 | 384 | loss <- recon_loss / sums$squeeze() 385 | loss <- loss$mean()$item() 386 | acc_loss <- acc_loss + loss 387 | cnt <- cnt + 1 388 | } 389 | cur_loss <- acc_loss / cnt 390 | cur_loss <- as.numeric(cur_loss) 391 | ppl_dc <- round(exp(cur_loss), digits = 1) 392 | }) 393 | ppl_dc 394 | }, 395 | fit = function(data, optimizer, epoch, batch_size, normalize = TRUE, clip = 0, lr_anneal_factor = 4, lr_anneal_nonmono = 10){ 396 | stopifnot(inherits(data, "sparseMatrix")) 397 | data <- data[Matrix::rowSums(data) > 0, ] 398 | if(nrow(data) == 0){ 399 | stop("data argument (document term matrix) does not contain any documents (which contain words part of the vocabulary)") 400 | } 401 | idx <- split_train_test(data, train_pct = 0.7) 402 | test1 <- as_tokencounts(data[idx$test1, ]) 403 | test2 <- as_tokencounts(data[idx$test2, ]) 404 | data <- as_tokencounts(data[idx$train, ]) 405 | loss_evolution <- self$fit_original(data = data, test1 = test1, test2 = test2, optimizer = optimizer, epoch = epoch, 406 | batch_size = batch_size, normalize = normalize, clip = clip, 407 | lr_anneal_factor = lr_anneal_factor, lr_anneal_nonmono = lr_anneal_nonmono) 408 | self$loss_fit <- loss_evolution 409 | invisible(loss_evolution) 410 | }, 411 | fit_original = function(data, test1, test2, optimizer, epoch, batch_size, normalize = TRUE, clip = 0, lr_anneal_factor = 4, lr_anneal_nonmono = 10, permute = TRUE){ 412 | epochs <- epoch 413 | anneal_lr <- lr_anneal_factor > 0 414 | best_epoch <- 0 415 | best_val_ppl <- 1e9 416 | all_val_ppls <- c() 417 | losses <- list() 418 | for(epoch in seq_len(epochs)){ 419 | lossevolution <- self$train_epoch(tokencounts = data, optimizer = optimizer, epoch = epoch, batch_size = batch_size, normalize = normalize, clip = clip, permute = permute) 420 | losses[[epoch]] <- lossevolution 421 | val_ppl <- self$evaluate(test1, test2, batch_size = batch_size, normalize = normalize) 422 | if(val_ppl < best_val_ppl){ 423 | best_epoch <- epoch 424 | best_val_ppl <- val_ppl 425 | ## TODO save model 426 | }else{ 427 | ## check whether to anneal lr 428 | lr <- optimizer$param_groups[[1]]$lr 429 | cat(sprintf("%s versus %s", val_ppl, min(tail(all_val_ppls, n = lr_anneal_nonmono))), sep = "\n") 430 | if(anneal_lr & lr > 1e-5 & (length(all_val_ppls) > lr_anneal_nonmono) & val_ppl > min(tail(all_val_ppls, n = lr_anneal_nonmono))){ 431 | optimizer$param_groups[[1]]$lr <- optimizer$param_groups[[1]]$lr / lr_anneal_factor 432 | } 433 | } 434 | all_val_ppls <- append(all_val_ppls, val_ppl) 435 | lossevolution <- subset(lossevolution, batch_is_last == TRUE) 436 | cat( 437 | sprintf('Epoch: %03d/%03d, learning rate: %5f. Training data stats - KL_theta: %2f, Rec_loss: %2f, NELBO: %s. Test data stats - Loss %2f', 438 | lossevolution$epoch, epochs, optimizer$param_groups[[1]][['lr']], lossevolution$kl_theta, lossevolution$loss, lossevolution$nelbo, 439 | val_ppl), sep = "\n") 440 | } 441 | losses <- do.call(rbind, losses) 442 | list(loss = losses, loss_test = all_val_ppls) 443 | } 444 | ) 445 | get_batch <- function(tokens, counts, ind, vocab_size){ 446 | ind <- as.integer(ind) 447 | batch_size <- length(ind) 448 | data_batch <- torch_zeros(c(batch_size, vocab_size)) 449 | tokens <- tokens[ind] 450 | counts <- counts[ind] 451 | for(i in seq_along(tokens)){ 452 | tok <- tokens[[i]] 453 | cnt <- counts[[i]] 454 | data_batch[i, tok] <- as.numeric(cnt) 455 | #for(j in tok){ 456 | # data_batch[i, j] <- cnt[j] 457 | #} 458 | } 459 | data_batch 460 | } 461 | 462 | get_activation = function(act) { 463 | switch(act, 464 | tanh = nn_tanh(), 465 | relu = nn_relu(), 466 | softplus = nn_softplus(), 467 | rrelu = nn_rrelu(), 468 | leakyrelu = nn_leaky_relu(), 469 | elu = nn_elu(), 470 | selu = nn_selu(), 471 | glu = nn_glu()) 472 | } 473 | 474 | 475 | split_train_test <- function(x, train_pct = 0.7){ 476 | stopifnot(train_pct <= 1) 477 | test_pct <- 1 - train_pct 478 | idx <- seq_len(nrow(x)) 479 | tst <- sample(idx, size = nrow(x) * test_pct, replace = FALSE) 480 | tst1 <- sample(tst, size = round(length(tst) / 2), replace = FALSE) 481 | tst2 <- setdiff(tst, tst1) 482 | trn <- setdiff(idx, tst) 483 | list(train = sort(trn), test1 = sort(tst1), test2 = sort(tst2)) 484 | } 485 | 486 | 487 | 488 | #' @title Predict functionality for an ETM object. 489 | #' @description Predict to which ETM topic a text belongs or extract which words are emitted for each topic. 490 | #' @param object an object of class \code{ETM} 491 | #' @param type a character string with either 'topics' or 'terms' indicating to either predict to which 492 | #' topic a document encoded as a set of bag of words belongs to or to extract the most emitted terms for each topic 493 | #' @param newdata bag of words document term matrix in \code{dgCMatrix} format. Only used in case type = 'topics'. 494 | #' @param batch_size integer with the size of the batch in order to do chunkwise predictions in chunks of \code{batch_size} rows. Defaults to the whole dataset provided in \code{newdata}. 495 | #' Only used in case type = 'topics'. 496 | #' @param normalize logical indicating to normalize the bag of words data. Defaults to \code{TRUE} similar as the default when building the \code{ETM} model. 497 | #' Only used in case type = 'topics'. 498 | #' @param top_n integer with the number of most relevant words for each topic to extract. Only used in case type = 'terms'. 499 | #' @param ... not used 500 | #' @seealso \code{\link{ETM}} 501 | #' @return Returns for 502 | #' \itemize{ 503 | #' \item{type 'topics': a matrix with topic probabilities of dimension nrow(newdata) x the number of topics} 504 | #' \item{type 'terms': a list of data.frame's where each data.frame has columns term, beta and rank indicating the 505 | #' top_n most emitted terms for that topic. List element 1 corresponds to the top terms emitted by topic 1, element 2 to topic 2 ...} 506 | #' } 507 | #' @export 508 | #' @examples 509 | #' \dontshow{if(require(torch) && torch::torch_is_installed()) 510 | #' \{ 511 | #' } 512 | #' library(torch) 513 | #' library(topicmodels.etm) 514 | #' path <- system.file(package = "topicmodels.etm", "example", "example_etm.ckpt") 515 | #' model <- torch_load(path) 516 | #' 517 | #' # Get most emitted words for each topic 518 | #' terminology <- predict(model, type = "terms", top_n = 5) 519 | #' terminology 520 | #' 521 | #' # Get topics probabilities for each document 522 | #' path <- system.file(package = "topicmodels.etm", "example", "example_dtm.rds") 523 | #' dtm <- readRDS(path) 524 | #' dtm <- head(dtm, n = 5) 525 | #' scores <- predict(model, newdata = dtm, type = "topics") 526 | #' scores 527 | #' \dontshow{ 528 | #' \} 529 | #' # End of main if statement running only if the torch is properly installed 530 | #' } 531 | predict.ETM <- function(object, newdata, type = c("topics", "terms"), batch_size = nrow(newdata), normalize = TRUE, top_n = 10, ...){ 532 | type <- match.arg(type) 533 | if(type == "terms"){ 534 | object$topwords(top_n) 535 | }else{ 536 | if(any(Matrix::rowSums(newdata) <= 0)){ 537 | stop("All rows of newdata should have at least 1 count") 538 | } 539 | x <- as_tokencounts(newdata) 540 | tokens <- x$tokens 541 | counts <- x$counts 542 | num_topics <- object$num_topics 543 | vocab_size <- object$vocab_size 544 | 545 | preds <- list() 546 | with_no_grad({ 547 | indices = torch_tensor(seq_along(tokens)) 548 | indices = torch_split(indices, batch_size) 549 | thetaWeightedAvg = torch_zeros(1, num_topics) 550 | cnt = 0 551 | for(i in seq_along(indices)){ 552 | ## get theta from first half of docs 553 | ind <- indices[[i]] 554 | data_batch = get_batch(tokens, counts, ind, vocab_size) 555 | sums <- data_batch$sum(2)$unsqueeze(2) 556 | cnt = cnt + as.numeric(sums$sum(1)$squeeze()) 557 | if(normalize){ 558 | normalized_data_batch <- data_batch / sums 559 | }else{ 560 | normalized_data_batch <- data_batch 561 | } 562 | theta <- object$get_theta(normalized_data_batch)$theta 563 | preds[[i]] <- as.matrix(theta) 564 | weighed_theta = sums * theta 565 | thetaWeightedAvg = thetaWeightedAvg + weighed_theta$sum(1)$unsqueeze(1) 566 | } 567 | thetaWeightedAvg = thetaWeightedAvg$squeeze() / cnt 568 | }) 569 | preds <- do.call(rbind, preds) 570 | rownames(preds) <- rownames(newdata) 571 | preds 572 | } 573 | } 574 | 575 | 576 | #' @title Get matrices out of an ETM object 577 | #' @description Convenience function to extract 578 | #' \itemize{ 579 | #' \item{embeddings of the topic centers} 580 | #' \item{embeddings of the words used in the model} 581 | #' \item{words emmitted by each topic (beta), which is the softmax-transformed inner product of word embedding and topic embeddings} 582 | #' } 583 | #' @param x an object of class \code{ETM} 584 | #' @param type character string with the type of information to extract: either 'beta' (words emttied by each topic) or 'embedding' (embeddings of words or topic centers). Defaults to 'embedding'. 585 | #' @param which a character string with either 'words' or 'topics' to get either the embeddings of the words used in the model or the embedding of the topic centers. Defaults to 'topics'. Only used if type = 'embedding'. 586 | #' @param ... not used 587 | #' @seealso \code{\link{ETM}} 588 | #' @return a numeric matrix containing, depending on the value supplied in \code{type} 589 | #' either the embeddings of the topic centers, the embeddings of the words or the words emitted by each topic 590 | #' @export 591 | #' @examples 592 | #' \dontshow{if(require(torch) && torch::torch_is_installed()) 593 | #' \{ 594 | #' } 595 | #' library(torch) 596 | #' library(topicmodels.etm) 597 | #' path <- system.file(package = "topicmodels.etm", "example", "example_etm.ckpt") 598 | #' model <- torch_load(path) 599 | #' 600 | #' topic.centers <- as.matrix(model, type = "embedding", which = "topics") 601 | #' word.embeddings <- as.matrix(model, type = "embedding", which = "words") 602 | #' topic.terminology <- as.matrix(model, type = "beta") 603 | #' \dontshow{ 604 | #' \} 605 | #' # End of main if statement running only if the torch is properly installed 606 | #' } 607 | as.matrix.ETM <- function(x, type = c("embedding", "beta"), which = c("topics", "words"), ...){ 608 | type <- match.arg(type) 609 | which <- match.arg(which) 610 | self <- x 611 | self$eval() 612 | if(type == "embedding"){ 613 | if(which == "topics"){ 614 | with_no_grad({ 615 | out <- as.matrix(self$parameters$alphas.weight) 616 | }) 617 | }else if(which == "words"){ 618 | with_no_grad({ 619 | out <- as.matrix(self$parameters$rho.weight) 620 | rownames(out) <- self$vocab 621 | }) 622 | } 623 | }else if(type == "beta"){ 624 | with_no_grad({ 625 | gammas <- self$get_beta() 626 | gammas <- as.matrix(gammas) 627 | colnames(gammas) <- self$vocab 628 | }) 629 | out <- t(gammas) 630 | } 631 | out 632 | } 633 | 634 | #' @title Plot functionality for an ETM object 635 | #' @description Convenience function allowing to plot 636 | #' \itemize{ 637 | #' \item{the evolution of the loss on the training / test set in order to inspect training convergence} 638 | #' \item{the \code{ETM} model in 2D dimensional space using a umap projection. 639 | #' This plot uses function \code{\link[textplot]{textplot_embedding_2d}} from the textplot R package and 640 | #' plots the top_n most emitted words of each topic and the topic centers in 2 dimensions} 641 | #' } 642 | #' @param x an object of class \code{ETM} 643 | #' @param type character string with the type of plot to generate: either 'loss' or 'topics' 644 | #' @param which an integer vector of topics to plot, used in case type = 'topics'. Defaults to all topics. See the example below. 645 | #' @param top_n passed on to \code{summary.ETM} in order to visualise the top_n most relevant words for each topic. Defaults to 4. 646 | #' @param title passed on to textplot_embedding_2d, used in case type = 'topics' 647 | #' @param subtitle passed on to textplot_embedding_2d, used in case type = 'topics' 648 | #' @param encircle passed on to textplot_embedding_2d, used in case type = 'topics' 649 | #' @param points passed on to textplot_embedding_2d, used in case type = 'topics' 650 | #' @param ... arguments passed on to \code{\link{summary.ETM}} 651 | #' @seealso \code{\link{ETM}}, \code{\link{summary.ETM}}, \code{\link[textplot]{textplot_embedding_2d}} 652 | #' @return In case \code{type} is set to 'topics', maps the topic centers and most emitted words for each topic 653 | #' to 2D using \code{\link{summary.ETM}} and returns a ggplot object by calling \code{\link[textplot]{textplot_embedding_2d}}. \cr 654 | #' For type 'loss', makes a base graphics plot and returns invisibly nothing. 655 | #' @export 656 | #' @examples 657 | #' \dontshow{if(require(torch) && torch::torch_is_installed()) 658 | #' \{ 659 | #' } 660 | #' library(torch) 661 | #' library(topicmodels.etm) 662 | #' path <- system.file(package = "topicmodels.etm", "example", "example_etm.ckpt") 663 | #' model <- torch_load(path) 664 | #' plot(model, type = "loss") 665 | #' \dontshow{ 666 | #' \} 667 | #' # End of main if statement running only if the torch is properly installed 668 | #' } 669 | #' 670 | #' \dontshow{if(require(torch) && torch::torch_is_installed() && 671 | #' require(textplot) && require(uwot) && require(ggrepel)) 672 | #' \{ 673 | #' } 674 | #' library(torch) 675 | #' library(topicmodels.etm) 676 | #' library(textplot) 677 | #' library(uwot) 678 | #' library(ggrepel) 679 | #' path <- system.file(package = "topicmodels.etm", "example", "example_etm.ckpt") 680 | #' model <- torch_load(path) 681 | #' plt <- plot(model, type = "topics", top_n = 7, which = c(1, 2, 14, 16, 18, 19), 682 | #' metric = "cosine", n_neighbors = 15, 683 | #' fast_sgd = FALSE, n_threads = 2, verbose = TRUE, 684 | #' title = "ETM Topics example") 685 | #' plt 686 | #' \dontshow{ 687 | #' \} 688 | #' # End of main if statement running only if the torch is properly installed 689 | #' } 690 | plot.ETM <- function(x, type = c("loss", "topics"), which, top_n = 4, 691 | title = "ETM topics", subtitle = "", 692 | encircle = FALSE, points = FALSE, ...){ 693 | type <- match.arg(type) 694 | if(type == "loss"){ 695 | loss_evolution <- x$loss_fit 696 | if(is.null(loss_evolution)){ 697 | stop("You haven't trained the model yet") 698 | } 699 | oldpar <- par(no.readonly = TRUE) 700 | on.exit({ 701 | par(oldpar) 702 | }) 703 | 704 | combined <- loss_evolution$loss[loss_evolution$loss$batch_is_last == TRUE, ] 705 | combined$loss_test <- loss_evolution$loss_test 706 | par(mfrow = c(1, 2)) 707 | plot(combined$epoch, combined$loss, xlab = "Epoch", ylab = "loss", main = "Avg batch loss evolution\non 70% training set", col = "steelblue", type = "b", pch = 20, lty = 2) 708 | plot(combined$epoch, combined$loss_test, xlab = "Epoch", ylab = "exp(loss)", main = "Avg batch loss evolution\non 30% test set", col = "purple", type = "b", pch = 20, lty = 2) 709 | invisible() 710 | }else{ 711 | requireNamespace("textplot") 712 | manifolded <- summary(x, top_n = top_n, ...) 713 | space <- manifolded$embed_2d 714 | if(!missing(which)){ 715 | space <- space[space$cluster %in% which, ] 716 | } 717 | textplot::textplot_embedding_2d(space, title = title, subtitle = subtitle, encircle = encircle, points = points) 718 | } 719 | } 720 | 721 | 722 | #' @title Project ETM embeddings using UMAP 723 | #' @description Uses the uwot package to map the word embeddings and the center of the topic embeddings to a 2-dimensional space 724 | #' @param object object of class \code{ETM} 725 | #' @param type character string with the type of summary to extract. Defaults to 'umap', no other summary information currently implemented. 726 | #' @param n_components the dimension of the space to embed into. Passed on to \code{\link[uwot]{umap}}. Defaults to 2. 727 | #' @param top_n passed on to \code{\link{predict.ETM}} to get the \code{top_n} most relevant words for each topic in the 2-dimensional space 728 | #' @param ... further arguments passed onto \code{\link[uwot]{umap}} 729 | #' @seealso \code{\link[uwot]{umap}}, \code{\link{ETM}} 730 | #' @return a list with elements 731 | #' \itemize{ 732 | #' \item{center: a matrix with the embeddings of the topic centers} 733 | #' \item{words: a matrix with the embeddings of the words} 734 | #' \item{embed_2d: a data.frame which contains a lower dimensional presentation in 2D of the topics and the top_n words associated with 735 | #' the topic, containing columns type, term, cluster (the topic number), rank, beta, x, y, weight; where type is either 'words' or 'centers', x/y contain the lower dimensional 736 | #' positions in 2D of the word and weight is the emitted beta scaled to the highest beta within a topic where the topic center always gets weight 0.8} 737 | #' } 738 | #' @export 739 | #' @examples 740 | #' \dontshow{if(require(torch) && torch::torch_is_installed() && require(uwot)) 741 | #' \{ 742 | #' } 743 | #' library(torch) 744 | #' library(topicmodels.etm) 745 | #' library(uwot) 746 | #' path <- system.file(package = "topicmodels.etm", "example", "example_etm.ckpt") 747 | #' model <- torch_load(path) 748 | #' overview <- summary(model, 749 | #' metric = "cosine", n_neighbors = 15, 750 | #' fast_sgd = FALSE, n_threads = 1, verbose = TRUE) 751 | #' overview$center 752 | #' overview$embed_2d 753 | #' \dontshow{ 754 | #' \} 755 | #' # End of main if statement running only if the torch is properly installed 756 | #' } 757 | summary.ETM <- function(object, type = c("umap"), n_components = 2, top_n = 20, ...){ 758 | type <- match.arg(type) 759 | if(type == "umap"){ 760 | requireNamespace("uwot") 761 | centers <- as.matrix(object, type = "embedding", which = "topics") 762 | embeddings <- as.matrix(object, type = "embedding", which = "words") 763 | manifold <- uwot::umap(embeddings, n_components = n_components, ret_model = TRUE, ...) 764 | centers <- uwot::umap_transform(X = centers, model = manifold) 765 | words <- manifold$embedding 766 | rownames(words) <- rownames(embeddings) 767 | rownames(centers) <- rownames(centers) 768 | 769 | terminology <- predict(object, type = "terms", top_n = top_n) 770 | terminology <- mapply(seq_along(terminology), terminology, FUN = function(topicnr, terminology){ 771 | terminology$cluster <- rep(topicnr, nrow(terminology)) 772 | terminology 773 | }, SIMPLIFY = FALSE) 774 | terminology <- do.call(rbind, terminology) 775 | space.2d.words <- merge(x = terminology, y = data.frame(x = words[, 1], y = words[, 2], term = rownames(words), stringsAsFactors = FALSE), by = "term") 776 | space.2d.centers <- data.frame(x = centers[, 1], y = centers[, 2], term = paste("Cluster-", seq_len(nrow(centers)), sep = ""), cluster = seq_len(nrow(centers)), stringsAsFactors = FALSE) 777 | space.2d.words$type <- rep("words", nrow(space.2d.words)) 778 | space.2d.words <- space.2d.words[order(space.2d.words$cluster, space.2d.words$rank, decreasing = FALSE), ] 779 | space.2d.centers$type <- rep("centers", nrow(space.2d.centers)) 780 | space.2d.centers$rank <- rep(0L, nrow(space.2d.centers)) 781 | space.2d.centers$beta <- rep(NA_real_, nrow(space.2d.centers)) 782 | fields <- c("type", "term", "cluster", "rank", "beta", "x", "y") 783 | df <- rbind(space.2d.words[, fields], space.2d.centers[, fields]) 784 | df <- split(df, df$cluster) 785 | df <- lapply(df, FUN = function(x){ 786 | x$weight <- ifelse(is.na(x$beta), 0.8, x$beta / max(x$beta, na.rm = TRUE)) 787 | x 788 | }) 789 | df <- do.call(rbind, df) 790 | rownames(df) <- NULL 791 | list(center = centers, words = words, embed_2d = df) 792 | }else{ 793 | .NotYetImplemented() 794 | } 795 | } 796 | 797 | 798 | --------------------------------------------------------------------------------