├── .Rbuildignore ├── .gitignore ├── CellSpace.Rproj ├── CellSpace.png ├── DESCRIPTION ├── LICENSE.md ├── NAMESPACE ├── R ├── CellSpace.R └── README.md ├── README.md ├── cpp ├── .gitignore ├── README.md ├── StarSpace-MIT-License.md ├── makefile └── src │ ├── data.cpp │ ├── data.h │ ├── dict.cpp │ ├── dict.h │ ├── main.cpp │ ├── matrix.h │ ├── model.cpp │ ├── model.h │ ├── parser.cpp │ ├── parser.h │ ├── proj.cpp │ ├── proj.h │ ├── starspace.cpp │ ├── starspace.h │ └── utils │ ├── args.cpp │ ├── args.h │ ├── normalize.cpp │ ├── normalize.h │ ├── utils.cpp │ └── utils.h ├── man ├── CellSpace-class.Rd ├── CellSpace.Rd ├── DNA_sequence_embedding.Rd ├── add_motif_db.Rd ├── cosine_similarity.Rd ├── docs │ ├── CellSpace.md │ ├── DNA_sequence_embedding.md │ ├── add_motif_db.md │ ├── cosine_similarity.md │ ├── embedding_distance.md │ ├── find_clusters.md │ ├── find_neighbors.md │ ├── merge_small_clusters.md │ ├── motif_embedding.md │ └── run_UMAP.md ├── embedding_distance.Rd ├── find_clusters.Rd ├── find_neighbors.Rd ├── merge_small_clusters.Rd ├── motif_embedding.Rd └── run_UMAP.Rd └── tutorial ├── README.md ├── data ├── CellSpace_embedding-var_tiles.tsv.gz ├── PWM-list.rds ├── cell-names.txt ├── cell_by_peak-counts.mtx.gz ├── cell_by_tile-counts.mtx.gz ├── palette.rds ├── sample-info.tsv ├── var_peaks.fa.gz └── var_tiles.fa.gz ├── plot-functions.R ├── plots ├── UMAP-cells.png ├── UMAP-cells_and_TFs.png └── motif-scores.png ├── variable-peaks ├── IterativeLSI.R └── filter-peaks.R └── variable-tiles ├── 1-fastq-dump.sh ├── 2-TrimGalore.sh ├── 3-bowtie2.sh ├── 4-samtools.sh ├── 5-ArchR.R ├── README.md ├── SRR_Acc_List.txt └── addBarcodeTag.cpp /.Rbuildignore: -------------------------------------------------------------------------------- 1 | ^CellSpace\.Rproj$ 2 | ^\.Rproj\.user$ 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | .Rproj.user 3 | .Rhistory 4 | 5 | 6 | -------------------------------------------------------------------------------- /CellSpace.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: No 4 | SaveWorkspace: No 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: Sweave 13 | LaTeX: pdfLaTeX 14 | 15 | AutoAppendNewline: Yes 16 | StripTrailingWhitespace: Yes 17 | LineEndingConversion: Posix 18 | 19 | BuildType: Package 20 | PackageUseDevtools: Yes 21 | PackageInstallArgs: --no-multiarch --with-keep.source 22 | PackageRoxygenize: rd,collate,namespace 23 | -------------------------------------------------------------------------------- /CellSpace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zakieh-tayyebi/CellSpace/c5dc59b1a3c843596834acab4d4e421e4cd59343/CellSpace.png -------------------------------------------------------------------------------- /DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: CellSpace 2 | Title: Downstream Analyses for CellSpace 3 | Description: A toolkit to analyze single-cell ATAC-seq data based on a learned CellSpace embedding 4 | Version: 0.0.6 5 | Authors@R: c( 6 | person("Zakieh", "Tayyebi", email = "zakieh.tayyebi@gmail.com", role = c("aut", "cre"), 7 | comment = c(ORCID = "https://orcid.org/0000-0002-7821-9905")), 8 | person("The Christina Leslie Lab", role = "fnd", 9 | comment = c(ORCID = "https://orcid.org/0000-0002-4571-5910")), 10 | person("The Tri-Institutional PhD Program in Computational Biology & Medicine", role = "fnd")) 11 | URL: https://github.com/zakieh-tayyebi/CellSpace 12 | BugReports: https://github.com/zakieh-tayyebi/CellSpace/issues 13 | Imports: 14 | Seurat, 15 | Biostrings 16 | Encoding: UTF-8 17 | Roxygen: list(markdown = TRUE) 18 | RoxygenNote: 7.2.3 19 | Suggests: 20 | knitr, 21 | rmarkdown 22 | VignetteBuilder: knitr 23 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2023 The Christina Leslie Lab, MSK CC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | export(CellSpace) 4 | export(DNA_sequence_embedding) 5 | export(add_motif_db) 6 | export(cosine_similarity) 7 | export(embedding_distance) 8 | export(find_clusters) 9 | export(find_neighbors) 10 | export(merge_small_clusters) 11 | export(motif_embedding) 12 | export(run_UMAP) 13 | exportClasses(CellSpace) 14 | importFrom(Biostrings,DNAStringSet) 15 | importFrom(Biostrings,as.matrix) 16 | importFrom(Biostrings,reverseComplement) 17 | importFrom(Seurat,FindClusters) 18 | importFrom(Seurat,FindNeighbors) 19 | importFrom(Seurat,RunUMAP) 20 | -------------------------------------------------------------------------------- /R/CellSpace.R: -------------------------------------------------------------------------------- 1 | #' The CellSpace Class 2 | #' 3 | #' The \code{CellSpace} class stores CellSpace embedding and 4 | #' related information needed for performing downstream analyses. 5 | #' 6 | #' @slot project title of the project 7 | #' @slot emb.file the .tsv output of CellSpace containing the embedding matrix for cells and k-mers 8 | #' @slot cell.emb the embedding matrix for cells 9 | #' @slot kmer.emb the embedding matrix for k-mers 10 | #' @slot motif.emb the embedding matrix for TF motifs 11 | #' @slot meta.data data frame containing meta-information about each cell 12 | #' @slot dim the dimensions of the CellSpace embeddings 13 | #' @slot k the length of DNA k-mers 14 | #' @slot similarity the similarity function in hinge loss 15 | #' @slot p the embedding of an entity equals the sum of its \code{M} feature embedding vectors divided by \code{M^p} 16 | #' @slot label cell label prefix 17 | #' @slot neighbors list containing nearest neighbor graphs 18 | #' @slot reductions list containing dimensional reductions 19 | #' @slot misc list containing miscellaneous objects 20 | #' 21 | #' @name CellSpace-class 22 | #' @rdname CellSpace-class 23 | #' @exportClass CellSpace 24 | #' 25 | setClass("CellSpace", slots = list( 26 | project = "character", 27 | emb.file = "character", 28 | cell.emb = "matrix", 29 | kmer.emb = "matrix", 30 | motif.emb = "list", 31 | meta.data = "data.frame", 32 | dim = "integer", 33 | k = "integer", 34 | similarity = "character", 35 | p = "numeric", 36 | label = "character", 37 | neighbors = "list", 38 | reductions = "list", 39 | misc = "list" 40 | )) 41 | 42 | setMethod( 43 | "show", "CellSpace", 44 | function(object){ 45 | cat("An object of class \"CellSpace\"\n") 46 | if(object@project != "") cat("Project:", object@project, "\n") 47 | cat("CellSpace model: ", object@dim, "-dimensional embedding for ", 48 | nrow(object@cell.emb), " cells and all DNA k-mers (k=", object@k, ")\n", sep = "" 49 | ) 50 | if(length(object@motif.emb) > 0){ 51 | cat("TF motif embedding matrices: ") 52 | cat(names(object@motif.emb), sep = ", ") 53 | cat("\n") 54 | } 55 | if(ncol(object@meta.data) > 0){ 56 | cat("Cell meta-data: ") 57 | if(length(colnames(object@meta.data)) > 10){ 58 | cat(head(colnames(object@meta.data), 3), sep = ", ") 59 | cat(", ..., ") 60 | cat(tail(colnames(object@meta.data), 2), sep = ", ") 61 | } else cat(colnames(object@meta.data), sep = ", ") 62 | cat("\n") 63 | } 64 | if(length(object@neighbors) > 0){ 65 | cat("Nearest neighbor graphs created: ") 66 | cat(names(object@neighbors), sep = ", ") 67 | cat("\n") 68 | } 69 | if(length(object@reductions) > 0){ 70 | cat("Dimensional reductions calculated: ") 71 | cat(names(object@reductions), sep = ", ") 72 | cat("\n") 73 | } 74 | if(length(object@misc) > 0){ 75 | cat("miscellaneous: ") 76 | cat(names(object@misc), sep = ", ") 77 | cat("\n") 78 | } 79 | } 80 | ) 81 | 82 | setMethod( 83 | "$", "CellSpace", 84 | function(x, name){ 85 | if(name %in% colnames(x@meta.data)){ 86 | return(x@meta.data[, name]) 87 | } else return(NULL) 88 | } 89 | ) 90 | 91 | setMethod( 92 | "$<-", "CellSpace", 93 | function(x, name, value){ 94 | x@meta.data[, name] <- value 95 | return(x) 96 | } 97 | ) 98 | 99 | #' CellSpace 100 | #' 101 | #' Generates an object from the \code{CellSpace} class. 102 | #' 103 | #' @param emb.file the .tsv output of CellSpace containing the embedding matrix for cells and k-mers 104 | #' @param cell.names vector of unique cell names 105 | #' @param meta.data a \code{data.frame} containing meta-information about each cell 106 | #' @param project title of the project 107 | #' @param similarity the similarity function in hinge loss 108 | #' @param p the embedding of an entity equals the sum of its \code{M} feature embedding vectors divided by \code{M^p} 109 | #' @param label cell label prefix 110 | #' 111 | #' @return a new \code{CellSpace} object 112 | #' 113 | #' @name CellSpace 114 | #' @rdname CellSpace 115 | #' @export 116 | #' 117 | CellSpace <- function( 118 | emb.file, 119 | cell.names = NULL, 120 | meta.data = NULL, 121 | project = NULL, 122 | similarity = "cosine", 123 | p = 0.5, 124 | label = "__label__" 125 | ){ 126 | project <- ifelse(is.null(project), "", project) 127 | 128 | emb.file <- normalizePath(emb.file) 129 | emb <- read.table(emb.file, header = F, row.names = 1, sep = "\t") 130 | dim <- ncol(emb) 131 | colnames(emb) <- paste0("CS", 1:dim) 132 | 133 | cell.label <- paste0(label, "C") 134 | cell.labels <- rownames(emb)[grep(cell.label, rownames(emb))] 135 | cell.idx <- sort(as.integer(gsub(cell.label, "", cell.labels))) 136 | cell.emb <- as.matrix(emb[paste0(cell.label, cell.idx), ]) 137 | 138 | kmer.emb <- as.matrix(emb[!grepl(label, rownames(emb)), ]) 139 | k <- unique(nchar(rownames(kmer.emb))) 140 | if(length(k) > 1) stop("Varying k-mer lengths!") 141 | 142 | if(!is.null(cell.names)){ 143 | if(length(cell.names) != nrow(cell.emb) || any(duplicated(cell.names))){ 144 | warning("\'cell.names\' must be a character vector with ", nrow(cell.emb), " unique values!") 145 | cell.names <- NULL 146 | } 147 | } 148 | 149 | if(!is.null(meta.data)){ 150 | if(nrow(meta.data) != nrow(cell.emb)){ 151 | warning("The number of rows in \'meta.data\' does not match the cell embedding matrix!") 152 | meta.data <- NULL 153 | } 154 | } 155 | 156 | if(is.null(cell.names)){ 157 | if(is.null(meta.data)){ 158 | meta.data <- data.frame(row.names = rownames(cell.emb), check.rows = F, check.names = F) 159 | } else rownames(cell.emb) <- rownames(meta.data) 160 | } else { 161 | if(is.null(meta.data)){ 162 | meta.data <- data.frame(row.names = cell.names, check.rows = F, check.names = F) 163 | } else rownames(meta.data) <- cell.names 164 | rownames(cell.emb) <- cell.names 165 | } 166 | 167 | new("CellSpace", 168 | project = project, 169 | emb.file = emb.file, 170 | cell.emb = cell.emb, 171 | kmer.emb = kmer.emb, 172 | motif.emb = list(), 173 | meta.data = meta.data, 174 | dim = dim, 175 | k = k, 176 | similarity = similarity, 177 | p = p, 178 | label = label, 179 | neighbors = list(), 180 | reductions = list(), 181 | misc = list() 182 | ) 183 | } 184 | 185 | #' find_neighbors 186 | #' 187 | #' Builds a nearest neighbor graph and shared nearest neighbor graph from the CellSpace embedding. 188 | #' 189 | #' @importFrom Seurat FindNeighbors 190 | #' 191 | #' @param object a \code{CellSpace} object 192 | #' @param n.neighbors the number of nearest neighbors for the KNN algorithm 193 | #' @param emb the embedding matrix used to create the nearest neighbor graphs 194 | #' @param emb.name prefix for the graph names that will be added to the \code{neighbors} slot 195 | #' @param ... arguments passed to \code{Seurat::FindNeighbors} 196 | #' 197 | #' @return a \code{CellSpace} object containing nearest neighbor and shared nearest neighbor graphs in the \code{neighbors} slot 198 | #' 199 | #' @name find_neighbors 200 | #' @rdname find_neighbors 201 | #' @export 202 | #' 203 | find_neighbors <- function( 204 | object, 205 | n.neighbors = 30, 206 | emb = object@cell.emb, 207 | emb.name = "cells", 208 | ... 209 | ){ 210 | graphs <- FindNeighbors( 211 | emb, distance.matrix = F, k.param = n.neighbors, l2.norm = F, 212 | nn.method = "annoy", annoy.metric = object@similarity, ... 213 | ) 214 | names(graphs) <- paste(emb.name, names(graphs), sep = "_") 215 | object@neighbors[names(graphs)] <- graphs 216 | return(object) 217 | } 218 | 219 | #' find_clusters 220 | #' 221 | #' Finds clusters from a shared nearest neighbor graph built from the CellSpace embedding. 222 | #' 223 | #' @importFrom Seurat FindClusters 224 | #' 225 | #' @param object a \code{CellSpace} object 226 | #' @param graph name of the shared nearest neighbor graph in the \code{neighbors} slot used to find clusters 227 | #' @param ... arguments passed to \code{Seurat::FindClusters} 228 | #' 229 | #' @return a \code{CellSpace} object with the cell clusters added to the \code{meta.data} slot 230 | #' 231 | #' @name find_clusters 232 | #' @rdname find_clusters 233 | #' @export 234 | #' 235 | find_clusters <- function(object, graph = "cells_snn", ...){ 236 | if(!(graph %in% names(object@neighbors))) 237 | stop("\'", graph, "\' not available! Run \'find_neighbors\' to create nearest neighbor graphs.") 238 | cl <- FindClusters(object@neighbors[[graph]], ...) 239 | colnames(cl) <- gsub("res\\.", "Clusters.res_", colnames(cl)) 240 | object@meta.data[, colnames(cl)] <- cl 241 | return(object) 242 | } 243 | 244 | #' merge_small_clusters 245 | #' 246 | #' Merges cells from small clusters with the nearest clusters. 247 | #' 248 | #' @param object a \code{CellSpace} object 249 | #' @param clusters a vector of cluster labels, or the name of a column in the \code{meta.data} slot containing cluster labels 250 | #' @param min.cells any cluster with fewer cells than \code{min.cells} will be merged with the nearest cluster 251 | #' @param graph a shared nearest neighbor graph, or the name of a graph in the \code{neighbors} slot, used to find clusters 252 | #' 253 | #' @return new cluster labels 254 | #' 255 | #' @name merge_small_clusters 256 | #' @rdname merge_small_clusters 257 | #' @export 258 | #' 259 | merge_small_clusters <- function( 260 | object, 261 | clusters, 262 | min.cells = 10, 263 | graph = "cells_snn", 264 | seed = 1 265 | ){ 266 | if(class(graph) == "character" && graph %in% names(object@neighbors)){ 267 | graph <- object@neighbors[[graph]] 268 | } else if(class(graph) != "Graph") 269 | stop("'graph' must be a nearest neighbor graph of class 'Graph'") 270 | 271 | if(length(clusters) == 1 && clusters %in% colnames(object@meta.data)){ 272 | clusters <- object@meta.data[, clusters] 273 | } else if(length(clusters) == nrow(graph)){ 274 | if(class(clusters) != "factor") 275 | clusters <- factor(clusters) 276 | } else stop("'clusters' must be the name of a column in 'object@meta.data', or a vector of the same length as the number of cells.") 277 | 278 | small.clusters <- names(which(table(clusters) < min.cells)) %>% intersect(levels(clusters)) 279 | cluster_names <- setdiff(levels(clusters), small.clusters) 280 | connectivity <- vector(mode = "numeric", length = length(cluster_names)) 281 | names(connectivity) <- cluster_names 282 | 283 | new.ids <- clusters 284 | for(i in small.clusters){ 285 | i.cells <- which(clusters == i) 286 | for(j in cluster_names){ 287 | j.cells <- which(clusters == j) 288 | subSNN <- graph[i.cells, j.cells] 289 | if(is.object(subSNN)){ 290 | connectivity[j] <- sum(subSNN) / (nrow(subSNN) * ncol(subSNN)) 291 | } else connectivity[j] <- mean(subSNN) 292 | } 293 | set.seed(seed = seed) 294 | m <- max(connectivity, na.rm = T) 295 | mi <- which(connectivity == m, arr.ind = T) 296 | closest_cluster <- sample(names(connectivity[mi]), 1) 297 | new.ids[i.cells] <- closest_cluster 298 | } 299 | 300 | return(factor(new.ids, levels = cluster_names)) 301 | } 302 | 303 | #' run_UMAP 304 | #' 305 | #' Computes a UMAP embedding from the CellSpace embedding. 306 | #' 307 | #' @importFrom Seurat RunUMAP 308 | #' 309 | #' @param object a \code{CellSpace} object 310 | #' @param emb the embedding matrix used to compute the UMAP embedding 311 | #' @param graph name of the nearest neighbor graph in the \code{neighbors} slot used to compute the UMAP embedding 312 | #' @param name name of the lower-dimensional embedding that will be added to the \code{reductions} slot 313 | #' @param ... arguments passed to \code{Seurat::RunUMAP} 314 | #' 315 | #' @return a \code{CellSpace} object containing a UMAP embedding in the \code{reductions} slot 316 | #' 317 | #' @name run_UMAP 318 | #' @rdname run_UMAP 319 | #' @export 320 | #' 321 | run_UMAP <- function( 322 | object, 323 | emb = object@cell.emb, 324 | graph = NULL, 325 | name = "cells_UMAP", 326 | ... 327 | ){ 328 | if(!is.null(graph)){ 329 | if(graph %in% names(object@neighbors)){ 330 | umap <- RunUMAP(object = object@neighbors[[graph]], assay = "CellSpace", ...) 331 | } else stop("\'", graph, "\' not available! Run \'find_neighbors\' to create nearest neighbor graphs.") 332 | } else if(!is.null(emb)){ 333 | umap <- RunUMAP( 334 | object = emb, 335 | metric = object@similarity, 336 | assay = "CellSpace", 337 | ... 338 | ) 339 | } else stop("'emb' or 'graph' must be provided!") 340 | object@reductions[[name]] <- umap@cell.embeddings 341 | return(object) 342 | } 343 | 344 | #' cosine_similarity 345 | #' 346 | #' Computes cosine similarity in the embedding space. 347 | #' 348 | #' @param x an embedding matrix 349 | #' @param y an embedding matrix with compatible dimensions to \code{x}, or \code{NULL}, in which case \code{y=x} 350 | #' 351 | #' @return a matrix containing the cosine similarity between rows of \code{x} and \code{y} 352 | #' 353 | #' @name cosine_similarity 354 | #' @rdname cosine_similarity 355 | #' @export 356 | #' 357 | cosine_similarity <- function(x, y = NULL){ 358 | if(!is.matrix(x)) x <- matrix(x, nrow = 1) 359 | if(!is.null(y) && !is.matrix(y)) y <- matrix(y, nrow = 1) 360 | 361 | normx <- sqrt(rowSums(x ^ 2)) 362 | if(is.null(y)){ 363 | y <- x 364 | normy <- normx 365 | } else normy <- sqrt(rowSums(y ^ 2)) 366 | 367 | s <- tcrossprod(x, y) / (normx %o% normy) 368 | return(s) 369 | } 370 | 371 | #' embedding_distance 372 | #' 373 | #' Computes distance in the embedding space based on cosine similarity. 374 | #' 375 | #' @param x an embedding matrix 376 | #' @param y an embedding matrix with compatible dimensions to \code{x}, or \code{NULL}, in which case \code{y=x} 377 | #' @param distance the distance metric, either 'cosine' or 'angular', to compute from the cosine similarity 378 | #' 379 | #' @return a matrix containing the distance between rows of \code{x} and \code{y}, computed from their cosine similarity 380 | #' 381 | #' @name embedding_distance 382 | #' @rdname embedding_distance 383 | #' @export 384 | #' 385 | embedding_distance <- function(x, y = NULL, distance = c("cosine", "angular")){ 386 | s <- cosine_similarity(x = x, y = y) 387 | idx1 <- which(s > 1); s[idx1] <- 1 388 | idx2 <- which(s < -1); s[idx2] <- -1 389 | 390 | distance <- distance[1] 391 | if(distance == "cosine"){ ds <- 1 - s 392 | } else if(distance == "angular"){ ds <- acos(s) / pi 393 | } else stop("The distance metric must be \'cosine\' or \'angular\'!\n") 394 | return(ds) 395 | } 396 | 397 | #' DNA_sequence_embedding 398 | #' 399 | #' Maps a DNA sequence to the embedding space. 400 | #' 401 | #' @importFrom Biostrings reverseComplement DNAStringSet 402 | #' 403 | #' @param object a \code{CellSpace} object 404 | #' @param seq a DNA sequence 405 | #' 406 | #' @return a numerical vector containing the CellSpace embedding of \code{seq} 407 | #' 408 | #' @name DNA_sequence_embedding 409 | #' @rdname DNA_sequence_embedding 410 | #' @export 411 | #' 412 | DNA_sequence_embedding <- function(object, seq){ 413 | if(!is.character(seq)) seq <- as.character(seq) 414 | sl <- nchar(seq) 415 | if(sl < object@k){ 416 | warning("The sequence \'", seq, "\' is shorter than CellSpace k-mers (k=", object@k, ")!") 417 | return(rep(NA, object@dim)) 418 | } 419 | 420 | b <- 1:(sl - object@k + 1) 421 | kmers <- substring(text = toupper(seq), first = b, last = b + object@k - 1) 422 | kmers <- ifelse( 423 | test = kmers %in% rownames(object@kmer.emb), 424 | yes = kmers, 425 | no = toupper(as.character(reverseComplement(DNAStringSet(kmers)))) 426 | ) 427 | if(!all(kmers %in% rownames(object@kmer.emb))){ 428 | warning("\'", seq, "\' is not a valid DNA sequence!") 429 | return(rep(NA, object@dim)) 430 | } 431 | 432 | kn <- length(kmers) 433 | if(kn == 1) return(object@kmer.emb[kmers, ]) 434 | else return(colSums(object@kmer.emb[kmers, ]) / (kn ^ object@p)) 435 | } 436 | 437 | #' motif_embedding 438 | #' 439 | #' Maps a motif to the embedding space. 440 | #' 441 | #' @importFrom Biostrings as.matrix 442 | #' 443 | #' @param object a \code{CellSpace} object 444 | #' @param PWM \code{PFMatrix} or \code{PWMatrix} 445 | #' 446 | #' @return a numerical vector containing the CellSpace embedding of the consensus sequence for \code{PWM} 447 | #' 448 | #' @name motif_embedding 449 | #' @rdname motif_embedding 450 | #' @export 451 | #' 452 | motif_embedding <- function(object, PWM){ 453 | pwm <- as.matrix(PWM) 454 | consensus <- paste(rownames(pwm)[apply(pwm, 2, which.max)], collapse = "") 455 | DNA_sequence_embedding(object = object, seq = consensus) 456 | } 457 | 458 | #' add_motif_db 459 | #' 460 | #' Computes the CellSpace embedding and activity scores of transcription factor motifs. 461 | #' 462 | #' @param object a \code{CellSpace} object 463 | #' @param motif.db \code{PFMatrixList} or \code{PWMatrixList} 464 | #' @param db.name the name of the transcription factor motif database 465 | #' 466 | #' @return a \code{CellSpace} object containing the motif embedding matrix, in the \code{motif.emb} slot, and the corresponding similarity Z-scores, in the \code{misc} slot 467 | #' 468 | #' @name add_motif_db 469 | #' @rdname add_motif_db 470 | #' @export 471 | #' 472 | add_motif_db <- function(object, motif.db, db.name){ 473 | object@motif.emb[[db.name]] <- lapply(motif.db, function(motif.pwm){ 474 | motif_embedding(object, PWM = motif.pwm) 475 | }) %>% do.call(what = rbind) %>% na.omit() 476 | 477 | object@misc[[paste(db.name, "scores", sep = "_")]] <- cosine_similarity( 478 | x = object@cell.emb, 479 | y = object@motif.emb[[db.name]] 480 | ) %>% scale() 481 | 482 | return(object) 483 | } 484 | 485 | -------------------------------------------------------------------------------- /R/README.md: -------------------------------------------------------------------------------- 1 | ## API 2 | 3 | | Function | Description | 4 | |-------------------------|-----------------------------------------------| 5 | | [`CellSpace`](../man/docs/CellSpace.md) | Generates an object from the CellSpace class. | 6 | | [`cosine_similarity`](../man/docs/cosine_similarity.md) | Computes cosine similarity in the embedding space. | 7 | | [`embedding_distance`](../man/docs/embedding_distance.md) | Computes distance in the embedding space based on cosine similarity. | 8 | | [`find_neighbors`](../man/docs/find_neighbors.md) | Builds a nearest neighbor graph and shared nearest neighbor graph from the CellSpace embedding. | 9 | | [`find_clusters`](../man/docs/find_clusters.md) | Finds clusters in a nearest neighbor graph built from the CellSpace embedding. | 10 | | [`merge_small_clusters`](../man/docs/merge_small_clusters.md) | Merges cells from small clusters with the nearest clusters. | 11 | | [`run_UMAP`](../man/docs/run_UMAP.md) | Computes a UMAP embedding from the CellSpace embedding. | 12 | | [`DNA_sequence_embedding`](../man/docs/DNA_sequence_embedding.md) | Maps a DNA sequence to the embedding space. | 13 | | [`motif_embedding`](../man/docs/motif_embedding.md) | Maps a motif to the embedding space. | 14 | | [`add_motif_db`](../man/docs/add_motif_db.md) | Computes the CellSpace embedding and activity scores of transcription factor motifs. | 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CellSpace 2 | 3 | CellSpace is a sequence-informed embedding method for scATAC-seq that learns a mapping of DNA *k*-mers and cells to the same space. 4 | 5 | See our [publication](https://doi.org/10.1038/s41592-024-02274-x) for more details. 6 | 7 | 8 | 9 | ## Installation and Usage 10 | 11 | 1. Compile the **C++** program to use as a command line tool to train a CellSpace model. 12 | 13 | CellSpace, which uses the **C++** implementation of [StarSpace](https://github.com/facebookresearch/StarSpace) [[Wu *et al.*, 2017](https://doi.org/10.48550/arXiv.1709.03856)], builds on modern Mac OS and Linux distributions. It requires a compiler with **C++11** support and a working **make**. 14 | 15 | Install [Boost](http://www.boost.org) library and specify the path of the library in the [makefile](cpp/makefile) (set variable **BOOST_DIR**). The default path will work if you install **Boost** by: 16 | 17 | ``` bash 18 | wget https://boostorg.jfrog.io/artifactory/main/release/1.63.0/source/boost_1_63_0.zip 19 | unzip boost_1\_63_0.zip 20 | sudo mv boost_1\_63_0 /usr/local/bin 21 | ``` 22 | 23 | Download and build CellSpace: 24 | 25 | ``` bash 26 | git clone https://github.com/zakieh-tayyebi/CellSpace.git 27 | cd CellSpace/cpp/ 28 | make 29 | export PATH=$(pwd):$PATH 30 | ``` 31 | 32 | Verify that it was successfully compiled: 33 | 34 | ``` bash 35 | CellSpace --help 36 | ``` 37 | 38 | 2. Install the **R** package to use the trained CellSpace model for downstream analysis. 39 | 40 | Run the following commands in **R**: 41 | 42 | ``` r 43 | install.packages("devtools") 44 | devtools::install_github("https://github.com/zakieh-tayyebi/CellSpace.git") 45 | library(CellSpace) 46 | ``` 47 | 48 | Installation should take only a few minutes. For details about the **R** functions, please refer to the [API](R/README.md). 49 | 50 | 3. A tutorial on CellSpace usage can be found [here](tutorial/README.md). 51 | 52 | ## Citation 53 | 54 | Please cite our [Nature Methods paper](https://doi.org/10.1038/s41592-024-02274-x) if you use CellSpace: 55 | 56 | ``` 57 | Tayyebi, Z., Pine, A.R. & Leslie, C.S. Scalable and unbiased sequence-informed embedding of single-cell ATAC-seq data with CellSpace. Nature Methods 21, 1014–1022 (2024). https://doi.org/10.1038/s41592-024-02274-x 58 | ``` 59 | 60 | ## Contact 61 | 62 | - [zakieh.tayyebi\@gmail.com](mailto:zakieh.tayyebi@gmail.com) (Zakieh Tayyebi) 63 | - [lesliec\@mskcc.org](mailto:lesliec@mskcc.org) (Christina S. Leslie, PhD) 64 | -------------------------------------------------------------------------------- /cpp/.gitignore: -------------------------------------------------------------------------------- 1 | CellSpace 2 | /CellSpace.dSYM/ 3 | *.DS_Store 4 | 5 | # Prerequisites 6 | *.d 7 | 8 | # Compiled Object files 9 | *.slo 10 | *.lo 11 | *.o 12 | *.obj 13 | 14 | # Precompiled Headers 15 | *.gch 16 | *.pch 17 | 18 | # Compiled Dynamic libraries 19 | *.so 20 | *.dylib 21 | *.dll 22 | 23 | # Fortran module files 24 | *.mod 25 | *.smod 26 | 27 | # Compiled Static libraries 28 | *.lai 29 | *.la 30 | *.a 31 | *.lib 32 | 33 | # Executables 34 | *.exe 35 | *.out 36 | *.app 37 | 38 | -------------------------------------------------------------------------------- /cpp/README.md: -------------------------------------------------------------------------------- 1 | This code is an extension of [StarSpace](https://github.com/facebookresearch/StarSpace) (see [MIT license](StarSpace-MIT-License.md)) and incorporates the StarSpace algorithm (mode 0) for scATAC-seq data. 2 | -------------------------------------------------------------------------------- /cpp/StarSpace-MIT-License.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cpp/makefile: -------------------------------------------------------------------------------- 1 | # CellSpace v0.0.5 2 | 3 | CXX = g++ 4 | CXXFLAGS = -pthread -std=gnu++11 5 | 6 | BOOST_DIR = /usr/local/bin/boost_1_63_0/ 7 | 8 | OBJS = normalize.o dict.o args.o proj.o parser.o data.o model.o starspace.o utils.o 9 | INCLUDES = -I$(BOOST_DIR) 10 | 11 | opt: CXXFLAGS += -O3 -funroll-loops 12 | opt: starspace 13 | 14 | normalize.o: src/utils/normalize.cpp src/utils/normalize.h 15 | $(CXX) $(CXXFLAGS) -g -c src/utils/normalize.cpp 16 | 17 | dict.o: src/dict.cpp src/dict.h src/utils/args.h 18 | $(CXX) $(CXXFLAGS) $(INCLUDES) -g -c src/dict.cpp 19 | 20 | args.o: src/utils/args.cpp src/utils/args.h 21 | $(CXX) $(CXXFLAGS) -g -c src/utils/args.cpp 22 | 23 | model.o: data.o src/model.cpp src/model.h src/utils/args.h src/proj.h 24 | $(CXX) $(CXXFLAGS) $(INCLUDES) -g -c src/model.cpp 25 | 26 | proj.o: src/proj.cpp src/proj.h src/matrix.h 27 | $(CXX) $(CXXFLAGS) $(INCLUDES) -g -c src/proj.cpp 28 | 29 | data.o: parser.o src/data.cpp src/data.h 30 | $(CXX) $(CXXFLAGS) $(INCLUDES) -g -c src/data.cpp -o data.o 31 | 32 | utils.o: src/utils/utils.cpp src/utils/utils.h 33 | $(CXX) $(CXXFLAGS) $(INCLUDES) -g -c src/utils/utils.cpp -o utils.o 34 | 35 | parser.o: dict.o src/parser.cpp src/parser.h 36 | $(CXX) $(CXXFLAGS) $(INCLUDES) -g -c src/parser.cpp -o parser.o 37 | 38 | starspace.o: src/starspace.cpp src/starspace.h 39 | $(CXX) $(CXXFLAGS) $(INCLUDES) -g -c src/starspace.cpp 40 | 41 | starspace: $(OBJS) 42 | $(CXX) $(CXXFLAGS) $(OBJS) $(INCLUDES) -g src/main.cpp -o CellSpace 43 | 44 | clean: 45 | rm -rf *.o CellSpace CellSpace.dSYM 46 | -------------------------------------------------------------------------------- /cpp/src/data.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include "data.h" 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | using namespace std; 16 | 17 | namespace starspace { 18 | 19 | InternDataHandler::InternDataHandler(shared_ptr args, std::shared_ptr dict, std::shared_ptr parser) { 20 | size_ = 0; 21 | idx_ = -1; 22 | examples_.clear(); 23 | dict_ = dict; 24 | parser_ = parser; 25 | args_= args; 26 | } 27 | 28 | void InternDataHandler::errorOnZeroExample(const string& fileName) { 29 | std::cerr << "ERROR: File '" << fileName 30 | << "' does not contain any valid example.\n" 31 | << "Please check: is the file empty? " 32 | << "Do the examples contain proper feature and label according to the trainMode? " 33 | << "If your examples are unlabeled, try to set trainMode=5.\n"; 34 | exit(EXIT_FAILURE); 35 | } 36 | 37 | inline bool check_header(string header){ 38 | if(header == "%%MatrixMarket matrix coordinate pattern general") return(true); 39 | else if(header == "%%MatrixMarket matrix coordinate integer general") return(false); 40 | else { 41 | cerr << "Unsupported matrix format!" << endl; 42 | exit(EXIT_FAILURE); 43 | } 44 | } 45 | 46 | void InternDataHandler::loadFromFile() { 47 | 48 | examples_ = vector(args_->nPeaks_total); 49 | ngrams_ = vector>(args_->nPeaks_total); 50 | // hashes_ = vector>(args_->nPeaks_total); 51 | 52 | unsigned pfn = 0, pnt = 0, nPeaks_cur = 0; 53 | for(auto peaks: args_->peaks_list){ 54 | auto nPeaks = args_->nPeaks_list[pfn++]; 55 | cout << "Reading " << nPeaks << " peak sequences from \'" << peaks << "\'" << endl; 56 | ifstream pf(peaks); 57 | string line, seq; 58 | int pi = -1; 59 | while(getline(pf, line)) 60 | if(line[0] == '>'){ 61 | if(pi >= 0) add_kmer_tokens(nPeaks_cur + pi, seq); 62 | pi++; pnt++; 63 | seq = ""; 64 | } else seq += line; 65 | if(pi >= 0 && seq != "") add_kmer_tokens(nPeaks_cur + pi, seq); 66 | pf.close(); 67 | 68 | if(pi != nPeaks - 1){ 69 | cerr << "The number of peak sequences should match the number of columns in the corresponding count matrix!" << endl; 70 | exit(EXIT_FAILURE); 71 | } 72 | 73 | nPeaks_cur += nPeaks; 74 | } 75 | assert(pnt == args_->nPeaks_total); 76 | 77 | unsigned cfn = 0, nCells_cur = 0; nPeaks_cur = 0; 78 | for(auto cp_matrix: args_->cp_matrix_list){ 79 | cout << "Reading a " << args_->nCells_list[cfn] << " cell by " << args_->nPeaks_list[cfn] 80 | << " peak count matrix from \'" << cp_matrix << "\'" << endl; 81 | string str, header, mat_format = cp_matrix.substr(cp_matrix.find_last_of(".") + 1); 82 | bool bin; 83 | unsigned nCells, nPeaks, peak, cell; 84 | if(mat_format == "gz"){ 85 | #ifdef COMPRESS_FILE 86 | ifstream ifs2(cp_matrix); 87 | if (!ifs2.good()) exit(EXIT_FAILURE); 88 | filtering_istream cf; 89 | cf.push(gzip_decompressor()); 90 | cf.push(ifs2); 91 | 92 | getline(cf, header); 93 | bin = check_header(header); 94 | 95 | cf >> nCells >> nPeaks >> str; 96 | assert(nCells == args_->nCells_list[cfn] && nPeaks == args_->nPeaks_list[cfn]); 97 | 98 | if(bin) 99 | while(cf >> cell){ 100 | cf >> peak; 101 | addCellLabel(nPeaks_cur + peak - 1, nCells_cur + cell); 102 | } 103 | else 104 | while(cf >> cell){ 105 | cf >> peak >> str; 106 | addCellLabel(nPeaks_cur + peak - 1, nCells_cur + cell); 107 | } 108 | 109 | nCells_cur += nCells; 110 | nPeaks_cur += nPeaks; 111 | 112 | ifs2.close(); 113 | #endif 114 | } else { 115 | ifstream cf(cp_matrix); 116 | if(!cf.good()) exit(EXIT_FAILURE); 117 | 118 | getline(cf, header); 119 | bin = check_header(header); 120 | 121 | cf >> nCells >> nPeaks >> str; 122 | assert(nCells == args_->nCells_list[cfn] && nPeaks == args_->nPeaks_list[cfn]); 123 | 124 | if(bin) 125 | while(cf >> cell){ 126 | cf >> peak; 127 | addCellLabel(nPeaks_cur + peak - 1, nCells_cur + cell); 128 | } 129 | else 130 | while(cf >> cell){ 131 | cf >> peak >> str; 132 | addCellLabel(nPeaks_cur + peak - 1, nCells_cur + cell); 133 | } 134 | 135 | nCells_cur += nCells; 136 | nPeaks_cur += nPeaks; 137 | 138 | cf.close(); 139 | } 140 | 141 | for(unsigned pi = nPeaks_cur - args_->nPeaks_list[cfn]; pi < nPeaks_cur; pi++) 142 | examples_[pi].dataset = cfn; 143 | 144 | // if(args_->batchLabels) 145 | // for(unsigned pi = nPeaks_cur - args_->nPeaks_list[cfn]; pi < nPeaks_cur; pi++){ 146 | // int32_t wid = dict_->getId(args_->label + "B" + std::to_string(cfn + 1)); 147 | // assert(wid >= 0); 148 | // examples_[pi].RHSTokens.push_back(make_pair(wid, 1.0)); 149 | // } 150 | 151 | cfn ++; 152 | } 153 | 154 | assert(nCells_cur == args_->nCells_total && nPeaks_cur == args_->nPeaks_total); 155 | 156 | size_ = args_->nPeaks_total * args_->exmpPerPeak; 157 | cout << "Number of datasets: " << args_->nCells_list.size() << endl 158 | << "Total number of cells: " << args_->nCells_total << endl 159 | << "Total number of peaks: " << args_->nPeaks_total << endl 160 | << "Number of training examples: " << size_ << " (" << args_->exmpPerPeak << " per peak)" 161 | << "\n-----------------\n" << endl; 162 | 163 | // for(auto ex: examples_){ 164 | // for(auto word: ex.RHSTokens) cerr << dict_->getSymbol(word.first) << " "; 165 | // cerr << ": "; 166 | // for(auto word: ex.LHSTokens) cerr << dict_->getSymbol(word.first) << " "; 167 | // cerr << ex.LHSTokens.size(); 168 | // cerr << endl; 169 | // } 170 | 171 | } 172 | 173 | int first_ngram_idx(int L, int i, int n){ 174 | if(i <= L - n + 1) 175 | return(i * (n - 1)); 176 | else { 177 | int first = (L - n + 1) * (n - 1); 178 | for(int j = L - n + 2; j <= i; j++) 179 | first += L - j; 180 | return(first); 181 | } 182 | } 183 | 184 | // Convert an example for training/testing if needed. 185 | // In the case of trainMode=1, a random label from r.h.s will be selected 186 | // as label, and the rest of labels from r.h.s. will be input features 187 | void InternDataHandler::convert( 188 | const ParseResults& example, 189 | // const std::vector& hashes, 190 | const std::vector& ngrams, 191 | ParseResults& rslt) const { 192 | 193 | #define LHS_SIZE (example.LHSTokens.size()) 194 | int num_tokens = (args_->sampleLen == -1) ? LHS_SIZE : (args_->sampleLen - args_->k + 1); 195 | if(num_tokens > LHS_SIZE) num_tokens = LHS_SIZE; 196 | int first = (num_tokens == LHS_SIZE) ? 0 : (rand() % (LHS_SIZE - num_tokens + 1)); 197 | 198 | rslt.weight = example.weight; 199 | rslt.dataset = example.dataset; 200 | rslt.LHSTokens.clear(); 201 | rslt.RHSTokens.clear(); 202 | auto first_token = example.LHSTokens.begin() + first; 203 | rslt.LHSTokens.insert(rslt.LHSTokens.end(), first_token, first_token + num_tokens); 204 | 205 | if(args_->ngrams > 1){ 206 | int first_ngram = first_ngram_idx(LHS_SIZE, first, args_->ngrams), 207 | last_ngram = first_ngram_idx(LHS_SIZE, first + num_tokens - 2, args_->ngrams); 208 | if(last_ngram >= first_ngram){ 209 | rslt.LHSTokens.insert(rslt.LHSTokens.end(), ngrams.begin() + first_ngram, ngrams.begin() + last_ngram); 210 | } 211 | // std::cerr << num_tokens << "\t" << last_ngram - first_ngram + 1 << std::endl; 212 | } 213 | 214 | if (args_->trainMode == 0) { 215 | // lhs is the same, pick one random label as rhs 216 | assert(example.LHSTokens.size() > 0); 217 | assert(example.RHSTokens.size() > 0); 218 | auto idx = rand() % example.RHSTokens.size(); 219 | rslt.RHSTokens.push_back(example.RHSTokens[idx]); 220 | } else { 221 | assert(example.RHSTokens.size() > 1); 222 | if (args_->trainMode == 1) { 223 | // pick one random label as rhs and the rest is lhs 224 | auto idx = rand() % example.RHSTokens.size(); 225 | for (unsigned int i = 0; i < example.RHSTokens.size(); i++) { 226 | auto tok = example.RHSTokens[i]; 227 | if (i == idx) { 228 | rslt.RHSTokens.push_back(tok); 229 | } else { 230 | rslt.LHSTokens.push_back(tok); 231 | } 232 | } 233 | } else 234 | if (args_->trainMode == 2) { 235 | // pick one random label as lhs and the rest is rhs 236 | auto idx = rand() % example.RHSTokens.size(); 237 | for (unsigned int i = 0; i < example.RHSTokens.size(); i++) { 238 | auto tok = example.RHSTokens[i]; 239 | if (i == idx) { 240 | rslt.LHSTokens.push_back(tok); 241 | } else { 242 | rslt.RHSTokens.push_back(tok); 243 | } 244 | } 245 | } else 246 | if (args_->trainMode == 3) { 247 | // pick two random labels, one as lhs and the other as rhs 248 | auto idx = rand() % example.RHSTokens.size(); 249 | unsigned int idx2; 250 | do { 251 | idx2 = rand() % example.RHSTokens.size(); 252 | } while (idx2 == idx); 253 | rslt.LHSTokens.push_back(example.RHSTokens[idx]); 254 | rslt.RHSTokens.push_back(example.RHSTokens[idx2]); 255 | } else 256 | if (args_->trainMode == 4) { 257 | // the first one as lhs and the second one as rhs 258 | rslt.LHSTokens.push_back(example.RHSTokens[0]); 259 | rslt.RHSTokens.push_back(example.RHSTokens[1]); 260 | } 261 | } 262 | } 263 | 264 | void InternDataHandler::getWordExamples( 265 | const vector& doc, 266 | vector& rslts) const { 267 | 268 | rslts.clear(); 269 | for (int widx = 0; widx < (int)(doc.size()); widx++) { 270 | ParseResults rslt; 271 | rslt.LHSTokens.clear(); 272 | rslt.RHSTokens.clear(); 273 | rslt.RHSTokens.push_back(doc[widx]); 274 | for (unsigned int i = max(widx - args_->ws, 0); 275 | i < min(size_t(widx + args_->ws), doc.size()); i++) { 276 | if ((int)i != widx) { 277 | rslt.LHSTokens.push_back(doc[i]); 278 | } 279 | } 280 | rslt.weight = args_->wordWeight; 281 | rslts.emplace_back(rslt); 282 | } 283 | } 284 | 285 | void InternDataHandler::getWordExamples( 286 | int idx, 287 | vector& rslts) const { 288 | assert(idx >= 0 && idx < size_); 289 | idx = idx % args_->nPeaks_total; 290 | const auto& example = examples_[idx]; 291 | getWordExamples(example.LHSTokens, rslts); 292 | } 293 | 294 | void InternDataHandler::addExample(const ParseResults& example) { 295 | examples_.push_back(example); 296 | size_++; 297 | } 298 | 299 | void InternDataHandler::getExampleById(int32_t idx, ParseResults& rslt) const { 300 | assert(idx >= 0 && idx < size_); 301 | idx = idx % args_->nPeaks_total; 302 | convert(examples_[idx], ngrams_[idx], rslt); 303 | } 304 | 305 | void InternDataHandler::getNextExample(ParseResults& rslt) { 306 | // assert(args_->nPeaks_total > 0); 307 | idx_ = idx_ + 1; 308 | // go back to the beginning of the examples if we reach the end 309 | if (idx_ >= args_->nPeaks_total) { 310 | idx_ = idx_ - args_->nPeaks_total; 311 | } 312 | convert(examples_[idx_], ngrams_[idx_], rslt); 313 | } 314 | 315 | void InternDataHandler::getRandomExample(ParseResults& rslt) const { 316 | // assert(args_->nPeaks_total > 0); 317 | int32_t idx = rand() % args_->nPeaks_total; 318 | convert(examples_[idx], ngrams_[idx], rslt); 319 | } 320 | 321 | void InternDataHandler::getKRandomExamples(int K, vector& c) { 322 | auto kSamples = min(K, (int)args_->nPeaks_total); 323 | for (int i = 0; i < kSamples; i++) { 324 | ParseResults example; 325 | getRandomExample(example); 326 | c.push_back(example); 327 | } 328 | } 329 | 330 | void InternDataHandler::getNextKExamples(int K, vector& c) { 331 | auto kSamples = min(K, (int)args_->nPeaks_total); 332 | for (int i = 0; i < kSamples; i++) { 333 | idx_ = (idx_ + 1) % args_->nPeaks_total; 334 | ParseResults example; 335 | convert(examples_[idx_], ngrams_[idx_], example); 336 | c.push_back(example); 337 | } 338 | } 339 | 340 | void InternDataHandler::getRandomWord(vector& result) { 341 | result.push_back(word_negatives_[word_iter_]); 342 | word_iter_++; 343 | if (word_iter_ >= (int)word_negatives_.size()) { 344 | word_iter_ = 0; 345 | } 346 | } 347 | 348 | void InternDataHandler::initWordNegatives() { 349 | word_iter_ = 0; 350 | word_negatives_.clear(); 351 | // assert(args_->nPeaks_total > 0); 352 | for (int i = 0; i < MAX_WORD_NEGATIVES_SIZE; i++) { 353 | word_negatives_.emplace_back(genRandomWord()); 354 | } 355 | } 356 | 357 | Base InternDataHandler::genRandomWord() const { 358 | // assert(args_->nPeaks_total > 0); 359 | auto& ex = examples_[rand() % args_->nPeaks_total]; 360 | int r = rand() % ex.LHSTokens.size(); 361 | return ex.LHSTokens[r]; 362 | } 363 | 364 | // Randomly sample one example and randomly sample a label from this example 365 | // The result is usually used as negative samples in training 366 | void InternDataHandler::getRandomRHS(vector& results, unsigned dataset) const { 367 | // assert(args_->nPeaks_total > 0); 368 | results.clear(); 369 | 370 | unsigned pi = args_->first_peak_idx[dataset] + (rand() % args_->nPeaks_list[dataset]); 371 | // std::cerr << args_->first_peak_idx[dataset] << " " << pi << " - "; 372 | auto& ex = examples_[pi]; 373 | unsigned int r = rand() % ex.RHSTokens.size(); 374 | if (args_->trainMode == 2) { 375 | for (unsigned int i = 0; i < ex.RHSTokens.size(); i++) { 376 | if (i != r) { 377 | results.push_back(ex.RHSTokens[i]); 378 | } 379 | } 380 | } else { 381 | results.push_back(ex.RHSTokens[r]); 382 | } 383 | } 384 | 385 | void InternDataHandler::save(std::ostream& out) { 386 | out << "data size : " << args_->nPeaks_total << endl; 387 | for (auto& example : examples_) { 388 | out << "lhs : "; 389 | for (auto t : example.LHSTokens) {out << t.first << ':' << t.second << ' ';} 390 | out << endl; 391 | out << "rhs : "; 392 | for (auto t : example.RHSTokens) {out << t.first << ':' << t.second << ' ';} 393 | out << endl; 394 | } 395 | } 396 | 397 | } // unamespace starspace 398 | -------------------------------------------------------------------------------- /cpp/src/data.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "dict.h" 11 | #include "parser.h" 12 | #include "utils/utils.h" 13 | #include 14 | #include 15 | #include 16 | 17 | namespace starspace { 18 | 19 | class InternDataHandler { 20 | public: 21 | explicit InternDataHandler(std::shared_ptr args, std::shared_ptr dict, std::shared_ptr parser); 22 | 23 | virtual void loadFromFile(); 24 | 25 | virtual void convert(const ParseResults& example, const std::vector& ngrams, ParseResults& rslt) const; 26 | 27 | virtual void getRandomRHS(std::vector& results, unsigned dataset) const; 28 | 29 | virtual void save(std::ostream& out); 30 | 31 | virtual void getWordExamples(int idx, std::vector& rslt) const; 32 | 33 | void add_kmer_tokens(int pi, std::string seq){ 34 | std::vector tokens; 35 | for(unsigned i = 0; i < seq.length() - args_->k + 1; i++){ 36 | std::string kmer = "", rc_kmer = ""; 37 | bool N = false; 38 | for(unsigned j = i; (j < i + args_->k) && !N; j++){ 39 | switch(seq[j]){ 40 | case 'A': 41 | case 'a': 42 | kmer = kmer + "A"; 43 | rc_kmer = "T" + rc_kmer; 44 | break; 45 | case 'C': 46 | case 'c': 47 | kmer = kmer + "C"; 48 | rc_kmer = "G" + rc_kmer; 49 | break; 50 | case 'T': 51 | case 't': 52 | kmer = kmer + "T"; 53 | rc_kmer = "A" + rc_kmer; 54 | break; 55 | case 'G': 56 | case 'g': 57 | kmer = kmer + "G"; 58 | rc_kmer = "C" + rc_kmer; 59 | break; 60 | default: 61 | N = true; 62 | i = j; 63 | } 64 | } 65 | if(!N){ 66 | int32_t kmer_wid = dict_->getId(kmer), rc_kmer_wid = dict_->getId(rc_kmer); 67 | if(!(kmer_wid < 0)){ 68 | tokens.push_back(kmer); 69 | examples_[pi].LHSTokens.push_back(std::make_pair(kmer_wid, 1.0)); 70 | } else if(!(rc_kmer_wid < 0)){ 71 | tokens.push_back(rc_kmer); 72 | examples_[pi].LHSTokens.push_back(std::make_pair(rc_kmer_wid, 1.0)); 73 | } else { 74 | std::cerr << "Invalid DNA k-mer! \'" << kmer << "\'" << std::endl; 75 | continue; 76 | } 77 | } 78 | } 79 | 80 | assert(examples_[pi].LHSTokens.size() > 0); 81 | 82 | if(args_->ngrams > 1){ 83 | parser_->addNgrams(tokens, ngrams_[pi], args_->ngrams); 84 | // for(auto token: tokens) hashes_[pi].push_back(dict_->hash(token)); 85 | } 86 | } 87 | 88 | void getWordExamples( 89 | const std::vector& doc, 90 | std::vector& rslt) const; 91 | 92 | void addExample(const ParseResults& example); 93 | 94 | void addCellLabel(unsigned peak, unsigned cell){ 95 | int32_t wid = dict_->getId(args_->label + "C" + std::to_string(cell)); 96 | assert(wid >= 0); 97 | examples_[peak].RHSTokens.push_back(std::make_pair(wid, 1.0)); 98 | } 99 | 100 | void getExampleById(int32_t idx, ParseResults& rslt) const; 101 | 102 | void getNextExample(ParseResults& rslt); 103 | 104 | void getRandomExample(ParseResults& rslt) const; 105 | 106 | void getKRandomExamples(int K, std::vector& c); 107 | 108 | void getNextKExamples(int K, std::vector& c); 109 | 110 | size_t getSize() const { return size_; }; 111 | 112 | void errorOnZeroExample(const std::string& fileName); 113 | 114 | void initWordNegatives(); 115 | void getRandomWord(std::vector& result); 116 | 117 | 118 | protected: 119 | virtual Base genRandomWord() const; 120 | 121 | static const int32_t MAX_VOCAB_SIZE = 10000000; 122 | static const int32_t MAX_WORD_NEGATIVES_SIZE = 10000000; 123 | 124 | std::shared_ptr args_; 125 | std::shared_ptr dict_; 126 | std::shared_ptr parser_; 127 | std::vector examples_; 128 | std::vector> ngrams_; 129 | // std::vector> hashes_; 130 | 131 | int32_t idx_ = -1; 132 | int32_t size_ = 0; 133 | 134 | int32_t word_iter_; 135 | std::vector word_negatives_; 136 | }; 137 | 138 | } 139 | -------------------------------------------------------------------------------- /cpp/src/dict.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include "dict.h" 9 | #include "parser.h" 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | using namespace std; 19 | using namespace boost::iostreams; 20 | 21 | namespace starspace { 22 | 23 | const std::string Dictionary::EOS = ""; 24 | const uint32_t Dictionary::HASH_C = 116049371; 25 | 26 | Dictionary::Dictionary(shared_ptr args) : args_(args), 27 | hashToIndex_(MAX_VOCAB_SIZE, -1), size_(0), nwords_(0), nlabels_(0), 28 | ntokens_(0) 29 | { 30 | entryList_.clear(); 31 | } 32 | 33 | // hash trick from fastText 34 | uint32_t Dictionary::hash(const std::string& str) const { 35 | uint32_t h = 2166136261; 36 | for (size_t i = 0; i < str.size(); i++) { 37 | h = h ^ uint32_t(str[i]); 38 | h = h * 16777619; 39 | } 40 | return h; 41 | } 42 | 43 | int32_t Dictionary::find(const std::string& w) const { 44 | int32_t h = hash(w) % MAX_VOCAB_SIZE; 45 | while (hashToIndex_[h] != -1 && entryList_[hashToIndex_[h]].symbol != w) { 46 | h = (h + 1) % MAX_VOCAB_SIZE; 47 | } 48 | return h; 49 | } 50 | 51 | int32_t Dictionary::getId(const string& symbol) const { 52 | int32_t h = find(symbol); 53 | return hashToIndex_[h]; 54 | } 55 | 56 | const std::string& Dictionary::getSymbol(int32_t id) const { 57 | assert(id >= 0); 58 | assert(id < size_); 59 | return entryList_[id].symbol; 60 | } 61 | 62 | const std::string& Dictionary::getLabel(int32_t lid) const { 63 | assert(lid >= 0); 64 | assert(lid < nlabels_); 65 | return entryList_[lid + nwords_].symbol; 66 | } 67 | 68 | entry_type Dictionary::getType(int32_t id) const { 69 | assert(id >= 0); 70 | assert(id < size_); 71 | return entryList_[id].type; 72 | } 73 | 74 | entry_type Dictionary::getType(const string& w) const { 75 | return (w.find(args_->label) == 0)? entry_type::label : entry_type::word; 76 | } 77 | 78 | void Dictionary::insert(const string& symbol) { 79 | int32_t h = find(symbol); 80 | ntokens_++; 81 | if (hashToIndex_[h] == -1) { 82 | entry e; 83 | e.symbol = symbol; 84 | e.count = 1; 85 | e.type = getType(symbol); 86 | entryList_.push_back(e); 87 | hashToIndex_[h] = size_++; 88 | } else { 89 | entryList_[hashToIndex_[h]].count++; 90 | } 91 | } 92 | 93 | void Dictionary::save(std::ostream& out) const { 94 | out.write((char*) &size_, sizeof(int32_t)); 95 | out.write((char*) &nwords_, sizeof(int32_t)); 96 | out.write((char*) &nlabels_, sizeof(int32_t)); 97 | out.write((char*) &ntokens_, sizeof(int64_t)); 98 | for (int32_t i = 0; i < size_; i++) { 99 | entry e = entryList_[i]; 100 | out.write(e.symbol.data(), e.symbol.size() * sizeof(char)); 101 | out.put(0); 102 | out.write((char*) &(e.count), sizeof(int64_t)); 103 | out.write((char*) &(e.type), sizeof(entry_type)); 104 | } 105 | } 106 | 107 | void Dictionary::load(std::istream& in) { 108 | entryList_.clear(); 109 | std::fill(hashToIndex_.begin(), hashToIndex_.end(), -1); 110 | in.read((char*) &size_, sizeof(int32_t)); 111 | in.read((char*) &nwords_, sizeof(int32_t)); 112 | in.read((char*) &nlabels_, sizeof(int32_t)); 113 | in.read((char*) &ntokens_, sizeof(int64_t)); 114 | for (int32_t i = 0; i < size_; i++) { 115 | char c; 116 | entry e; 117 | while ((c = in.get()) != 0) { 118 | e.symbol.push_back(c); 119 | } 120 | in.read((char*) &e.count, sizeof(int64_t)); 121 | in.read((char*) &e.type, sizeof(entry_type)); 122 | entryList_.push_back(e); 123 | hashToIndex_[find(e.symbol)] = i; 124 | } 125 | } 126 | 127 | /* Build dictionary from file. 128 | * In dictionary building process, if the current dictionary is at 75% capacity, 129 | * it automatically increases the threshold for both word and label. 130 | * At the end the -minCount and -minCountLabel from arguments will be applied 131 | * as thresholds. 132 | */ 133 | void Dictionary::readFromFile( 134 | const std::string& file, 135 | shared_ptr parser) { 136 | 137 | int64_t minThreshold = 1; 138 | size_t lines_read = 0; 139 | 140 | auto readFromInputStream = [&](std::istream& in) { 141 | string line; 142 | while (getline(in, line, '\n')) { 143 | vector tokens; 144 | parser->parseForDict(line, tokens); 145 | lines_read++; 146 | for (auto token : tokens) { 147 | insert(token); 148 | if ((ntokens_ % 1000000 == 0) && args_->verbose) { 149 | std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::flush; 150 | } 151 | if (size_ > 0.75 * MAX_VOCAB_SIZE) { 152 | minThreshold++; 153 | threshold(minThreshold, minThreshold); 154 | } 155 | } 156 | } 157 | }; 158 | 159 | #ifdef COMPRESS_FILE 160 | if (args_->compressFile == "gzip") { 161 | cout << "Build dict from compressed input file.\n"; 162 | for (int i = 0; i < args_->numGzFile; i++) { 163 | filtering_istream in; 164 | auto str_idx = boost::str(boost::format("%02d") % i); 165 | auto fname = file + str_idx + ".gz"; 166 | ifstream ifs(fname); 167 | if (!ifs.good()) { 168 | continue; 169 | } 170 | in.push(gzip_decompressor()); 171 | in.push(ifs); 172 | readFromInputStream(in); 173 | ifs.close(); 174 | } 175 | } else { 176 | cout << "Build dict from input file : " << file << endl; 177 | ifstream fin(file); 178 | if (!fin.is_open()) { 179 | cerr << "Input file cannot be opened!" << endl; 180 | exit(EXIT_FAILURE); 181 | } 182 | readFromInputStream(fin); 183 | fin.close(); 184 | } 185 | #else 186 | cout << "Build dict from input file : " << file << endl; 187 | ifstream fin(file); 188 | if (!fin.is_open()) { 189 | cerr << "Input file cannot be opened!" << endl; 190 | exit(EXIT_FAILURE); 191 | } 192 | readFromInputStream(fin); 193 | fin.close(); 194 | #endif 195 | 196 | threshold(args_->minCount, args_->minCountLabel); 197 | 198 | std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::endl; 199 | std::cerr << "Number of words (" << args_->k << "-mers) in dictionary: " << nwords_ << std::endl; 200 | std::cerr << "Number of labels in dictionary: " << nlabels_ << std::endl; 201 | if (lines_read == 0) { 202 | std::cerr << "ERROR: Empty file." << std::endl; 203 | exit(EXIT_FAILURE); 204 | } 205 | if (size_ == 0) { 206 | std::cerr << "Empty vocabulary. Try a smaller -minCount value." 207 | << std::endl; 208 | exit(EXIT_FAILURE); 209 | } 210 | } 211 | 212 | void Dictionary::addKmers(unsigned k, int64_t &minThreshold, unsigned len = 0, 213 | string seq = "", string rc_seq = ""){ 214 | if(len == k){ 215 | int32_t h = find(rc_seq); 216 | if(hashToIndex_[h] == -1){ 217 | insert(seq); 218 | if(size_ > 0.75 * MAX_VOCAB_SIZE){ 219 | minThreshold++; 220 | threshold(minThreshold, minThreshold); 221 | } 222 | } 223 | } else { 224 | addKmers(k, minThreshold, len + 1, seq + "A", "T" + rc_seq); 225 | addKmers(k, minThreshold, len + 1, seq + "C", "G" + rc_seq); 226 | addKmers(k, minThreshold, len + 1, seq + "T", "A" + rc_seq); 227 | addKmers(k, minThreshold, len + 1, seq + "G", "C" + rc_seq); 228 | } 229 | } 230 | 231 | void Dictionary::addCells(unsigned nCells, int64_t &minThreshold, string label){ 232 | for (unsigned ci = 0; ci < nCells; ci ++) { 233 | insert(label + "C" + to_string(ci + 1)); 234 | if (size_ > 0.75 * MAX_VOCAB_SIZE) { 235 | minThreshold++; 236 | threshold(minThreshold, minThreshold); 237 | } 238 | } 239 | } 240 | 241 | void Dictionary::addBatches(unsigned nBatches, int64_t &minThreshold, string label){ 242 | for (unsigned bi = 0; bi < nBatches; bi ++) { 243 | insert(label + "B" + to_string(bi + 1)); 244 | if (size_ > 0.75 * MAX_VOCAB_SIZE) { 245 | minThreshold++; 246 | threshold(minThreshold, minThreshold); 247 | } 248 | } 249 | } 250 | 251 | void Dictionary::CreateForCellSpace(shared_ptr parser){ 252 | int64_t minThreshold = 1; 253 | addKmers(args_->k, minThreshold); // Add k-mer IDs as words 254 | addCells(args_->nCells_total, minThreshold, args_->label); // Add cell IDs as labels 255 | // addPeaks(args_->nPeaks_total, minThreshold, args_->label); // Add peak IDs as labels 256 | // if(args_->batchLabels) addBatches(args_->nBatches, minThreshold, args_->label); // Add batch IDs as labels 257 | 258 | threshold(args_->minCount, args_->minCountLabel); 259 | 260 | // std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::endl; 261 | std::cout << "Number of words (" << args_->k << "-mers) in dictionary: " << nwords_ << std::endl 262 | << "Number of labels in dictionary: " << nlabels_ << std::endl; 263 | } 264 | 265 | // Sort the dictionary by [word, label] order and by number of occurance. 266 | // Removes word / label that does not pass respective threshold. 267 | void Dictionary::threshold(int64_t t, int64_t tl) { 268 | sort(entryList_.begin(), entryList_.end(), [](const entry& e1, const entry& e2) { 269 | if (e1.type != e2.type) return e1.type < e2.type; 270 | return e1.count > e2.count; 271 | }); 272 | entryList_.erase(remove_if(entryList_.begin(), entryList_.end(), [&](const entry& e) { 273 | return (e.type == entry_type::word && e.count < t) || 274 | (e.type == entry_type::label && e.count < tl); 275 | }), entryList_.end()); 276 | 277 | entryList_.shrink_to_fit(); 278 | 279 | computeCounts(); 280 | } 281 | 282 | void Dictionary::computeCounts() { 283 | size_ = 0; 284 | nwords_ = 0; 285 | nlabels_ = 0; 286 | std::fill(hashToIndex_.begin(), hashToIndex_.end(), -1); 287 | for (auto it = entryList_.begin(); it != entryList_.end(); ++it) { 288 | int32_t h = find(it->symbol); 289 | hashToIndex_[h] = size_++; 290 | if (it->type == entry_type::word) nwords_++; 291 | if (it->type == entry_type::label) nlabels_++; 292 | } 293 | } 294 | 295 | // Given a model saved in .tsv format, build the dictionary from model. 296 | void Dictionary::loadDictFromModel(const string& modelfile) { 297 | cout << "Loading dict from model file : " << modelfile << endl; 298 | ifstream fin(modelfile); 299 | string line; 300 | while (getline(fin, line)) { 301 | string symbol; 302 | stringstream ss(line); 303 | ss >> symbol; 304 | insert(symbol); 305 | } 306 | fin.close(); 307 | computeCounts(); 308 | 309 | std::cout << "Number of words in dictionary: " << nwords_ << std::endl; 310 | std::cout << "Number of labels in dictionary: " << nlabels_ << std::endl; 311 | } 312 | 313 | } // namespace 314 | -------------------------------------------------------------------------------- /cpp/src/dict.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | /** 9 | * The implementation of dictionary here is very similar to the dictionary used 10 | * in fastText (https://github.com/facebookresearch/fastText). 11 | */ 12 | 13 | #pragma once 14 | 15 | #include "utils/args.h" 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | #ifdef COMPRESS_FILE 26 | #include 27 | #include 28 | #endif 29 | 30 | namespace starspace { 31 | 32 | class DataParser; 33 | 34 | enum class entry_type : int8_t {word=0, label=1}; 35 | 36 | struct entry { 37 | std::string symbol; 38 | int64_t count; 39 | entry_type type; 40 | }; 41 | 42 | class Dictionary { 43 | public: 44 | static const std::string EOS; 45 | static const uint32_t HASH_C; 46 | 47 | explicit Dictionary(std::shared_ptr); 48 | int32_t size() const { return size_; }; 49 | int32_t nwords() const { return nwords_; }; 50 | int32_t nlabels() const { return nlabels_; }; 51 | int32_t ntokens() const { return ntokens_; }; 52 | int32_t getId(const std::string&) const; 53 | entry_type getType(int32_t) const; 54 | entry_type getType(const std::string&) const; 55 | const std::string& getSymbol(int32_t) const; 56 | const std::string& getLabel(int32_t) const; 57 | 58 | uint32_t hash(const std::string& str) const; 59 | void insert(const std::string&); 60 | 61 | void load(std::istream&); 62 | void save(std::ostream&) const; 63 | void readFromFile(const std::string&, std::shared_ptr); 64 | bool readWord(std::istream&, std::string&) const; 65 | 66 | void threshold(int64_t, int64_t); 67 | void computeCounts(); 68 | void loadDictFromModel(const std::string& model); 69 | 70 | void addKmers(unsigned, int64_t&, unsigned, std::string, std::string); 71 | void addCells(unsigned, int64_t&, std::string); 72 | void addBatches(unsigned, int64_t&, std::string); 73 | void CreateForCellSpace(std::shared_ptr); 74 | 75 | private: 76 | static const int32_t MAX_VOCAB_SIZE = 30000000; 77 | 78 | int32_t find(const std::string&) const; 79 | 80 | void addNgrams( 81 | std::vector& line, 82 | const std::vector& hashes, 83 | int32_t n) const; 84 | 85 | std::shared_ptr args_; 86 | std::vector entryList_; 87 | std::vector hashToIndex_; 88 | 89 | int32_t size_; 90 | int32_t nwords_; 91 | int32_t nlabels_; 92 | int64_t ntokens_; 93 | }; 94 | 95 | } 96 | -------------------------------------------------------------------------------- /cpp/src/main.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include "starspace.h" 9 | #include 10 | #include 11 | 12 | using namespace std; 13 | using namespace starspace; 14 | 15 | int main(int argc, char** argv) { 16 | shared_ptr args = make_shared(); 17 | args->parseArgs(argc, argv); 18 | args->printArgs(); 19 | 20 | StarSpace sp(args); 21 | if (args->isTrain) { 22 | // if (!args->initModel.empty()) { 23 | // if (boost::algorithm::ends_with(args->initModel, ".tsv")) { 24 | // sp.initFromTsv(args->initModel); 25 | // } else { 26 | // sp.initFromSavedModel(args->initModel); 27 | // cout << "------Loaded model args:\n"; 28 | // args->printArgs(); 29 | // } 30 | // } else { 31 | sp.init(); 32 | // } 33 | sp.train(); 34 | // sp.saveModel(args->model); 35 | sp.saveModelTsv(args->model + ".tsv"); 36 | } 37 | // else { 38 | // if (boost::algorithm::ends_with(args->model, ".tsv")) { 39 | // sp.initFromTsv(args->model); 40 | // } else { 41 | // sp.initFromSavedModel(args->model); 42 | // cout << "------Loaded model args:\n"; 43 | // args->printArgs(); 44 | // } 45 | // sp.evaluate(); 46 | // } 47 | 48 | return 0; 49 | } 50 | -------------------------------------------------------------------------------- /cpp/src/matrix.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | /** 9 | * Mostly a collection of convenience routines around ublas. 10 | * We avoid doing any actual compute-intensive work in this file. 11 | */ 12 | 13 | #pragma once 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | #include 24 | #include 25 | #include 26 | 27 | namespace starspace { 28 | 29 | struct MatrixDims { 30 | size_t r, c; 31 | size_t numElts() const { return r * c; } 32 | bool operator==(const MatrixDims& rhs) { 33 | return r == rhs.r && c == rhs.c; 34 | } 35 | }; 36 | 37 | template 38 | struct Matrix { 39 | static const int kAlign = 64; 40 | boost::numeric::ublas::matrix matrix; 41 | 42 | explicit Matrix(MatrixDims dims, 43 | Real sd = 1.0) : 44 | matrix(dims.r, dims.c) 45 | { 46 | assert(matrix.size1() == dims.r); 47 | assert(matrix.size2() == dims.c); 48 | if (sd > 0.0) { 49 | randomInit(sd); 50 | } 51 | } 52 | 53 | explicit Matrix(const std::vector>& init) { 54 | size_t rows = init.size(); 55 | size_t maxCols = 0; 56 | for (const auto& r : init) { 57 | maxCols = std::max(maxCols, r.size()); 58 | } 59 | alloc(rows, maxCols); 60 | for (size_t i = 0; i < numRows(); i++) { 61 | size_t j; 62 | for (j = 0; j < init[i].size(); j++) { 63 | (*this)[i][j] = init[i][j]; 64 | } 65 | for (; j < numCols(); j++) { 66 | (*this)[i][j] = 0.0; 67 | } 68 | } 69 | } 70 | 71 | explicit Matrix(std::istream& in) { 72 | in >> matrix; 73 | } 74 | 75 | Matrix() { 76 | alloc(0, 0); 77 | } 78 | 79 | Real* operator[](size_t i) { 80 | assert(i >= 0); 81 | assert(i < numRows()); 82 | return &matrix(i, 0); 83 | } 84 | 85 | const Real* operator[](size_t i) const { 86 | assert(i >= 0); 87 | assert(i < numRows()); 88 | return &matrix(i, 0); 89 | } 90 | 91 | Real& cell(size_t i, size_t j) { 92 | assert(i >= 0); 93 | assert(i < numRows()); 94 | assert(j < numCols()); 95 | assert(j >= 0); 96 | return matrix(i, j); 97 | } 98 | 99 | void add(const Matrix& rhs, Real scale = 1.0) { 100 | matrix += scale * rhs.matrix; 101 | } 102 | 103 | void forEachCell(std::function l) { 104 | for (size_t i = 0; i < numRows(); i++) 105 | for (size_t j = 0; j < numCols(); j++) 106 | l(matrix(i, j)); 107 | } 108 | 109 | void forEachCell(std::function l) const { 110 | for (size_t i = 0; i < numRows(); i++) 111 | for (size_t j = 0; j < numCols(); j++) 112 | l(matrix(i, j)); 113 | } 114 | 115 | void forEachCell(std::function l) { 116 | for (size_t i = 0; i < numRows(); i++) 117 | for (size_t j = 0; j < numCols(); j++) 118 | l(matrix(i, j), i, j); 119 | } 120 | 121 | void forEachCell(std::function l) const { 122 | for (size_t i = 0; i < numRows(); i++) 123 | for (size_t j = 0; j < numCols(); j++) 124 | l(matrix(i, j), i, j); 125 | } 126 | 127 | void sanityCheck() const { 128 | #ifndef NDEBUG 129 | forEachCell([&](Real r, size_t i, size_t j) { 130 | assert(!std::isnan(r)); 131 | assert(!std::isinf(r)); 132 | }); 133 | #endif 134 | } 135 | 136 | void forRow(size_t r, std::function l) { 137 | for (size_t j = 0; j < numCols(); j++) l(matrix(r, j), j); 138 | } 139 | 140 | void forRow(size_t r, std::function l) const { 141 | for (size_t j = 0; j < numCols(); j++) l(matrix(r, j), j); 142 | } 143 | 144 | void forCol(size_t c, std::function l) { 145 | for (size_t i = 0; i < numRows(); i++) l(matrix(i, c), i); 146 | } 147 | 148 | void forCol(size_t c, std::function l) const { 149 | for (size_t i = 0; i < numRows(); i++) l(matrix(c, i), i); 150 | } 151 | 152 | static void mul(const Matrix& l, const Matrix& r, Matrix& dest) { 153 | dest.matrix = boost::numeric::ublas::prod(l.matrix, r.matrix); 154 | } 155 | 156 | void updateRow(size_t r, Matrix& addend, Real scale = 1.0) { 157 | using namespace boost::numeric::ublas; 158 | assert(addend.numRows() == 1); 159 | assert(addend.numCols() == numCols()); 160 | row(r) += Row { addend.matrix, 0 } * scale; 161 | } 162 | 163 | typedef boost::numeric::ublas::matrix_row> 164 | Row; 165 | Row row(size_t r) { return Row{ matrix, r }; } 166 | 167 | /* implicit */ operator Row() { 168 | assert(numRows() == 1); 169 | return Row{ matrix, 0 }; 170 | } 171 | 172 | size_t numElts() const { return numRows() * numCols(); } 173 | size_t numRows() const { return matrix.size1(); } 174 | size_t numCols() const { return matrix.size2(); } 175 | MatrixDims getDims() const { return { numRows(), numCols() }; } 176 | 177 | void reshape(MatrixDims dims) { 178 | if (dims == getDims()) return; 179 | alloc(dims.r, dims.c); 180 | } 181 | 182 | typedef size_t iterator; 183 | iterator begin() { return 0; } 184 | iterator end() { return numElts(); } 185 | 186 | void write(std::ostream& out) { 187 | out << matrix; 188 | } 189 | 190 | void randomInit(Real sd = 1.0) { 191 | if (numElts() > 0) { 192 | // Multi-threaded initialization brings debug init time down 193 | // from minutes to seconds. 194 | auto d = &matrix(0, 0); 195 | std::minstd_rand gen; 196 | auto nd = std::normal_distribution(0, sd); 197 | for (size_t i = 0; i < numElts(); i++) { 198 | d[i] = nd(gen); 199 | }; 200 | } 201 | } 202 | 203 | private: 204 | void alloc(size_t r, size_t c) { 205 | matrix = boost::numeric::ublas::matrix(r, c); 206 | } 207 | }; 208 | 209 | } 210 | -------------------------------------------------------------------------------- /cpp/src/model.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "matrix.h" 11 | #include "proj.h" 12 | #include "dict.h" 13 | #include "utils/normalize.h" 14 | #include "utils/args.h" 15 | #include "data.h" 16 | // #include "doc_data.h" 17 | 18 | #include 19 | #include 20 | #include 21 | 22 | 23 | namespace starspace { 24 | 25 | typedef float Real; 26 | typedef boost::numeric::ublas::matrix_row::matrix)> 27 | MatrixRow; 28 | typedef boost::numeric::ublas::vector Vector; 29 | 30 | /* 31 | * The model is basically two lookup tables: one for left hand side 32 | * (LHS) entities, one for right hand side (RHS) entities. 33 | */ 34 | struct EmbedModel : public boost::noncopyable { 35 | public: 36 | explicit EmbedModel(std::shared_ptr args, 37 | std::shared_ptr dict); 38 | 39 | 40 | typedef std::vector Corpus; 41 | float train(std::shared_ptr data, 42 | int numThreads, 43 | std::chrono::time_point t_start, 44 | int epochs_done, 45 | Real startRate, 46 | Real endRate, 47 | bool verbose = true); 48 | 49 | float test(std::shared_ptr data, int numThreads) { 50 | return this->train(data, numThreads, 51 | std::chrono::high_resolution_clock::now(), 0, 52 | 0.0, 0.0, false); 53 | } 54 | 55 | float trainOneBatch(std::shared_ptr data, 56 | const std::vector& batch_exs, 57 | size_t negSearchLimits, 58 | Real rate, 59 | bool trainWord = false); 60 | 61 | float trainNLLBatch(std::shared_ptr data, 62 | const std::vector& batch_exs, 63 | int32_t negSearchLimit, 64 | Real rate, 65 | bool trainWord = false); 66 | 67 | void backward(const std::vector& batch_exs, 68 | const std::vector>>& negLabels, 69 | std::vector> gradW, 70 | std::vector> lhs, 71 | const std::vector& num_negs, 72 | Real rate_lhs, 73 | const std::vector& rate_rhsP, 74 | const std::vector>& nRate, 75 | std::vector dataset); 76 | 77 | // Querying 78 | std::vector> 79 | kNN(std::shared_ptr> lookup, 80 | Matrix point, 81 | int numSim); 82 | 83 | std::vector> 84 | findLHSLike(Matrix point, int numSim = 5) { 85 | return kNN(LHSEmbeddings_, point, numSim); 86 | } 87 | 88 | std::vector> 89 | findRHSLike(Matrix point, int numSim = 5) { 90 | return kNN(RHSEmbeddings_, point, numSim); 91 | } 92 | 93 | Matrix projectRHS(const std::vector& ws); 94 | Matrix projectLHS(const std::vector& ws); 95 | 96 | void projectLHS(const std::vector& ws, Matrix& retval); 97 | void projectRHS(const std::vector& ws, Matrix& retval); 98 | 99 | void loadTsv(std::istream& in, const std::string sep = "\t "); 100 | void loadTsv(const char* fname, const std::string sep = "\t "); 101 | void loadTsv(const std::string& fname, const std::string sep = "\t ") { 102 | return loadTsv(fname.c_str(), sep); 103 | } 104 | void loadFeatEmb(const std::string& fname, const std::string sep = "\t"); 105 | 106 | void saveTsv(std::ostream& out, const char sep = '\t') const; 107 | 108 | void save(std::ostream& out) const; 109 | 110 | void load(std::ifstream& in); 111 | 112 | const std::string& lookupLHS(int32_t idx) const { 113 | return dict_->getSymbol(idx); 114 | } 115 | const std::string& lookupRHS(int32_t idx) const { 116 | return dict_->getLabel(idx); 117 | } 118 | 119 | void loadTsvLine(std::string& line, int lineNum, int cols, 120 | const std::string sep = "\t"); 121 | 122 | std::shared_ptr getDict() { return dict_; } 123 | 124 | std::shared_ptr>& getLHSEmbeddings() { 125 | return LHSEmbeddings_; 126 | } 127 | const std::shared_ptr>& getLHSEmbeddings() const { 128 | return LHSEmbeddings_; 129 | } 130 | std::shared_ptr>& getRHSEmbeddings() { 131 | return RHSEmbeddings_; 132 | } 133 | const std::shared_ptr>& getRHSEmbeddings() const { 134 | return RHSEmbeddings_; 135 | } 136 | 137 | void initModelWeights(); 138 | 139 | Real similarity(const MatrixRow& a, const MatrixRow& b); 140 | Real similarity(Matrix& a, Matrix& b) { 141 | return similarity(asRow(a), asRow(b)); 142 | } 143 | 144 | static Real cosine(const MatrixRow& a, const MatrixRow& b); 145 | static Real cosine(Matrix& a, Matrix& b) { 146 | return cosine(asRow(a), asRow(b)); 147 | } 148 | 149 | static MatrixRow asRow(Matrix& m) { 150 | assert(m.numRows() == 1); 151 | return MatrixRow(m.matrix, 0); 152 | } 153 | 154 | static void normalize(Matrix::Row row, double maxNorm = 1.0); 155 | static void normalize(Matrix& m) { normalize(asRow(m)); } 156 | 157 | private: 158 | std::shared_ptr dict_; 159 | std::shared_ptr> LHSEmbeddings_; 160 | std::shared_ptr> RHSEmbeddings_; 161 | std::shared_ptr args_; 162 | 163 | std::vector LHSUpdates_; 164 | std::vector RHSUpdates_; 165 | 166 | #ifdef NDEBUG 167 | static const bool debug = false; 168 | #else 169 | static const bool debug = false; 170 | #endif 171 | 172 | static void check(const Matrix& m) { 173 | m.sanityCheck(); 174 | } 175 | 176 | static void check(const boost::numeric::ublas::matrix& m) { 177 | if (!debug) return; 178 | for (unsigned int i = 0; i < m.size1(); i++) { 179 | for (unsigned int j = 0; j < m.size2(); j++) { 180 | assert(!std::isnan(m(i, j))); 181 | assert(!std::isinf(m(i, j))); 182 | } 183 | } 184 | } 185 | 186 | }; 187 | 188 | } 189 | -------------------------------------------------------------------------------- /cpp/src/parser.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | 9 | #include "parser.h" 10 | #include "utils/normalize.h" 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | using namespace std; 19 | 20 | namespace starspace { 21 | 22 | void chomp(std::string& line, char toChomp = '\n') { 23 | auto sz = line.size(); 24 | if (sz >= 1 && line[sz - 1] == toChomp) { 25 | line.resize(sz - 1); 26 | } 27 | } 28 | 29 | DataParser::DataParser( 30 | shared_ptr dict, 31 | shared_ptr args) { 32 | dict_ = dict; 33 | args_ = args; 34 | } 35 | 36 | bool DataParser::parse( 37 | std::string& s, 38 | ParseResults& rslts, 39 | const string& sep) { 40 | 41 | chomp(s); 42 | vector toks; 43 | boost::split(toks, s, boost::is_any_of(string(sep))); 44 | 45 | return parse(toks, rslts); 46 | } 47 | 48 | void DataParser::parseForDict( 49 | string& line, 50 | vector& tokens, 51 | const string& sep) { 52 | 53 | chomp(line); 54 | vector toks; 55 | boost::split(toks, line, boost::is_any_of(string(sep))); 56 | for (unsigned int i = 0; i < toks.size(); i++) { 57 | string token = toks[i]; 58 | if (args_->useWeight) { 59 | std::size_t pos = toks[i].find(args_->weightSep); 60 | if (pos != std::string::npos) { 61 | token = toks[i].substr(0, pos); 62 | } 63 | } 64 | if (args_->normalizeText) { 65 | normalize_text(token); 66 | } 67 | if (token.find("__weight__") == std::string::npos) { 68 | tokens.push_back(token); 69 | } 70 | } 71 | } 72 | 73 | // check wether it is a valid example 74 | bool DataParser::check(const ParseResults& example) { 75 | if (args_->trainMode == 0) { 76 | // require lhs and rhs 77 | return !example.RHSTokens.empty() && !example.LHSTokens.empty(); 78 | } if (args_->trainMode == 5) { 79 | // only requires lhs. 80 | return !example.LHSTokens.empty(); 81 | } else { 82 | // lhs is not required, but rhs should contain at least 2 example 83 | return example.RHSTokens.size() > 1; 84 | } 85 | } 86 | 87 | void DataParser::addNgrams( 88 | const std::vector& tokens, 89 | std::vector& line, 90 | int n) 91 | { 92 | vector hashes; 93 | 94 | for (auto token: tokens) { 95 | entry_type type = dict_->getType(token); 96 | if (type == entry_type::word) { 97 | hashes.push_back(dict_->hash(token)); 98 | } 99 | } 100 | 101 | for (int32_t i = 0; i < (int32_t)(hashes.size()); i++){ 102 | uint64_t h = hashes[i]; 103 | for (int32_t j = i + 1; j < (int32_t)(hashes.size()) && j < i + n; j++){ 104 | h = h * Dictionary::HASH_C + hashes[j]; 105 | int64_t id = h % args_->bucket; 106 | line.push_back(make_pair(dict_->nwords() + dict_->nlabels() + id, 1.0)); 107 | // std::cerr << "i=" << i << "\tj=" << j << "\tngram=" << line.size() - 1 << std::endl; 108 | } 109 | } 110 | } 111 | 112 | bool DataParser::parse( 113 | const std::vector& tokens, 114 | ParseResults& rslts) { 115 | 116 | for (auto &token: tokens) { 117 | if (token.find("__weight__") != std::string::npos) { 118 | std::size_t pos = token.find(args_->weightSep); 119 | if (pos != std::string::npos) { 120 | rslts.weight = atof(token.substr(pos + 1).c_str()); 121 | } 122 | continue; 123 | } 124 | string t = token; 125 | float weight = 1.0; 126 | if (args_->useWeight) { 127 | std::size_t pos = token.find(args_->weightSep); 128 | if (pos != std::string::npos) { 129 | t = token.substr(0, pos); 130 | weight = atof(token.substr(pos + 1).c_str()); 131 | } 132 | } 133 | 134 | if (args_->normalizeText) { 135 | normalize_text(t); 136 | } 137 | int32_t wid = dict_->getId(t); 138 | if (wid < 0) { 139 | continue; 140 | } 141 | 142 | entry_type type = dict_->getType(wid); 143 | if (type == entry_type::word) { 144 | rslts.LHSTokens.push_back(make_pair(wid, weight)); 145 | } 146 | if (type == entry_type::label) { 147 | rslts.RHSTokens.push_back(make_pair(wid, weight)); 148 | } 149 | } 150 | 151 | if (args_->ngrams > 1) { 152 | addNgrams(tokens, rslts.LHSTokens, args_->ngrams); 153 | } 154 | 155 | return check(rslts); 156 | } 157 | 158 | bool DataParser::parse( 159 | const std::vector& tokens, 160 | vector& rslts) { 161 | 162 | for (auto &token: tokens) { 163 | auto t = token; 164 | float weight = 1.0; 165 | if (args_->useWeight) { 166 | std::size_t pos = token.find(args_->weightSep); 167 | if (pos != std::string::npos) { 168 | t = token.substr(0, pos); 169 | weight = atof(token.substr(pos + 1).c_str()); 170 | } 171 | } 172 | 173 | if (args_->normalizeText) { 174 | normalize_text(t); 175 | } 176 | int32_t wid = dict_->getId(t); 177 | if (wid < 0) { 178 | continue; 179 | } 180 | 181 | //entry_type type = dict_->getType(wid); 182 | rslts.push_back(make_pair(wid, weight)); 183 | } 184 | 185 | if (args_->ngrams > 1) { 186 | addNgrams(tokens, rslts, args_->ngrams); 187 | } 188 | return rslts.size() > 0; 189 | } 190 | 191 | } // namespace starspace 192 | -------------------------------------------------------------------------------- /cpp/src/parser.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | 9 | /** 10 | * This is the basic class of data parsing. 11 | * It provides essential functions as follows: 12 | * - parse(input, output): 13 | * takes input as a line of string (or a vector of string tokens) 14 | * and return output result which is one example contains l.h.s. features 15 | * and r.h.s. features. 16 | * 17 | * - parseForDict(input, tokens): 18 | * takes input as a line of string, output tokens to be added for building 19 | * the dictionary. 20 | * 21 | * - check(example): 22 | * checks whether the example is a valid example. 23 | * 24 | * - addNgrams(input, output): 25 | * add ngrams from input as output. 26 | * 27 | * One can write different parsers for data with different format. 28 | */ 29 | 30 | #pragma once 31 | 32 | #include "dict.h" 33 | #include 34 | #include 35 | 36 | namespace starspace { 37 | 38 | typedef std::pair Base; 39 | 40 | struct ParseResults { 41 | unsigned dataset; 42 | float weight = 1.0; 43 | std::vector LHSTokens; 44 | std::vector RHSTokens; 45 | std::vector> RHSFeatures; 46 | }; 47 | 48 | typedef std::vector Corpus; 49 | 50 | class DataParser { 51 | public: 52 | explicit DataParser( 53 | std::shared_ptr dict, 54 | std::shared_ptr args); 55 | 56 | virtual bool parse( 57 | std::string& s, 58 | ParseResults& rslt, 59 | const std::string& sep="\t "); 60 | 61 | virtual void parseForDict( 62 | std::string& s, 63 | std::vector& tokens, 64 | const std::string& sep="\t "); 65 | 66 | bool parse( 67 | const std::vector& tokens, 68 | std::vector& rslt); 69 | 70 | bool parse( 71 | const std::vector& tokens, 72 | ParseResults& rslt); 73 | 74 | bool check(const ParseResults& example); 75 | 76 | void addNgrams( 77 | const std::vector& tokens, 78 | std::vector& line, 79 | int32_t n); 80 | 81 | std::shared_ptr getDict() { return dict_; }; 82 | 83 | void resetDict(std::shared_ptr dict) { dict_ = dict; }; 84 | 85 | protected: 86 | std::shared_ptr dict_; 87 | std::shared_ptr args_; 88 | }; 89 | 90 | } 91 | -------------------------------------------------------------------------------- /cpp/src/proj.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | 9 | #include "proj.h" 10 | -------------------------------------------------------------------------------- /cpp/src/proj.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | // The SparseLinear class implements the lookup tables used in starspace model. 9 | 10 | #pragma once 11 | 12 | #include "matrix.h" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | namespace starspace { 22 | 23 | template 24 | struct SparseLinear : public Matrix { 25 | explicit SparseLinear(MatrixDims dims, 26 | Real sd = 1.0) : Matrix(dims, sd) { } 27 | 28 | explicit SparseLinear(std::ifstream& in) : Matrix(in) { } 29 | 30 | void forward(int in, Matrix& mout) { 31 | using namespace boost::numeric::ublas; 32 | const auto c = this->numCols(); 33 | mout.matrix.resize(1, c); 34 | memcpy(&mout[0][0], &(*this)[in][0], c * sizeof(Real)); 35 | } 36 | 37 | void forward(const std::vector& in, Matrix& mout) { 38 | using namespace boost::numeric::ublas; 39 | const auto c = this->numCols(); 40 | mout.matrix = zero_matrix(1, c); 41 | auto outRow = mout.row(0); 42 | for (const auto& elt: in) { 43 | assert(elt < this->numRows()); 44 | outRow += this->row(elt); 45 | } 46 | } 47 | 48 | void forward(const std::vector>& in, 49 | Matrix &mout) { 50 | using namespace boost::numeric::ublas; 51 | const auto c = this->numCols(); 52 | mout.matrix = zero_matrix(1, c); 53 | auto outRow = mout.row(0); 54 | for (const auto& pair: in) { 55 | assert(pair.first < this->numRows()); 56 | outRow += this->row(pair.first) * pair.second; 57 | } 58 | } 59 | 60 | void backward(const std::vector& in, 61 | const Matrix& mb, const Real alpha) { 62 | // Just update this racily and in-place. 63 | assert(mb.numRows() == 1); 64 | auto b = mb[0]; 65 | for (const auto& elt: in) { 66 | auto row = (*this)[elt]; 67 | for (int i = 0; i < this->numCols(); i++) { 68 | row[i] -= alpha * b[i]; 69 | } 70 | } 71 | } 72 | 73 | Real* allocOutput() { 74 | Real* retval; 75 | auto val = posix_memalign((void**)&retval, Matrix::kAlign, 76 | this->numCols() * sizeof(Real)); 77 | if (val != 0) { 78 | perror("could not allocate output"); 79 | throw this; 80 | } 81 | return retval; 82 | } 83 | }; 84 | 85 | } 86 | -------------------------------------------------------------------------------- /cpp/src/starspace.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | 9 | #include "starspace.h" 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | using namespace std; 17 | 18 | namespace starspace { 19 | 20 | StarSpace::StarSpace(shared_ptr args) 21 | : args_(args) 22 | , dict_(nullptr) 23 | , parser_(nullptr) 24 | , trainData_(nullptr) 25 | , validData_(nullptr) 26 | , testData_(nullptr) 27 | , model_(nullptr) 28 | {} 29 | 30 | void StarSpace::initParser() { 31 | // if (args_->fileFormat == "fastText") { 32 | parser_ = make_shared(dict_, args_); 33 | // } else if (args_->fileFormat == "labelDoc") { 34 | // parser_ = make_shared(dict_, args_); 35 | // } else { 36 | // cerr << "Unsupported file format. Currently support: fastText or labelDoc.\n"; 37 | // exit(EXIT_FAILURE); 38 | // } 39 | } 40 | 41 | void StarSpace::initDataHandler() { 42 | if (args_->isTrain) { 43 | trainData_ = initData(); 44 | trainData_->loadFromFile(); 45 | // set validation data 46 | if (!args_->validationFile.empty()) { 47 | validData_ = initData(); 48 | validData_->loadFromFile(); 49 | } 50 | } else { 51 | if (args_->testFile != "") { 52 | testData_ = initData(); 53 | testData_->loadFromFile(); 54 | } 55 | } 56 | } 57 | 58 | shared_ptr StarSpace::initData() { 59 | // if (args_->fileFormat == "fastText") { 60 | return make_shared(args_, dict_, parser_); 61 | // } else if (args_->fileFormat == "labelDoc") { 62 | // return make_shared(args_, dict_, parser_); 63 | // } else { 64 | // cerr << "Unsupported file format. Currently support: fastText or labelDoc.\n"; 65 | // exit(EXIT_FAILURE); 66 | // } 67 | return nullptr; 68 | } 69 | 70 | // initialize dict and load data 71 | void StarSpace::init() { 72 | cout << "\nStart to initialize starspace model.\n"; 73 | assert(args_ != nullptr); 74 | 75 | // build dict 76 | initParser(); 77 | dict_ = make_shared(args_); 78 | if(args_->isCellSpace) dict_->CreateForCellSpace(parser_); 79 | else dict_->readFromFile(args_->trainFile, parser_); 80 | parser_->resetDict(dict_); 81 | if (args_->debug) {dict_->save(cout);} 82 | 83 | // init train data class 84 | trainData_ = initData(); 85 | trainData_->loadFromFile(); 86 | 87 | // init model with args and dict 88 | model_ = make_shared(args_, dict_); 89 | // if(args_->feat_emb != "") model_->loadFeatEmb(args_->feat_emb); 90 | 91 | // set validation data 92 | if (!args_->validationFile.empty()) { 93 | validData_ = initData(); 94 | validData_->loadFromFile(); 95 | } 96 | } 97 | 98 | void StarSpace::initFromSavedModel(const string& filename) { 99 | cout << "Start to load a trained starspace model.\n"; 100 | std::ifstream in(filename, std::ifstream::binary); 101 | if (!in.is_open()) { 102 | std::cerr << "Model file cannot be opened for loading!" << std::endl; 103 | exit(EXIT_FAILURE); 104 | } 105 | string magic; 106 | char c; 107 | while ((c = in.get()) != 0) { 108 | magic.push_back(c); 109 | } 110 | cout << magic << endl; 111 | if (magic != kMagic) { 112 | std::cerr << "Magic signature does not match!" << std::endl; 113 | exit(EXIT_FAILURE); 114 | } 115 | // load args 116 | args_->load(in); 117 | 118 | // init and load dict 119 | dict_ = make_shared(args_); 120 | dict_->load(in); 121 | 122 | // init and load model 123 | model_ = make_shared(args_, dict_); 124 | model_->load(in); 125 | cout << "Model loaded.\n"; 126 | 127 | // init data parser 128 | initParser(); 129 | initDataHandler(); 130 | 131 | loadBaseDocs(); 132 | } 133 | 134 | void StarSpace::initFromTsv(const string& filename) { 135 | cout << "Start to load a trained embedding model in tsv format.\n"; 136 | assert(args_ != nullptr); 137 | ifstream in(filename); 138 | if (!in.is_open()) { 139 | std::cerr << "Model file cannot be opened for loading!" << std::endl; 140 | exit(EXIT_FAILURE); 141 | } 142 | // Test dimension of first line, adjust args appropriately 143 | // (This is also so we can load a TSV file without even specifying the dim.) 144 | string line; 145 | getline(in, line); 146 | vector pieces; 147 | boost::split(pieces, line, boost::is_any_of("\t ")); 148 | int dim = pieces.size() - 1; 149 | if ((int)(args_->dim) != dim) { 150 | args_->dim = dim; 151 | cout << "Setting dim from Tsv file to: " << dim << endl; 152 | } 153 | in.close(); 154 | 155 | // build dict 156 | dict_ = make_shared(args_); 157 | dict_->loadDictFromModel(filename); 158 | if (args_->debug) {dict_->save(cout);} 159 | 160 | // load Model 161 | model_ = make_shared(args_, dict_); 162 | model_->loadTsv(filename, "\t "); 163 | 164 | // init data parser 165 | initParser(); 166 | initDataHandler(); 167 | } 168 | 169 | void StarSpace::train() { 170 | float rate = args_->lr; 171 | float decrPerEpoch = (rate - 1e-9) / args_->epoch; 172 | 173 | int impatience = 0; 174 | float best_valid_err = 1e9; 175 | auto t_start = std::chrono::high_resolution_clock::now(); 176 | for (int i = 0; i < args_->epoch; i++) { 177 | if (args_->saveEveryNEpochs > 0 && i > 0 && i % args_->saveEveryNEpochs == 0) { 178 | auto filename = args_->model + "_epoch" + std::to_string(i); 179 | // saveModel(filename); 180 | saveModelTsv(filename + ".tsv"); 181 | } 182 | cout << "Training epoch " << i << ": " << rate << ' ' << decrPerEpoch << endl; 183 | auto err = model_->train(trainData_, args_->thread, 184 | t_start, i, 185 | rate, rate - decrPerEpoch); 186 | printf("\n ---+++ %20s %4d Train error : %3.8f +++--- %c%c%c\n", 187 | "Epoch", i, err, 188 | 0xe2, 0x98, 0x83); 189 | if (validData_ != nullptr) { 190 | auto valid_err = model_->test(validData_, args_->thread); 191 | cout << "\nValidation error: " << valid_err << endl; 192 | if (valid_err > best_valid_err) { 193 | impatience += 1; 194 | if (impatience > args_->validationPatience) { 195 | cout << "Ran out of Patience! Early stopping based on validation set." << endl; 196 | break; 197 | } 198 | } else { 199 | best_valid_err = valid_err; 200 | } 201 | } 202 | rate -= decrPerEpoch; 203 | 204 | auto t_end = std::chrono::high_resolution_clock::now(); 205 | auto tot_spent = std::chrono::duration(t_end-t_start).count(); 206 | if (tot_spent >args_->maxTrainTime) { 207 | cout << "MaxTrainTime exceeded." << endl; 208 | break; 209 | } 210 | } 211 | } 212 | 213 | void StarSpace::parseDoc( 214 | const string& line, 215 | vector& ids, 216 | const string& sep) { 217 | 218 | vector tokens; 219 | boost::split(tokens, line, boost::is_any_of(string(sep))); 220 | parser_->parse(tokens, ids); 221 | } 222 | 223 | Matrix StarSpace::getDocVector(const string& line, const string& sep) { 224 | vector ids; 225 | parseDoc(line, ids, sep); 226 | return model_->projectLHS(ids); 227 | } 228 | 229 | MatrixRow StarSpace::getNgramVector(const string& phrase) { 230 | vector tokens; 231 | boost::split(tokens, phrase, boost::is_any_of(string(" "))); 232 | if (tokens.size() > (unsigned int)(args_->ngrams)) { 233 | std::cerr << "Error! Input ngrams size is greater than model ngrams size.\n"; 234 | exit(EXIT_FAILURE); 235 | } 236 | if (tokens.size() == 1) { 237 | // looking up the entity embedding directly 238 | auto id = dict_->getId(tokens[0]); 239 | if (id != -1) { 240 | return model_->getLHSEmbeddings()->row(id); 241 | } 242 | } 243 | 244 | uint64_t h = 0; 245 | for (auto token: tokens) { 246 | if (dict_->getType(token) == entry_type::word) { 247 | h = h * Dictionary::HASH_C + dict_->hash(token); 248 | } 249 | } 250 | int64_t id = h % args_->bucket; 251 | return model_->getLHSEmbeddings()->row(id + dict_->nwords() + dict_->nlabels()); 252 | } 253 | 254 | void StarSpace::nearestNeighbor(const string& line, int k) { 255 | auto vec = getDocVector(line, " "); 256 | auto preds = model_->findLHSLike(vec, k); 257 | for (auto n : preds) { 258 | cout << dict_->getSymbol(n.first) << ' ' << n.second << endl; 259 | } 260 | } 261 | 262 | unordered_map StarSpace::predictTags(const string& line, int k){ 263 | args_->K = k; 264 | vector query_vec; 265 | parseDoc(line, query_vec, " "); 266 | 267 | vector predictions; 268 | predictOne(query_vec, predictions); 269 | 270 | unordered_map umap; 271 | 272 | for (int i = 0; i < predictions.size(); i++) { 273 | string tmp = printDocStr(baseDocs_[predictions[i].second]); 274 | umap[ tmp ] = predictions[i].first; 275 | } 276 | return umap; 277 | } 278 | 279 | void StarSpace::loadBaseDocs() { 280 | if (args_->basedoc.empty()) { 281 | // if (args_->fileFormat == "labelDoc") { 282 | // std::cerr << "Must provide base labels when label is featured.\n"; 283 | // exit(EXIT_FAILURE); 284 | // } 285 | for (int i = 0; i < dict_->nlabels(); i++) { 286 | baseDocs_.push_back({ make_pair(i + dict_->nwords(), 1.0) }); 287 | baseDocVectors_.push_back( 288 | model_->projectRHS({ make_pair(i + dict_->nwords(), 1.0) }) 289 | ); 290 | } 291 | cout << "Predictions use " << dict_->nlabels() << " known labels." << endl; 292 | } else { 293 | cout << "Loading base docs from file : " << args_->basedoc << endl; 294 | ifstream fin(args_->basedoc); 295 | if (!fin.is_open()) { 296 | std::cerr << "Base doc file cannot be opened for loading!" << std::endl; 297 | exit(EXIT_FAILURE); 298 | } 299 | string line; 300 | while (getline(fin, line)) { 301 | vector ids; 302 | parseDoc(line, ids, "\t "); 303 | baseDocs_.push_back(ids); 304 | auto docVec = model_->projectRHS(ids); 305 | baseDocVectors_.push_back(docVec); 306 | } 307 | fin.close(); 308 | if (baseDocVectors_.size() == 0) { 309 | std::cerr << "ERROR: basedoc file '" << args_->basedoc << "' is empty." << std::endl; 310 | exit(EXIT_FAILURE); 311 | } 312 | cout << "Finished loading " << baseDocVectors_.size() << " base docs.\n"; 313 | } 314 | } 315 | 316 | void StarSpace::predictOne( 317 | const vector& input, 318 | vector& pred) { 319 | auto lhsM = model_->projectLHS(input); 320 | std::priority_queue heap; 321 | for (unsigned int i = 0; i < baseDocVectors_.size(); i++) { 322 | auto cur_score = model_->similarity(lhsM, baseDocVectors_[i]); 323 | heap.push({ cur_score, i }); 324 | } 325 | // get the first K predictions 326 | int i = 0; 327 | while (i < args_->K && heap.size() > 0) { 328 | pred.push_back(heap.top()); 329 | heap.pop(); 330 | i++; 331 | } 332 | } 333 | 334 | Metrics StarSpace::evaluateOne( 335 | const vector& lhs, 336 | const vector& rhs, 337 | vector& pred, 338 | bool excludeLHS) { 339 | 340 | std::priority_queue heap; 341 | 342 | auto lhsM = model_->projectLHS(lhs); 343 | auto rhsM = model_->projectRHS(rhs); 344 | // Our evaluation function currently assumes there is only one correct label. 345 | // TODO: generalize this to the multilabel case. 346 | auto score = model_->similarity(lhsM, rhsM); 347 | 348 | int rank = 1; 349 | heap.push({ score, 0 }); 350 | 351 | for (unsigned int i = 0; i < baseDocVectors_.size(); i++) { 352 | // in the case basedoc labels are not provided, all labels become basedoc, 353 | // and we skip the correct label for comparison. 354 | if ((args_->basedoc.empty()) && ((int)i == rhs[0].first - dict_->nwords())) { 355 | continue; 356 | } 357 | auto cur_score = model_->similarity(lhsM, baseDocVectors_[i]); 358 | if (cur_score > score) { 359 | rank++; 360 | } else if (cur_score == score) { 361 | float flip = (float) rand() / RAND_MAX; 362 | if (flip > 0.5) { 363 | rank++; 364 | } 365 | } 366 | heap.push({ cur_score, i + 1 }); 367 | } 368 | 369 | // get the first K predictions 370 | int i = 0; 371 | while (i < args_->K && heap.size() > 0) { 372 | Predictions heap_top = heap.top(); 373 | heap.pop(); 374 | 375 | bool keep = true; 376 | if(excludeLHS && (args_->basedoc.empty())) { 377 | int nwords = dict_->nwords(); 378 | auto it = std::find_if( lhs.begin(), lhs.end(), 379 | [&heap_top, &nwords](const Base& el){ return (el.first - nwords + 1) == heap_top.second;} ); 380 | keep = it == lhs.end(); 381 | } 382 | 383 | if(keep) { 384 | pred.push_back(heap_top); 385 | i++; 386 | } 387 | } 388 | 389 | Metrics s; 390 | s.clear(); 391 | s.update(rank); 392 | return s; 393 | } 394 | 395 | void StarSpace::printDoc(ostream& ofs, const vector& tokens) { 396 | for (auto t : tokens) { 397 | // skip ngram tokens 398 | if (t.first < dict_->size()) { 399 | ofs << dict_->getSymbol(t.first) << ' '; 400 | } 401 | } 402 | ofs << endl; 403 | } 404 | 405 | string StarSpace::printDocStr(const vector& tokens) { 406 | for (auto t : tokens) { 407 | if (t.first < dict_->size()) { 408 | return dict_->getSymbol(t.first); 409 | } 410 | } 411 | 412 | return "__label_unk"; 413 | } 414 | 415 | void StarSpace::evaluate() { 416 | // check that it is not in trainMode 5 417 | if (args_->trainMode == 5) { 418 | std::cerr << "Test is undefined in trainMode 5. Please use other trainMode for testing.\n"; 419 | exit(EXIT_FAILURE); 420 | } 421 | 422 | // set dropout probability to 0 in test case 423 | args_->dropoutLHS = 0.0; 424 | args_->dropoutRHS = 0.0; 425 | 426 | loadBaseDocs(); 427 | int N = testData_->getSize(); 428 | 429 | auto numThreads = args_->thread; 430 | vector threads; 431 | vector metrics(numThreads); 432 | vector> predictions(N); 433 | int numPerThread = ceil((float) N / numThreads); 434 | assert(numPerThread > 0); 435 | 436 | vector examples; 437 | testData_->getNextKExamples(N, examples); 438 | 439 | auto evalThread = [&] (int idx, int start, int end) { 440 | metrics[idx].clear(); 441 | for (int i = start; i < end; i++) { 442 | auto s = evaluateOne(examples[i].LHSTokens, examples[i].RHSTokens, predictions[i], args_->excludeLHS); 443 | metrics[idx].add(s); 444 | } 445 | }; 446 | 447 | for (int i = 0; i < numThreads; i++) { 448 | auto start = std::min(i * numPerThread, N); 449 | auto end = std::min(start + numPerThread, N); 450 | assert(end >= start); 451 | threads.emplace_back(thread([=] { 452 | evalThread(i, start, end); 453 | })); 454 | } 455 | for (auto& t : threads) t.join(); 456 | 457 | Metrics result; 458 | result.clear(); 459 | for (int i = 0; i < numThreads; i++) { 460 | if (args_->debug) { metrics[i].print(); } 461 | result.add(metrics[i]); 462 | } 463 | result.average(); 464 | result.print(); 465 | 466 | if (!args_->predictionFile.empty()) { 467 | // print out prediction results to file 468 | ofstream ofs(args_->predictionFile); 469 | for (int i = 0; i < N; i++) { 470 | ofs << "Example " << i << ":\nLHS:\n"; 471 | printDoc(ofs, examples[i].LHSTokens); 472 | ofs << "RHS: \n"; 473 | printDoc(ofs, examples[i].RHSTokens); 474 | ofs << "Predictions: \n"; 475 | for (auto pred : predictions[i]) { 476 | if (pred.second == 0) { 477 | ofs << "(++) [" << pred.first << "]\t"; 478 | printDoc(ofs, examples[i].RHSTokens); 479 | } else { 480 | ofs << "(--) [" << pred.first << "]\t"; 481 | printDoc(ofs, baseDocs_[pred.second - 1]); 482 | } 483 | } 484 | ofs << "\n"; 485 | } 486 | ofs.close(); 487 | } 488 | } 489 | 490 | void StarSpace::saveModel(const string& filename) { 491 | cout << "Saving model to file : " << filename << endl; 492 | std::ofstream ofs(filename, std::ofstream::binary); 493 | if (!ofs.is_open()) { 494 | std::cerr << "Model file cannot be opened for saving!" << std::endl; 495 | exit(EXIT_FAILURE); 496 | } 497 | // sign model 498 | ofs.write(kMagic.data(), kMagic.size() * sizeof(char)); 499 | ofs.put(0); 500 | args_->save(ofs); 501 | dict_->save(ofs); 502 | model_->save(ofs); 503 | ofs.close(); 504 | } 505 | 506 | void StarSpace::saveModelTsv(const string& filename) { 507 | cout << "Saving model in tsv format : " << filename << endl; 508 | ofstream fout(filename); 509 | model_->saveTsv(fout, '\t'); 510 | fout.close(); 511 | } 512 | 513 | } // starspace 514 | -------------------------------------------------------------------------------- /cpp/src/starspace.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "utils/args.h" 11 | #include "dict.h" 12 | #include "matrix.h" 13 | #include "parser.h" 14 | // #include "doc_parser.h" 15 | #include "model.h" 16 | #include "utils/utils.h" 17 | 18 | namespace starspace { 19 | 20 | typedef std::pair Predictions; 21 | 22 | class StarSpace { 23 | public: 24 | explicit StarSpace(std::shared_ptr args); 25 | 26 | void init(); 27 | void initFromTsv(const std::string& filename); 28 | void initFromSavedModel(const std::string& filename); 29 | 30 | void train(); 31 | void evaluate(); 32 | 33 | MatrixRow getNgramVector(const std::string& phrase); 34 | Matrix getDocVector( 35 | const std::string& line, 36 | const std::string& sep = " \t"); 37 | void parseDoc( 38 | const std::string& line, 39 | std::vector& ids, 40 | const std::string& sep); 41 | 42 | void nearestNeighbor(const std::string& line, int k); 43 | 44 | 45 | std::unordered_map predictTags(const std::string& line, int k); 46 | std::string printDocStr(const std::vector& tokens); 47 | 48 | void saveModel(const std::string& filename); 49 | void saveModelTsv(const std::string& filename); 50 | void printDoc(std::ostream& ofs, const std::vector& tokens); 51 | 52 | const std::string kMagic = "STARSPACE-2018-2"; 53 | 54 | 55 | void loadBaseDocs(); 56 | 57 | void predictOne( 58 | const std::vector& input, 59 | std::vector& pred); 60 | 61 | std::shared_ptr args_; 62 | std::vector> baseDocs_; 63 | private: 64 | void initParser(); 65 | void initDataHandler(); 66 | std::shared_ptr initData(); 67 | Metrics evaluateOne( 68 | const std::vector& lhs, 69 | const std::vector& rhs, 70 | std::vector& pred, 71 | bool excludeLHS); 72 | 73 | std::shared_ptr dict_; 74 | std::shared_ptr parser_; 75 | std::shared_ptr trainData_; 76 | std::shared_ptr validData_; 77 | std::shared_ptr testData_; 78 | std::shared_ptr model_; 79 | 80 | std::vector> baseDocVectors_; 81 | }; 82 | 83 | } 84 | -------------------------------------------------------------------------------- /cpp/src/utils/args.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include "args.h" 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | using namespace std; 18 | 19 | namespace starspace { 20 | 21 | Args::Args() { 22 | lr = 0.01; 23 | termLr = 1e-9; 24 | norm = 1.0; 25 | margin = 0.05; 26 | wordWeight = 0.5; 27 | initRandSd = 0.001; 28 | dropoutLHS = 0.0; 29 | dropoutRHS = 0.0; 30 | p = 0.5; 31 | ws = 5; 32 | maxTrainTime = 60*60*24*100; 33 | validationPatience = 10; 34 | thread = 10; 35 | maxNegSamples = 10; 36 | negSearchLimit = 50; 37 | minCount = 1; 38 | minCountLabel = 1; 39 | K = 5; 40 | batchSize = 5; 41 | verbose = false; 42 | debug = false; 43 | adagrad = true; 44 | normalizeText = false; 45 | trainMode = 0; 46 | // fileFormat = "fastText"; 47 | label = "__label__"; 48 | bucket = 2000000; 49 | isTrain = true; 50 | shareEmb = true; 51 | // saveEveryEpoch = false; 52 | saveEveryNEpochs = -1; 53 | // saveTempModel = false; 54 | useWeight = false; 55 | trainWord = false; 56 | excludeLHS = false; 57 | weightSep = ':'; 58 | numGzFile = 1; 59 | 60 | loss = "hinge"; 61 | similarity = "cosine"; 62 | ngrams = 3; 63 | dim = 30; 64 | epoch = 50; 65 | exmpPerPeak = 20; 66 | k = 8; 67 | sampleLen = 150; 68 | // fixedFeatEmb = false; 69 | // batchLabels = false; 70 | } 71 | 72 | bool Args::isTrue(string arg) { 73 | std::transform(arg.begin(), arg.end(), arg.begin(), 74 | [&](char c) { return tolower(c); } 75 | ); 76 | return (arg == "true" || arg == "1"); 77 | } 78 | 79 | bool valid_file(std::string path){ 80 | if(path.empty()){ 81 | cerr << "\'" << path << "\' cannot be opened!" << endl; 82 | return false; 83 | } else { 84 | std::ifstream fin(path, std::ifstream::in); 85 | if (!fin.is_open()){ 86 | cerr << "\'" << path << "\' cannot be opened!" << endl; 87 | return false; 88 | } 89 | fin.close(); 90 | } 91 | return true; 92 | } 93 | 94 | inline bool check_header(string header){ 95 | if(header == "%%MatrixMarket matrix coordinate pattern general") return(true); 96 | else if(header == "%%MatrixMarket matrix coordinate integer general") return(false); 97 | else { 98 | cerr << "Unsupported matrix format!" << endl; 99 | exit(EXIT_FAILURE); 100 | } 101 | } 102 | 103 | void Args::parseArgs(int argc, char** argv) { 104 | string a0 = argv[0]; 105 | isCellSpace = a0.substr(a0.length() - 9, 9) == "CellSpace"; 106 | 107 | if (argc < 4) { 108 | // cerr << "Usage: need to specify whether it is train or test.\n"; 109 | printHelp(); 110 | exit(EXIT_FAILURE); 111 | } 112 | // if (strcmp(argv[1], "train") == 0) { 113 | // isTrain = true; 114 | // } else if (strcmp(argv[1], "test") == 0) { 115 | // isTrain = false; 116 | // if(isCellSpace){ 117 | // std::cerr << "CellSpace test has not been implemented yet!" << std::endl; 118 | // exit(EXIT_FAILURE); 119 | // } 120 | // } else if (strcmp(argv[1], "-h") == 0 || strcmp(argv[1], "-help") == 0) { 121 | // std::cerr << "Here is the help! Usage:" << std::endl; 122 | // printHelp(); 123 | // exit(EXIT_FAILURE); 124 | // } else { 125 | // cerr << "Usage: the first argument should be either train or test.\n"; 126 | // printHelp(); 127 | // exit(EXIT_FAILURE); 128 | // } 129 | int i = 1; 130 | while (i < argc) { 131 | if (argv[i][0] != '-') { 132 | cout << "Provided argument without a dash! Usage:" << endl; 133 | printHelp(); 134 | exit(EXIT_FAILURE); 135 | } 136 | 137 | // handling "--" 138 | if (strlen(argv[i]) >= 2 && argv[i][1] == '-') { 139 | argv[i] = argv[i] + 1; 140 | } 141 | 142 | if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "-help") == 0) { 143 | // std::cerr << "Here is the help! Usage:" << std::endl; 144 | printHelp(); 145 | exit(EXIT_FAILURE); 146 | // } else if (strcmp(argv[i], "-trainFile") == 0 && !isCellSpace) { 147 | // trainFile = string(argv[i + 1]); 148 | // } else if (strcmp(argv[i], "-validationFile") == 0 && !isCellSpace) { 149 | // validationFile = string(argv[i + 1]); 150 | // } else if (strcmp(argv[i], "-testFile") == 0 && !isCellSpace) { 151 | // testFile = string(argv[i + 1]); 152 | // } else if (strcmp(argv[i], "-predictionFile") == 0 && !isCellSpace) { 153 | // predictionFile = string(argv[i + 1]); 154 | // } else if (strcmp(argv[i], "-basedoc") == 0 && !isCellSpace) { 155 | // basedoc = string(argv[i + 1]); 156 | } else if (strcmp(argv[i], "-output") == 0) { 157 | model = string(argv[i + 1]); 158 | // } else if (strcmp(argv[i], "-initModel") == 0 && !isCellSpace) { 159 | // initModel = string(argv[i + 1]); 160 | // } else if (strcmp(argv[i], "-fileFormat") == 0 && !isCellSpace) { 161 | // fileFormat = string(argv[i + 1]); 162 | // } else if (strcmp(argv[i], "-compressFile") == 0 && !isCellSpace) { 163 | // compressFile = string(argv[i + 1]); 164 | // } else if (strcmp(argv[i], "-numGzFile") == 0) { 165 | // numGzFile = atoi(argv[i + 1]); 166 | } else if (strcmp(argv[i], "-label") == 0) { 167 | label = string(argv[i + 1]); 168 | // } else if (strcmp(argv[i], "-weightSep") == 0 && !isCellSpace) { 169 | // weightSep = argv[i + 1][0]; 170 | // } else if (strcmp(argv[i], "-loss") == 0 && !isCellSpace) { 171 | // loss = string(argv[i + 1]); 172 | // } else if (strcmp(argv[i], "-similarity") == 0) { 173 | // similarity = string(argv[i + 1]); 174 | } else if (strcmp(argv[i], "-lr") == 0) { 175 | lr = atof(argv[i + 1]); 176 | } else if (strcmp(argv[i], "-p") == 0) { 177 | p = atof(argv[i + 1]); 178 | // } else if (strcmp(argv[i], "-termLr") == 0) { 179 | // termLr = atof(argv[i + 1]); 180 | // } else if (strcmp(argv[i], "-norm") == 0) { 181 | // norm = atof(argv[i + 1]); 182 | } else if (strcmp(argv[i], "-margin") == 0) { 183 | margin = atof(argv[i + 1]); 184 | } else if (strcmp(argv[i], "-initRandSd") == 0) { 185 | initRandSd = atof(argv[i + 1]); 186 | // } else if (strcmp(argv[i], "-dropoutLHS") == 0) { 187 | // dropoutLHS = atof(argv[i + 1]); 188 | // } else if (strcmp(argv[i], "-dropoutRHS") == 0) { 189 | // dropoutRHS = atof(argv[i + 1]); 190 | // } else if (strcmp(argv[i], "-wordWeight") == 0) { 191 | // wordWeight = atof(argv[i + 1]); 192 | } else if (strcmp(argv[i], "-dim") == 0) { 193 | dim = atoi(argv[i + 1]); 194 | } else if (strcmp(argv[i], "-epoch") == 0) { 195 | epoch = atoi(argv[i + 1]); 196 | // } else if (strcmp(argv[i], "-ws") == 0) { 197 | // ws = atoi(argv[i + 1]); 198 | } else if (strcmp(argv[i], "-maxTrainTime") == 0) { 199 | maxTrainTime = atoi(argv[i + 1]); 200 | // } else if (strcmp(argv[i], "-validationPatience") == 0) { 201 | // validationPatience = atoi(argv[i + 1]); 202 | } else if (strcmp(argv[i], "-thread") == 0) { 203 | thread = atoi(argv[i + 1]); 204 | } else if (strcmp(argv[i], "-maxNegSamples") == 0) { 205 | maxNegSamples = atoi(argv[i + 1]); 206 | } else if (strcmp(argv[i], "-negSearchLimit") == 0) { 207 | negSearchLimit = atoi(argv[i + 1]); 208 | // } else if (strcmp(argv[i], "-minCount") == 0 && !isCellSpace) { 209 | // minCount = atoi(argv[i + 1]); 210 | // } else if (strcmp(argv[i], "-minCountLabel") == 0 && !isCellSpace) { 211 | // minCountLabel = atoi(argv[i + 1]); 212 | } else if (strcmp(argv[i], "-bucket") == 0) { 213 | bucket = atoi(argv[i + 1]); 214 | } else if (strcmp(argv[i], "-ngrams") == 0) { 215 | ngrams = atoi(argv[i + 1]); 216 | // } else if (strcmp(argv[i], "-K") == 0) { 217 | // K = atoi(argv[i + 1]); 218 | } else if (strcmp(argv[i], "-batchSize") == 0) { 219 | batchSize = atoi(argv[i + 1]); 220 | // } else if (strcmp(argv[i], "-trainMode") == 0 && !isCellSpace) { 221 | // trainMode = atoi(argv[i + 1]); 222 | // } else if (strcmp(argv[i], "-fixedFeatEmb") == 0 && isCellSpace) { 223 | // fixedFeatEmb = isTrue(string(argv[i + 1])); 224 | // } else if (strcmp(argv[i], "-batchLabels") == 0 && isCellSpace) { 225 | // batchLabels = isTrue(string(argv[i + 1])); 226 | // } else if (strcmp(argv[i], "-verbose") == 0) { 227 | // verbose = isTrue(string(argv[i + 1])); 228 | // } else if (strcmp(argv[i], "-debug") == 0) { 229 | // debug = isTrue(string(argv[i + 1])); 230 | // } else if (strcmp(argv[i], "-adagrad") == 0) { 231 | // adagrad = isTrue(string(argv[i + 1])); 232 | // } else if (strcmp(argv[i], "-shareEmb") == 0) { 233 | // shareEmb = isTrue(string(argv[i + 1])); 234 | // } else if (strcmp(argv[i], "-normalizeText") == 0 && !isCellSpace) { 235 | // normalizeText = isTrue(string(argv[i + 1])); 236 | } else if (strcmp(argv[i], "-saveIntermediates") == 0) { 237 | if(string(argv[i + 1]) == "final") saveEveryNEpochs = -1; 238 | else saveEveryNEpochs = atoi(argv[i + 1]); 239 | if(saveEveryNEpochs < 1) saveEveryNEpochs = -1; 240 | // } else if (strcmp(argv[i], "-saveTempModel") == 0) { 241 | // saveTempModel = isTrue(string(argv[i + 1])); 242 | // } else if (strcmp(argv[i], "-useWeight") == 0 && !isCellSpace) { 243 | // useWeight = isTrue(string(argv[i + 1])); 244 | // } else if (strcmp(argv[i], "-trainWord") == 0 && !isCellSpace) { 245 | // trainWord = isTrue(string(argv[i + 1])); 246 | // } else if (strcmp(argv[i], "-excludeLHS") == 0) { 247 | // excludeLHS = isTrue(string(argv[i + 1])); 248 | } else if (strcmp(argv[i], "-cpMat") == 0 && isCellSpace) { 249 | assert(i + 1 < argc); 250 | while(i + 1 < argc && argv[i + 1][0] != '-') 251 | cp_matrix_list.push_back(string(argv[(i++) + 1])); 252 | i--; 253 | } else if (strcmp(argv[i], "-peaks") == 0 && isCellSpace) { 254 | assert(i + 1 < argc); 255 | while(i + 1 < argc && argv[i + 1][0] != '-') 256 | peaks_list.push_back(string(argv[(i++) + 1])); 257 | i--; 258 | // } else if (strcmp(argv[i], "-featEmb") == 0 && isCellSpace) { 259 | // feat_emb = string(argv[i + 1]); 260 | } else if (strcmp(argv[i], "-k") == 0 && isCellSpace) { 261 | k = atoi(argv[i + 1]); 262 | } else if (strcmp(argv[i], "-sampleLen") == 0 && isCellSpace) { 263 | sampleLen = (strcmp(argv[i + 1], "given") == 0) ? -1 : atoi(argv[i + 1]); 264 | } else if (strcmp(argv[i], "-exmpPerPeak") == 0 && isCellSpace) { 265 | exmpPerPeak = atoi(argv[i + 1]); 266 | } else { 267 | cerr << "Unknown argument: " << argv[i] << std::endl; 268 | printHelp(); 269 | exit(EXIT_FAILURE); 270 | } 271 | i += 2; 272 | } 273 | 274 | if(isCellSpace){ 275 | trainMode = 0; 276 | minCount = 1; 277 | minCountLabel = 1; 278 | trainWord = false; 279 | useWeight = false; 280 | normalizeText = false; 281 | compressFile = ""; 282 | 283 | if(peaks_list.size() != cp_matrix_list.size()){ 284 | cerr << "The number of input files for peak sequences should match the number of input files for count matrices!" << endl; 285 | exit(EXIT_FAILURE); 286 | } 287 | nBatches = cp_matrix_list.size(); 288 | 289 | nCells_total = 0; nCells_list.clear(); 290 | nPeaks_total = 0; nPeaks_list.clear(); 291 | first_peak_idx.clear(); 292 | for(auto cp_matrix: cp_matrix_list){ 293 | unsigned nCells = 0, nPeaks = 0; 294 | if(!valid_file(cp_matrix)) exit(EXIT_FAILURE); 295 | else { 296 | string header, mat_format = cp_matrix.substr(cp_matrix.find_last_of(".") + 1); 297 | if(mat_format == "gz"){ 298 | #ifdef COMPRESS_FILE 299 | ifstream ifs2(cp_matrix); 300 | if(!ifs2.good()) exit(EXIT_FAILURE); 301 | filtering_istream fin; 302 | fin.push(gzip_decompressor()); 303 | fin.push(ifs2); 304 | getline(fin, header); 305 | bool bin = check_header(header); 306 | fin >> nCells >> nPeaks; 307 | ifs2.close(); 308 | #endif 309 | } else { 310 | std::ifstream fin(cp_matrix, std::ifstream::in); 311 | getline(fin, header); 312 | bool bin = check_header(header); 313 | fin >> nCells >> nPeaks; 314 | fin.close(); 315 | } 316 | if(nCells < 2 || nPeaks < 2){ 317 | cerr << "There must be at least two cells and two peaks in a dataset!" << endl; 318 | exit(EXIT_FAILURE); 319 | } else { 320 | first_peak_idx.push_back(nPeaks_total); 321 | nCells_list.push_back(nCells); nCells_total += nCells; 322 | nPeaks_list.push_back(nPeaks); nPeaks_total += nPeaks; 323 | // cerr << nCells_list.size() << ") cells: " << nCells << ", peaks: " << nPeaks << endl; 324 | } 325 | } 326 | } 327 | 328 | for(auto peaks: peaks_list) 329 | if(!valid_file(peaks)) exit(EXIT_FAILURE); 330 | 331 | // if(feat_emb != "" && !valid_file(feat_emb)) exit(EXIT_FAILURE); 332 | 333 | if(k < 2 || k > 16){ 334 | cerr << "k must be >1 and <17" << endl; 335 | exit(EXIT_FAILURE); 336 | } 337 | if(sampleLen != -1 && sampleLen < k){ 338 | cerr << "sampleLen must be >=k" << endl; 339 | exit(EXIT_FAILURE); 340 | } 341 | if(exmpPerPeak < 1){ 342 | cerr << "exmpPerPeak must be >0" << endl; 343 | exit(EXIT_FAILURE); 344 | } 345 | } 346 | // else { 347 | // if (isTrain) { 348 | // if (model.empty()) { 349 | // cerr << "Empty output path." << endl; 350 | // printHelp(); 351 | // exit(EXIT_FAILURE); 352 | // } 353 | // } else { 354 | // if (testFile.empty() || model.empty()) { 355 | // cerr << "Empty test file or model path." << endl; 356 | // printHelp(); 357 | // exit(EXIT_FAILURE); 358 | // } 359 | // } 360 | // } 361 | } 362 | 363 | void Args::printHelp() { 364 | if(isCellSpace){ 365 | cout << "\n" 366 | << "\"CellSpace ...\"\n" 367 | 368 | << "\nThe following arguments are mandatory:\n" 369 | << " -output prefix of the output\n" 370 | << " -cpMat sparse cell by peak/tile count matrix (.mtx)\n" 371 | << " -peaks multi-fasta file containing peak/tile DNA sequences with the order they appear in the corresponding count matrix (.fa)\n" 372 | // << " -featEmb optional embedding matrix to initialize k-mer embeddings from a previously trained CellSpace model (.tsv)\n" 373 | 374 | << "\nThe following arguments are optional:\n" 375 | << " -dim size of embedding vectors [default=" << dim << "]\n" 376 | << " -ngrams max length of k-mer ngram [default=" << ngrams << "]\n" 377 | << " -k k-mer length [default=" << k << "]\n" 378 | << " -sampleLen length of the sequences randomly sampled from the peak/tile DNA sequences (integer or 'given') [default=" << (sampleLen == -1 ? "'given'" : to_string(sampleLen)) << "]\n" 379 | << " -exmpPerPeak number of training examples per peak/tile [default=" << exmpPerPeak << "]\n" 380 | << " -epoch number of epochs [default=" << epoch << "]\n" 381 | // << " -fixedFeatEmb whether feature (k-mer) embeddings should be kept fixed [default=" << fixedFeatEmb << "]\n" 382 | // << " -batchLabels whether the batch (i.e. dataset) should be included as a label (i.e. RHS) in training [default=" << batchLabels << "]\n" 383 | << " -margin margin parameter in hinge loss [default=" << margin << "]\n" 384 | // << " -similarity takes value in [cosine, dot]. Whether to use cosine or dot product as similarity function in hinge loss [default=" << similarity << "]\n" 385 | << " -bucket number of buckets [default=" << bucket << "]\n" 386 | << " -label labels prefix [default='" << label << "']\n" 387 | << " -lr learning rate [default=" << lr << "]\n" 388 | << " -maxTrainTime max train time (seconds) [default=" << maxTrainTime << "]\n" 389 | << " -negSearchLimit number of negative labels sampled per dataset [default=" << negSearchLimit << "]\n" 390 | << " -maxNegSamples max number of negatives in a batch update [default=" << maxNegSamples << "]\n" 391 | << " -p the embedding of an entity equals the sum of its M feature embedding vectors devided by M^p [default=" << p << "]\n" 392 | // << " -adagrad whether to use adagrad in training [default=" << adagrad << "]\n" 393 | // << " -shareEmb whether to use the same embedding matrix for LHS and RHS [default=" << shareEmb << "]\n" 394 | // << " -dropoutLHS dropout probability for LHS features [default=" << dropoutLHS << "]\n" 395 | // << " -dropoutRHS dropout probability for RHS features [default=" << dropoutRHS << "]\n" 396 | << " -initRandSd initial values of embeddings are randomly generated from normal distribution with mean=0 and standard deviation=initRandSd [default=" << initRandSd << "]\n" 397 | << " -batchSize size of mini batch in training [default=" << batchSize << "]\n" 398 | // << " -verbose verbosity level [default=" << verbose << "]\n" 399 | // << " -debug whether it's in debug mode [default=" << debug << "]\n" 400 | << " -saveIntermediates save intermediate models or only the final model (integer or 'final') [default=" << (saveEveryNEpochs == -1 ? "'final'" : to_string(saveEveryNEpochs)) << "]\n" 401 | // << " -saveTempModel save intermediate models after each epoch with an unique name including epoch number [default=" << saveTempModel << "]\n" 402 | << " -thread number of threads [default=" << thread << "]\n\n"; 403 | } 404 | } 405 | 406 | void Args::printArgs() { 407 | cout << "CellSpace Arguments: \n" 408 | << "dim: " << dim << endl 409 | << "ngrams: " << ngrams << endl 410 | << "k: " << k << endl 411 | << "sampleLen: " << (sampleLen == -1 ? "'given'" : to_string(sampleLen)) << endl 412 | << "exmpPerPeak: " << exmpPerPeak << endl 413 | << "epoch: " << epoch << endl 414 | << "margin: " << margin << endl 415 | << "bucket: " << bucket << endl 416 | << "label: '" << label << "'" << endl 417 | << "lr: " << lr << endl 418 | << "maxTrainTime: " << maxTrainTime << endl 419 | << "negSearchLimit: " << negSearchLimit << endl 420 | << "maxNegSamples: " << maxNegSamples << endl 421 | << "p: " << p << endl 422 | << "initRandSd: " << initRandSd << endl 423 | << "batchSize: " << batchSize << endl 424 | << "saveIntermediates: " << (saveEveryNEpochs == -1 ? "'final'" : to_string(saveEveryNEpochs)) << endl 425 | << "thread: " << thread << endl 426 | << "-----------------" << endl; 427 | } 428 | 429 | void Args::save(std::ostream& out) { 430 | out.write((char*) &(dim), sizeof(int)); 431 | out.write((char*) &(epoch), sizeof(int)); 432 | out.write((char*) &(maxTrainTime), sizeof(int)); 433 | out.write((char*) &(minCount), sizeof(int)); 434 | out.write((char*) &(minCountLabel), sizeof(int)); 435 | out.write((char*) &(maxNegSamples), sizeof(int)); 436 | out.write((char*) &(negSearchLimit), sizeof(int)); 437 | out.write((char*) &(ngrams), sizeof(int)); 438 | out.write((char*) &(bucket), sizeof(int)); 439 | out.write((char*) &(trainMode), sizeof(int)); 440 | out.write((char*) &(shareEmb), sizeof(bool)); 441 | out.write((char*) &(useWeight), sizeof(bool)); 442 | out.write((char*) &(weightSep), sizeof(char)); 443 | size_t size; // = fileFormat.size(); 444 | out.write((char*) &(size), sizeof(size_t)); 445 | // out.write((char*) &(fileFormat[0]), size); 446 | size = similarity.size(); 447 | out.write((char*) &(size), sizeof(size_t)); 448 | out.write((char*) &(similarity[0]), size); 449 | out.write((char*) &(batchSize), sizeof(int)); 450 | } 451 | 452 | void Args::load(std::istream& in) { 453 | in.read((char*) &(dim), sizeof(int)); 454 | in.read((char*) &(epoch), sizeof(int)); 455 | // in.read((char*) &(maxTrainTime), sizeof(int)); 456 | in.read((char*) &(minCount), sizeof(int)); 457 | in.read((char*) &(minCountLabel), sizeof(int)); 458 | in.read((char*) &(maxNegSamples), sizeof(int)); 459 | in.read((char*) &(negSearchLimit), sizeof(int)); 460 | in.read((char*) &(ngrams), sizeof(int)); 461 | in.read((char*) &(bucket), sizeof(int)); 462 | in.read((char*) &(trainMode), sizeof(int)); 463 | // in.read((char*) &(shareEmb), sizeof(bool)); 464 | in.read((char*) &(useWeight), sizeof(bool)); 465 | in.read((char*) &(weightSep), sizeof(char)); 466 | size_t size; 467 | in.read((char*) &(size), sizeof(size_t)); 468 | // fileFormat.resize(size); 469 | // in.read((char*) &(fileFormat[0]), size); 470 | in.read((char*) &(size), sizeof(size_t)); 471 | similarity.resize(size); 472 | in.read((char*) &(similarity[0]), size); 473 | in.read((char*) &(batchSize), sizeof(int)); 474 | } 475 | 476 | } 477 | -------------------------------------------------------------------------------- /cpp/src/utils/args.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | namespace starspace { 15 | 16 | class Args { 17 | public: 18 | Args(); 19 | std::string trainFile; 20 | std::string validationFile; 21 | std::string testFile; 22 | std::string predictionFile; 23 | std::string model; 24 | // std::string initModel; 25 | // std::string fileFormat; 26 | std::string compressFile; 27 | std::string label; 28 | std::string basedoc; 29 | std::string loss; 30 | std::string similarity; 31 | 32 | char weightSep; 33 | double lr; 34 | double termLr; 35 | double norm; 36 | double margin; 37 | double initRandSd; 38 | double p; 39 | double dropoutLHS; 40 | double dropoutRHS; 41 | double wordWeight; 42 | size_t dim; 43 | int epoch; 44 | int ws; 45 | int maxTrainTime; 46 | int validationPatience; 47 | int thread; 48 | int maxNegSamples; 49 | int negSearchLimit; 50 | int minCount; 51 | int minCountLabel; 52 | int bucket; 53 | int ngrams; 54 | int trainMode; 55 | int K; 56 | int batchSize; 57 | int numGzFile; 58 | bool verbose; 59 | bool debug; 60 | bool adagrad; 61 | bool isTrain; 62 | bool normalizeText; 63 | // bool saveEveryEpoch; 64 | int saveEveryNEpochs; 65 | // bool saveTempModel; 66 | bool shareEmb; 67 | bool useWeight; 68 | bool trainWord; 69 | bool excludeLHS; 70 | 71 | // CellSpace parameters: 72 | bool isCellSpace; 73 | // bool fixedFeatEmb; 74 | // bool batchLabels; 75 | int k; 76 | int sampleLen; 77 | int exmpPerPeak; 78 | std::vector cp_matrix_list; 79 | std::vector peaks_list; 80 | // std::string feat_emb; 81 | std::vector nCells_list, nPeaks_list, first_peak_idx; 82 | unsigned nCells_total, nPeaks_total, nBatches; 83 | 84 | void parseArgs(int, char**); 85 | void printHelp(); 86 | void printArgs(); 87 | void save(std::ostream& out); 88 | void load(std::istream& in); 89 | bool isTrue(std::string arg); 90 | }; 91 | 92 | } 93 | -------------------------------------------------------------------------------- /cpp/src/utils/normalize.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include "normalize.h" 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | namespace starspace { 16 | 17 | void normalize_text(std::string& str) { 18 | /* 19 | * We categorize longer strings into the following buckets: 20 | * 21 | * 1. All punctuation-and-numeric. Things in this bucket get 22 | * their numbers flattened, to prevent combinatorial explosions. 23 | * They might be specific numbers, prices, etc. 24 | * 25 | * 2. All letters: case-flattened. 26 | * 27 | * 3. Mixed letters and numbers: a product ID? Flatten case and leave 28 | * numbers alone. 29 | * 30 | * The case-normalization is state-machine-driven. 31 | */ 32 | bool allNumeric = true; 33 | bool containsDigits = false; 34 | 35 | for (char c: str) { 36 | assert(c); // don't shove binary data through this. 37 | containsDigits |= isdigit(c); 38 | if (!isascii(c)) { 39 | allNumeric = false; 40 | continue; 41 | } 42 | if (!isalpha(c)) continue; 43 | allNumeric = false; 44 | } 45 | 46 | bool flattenCase = true; 47 | bool flattenNum = allNumeric && containsDigits; 48 | if (!flattenNum && !flattenCase) return; 49 | 50 | std::transform(str.begin(), str.end(), str.begin(), 51 | [&](char c) { 52 | if (flattenNum && isdigit(c)) return '0'; 53 | if (isalpha(c)) return char(tolower(c)); 54 | return c; 55 | }); 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /cpp/src/utils/normalize.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | 10 | #include 11 | 12 | namespace starspace { 13 | 14 | // In-place normalization of UTF-8 strings. 15 | extern void normalize_text(std::string& buf); 16 | 17 | } 18 | -------------------------------------------------------------------------------- /cpp/src/utils/utils.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #include "utils.h" 9 | 10 | namespace starspace { 11 | namespace detail { thread_local int id; } 12 | } 13 | 14 | 15 | -------------------------------------------------------------------------------- /cpp/src/utils/utils.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Facebook, Inc. and its affiliates. 3 | * 4 | * This source code is licensed under the MIT license found in the 5 | * LICENSE file in the root directory of this source tree. 6 | */ 7 | 8 | #pragma once 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #ifdef COMPRESS_FILE 19 | #include 20 | #endif 21 | 22 | namespace starspace { 23 | 24 | struct Metrics { 25 | float hit1, hit10, hit20, hit50, rank; 26 | int32_t count; 27 | 28 | void clear() { 29 | hit1 = 0; 30 | hit10 = 0; 31 | hit20 = 0; 32 | hit50 = 0; 33 | rank = 0; 34 | count = 0; 35 | }; 36 | 37 | void add(const Metrics& b) { 38 | hit1 += b.hit1; 39 | hit10 += b.hit10; 40 | hit20 += b.hit20; 41 | hit50 += b.hit50; 42 | rank += b.rank; 43 | count += b.count; 44 | }; 45 | 46 | void average() { 47 | if (count == 0) { 48 | return ; 49 | } 50 | hit1 /= count; 51 | hit10 /= count; 52 | hit20 /= count; 53 | hit50 /= count; 54 | rank /= count; 55 | } 56 | 57 | void print() { 58 | std::cout << "Evaluation Metrics : \n" 59 | << "hit@1: " << hit1 60 | << " hit@10: " << hit10 61 | << " hit@20: " << hit20 62 | << " hit@50: " << hit50 63 | << " mean ranks : " << rank 64 | << " Total examples : " << count << "\n"; 65 | } 66 | 67 | void update(int cur_rank) { 68 | if (cur_rank == 1) { hit1++; } 69 | if (cur_rank <= 10) { hit10++; } 70 | if (cur_rank <= 20) { hit20++; } 71 | if (cur_rank <= 50) { hit50++; } 72 | rank += cur_rank; 73 | count++; 74 | } 75 | 76 | }; 77 | 78 | 79 | namespace detail { 80 | extern thread_local int id; 81 | } 82 | 83 | namespace { 84 | inline int getThreadID() { 85 | return detail::id; 86 | } 87 | } 88 | 89 | namespace { 90 | template 91 | void reset(Stream& s, std::streampos pos) { 92 | s.clear(); 93 | s.seekg(pos, std::ios_base::beg); 94 | } 95 | 96 | template 97 | std::streampos tellg(Stream& s) { 98 | auto retval = s.tellg(); 99 | return retval; 100 | } 101 | } 102 | 103 | // Apply a closure pointwise to every line of a file. 104 | template 106 | void foreach_line(const String& fname, 107 | Lambda f, 108 | int numThreads = 1) { 109 | using namespace std; 110 | 111 | auto filelen = [&](ifstream& f) { 112 | f.seekg(0, ios_base::end); 113 | return tellg(f); 114 | }; 115 | 116 | ifstream ifs(fname); 117 | if (!ifs.good()) { 118 | throw runtime_error(string("error opening ") + fname); 119 | } 120 | auto len = filelen(ifs); 121 | // partitions[i],partitions[i+1] will be the bytewise boundaries for the i'th 122 | // thread. 123 | std::vector partitions(numThreads + 1); 124 | partitions[0] = 0; 125 | partitions[numThreads] = len; 126 | 127 | // Seek to bytewise partition boundaries, and read one line forward. 128 | string unused; 129 | for (int i = 1; i < numThreads; i++) { 130 | reset(ifs, (len / numThreads) * i); 131 | getline(ifs, unused); 132 | partitions[i] = tellg(ifs); 133 | } 134 | 135 | // It's possible that the ranges in partitions overlap; consider, 136 | // e.g., a machine with 100 hardware threads and only 99 lines 137 | // in the file. In this case, we'll do some excess work, so we ask 138 | // that f() be idempotent. 139 | vector threads; 140 | for (int i = 0; i < numThreads; i++) { 141 | threads.emplace_back([i, f, &fname, &partitions] { 142 | detail::id = i; 143 | // Get our own seek pointer. 144 | ifstream ifs2(fname); 145 | ifs2.seekg(partitions[i]); 146 | string line; 147 | while (tellg(ifs2) < partitions[i + 1] && getline(ifs2, line)) { 148 | // We don't know the line number. Super-bummer. 149 | f(line); 150 | } 151 | }); 152 | } 153 | for (auto &t: threads) { 154 | t.join(); 155 | } 156 | } 157 | 158 | template 160 | void foreach_line_gz( 161 | const String& fname, 162 | int numFiles, 163 | Lambda f, 164 | int numThreads = 1) { 165 | 166 | using namespace std; 167 | using namespace boost::iostreams; 168 | 169 | vector threads; 170 | numThreads = std::min(numFiles, numThreads); 171 | 172 | #ifdef COMPRESS_FILE 173 | for (int i = 0; i < numFiles; i++) { 174 | auto thread_id = i % numThreads; 175 | threads.emplace_back([thread_id, i, f, &fname] { 176 | detail::id = thread_id; 177 | auto fname_t = fname + boost::str(boost::format("%02d") % i) + ".gz"; 178 | ifstream ifs2(fname_t); 179 | if (!ifs2.good()) { 180 | return; 181 | } 182 | 183 | cout << "Reading file from " << fname_t << endl; 184 | filtering_istream in; 185 | in.push(gzip_decompressor()); 186 | in.push(ifs2); 187 | std::string line; 188 | while (getline(in, line, '\n')) { 189 | f(line); 190 | } 191 | }); 192 | } 193 | for (auto &t: threads) { 194 | t.join(); 195 | } 196 | #endif 197 | } 198 | 199 | } // namespace 200 | -------------------------------------------------------------------------------- /man/CellSpace-class.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CellSpace.R 3 | \docType{class} 4 | \name{CellSpace-class} 5 | \alias{CellSpace-class} 6 | \title{The CellSpace Class} 7 | \description{ 8 | The \code{CellSpace} class stores CellSpace embedding and 9 | related information needed for performing downstream analyses. 10 | } 11 | \section{Slots}{ 12 | 13 | \describe{ 14 | \item{\code{project}}{title of the project} 15 | 16 | \item{\code{emb.file}}{the .tsv output of CellSpace containing the embedding matrix for cells and k-mers} 17 | 18 | \item{\code{cell.emb}}{the embedding matrix for cells} 19 | 20 | \item{\code{kmer.emb}}{the embedding matrix for k-mers} 21 | 22 | \item{\code{motif.emb}}{the embedding matrix for TF motifs} 23 | 24 | \item{\code{meta.data}}{data frame containing meta-information about each cell} 25 | 26 | \item{\code{dim}}{the dimensions of the CellSpace embeddings} 27 | 28 | \item{\code{k}}{the length of DNA k-mers} 29 | 30 | \item{\code{similarity}}{the similarity function in hinge loss} 31 | 32 | \item{\code{p}}{the embedding of an entity equals the sum of its \code{M} feature embedding vectors divided by \code{M^p}} 33 | 34 | \item{\code{label}}{cell label prefix} 35 | 36 | \item{\code{neighbors}}{list containing nearest neighbor graphs} 37 | 38 | \item{\code{reductions}}{list containing dimensional reductions} 39 | 40 | \item{\code{misc}}{list containing miscellaneous objects} 41 | }} 42 | 43 | -------------------------------------------------------------------------------- /man/CellSpace.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CellSpace.R 3 | \name{CellSpace} 4 | \alias{CellSpace} 5 | \title{CellSpace} 6 | \usage{ 7 | CellSpace( 8 | emb.file, 9 | cell.names = NULL, 10 | meta.data = NULL, 11 | project = NULL, 12 | similarity = "cosine", 13 | p = 0.5, 14 | label = "__label__" 15 | ) 16 | } 17 | \arguments{ 18 | \item{emb.file}{the .tsv output of CellSpace containing the embedding matrix for cells and k-mers} 19 | 20 | \item{cell.names}{vector of unique cell names} 21 | 22 | \item{meta.data}{a \code{data.frame} containing meta-information about each cell} 23 | 24 | \item{project}{title of the project} 25 | 26 | \item{similarity}{the similarity function in hinge loss} 27 | 28 | \item{p}{the embedding of an entity equals the sum of its \code{M} feature embedding vectors divided by \code{M^p}} 29 | 30 | \item{label}{cell label prefix} 31 | } 32 | \value{ 33 | a new \code{CellSpace} object 34 | } 35 | \description{ 36 | Generates an object from the \code{CellSpace} class. 37 | } 38 | -------------------------------------------------------------------------------- /man/DNA_sequence_embedding.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CellSpace.R 3 | \name{DNA_sequence_embedding} 4 | \alias{DNA_sequence_embedding} 5 | \title{DNA_sequence_embedding} 6 | \usage{ 7 | DNA_sequence_embedding(object, seq) 8 | } 9 | \arguments{ 10 | \item{object}{a \code{CellSpace} object} 11 | 12 | \item{seq}{a DNA sequence} 13 | } 14 | \value{ 15 | a numerical vector containing the CellSpace embedding of \code{seq} 16 | } 17 | \description{ 18 | Maps a DNA sequence to the embedding space. 19 | } 20 | -------------------------------------------------------------------------------- /man/add_motif_db.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CellSpace.R 3 | \name{add_motif_db} 4 | \alias{add_motif_db} 5 | \title{add_motif_db} 6 | \usage{ 7 | add_motif_db(object, motif.db, db.name) 8 | } 9 | \arguments{ 10 | \item{object}{a \code{CellSpace} object} 11 | 12 | \item{motif.db}{\code{PFMatrixList} or \code{PWMatrixList}} 13 | 14 | \item{db.name}{the name of the transcription factor motif database} 15 | } 16 | \value{ 17 | a \code{CellSpace} object containing the motif embedding matrix, in the \code{motif.emb} slot, and the corresponding similarity Z-scores, in the \code{misc} slot 18 | } 19 | \description{ 20 | Computes the CellSpace embedding and activity scores of transcription factor motifs. 21 | } 22 | -------------------------------------------------------------------------------- /man/cosine_similarity.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CellSpace.R 3 | \name{cosine_similarity} 4 | \alias{cosine_similarity} 5 | \title{cosine_similarity} 6 | \usage{ 7 | cosine_similarity(x, y = NULL) 8 | } 9 | \arguments{ 10 | \item{x}{an embedding matrix} 11 | 12 | \item{y}{an embedding matrix with compatible dimensions to \code{x}, or \code{NULL}, in which case \code{y=x}} 13 | } 14 | \value{ 15 | a matrix containing the cosine similarity between rows of \code{x} and \code{y} 16 | } 17 | \description{ 18 | Computes cosine similarity in the embedding space. 19 | } 20 | -------------------------------------------------------------------------------- /man/docs/CellSpace.md: -------------------------------------------------------------------------------- 1 | # `CellSpace` 2 | 3 | ## Description 4 | 5 | Generates an object from the `CellSpace` class. 6 | 7 | ## Usage 8 | 9 | ``` r 10 | CellSpace( 11 | emb.file, 12 | cell.names = NULL, 13 | meta.data = NULL, 14 | project = NULL, 15 | similarity = "cosine", 16 | p = 0.5, 17 | label = "__label__" 18 | ) 19 | ``` 20 | 21 | ## Arguments 22 | 23 | | Argument | Description | 24 | |-----------------|-------------------------------------------------------| 25 | | `emb.file` | the .tsv output of CellSpace containing the embedding matrix for cells and *k*-mers | 26 | | `cell.names` | vector of unique cell names | 27 | | `meta.data` | a `data.frame` containing meta-information about each cell | 28 | | `project` | title of the project | 29 | | `similarity` | the similarity function in hinge loss | 30 | | `p` | the embedding of an entity equals the sum of its `M` feature embedding vectors divided by `M^p` | 31 | | `label` | cell label prefix | 32 | 33 | ## Value 34 | 35 | a new `CellSpace` object 36 | -------------------------------------------------------------------------------- /man/docs/DNA_sequence_embedding.md: -------------------------------------------------------------------------------- 1 | # `DNA_sequence_embedding` 2 | 3 | ## Description 4 | 5 | Maps a DNA sequence to the embedding space. 6 | 7 | ## Usage 8 | 9 | ``` r 10 | DNA_sequence_embedding(object, seq) 11 | ``` 12 | 13 | ## Arguments 14 | 15 | | Argument | Description | 16 | |----------|----------------------| 17 | | `object` | a `CellSpace` object | 18 | | `seq` | a DNA sequence | 19 | 20 | ## Value 21 | 22 | a numerical vector containing the CellSpace embedding of `seq` 23 | -------------------------------------------------------------------------------- /man/docs/add_motif_db.md: -------------------------------------------------------------------------------- 1 | # `add_motif_db` 2 | 3 | ## Description 4 | 5 | Computes the CellSpace embedding and activity scores of transcription factor motifs. 6 | 7 | ## Usage 8 | 9 | ``` r 10 | add_motif_db(object, motif.db, db.name) 11 | ``` 12 | 13 | ## Arguments 14 | 15 | | Argument | Description | 16 | |------------|-----------------------------------------------------| 17 | | `object` | a `CellSpace` object | 18 | | `motif.db` | `PFMatrixList` or `PWMatrixList` | 19 | | `db.name` | the name of the transcription factor motif database | 20 | 21 | ## Value 22 | 23 | a `CellSpace` object containing the motif embedding matrix, in the `motif.emb` slot, and the corresponding similarity *Z*-scores, in the `misc` slot 24 | -------------------------------------------------------------------------------- /man/docs/cosine_similarity.md: -------------------------------------------------------------------------------- 1 | # `cosine_similarity` 2 | 3 | ## Description 4 | 5 | Computes cosine similarity in the embedding space. 6 | 7 | ## Usage 8 | 9 | ``` r 10 | cosine_similarity(x, y = NULL) 11 | ``` 12 | 13 | ## Arguments 14 | 15 | | Argument | Description | 16 | |----------|---------------------------------------------------------------------------------------| 17 | | `x` | an embedding matrix | 18 | | `y` | an embedding matrix with compatible dimensions to `x`, or `NULL`, in which case `y=x` | 19 | 20 | ## Value 21 | 22 | a matrix containing the cosine similarity between rows of `x` and `y` 23 | -------------------------------------------------------------------------------- /man/docs/embedding_distance.md: -------------------------------------------------------------------------------- 1 | # `embedding_distance` 2 | 3 | ## Description 4 | 5 | Computes distance in the embedding space based on cosine similarity. 6 | 7 | ## Usage 8 | 9 | ``` r 10 | embedding_distance(x, y = NULL, distance = c("cosine", "angular")) 11 | ``` 12 | 13 | ## Arguments 14 | 15 | | Argument | Description | 16 | |------------|------------------------------------------------------------------------------------------| 17 | | `x` | an embedding matrix | 18 | | `y` | an embedding matrix with compatible dimensions to `x`, or `NULL`, in which case `y=x` | 19 | | `distance` | the distance metric, either 'cosine' or 'angular', to compute from the cosine similarity | 20 | 21 | ## Value 22 | 23 | a matrix containing the distance between rows of `x` and `y`, computed from their cosine similarity 24 | -------------------------------------------------------------------------------- /man/docs/find_clusters.md: -------------------------------------------------------------------------------- 1 | # `find_clusters` 2 | 3 | ## Description 4 | 5 | Finds clusters in a shared nearest neighbor graph built from the CellSpace embedding. 6 | 7 | ## Usage 8 | 9 | ``` r 10 | find_clusters(object, graph = "cells_snn", ...) 11 | ``` 12 | 13 | ## Arguments 14 | 15 | | Argument | Description | 16 | |---------------------------------|---------------------------------------| 17 | | `object` | a `CellSpace` object | 18 | | `graph` | name of the shared nearest neighbor graph in the `neighbors` slot used to find clusters | 19 | | `...` | arguments passed to `Seurat::FindClusters` | 20 | 21 | ## Value 22 | 23 | a `CellSpace` object with the cell clusters added to the `meta.data` slot 24 | -------------------------------------------------------------------------------- /man/docs/find_neighbors.md: -------------------------------------------------------------------------------- 1 | # `find_neighbors` 2 | 3 | ## Description 4 | 5 | Builds a nearest neighbor graph and shared nearest neighbor graph from the CellSpace embedding. 6 | 7 | ## Usage 8 | 9 | ``` r 10 | find_neighbors( 11 | object, 12 | n.neighbors = 30, 13 | emb = object@cell.emb, 14 | emb.name = "cells", 15 | ... 16 | ) 17 | ``` 18 | 19 | ## Arguments 20 | 21 | | Argument | Description | 22 | |---------------|-----------------------------------------------------------------------| 23 | | `object` | a `CellSpace` object | 24 | | `n.neighbors` | the number of nearest neighbors for the KNN algorithm | 25 | | `emb` | the embedding matrix used to create the nearest neighbor graphs | 26 | | `emb.name` | prefix for the graph names that will be added to the `neighbors` slot | 27 | | `...` | arguments passed to `Seurat::FindNeighbors` | 28 | 29 | ## Value 30 | 31 | a `CellSpace` object containing nearest neighbor and shared nearest neighbor graphs in the `neighbors` slot 32 | -------------------------------------------------------------------------------- /man/docs/merge_small_clusters.md: -------------------------------------------------------------------------------- 1 | # `merge_small_clusters` 2 | 3 | ## Description 4 | 5 | Merges cells from small clusters with the nearest clusters. 6 | 7 | ## Usage 8 | 9 | ``` r 10 | merge_small_clusters( 11 | object, 12 | clusters, 13 | min.cells = 10, 14 | graph = "cells_snn", 15 | seed = 1 16 | ) 17 | ``` 18 | 19 | ## Arguments 20 | 21 | | Argument | Description | 22 | |---------------------------------|---------------------------------------| 23 | | `object` | a `CellSpace` object | 24 | | `clusters` | a vector of cluster labels, or the name of a column in the `meta.data` slot containing cluster labels | 25 | | `min.cells` | any cluster with fewer cells than `min.cells` will be merged with the nearest cluster | 26 | | `graph` | a shared nearest neighbor graph, or the name of a graph in the `neighbors` slot, used to find clusters | 27 | 28 | ## Value 29 | 30 | new cluster labels 31 | -------------------------------------------------------------------------------- /man/docs/motif_embedding.md: -------------------------------------------------------------------------------- 1 | # `motif_embedding` 2 | 3 | ## Description 4 | 5 | Maps a motif to the embedding space. 6 | 7 | ## Usage 8 | 9 | ``` r 10 | motif_embedding(object, PWM) 11 | ``` 12 | 13 | ## Arguments 14 | 15 | | Argument | Description | 16 | |----------|--------------------------| 17 | | `object` | a `CellSpace` object | 18 | | `PWM` | `PFMatrix` or `PWMatrix` | 19 | 20 | ## Value 21 | 22 | a numerical vector containing the CellSpace embedding of the consensus sequence for `PWM` 23 | -------------------------------------------------------------------------------- /man/docs/run_UMAP.md: -------------------------------------------------------------------------------- 1 | # `run_UMAP` 2 | 3 | ## Description 4 | 5 | Computes a UMAP embedding from the CellSpace embedding. 6 | 7 | ## Usage 8 | 9 | ``` r 10 | run_UMAP(object, emb = object@cell.emb, graph = NULL, name = "cells_UMAP", ...) 11 | ``` 12 | 13 | ## Arguments 14 | 15 | | Argument | Description | 16 | |---------------------------------|---------------------------------------| 17 | | `object` | a `CellSpace` object | 18 | | `emb` | the embedding matrix used to compute the UMAP embedding | 19 | | `graph` | name of the nearest neighbor graph in the `neighbors` slot used to compute the UMAP embedding | 20 | | `name` | name of the lower-dimensional embedding that will be added to the `reductions` slot | 21 | | `...` | arguments passed to `Seurat::RunUMAP` | 22 | 23 | ## Value 24 | 25 | a `CellSpace` object containing a UMAP embedding in the `reductions` slot 26 | -------------------------------------------------------------------------------- /man/embedding_distance.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CellSpace.R 3 | \name{embedding_distance} 4 | \alias{embedding_distance} 5 | \title{embedding_distance} 6 | \usage{ 7 | embedding_distance(x, y = NULL, distance = c("cosine", "angular")) 8 | } 9 | \arguments{ 10 | \item{x}{an embedding matrix} 11 | 12 | \item{y}{an embedding matrix with compatible dimensions to \code{x}, or \code{NULL}, in which case \code{y=x}} 13 | 14 | \item{distance}{the distance metric, either 'cosine' or 'angular', to compute from the cosine similarity} 15 | } 16 | \value{ 17 | a matrix containing the distance between rows of \code{x} and \code{y}, computed from their cosine similarity 18 | } 19 | \description{ 20 | Computes distance in the embedding space based on cosine similarity. 21 | } 22 | -------------------------------------------------------------------------------- /man/find_clusters.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CellSpace.R 3 | \name{find_clusters} 4 | \alias{find_clusters} 5 | \title{find_clusters} 6 | \usage{ 7 | find_clusters(object, graph = "cells_snn", ...) 8 | } 9 | \arguments{ 10 | \item{object}{a \code{CellSpace} object} 11 | 12 | \item{graph}{name of the shared nearest neighbor graph in the \code{neighbors} slot used to find clusters} 13 | 14 | \item{...}{arguments passed to \code{Seurat::FindClusters}} 15 | } 16 | \value{ 17 | a \code{CellSpace} object with the cell clusters added to the \code{meta.data} slot 18 | } 19 | \description{ 20 | Finds clusters in a shared nearest neighbor graph built from the CellSpace embedding. 21 | } 22 | -------------------------------------------------------------------------------- /man/find_neighbors.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CellSpace.R 3 | \name{find_neighbors} 4 | \alias{find_neighbors} 5 | \title{find_neighbors} 6 | \usage{ 7 | find_neighbors( 8 | object, 9 | n.neighbors = 30, 10 | emb = object@cell.emb, 11 | emb.name = "cells", 12 | ... 13 | ) 14 | } 15 | \arguments{ 16 | \item{object}{a \code{CellSpace} object} 17 | 18 | \item{n.neighbors}{the number of nearest neighbors for the KNN algorithm} 19 | 20 | \item{emb}{the embedding matrix used to create the nearest neighbor graphs} 21 | 22 | \item{emb.name}{prefix for the graph names that will be added to the \code{neighbors} slot} 23 | 24 | \item{...}{arguments passed to \code{Seurat::FindNeighbors}} 25 | } 26 | \value{ 27 | a \code{CellSpace} object containing nearest neighbor and shared nearest neighbor graphs in the \code{neighbors} slot 28 | } 29 | \description{ 30 | Builds a nearest neighbor graph and shared nearest neighbor graph from the CellSpace embedding. 31 | } 32 | -------------------------------------------------------------------------------- /man/merge_small_clusters.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CellSpace.R 3 | \name{merge_small_clusters} 4 | \alias{merge_small_clusters} 5 | \title{merge_small_clusters} 6 | \usage{ 7 | merge_small_clusters( 8 | object, 9 | clusters, 10 | min.cells = 10, 11 | graph = "cells_snn", 12 | seed = 1 13 | ) 14 | } 15 | \arguments{ 16 | \item{object}{a \code{CellSpace} object} 17 | 18 | \item{clusters}{a vector of cluster labels, or the name of a column in the \code{meta.data} slot containing cluster labels} 19 | 20 | \item{min.cells}{any cluster with fewer cells than \code{min.cells} will be merged with the nearest cluster} 21 | 22 | \item{graph}{a shared nearest neighbor graph, or the name of a graph in the \code{neighbors} slot, used to find clusters} 23 | } 24 | \value{ 25 | new cluster labels 26 | } 27 | \description{ 28 | Merges cells from small clusters with the nearest clusters. 29 | } 30 | -------------------------------------------------------------------------------- /man/motif_embedding.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CellSpace.R 3 | \name{motif_embedding} 4 | \alias{motif_embedding} 5 | \title{motif_embedding} 6 | \usage{ 7 | motif_embedding(object, PWM) 8 | } 9 | \arguments{ 10 | \item{object}{a \code{CellSpace} object} 11 | 12 | \item{PWM}{\code{PFMatrix} or \code{PWMatrix}} 13 | } 14 | \value{ 15 | a numerical vector containing the CellSpace embedding of the consensus sequence for \code{PWM} 16 | } 17 | \description{ 18 | Maps a motif to the embedding space. 19 | } 20 | -------------------------------------------------------------------------------- /man/run_UMAP.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CellSpace.R 3 | \name{run_UMAP} 4 | \alias{run_UMAP} 5 | \title{run_UMAP} 6 | \usage{ 7 | run_UMAP(object, emb = object@cell.emb, graph = NULL, name = "cells_UMAP", ...) 8 | } 9 | \arguments{ 10 | \item{object}{a \code{CellSpace} object} 11 | 12 | \item{emb}{the embedding matrix used to compute the UMAP embedding} 13 | 14 | \item{graph}{name of the nearest neighbor graph in the \code{neighbors} slot used to compute the UMAP embedding} 15 | 16 | \item{name}{name of the lower-dimensional embedding that will be added to the \code{reductions} slot} 17 | 18 | \item{...}{arguments passed to \code{Seurat::RunUMAP}} 19 | } 20 | \value{ 21 | a \code{CellSpace} object containing a UMAP embedding in the \code{reductions} slot 22 | } 23 | \description{ 24 | Computes a UMAP embedding from the CellSpace embedding. 25 | } 26 | -------------------------------------------------------------------------------- /tutorial/README.md: -------------------------------------------------------------------------------- 1 | ## Pre-processing scATAC-seq data 2 | 3 | Learning a CellSpace embedding requires a cell-by-event matrix, where events are either an atlas of accessible peaks or genomic tiles, as well as the genomic sequences of the events. 4 | 5 | If you want to follow the tutorial using our example dataset, 'Small-scale human hematopoiesis dataset' [[Buenrostro *et al.*, 2018](https://doi.org/10.1016/j.cell.2018.03.074)], you can proceed to the next section and use the files in [data/](data/) after decompressing them: 6 | 7 | ``` bash 8 | gunzip data/*.gz 9 | ``` 10 | 11 | ### Using genomic tiles 12 | 13 | The most convenient way of pre-processing your scATAC-seq data from scratch to train a CellSpace embeddding is using [ArchR](https://www.archrproject.com) [[Granja *et al.*, 2021](https://doi.org/10.1038/s41588-021-00790-6)] and exporting its top variable tiles (500bp genomic bins); so you will not need to deal with various peak-calling strategies, their biases, and manually filtering the peak atlas. 14 | 15 | The scripts used to pre-process the 'Small-scale human hematopoiesis dataset' from scratch, as described in [our manuscript](https://www.biorxiv.org/content/10.1101/2022.05.02.490310v4.full.pdf), are available [here](variable-tiles/README.md). Once you have an ArchR object containing the pre-processed scATAC-seq data (named `archr.obj` here), you can perform iterative LSI and prepare the input files to train a CellSpace embedding. 16 | 17 | Load required libraries: 18 | 19 | ``` r 20 | library(ArchR) 21 | library(Matrix) 22 | library(dplyr) 23 | library(GenomicRanges) 24 | library(Biostrings) 25 | library(BSgenome.Hsapiens.UCSC.hg19) 26 | ``` 27 | 28 | Perform iterative LSI to identify top 50K variable genomic tiles: 29 | 30 | ``` r 31 | archr.obj <- addIterativeLSI(archr.obj, iterations = 5, varFeatures = 50000) 32 | ``` 33 | 34 | Extract the accessibility matrix for top variable tiles: 35 | 36 | ``` r 37 | tile.mtx <- assays(getMatrixFromProject(archr.obj, useMatrix = "TileMatrix", binarize = T))$TileMatrix 38 | var.tiles <- archr.obj@reducedDims$IterativeLSI$LSIFeatures[, -3] 39 | var.tile.mtx <- tile.mtx[match(var.tiles, tile.mtx@elementMetadata), match(archr.obj$cellNames, colnames(tile.mtx))] 40 | ``` 41 | 42 | Prepare input files for CellSpace: 43 | 44 | ``` r 45 | genome <- BSgenome.Hsapiens.UCSC.hg19 46 | tileSize <- archr.obj@reducedDims$IterativeLSI$tileSize 47 | GRanges( 48 | seqinfo = seqinfo(genome), 49 | seqnames = var.tiles$seqnames, strand = "+", 50 | ranges = IRanges(start = var.tiles$start, width = tileSize) 51 | ) %>% getSeq(x = genome) %>% writeXStringSet(filepath = "data/var_tiles.fa") 52 | gsub("^.+#", "", archr.obj$cellNames) %>% 53 | write(ncolumns = 1, file = "data/cell-names.txt") 54 | writeMM(t(var.tile.mtx), file = "data/cell_by_tile-counts.mtx") 55 | ``` 56 | 57 | ### Using a peak atlas 58 | 59 | You can use CellSpace regardless of your pre-processing pipeline, but we suggest filtering lower-quality cells/peaks; in particular, peaks that are not accessible in any cells must be excluded! 60 | 61 | We also suggest removing promoter-proximal peaks, since it's been demonstrated that promoter-distal regions show significantly higher cell-type-specificity than promoter-proximal regions [[Preissl *et al.*, 2018](https://doi.org/10.1038/s41593-018-0079-3)][[Chung *et al.*, 2019](https://doi.org/10.1016/j.celrep.2019.08.089)]. 62 | 63 | Additionally, using top variable peaks significantly reduces training time, while preserving or potentially improving the quality of the embedding. We have prepared [an **R** script](variable-peaks/IterativeLSI.R) based on [ArchR's iterative LSI method](https://www.archrproject.com/bookdown/iterative-latent-semantic-indexing-lsi.html) to identify top variable peaks, without the need for an ArchR object and its ArrowFiles. Please refer to [ArchR::addIterativeLSI](https://www.archrproject.com/reference/addIterativeLSI.html) for a detailed description of parameters. 64 | 65 | The script used to filter the peak atlas of the 'Small-scale human hematopoiesis dataset', as described in [our manuscript](https://www.biorxiv.org/content/10.1101/2022.05.02.490310v4.full.pdf), is available [here](variable-peaks/filter-peaks.R). Once you have a peak-by-cell sparse count matrix of class **dgCMatrix** (named `counts` here) and its corresponding peak set of class **GRanges** (named `peaks.gr` here), you can perform iterative LSI and prepare the input files to train a CellSpace embedding. 66 | 67 | Load required libraries and functions: 68 | 69 | ``` r 70 | library(ArchR) 71 | library(Matrix) 72 | library(dplyr) 73 | library(GenomicRanges) 74 | library(Biostrings) 75 | library(BSgenome.Hsapiens.UCSC.hg19) 76 | source("variable-peaks/IterativeLSI.R") 77 | ``` 78 | 79 | Perform iterative LSI to identify top 50K variable peaks: 80 | 81 | ``` r 82 | itLSI <- IterativeLSI(counts, iterations = 5, varFeatures = 50000) 83 | var.peaks <- itLSI$LSIFeatures$idx 84 | ``` 85 | 86 | Prepare input files for CellSpace: 87 | 88 | ``` r 89 | genome <- BSgenome.Hsapiens.UCSC.hg19 90 | getSeq(peaks.gr[var.peaks], x = genome) %>% 91 | writeXStringSet(filepath = "data/var_peaks.fa") 92 | write(colnames(counts), ncolumns = 1, file = "data/cell-names.txt") 93 | writeMM(t(counts[var.peaks, ]), file = "data/cell_by_peak-counts.mtx") 94 | ``` 95 | 96 | ## Training a CellSpace model 97 | 98 | We will continue this tutorial using the variable genomic tiles and the corresponding count matrix prepared in the previous section. 99 | 100 | Train a CellSpace embedding: 101 | 102 | ``` bash 103 | CellSpace \ 104 | -output data/CellSpace_embedding-var_tiles \ 105 | -cpMat data/cell_by_tile-counts.mtx \ 106 | -peaks data/var_tiles.fa 107 | ``` 108 | 109 | The result is a tab-delimited file containing the CellSpace embedding matrix for cells and DNA *k*-mers. 110 | 111 | ``` bash 112 | CellSpace ... 113 | 114 | The following arguments are mandatory: 115 | -output prefix of the output 116 | -cpMat sparse cell by peak/tile count matrix (.mtx) 117 | -peaks multi-fasta file containing peak/tile DNA sequences with the order they appear in the corresponding count matrix (.fa) 118 | 119 | The following arguments are optional: 120 | -dim size of embedding vectors [default=30] 121 | -ngrams max length of k-mer ngram [default=3] 122 | -k k-mer length [default=8] 123 | -sampleLen length of the sequences randomly sampled from the peak/tile DNA sequences (integer or 'given') [default=150] 124 | -exmpPerPeak number of training examples per peak/tile [default=20] 125 | -epoch number of epochs [default=50] 126 | -margin margin parameter in hinge loss [default=0.05] 127 | -bucket size of the hashing map for n-grams [default=2000000] 128 | -label cell labels prefix [default='__label__'] 129 | -lr learning rate [default=0.01] 130 | -maxTrainTime max train time (seconds) [default=8640000] 131 | -negSearchLimit number of negative labels sampled per dataset [default=50] 132 | -maxNegSamples max number of negative labels in a batch update [default=10] 133 | -p the embedding of an entity equals the sum of its M feature embedding vectors devided by M^p [default=0.5] 134 | -initRandSd initial values of embeddings are randomly generated from normal distribution with mean=0 and standard deviation=initRandSd [default=0.001] 135 | -batchSize size of mini batch in training [default=5] 136 | -saveIntermediates save intermediate models or only the final model (integer or 'final') [default='final'] 137 | -thread number of threads [default=10] 138 | ``` 139 | 140 | - In order to train CellSpace on multiple datasets, processed with respect to their own peak atlases, their input files must be provided as a list: `-cpMat data1-counts.mtx data2-counts.mtx -peaks data1-peaks.fa data2-peaks.fa`\ 141 | CellSpace will avoid pushing cells from different datasets away from each other by sampling the peak, the positive cell, and the negative cells for each training example from the same dataset. 142 | - To get more compact clusters and an overall better embedding, specially for larger or more heterogeneous data sets, we suggest increasing `-epoch`. Increasing the size of *N*-grams will have a similar effect; however, we suggest the default `-ngrams 3` in most cases. 143 | - Running time will increase with the number of of events, so we suggest using top variable peaks or genomic tiles instead of a large peak atlas. 144 | - If `-sampleLen given` is specified, the entire peak/tile sequence is used for every training example; otherwise, a fixed-size sequence is randomly sampled from the event for each training example. 145 | - if `-saveIntermediates final` is specified, only the final model will be saved; otherwise, if `-saveIntermediates M` is specified, the model will be saved after every `M` epochs. 146 | - You can increase `-thread` to speed up the training process. 147 | 148 | ## Downstream analysis 149 | 150 | Load required libraries and functions: 151 | 152 | ``` r 153 | library(CellSpace) 154 | library(TFBSTools) 155 | library(JASPAR2020) 156 | library(dplyr) 157 | source("plot-functions.R") 158 | ``` 159 | 160 | Load CellSpace embedding: 161 | 162 | ``` r 163 | cell.idx <- readLines("data/cell-names.txt") 164 | sample.info <- read.table( 165 | file = "data/sample-info.tsv", sep = "\t", 166 | header = T, check.names = F, row.names = 1 167 | ) 168 | pal <- readRDS("data/palette.rds") 169 | sample.info$Cell_type <- factor(sample.info$Cell_type, levels = names(pal$Cell_type)) 170 | 171 | cso <- CellSpace( 172 | project = "tutorial", 173 | emb.file = "data/CellSpace_embedding-var_tiles.tsv", 174 | meta.data = sample.info[cell.idx, ] 175 | ) 176 | ``` 177 | 178 | The CellSpace embedding matrices are stored in `cso@cell.emb` and `cso@kmer.emb`. 179 | 180 | Note that the order of cells in `meta.data` and `cell.names`, if provided to the constructor of a CellSpace object, must be the same as the count matrix provided to train the CellSpace model. If `cell.names` is not provided, the row names of `meta.data` will be used. 181 | 182 | Create a shared nearest-neighbor graph for the cells, cluster the cells, and compute a UMAP embedding for cells, using wrappers of [Seurat](https://github.com/satijalab/seurat/tree/master) [[Hao, *et al.*, Cell 2021](https://doi.org/10.1016/j.cell.2021.04.048)] functions: 183 | 184 | ``` r 185 | cso <- find_neighbors(cso, n.neighbors = 20) %>% 186 | find_clusters(resolution = c(0.8, 1.2)) %>% 187 | run_UMAP(n.neighbors = 20, min.dist = 0.2, spread = 1) 188 | ``` 189 | 190 | The nearest-neighbor (NN) and shared nearest-neighbor (SNN) graphs are stored in `cso@neighbors`, the clusters are added to `cso@meta.data`, and the UMAP is stored in `cso@reductions`. 191 | 192 | `run_UMAP` uses the CellSpace embedding of cells by default; alternatively, you can provide the name of a pre-computed nearest neighbor graph in `cso@neighbors` to compute the UMAP embedding from. 193 | 194 | Let's visualize the results: 195 | 196 | ``` r 197 | plot.groups( 198 | cso, vis = "cells_UMAP", groups = "Cell_type", 199 | pal = pal$Cell_type, add.labels = T 200 | ) 201 | plot.groups( 202 | cso, vis = "cells_UMAP", groups = "Donor", 203 | pal = pal$Donor, add.labels = F 204 | ) 205 | plot.groups( 206 | cso, vis = "cells_UMAP", groups = "Clusters.res_1.2", 207 | pal = pal$Cluster, add.labels = T 208 | ) + labs(color = "CellSpace\nCluster") 209 | ``` 210 | 211 | 212 | 213 | You can map any valid DNA sequence that is not shorter than the *k*-mers (8bp by default) to the same embedding space, using `DNA_sequence_embedding`. The embedding of a transcription factor motif is computed using its consensus sequence: 214 | 215 | ``` r 216 | pwm.list <- readRDS("data/PWM-list.rds") # example CIS-BP TF motifs 217 | motif.emb <- lapply(pwm.list, function(pwm){ 218 | motif_embedding(cso, PWM = pwm) 219 | }) %>% do.call(what = rbind) 220 | ``` 221 | 222 | The cosine similarity/distance of CellSpace embedding matrices can be computed with `cosine_similarity`/`embedding_distance`; for example, the similarity between a TF motif and cell embedding vectors in the latent space represents TF activity scores: 223 | 224 | ``` r 225 | motif.score <- cosine_similarity(x = cso@cell.emb, y = motif.emb) %>% scale() 226 | ``` 227 | 228 | Let's visualize the scores: 229 | 230 | ``` r 231 | md <- cso@meta.data[, c("Cell_type", "Donor", "Clusters.res_1.2")] 232 | colnames(md)[3] <- "Cluster" 233 | plot.scores(t(motif.score), cell.annot = md, pal = pal, column_split = cso$Cell_type) 234 | ``` 235 | 236 | 237 | 238 | You can automatically perform these steps using `add_motif_db`: 239 | 240 | ``` r 241 | jaspar <- getMatrixSet(JASPAR2020@db, opts = list(species = "Homo sapiens", collection = "CORE")) 242 | names(jaspar) <- TFBSTools::name(jaspar) 243 | cso <- add_motif_db(cso, motif.db = jaspar, db.name = "JASPAR2020") 244 | ``` 245 | 246 | The motif embedding matrix and the corresponding similarity *Z*-scores are stored in `cso@motif.emb` and `cso@misc`. 247 | 248 | Let's visualize cells and TF motifs in the same space: 249 | 250 | ``` r 251 | cso <- run_UMAP( 252 | cso, name = "cell&TFs_UMAP", 253 | emb = rbind(cso@cell.emb, motif.emb), 254 | n.neighbors = 50, min.dist = 0.2, spread = 1 255 | ) 256 | 257 | TFs <- rownames(motif.emb) 258 | TF.umap <- data.frame(cso@reductions$`cell&TFs_UMAP`)[TFs, ] 259 | plot.groups( 260 | cso, vis = "cell&TFs_UMAP", groups = "Cell_type", 261 | pal = pal$Cell_type, add.labels = F 262 | ) + 263 | geom_point(data = TF.umap, color = "black", shape = 17, size = 1.5) + 264 | geom_label_repel( 265 | mapping = aes(label = TFs), data = TF.umap, 266 | size = 3, fontface = "bold", min.segment.length = 0, 267 | max.overlaps = Inf, segment.color = "black" 268 | ) 269 | ``` 270 | 271 | 272 | -------------------------------------------------------------------------------- /tutorial/data/CellSpace_embedding-var_tiles.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zakieh-tayyebi/CellSpace/c5dc59b1a3c843596834acab4d4e421e4cd59343/tutorial/data/CellSpace_embedding-var_tiles.tsv.gz -------------------------------------------------------------------------------- /tutorial/data/PWM-list.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zakieh-tayyebi/CellSpace/c5dc59b1a3c843596834acab4d4e421e4cd59343/tutorial/data/PWM-list.rds -------------------------------------------------------------------------------- /tutorial/data/cell_by_peak-counts.mtx.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zakieh-tayyebi/CellSpace/c5dc59b1a3c843596834acab4d4e421e4cd59343/tutorial/data/cell_by_peak-counts.mtx.gz -------------------------------------------------------------------------------- /tutorial/data/cell_by_tile-counts.mtx.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zakieh-tayyebi/CellSpace/c5dc59b1a3c843596834acab4d4e421e4cd59343/tutorial/data/cell_by_tile-counts.mtx.gz -------------------------------------------------------------------------------- /tutorial/data/palette.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zakieh-tayyebi/CellSpace/c5dc59b1a3c843596834acab4d4e421e4cd59343/tutorial/data/palette.rds -------------------------------------------------------------------------------- /tutorial/data/var_peaks.fa.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zakieh-tayyebi/CellSpace/c5dc59b1a3c843596834acab4d4e421e4cd59343/tutorial/data/var_peaks.fa.gz -------------------------------------------------------------------------------- /tutorial/data/var_tiles.fa.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zakieh-tayyebi/CellSpace/c5dc59b1a3c843596834acab4d4e421e4cd59343/tutorial/data/var_tiles.fa.gz -------------------------------------------------------------------------------- /tutorial/plot-functions.R: -------------------------------------------------------------------------------- 1 | library(ggplot2) 2 | library(ggrepel) 3 | library(cowplot) 4 | library(ComplexHeatmap) 5 | library(circlize) 6 | 7 | plot.groups <- function( 8 | object, vis, groups, 9 | cell.idx = rownames(object@cell.emb), 10 | mapping = aes(x = UMAP_1, y = UMAP_2), 11 | add.labels = T, pal = NULL, 12 | point.size = 1.5, label.size = 5, 13 | et = element_text(size = 15), 14 | sx = 0, sy = 0, nx = NULL, ny = NULL 15 | ){ 16 | vis <- data.frame(cso@reductions[[vis]])[cell.idx, ] 17 | group.name <- groups 18 | groups <- cso@meta.data[cell.idx, groups] 19 | eb <- element_blank() 20 | gp <- ggplot(mapping = mapping) + 21 | geom_point(aes(color = groups), vis, size = point.size, alpha = 0.9) + 22 | labs(color = group.name) + theme_classic() + 23 | theme( 24 | axis.ticks = eb, axis.text = eb, 25 | axis.title = et, title = et, 26 | legend.title = et, legend.text = et 27 | ) 28 | if(add.labels){ 29 | G <- sort(unique(groups)) 30 | NX <- rep(sx, length(G)); names(NX) <- G 31 | if(!is.null(nx)) NX[names(nx)] <- nx 32 | NY <- rep(sy, length(G)); names(NY) <- G 33 | if(!is.null(ny)) NY[names(ny)] <- ny 34 | vis.gp <- data.frame(matrix( 35 | NA, nrow = length(G), ncol = ncol(vis), 36 | dimnames = list(G, colnames(vis)) 37 | )) 38 | for(g in G) 39 | for(col in colnames(vis)) 40 | vis.gp[g, col] <- median(vis[groups == g, col]) 41 | gp <- gp + geom_label(aes(color = G, label = G), vis.gp, show.legend = F, 42 | size = label.size, nudge_x = NX, nudge_y = NY) 43 | } 44 | gp + scale_color_manual(values = pal) 45 | } 46 | 47 | plot.scores <- function(motif.score, cell.annot, pal, ...){ 48 | Heatmap( 49 | motif.score, name = "Similarity\nz-score", 50 | col = colorRamp2(c(-1, 0, 1), c("#007e8c", "white", "#a3005c")), 51 | top_annotation = columnAnnotation(df = cell.annot, col = pal), 52 | column_title_rot = 90, show_column_names = F, 53 | row_title_rot = 0, cluster_rows = F, 54 | use_raster = T, ... 55 | ) 56 | } 57 | 58 | -------------------------------------------------------------------------------- /tutorial/plots/UMAP-cells.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zakieh-tayyebi/CellSpace/c5dc59b1a3c843596834acab4d4e421e4cd59343/tutorial/plots/UMAP-cells.png -------------------------------------------------------------------------------- /tutorial/plots/UMAP-cells_and_TFs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zakieh-tayyebi/CellSpace/c5dc59b1a3c843596834acab4d4e421e4cd59343/tutorial/plots/UMAP-cells_and_TFs.png -------------------------------------------------------------------------------- /tutorial/plots/motif-scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zakieh-tayyebi/CellSpace/c5dc59b1a3c843596834acab4d4e421e4cd59343/tutorial/plots/motif-scores.png -------------------------------------------------------------------------------- /tutorial/variable-peaks/IterativeLSI.R: -------------------------------------------------------------------------------- 1 | library(ArchR) 2 | library(Matrix) 3 | library(matrixStats) 4 | library(irlba) 5 | library(dplyr) 6 | 7 | IterativeLSI <- function( 8 | counts, # a peak-by-cell count matrix of class dgCMatrix 9 | outDir = "IterativeLSI/", # where to save the LSI embedding and clustering results of every iteration 10 | iterations = 2, # number of LSI iterations (at least 2) 11 | varFeatures = 25000, # number of top variable peaks 12 | groupNames = NULL, # cell labels provided if you want the random sampling of cells to preserve group proportions 13 | sampleCellsPre = 10000, # An integer specifying the number of cells to sample in iterations prior to the last in order to perform a sub-sampled LSI and sub-sampled clustering 14 | sampleCellsFinal = NULL, # An integer specifying the number of cells to sample in order to perform a sub-sampled LSI in final iteration 15 | filterQuantile = 0.995, 16 | clusterParams = list(resolution = 2, maxClusters = 10), 17 | LSIMethod = 2, 18 | scaleTo = 10000, 19 | dimsToUse = 1:30, 20 | scaleDims = T, 21 | outlierQuantiles = c(0.02, 0.98), 22 | selectionMethod = "var", 23 | corCutOff = 0.75, 24 | filterBias = T, 25 | seed = 1 26 | ){ 27 | if(class(counts) != "dgCMatrix") 28 | stop("'counts' must be a sparse matrix of class 'dgCMatrix'!") 29 | 30 | if(is.null(colnames(counts)) || any(duplicated(colnames(counts))) || 31 | is.null(rownames(counts)) || any(duplicated(rownames(counts)))) 32 | stop("'counts' must be an event-by-cell matrix with unique row and column names!") 33 | 34 | if(is.null(groupNames)) 35 | groupNames <- rep("G1", ncol(counts)) 36 | 37 | set.seed(seed) 38 | dir.create(outDir, showWarnings = F, recursive = T) 39 | 40 | # Handle number of cells to sample ##### 41 | cellNames <- colnames(counts) 42 | if(!is.null(sampleCellsPre)){ 43 | if(length(cellNames) < sampleCellsPre){ 44 | sampleCellsPre <- NULL 45 | } 46 | } 47 | if(!is.null(sampleCellsFinal)){ 48 | if(length(cellNames) < sampleCellsFinal){ 49 | sampleCellsFinal <- NULL 50 | } 51 | } 52 | 53 | # LSI Iteration 1 ##### 54 | nFeature <- varFeatures 55 | rmTop <- floor((1 - filterQuantile) * nrow(counts)) 56 | message("\nSelecting top ", varFeatures, " most accessible features (excluding top ", rmTop, ") for LSI iteration 1.") 57 | 58 | totalAcc <- data.frame( 59 | rowSums = Matrix::rowSums(counts), 60 | idx = 1:nrow(counts), 61 | row.names = rownames(counts) 62 | ) 63 | topIdx <- head(order(totalAcc$rowSums, decreasing = T), nFeature + rmTop)[-(1:rmTop)] 64 | topFeatures <- totalAcc[sort(topIdx), ] 65 | 66 | cellDF <- data.frame( 67 | cellDepth = log10(Matrix::colSums(counts) + 1), 68 | groupNames = groupNames, 69 | idx = 1:ncol(counts), 70 | row.names = colnames(counts) 71 | ) 72 | 73 | j <- 1 74 | message("Running LSI iteration ", j, ".") 75 | 76 | outLSI <- .LSIPartialMatrix( 77 | counts = counts, 78 | featureDF = topFeatures, 79 | cellDF = cellDF, 80 | LSIMethod = LSIMethod, 81 | dimsToUse = dimsToUse, 82 | scaleTo = scaleTo, 83 | outlierQuantiles = outlierQuantiles, 84 | sampleCells = sampleCellsPre, 85 | projectAll = F 86 | ) 87 | outLSI$scaleDims <- scaleDims 88 | saveRDS(outLSI, file = paste0(outDir, "/iteration", j, "-LSI.rds")) 89 | 90 | message("Identify clusters after LSI iteration ", j, ".") 91 | clusterDF <- .LSICluster( 92 | outLSI = outLSI, 93 | cellDF = cellDF, 94 | filterBias = filterBias, 95 | dimsToUse = dimsToUse, 96 | scaleDims = scaleDims, 97 | corCutOff = corCutOff, 98 | clusterParams = clusterParams, 99 | j = j 100 | ) 101 | saveRDS(clusterDF, file = paste0(outDir, "/iteration", j, "-clusters.rds")) 102 | 103 | # LSI Iteration 2+ ##### 104 | variableFeatures <- topFeatures 105 | 106 | while(j < iterations){ 107 | j <- j + 1 108 | 109 | message("\nSelecting top ", varFeatures, " most variable features for LSI iteration ", j, ".") 110 | variableFeatures <- .identifyVarFeatures( 111 | counts = counts, 112 | outLSI = outLSI, 113 | clusterDF = clusterDF, 114 | prevFeatures = variableFeatures, 115 | scaleTo = scaleTo, 116 | totalAcc = totalAcc, 117 | firstSelection = firstSelection, 118 | selectionMethod = selectionMethod, 119 | varFeatures = varFeatures 120 | ) 121 | 122 | message("Running LSI iteration ", j, ".") 123 | outLSI <- .LSIPartialMatrix( 124 | counts = counts, 125 | featureDF = variableFeatures, 126 | cellDF = cellDF, 127 | LSIMethod = LSIMethod, 128 | scaleTo = scaleTo, 129 | dimsToUse = dimsToUse, 130 | outlierQuantiles = outlierQuantiles, 131 | sampleCells = if(j != iterations) sampleCellsPre else sampleCellsFinal, 132 | projectAll = j == iterations 133 | ) 134 | outLSI$scaleDims <- scaleDims 135 | saveRDS(outLSI, file = paste0(outDir, "/iteration", j, "-LSI.rds")) 136 | 137 | clusterDF <- .LSICluster( 138 | outLSI = outLSI, 139 | cellDF = cellDF, 140 | dimsToUse = dimsToUse, 141 | scaleDims = scaleDims, 142 | corCutOff = corCutOff, 143 | filterBias = filterBias, 144 | j = j, 145 | clusterParams = clusterParams 146 | ) 147 | saveRDS(clusterDF, file = paste0(outDir, "/iteration", j, "-clusters.rds")) 148 | } 149 | 150 | return(outLSI) 151 | } 152 | 153 | .identifyVarFeatures <- function( 154 | counts, 155 | outLSI, 156 | clusterDF, 157 | prevFeatures, 158 | totalAcc, 159 | scaleTo, 160 | firstSelection, 161 | selectionMethod, 162 | varFeatures 163 | ){ 164 | 165 | groupMat <- sapply(SimpleList(split(clusterDF$cellNames, clusterDF$clusters)), function(group.idx){ 166 | Matrix::rowSums(counts[totalAcc$idx, group.idx]) 167 | }, simplify = T, USE.NAMES = T) 168 | 169 | nFeature <- varFeatures 170 | if(tolower(selectionMethod) == "var"){ 171 | 172 | # Log-Normalize 173 | groupMat <- log2(t(t(groupMat) / colSums(groupMat)) * scaleTo + 1) 174 | feature.var <- matrixStats::rowVars(groupMat) 175 | idx <- sort(head(order(feature.var, decreasing = T), nFeature)) 176 | variableFeatures <- totalAcc[idx, ] 177 | 178 | 179 | } else if(tolower(selectionMethod) == "vmr"){ 180 | 181 | # Variance-to-Mean Ratio 182 | feature.vmr <- matrixStats::rowVars(groupMat) / rowMeans(groupMat) 183 | idx <- sort(head(order(feature.vmr, decreasing = T), nFeature)) 184 | variableFeatures <- totalAcc[idx, ] 185 | 186 | 187 | } else stop("Feature selection method is not valid!") 188 | 189 | return(variableFeatures) 190 | } 191 | 192 | .LSIPartialMatrix <- function( 193 | counts, 194 | featureDF, 195 | cellDF, 196 | LSIMethod, 197 | dimsToUse, 198 | scaleTo, 199 | outlierQuantiles, 200 | sampleCells, 201 | projectAll, 202 | projection.steps = 10 203 | ){ 204 | 205 | if(is.null(sampleCells)){ 206 | 207 | message("- Computing LSI for all ", nrow(cellDF), " cells.") 208 | 209 | # Construct Matrix 210 | mat <- counts[featureDF$idx, cellDF$idx] 211 | if(!all(rownames(cellDF) == colnames(mat))) 212 | stop("Names of cells don't match!") 213 | 214 | # Perform LSI 215 | outLSI <- .computeLSI( 216 | mat = mat, 217 | LSIMethod = LSIMethod, 218 | scaleTo = scaleTo, 219 | nDimensions = max(dimsToUse), 220 | outlierQuantiles = outlierQuantiles 221 | ) 222 | 223 | } 224 | else{ 225 | 226 | message("- Computing partial LSI for ", sampleCells, " sampled cells.") 227 | 228 | sampledCellNames <- .sampleBySample( 229 | cellNames = rownames(cellDF), 230 | sampleNames = cellDF$groupNames, 231 | cellDepth = cellDF$cellDepth, 232 | sampleCells = sampleCells, 233 | outlierQuantiles = outlierQuantiles, 234 | factor = 2 235 | ) 236 | sampledCells <- cellDF[sampledCellNames, ] 237 | mat <- counts[featureDF$idx, sampledCells$idx] 238 | if(!all(rownames(sampledCells) == colnames(mat))) 239 | stop("Names of sampled cells don't match!") 240 | 241 | print(table(sampledCells$groupNames)) 242 | 243 | # Perform LSI on Partial Sampled Matrix 244 | outLSI <- .computeLSI( 245 | mat = mat, 246 | LSIMethod = LSIMethod, 247 | scaleTo = scaleTo, 248 | nDimensions = max(dimsToUse), 249 | outlierQuantiles = outlierQuantiles 250 | ) 251 | 252 | if(projectAll){ 253 | 254 | message("- Projecting LSI in ", projection.steps, " steps for the rest of the cells.") 255 | rest.idx <- setdiff(cellDF$idx, sampledCells$idx) 256 | steps <- c(floor(seq(1, length(rest.idx), length.out = projection.steps + 1))[-(projection.steps + 1)], length(rest.idx) + 1) 257 | pLSI <- lapply(1:projection.steps, function(i){ 258 | message("-- Step ", i, " with ", steps[i + 1] - steps[i], " cells:") 259 | first <- steps[i] 260 | last <- steps[i + 1] - 1 261 | .projectLSI(mat = counts[featureDF$idx, rest.idx[first:last]], LSI = outLSI) 262 | }) %>% Reduce("rbind", .) 263 | matSVD <- rbind(outLSI$matSVD, pLSI) 264 | 265 | outLSI$exlcude.projection <- !(rownames(cellDF) %in% rownames(pLSI)) 266 | outLSI$matSVD <- matSVD[rownames(cellDF), ] 267 | } 268 | } 269 | 270 | outLSI$LSIFeatures <- featureDF 271 | outLSI$corToDepth <- list( 272 | scaled = abs(cor(.scaleDims(outLSI[[1]]), cellDF[rownames(outLSI[[1]]), "cellDepth"]))[, 1], 273 | none = abs(cor(outLSI[[1]], cellDF[rownames(outLSI[[1]]), "cellDepth"]))[, 1] 274 | ) 275 | 276 | return(outLSI) 277 | } 278 | 279 | .computeLSI <- function( 280 | mat, 281 | LSIMethod, 282 | scaleTo, 283 | nDimensions, 284 | outlierQuantiles 285 | ){ # TF IDF LSI adapted from flyATAC 286 | 287 | # Binarize Matrix 288 | mat@x[mat@x > 0] <- 1 289 | 290 | message("-- Cleaning up the matrix.") 291 | 292 | # Clean up zero columns 293 | colSm <- Matrix::colSums(mat) 294 | exclude <- colSm == 0 295 | if(any(exclude)){ 296 | mat <- mat[, !exclude, drop = F] 297 | colSm <- colSm[!exclude] 298 | } 299 | 300 | # Remove outlying columns 301 | cn <- colnames(mat) 302 | filterOutliers <- 0 303 | if(!is.null(outlierQuantiles)){ 304 | qCS <- quantile(colSm, probs = c(min(outlierQuantiles), max(outlierQuantiles))) 305 | idxOutlier <- which(colSm <= qCS[1] | colSm >= qCS[2]) 306 | if(length(idxOutlier) > 0){ 307 | matO <- mat[, idxOutlier, drop = F] 308 | mat <- mat[, -idxOutlier, drop = F] 309 | mat2 <- mat[, 1:10, drop = F] # A 2nd Matrix to Check Projection is Working 310 | colSm <- colSm[-idxOutlier] 311 | filterOutliers <- 1 312 | } 313 | } 314 | 315 | # Clean up zero rows 316 | rowSm <- Matrix::rowSums(mat) 317 | exclude2 <- rowSm == 0 318 | if(any(exclude2)){ 319 | mat <- mat[!exclude2, ] 320 | rowSm <- rowSm[!exclude2] 321 | } 322 | 323 | message("-- Computing TF-IDF.") 324 | 325 | # TF - Normalize 326 | mat@x <- mat@x / rep.int(colSm, Matrix::diff(mat@p)) 327 | 328 | # TF-IDF 329 | if(LSIMethod == 1 | tolower(LSIMethod) == "tf-log(idf)"){ #Adapted from Casanovich et al. 330 | 331 | #LogIDF 332 | idf <- as(log(1 + ncol(mat) / rowSm), "sparseVector") 333 | 334 | #TF-LogIDF 335 | mat <- as(Matrix::Diagonal(x=as.vector(idf)), "sparseMatrix") %*% mat 336 | 337 | 338 | } else if(LSIMethod == 2 | tolower(LSIMethod) == "log(tf-idf)"){ #Adapted from Stuart et al. 339 | 340 | #IDF 341 | idf <- as(ncol(mat) / rowSm, "sparseVector") 342 | 343 | #TF-IDF 344 | mat <- as(Matrix::Diagonal(x=as.vector(idf)), "sparseMatrix") %*% mat 345 | 346 | #Log transform TF-IDF 347 | mat@x <- log(mat@x * scaleTo + 1) 348 | 349 | 350 | } else if(LSIMethod == 3 | tolower(LSIMethod) == "log(tf-log(idf))"){ 351 | 352 | #LogTF 353 | mat@x <- log(mat@x + 1) 354 | 355 | #LogIDF 356 | idf <- as(log(1 + ncol(mat) / rowSm), "sparseVector") 357 | 358 | #TF-IDF 359 | mat <- as(Matrix::Diagonal(x=as.vector(idf)), "sparseMatrix") %*% mat 360 | 361 | 362 | } else stop("Invalid LSI method!") 363 | 364 | message("-- Calculating SVD and LSI.") 365 | # Calculate SVD then LSI 366 | svd <- irlba::irlba(mat, nDimensions, nDimensions) 367 | svdDiag <- matrix(0, nrow = nDimensions, ncol = nDimensions) 368 | diag(svdDiag) <- svd$d 369 | matSVD <- t(svdDiag %*% t(svd$v)) 370 | rownames(matSVD) <- colnames(mat) 371 | colnames(matSVD) <- paste0("LSI", 1:ncol(matSVD)) 372 | 373 | # Return Object 374 | out <- SimpleList( 375 | matSVD = matSVD, 376 | rowSm = rowSm, 377 | nCol = length(colSm), 378 | exclude = exclude, 379 | exclude2 = exclude2, 380 | svd = svd, 381 | scaleTo = scaleTo, 382 | nDimensions = nDimensions, 383 | LSIMethod = LSIMethod, 384 | outliers = NA 385 | ) 386 | 387 | if(filterOutliers == 1){ 388 | message("-- Checking if LSI projection works.") 389 | # Quick Check LSI-Projection Works 390 | pCheck <- .projectLSI(mat = mat2, LSI = out) 391 | pCheck2 <- out[[1]][rownames(pCheck), ] 392 | pCheck3 <- lapply(1:ncol(pCheck), function(x){ 393 | cor(pCheck[, x], pCheck2[, x]) 394 | }) %>% unlist 395 | if(min(pCheck3) < 0.95) stop("cor<0.95 of re-projection!") 396 | 397 | message("-- Project LSI for ", ncol(matO), " outlying cells.") 398 | # Project LSI Outliers 399 | out$outliers <- colnames(matO) 400 | outlierLSI <- .projectLSI(mat = matO, LSI = out) 401 | allLSI <- rbind(out[[1]], outlierLSI) 402 | allLSI <- allLSI[cn, , drop = F] # Re-Order Correctly to original 403 | out[[1]] <- allLSI 404 | } 405 | 406 | return(out) 407 | } 408 | 409 | .projectLSI <- function( 410 | mat, 411 | LSI, 412 | returnModel = F 413 | ){ 414 | 415 | # Get Same Features 416 | if(any(LSI$exclude2)) mat <- mat[!LSI$exclude2, ] 417 | 418 | # Binarize Matrix 419 | mat@x[mat@x > 0] <- 1 420 | 421 | message("--- Cleaning up the matrix.") 422 | 423 | # Clean up zero columns 424 | colSm <- Matrix::colSums(mat) 425 | exclude <- colSm == 0 426 | if(any(exclude)){ 427 | mat <- mat[, !exclude] 428 | colSm <- colSm[!exclude] 429 | } 430 | 431 | message("--- Computing TF-IDF.") 432 | 433 | # TF - Normalize 434 | mat@x <- mat@x / rep.int(colSm, Matrix::diff(mat@p)) 435 | 436 | # TF-IDF 437 | if(LSI$LSIMethod == 1 | tolower(LSI$LSIMethod) == "tf-log(idf)"){ #Adapted from Casanovich et al. 438 | 439 | #LogIDF 440 | idf <- as(log(1 + LSI$nCol / LSI$rowSm), "sparseVector") 441 | 442 | #TF-LogIDF 443 | mat <- as(Matrix::Diagonal(x=as.vector(idf)), "sparseMatrix") %*% mat 444 | 445 | 446 | } else if(LSI$LSIMethod == 2 | tolower(LSI$LSIMethod) == "log(tf-idf)"){ #Adapted from Stuart et al. 447 | 448 | #IDF 449 | idf <- as(LSI$nCol / LSI$rowSm, "sparseVector") 450 | 451 | #TF-IDF 452 | mat <- as(Matrix::Diagonal(x=as.vector(idf)), "sparseMatrix") %*% mat 453 | 454 | #Log transform TF-IDF 455 | mat@x <- log(mat@x * LSI$scaleTo + 1) 456 | 457 | 458 | } else if(LSI$LSIMethod == 3 | tolower(LSI$LSIMethod) == "log(tf-log(idf))"){ 459 | 460 | #LogTF 461 | mat@x <- log(mat@x + 1) 462 | 463 | #LogIDF 464 | idf <- as(log(1 + LSI$nCol / LSI$rowSm), "sparseVector") 465 | 466 | #TF-IDF 467 | mat <- as(Matrix::Diagonal(x=as.vector(idf)), "sparseMatrix") %*% mat 468 | 469 | 470 | } else stop("Invalid LSI method!") 471 | 472 | # Clean Up Matrix 473 | idxNA <- Matrix::which(is.na(mat), arr.ind = T) 474 | if(length(idxNA) > 0) mat[idxNA] <- 0 475 | 476 | message("--- Calculating V and LSI.") 477 | 478 | # Calculate V 479 | V <- Matrix::t(mat) %*% LSI$svd$u %*% Matrix::diag(1 / LSI$svd$d) 480 | 481 | # LSI Diagonal 482 | svdDiag <- matrix(0, nrow = LSI$nDimensions, ncol = LSI$nDimensions) 483 | diag(svdDiag) <- LSI$svd$d 484 | matSVD <- Matrix::t(svdDiag %*% Matrix::t(V)) 485 | matSVD <- as.matrix(matSVD) 486 | rownames(matSVD) <- colnames(mat) 487 | colnames(matSVD) <- paste0("LSI", 1:ncol(matSVD)) 488 | 489 | if(returnModel){ 490 | X <- LSI$svd$u %*% diag(LSI$svd$d) %*% t(V) 491 | out <- list(matSVD = matSVD, V = V, X = X) 492 | } else out <- matSVD 493 | 494 | rm(mat, V, svdDiag, matSVD) 495 | 496 | return(out) 497 | } 498 | 499 | .LSICluster <- function( 500 | outLSI, 501 | cellDF, 502 | corCutOff, 503 | dimsToUse, 504 | scaleDims, 505 | clusterParams, 506 | j, 507 | filterBias 508 | ){ 509 | 510 | if(scaleDims){ 511 | dimsPF <- dimsToUse[which(outLSI$corToDepth$scaled[dimsToUse] <= corCutOff)] 512 | } else{ 513 | dimsPF <- dimsToUse[which(outLSI$corToDepth$none[dimsToUse] <= corCutOff)] 514 | } 515 | 516 | if(length(dimsPF) != length(dimsToUse)){ 517 | message("- Filtering ", length(dimsToUse) - length(dimsPF), " dims with cor>", corCutOff, " to log10(depth+1)") 518 | } 519 | if(length(dimsPF) < 2){ 520 | stop("Dimensions to use (after filtering for correlation to depth) fewer than 2!") 521 | } 522 | 523 | # Time to compute clusters 524 | parClust <- clusterParams 525 | if(scaleDims){ parClust$input <- as.matrix(.scaleDims(outLSI$matSVD)[, dimsPF, drop = F]) 526 | } else parClust$input <- as.matrix(outLSI$matSVD[, dimsPF, drop = F]) 527 | 528 | if(filterBias){ 529 | parClust$testBias <- T 530 | parClust$filterBias <- T 531 | } 532 | parClust$biasVals <- cellDF[rownames(outLSI$matSVD), "cellDepth"] 533 | 534 | clusters <- do.call(ArchR::addClusters, parClust) %>% suppressMessages() 535 | parClust$input <- NULL 536 | parClust$biasVals <- "cellDepth" 537 | nClust <- length(unique(clusters)) 538 | message("- Identified ", nClust, " clusters:") 539 | message("-- method=\'Seurat\'") 540 | message("-- resolution=", parClust$resolution) 541 | message("-- maxClusters=", parClust$maxClusters) 542 | if(parClust$filterBias) message("-- biasVals=\'", parClust$biasVals, "\'") 543 | 544 | df <- DataFrame( 545 | cellNames = rownames(outLSI$matSVD), clusters = clusters, 546 | cellDF[rownames(outLSI$matSVD), c("idx", "groupNames", "cellDepth")] 547 | ) 548 | metadata(df)$parClust <- parClust 549 | print(table(df[, c("groupNames", "clusters")])) 550 | 551 | return(df) 552 | } 553 | 554 | .scaleDims <- function(m, scaleMax = 2, min = -scaleMax, max = scaleMax, limit = F){ 555 | z <- sweep(m - rowMeans(m), 1, matrixStats::rowSds(m), `/`) 556 | if(limit){ 557 | z[z > max] <- max 558 | z[z < min] <- min 559 | } 560 | return(z) 561 | } 562 | 563 | .filterSample <- function( 564 | x = NULL, 565 | n = NULL, 566 | vals = x, 567 | outlierQuantiles = NULL, 568 | factor = 2 569 | ){ 570 | if(!is.null(outlierQuantiles)){ 571 | quant <- quantile(vals, probs = c(min(outlierQuantiles) / factor, 1 - ((1-max(outlierQuantiles)) / factor))) 572 | idx <- which(vals >= quant[1] & vals <= quant[2]) 573 | } else idx <- seq_along(x) 574 | 575 | if(length(idx) >= n){ return(sample(x = x[idx], size = n)) 576 | } else return(sample(x = x, size = n)) 577 | } 578 | 579 | .sampleBySample <- function( 580 | cellNames = NULL, 581 | cellDepth = NULL, 582 | sampleNames = NULL, 583 | sampleCells = NULL, 584 | outlierQuantiles = NULL, 585 | factor = 2 586 | ){ 587 | if(sampleCells < length(cellNames)){ 588 | 589 | sampleN <- ceiling(sampleCells * table(sampleNames) / length(sampleNames)) 590 | splitCells <- split(cellNames, sampleNames) 591 | splitDepth <- split(cellDepth, sampleNames) 592 | 593 | sampledCellNames <- lapply(seq_along(splitCells), function(x){ 594 | .filterSample( 595 | x = splitCells[[x]], 596 | n = sampleN[names(splitCells)[x]], 597 | vals = splitDepth[[x]], 598 | outlierQuantiles = outlierQuantiles, 599 | factor = factor 600 | ) 601 | }) %>% unlist %>% sort 602 | 603 | return(sampledCellNames) 604 | 605 | } else return(cellNames) 606 | } 607 | 608 | -------------------------------------------------------------------------------- /tutorial/variable-peaks/filter-peaks.R: -------------------------------------------------------------------------------- 1 | library(GEOquery) 2 | library(GenomicRanges) 3 | library(Matrix) 4 | library(data.table) 5 | library(dplyr) 6 | library(BSgenome.Hsapiens.UCSC.hg19) 7 | 8 | # downloaded data from GEO: 9 | getGEOSuppFiles(GEO = "GSE96769", makeDirectory = F) 10 | 11 | # create GRanges object for peaks: 12 | peaks <- fread( 13 | file = "GSE96769_PeakFile_20160207.bed.gz", 14 | header = F, stringsAsFactors = F 15 | )[, c(1:3, 7)] 16 | colnames(peaks) <- c("chr", "start", "end", "annot") 17 | peaks.gr <- makeGRangesFromDataFrame( 18 | df = peaks, 19 | seqinfo = seqinfo(BSgenome.Hsapiens.UCSC.hg19), 20 | starts.in.df.are.0based = T 21 | ) 22 | 23 | # create dgCMatrix object for counts: 24 | cellname <- strsplit( 25 | x = readLines("GSE96769_scATACseq_counts.txt.gz", n = 1), 26 | split = ";|(#\t)" 27 | )[[1]][-1] 28 | counts <- read.table( 29 | file = "GSE96769_scATACseq_counts.txt.gz", 30 | skip = 1, header = F, sep = "\t", 31 | col.names = c("peak", "cell", "count") 32 | ) 33 | counts <- sparseMatrix( 34 | i = counts$cell, j = counts$peak, x = counts$count, 35 | dimnames = list(cellname, as.character(peaks.gr)) 36 | ) %>% t() 37 | 38 | # cells filtered by ArchR (../variable-tiles/5-ArchR.R): 39 | sample.info <- read.table( 40 | file = "sample-info.tsv", sep = "\t", 41 | header = T, check.names = F, row.names = 1 42 | ) %>% subset(!discard) 43 | ci <- match(sample.info$Title, colnames(counts)) 44 | counts <- counts[, ci] 45 | colnames(counts) <- sample.info$Run 46 | 47 | # filter peaks 48 | num.cells <- rowSums(counts > 0) 49 | pi <- !(peaks$chr %in% c("chrX", "chrY", "chrM")) & 50 | !grepl("promoter", peaks$annot) & 51 | (num.cells >= 5) 52 | peaks.gr <- peaks.gr[pi] 53 | counts <- counts[pi, ] 54 | 55 | -------------------------------------------------------------------------------- /tutorial/variable-tiles/1-fastq-dump.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # sratoolkit v2.11.0-ubuntu64 3 | 4 | mkdir fastq/ 5 | i=0 6 | for sample in $(cat SRR_Acc_List.txt); do 7 | let i++ 8 | fastq-dump --split-3 --dumpbase --skip-technical --clip --read-filter pass --outdir fastq/ $sample 9 | gzip fastq/*.fastq 10 | R1=fastq/$sample\_pass_1.fastq.gz 11 | R2=fastq/$sample\_pass_2.fastq.gz 12 | if [ -f $R1 ]; then 13 | if [ -f $R2 ]; then 14 | mkdir fastq/$sample 15 | mv $R1 fastq/$sample/$sample\_S$i\_R1.fastq.gz 16 | mv $R2 fastq/$sample/$sample\_S$i\_R2.fastq.gz 17 | fi 18 | fi 19 | done 20 | rm fastq/*_pass_*.fastq.gz 21 | 22 | -------------------------------------------------------------------------------- /tutorial/variable-tiles/2-TrimGalore.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # TrimGalore v0.6.6 3 | 4 | for sample in $(cat SRR_Acc_List.txt); do 5 | r1=$(ls fastq/$sample/*_R1.fastq.gz) 6 | r2=$(ls fastq/$sample/*_R2.fastq.gz) 7 | if [ -f $r1 ]; then 8 | if [ -f $r2 ]; then 9 | trim_galore --cores 8 --paired --fastqc --output_dir fastq/$sample $r1 $r2 10 | R1=$(ls fastq/$sample/*_R1_val_1.fq.gz); mv "$R1" "${R1%_val_1.fq.gz}_trimmed.fastq.gz" 11 | R2=$(ls fastq/$sample/*_R2_val_2.fq.gz); mv "$R2" "${R2%_val_2.fq.gz}_trimmed.fastq.gz" 12 | fi 13 | fi 14 | done 15 | 16 | -------------------------------------------------------------------------------- /tutorial/variable-tiles/3-bowtie2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # bowtie2 v2.4.2-linux-x86_64 3 | 4 | mkdir aligned/ 5 | for sample in $(cat SRR_Acc_List.txt); do 6 | R1=$(ls fastq/$sample/*_R1_trimmed.fastq.gz) 7 | R2=$(ls fastq/$sample/*_R2_trimmed.fastq.gz) 8 | if [ -f $R1 ]; then 9 | if [ -f $R2 ]; then 10 | bowtie2 --threads 16 --mm \ 11 | --maxins 2000 --very-sensitive --no-unal --no-mixed \ 12 | -x Bowtie2_index/hg19 \ 13 | -1 $R1 -2 $R2 -S aligned/$sample.SAM 14 | fi 15 | fi 16 | done 17 | 18 | -------------------------------------------------------------------------------- /tutorial/variable-tiles/4-samtools.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # samtools v1.11 3 | 4 | prepBam(){ 5 | if [ -f aligned/$1.SAM ]; then 6 | samtools view -H --no-PG aligned/$1.SAM > aligned/$1\_BC.SAM 7 | samtools view aligned/$1.SAM | addBarcodeTag $1 >> aligned/$1\_BC.SAM 8 | samtools view -h -q 30 aligned/$1\_BC.SAM | \ 9 | samtools sort -m 8G -l 9 -o aligned/$1\_BC_MQ30_posSorted.BAM - 10 | samtools index aligned/$1\_BC_MQ30_posSorted.BAM 11 | rm aligned/$1\_BC.SAM 12 | fi 13 | } 14 | 15 | g++ -std=c++11 addBarcodeTag.cpp -o addBarcodeTag 16 | export -f prepBam 17 | parallel -j 16 -a SRR_Acc_List.txt prepBam 18 | 19 | ls aligned/*_BC_MQ30_posSorted.BAM > bam-files.txt 20 | samtools merge -@ 30 -l 9 -b bam-files.txt aligned/BC_MQ30_posSorted_merged.BAM 21 | samtools index -@ 30 aligned/BC_MQ30_posSorted_merged.BAM 22 | 23 | -------------------------------------------------------------------------------- /tutorial/variable-tiles/5-ArchR.R: -------------------------------------------------------------------------------- 1 | library(ArchR) 2 | library(parallel) 3 | 4 | addArchRThreads(threads = 16) 5 | addArchRGenome("hg19") 6 | set.seed(1) 7 | 8 | # sample info of 2,971 cells from GSE74310 & GSE96769: 9 | sample.info <- read.table( 10 | file = "../data/sample-info.tsv", sep = "\t", 11 | header = T, check.names = F, row.names = 1 12 | ) 13 | 14 | # filter low-quality ATAC-seq fragments, low-quality cells, and doublets: 15 | arrow.files <- createArrowFiles( 16 | inputFiles = c(buenrostro2018 = "aligned/BC_MQ30_posSorted_merged.BAM"), 17 | minTSS = 4, minFrags = 5000, maxFrags = 2 * 10^5, 18 | excludeChr = c("chrX", "chrY", "chrM"), 19 | bcTag = "CB", validBarcodes = sample.info$Run 20 | ) 21 | doubScores <- addDoubletScores(arrow.files) 22 | archr.obj <- filterDoublets(ArchRProject(arrow.files)) 23 | 24 | # 2,154 cells retained for all analyses: 25 | sample.info$discard <- !(sample.info$Run %in% gsub("^.+#", "", archr.obj$cellNames)) 26 | write.table(sample.info, file = "../data/sample-info.tsv", sep = "\t", row.names = T) 27 | 28 | -------------------------------------------------------------------------------- /tutorial/variable-tiles/README.md: -------------------------------------------------------------------------------- 1 | Pre-processing the 'Small-scale human hematopoiesis dataset': 2 | 3 | 1. [Downloading FASTQ files](1-fastq-dump.sh), 4 | 2. [Trimming adapter sequences](2-TrimGalore.sh), 5 | 3. [Aligning reads to the reference genome](3-bowtie2.sh), 6 | 4. [Filtering and sorting the alignments](4-samtools.sh), 7 | 5. [Filtering fragments and cells with ArchR (quality-control and doublet removal)](5-ArchR.R). 8 | 9 | The result is an ArchR object containing the pre-processed scATAC-seq data (named `archr.obj`). 10 | -------------------------------------------------------------------------------- /tutorial/variable-tiles/addBarcodeTag.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | using namespace std; 4 | 5 | int main(int argc, char** argv){ 6 | string line, barcode = string(argv[1]); 7 | while(getline(cin, line)) cout << line << "\t" << "CB:Z:" << barcode << endl; 8 | 9 | return 0; 10 | } 11 | --------------------------------------------------------------------------------