├── .Rbuildignore ├── .binder ├── apt.txt ├── install.R ├── postBuild └── runtime.txt ├── .gitattributes ├── .github ├── .gitignore └── workflows │ └── pkgdown.yaml ├── .gitignore ├── CITATION.cff ├── DESCRIPTION ├── LICENSE.md ├── NAMESPACE ├── R ├── lime.R ├── misc.R ├── setup.R └── train.R ├── README.Rmd ├── README.md ├── _pkgdown.yml ├── _quarto.yml ├── data ├── ecosent.rda ├── smallunciviltweets.rda ├── supported_model_types.rda └── unciviltweets.rda ├── inst ├── CITATION ├── grafzahl.yml ├── grafzahl_gpu.yml └── python │ └── st.py ├── man ├── detect_cuda.Rd ├── ecosent.Rd ├── figures │ ├── grafzahl_logo.png │ └── grafzahl_logo.svg ├── get_amharic_data.Rd ├── grafzahl.Rd ├── hydrate.Rd ├── predict.grafzahl.Rd ├── setup_grafzahl.Rd ├── supported_model_types.Rd ├── unciviltweets.Rd └── use_nonconda.Rd ├── methodshub.qmd ├── paper ├── .here ├── aup_logo.pdf ├── azime.md ├── azime.qmd ├── ccr.cls ├── coltekin.md ├── coltekin.qmd ├── dobbrick.md ├── dobbrick.qmd ├── fig1.png ├── fig2-1.pdf ├── grafzahl_sp.bib ├── grafzahl_sp.qmd ├── grafzahl_sp.rmd ├── grafzahl_sp.tex ├── img │ ├── fig-fig1-1.pdf │ ├── fig1.png │ ├── learning-curve-1.png │ └── theocharis-roc-1.png ├── misc │ ├── explore_lime.R │ ├── hindi.R │ ├── movie.R │ ├── multilingual.R │ └── tm.R ├── paper.rmd ├── plot_training.R ├── svm_curve.csv ├── theocharis.md ├── theocharis.qmd ├── vanatteveldt.md └── vanatteveldt.qmd ├── rawdata ├── createdata.R ├── synthetic-labels.csv └── training-data.csv ├── tests ├── testdata │ └── fake │ │ ├── .gitattributes │ │ ├── README.md │ │ ├── config.json │ │ ├── special_tokens_map.json │ │ └── tokenizer_config.json ├── testthat.R └── testthat │ ├── test_grafzahl.R │ └── test_setup.R └── vignettes ├── .gitignore └── grafzahl.Rmd /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^README\.* 2 | ^LICENSE\.md$ 3 | ^rawdata/ 4 | ^data/smallunciviltweets\.rda$ 5 | ^paper/ 6 | ^cran-comments\.md$ 7 | ^CRAN-SUBMISSION$ 8 | ^.gitattributes$ 9 | ^_pkgdown\.yml$ 10 | ^docs$ 11 | ^pkgdown$ 12 | ^\.github$ 13 | ^CITATION\.cff$ 14 | ^install\.R$ 15 | ^postBuild$ 16 | ^apt\.txt$ 17 | ^runtime\.txt$ 18 | ^_quarto\.yml$ 19 | ^\.quarto$ 20 | ^methodshub 21 | ^\.binder$ 22 | -------------------------------------------------------------------------------- /.binder/apt.txt: -------------------------------------------------------------------------------- 1 | zip -------------------------------------------------------------------------------- /.binder/install.R: -------------------------------------------------------------------------------- 1 | install.packages("grafzahl") 2 | -------------------------------------------------------------------------------- /.binder/postBuild: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env -S bash -v 2 | 3 | # determine which version of Quarto to install 4 | QUARTO_VERSION=1.6.39 5 | 6 | # See whether we need to lookup a Quarto version 7 | if [ $QUARTO_VERSION = "prerelease" ]; then 8 | QUARTO_JSON="_prerelease.json" 9 | elif [ $QUARTO_VERSION = "release" ]; then 10 | QUARTO_JSON="_download.json" 11 | fi 12 | 13 | if [ $QUARTO_JSON != "" ]; then 14 | 15 | # create a python script and run it 16 | PYTHON_SCRIPT=_quarto_version.py 17 | if [ -e $PYTHON_SCRIPT ]; then 18 | rm -rf $PYTHON_SCRIPT 19 | fi 20 | 21 | cat > $PYTHON_SCRIPT < 18 | to modern Transformer-based text classification models (Wolf et al., 2020) , 19 | in order to facilitate supervised machine learning for textual data. This package 20 | mimics the behaviors of ''quanteda.textmodels'' and provides a function to setup 21 | the ''Python'' environment to use the pretrained models from ''Hugging Face'' . 22 | More information: .' 23 | authors: 24 | - family-names: Chan 25 | given-names: Chung-hong 26 | email: chainsawtiney@gmail.com 27 | orcid: https://orcid.org/0000-0002-6232-7530 28 | preferred-citation: 29 | type: article 30 | title: 'grafzahl: fine-tuning Transformers for text data from within R.' 31 | authors: 32 | - family-names: Chan 33 | given-names: Chung-hong 34 | email: chainsawtiney@gmail.com 35 | orcid: https://orcid.org/0000-0002-6232-7530 36 | journal: Computational Communication Research 37 | doi: 10.5117/CCR2023.1.003.CHAN 38 | volume: '5' 39 | issue: '1' 40 | year: '2023' 41 | start: 76-84 42 | repository: https://CRAN.R-project.org/package=grafzahl 43 | repository-code: https://github.com/gesistsa/grafzahl 44 | url: https://gesistsa.github.io/grafzahl/ 45 | contact: 46 | - family-names: Chan 47 | given-names: Chung-hong 48 | email: chainsawtiney@gmail.com 49 | orcid: https://orcid.org/0000-0002-6232-7530 50 | references: 51 | - type: software 52 | title: knitr 53 | abstract: 'knitr: A General-Purpose Package for Dynamic Report Generation in R' 54 | notes: Suggests 55 | url: https://yihui.org/knitr/ 56 | repository: https://CRAN.R-project.org/package=knitr 57 | authors: 58 | - family-names: Xie 59 | given-names: Yihui 60 | email: xie@yihui.name 61 | orcid: https://orcid.org/0000-0003-0645-5666 62 | year: '2024' 63 | doi: 10.32614/CRAN.package.knitr 64 | - type: software 65 | title: rmarkdown 66 | abstract: 'rmarkdown: Dynamic Documents for R' 67 | notes: Suggests 68 | url: https://pkgs.rstudio.com/rmarkdown/ 69 | repository: https://CRAN.R-project.org/package=rmarkdown 70 | authors: 71 | - family-names: Allaire 72 | given-names: JJ 73 | email: jj@posit.co 74 | - family-names: Xie 75 | given-names: Yihui 76 | email: xie@yihui.name 77 | orcid: https://orcid.org/0000-0003-0645-5666 78 | - family-names: Dervieux 79 | given-names: Christophe 80 | email: cderv@posit.co 81 | orcid: https://orcid.org/0000-0003-4474-2498 82 | - family-names: McPherson 83 | given-names: Jonathan 84 | email: jonathan@posit.co 85 | - family-names: Luraschi 86 | given-names: Javier 87 | - family-names: Ushey 88 | given-names: Kevin 89 | email: kevin@posit.co 90 | - family-names: Atkins 91 | given-names: Aron 92 | email: aron@posit.co 93 | - family-names: Wickham 94 | given-names: Hadley 95 | email: hadley@posit.co 96 | - family-names: Cheng 97 | given-names: Joe 98 | email: joe@posit.co 99 | - family-names: Chang 100 | given-names: Winston 101 | email: winston@posit.co 102 | - family-names: Iannone 103 | given-names: Richard 104 | email: rich@posit.co 105 | orcid: https://orcid.org/0000-0003-3925-190X 106 | year: '2024' 107 | doi: 10.32614/CRAN.package.rmarkdown 108 | - type: software 109 | title: testthat 110 | abstract: 'testthat: Unit Testing for R' 111 | notes: Suggests 112 | url: https://testthat.r-lib.org 113 | repository: https://CRAN.R-project.org/package=testthat 114 | authors: 115 | - family-names: Wickham 116 | given-names: Hadley 117 | email: hadley@posit.co 118 | year: '2024' 119 | doi: 10.32614/CRAN.package.testthat 120 | version: '>= 3.0.0' 121 | - type: software 122 | title: withr 123 | abstract: 'withr: Run Code ''With'' Temporarily Modified Global State' 124 | notes: Suggests 125 | url: https://withr.r-lib.org 126 | repository: https://CRAN.R-project.org/package=withr 127 | authors: 128 | - family-names: Hester 129 | given-names: Jim 130 | - family-names: Henry 131 | given-names: Lionel 132 | email: lionel@posit.co 133 | - family-names: Müller 134 | given-names: Kirill 135 | email: krlmlr+r@mailbox.org 136 | - family-names: Ushey 137 | given-names: Kevin 138 | email: kevinushey@gmail.com 139 | - family-names: Wickham 140 | given-names: Hadley 141 | email: hadley@posit.co 142 | - family-names: Chang 143 | given-names: Winston 144 | year: '2024' 145 | doi: 10.32614/CRAN.package.withr 146 | - type: software 147 | title: jsonlite 148 | abstract: 'jsonlite: A Simple and Robust JSON Parser and Generator for R' 149 | notes: Imports 150 | url: https://jeroen.r-universe.dev/jsonlite 151 | repository: https://CRAN.R-project.org/package=jsonlite 152 | authors: 153 | - family-names: Ooms 154 | given-names: Jeroen 155 | email: jeroenooms@gmail.com 156 | orcid: https://orcid.org/0000-0002-4035-0289 157 | year: '2024' 158 | doi: 10.32614/CRAN.package.jsonlite 159 | - type: software 160 | title: lime 161 | abstract: 'lime: Local Interpretable Model-Agnostic Explanations' 162 | notes: Imports 163 | url: https://lime.data-imaginist.com 164 | repository: https://CRAN.R-project.org/package=lime 165 | authors: 166 | - family-names: Hvitfeldt 167 | given-names: Emil 168 | email: emilhhvitfeldt@gmail.com 169 | orcid: https://orcid.org/0000-0002-0679-1945 170 | - family-names: Pedersen 171 | given-names: Thomas Lin 172 | email: thomasp85@gmail.com 173 | orcid: https://orcid.org/0000-0002-5147-4711 174 | - family-names: Benesty 175 | given-names: Michaël 176 | email: michael@benesty.fr 177 | year: '2024' 178 | doi: 10.32614/CRAN.package.lime 179 | - type: software 180 | title: quanteda 181 | abstract: 'quanteda: Quantitative Analysis of Textual Data' 182 | notes: Imports 183 | url: https://quanteda.io 184 | repository: https://CRAN.R-project.org/package=quanteda 185 | authors: 186 | - family-names: Benoit 187 | given-names: Kenneth 188 | email: kbenoit@lse.ac.uk 189 | orcid: https://orcid.org/0000-0002-0797-564X 190 | - family-names: Watanabe 191 | given-names: Kohei 192 | email: watanabe.kohei@gmail.com 193 | orcid: https://orcid.org/0000-0001-6519-5265 194 | - family-names: Wang 195 | given-names: Haiyan 196 | email: whyinsa@yahoo.com 197 | orcid: https://orcid.org/0000-0003-4992-4311 198 | - family-names: Nulty 199 | given-names: Paul 200 | email: paul.nulty@gmail.com 201 | orcid: https://orcid.org/0000-0002-7214-4666 202 | - family-names: Obeng 203 | given-names: Adam 204 | email: quanteda@binaryeagle.com 205 | orcid: https://orcid.org/0000-0002-2906-4775 206 | - family-names: Müller 207 | given-names: Stefan 208 | email: stefan.mueller@ucd.ie 209 | orcid: https://orcid.org/0000-0002-6315-4125 210 | - family-names: Matsuo 211 | given-names: Akitaka 212 | email: a.matsuo@essex.ac.uk 213 | orcid: https://orcid.org/0000-0002-3323-6330 214 | - family-names: Lowe 215 | given-names: William 216 | email: lowe@hertie-school.org 217 | orcid: https://orcid.org/0000-0002-1549-6163 218 | year: '2024' 219 | doi: 10.32614/CRAN.package.quanteda 220 | - type: software 221 | title: reticulate 222 | abstract: 'reticulate: Interface to ''Python''' 223 | notes: Imports 224 | url: https://rstudio.github.io/reticulate/ 225 | repository: https://CRAN.R-project.org/package=reticulate 226 | authors: 227 | - family-names: Ushey 228 | given-names: Kevin 229 | email: kevin@posit.co 230 | - family-names: Allaire 231 | given-names: JJ 232 | email: jj@posit.co 233 | - family-names: Tang 234 | given-names: Yuan 235 | email: terrytangyuan@gmail.com 236 | orcid: https://orcid.org/0000-0001-5243-233X 237 | year: '2024' 238 | doi: 10.32614/CRAN.package.reticulate 239 | - type: software 240 | title: utils 241 | abstract: 'R: A Language and Environment for Statistical Computing' 242 | notes: Imports 243 | authors: 244 | - name: R Core Team 245 | institution: 246 | name: R Foundation for Statistical Computing 247 | address: Vienna, Austria 248 | year: '2024' 249 | - type: software 250 | title: stats 251 | abstract: 'R: A Language and Environment for Statistical Computing' 252 | notes: Imports 253 | authors: 254 | - name: R Core Team 255 | institution: 256 | name: R Foundation for Statistical Computing 257 | address: Vienna, Austria 258 | year: '2024' 259 | - type: software 260 | title: 'R: A Language and Environment for Statistical Computing' 261 | notes: Depends 262 | url: https://www.R-project.org/ 263 | authors: 264 | - name: R Core Team 265 | institution: 266 | name: R Foundation for Statistical Computing 267 | address: Vienna, Austria 268 | year: '2024' 269 | version: '>= 3.5' 270 | 271 | -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: grafzahl 2 | Title: Supervised Machine Learning for Textual Data Using Transformers and 'Quanteda' 3 | Version: 0.0.11 4 | Authors@R: 5 | person("Chung-hong", "Chan", , "chainsawtiney@gmail.com", role = c("aut", "cre"), 6 | comment = c(ORCID = "0000-0002-6232-7530")) 7 | Description: Duct tape the 'quanteda' ecosystem (Benoit et al., 2018) to modern Transformer-based text classification models (Wolf et al., 2020) , in order to facilitate supervised machine learning for textual data. This package mimics the behaviors of 'quanteda.textmodels' and provides a function to setup the 'Python' environment to use the pretrained models from 'Hugging Face' . More information: . 8 | License: GPL (>= 3) 9 | Encoding: UTF-8 10 | Roxygen: list(markdown = TRUE) 11 | RoxygenNote: 7.3.1 12 | URL: https://gesistsa.github.io/grafzahl/, https://github.com/gesistsa/grafzahl 13 | BugReports: https://github.com/gesistsa/grafzahl/issues 14 | Suggests: 15 | knitr, 16 | quanteda.textmodels, 17 | rmarkdown, 18 | testthat (>= 3.0.0), 19 | withr 20 | Config/testthat/edition: 3 21 | Imports: 22 | jsonlite, 23 | lime, 24 | quanteda, 25 | reticulate, 26 | utils, 27 | stats 28 | LazyData: true 29 | Depends: 30 | R (>= 3.5) 31 | VignetteBuilder: knitr 32 | Config/Needs/website: gesistsa/tsatemplate 33 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | S3method(grafzahl,character) 4 | S3method(grafzahl,corpus) 5 | S3method(grafzahl,default) 6 | S3method(model_type,grafzahl) 7 | S3method(predict,grafzahl) 8 | S3method(predict_model,grafzahl) 9 | S3method(print,grafzahl) 10 | export(detect_conda) 11 | export(detect_cuda) 12 | export(get_amharic_data) 13 | export(grafzahl) 14 | export(hydrate) 15 | export(setup_grafzahl) 16 | export(textmodel_transformer) 17 | export(use_nonconda) 18 | importFrom(lime,model_type) 19 | importFrom(lime,predict_model) 20 | importFrom(stats,predict) 21 | importFrom(stats,runif) 22 | importFrom(utils,download.file) 23 | importFrom(utils,installed.packages) 24 | importFrom(utils,tail) 25 | -------------------------------------------------------------------------------- /R/lime.R: -------------------------------------------------------------------------------- 1 | #' @importFrom lime predict_model 2 | #' @method predict_model grafzahl 3 | #' @export 4 | predict_model.grafzahl <- function(x, newdata, type, ...) { 5 | if (!requireNamespace('grafzahl', quietly = TRUE)) { 6 | stop('grafzahl must be available when working with grafzahl models') 7 | } 8 | if (type == "raw") { 9 | res <- predict(x, newdata = newdata, return_raw = FALSE, ...) 10 | return(data.frame(Response = as.character(res), stringsAsFactors = FALSE)) 11 | } else if (type == "prob") { 12 | res <- predict(x, newdata = newdata, return_raw = TRUE, ...) 13 | ey <- exp(res) 14 | output <- as.data.frame(ey / apply(ey, 1, sum)) 15 | colnames(output) <- x$levels 16 | return(output) 17 | } else { 18 | stop("Unknown `type`.") 19 | } 20 | } 21 | 22 | #' @importFrom lime model_type 23 | #' @method model_type grafzahl 24 | #' @export 25 | model_type.grafzahl <- function(x, ...) { 26 | if (x$regression) { 27 | return("regression") 28 | } else { 29 | return("classification") 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /R/misc.R: -------------------------------------------------------------------------------- 1 | utils::globalVariables(c("py_train", "py_predict", "py_detect_cuda")) 2 | 3 | #' @importFrom stats predict runif 4 | #' @importFrom utils download.file installed.packages tail 5 | NULL 6 | 7 | #' A Corpus Of Tweets With Incivility Labels 8 | #' 9 | #' This is a dataset from the paper "The Dynamics of Political Incivility on Twitter". The tweets were by Members of Congress elected to the 115th Congress (2017–2018). It is important to note that not all the incivility labels were coded by human. Majority of the labels were coded by the Google Perspective API. All mentions were removed. The dataset is available from Pablo Barbera's Github. 10 | #' 11 | #' @references 12 | #' Theocharis, Y., Barberá, P., Fazekas, Z., & Popa, S. A. (2020). The dynamics of political incivility on Twitter. Sage Open, 10(2), 2158244020919447. 13 | "unciviltweets" 14 | 15 | #' A Corpus Of Dutch News Headlines 16 | #' 17 | #' This is a dataset from the paper "The Validity of Sentiment Analysis: Comparing Manual Annotation, Crowd-Coding, Dictionary Approaches, and Machine Learning Algorithms." 18 | #' The data frame contains four columns: id (identifier), headline (the actual text data), value (sentiment: 0 Neutral, +1 Positive, -1 Negative), gold (whether or not this row is "gold standard", i.e. test set). The data is available from Wouter van Atteveldt's Github. 19 | #' @references 20 | #' Van Atteveldt, W., Van der Velden, M. A., & Boukes, M. (2021). The validity of sentiment analysis: Comparing manual annotation, crowd-coding, dictionary approaches, and machine learning algorithms. Communication Methods and Measures, 15(2), 121-140. 21 | "ecosent" 22 | 23 | #' Supported model types 24 | #' 25 | #' A vector of all supported model types. 26 | #' 27 | "supported_model_types" 28 | 29 | #' Download The Amharic News Text Classification Dataset 30 | #' 31 | #' This function downloads the training and test sets of the Amharic News Text Classification Dataset from Hugging Face. 32 | #' 33 | #' @return A named list of two corpora: training and test 34 | #' @references Azime, Israel Abebe, and Nebil Mohammed (2021). "An Amharic News Text classification Dataset." arXiv preprint arXiv:2103.05639 35 | #' @export 36 | get_amharic_data <- function() { 37 | current_tempdir <- tempdir() 38 | download.file("https://huggingface.co/datasets/israel/Amharic-News-Text-classification-Dataset/resolve/main/train.csv", 39 | destfile = file.path(current_tempdir, "train.csv")) 40 | download.file("https://huggingface.co/datasets/israel/Amharic-News-Text-classification-Dataset/resolve/main/test.csv", 41 | destfile = file.path(current_tempdir, "test.csv")) 42 | training <- utils::read.csv(file.path(current_tempdir, "train.csv")) 43 | test <- utils::read.csv(file.path(current_tempdir, "test.csv")) 44 | training <- training[training$category != "",] 45 | test <- test[test$category != "",] 46 | training_corpus <- quanteda::corpus(training$article) 47 | quanteda::docvars(training_corpus, "category") <- training$category 48 | test_corpus <- quanteda::corpus(test$article) 49 | quanteda::docvars(test_corpus, "category") <- test$category 50 | return(list(training = training_corpus, test = test_corpus)) 51 | } 52 | 53 | ## .count_lang <- function(x, force_cld2 = FALSE, sampl = NULL, safety_multiplier = 1.5) { 54 | ## if (isTRUE("cld3" %in% row.names(installed.packages())) & !force_cld2) { 55 | ## trans_fun <- cld3::detect_language 56 | ## } else { 57 | ## trans_fun <- cld2::detect_language 58 | ## } 59 | ## if (length(x) > 5000 & !is.null(sampl)) { 60 | ## if (sampl >= 1 | sampl <= 0) { 61 | ## stop("`sampl` must be < 1 and > 0.") 62 | ## } 63 | ## x <- sample(x, ceiling(length(x) * sampl)) 64 | ## } 65 | ## longest <- max(quanteda::ntoken(x, what = "word"), na.rm = TRUE) * safety_multiplier 66 | ## res <- trans_fun(x) 67 | ## non_na_res <- res[!is.na(res)] 68 | ## freqcount <- table(non_na_res) 69 | ## majority_lang <- tail(names(sort(freqcount)), 1) 70 | ## majority_n <- max(table(non_na_res)) 71 | ## ratio <- majority_n / length(non_na_res) 72 | ## caseness <- stringi::stri_detect_charclass(x, "\\p{Lu}") 73 | ## output <- list(n = length(res), valid_n = length(non_na_res), 74 | ## maj_lang = majority_lang, maj_ratio = ratio, 75 | ## caseness = any(caseness), longest = longest) 76 | ## return(output) 77 | ## } 78 | 79 | -------------------------------------------------------------------------------- /R/setup.R: -------------------------------------------------------------------------------- 1 | .is_windows <- function() { 2 | Sys.info()[['sysname']] == "Windows" 3 | } 4 | 5 | .gen_conda_path <- function(envvar = "GRAFZAHL_MINICONDA_PATH", bin = FALSE) { 6 | if (Sys.getenv(envvar) == "") { 7 | main_path <- reticulate::miniconda_path() 8 | } else { 9 | main_path <- Sys.getenv(envvar) 10 | } 11 | if (isFALSE(bin)) { 12 | return(main_path) 13 | } 14 | if (.is_windows()) { 15 | return(file.path(main_path, "Scripts", "conda.exe")) 16 | } 17 | file.path(main_path, "bin", "conda") 18 | } 19 | 20 | ## list all conda envs, but restrict to .gen_conda_path 21 | ## Should err somehow 22 | .list_condaenvs <- function() { 23 | all_condaenvs <- reticulate::conda_list(conda = .gen_conda_path(bin = TRUE)) 24 | ##if (.is_windows()) { 25 | return(all_condaenvs$name) 26 | ##} 27 | ##all_condaenvs[grepl(.gen_conda_path(), all_condaenvs$python),]$name 28 | } 29 | 30 | .have_conda <- function() { 31 | ## !is.null(tryCatch(reticulate::conda_list(), error = function(e) NULL)) 32 | ## Not a very robust test, but take it. 33 | file.exists(.gen_conda_path(bin = TRUE)) 34 | } 35 | 36 | #' @rdname detect_cuda 37 | #' @export 38 | detect_conda <- function() { 39 | if(!.have_conda()) { 40 | return(FALSE) 41 | } 42 | envnames <- grep("^grafzahl_condaenv", .list_condaenvs(), value = TRUE) 43 | length(envnames) != 0 44 | } 45 | 46 | .gen_envname <- function(cuda = TRUE) { 47 | envname <- "grafzahl_condaenv" 48 | if (cuda) { 49 | envname <- paste0(envname, "_cuda") 50 | } 51 | return(envname) 52 | } 53 | 54 | .initialize_python <- function(envname, verbose = FALSE) { 55 | if (is.null(getOption("python_init")) && isTRUE(getOption("grafzahl.nonconda"))) { 56 | options("python_init" = TRUE) 57 | .say(verbose = verbose, "[Non-conda MODE] Use a non-conda Python environment. The environment is not checked.") 58 | return(invisible(NULL)) 59 | } 60 | if (is.null(getOption("python_init"))) { 61 | if (.is_windows()) { 62 | python_executable <- file.path(.gen_conda_path(), "envs", envname, "python.exe") 63 | } else { 64 | python_executable <- file.path(.gen_conda_path(), "envs", envname, "bin", "python") 65 | } 66 | ## Until rstydio/reticulate#1308 is fixed; mask it for now 67 | Sys.setenv(RETICULATE_MINICONDA_PATH = .gen_conda_path()) 68 | reticulate::use_miniconda(python_executable, required = TRUE) 69 | options("python_init" = TRUE) 70 | .say(verbose = verbose, "Conda environment ", envname, " is initialized.") 71 | } 72 | return(invisible(NULL)) 73 | } 74 | 75 | #' Detecting Miniconda And Cuda 76 | #' 77 | #' These functions detects miniconda and cuda. 78 | #' 79 | #' `detect_conda` conducts a test to check whether 1) a miniconda installation and 2) the grafzahl miniconda environment exist. 80 | #' 81 | #' `detect_cuda` checks whether cuda is available. If `setup_grafzahl` was executed with `cuda` being `FALSE`, this function will return `FALSE`. Even if `setup_grafzahl` was executed with `cuda` being `TRUE` but with any factor that can't enable cuda (e.g. no Nvidia GPU, the environment was incorrectly created), this function will also return `FALSE`. 82 | #' @return boolean, whether the system is available. 83 | #' @export 84 | detect_cuda <- function() { 85 | options("python_init" = NULL) 86 | if (Sys.getenv("KILL_SWITCH") == "KILL") { 87 | return(NA) 88 | } 89 | if (is.null(getOption("grafzahl.nonconda"))) { 90 | envnames <- grep("^grafzahl_condaenv", .list_condaenvs(), value = TRUE) 91 | if (length(envnames) == 0) { 92 | stop("No conda environment found. Run `setup_grafzahl` to bootstrap one.") 93 | } 94 | if ("grafzahl_condaenv_cuda" %in% envnames) { 95 | envname <- "grafzahl_condaenv_cuda" 96 | } else { 97 | envname <- "grafzahl_condaenv" 98 | } 99 | } else { 100 | envname <- NA 101 | } 102 | .initialize_python(envname = envname, verbose = FALSE) 103 | reticulate::source_python(system.file("python", "st.py", package = "grafzahl")) 104 | return(py_detect_cuda()) 105 | } 106 | 107 | .install_gpu_pytorch <- function(cuda_version) { 108 | .initialize_python(.gen_envname(cuda = TRUE)) 109 | conda_executable <- .gen_conda_path(bin = TRUE) 110 | status <- system2(conda_executable, args = c("install", "-n", .gen_envname(cuda = TRUE), "pytorch", "pytorch-cuda", paste0("cudatoolkit=", cuda_version), "-c", "pytorch", "-c", "nvidia", "-y")) 111 | if (status != 0) { 112 | stop("Cannot set up `pytorch`.") 113 | } 114 | python_executable <- reticulate::py_config()$python 115 | status <- system2(python_executable, args = c("-m", "pip", "install", "simpletransformers", "\"transformers\"", "\"scipy\"")) 116 | if (status != 0) { 117 | stop("Cannot set up `simpletransformers`.") 118 | } 119 | } 120 | 121 | #' Setup grafzahl 122 | #' 123 | #' Install a self-contained miniconda environment with all Python components (PyTorch, Transformers, Simpletransformers, etc) which grafzahl required. The default location is "~/.local/share/r-miniconda/envs/grafzahl_condaenv" (suffix "_cuda" is added if `cuda` is `TRUE`). 124 | #' On Linux or Mac and if miniconda is not found, this function will also install miniconda. The path can be changed by the environment variable `GRAFZAHL_MINICONDA_PATH` 125 | #' @param cuda logical, if `TRUE`, indicate whether a CUDA-enabled environment is wanted. 126 | #' @param force logical, if `TRUE`, delete previous environment (if exists) and create a new environment 127 | #' @param cuda_version character, indicate CUDA version, ignore if `cuda` is `FALSE` 128 | #' @examples 129 | #' # setup an environment with cuda enabled. 130 | #' if (detect_conda() && interactive()) { 131 | #' setup_grafzahl(cuda = TRUE) 132 | #' } 133 | #' @return TRUE (invisibly) if installation is successful. 134 | #' @export 135 | setup_grafzahl <- function(cuda = FALSE, force = FALSE, cuda_version = "11.3") { 136 | envname <- .gen_envname(cuda = cuda) 137 | if (!.have_conda()) { 138 | if (!force) { 139 | message("No conda was found in ", .gen_conda_path()) 140 | ans <- utils::menu(c("No", "Yes"), title = paste0("Do you want to install miniconda in ", .gen_conda_path())) 141 | if (ans == 1) { 142 | stop("Setup aborted.\n") 143 | } 144 | } 145 | reticulate::install_miniconda(.gen_conda_path(bin = FALSE), update = TRUE, force = TRUE) 146 | } 147 | allenvs <- .list_condaenvs() 148 | if (envname %in% allenvs && !force) { 149 | stop(paste0("Conda environment ", envname, " already exists.\nForce reinstallation by setting `force` to `TRUE`.\n")) 150 | } 151 | if (envname %in% allenvs && force) { 152 | reticulate::conda_remove(envname = envname, conda = .gen_conda_path(bin = TRUE)) 153 | } 154 | ## The actual installation 155 | ## https://github.com/rstudio/reticulate/issues/779 156 | ##conda_executable <- file.path(.gen_conda_path(), "bin/conda") 157 | if (isTRUE(cuda)) { 158 | yml_file <- "grafzahl_gpu.yml" 159 | } else { 160 | yml_file <- "grafzahl.yml" 161 | } 162 | status <- system2(.gen_conda_path(bin = TRUE), args = c("env", "create", paste0("-f=", system.file(yml_file, package = 'grafzahl')), "-n", envname)) 163 | if (status != 0) { 164 | stop("Cannot set up the basic conda environment.") 165 | } 166 | if (isTRUE(cuda)) { 167 | .install_gpu_pytorch(cuda_version = cuda_version) 168 | } 169 | ## Post-setup checks 170 | if (!detect_conda()) { 171 | stop("Conda can't be detected.") 172 | } 173 | if (detect_cuda() != cuda) { 174 | stop("Cuda wasn't configurated correctly.") 175 | } 176 | return(invisible(TRUE)) 177 | } 178 | 179 | #' Set up grafzahl to be used on Google Colab or similar environments 180 | #' 181 | #' Set up grafzahl to be used on Google Colab or similar environments. This function is also useful if you do not 182 | #' want to use conda on a local machine, e.g. you have configurateed the required Python package. 183 | #' 184 | #' @param install logical, whether to install the required Python packages 185 | #' @param check logical, whether to perform a check after the setup. The check displays 1) whether CUDA can be detected, 2) whether 186 | #' the non-conda mode has been activated, i.e. whether the option 'grafzahl.nonconda' is `TRUE`. 187 | #' @param verbose, logical, whether to display messages 188 | #' @examples 189 | #' # A typical use case for Google Colab 190 | #' if (interactive()) { 191 | #' use_nonconda() 192 | #' } 193 | #' @return TRUE (invisibly) if installation is successful. 194 | #' @export 195 | use_nonconda <- function(install = TRUE, check = TRUE, verbose = TRUE) { 196 | if (install) { 197 | system("python3 -m pip install simpletransformers emoji", intern = TRUE, ignore.stdout = !verbose) 198 | } 199 | options("grafzahl.nonconda" = TRUE) 200 | if (check) { 201 | .say(verbose, "Post-setup Check:") 202 | .say(verbose, "CUDA detected: ", detect_cuda()) 203 | .say(verbose, "Non-conda mode activated: ", isTRUE(getOption("grafzahl.nonconda"))) 204 | } 205 | return(invisible(TRUE)) 206 | } 207 | -------------------------------------------------------------------------------- /R/train.R: -------------------------------------------------------------------------------- 1 | .say <- function(verbose, ...) { 2 | if (isTRUE(verbose)) { 3 | message(paste(..., sep = " ")) 4 | } 5 | invisible() 6 | } 7 | 8 | .download_from_huggingface <- function(model_name, json_file = tempfile()) { 9 | json_url <- paste0("https://huggingface.co/", model_name, "/raw/main/config.json") 10 | tryCatch({ 11 | suppressWarnings(download.file(url = json_url, destfile = json_file, quiet = TRUE)) 12 | }, error = function(e) { 13 | stop("Fail to download the model `", model_name, "` from Hugging Face", call. = FALSE) 14 | }) 15 | return(json_file) 16 | } 17 | 18 | .infer_model_type <- function(model_name) { 19 | if (!dir.exists(model_name)) { 20 | json_file <- .download_from_huggingface(model_name) 21 | } else { 22 | json_file <- file.path(model_name, "config.json") 23 | } 24 | jsonlite::fromJSON(json_file)$model_type 25 | } 26 | 27 | .check_model_type <- function(model_type, model_name) { 28 | if (missing(model_name)) { 29 | stop("You must provide `model_name`", call. = FALSE) 30 | } 31 | if (is.null(model_type)) { 32 | model_type <- .infer_model_type(model_name) 33 | } 34 | model_type <- gsub("-", "", tolower(model_type)) 35 | if (!model_type %in% grafzahl::supported_model_types) { 36 | stop("Invalid `model_type`.", call. = FALSE) 37 | } 38 | return(model_type) 39 | } 40 | 41 | ## Create a factor-like thing that is zero-indexed. It should work with numeric vectors, char vectors and factors. 42 | ## The output is a vector 43 | .make_0i <- function(x) { 44 | levels <- unique(x) 45 | matching_levels <- seq(0, (length(levels) - 1)) 46 | res <- matching_levels[match(x, levels)] 47 | attr(res, "levels") <- levels 48 | return(res) 49 | } 50 | 51 | .restore_0i <- function(x) { 52 | attr(x, "levels")[x + 1] 53 | } 54 | 55 | .prepare_y <- function(y, x) { 56 | if (is.null(y)) { 57 | if (ncol(quanteda::docvars(x)) == 1) { 58 | return(as.vector(quanteda::docvars(x)[,1])) 59 | } else { 60 | stop("Please either specify `y` or set exactly one `docvars` in `x`.", call. = FALSE) 61 | } 62 | } 63 | if (length(y) == 1) { 64 | ## It should be a docvars name, but it's better check 65 | if (!y %in% colnames(quanteda::docvars(x))) { 66 | stop(paste0(y, " is not a docvar.")) 67 | } 68 | return(as.vector(quanteda::docvars(x)[,y])) 69 | } 70 | return(y) 71 | } 72 | 73 | ##stole from quanteda 74 | .generate_meta <- function() { 75 | list("package-version" = utils::packageVersion("grafzahl"), 76 | "r-version" = getRversion(), 77 | "system" = Sys.info()[c("sysname", "machine", "user")], 78 | "directory" = getwd(), 79 | "created" = Sys.Date()) 80 | } 81 | 82 | .generate_random_dir <- function(lowest = 1, highest = 1000) { 83 | random_dir <- file.path(tempdir(), sample(seq(from = lowest, to = highest), 1)) 84 | if (!dir.exists(random_dir)) { 85 | dir.create(random_dir) 86 | } 87 | return(normalizePath(random_dir)) 88 | } 89 | 90 | .create_object <- function(call = NA, input_data = NULL, output_dir, model_type, model_name = NA, regression, levels = NULL, manual_seed = NULL) { 91 | result <- list( 92 | call = call, 93 | input_data = input_data, 94 | output_dir = output_dir, 95 | model_type = model_type, 96 | model_name = model_name, 97 | regression = regression, 98 | levels = levels, 99 | manual_seed = manual_seed, 100 | meta = .generate_meta() 101 | ) 102 | class(result) <- c("grafzahl", "textmodel_transformer", "textmodel", "list") 103 | return(result) 104 | } 105 | 106 | #' Fine tune a pretrained Transformer model for texts 107 | #' 108 | #' Fine tune (or train) a pretrained Transformer model for your given training labelled data `x` and `y`. The prediction task can be classification (if `regression` is `FALSE`, default) or regression (if `regression` is `TRUE`). 109 | #' @param x the [corpus] or character vector of texts on which the model will be trained. Depending on `train_size`, some texts will be used for cross-validation. 110 | #' @param y training labels. It can either be a single string indicating which [docvars] of the [corpus] is the training labels; a vector of training labels in either character or factor; or `NULL` if the [corpus] contains exactly one column in [docvars] and that column is the training labels. If `x` is a character vector, `y` must be a vector of the same length. 111 | #' @param model_name string indicates either 1) the model name on Hugging Face website; 2) the local path of the model 112 | #' @param regression logical, if `TRUE`, the task is regression, classification otherwise. 113 | #' @param output_dir string, location of the output model. If missing, the model will be stored in a temporary directory. Important: Please note that if this directory exists, it will be overwritten. 114 | #' @param cuda logical, whether to use CUDA, default to [detect_cuda()]. 115 | #' @param num_train_epochs numeric, if `train_size` is not exactly 1.0, the maximum number of epochs to try in the "early stop" regime will be this number times 5 (i.e. 4 * 5 = 20 by default). If `train_size` is exactly 1.0, the number of epochs is exactly that. 116 | #' @param train_size numeric, proportion of data in `x` and `y` to be used actually for training. The rest will be used for cross validation. 117 | #' @param args list, additionally parameters to be used in the underlying simple transformers 118 | #' @param cleanup logical, if `TRUE`, the `runs` directory generated will be removed when the training is done 119 | #' @param model_type a string indicating model_type of the input model. If `NULL`, it will be inferred from `model_name`. Supported model types are available in [supported_model_types]. 120 | #' @param manual_seed numeric, random seed 121 | #' @param verbose logical, if `TRUE`, debug messages will be displayed 122 | #' @param ... paramters pass to [grafzahl()] 123 | #' @return a `grafzahl` S3 object with the following items 124 | #' \item{call}{original function call} 125 | #' \item{input_data}{input_data for the underlying python function} 126 | #' \item{output_dir}{location of the output model} 127 | #' \item{model_type}{model type} 128 | #' \item{model_name}{model name} 129 | #' \item{regression}{whether or not it is a regression model} 130 | #' \item{levels}{factor levels of y} 131 | #' \item{manual_seed}{random seed} 132 | #' \item{meta}{metadata about the current session} 133 | #' @examples 134 | #' if (detect_conda() && interactive()) { 135 | #' library(quanteda) 136 | #' set.seed(20190721) 137 | #' ## Using the default cross validation method 138 | #' model1 <- grafzahl(unciviltweets, model_type = "bertweet", model_name = "vinai/bertweet-base") 139 | #' predict(model1) 140 | #' 141 | #' ## Using LIME 142 | #' input <- corpus(ecosent, text_field = "headline") 143 | #' training_corpus <- corpus_subset(input, !gold) 144 | #' model2 <- grafzahl(x = training_corpus, 145 | #' y = "value", 146 | #' model_name = "GroNLP/bert-base-dutch-cased") 147 | #' test_corpus <- corpus_subset(input, gold) 148 | #' predicted_sentiment <- predict(model2, test_corpus) 149 | #' require(lime) 150 | #' sentences <- c("Dijsselbloem pessimistisch over snelle stappen Grieken", 151 | #' "Aandelenbeurzen zetten koersopmars voort") 152 | #' explainer <- lime(training_corpus, model2) 153 | #' explanations <- explain(sentences, explainer, n_labels = 1, 154 | #' n_features = 2) 155 | #' plot_text_explanations(explanations) 156 | #' } 157 | #' @seealso [predict.grafzahl()] 158 | #' @export 159 | grafzahl <- function(x, y = NULL, model_name = "xlm-roberta-base", 160 | regression = FALSE, output_dir, cuda = detect_cuda(), num_train_epochs = 4, 161 | train_size = 0.8, args = NULL, cleanup = TRUE, model_type = NULL, 162 | manual_seed = floor(runif(1, min = 1, max = 721831)), verbose = TRUE) { 163 | UseMethod("grafzahl") 164 | } 165 | 166 | #' @rdname grafzahl 167 | #' @export 168 | grafzahl.default <- function(x, y = NULL, model_name = "xlm-roberta-base", 169 | regression = FALSE, output_dir, cuda = detect_cuda(), num_train_epochs = 4, 170 | train_size = 0.8, args = NULL, cleanup = TRUE, model_type = NULL, 171 | manual_seed = floor(runif(1, min = 1, max = 721831)), verbose = TRUE) { 172 | return(invisible(NULL)) 173 | } 174 | 175 | #' @rdname grafzahl 176 | #' @export 177 | grafzahl.corpus <- function(x, y = NULL, model_name = "xlm-roberta-base", 178 | regression = FALSE, output_dir, cuda = detect_cuda(), num_train_epochs = 4, 179 | train_size = 0.8, args = NULL, cleanup = TRUE, model_type = NULL, 180 | manual_seed = floor(runif(1, min = 1, max = 721831)), verbose = TRUE) { 181 | if (quanteda::ndoc(x) <= 1) { 182 | stop("Too few documents.") 183 | } 184 | y <- .prepare_y(y, x) 185 | if (!regression) { 186 | y <- .make_0i(y) 187 | num_labels <- length(attr(y, "levels")) 188 | levels <- attr(y, "levels") 189 | } else { 190 | num_labels <- 1L 191 | levels <- NULL 192 | } 193 | if (!is.integer(manual_seed)) { 194 | manual_seed <- as.integer(manual_seed) 195 | } 196 | model_type <- .check_model_type(model_type = model_type, model_name = model_name) 197 | input_data <- data.frame("text" = as.vector(x), "label" = as.vector(y)) 198 | if (Sys.getenv("KILL_SWITCH") == "KILL") { 199 | return(NA) 200 | } 201 | if (missing(output_dir)) { 202 | output_dir <- .generate_random_dir() 203 | .say(verbose, "No `output_dir` provided. The output model will be written to:", output_dir, "\n") 204 | } else { 205 | if (dir.exists(output_dir)) { 206 | .say(verbose, output_dir, "exists. Will be overwritten.\n") 207 | } else { 208 | dir.create(output_dir) 209 | .say(verbose, output_dir, "created.\n") 210 | } 211 | output_dir <- normalizePath(output_dir) 212 | } 213 | best_model_dir <- file.path(output_dir, "best_model") 214 | cache_dir <- .generate_random_dir(9999, 300000) 215 | .initialize_python(envname = .gen_envname(cuda = cuda), verbose = verbose) 216 | reticulate::source_python(system.file("python", "st.py", package = "grafzahl")) 217 | if (isTRUE(getOption("grafzahl.nonconda"))) { 218 | .say(verbose, "[Non-conda MODE] If you are running this on Google Colab, you will not see the training progress.") 219 | } 220 | py_train(data = input_data, num_labels = num_labels, output_dir = output_dir, best_model_dir = best_model_dir, cache_dir = cache_dir, model_type = model_type, model_name = model_name, num_train_epochs = num_train_epochs, train_size = train_size, manual_seed = manual_seed, regression = regression, verbose = verbose) 221 | if (cleanup && dir.exists(file.path("./", "runs"))) { 222 | unlink(file.path("./", "runs"), recursive = TRUE, force = TRUE) 223 | } 224 | result <- .create_object(call = match.call(), 225 | input_data = input_data, 226 | output_dir = output_dir, 227 | model_type = model_type, 228 | model_name = model_name, 229 | regression = regression, 230 | levels = levels, 231 | manual_seed = manual_seed) 232 | return(result) 233 | } 234 | 235 | #' @rdname grafzahl 236 | #' @export 237 | textmodel_transformer <- function(...) { 238 | grafzahl(...) 239 | } 240 | 241 | #' @rdname grafzahl 242 | #' @export 243 | grafzahl.character <- function(x, y = NULL, model_name = "xlmroberta", 244 | regression = FALSE, output_dir, cuda = detect_cuda(), num_train_epochs = 4, 245 | train_size = 0.8, args = NULL, cleanup = TRUE, model_type = NULL, 246 | manual_seed = floor(runif(1, min = 1, max = 721831)), verbose = TRUE) { 247 | if (is.null(y)) { 248 | stop("`y` cannot be NULL when x is a character vector.", call. = FALSE) 249 | } 250 | if (length(x) != length(y)) { 251 | stop("`y` must have the same length as `x`.", call. = FALSE) 252 | } 253 | grafzahl(x = quanteda::corpus(x), y = y, model_type = model_type, model_name = model_name, regression = regression, 254 | output_dir = output_dir, cuda = cuda, num_train_epochs = num_train_epochs, train_size = train_size, 255 | args = args, cleanup = cleanup, manual_seed = manual_seed, verbose = verbose) 256 | } 257 | 258 | 259 | #' Prediction from a fine-tuned grafzahl object 260 | #' 261 | #' Make prediction from a fine-tuned grafzahl object. 262 | #' @param object an S3 object trained with [grafzahl()] 263 | #' @param newdata a [corpus] or a character vector of texts on which prediction should be made. 264 | #' @inheritParams grafzahl 265 | #' @param return_raw logical, if `TRUE`, return a matrix of logits; a vector of class prediction otherwise 266 | #' @param ... not used 267 | #' @return a vector of class prediction or a matrix of logits 268 | #' @method predict grafzahl 269 | #' @export 270 | predict.grafzahl <- function(object, newdata, cuda = detect_cuda(), return_raw = FALSE, ...) { 271 | if (missing(newdata)) { 272 | if (!is.data.frame(object$input_data)) { 273 | stop("`newdata` is missing. And no input data in the `grafzahl` object.", call. = FALSE) 274 | } 275 | newdata <- object$input_data$text 276 | } 277 | if (Sys.getenv("KILL_SWITCH") == "KILL") { 278 | return(NA) 279 | } 280 | .initialize_python(envname = .gen_envname(cuda = cuda), verbose = FALSE) 281 | reticulate::source_python(system.file("python", "st.py", package = "grafzahl")) 282 | res <- py_predict(to_predict = newdata, model_type = object$model_type, output_dir = object$output_dir, return_raw = return_raw, use_cuda = cuda) 283 | if (return_raw || is.null(object$levels)) { 284 | return(res) 285 | } 286 | return(object$levels[res + 1]) 287 | } 288 | 289 | #' @method print grafzahl 290 | #' @export 291 | print.grafzahl <- function(x, ...) { 292 | if (is.data.frame(x$input_data)) { 293 | n_training <- nrow(x$input_data) 294 | } 295 | cat("\nCall:\n") 296 | print(x$call) 297 | cat("\n", 298 | "output_dir:", x$output_dir, ";", 299 | "model_type:", x$model_type, ";", 300 | "model_name:", x$model_name, ";", 301 | n_training, "training documents; ", 302 | "\n", sep = " ") 303 | } 304 | 305 | #' Create a grafzahl S3 object from the output_dir 306 | #' 307 | #' Create a grafzahl S3 object from the output_dir 308 | #' @inherit grafzahl return params 309 | #' @export 310 | hydrate <- function(output_dir, model_type = NULL, regression = FALSE) { 311 | if (missing(output_dir)) { 312 | stop("You must provide `output_dir`") 313 | } 314 | model_type <- .check_model_type(model_type = model_type, model_name = output_dir) 315 | results <- .create_object( 316 | output_dir = output_dir, 317 | model_type = model_type, 318 | regression = regression, 319 | ) 320 | return(results) 321 | } 322 | -------------------------------------------------------------------------------- /README.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | output: github_document 3 | --- 4 | 5 | 6 | 7 | ```{r, include = FALSE} 8 | knitr::opts_chunk$set( 9 | collapse = TRUE, 10 | comment = "#>", 11 | fig.path = "man/figures/README-", 12 | out.width = "100%" 13 | ) 14 | ``` 15 | 16 | # grafzahl 17 | 18 | 19 | [![CRAN status](https://www.r-pkg.org/badges/version/grafzahl)](https://CRAN.R-project.org/package=grafzahl) 20 | 21 | 22 | The goal of grafzahl (**G**racious **R** **A**nalytical **F**ramework for **Z**appy **A**nalysis of **H**uman **L**anguages [^1]) is to duct tape the [quanteda](https://github.com/quanteda/quanteda) ecosystem to modern [Transformer-based text classification models](https://simpletransformers.ai/), e.g. BERT, RoBERTa, etc. The model object looks and feels like the textmodel S3 object from the package [quanteda.textmodels](https://github.com/quanteda/quanteda.textmodels). 23 | 24 | If you don't know what I am talking about, don't worry, this package is gracious. You don't need to know a lot about Transformers to use this package. See the examples below. 25 | 26 | Please cite this software as: 27 | 28 | Chan, C., (2023). [grafzahl: fine-tuning Transformers for text data from within R](paper/grafzahl_sp.pdf). *Computational Communication Research* 5(1): 76-84. [https://doi.org/10.5117/CCR2023.1.003.CHAN](https://doi.org/10.5117/CCR2023.1.003.CHAN) 29 | 30 | ## Installation: Local environment 31 | 32 | Install the CRAN version 33 | 34 | ```r 35 | install.packages("grafzahl") 36 | ``` 37 | 38 | After that, you need to setup your conda environment 39 | 40 | ```r 41 | require(grafzahl) 42 | setup_grafzahl(cuda = TRUE) ## if you have GPU(s) 43 | ``` 44 | 45 | ## On remote environments, e.g. Google Colab 46 | 47 | On Google Colab, you need to enable non-Conda mode 48 | 49 | ```r 50 | install.packages("grafzahl") 51 | require(grafzahl) 52 | use_nonconda() 53 | ``` 54 | 55 | Please refer the vignette. 56 | 57 | ## Usage 58 | 59 | Suppose you have a bunch of tweets in the quanteda corpus format. And the corpus has exactly one docvar that denotes the labels you want to predict. The data is from [this repository](https://github.com/pablobarbera/incivility-sage-open) (Theocharis et al., 2020). 60 | 61 | ```{r, echo = FALSE, message = FALSE} 62 | devtools::load_all() 63 | ``` 64 | 65 | ```{r} 66 | unciviltweets 67 | ``` 68 | 69 | In order to train a Transfomer model, please select the `model_name` from [Hugging Face's list](https://huggingface.co/models). The table below lists some common choices. In most of the time, providing `model_name` is sufficient, there is no need to provide `model_type`. 70 | 71 | Suppose you want to train a Transformer model using "bertweet" (Nguyen et al., 2020) because it matches your domain of usage. By default, it will save the model in the `output` directory of the current directory. You can change it to elsewhere using the `output_dir` parameter. 72 | 73 | ```r 74 | model <- grafzahl(unciviltweets, model_type = "bertweet", model_name = "vinai/bertweet-base") 75 | ### If you are hardcore quanteda user: 76 | ## model <- textmodel_transformer(unciviltweets, 77 | ## model_type = "bertweet", model_name = "vinai/bertweet-base") 78 | ``` 79 | 80 | Make prediction 81 | 82 | ```r 83 | predict(model) 84 | ``` 85 | 86 | That is it. 87 | 88 | ## Extended examples 89 | 90 | Several extended examples are also available. 91 | 92 | | Examples | file | 93 | |-------------------------------------------------|------------------------------------------------| 94 | | van Atteveldt et al. (2021) | [paper/vanatteveldt.md](paper/vanatteveldt.md) | 95 | | Dobbrick et al. (2021) | [paper/dobbrick.md](paper/dobbrick.md) | 96 | | Theocharis et al. (2020) | [paper/theocharis.md](paper/theocharis.md) | 97 | | OffensEval-TR (2020) | [paper/coltekin.md](paper/coltekin.md) | 98 | | Amharic News Text classification Dataset (2021) | [paper/azime.md](paper/azime.md) | 99 | 100 | ## Some common choices of `model_name` 101 | 102 | | Your data | model_type | model_name | 103 | |-------------------|------------|------------------------------------| 104 | | English tweets | bertweet | vinai/bertweet-base | 105 | | Lightweight | mobilebert | google/mobilebert-uncased | 106 | | | distilbert | distilbert-base-uncased | 107 | | Long Text | longformer | allenai/longformer-base-4096 | 108 | | | bigbird | google/bigbird-roberta-base | 109 | | English (General) | bert | bert-base-uncased | 110 | | | bert | bert-base-cased | 111 | | | electra | google/electra-small-discriminator | 112 | | | roberta | roberta-base | 113 | | Multilingual | xlm | xlm-mlm-17-1280 | 114 | | | xml | xlm-mlm-100-1280 | 115 | | | bert | bert-base-multilingual-cased | 116 | | | xlmroberta | xlm-roberta-base | 117 | | | xlmroberta | xlm-roberta-large | 118 | 119 | # References 120 | 121 | 1. Theocharis, Y., Barberá, P., Fazekas, Z., & Popa, S. A. (2020). The dynamics of political incivility on Twitter. Sage Open, 10(2), 2158244020919447. 122 | 2. Nguyen, D. Q., Vu, T., & Nguyen, A. T. (2020). BERTweet: A pre-trained language model for English Tweets. arXiv preprint arXiv:2005.10200. 123 | 124 | --- 125 | [^1]: Yes, I totally made up the meaningless long name. Actually, it is the German name of the *Sesame Street* character [Count von Count](https://de.wikipedia.org/wiki/Sesamstra%C3%9Fe#Graf_Zahl), meaning "Count (the noble title) Number". And it seems to be so that it is compulsory to name absolutely everything related to Transformers after Seasame Street characters. 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # grafzahl 5 | 6 | 7 | 8 | [![CRAN 9 | status](https://www.r-pkg.org/badges/version/grafzahl)](https://CRAN.R-project.org/package=grafzahl) 10 | 11 | 12 | The goal of grafzahl (**G**racious **R** **A**nalytical **F**ramework 13 | for **Z**appy **A**nalysis of **H**uman **L**anguages \[1\]) is to duct 14 | tape the [quanteda](https://github.com/quanteda/quanteda) ecosystem to 15 | modern [Transformer-based text classification 16 | models](https://simpletransformers.ai/), e.g. BERT, RoBERTa, etc. The 17 | model object looks and feels like the textmodel S3 object from the 18 | package 19 | [quanteda.textmodels](https://github.com/quanteda/quanteda.textmodels). 20 | 21 | If you don’t know what I am talking about, don’t worry, this package is 22 | gracious. You don’t need to know a lot about Transformers to use this 23 | package. See the examples below. 24 | 25 | Please cite this software as: 26 | 27 | Chan, C., (2023). [grafzahl: fine-tuning Transformers for text data from 28 | within R](paper/grafzahl_sp.pdf). *Computational Communication Research* 29 | 5(1): 76-84. 30 | 31 | ## Installation: Local environment 32 | 33 | Install the CRAN version 34 | 35 | ``` r 36 | install.packages("grafzahl") 37 | ``` 38 | 39 | After that, you need to setup your conda environment 40 | 41 | ``` r 42 | require(grafzahl) 43 | setup_grafzahl(cuda = TRUE) ## if you have GPU(s) 44 | ``` 45 | 46 | ## On remote environments, e.g. Google Colab 47 | 48 | On Google Colab, you need to enable non-Conda mode 49 | 50 | ``` r 51 | install.packages("grafzahl") 52 | require(grafzahl) 53 | use_nonconda() 54 | ``` 55 | 56 | Please refer the vignette. 57 | 58 | ## Usage 59 | 60 | Suppose you have a bunch of tweets in the quanteda corpus format. And 61 | the corpus has exactly one docvar that denotes the labels you want to 62 | predict. The data is from [this 63 | repository](https://github.com/pablobarbera/incivility-sage-open) 64 | (Theocharis et al., 2020). 65 | 66 | ``` r 67 | unciviltweets 68 | #> Corpus consisting of 19,982 documents and 1 docvar. 69 | #> text1 : 70 | #> "@ @ Karma gave you a second chance yesterday. Start doing m..." 71 | #> 72 | #> text2 : 73 | #> "@ With people like you, Steve King there's still hope for we..." 74 | #> 75 | #> text3 : 76 | #> "@ @ You bill is a joke and will sink the GOP. #WEDESERVEBETT..." 77 | #> 78 | #> text4 : 79 | #> "@ Dream on. The only thing trump understands is how to enric..." 80 | #> 81 | #> text5 : 82 | #> "@ @ Just like the Democrat taliban party was up front with t..." 83 | #> 84 | #> text6 : 85 | #> "@ you are going to have more of the same with HRC, and you a..." 86 | #> 87 | #> [ reached max_ndoc ... 19,976 more documents ] 88 | ``` 89 | 90 | In order to train a Transfomer model, please select the `model_name` 91 | from [Hugging Face’s list](https://huggingface.co/models). The table 92 | below lists some common choices. In most of the time, providing 93 | `model_name` is sufficient, there is no need to provide `model_type`. 94 | 95 | Suppose you want to train a Transformer model using “bertweet” (Nguyen 96 | et al., 2020) because it matches your domain of usage. By default, it 97 | will save the model in the `output` directory of the current directory. 98 | You can change it to elsewhere using the `output_dir` parameter. 99 | 100 | ``` r 101 | model <- grafzahl(unciviltweets, model_type = "bertweet", model_name = "vinai/bertweet-base") 102 | ### If you are hardcore quanteda user: 103 | ## model <- textmodel_transformer(unciviltweets, 104 | ## model_type = "bertweet", model_name = "vinai/bertweet-base") 105 | ``` 106 | 107 | Make prediction 108 | 109 | ``` r 110 | predict(model) 111 | ``` 112 | 113 | That is it. 114 | 115 | ## Extended examples 116 | 117 | Several extended examples are also available. 118 | 119 | | Examples | file | 120 | | ----------------------------------------------- | ---------------------------------------------- | 121 | | van Atteveldt et al. (2021) | [paper/vanatteveldt.md](paper/vanatteveldt.md) | 122 | | Dobbrick et al. (2021) | [paper/dobbrick.md](paper/dobbrick.md) | 123 | | Theocharis et al. (2020) | [paper/theocharis.md](paper/theocharis.md) | 124 | | OffensEval-TR (2020) | [paper/coltekin.md](paper/coltekin.md) | 125 | | Amharic News Text classification Dataset (2021) | [paper/azime.md](paper/azime.md) | 126 | 127 | ## Some common choices of `model_name` 128 | 129 | | Your data | model\_type | model\_name | 130 | | ----------------- | ----------- | ---------------------------------- | 131 | | English tweets | bertweet | vinai/bertweet-base | 132 | | Lightweight | mobilebert | google/mobilebert-uncased | 133 | | | distilbert | distilbert-base-uncased | 134 | | Long Text | longformer | allenai/longformer-base-4096 | 135 | | | bigbird | google/bigbird-roberta-base | 136 | | English (General) | bert | bert-base-uncased | 137 | | | bert | bert-base-cased | 138 | | | electra | google/electra-small-discriminator | 139 | | | roberta | roberta-base | 140 | | Multilingual | xlm | xlm-mlm-17-1280 | 141 | | | xml | xlm-mlm-100-1280 | 142 | | | bert | bert-base-multilingual-cased | 143 | | | xlmroberta | xlm-roberta-base | 144 | | | xlmroberta | xlm-roberta-large | 145 | 146 | # References 147 | 148 | 1. Theocharis, Y., Barberá, P., Fazekas, Z., & Popa, S. A. (2020). The 149 | dynamics of political incivility on Twitter. Sage Open, 10(2), 150 | 2158244020919447. 151 | 2. Nguyen, D. Q., Vu, T., & Nguyen, A. T. (2020). BERTweet: A 152 | pre-trained language model for English Tweets. arXiv preprint 153 | arXiv:2005.10200. 154 | 155 | ----- 156 | 157 | 1. Yes, I totally made up the meaningless long name. Actually, it is 158 | the German name of the *Sesame Street* character [Count von 159 | Count](https://de.wikipedia.org/wiki/Sesamstra%C3%9Fe#Graf_Zahl), 160 | meaning “Count (the noble title) Number”. And it seems to be so that 161 | it is compulsory to name absolutely everything related to 162 | Transformers after Seasame Street characters. 163 | -------------------------------------------------------------------------------- /_pkgdown.yml: -------------------------------------------------------------------------------- 1 | url: https://gesistsa.github.io/grafzahl/ 2 | template: 3 | package: tsatemplate 4 | 5 | -------------------------------------------------------------------------------- /_quarto.yml: -------------------------------------------------------------------------------- 1 | project: 2 | title: grafzahl 3 | type: default 4 | render: 5 | - methodshub.qmd 6 | -------------------------------------------------------------------------------- /data/ecosent.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gesistsa/grafzahl/0cb7d375e1bb64f6c2fedc5bd4ba896f81f68572/data/ecosent.rda -------------------------------------------------------------------------------- /data/smallunciviltweets.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gesistsa/grafzahl/0cb7d375e1bb64f6c2fedc5bd4ba896f81f68572/data/smallunciviltweets.rda -------------------------------------------------------------------------------- /data/supported_model_types.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gesistsa/grafzahl/0cb7d375e1bb64f6c2fedc5bd4ba896f81f68572/data/supported_model_types.rda -------------------------------------------------------------------------------- /data/unciviltweets.rda: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gesistsa/grafzahl/0cb7d375e1bb64f6c2fedc5bd4ba896f81f68572/data/unciviltweets.rda -------------------------------------------------------------------------------- /inst/CITATION: -------------------------------------------------------------------------------- 1 | citHeader("To cite grafzahl in publications use:") 2 | 3 | 4 | bibentry(bibtype = "article", 5 | title = "grafzahl: fine-tuning Transformers for text data from within R.", 6 | journal = "Computational Communication Research", 7 | author = c(person("Chung-hong", "Chan")), 8 | doi = "10.5117/CCR2023.1.003.CHAN", 9 | volume = 5, 10 | number = 1, 11 | pages = "76-84", 12 | year = 2023 13 | ) 14 | -------------------------------------------------------------------------------- /inst/grafzahl.yml: -------------------------------------------------------------------------------- 1 | name: grafzahl 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - anaconda 6 | - defaults 7 | dependencies: 8 | - python 9 | - pip 10 | - pytorch>=1.6+cpuonly 11 | - pip: 12 | - pandas 13 | - tqdm 14 | - simpletransformers 15 | - emoji 16 | - transformers 17 | - scipy 18 | -------------------------------------------------------------------------------- /inst/grafzahl_gpu.yml: -------------------------------------------------------------------------------- 1 | name: grafzahl 2 | channels: 3 | - conda-forge 4 | - anaconda 5 | - defaults 6 | dependencies: 7 | - python 8 | - pip 9 | - pip: 10 | - pandas 11 | - tqdm 12 | - emoji 13 | -------------------------------------------------------------------------------- /inst/python/st.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from simpletransformers.classification import ClassificationModel, ClassificationArgs 3 | import os 4 | import pandas as pd 5 | import random 6 | 7 | 8 | from sklearn.model_selection import train_test_split 9 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 10 | 11 | def py_detect_cuda(): 12 | return(torch.cuda.is_available()) 13 | 14 | def py_train(data, num_labels, output_dir, best_model_dir, cache_dir, model_type, model_name, num_train_epochs, train_size, manual_seed, regression, verbose): 15 | if py_detect_cuda(): 16 | torch.cuda.empty_cache() 17 | random.seed(manual_seed) 18 | data.columns = ["text", "labels"] 19 | if train_size < 1 and num_train_epochs <= 4: 20 | num_train_epochs = 20 21 | mod_args = { 22 | 'reprocess_input_data': True, 23 | 'overwrite_output_dir': True, 24 | 'fp16': True, 25 | 'output_dir': output_dir, 26 | "best_model_dir": best_model_dir, 27 | "cache_dir": cache_dir, 28 | "use_multiprocessing": False, 29 | "use_multiprocessing_for_evaluation": False, 30 | "save_steps": -1, 31 | "save_eval_checkpoints": False, 32 | "save_model_every_epoch": False, 33 | "num_train_epochs": num_train_epochs, 34 | "manual_seed": manual_seed, 35 | "silent": not verbose 36 | } 37 | if regression: 38 | mod_args["regression"] = True 39 | if train_size < 1: 40 | mod_args["use_early_stopping"] = True 41 | mod_args["evaluate_during_training"] = True 42 | mod_args["early_stopping_delta"] = 0.02 43 | mod_args["early_stopping_patience"] = 1 44 | if regression: 45 | mod_args["early_stopping_metric"] = "eval_loss" 46 | else: 47 | ## Classification 48 | mod_args["early_stopping_metric"] = "mcc" 49 | mod_args["early_stopping_metric_minimize"] = False 50 | model = ClassificationModel(model_type = model_type, model_name = model_name, num_labels = num_labels, use_cuda = py_detect_cuda(), args = mod_args) 51 | data_train, data_cv = train_test_split(data, train_size = train_size, stratify = data['labels'].values.tolist()) 52 | model.train_model(data_train, eval_df = data_cv, verbose = verbose, show_running_loss = verbose) 53 | else: 54 | model = ClassificationModel(model_type = model_type, model_name = model_name, num_labels = num_labels, use_cuda = py_detect_cuda(), args = mod_args) 55 | model.train_model(data, verbose = verbose, show_running_loss = verbose) 56 | 57 | def py_predict(to_predict, model_type, output_dir, return_raw, use_cuda): 58 | if len(to_predict) == 1: 59 | to_predict = [to_predict] 60 | model = ClassificationModel(model_type, output_dir, use_cuda = use_cuda, args = { 61 | 'reprocess_input_data': True, 62 | "use_multiprocessing": False, 63 | "fp16": True, 64 | "use_multiprocessing_for_evaluation": False 65 | }) 66 | predictions, raw_outputs = model.predict(to_predict) 67 | if return_raw: 68 | return(raw_outputs) 69 | else: 70 | return(predictions) 71 | -------------------------------------------------------------------------------- /man/detect_cuda.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/setup.R 3 | \name{detect_conda} 4 | \alias{detect_conda} 5 | \alias{detect_cuda} 6 | \title{Detecting Miniconda And Cuda} 7 | \usage{ 8 | detect_conda() 9 | 10 | detect_cuda() 11 | } 12 | \value{ 13 | boolean, whether the system is available. 14 | } 15 | \description{ 16 | These functions detects miniconda and cuda. 17 | } 18 | \details{ 19 | \code{detect_conda} conducts a test to check whether 1) a miniconda installation and 2) the grafzahl miniconda environment exist. 20 | 21 | \code{detect_cuda} checks whether cuda is available. If \code{setup_grafzahl} was executed with \code{cuda} being \code{FALSE}, this function will return \code{FALSE}. Even if \code{setup_grafzahl} was executed with \code{cuda} being \code{TRUE} but with any factor that can't enable cuda (e.g. no Nvidia GPU, the environment was incorrectly created), this function will also return \code{FALSE}. 22 | } 23 | -------------------------------------------------------------------------------- /man/ecosent.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/misc.R 3 | \docType{data} 4 | \name{ecosent} 5 | \alias{ecosent} 6 | \title{A Corpus Of Dutch News Headlines} 7 | \format{ 8 | An object of class \code{data.frame} with 6322 rows and 4 columns. 9 | } 10 | \usage{ 11 | ecosent 12 | } 13 | \description{ 14 | This is a dataset from the paper "The Validity of Sentiment Analysis: Comparing Manual Annotation, Crowd-Coding, Dictionary Approaches, and Machine Learning Algorithms." 15 | The data frame contains four columns: id (identifier), headline (the actual text data), value (sentiment: 0 Neutral, +1 Positive, -1 Negative), gold (whether or not this row is "gold standard", i.e. test set). The data is available from Wouter van Atteveldt's Github. \url{https://github.com/vanatteveldt/ecosent} 16 | } 17 | \references{ 18 | Van Atteveldt, W., Van der Velden, M. A., & Boukes, M. (2021). The validity of sentiment analysis: Comparing manual annotation, crowd-coding, dictionary approaches, and machine learning algorithms. Communication Methods and Measures, 15(2), 121-140. 19 | } 20 | \keyword{datasets} 21 | -------------------------------------------------------------------------------- /man/figures/grafzahl_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gesistsa/grafzahl/0cb7d375e1bb64f6c2fedc5bd4ba896f81f68572/man/figures/grafzahl_logo.png -------------------------------------------------------------------------------- /man/get_amharic_data.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/misc.R 3 | \name{get_amharic_data} 4 | \alias{get_amharic_data} 5 | \title{Download The Amharic News Text Classification Dataset} 6 | \usage{ 7 | get_amharic_data() 8 | } 9 | \value{ 10 | A named list of two corpora: training and test 11 | } 12 | \description{ 13 | This function downloads the training and test sets of the Amharic News Text Classification Dataset from Hugging Face. 14 | } 15 | \references{ 16 | Azime, Israel Abebe, and Nebil Mohammed (2021). "An Amharic News Text classification Dataset." arXiv preprint arXiv:2103.05639 17 | } 18 | -------------------------------------------------------------------------------- /man/grafzahl.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/train.R 3 | \name{grafzahl} 4 | \alias{grafzahl} 5 | \alias{grafzahl.default} 6 | \alias{grafzahl.corpus} 7 | \alias{textmodel_transformer} 8 | \alias{grafzahl.character} 9 | \title{Fine tune a pretrained Transformer model for texts} 10 | \usage{ 11 | grafzahl( 12 | x, 13 | y = NULL, 14 | model_name = "xlm-roberta-base", 15 | regression = FALSE, 16 | output_dir, 17 | cuda = detect_cuda(), 18 | num_train_epochs = 4, 19 | train_size = 0.8, 20 | args = NULL, 21 | cleanup = TRUE, 22 | model_type = NULL, 23 | manual_seed = floor(runif(1, min = 1, max = 721831)), 24 | verbose = TRUE 25 | ) 26 | 27 | \method{grafzahl}{default}( 28 | x, 29 | y = NULL, 30 | model_name = "xlm-roberta-base", 31 | regression = FALSE, 32 | output_dir, 33 | cuda = detect_cuda(), 34 | num_train_epochs = 4, 35 | train_size = 0.8, 36 | args = NULL, 37 | cleanup = TRUE, 38 | model_type = NULL, 39 | manual_seed = floor(runif(1, min = 1, max = 721831)), 40 | verbose = TRUE 41 | ) 42 | 43 | \method{grafzahl}{corpus}( 44 | x, 45 | y = NULL, 46 | model_name = "xlm-roberta-base", 47 | regression = FALSE, 48 | output_dir, 49 | cuda = detect_cuda(), 50 | num_train_epochs = 4, 51 | train_size = 0.8, 52 | args = NULL, 53 | cleanup = TRUE, 54 | model_type = NULL, 55 | manual_seed = floor(runif(1, min = 1, max = 721831)), 56 | verbose = TRUE 57 | ) 58 | 59 | textmodel_transformer(...) 60 | 61 | \method{grafzahl}{character}( 62 | x, 63 | y = NULL, 64 | model_name = "xlmroberta", 65 | regression = FALSE, 66 | output_dir, 67 | cuda = detect_cuda(), 68 | num_train_epochs = 4, 69 | train_size = 0.8, 70 | args = NULL, 71 | cleanup = TRUE, 72 | model_type = NULL, 73 | manual_seed = floor(runif(1, min = 1, max = 721831)), 74 | verbose = TRUE 75 | ) 76 | } 77 | \arguments{ 78 | \item{x}{the \link{corpus} or character vector of texts on which the model will be trained. Depending on \code{train_size}, some texts will be used for cross-validation.} 79 | 80 | \item{y}{training labels. It can either be a single string indicating which \link{docvars} of the \link{corpus} is the training labels; a vector of training labels in either character or factor; or \code{NULL} if the \link{corpus} contains exactly one column in \link{docvars} and that column is the training labels. If \code{x} is a character vector, \code{y} must be a vector of the same length.} 81 | 82 | \item{model_name}{string indicates either 1) the model name on Hugging Face website; 2) the local path of the model} 83 | 84 | \item{regression}{logical, if \code{TRUE}, the task is regression, classification otherwise.} 85 | 86 | \item{output_dir}{string, location of the output model. If missing, the model will be stored in a temporary directory. Important: Please note that if this directory exists, it will be overwritten.} 87 | 88 | \item{cuda}{logical, whether to use CUDA, default to \code{\link[=detect_cuda]{detect_cuda()}}.} 89 | 90 | \item{num_train_epochs}{numeric, if \code{train_size} is not exactly 1.0, the maximum number of epochs to try in the "early stop" regime will be this number times 5 (i.e. 4 * 5 = 20 by default). If \code{train_size} is exactly 1.0, the number of epochs is exactly that.} 91 | 92 | \item{train_size}{numeric, proportion of data in \code{x} and \code{y} to be used actually for training. The rest will be used for cross validation.} 93 | 94 | \item{args}{list, additionally parameters to be used in the underlying simple transformers} 95 | 96 | \item{cleanup}{logical, if \code{TRUE}, the \code{runs} directory generated will be removed when the training is done} 97 | 98 | \item{model_type}{a string indicating model_type of the input model. If \code{NULL}, it will be inferred from \code{model_name}. Supported model types are available in \link{supported_model_types}.} 99 | 100 | \item{manual_seed}{numeric, random seed} 101 | 102 | \item{verbose}{logical, if \code{TRUE}, debug messages will be displayed} 103 | 104 | \item{...}{paramters pass to \code{\link[=grafzahl]{grafzahl()}}} 105 | } 106 | \value{ 107 | a \code{grafzahl} S3 object with the following items 108 | \item{call}{original function call} 109 | \item{input_data}{input_data for the underlying python function} 110 | \item{output_dir}{location of the output model} 111 | \item{model_type}{model type} 112 | \item{model_name}{model name} 113 | \item{regression}{whether or not it is a regression model} 114 | \item{levels}{factor levels of y} 115 | \item{manual_seed}{random seed} 116 | \item{meta}{metadata about the current session} 117 | } 118 | \description{ 119 | Fine tune (or train) a pretrained Transformer model for your given training labelled data \code{x} and \code{y}. The prediction task can be classification (if \code{regression} is \code{FALSE}, default) or regression (if \code{regression} is \code{TRUE}). 120 | } 121 | \examples{ 122 | if (detect_conda() && interactive()) { 123 | library(quanteda) 124 | set.seed(20190721) 125 | ## Using the default cross validation method 126 | model1 <- grafzahl(unciviltweets, model_type = "bertweet", model_name = "vinai/bertweet-base") 127 | predict(model1) 128 | 129 | ## Using LIME 130 | input <- corpus(ecosent, text_field = "headline") 131 | training_corpus <- corpus_subset(input, !gold) 132 | model2 <- grafzahl(x = training_corpus, 133 | y = "value", 134 | model_name = "GroNLP/bert-base-dutch-cased") 135 | test_corpus <- corpus_subset(input, gold) 136 | predicted_sentiment <- predict(model2, test_corpus) 137 | require(lime) 138 | sentences <- c("Dijsselbloem pessimistisch over snelle stappen Grieken", 139 | "Aandelenbeurzen zetten koersopmars voort") 140 | explainer <- lime(training_corpus, model2) 141 | explanations <- explain(sentences, explainer, n_labels = 1, 142 | n_features = 2) 143 | plot_text_explanations(explanations) 144 | } 145 | } 146 | \seealso{ 147 | \code{\link[=predict.grafzahl]{predict.grafzahl()}} 148 | } 149 | -------------------------------------------------------------------------------- /man/hydrate.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/train.R 3 | \name{hydrate} 4 | \alias{hydrate} 5 | \title{Create a grafzahl S3 object from the output_dir} 6 | \usage{ 7 | hydrate(output_dir, model_type = NULL, regression = FALSE) 8 | } 9 | \arguments{ 10 | \item{output_dir}{string, location of the output model. If missing, the model will be stored in a temporary directory. Important: Please note that if this directory exists, it will be overwritten.} 11 | 12 | \item{model_type}{a string indicating model_type of the input model. If \code{NULL}, it will be inferred from \code{model_name}. Supported model types are available in \link{supported_model_types}.} 13 | 14 | \item{regression}{logical, if \code{TRUE}, the task is regression, classification otherwise.} 15 | } 16 | \value{ 17 | a \code{grafzahl} S3 object with the following items 18 | \item{call}{original function call} 19 | \item{input_data}{input_data for the underlying python function} 20 | \item{output_dir}{location of the output model} 21 | \item{model_type}{model type} 22 | \item{model_name}{model name} 23 | \item{regression}{whether or not it is a regression model} 24 | \item{levels}{factor levels of y} 25 | \item{manual_seed}{random seed} 26 | \item{meta}{metadata about the current session} 27 | } 28 | \description{ 29 | Create a grafzahl S3 object from the output_dir 30 | } 31 | -------------------------------------------------------------------------------- /man/predict.grafzahl.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/train.R 3 | \name{predict.grafzahl} 4 | \alias{predict.grafzahl} 5 | \title{Prediction from a fine-tuned grafzahl object} 6 | \usage{ 7 | \method{predict}{grafzahl}(object, newdata, cuda = detect_cuda(), return_raw = FALSE, ...) 8 | } 9 | \arguments{ 10 | \item{object}{an S3 object trained with \code{\link[=grafzahl]{grafzahl()}}} 11 | 12 | \item{newdata}{a \link{corpus} or a character vector of texts on which prediction should be made.} 13 | 14 | \item{cuda}{logical, whether to use CUDA, default to \code{\link[=detect_cuda]{detect_cuda()}}.} 15 | 16 | \item{return_raw}{logical, if \code{TRUE}, return a matrix of logits; a vector of class prediction otherwise} 17 | 18 | \item{...}{not used} 19 | } 20 | \value{ 21 | a vector of class prediction or a matrix of logits 22 | } 23 | \description{ 24 | Make prediction from a fine-tuned grafzahl object. 25 | } 26 | -------------------------------------------------------------------------------- /man/setup_grafzahl.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/setup.R 3 | \name{setup_grafzahl} 4 | \alias{setup_grafzahl} 5 | \title{Setup grafzahl} 6 | \usage{ 7 | setup_grafzahl(cuda = FALSE, force = FALSE, cuda_version = "11.3") 8 | } 9 | \arguments{ 10 | \item{cuda}{logical, if \code{TRUE}, indicate whether a CUDA-enabled environment is wanted.} 11 | 12 | \item{force}{logical, if \code{TRUE}, delete previous environment (if exists) and create a new environment} 13 | 14 | \item{cuda_version}{character, indicate CUDA version, ignore if \code{cuda} is \code{FALSE}} 15 | } 16 | \value{ 17 | TRUE (invisibly) if installation is successful. 18 | } 19 | \description{ 20 | Install a self-contained miniconda environment with all Python components (PyTorch, Transformers, Simpletransformers, etc) which grafzahl required. The default location is "~/.local/share/r-miniconda/envs/grafzahl_condaenv" (suffix "_cuda" is added if \code{cuda} is \code{TRUE}). 21 | On Linux or Mac and if miniconda is not found, this function will also install miniconda. The path can be changed by the environment variable \code{GRAFZAHL_MINICONDA_PATH} 22 | } 23 | \examples{ 24 | # setup an environment with cuda enabled. 25 | if (detect_conda() && interactive()) { 26 | setup_grafzahl(cuda = TRUE) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /man/supported_model_types.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/misc.R 3 | \docType{data} 4 | \name{supported_model_types} 5 | \alias{supported_model_types} 6 | \title{Supported model types} 7 | \format{ 8 | An object of class \code{character} of length 23. 9 | } 10 | \usage{ 11 | supported_model_types 12 | } 13 | \description{ 14 | A vector of all supported model types. 15 | } 16 | \keyword{datasets} 17 | -------------------------------------------------------------------------------- /man/unciviltweets.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/misc.R 3 | \docType{data} 4 | \name{unciviltweets} 5 | \alias{unciviltweets} 6 | \title{A Corpus Of Tweets With Incivility Labels} 7 | \format{ 8 | An object of class \code{corpus} (inherits from \code{character}) of length 19982. 9 | } 10 | \usage{ 11 | unciviltweets 12 | } 13 | \description{ 14 | This is a dataset from the paper "The Dynamics of Political Incivility on Twitter". The tweets were by Members of Congress elected to the 115th Congress (2017–2018). It is important to note that not all the incivility labels were coded by human. Majority of the labels were coded by the Google Perspective API. All mentions were removed. The dataset is available from Pablo Barbera's Github. \url{https://github.com/pablobarbera/incivility-sage-open} 15 | } 16 | \references{ 17 | Theocharis, Y., Barberá, P., Fazekas, Z., & Popa, S. A. (2020). The dynamics of political incivility on Twitter. Sage Open, 10(2), 2158244020919447. 18 | } 19 | \keyword{datasets} 20 | -------------------------------------------------------------------------------- /man/use_nonconda.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/setup.R 3 | \name{use_nonconda} 4 | \alias{use_nonconda} 5 | \title{Set up grafzahl to be used on Google Colab or similar environments} 6 | \usage{ 7 | use_nonconda(install = TRUE, check = TRUE, verbose = TRUE) 8 | } 9 | \arguments{ 10 | \item{install}{logical, whether to install the required Python packages} 11 | 12 | \item{check}{logical, whether to perform a check after the setup. The check displays 1) whether CUDA can be detected, 2) whether 13 | the non-conda mode has been activated, i.e. whether the option 'grafzahl.nonconda' is \code{TRUE}.} 14 | 15 | \item{verbose, }{logical, whether to display messages} 16 | } 17 | \value{ 18 | TRUE (invisibly) if installation is successful. 19 | } 20 | \description{ 21 | Set up grafzahl to be used on Google Colab or similar environments. This function is also useful if you do not 22 | want to use conda on a local machine, e.g. you have configurateed the required Python package. 23 | } 24 | \examples{ 25 | # A typical use case for Google Colab 26 | if (interactive()) { 27 | use_nonconda() 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /methodshub.qmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: grafzahl - Supervised Machine Learning for Textual Data Using Transformers and 'Quanteda' 3 | format: 4 | html: 5 | embed-resources: true 6 | gfm: default 7 | --- 8 | 9 | ## Description 10 | 11 | 12 | 13 | Duct tape the 'quanteda' ecosystem (Benoit et al., 2018) [doi:10.21105/joss.00774](https://doi.org/10.21105/joss.00774) to modern Transformer-based text classification models (Wolf et al., 2020) [doi:10.18653/v1/2020.emnlp-demos.6](https://doi.org/10.18653/v1/2020.emnlp-demos.6), in order to facilitate supervised machine learning for textual data. This package mimics the behaviors of 'quanteda.textmodels' and provides a function to setup the 'Python' environment to use the pretrained models from 'Hugging Face' . More information: [doi:10.5117/CCR2023.1.003.CHAN](https://doi.org/10.5117/CCR2023.1.003.CHAN). 14 | 15 | ## Keywords 16 | 17 | 18 | 19 | * Deep Learning 20 | * Supervised machine learning 21 | * Text analysis 22 | 23 | ## Science Usecase(s) 24 | 25 | 26 | 27 | 28 | 29 | 30 | This package can be used in any typical supervised machine learning usecase involving text data. In the software paper ([Chan et al.](https://doi.org/10.5117/CCR2023.1.003.CHAN)), several cases were presented, e.g. Prediction of incivility based on tweets ([Theocharis et al., 2020](https://doi.org/10.1177/2158244020919447)). 31 | 32 | ## Repository structure 33 | 34 | This repository follows [the standard structure of an R package](https://cran.r-project.org/doc/FAQ/R-exts.html#Package-structure). 35 | 36 | ## Environment Setup 37 | 38 | With R installed: 39 | 40 | ```r 41 | install.packages("grafzahl") 42 | ``` 43 | 44 | ## Hardware Requirements (Optional) 45 | 46 | A GPU that supports CUDA is optional. 47 | 48 | ## Input Data 49 | 50 | 51 | 52 | 53 | 54 | 55 | `grafzahl` accepts text data as either character vector or the `corpus` data structure of `quanteda`. 56 | 57 | ## Sample Input and Output Data 58 | 59 | 60 | 61 | 62 | A sample input is a `corpus`. This is an example dataset: 63 | 64 | ```{r} 65 | #| message: false 66 | library(grafzahl) 67 | library(quanteda) 68 | unciviltweets 69 | ``` 70 | 71 | The output is an S3 object. 72 | 73 | ## How to Use 74 | 75 | Before training, please setup the conda environment. 76 | 77 | ```r 78 | setup_grafzahl(cuda = TRUE) ## if you have GPU(s) 79 | ``` 80 | 81 | A typical way to train and make predictions. 82 | 83 | ```r 84 | input <- corpus(ecosent, text_field = "headline") 85 | training_corpus <- corpus_subset(input, !gold) 86 | ``` 87 | 88 | Use the `x` (text data), `y` (label, in this case a [`docvar`](https://quanteda.io/reference/docvars.html)), and `model_name` (Model name, from Hugging Face) parameters to control how the supervised machine learning model is trained. 89 | 90 | ```r 91 | model2 <- grafzahl(x = training_corpus, 92 | y = "value", 93 | model_name = "GroNLP/bert-base-dutch-cased") 94 | test_corpus <- corpus_subset(input, gold) 95 | predict(model2, test_corpus) 96 | ``` 97 | 98 | ## Contact Details 99 | 100 | Maintainer: Chung-hong Chan 101 | 102 | Issue Tracker: [https://github.com/gesistsa/grafzahl/issues](https://github.com/gesistsa/grafzahl/issues) 103 | 104 | ## Publication 105 | 106 | 1. Chan, C. H. (2023). grafzahl: fine-tuning Transformers for text data from within R. Computational Communication Research, 5(1), 76. 107 | 108 | 109 | 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /paper/.here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gesistsa/grafzahl/0cb7d375e1bb64f6c2fedc5bd4ba896f81f68572/paper/.here -------------------------------------------------------------------------------- /paper/aup_logo.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gesistsa/grafzahl/0cb7d375e1bb64f6c2fedc5bd4ba896f81f68572/paper/aup_logo.pdf -------------------------------------------------------------------------------- /paper/azime.md: -------------------------------------------------------------------------------- 1 | Azime & Mohammed. (2021) 2 | ================ 3 | 4 | Amharic is a Semitic language mainly spoken in Ethiopia. After Arabic, 5 | Amharic is the second most-spoken Semitic language. Unlike many Semitic 6 | languages using the *abjad* (consonant-only) writing system, Amharic is 7 | written in a unique alphasyllabary writing system called *Ge’ez*. 8 | Syntactically, Amharic is also different from many Germanic languages 9 | for its SOV (subject-object-verb) word order [^1]. It is in general 10 | considered to be a “low resource” language. Only recently, the first 11 | news classification dataset called “Amharic News Text classification 12 | Dataset” is available \[[link](https://arxiv.org/abs/2103.05639)\]. 13 | 14 | Amharic News Text classification Dataset contains 50,706 news articles 15 | curated from various Amharic websites. The original paper reports the 16 | baseline out-of-sample accuracy of 62.2% using Naive Bayes. The released 17 | data also contains the training-and-test split [^2]. It is a much bigger 18 | dataset than the two previous examples (training set: 41,185 articles, 19 | test set: 10,287). News articles were annotated into the following 20 | categories (originally written in *Ge’ez*, transliterated to Latin 21 | characters here): *hāgeri āk’efi zēna* (national news), *mezinanya* 22 | (entertainment), *siporiti* (sport), *bīzinesi* (business), *’alemi 23 | āk’efi zēna* (international news), and *poletīka* (politics). 24 | 25 | In this example, the AfriBERTa is used as the pretrained model. The 26 | AfriBERTa model was trained with a small corpus of 11 African languages. 27 | 28 | # Obtain the data 29 | 30 | ``` r 31 | download.file("https://huggingface.co/datasets/israel/Amharic-News-Text-classification-Dataset/resolve/main/train.csv", destfile = here::here("am_train.csv")) 32 | 33 | download.file("https://huggingface.co/datasets/israel/Amharic-News-Text-classification-Dataset/resolve/main/test.csv", destfile = here::here("am_test.csv")) 34 | ``` 35 | 36 | # Preserve a model 37 | 38 | We can directly use the AfriBERTa model online. We can also preserve a 39 | local copy of a pretrained model. As all models on Hugging Face are 40 | stored as a Git repository, one can use git to clone the model locally. 41 | A cloned model usually takes around 1G of local storage. 42 | 43 | ``` bash 44 | ## make sure you have installed git lfs 45 | ## https://git-lfs.github.com/ 46 | git lfs install 47 | git clone https://huggingface.co/castorini/afriberta_base localafriberta 48 | ``` 49 | 50 | # Train a classifer using the preserved AfriBERTa model 51 | 52 | ``` r 53 | require(quanteda) 54 | #> Loading required package: quanteda 55 | #> Package version: 3.2.4 56 | #> Unicode version: 13.0 57 | #> ICU version: 66.1 58 | #> Parallel computing: 16 of 16 threads used. 59 | #> See https://quanteda.io for tutorials and examples. 60 | require(readtext) 61 | #> Loading required package: readtext 62 | require(grafzahl) 63 | #> Loading required package: grafzahl 64 | input <- readtext::readtext(here::here("am_train.csv"), text_field = "article") %>% 65 | corpus %>% corpus_subset(category != "") 66 | ``` 67 | 68 | ``` r 69 | model <- grafzahl(x = input, 70 | y = "category", 71 | model_name = here::here("localafriberta")) 72 | ``` 73 | 74 | # Evaluate 75 | 76 | Accuracy: 84% 77 | 78 | ``` r 79 | testset_corpus <- readtext::readtext(here::here("am_test.csv"), text_field = "article") %>% corpus %>% corpus_subset(category != "") 80 | 81 | preds <- predict(model, newdata = testset_corpus) 82 | caret::confusionMatrix(table(preds, docvars(testset_corpus, "category"))) 83 | #> Confusion Matrix and Statistics 84 | #> 85 | #> 86 | #> preds ሀገር አቀፍ ዜና መዝናኛ ስፖርት ቢዝነስ ዓለም አቀፍ ዜና ፖለቲካ 87 | #> ሀገር አቀፍ ዜና 3434 23 57 194 88 234 88 | #> መዝናኛ 25 100 0 7 4 0 89 | #> ስፖርት 33 4 2052 4 7 4 90 | #> ቢዝነስ 130 2 0 454 7 104 91 | #> ዓለም አቀፍ ዜና 115 4 4 5 1136 17 92 | #> ፖለቲካ 351 2 2 120 82 1492 93 | #> 94 | #> Overall Statistics 95 | #> 96 | #> Accuracy : 0.8418 97 | #> 95% CI : (0.8346, 0.8488) 98 | #> No Information Rate : 0.397 99 | #> P-Value [Acc > NIR] : < 2.2e-16 100 | #> 101 | #> Kappa : 0.7878 102 | #> 103 | #> Mcnemar's Test P-Value : 1.794e-15 104 | #> 105 | #> Statistics by Class: 106 | #> 107 | #> Class: ሀገር አቀፍ ዜና Class: መዝናኛ Class: ስፖርት Class: ቢዝነስ 108 | #> Sensitivity 0.8400 0.740741 0.9702 0.57908 109 | #> Specificity 0.9040 0.996457 0.9936 0.97446 110 | #> Pos Pred Value 0.8521 0.735294 0.9753 0.65136 111 | #> Neg Pred Value 0.8956 0.996555 0.9923 0.96562 112 | #> Prevalence 0.3970 0.013111 0.2054 0.07614 113 | #> Detection Rate 0.3335 0.009712 0.1993 0.04409 114 | #> Detection Prevalence 0.3914 0.013208 0.2043 0.06769 115 | #> Balanced Accuracy 0.8720 0.868599 0.9819 0.77677 116 | #> Class: ዓለም አቀፍ ዜና Class: ፖለቲካ 117 | #> Sensitivity 0.8580 0.8061 118 | #> Specificity 0.9838 0.9341 119 | #> Pos Pred Value 0.8868 0.7282 120 | #> Neg Pred Value 0.9791 0.9565 121 | #> Prevalence 0.1286 0.1798 122 | #> Detection Rate 0.1103 0.1449 123 | #> Detection Prevalence 0.1244 0.1990 124 | #> Balanced Accuracy 0.9209 0.8701 125 | ``` 126 | 127 | [^1]: Actually, majority of the languages are SOV, while SVO (many 128 | Germanic languages) are slightly less common. 129 | 130 | [^2]: https://huggingface.co/datasets/israel/Amharic-News-Text-classification-Dataset/tree/main 131 | -------------------------------------------------------------------------------- /paper/azime.qmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: Azime & Mohammed. (2021) 3 | format: gfm 4 | --- 5 | 6 | ```{r} 7 | #| include: false 8 | knitr::opts_chunk$set( 9 | collapse = TRUE, 10 | comment = "#>", 11 | fig.path = "img/", 12 | out.width = "100%" 13 | ) 14 | model <- readRDS(here::here("amharic.RDS")) 15 | model$output_dir <- here::here("amharic") 16 | ``` 17 | 18 | Amharic is a Semitic language mainly spoken in Ethiopia. After Arabic, Amharic is the second most-spoken Semitic language. Unlike many Semitic languages using the *abjad* (consonant-only) writing system, Amharic is written in a unique alphasyllabary writing system called *Ge'ez*. Syntactically, Amharic is also different from many Germanic languages for its SOV (subject-object-verb) word order [^SOV]. It is in general considered to be a "low resource" language. Only recently, the first news classification dataset called "Amharic News Text classification Dataset" is available [[link](https://arxiv.org/abs/2103.05639)]. 19 | 20 | Amharic News Text classification Dataset contains 50,706 news articles curated from various Amharic websites. The original paper reports the baseline out-of-sample accuracy of 62.2\% using Naive Bayes. The released data also contains the training-and-test split [^Amharic]. It is a much bigger dataset than the two previous examples (training set: 41,185 articles, test set: 10,287). News articles were annotated into the following categories (originally written in *Ge'ez*, transliterated to Latin characters here): *hāgeri āk’efi zēna* (national news), *mezinanya* (entertainment), *siporiti* (sport), *bīzinesi* (business), *‘alemi āk’efi zēna* (international news), and *poletīka* (politics). 21 | 22 | In this example, the AfriBERTa is used as the pretrained model. The AfriBERTa model was trained with a small corpus of 11 African languages. 23 | 24 | # Obtain the data 25 | 26 | ```{r} 27 | #| eval: false 28 | download.file("https://huggingface.co/datasets/israel/Amharic-News-Text-classification-Dataset/resolve/main/train.csv", destfile = here::here("am_train.csv")) 29 | 30 | download.file("https://huggingface.co/datasets/israel/Amharic-News-Text-classification-Dataset/resolve/main/test.csv", destfile = here::here("am_test.csv")) 31 | ``` 32 | 33 | # Preserve a model 34 | 35 | We can directly use the AfriBERTa model online. We can also preserve a local copy of a pretrained model. As all models on Hugging Face are stored as a Git repository, one can use git to clone the model locally. A cloned model usually takes around 1G of local storage. 36 | 37 | 38 | ```bash 39 | ## make sure you have installed git lfs 40 | ## https://git-lfs.github.com/ 41 | git lfs install 42 | git clone https://huggingface.co/castorini/afriberta_base localafriberta 43 | ``` 44 | 45 | # Train a classifer using the preserved AfriBERTa model 46 | 47 | ```{r} 48 | require(quanteda) 49 | require(readtext) 50 | require(grafzahl) 51 | input <- readtext::readtext(here::here("am_train.csv"), text_field = "article") %>% 52 | corpus %>% corpus_subset(category != "") 53 | ``` 54 | 55 | ```{r} 56 | #| eval: false 57 | model <- grafzahl(x = input, 58 | y = "category", 59 | model_name = here::here("localafriberta")) 60 | ``` 61 | 62 | # Evaluate 63 | 64 | Accuracy: 84\% 65 | 66 | ```{r} 67 | testset_corpus <- readtext::readtext(here::here("am_test.csv"), text_field = "article") %>% corpus %>% corpus_subset(category != "") 68 | 69 | preds <- predict(model, newdata = testset_corpus) 70 | caret::confusionMatrix(table(preds, docvars(testset_corpus, "category"))) 71 | ``` 72 | 73 | 74 | [^SOV]: Actually, majority of the languages are SOV, while SVO (many Germanic languages) are slightly less common. 75 | 76 | [^Amharic]: https://huggingface.co/datasets/israel/Amharic-News-Text-classification-Dataset/tree/main 77 | -------------------------------------------------------------------------------- /paper/ccr.cls: -------------------------------------------------------------------------------- 1 | % Template for CCR Articles (very WIP) 2 | % 2023 Wouter van Atteveldt, Damian Trilling, Chung-hong Chan 3 | % Version 0.02 4 | % Please see https://github.com/vanatteveldt/ccr-quarto for the latest version of this file 5 | 6 | \ProvidesClass{ccr}[2023-02-03 v0.01] 7 | \NeedsTeXFormat{LaTeX2e} 8 | 9 | \LoadClass[twoside]{article} 10 | 11 | %%%%%%%%%%%%% OVRERALL PAGE LAYOUT %%%%%%%%%%%%%%%%%%% 12 | \RequirePackage[papersize={6.53in,9.61in}, 13 | left=1.1in,right=1.1in,top=1in,textheight=7.25in]{geometry} 14 | 15 | \RequirePackage{graphicx} 16 | 17 | \RequirePackage[english]{babel} 18 | \renewcommand*\oldstylenums[1]{\textosf{#1}} 19 | \RequirePackage{ifxetex} 20 | \ifxetex 21 | \RequirePackage[protrusion=true,final,babel]{microtype} 22 | \RequirePackage{fontspec} 23 | \setmainfont{erewhon} 24 | \setsansfont{Noto Sans} 25 | \usepackage{unicode-math} 26 | \defaultfontfeatures{ Scale=MatchLowercase, Ligatures=TeX } 27 | \setmathfont{TeX Gyre Termes Math} 28 | 29 | \else 30 | \RequirePackage[protrusion=true,expansion=true,final,babel]{microtype} 31 | \RequirePackage[sfdefault,scaled=.9]{noto} 32 | \RequirePackage[proportional,scaled=1]{erewhon} 33 | \RequirePackage[erewhon,vvarbb,bigdelims]{newtxmath} 34 | \fi 35 | \renewcommand{\floatpagefraction}{.7} 36 | \linespread{1.1} 37 | 38 | %%%%%%%%%%%%%% Information from authors %%%%%%%%%%%%%%%%%%%%% 39 | \RequirePackage{ifthen} 40 | 41 | \def\@shorttitle{} 42 | \newcommand{\shorttitle}[1]{\def\@shorttitle{#1}} 43 | \newcommand{\show@shorttitle}{% 44 | \ifthenelse{\equal{\@shorttitle}{}}% 45 | {\MakeLowercase{\@title}}{\MakeLowercase{\@shorttitle}}% 46 | } 47 | 48 | \def\@abstract{(specify abstract using \textbackslash abstract in preamble)} 49 | \renewcommand{\abstract}[1]{\def\@abstract{#1}} 50 | \newcommand{\show@abstract}{\@abstract} 51 | 52 | \def\@keywords{(specify keywords using \textbackslash keywords in preamble)} 53 | \newcommand{\keywords}[1]{\def\@keywords{#1}} 54 | \newcommand{\show@keywords}{\@keywords} 55 | 56 | \def\@shortauthors{(please specify \textbackslash shortauthors)} 57 | \newcommand{\shortauthors}[1]{\def\@shortauthors{#1}} 58 | \newcommand{\show@shortauthors}{\@shortauthors} 59 | 60 | \def\@volume{X} 61 | \newcommand{\volume}[1]{\def\@volume{#1}} 62 | \newcommand{\show@volume}{\@volume} 63 | \def\@pubnumber{Y} 64 | \newcommand{\pubnumber}[1]{\def\@pubnumber{#1}} 65 | \newcommand{\show@pubnumber}{\@pubnumber} 66 | \def\@pubyear{20xx} 67 | \newcommand{\pubyear}[1]{\def\@pubyear{#1}} 68 | \newcommand{\show@pubyear}{\@pubyear} 69 | \def\@doi{10.5117/ccr20xx.xxx.xxxx} 70 | \newcommand{\doi}[1]{\def\@doi{#1}} 71 | \newcommand{\show@doi}{\@doi} 72 | 73 | 74 | \newcommand{\firstpage}[1]{\setcounter{page}{#1}} 75 | %%%%%%%%%%%%%% HEADER / FOOTER %%%%%%%%%%%%%%%%%%%%%%%% 76 | \RequirePackage{lastpage} 77 | \RequirePackage{fancyhdr} 78 | 79 | \newcommand{\smallcaps}[1]{\textsc{\footnotesize #1}} 80 | \newcommand{\smallcapsl}[1]{\MakeLowercase{\smallcaps{#1}}} 81 | \setlength{\headheight}{24pt} 82 | 83 | %\usepackage{showframe} % useful for debugging header / margin 84 | \fancypagestyle{firstpage}{% 85 | \fancyhf{} % clear all six fields 86 | \fancyhead[L]{\includegraphics[height=2em]{aup_logo.pdf}} 87 | \fancyhead[R]{\raisebox{.1em}{\smallcaps{computational communication research 88 | \oldstylenums{\show@volume}.\oldstylenums{\show@pubnumber} (\oldstylenums{\show@pubyear}) 89 | \oldstylenums{\thepage}--\oldstylenums{\pageref{LastPage}}}}% 90 | \\% 91 | \raisebox{.4em}{\scriptsize\MakeUppercase{\sc\url{https://doi.org/\show@doi}}}} 92 | \fancyfoot[RE,LO]{\footnotesize© The author(s). This is an open access article distributed under the \href{https://creativecommons.org/licenses/by/4.0/}{\textsc{cc by} \oldstylenums{4.0} license}} 93 | \fancyfoot[LE,RO]{\smallcaps{\thepage}} 94 | \renewcommand{\headrulewidth}{0pt} 95 | \renewcommand{\footrulewidth}{0pt} 96 | } 97 | \fancypagestyle{followingpage}{% 98 | \fancyhf{} % clear all six fields 99 | \fancyhead[RE]{\smallcaps{\show@shorttitle}} 100 | \fancyhead[LO]{\smallcapsl{Computational Communication Research}} 101 | \fancyfoot[LE,RO]{\smallcaps{\thepage}} 102 | \fancyfoot[LO]{\smallcapsl{\show@shortauthors}} 103 | \fancyfoot[RE]{\smallcaps{vol. \oldstylenums{\show@volume}, no. \oldstylenums{\show@pubnumber}, \oldstylenums{\show@pubyear}}} 104 | \renewcommand{\headrulewidth}{0Pt} 105 | \renewcommand{\footrulewidth}{0pt} 106 | } 107 | 108 | \pagestyle{followingpage} 109 | \AtBeginDocument{\thispagestyle{firstpage}} 110 | 111 | 112 | 113 | 114 | %%%%%%%%%%%%%%%%%% Other markup / styling %%%%%%%%%%%%%%%%%%%%% 115 | \RequirePackage{etoolbox} 116 | % Tables 117 | \AtBeginEnvironment{tabularx}{\footnotesize\sffamily} 118 | \AtBeginEnvironment{tabular}{\footnotesize\sffamily} 119 | \renewcommand{\arraystretch}{1.3} 120 | 121 | \RequirePackage{makecell} 122 | \renewcommand\theadfont{ \bfseries} 123 | \renewcommand\theadalign{ll} 124 | \usepackage[font={footnotesize,sf}]{caption} 125 | 126 | % More compact enumerations 127 | \renewcommand{\@listI}{% 128 | \itemsep=0\parsep} 129 | % Bibliography style 130 | 131 | %%%%%%%%%%%%% Author information black magic %%%%%%%%%%%%%%%%% 132 | % Mostly stolen from https://github.com/dan-weiss/apa7-latex-cls-source/blob/62f31e0b2c8c75e260a7690928c745d803333549/apa7/apa7.dtx (LPPL licensed) 133 | 134 | \newcommand*\listauthors{} 135 | \newcommand*\listsuperscripts{} 136 | \newcommand*\listaffiliations{} 137 | 138 | \newcommand*{\authorsnames}[2][]{ 139 | \def\def@multipleauthors{\@multipleauthorsmode} % 140 | \renewcommand*\listauthors{} 141 | \renewcommand*\listsuperscripts{} 142 | \newcounter{NumberOfAuthors} 143 | \newcounter{NumberOfSuperscripts} 144 | \forcsvlist{\stepcounter{NumberOfAuthors}\listadd\listauthors}{#2} 145 | \forcsvlist{\stepcounter{NumberOfSuperscripts}\listadd\listsuperscripts}{#1} 146 | } 147 | 148 | \newcommand*{\authorsaffiliations}[1]{ 149 | \def\def@multipleaffils{\@multipleaffilsmode} % 150 | \renewcommand*\listaffiliations{} 151 | \newcounter{NumberOfAffiliations} 152 | \forcsvlist{\stepcounter{NumberOfAffiliations}\listadd\listaffiliations}{#1} 153 | } 154 | 155 | \catcode`\|=3 156 | 157 | \def\looptwo#1#2{% 158 | \edef\tmp{\noexpand\xtwo% 159 | \unexpanded\expandafter{#1}\relax % no added delimiter here 160 | \unexpanded\expandafter{#2}\relax % no added delimiter here 161 | } \tmp% 162 | }% 163 | 164 | \def\xtwo#1|#2\relax#3|#4\relax{% 165 | \dotwo{#1}{#3}% 166 | \def\tmp{#2}% 167 | \ifx\empty\tmp% 168 | \expandafter\@gobble% 169 | \else% 170 | \expandafter\@firstofone% 171 | \fi% 172 | {\xtwo#2\relax#4\relax}% 173 | }% 174 | 175 | \catcode`\|=12 176 | \newcommand*{\dotwo}[2]{} 177 | 178 | 179 | %%%%%%%%%%%%%% Title page %%%%%%%%%%%%%%%%%%%%% 180 | \RequirePackage{calc} 181 | \RequirePackage{pstricks} 182 | \RequirePackage{hyphenat} 183 | 184 | \renewcommand{\maketitle}{% 185 | \sloppy 186 | \noindent{\fontsize{14}{13.5}\fontseries{b}\selectfont\raggedright\nohyphens{\@title}} 187 | \vspace{.5in} 188 | 189 | \fussy 190 | \renewcommand*{\dotwo}[2]{% 191 | \noindent##1\\ 192 | \noindent\textit{##2} 193 | \vspace{1em}\par 194 | } 195 | \looptwo\listauthors\listaffiliations 196 | 197 | \vspace{1em} 198 | \parbox{\textwidth-\parindent-\parindent}{\small 199 | \textbf{Abstract}\\\show@abstract 200 | \\\\ 201 | \textbf{Keywords:} \show@keywords 202 | } 203 | } 204 | 205 | %%%%%%%%%%%%%% Sections %%%%%%%%%%%%%%%%%%%%% 206 | \renewcommand{\section}{% 207 | \@startsection{section}{1}{0pt}% 208 | {-1.5ex plus -1ex minus -.2ex}{2ex}% 209 | {\fontsize{14}{13.5}\fontseries{b}\selectfont}% 210 | } 211 | \setcounter{secnumdepth}{0} 212 | 213 | \renewcommand{\appendixname}{Appendix} 214 | 215 | % From: https://tex.stackexchange.com/a/160850 216 | \def\@seccntformat#1{\@ifundefined{#1@cntformat}% 217 | {\csname the#1\endcsname\quad} % default 218 | {\csname #1@cntformat\endcsname}% enable individual control 219 | } 220 | \let\oldappendix\appendix %% save current definition of \appendix 221 | \renewcommand\appendix{% 222 | \oldappendix 223 | % Force sections to start on new page 224 | \let\oldsection\section 225 | \renewcommand{\section}{\clearpage\oldsection} 226 | % Renew sections to 'Appendix A ' 227 | \setcounter{secnumdepth}{1} % start numbering sections again 228 | \newcommand{\section@cntformat}{\appendixname~\thesection\quad} 229 | % Count figures/tables from A1 230 | \setcounter{table}{0} 231 | \renewcommand{\thetable}{A\arabic{table}} 232 | \setcounter{figure}{0} 233 | \renewcommand{\thetable}{A\arabic{table}} 234 | } 235 | \makeatother 236 | 237 | 238 | %%%%%%%%%%%%%% Reference handling %%%%%%%%%%%%%%%%%%%%% 239 | \RequirePackage{csquotes} 240 | \RequirePackage[style=apa,sortcites=true,sorting=nyt,backend=biber]{biblatex} 241 | 242 | \DeclareLanguageMapping{american}{american-apa} 243 | 244 | \renewcommand{\bibfont}{\small} 245 | \setlength{\bibhang}{\parindent} 246 | -------------------------------------------------------------------------------- /paper/coltekin.md: -------------------------------------------------------------------------------- 1 | Çöltekin (2020) 2 | ================ 3 | 4 | OffensEval-TR 2020 is a [shared 5 | task](https://sites.google.com/site/offensevalsharedtask/results-and-paper-submission). 6 | The Turkish social media dataset by Çöltekin (2020) 7 | \[[link](https://aclanthology.org/2020.lrec-1.758)\] is available here. 8 | 9 | In this subtask, Turkish tweets, 31,756 and 3,528 in the training and 10 | test sets respectively, were coded as “Offensive” or “Not Offensive”. 11 | The state-of-the-art performance by the world’s best NLP experts for 12 | this subtask is 82.58% (Marco F1). Of course, it is quite impossible for 13 | this R package with default settings to obtain this performance. But it 14 | would be interesting to see how well the performance this package could 15 | get. 16 | 17 | ## Obtaining the data 18 | 19 | ``` r 20 | url <- "https://coltekin.github.io/offensive-turkish/offenseval2020-turkish.zip" 21 | temp <- tempfile(fileext = ".zip") 22 | download.file(url, temp) 23 | unzip(temp, exdir = here::here("paper")) 24 | ``` 25 | 26 | ## Create the training corpus 27 | 28 | ``` r 29 | require(quanteda) 30 | #> Loading required package: quanteda 31 | #> Package version: 3.2.4 32 | #> Unicode version: 13.0 33 | #> ICU version: 66.1 34 | #> Parallel computing: 16 of 16 threads used. 35 | #> See https://quanteda.io for tutorials and examples. 36 | require(readtext) 37 | #> Loading required package: readtext 38 | input <- readtext::readtext(here::here("offenseval2020-turkish/offenseval-tr-training-v1/offenseval-tr-training-v1.tsv"), text_field = "tweet", quote = "") %>% corpus 39 | ``` 40 | 41 | ## Train a classifer 42 | 43 | The model is based on the BERTurk model by the *Bayerische 44 | Staatsbibliothek* [^1]. 45 | 46 | ``` r 47 | set.seed(721) 48 | model <- grafzahl(x = input, 49 | y = "subtask_a", 50 | model_type = "bert", 51 | model_name = "dbmdz/bert-base-turkish-cased", 52 | output_dir = here::here("turkmodel")) 53 | saveRDS(model, here::here("turkmodel.RDS")) 54 | ``` 55 | 56 | ## Create the test corpus 57 | 58 | ``` r 59 | test <- rio::import(here::here("offenseval2020-turkish/offenseval-tr-testset-v1/offenseval-tr-testset-v1.tsv"), quote = "") 60 | 61 | labels <- rio::import(here::here("offenseval2020-turkish/offenseval-tr-testset-v1/offenseval-tr-labela-v1.tsv"), quote = "") 62 | 63 | colnames(labels)[1] <- "id" 64 | colnames(labels)[2] <- "subtask_a" 65 | 66 | require(dplyr) 67 | #> Loading required package: dplyr 68 | #> 69 | #> Attaching package: 'dplyr' 70 | #> The following objects are masked from 'package:stats': 71 | #> 72 | #> filter, lag 73 | #> The following objects are masked from 'package:base': 74 | #> 75 | #> intersect, setdiff, setequal, union 76 | test %>% left_join(labels) -> test 77 | #> Joining, by = "id" 78 | 79 | 80 | corpus(test, text_field = "tweet") -> test_corpus 81 | ``` 82 | 83 | ## Calculation of Macro-F1 84 | 85 | ``` r 86 | preds <- predict(model, newdata = test_corpus) 87 | 88 | sum(caret::confusionMatrix(table(preds, docvars(test_corpus, "subtask_a")), mode = "prec_recall", positive = "OFF")$byClass["F1"], caret::confusionMatrix(table(preds, docvars(test_corpus, "subtask_a")), mode = "prec_recall", positive = "NOT")$byClass["F1"]) / 2 89 | #> [1] 0.7972064 90 | ``` 91 | 92 | Not bad (vs the SOTA: 82.58%)! 93 | 94 | [^1]: https://huggingface.co/dbmdz/bert-base-turkish-cased 95 | -------------------------------------------------------------------------------- /paper/coltekin.qmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: Çöltekin (2020) 3 | format: gfm 4 | --- 5 | 6 | ```{r} 7 | #| include: false 8 | knitr::opts_chunk$set( 9 | collapse = TRUE, 10 | comment = "#>", 11 | fig.path = "img/", 12 | out.width = "100%" 13 | ) 14 | require(grafzahl) 15 | model <- readRDS(here::here("turkmodel.RDS")) 16 | ``` 17 | 18 | 19 | OffensEval-TR 2020 is a [shared task](https://sites.google.com/site/offensevalsharedtask/results-and-paper-submission). The Turkish social media dataset by 20 | Çöltekin (2020) [[link](https://aclanthology.org/2020.lrec-1.758)] is available here. 21 | 22 | In this subtask, Turkish tweets, 31,756 and 3,528 in the training and test sets respectively, were coded as "Offensive" or "Not Offensive". The state-of-the-art performance by the world's best NLP experts for this subtask is 82.58\% (Marco F1). Of course, it is quite impossible for this R package with default settings to obtain this performance. But it would be interesting to see how well the performance this package could get. 23 | 24 | ## Obtaining the data 25 | 26 | ```{r} 27 | #| eval: false 28 | url <- "https://coltekin.github.io/offensive-turkish/offenseval2020-turkish.zip" 29 | temp <- tempfile(fileext = ".zip") 30 | download.file(url, temp) 31 | unzip(temp, exdir = here::here("paper")) 32 | ``` 33 | 34 | ## Create the training corpus 35 | 36 | ```{r} 37 | require(quanteda) 38 | require(readtext) 39 | input <- readtext::readtext(here::here("offenseval2020-turkish/offenseval-tr-training-v1/offenseval-tr-training-v1.tsv"), text_field = "tweet", quote = "") %>% corpus 40 | ``` 41 | 42 | ## Train a classifer 43 | 44 | The model is based on the BERTurk model by the *Bayerische Staatsbibliothek* [^BERTurk]. 45 | 46 | [^BERTurk]: https://huggingface.co/dbmdz/bert-base-turkish-cased 47 | 48 | ```{r} 49 | #| eval: false 50 | set.seed(721) 51 | model <- grafzahl(x = input, 52 | y = "subtask_a", 53 | model_type = "bert", 54 | model_name = "dbmdz/bert-base-turkish-cased", 55 | output_dir = here::here("turkmodel")) 56 | saveRDS(model, here::here("turkmodel.RDS")) 57 | ``` 58 | 59 | ## Create the test corpus 60 | 61 | ```{r} 62 | test <- rio::import(here::here("offenseval2020-turkish/offenseval-tr-testset-v1/offenseval-tr-testset-v1.tsv"), quote = "") 63 | 64 | labels <- rio::import(here::here("offenseval2020-turkish/offenseval-tr-testset-v1/offenseval-tr-labela-v1.tsv"), quote = "") 65 | 66 | colnames(labels)[1] <- "id" 67 | colnames(labels)[2] <- "subtask_a" 68 | 69 | require(dplyr) 70 | test %>% left_join(labels) -> test 71 | 72 | 73 | corpus(test, text_field = "tweet") -> test_corpus 74 | ``` 75 | 76 | ## Calculation of Macro-F1 77 | 78 | ```{r} 79 | preds <- predict(model, newdata = test_corpus) 80 | 81 | sum(caret::confusionMatrix(table(preds, docvars(test_corpus, "subtask_a")), mode = "prec_recall", positive = "OFF")$byClass["F1"], caret::confusionMatrix(table(preds, docvars(test_corpus, "subtask_a")), mode = "prec_recall", positive = "NOT")$byClass["F1"]) / 2 82 | ``` 83 | 84 | Not bad (vs the SOTA: 82.58\%)! 85 | -------------------------------------------------------------------------------- /paper/dobbrick.md: -------------------------------------------------------------------------------- 1 | Dobbrick et al. (2021) 2 | ================ 3 | 4 | The following is to analyse the same data used in Dobbrick et al. (2021) 5 | “Enhancing Theory-Informed Dictionary Approaches with “Glass-box” 6 | Machine Learning: The Case of Integrative Complexity in Social Media 7 | Comments” \[[doi](https://doi.org/10.1080/19312458.2021.1999913)\]. The 8 | data is available from [osf](https://doi.org/10.17605/OSF.IO/578MG). 9 | 10 | Dobbrick et al. present a study of comparing various methods to learn 11 | and predict integrative complexity of English and German online user 12 | comments from Facebook, Twitter, and news website comment sections. 13 | According to the original paper, “Integrative complexity is a 14 | psychological measure that researchers increasingly implement to assess 15 | the argumentative quality of public debate contributions.” (p. 3) 16 | Comments were coded with a standard coding scheme into a 7-point Likert 17 | scale from 1 (lowest complexity) to 7 (highest complexity). The paper 18 | presents two approaches: Assumption-based approach and Shotgun approach. 19 | The Shotgun approach is similar to the traditional full-text machine 20 | learning approach. Dobbrick et al. report that CNN with word embeddings 21 | provides the best out-of-sample performance. The original paper reports 22 | 10-fold cross-validation. Root mean squared error (RMSE) of .75 23 | (English) and .84 (German) were reported. It is also important to note 24 | that Dobbrick et al. trained an individual model for each language. The 25 | human annotated data and the original training-and-test split are 26 | publicly available. In total, there are 4,800 annotated comments. 27 | 28 | Please note that this is a regression example. 29 | 30 | # Obtain the data from OSF 31 | 32 | ``` r 33 | temp <- tempdir() 34 | require(osfr) 35 | osf_retrieve_file("https://osf.io/m6a9n") %>% 36 | osf_download(path = temp) 37 | ## sanity check 38 | file.exists(file.path(temp, "glassbox.zip")) 39 | ## goldstandard_ic_en.csv 40 | ## goldstandard_ic_de.csv 41 | unzip(file.path(temp, "glassbox.zip"), files = c("glassbox/data/goldstandard_ic_de.csv", "glassbox/data/goldstandard_ic_en.csv"), exdir = temp) 42 | 43 | ## sanity check 44 | file.exists(file.path(temp, "glassbox/data/goldstandard_ic_de.csv")) 45 | file.exists(file.path(temp, "glassbox/data/goldstandard_ic_en.csv")) 46 | 47 | file.copy(file.path(temp, "glassbox/data/goldstandard_ic_de.csv"), here::here("goldstandard_ic_de.csv")) 48 | file.copy(file.path(temp, "glassbox/data/goldstandard_ic_en.csv"), here::here("goldstandard_ic_en.csv")) 49 | ``` 50 | 51 | # Read the data 52 | 53 | ``` r 54 | require(quanteda) 55 | #> Loading required package: quanteda 56 | #> Package version: 3.2.4 57 | #> Unicode version: 13.0 58 | #> ICU version: 66.1 59 | #> Parallel computing: 16 of 16 threads used. 60 | #> See https://quanteda.io for tutorials and examples. 61 | require(grafzahl) 62 | #> Loading required package: grafzahl 63 | require(rio) 64 | #> Loading required package: rio 65 | #> 66 | #> Attaching package: 'rio' 67 | #> The following object is masked from 'package:quanteda': 68 | #> 69 | #> convert 70 | require(dplyr) 71 | #> Loading required package: dplyr 72 | #> 73 | #> Attaching package: 'dplyr' 74 | #> The following objects are masked from 'package:stats': 75 | #> 76 | #> filter, lag 77 | #> The following objects are masked from 'package:base': 78 | #> 79 | #> intersect, setdiff, setequal, union 80 | 81 | ## The csv file is actually the "European" variant; can't use readtext 82 | ## https://github.com/quanteda/readtext/issues/170 83 | en_data <- rio::import(here::here("goldstandard_ic_en.csv")) %>% tibble::as_tibble() %>% filter(!is_redacted & main_language == "en" & WC > 0) 84 | en_data %>% pull(post) %>% corpus -> en_corpus 85 | docvars(en_corpus, "icom") <- as.numeric(en_data$ic_ordinal) 86 | docnames(en_corpus) <- paste0("en", seq_along(en_corpus)) 87 | 88 | de_data <- rio::import(here::here("goldstandard_ic_de.csv")) %>% tibble::as_tibble() %>% filter(!is_redacted & main_language == "de" & WC > 0) 89 | de_data %>% pull(post) %>% corpus -> de_corpus 90 | docvars(de_corpus, "icom") <- as.numeric(de_data$ic_ordinal) 91 | 92 | docnames(de_corpus) <- paste0("de", seq_along(de_corpus)) 93 | ``` 94 | 95 | # Generate the 10-fold cross validation setup 96 | 97 | ``` r 98 | set.seed(2020) 99 | en_ranid <- sample(1:10, size = ndoc(en_corpus), replace = TRUE) 100 | 101 | set.seed(2020) 102 | de_ranid <- sample(1:10, size = ndoc(de_corpus), replace = TRUE) 103 | ``` 104 | 105 | # Do the 10-fold cross validation 106 | 107 | The distil mBERT is used in this case to model the two languages 108 | simultanesouly, whereas the original paper modeled the two languages 109 | seperately. 110 | 111 | Also, we can see here that the quanteda `corpus` objects can be combined 112 | by `+`. In this example, as there is only `docvar` in the corpus 113 | (“icom”), that docvar is used as the label to be predicted. 114 | 115 | ``` r 116 | rmse <- function(x, y) { 117 | sqrt(mean((x - y)^2)) 118 | } 119 | 120 | fold <- function(i, en_corpus, de_corpus, en_ranid, de_ranid) { 121 | mod <- grafzahl(en_corpus[en_ranid != i] + de_corpus[de_ranid != i], 122 | model_name = "distilbert-base-multilingual-cased", 123 | output_dir = here::here(paste0("dobbrick", i)), 124 | regression = TRUE) 125 | pred_en <- predict(mod, en_corpus[en_ranid == i]) 126 | pred_de <- predict(mod, de_corpus[de_ranid == i]) 127 | x1 <- cor(docvars(en_corpus, "icom")[en_ranid == i], pred_en) 128 | x2 <- rmse(docvars(en_corpus, "icom")[en_ranid == i], pred_en) 129 | x3 <- cor(docvars(de_corpus, "icom")[de_ranid == i], pred_de) 130 | x4 <- rmse(docvars(de_corpus, "icom")[de_ranid == i], pred_de) 131 | return(tibble(i = i, en_cor = x1, en_rmse = x2, de_cor = x3, de_rmse = x4)) 132 | } 133 | 134 | res <- purrr::map_dfr(1:10, fold, en_corpus = en_corpus, de_corpus = de_corpus, en_ranid = en_ranid, de_ranid = de_ranid) 135 | ``` 136 | 137 | # Results 138 | 139 | Apply the same 10-fold cross-validation setup, the RMSE for English and 140 | German are .67 and .74 respectively (vs. .75 and .84 from the original 141 | paper, lower is better). 142 | 143 | ``` r 144 | res %>% select(-i) %>% summarise_all(mean) 145 | #> # A tibble: 1 × 4 146 | #> en_cor en_rmse de_cor de_rmse 147 | #> <dbl> <dbl> <dbl> <dbl> 148 | #> 1 0.740 0.671 0.726 0.738 149 | ``` 150 | -------------------------------------------------------------------------------- /paper/dobbrick.qmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: Dobbrick et al. (2021) 3 | format: gfm 4 | --- 5 | 6 | ```{r} 7 | #| include: false 8 | knitr::opts_chunk$set( 9 | collapse = TRUE, 10 | comment = "#>", 11 | fig.path = "img/", 12 | out.width = "100%" 13 | ) 14 | res <- readRDS(here::here("julia_cv.RDS")) 15 | ``` 16 | 17 | The following is to analyse the same data used in Dobbrick et al. (2021) "Enhancing Theory-Informed Dictionary Approaches with “Glass-box” Machine Learning: The Case of Integrative Complexity in Social Media Comments" [[doi](https://doi.org/10.1080/19312458.2021.1999913)]. The data is available from [osf](https://doi.org/10.17605/OSF.IO/578MG). 18 | 19 | Dobbrick et al. present a study of comparing various methods to learn and predict integrative complexity of English and German online user comments from Facebook, Twitter, and news website comment sections. According to the original paper, "Integrative complexity is a psychological measure that researchers increasingly implement to assess the argumentative quality of public debate contributions." (p. 3) Comments were coded with a standard coding scheme into a 7-point Likert scale from 1 (lowest complexity) to 7 (highest complexity). The paper presents two approaches: Assumption-based approach and Shotgun approach. The Shotgun approach is similar to the traditional full-text machine learning approach. Dobbrick et al. report that CNN with word embeddings provides the best out-of-sample performance. The original paper reports 10-fold cross-validation. Root mean squared error (RMSE) of .75 (English) and .84 (German) were reported. It is also important to note that Dobbrick et al. trained an individual model for each language. The human annotated data and the original training-and-test split are publicly available. In total, there are 4,800 annotated comments. 20 | 21 | Please note that this is a regression example. 22 | 23 | # Obtain the data from OSF 24 | 25 | ```{r} 26 | #| eval: false 27 | temp <- tempdir() 28 | require(osfr) 29 | osf_retrieve_file("https://osf.io/m6a9n") %>% 30 | osf_download(path = temp) 31 | ## sanity check 32 | file.exists(file.path(temp, "glassbox.zip")) 33 | ## goldstandard_ic_en.csv 34 | ## goldstandard_ic_de.csv 35 | unzip(file.path(temp, "glassbox.zip"), files = c("glassbox/data/goldstandard_ic_de.csv", "glassbox/data/goldstandard_ic_en.csv"), exdir = temp) 36 | 37 | ## sanity check 38 | file.exists(file.path(temp, "glassbox/data/goldstandard_ic_de.csv")) 39 | file.exists(file.path(temp, "glassbox/data/goldstandard_ic_en.csv")) 40 | 41 | file.copy(file.path(temp, "glassbox/data/goldstandard_ic_de.csv"), here::here("goldstandard_ic_de.csv")) 42 | file.copy(file.path(temp, "glassbox/data/goldstandard_ic_en.csv"), here::here("goldstandard_ic_en.csv")) 43 | ``` 44 | 45 | # Read the data 46 | 47 | ```{r} 48 | require(quanteda) 49 | require(grafzahl) 50 | require(rio) 51 | require(dplyr) 52 | 53 | ## The csv file is actually the "European" variant; can't use readtext 54 | ## https://github.com/quanteda/readtext/issues/170 55 | en_data <- rio::import(here::here("goldstandard_ic_en.csv")) %>% tibble::as_tibble() %>% filter(!is_redacted & main_language == "en" & WC > 0) 56 | en_data %>% pull(post) %>% corpus -> en_corpus 57 | docvars(en_corpus, "icom") <- as.numeric(en_data$ic_ordinal) 58 | docnames(en_corpus) <- paste0("en", seq_along(en_corpus)) 59 | 60 | de_data <- rio::import(here::here("goldstandard_ic_de.csv")) %>% tibble::as_tibble() %>% filter(!is_redacted & main_language == "de" & WC > 0) 61 | de_data %>% pull(post) %>% corpus -> de_corpus 62 | docvars(de_corpus, "icom") <- as.numeric(de_data$ic_ordinal) 63 | 64 | docnames(de_corpus) <- paste0("de", seq_along(de_corpus)) 65 | ``` 66 | 67 | # Generate the 10-fold cross validation setup 68 | 69 | ```{r} 70 | set.seed(2020) 71 | en_ranid <- sample(1:10, size = ndoc(en_corpus), replace = TRUE) 72 | 73 | set.seed(2020) 74 | de_ranid <- sample(1:10, size = ndoc(de_corpus), replace = TRUE) 75 | ``` 76 | 77 | # Do the 10-fold cross validation 78 | 79 | The distil mBERT is used in this case to model the two languages simultanesouly, whereas the original paper modeled the two languages seperately. 80 | 81 | Also, we can see here that the quanteda `corpus` objects can be combined by `+`. In this example, as there is only `docvar` in the corpus ("icom"), that docvar is used as the label to be predicted. 82 | 83 | ```{r} 84 | #| eval: false 85 | rmse <- function(x, y) { 86 | sqrt(mean((x - y)^2)) 87 | } 88 | 89 | fold <- function(i, en_corpus, de_corpus, en_ranid, de_ranid) { 90 | mod <- grafzahl(en_corpus[en_ranid != i] + de_corpus[de_ranid != i], 91 | model_name = "distilbert-base-multilingual-cased", 92 | output_dir = here::here(paste0("dobbrick", i)), 93 | regression = TRUE) 94 | pred_en <- predict(mod, en_corpus[en_ranid == i]) 95 | pred_de <- predict(mod, de_corpus[de_ranid == i]) 96 | x1 <- cor(docvars(en_corpus, "icom")[en_ranid == i], pred_en) 97 | x2 <- rmse(docvars(en_corpus, "icom")[en_ranid == i], pred_en) 98 | x3 <- cor(docvars(de_corpus, "icom")[de_ranid == i], pred_de) 99 | x4 <- rmse(docvars(de_corpus, "icom")[de_ranid == i], pred_de) 100 | return(tibble(i = i, en_cor = x1, en_rmse = x2, de_cor = x3, de_rmse = x4)) 101 | } 102 | 103 | res <- purrr::map_dfr(1:10, fold, en_corpus = en_corpus, de_corpus = de_corpus, en_ranid = en_ranid, de_ranid = de_ranid) 104 | ``` 105 | 106 | # Results 107 | 108 | Apply the same 10-fold cross-validation setup, the RMSE for English and German are .67 and .74 respectively (vs. .75 and .84 from the original paper, lower is better). 109 | 110 | ```{r} 111 | res %>% select(-i) %>% summarise_all(mean) 112 | ``` 113 | -------------------------------------------------------------------------------- /paper/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gesistsa/grafzahl/0cb7d375e1bb64f6c2fedc5bd4ba896f81f68572/paper/fig1.png -------------------------------------------------------------------------------- /paper/fig2-1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gesistsa/grafzahl/0cb7d375e1bb64f6c2fedc5bd4ba896f81f68572/paper/fig2-1.pdf -------------------------------------------------------------------------------- /paper/grafzahl_sp.bib: -------------------------------------------------------------------------------- 1 | @article{azime2021amharic, 2 | author = {Azime, Israel Abebe and Mohammed, Nebil}, 3 | journal = {arXiv preprint arXiv:2103.05639}, 4 | title = {{An Amharic News Text classification Dataset}}, 5 | year = {2021}, 6 | } 7 | 8 | @article{theocharis:2020:DPI, 9 | author = {Theocharis, Yannis and Barberá, Pablo and Fazekas, 10 | Zoltán and Popa, Sebastian Adrian}, 11 | doi = {10.1177/2158244020919447}, 12 | issn = {2158-2440}, 13 | journal = {SAGE Open}, 14 | month = {Apr}, 15 | number = {2}, 16 | pages = {215824402091944}, 17 | publisher = {SAGE Publications}, 18 | title = {The Dynamics of Political Incivility on Twitter}, 19 | url = {http://dx.doi.org/10.1177/2158244020919447}, 20 | volume = {10}, 21 | year = {2020}, 22 | } 23 | 24 | @article{joshi2020state, 25 | author = {Joshi, Pratik and Santy, Sebastin and Budhiraja, Amar and Bali, Kalika and Choudhury, Monojit}, 26 | journal = {arXiv preprint arXiv:2004.09095}, 27 | title = {The state and fate of linguistic diversity and inclusion in the NLP world}, 28 | year = {2020}, 29 | } 30 | 31 | @inproceedings{ribeiro2016should, 32 | author = {Ribeiro, Marco Tulio and Singh, Sameer and Guestrin, Carlos}, 33 | booktitle = {Proceedings of the 22nd ACM SIGKDD international conference on knowledge discovery and data mining}, 34 | pages = {1135--1144}, 35 | title = {"Why should i trust you?" Explaining the predictions of any classifier}, 36 | year = {2016}, 37 | } 38 | 39 | @inproceedings{ccoltekin2020corpus, 40 | author = {{Ç}{ö}ltekin, {Ç}a{ğ}r{ı}}, 41 | booktitle = {Proceedings of the 12th language resources and evaluation conference}, 42 | pages = {6174--6184}, 43 | title = {{A corpus of Turkish offensive language on social media}}, 44 | year = {2020}, 45 | } 46 | 47 | @manual{quantedatextmodels, 48 | author = {Kenneth Benoit and Kohei Watanabe and Haiyan Wang and Patrick O. Perry and Benjamin Lauderdale and Johannes Gruber and William Lowe}, 49 | note = {R package version 0.9.4}, 50 | title = {quanteda.textmodels: {Scaling Models and Classifiers for Textual Data}}, 51 | url = {https://CRAN.R-project.org/package=quanteda.textmodels}, 52 | year = {2021}, 53 | } 54 | 55 | @inproceedings{ogueji2021small, 56 | author = {Ogueji, Kelechi and Zhu, Yuxin and Lin, Jimmy}, 57 | booktitle = {Proceedings of the 1st Workshop on Multilingual Representation Learning}, 58 | pages = {116--126}, 59 | title = {Small data? no problem! exploring the viability of pretrained multilingual language models for low-resourced languages}, 60 | year = {2021}, 61 | } 62 | 63 | @article{atteveldt:2021:VSA, 64 | author = {{Van Atteveldt}, Wouter and {Van der Velden}, Mariken 65 | A. C. G. and Boukes, Mark}, 66 | doi = {10.1080/19312458.2020.1869198}, 67 | issn = {1931-2466}, 68 | journal = {Communication Methods and Measures}, 69 | month = {Jan}, 70 | pages = {1–20}, 71 | publisher = {Informa UK Limited}, 72 | title = {The Validity of Sentiment Analysis:Comparing Manual 73 | Annotation, Crowd-Coding, Dictionary Approaches, and 74 | Machine Learning Algorithms}, 75 | url = {http://dx.doi.org/10.1080/19312458.2020.1869198}, 76 | year = {2021}, 77 | } 78 | 79 | @article{dobbrick:2021:ETI, 80 | author = {Dobbrick, Timo and Jakob, Julia and Chan, Chung-Hong 81 | and Wessler, Hartmut}, 82 | doi = {10.1080/19312458.2021.1999913}, 83 | issn = {1931-2466}, 84 | journal = {Communication Methods and Measures}, 85 | month = {Nov}, 86 | pages = {1–18}, 87 | publisher = {Informa UK Limited}, 88 | title = {Enhancing Theory-Informed Dictionary Approaches with 89 | “Glass-box” Machine Learning: The Case of 90 | Integrative Complexity in Social Media Comments}, 91 | url = {http://dx.doi.org/10.1080/19312458.2021.1999913}, 92 | year = {2021}, 93 | } 94 | 95 | @article{de2019bertje, 96 | author = {{de Vries}, Wietse and van Cranenburgh, Andreas and Bisazza, Arianna and Caselli, Tommaso and van Noord, Gertjan and Nissim, Malvina}, 97 | journal = {arXiv preprint arXiv:1912.09582}, 98 | title = {{Bertje: A Dutch BERT model}}, 99 | year = {2019}, 100 | } 101 | 102 | @manual{simpletransformers, 103 | author = {Thilina Rajapakse}, 104 | title = {{Simple Transformers}}, 105 | url = {https://simpletransformers.ai/}, 106 | year = {2022}, 107 | } 108 | 109 | @article{kuhn:2008:BPM, 110 | author = {Kuhn, Max}, 111 | doi = {10.18637/jss.v028.i05}, 112 | issn = {1548-7660}, 113 | journal = {Journal of Statistical Software}, 114 | number = {5}, 115 | publisher = {Foundation for Open Access Statistic}, 116 | title = {Building Predictive Models in {R} Using the caret Package}, 117 | url = {http://dx.doi.org/10.18637/jss.v028.i05}, 118 | volume = {28}, 119 | year = {2008}, 120 | } 121 | 122 | @manual{reticulate, 123 | author = {Kevin Ushey and JJ Allaire and Yuan Tang}, 124 | note = {R package version 1.25}, 125 | title = {{reticulate: Interface to 'Python'}}, 126 | url = {https://CRAN.R-project.org/package=reticulate}, 127 | year = {2022}, 128 | } 129 | 130 | @manual{lime, 131 | author = {Thomas Lin Pedersen and Michaël Benesty}, 132 | note = {R package version 0.5.2}, 133 | title = {lime: Local Interpretable Model-Agnostic Explanations}, 134 | url = {https://CRAN.R-project.org/package=lime}, 135 | year = {2021}, 136 | } 137 | 138 | @inproceedings{wolf-etal-2020-transformers, 139 | address = {Online}, 140 | author = {Thomas Wolf and Lysandre Debut and Victor Sanh and Julien Chaumond and Clement Delangue and Anthony Moi and Pierric Cistac and Tim Rault and Rémi Louf and Morgan Funtowicz and Joe Davison and Sam Shleifer and Patrick von Platen and Clara Ma and Yacine Jernite and Julien Plu and Canwen Xu and Teven Le Scao and Sylvain Gugger and Mariama Drame and Quentin Lhoest and Alexander M. Rush}, 141 | booktitle = {Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: System Demonstrations}, 142 | month = {October}, 143 | pages = {38--45}, 144 | publisher = {Association for Computational Linguistics}, 145 | title = {Transformers: State-of-the-Art Natural Language Processing}, 146 | url = {https://www.aclweb.org/anthology/2020.emnlp-demos.6}, 147 | year = {2020}, 148 | } 149 | 150 | @article{baden:2021:TGC, 151 | author = {Baden, Christian and Pipal, Christian and 152 | Schoonvelde, Martijn and {van der Velden}, Mariken 153 | A. C. G}, 154 | doi = {10.1080/19312458.2021.2015574}, 155 | issn = {1931-2466}, 156 | journal = {Communication Methods and Measures}, 157 | month = {Dec}, 158 | number = {1}, 159 | pages = {1–18}, 160 | publisher = {Informa UK Limited}, 161 | title = {Three Gaps in Computational Text Analysis Methods 162 | for Social Sciences: A Research Agenda}, 163 | url = {http://dx.doi.org/10.1080/19312458.2021.2015574}, 164 | volume = {16}, 165 | year = {2021}, 166 | } 167 | 168 | -------------------------------------------------------------------------------- /paper/grafzahl_sp.qmd: -------------------------------------------------------------------------------- 1 | --- 2 | title : "grafzahl: fine-tuning Transformers for text data from within R" 3 | shorttitle : "PUT THE R BACK IN TRANSFORMERS" 4 | 5 | author: 6 | - name : "Chung-hong Chan" 7 | affiliation : "GESIS - Leibniz-Institut für Sozialwissenschaften, Germany" 8 | corresponding : yes 9 | address : "Unter Sachsenhausen 6-8, 50667 Köln" 10 | email : "chung-hong.chan@gesis.org" 11 | orcid : "0000-0002-6232-7530" 12 | 13 | authornote: | 14 | Source code and data are available at https://github.com/chainsawriot/grafzahl. 15 | 16 | abstract: | 17 | This paper introduces `grafzahl`, an R package for fine-tuning Transformers for text data from within R. The package is used in this paper to reproduce the analyses in other papers. Very significant improvement in model accuracy over traditional machine learning approaches such as convolutional Neural Network is observed. 18 | 19 | keywords : "machine learning, transformers, R, python, automated content analysis" 20 | volume: 5 21 | pubnumber: 1 22 | pubyear: 2023 23 | firstpage: 76 24 | doi: "10.5117/CCR2023.1.003.CHAN" 25 | bibliography : "grafzahl_sp.bib" 26 | shortauthors: "Chan" 27 | knitr: 28 | opts_chunk: 29 | fig.path: img/ 30 | format: 31 | ccr-pdf: 32 | keep-tex: true 33 | --- 34 | 35 | ## Put the R back in Transformers 36 | 37 | The purpose of this R package, `grafzahl`, is to provide the missing link between R and modern Transformers language models. Under the hood, the training part is based on the Python packages `transformers` [@wolf-etal-2020-transformers] and `simpletransformers` [@simpletransformers]. The integration based on `reticulate` [@reticulate] is seamless. With this seamless integration provided, communication researchers can produce the most advanced supervised learning models entirely from within R. This package provides the function `grafzahl()`, which emulates the behaviors of `quanteda.textmodels` [@quantedatextmodels]. [^f] 38 | 39 | [^f]: This package uses reasonable default settings which suit what communication researchers would like to achieve with these models. But the package also provides the freedom for communication researchers to finely adjust the parameters for their specific applications. However, the reanalysis of several examples in communication suggests that even the default settings can generate great improvement over the performance as reported in the original papers. Also, there is almost no need to conduct the cumbersome proprocessing and feature engineering steps, which all examples originally required. 40 | 41 | Two examples [@atteveldt:2021:VSA;@azime2021amharic] are presented here. Additional examples [@theocharis:2020:DPI; @dobbrick:2021:ETI; @ccoltekin2020corpus] are available in the Github repository of the package ([https://github.com/chainsawriot/grafzahl](https://github.com/chainsawriot/grafzahl)). 42 | 43 | # Monolingual classification example 44 | 45 | @atteveldt:2021:VSA compare various methods to analyze the tone of Dutch economic news' headlines. Headlines were coded into three categories: negative (-1), neutral (0), and positive (+1). 46 | 47 | In the original paper, @atteveldt:2021:VSA show that the best method for predicting expert coding, other than coding by student helpers, is convolutional neural network (CNN) with Dutch word embeddings trained on Dutch news. The out-of-sample F1 of .63, .66, and .56 were reported for the three categories. As the data (including the training-and-test split) are publicly available [^wouter] and included in this package (as `ecosent`), I can provide a head-to-head comparison between the reported CNN and the Transformer-based model trained with `grafzahl`. 48 | 49 | [^wouter]: [https://github.com/vanatteveldt/ecosent/](https://github.com/vanatteveldt/ecosent/) 50 | 51 | There are three important columns in the `ecosent` data frame: 52 | 53 | 1. `headline`: the actual text data 54 | 2. `value`: the sentiment 55 | 3. `gold`: whether or not this row is "gold standard", i.e. test set. There are 6,038 and 300 headlines in the training and test set respectively. 56 | 57 | ## Workflow 58 | 59 | ### Step 0: Setup `grafzahl` 60 | 61 | This step only needs to be done once. A miniconda environment needs to be setup. It is in general not recommended to use this package without a CUDA-compatible GPU. Without a CUDA-compatible GPU, the fine-tuning processes below might take days, if not weeks. 62 | 63 | If there is a GPU capable of performing CUDA, run: 64 | 65 | ```r 66 | ## Github version 67 | ## remotes::install_github("chainsawriot/grafzahl") 68 | install.packages("grafzahl") ## CRAN version 69 | require(grafzahl) 70 | setup_grafzahl(cuda = TRUE) # set to FALSE otherwise 71 | detect_cuda() 72 | ``` 73 | 74 | If the automatic setup failed, one can also set up the miniconda environment manually to diagnose what went wrong. The complete instructions are available here: [https://github.com/chainsawriot/grafzahl/wiki/setup_grafzahl](https://github.com/chainsawriot/grafzahl/wiki/setup_grafzahl). 75 | 76 | ### Step 1: Get information of the pretrained Transformer 77 | 78 | The first step of training a Transformer-based model is to find a suitable pretrained Transformer model on Hugging Face [^hugg], which would work for the data. As the data are in Dutch, the pretrained Dutch Transformer model BERTje should work [@de2019bertje] [^bertje]. The model name of BERTje is `GroNLP/bert-base-dutch-cased`. It is also important to note the citation information to properly cite the pretrained Transformer model. 79 | 80 | [^hugg]: Hugging Face ([https://huggingface.co](https://huggingface.co)) is an online repository of pretrained machine learning models. 81 | 82 | [^bertje]: Available from [https://huggingface.co/GroNLP/bert-base-dutch-cased](https://huggingface.co/GroNLP/bert-base-dutch-cased) 83 | 84 | ### Step 2: Create the corpus 85 | 86 | The second step is to read the data as a corpus. [^CORPUS] 87 | 88 | [^CORPUS]: This step is not absolutely needed. The package can also work with character vectors. The `corpus` data structure is a better representation of character vector. 89 | 90 | ```r 91 | require(readtext) 92 | require(quanteda) 93 | input <- corpus(ecosent, text_field = "headline") 94 | ``` 95 | 96 | We can manipulate the corpus object using the functions provided by `quanteda`. For example, one can subset the training set using the function `corpus_subset()`. 97 | 98 | ```r 99 | ## selecting documents where the docvar `gold` is FALSE 100 | training_corpus <- corpus_subset(input, !gold) 101 | ``` 102 | 103 | ### Step 3: Fine-tune the model 104 | 105 | With the corpus and model name, the `grafzahl` function is used to fine-tune the model. 106 | 107 | ```r 108 | model <- grafzahl(x = training_corpus, 109 | y = "value", 110 | model_name = "GroNLP/bert-base-dutch-cased") 111 | #### specify `output_dir` 112 | ## model <- grafzahl(x = training_corpus, 113 | ## y = "value", 114 | ## model_name = "GroNLP/bert-base-dutch-cased", 115 | ## output_dir = "~/dutch_model") 116 | ``` 117 | 118 | in general, it is better to specify `output_dir` (where to put the saved model object). By default, it will be a random temporary directory. The R function `set.seed()` can also be used to preserve the random seed for reproducibility. 119 | 120 | On a regular off-the-shelf gaming laptop with a GeForce RTX 3050 Ti GPU and 4G of GPU ram, the process took around 20 minutes. 121 | 122 | ### Step 4: Make prediction 123 | 124 | Following the convention of `lm()` and many other R packages, the object returned by the function `grafzahl()` has a `predict()` S3 method. The following code gets the predicted sentiment of the headlines in the test set. 125 | 126 | ```r 127 | test_corpus <- corpus_subset(input, gold) 128 | predicted_sentiment <- predict(model, test_corpus) 129 | ``` 130 | 131 | ### Step 5: Evaluate Performance 132 | 133 | With the predicted sentiment and the ground truth, there are many ways to evaluate the performance of the fine-tuned model. The simplest way is to construct a confusion matrix using the standard `table()` function. 134 | 135 | ```r 136 | cm <- table(predicted_sentiment, 137 | ground_truth = docvars(test_corpus, "value")) 138 | ``` 139 | 140 | The R package `caret` [@kuhn:2008:BPM] can also be used to calculate standard performance metrics such as Precision, Recall, and F1 [^caret]. 141 | 142 | ```r 143 | require(caret) 144 | confusionMatrix(cm, mode = "prec_recall") 145 | ``` 146 | 147 | The out-of-sample F1 measures of the fine-tuned model are .76, .67, and .72 (vs reported .63, .66, and .56). There is great improvement over the CNN model reported by @atteveldt:2021:VSA, although the prediction accuracy for the neutral category is just on par. @atteveldt:2021:VSA also provide the learning curve of CNN and Support Vector Machine (SVM). A learning curve plots the out-of-sample prediction performance as a function of number of training examples. I repeat the analysis in a similar manner to @atteveldt:2021:VSA and plot the learning curve of Transformer-based model trained using the default workflow of `grafzahl`. 148 | 149 | @fig-fig1 show the fine-tuned Transformer model's learning curve alongside CNN's and SVM's [^learningcode]. The fune-tuned model has much better performance than CNN and SVM even with only 500 training examples. Unlike CNN and SVM, the gain in performance appears to plateau after 2500. It points to the fact that one does not need to have a lot of training data to fine-tune a Transformer model. 150 | 151 | ```{r} 152 | #| echo: false 153 | #| fig.cap: Learning curve of machine learning algorithms 154 | #| label: fig-fig1 155 | readRDS(here::here("paper/learning.RDS")) 156 | ``` 157 | 158 | [^caret]: The function `confusionMatrix()` can accept the predicted values and ground truth directly, without using `table()` first. But the predicted values and ground truth must be `factor`: `confusionMatrix(as.factor(predicted_sentiment), as.factor(docvars(test_corpus, "value")), mode = "prec_recall")`. 159 | 160 | [^learningcode]: The R code for generating the learning curves is available in the official repository: [https://github.com/chainsawriot/grafzahl](https://github.com/chainsawriot/grafzahl) 161 | 162 | ### Step 5: Explain the prediction 163 | 164 | Unlike "glass-box" machine learning models [@dobbrick:2021:ETI], Transformer-based prediction models are "black-box". There are so many parameters in Transformers (the BERT base model has 110 million parameters) and this complexity makes each individual parameter of a model not interpretable. 165 | 166 | A reasonable compromise is to make the prediction *explainable* instead. Generating Local Interpretable Model-agnostic Explanations (LIME) [@ribeiro2016should; R implementation by @lime] is a good way to explain how the model makes its prediction. The gist of the method is to perturb the input text data by deleting parts of the sentence. For example: the sentence "I hate this movie" will be perturbed as "I this movie", "I hate movie", "I hate this", "I hate" etc. These perturbed sentences are then feed into the machine learning model to make predictions. The relationship between what get deleted and the prediction is studied. The parts that change the prediction a lot would be more *causal* to the original prediction. 167 | 168 | With the trained model, we can explain the predictions made for the following two Dutch headlines: *"Dijsselbloem pessimistisch over snelle stappen Grieken"* (Dijsselbloem [the Former Minister of Finance of the Netherlands] pessimistic about rapid maneuvers from Greeks) and *"Aandelenbeurzen zetten koersopmars voort"* (Stock markets continue to rise). Models trained with `grafzahl` support the R package `lime` directly. One can get explanations using the following code: 169 | 170 | ```r 171 | require(lime) 172 | sentences <- 173 | c("Dijsselbloem pessimistisch over snelle stappen Grieken", 174 | "Aandelenbeurzen zetten koersopmars voort") 175 | explainer <- lime(training_corpus, model) 176 | explanations <- explain(sentences, explainer, n_labels = 1, 177 | n_features = 3) 178 | plot_text_explanations(explanations) 179 | ``` 180 | 181 | ```{r} 182 | #| label: fig-fig2 183 | #| echo: false 184 | #| fig.cap: Generating Local Interpretable Model-agnostic Explanations (LIME) of two predictions from the trained Dutch sentiment model 185 | knitr::include_graphics("fig1.png", dpi = 150) 186 | ``` 187 | 188 | @fig-fig2 shows that for the sentence *"Dijsselbloem pessimistisch over snelle stappen Grieken"* (classified as negative), the tokens *pessimistisch* and *stappen* are making the prediction towards the classified position (negative). But the token *Dijsselbloem* is making it away. 189 | 190 | # Non-Germanic example: Amharic 191 | 192 | I want to emphasize that `grafzahl` is not just another package focusing only on English, or Germanic languages such as Dutch. @baden:2021:TGC criticize this tendency. 193 | 194 | Amharic is a Semitic language mainly spoken in Ethiopia and is in general considered to be a "low resource" language. [@joshi2020state] Only recently, the first news classification dataset called "Amharic News Text classification Dataset" is available [@azime2021amharic]. The dataset contains 50,706 news articles curated from various Amharic websites. The original paper reports the baseline out-of-sample accuracy of 62.2\% using Naive Bayes. The released data also contains the training-and-test split [^Amharic]. In this example, the AfriBERTa is used as the pretrained model [@ogueji2021small]. The AfriBERTa model was trained with a small corpus of 11 African languages. Similar to the previous example, the default settings of `grafzahl` are used. 195 | 196 | [^Amharic]: [https://huggingface.co/datasets/israel/Amharic-News-Text-classification-Dataset](https://huggingface.co/datasets/israel/Amharic-News-Text-classification-Dataset) 197 | 198 | ```r 199 | input <- get_amharic_data() 200 | model <- grafzahl(x = input$training, 201 | y = "category", 202 | model_name = "castorini/afriberta_base") 203 | 204 | ## Calculate the out-of-sample accuracy 205 | 206 | preds <- predict(model, newdata = input$test) 207 | caret::confusionMatrix(table(preds, docvars(input$test, 208 | "category"))) 209 | ``` 210 | 211 | ## Results 212 | 213 | The final out-of-sample accuracy is 84.18\%, a solid improvement from the baseline of 62.2\%. 214 | 215 | # Conclusion 216 | 217 | This paper presents the R packages `grafzahl` and demonstrates its applicability to communication research by replicating the supervised machine learning part of published communication research. 218 | 219 | # Acknowledgments 220 | 221 | I would like to thank 1) Jarvis Labs for providing discounted GPU cloud service for the development of this package; 2) Pablo Barberá (University of Southern California) and Wouter van Atteveldt (VU Amsterdam) for allowing me to include their datasets in this package. 222 | 223 | # References 224 | -------------------------------------------------------------------------------- /paper/grafzahl_sp.rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title : "grafzahl: fine-tuning Transformers for text data from within R" 3 | shorttitle : "PUT THE R BACK IN TRANSFORMERS" 4 | 5 | author: 6 | - name : "Chung-hong Chan" 7 | affiliation : "1" 8 | corresponding : yes 9 | address : "Unter Sachsenhausen 6-8, 50667 Köln" 10 | email : "chung-hong.chan@gesis.org" 11 | 12 | affiliation: 13 | - id : "1" 14 | institution : "GESIS - Leibniz-Institut für Sozialwissenschaften, Germany" 15 | 16 | authornote: | 17 | Source code and data are available at https://github.com/chainsawriot/grafzahl. I would like to thank 1) Jarvis Labs for providing discounted GPU cloud service for the development of this package; 2) Pablo Barberá (University of Southern California) and Wouter van Atteveldt (VU Amsterdam) for allowing me to include their datasets in this package. 18 | 19 | abstract: | 20 | This paper introduces `grafzahl`, an R package for fine-tuning Transformers for text data from within R. The package is used in this paper to reproduce the analyses in communication papers or, of non-Germanic benchmark datasets. Very significant improvement in model accuacy over traditional machine learning approach such as Convoluted Neural Network is observed. 21 | 22 | keywords : "machine learning, transformers, R, python, automated content analysis" 23 | wordcount : "2018" 24 | 25 | bibliography : "grafzahl_sp.bib" 26 | 27 | floatsintext : yes 28 | figurelist : no 29 | tablelist : no 30 | figsintext : yes 31 | footnotelist : no 32 | linenumbers : no 33 | mask : no 34 | draft : no 35 | 36 | documentclass : "apa6" 37 | classoption : "man" 38 | output: 39 | papaja::apa6_pdf: 40 | latex_engine: xelatex 41 | --- 42 | 43 | ```{r setup, include = FALSE} 44 | library("papaja") 45 | ``` 46 | 47 | ```{r analysis-preferences} 48 | # Seed for random number generation 49 | set.seed(42) 50 | knitr::opts_chunk$set(cache.extra = knitr::rand_seed) 51 | ``` 52 | 53 | ## Put the R back in Transformers 54 | 55 | The purpose of this R package, `grafzahl`, is to provide the missing link between R and modern Transformers language models. Under the hood, the training part is based on the Python packages `transformers` [@wolf-etal-2020-transformers] and `simpletransformers` [@simpletransformers]. The integration based on `reticulate` [@reticulate] is seamless. With this seamless integration provided, communication researchers can produce the most advanced supervised learning models entirely from within R. This package provides the function `grafzahl()`, which emulates the behaviors of `quanteda.textmodels` [@quantedatextmodels]. [^f] 56 | 57 | [^f]: This package uses reasonable default settings which suit what communication researchers would like to achieve with these models. But the package also provides the freedom for communication researchers to finely adjust the parameters for their specific applications. However, the reanalysis of several examples in communication suggests that even the default settings can generate great improvement over the performance as reported in the original papers. Also, there is almost no need to conduct the cumbersome proprocessing and feature engineering steps, which all examples originally required. 58 | 59 | Two examples [@atteveldt:2021:VSA;@azime2021amharic] are presented here. Additional examples [@theocharis:2020:DPI; @dobbrick:2021:ETI; @ccoltekin2020corpus] are available in the Github repository of the package (https://github.com/chainsawriot/grafzahl). 60 | 61 | # Monolingual classification example 62 | 63 | @atteveldt:2021:VSA compare various methods to analyze the tone of Dutch economic news' headlines. Headlines were coded into three categories: negative (-1), neutral (0), and positive (+1). 64 | 65 | In the original paper, @atteveldt:2021:VSA show that the best method for predicting expert coding, other than coding by student helpers, is convoluted neural network (CNN) with Dutch word embeddings trained on Dutch news. The out-of-sample F1 of .63, .66, and .56 were reported for the three categories. As the data (including the training-and-test split) are publicly available [^wouter] and included in this package (as `ecosent`), I can provide a head-to-head comparison between the reported CNN and the Transformer-based model trained with `grafzahl`. 66 | 67 | [^wouter]: https://github.com/vanatteveldt/ecosent/ 68 | 69 | There are three important columns in the `ecosent` data frame: 70 | 71 | 1. `headline`: the actual text data 72 | 2. `value`: the sentiment 73 | 3. `gold`: whether or not this row is "gold standard", i.e. test set. There are 6,038 and 300 headlines in the training and test set respectively. 74 | 75 | ## Workflow 76 | 77 | ### Step 0: Setup `grafzahl` 78 | 79 | This step only needs to be done once. A miniconda environment needs to be setup. It is in general not recommended to use this package without a CUDA-compatible GPU. Without a CUDA-compatible GPU, the fine-tuning processes below might take days, if not weeks. 80 | 81 | If there is a GPU capable of performing CUDA, run: 82 | 83 | ```r 84 | require(grafzahl) 85 | setup_grafzahl(cuda = TRUE) # set to FALSE otherwise 86 | detect_cuda() 87 | ``` 88 | 89 | If the automatic setup failed, one can also set up the miniconda environment manually to diagnose what went wrong. The complete instructions are available here: https://github.com/chainsawriot/grafzahl/wiki/setup_grafzahl 90 | 91 | ### Step 1: Get information of the pretrained Transformer 92 | 93 | The first step of training a Transformer-based model is to find a suitable pretrained Transformer model on Hugging Face [^hugg], which would work for the data. As the data are in Dutch, the pretrained Dutch Transformer model BERTje should work [@de2019bertje, available from https://huggingface.co/GroNLP/bert-base-dutch-cased]. The model name of BERTje is `GroNLP/bert-base-dutch-cased`. It is also important to note the citation information to properly cite the pretrained Transformer model. 94 | 95 | [^hugg]: Hugging Face (https://huggingface.co) is an online repository of pretrained machine learning models. 96 | 97 | ### Step 2: Create the corpus 98 | 99 | The second step is to read the data as a corpus. [^CORPUS] 100 | 101 | [^CORPUS]: This step is not absolutely needed. The package can also work with character vectors. The `corpus` data structure is a better representation of character vector. 102 | 103 | ```r 104 | require(readtext) 105 | require(quanteda) 106 | input <- corpus(ecosent, text_field = "headline") 107 | ``` 108 | 109 | We can manipulate the corpus object using the functions provided by `quanteda`. For example, one can subset the training set using the function `corpus_subset()`. 110 | 111 | ```r 112 | ## selecting documents where the docvar `gold` is FALSE 113 | training_corpus <- corpus_subset(input, !gold) 114 | ``` 115 | 116 | ### Step 3: Fine-tune the model 117 | 118 | With the corpus and model name, the `grafzahl` function is used to fine-tune the model. 119 | 120 | ```r 121 | model <- grafzahl(x = training_corpus, 122 | y = "value", 123 | model_name = "GroNLP/bert-base-dutch-cased") 124 | ``` 125 | 126 | In general, it is better to specify `output_dir` (where to put the saved model object). By default, it will be `output` a random temporary directory. The R function `set.seed()` can also be used to preserve the random seed for reproducibility. 127 | 128 | On a regular off-the-shelf gaming laptop with a GeForce RTX 3050 Ti GPU and 4G of GPU ram, the process took around 20 minutes. 129 | 130 | ### Step 4: Make prediction 131 | 132 | Following the convention of `lm()` and many other R packages, the object returned by the function `grafzahl()` has a `predict()` S3 method. The following code gets the predicted sentiment of the headlines in the test set. 133 | 134 | ```r 135 | test_corpus <- corpus_subset(input, gold) 136 | predicted_sentiment <- predict(model, test_corpus) 137 | ``` 138 | 139 | ### Step 5: Evaluate performance 140 | 141 | With the predicted sentiment and the ground truth, there are many ways to evaluate the performance of the fine-tuned model. The simplest way is to construct a confusion matrix using the standard `table()` function. 142 | 143 | ```r 144 | cm <- table(predicted_sentiment, 145 | ground_truth = docvars(test_corpus, "value")) 146 | ``` 147 | 148 | The R package `caret` [@kuhn:2008:BPM] can also be used to calculate standard performance metrics such as Precision, Recall, and F1 [^caret]. 149 | 150 | ```r 151 | require(caret) 152 | confusionMatrix(cm, mode = "prec_recall") 153 | ``` 154 | 155 | The out-of-sample F1 measures of the fine-tuned model are .76, .67, and .72 (vs reported .63, .66, and .56). There is great improvement over the CNN model reported by @atteveldt:2021:VSA, although the prediction accuracy for the neutral category is just on par. @atteveldt:2021:VSA also provide the learning curve of CNN and Support Vector Machine (SVM). A learning curve plots the out-of-sample prediction performance as a function of number of training examples. I repeat the analysis in a similar manner to @atteveldt:2021:VSA and plot the learning curve of Transformer-based model trained using the default workflow of `grafzahl`. 156 | 157 | Figure \@ref(fig:fig2) show the fine-tuned Transformer model's learning curve alongside CNN's and SVM's [^learningcode]. The fune-tuned model has much better performance than CNN and SVM even with only 500 training examples. Unlike CNN and SVM, the gain in performance appears to plateau after 2500. It points to the fact that one does not need to have a lot of training data to fine-tune a Transformer model. 158 | 159 | ```{r fig2, fig.cap = "Learning curve of machine learning algorithms"} 160 | readRDS(here::here("learning.RDS")) 161 | ``` 162 | 163 | [^caret]: The function `confusionMatrix()` can accept the predicted values and ground truth directly, without using `table()` first. But the predicted values and ground truth must be `factor`: `confusionMatrix(as.factor(predicted_sentiment), as.factor(docvars(test_corpus, "value")), mode = "prec_recall")`. 164 | 165 | [^learningcode]: The R code for generating the learning curves is available in the official repository: https://github.com/chainsawriot/grafzahl 166 | 167 | ### Step 5: Explain the prediction 168 | 169 | Unlike "glass-box" machine learning models [@dobbrick:2021:ETI], Transformer-based prediction models are "black-box". There are so many parameters in Transformers (the BERT base model has 110 million parameters) and this complexity makes each individual parameter of a model not interpretable. 170 | 171 | A reasonable compromise is to make the prediction *explainable* instead. Generating Local Interpretable Model-agnostic Explanations (LIME) [@ribeiro2016should; R implementation by @lime] is a good way to explain how the model makes its prediction. The gist of the method is to perturb the input text data by deleting parts of the sentence. For example: the sentence "I hate this movie" will be perturbed as "I this movie", "I hate movie", "I hate this", "I hate" etc. These perturbed sentences are then feed into the machine learning model to make predictions. The relationship between what get deleted and the prediction is studied. The parts that change the prediction a lot would be more *causal* to the original prediction. 172 | 173 | With the trained model, we can explain the predictions made for the following two Dutch headlines: *"Dijsselbloem pessimistisch over snelle stappen Grieken"* (Dijsselbloem [the Former Minister of Finance of the Netherlands] pessimistic about rapid maneuvers from Greeks) and *"Aandelenbeurzen zetten koersopmars voort"* (Stock markets continue to rise). Models trained with `grafzahl` support the R package `lime` directly. One can get explanations using the following code: 174 | 175 | ```r 176 | require(lime) 177 | sentences <- c("Dijsselbloem pessimistisch over snelle stappen Grieken", 178 | "Aandelenbeurzen zetten koersopmars voort") 179 | explainer <- lime(training_corpus, model) 180 | explanations <- explain(sentences, explainer, n_labels = 1, 181 | n_features = 3) 182 | plot_text_explanations(explanations) 183 | ``` 184 | 185 | ```{r fig1, echo = FALSE, fig.cap = 'Generating Local Interpretable Model-agnostic Explanations (LIME) of two predictions from the trained Dutch sentiment model', out.width = "100%"} 186 | knitr::include_graphics("fig1.png") 187 | ``` 188 | 189 | Figure \@ref(fig:fig1) shows that for the sentence *"Dijsselbloem pessimistisch over snelle stappen Grieken"* (classified as negative), the tokens *pessimistisch* and *stappen* are making the prediction towards the classified position (negative). But the token *Dijsselbloem* is making it away. 190 | 191 | # Non-Germanic example: Amharic 192 | 193 | I want to emphasize that `grafzahl` is not just another package focusing only on English, or Germanic languages such as Dutch. @baden:2021:TGC criticize this tendency. 194 | 195 | Amharic is a Semitic language mainly spoken in Ethiopia and is in general considered to be a "low resource" language. [@joshi2020state] Only recently, the first news classification dataset called "Amharic News Text classification Dataset" is available [@azime2021amharic]. The dataset contains 50,706 news articles curated from various Amharic websites. The original paper reports the baseline out-of-sample accuracy of 62.2\% using Naive Bayes. The released data also contains the training-and-test split [^Amharic]. In this example, the AfriBERTa is used as the pretrained model [@ogueji2021small]. The AfriBERTa model was trained with a small corpus of 11 African languages. Similar to the previous example, the default settings of `grafzahl` are used. 196 | 197 | [^Amharic]: https://huggingface.co/datasets/israel/Amharic-News-Text-classification-Dataset 198 | 199 | ```r 200 | input <- get_amharic_data() 201 | model <- grafzahl(x = input$training, 202 | y = "category", 203 | model_name = "castorini/afriberta_base") 204 | 205 | ## Calculate the out-of-sample accuracy 206 | 207 | preds <- predict(model, newdata = input$test) 208 | caret::confusionMatrix(table(preds, docvars(input$test, "category"))) 209 | ``` 210 | 211 | ## Results 212 | 213 | The final out-of-sample accuracy is 84.18\%, a solid improvement from the baseline of 62.2\%. 214 | 215 | # Conclusion 216 | 217 | This paper presents the R packages `grafzahl` and demonstrates its applicability to communication research by replicating the supervised machine learning part of published communication research. 218 | 219 | # References 220 | 221 | \begingroup 222 | \setlength{\parindent}{-0.5in} 223 | \setlength{\leftskip}{0.5in} 224 | 225 | <div id="refs" custom-style="Bibliography"></div> 226 | \endgroup 227 | -------------------------------------------------------------------------------- /paper/img/fig-fig1-1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gesistsa/grafzahl/0cb7d375e1bb64f6c2fedc5bd4ba896f81f68572/paper/img/fig-fig1-1.pdf -------------------------------------------------------------------------------- /paper/img/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gesistsa/grafzahl/0cb7d375e1bb64f6c2fedc5bd4ba896f81f68572/paper/img/fig1.png -------------------------------------------------------------------------------- /paper/img/learning-curve-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gesistsa/grafzahl/0cb7d375e1bb64f6c2fedc5bd4ba896f81f68572/paper/img/learning-curve-1.png -------------------------------------------------------------------------------- /paper/img/theocharis-roc-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gesistsa/grafzahl/0cb7d375e1bb64f6c2fedc5bd4ba896f81f68572/paper/img/theocharis-roc-1.png -------------------------------------------------------------------------------- /paper/misc/explore_lime.R: -------------------------------------------------------------------------------- 1 | require(lime) 2 | devtools::load_all() 3 | library(caret) 4 | library(lime) 5 | 6 | # Split up the data set 7 | iris_test <- iris[1:5, 1:4] 8 | iris_train <- iris[-(1:5), 1:4] 9 | iris_lab <- iris[[5]][-(1:5)] 10 | 11 | # Create Random Forest model on iris data 12 | model <- train(iris_train, iris_lab, method = 'nb') 13 | 14 | predict_model(model, newdata = iris_test, type = "prob") 15 | 16 | # Create an explainer object 17 | explainer <- lime(iris_train, model) 18 | 19 | model <- readRDS(here::here("paper/va_model.RDS")) 20 | 21 | 22 | y <- predict(model, return_raw = TRUE) 23 | 24 | input <- rio::import(here::here("paper/sentences_ml.csv")) %>% tibble::as_tibble() 25 | 26 | table(input$gold) 27 | 28 | training <- input$headline[!input$gold] 29 | y <- input$value[!input$gold] 30 | testset <- input$headline[input$gold] 31 | predict_model(model, testset, type = "rawaa") 32 | 33 | 34 | explainer <- lime(training, model) 35 | 36 | explanation <- explain(testset[c(2, 4, 284)], explainer, n_labels = 1, n_features = 5) 37 | 38 | -------------------------------------------------------------------------------- /paper/misc/hindi.R: -------------------------------------------------------------------------------- 1 | ## https://arxiv.org/pdf/2101.06949.pdf 2 | require(quanteda) 3 | x <- readr::read_tsv(here::here("paper/hindi-train.csv"), col_names = FALSE) 4 | 5 | input_corpus <- corpus(x$X2) 6 | docvars(input_corpus, "outcome") <- x$X1 7 | 8 | devtools::load_all() 9 | set.seed(721) 10 | ## model <- grafzahl(input_corpus, model_type = "albert", model_name = "ai4bharat/indic-bert", output_dir = here::here("paper/hindi"), num_train_epoch = 20, train_size = 1) 11 | ## saveRDS(model, here::here("paper/hindi_model.RDS")) 12 | model <- readRDS(here::here("paper/hindi_model.RDS")) 13 | x2 <- readr::read_tsv(here::here("paper/hindi-test.csv"), col_names = FALSE) 14 | 15 | test_corpus <- corpus(x2$X2) 16 | docvars(test_corpus, "outcome") <- x2$X1 17 | 18 | pred <- predict(model, newdata = test_corpus) 19 | 20 | u <- union(pred, x2$X1) 21 | t <- table(factor(pred, u), factor(x2$X1, u)) 22 | mean(caret::confusionMatrix(t)$byClass[,11]) 23 | 24 | caret::confusionMatrix(t)$byClass[,11] 25 | -------------------------------------------------------------------------------- /paper/misc/movie.R: -------------------------------------------------------------------------------- 1 | require(grafzahl) 2 | 3 | require(quanteda.textmodels) 4 | require(quanteda) 5 | set.seed(20190721) 6 | model <- grafzahl(x = data_corpus_moviereviews, y = "sentiment", 7 | model_type = "bert", model_name = "bert-base-uncased", 8 | train_size = 1, num_train_epochs = 2) 9 | preds <- predict(model) 10 | table(preds, docvars(data_corpus_moviereviews, "sentiment")) 11 | 12 | mr_vec <- as.vector(data_corpus_moviereviews) 13 | 14 | sent <- docvars(data_corpus_moviereviews, "sentiment") 15 | 16 | grafzahl(x = mr_vec, y = sent, 17 | model_type = "bert", model_name = "bert-base-uncased", 18 | train_size = 1, num_train_epochs = 2) 19 | -------------------------------------------------------------------------------- /paper/misc/multilingual.R: -------------------------------------------------------------------------------- 1 | require(readr) 2 | 3 | all_dirs <- tail(list.dirs(here::here("paper/sentiment")), -1) 4 | 5 | .reading <- function(dir, file = "train.tsv") { 6 | res <- readr::read_tsv(file.path(dir, file), col_names = FALSE) 7 | res$dir <- dir 8 | return(res) 9 | } 10 | 11 | training_set <- suppressMessages(purrr::map_dfr(all_dirs, .reading)) 12 | 13 | test_set <- suppressMessages(purrr::map_dfr(all_dirs, .reading, file = "test.tsv")) 14 | 15 | 16 | devtools::load_all() 17 | 18 | require(quanteda) 19 | training_corpus <- corpus(training_set$X2) 20 | docvars(training_corpus, "sentiment") <- training_set$X1 21 | 22 | set.seed(1212121) 23 | ##model <- grafzahl(training_corpus, model_type = "distilbert", model_name = "distilbert-base-multilingual-cased", output_dir = here::here("paper/multi")) 24 | ##saveRDS(model, here::here("paper/multimodel.RDS")) 25 | model <- readRDS(here::here("paper/multimodel.RDS")) 26 | training_pred <- predict(model, training_corpus) 27 | 28 | table(docvars(training_corpus, "sentiment"), training_pred) 29 | 30 | test_pred <- predict(model, corpus(test_set$X2)) 31 | 32 | dir <- unique(test_set$dir) 33 | 34 | dir[4] 35 | 36 | i <- 1 37 | .acc <- function(i) { 38 | F1 <- caret::confusionMatrix(table(test_set$X1[test_set$dir == dir[i]], test_pred[test_set$dir == dir[i]]))$byClass["F1"] 39 | lang <- dir[i] 40 | tibble::tibble(lang, F1) 41 | } 42 | 43 | purrr::map_dfr(seq_along(dir), .acc) 44 | -------------------------------------------------------------------------------- /paper/misc/tm.R: -------------------------------------------------------------------------------- 1 | require(RTextTools) 2 | devtools::load_all() 3 | data(USCongress) 4 | USCongress$text <- as.character(USCongress$text) 5 | USCongress$major <- as.factor(USCongress$major) 6 | ## training <- USCongress[1:4000,] 7 | ## test <- USCongress[4001:4449,] 8 | 9 | require(quanteda) 10 | USCongress_corpus <- corpus(USCongress$text) 11 | docvars(USCongress_corpus, "major") <- as.numeric(USCongress$major) - 1 12 | set.seed(721) 13 | model <- grafzahl(USCongress_corpus[1:4000], model_type = "bert", model_name = "bert-base-cased", output_dir = here::here("paper/rtt")) 14 | saveRDS(model, here::here("paper/tm_mod.RDS")) 15 | output <- predict(model, newdata = USCongress_corpus[1:4000]) 16 | 17 | testset_out <- predict(model, newdata = USCongress_corpus[4001:4449]) 18 | 19 | u <- union(testset_out, docvars(USCongress_corpus, "major")[4001:4449]) 20 | 21 | tb <- table(factor(docvars(USCongress_corpus, "major")[4001:4449], u), factor(testset_out, u)) 22 | cb <- caret::confusionMatrix(tb) 23 | 24 | mean(cb$byClass[,"F1"], na.rm = TRUE) 25 | 26 | doc_matrix <- create_matrix(USCongress$text, language="english", removeNumbers=TRUE, stemWords=TRUE, removeSparseTerms=.998) 27 | container <- create_container(doc_matrix, USCongress$major, trainSize=1:4000, testSize=4001:4449, virgin=FALSE) 28 | SVM <- train_model(container,"SVM") 29 | 30 | 31 | SVM_CLASSIFY <- classify_model(container, SVM) 32 | 33 | analytics <- create_analytics(container, cbind(SVM_CLASSIFY)) 34 | topic_summary <- analytics@label_summary 35 | alg_summary <- analytics@algorithm_summary 36 | 37 | mean(alg_summary$SVM_FSCORE, na.rm = TRUE) 38 | 39 | -------------------------------------------------------------------------------- /paper/plot_training.R: -------------------------------------------------------------------------------- 1 | require(tidyverse) 2 | 3 | res <- readRDS("va_learning.RDS") 4 | n <- rep(seq(500, 6000, by = 500), 10) 5 | 6 | acc <- purrr::map_dbl(res, ~.$overall['Accuracy']) 7 | 8 | ## Downright stole from Van Atteveldt. 9 | ## Except their Amsterdam style of R programming. 10 | ## https://github.com/vanatteveldt/ecosent/blob/36b84628ec908666ea8280593cb335c89c4e5e7e/src/analysis/performance.md 11 | 12 | curve <- rbind(read_csv(here::here("paper/cnn_curve.csv")) %>% add_column(method="CNN", .before=1), read_csv(here::here("paper/svm_curve.csv")) %>% add_column(method="SVM", .before=1)) %>% group_by(method, perc) %>% summarize(n=mean(n), acc=mean(acc)) %>% ungroup 13 | 14 | tibble::tibble(n, acc, method = "Transformer (BERTje)", perc = 0) %>% group_by(method, n) %>% summarise(acc = mean(acc)) %>% ungroup %>% add_column(perc = 1, .before = "n") %>% bind_rows(curve) -> curve 15 | 16 | 17 | plot <- ggplot(curve, aes(x=n, y=acc, group=method, lty=method)) + geom_line() + 18 | scale_linetype(name="Method") + 19 | xlab("Number of training examples") + ylab("Accuracy") + 20 | scale_y_continuous(labels = scales::percent_format(accuracy = 1))+ 21 | ggthemes::theme_clean() + theme(legend.position = "top", legend.background = element_blank(), 22 | plot.background = element_blank()) 23 | saveRDS(plot, here::here("paper/learning.RDS")) 24 | -------------------------------------------------------------------------------- /paper/theocharis.md: -------------------------------------------------------------------------------- 1 | Theocharis et al. (2020) 2 | ================ 3 | 4 | The following is to analyse the same data used in Theocharis et 5 | al. (2020) “The Dynamics of Political Incivility on Twitter” 6 | \[[doi](https://doi.org/10.1177/2158244020919447)\]. The data is 7 | available from [Professor Pablo Barberá’s 8 | Github](https://github.com/pablobarbera/incivility-sage-open). 9 | 10 | # Data and Lasso regression 11 | 12 | The dataset `unciviltweets` is available in this package by agreement of 13 | Professor Pablo Barberá. The dataset bundled in this package is a 14 | quanteda corpus of 19,982 tweets and a single docvar of incivility, the 15 | label to be predicted. 16 | 17 | The following attempts to train the [lasso incivility 18 | classifier](https://github.com/pablobarbera/incivility-sage-open/blob/master/02-classifier.R) 19 | in the original paper. 20 | 21 | ## Creation of train-test split 22 | 23 | Preprocessing 24 | 25 | ``` r 26 | require(quanteda) 27 | #> Loading required package: quanteda 28 | #> Package version: 3.2.4 29 | #> Unicode version: 13.0 30 | #> ICU version: 66.1 31 | #> Parallel computing: 16 of 16 threads used. 32 | #> See https://quanteda.io for tutorials and examples. 33 | require(grafzahl) 34 | require(caret) 35 | #> Loading required package: caret 36 | #> Loading required package: ggplot2 37 | #> Loading required package: lattice 38 | require(glmnet) 39 | #> Loading required package: glmnet 40 | #> Loading required package: Matrix 41 | #> Loaded glmnet 4.1-6 42 | require(pROC) 43 | #> Loading required package: pROC 44 | #> Type 'citation("pROC")' for a citation. 45 | #> 46 | #> Attaching package: 'pROC' 47 | #> The following objects are masked from 'package:stats': 48 | #> 49 | #> cov, smooth, var 50 | 51 | uncivildfm <- unciviltweets %>% tokens(remove_url = TRUE, remove_numbers = TRUE) %>% tokens_wordstem() %>% dfm() %>% dfm_remove(stopwords("english")) %>% dfm_trim(min_docfreq = 2) 52 | y <- docvars(unciviltweets)[,1] 53 | seed <- 123 54 | set.seed(seed) 55 | training <- sample(seq_along(y), floor(.80 * length(y))) 56 | test <- (seq_along(y))[seq_along(y) %in% training == FALSE] 57 | ``` 58 | 59 | A “downsample” process was introduced in the original paper. 60 | 61 | ``` r 62 | small_class <- which.min(table(y[training])) - 1 63 | n_small_class <- sum(y[training] == small_class) 64 | downsample <- sample(training[y[training] != small_class], n_small_class, replace = TRUE) 65 | training <- c(training[y[training] == small_class], downsample) 66 | original_training <- setdiff(seq_along(y), test) ## retain a copy 67 | ``` 68 | 69 | ## Training a lasso classifier 70 | 71 | Confusion matrix 72 | 73 | ``` r 74 | X <- as(uncivildfm, "dgCMatrix") 75 | 76 | lasso <- glmnet::cv.glmnet(x = X[training,], y = y[training], alpha = 1, nfold = 5, family = "binomial") 77 | ``` 78 | 79 | ### Evaluation 80 | 81 | ``` r 82 | preds <- predict(lasso, uncivildfm[test,], type="response") 83 | caret::confusionMatrix(table(y[test], ifelse(preds > .5, 1, 0)), mode = "prec_recall") 84 | #> Confusion Matrix and Statistics 85 | #> 86 | #> 87 | #> 0 1 88 | #> 0 2929 384 89 | #> 1 183 501 90 | #> 91 | #> Accuracy : 0.8581 92 | #> 95% CI : (0.8469, 0.8688) 93 | #> No Information Rate : 0.7786 94 | #> P-Value [Acc > NIR] : < 2.2e-16 95 | #> 96 | #> Kappa : 0.5522 97 | #> 98 | #> Mcnemar's Test P-Value : < 2.2e-16 99 | #> 100 | #> Precision : 0.8841 101 | #> Recall : 0.9412 102 | #> F1 : 0.9118 103 | #> Prevalence : 0.7786 104 | #> Detection Rate : 0.7328 105 | #> Detection Prevalence : 0.8289 106 | #> Balanced Accuracy : 0.7536 107 | #> 108 | #> 'Positive' Class : 0 109 | #> 110 | ``` 111 | 112 | ROC 113 | 114 | ``` r 115 | pROC::auc(as.vector((y[test])*1), as.vector((preds)*1)) 116 | #> Setting levels: control = 0, case = 1 117 | #> Setting direction: controls < cases 118 | #> Area under the curve: 0.8734 119 | ``` 120 | 121 | ## Training a BERTweet classifier 122 | 123 | In this example, a BERTweet-based classifier (Nguyen et al. 2020) is 124 | trained. Please note that the following doesn’t involve the 125 | preprocessing and downsampling procedures. 126 | 127 | ``` r 128 | set.seed(721) 129 | model <- grafzahl(unciviltweets[original_training], model_type = "bertweet", model_name = "vinai/bertweet-base", output_dir = here::here("theocharis")) 130 | ``` 131 | 132 | ### Evaluation 133 | 134 | ``` r 135 | pred_bert <- predict(model, unciviltweets[test]) 136 | pred_bert2 <- predict(model, unciviltweets[test], return_raw = TRUE) 137 | 138 | caret::confusionMatrix(table(y[test], pred_bert), mode = "prec_recall") 139 | #> Confusion Matrix and Statistics 140 | #> 141 | #> pred_bert 142 | #> 0 1 143 | #> 0 3162 151 144 | #> 1 186 498 145 | #> 146 | #> Accuracy : 0.9157 147 | #> 95% CI : (0.9066, 0.9241) 148 | #> No Information Rate : 0.8376 149 | #> P-Value [Acc > NIR] : < 2e-16 150 | #> 151 | #> Kappa : 0.6966 152 | #> 153 | #> Mcnemar's Test P-Value : 0.06401 154 | #> 155 | #> Precision : 0.9544 156 | #> Recall : 0.9444 157 | #> F1 : 0.9494 158 | #> Prevalence : 0.8376 159 | #> Detection Rate : 0.7911 160 | #> Detection Prevalence : 0.8289 161 | #> Balanced Accuracy : 0.8559 162 | #> 163 | #> 'Positive' Class : 0 164 | #> 165 | ``` 166 | 167 | ### ROC 168 | 169 | ``` r 170 | pROC::auc(as.vector((y[test])*1), pred_bert2[,1]) 171 | #> Setting levels: control = 0, case = 1 172 | #> Setting direction: controls > cases 173 | #> Area under the curve: 0.9274 174 | ``` 175 | 176 | ### Plotting the two curves 177 | 178 | ``` r 179 | require(ROCR) 180 | #> Loading required package: ROCR 181 | performance_bert <- performance(prediction(pred_bert2[,2], y[test]), "tpr", "fpr") 182 | performance_origin <- performance(prediction(preds, y[test]), "tpr", "fpr") 183 | plot(performance_origin) 184 | abline(a = 0, b = 1, col = "grey") 185 | plot(performance_bert, add = TRUE, col = "red") 186 | ``` 187 | 188 | <img src="img/theocharis-roc-1.png" style="width:100.0%" /> 189 | 190 | ## References 191 | 192 | 1. Nguyen, D. Q., Vu, T., & Nguyen, A. T. (2020). BERTweet: A 193 | pre-trained language model for English Tweets. arXiv preprint 194 | arXiv:2005.10200. 195 | -------------------------------------------------------------------------------- /paper/theocharis.qmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: Theocharis et al. (2020) 3 | format: gfm 4 | --- 5 | 6 | ```{r} 7 | #| include: false 8 | knitr::opts_chunk$set( 9 | collapse = TRUE, 10 | comment = "#>", 11 | fig.path = "img/", 12 | out.width = "100%" 13 | ) 14 | require(grafzahl) 15 | 16 | ## model <- hydrate(here::here("theocharis"), model_type = "bertweet") 17 | model <- readRDS(here::here("theo.RDS")) 18 | ``` 19 | 20 | The following is to analyse the same data used in Theocharis et al. (2020) "The Dynamics of Political Incivility on Twitter" [[doi](https://doi.org/10.1177/2158244020919447)]. The data is available from [Professor Pablo Barberá's Github 21 | ](https://github.com/pablobarbera/incivility-sage-open). 22 | 23 | # Data and Lasso regression 24 | 25 | The dataset `unciviltweets` is available in this package by agreement of Professor Pablo Barberá. The dataset bundled in this package is a quanteda corpus of 19,982 tweets and a single docvar of incivility, the label to be predicted. 26 | 27 | The following attempts to train the [lasso incivility classifier](https://github.com/pablobarbera/incivility-sage-open/blob/master/02-classifier.R) in the original paper. 28 | 29 | ## Creation of train-test split 30 | 31 | Preprocessing 32 | 33 | ```{r} 34 | require(quanteda) 35 | require(grafzahl) 36 | require(caret) 37 | require(glmnet) 38 | require(pROC) 39 | 40 | uncivildfm <- unciviltweets %>% tokens(remove_url = TRUE, remove_numbers = TRUE) %>% tokens_wordstem() %>% dfm() %>% dfm_remove(stopwords("english")) %>% dfm_trim(min_docfreq = 2) 41 | y <- docvars(unciviltweets)[,1] 42 | seed <- 123 43 | set.seed(seed) 44 | training <- sample(seq_along(y), floor(.80 * length(y))) 45 | test <- (seq_along(y))[seq_along(y) %in% training == FALSE] 46 | ``` 47 | 48 | A "downsample" process was introduced in the original paper. 49 | 50 | ```{r} 51 | small_class <- which.min(table(y[training])) - 1 52 | n_small_class <- sum(y[training] == small_class) 53 | downsample <- sample(training[y[training] != small_class], n_small_class, replace = TRUE) 54 | training <- c(training[y[training] == small_class], downsample) 55 | original_training <- setdiff(seq_along(y), test) ## retain a copy 56 | ``` 57 | 58 | ## Training a lasso classifier 59 | 60 | Confusion matrix 61 | 62 | ```{r} 63 | X <- as(uncivildfm, "dgCMatrix") 64 | 65 | lasso <- glmnet::cv.glmnet(x = X[training,], y = y[training], alpha = 1, nfold = 5, family = "binomial") 66 | ``` 67 | 68 | ### Evaluation 69 | 70 | ```{r} 71 | preds <- predict(lasso, uncivildfm[test,], type="response") 72 | caret::confusionMatrix(table(y[test], ifelse(preds > .5, 1, 0)), mode = "prec_recall") 73 | ``` 74 | 75 | ROC 76 | 77 | ```{r} 78 | pROC::auc(as.vector((y[test])*1), as.vector((preds)*1)) 79 | ``` 80 | 81 | ## Training a BERTweet classifier 82 | 83 | In this example, a BERTweet-based classifier (Nguyen et al. 2020) is trained. Please note that the following doesn't involve the preprocessing and downsampling procedures. 84 | 85 | ```{r} 86 | #| eval: false 87 | set.seed(721) 88 | model <- grafzahl(unciviltweets[original_training], model_type = "bertweet", model_name = "vinai/bertweet-base", output_dir = here::here("theocharis")) 89 | ``` 90 | 91 | ### Evaluation 92 | 93 | ```{r} 94 | pred_bert <- predict(model, unciviltweets[test]) 95 | pred_bert2 <- predict(model, unciviltweets[test], return_raw = TRUE) 96 | 97 | caret::confusionMatrix(table(y[test], pred_bert), mode = "prec_recall") 98 | ``` 99 | 100 | ### ROC 101 | 102 | ```{r} 103 | pROC::auc(as.vector((y[test])*1), pred_bert2[,1]) 104 | ``` 105 | 106 | ### Plotting the two curves 107 | 108 | ```{r} 109 | #| label: theocharis-roc 110 | require(ROCR) 111 | performance_bert <- performance(prediction(pred_bert2[,2], y[test]), "tpr", "fpr") 112 | performance_origin <- performance(prediction(preds, y[test]), "tpr", "fpr") 113 | plot(performance_origin) 114 | abline(a = 0, b = 1, col = "grey") 115 | plot(performance_bert, add = TRUE, col = "red") 116 | ``` 117 | 118 | ## References 119 | 120 | 1. Nguyen, D. Q., Vu, T., & Nguyen, A. T. (2020). BERTweet: A pre-trained language model for English Tweets. arXiv preprint arXiv:2005.10200. 121 | -------------------------------------------------------------------------------- /paper/vanatteveldt.md: -------------------------------------------------------------------------------- 1 | van Atteveldt et al. (2021) 2 | ================ 3 | 4 | The following is to analyse the same data used in van Atteveldt et 5 | al. (2021) “The Validity of Sentiment Analysis: Comparing Manual 6 | Annotation, Crowd-Coding, Dictionary Approaches, and Machine Learning 7 | Algorithms” \[[doi](https://doi.org/10.1080/19312458.2020.1869198)\]. 8 | The data is available from this [github 9 | repo](https://github.com/vanatteveldt/ecosent). 10 | 11 | # Read the data directly from Github using `readtext` 12 | 13 | ``` r 14 | require(quanteda) 15 | #> Loading required package: quanteda 16 | #> Package version: 3.2.4 17 | #> Unicode version: 13.0 18 | #> ICU version: 66.1 19 | #> Parallel computing: 16 of 16 threads used. 20 | #> See https://quanteda.io for tutorials and examples. 21 | require(grafzahl) 22 | #> Loading required package: grafzahl 23 | require(readtext) 24 | #> Loading required package: readtext 25 | url <- "https://raw.githubusercontent.com/vanatteveldt/ecosent/master/data/intermediate/sentences_ml.csv" 26 | input <- readtext(url, text_field = "headline") %>% corpus 27 | training_corpus <- corpus_subset(input, !gold) 28 | ``` 29 | 30 | # Training 31 | 32 | In this analysis, [the Dutch language BERT 33 | (BERTje)](https://huggingface.co/GroNLP/bert-base-dutch-cased) is used. 34 | 35 | ``` r 36 | model <- grafzahl(x = training_corpus, y = "value", model_name = "GroNLP/bert-base-dutch-cased", 37 | output_dir = here::here("va_output"), manual_seed = 721) 38 | saveRDS(model, here::here("va_model.RDS")) 39 | ``` 40 | 41 | # Make prediction for the test set 42 | 43 | ``` r 44 | model <- readRDS(here::here("va_model.RDS")) 45 | test_corpus<- corpus_subset(input, gold) 46 | predicted_sentiment <- predict(model, newdata = test_corpus) 47 | ``` 48 | 49 | # Confusion matrix 50 | 51 | ``` r 52 | mt <- table(predicted_sentiment, gt = docvars(test_corpus, "value")) 53 | caret::confusionMatrix(mt, mode = "prec_recall") 54 | #> Confusion Matrix and Statistics 55 | #> 56 | #> gt 57 | #> predicted_sentiment -1 0 1 58 | #> -1 81 27 6 59 | #> 0 14 74 18 60 | #> 1 4 11 49 61 | #> 62 | #> Overall Statistics 63 | #> 64 | #> Accuracy : 0.7183 65 | #> 95% CI : (0.6621, 0.7699) 66 | #> No Information Rate : 0.3944 67 | #> P-Value [Acc > NIR] : <2e-16 68 | #> 69 | #> Kappa : 0.5699 70 | #> 71 | #> Mcnemar's Test P-Value : 0.1018 72 | #> 73 | #> Statistics by Class: 74 | #> 75 | #> Class: -1 Class: 0 Class: 1 76 | #> Precision 0.7105 0.6981 0.7656 77 | #> Recall 0.8182 0.6607 0.6712 78 | #> F1 0.7606 0.6789 0.7153 79 | #> Prevalence 0.3486 0.3944 0.2570 80 | #> Detection Rate 0.2852 0.2606 0.1725 81 | #> Detection Prevalence 0.4014 0.3732 0.2254 82 | #> Balanced Accuracy 0.8199 0.7373 0.8001 83 | ``` 84 | 85 | # LIME 86 | 87 | Explaining the prediction using Local Interpretable Model-agnostic 88 | Explanations (LIME). 89 | 90 | ``` r 91 | require(lime) 92 | sentences <- c("Dijsselbloem pessimistisch over snelle stappen Grieken", 93 | "Aandelenbeurzen zetten koersopmars voort") 94 | explainer <- lime(training_corpus, model) 95 | explanations <- explain(sentences, explainer, n_labels = 1, 96 | n_features = 5) 97 | plot_text_explanations(explanations) 98 | ``` 99 | 100 | You should see something like this: 101 | 102 | ![Fig 1.Generating Local Interpretable Model-agnostic Explanations 103 | (LIME) of two predictions from the trained Dutch sentiment 104 | model](img/fig1.png) 105 | 106 | # Learning curve 107 | 108 | van Atteveldt et al. (2021) present also the learning curves for CNN and 109 | SVM. Let’s overlay the learning curve for BERTje here. 110 | 111 | ## Training using different training sizes 112 | 113 | ``` r 114 | n <- rep(seq(500, 6000, by = 500), 10) 115 | res <- list() 116 | test_corpus<- corpus_subset(input, gold) 117 | set.seed(721831) 118 | for (i in seq_along(n)) { 119 | current_corpus <- corpus_sample(training_corpus, n[i]) 120 | model <- grafzahl(x = current_corpus, y = "value", model_name = "GroNLP/bert-base-dutch-cased", output_dir = here::here("va_size")) 121 | predicted_sentiment <- predict(model, newdata = test_corpus) 122 | res[[i]] <- caret::confusionMatrix(table(predicted_sentiment, gt = docvars(test_corpus, "value")), mode = "prec_recall") 123 | } 124 | saveRDS(res, here::here("va_learning.RDS")) 125 | ``` 126 | 127 | ## Plotting 128 | 129 | ``` r 130 | require(tidyverse) 131 | n <- rep(seq(500, 6000, by = 500), 10) 132 | res <- readRDS(here::here("va_learning.RDS")) 133 | acc <- purrr::map_dbl(res, ~.$overall['Accuracy']) 134 | 135 | ## Downright stole from Van Atteveldt. 136 | ## Except their Amsterdam style of R programming. 137 | ## https://github.com/vanatteveldt/ecosent/blob/36b84628ec908666ea8280593cb335c89c4e5e7e/src/analysis/performance.md 138 | 139 | url_cnn <- "https://raw.githubusercontent.com/vanatteveldt/ecosent/master/data/intermediate/cnn_curve.csv" 140 | url_svm <- "https://raw.githubusercontent.com/vanatteveldt/ecosent/master/data/intermediate/svm_curve.csv" 141 | 142 | curve <- rbind(readr::read_csv(url_cnn) %>% add_column(method="CNN", .before=1), 143 | readr::read_csv(url_svm) %>% add_column(method="SVM", .before=1)) %>% 144 | group_by(method, perc) %>% summarize(n=mean(n), acc=mean(acc)) %>% ungroup 145 | 146 | tibble::tibble(n, acc, method = "Transformer (BERTje)", perc = 0) %>% group_by(method, n) %>% 147 | summarise(acc = mean(acc)) %>% ungroup %>% add_column(perc = 1, .before = "n") %>% 148 | bind_rows(curve) -> curve 149 | 150 | learning <- ggplot(curve, aes(x=n, y=acc, group=method, lty=method)) + geom_line() + 151 | scale_linetype(name="Method") + 152 | xlab("Number of training examples") + ylab("Accuracy") + 153 | scale_y_continuous(labels = scales::percent_format(accuracy = 1))+ 154 | ggthemes::theme_clean() + theme(legend.position = "top", legend.background = element_blank(), 155 | plot.background = element_blank()) 156 | saveRDS(learning, here::here("learning.RDS")) 157 | learning 158 | ``` 159 | 160 | <figure> 161 | <img src="img/learning-curve-1.png" style="width:100.0%" 162 | alt="Learning curve of machine learning algorithms" /> 163 | <figcaption aria-hidden="true">Learning curve of machine learning 164 | algorithms</figcaption> 165 | </figure> 166 | -------------------------------------------------------------------------------- /paper/vanatteveldt.qmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: van Atteveldt et al. (2021) 3 | format: gfm 4 | --- 5 | 6 | ```{r} 7 | #| include: false 8 | knitr::opts_chunk$set( 9 | collapse = TRUE, 10 | comment = "#>", 11 | fig.path = "img/", 12 | out.width = "100%" 13 | ) 14 | ``` 15 | 16 | The following is to analyse the same data used in van Atteveldt et al. (2021) "The Validity of Sentiment Analysis: Comparing Manual Annotation, Crowd-Coding, Dictionary Approaches, and Machine Learning Algorithms" [[doi](https://doi.org/10.1080/19312458.2020.1869198)]. The data is available from this [github repo](https://github.com/vanatteveldt/ecosent). 17 | 18 | # Read the data directly from Github using `readtext` 19 | 20 | ```{r} 21 | require(quanteda) 22 | require(grafzahl) 23 | require(readtext) 24 | url <- "https://raw.githubusercontent.com/vanatteveldt/ecosent/master/data/intermediate/sentences_ml.csv" 25 | input <- readtext(url, text_field = "headline") %>% corpus 26 | training_corpus <- corpus_subset(input, !gold) 27 | ``` 28 | 29 | # Training 30 | 31 | In this analysis, [the Dutch language BERT (BERTje)](https://huggingface.co/GroNLP/bert-base-dutch-cased) is used. 32 | 33 | ```{r} 34 | #| eval: false 35 | model <- grafzahl(x = training_corpus, y = "value", model_name = "GroNLP/bert-base-dutch-cased", 36 | output_dir = here::here("va_output"), manual_seed = 721) 37 | saveRDS(model, here::here("va_model.RDS")) 38 | ``` 39 | 40 | # Make prediction for the test set 41 | 42 | ```{r} 43 | model <- readRDS(here::here("va_model.RDS")) 44 | test_corpus<- corpus_subset(input, gold) 45 | predicted_sentiment <- predict(model, newdata = test_corpus) 46 | ``` 47 | 48 | # Confusion matrix 49 | 50 | ```{r} 51 | mt <- table(predicted_sentiment, gt = docvars(test_corpus, "value")) 52 | caret::confusionMatrix(mt, mode = "prec_recall") 53 | ``` 54 | 55 | # LIME 56 | 57 | Explaining the prediction using Local Interpretable Model-agnostic Explanations (LIME). 58 | 59 | ```{r} 60 | #| eval: false 61 | require(lime) 62 | sentences <- c("Dijsselbloem pessimistisch over snelle stappen Grieken", 63 | "Aandelenbeurzen zetten koersopmars voort") 64 | explainer <- lime(training_corpus, model) 65 | explanations <- explain(sentences, explainer, n_labels = 1, 66 | n_features = 5) 67 | plot_text_explanations(explanations) 68 | ``` 69 | 70 | You should see something like this: 71 | 72 | ![Fig 1.Generating Local Interpretable Model-agnostic Explanations (LIME) of two predictions from the trained Dutch sentiment model](img/fig1.png) 73 | 74 | # Learning curve 75 | 76 | van Atteveldt et al. (2021) present also the learning curves for CNN and SVM. Let's overlay the learning curve for BERTje here. 77 | 78 | ## Training using different training sizes 79 | 80 | ```{r} 81 | #| eval: false 82 | n <- rep(seq(500, 6000, by = 500), 10) 83 | res <- list() 84 | test_corpus<- corpus_subset(input, gold) 85 | set.seed(721831) 86 | for (i in seq_along(n)) { 87 | current_corpus <- corpus_sample(training_corpus, n[i]) 88 | model <- grafzahl(x = current_corpus, y = "value", model_name = "GroNLP/bert-base-dutch-cased", output_dir = here::here("va_size")) 89 | predicted_sentiment <- predict(model, newdata = test_corpus) 90 | res[[i]] <- caret::confusionMatrix(table(predicted_sentiment, gt = docvars(test_corpus, "value")), mode = "prec_recall") 91 | } 92 | saveRDS(res, here::here("va_learning.RDS")) 93 | ``` 94 | 95 | ## Plotting 96 | 97 | ```{r} 98 | #| label: learning-curve 99 | #| fig-cap: Learning curve of machine learning algorithms 100 | #| warning: false 101 | 102 | require(tidyverse) 103 | n <- rep(seq(500, 6000, by = 500), 10) 104 | res <- readRDS(here::here("va_learning.RDS")) 105 | acc <- purrr::map_dbl(res, ~.$overall['Accuracy']) 106 | 107 | ## Downright stole from Van Atteveldt. 108 | ## Except their Amsterdam style of R programming. 109 | ## https://github.com/vanatteveldt/ecosent/blob/36b84628ec908666ea8280593cb335c89c4e5e7e/src/analysis/performance.md 110 | 111 | url_cnn <- "https://raw.githubusercontent.com/vanatteveldt/ecosent/master/data/intermediate/cnn_curve.csv" 112 | url_svm <- "https://raw.githubusercontent.com/vanatteveldt/ecosent/master/data/intermediate/svm_curve.csv" 113 | 114 | curve <- rbind(readr::read_csv(url_cnn) %>% add_column(method="CNN", .before=1), 115 | readr::read_csv(url_svm) %>% add_column(method="SVM", .before=1)) %>% 116 | group_by(method, perc) %>% summarize(n=mean(n), acc=mean(acc)) %>% ungroup 117 | 118 | tibble::tibble(n, acc, method = "Transformer (BERTje)", perc = 0) %>% group_by(method, n) %>% 119 | summarise(acc = mean(acc)) %>% ungroup %>% add_column(perc = 1, .before = "n") %>% 120 | bind_rows(curve) -> curve 121 | 122 | learning <- ggplot(curve, aes(x=n, y=acc, group=method, lty=method)) + geom_line() + 123 | scale_linetype(name="Method") + 124 | xlab("Number of training examples") + ylab("Accuracy") + 125 | scale_y_continuous(labels = scales::percent_format(accuracy = 1))+ 126 | ggthemes::theme_clean() + theme(legend.position = "top", legend.background = element_blank(), 127 | plot.background = element_blank()) 128 | saveRDS(learning, here::here("learning.RDS")) 129 | learning 130 | ``` 131 | -------------------------------------------------------------------------------- /rawdata/createdata.R: -------------------------------------------------------------------------------- 1 | 2 | d <- readr::read_csv(here::here("rawdata/training-data.csv"), col_types="cccc") 3 | 4 | # adding synthetic labels 5 | d2 <- readr::read_csv(here::here("rawdata/synthetic-labels.csv"), col_types="cccc") 6 | d <- rbind(d, d2) 7 | 8 | 9 | d$text <- gsub('@[0-9_A-Za-z]+', '@', d$text) 10 | d$uncivil_dummy <- ifelse(d$uncivil=="yes", 1, 0) 11 | 12 | unciviltweets <- quanteda::corpus(d$text) 13 | 14 | quanteda::docvars(unciviltweets, "uncivil") <- d$uncivil_dummy 15 | usethis::use_data(unciviltweets) 16 | 17 | set.seed(123) 18 | smallunciviltweets <- quanteda::corpus_sample(unciviltweets, 200) 19 | usethis::use_data(smallunciviltweets) 20 | 21 | download.file(url <- "https://raw.githubusercontent.com/vanatteveldt/ecosent/master/data/intermediate/sentences_ml.csv", "rawdata/sentences_ml.csv") 22 | 23 | ecosent <- read.csv("rawdata/sentences_ml.csv", encoding = "UTF-8")[c("id", "headline", "value", "gold")] 24 | save(ecosent, file = "data/ecosent.rda", ascii = FALSE, compress = "xz") 25 | 26 | supported_model_types <- c("albert", "bert", "bertweet", "bigbird", "camembert", "deberta", "distilbert", "electra", "flaubert", 27 | "herbert", "layoutlm", "layoutlmv2", "longformer", "mpnet", "mobilebert", "rembert", "roberta", "squeezebert", 28 | "squeezebert", "xlm", "xlmroberta", "xlnet", "debertav2") 29 | usethis::use_data(supported_model_types) 30 | -------------------------------------------------------------------------------- /tests/testdata/fake/.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bin.* filter=lfs diff=lfs merge=lfs -text 5 | *.bz2 filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.model filter=lfs diff=lfs merge=lfs -text 12 | *.msgpack filter=lfs diff=lfs merge=lfs -text 13 | *.onnx filter=lfs diff=lfs merge=lfs -text 14 | *.ot filter=lfs diff=lfs merge=lfs -text 15 | *.parquet filter=lfs diff=lfs merge=lfs -text 16 | *.pb filter=lfs diff=lfs merge=lfs -text 17 | *.pt filter=lfs diff=lfs merge=lfs -text 18 | *.pth filter=lfs diff=lfs merge=lfs -text 19 | *.rar filter=lfs diff=lfs merge=lfs -text 20 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 21 | *.tar.* filter=lfs diff=lfs merge=lfs -text 22 | *.tflite filter=lfs diff=lfs merge=lfs -text 23 | *.tgz filter=lfs diff=lfs merge=lfs -text 24 | *.xz filter=lfs diff=lfs merge=lfs -text 25 | *.zip filter=lfs diff=lfs merge=lfs -text 26 | *.zstandard filter=lfs diff=lfs merge=lfs -text 27 | *tfevents* filter=lfs diff=lfs merge=lfs -text 28 | -------------------------------------------------------------------------------- /tests/testdata/fake/README.md: -------------------------------------------------------------------------------- 1 | Hugging Face's logo 2 | --- 3 | language: 4 | - om 5 | - am 6 | - rw 7 | - rn 8 | - ha 9 | - ig 10 | - pcm 11 | - so 12 | - sw 13 | - ti 14 | - yo 15 | - multilingual 16 | 17 | --- 18 | # afriberta_base 19 | ## Model description 20 | AfriBERTa base is a pretrained multilingual language model with around 111 million parameters. 21 | The model has 8 layers, 6 attention heads, 768 hidden units and 3072 feed forward size. 22 | The model was pretrained on 11 African languages namely - Afaan Oromoo (also called Oromo), Amharic, Gahuza (a mixed language containing Kinyarwanda and Kirundi), Hausa, Igbo, Nigerian Pidgin, Somali, Swahili, Tigrinya and Yorùbá. 23 | The model has been shown to obtain competitive downstream performances on text classification and Named Entity Recognition on several African languages, including those it was not pretrained on. 24 | 25 | 26 | ## Intended uses & limitations 27 | 28 | #### How to use 29 | You can use this model with Transformers for any downstream task. 30 | For example, assuming we want to finetune this model on a token classification task, we do the following: 31 | 32 | ```python 33 | >>> from transformers import AutoTokenizer, AutoModelForTokenClassification 34 | >>> model = AutoModelForTokenClassification.from_pretrained("castorini/afriberta_base") 35 | >>> tokenizer = AutoTokenizer.from_pretrained("castorini/afriberta_base") 36 | # we have to manually set the model max length because it is an imported sentencepiece model, which huggingface does not properly support right now 37 | >>> tokenizer.model_max_length = 512 38 | ``` 39 | 40 | #### Limitations and bias 41 | - This model is possibly limited by its training dataset which are majorly obtained from news articles from a specific span of time. Thus, it may not generalize well. 42 | - This model is trained on very little data (less than 1 GB), hence it may not have seen enough data to learn very complex linguistic relations. 43 | 44 | ## Training data 45 | The model was trained on an aggregation of datasets from the BBC news website and Common Crawl. 46 | 47 | ## Training procedure 48 | For information on training procedures, please refer to the AfriBERTa [paper]() or [repository](https://github.com/keleog/afriberta) 49 | 50 | ### BibTeX entry and citation info 51 | ``` 52 | @inproceedings{ogueji-etal-2021-small, 53 | title = "Small Data? No Problem! Exploring the Viability of Pretrained Multilingual Language Models for Low-resourced Languages", 54 | author = "Ogueji, Kelechi and 55 | Zhu, Yuxin and 56 | Lin, Jimmy", 57 | booktitle = "Proceedings of the 1st Workshop on Multilingual Representation Learning", 58 | month = nov, 59 | year = "2021", 60 | address = "Punta Cana, Dominican Republic", 61 | publisher = "Association for Computational Linguistics", 62 | url = "https://aclanthology.org/2021.mrl-1.11", 63 | pages = "116--126", 64 | } 65 | ``` 66 | 67 | 68 | -------------------------------------------------------------------------------- /tests/testdata/fake/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "/Users/kelechogueji/Downloads/afriberta_base", 3 | "architectures": [ 4 | "XLMRobertaForMaskedLM" 5 | ], 6 | "attention_probs_dropout_prob": 0.1, 7 | "bos_token_id": 0, 8 | "eos_token_id": 2, 9 | "gradient_checkpointing": false, 10 | "hidden_act": "gelu", 11 | "hidden_dropout_prob": 0.1, 12 | "hidden_size": 768, 13 | "initializer_range": 0.02, 14 | "intermediate_size": 3072, 15 | "layer_norm_eps": 1e-05, 16 | "max_length": 512, 17 | "max_position_embeddings": 514, 18 | "model_type": "xlm-roberta", 19 | "num_attention_heads": 6, 20 | "num_hidden_layers": 8, 21 | "output_past": true, 22 | "pad_token_id": 1, 23 | "position_embedding_type": "absolute", 24 | "transformers_version": "4.2.1", 25 | "type_vocab_size": 1, 26 | "use_cache": true, 27 | "vocab_size": 70006 28 | } 29 | -------------------------------------------------------------------------------- /tests/testdata/fake/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": "<mask>"} -------------------------------------------------------------------------------- /tests/testdata/fake/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "cls_token": "<s>", "pad_token": "<pad>", "mask_token": "<mask>", "special_tokens_map_file": null, "tokenizer_file": null, "name_or_path": "spm_models/spm_model_final_70k"} -------------------------------------------------------------------------------- /tests/testthat.R: -------------------------------------------------------------------------------- 1 | # This file is part of the standard setup for testthat. 2 | # It is recommended that you do not modify it. 3 | # 4 | # Where should you do additional test configuration? 5 | # Learn more about the roles of various files in: 6 | # * https://r-pkgs.org/tests.html 7 | # * https://testthat.r-lib.org/reference/test_package.html#special-files 8 | 9 | library(testthat) 10 | library(grafzahl) 11 | 12 | test_check("grafzahl") 13 | -------------------------------------------------------------------------------- /tests/testthat/test_grafzahl.R: -------------------------------------------------------------------------------- 1 | Sys.setenv(KILL_SWITCH = "KILL") 2 | txt <- c(d1 = "Chinese Beijing Chinese", 3 | d2 = "Chinese Chinese Shanghai", 4 | d3 = "Chinese", 5 | d4 = "Tokyo Japan Chinese", 6 | d5 = "Chinese Chinese Chinese Tokyo Japan") 7 | y <- factor(c("Y", "Y", "Y", "N", "Y"), ordered = TRUE) 8 | 9 | test_that("basic", { 10 | expect_error(grafzahl(x = txt, y = y, model_name = "bert-base-cased", train_size = 1, num_train_epochs = 1), NA) 11 | }) 12 | 13 | test_that(".infer local", { 14 | expect_error(.infer_model_type("../testdata/fake"), NA) 15 | expect_equal(.infer_model_type("../testdata/fake"), "xlm-roberta") 16 | ## Integration with .check_model_type 17 | expect_equal(.check_model_type(model_type = NULL, model_name = "../testdata/fake"), "xlmroberta") 18 | ## Integration with grafzahl 19 | expect_error(grafzahl(x = txt, y = y, model_name = "../testdata/fake", train_size = 1, num_train_epochs = 1), NA) 20 | }) 21 | 22 | test_that(".check_model_type", { 23 | expect_error(.check_model_type(model_type = NULL)) 24 | expect_error(.check_model_type(model_type = "idk")) 25 | expect_error(.check_model_type(model_type = "xlmroberta", model_name = "xlm-roberta-base"), NA) 26 | expect_error(.check_model_type(model_type = "xlm-roberta", model_name = "xlm-roberta-base"), NA) 27 | expect_error(.check_model_type(model_type = "XLM-roberta", model_name = "xlm-roberta-base"), NA) 28 | ## Integration with grafzahl 29 | expect_error(grafzahl(x = txt, y = y, model_type = "idk", model_name = "bert-base-cased", train_size = 1, num_train_epochs = 1)) 30 | expect_error(grafzahl(x = txt, y = y, model_type = "xlmroberta", model_name = "xlm-roberta-base", train_size = 1, num_train_epochs = 1), NA) 31 | expect_error(grafzahl(x = txt, y = y, model_type = "xlm-roberta", model_name = "xlm-roberta-base", train_size = 1, num_train_epochs = 1), NA) 32 | expect_error(grafzahl(x = txt, y = y, model_type = "XLM-roberta", model_name = "xlm-roberta-base", train_size = 1, num_train_epochs = 1), NA) 33 | }) 34 | 35 | Sys.setenv(KILL_SWITCH = "") 36 | -------------------------------------------------------------------------------- /tests/testthat/test_setup.R: -------------------------------------------------------------------------------- 1 | ## test_that("setup", { 2 | ## skip_on_cran() 3 | ## testbed <- file.path(tempdir(), "testbed") 4 | ## dir.create(testbed) 5 | ## withr::local_envvar(GRAFZAHL_MINICONDA_PATH = testbed) 6 | ## expect_error(setup_grafzahl(force = TRUE), NA) 7 | ## expect_true(detect_conda()) 8 | ## expect_false(detect_cuda()) 9 | ## txt <- c(d1 = "Chinese Beijing Chinese", 10 | ## d2 = "Chinese Chinese Shanghai", 11 | ## d3 = "Chinese", 12 | ## d4 = "Tokyo Japan Chinese", 13 | ## d5 = "Chinese Chinese Chinese Tokyo Japan") 14 | ## y <- factor(c("Y", "Y", "Y", "N", "Y"), ordered = TRUE) 15 | 16 | ## expect_error(model <- grafzahl(x = txt, y = y, train_size = 1, 17 | ## num_train_epochs = 1, 18 | ## model_name = "bert-base-cased", 19 | ## cuda = FALSE), NA) 20 | ## expect_error(predict(model, cuda = FALSE), NA) 21 | ## }) 22 | -------------------------------------------------------------------------------- /vignettes/.gitignore: -------------------------------------------------------------------------------- 1 | *.html 2 | *.R 3 | -------------------------------------------------------------------------------- /vignettes/grafzahl.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Setup Guide" 3 | output: rmarkdown::html_vignette 4 | vignette: > 5 | %\VignetteIndexEntry{Setup Guide} 6 | %\VignetteEngine{knitr::rmarkdown} 7 | %\VignetteEncoding{UTF-8} 8 | --- 9 | 10 | ```{r, include = FALSE} 11 | knitr::opts_chunk$set( 12 | collapse = TRUE, 13 | comment = "#>" 14 | ) 15 | ``` 16 | 17 | This is a quick setup guide for different situations. 18 | 19 | `grafzahl` requires a Python environment. By default, `grafzahl` assumes you would like to use a miniconda-based Python environment. It can be installed by using the provided `setup_grafzahl()` function. 20 | 21 | ```r 22 | require(grafzahl) 23 | setup_grafzahl(cuda = TRUE) # FALSE if you don't have CUDA compatible GPUs 24 | 25 | ## Use grafzahl right away, an example 26 | model <- grafzahl(unciviltweets, model_type = "bertweet", model_name = "vinai/bertweet-base") 27 | 28 | ``` 29 | 30 | There are other setup options. 31 | 32 | # Google Colab and similar services 33 | 34 | In order to use `grafzahl` on Google Colab, please choose the R-based Runtime (Runtime > Change Runtime Type > Runtime Type: R). You might also want to choose a hardware accelerator, e.g. T4 GPU. 35 | 36 | In this case, you need to enable the non-Conda mode, i.e. `use_nonconda()`. By default, it will also install the required Python packages. 37 | 38 | ```r 39 | install.packages("grafzahl") 40 | use_nonconda(install = TRUE, check = TRUE) # default 41 | 42 | ## Use grafzahl right away, an example 43 | model <- grafzahl(unciviltweets, model_type = "bertweet", model_name = "vinai/bertweet-base") 44 | ``` 45 | 46 | # Default Python 47 | 48 | If you don't want to use any conda configuration on your local machine, you can just install the Python packages `simpletransformers` and `emoji`. 49 | 50 | ```bash 51 | python3 -m pip install simpletransformers emoji 52 | ``` 53 | 54 | And then 55 | 56 | ```r 57 | require(grafzahl) 58 | use_nonconda(install = FALSE, check = TRUE) ## what it does is just: options("grafzahl.nonconda" = TRUE) 59 | 60 | ## Use grafzahl right away, an example 61 | model <- grafzahl(unciviltweets, model_type = "bertweet", model_name = "vinai/bertweet-base") 62 | ``` 63 | 64 | # Use conda, but not the grafzahl's default 65 | 66 | Suppose you have installed a conda installation elsewhere. Please note the `base` path of your conda installation. 67 | 68 | ```bash 69 | conda env list 70 | ``` 71 | 72 | Create a new conda environment with the default grafzahl environment name 73 | 74 | ## With Cuda 75 | 76 | ```bash 77 | conda env create -n grafzahl_condaenv_cuda 78 | conda activate grafzahl_condaenv_cuda 79 | conda install -n grafzahl_condaenv_cuda python pip pytorch pytorch-cuda cudatoolkit -c pytorch -c nvidia 80 | python -m pip install simpletransformers emoji 81 | conda deactivate 82 | 83 | ## Test the CUDA installation with 84 | 85 | Rscript -e "grafzahl::detect_cuda()" 86 | ``` 87 | 88 | ## Without Cuda 89 | 90 | ```bash 91 | conda env create -n grafzahl_condaenv 92 | conda activate grafzahl_condaenv 93 | conda install -n grafzahl_condaenv python pip pytorch -c pytorch 94 | python -m pip install simpletransformers emoji 95 | conda deactivate 96 | ``` 97 | 98 | In R, you have to change to default conda path 99 | 100 | ```r 101 | ## suppose /home/yourname/miniconda is the base path of your conda installation 102 | require(grafzahl) 103 | Sys.setenv(GRAFZAHL_MINICONDA_PATH = "/home/yourname/miniconda") 104 | 105 | ## Use grafzahl right away, an example 106 | model <- grafzahl(unciviltweets, model_type = "bertweet", model_name = "vinai/bertweet-base") 107 | ``` 108 | 109 | # Explanation: Important options and envvars 110 | 111 | There are two important options and envvars. `options("grafzahl.nonconda")` controls whether to use the non-conda mode. Envvar `GRAFZAHL_MINICONDA_PATH` controls the base path of the conda installation. If it is `""` (the default), `reticulate::miniconda_path()` is used as the base path. 112 | --------------------------------------------------------------------------------