├── .Rbuildignore
├── .gitignore
├── CellSpace.Rproj
├── CellSpace.png
├── DESCRIPTION
├── LICENSE.md
├── NAMESPACE
├── R
├── CellSpace.R
└── README.md
├── README.md
├── cpp
├── .gitignore
├── README.md
├── StarSpace-MIT-License.md
├── makefile
└── src
│ ├── data.cpp
│ ├── data.h
│ ├── dict.cpp
│ ├── dict.h
│ ├── main.cpp
│ ├── matrix.h
│ ├── model.cpp
│ ├── model.h
│ ├── parser.cpp
│ ├── parser.h
│ ├── proj.cpp
│ ├── proj.h
│ ├── starspace.cpp
│ ├── starspace.h
│ └── utils
│ ├── args.cpp
│ ├── args.h
│ ├── normalize.cpp
│ ├── normalize.h
│ ├── utils.cpp
│ └── utils.h
├── man
├── CellSpace-class.Rd
├── CellSpace.Rd
├── DNA_sequence_embedding.Rd
├── add_motif_db.Rd
├── cosine_similarity.Rd
├── docs
│ ├── CellSpace.md
│ ├── DNA_sequence_embedding.md
│ ├── add_motif_db.md
│ ├── cosine_similarity.md
│ ├── embedding_distance.md
│ ├── find_clusters.md
│ ├── find_neighbors.md
│ ├── merge_small_clusters.md
│ ├── motif_embedding.md
│ └── run_UMAP.md
├── embedding_distance.Rd
├── find_clusters.Rd
├── find_neighbors.Rd
├── merge_small_clusters.Rd
├── motif_embedding.Rd
└── run_UMAP.Rd
└── tutorial
├── README.md
├── data
├── CellSpace_embedding-var_tiles.tsv.gz
├── PWM-list.rds
├── cell-names.txt
├── cell_by_peak-counts.mtx.gz
├── cell_by_tile-counts.mtx.gz
├── palette.rds
├── sample-info.tsv
├── var_peaks.fa.gz
└── var_tiles.fa.gz
├── plot-functions.R
├── plots
├── UMAP-cells.png
├── UMAP-cells_and_TFs.png
└── motif-scores.png
├── variable-peaks
├── IterativeLSI.R
└── filter-peaks.R
└── variable-tiles
├── 1-fastq-dump.sh
├── 2-TrimGalore.sh
├── 3-bowtie2.sh
├── 4-samtools.sh
├── 5-ArchR.R
├── README.md
├── SRR_Acc_List.txt
└── addBarcodeTag.cpp
/.Rbuildignore:
--------------------------------------------------------------------------------
1 | ^CellSpace\.Rproj$
2 | ^\.Rproj\.user$
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.DS_Store
2 | .Rproj.user
3 | .Rhistory
4 |
5 |
6 |
--------------------------------------------------------------------------------
/CellSpace.Rproj:
--------------------------------------------------------------------------------
1 | Version: 1.0
2 |
3 | RestoreWorkspace: No
4 | SaveWorkspace: No
5 | AlwaysSaveHistory: Default
6 |
7 | EnableCodeIndexing: Yes
8 | UseSpacesForTab: Yes
9 | NumSpacesForTab: 2
10 | Encoding: UTF-8
11 |
12 | RnwWeave: Sweave
13 | LaTeX: pdfLaTeX
14 |
15 | AutoAppendNewline: Yes
16 | StripTrailingWhitespace: Yes
17 | LineEndingConversion: Posix
18 |
19 | BuildType: Package
20 | PackageUseDevtools: Yes
21 | PackageInstallArgs: --no-multiarch --with-keep.source
22 | PackageRoxygenize: rd,collate,namespace
23 |
--------------------------------------------------------------------------------
/CellSpace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zakieh-tayyebi/CellSpace/c5dc59b1a3c843596834acab4d4e421e4cd59343/CellSpace.png
--------------------------------------------------------------------------------
/DESCRIPTION:
--------------------------------------------------------------------------------
1 | Package: CellSpace
2 | Title: Downstream Analyses for CellSpace
3 | Description: A toolkit to analyze single-cell ATAC-seq data based on a learned CellSpace embedding
4 | Version: 0.0.6
5 | Authors@R: c(
6 | person("Zakieh", "Tayyebi", email = "zakieh.tayyebi@gmail.com", role = c("aut", "cre"),
7 | comment = c(ORCID = "https://orcid.org/0000-0002-7821-9905")),
8 | person("The Christina Leslie Lab", role = "fnd",
9 | comment = c(ORCID = "https://orcid.org/0000-0002-4571-5910")),
10 | person("The Tri-Institutional PhD Program in Computational Biology & Medicine", role = "fnd"))
11 | URL: https://github.com/zakieh-tayyebi/CellSpace
12 | BugReports: https://github.com/zakieh-tayyebi/CellSpace/issues
13 | Imports:
14 | Seurat,
15 | Biostrings
16 | Encoding: UTF-8
17 | Roxygen: list(markdown = TRUE)
18 | RoxygenNote: 7.2.3
19 | Suggests:
20 | knitr,
21 | rmarkdown
22 | VignetteBuilder: knitr
23 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | The MIT License
2 |
3 | Copyright (c) 2023 The Christina Leslie Lab, MSK CC
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
10 |
--------------------------------------------------------------------------------
/NAMESPACE:
--------------------------------------------------------------------------------
1 | # Generated by roxygen2: do not edit by hand
2 |
3 | export(CellSpace)
4 | export(DNA_sequence_embedding)
5 | export(add_motif_db)
6 | export(cosine_similarity)
7 | export(embedding_distance)
8 | export(find_clusters)
9 | export(find_neighbors)
10 | export(merge_small_clusters)
11 | export(motif_embedding)
12 | export(run_UMAP)
13 | exportClasses(CellSpace)
14 | importFrom(Biostrings,DNAStringSet)
15 | importFrom(Biostrings,as.matrix)
16 | importFrom(Biostrings,reverseComplement)
17 | importFrom(Seurat,FindClusters)
18 | importFrom(Seurat,FindNeighbors)
19 | importFrom(Seurat,RunUMAP)
20 |
--------------------------------------------------------------------------------
/R/CellSpace.R:
--------------------------------------------------------------------------------
1 | #' The CellSpace Class
2 | #'
3 | #' The \code{CellSpace} class stores CellSpace embedding and
4 | #' related information needed for performing downstream analyses.
5 | #'
6 | #' @slot project title of the project
7 | #' @slot emb.file the .tsv output of CellSpace containing the embedding matrix for cells and k-mers
8 | #' @slot cell.emb the embedding matrix for cells
9 | #' @slot kmer.emb the embedding matrix for k-mers
10 | #' @slot motif.emb the embedding matrix for TF motifs
11 | #' @slot meta.data data frame containing meta-information about each cell
12 | #' @slot dim the dimensions of the CellSpace embeddings
13 | #' @slot k the length of DNA k-mers
14 | #' @slot similarity the similarity function in hinge loss
15 | #' @slot p the embedding of an entity equals the sum of its \code{M} feature embedding vectors divided by \code{M^p}
16 | #' @slot label cell label prefix
17 | #' @slot neighbors list containing nearest neighbor graphs
18 | #' @slot reductions list containing dimensional reductions
19 | #' @slot misc list containing miscellaneous objects
20 | #'
21 | #' @name CellSpace-class
22 | #' @rdname CellSpace-class
23 | #' @exportClass CellSpace
24 | #'
25 | setClass("CellSpace", slots = list(
26 | project = "character",
27 | emb.file = "character",
28 | cell.emb = "matrix",
29 | kmer.emb = "matrix",
30 | motif.emb = "list",
31 | meta.data = "data.frame",
32 | dim = "integer",
33 | k = "integer",
34 | similarity = "character",
35 | p = "numeric",
36 | label = "character",
37 | neighbors = "list",
38 | reductions = "list",
39 | misc = "list"
40 | ))
41 |
42 | setMethod(
43 | "show", "CellSpace",
44 | function(object){
45 | cat("An object of class \"CellSpace\"\n")
46 | if(object@project != "") cat("Project:", object@project, "\n")
47 | cat("CellSpace model: ", object@dim, "-dimensional embedding for ",
48 | nrow(object@cell.emb), " cells and all DNA k-mers (k=", object@k, ")\n", sep = ""
49 | )
50 | if(length(object@motif.emb) > 0){
51 | cat("TF motif embedding matrices: ")
52 | cat(names(object@motif.emb), sep = ", ")
53 | cat("\n")
54 | }
55 | if(ncol(object@meta.data) > 0){
56 | cat("Cell meta-data: ")
57 | if(length(colnames(object@meta.data)) > 10){
58 | cat(head(colnames(object@meta.data), 3), sep = ", ")
59 | cat(", ..., ")
60 | cat(tail(colnames(object@meta.data), 2), sep = ", ")
61 | } else cat(colnames(object@meta.data), sep = ", ")
62 | cat("\n")
63 | }
64 | if(length(object@neighbors) > 0){
65 | cat("Nearest neighbor graphs created: ")
66 | cat(names(object@neighbors), sep = ", ")
67 | cat("\n")
68 | }
69 | if(length(object@reductions) > 0){
70 | cat("Dimensional reductions calculated: ")
71 | cat(names(object@reductions), sep = ", ")
72 | cat("\n")
73 | }
74 | if(length(object@misc) > 0){
75 | cat("miscellaneous: ")
76 | cat(names(object@misc), sep = ", ")
77 | cat("\n")
78 | }
79 | }
80 | )
81 |
82 | setMethod(
83 | "$", "CellSpace",
84 | function(x, name){
85 | if(name %in% colnames(x@meta.data)){
86 | return(x@meta.data[, name])
87 | } else return(NULL)
88 | }
89 | )
90 |
91 | setMethod(
92 | "$<-", "CellSpace",
93 | function(x, name, value){
94 | x@meta.data[, name] <- value
95 | return(x)
96 | }
97 | )
98 |
99 | #' CellSpace
100 | #'
101 | #' Generates an object from the \code{CellSpace} class.
102 | #'
103 | #' @param emb.file the .tsv output of CellSpace containing the embedding matrix for cells and k-mers
104 | #' @param cell.names vector of unique cell names
105 | #' @param meta.data a \code{data.frame} containing meta-information about each cell
106 | #' @param project title of the project
107 | #' @param similarity the similarity function in hinge loss
108 | #' @param p the embedding of an entity equals the sum of its \code{M} feature embedding vectors divided by \code{M^p}
109 | #' @param label cell label prefix
110 | #'
111 | #' @return a new \code{CellSpace} object
112 | #'
113 | #' @name CellSpace
114 | #' @rdname CellSpace
115 | #' @export
116 | #'
117 | CellSpace <- function(
118 | emb.file,
119 | cell.names = NULL,
120 | meta.data = NULL,
121 | project = NULL,
122 | similarity = "cosine",
123 | p = 0.5,
124 | label = "__label__"
125 | ){
126 | project <- ifelse(is.null(project), "", project)
127 |
128 | emb.file <- normalizePath(emb.file)
129 | emb <- read.table(emb.file, header = F, row.names = 1, sep = "\t")
130 | dim <- ncol(emb)
131 | colnames(emb) <- paste0("CS", 1:dim)
132 |
133 | cell.label <- paste0(label, "C")
134 | cell.labels <- rownames(emb)[grep(cell.label, rownames(emb))]
135 | cell.idx <- sort(as.integer(gsub(cell.label, "", cell.labels)))
136 | cell.emb <- as.matrix(emb[paste0(cell.label, cell.idx), ])
137 |
138 | kmer.emb <- as.matrix(emb[!grepl(label, rownames(emb)), ])
139 | k <- unique(nchar(rownames(kmer.emb)))
140 | if(length(k) > 1) stop("Varying k-mer lengths!")
141 |
142 | if(!is.null(cell.names)){
143 | if(length(cell.names) != nrow(cell.emb) || any(duplicated(cell.names))){
144 | warning("\'cell.names\' must be a character vector with ", nrow(cell.emb), " unique values!")
145 | cell.names <- NULL
146 | }
147 | }
148 |
149 | if(!is.null(meta.data)){
150 | if(nrow(meta.data) != nrow(cell.emb)){
151 | warning("The number of rows in \'meta.data\' does not match the cell embedding matrix!")
152 | meta.data <- NULL
153 | }
154 | }
155 |
156 | if(is.null(cell.names)){
157 | if(is.null(meta.data)){
158 | meta.data <- data.frame(row.names = rownames(cell.emb), check.rows = F, check.names = F)
159 | } else rownames(cell.emb) <- rownames(meta.data)
160 | } else {
161 | if(is.null(meta.data)){
162 | meta.data <- data.frame(row.names = cell.names, check.rows = F, check.names = F)
163 | } else rownames(meta.data) <- cell.names
164 | rownames(cell.emb) <- cell.names
165 | }
166 |
167 | new("CellSpace",
168 | project = project,
169 | emb.file = emb.file,
170 | cell.emb = cell.emb,
171 | kmer.emb = kmer.emb,
172 | motif.emb = list(),
173 | meta.data = meta.data,
174 | dim = dim,
175 | k = k,
176 | similarity = similarity,
177 | p = p,
178 | label = label,
179 | neighbors = list(),
180 | reductions = list(),
181 | misc = list()
182 | )
183 | }
184 |
185 | #' find_neighbors
186 | #'
187 | #' Builds a nearest neighbor graph and shared nearest neighbor graph from the CellSpace embedding.
188 | #'
189 | #' @importFrom Seurat FindNeighbors
190 | #'
191 | #' @param object a \code{CellSpace} object
192 | #' @param n.neighbors the number of nearest neighbors for the KNN algorithm
193 | #' @param emb the embedding matrix used to create the nearest neighbor graphs
194 | #' @param emb.name prefix for the graph names that will be added to the \code{neighbors} slot
195 | #' @param ... arguments passed to \code{Seurat::FindNeighbors}
196 | #'
197 | #' @return a \code{CellSpace} object containing nearest neighbor and shared nearest neighbor graphs in the \code{neighbors} slot
198 | #'
199 | #' @name find_neighbors
200 | #' @rdname find_neighbors
201 | #' @export
202 | #'
203 | find_neighbors <- function(
204 | object,
205 | n.neighbors = 30,
206 | emb = object@cell.emb,
207 | emb.name = "cells",
208 | ...
209 | ){
210 | graphs <- FindNeighbors(
211 | emb, distance.matrix = F, k.param = n.neighbors, l2.norm = F,
212 | nn.method = "annoy", annoy.metric = object@similarity, ...
213 | )
214 | names(graphs) <- paste(emb.name, names(graphs), sep = "_")
215 | object@neighbors[names(graphs)] <- graphs
216 | return(object)
217 | }
218 |
219 | #' find_clusters
220 | #'
221 | #' Finds clusters from a shared nearest neighbor graph built from the CellSpace embedding.
222 | #'
223 | #' @importFrom Seurat FindClusters
224 | #'
225 | #' @param object a \code{CellSpace} object
226 | #' @param graph name of the shared nearest neighbor graph in the \code{neighbors} slot used to find clusters
227 | #' @param ... arguments passed to \code{Seurat::FindClusters}
228 | #'
229 | #' @return a \code{CellSpace} object with the cell clusters added to the \code{meta.data} slot
230 | #'
231 | #' @name find_clusters
232 | #' @rdname find_clusters
233 | #' @export
234 | #'
235 | find_clusters <- function(object, graph = "cells_snn", ...){
236 | if(!(graph %in% names(object@neighbors)))
237 | stop("\'", graph, "\' not available! Run \'find_neighbors\' to create nearest neighbor graphs.")
238 | cl <- FindClusters(object@neighbors[[graph]], ...)
239 | colnames(cl) <- gsub("res\\.", "Clusters.res_", colnames(cl))
240 | object@meta.data[, colnames(cl)] <- cl
241 | return(object)
242 | }
243 |
244 | #' merge_small_clusters
245 | #'
246 | #' Merges cells from small clusters with the nearest clusters.
247 | #'
248 | #' @param object a \code{CellSpace} object
249 | #' @param clusters a vector of cluster labels, or the name of a column in the \code{meta.data} slot containing cluster labels
250 | #' @param min.cells any cluster with fewer cells than \code{min.cells} will be merged with the nearest cluster
251 | #' @param graph a shared nearest neighbor graph, or the name of a graph in the \code{neighbors} slot, used to find clusters
252 | #'
253 | #' @return new cluster labels
254 | #'
255 | #' @name merge_small_clusters
256 | #' @rdname merge_small_clusters
257 | #' @export
258 | #'
259 | merge_small_clusters <- function(
260 | object,
261 | clusters,
262 | min.cells = 10,
263 | graph = "cells_snn",
264 | seed = 1
265 | ){
266 | if(class(graph) == "character" && graph %in% names(object@neighbors)){
267 | graph <- object@neighbors[[graph]]
268 | } else if(class(graph) != "Graph")
269 | stop("'graph' must be a nearest neighbor graph of class 'Graph'")
270 |
271 | if(length(clusters) == 1 && clusters %in% colnames(object@meta.data)){
272 | clusters <- object@meta.data[, clusters]
273 | } else if(length(clusters) == nrow(graph)){
274 | if(class(clusters) != "factor")
275 | clusters <- factor(clusters)
276 | } else stop("'clusters' must be the name of a column in 'object@meta.data', or a vector of the same length as the number of cells.")
277 |
278 | small.clusters <- names(which(table(clusters) < min.cells)) %>% intersect(levels(clusters))
279 | cluster_names <- setdiff(levels(clusters), small.clusters)
280 | connectivity <- vector(mode = "numeric", length = length(cluster_names))
281 | names(connectivity) <- cluster_names
282 |
283 | new.ids <- clusters
284 | for(i in small.clusters){
285 | i.cells <- which(clusters == i)
286 | for(j in cluster_names){
287 | j.cells <- which(clusters == j)
288 | subSNN <- graph[i.cells, j.cells]
289 | if(is.object(subSNN)){
290 | connectivity[j] <- sum(subSNN) / (nrow(subSNN) * ncol(subSNN))
291 | } else connectivity[j] <- mean(subSNN)
292 | }
293 | set.seed(seed = seed)
294 | m <- max(connectivity, na.rm = T)
295 | mi <- which(connectivity == m, arr.ind = T)
296 | closest_cluster <- sample(names(connectivity[mi]), 1)
297 | new.ids[i.cells] <- closest_cluster
298 | }
299 |
300 | return(factor(new.ids, levels = cluster_names))
301 | }
302 |
303 | #' run_UMAP
304 | #'
305 | #' Computes a UMAP embedding from the CellSpace embedding.
306 | #'
307 | #' @importFrom Seurat RunUMAP
308 | #'
309 | #' @param object a \code{CellSpace} object
310 | #' @param emb the embedding matrix used to compute the UMAP embedding
311 | #' @param graph name of the nearest neighbor graph in the \code{neighbors} slot used to compute the UMAP embedding
312 | #' @param name name of the lower-dimensional embedding that will be added to the \code{reductions} slot
313 | #' @param ... arguments passed to \code{Seurat::RunUMAP}
314 | #'
315 | #' @return a \code{CellSpace} object containing a UMAP embedding in the \code{reductions} slot
316 | #'
317 | #' @name run_UMAP
318 | #' @rdname run_UMAP
319 | #' @export
320 | #'
321 | run_UMAP <- function(
322 | object,
323 | emb = object@cell.emb,
324 | graph = NULL,
325 | name = "cells_UMAP",
326 | ...
327 | ){
328 | if(!is.null(graph)){
329 | if(graph %in% names(object@neighbors)){
330 | umap <- RunUMAP(object = object@neighbors[[graph]], assay = "CellSpace", ...)
331 | } else stop("\'", graph, "\' not available! Run \'find_neighbors\' to create nearest neighbor graphs.")
332 | } else if(!is.null(emb)){
333 | umap <- RunUMAP(
334 | object = emb,
335 | metric = object@similarity,
336 | assay = "CellSpace",
337 | ...
338 | )
339 | } else stop("'emb' or 'graph' must be provided!")
340 | object@reductions[[name]] <- umap@cell.embeddings
341 | return(object)
342 | }
343 |
344 | #' cosine_similarity
345 | #'
346 | #' Computes cosine similarity in the embedding space.
347 | #'
348 | #' @param x an embedding matrix
349 | #' @param y an embedding matrix with compatible dimensions to \code{x}, or \code{NULL}, in which case \code{y=x}
350 | #'
351 | #' @return a matrix containing the cosine similarity between rows of \code{x} and \code{y}
352 | #'
353 | #' @name cosine_similarity
354 | #' @rdname cosine_similarity
355 | #' @export
356 | #'
357 | cosine_similarity <- function(x, y = NULL){
358 | if(!is.matrix(x)) x <- matrix(x, nrow = 1)
359 | if(!is.null(y) && !is.matrix(y)) y <- matrix(y, nrow = 1)
360 |
361 | normx <- sqrt(rowSums(x ^ 2))
362 | if(is.null(y)){
363 | y <- x
364 | normy <- normx
365 | } else normy <- sqrt(rowSums(y ^ 2))
366 |
367 | s <- tcrossprod(x, y) / (normx %o% normy)
368 | return(s)
369 | }
370 |
371 | #' embedding_distance
372 | #'
373 | #' Computes distance in the embedding space based on cosine similarity.
374 | #'
375 | #' @param x an embedding matrix
376 | #' @param y an embedding matrix with compatible dimensions to \code{x}, or \code{NULL}, in which case \code{y=x}
377 | #' @param distance the distance metric, either 'cosine' or 'angular', to compute from the cosine similarity
378 | #'
379 | #' @return a matrix containing the distance between rows of \code{x} and \code{y}, computed from their cosine similarity
380 | #'
381 | #' @name embedding_distance
382 | #' @rdname embedding_distance
383 | #' @export
384 | #'
385 | embedding_distance <- function(x, y = NULL, distance = c("cosine", "angular")){
386 | s <- cosine_similarity(x = x, y = y)
387 | idx1 <- which(s > 1); s[idx1] <- 1
388 | idx2 <- which(s < -1); s[idx2] <- -1
389 |
390 | distance <- distance[1]
391 | if(distance == "cosine"){ ds <- 1 - s
392 | } else if(distance == "angular"){ ds <- acos(s) / pi
393 | } else stop("The distance metric must be \'cosine\' or \'angular\'!\n")
394 | return(ds)
395 | }
396 |
397 | #' DNA_sequence_embedding
398 | #'
399 | #' Maps a DNA sequence to the embedding space.
400 | #'
401 | #' @importFrom Biostrings reverseComplement DNAStringSet
402 | #'
403 | #' @param object a \code{CellSpace} object
404 | #' @param seq a DNA sequence
405 | #'
406 | #' @return a numerical vector containing the CellSpace embedding of \code{seq}
407 | #'
408 | #' @name DNA_sequence_embedding
409 | #' @rdname DNA_sequence_embedding
410 | #' @export
411 | #'
412 | DNA_sequence_embedding <- function(object, seq){
413 | if(!is.character(seq)) seq <- as.character(seq)
414 | sl <- nchar(seq)
415 | if(sl < object@k){
416 | warning("The sequence \'", seq, "\' is shorter than CellSpace k-mers (k=", object@k, ")!")
417 | return(rep(NA, object@dim))
418 | }
419 |
420 | b <- 1:(sl - object@k + 1)
421 | kmers <- substring(text = toupper(seq), first = b, last = b + object@k - 1)
422 | kmers <- ifelse(
423 | test = kmers %in% rownames(object@kmer.emb),
424 | yes = kmers,
425 | no = toupper(as.character(reverseComplement(DNAStringSet(kmers))))
426 | )
427 | if(!all(kmers %in% rownames(object@kmer.emb))){
428 | warning("\'", seq, "\' is not a valid DNA sequence!")
429 | return(rep(NA, object@dim))
430 | }
431 |
432 | kn <- length(kmers)
433 | if(kn == 1) return(object@kmer.emb[kmers, ])
434 | else return(colSums(object@kmer.emb[kmers, ]) / (kn ^ object@p))
435 | }
436 |
437 | #' motif_embedding
438 | #'
439 | #' Maps a motif to the embedding space.
440 | #'
441 | #' @importFrom Biostrings as.matrix
442 | #'
443 | #' @param object a \code{CellSpace} object
444 | #' @param PWM \code{PFMatrix} or \code{PWMatrix}
445 | #'
446 | #' @return a numerical vector containing the CellSpace embedding of the consensus sequence for \code{PWM}
447 | #'
448 | #' @name motif_embedding
449 | #' @rdname motif_embedding
450 | #' @export
451 | #'
452 | motif_embedding <- function(object, PWM){
453 | pwm <- as.matrix(PWM)
454 | consensus <- paste(rownames(pwm)[apply(pwm, 2, which.max)], collapse = "")
455 | DNA_sequence_embedding(object = object, seq = consensus)
456 | }
457 |
458 | #' add_motif_db
459 | #'
460 | #' Computes the CellSpace embedding and activity scores of transcription factor motifs.
461 | #'
462 | #' @param object a \code{CellSpace} object
463 | #' @param motif.db \code{PFMatrixList} or \code{PWMatrixList}
464 | #' @param db.name the name of the transcription factor motif database
465 | #'
466 | #' @return a \code{CellSpace} object containing the motif embedding matrix, in the \code{motif.emb} slot, and the corresponding similarity Z-scores, in the \code{misc} slot
467 | #'
468 | #' @name add_motif_db
469 | #' @rdname add_motif_db
470 | #' @export
471 | #'
472 | add_motif_db <- function(object, motif.db, db.name){
473 | object@motif.emb[[db.name]] <- lapply(motif.db, function(motif.pwm){
474 | motif_embedding(object, PWM = motif.pwm)
475 | }) %>% do.call(what = rbind) %>% na.omit()
476 |
477 | object@misc[[paste(db.name, "scores", sep = "_")]] <- cosine_similarity(
478 | x = object@cell.emb,
479 | y = object@motif.emb[[db.name]]
480 | ) %>% scale()
481 |
482 | return(object)
483 | }
484 |
485 |
--------------------------------------------------------------------------------
/R/README.md:
--------------------------------------------------------------------------------
1 | ## API
2 |
3 | | Function | Description |
4 | |-------------------------|-----------------------------------------------|
5 | | [`CellSpace`](../man/docs/CellSpace.md) | Generates an object from the CellSpace class. |
6 | | [`cosine_similarity`](../man/docs/cosine_similarity.md) | Computes cosine similarity in the embedding space. |
7 | | [`embedding_distance`](../man/docs/embedding_distance.md) | Computes distance in the embedding space based on cosine similarity. |
8 | | [`find_neighbors`](../man/docs/find_neighbors.md) | Builds a nearest neighbor graph and shared nearest neighbor graph from the CellSpace embedding. |
9 | | [`find_clusters`](../man/docs/find_clusters.md) | Finds clusters in a nearest neighbor graph built from the CellSpace embedding. |
10 | | [`merge_small_clusters`](../man/docs/merge_small_clusters.md) | Merges cells from small clusters with the nearest clusters. |
11 | | [`run_UMAP`](../man/docs/run_UMAP.md) | Computes a UMAP embedding from the CellSpace embedding. |
12 | | [`DNA_sequence_embedding`](../man/docs/DNA_sequence_embedding.md) | Maps a DNA sequence to the embedding space. |
13 | | [`motif_embedding`](../man/docs/motif_embedding.md) | Maps a motif to the embedding space. |
14 | | [`add_motif_db`](../man/docs/add_motif_db.md) | Computes the CellSpace embedding and activity scores of transcription factor motifs. |
15 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CellSpace
2 |
3 | CellSpace is a sequence-informed embedding method for scATAC-seq that learns a mapping of DNA *k*-mers and cells to the same space.
4 |
5 | See our [publication](https://doi.org/10.1038/s41592-024-02274-x) for more details.
6 |
7 |
8 |
9 | ## Installation and Usage
10 |
11 | 1. Compile the **C++** program to use as a command line tool to train a CellSpace model.
12 |
13 | CellSpace, which uses the **C++** implementation of [StarSpace](https://github.com/facebookresearch/StarSpace) [[Wu *et al.*, 2017](https://doi.org/10.48550/arXiv.1709.03856)], builds on modern Mac OS and Linux distributions. It requires a compiler with **C++11** support and a working **make**.
14 |
15 | Install [Boost](http://www.boost.org) library and specify the path of the library in the [makefile](cpp/makefile) (set variable **BOOST_DIR**). The default path will work if you install **Boost** by:
16 |
17 | ``` bash
18 | wget https://boostorg.jfrog.io/artifactory/main/release/1.63.0/source/boost_1_63_0.zip
19 | unzip boost_1\_63_0.zip
20 | sudo mv boost_1\_63_0 /usr/local/bin
21 | ```
22 |
23 | Download and build CellSpace:
24 |
25 | ``` bash
26 | git clone https://github.com/zakieh-tayyebi/CellSpace.git
27 | cd CellSpace/cpp/
28 | make
29 | export PATH=$(pwd):$PATH
30 | ```
31 |
32 | Verify that it was successfully compiled:
33 |
34 | ``` bash
35 | CellSpace --help
36 | ```
37 |
38 | 2. Install the **R** package to use the trained CellSpace model for downstream analysis.
39 |
40 | Run the following commands in **R**:
41 |
42 | ``` r
43 | install.packages("devtools")
44 | devtools::install_github("https://github.com/zakieh-tayyebi/CellSpace.git")
45 | library(CellSpace)
46 | ```
47 |
48 | Installation should take only a few minutes. For details about the **R** functions, please refer to the [API](R/README.md).
49 |
50 | 3. A tutorial on CellSpace usage can be found [here](tutorial/README.md).
51 |
52 | ## Citation
53 |
54 | Please cite our [Nature Methods paper](https://doi.org/10.1038/s41592-024-02274-x) if you use CellSpace:
55 |
56 | ```
57 | Tayyebi, Z., Pine, A.R. & Leslie, C.S. Scalable and unbiased sequence-informed embedding of single-cell ATAC-seq data with CellSpace. Nature Methods 21, 1014–1022 (2024). https://doi.org/10.1038/s41592-024-02274-x
58 | ```
59 |
60 | ## Contact
61 |
62 | - [zakieh.tayyebi\@gmail.com](mailto:zakieh.tayyebi@gmail.com) (Zakieh Tayyebi)
63 | - [lesliec\@mskcc.org](mailto:lesliec@mskcc.org) (Christina S. Leslie, PhD)
64 |
--------------------------------------------------------------------------------
/cpp/.gitignore:
--------------------------------------------------------------------------------
1 | CellSpace
2 | /CellSpace.dSYM/
3 | *.DS_Store
4 |
5 | # Prerequisites
6 | *.d
7 |
8 | # Compiled Object files
9 | *.slo
10 | *.lo
11 | *.o
12 | *.obj
13 |
14 | # Precompiled Headers
15 | *.gch
16 | *.pch
17 |
18 | # Compiled Dynamic libraries
19 | *.so
20 | *.dylib
21 | *.dll
22 |
23 | # Fortran module files
24 | *.mod
25 | *.smod
26 |
27 | # Compiled Static libraries
28 | *.lai
29 | *.la
30 | *.a
31 | *.lib
32 |
33 | # Executables
34 | *.exe
35 | *.out
36 | *.app
37 |
38 |
--------------------------------------------------------------------------------
/cpp/README.md:
--------------------------------------------------------------------------------
1 | This code is an extension of [StarSpace](https://github.com/facebookresearch/StarSpace) (see [MIT license](StarSpace-MIT-License.md)) and incorporates the StarSpace algorithm (mode 0) for scATAC-seq data.
2 |
--------------------------------------------------------------------------------
/cpp/StarSpace-MIT-License.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Facebook, Inc. and its affiliates.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/cpp/makefile:
--------------------------------------------------------------------------------
1 | # CellSpace v0.0.5
2 |
3 | CXX = g++
4 | CXXFLAGS = -pthread -std=gnu++11
5 |
6 | BOOST_DIR = /usr/local/bin/boost_1_63_0/
7 |
8 | OBJS = normalize.o dict.o args.o proj.o parser.o data.o model.o starspace.o utils.o
9 | INCLUDES = -I$(BOOST_DIR)
10 |
11 | opt: CXXFLAGS += -O3 -funroll-loops
12 | opt: starspace
13 |
14 | normalize.o: src/utils/normalize.cpp src/utils/normalize.h
15 | $(CXX) $(CXXFLAGS) -g -c src/utils/normalize.cpp
16 |
17 | dict.o: src/dict.cpp src/dict.h src/utils/args.h
18 | $(CXX) $(CXXFLAGS) $(INCLUDES) -g -c src/dict.cpp
19 |
20 | args.o: src/utils/args.cpp src/utils/args.h
21 | $(CXX) $(CXXFLAGS) -g -c src/utils/args.cpp
22 |
23 | model.o: data.o src/model.cpp src/model.h src/utils/args.h src/proj.h
24 | $(CXX) $(CXXFLAGS) $(INCLUDES) -g -c src/model.cpp
25 |
26 | proj.o: src/proj.cpp src/proj.h src/matrix.h
27 | $(CXX) $(CXXFLAGS) $(INCLUDES) -g -c src/proj.cpp
28 |
29 | data.o: parser.o src/data.cpp src/data.h
30 | $(CXX) $(CXXFLAGS) $(INCLUDES) -g -c src/data.cpp -o data.o
31 |
32 | utils.o: src/utils/utils.cpp src/utils/utils.h
33 | $(CXX) $(CXXFLAGS) $(INCLUDES) -g -c src/utils/utils.cpp -o utils.o
34 |
35 | parser.o: dict.o src/parser.cpp src/parser.h
36 | $(CXX) $(CXXFLAGS) $(INCLUDES) -g -c src/parser.cpp -o parser.o
37 |
38 | starspace.o: src/starspace.cpp src/starspace.h
39 | $(CXX) $(CXXFLAGS) $(INCLUDES) -g -c src/starspace.cpp
40 |
41 | starspace: $(OBJS)
42 | $(CXX) $(CXXFLAGS) $(OBJS) $(INCLUDES) -g src/main.cpp -o CellSpace
43 |
44 | clean:
45 | rm -rf *.o CellSpace CellSpace.dSYM
46 |
--------------------------------------------------------------------------------
/cpp/src/data.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | #include "data.h"
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 |
15 | using namespace std;
16 |
17 | namespace starspace {
18 |
19 | InternDataHandler::InternDataHandler(shared_ptr args, std::shared_ptr dict, std::shared_ptr parser) {
20 | size_ = 0;
21 | idx_ = -1;
22 | examples_.clear();
23 | dict_ = dict;
24 | parser_ = parser;
25 | args_= args;
26 | }
27 |
28 | void InternDataHandler::errorOnZeroExample(const string& fileName) {
29 | std::cerr << "ERROR: File '" << fileName
30 | << "' does not contain any valid example.\n"
31 | << "Please check: is the file empty? "
32 | << "Do the examples contain proper feature and label according to the trainMode? "
33 | << "If your examples are unlabeled, try to set trainMode=5.\n";
34 | exit(EXIT_FAILURE);
35 | }
36 |
37 | inline bool check_header(string header){
38 | if(header == "%%MatrixMarket matrix coordinate pattern general") return(true);
39 | else if(header == "%%MatrixMarket matrix coordinate integer general") return(false);
40 | else {
41 | cerr << "Unsupported matrix format!" << endl;
42 | exit(EXIT_FAILURE);
43 | }
44 | }
45 |
46 | void InternDataHandler::loadFromFile() {
47 |
48 | examples_ = vector(args_->nPeaks_total);
49 | ngrams_ = vector>(args_->nPeaks_total);
50 | // hashes_ = vector>(args_->nPeaks_total);
51 |
52 | unsigned pfn = 0, pnt = 0, nPeaks_cur = 0;
53 | for(auto peaks: args_->peaks_list){
54 | auto nPeaks = args_->nPeaks_list[pfn++];
55 | cout << "Reading " << nPeaks << " peak sequences from \'" << peaks << "\'" << endl;
56 | ifstream pf(peaks);
57 | string line, seq;
58 | int pi = -1;
59 | while(getline(pf, line))
60 | if(line[0] == '>'){
61 | if(pi >= 0) add_kmer_tokens(nPeaks_cur + pi, seq);
62 | pi++; pnt++;
63 | seq = "";
64 | } else seq += line;
65 | if(pi >= 0 && seq != "") add_kmer_tokens(nPeaks_cur + pi, seq);
66 | pf.close();
67 |
68 | if(pi != nPeaks - 1){
69 | cerr << "The number of peak sequences should match the number of columns in the corresponding count matrix!" << endl;
70 | exit(EXIT_FAILURE);
71 | }
72 |
73 | nPeaks_cur += nPeaks;
74 | }
75 | assert(pnt == args_->nPeaks_total);
76 |
77 | unsigned cfn = 0, nCells_cur = 0; nPeaks_cur = 0;
78 | for(auto cp_matrix: args_->cp_matrix_list){
79 | cout << "Reading a " << args_->nCells_list[cfn] << " cell by " << args_->nPeaks_list[cfn]
80 | << " peak count matrix from \'" << cp_matrix << "\'" << endl;
81 | string str, header, mat_format = cp_matrix.substr(cp_matrix.find_last_of(".") + 1);
82 | bool bin;
83 | unsigned nCells, nPeaks, peak, cell;
84 | if(mat_format == "gz"){
85 | #ifdef COMPRESS_FILE
86 | ifstream ifs2(cp_matrix);
87 | if (!ifs2.good()) exit(EXIT_FAILURE);
88 | filtering_istream cf;
89 | cf.push(gzip_decompressor());
90 | cf.push(ifs2);
91 |
92 | getline(cf, header);
93 | bin = check_header(header);
94 |
95 | cf >> nCells >> nPeaks >> str;
96 | assert(nCells == args_->nCells_list[cfn] && nPeaks == args_->nPeaks_list[cfn]);
97 |
98 | if(bin)
99 | while(cf >> cell){
100 | cf >> peak;
101 | addCellLabel(nPeaks_cur + peak - 1, nCells_cur + cell);
102 | }
103 | else
104 | while(cf >> cell){
105 | cf >> peak >> str;
106 | addCellLabel(nPeaks_cur + peak - 1, nCells_cur + cell);
107 | }
108 |
109 | nCells_cur += nCells;
110 | nPeaks_cur += nPeaks;
111 |
112 | ifs2.close();
113 | #endif
114 | } else {
115 | ifstream cf(cp_matrix);
116 | if(!cf.good()) exit(EXIT_FAILURE);
117 |
118 | getline(cf, header);
119 | bin = check_header(header);
120 |
121 | cf >> nCells >> nPeaks >> str;
122 | assert(nCells == args_->nCells_list[cfn] && nPeaks == args_->nPeaks_list[cfn]);
123 |
124 | if(bin)
125 | while(cf >> cell){
126 | cf >> peak;
127 | addCellLabel(nPeaks_cur + peak - 1, nCells_cur + cell);
128 | }
129 | else
130 | while(cf >> cell){
131 | cf >> peak >> str;
132 | addCellLabel(nPeaks_cur + peak - 1, nCells_cur + cell);
133 | }
134 |
135 | nCells_cur += nCells;
136 | nPeaks_cur += nPeaks;
137 |
138 | cf.close();
139 | }
140 |
141 | for(unsigned pi = nPeaks_cur - args_->nPeaks_list[cfn]; pi < nPeaks_cur; pi++)
142 | examples_[pi].dataset = cfn;
143 |
144 | // if(args_->batchLabels)
145 | // for(unsigned pi = nPeaks_cur - args_->nPeaks_list[cfn]; pi < nPeaks_cur; pi++){
146 | // int32_t wid = dict_->getId(args_->label + "B" + std::to_string(cfn + 1));
147 | // assert(wid >= 0);
148 | // examples_[pi].RHSTokens.push_back(make_pair(wid, 1.0));
149 | // }
150 |
151 | cfn ++;
152 | }
153 |
154 | assert(nCells_cur == args_->nCells_total && nPeaks_cur == args_->nPeaks_total);
155 |
156 | size_ = args_->nPeaks_total * args_->exmpPerPeak;
157 | cout << "Number of datasets: " << args_->nCells_list.size() << endl
158 | << "Total number of cells: " << args_->nCells_total << endl
159 | << "Total number of peaks: " << args_->nPeaks_total << endl
160 | << "Number of training examples: " << size_ << " (" << args_->exmpPerPeak << " per peak)"
161 | << "\n-----------------\n" << endl;
162 |
163 | // for(auto ex: examples_){
164 | // for(auto word: ex.RHSTokens) cerr << dict_->getSymbol(word.first) << " ";
165 | // cerr << ": ";
166 | // for(auto word: ex.LHSTokens) cerr << dict_->getSymbol(word.first) << " ";
167 | // cerr << ex.LHSTokens.size();
168 | // cerr << endl;
169 | // }
170 |
171 | }
172 |
173 | int first_ngram_idx(int L, int i, int n){
174 | if(i <= L - n + 1)
175 | return(i * (n - 1));
176 | else {
177 | int first = (L - n + 1) * (n - 1);
178 | for(int j = L - n + 2; j <= i; j++)
179 | first += L - j;
180 | return(first);
181 | }
182 | }
183 |
184 | // Convert an example for training/testing if needed.
185 | // In the case of trainMode=1, a random label from r.h.s will be selected
186 | // as label, and the rest of labels from r.h.s. will be input features
187 | void InternDataHandler::convert(
188 | const ParseResults& example,
189 | // const std::vector& hashes,
190 | const std::vector& ngrams,
191 | ParseResults& rslt) const {
192 |
193 | #define LHS_SIZE (example.LHSTokens.size())
194 | int num_tokens = (args_->sampleLen == -1) ? LHS_SIZE : (args_->sampleLen - args_->k + 1);
195 | if(num_tokens > LHS_SIZE) num_tokens = LHS_SIZE;
196 | int first = (num_tokens == LHS_SIZE) ? 0 : (rand() % (LHS_SIZE - num_tokens + 1));
197 |
198 | rslt.weight = example.weight;
199 | rslt.dataset = example.dataset;
200 | rslt.LHSTokens.clear();
201 | rslt.RHSTokens.clear();
202 | auto first_token = example.LHSTokens.begin() + first;
203 | rslt.LHSTokens.insert(rslt.LHSTokens.end(), first_token, first_token + num_tokens);
204 |
205 | if(args_->ngrams > 1){
206 | int first_ngram = first_ngram_idx(LHS_SIZE, first, args_->ngrams),
207 | last_ngram = first_ngram_idx(LHS_SIZE, first + num_tokens - 2, args_->ngrams);
208 | if(last_ngram >= first_ngram){
209 | rslt.LHSTokens.insert(rslt.LHSTokens.end(), ngrams.begin() + first_ngram, ngrams.begin() + last_ngram);
210 | }
211 | // std::cerr << num_tokens << "\t" << last_ngram - first_ngram + 1 << std::endl;
212 | }
213 |
214 | if (args_->trainMode == 0) {
215 | // lhs is the same, pick one random label as rhs
216 | assert(example.LHSTokens.size() > 0);
217 | assert(example.RHSTokens.size() > 0);
218 | auto idx = rand() % example.RHSTokens.size();
219 | rslt.RHSTokens.push_back(example.RHSTokens[idx]);
220 | } else {
221 | assert(example.RHSTokens.size() > 1);
222 | if (args_->trainMode == 1) {
223 | // pick one random label as rhs and the rest is lhs
224 | auto idx = rand() % example.RHSTokens.size();
225 | for (unsigned int i = 0; i < example.RHSTokens.size(); i++) {
226 | auto tok = example.RHSTokens[i];
227 | if (i == idx) {
228 | rslt.RHSTokens.push_back(tok);
229 | } else {
230 | rslt.LHSTokens.push_back(tok);
231 | }
232 | }
233 | } else
234 | if (args_->trainMode == 2) {
235 | // pick one random label as lhs and the rest is rhs
236 | auto idx = rand() % example.RHSTokens.size();
237 | for (unsigned int i = 0; i < example.RHSTokens.size(); i++) {
238 | auto tok = example.RHSTokens[i];
239 | if (i == idx) {
240 | rslt.LHSTokens.push_back(tok);
241 | } else {
242 | rslt.RHSTokens.push_back(tok);
243 | }
244 | }
245 | } else
246 | if (args_->trainMode == 3) {
247 | // pick two random labels, one as lhs and the other as rhs
248 | auto idx = rand() % example.RHSTokens.size();
249 | unsigned int idx2;
250 | do {
251 | idx2 = rand() % example.RHSTokens.size();
252 | } while (idx2 == idx);
253 | rslt.LHSTokens.push_back(example.RHSTokens[idx]);
254 | rslt.RHSTokens.push_back(example.RHSTokens[idx2]);
255 | } else
256 | if (args_->trainMode == 4) {
257 | // the first one as lhs and the second one as rhs
258 | rslt.LHSTokens.push_back(example.RHSTokens[0]);
259 | rslt.RHSTokens.push_back(example.RHSTokens[1]);
260 | }
261 | }
262 | }
263 |
264 | void InternDataHandler::getWordExamples(
265 | const vector& doc,
266 | vector& rslts) const {
267 |
268 | rslts.clear();
269 | for (int widx = 0; widx < (int)(doc.size()); widx++) {
270 | ParseResults rslt;
271 | rslt.LHSTokens.clear();
272 | rslt.RHSTokens.clear();
273 | rslt.RHSTokens.push_back(doc[widx]);
274 | for (unsigned int i = max(widx - args_->ws, 0);
275 | i < min(size_t(widx + args_->ws), doc.size()); i++) {
276 | if ((int)i != widx) {
277 | rslt.LHSTokens.push_back(doc[i]);
278 | }
279 | }
280 | rslt.weight = args_->wordWeight;
281 | rslts.emplace_back(rslt);
282 | }
283 | }
284 |
285 | void InternDataHandler::getWordExamples(
286 | int idx,
287 | vector& rslts) const {
288 | assert(idx >= 0 && idx < size_);
289 | idx = idx % args_->nPeaks_total;
290 | const auto& example = examples_[idx];
291 | getWordExamples(example.LHSTokens, rslts);
292 | }
293 |
294 | void InternDataHandler::addExample(const ParseResults& example) {
295 | examples_.push_back(example);
296 | size_++;
297 | }
298 |
299 | void InternDataHandler::getExampleById(int32_t idx, ParseResults& rslt) const {
300 | assert(idx >= 0 && idx < size_);
301 | idx = idx % args_->nPeaks_total;
302 | convert(examples_[idx], ngrams_[idx], rslt);
303 | }
304 |
305 | void InternDataHandler::getNextExample(ParseResults& rslt) {
306 | // assert(args_->nPeaks_total > 0);
307 | idx_ = idx_ + 1;
308 | // go back to the beginning of the examples if we reach the end
309 | if (idx_ >= args_->nPeaks_total) {
310 | idx_ = idx_ - args_->nPeaks_total;
311 | }
312 | convert(examples_[idx_], ngrams_[idx_], rslt);
313 | }
314 |
315 | void InternDataHandler::getRandomExample(ParseResults& rslt) const {
316 | // assert(args_->nPeaks_total > 0);
317 | int32_t idx = rand() % args_->nPeaks_total;
318 | convert(examples_[idx], ngrams_[idx], rslt);
319 | }
320 |
321 | void InternDataHandler::getKRandomExamples(int K, vector& c) {
322 | auto kSamples = min(K, (int)args_->nPeaks_total);
323 | for (int i = 0; i < kSamples; i++) {
324 | ParseResults example;
325 | getRandomExample(example);
326 | c.push_back(example);
327 | }
328 | }
329 |
330 | void InternDataHandler::getNextKExamples(int K, vector& c) {
331 | auto kSamples = min(K, (int)args_->nPeaks_total);
332 | for (int i = 0; i < kSamples; i++) {
333 | idx_ = (idx_ + 1) % args_->nPeaks_total;
334 | ParseResults example;
335 | convert(examples_[idx_], ngrams_[idx_], example);
336 | c.push_back(example);
337 | }
338 | }
339 |
340 | void InternDataHandler::getRandomWord(vector& result) {
341 | result.push_back(word_negatives_[word_iter_]);
342 | word_iter_++;
343 | if (word_iter_ >= (int)word_negatives_.size()) {
344 | word_iter_ = 0;
345 | }
346 | }
347 |
348 | void InternDataHandler::initWordNegatives() {
349 | word_iter_ = 0;
350 | word_negatives_.clear();
351 | // assert(args_->nPeaks_total > 0);
352 | for (int i = 0; i < MAX_WORD_NEGATIVES_SIZE; i++) {
353 | word_negatives_.emplace_back(genRandomWord());
354 | }
355 | }
356 |
357 | Base InternDataHandler::genRandomWord() const {
358 | // assert(args_->nPeaks_total > 0);
359 | auto& ex = examples_[rand() % args_->nPeaks_total];
360 | int r = rand() % ex.LHSTokens.size();
361 | return ex.LHSTokens[r];
362 | }
363 |
364 | // Randomly sample one example and randomly sample a label from this example
365 | // The result is usually used as negative samples in training
366 | void InternDataHandler::getRandomRHS(vector& results, unsigned dataset) const {
367 | // assert(args_->nPeaks_total > 0);
368 | results.clear();
369 |
370 | unsigned pi = args_->first_peak_idx[dataset] + (rand() % args_->nPeaks_list[dataset]);
371 | // std::cerr << args_->first_peak_idx[dataset] << " " << pi << " - ";
372 | auto& ex = examples_[pi];
373 | unsigned int r = rand() % ex.RHSTokens.size();
374 | if (args_->trainMode == 2) {
375 | for (unsigned int i = 0; i < ex.RHSTokens.size(); i++) {
376 | if (i != r) {
377 | results.push_back(ex.RHSTokens[i]);
378 | }
379 | }
380 | } else {
381 | results.push_back(ex.RHSTokens[r]);
382 | }
383 | }
384 |
385 | void InternDataHandler::save(std::ostream& out) {
386 | out << "data size : " << args_->nPeaks_total << endl;
387 | for (auto& example : examples_) {
388 | out << "lhs : ";
389 | for (auto t : example.LHSTokens) {out << t.first << ':' << t.second << ' ';}
390 | out << endl;
391 | out << "rhs : ";
392 | for (auto t : example.RHSTokens) {out << t.first << ':' << t.second << ' ';}
393 | out << endl;
394 | }
395 | }
396 |
397 | } // unamespace starspace
398 |
--------------------------------------------------------------------------------
/cpp/src/data.h:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | #pragma once
9 |
10 | #include "dict.h"
11 | #include "parser.h"
12 | #include "utils/utils.h"
13 | #include
14 | #include
15 | #include
16 |
17 | namespace starspace {
18 |
19 | class InternDataHandler {
20 | public:
21 | explicit InternDataHandler(std::shared_ptr args, std::shared_ptr dict, std::shared_ptr parser);
22 |
23 | virtual void loadFromFile();
24 |
25 | virtual void convert(const ParseResults& example, const std::vector& ngrams, ParseResults& rslt) const;
26 |
27 | virtual void getRandomRHS(std::vector& results, unsigned dataset) const;
28 |
29 | virtual void save(std::ostream& out);
30 |
31 | virtual void getWordExamples(int idx, std::vector& rslt) const;
32 |
33 | void add_kmer_tokens(int pi, std::string seq){
34 | std::vector tokens;
35 | for(unsigned i = 0; i < seq.length() - args_->k + 1; i++){
36 | std::string kmer = "", rc_kmer = "";
37 | bool N = false;
38 | for(unsigned j = i; (j < i + args_->k) && !N; j++){
39 | switch(seq[j]){
40 | case 'A':
41 | case 'a':
42 | kmer = kmer + "A";
43 | rc_kmer = "T" + rc_kmer;
44 | break;
45 | case 'C':
46 | case 'c':
47 | kmer = kmer + "C";
48 | rc_kmer = "G" + rc_kmer;
49 | break;
50 | case 'T':
51 | case 't':
52 | kmer = kmer + "T";
53 | rc_kmer = "A" + rc_kmer;
54 | break;
55 | case 'G':
56 | case 'g':
57 | kmer = kmer + "G";
58 | rc_kmer = "C" + rc_kmer;
59 | break;
60 | default:
61 | N = true;
62 | i = j;
63 | }
64 | }
65 | if(!N){
66 | int32_t kmer_wid = dict_->getId(kmer), rc_kmer_wid = dict_->getId(rc_kmer);
67 | if(!(kmer_wid < 0)){
68 | tokens.push_back(kmer);
69 | examples_[pi].LHSTokens.push_back(std::make_pair(kmer_wid, 1.0));
70 | } else if(!(rc_kmer_wid < 0)){
71 | tokens.push_back(rc_kmer);
72 | examples_[pi].LHSTokens.push_back(std::make_pair(rc_kmer_wid, 1.0));
73 | } else {
74 | std::cerr << "Invalid DNA k-mer! \'" << kmer << "\'" << std::endl;
75 | continue;
76 | }
77 | }
78 | }
79 |
80 | assert(examples_[pi].LHSTokens.size() > 0);
81 |
82 | if(args_->ngrams > 1){
83 | parser_->addNgrams(tokens, ngrams_[pi], args_->ngrams);
84 | // for(auto token: tokens) hashes_[pi].push_back(dict_->hash(token));
85 | }
86 | }
87 |
88 | void getWordExamples(
89 | const std::vector& doc,
90 | std::vector& rslt) const;
91 |
92 | void addExample(const ParseResults& example);
93 |
94 | void addCellLabel(unsigned peak, unsigned cell){
95 | int32_t wid = dict_->getId(args_->label + "C" + std::to_string(cell));
96 | assert(wid >= 0);
97 | examples_[peak].RHSTokens.push_back(std::make_pair(wid, 1.0));
98 | }
99 |
100 | void getExampleById(int32_t idx, ParseResults& rslt) const;
101 |
102 | void getNextExample(ParseResults& rslt);
103 |
104 | void getRandomExample(ParseResults& rslt) const;
105 |
106 | void getKRandomExamples(int K, std::vector& c);
107 |
108 | void getNextKExamples(int K, std::vector& c);
109 |
110 | size_t getSize() const { return size_; };
111 |
112 | void errorOnZeroExample(const std::string& fileName);
113 |
114 | void initWordNegatives();
115 | void getRandomWord(std::vector& result);
116 |
117 |
118 | protected:
119 | virtual Base genRandomWord() const;
120 |
121 | static const int32_t MAX_VOCAB_SIZE = 10000000;
122 | static const int32_t MAX_WORD_NEGATIVES_SIZE = 10000000;
123 |
124 | std::shared_ptr args_;
125 | std::shared_ptr dict_;
126 | std::shared_ptr parser_;
127 | std::vector examples_;
128 | std::vector> ngrams_;
129 | // std::vector> hashes_;
130 |
131 | int32_t idx_ = -1;
132 | int32_t size_ = 0;
133 |
134 | int32_t word_iter_;
135 | std::vector word_negatives_;
136 | };
137 |
138 | }
139 |
--------------------------------------------------------------------------------
/cpp/src/dict.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | #include "dict.h"
9 | #include "parser.h"
10 |
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 |
18 | using namespace std;
19 | using namespace boost::iostreams;
20 |
21 | namespace starspace {
22 |
23 | const std::string Dictionary::EOS = "";
24 | const uint32_t Dictionary::HASH_C = 116049371;
25 |
26 | Dictionary::Dictionary(shared_ptr args) : args_(args),
27 | hashToIndex_(MAX_VOCAB_SIZE, -1), size_(0), nwords_(0), nlabels_(0),
28 | ntokens_(0)
29 | {
30 | entryList_.clear();
31 | }
32 |
33 | // hash trick from fastText
34 | uint32_t Dictionary::hash(const std::string& str) const {
35 | uint32_t h = 2166136261;
36 | for (size_t i = 0; i < str.size(); i++) {
37 | h = h ^ uint32_t(str[i]);
38 | h = h * 16777619;
39 | }
40 | return h;
41 | }
42 |
43 | int32_t Dictionary::find(const std::string& w) const {
44 | int32_t h = hash(w) % MAX_VOCAB_SIZE;
45 | while (hashToIndex_[h] != -1 && entryList_[hashToIndex_[h]].symbol != w) {
46 | h = (h + 1) % MAX_VOCAB_SIZE;
47 | }
48 | return h;
49 | }
50 |
51 | int32_t Dictionary::getId(const string& symbol) const {
52 | int32_t h = find(symbol);
53 | return hashToIndex_[h];
54 | }
55 |
56 | const std::string& Dictionary::getSymbol(int32_t id) const {
57 | assert(id >= 0);
58 | assert(id < size_);
59 | return entryList_[id].symbol;
60 | }
61 |
62 | const std::string& Dictionary::getLabel(int32_t lid) const {
63 | assert(lid >= 0);
64 | assert(lid < nlabels_);
65 | return entryList_[lid + nwords_].symbol;
66 | }
67 |
68 | entry_type Dictionary::getType(int32_t id) const {
69 | assert(id >= 0);
70 | assert(id < size_);
71 | return entryList_[id].type;
72 | }
73 |
74 | entry_type Dictionary::getType(const string& w) const {
75 | return (w.find(args_->label) == 0)? entry_type::label : entry_type::word;
76 | }
77 |
78 | void Dictionary::insert(const string& symbol) {
79 | int32_t h = find(symbol);
80 | ntokens_++;
81 | if (hashToIndex_[h] == -1) {
82 | entry e;
83 | e.symbol = symbol;
84 | e.count = 1;
85 | e.type = getType(symbol);
86 | entryList_.push_back(e);
87 | hashToIndex_[h] = size_++;
88 | } else {
89 | entryList_[hashToIndex_[h]].count++;
90 | }
91 | }
92 |
93 | void Dictionary::save(std::ostream& out) const {
94 | out.write((char*) &size_, sizeof(int32_t));
95 | out.write((char*) &nwords_, sizeof(int32_t));
96 | out.write((char*) &nlabels_, sizeof(int32_t));
97 | out.write((char*) &ntokens_, sizeof(int64_t));
98 | for (int32_t i = 0; i < size_; i++) {
99 | entry e = entryList_[i];
100 | out.write(e.symbol.data(), e.symbol.size() * sizeof(char));
101 | out.put(0);
102 | out.write((char*) &(e.count), sizeof(int64_t));
103 | out.write((char*) &(e.type), sizeof(entry_type));
104 | }
105 | }
106 |
107 | void Dictionary::load(std::istream& in) {
108 | entryList_.clear();
109 | std::fill(hashToIndex_.begin(), hashToIndex_.end(), -1);
110 | in.read((char*) &size_, sizeof(int32_t));
111 | in.read((char*) &nwords_, sizeof(int32_t));
112 | in.read((char*) &nlabels_, sizeof(int32_t));
113 | in.read((char*) &ntokens_, sizeof(int64_t));
114 | for (int32_t i = 0; i < size_; i++) {
115 | char c;
116 | entry e;
117 | while ((c = in.get()) != 0) {
118 | e.symbol.push_back(c);
119 | }
120 | in.read((char*) &e.count, sizeof(int64_t));
121 | in.read((char*) &e.type, sizeof(entry_type));
122 | entryList_.push_back(e);
123 | hashToIndex_[find(e.symbol)] = i;
124 | }
125 | }
126 |
127 | /* Build dictionary from file.
128 | * In dictionary building process, if the current dictionary is at 75% capacity,
129 | * it automatically increases the threshold for both word and label.
130 | * At the end the -minCount and -minCountLabel from arguments will be applied
131 | * as thresholds.
132 | */
133 | void Dictionary::readFromFile(
134 | const std::string& file,
135 | shared_ptr parser) {
136 |
137 | int64_t minThreshold = 1;
138 | size_t lines_read = 0;
139 |
140 | auto readFromInputStream = [&](std::istream& in) {
141 | string line;
142 | while (getline(in, line, '\n')) {
143 | vector tokens;
144 | parser->parseForDict(line, tokens);
145 | lines_read++;
146 | for (auto token : tokens) {
147 | insert(token);
148 | if ((ntokens_ % 1000000 == 0) && args_->verbose) {
149 | std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::flush;
150 | }
151 | if (size_ > 0.75 * MAX_VOCAB_SIZE) {
152 | minThreshold++;
153 | threshold(minThreshold, minThreshold);
154 | }
155 | }
156 | }
157 | };
158 |
159 | #ifdef COMPRESS_FILE
160 | if (args_->compressFile == "gzip") {
161 | cout << "Build dict from compressed input file.\n";
162 | for (int i = 0; i < args_->numGzFile; i++) {
163 | filtering_istream in;
164 | auto str_idx = boost::str(boost::format("%02d") % i);
165 | auto fname = file + str_idx + ".gz";
166 | ifstream ifs(fname);
167 | if (!ifs.good()) {
168 | continue;
169 | }
170 | in.push(gzip_decompressor());
171 | in.push(ifs);
172 | readFromInputStream(in);
173 | ifs.close();
174 | }
175 | } else {
176 | cout << "Build dict from input file : " << file << endl;
177 | ifstream fin(file);
178 | if (!fin.is_open()) {
179 | cerr << "Input file cannot be opened!" << endl;
180 | exit(EXIT_FAILURE);
181 | }
182 | readFromInputStream(fin);
183 | fin.close();
184 | }
185 | #else
186 | cout << "Build dict from input file : " << file << endl;
187 | ifstream fin(file);
188 | if (!fin.is_open()) {
189 | cerr << "Input file cannot be opened!" << endl;
190 | exit(EXIT_FAILURE);
191 | }
192 | readFromInputStream(fin);
193 | fin.close();
194 | #endif
195 |
196 | threshold(args_->minCount, args_->minCountLabel);
197 |
198 | std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::endl;
199 | std::cerr << "Number of words (" << args_->k << "-mers) in dictionary: " << nwords_ << std::endl;
200 | std::cerr << "Number of labels in dictionary: " << nlabels_ << std::endl;
201 | if (lines_read == 0) {
202 | std::cerr << "ERROR: Empty file." << std::endl;
203 | exit(EXIT_FAILURE);
204 | }
205 | if (size_ == 0) {
206 | std::cerr << "Empty vocabulary. Try a smaller -minCount value."
207 | << std::endl;
208 | exit(EXIT_FAILURE);
209 | }
210 | }
211 |
212 | void Dictionary::addKmers(unsigned k, int64_t &minThreshold, unsigned len = 0,
213 | string seq = "", string rc_seq = ""){
214 | if(len == k){
215 | int32_t h = find(rc_seq);
216 | if(hashToIndex_[h] == -1){
217 | insert(seq);
218 | if(size_ > 0.75 * MAX_VOCAB_SIZE){
219 | minThreshold++;
220 | threshold(minThreshold, minThreshold);
221 | }
222 | }
223 | } else {
224 | addKmers(k, minThreshold, len + 1, seq + "A", "T" + rc_seq);
225 | addKmers(k, minThreshold, len + 1, seq + "C", "G" + rc_seq);
226 | addKmers(k, minThreshold, len + 1, seq + "T", "A" + rc_seq);
227 | addKmers(k, minThreshold, len + 1, seq + "G", "C" + rc_seq);
228 | }
229 | }
230 |
231 | void Dictionary::addCells(unsigned nCells, int64_t &minThreshold, string label){
232 | for (unsigned ci = 0; ci < nCells; ci ++) {
233 | insert(label + "C" + to_string(ci + 1));
234 | if (size_ > 0.75 * MAX_VOCAB_SIZE) {
235 | minThreshold++;
236 | threshold(minThreshold, minThreshold);
237 | }
238 | }
239 | }
240 |
241 | void Dictionary::addBatches(unsigned nBatches, int64_t &minThreshold, string label){
242 | for (unsigned bi = 0; bi < nBatches; bi ++) {
243 | insert(label + "B" + to_string(bi + 1));
244 | if (size_ > 0.75 * MAX_VOCAB_SIZE) {
245 | minThreshold++;
246 | threshold(minThreshold, minThreshold);
247 | }
248 | }
249 | }
250 |
251 | void Dictionary::CreateForCellSpace(shared_ptr parser){
252 | int64_t minThreshold = 1;
253 | addKmers(args_->k, minThreshold); // Add k-mer IDs as words
254 | addCells(args_->nCells_total, minThreshold, args_->label); // Add cell IDs as labels
255 | // addPeaks(args_->nPeaks_total, minThreshold, args_->label); // Add peak IDs as labels
256 | // if(args_->batchLabels) addBatches(args_->nBatches, minThreshold, args_->label); // Add batch IDs as labels
257 |
258 | threshold(args_->minCount, args_->minCountLabel);
259 |
260 | // std::cerr << "\rRead " << ntokens_ / 1000000 << "M words" << std::endl;
261 | std::cout << "Number of words (" << args_->k << "-mers) in dictionary: " << nwords_ << std::endl
262 | << "Number of labels in dictionary: " << nlabels_ << std::endl;
263 | }
264 |
265 | // Sort the dictionary by [word, label] order and by number of occurance.
266 | // Removes word / label that does not pass respective threshold.
267 | void Dictionary::threshold(int64_t t, int64_t tl) {
268 | sort(entryList_.begin(), entryList_.end(), [](const entry& e1, const entry& e2) {
269 | if (e1.type != e2.type) return e1.type < e2.type;
270 | return e1.count > e2.count;
271 | });
272 | entryList_.erase(remove_if(entryList_.begin(), entryList_.end(), [&](const entry& e) {
273 | return (e.type == entry_type::word && e.count < t) ||
274 | (e.type == entry_type::label && e.count < tl);
275 | }), entryList_.end());
276 |
277 | entryList_.shrink_to_fit();
278 |
279 | computeCounts();
280 | }
281 |
282 | void Dictionary::computeCounts() {
283 | size_ = 0;
284 | nwords_ = 0;
285 | nlabels_ = 0;
286 | std::fill(hashToIndex_.begin(), hashToIndex_.end(), -1);
287 | for (auto it = entryList_.begin(); it != entryList_.end(); ++it) {
288 | int32_t h = find(it->symbol);
289 | hashToIndex_[h] = size_++;
290 | if (it->type == entry_type::word) nwords_++;
291 | if (it->type == entry_type::label) nlabels_++;
292 | }
293 | }
294 |
295 | // Given a model saved in .tsv format, build the dictionary from model.
296 | void Dictionary::loadDictFromModel(const string& modelfile) {
297 | cout << "Loading dict from model file : " << modelfile << endl;
298 | ifstream fin(modelfile);
299 | string line;
300 | while (getline(fin, line)) {
301 | string symbol;
302 | stringstream ss(line);
303 | ss >> symbol;
304 | insert(symbol);
305 | }
306 | fin.close();
307 | computeCounts();
308 |
309 | std::cout << "Number of words in dictionary: " << nwords_ << std::endl;
310 | std::cout << "Number of labels in dictionary: " << nlabels_ << std::endl;
311 | }
312 |
313 | } // namespace
314 |
--------------------------------------------------------------------------------
/cpp/src/dict.h:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | /**
9 | * The implementation of dictionary here is very similar to the dictionary used
10 | * in fastText (https://github.com/facebookresearch/fastText).
11 | */
12 |
13 | #pragma once
14 |
15 | #include "utils/args.h"
16 | #include
17 | #include
18 | #include
19 | #include
20 | #include
21 | #include
22 | #include
23 | #include
24 |
25 | #ifdef COMPRESS_FILE
26 | #include
27 | #include
28 | #endif
29 |
30 | namespace starspace {
31 |
32 | class DataParser;
33 |
34 | enum class entry_type : int8_t {word=0, label=1};
35 |
36 | struct entry {
37 | std::string symbol;
38 | int64_t count;
39 | entry_type type;
40 | };
41 |
42 | class Dictionary {
43 | public:
44 | static const std::string EOS;
45 | static const uint32_t HASH_C;
46 |
47 | explicit Dictionary(std::shared_ptr);
48 | int32_t size() const { return size_; };
49 | int32_t nwords() const { return nwords_; };
50 | int32_t nlabels() const { return nlabels_; };
51 | int32_t ntokens() const { return ntokens_; };
52 | int32_t getId(const std::string&) const;
53 | entry_type getType(int32_t) const;
54 | entry_type getType(const std::string&) const;
55 | const std::string& getSymbol(int32_t) const;
56 | const std::string& getLabel(int32_t) const;
57 |
58 | uint32_t hash(const std::string& str) const;
59 | void insert(const std::string&);
60 |
61 | void load(std::istream&);
62 | void save(std::ostream&) const;
63 | void readFromFile(const std::string&, std::shared_ptr);
64 | bool readWord(std::istream&, std::string&) const;
65 |
66 | void threshold(int64_t, int64_t);
67 | void computeCounts();
68 | void loadDictFromModel(const std::string& model);
69 |
70 | void addKmers(unsigned, int64_t&, unsigned, std::string, std::string);
71 | void addCells(unsigned, int64_t&, std::string);
72 | void addBatches(unsigned, int64_t&, std::string);
73 | void CreateForCellSpace(std::shared_ptr);
74 |
75 | private:
76 | static const int32_t MAX_VOCAB_SIZE = 30000000;
77 |
78 | int32_t find(const std::string&) const;
79 |
80 | void addNgrams(
81 | std::vector& line,
82 | const std::vector& hashes,
83 | int32_t n) const;
84 |
85 | std::shared_ptr args_;
86 | std::vector entryList_;
87 | std::vector hashToIndex_;
88 |
89 | int32_t size_;
90 | int32_t nwords_;
91 | int32_t nlabels_;
92 | int64_t ntokens_;
93 | };
94 |
95 | }
96 |
--------------------------------------------------------------------------------
/cpp/src/main.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | #include "starspace.h"
9 | #include
10 | #include
11 |
12 | using namespace std;
13 | using namespace starspace;
14 |
15 | int main(int argc, char** argv) {
16 | shared_ptr args = make_shared();
17 | args->parseArgs(argc, argv);
18 | args->printArgs();
19 |
20 | StarSpace sp(args);
21 | if (args->isTrain) {
22 | // if (!args->initModel.empty()) {
23 | // if (boost::algorithm::ends_with(args->initModel, ".tsv")) {
24 | // sp.initFromTsv(args->initModel);
25 | // } else {
26 | // sp.initFromSavedModel(args->initModel);
27 | // cout << "------Loaded model args:\n";
28 | // args->printArgs();
29 | // }
30 | // } else {
31 | sp.init();
32 | // }
33 | sp.train();
34 | // sp.saveModel(args->model);
35 | sp.saveModelTsv(args->model + ".tsv");
36 | }
37 | // else {
38 | // if (boost::algorithm::ends_with(args->model, ".tsv")) {
39 | // sp.initFromTsv(args->model);
40 | // } else {
41 | // sp.initFromSavedModel(args->model);
42 | // cout << "------Loaded model args:\n";
43 | // args->printArgs();
44 | // }
45 | // sp.evaluate();
46 | // }
47 |
48 | return 0;
49 | }
50 |
--------------------------------------------------------------------------------
/cpp/src/matrix.h:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | /**
9 | * Mostly a collection of convenience routines around ublas.
10 | * We avoid doing any actual compute-intensive work in this file.
11 | */
12 |
13 | #pragma once
14 |
15 | #include
16 | #include
17 | #include
18 | #include
19 | #include
20 | #include
21 | #include
22 |
23 | #include
24 | #include
25 | #include
26 |
27 | namespace starspace {
28 |
29 | struct MatrixDims {
30 | size_t r, c;
31 | size_t numElts() const { return r * c; }
32 | bool operator==(const MatrixDims& rhs) {
33 | return r == rhs.r && c == rhs.c;
34 | }
35 | };
36 |
37 | template
38 | struct Matrix {
39 | static const int kAlign = 64;
40 | boost::numeric::ublas::matrix matrix;
41 |
42 | explicit Matrix(MatrixDims dims,
43 | Real sd = 1.0) :
44 | matrix(dims.r, dims.c)
45 | {
46 | assert(matrix.size1() == dims.r);
47 | assert(matrix.size2() == dims.c);
48 | if (sd > 0.0) {
49 | randomInit(sd);
50 | }
51 | }
52 |
53 | explicit Matrix(const std::vector>& init) {
54 | size_t rows = init.size();
55 | size_t maxCols = 0;
56 | for (const auto& r : init) {
57 | maxCols = std::max(maxCols, r.size());
58 | }
59 | alloc(rows, maxCols);
60 | for (size_t i = 0; i < numRows(); i++) {
61 | size_t j;
62 | for (j = 0; j < init[i].size(); j++) {
63 | (*this)[i][j] = init[i][j];
64 | }
65 | for (; j < numCols(); j++) {
66 | (*this)[i][j] = 0.0;
67 | }
68 | }
69 | }
70 |
71 | explicit Matrix(std::istream& in) {
72 | in >> matrix;
73 | }
74 |
75 | Matrix() {
76 | alloc(0, 0);
77 | }
78 |
79 | Real* operator[](size_t i) {
80 | assert(i >= 0);
81 | assert(i < numRows());
82 | return &matrix(i, 0);
83 | }
84 |
85 | const Real* operator[](size_t i) const {
86 | assert(i >= 0);
87 | assert(i < numRows());
88 | return &matrix(i, 0);
89 | }
90 |
91 | Real& cell(size_t i, size_t j) {
92 | assert(i >= 0);
93 | assert(i < numRows());
94 | assert(j < numCols());
95 | assert(j >= 0);
96 | return matrix(i, j);
97 | }
98 |
99 | void add(const Matrix& rhs, Real scale = 1.0) {
100 | matrix += scale * rhs.matrix;
101 | }
102 |
103 | void forEachCell(std::function l) {
104 | for (size_t i = 0; i < numRows(); i++)
105 | for (size_t j = 0; j < numCols(); j++)
106 | l(matrix(i, j));
107 | }
108 |
109 | void forEachCell(std::function l) const {
110 | for (size_t i = 0; i < numRows(); i++)
111 | for (size_t j = 0; j < numCols(); j++)
112 | l(matrix(i, j));
113 | }
114 |
115 | void forEachCell(std::function l) {
116 | for (size_t i = 0; i < numRows(); i++)
117 | for (size_t j = 0; j < numCols(); j++)
118 | l(matrix(i, j), i, j);
119 | }
120 |
121 | void forEachCell(std::function l) const {
122 | for (size_t i = 0; i < numRows(); i++)
123 | for (size_t j = 0; j < numCols(); j++)
124 | l(matrix(i, j), i, j);
125 | }
126 |
127 | void sanityCheck() const {
128 | #ifndef NDEBUG
129 | forEachCell([&](Real r, size_t i, size_t j) {
130 | assert(!std::isnan(r));
131 | assert(!std::isinf(r));
132 | });
133 | #endif
134 | }
135 |
136 | void forRow(size_t r, std::function l) {
137 | for (size_t j = 0; j < numCols(); j++) l(matrix(r, j), j);
138 | }
139 |
140 | void forRow(size_t r, std::function l) const {
141 | for (size_t j = 0; j < numCols(); j++) l(matrix(r, j), j);
142 | }
143 |
144 | void forCol(size_t c, std::function l) {
145 | for (size_t i = 0; i < numRows(); i++) l(matrix(i, c), i);
146 | }
147 |
148 | void forCol(size_t c, std::function l) const {
149 | for (size_t i = 0; i < numRows(); i++) l(matrix(c, i), i);
150 | }
151 |
152 | static void mul(const Matrix& l, const Matrix& r, Matrix& dest) {
153 | dest.matrix = boost::numeric::ublas::prod(l.matrix, r.matrix);
154 | }
155 |
156 | void updateRow(size_t r, Matrix& addend, Real scale = 1.0) {
157 | using namespace boost::numeric::ublas;
158 | assert(addend.numRows() == 1);
159 | assert(addend.numCols() == numCols());
160 | row(r) += Row { addend.matrix, 0 } * scale;
161 | }
162 |
163 | typedef boost::numeric::ublas::matrix_row>
164 | Row;
165 | Row row(size_t r) { return Row{ matrix, r }; }
166 |
167 | /* implicit */ operator Row() {
168 | assert(numRows() == 1);
169 | return Row{ matrix, 0 };
170 | }
171 |
172 | size_t numElts() const { return numRows() * numCols(); }
173 | size_t numRows() const { return matrix.size1(); }
174 | size_t numCols() const { return matrix.size2(); }
175 | MatrixDims getDims() const { return { numRows(), numCols() }; }
176 |
177 | void reshape(MatrixDims dims) {
178 | if (dims == getDims()) return;
179 | alloc(dims.r, dims.c);
180 | }
181 |
182 | typedef size_t iterator;
183 | iterator begin() { return 0; }
184 | iterator end() { return numElts(); }
185 |
186 | void write(std::ostream& out) {
187 | out << matrix;
188 | }
189 |
190 | void randomInit(Real sd = 1.0) {
191 | if (numElts() > 0) {
192 | // Multi-threaded initialization brings debug init time down
193 | // from minutes to seconds.
194 | auto d = &matrix(0, 0);
195 | std::minstd_rand gen;
196 | auto nd = std::normal_distribution(0, sd);
197 | for (size_t i = 0; i < numElts(); i++) {
198 | d[i] = nd(gen);
199 | };
200 | }
201 | }
202 |
203 | private:
204 | void alloc(size_t r, size_t c) {
205 | matrix = boost::numeric::ublas::matrix(r, c);
206 | }
207 | };
208 |
209 | }
210 |
--------------------------------------------------------------------------------
/cpp/src/model.h:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | #pragma once
9 |
10 | #include "matrix.h"
11 | #include "proj.h"
12 | #include "dict.h"
13 | #include "utils/normalize.h"
14 | #include "utils/args.h"
15 | #include "data.h"
16 | // #include "doc_data.h"
17 |
18 | #include
19 | #include
20 | #include
21 |
22 |
23 | namespace starspace {
24 |
25 | typedef float Real;
26 | typedef boost::numeric::ublas::matrix_row::matrix)>
27 | MatrixRow;
28 | typedef boost::numeric::ublas::vector Vector;
29 |
30 | /*
31 | * The model is basically two lookup tables: one for left hand side
32 | * (LHS) entities, one for right hand side (RHS) entities.
33 | */
34 | struct EmbedModel : public boost::noncopyable {
35 | public:
36 | explicit EmbedModel(std::shared_ptr args,
37 | std::shared_ptr dict);
38 |
39 |
40 | typedef std::vector Corpus;
41 | float train(std::shared_ptr data,
42 | int numThreads,
43 | std::chrono::time_point t_start,
44 | int epochs_done,
45 | Real startRate,
46 | Real endRate,
47 | bool verbose = true);
48 |
49 | float test(std::shared_ptr data, int numThreads) {
50 | return this->train(data, numThreads,
51 | std::chrono::high_resolution_clock::now(), 0,
52 | 0.0, 0.0, false);
53 | }
54 |
55 | float trainOneBatch(std::shared_ptr data,
56 | const std::vector& batch_exs,
57 | size_t negSearchLimits,
58 | Real rate,
59 | bool trainWord = false);
60 |
61 | float trainNLLBatch(std::shared_ptr data,
62 | const std::vector& batch_exs,
63 | int32_t negSearchLimit,
64 | Real rate,
65 | bool trainWord = false);
66 |
67 | void backward(const std::vector& batch_exs,
68 | const std::vector>>& negLabels,
69 | std::vector> gradW,
70 | std::vector> lhs,
71 | const std::vector& num_negs,
72 | Real rate_lhs,
73 | const std::vector& rate_rhsP,
74 | const std::vector>& nRate,
75 | std::vector dataset);
76 |
77 | // Querying
78 | std::vector>
79 | kNN(std::shared_ptr> lookup,
80 | Matrix point,
81 | int numSim);
82 |
83 | std::vector>
84 | findLHSLike(Matrix point, int numSim = 5) {
85 | return kNN(LHSEmbeddings_, point, numSim);
86 | }
87 |
88 | std::vector>
89 | findRHSLike(Matrix point, int numSim = 5) {
90 | return kNN(RHSEmbeddings_, point, numSim);
91 | }
92 |
93 | Matrix projectRHS(const std::vector& ws);
94 | Matrix projectLHS(const std::vector& ws);
95 |
96 | void projectLHS(const std::vector& ws, Matrix& retval);
97 | void projectRHS(const std::vector& ws, Matrix& retval);
98 |
99 | void loadTsv(std::istream& in, const std::string sep = "\t ");
100 | void loadTsv(const char* fname, const std::string sep = "\t ");
101 | void loadTsv(const std::string& fname, const std::string sep = "\t ") {
102 | return loadTsv(fname.c_str(), sep);
103 | }
104 | void loadFeatEmb(const std::string& fname, const std::string sep = "\t");
105 |
106 | void saveTsv(std::ostream& out, const char sep = '\t') const;
107 |
108 | void save(std::ostream& out) const;
109 |
110 | void load(std::ifstream& in);
111 |
112 | const std::string& lookupLHS(int32_t idx) const {
113 | return dict_->getSymbol(idx);
114 | }
115 | const std::string& lookupRHS(int32_t idx) const {
116 | return dict_->getLabel(idx);
117 | }
118 |
119 | void loadTsvLine(std::string& line, int lineNum, int cols,
120 | const std::string sep = "\t");
121 |
122 | std::shared_ptr getDict() { return dict_; }
123 |
124 | std::shared_ptr>& getLHSEmbeddings() {
125 | return LHSEmbeddings_;
126 | }
127 | const std::shared_ptr>& getLHSEmbeddings() const {
128 | return LHSEmbeddings_;
129 | }
130 | std::shared_ptr>& getRHSEmbeddings() {
131 | return RHSEmbeddings_;
132 | }
133 | const std::shared_ptr>& getRHSEmbeddings() const {
134 | return RHSEmbeddings_;
135 | }
136 |
137 | void initModelWeights();
138 |
139 | Real similarity(const MatrixRow& a, const MatrixRow& b);
140 | Real similarity(Matrix& a, Matrix& b) {
141 | return similarity(asRow(a), asRow(b));
142 | }
143 |
144 | static Real cosine(const MatrixRow& a, const MatrixRow& b);
145 | static Real cosine(Matrix& a, Matrix& b) {
146 | return cosine(asRow(a), asRow(b));
147 | }
148 |
149 | static MatrixRow asRow(Matrix& m) {
150 | assert(m.numRows() == 1);
151 | return MatrixRow(m.matrix, 0);
152 | }
153 |
154 | static void normalize(Matrix::Row row, double maxNorm = 1.0);
155 | static void normalize(Matrix& m) { normalize(asRow(m)); }
156 |
157 | private:
158 | std::shared_ptr dict_;
159 | std::shared_ptr> LHSEmbeddings_;
160 | std::shared_ptr> RHSEmbeddings_;
161 | std::shared_ptr args_;
162 |
163 | std::vector LHSUpdates_;
164 | std::vector RHSUpdates_;
165 |
166 | #ifdef NDEBUG
167 | static const bool debug = false;
168 | #else
169 | static const bool debug = false;
170 | #endif
171 |
172 | static void check(const Matrix& m) {
173 | m.sanityCheck();
174 | }
175 |
176 | static void check(const boost::numeric::ublas::matrix& m) {
177 | if (!debug) return;
178 | for (unsigned int i = 0; i < m.size1(); i++) {
179 | for (unsigned int j = 0; j < m.size2(); j++) {
180 | assert(!std::isnan(m(i, j)));
181 | assert(!std::isinf(m(i, j)));
182 | }
183 | }
184 | }
185 |
186 | };
187 |
188 | }
189 |
--------------------------------------------------------------------------------
/cpp/src/parser.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 |
9 | #include "parser.h"
10 | #include "utils/normalize.h"
11 | #include
12 | #include
13 | #include
14 | #include
15 |
16 | #include
17 |
18 | using namespace std;
19 |
20 | namespace starspace {
21 |
22 | void chomp(std::string& line, char toChomp = '\n') {
23 | auto sz = line.size();
24 | if (sz >= 1 && line[sz - 1] == toChomp) {
25 | line.resize(sz - 1);
26 | }
27 | }
28 |
29 | DataParser::DataParser(
30 | shared_ptr dict,
31 | shared_ptr args) {
32 | dict_ = dict;
33 | args_ = args;
34 | }
35 |
36 | bool DataParser::parse(
37 | std::string& s,
38 | ParseResults& rslts,
39 | const string& sep) {
40 |
41 | chomp(s);
42 | vector toks;
43 | boost::split(toks, s, boost::is_any_of(string(sep)));
44 |
45 | return parse(toks, rslts);
46 | }
47 |
48 | void DataParser::parseForDict(
49 | string& line,
50 | vector& tokens,
51 | const string& sep) {
52 |
53 | chomp(line);
54 | vector toks;
55 | boost::split(toks, line, boost::is_any_of(string(sep)));
56 | for (unsigned int i = 0; i < toks.size(); i++) {
57 | string token = toks[i];
58 | if (args_->useWeight) {
59 | std::size_t pos = toks[i].find(args_->weightSep);
60 | if (pos != std::string::npos) {
61 | token = toks[i].substr(0, pos);
62 | }
63 | }
64 | if (args_->normalizeText) {
65 | normalize_text(token);
66 | }
67 | if (token.find("__weight__") == std::string::npos) {
68 | tokens.push_back(token);
69 | }
70 | }
71 | }
72 |
73 | // check wether it is a valid example
74 | bool DataParser::check(const ParseResults& example) {
75 | if (args_->trainMode == 0) {
76 | // require lhs and rhs
77 | return !example.RHSTokens.empty() && !example.LHSTokens.empty();
78 | } if (args_->trainMode == 5) {
79 | // only requires lhs.
80 | return !example.LHSTokens.empty();
81 | } else {
82 | // lhs is not required, but rhs should contain at least 2 example
83 | return example.RHSTokens.size() > 1;
84 | }
85 | }
86 |
87 | void DataParser::addNgrams(
88 | const std::vector& tokens,
89 | std::vector& line,
90 | int n)
91 | {
92 | vector hashes;
93 |
94 | for (auto token: tokens) {
95 | entry_type type = dict_->getType(token);
96 | if (type == entry_type::word) {
97 | hashes.push_back(dict_->hash(token));
98 | }
99 | }
100 |
101 | for (int32_t i = 0; i < (int32_t)(hashes.size()); i++){
102 | uint64_t h = hashes[i];
103 | for (int32_t j = i + 1; j < (int32_t)(hashes.size()) && j < i + n; j++){
104 | h = h * Dictionary::HASH_C + hashes[j];
105 | int64_t id = h % args_->bucket;
106 | line.push_back(make_pair(dict_->nwords() + dict_->nlabels() + id, 1.0));
107 | // std::cerr << "i=" << i << "\tj=" << j << "\tngram=" << line.size() - 1 << std::endl;
108 | }
109 | }
110 | }
111 |
112 | bool DataParser::parse(
113 | const std::vector& tokens,
114 | ParseResults& rslts) {
115 |
116 | for (auto &token: tokens) {
117 | if (token.find("__weight__") != std::string::npos) {
118 | std::size_t pos = token.find(args_->weightSep);
119 | if (pos != std::string::npos) {
120 | rslts.weight = atof(token.substr(pos + 1).c_str());
121 | }
122 | continue;
123 | }
124 | string t = token;
125 | float weight = 1.0;
126 | if (args_->useWeight) {
127 | std::size_t pos = token.find(args_->weightSep);
128 | if (pos != std::string::npos) {
129 | t = token.substr(0, pos);
130 | weight = atof(token.substr(pos + 1).c_str());
131 | }
132 | }
133 |
134 | if (args_->normalizeText) {
135 | normalize_text(t);
136 | }
137 | int32_t wid = dict_->getId(t);
138 | if (wid < 0) {
139 | continue;
140 | }
141 |
142 | entry_type type = dict_->getType(wid);
143 | if (type == entry_type::word) {
144 | rslts.LHSTokens.push_back(make_pair(wid, weight));
145 | }
146 | if (type == entry_type::label) {
147 | rslts.RHSTokens.push_back(make_pair(wid, weight));
148 | }
149 | }
150 |
151 | if (args_->ngrams > 1) {
152 | addNgrams(tokens, rslts.LHSTokens, args_->ngrams);
153 | }
154 |
155 | return check(rslts);
156 | }
157 |
158 | bool DataParser::parse(
159 | const std::vector& tokens,
160 | vector& rslts) {
161 |
162 | for (auto &token: tokens) {
163 | auto t = token;
164 | float weight = 1.0;
165 | if (args_->useWeight) {
166 | std::size_t pos = token.find(args_->weightSep);
167 | if (pos != std::string::npos) {
168 | t = token.substr(0, pos);
169 | weight = atof(token.substr(pos + 1).c_str());
170 | }
171 | }
172 |
173 | if (args_->normalizeText) {
174 | normalize_text(t);
175 | }
176 | int32_t wid = dict_->getId(t);
177 | if (wid < 0) {
178 | continue;
179 | }
180 |
181 | //entry_type type = dict_->getType(wid);
182 | rslts.push_back(make_pair(wid, weight));
183 | }
184 |
185 | if (args_->ngrams > 1) {
186 | addNgrams(tokens, rslts, args_->ngrams);
187 | }
188 | return rslts.size() > 0;
189 | }
190 |
191 | } // namespace starspace
192 |
--------------------------------------------------------------------------------
/cpp/src/parser.h:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 |
9 | /**
10 | * This is the basic class of data parsing.
11 | * It provides essential functions as follows:
12 | * - parse(input, output):
13 | * takes input as a line of string (or a vector of string tokens)
14 | * and return output result which is one example contains l.h.s. features
15 | * and r.h.s. features.
16 | *
17 | * - parseForDict(input, tokens):
18 | * takes input as a line of string, output tokens to be added for building
19 | * the dictionary.
20 | *
21 | * - check(example):
22 | * checks whether the example is a valid example.
23 | *
24 | * - addNgrams(input, output):
25 | * add ngrams from input as output.
26 | *
27 | * One can write different parsers for data with different format.
28 | */
29 |
30 | #pragma once
31 |
32 | #include "dict.h"
33 | #include
34 | #include
35 |
36 | namespace starspace {
37 |
38 | typedef std::pair Base;
39 |
40 | struct ParseResults {
41 | unsigned dataset;
42 | float weight = 1.0;
43 | std::vector LHSTokens;
44 | std::vector RHSTokens;
45 | std::vector> RHSFeatures;
46 | };
47 |
48 | typedef std::vector Corpus;
49 |
50 | class DataParser {
51 | public:
52 | explicit DataParser(
53 | std::shared_ptr dict,
54 | std::shared_ptr args);
55 |
56 | virtual bool parse(
57 | std::string& s,
58 | ParseResults& rslt,
59 | const std::string& sep="\t ");
60 |
61 | virtual void parseForDict(
62 | std::string& s,
63 | std::vector& tokens,
64 | const std::string& sep="\t ");
65 |
66 | bool parse(
67 | const std::vector& tokens,
68 | std::vector& rslt);
69 |
70 | bool parse(
71 | const std::vector& tokens,
72 | ParseResults& rslt);
73 |
74 | bool check(const ParseResults& example);
75 |
76 | void addNgrams(
77 | const std::vector& tokens,
78 | std::vector& line,
79 | int32_t n);
80 |
81 | std::shared_ptr getDict() { return dict_; };
82 |
83 | void resetDict(std::shared_ptr dict) { dict_ = dict; };
84 |
85 | protected:
86 | std::shared_ptr dict_;
87 | std::shared_ptr args_;
88 | };
89 |
90 | }
91 |
--------------------------------------------------------------------------------
/cpp/src/proj.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 |
9 | #include "proj.h"
10 |
--------------------------------------------------------------------------------
/cpp/src/proj.h:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | // The SparseLinear class implements the lookup tables used in starspace model.
9 |
10 | #pragma once
11 |
12 | #include "matrix.h"
13 |
14 | #include
15 | #include
16 | #include
17 | #include
18 | #include
19 | #include
20 |
21 | namespace starspace {
22 |
23 | template
24 | struct SparseLinear : public Matrix {
25 | explicit SparseLinear(MatrixDims dims,
26 | Real sd = 1.0) : Matrix(dims, sd) { }
27 |
28 | explicit SparseLinear(std::ifstream& in) : Matrix(in) { }
29 |
30 | void forward(int in, Matrix& mout) {
31 | using namespace boost::numeric::ublas;
32 | const auto c = this->numCols();
33 | mout.matrix.resize(1, c);
34 | memcpy(&mout[0][0], &(*this)[in][0], c * sizeof(Real));
35 | }
36 |
37 | void forward(const std::vector& in, Matrix& mout) {
38 | using namespace boost::numeric::ublas;
39 | const auto c = this->numCols();
40 | mout.matrix = zero_matrix(1, c);
41 | auto outRow = mout.row(0);
42 | for (const auto& elt: in) {
43 | assert(elt < this->numRows());
44 | outRow += this->row(elt);
45 | }
46 | }
47 |
48 | void forward(const std::vector>& in,
49 | Matrix &mout) {
50 | using namespace boost::numeric::ublas;
51 | const auto c = this->numCols();
52 | mout.matrix = zero_matrix(1, c);
53 | auto outRow = mout.row(0);
54 | for (const auto& pair: in) {
55 | assert(pair.first < this->numRows());
56 | outRow += this->row(pair.first) * pair.second;
57 | }
58 | }
59 |
60 | void backward(const std::vector& in,
61 | const Matrix& mb, const Real alpha) {
62 | // Just update this racily and in-place.
63 | assert(mb.numRows() == 1);
64 | auto b = mb[0];
65 | for (const auto& elt: in) {
66 | auto row = (*this)[elt];
67 | for (int i = 0; i < this->numCols(); i++) {
68 | row[i] -= alpha * b[i];
69 | }
70 | }
71 | }
72 |
73 | Real* allocOutput() {
74 | Real* retval;
75 | auto val = posix_memalign((void**)&retval, Matrix::kAlign,
76 | this->numCols() * sizeof(Real));
77 | if (val != 0) {
78 | perror("could not allocate output");
79 | throw this;
80 | }
81 | return retval;
82 | }
83 | };
84 |
85 | }
86 |
--------------------------------------------------------------------------------
/cpp/src/starspace.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 |
9 | #include "starspace.h"
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 |
16 | using namespace std;
17 |
18 | namespace starspace {
19 |
20 | StarSpace::StarSpace(shared_ptr args)
21 | : args_(args)
22 | , dict_(nullptr)
23 | , parser_(nullptr)
24 | , trainData_(nullptr)
25 | , validData_(nullptr)
26 | , testData_(nullptr)
27 | , model_(nullptr)
28 | {}
29 |
30 | void StarSpace::initParser() {
31 | // if (args_->fileFormat == "fastText") {
32 | parser_ = make_shared(dict_, args_);
33 | // } else if (args_->fileFormat == "labelDoc") {
34 | // parser_ = make_shared(dict_, args_);
35 | // } else {
36 | // cerr << "Unsupported file format. Currently support: fastText or labelDoc.\n";
37 | // exit(EXIT_FAILURE);
38 | // }
39 | }
40 |
41 | void StarSpace::initDataHandler() {
42 | if (args_->isTrain) {
43 | trainData_ = initData();
44 | trainData_->loadFromFile();
45 | // set validation data
46 | if (!args_->validationFile.empty()) {
47 | validData_ = initData();
48 | validData_->loadFromFile();
49 | }
50 | } else {
51 | if (args_->testFile != "") {
52 | testData_ = initData();
53 | testData_->loadFromFile();
54 | }
55 | }
56 | }
57 |
58 | shared_ptr StarSpace::initData() {
59 | // if (args_->fileFormat == "fastText") {
60 | return make_shared(args_, dict_, parser_);
61 | // } else if (args_->fileFormat == "labelDoc") {
62 | // return make_shared(args_, dict_, parser_);
63 | // } else {
64 | // cerr << "Unsupported file format. Currently support: fastText or labelDoc.\n";
65 | // exit(EXIT_FAILURE);
66 | // }
67 | return nullptr;
68 | }
69 |
70 | // initialize dict and load data
71 | void StarSpace::init() {
72 | cout << "\nStart to initialize starspace model.\n";
73 | assert(args_ != nullptr);
74 |
75 | // build dict
76 | initParser();
77 | dict_ = make_shared(args_);
78 | if(args_->isCellSpace) dict_->CreateForCellSpace(parser_);
79 | else dict_->readFromFile(args_->trainFile, parser_);
80 | parser_->resetDict(dict_);
81 | if (args_->debug) {dict_->save(cout);}
82 |
83 | // init train data class
84 | trainData_ = initData();
85 | trainData_->loadFromFile();
86 |
87 | // init model with args and dict
88 | model_ = make_shared(args_, dict_);
89 | // if(args_->feat_emb != "") model_->loadFeatEmb(args_->feat_emb);
90 |
91 | // set validation data
92 | if (!args_->validationFile.empty()) {
93 | validData_ = initData();
94 | validData_->loadFromFile();
95 | }
96 | }
97 |
98 | void StarSpace::initFromSavedModel(const string& filename) {
99 | cout << "Start to load a trained starspace model.\n";
100 | std::ifstream in(filename, std::ifstream::binary);
101 | if (!in.is_open()) {
102 | std::cerr << "Model file cannot be opened for loading!" << std::endl;
103 | exit(EXIT_FAILURE);
104 | }
105 | string magic;
106 | char c;
107 | while ((c = in.get()) != 0) {
108 | magic.push_back(c);
109 | }
110 | cout << magic << endl;
111 | if (magic != kMagic) {
112 | std::cerr << "Magic signature does not match!" << std::endl;
113 | exit(EXIT_FAILURE);
114 | }
115 | // load args
116 | args_->load(in);
117 |
118 | // init and load dict
119 | dict_ = make_shared(args_);
120 | dict_->load(in);
121 |
122 | // init and load model
123 | model_ = make_shared(args_, dict_);
124 | model_->load(in);
125 | cout << "Model loaded.\n";
126 |
127 | // init data parser
128 | initParser();
129 | initDataHandler();
130 |
131 | loadBaseDocs();
132 | }
133 |
134 | void StarSpace::initFromTsv(const string& filename) {
135 | cout << "Start to load a trained embedding model in tsv format.\n";
136 | assert(args_ != nullptr);
137 | ifstream in(filename);
138 | if (!in.is_open()) {
139 | std::cerr << "Model file cannot be opened for loading!" << std::endl;
140 | exit(EXIT_FAILURE);
141 | }
142 | // Test dimension of first line, adjust args appropriately
143 | // (This is also so we can load a TSV file without even specifying the dim.)
144 | string line;
145 | getline(in, line);
146 | vector pieces;
147 | boost::split(pieces, line, boost::is_any_of("\t "));
148 | int dim = pieces.size() - 1;
149 | if ((int)(args_->dim) != dim) {
150 | args_->dim = dim;
151 | cout << "Setting dim from Tsv file to: " << dim << endl;
152 | }
153 | in.close();
154 |
155 | // build dict
156 | dict_ = make_shared(args_);
157 | dict_->loadDictFromModel(filename);
158 | if (args_->debug) {dict_->save(cout);}
159 |
160 | // load Model
161 | model_ = make_shared(args_, dict_);
162 | model_->loadTsv(filename, "\t ");
163 |
164 | // init data parser
165 | initParser();
166 | initDataHandler();
167 | }
168 |
169 | void StarSpace::train() {
170 | float rate = args_->lr;
171 | float decrPerEpoch = (rate - 1e-9) / args_->epoch;
172 |
173 | int impatience = 0;
174 | float best_valid_err = 1e9;
175 | auto t_start = std::chrono::high_resolution_clock::now();
176 | for (int i = 0; i < args_->epoch; i++) {
177 | if (args_->saveEveryNEpochs > 0 && i > 0 && i % args_->saveEveryNEpochs == 0) {
178 | auto filename = args_->model + "_epoch" + std::to_string(i);
179 | // saveModel(filename);
180 | saveModelTsv(filename + ".tsv");
181 | }
182 | cout << "Training epoch " << i << ": " << rate << ' ' << decrPerEpoch << endl;
183 | auto err = model_->train(trainData_, args_->thread,
184 | t_start, i,
185 | rate, rate - decrPerEpoch);
186 | printf("\n ---+++ %20s %4d Train error : %3.8f +++--- %c%c%c\n",
187 | "Epoch", i, err,
188 | 0xe2, 0x98, 0x83);
189 | if (validData_ != nullptr) {
190 | auto valid_err = model_->test(validData_, args_->thread);
191 | cout << "\nValidation error: " << valid_err << endl;
192 | if (valid_err > best_valid_err) {
193 | impatience += 1;
194 | if (impatience > args_->validationPatience) {
195 | cout << "Ran out of Patience! Early stopping based on validation set." << endl;
196 | break;
197 | }
198 | } else {
199 | best_valid_err = valid_err;
200 | }
201 | }
202 | rate -= decrPerEpoch;
203 |
204 | auto t_end = std::chrono::high_resolution_clock::now();
205 | auto tot_spent = std::chrono::duration(t_end-t_start).count();
206 | if (tot_spent >args_->maxTrainTime) {
207 | cout << "MaxTrainTime exceeded." << endl;
208 | break;
209 | }
210 | }
211 | }
212 |
213 | void StarSpace::parseDoc(
214 | const string& line,
215 | vector& ids,
216 | const string& sep) {
217 |
218 | vector tokens;
219 | boost::split(tokens, line, boost::is_any_of(string(sep)));
220 | parser_->parse(tokens, ids);
221 | }
222 |
223 | Matrix StarSpace::getDocVector(const string& line, const string& sep) {
224 | vector ids;
225 | parseDoc(line, ids, sep);
226 | return model_->projectLHS(ids);
227 | }
228 |
229 | MatrixRow StarSpace::getNgramVector(const string& phrase) {
230 | vector tokens;
231 | boost::split(tokens, phrase, boost::is_any_of(string(" ")));
232 | if (tokens.size() > (unsigned int)(args_->ngrams)) {
233 | std::cerr << "Error! Input ngrams size is greater than model ngrams size.\n";
234 | exit(EXIT_FAILURE);
235 | }
236 | if (tokens.size() == 1) {
237 | // looking up the entity embedding directly
238 | auto id = dict_->getId(tokens[0]);
239 | if (id != -1) {
240 | return model_->getLHSEmbeddings()->row(id);
241 | }
242 | }
243 |
244 | uint64_t h = 0;
245 | for (auto token: tokens) {
246 | if (dict_->getType(token) == entry_type::word) {
247 | h = h * Dictionary::HASH_C + dict_->hash(token);
248 | }
249 | }
250 | int64_t id = h % args_->bucket;
251 | return model_->getLHSEmbeddings()->row(id + dict_->nwords() + dict_->nlabels());
252 | }
253 |
254 | void StarSpace::nearestNeighbor(const string& line, int k) {
255 | auto vec = getDocVector(line, " ");
256 | auto preds = model_->findLHSLike(vec, k);
257 | for (auto n : preds) {
258 | cout << dict_->getSymbol(n.first) << ' ' << n.second << endl;
259 | }
260 | }
261 |
262 | unordered_map StarSpace::predictTags(const string& line, int k){
263 | args_->K = k;
264 | vector query_vec;
265 | parseDoc(line, query_vec, " ");
266 |
267 | vector predictions;
268 | predictOne(query_vec, predictions);
269 |
270 | unordered_map umap;
271 |
272 | for (int i = 0; i < predictions.size(); i++) {
273 | string tmp = printDocStr(baseDocs_[predictions[i].second]);
274 | umap[ tmp ] = predictions[i].first;
275 | }
276 | return umap;
277 | }
278 |
279 | void StarSpace::loadBaseDocs() {
280 | if (args_->basedoc.empty()) {
281 | // if (args_->fileFormat == "labelDoc") {
282 | // std::cerr << "Must provide base labels when label is featured.\n";
283 | // exit(EXIT_FAILURE);
284 | // }
285 | for (int i = 0; i < dict_->nlabels(); i++) {
286 | baseDocs_.push_back({ make_pair(i + dict_->nwords(), 1.0) });
287 | baseDocVectors_.push_back(
288 | model_->projectRHS({ make_pair(i + dict_->nwords(), 1.0) })
289 | );
290 | }
291 | cout << "Predictions use " << dict_->nlabels() << " known labels." << endl;
292 | } else {
293 | cout << "Loading base docs from file : " << args_->basedoc << endl;
294 | ifstream fin(args_->basedoc);
295 | if (!fin.is_open()) {
296 | std::cerr << "Base doc file cannot be opened for loading!" << std::endl;
297 | exit(EXIT_FAILURE);
298 | }
299 | string line;
300 | while (getline(fin, line)) {
301 | vector ids;
302 | parseDoc(line, ids, "\t ");
303 | baseDocs_.push_back(ids);
304 | auto docVec = model_->projectRHS(ids);
305 | baseDocVectors_.push_back(docVec);
306 | }
307 | fin.close();
308 | if (baseDocVectors_.size() == 0) {
309 | std::cerr << "ERROR: basedoc file '" << args_->basedoc << "' is empty." << std::endl;
310 | exit(EXIT_FAILURE);
311 | }
312 | cout << "Finished loading " << baseDocVectors_.size() << " base docs.\n";
313 | }
314 | }
315 |
316 | void StarSpace::predictOne(
317 | const vector& input,
318 | vector& pred) {
319 | auto lhsM = model_->projectLHS(input);
320 | std::priority_queue heap;
321 | for (unsigned int i = 0; i < baseDocVectors_.size(); i++) {
322 | auto cur_score = model_->similarity(lhsM, baseDocVectors_[i]);
323 | heap.push({ cur_score, i });
324 | }
325 | // get the first K predictions
326 | int i = 0;
327 | while (i < args_->K && heap.size() > 0) {
328 | pred.push_back(heap.top());
329 | heap.pop();
330 | i++;
331 | }
332 | }
333 |
334 | Metrics StarSpace::evaluateOne(
335 | const vector& lhs,
336 | const vector& rhs,
337 | vector& pred,
338 | bool excludeLHS) {
339 |
340 | std::priority_queue heap;
341 |
342 | auto lhsM = model_->projectLHS(lhs);
343 | auto rhsM = model_->projectRHS(rhs);
344 | // Our evaluation function currently assumes there is only one correct label.
345 | // TODO: generalize this to the multilabel case.
346 | auto score = model_->similarity(lhsM, rhsM);
347 |
348 | int rank = 1;
349 | heap.push({ score, 0 });
350 |
351 | for (unsigned int i = 0; i < baseDocVectors_.size(); i++) {
352 | // in the case basedoc labels are not provided, all labels become basedoc,
353 | // and we skip the correct label for comparison.
354 | if ((args_->basedoc.empty()) && ((int)i == rhs[0].first - dict_->nwords())) {
355 | continue;
356 | }
357 | auto cur_score = model_->similarity(lhsM, baseDocVectors_[i]);
358 | if (cur_score > score) {
359 | rank++;
360 | } else if (cur_score == score) {
361 | float flip = (float) rand() / RAND_MAX;
362 | if (flip > 0.5) {
363 | rank++;
364 | }
365 | }
366 | heap.push({ cur_score, i + 1 });
367 | }
368 |
369 | // get the first K predictions
370 | int i = 0;
371 | while (i < args_->K && heap.size() > 0) {
372 | Predictions heap_top = heap.top();
373 | heap.pop();
374 |
375 | bool keep = true;
376 | if(excludeLHS && (args_->basedoc.empty())) {
377 | int nwords = dict_->nwords();
378 | auto it = std::find_if( lhs.begin(), lhs.end(),
379 | [&heap_top, &nwords](const Base& el){ return (el.first - nwords + 1) == heap_top.second;} );
380 | keep = it == lhs.end();
381 | }
382 |
383 | if(keep) {
384 | pred.push_back(heap_top);
385 | i++;
386 | }
387 | }
388 |
389 | Metrics s;
390 | s.clear();
391 | s.update(rank);
392 | return s;
393 | }
394 |
395 | void StarSpace::printDoc(ostream& ofs, const vector& tokens) {
396 | for (auto t : tokens) {
397 | // skip ngram tokens
398 | if (t.first < dict_->size()) {
399 | ofs << dict_->getSymbol(t.first) << ' ';
400 | }
401 | }
402 | ofs << endl;
403 | }
404 |
405 | string StarSpace::printDocStr(const vector& tokens) {
406 | for (auto t : tokens) {
407 | if (t.first < dict_->size()) {
408 | return dict_->getSymbol(t.first);
409 | }
410 | }
411 |
412 | return "__label_unk";
413 | }
414 |
415 | void StarSpace::evaluate() {
416 | // check that it is not in trainMode 5
417 | if (args_->trainMode == 5) {
418 | std::cerr << "Test is undefined in trainMode 5. Please use other trainMode for testing.\n";
419 | exit(EXIT_FAILURE);
420 | }
421 |
422 | // set dropout probability to 0 in test case
423 | args_->dropoutLHS = 0.0;
424 | args_->dropoutRHS = 0.0;
425 |
426 | loadBaseDocs();
427 | int N = testData_->getSize();
428 |
429 | auto numThreads = args_->thread;
430 | vector threads;
431 | vector metrics(numThreads);
432 | vector> predictions(N);
433 | int numPerThread = ceil((float) N / numThreads);
434 | assert(numPerThread > 0);
435 |
436 | vector examples;
437 | testData_->getNextKExamples(N, examples);
438 |
439 | auto evalThread = [&] (int idx, int start, int end) {
440 | metrics[idx].clear();
441 | for (int i = start; i < end; i++) {
442 | auto s = evaluateOne(examples[i].LHSTokens, examples[i].RHSTokens, predictions[i], args_->excludeLHS);
443 | metrics[idx].add(s);
444 | }
445 | };
446 |
447 | for (int i = 0; i < numThreads; i++) {
448 | auto start = std::min(i * numPerThread, N);
449 | auto end = std::min(start + numPerThread, N);
450 | assert(end >= start);
451 | threads.emplace_back(thread([=] {
452 | evalThread(i, start, end);
453 | }));
454 | }
455 | for (auto& t : threads) t.join();
456 |
457 | Metrics result;
458 | result.clear();
459 | for (int i = 0; i < numThreads; i++) {
460 | if (args_->debug) { metrics[i].print(); }
461 | result.add(metrics[i]);
462 | }
463 | result.average();
464 | result.print();
465 |
466 | if (!args_->predictionFile.empty()) {
467 | // print out prediction results to file
468 | ofstream ofs(args_->predictionFile);
469 | for (int i = 0; i < N; i++) {
470 | ofs << "Example " << i << ":\nLHS:\n";
471 | printDoc(ofs, examples[i].LHSTokens);
472 | ofs << "RHS: \n";
473 | printDoc(ofs, examples[i].RHSTokens);
474 | ofs << "Predictions: \n";
475 | for (auto pred : predictions[i]) {
476 | if (pred.second == 0) {
477 | ofs << "(++) [" << pred.first << "]\t";
478 | printDoc(ofs, examples[i].RHSTokens);
479 | } else {
480 | ofs << "(--) [" << pred.first << "]\t";
481 | printDoc(ofs, baseDocs_[pred.second - 1]);
482 | }
483 | }
484 | ofs << "\n";
485 | }
486 | ofs.close();
487 | }
488 | }
489 |
490 | void StarSpace::saveModel(const string& filename) {
491 | cout << "Saving model to file : " << filename << endl;
492 | std::ofstream ofs(filename, std::ofstream::binary);
493 | if (!ofs.is_open()) {
494 | std::cerr << "Model file cannot be opened for saving!" << std::endl;
495 | exit(EXIT_FAILURE);
496 | }
497 | // sign model
498 | ofs.write(kMagic.data(), kMagic.size() * sizeof(char));
499 | ofs.put(0);
500 | args_->save(ofs);
501 | dict_->save(ofs);
502 | model_->save(ofs);
503 | ofs.close();
504 | }
505 |
506 | void StarSpace::saveModelTsv(const string& filename) {
507 | cout << "Saving model in tsv format : " << filename << endl;
508 | ofstream fout(filename);
509 | model_->saveTsv(fout, '\t');
510 | fout.close();
511 | }
512 |
513 | } // starspace
514 |
--------------------------------------------------------------------------------
/cpp/src/starspace.h:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | #pragma once
9 |
10 | #include "utils/args.h"
11 | #include "dict.h"
12 | #include "matrix.h"
13 | #include "parser.h"
14 | // #include "doc_parser.h"
15 | #include "model.h"
16 | #include "utils/utils.h"
17 |
18 | namespace starspace {
19 |
20 | typedef std::pair Predictions;
21 |
22 | class StarSpace {
23 | public:
24 | explicit StarSpace(std::shared_ptr args);
25 |
26 | void init();
27 | void initFromTsv(const std::string& filename);
28 | void initFromSavedModel(const std::string& filename);
29 |
30 | void train();
31 | void evaluate();
32 |
33 | MatrixRow getNgramVector(const std::string& phrase);
34 | Matrix getDocVector(
35 | const std::string& line,
36 | const std::string& sep = " \t");
37 | void parseDoc(
38 | const std::string& line,
39 | std::vector& ids,
40 | const std::string& sep);
41 |
42 | void nearestNeighbor(const std::string& line, int k);
43 |
44 |
45 | std::unordered_map predictTags(const std::string& line, int k);
46 | std::string printDocStr(const std::vector& tokens);
47 |
48 | void saveModel(const std::string& filename);
49 | void saveModelTsv(const std::string& filename);
50 | void printDoc(std::ostream& ofs, const std::vector& tokens);
51 |
52 | const std::string kMagic = "STARSPACE-2018-2";
53 |
54 |
55 | void loadBaseDocs();
56 |
57 | void predictOne(
58 | const std::vector& input,
59 | std::vector& pred);
60 |
61 | std::shared_ptr args_;
62 | std::vector> baseDocs_;
63 | private:
64 | void initParser();
65 | void initDataHandler();
66 | std::shared_ptr initData();
67 | Metrics evaluateOne(
68 | const std::vector& lhs,
69 | const std::vector& rhs,
70 | std::vector& pred,
71 | bool excludeLHS);
72 |
73 | std::shared_ptr dict_;
74 | std::shared_ptr parser_;
75 | std::shared_ptr trainData_;
76 | std::shared_ptr validData_;
77 | std::shared_ptr testData_;
78 | std::shared_ptr model_;
79 |
80 | std::vector> baseDocVectors_;
81 | };
82 |
83 | }
84 |
--------------------------------------------------------------------------------
/cpp/src/utils/args.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | #include "args.h"
9 |
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 |
17 | using namespace std;
18 |
19 | namespace starspace {
20 |
21 | Args::Args() {
22 | lr = 0.01;
23 | termLr = 1e-9;
24 | norm = 1.0;
25 | margin = 0.05;
26 | wordWeight = 0.5;
27 | initRandSd = 0.001;
28 | dropoutLHS = 0.0;
29 | dropoutRHS = 0.0;
30 | p = 0.5;
31 | ws = 5;
32 | maxTrainTime = 60*60*24*100;
33 | validationPatience = 10;
34 | thread = 10;
35 | maxNegSamples = 10;
36 | negSearchLimit = 50;
37 | minCount = 1;
38 | minCountLabel = 1;
39 | K = 5;
40 | batchSize = 5;
41 | verbose = false;
42 | debug = false;
43 | adagrad = true;
44 | normalizeText = false;
45 | trainMode = 0;
46 | // fileFormat = "fastText";
47 | label = "__label__";
48 | bucket = 2000000;
49 | isTrain = true;
50 | shareEmb = true;
51 | // saveEveryEpoch = false;
52 | saveEveryNEpochs = -1;
53 | // saveTempModel = false;
54 | useWeight = false;
55 | trainWord = false;
56 | excludeLHS = false;
57 | weightSep = ':';
58 | numGzFile = 1;
59 |
60 | loss = "hinge";
61 | similarity = "cosine";
62 | ngrams = 3;
63 | dim = 30;
64 | epoch = 50;
65 | exmpPerPeak = 20;
66 | k = 8;
67 | sampleLen = 150;
68 | // fixedFeatEmb = false;
69 | // batchLabels = false;
70 | }
71 |
72 | bool Args::isTrue(string arg) {
73 | std::transform(arg.begin(), arg.end(), arg.begin(),
74 | [&](char c) { return tolower(c); }
75 | );
76 | return (arg == "true" || arg == "1");
77 | }
78 |
79 | bool valid_file(std::string path){
80 | if(path.empty()){
81 | cerr << "\'" << path << "\' cannot be opened!" << endl;
82 | return false;
83 | } else {
84 | std::ifstream fin(path, std::ifstream::in);
85 | if (!fin.is_open()){
86 | cerr << "\'" << path << "\' cannot be opened!" << endl;
87 | return false;
88 | }
89 | fin.close();
90 | }
91 | return true;
92 | }
93 |
94 | inline bool check_header(string header){
95 | if(header == "%%MatrixMarket matrix coordinate pattern general") return(true);
96 | else if(header == "%%MatrixMarket matrix coordinate integer general") return(false);
97 | else {
98 | cerr << "Unsupported matrix format!" << endl;
99 | exit(EXIT_FAILURE);
100 | }
101 | }
102 |
103 | void Args::parseArgs(int argc, char** argv) {
104 | string a0 = argv[0];
105 | isCellSpace = a0.substr(a0.length() - 9, 9) == "CellSpace";
106 |
107 | if (argc < 4) {
108 | // cerr << "Usage: need to specify whether it is train or test.\n";
109 | printHelp();
110 | exit(EXIT_FAILURE);
111 | }
112 | // if (strcmp(argv[1], "train") == 0) {
113 | // isTrain = true;
114 | // } else if (strcmp(argv[1], "test") == 0) {
115 | // isTrain = false;
116 | // if(isCellSpace){
117 | // std::cerr << "CellSpace test has not been implemented yet!" << std::endl;
118 | // exit(EXIT_FAILURE);
119 | // }
120 | // } else if (strcmp(argv[1], "-h") == 0 || strcmp(argv[1], "-help") == 0) {
121 | // std::cerr << "Here is the help! Usage:" << std::endl;
122 | // printHelp();
123 | // exit(EXIT_FAILURE);
124 | // } else {
125 | // cerr << "Usage: the first argument should be either train or test.\n";
126 | // printHelp();
127 | // exit(EXIT_FAILURE);
128 | // }
129 | int i = 1;
130 | while (i < argc) {
131 | if (argv[i][0] != '-') {
132 | cout << "Provided argument without a dash! Usage:" << endl;
133 | printHelp();
134 | exit(EXIT_FAILURE);
135 | }
136 |
137 | // handling "--"
138 | if (strlen(argv[i]) >= 2 && argv[i][1] == '-') {
139 | argv[i] = argv[i] + 1;
140 | }
141 |
142 | if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "-help") == 0) {
143 | // std::cerr << "Here is the help! Usage:" << std::endl;
144 | printHelp();
145 | exit(EXIT_FAILURE);
146 | // } else if (strcmp(argv[i], "-trainFile") == 0 && !isCellSpace) {
147 | // trainFile = string(argv[i + 1]);
148 | // } else if (strcmp(argv[i], "-validationFile") == 0 && !isCellSpace) {
149 | // validationFile = string(argv[i + 1]);
150 | // } else if (strcmp(argv[i], "-testFile") == 0 && !isCellSpace) {
151 | // testFile = string(argv[i + 1]);
152 | // } else if (strcmp(argv[i], "-predictionFile") == 0 && !isCellSpace) {
153 | // predictionFile = string(argv[i + 1]);
154 | // } else if (strcmp(argv[i], "-basedoc") == 0 && !isCellSpace) {
155 | // basedoc = string(argv[i + 1]);
156 | } else if (strcmp(argv[i], "-output") == 0) {
157 | model = string(argv[i + 1]);
158 | // } else if (strcmp(argv[i], "-initModel") == 0 && !isCellSpace) {
159 | // initModel = string(argv[i + 1]);
160 | // } else if (strcmp(argv[i], "-fileFormat") == 0 && !isCellSpace) {
161 | // fileFormat = string(argv[i + 1]);
162 | // } else if (strcmp(argv[i], "-compressFile") == 0 && !isCellSpace) {
163 | // compressFile = string(argv[i + 1]);
164 | // } else if (strcmp(argv[i], "-numGzFile") == 0) {
165 | // numGzFile = atoi(argv[i + 1]);
166 | } else if (strcmp(argv[i], "-label") == 0) {
167 | label = string(argv[i + 1]);
168 | // } else if (strcmp(argv[i], "-weightSep") == 0 && !isCellSpace) {
169 | // weightSep = argv[i + 1][0];
170 | // } else if (strcmp(argv[i], "-loss") == 0 && !isCellSpace) {
171 | // loss = string(argv[i + 1]);
172 | // } else if (strcmp(argv[i], "-similarity") == 0) {
173 | // similarity = string(argv[i + 1]);
174 | } else if (strcmp(argv[i], "-lr") == 0) {
175 | lr = atof(argv[i + 1]);
176 | } else if (strcmp(argv[i], "-p") == 0) {
177 | p = atof(argv[i + 1]);
178 | // } else if (strcmp(argv[i], "-termLr") == 0) {
179 | // termLr = atof(argv[i + 1]);
180 | // } else if (strcmp(argv[i], "-norm") == 0) {
181 | // norm = atof(argv[i + 1]);
182 | } else if (strcmp(argv[i], "-margin") == 0) {
183 | margin = atof(argv[i + 1]);
184 | } else if (strcmp(argv[i], "-initRandSd") == 0) {
185 | initRandSd = atof(argv[i + 1]);
186 | // } else if (strcmp(argv[i], "-dropoutLHS") == 0) {
187 | // dropoutLHS = atof(argv[i + 1]);
188 | // } else if (strcmp(argv[i], "-dropoutRHS") == 0) {
189 | // dropoutRHS = atof(argv[i + 1]);
190 | // } else if (strcmp(argv[i], "-wordWeight") == 0) {
191 | // wordWeight = atof(argv[i + 1]);
192 | } else if (strcmp(argv[i], "-dim") == 0) {
193 | dim = atoi(argv[i + 1]);
194 | } else if (strcmp(argv[i], "-epoch") == 0) {
195 | epoch = atoi(argv[i + 1]);
196 | // } else if (strcmp(argv[i], "-ws") == 0) {
197 | // ws = atoi(argv[i + 1]);
198 | } else if (strcmp(argv[i], "-maxTrainTime") == 0) {
199 | maxTrainTime = atoi(argv[i + 1]);
200 | // } else if (strcmp(argv[i], "-validationPatience") == 0) {
201 | // validationPatience = atoi(argv[i + 1]);
202 | } else if (strcmp(argv[i], "-thread") == 0) {
203 | thread = atoi(argv[i + 1]);
204 | } else if (strcmp(argv[i], "-maxNegSamples") == 0) {
205 | maxNegSamples = atoi(argv[i + 1]);
206 | } else if (strcmp(argv[i], "-negSearchLimit") == 0) {
207 | negSearchLimit = atoi(argv[i + 1]);
208 | // } else if (strcmp(argv[i], "-minCount") == 0 && !isCellSpace) {
209 | // minCount = atoi(argv[i + 1]);
210 | // } else if (strcmp(argv[i], "-minCountLabel") == 0 && !isCellSpace) {
211 | // minCountLabel = atoi(argv[i + 1]);
212 | } else if (strcmp(argv[i], "-bucket") == 0) {
213 | bucket = atoi(argv[i + 1]);
214 | } else if (strcmp(argv[i], "-ngrams") == 0) {
215 | ngrams = atoi(argv[i + 1]);
216 | // } else if (strcmp(argv[i], "-K") == 0) {
217 | // K = atoi(argv[i + 1]);
218 | } else if (strcmp(argv[i], "-batchSize") == 0) {
219 | batchSize = atoi(argv[i + 1]);
220 | // } else if (strcmp(argv[i], "-trainMode") == 0 && !isCellSpace) {
221 | // trainMode = atoi(argv[i + 1]);
222 | // } else if (strcmp(argv[i], "-fixedFeatEmb") == 0 && isCellSpace) {
223 | // fixedFeatEmb = isTrue(string(argv[i + 1]));
224 | // } else if (strcmp(argv[i], "-batchLabels") == 0 && isCellSpace) {
225 | // batchLabels = isTrue(string(argv[i + 1]));
226 | // } else if (strcmp(argv[i], "-verbose") == 0) {
227 | // verbose = isTrue(string(argv[i + 1]));
228 | // } else if (strcmp(argv[i], "-debug") == 0) {
229 | // debug = isTrue(string(argv[i + 1]));
230 | // } else if (strcmp(argv[i], "-adagrad") == 0) {
231 | // adagrad = isTrue(string(argv[i + 1]));
232 | // } else if (strcmp(argv[i], "-shareEmb") == 0) {
233 | // shareEmb = isTrue(string(argv[i + 1]));
234 | // } else if (strcmp(argv[i], "-normalizeText") == 0 && !isCellSpace) {
235 | // normalizeText = isTrue(string(argv[i + 1]));
236 | } else if (strcmp(argv[i], "-saveIntermediates") == 0) {
237 | if(string(argv[i + 1]) == "final") saveEveryNEpochs = -1;
238 | else saveEveryNEpochs = atoi(argv[i + 1]);
239 | if(saveEveryNEpochs < 1) saveEveryNEpochs = -1;
240 | // } else if (strcmp(argv[i], "-saveTempModel") == 0) {
241 | // saveTempModel = isTrue(string(argv[i + 1]));
242 | // } else if (strcmp(argv[i], "-useWeight") == 0 && !isCellSpace) {
243 | // useWeight = isTrue(string(argv[i + 1]));
244 | // } else if (strcmp(argv[i], "-trainWord") == 0 && !isCellSpace) {
245 | // trainWord = isTrue(string(argv[i + 1]));
246 | // } else if (strcmp(argv[i], "-excludeLHS") == 0) {
247 | // excludeLHS = isTrue(string(argv[i + 1]));
248 | } else if (strcmp(argv[i], "-cpMat") == 0 && isCellSpace) {
249 | assert(i + 1 < argc);
250 | while(i + 1 < argc && argv[i + 1][0] != '-')
251 | cp_matrix_list.push_back(string(argv[(i++) + 1]));
252 | i--;
253 | } else if (strcmp(argv[i], "-peaks") == 0 && isCellSpace) {
254 | assert(i + 1 < argc);
255 | while(i + 1 < argc && argv[i + 1][0] != '-')
256 | peaks_list.push_back(string(argv[(i++) + 1]));
257 | i--;
258 | // } else if (strcmp(argv[i], "-featEmb") == 0 && isCellSpace) {
259 | // feat_emb = string(argv[i + 1]);
260 | } else if (strcmp(argv[i], "-k") == 0 && isCellSpace) {
261 | k = atoi(argv[i + 1]);
262 | } else if (strcmp(argv[i], "-sampleLen") == 0 && isCellSpace) {
263 | sampleLen = (strcmp(argv[i + 1], "given") == 0) ? -1 : atoi(argv[i + 1]);
264 | } else if (strcmp(argv[i], "-exmpPerPeak") == 0 && isCellSpace) {
265 | exmpPerPeak = atoi(argv[i + 1]);
266 | } else {
267 | cerr << "Unknown argument: " << argv[i] << std::endl;
268 | printHelp();
269 | exit(EXIT_FAILURE);
270 | }
271 | i += 2;
272 | }
273 |
274 | if(isCellSpace){
275 | trainMode = 0;
276 | minCount = 1;
277 | minCountLabel = 1;
278 | trainWord = false;
279 | useWeight = false;
280 | normalizeText = false;
281 | compressFile = "";
282 |
283 | if(peaks_list.size() != cp_matrix_list.size()){
284 | cerr << "The number of input files for peak sequences should match the number of input files for count matrices!" << endl;
285 | exit(EXIT_FAILURE);
286 | }
287 | nBatches = cp_matrix_list.size();
288 |
289 | nCells_total = 0; nCells_list.clear();
290 | nPeaks_total = 0; nPeaks_list.clear();
291 | first_peak_idx.clear();
292 | for(auto cp_matrix: cp_matrix_list){
293 | unsigned nCells = 0, nPeaks = 0;
294 | if(!valid_file(cp_matrix)) exit(EXIT_FAILURE);
295 | else {
296 | string header, mat_format = cp_matrix.substr(cp_matrix.find_last_of(".") + 1);
297 | if(mat_format == "gz"){
298 | #ifdef COMPRESS_FILE
299 | ifstream ifs2(cp_matrix);
300 | if(!ifs2.good()) exit(EXIT_FAILURE);
301 | filtering_istream fin;
302 | fin.push(gzip_decompressor());
303 | fin.push(ifs2);
304 | getline(fin, header);
305 | bool bin = check_header(header);
306 | fin >> nCells >> nPeaks;
307 | ifs2.close();
308 | #endif
309 | } else {
310 | std::ifstream fin(cp_matrix, std::ifstream::in);
311 | getline(fin, header);
312 | bool bin = check_header(header);
313 | fin >> nCells >> nPeaks;
314 | fin.close();
315 | }
316 | if(nCells < 2 || nPeaks < 2){
317 | cerr << "There must be at least two cells and two peaks in a dataset!" << endl;
318 | exit(EXIT_FAILURE);
319 | } else {
320 | first_peak_idx.push_back(nPeaks_total);
321 | nCells_list.push_back(nCells); nCells_total += nCells;
322 | nPeaks_list.push_back(nPeaks); nPeaks_total += nPeaks;
323 | // cerr << nCells_list.size() << ") cells: " << nCells << ", peaks: " << nPeaks << endl;
324 | }
325 | }
326 | }
327 |
328 | for(auto peaks: peaks_list)
329 | if(!valid_file(peaks)) exit(EXIT_FAILURE);
330 |
331 | // if(feat_emb != "" && !valid_file(feat_emb)) exit(EXIT_FAILURE);
332 |
333 | if(k < 2 || k > 16){
334 | cerr << "k must be >1 and <17" << endl;
335 | exit(EXIT_FAILURE);
336 | }
337 | if(sampleLen != -1 && sampleLen < k){
338 | cerr << "sampleLen must be >=k" << endl;
339 | exit(EXIT_FAILURE);
340 | }
341 | if(exmpPerPeak < 1){
342 | cerr << "exmpPerPeak must be >0" << endl;
343 | exit(EXIT_FAILURE);
344 | }
345 | }
346 | // else {
347 | // if (isTrain) {
348 | // if (model.empty()) {
349 | // cerr << "Empty output path." << endl;
350 | // printHelp();
351 | // exit(EXIT_FAILURE);
352 | // }
353 | // } else {
354 | // if (testFile.empty() || model.empty()) {
355 | // cerr << "Empty test file or model path." << endl;
356 | // printHelp();
357 | // exit(EXIT_FAILURE);
358 | // }
359 | // }
360 | // }
361 | }
362 |
363 | void Args::printHelp() {
364 | if(isCellSpace){
365 | cout << "\n"
366 | << "\"CellSpace ...\"\n"
367 |
368 | << "\nThe following arguments are mandatory:\n"
369 | << " -output prefix of the output\n"
370 | << " -cpMat sparse cell by peak/tile count matrix (.mtx)\n"
371 | << " -peaks multi-fasta file containing peak/tile DNA sequences with the order they appear in the corresponding count matrix (.fa)\n"
372 | // << " -featEmb optional embedding matrix to initialize k-mer embeddings from a previously trained CellSpace model (.tsv)\n"
373 |
374 | << "\nThe following arguments are optional:\n"
375 | << " -dim size of embedding vectors [default=" << dim << "]\n"
376 | << " -ngrams max length of k-mer ngram [default=" << ngrams << "]\n"
377 | << " -k k-mer length [default=" << k << "]\n"
378 | << " -sampleLen length of the sequences randomly sampled from the peak/tile DNA sequences (integer or 'given') [default=" << (sampleLen == -1 ? "'given'" : to_string(sampleLen)) << "]\n"
379 | << " -exmpPerPeak number of training examples per peak/tile [default=" << exmpPerPeak << "]\n"
380 | << " -epoch number of epochs [default=" << epoch << "]\n"
381 | // << " -fixedFeatEmb whether feature (k-mer) embeddings should be kept fixed [default=" << fixedFeatEmb << "]\n"
382 | // << " -batchLabels whether the batch (i.e. dataset) should be included as a label (i.e. RHS) in training [default=" << batchLabels << "]\n"
383 | << " -margin margin parameter in hinge loss [default=" << margin << "]\n"
384 | // << " -similarity takes value in [cosine, dot]. Whether to use cosine or dot product as similarity function in hinge loss [default=" << similarity << "]\n"
385 | << " -bucket number of buckets [default=" << bucket << "]\n"
386 | << " -label labels prefix [default='" << label << "']\n"
387 | << " -lr learning rate [default=" << lr << "]\n"
388 | << " -maxTrainTime max train time (seconds) [default=" << maxTrainTime << "]\n"
389 | << " -negSearchLimit number of negative labels sampled per dataset [default=" << negSearchLimit << "]\n"
390 | << " -maxNegSamples max number of negatives in a batch update [default=" << maxNegSamples << "]\n"
391 | << " -p the embedding of an entity equals the sum of its M feature embedding vectors devided by M^p [default=" << p << "]\n"
392 | // << " -adagrad whether to use adagrad in training [default=" << adagrad << "]\n"
393 | // << " -shareEmb whether to use the same embedding matrix for LHS and RHS [default=" << shareEmb << "]\n"
394 | // << " -dropoutLHS dropout probability for LHS features [default=" << dropoutLHS << "]\n"
395 | // << " -dropoutRHS dropout probability for RHS features [default=" << dropoutRHS << "]\n"
396 | << " -initRandSd initial values of embeddings are randomly generated from normal distribution with mean=0 and standard deviation=initRandSd [default=" << initRandSd << "]\n"
397 | << " -batchSize size of mini batch in training [default=" << batchSize << "]\n"
398 | // << " -verbose verbosity level [default=" << verbose << "]\n"
399 | // << " -debug whether it's in debug mode [default=" << debug << "]\n"
400 | << " -saveIntermediates save intermediate models or only the final model (integer or 'final') [default=" << (saveEveryNEpochs == -1 ? "'final'" : to_string(saveEveryNEpochs)) << "]\n"
401 | // << " -saveTempModel save intermediate models after each epoch with an unique name including epoch number [default=" << saveTempModel << "]\n"
402 | << " -thread number of threads [default=" << thread << "]\n\n";
403 | }
404 | }
405 |
406 | void Args::printArgs() {
407 | cout << "CellSpace Arguments: \n"
408 | << "dim: " << dim << endl
409 | << "ngrams: " << ngrams << endl
410 | << "k: " << k << endl
411 | << "sampleLen: " << (sampleLen == -1 ? "'given'" : to_string(sampleLen)) << endl
412 | << "exmpPerPeak: " << exmpPerPeak << endl
413 | << "epoch: " << epoch << endl
414 | << "margin: " << margin << endl
415 | << "bucket: " << bucket << endl
416 | << "label: '" << label << "'" << endl
417 | << "lr: " << lr << endl
418 | << "maxTrainTime: " << maxTrainTime << endl
419 | << "negSearchLimit: " << negSearchLimit << endl
420 | << "maxNegSamples: " << maxNegSamples << endl
421 | << "p: " << p << endl
422 | << "initRandSd: " << initRandSd << endl
423 | << "batchSize: " << batchSize << endl
424 | << "saveIntermediates: " << (saveEveryNEpochs == -1 ? "'final'" : to_string(saveEveryNEpochs)) << endl
425 | << "thread: " << thread << endl
426 | << "-----------------" << endl;
427 | }
428 |
429 | void Args::save(std::ostream& out) {
430 | out.write((char*) &(dim), sizeof(int));
431 | out.write((char*) &(epoch), sizeof(int));
432 | out.write((char*) &(maxTrainTime), sizeof(int));
433 | out.write((char*) &(minCount), sizeof(int));
434 | out.write((char*) &(minCountLabel), sizeof(int));
435 | out.write((char*) &(maxNegSamples), sizeof(int));
436 | out.write((char*) &(negSearchLimit), sizeof(int));
437 | out.write((char*) &(ngrams), sizeof(int));
438 | out.write((char*) &(bucket), sizeof(int));
439 | out.write((char*) &(trainMode), sizeof(int));
440 | out.write((char*) &(shareEmb), sizeof(bool));
441 | out.write((char*) &(useWeight), sizeof(bool));
442 | out.write((char*) &(weightSep), sizeof(char));
443 | size_t size; // = fileFormat.size();
444 | out.write((char*) &(size), sizeof(size_t));
445 | // out.write((char*) &(fileFormat[0]), size);
446 | size = similarity.size();
447 | out.write((char*) &(size), sizeof(size_t));
448 | out.write((char*) &(similarity[0]), size);
449 | out.write((char*) &(batchSize), sizeof(int));
450 | }
451 |
452 | void Args::load(std::istream& in) {
453 | in.read((char*) &(dim), sizeof(int));
454 | in.read((char*) &(epoch), sizeof(int));
455 | // in.read((char*) &(maxTrainTime), sizeof(int));
456 | in.read((char*) &(minCount), sizeof(int));
457 | in.read((char*) &(minCountLabel), sizeof(int));
458 | in.read((char*) &(maxNegSamples), sizeof(int));
459 | in.read((char*) &(negSearchLimit), sizeof(int));
460 | in.read((char*) &(ngrams), sizeof(int));
461 | in.read((char*) &(bucket), sizeof(int));
462 | in.read((char*) &(trainMode), sizeof(int));
463 | // in.read((char*) &(shareEmb), sizeof(bool));
464 | in.read((char*) &(useWeight), sizeof(bool));
465 | in.read((char*) &(weightSep), sizeof(char));
466 | size_t size;
467 | in.read((char*) &(size), sizeof(size_t));
468 | // fileFormat.resize(size);
469 | // in.read((char*) &(fileFormat[0]), size);
470 | in.read((char*) &(size), sizeof(size_t));
471 | similarity.resize(size);
472 | in.read((char*) &(similarity[0]), size);
473 | in.read((char*) &(batchSize), sizeof(int));
474 | }
475 |
476 | }
477 |
--------------------------------------------------------------------------------
/cpp/src/utils/args.h:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | #pragma once
9 |
10 | #include
11 | #include
12 | #include
13 |
14 | namespace starspace {
15 |
16 | class Args {
17 | public:
18 | Args();
19 | std::string trainFile;
20 | std::string validationFile;
21 | std::string testFile;
22 | std::string predictionFile;
23 | std::string model;
24 | // std::string initModel;
25 | // std::string fileFormat;
26 | std::string compressFile;
27 | std::string label;
28 | std::string basedoc;
29 | std::string loss;
30 | std::string similarity;
31 |
32 | char weightSep;
33 | double lr;
34 | double termLr;
35 | double norm;
36 | double margin;
37 | double initRandSd;
38 | double p;
39 | double dropoutLHS;
40 | double dropoutRHS;
41 | double wordWeight;
42 | size_t dim;
43 | int epoch;
44 | int ws;
45 | int maxTrainTime;
46 | int validationPatience;
47 | int thread;
48 | int maxNegSamples;
49 | int negSearchLimit;
50 | int minCount;
51 | int minCountLabel;
52 | int bucket;
53 | int ngrams;
54 | int trainMode;
55 | int K;
56 | int batchSize;
57 | int numGzFile;
58 | bool verbose;
59 | bool debug;
60 | bool adagrad;
61 | bool isTrain;
62 | bool normalizeText;
63 | // bool saveEveryEpoch;
64 | int saveEveryNEpochs;
65 | // bool saveTempModel;
66 | bool shareEmb;
67 | bool useWeight;
68 | bool trainWord;
69 | bool excludeLHS;
70 |
71 | // CellSpace parameters:
72 | bool isCellSpace;
73 | // bool fixedFeatEmb;
74 | // bool batchLabels;
75 | int k;
76 | int sampleLen;
77 | int exmpPerPeak;
78 | std::vector cp_matrix_list;
79 | std::vector peaks_list;
80 | // std::string feat_emb;
81 | std::vector nCells_list, nPeaks_list, first_peak_idx;
82 | unsigned nCells_total, nPeaks_total, nBatches;
83 |
84 | void parseArgs(int, char**);
85 | void printHelp();
86 | void printArgs();
87 | void save(std::ostream& out);
88 | void load(std::istream& in);
89 | bool isTrue(std::string arg);
90 | };
91 |
92 | }
93 |
--------------------------------------------------------------------------------
/cpp/src/utils/normalize.cpp:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (c) Facebook, Inc. and its affiliates.
3 | *
4 | * This source code is licensed under the MIT license found in the
5 | * LICENSE file in the root directory of this source tree.
6 | */
7 |
8 | #include "normalize.h"
9 |
10 | #include
11 | #include