├── .github
└── workflows
│ └── tutorial-tests.yml
├── .gitignore
├── LICENSE
├── README.md
├── bin
└── micromamba
├── causalimages.pdf
├── causalimages
├── .DS_Store
├── DESCRIPTION
├── NAMESPACE
├── R
│ ├── CI_BuildBackend.R
│ ├── CI_Confounding.R
│ ├── CI_GetAndSaveGeolocatedImages.R
│ ├── CI_GetMoments.R
│ ├── CI_Heterogeneity.R
│ ├── CI_Heterogeneity_plot.R
│ ├── CI_ImageModelBackbones.R
│ ├── CI_InitializeJAX.R
│ ├── CI_PredictiveRun.R
│ ├── CI_TFRecordManagement.R
│ ├── CI_TfRecordFxns.R
│ ├── CI_TrainDefine.R
│ ├── CI_TrainDo.R
│ ├── CI_helperFxns.R
│ └── CI_image2.R
├── data
│ ├── CausalImagesTutorialData.RData
│ └── datalist
├── man
│ ├── AnalyzeImageConfounding.Rd
│ ├── AnalyzeImageHeterogeneity.Rd
│ ├── BuildBackend.Rd
│ ├── GetAndSaveGeolocatedImages.Rd
│ ├── GetElementFromTfRecordAtIndices.Rd
│ ├── GetImageRepresentations.Rd
│ ├── GetMoments.Rd
│ ├── LongLat2CRS.Rd
│ ├── PredictiveRun.Rd
│ ├── TFRecordManagement.Rd
│ ├── TrainDefine.Rd
│ ├── TrainDo.Rd
│ ├── WriteTfRecord.Rd
│ ├── image2.Rd
│ ├── message2.Rd
│ └── print2.Rd
└── tests
│ ├── Test_AAARunAllTutorialsSuite.R
│ ├── Test_AnalyzeImageConfounding.R
│ ├── Test_AnalyzeImageHeterogeneity.R
│ ├── Test_BuildBackend.R
│ ├── Test_ExtractImageRepresentations.R
│ └── Test_UsingTfRecords.R
├── documentPackage.R
├── misc
├── dataverse
│ ├── DataverseReadme_confounding.md
│ ├── DataverseReadme_heterogeneity.md
│ ├── DataverseTutorial_confounding.R
│ └── DataverseTutorial_heterogeneity.R
├── docker
│ └── setup
│ │ ├── CodexHelpers.sh
│ │ ├── Dockerfile
│ │ ├── DockerfileNoCompile
│ │ ├── FindDependencies.sh
│ │ ├── GenEnv.sh
│ │ └── GetRDependencyOrder.R
└── notes
│ └── MaintainerNotes.txt
├── other
├── DataverseReadme_confounding.md
├── DataverseReadme_heterogeneity.md
├── DataverseTutorial_confounding.R
├── DataverseTutorial_heterogeneity.R
└── PackageRunChecks.R
└── tutorials
├── AnalyzeImageConfounding_Tutorial_Advanced.R
├── AnalyzeImageConfounding_Tutorial_Base.R
├── AnalyzeImageConfounding_Tutorial_EmbeddingsOnly.R
├── AnalyzeImageConfounding_Tutorial_Simulation.R
├── AnalyzeImageHeterogeneity_Tutorial.R
├── BuildBackend_Tutorial.R
├── ExtractImageRepresentations_Tutorial.R
└── UsingTfRecords_Tutorial.R
/.github/workflows/tutorial-tests.yml:
--------------------------------------------------------------------------------
1 | name: tutorial-tests
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 |
8 | jobs:
9 | test:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v3
13 | - uses: conda-incubator/setup-miniconda@v2
14 | with:
15 | auto-update-conda: true
16 | python-version: 3.10
17 | - uses: r-lib/actions/setup-r@v2
18 | with:
19 | r-version: '4.2'
20 | - name: Install devtools
21 | run: |
22 | Rscript -e 'install.packages("devtools", repos="https://cloud.r-project.org")'
23 | - name: Install package dependencies
24 | run: |
25 | Rscript -e 'devtools::install_deps(dependencies = TRUE)'
26 | - name: Install causalimages package
27 | run: R CMD INSTALL causalimages
28 | - name: Build backend
29 | run: Rscript -e 'causalimages::BuildBackend()'
30 | - name: Link repository for tests
31 | run: |
32 | mkdir -p "$HOME/Documents"
33 | ln -s "$GITHUB_WORKSPACE" "$HOME/Documents/causalimages-software"
34 | - name: Run tutorial tests
35 | run: Rscript tests/Test_AAARunAllTutorialsSuite.R
36 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .ipynb_checkpoints
3 | .DS_Store
4 | .history
5 | default.profraw
6 |
7 | build
8 | dist
9 | /misc/docker/binaries/
10 | /misc/docker/binaries/bin/
11 | /misc/docker/binaries/src/
12 | */.DS_Store
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2025 Connor Jerzak
2 |
3 | This file is part of causalimages-software.
4 |
5 | causalimages-software is free software: you can redistribute it and/or modify
6 | it under the terms of the GNU General Public License as published by
7 | the Free Software Foundation, either version 3 of the License, or (at
8 | your option) any later version.
9 |
10 | causalimages-software is distributed in the hope that it will be useful,
11 | but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | GNU General Public License for more details.
14 |
15 | For more information, see .
--------------------------------------------------------------------------------
/bin/micromamba:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cjerzak/causalimages-software/a6ca8dd8aae433c5207dedaa89e2c099d5700b8f/bin/micromamba
--------------------------------------------------------------------------------
/causalimages.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cjerzak/causalimages-software/a6ca8dd8aae433c5207dedaa89e2c099d5700b8f/causalimages.pdf
--------------------------------------------------------------------------------
/causalimages/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cjerzak/causalimages-software/a6ca8dd8aae433c5207dedaa89e2c099d5700b8f/causalimages/.DS_Store
--------------------------------------------------------------------------------
/causalimages/DESCRIPTION:
--------------------------------------------------------------------------------
1 | Package: causalimages
2 | Title: Causal Inference with Earth Observation, Bio-medical,
3 | and Social Science Images
4 | Version: 0.1
5 | Authors@R:
6 | c(person(given = "Connor",
7 | family = "Jerzak",
8 | role = c("aut", "cre"),
9 | email = "connor.jerzak@gmail.com",
10 | comment = c(ORCID = "0000-0003-1914-8905")),
11 | person(given = "Adel",
12 | family = "Daoud",
13 | role = "aut",
14 | comment = c(ORCID = "0000-0001-7478-8345")))
15 | Description: Provides a system for performing causal inference with earth observation,
16 | bio-medical, and social science images and image sequences (videos). The package
17 | uses a 'JAX' backend for GPU/TPU acceleration. Key functionalities include building
18 | conda-based backends (e.g., via 'BuildBackend'), implementing image-based confounder
19 | and heterogeneity analyses (e.g., 'AnalyzeImageConfounding', 'AnalyzeImageHeterogeneity'),
20 | and writing/reading large image corpora as '.tfrecord' files for use in training
21 | (via 'WriteTfRecord' and 'GetElementFromTfRecordAtIndices'). This allows researchers
22 | to scale causal inference to modern large-scale imagery data, bridging R with
23 | hardware-accelerated Python libraries. The package is partly based on Jerzak
24 | and Daoud (2023) .
25 | URL: https://github.com/cjerzak/causalimages-software
26 | BugReports: https://github.com/cjerzak/causalimages-software/issues
27 | Depends: R (>= 3.3.3)
28 | License: GPL-3
29 | Encoding: UTF-8
30 | LazyData: false
31 | Imports:
32 | tensorflow,
33 | reticulate,
34 | geosphere,
35 | raster,
36 | rrapply,
37 | glmnet,
38 | sf,
39 | data.table,
40 | pROC
41 | Suggests:
42 | knitr,
43 | rmarkdown
44 | VignetteBuilder: knitr
45 | RoxygenNote: 7.3.2
46 | RemoteType: local
47 | RemotePkgRef: local::~/Documents/causalimages-software/causalimages
48 | RemoteUrl: /Users/cjerzak/Documents/causalimages-software/causalimages
49 |
--------------------------------------------------------------------------------
/causalimages/NAMESPACE:
--------------------------------------------------------------------------------
1 | # Generated by roxygen2: do not edit by hand
2 |
3 | export(AnalyzeImageConfounding)
4 | export(AnalyzeImageHeterogeneity)
5 | export(BuildBackend)
6 | export(GetAndSaveGeolocatedImages)
7 | export(GetElementFromTfRecordAtIndices)
8 | export(GetImageRepresentations)
9 | export(GetMoments)
10 | export(LongLat2CRS)
11 | export(PredictiveRun)
12 | export(TFRecordManagement)
13 | export(TrainDefine)
14 | export(TrainDo)
15 | export(WriteTfRecord)
16 | export(image2)
17 | export(message2)
18 | export(print2)
19 | import(raster)
20 | import(reticulate)
21 | import(rrapply)
22 |
--------------------------------------------------------------------------------
/causalimages/R/CI_BuildBackend.R:
--------------------------------------------------------------------------------
1 | #' Build the environment for CausalImages models. Builds a conda environment in which jax, tensorflow, tensorflow-probability optax, equinox, and jmp are installed.
2 | #'
3 | #' @param conda_env (default = `"CausalImagesEnv"`) Name of the conda environment in which to place the backends.
4 | #' @param conda (default = `auto`) The path to a conda executable. Using `"auto"` allows reticulate to attempt to automatically find an appropriate conda binary.
5 |
6 | #' @return Builds the computational environment for `causalimages`. This function requires an Internet connection.
7 | #' You may find out a list of conda Python paths via: `system("which python")`
8 | #'
9 | #' @examples
10 | #' # For a tutorial, see
11 | #' # github.com/cjerzak/causalimages-software/
12 | #'
13 | #' @export
14 | #' @md
15 |
16 | BuildBackend <- function(conda_env = "CausalImagesEnv", conda = "auto") {
17 | # --- helpers ---------------------------------------------------------------
18 | os <- Sys.info()[["sysname"]]
19 | machine <- Sys.info()[["machine"]]
20 | msg <- function(...) message(sprintf(...))
21 |
22 | pip_install <- function(pkgs, ...) {
23 | reticulate::py_install(
24 | packages = pkgs,
25 | envname = conda_env,
26 | conda = conda,
27 | pip = TRUE,
28 | ...
29 | )
30 | TRUE
31 | }
32 |
33 | # Find the Python executable inside the target conda env (for manual pip calls)
34 | env_python_path <- function() {
35 | # try via conda_list
36 | cl <- try(reticulate::conda_list(), silent = TRUE)
37 | if (!inherits(cl, "try-error") && any(cl$name == conda_env)) {
38 | py <- cl$python[match(conda_env, cl$name)]
39 | if (length(py) == 1 && !is.na(py) && nzchar(py) && file.exists(py)) return(py)
40 | }
41 | # fallback via conda binary location
42 | cb <- try(reticulate::conda_binary(conda), silent = TRUE)
43 | prefix <- if (!inherits(cb, "try-error") && nzchar(cb)) dirname(dirname(cb)) else {
44 | # last-ditch default
45 | if (os == "Windows") "C:/Miniconda3" else file.path(Sys.getenv("HOME"), "miniconda3")
46 | }
47 | if (os == "Windows")
48 | file.path(prefix, "envs", conda_env, "python.exe")
49 | else
50 | file.path(prefix, "envs", conda_env, "bin", "python")
51 | }
52 |
53 | pip_install_from_findlinks <- function(spec, find_links) {
54 | py <- env_python_path()
55 | cmd <- sprintf(
56 | "%s -m pip install --upgrade --no-user -f %s %s",
57 | shQuote(py), shQuote(find_links), shQuote(spec)
58 | )
59 | res <- try(system(cmd, intern = TRUE), silent = TRUE)
60 | !inherits(res, "try-error")
61 | }
62 |
63 | # --- conda env -------------------------------------------------------------
64 | reticulate::conda_create(
65 | envname = conda_env,
66 | conda = conda,
67 | python_version = "3.13"
68 | )
69 |
70 | # Install numpy early to stabilize BLAS/ABI choices if needed
71 | pip_install("numpy")
72 |
73 | # --- JAX first: hardware-aware selection -----------------------------------
74 | install_jax <- function() {
75 | if (os == "Darwin" && machine %in% c("arm64", "aarch64")) {
76 | # Apple Silicon: Metal backend
77 | #msg("Apple Silicon detected: installing JAX (Metal).")
78 | #pip_install(c("jax==0.5.0", "jaxlib==0.5.0", "jax-metal==0.1.1"))
79 | pip_install(c("jax", "jaxlib"))
80 | return(invisible(TRUE))
81 | }
82 |
83 | if (identical(os, "Linux")) {
84 | # Query NVIDIA driver major version (e.g., '535.171.04' -> 535)
85 | drv <- try(suppressWarnings(
86 | system("nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n1",
87 | intern = TRUE)
88 | ), silent = TRUE)
89 | drv_major <- suppressWarnings(as.integer(sub("^([0-9]+).*", "\\1", drv[1])))
90 | msg("Detected NVIDIA driver: %s", ifelse(length(drv) > 0, drv[1], "none/unknown"))
91 |
92 | # Prefer CUDA 13 when possible, fall back to CUDA 12, then CPU
93 | if (!is.na(drv_major) && drv_major >= 580) {
94 | msg("Driver >= 580: trying JAX CUDA 13 wheels.")
95 | ok <- try(pip_install("jax[cuda13]"), silent = TRUE)
96 | ok <- isTRUE(ok) && !inherits(ok, "try-error")
97 | if (!ok) {
98 | msg("CUDA 13 wheels failed; falling back to CUDA 12 extras.")
99 | ok <- try(pip_install("jax[cuda12]"), silent = TRUE)
100 | ok <- isTRUE(ok) && !inherits(ok, "try-error")
101 | }
102 | if (!ok) {
103 | msg("CUDA 12 extras failed; trying legacy 'cuda12_pip' via find-links.")
104 | ok <- pip_install_from_findlinks(
105 | "jax[cuda12_pip]",
106 | "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
107 | )
108 | }
109 | if (!ok) {
110 | msg("All CUDA wheel attempts failed; installing CPU-only JAX.")
111 | pip_install("jax")
112 | }
113 | } else if (!is.na(drv_major) && drv_major >= 525) {
114 | msg("Driver >= 525 and < 580: installing JAX CUDA 12 wheels.")
115 | ok <- try(pip_install("jax[cuda12]"), silent = TRUE)
116 | ok <- isTRUE(ok) && !inherits(ok, "try-error")
117 | if (!ok) {
118 | msg("CUDA 12 extras failed; trying legacy 'cuda12_pip' via find-links.")
119 | ok <- pip_install_from_findlinks(
120 | "jax[cuda12_pip]",
121 | "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
122 | )
123 | }
124 | if (!ok) {
125 | msg("CUDA wheels failed; installing CPU-only JAX.")
126 | pip_install("jax")
127 | }
128 | } else {
129 | msg("No suitable NVIDIA driver found (or too old); installing CPU-only JAX.")
130 | pip_install("jax")
131 | }
132 | return(invisible(TRUE))
133 | }
134 |
135 | # Other OSes: CPU-only JAX
136 | msg("Non-Linux or non-Apple-Silicon platform; installing CPU-only JAX.")
137 | pip_install("jax")
138 | }
139 |
140 | install_jax()
141 |
142 | # Optionally neutralize LD_LIBRARY_PATH within this env to avoid host overrides
143 | if (os == "Linux") {
144 | cb <- try(reticulate::conda_binary(conda), silent = TRUE)
145 | conda_prefix <- if (!inherits(cb, "try-error") && nzchar(cb)) dirname(dirname(cb)) else {
146 | file.path(Sys.getenv("HOME"), "miniconda3")
147 | }
148 | env_dir <- file.path(conda_prefix, "envs", conda_env)
149 | actdir <- file.path(env_dir, "etc", "conda", "activate.d")
150 | dir.create(actdir, recursive = TRUE, showWarnings = FALSE)
151 | try(writeLines("unset LD_LIBRARY_PATH", file.path(actdir, "00-unset-ld.sh")), silent = TRUE)
152 | }
153 |
154 | # --- Remaining packages (do NOT include 'jax' again to avoid downgrades) ----
155 | pip_install(c(
156 | "tensorflow",
157 | "optax",
158 | "torch",
159 | "transformers",
160 | "pillow",
161 | "tf-keras",
162 | "equinox",
163 | "jmp"
164 | ))
165 |
166 | # reinstall JAX again on Mac after forced upgrade to breaking version of JAX
167 | #if(os == "Darwin"){ install_jax() }
168 |
169 | done_msg <- sprintf("Done building causalimages backend (env '%s').", conda_env)
170 | if (exists("message2", mode = "function")) message2(done_msg) else message(done_msg)
171 | }
172 |
--------------------------------------------------------------------------------
/causalimages/R/CI_GetAndSaveGeolocatedImages.R:
--------------------------------------------------------------------------------
1 | #' Getting and saving geo-located images from a pool of .tif's
2 | #'
3 | #' A function that finds the image slice associated with the `long` and `lat` values, saves images by band (if `save_as = "csv"`) in save_folder.
4 | #'
5 | #' @param long Vector of numeric longitudes.
6 | #' @param lat Vector of numeric latitudes.
7 | #' @param keys The image keys associated with the long/lat coordinates.
8 | #' @param tif_pool A character vector specifying the fully qualified path to a corpus of .tif files.
9 | #' @param save_folder (default = `"."`) What folder should be used to save the output? Example: `"~/Downloads"`
10 | #' @param image_pixel_width An even integer specifying the pixel width (and height) of the saved images.
11 | #' @param save_as (default = `".csv"`) What format should the output be saved as? Only one option currently (`.csv`)
12 | #' @param lyrs (default = NULL) Integer (vector) specifying the layers to be extracted. Default is for all layers to be extracted.
13 | #'
14 | #' @return Finds the image slice associated with the `long` and `lat` values, saves images by band (if `save_as = "csv"`) in save_folder.
15 | #' The save format is: `sprintf("%s/Key%s_BAND%s.csv", save_folder, keys[i], band_)`
16 | #'
17 | #' @examples
18 | #'
19 | #' # Example use (not run):
20 | #' #MASTER_IMAGE_POOL_FULL_DIR <- c("./LargeTifs/tif1.tif","./LargeTifs/tif2.tif")
21 | #' #GetAndSaveGeolocatedImages(
22 | #' #long = GeoKeyMat$geo_long,
23 | #' #lat = GeoKeyMat$geo_lat,
24 | #' #image_pixel_width = 500L,
25 | #' #keys = row.names(GeoKeyMat),
26 | #' #tif_pool = MASTER_IMAGE_POOL_FULL_DIR,
27 | #' #save_folder = "./Data/Uganda2000_processed",
28 | #' #save_as = "csv",
29 | #' #lyrs = NULL)
30 | #'
31 | #' @import raster
32 | #' @export
33 | #' @md
34 | #'
35 | GetAndSaveGeolocatedImages <- function(
36 | long,
37 | lat,
38 | keys,
39 | tif_pool,
40 | image_pixel_width = 256L,
41 | save_folder = ".",
42 | save_as = "csv",
43 | lyrs = NULL){
44 |
45 | library(raster); library(sf)
46 | RADIUS_CELLS <- (DIAMETER_CELLS <- image_pixel_width) / 2
47 | bad_indices <- c();observation_indices <- 1:length(long)
48 | counter_b <- 0; for(i in observation_indices){
49 | counter_b <- counter_b + 1
50 | if(counter_b %% 10 == 0){message(sprintf("At image %s of %s",counter_b,length(observation_indices)))}
51 | SpatialTarget_longlat <- c(long[i],lat[i])
52 |
53 | found_<-F;counter_ <- 0; while(found_ == F){
54 | counter_ <- counter_ + 1
55 | if(is.na(tif_pool[counter_])){ found_ <- bad_ <- T }
56 | if(!is.na(tif_pool[counter_])){
57 | MASTER_IMAGE_ <- try(raster::brick( tif_pool[counter_] ), T)
58 | SpatialTarget_utm <- LongLat2CRS(
59 | long = SpatialTarget_longlat[1],
60 | lat = SpatialTarget_longlat[2],
61 | CRS_ref = raster::crs(MASTER_IMAGE_))
62 |
63 | SpatialTargetCellLoc <- raster::cellFromXY(
64 | object = MASTER_IMAGE_,
65 | xy = SpatialTarget_utm)
66 | SpatialTargetRowCol <- raster::rowColFromCell(
67 | object = MASTER_IMAGE_,
68 | cell = SpatialTargetCellLoc )
69 | if(!is.na(sum(SpatialTargetRowCol))){found_<-T;bad_ <- F}
70 | if(counter_ > 1000000){stop("ERROR! Target not found anywhere in pool!")}
71 | }
72 | }
73 | if(bad_){
74 | message(sprintf("Failure at %s. Apparently, no .tif contains the reference point",i))
75 | bad_indices <- c(bad_indices,i)
76 | }
77 | if(!bad_){
78 | message(sprintf("Success at %s - Extracting & saving image!", i))
79 | # available rows/cols
80 | rows_available <- nrow( MASTER_IMAGE_ )
81 | cols_available <- ncol( MASTER_IMAGE_ )
82 |
83 |
84 | # define start row/col
85 | start_row <- SpatialTargetRowCol[1,"row"] - RADIUS_CELLS
86 | start_col <- SpatialTargetRowCol[1,"col"] - RADIUS_CELLS
87 |
88 | # find end row/col
89 | end_row <- start_row + DIAMETER_CELLS
90 | end_col <- start_col + DIAMETER_CELLS
91 |
92 | # perform checks to deal with spilling over image
93 | if(start_row <= 0){start_row <- 1}
94 | if(start_col <= 0){start_col <- 1}
95 | if(end_row > rows_available){ start_row <- rows_available - DIAMETER_CELLS }
96 | if(end_col > cols_available){ start_col <- cols_available - DIAMETER_CELLS }
97 |
98 | for(iof in 0:0){
99 | if(is.null(lyrs)){lyrs <- 1:dim(MASTER_IMAGE_)[3] }
100 | band_iters <- ifelse(grepl(x = save_as, pattern ="csv"),
101 | yes = list(lyrs), no = list(1L) )[[1]]
102 | for(band_ in band_iters){
103 | if(iof > 0){
104 | start_row <- sample(1:(nrow(MASTER_IMAGE_)-DIAMETER_CELLS-1),1)
105 | start_col <- sample(1:(ncol(MASTER_IMAGE_)-DIAMETER_CELLS-1),1)
106 | }
107 | SpatialTargetImage_ <- getValuesBlock(MASTER_IMAGE_[[band_]],
108 | row = start_row, nrows = DIAMETER_CELLS,
109 | col = start_col, ncols = DIAMETER_CELLS,
110 | format = "matrix", lyrs = 1L)
111 | if(length(unique(c(SpatialTargetImage_)))<5){ bad_indices <- c(bad_indices,i) }
112 | check_ <- dim(SpatialTargetImage_) - c(DIAMETER_CELLS,DIAMETER_CELLS)
113 | if(any(check_ < 0)){print("WARNING: CHECKS FAILED"); browser()}
114 | if(grepl(x = save_as, pattern ="tif")){
115 | # in progress
116 | }
117 | if(grepl(x = save_as, pattern ="csv")){
118 | if(iof == 0){
119 | data.table::fwrite(file = sprintf("%s/Key%s_BAND%s.csv",
120 | save_folder, keys[i], band_),
121 | data.table::as.data.table(SpatialTargetImage_))
122 | }
123 | }
124 | }
125 | }
126 | }
127 | }
128 | message("Done with GetAndSaveGeolocatedImages()!")
129 | }
130 |
--------------------------------------------------------------------------------
/causalimages/R/CI_GetMoments.R:
--------------------------------------------------------------------------------
1 | #' Get moments for normalization (internal function)
2 | #'
3 | #' An internal function function for obtaining moments for channel normalization.
4 | #'
5 | #' @param iterator An iterator
6 | #' @param dataType A string denoting data type
7 | #' @param momentCalIters Number of minibatches with which to estimate moments
8 | #'
9 | #' @return Returns mean/sd arrays for normalization.
10 | #'
11 | #' @examples
12 | #' # (Not run)
13 | #' # GetMoments(iterator, dataType, image_dtype, momentCalIters = 34L)
14 | #' @export
15 | #' @md
16 | #'
17 | GetMoments <- function(iterator, dataType, image_dtype, momentCalIters = 34L){
18 | message2("Calibrating moments for input data normalization...")
19 | NORM_SD <- NORM_MEAN <- c(); for(momentCalIter in 1L:momentCalIters){
20 | # get a data batch
21 | ds_next_ <- try(iterator$get_next(),T)
22 |
23 | if(!"try-error" %in% class(ds_next_)){
24 | # setup normalizations
25 | ApplyAxis <- ifelse(dataType == "video", yes = 5, no = 4)
26 |
27 | # sanity check
28 | # causalimages::image2( cienv$np$array((ds_next_train[[1]])[2,,,1] )
29 |
30 | # update normalizations
31 | NORM_SD <- rbind(NORM_SD, apply(cienv$np$array(ds_next_[[1]]),ApplyAxis,sd))
32 | NORM_MEAN <- rbind(NORM_MEAN, apply(cienv$np$array(ds_next_[[1]]),ApplyAxis,mean))
33 | }
34 | }
35 |
36 | # mean calc
37 | NORM_MEAN_mat <- NORM_MEAN # same shape
38 | NORM_MEAN <- apply(NORM_MEAN,2,mean) # overall mean across all batches
39 | NORM_MEAN_array <- cienv$jnp$array(array(NORM_MEAN,dim=c(1,1,1,length(NORM_MEAN))),image_dtype)
40 |
41 | # SD calc using Rubin’s rule: combine within‐ and between‐batch variances
42 | NORM_SD_mat <- NORM_SD # matrix: rows = batches, cols = features
43 |
44 | #combine information to get
45 | m <- nrow(NORM_SD_mat)
46 | W <- colMeans(NORM_SD_mat^2) # average within‐batch variance
47 | B <- apply(NORM_MEAN_mat,2,var) # variance of the batch means
48 | T_var <- W + (1 + 1/m) * B # total variance
49 | NORM_SD <- sqrt(T_var) # combined SD
50 | # plot(apply(NORM_SD_mat,2,median),NORM_SD);abline(a=0,b=1)
51 | if("try-error" %in% class(NORM_SD)){browser()}
52 | NORM_SD_array <- cienv$jnp$array(array(NORM_SD,dim=c(1,1,1,length(NORM_SD))),image_dtype)
53 |
54 | if(dataType == "video"){
55 | NORM_MEAN_array <- cienv$jnp$expand_dims(NORM_MEAN_array, 0L)
56 | NORM_SD_array <- cienv$jnp$expand_dims(NORM_SD_array, 0L)
57 | }
58 |
59 | return(list("NORM_MEAN_array"=NORM_MEAN_array,
60 | "NORM_SD_array"=NORM_SD_array,
61 | "NORM_MEAN" = NORM_MEAN,
62 | "NORM_SD" = NORM_SD))
63 | }
64 |
--------------------------------------------------------------------------------
/causalimages/R/CI_InitializeJAX.R:
--------------------------------------------------------------------------------
1 | initialize_jax <- function(conda_env = "cienv",
2 | conda_env_required = TRUE,
3 | Sys.setenv_text = NULL) {
4 | message2("Establishing connection to computational environment (build via causalimages::BuildBackend())")
5 |
6 | library(reticulate)
7 | #library(tensorflow)
8 |
9 | # Load reticulate (Declared in Imports: in DESCRIPTION)
10 | reticulate::use_condaenv(condaenv = conda_env, required = conda_env_required)
11 |
12 | if(!is.null(Sys.setenv_text)){
13 | eval(parse(text = Sys.setenv_text), envir = .GlobalEnv)
14 | }
15 |
16 | # Import Python packages once, storing them in cienv
17 | if (!exists("jax", envir = cienv, inherits = FALSE)) {
18 | cienv$jax <- reticulate::import("jax")
19 | cienv$jnp <- reticulate::import("jax.numpy")
20 | cienv$flash_mha <- try(import("flash_attn_jax.flash_mha"),TRUE)
21 | cienv$tf <- reticulate::import("tensorflow")
22 | cienv$np <- reticulate::import("numpy")
23 | cienv$jmp <- reticulate::import("jmp")
24 | cienv$optax <- reticulate::import("optax")
25 | #cienv$oryx <- reticulate::import("tensorflow_probability.substrates.jax")
26 | cienv$eq <- reticulate::import("equinox")
27 | cienv$py_gc <- reticulate::import("gc")
28 | }
29 |
30 | # set memory growth for tensorflow
31 | for(device_ in cienv$tf$config$list_physical_devices()){ try(cienv$tf$config$experimental$set_memory_growth(device_, T),T) }
32 |
33 | # ensure tensorflow doesn't use GPU
34 | try(cienv$tf$config$set_visible_devices(list(), "GPU"), silent = TRUE)
35 |
36 | # Disable 64-bit computations
37 | cienv$jax$config$update("jax_enable_x64", FALSE)
38 | cienv$jaxFloatType <- cienv$jnp$float32
39 | }
40 |
41 | initialize_torch <- function(conda_env = "cienv",
42 | conda_env_required = TRUE,
43 | Sys.setenv_text = NULL) {
44 | # Import Python packages once, storing them in cienv
45 | if (!exists("torch", envir = cienv, inherits = FALSE)) {
46 | cienv$torch <- reticulate::import("torch")
47 | cienv$transformers <- reticulate::import("transformers")
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/causalimages/R/CI_TFRecordManagement.R:
--------------------------------------------------------------------------------
1 | #' Defines an internal TFRecord management routine (internal function)
2 | #'
3 | #' Defines management defined in TFRecordManagement(). Internal function.
4 | #'
5 | #' @param . No parameters.
6 | #'
7 | #' @return Internal function defining a tfrecord management sequence.
8 | #'
9 | #' @import reticulate rrapply
10 | #' @export
11 | #' @md
12 | TFRecordManagement <- function(){
13 |
14 | if(is.null(file)){stop("No file specified for tfrecord!")}
15 | changed_wd <- F; if( !is.null( file ) ){
16 | message2("Establishing connection with tfrecord")
17 | tf_record_name <- file
18 | if( !grepl(tf_record_name, pattern = "/") ){
19 | tf_record_name <- paste("./",tf_record_name, sep = "")
20 | }
21 | tf_record_name <- strsplit(tf_record_name,split="/")[[1]]
22 | new_wd <- paste(tf_record_name[-length(tf_record_name)], collapse = "/")
23 | message2(sprintf("Temporarily re-setting the wd to %s", new_wd ) )
24 | changed_wd <- T; setwd( new_wd )
25 |
26 | # define video indicator
27 | useVideoIndicator <- dataType == "video"
28 |
29 | # define tf record
30 | tf_dataset <- cienv$tf$data$TFRecordDataset( tf_record_name[length(tf_record_name)] )
31 |
32 | # helper functions
33 | getParsed_tf_dataset_inference <- function(tf_dataset){
34 | dataset <- tf_dataset$map( function(x){parse_tfr_element(x,
35 | readVideo = useVideoIndicator,
36 | image_dtype = image_dtype_tf)} )
37 | return( dataset <- dataset$batch( ai(max(2L,round(batchSize/2L) ))) )
38 | }
39 |
40 | message2("Setting up iterators...") # skip the first test_size observations
41 | getParsed_tf_dataset_train_Select <- function( tf_dataset ){
42 | return( tf_dataset$map( function(x){ parse_tfr_element(x,
43 | readVideo = useVideoIndicator,
44 | image_dtype = image_dtype_tf)},
45 | num_parallel_calls = cienv$tf$data$AUTOTUNE) )
46 | }
47 | getParsed_tf_dataset_train_BatchAndShuffle <- function( tf_dataset ){
48 | tf_dataset <- tf_dataset$shuffle(buffer_size = cienv$tf$constant(ai(TfRecords_BufferScaler*batchSize),
49 | dtype=cienv$tf$int64),
50 | reshuffle_each_iteration = T)
51 | tf_dataset <- tf_dataset$batch( ai(batchSize) )
52 | tf_dataset <- tf_dataset$prefetch( cienv$tf$data$AUTOTUNE )
53 | return( tf_dataset )
54 | }
55 | if(!is.null(TFRecordControl)){
56 | tf_dataset_train_control <- getParsed_tf_dataset_train_Select(
57 | tf_dataset$skip( test_size <- ai( TFRecordControl$nTest ) )
58 | )$take( ai(TFRecordControl$nControl) )$`repeat`(-1L)
59 |
60 | tf_dataset_train_treated <- getParsed_tf_dataset_train_Select(
61 | tf_dataset$skip( test_size <- ai( TFRecordControl$nTest ) )
62 | )$skip( ai(TFRecordControl$nControl)+1L)$`repeat`(-1L)
63 |
64 | tf_dataset_train_treated <- getParsed_tf_dataset_train_BatchAndShuffle( tf_dataset_train_treated )
65 | tf_dataset_train_control <- getParsed_tf_dataset_train_BatchAndShuffle( tf_dataset_train_control )
66 |
67 | ds_iterator_train_treated <- reticulate::as_iterator( tf_dataset_train_treated )
68 | ds_iterator_train_control <- reticulate::as_iterator( tf_dataset_train_control )
69 | ds_iterator_train <- reticulate::as_iterator( tf_dataset_train_control )
70 | }
71 | if(is.null(TFRecordControl)){
72 | getParsed_tf_dataset_train <- function( tf_dataset ){
73 | dataset <- tf_dataset$map( function(x){ parse_tfr_element(x, readVideo = useVideoIndicator, image_dtype = image_dtype_tf)},
74 | num_parallel_calls = cienv$tf$data$AUTOTUNE)
75 | dataset <- dataset$shuffle(buffer_size = cienv$tf$constant(ai(TfRecords_BufferScaler*batchSize), dtype=cienv$tf$int64),
76 | reshuffle_each_iteration = FALSE) # set FALSE so same train/test split each re-initialization
77 | dataset <- dataset$batch( ai(batchSize) )
78 | dataset <- dataset$prefetch( cienv$tf$data$AUTOTUNE )
79 | return( dataset )
80 | }
81 |
82 | # shuffle (generating different train/test splits)
83 | tf_dataset <- cienv$tf$data$Dataset$shuffle( tf_dataset,
84 | buffer_size = cienv$tf$constant(ai(10L*TfRecords_BufferScaler*batchSize),
85 | dtype=cienv$tf$int64), reshuffle_each_iteration = F )
86 | tf_dataset_train <- getParsed_tf_dataset_train(
87 | tf_dataset$skip(test_size <- as.integer(round(testFrac * length(unique(imageKeysOfUnits)) )) ) )$`repeat`( -1L )
88 | ds_iterator_train <- reticulate::as_iterator( tf_dataset_train )
89 | }
90 |
91 | # define inference iterator
92 | tf_dataset_inference <- getParsed_tf_dataset_inference( tf_dataset )
93 | ds_iterator_inference <- reticulate::as_iterator( tf_dataset_inference )
94 |
95 | # Other helper functions
96 | getParsed_tf_dataset_train_Shuffle <- function( tf_dataset ){
97 | tf_dataset <- tf_dataset$shuffle(buffer_size = cienv$tf$constant(ai(TfRecords_BufferScaler*batchSize),
98 | dtype=cienv$tf$int64),
99 | reshuffle_each_iteration = FALSE )
100 | return(tf_dataset)
101 | }
102 | }
103 | }
--------------------------------------------------------------------------------
/causalimages/R/CI_TfRecordFxns.R:
--------------------------------------------------------------------------------
1 | #' Write an image corpus as a .tfrecord file
2 | #'
3 | #' Writes an image corpus to a `.tfrecord` file for rapid reading of images into memory for fast ML training.
4 | #' Specifically, this function serializes an image or video corpus into a `.tfrecord` file, enabling efficient data loading for machine learning tasks, particularly for image-based causal inference training.
5 | #' It requires that users define an `acquireImageFxn` function that accepts keys and returns the corresponding image or video as an array of dimensions `(length(keys), nSpatialDim1, nSpatialDim2, nChannels)` for images or `(length(keys), nTimeSteps, nSpatialDim1, nSpatialDim2, nChannels)` for video sequences.
6 |
7 | #'
8 | #' @param file A character string naming a file for writing.
9 | #' @param uniqueImageKeys A vector specifying the unique image keys of the corpus.
10 | #' A key grabs an image/video array via `acquireImageFxn(key)`.
11 | #' @param acquireImageFxn A function whose input is an observation keys and whose output is an array with dimensions `(length(keys), nSpatialDim1, nSpatialDim2, nChannels)` for images and `(length(keys), nTimeSteps, nSpatialDim1, nSpatialDim2, nChannels)` for image sequence data.
12 | #' @param conda_env (default = `"CausalImagesEnv"`) A `conda` environment where computational environment lives, usually created via `causalimages::BuildBackend()`
13 | #' @param conda_env_required (default = `T`) A Boolean stating whether use of the specified conda environment is required.
14 | #' @param writeVideo (default = `FALSE`) Should we assume we're writing image sequence data of form batch by time by height by width by channels?
15 | #'
16 | #' @return Writes a unique key-referenced `.tfrecord` from an image/video corpus for use in image-based causal inference training.
17 | #'
18 | #' @examples
19 | #' # Example usage (not run):
20 | #' #WriteTfRecord(
21 | #' # file = "./NigeriaConfoundApp.tfrecord",
22 | #' # uniqueImageKeys = 1:n,
23 | #' # acquireImageFxn = acquireImageFxn)
24 | #'
25 | #' @export
26 | #' @md
27 | WriteTfRecord <- function(file,
28 | uniqueImageKeys,
29 | acquireImageFxn,
30 | writeVideo = F,
31 | image_dtype = "float16",
32 | conda_env = "CausalImagesEnv",
33 | conda_env_required = T,
34 | Sys.setenv_text = NULL){
35 | if(!"jax" %in% ls(envir = cienv)) {
36 | initialize_jax(conda_env = conda_env,
37 | conda_env_required = conda_env_required,
38 | Sys.setenv_text = Sys.setenv_text)
39 | }
40 |
41 | if(length(uniqueImageKeys) != length(unique(uniqueImageKeys))){
42 | stop("Stopping because length(uniqueImageKeys) != length(unique(uniqueImageKeys)) \n
43 | Remember: Input to WriteTFRecord is uniqueImageKeys, not imageKeysOfUnits where redundancies may live")
44 | }
45 |
46 | # helper fxns
47 | message2("Initializing tfrecord helpers...")
48 | {
49 | # see https://towardsdatascience.com/a-practical-guide-to-tfrecords-584536bc786c
50 | my_bytes_feature <- function(value){
51 | #"""Returns a bytes_list from a string / byte."""
52 | #if(class(value) == class(cienv$tf$constant(0))){ # if value ist tensor
53 | value = value$numpy() # get value of tensor
54 | #}
55 | return( cienv$tf$train$Feature(bytes_list=cienv$tf$train$BytesList(value=list(value))))
56 | }
57 |
58 | my_simple_bytes_feature <- function(value){
59 | return( cienv$tf$train$Feature(bytes_list = cienv$tf$train$BytesList(value = list(value$numpy()))) )
60 | }
61 |
62 | my_int_feature <- function(value){
63 | #"""Returns an int64_list from a bool / enum / int / uint."""
64 | return( cienv$tf$train$Feature(int64_list=cienv$tf$train$Int64List(value=list(value))) )
65 | }
66 |
67 | my_serialize_array <- function(array){return( cienv$tf$io$serialize_tensor(array) )}
68 |
69 | parse_single_image <- function(image, index, key){
70 | if(writeVideo == F){
71 | data <- dict(
72 | "height" = my_int_feature( image$shape[[1]] ), # note: zero indexed
73 | "width" = my_int_feature( image$shape[[2]] ),
74 | "channels" = my_int_feature( image$shape[[3]] ),
75 | "raw_image" = my_bytes_feature( my_serialize_array( image ) ),
76 | "index" = my_int_feature( index ),
77 | "key" = my_bytes_feature( my_serialize_array(key) ))
78 | }
79 | if(writeVideo == T){
80 | data <- dict(
81 | "time" = my_int_feature( image$shape[[1]] ), # note: zero indexed
82 | "height" = my_int_feature( image$shape[[2]] ),
83 | "width" = my_int_feature( image$shape[[3]] ),
84 | "channels" = my_int_feature( image$shape[[4]] ),
85 | "raw_image" = my_bytes_feature( my_serialize_array( image ) ),
86 | "index" = my_int_feature( index ),
87 | "key" = my_bytes_feature( my_serialize_array(key) ) )
88 | }
89 | out <- cienv$tf$train$Example( features = cienv$tf$train$Features(feature = data) )
90 | return( out )
91 | }
92 | }
93 |
94 | # for clarity, set file to tf_record_name
95 | message2("Starting save run...")
96 | tf_record_name <- file
97 | if( !grepl(tf_record_name, pattern = "/") ){
98 | tf_record_name <- paste("./",tf_record_name, sep = "")
99 | }
100 |
101 | orig_wd <- getwd()
102 | tf_record_name <- strsplit(tf_record_name,split="/")[[1]]
103 | new_wd <- paste(tf_record_name[- length(tf_record_name) ],collapse = "/")
104 | setwd( new_wd )
105 | tf_record_writer = cienv$tf$io$TFRecordWriter( tf_record_name[ length(tf_record_name) ] ) #create a writer that'll store our data to disk
106 | setwd( orig_wd )
107 | for(irz in 1:length(uniqueImageKeys)){
108 | if(irz %% 10 == 0 | irz == 1){ print( sprintf("[%s] At index %s of %s [%.3f%%]",
109 | format(Sys.time(), "%Y-%m-%d %H:%M:%S"),
110 | irz, length(uniqueImageKeys),
111 | 100*irz/length(uniqueImageKeys)) ) }
112 | tf_record_write_output <- parse_single_image(image = r2const(acquireImageFxn( uniqueImageKeys[irz] ),
113 | eval(parse(text = sprintf("cienv$tf$%s",image_dtype)))),
114 | index = irz,
115 | key = as.character(uniqueImageKeys[irz] ) )
116 | tf_record_writer$write( tf_record_write_output$SerializeToString() )
117 | }
118 | print("Finalizing tfrecords....")
119 | tf_record_writer$close()
120 | print("Done writing tfrecord!")
121 | }
122 |
123 | #!/usr/bin/env Rscript
124 | #' Reads unique key indices from a `.tfrecord` file.
125 | #'
126 | #' Reads unique key indices from a `.tfrecord` file saved via a call to `causalimages::WriteTfRecord`.
127 | #'
128 | #' @usage
129 | #'
130 | #' GetElementFromTfRecordAtIndices(uniqueKeyIndices, file,
131 | #' conda_env, conda_env_required)
132 | #'
133 | #' @param uniqueKeyIndices (integer vector) Unique image indices to be retrieved from a `.tfrecord`
134 | #' @param file (character string) A character string stating the path to a `.tfrecord`
135 | #' @param conda_env (Default = `NULL`) A `conda` environment where tensorflow v2 lives. Used only if a version of tensorflow is not already active.
136 | #' @param conda_env_required (default = `F`) A Boolean stating whether use of the specified conda environment is required.
137 | #'
138 | #' @return Returns content from a `.tfrecord` associated with `uniqueKeyIndices`
139 | #'
140 | #' @examples
141 | #' # Example usage (not run):
142 | #' #GetElementFromTfRecordAtIndices(
143 | #' #uniqueKeyIndices = 1:10,
144 | #' #file = "./NigeriaConfoundApp.tfrecord")
145 | #'
146 | #' @export
147 | #' @md
148 | GetElementFromTfRecordAtIndices <- function(uniqueKeyIndices, filename, nObs, readVideo = F,
149 | conda_env = NULL, conda_env_required = F, image_dtype = "float16",
150 | iterator = NULL, return_iterator = F){
151 | # consider passing iterator as input to function to speed up large-batch execution
152 | image_dtype_ <- try(eval(parse(text = sprintf("cienv$tf$%s",image_dtype))), T)
153 | if("try-error" %in% class(image_dtype_)){
154 | image_dtype_ <- try(eval(parse(text = sprintf("cienv$tf$%s",image_dtype$name))), T)
155 | }
156 | image_dtype <- image_dtype_
157 |
158 | if(is.null(iterator)){
159 | orig_wd <- getwd()
160 | tf_record_name <- filename
161 | if( !grepl(tf_record_name, pattern = "/") ){
162 | tf_record_name <- paste("./",tf_record_name, sep = "")
163 | }
164 | tf_record_name <- strsplit(tf_record_name,split="/")[[1]]
165 | new_wd <- paste(tf_record_name[-length(tf_record_name)],collapse = "/")
166 | setwd( new_wd )
167 |
168 | # Load the TFRecord file
169 | dataset = cienv$tf$data$TFRecordDataset( tf_record_name[length(tf_record_name)] )
170 |
171 | # Parse the tf.Example messages
172 | dataset <- dataset$map( function(x){ parse_tfr_element(x,
173 | readVideo = readVideo,
174 | image_dtype = image_dtype) }) # return
175 |
176 | index_counter <- last_in_ <- 0L
177 | return_list <- replicate(length( dataset$element_spec), {list(replicate(length(uniqueKeyIndices), list()))})
178 | }
179 |
180 | if(!is.null(iterator)){
181 | dataset_iterator <- iterator[[1]]
182 | last_in_ <- iterator[[2]] # note: last_in_ is 0 indexed
183 | index_counter <- 0L
184 | return_list <- replicate(length( dataset_iterator$element_spec),
185 | {list(replicate(length(uniqueKeyIndices), list()))})
186 | }
187 |
188 | # uniqueKeyIndices made 0 indexed
189 | uniqueKeyIndices <- as.integer( uniqueKeyIndices - 1L )
190 |
191 | for(in_ in (indices_sorted <- sort(uniqueKeyIndices))){
192 | index_counter <- index_counter + 1
193 |
194 | # Skip the first `uniqueKeyIndices` elements, shifted by current loc thru data set
195 | if( index_counter == 1 & is.null(iterator) ){
196 | dataset <- dataset$skip( as.integer(in_) )#$prefetch(buffer_size = 5L)
197 | dataset_iterator <- reticulate::as_iterator( dataset$take( as.integer(nObs - as.integer(in_) ) ))
198 | element <- dataset_iterator$`next`()
199 | }
200 |
201 | # Take the next element, then
202 | # Get the only element in the dataset (as a tuple of features)
203 | if(index_counter > 1 | !is.null(iterator)){
204 | needThisManyUnsavedIters <- (in_ - last_in_ - 1L)
205 | if(length(needThisManyUnsavedIters) > 0){ if(needThisManyUnsavedIters > 0){
206 | for(fari in 1:needThisManyUnsavedIters){ dataset_iterator$`next`() }
207 | } }
208 | element <- dataset_iterator$`next`()
209 | }
210 | last_in_ <- in_
211 |
212 | # form final output
213 | if(length(uniqueKeyIndices) == 1){ return_list <- element }
214 | if(length(uniqueKeyIndices) > 1){
215 | for(li_ in 1:length(element)){
216 | return_list[[li_]][[index_counter]] <- cienv$tf$expand_dims(element[[li_]],0L)
217 | }
218 | }
219 | if(index_counter %% 5==0){ try(cienv$py_gc$collect(),T) }
220 | }
221 |
222 | if(index_counter > 1){ for(li_ in 1:length(element)){
223 | return_list[[li_]] <- eval(parse(text =
224 | paste("cienv$tf$concat( list(", paste(paste("return_list[[li_]][[", 1:length(uniqueKeyIndices), "]]"),
225 | collapse = ","), "), 0L)", collapse = "") ))
226 | if( any(diff(uniqueKeyIndices)<0) ){ # re-order if needed
227 | return_list[[li_]] <- cienv$tf$gather(return_list[[li_]],
228 | indices = as.integer(match(uniqueKeyIndices,indices_sorted)-1L),
229 | axis = 0L)
230 | }
231 | }}
232 |
233 | if(is.null(iterator)){ setwd( orig_wd ) }
234 |
235 | if(return_iterator == T){
236 | return_list <- list(return_list, list(dataset_iterator, last_in_))
237 | }
238 |
239 | return( return_list )
240 | }
241 |
242 | parse_tfr_element <- function(element, readVideo, image_dtype){
243 | #use the same structure as above; it's kinda an outline of the structure we now want to create
244 | image_dtype_ <- try(eval(parse(text = sprintf("cienv$tf$%s",image_dtype))), T)
245 | if("try-error" %in% class(image_dtype_)){
246 | image_dtype_ <- try(eval(parse(text = sprintf("cienv$tf$%s",image_dtype$name))), T)
247 | }
248 | image_dtype <- image_dtype_
249 |
250 | dict_init_val <- list()
251 | if(!readVideo){
252 | im_feature_description <- dict(
253 | 'height' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$int64),
254 | 'width' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$int64),
255 | 'channels' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$int64),
256 | 'raw_image' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$string),
257 | 'index' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$int64),
258 | 'key' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$string)
259 | )
260 | }
261 |
262 | if(readVideo){
263 | im_feature_description <- dict(
264 | 'time' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$int64),
265 | 'height' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$int64),
266 | 'width' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$int64),
267 | 'channels' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$int64),
268 | 'raw_image' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$string),
269 | 'index' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$int64),
270 | 'key'= cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$string)
271 | )
272 | }
273 |
274 | # parse tf record
275 | content <- cienv$tf$io$parse_single_example(element,
276 | im_feature_description)
277 |
278 | # get 'feature' (e.g., image/image sequence)
279 | feature <- cienv$tf$io$parse_tensor( content[['raw_image']],
280 | out_type = image_dtype )
281 |
282 | # get the key
283 | key <- cienv$tf$io$parse_tensor( content[['key']],
284 | out_type = cienv$tf$string )
285 |
286 | # and reshape it appropriately
287 | if(!readVideo){
288 | feature = cienv$tf$reshape( feature, shape = c(content[['height']],
289 | content[['width']],
290 | content[['channels']]) )
291 | }
292 | if(readVideo){
293 | feature = cienv$tf$reshape( feature, shape = c(content[['time']],
294 | content[['height']],
295 | content[['width']],
296 | content[['channels']]) )
297 | }
298 |
299 | return( list(feature, content[['index']], key) )
300 | }
301 |
--------------------------------------------------------------------------------
/causalimages/R/CI_TrainDefine.R:
--------------------------------------------------------------------------------
1 | #' Defines an internal training routine (internal function)
2 | #'
3 | #' Defines trainers defined in TrainDefine(). Internal function.
4 | #'
5 | #' @param . No parameters.
6 | #'
7 | #' @return Internal function defining a training sequence.
8 | #'
9 | #' @import reticulate rrapply
10 | #' @export
11 | #' @md
12 | TrainDefine <- function(){
13 | message2("Define optimizer and training step...")
14 | {
15 | LR_schedule <- cienv$optax$warmup_cosine_decay_schedule(
16 | warmup_steps = (nWarmup <- (0.05*nSGD)),
17 | decay_steps = nSGD-nWarmup,
18 | init_value = learningRateMax/100,
19 | peak_value = learningRateMax,
20 | end_value = learningRateMax/100)
21 | plot(cienv$np$array(LR_schedule(cienv$jnp$array(1:nSGD))),xlab = "Iteration", ylab="Learning rate schedule")
22 | optax_optimizer <- cienv$optax$chain(
23 | cienv$optax$adaptive_grad_clip(clipping = 0.25, eps = 0.001),
24 | cienv$optax$adabelief( learning_rate = LR_schedule )
25 | )
26 | plot(cienv$np$array(LR_schedule(cienv$jnp$array(1:nSGD))), xlab = "Iteration", ylab = "Learning rate")
27 |
28 | # model partition, setup state, perform parameter count
29 | opt_state <- optax_optimizer$init( cienv$eq$partition(ModelList, cienv$eq$is_array)[[1]] )
30 | message2(sprintf("Total trainable parameter count: %s",
31 | nParamsRep <- nTrainable <-
32 | sum(unlist(lapply(cienv$jax$tree$leaves(cienv$eq$partition(ModelList,
33 | cienv$eq$is_array)[[1]]), function(zer){zer$size})))))
34 |
35 | # jit update fxns
36 | jit_apply_updates <- cienv$eq$filter_jit( cienv$optax$apply_updates )
37 | jit_get_update <- cienv$eq$filter_jit( optax_optimizer$update )
38 | }
39 | }
40 |
41 |
--------------------------------------------------------------------------------
/causalimages/R/CI_TrainDo.R:
--------------------------------------------------------------------------------
1 | #' Runs a training routine (internal function)
2 | #'
3 | #' Runs trainers defined in TrainDefine(). Internal function.
4 | #'
5 | #' @param . No parameters.
6 | #'
7 | #' @return Internal function performing model training.
8 | #'
9 | #' @import reticulate rrapply
10 | #' @export
11 | #' @md
12 | TrainDo <- function(){
13 | par(mfrow=c(1,2))
14 | keys2indices_list <- tapply(1:length(imageKeysOfUnits), imageKeysOfUnits, c)
15 | GradNorm_vec <- loss_vec <- rep(NA,times=nSGD)
16 | keysUsedInTraining <- c();i_<-1L ; DoneUpdates <- 0L; for(i in i_:nSGD){
17 | t0 <- Sys.time(); if(i %% 5 == 0 | i == 1){gc(); cienv$py_gc$collect()}
18 |
19 | if(is.null(TFRecordControl)){
20 | # get next batch
21 | ds_next_train <- ds_iterator_train$`next`()
22 |
23 | # if we run out of observations, reset iterator
24 | RestartedIterator <- FALSE; if( is.null(ds_next_train) ){
25 | message2("Re-setting iterator! (type 1)"); gc(); cienv$py_gc$collect()
26 | ds_iterator_train <- reticulate::as_iterator( tf_dataset_train )
27 | ds_next_train <- ds_iterator_train$`next`(); gc();cienv$py_gc$collect()
28 | }
29 |
30 | # get a new batch if size mismatch - size mismatches generate new cached compiled fxns
31 | if(!RestartedIterator){ if(dim(ds_next_train[[1]])[1] != batchSize){
32 | message2("Re-setting iterator! (type 2)"); gc(); cienv$py_gc$collect()
33 | ds_iterator_train <- reticulate::as_iterator( tf_dataset_train )
34 | ds_next_train <- ds_iterator_train$`next`(); gc(); cienv$py_gc$collect()
35 | } }
36 |
37 | # select batch indices based on keys
38 | batch_keys <- unlist( lapply( p2l(ds_next_train[[3]]$numpy()), as.character) )
39 | batch_indices <- sapply(batch_keys,function(key_){ f2n( sample(as.character( keys2indices_list[[key_]] ), 1) ) })
40 | ds_next_train <- ds_next_train[[1]]
41 | }
42 | if(!is.null(TFRecordControl)){
43 | # get next batch
44 | ds_next_train_control <- ds_iterator_train_control$`next`()
45 |
46 | # if we run out of observations, reset iterator
47 | RestartedIterator <- FALSE; if( is.null(ds_next_train_control) ){
48 | message2("Re-setting iterator! (type 1)"); gc(); cienv$py_gc$collect()
49 | ds_iterator_train_control <- reticulate::as_iterator( tf_dataset_train_control )
50 | ds_next_train_control <- ds_iterator_train_control$`next`(); gc();cienv$py_gc$collect()
51 | }
52 |
53 | # get a new batch if size mismatch - size mismatches generate new cached compiled fxns
54 | if(!RestartedIterator){ if(dim(ds_next_train_control[[1]])[1] != batchSize){
55 | message2("Re-setting iterator! (type 2)"); gc(); cienv$py_gc$collect()
56 | ds_iterator_train_control <- reticulate::as_iterator( tf_dataset_train_control )
57 | ds_next_train_control <- ds_iterator_train_control$`next`(); gc(); cienv$py_gc$collect()
58 | } }
59 |
60 | # get next batch
61 | ds_next_train_treated <- ds_iterator_train_treated$`next`()
62 |
63 | # if we run out of observations, reset iterator
64 | RestartedIterator <- F; if( is.null(ds_next_train_treated) ){
65 | message2("Re-setting iterator! (type 1)"); gc(); cienv$py_gc$collect()
66 | ds_iterator_train_treated <- reticulate::as_iterator( tf_dataset_train_treated )
67 | ds_next_train_treated <- ds_iterator_train_treated$`next`(); gc();cienv$py_gc$collect()
68 | }
69 |
70 | # get a new batch if size mismatch - size mismatches generate new cached compiled fxns
71 | if(!RestartedIterator){ if(dim(ds_next_train_treated[[1]])[1] != batchSize){
72 | message2("Re-setting iterator! (type 2)"); gc(); cienv$py_gc$collect()
73 | ds_iterator_train_treated <- reticulate::as_iterator( tf_dataset_train_treated )
74 | ds_next_train_treated <- ds_iterator_train_treated$`next`(); gc(); cienv$py_gc$collect()
75 | } }
76 |
77 | # select batch indices based on keys
78 | batch_keys <- c(unlist( lapply( p2l(ds_next_train_control[[3]]$numpy()), as.character) ),
79 | unlist( lapply( p2l(ds_next_train_treated[[3]]$numpy()), as.character) ))
80 | batch_indices <- sapply(batch_keys,function(key_){ f2n( sample(as.character( keys2indices_list[[key_]] ), 1) ) })
81 | ds_next_train <- cienv$tf$concat(list(ds_next_train_control[[1]],
82 | ds_next_train_treated[[1]]), 0L)
83 | }
84 | if(any(!batch_indices %in% keysUsedInTraining)){
85 | keysUsedInTraining <- c(keysUsedInTraining, batch_keys[!batch_keys %in% keysUsedInTraining])
86 | }
87 |
88 | # if no treat, define it (unused in GetLoss)
89 | if(!"obsW" %in% ls()){ obsW <- obsY }
90 |
91 | # training step
92 | if(!justCheckIterators){
93 | t1 <- Sys.time()
94 | if(i == 1){
95 | message2("Initial forward pass...")
96 | GetLoss(
97 | MPList[[1]]$cast_to_compute(ModelList), # model list
98 | MPList[[1]]$cast_to_compute(ModelList_fixed), # model list
99 | InitImageProcessFn(cienv$jnp$array(ds_next_train), cienv$jax$random$key(600L+i), inference = F), # m
100 | #InitImageProcessFn(cienv$jnp$array(ds_next_train), cienv$jax$random$key(600L), inference = FALSE), # m
101 | cienv$jnp$array(ifelse( !is.null(X), yes = list(X[batch_indices,]), no = list(1.))[[1]] , dtype = cienv$jnp$float16), # x
102 | cienv$jnp$array(as.matrix(obsW[batch_indices]), dtype = cienv$jnp$float16), # treat
103 | cienv$jnp$array(as.matrix(obsY[batch_indices]), dtype = cienv$jnp$float16), # y
104 | cienv$jax$random$split(cienv$jax$random$key( 500L+i ),length(batch_indices)), # vseed for observations
105 | StateList, # StateList
106 | MPList, # MPlist
107 | FALSE)[[1]]
108 | }
109 |
110 | # sanity check
111 | if(FALSE){
112 | test_index <- 2
113 | GetElementFromTfRecordAtIndices(
114 | uniqueKeyIndices = which(unique(imageKeysOfUnits)==unique(imageKeysOfUnits)[test_index]),
115 | filename = file,
116 | readVideo = useVideoIndicator,
117 | nObs = length(unique(imageKeysOfUnits) ) )
118 | # unique(imageKeysOfUnits)[test_index]
119 | }
120 |
121 | # Sanity check for dimension swapping as i varies
122 | if(i == 1){
123 | message2(sprintf("Training balance: %s",
124 | paste(paste(names(table(obsW[batch_indices])),
125 | table(obsW[batch_indices]), sep = " has "),collapse="; ")
126 | ))
127 | if(any(prop.table(table(obsW[batch_indices])) > 0.9) &
128 | !is.null(TFRecordControl)){
129 | stop( "Stopping - Balanced training not satisfied despite TFRecordControl being defined!" )
130 | }
131 | }
132 | # causalimages::image2(cienv$np$array(InitImageProcessFn(cienv$jnp$array(ds_next_train), cienv$jax$random$key(600L+sample(1:100,1)), inference = F)[2,,,1]))
133 | # causalimages::image2(cienv$np$array(InitImageProcessFn(cienv$jnp$array(ds_next_train), cienv$jax$random$key(600L+sample(1:100,1)), inference = F)[1,,,1]))
134 |
135 | # Get gradient update packages
136 | GradientUpdatePackage <- GradAndLossAndAux(
137 | MPList[[1]]$cast_to_compute(ModelList), MPList[[1]]$cast_to_compute(ModelList_fixed), # model lists
138 | InitImageProcessFn(cienv$jnp$array(ds_next_train), cienv$jax$random$key(600L+i), inference = FALSE), # m
139 | cienv$jnp$array(ifelse( !is.null(X), yes = list(X[batch_indices,]), no = list(1.))[[1]], dtype = ComputeDtype), # x
140 | cienv$jnp$array(as.matrix(obsW[batch_indices]), dtype = ComputeDtype), # treat (unused for prediction only runs)
141 | cienv$jnp$array(as.matrix(obsY[batch_indices]), dtype = ComputeDtype), # y
142 | cienv$jax$random$split(cienv$jax$random$key( 50L+i ),length(batch_indices)), # vseed for observations
143 | StateList, # StateList
144 | MPList, # MPlist
145 | FALSE) # inference
146 |
147 | # perform gradient updates
148 | {
149 | # get updated state
150 | StateList_tmp <- GradientUpdatePackage[[1]][[2]] # state
151 |
152 | # get loss + grad
153 | if(image_dtype_char == "float16"){
154 | loss_vec[i] <- myLoss_fromGrad <- cienv$np$array( MPList[[2]]$unscale( GradientUpdatePackage[[1]][[1]] ) )# value
155 | }
156 | if(image_dtype_char != "float16"){
157 | loss_vec[i] <- myLoss_fromGrad <- cienv$np$array( GradientUpdatePackage[[1]][[1]] )# value
158 | }
159 | GradientUpdatePackage <- GradientUpdatePackage[[2]] # grads
160 | GradientUpdatePackage <- cienv$eq$partition(GradientUpdatePackage, cienv$eq$is_inexact_array)
161 | GradientUpdatePackage_aux <- GradientUpdatePackage[[2]]; GradientUpdatePackage <- GradientUpdatePackage[[1]]
162 |
163 | # unscale + adjust loss scale is some non-finite or NA
164 | if(i == 1){
165 | Map2Zero <- cienv$eq$filter_jit(function(input){
166 | cienv$jax$tree$map(function(x){ cienv$jnp$where(cienv$jnp$isnan(x), cienv$jnp$array(0), x)}, input) })
167 | GetGetNorms <- cienv$eq$filter_jit(function(input){
168 | cienv$jax$tree$map(function(x){ cienv$jnp$mean(cienv$jnp$abs(x)) }, input) })
169 | AllFinite <- cienv$jax$jit( cienv$jmp$all_finite )
170 | }
171 | if(image_dtype_char == "float16"){
172 | GradientUpdatePackage <- Map2Zero( MPList[[2]]$unscale( GradientUpdatePackage ) )
173 | }
174 | AllFinite_DontAdjust <- AllFinite( GradientUpdatePackage ) &
175 | cienv$jnp$squeeze(cienv$jnp$array(!is.infinite(myLoss_fromGrad)))
176 | # MPList[[2]]$adjust( cienv$jnp$array(FALSE) )
177 | # MPList[[2]]$adjust( cienv$jnp$array(TRUE) )
178 | MPList[[2]] <- MPList[[2]]$adjust( AllFinite_DontAdjust )
179 | # which(is.na( c(unlist(lapply(cienv$jax$tree$leaves(myGrad_jax), function(zer){cienv$np$array(zer)}))) ) )
180 | # which(is.infinite( c(unlist(lapply(cienv$jax$tree$leaves(myGrad_jax), function(zer){cienv$np$array(zer)}))) ) )
181 |
182 | # get update norm
183 | GradNorm_vec[i] <- mean( GradVec <- unlist( lapply(cienv$jax$tree$leaves(GradientUpdatePackage),
184 | function(zer){ cienv$np$array(cienv$jnp$mean(cienv$jnp$abs(zer) )) }) ) )
185 |
186 | # update parameters if finite gradients
187 | DoUpdate <- !is.na(myLoss_fromGrad) & cienv$np$array(AllFinite_DontAdjust) &
188 | !is.infinite(myLoss_fromGrad) & ( GradNorm_vec[i] > 1e-10)
189 | if( !DoUpdate ){
190 | message2("Warning: Not updating parameters due to NA, zero, or non-finite gradients in mixed-precision training...")
191 | }
192 | if( DoUpdate ){
193 | DoneUpdates <- DoneUpdates + 1
194 |
195 | # cast updates to param
196 | GradientUpdatePackage <- MPList[[1]]$cast_to_param( GradientUpdatePackage )
197 |
198 | # get gradient updates
199 | # GradientUpdatePackage$SpatialTransformer$ResidualWts # check non-zero gradients
200 | # GradientUpdatePackage$SpatialTransformer$TransformerRenormer # -> check non-zero gradients here (indicates problem with dropout)
201 | GradientUpdatePackage <- jit_get_update(
202 | updates = GradientUpdatePackage,
203 | state = opt_state,
204 | params = (cienv$eq$partition(ModelList, cienv$eq$is_inexact_array)[[1]] )
205 | )
206 |
207 | if(FALSE){
208 | # Before the jit_get_update call, add:
209 | params_tree = cienv$eq$partition(ModelList, cienv$eq$is_array)[[1]]
210 | grads_tree = GradientUpdatePackage
211 |
212 | # Print tree structures
213 | message2("Params tree structure:")
214 | print(cienv$jax$tree$structure(params_tree))
215 | message2("Grads tree structure:")
216 | print(cienv$jax$tree$structure(grads_tree))
217 |
218 | # Check for None values
219 | params_leaves = cienv$jax$tree$leaves(params_tree)
220 | grads_leaves = cienv$jax$tree$leaves(grads_tree)
221 | message2(sprintf("Params leaves: %d, Grads leaves: %d",
222 | length(params_leaves), length(grads_leaves)))
223 | }
224 |
225 | # separate updates from state
226 | opt_state <- GradientUpdatePackage[[2]]
227 | GradientUpdatePackage <- cienv$eq$combine(GradientUpdatePackage[[1]],
228 | GradientUpdatePackage_aux)
229 |
230 | # perform updates
231 | #ModelList_tminus1 <- ModelList
232 | ModelList <- cienv$eq$combine( jit_apply_updates(
233 | params = cienv$eq$partition(ModelList, cienv$eq$is_inexact_array)[[1]],
234 | updates = GradientUpdatePackage),
235 | cienv$eq$partition(ModelList, cienv$eq$is_inexact_array)[[2]])
236 | StateList <- StateList_tmp
237 | if(FALSE){
238 | LayerWiseParamDiff <- function(params_new, params_old){
239 | diff_fn <- function(new, old){
240 | cienv$jnp$mean(cienv$jnp$abs(new - old))
241 | }
242 | diff_list <- cienv$jax$tree$map(diff_fn, params_new, params_old)
243 | rrapply::rrapply(diff_list, how = "flatten")
244 | }
245 |
246 | # Use the function right after parameters update:
247 | param_diffs <- LayerWiseParamDiff(
248 | cienv$eq$partition(ModelList, cienv$eq$is_array)[[1]],
249 | cienv$eq$partition(ModelList_tminus1, cienv$eq$is_array)[[1]]
250 | )
251 | GradientUpdatePackage
252 | }
253 | suppressWarnings( rm(StateList_tmp, GradientUpdatePackage,BNInfo) )
254 | }
255 | i_ <- i ; if( (i %% 25 == 0 | i < 10) &
256 | (length(loss_vec[!is.na(loss_vec) & !is.infinite(loss_vec)]) > 5) ){
257 | loss_vec_ <- loss_vec
258 | loss_vec_[is.infinite(loss_vec_)] <- NA
259 | message2(sprintf("SGD iteration %s of %s -- Loss: %.2f (%.1f%%) --
260 | Total iter time (s): %.2f - Grad iter time (s): %.2f --
261 | Grad norm: %.3f -- Grads zero %%: %.1f%% --
262 | %.3f tstat on log(iter)",
263 | i, nSGD, loss_vec[i], 100*mean(loss_vec[i] <= loss_vec[1:i],na.rm=T),
264 | (Sys.time() - t0)[[1]], (Sys.time() - t1)[[1]],
265 | mean(GradVec,na.rm=T), 100*mean(GradVec==0,na.rm=T),
266 | ifelse("try-error" %in% class(tstat_ <- try(coef(summary(lm(loss_vec[1:i]~log(1:i))))[2,3], T)),yes = NA, no = tstat_)
267 | ) )
268 | loss_vec <- f2n(loss_vec); loss_vec[is.infinite(loss_vec)] <- NA
269 | plot( (na.omit(loss_vec)), cex.main = 0.95,ylab = "Loss Function",xlab="SGD Iteration Number")
270 | if(length(na.omit(loss_vec)) > 10){ points(smooth.spline( (na.omit(loss_vec) ),spar=1,cv=TRUE), col="red",type = "l",lwd=5) }
271 | plot(GradNorm_vec[!is.infinite(GradNorm_vec) & !is.na(GradNorm_vec)], cex.main = 0.95,ylab = "GradNorm",xlab="SGD Iteration Number")
272 | }
273 |
274 | # Early stopping
275 |
276 | if( !is.null(earlyStopThreshold) ){
277 | window <- 25
278 | patience_limit <- 25
279 | if(!"patience_counter" %in% ls()){ patience_counter <- 0 }
280 | if( i > 2*window & i > 100 ){
281 | first_avg <- mean(loss_vec[1:10], na.rm = TRUE)
282 | prev_avg <- mean(loss_vec[(i-2*window):(i-window-1)], na.rm = TRUE)
283 | curr_avg <- mean(loss_vec[(i-window):i], na.rm = TRUE)
284 |
285 | se_diff <- sqrt( var(loss_vec[(i-2*window):(i-window-1)], na.rm=TRUE)/window +
286 | var(loss_vec[(i-window):i], na.rm=TRUE)/window )
287 | prev_avg_upper <- curr_avg + (t_es<-2.528)*sqrt( var(loss_vec[(i-2*window):(i-window-1)], na.rm=TRUE)/window )
288 | curr_avg_lower <- prev_avg - t_es*sqrt( var(loss_vec[(i-window):i], na.rm=TRUE)/window )
289 |
290 | if( curr_avg >= prev_avg - t_es*se_diff & curr_avg < 0.8*first_avg ){
291 | message2("We fail to detect evidence of improvement, early stopping being considered...")
292 | patience_counter <- patience_counter + 1
293 | if(patience_counter >= patience_limit){
294 | message2("Early stopping triggered - No more meaningful improvement.")
295 | break
296 | }
297 | } else {
298 | patience_counter <- 0 # reset when any improvement detected
299 | }
300 | }
301 | }
302 | }
303 | }
304 | } # end for(i in i_:nSGD){
305 | par(mfrow=c(1,1))
306 | }
307 |
308 |
--------------------------------------------------------------------------------
/causalimages/R/CI_helperFxns.R:
--------------------------------------------------------------------------------
1 | #' Get the spatial point of long/lat coordinates
2 | #'
3 | #' Convert longitude and latitude coordinates to a different coordinate reference
4 | #' system (CRS).
5 | #'
6 | #' @param long Vector of numeric longitudes.
7 | #' @param lat Vector of numeric latitudes.
8 | #' @param CRS_ref A CRS into which the long-lat point should be projected.
9 | #'
10 | #' @return Numeric vector of length two giving the coordinates of the supplied
11 | #' location in the CRS defined by `CRS_ref`.
12 | #'
13 | #' @examples
14 | #' # (Not run)
15 | #' #spatialPt <- LongLat2CRS(long = 49.932,
16 | #' # lat = 35.432,
17 | #' # CRS_ref = sf::st_crs("+proj=lcc +lat_1=48 +lat_2=33 +lon_0=-100 +ellps=WGS84"))
18 | #' @export
19 | #' @md
20 | #'
21 | LongLat2CRS <- function(long, lat, CRS_ref){
22 | point_longlat <- sf::st_as_sf(
23 | data.frame(long = as.numeric(long), lat = as.numeric(lat)),
24 | coords = c("long", "lat"),
25 | crs = 4326
26 | )
27 | point_longlat_ref <- sf::st_transform(point_longlat, crs = sf::st_crs(CRS_ref))
28 | coords_ <- sf::st_coordinates(point_longlat_ref)[1, ]
29 | return(coords_)
30 | }
31 |
32 | LongLat2CRS_extent <- function(point_longlat,
33 | CRS_ref,
34 | target_km_diameter = 10){
35 | target_km <- target_km_diameter
36 | offset <- 1/111 * (target_km/2)
37 | point_longlat1 <- c(long = as.numeric(point_longlat[1]) - offset,
38 | lat = as.numeric(point_longlat[2]) - offset)
39 | point_longlat2 <- c(long = as.numeric(point_longlat[1]) + offset,
40 | lat = as.numeric(point_longlat[2]) + offset)
41 | pts <- sf::st_as_sf(rbind(point_longlat1, point_longlat2),
42 | coords = c("long", "lat"), crs = 4326)
43 | pts_ref <- sf::st_transform(pts, crs = sf::st_crs(CRS_ref))
44 | coords_ <- sf::st_coordinates(pts_ref)
45 | return(raster::extent(min(coords_[,1]), max(coords_[,1]),
46 | min(coords_[,2]), max(coords_[,2])))
47 | }
48 |
49 | # converts python builtin to list
50 | p2l <- function(zer){
51 | if("python.builtin.bytes" %in% class(zer)){ zer <- list(zer) }
52 | return( zer )
53 | }
54 |
55 | # zips two lists
56 | rzip<-function(l1,l2){ fl<-list(); for(aia in 1:length(l1)){ fl[[aia]] <- list(l1[[aia]], l2[[aia]]) }; return( fl ) }
57 |
58 | # reshapes
59 | reshape_fxn_DEPRECIATED <- function(input_){
60 | ## DEPRECIATED
61 | cienv$tf$reshape(input_, list(cienv$tf$shape(input_)[1],
62 | cienv$tf$reduce_prod(cienv$tf$shape(input_)[2:5])))
63 | }
64 |
65 | fixZeroEndings <- function(zr,roundAt=2){
66 | unlist( lapply(strsplit(as.character(zr),split="\\."),function(l_){
67 | if(length(l_) == 1){ retl <- paste(l_, paste(rep("0",times=roundAt),collapse=""),sep=".") }
68 | if(length(l_) == 2){
69 | retl <- paste(l_[1], paste(l_[2], paste(rep("0",times=roundAt-nchar(l_[2])),collapse=""),sep=""),
70 | sep = ".") }
71 | return( retl )
72 | }) ) }
73 |
74 | r2const <- function(x, dtype){
75 | if("tensorflow.tensor" %in% class( x )){ x <- cienv$tf$cast(x, dtype = dtype ) }
76 | if(!"tensorflow.tensor" %in% class( x )){ x <- cienv$tf$constant(x, dtype = dtype ) }
77 | return( x )
78 | }
79 |
80 | #' print2 print() with timestamps
81 | #'
82 | #' A function prints a string with date and time.
83 | #'
84 | #' @param x Character string to be printed, with date and time.
85 | #'
86 | #' @return Prints with date and time.
87 | #'
88 | #' @examples
89 | #' message("Hello world")
90 | #' @export
91 | #' @md
92 | #'
93 | print2 <- function(text, quiet = F){
94 | if(!quiet){ print( sprintf("[%s] %s" ,format(Sys.time(), "%Y-%m-%d %H:%M:%S"),text) ) }
95 | }
96 |
97 | #' message2 message() with timestamps
98 | #'
99 | #' A function that displays a message with date and time.
100 | #'
101 | #' @param text Character string to be displayed as message, with date and time.
102 | #' @param quiet Logical. If TRUE, suppresses the message output. Default is FALSE.
103 | #'
104 | #' @return Displays message with date and time to stderr.
105 | #'
106 | #' @examples
107 | #' message2("Hello world")
108 | #' message2("Process completed", quiet = FALSE)
109 | #' @export
110 | #' @md
111 | #'
112 | message2 <- function(text, quiet = FALSE){
113 | if(!quiet){
114 | message(sprintf("[%s] %s", format(Sys.time(), "%Y-%m-%d %H:%M:%S"), text))
115 | }
116 | }
117 |
118 | # LE <- function(l_, name_){ return( unlist(l_)[[name_]] ) }
119 | # l_ <- DenseList;name <-"Tau_d1"
120 | LE <- function(l_, key) {
121 | # Recursive helper function
122 | search_recursive <- function(list_element, key) {
123 | # Check if the current element is a list
124 | if (is.list(list_element)) {
125 | # If it's a list, check if the key exists in this list
126 | if (key %in% names(list_element)) {
127 | return(list_element[[key]])
128 | }
129 | # Otherwise, iterate over its elements
130 | for (item in list_element) {
131 | found <- search_recursive(item, key)
132 | if (!is.null(found)) {
133 | return(found)
134 | }
135 | }
136 | }
137 | return(NULL)
138 | }
139 |
140 | # Start the recursive search
141 | return(search_recursive(l_, key))
142 | }
143 |
144 | LE_index <- function(l_, key) {
145 | # Recursive helper function
146 | search_recursive <- function(list_element, key, path) {
147 | # Check if the current element is a list
148 | if(is.list(list_element)){
149 |
150 | # If it's a list, check if the key exists in this list
151 | if(key %in% names(list_element) & (length(names(list_element)) == 1)){
152 | return( c(path,
153 | ifelse("list" %in% class(list_element), yes = 1, no = NULL)) )
154 | }
155 | if(key %in% names(list_element) & (length(names(list_element)) > 1)){
156 | return( c(path,
157 | which(names(list_element) == key),
158 | ifelse("list" %in% class(list_element), yes = 1, no = NULL) ) )
159 | }
160 |
161 | # Otherwise, iterate over its elements
162 | for (i in seq_along(list_element)) {
163 | new_path <- c(path, i)
164 | found <- search_recursive(list_element[[i]], key, new_path)
165 | if (!is.null(found)) { return( found ) }
166 | }
167 | }
168 | return(NULL)
169 | }
170 |
171 | # Start the recursive search
172 | return(search_recursive(l_, key, c()))
173 | }
174 |
175 | GlobalPartition <- function(zer, eq_fxn){
176 | yes_branches <- rrapply::rrapply(zer,f=function(zerr){
177 | unlist(ifelse(eq_fxn(zerr),yes = list(zerr), no = list(NULL))[[1]])
178 | },how="list")
179 | no_branches <- rrapply::rrapply(zer,f=function(zerrr){
180 | unlist(ifelse(eq_fxn(zerrr),yes = list(NULL), no = list(zerrr))[[1]])
181 | },how="list")
182 | list(yes_branches,no_branches)
183 | }
184 | PartFxn <- function(zerz){ !"first_time_index" %in% names(zerz)}
185 |
186 | AddQuotes <- function(text) { gsub("\\[\\[([A-Za-z]\\w*)", "\\[\\['\\1'", text) }
187 | LinearizeNestedList <- function (NList,
188 | LinearizeDataFrames = FALSE,
189 | NameSep = "/",
190 | ForceNames = FALSE){
191 | stopifnot(is.character(NameSep), length(NameSep) == 1)
192 | stopifnot(is.logical(LinearizeDataFrames), length(LinearizeDataFrames) == 1)
193 | stopifnot(is.logical(ForceNames), length(ForceNames) == 1)
194 | if (!is.list(NList))
195 | return(NList)
196 | if (is.null(names(NList)) | ForceNames == TRUE)
197 | names(NList) <- as.character(1:length(NList))
198 | if (is.data.frame(NList) & LinearizeDataFrames == FALSE)
199 | return(NList)
200 | if (is.data.frame(NList) & LinearizeDataFrames == TRUE)
201 | return(as.list(NList))
202 | A <- 1
203 | B <- length(NList)
204 | while (A <= B) {
205 | Element <- NList[[A]]
206 | EName <- names(NList)[A]
207 | if (is.list(Element)) {
208 | Before <- if (A == 1)
209 | NULL
210 | else NList[1:(A - 1)]
211 | After <- if (A == B)
212 | NULL
213 | else NList[(A + 1):B]
214 | if (is.data.frame(Element)) {
215 | if (LinearizeDataFrames == TRUE) {
216 | Jump <- length(Element)
217 | NList[[A]] <- NULL
218 | if(is.null(names(Element)) | ForceNames == TRUE)
219 | names(Element) <- as.character(1:length(Element))
220 | Element <- as.list(Element)
221 | names(Element) <- paste(EName, names(Element),
222 | sep = NameSep)
223 | NList <- c(Before, Element, After)
224 | }
225 | Jump <- 1
226 | }
227 | else {
228 | NList[[A]] <- NULL
229 | if (is.null(names(Element)) | ForceNames == TRUE)
230 | names(Element) <- as.character(1:length(Element))
231 | Element <- LinearizeNestedList(Element, LinearizeDataFrames,
232 | NameSep, ForceNames)
233 | names(Element) <- AddQuotes( paste(EName,
234 | names(Element),
235 | sep = NameSep) )
236 | Jump <- length(Element)
237 | NList <- c(Before, Element, After)
238 | }
239 | }
240 | else {
241 | Jump <- 1
242 | }
243 | A <- A + Jump
244 | B <- length(NList)
245 | }
246 | return(NList)
247 | }
248 |
249 |
250 | ai <- as.integer
251 |
252 | se <- function(x){ x <- c(na.omit(x)); return(sqrt(var(x)/length(x)))}
253 |
254 | LocalFxnSource <- function(fxn, evaluation_environment){
255 | fxn_text <- paste(deparse(fxn), collapse="\n")
256 | fxn_text <- gsub(fxn_text,pattern="function \\(\\)", replace="")
257 | eval( parse( text = fxn_text ), envir = evaluation_environment )
258 | }
259 |
260 | FilterBN <- function(l_){ cienv$eq$partition(l_, function(l__){"first_time_index" %in% names(l__)}) }
261 |
262 | cienv <- new.env( parent = emptyenv() )
263 |
264 |
265 | dropout_layer_init <- function(p) {
266 | if (p == 0) { return( function(x, key, inference) { return( x ) } ) }
267 | if (p != 0) {
268 | keep_prob <- (1 - p)
269 | return(
270 | function(x, key, inference) {
271 | # Efficient dynamic branch: skip dropout at inference time
272 | cienv$jax$lax$cond(
273 | pred = inference,
274 | true_fun = function(args){ return(args[[1]]) },
275 | false_fun = function(args){
276 | mask <- cienv$jax$lax$stop_gradient(
277 | #cienv$jax$random$bernoulli(key = args[[2]], p = keep_prob, shape = args[[1]]$shape)$astype(args[[1]])
278 | cienv$jax$random$bernoulli(key = args[[2]], p = keep_prob, shape = args[[1]]$shape)$astype( args[[1]]$dtype )
279 | )
280 | return( args[[1]] * mask / keep_prob )
281 | },
282 | operand = list(x, key) # pack arguments
283 | ) }
284 | )
285 | }}
286 |
287 | wt_init <- function(shape, seed_key){
288 | init_std <- sqrt(2.0 / as.numeric(shape[[1]] + shape[[1]]))
289 | cienv$jax$random$normal(
290 | key = seed_key,
291 | shape = shape
292 | ) * cienv$jnp$array(init_std)$astype( cienv$jaxFloatType )
293 | }
294 |
295 |
--------------------------------------------------------------------------------
/causalimages/R/CI_image2.R:
--------------------------------------------------------------------------------
1 | #' Visualizing matrices as heatmaps with correct north-south-east-west orientation
2 | #'
3 | #' A function for generating a heatmap representation of a matrix with correct spatial orientation.
4 | #'
5 | #' @param x The numeric matrix to be visualized.
6 | #' @param xlab The x-axis labels.
7 | #' @param ylab The y-axis labels.
8 | #' @param xaxt The x-axis tick labels.
9 | #' @param yaxt The y-axis tick labels.
10 | #' @param main The main figure label.
11 | #' @param cex.main The main figure label sizing factor.
12 | #' @param col.lab Axis label color.
13 | #' @param col.main Main label color.
14 | #' @param cex.lab Cex for the labels.
15 | #' @param box Draw a box around the image?
16 |
17 | #' @return Returns a heatmap representation of the matrix, `x`, with correct north/south/east/west orientation.
18 | #'
19 | #' @examples
20 | #' #set seed
21 | #' set.seed(1)
22 | #'
23 | #' #Geneate data
24 | #' x <- matrix(rnorm(50*50), ncol = 50)
25 | #' diag(x) <- 3
26 | #'
27 | #' # create plot
28 | #' image2(x, main = "Example Text", cex.main = 2)
29 | #'
30 | #' @export
31 | #' @md
32 |
33 | image2 = function(x,
34 | xaxt=NULL,
35 | yaxt = NULL,
36 | xlab = "",
37 | ylab = "",
38 | main=NULL,
39 | cex.main = NULL,
40 | col.lab = "black",
41 | col.main = "black",
42 | cex.lab = 1.5,
43 | box=F){
44 | image((t(x)[,nrow(x):1]),
45 | axes = F,
46 | main = main,
47 | xlab = xlab,
48 | ylab = ylab,
49 | xaxs = "i",
50 | cex.lab = cex.lab,
51 | col.lab = col.lab,
52 | col.main = col.main,
53 | cex.main = cex.main)
54 | if(box == T){box()}
55 | if(!is.null(xaxt)){ axis(1, at = 0:(nrow(x)-1)/nrow(x)*1.04, tick=F,labels = (xaxt),cex.axis = 1,las = 1) }
56 | if(!is.null(yaxt)){ axis(2, at = 0:(nrow(x)-1)/nrow(x)*1.04, tick=F,labels = rev(yaxt),cex.axis = 1,las = 2) }
57 | }
58 |
59 | f2n <- function(.){as.numeric(as.character(.))}
60 |
--------------------------------------------------------------------------------
/causalimages/data/CausalImagesTutorialData.RData:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cjerzak/causalimages-software/a6ca8dd8aae433c5207dedaa89e2c099d5700b8f/causalimages/data/CausalImagesTutorialData.RData
--------------------------------------------------------------------------------
/causalimages/data/datalist:
--------------------------------------------------------------------------------
1 | CausalImagesTutorialData: FullImageArray KeysOfImages KeysOfObservations LongLat obsW obsY X
2 |
--------------------------------------------------------------------------------
/causalimages/man/AnalyzeImageConfounding.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/CI_Confounding.R
3 | \name{AnalyzeImageConfounding}
4 | \alias{AnalyzeImageConfounding}
5 | \title{Perform causal estimation under image confounding}
6 | \usage{
7 | AnalyzeImageConfounding(
8 | obsW,
9 | obsY,
10 | X = NULL,
11 | file = NULL,
12 | imageKeysOfUnits = NULL,
13 | fileTransport = NULL,
14 | imageKeysOfUnitsTransport = NULL,
15 | nBoot = 10L,
16 | inputAvePoolingSize = 1L,
17 | useTrainingPertubations = T,
18 | useScalePertubations = F,
19 | crossFit = FALSE,
20 | augmented = FALSE,
21 | orthogonalize = F,
22 | transportabilityMat = NULL,
23 | latTransport = NULL,
24 | longTransport = NULL,
25 | lat = NULL,
26 | long = NULL,
27 | conda_env = "CausalImagesEnv",
28 | conda_env_required = T,
29 | Sys.setenv_text = NULL,
30 | figuresTag = NULL,
31 | figuresPath = "./",
32 | plotBands = 1L,
33 | plotResults = T,
34 | XCrossModal = T,
35 | XForceModal = F,
36 | optimizeImageRep = T,
37 | nonLinearScaler = NULL,
38 | nWidth_ImageRep = 64L,
39 | nDepth_ImageRep = 1L,
40 | kernelSize = 5L,
41 | nWidth_Dense = 64L,
42 | nDepth_Dense = 1L,
43 | imageModelClass = "VisionTransformer",
44 | pretrainedModel = NULL,
45 | strides = 2L,
46 | nDepth_TemporalRep = 3L,
47 | patchEmbedDim = 16L,
48 | dropoutRate = 0.1,
49 | droppathRate = 0.1,
50 | batchSize = 16L,
51 | nSGD = 400L,
52 | testFrac = 0.05,
53 | TfRecords_BufferScaler = 4L,
54 | learningRateMax = 0.001,
55 | TFRecordControl = NULL,
56 | dataType = "image",
57 | image_dtype = "float16",
58 | atError = "stop",
59 | seed = NULL
60 | )
61 | }
62 | \arguments{
63 | \item{obsW}{A numeric vector where \code{0}'s correspond to control units and \code{1}'s to treated units.}
64 |
65 | \item{obsY}{A numeric vector containing observed outcomes.}
66 |
67 | \item{X}{An optional numeric matrix containing tabular information used if \code{orthogonalize = T}. \code{X} is normalized internally and salience maps with respect to \code{X} are transformed back to the original scale.}
68 |
69 | \item{file}{Path to a tfrecord file generated by \code{WriteTfRecord}.}
70 |
71 | \item{imageKeysOfUnits}{A vector of length \code{length(obsY)} specifying the unique image ID associated with each unit. Samples of \code{imageKeysOfUnits} are fed into the package to call images into memory.}
72 |
73 | \item{nBoot}{Number of bootstrap iterations for uncertainty estimation.}
74 |
75 | \item{useTrainingPertubations}{Boolean specifying whether to randomly perturb the image axes during training to reduce overfitting.}
76 |
77 | \item{transportabilityMat}{Optional matrix with a column named \code{imageKeysOfUnits} specifying keys to be used by the package for generating treatment effect predictions for out-of-sample points.}
78 |
79 | \item{long, lat}{Optional vectors specifying longitude and latitude coordinates for units. Used only for describing highest and lowest probability neighborhood units if specified.}
80 |
81 | \item{conda_env}{A \code{conda} environment where computational environment lives, usually created via \code{causalimages::BuildBackend()}. Default = \code{"CausalImagesEnv"}.}
82 |
83 | \item{conda_env_required}{A Boolean stating whether use of the specified conda environment is required.}
84 |
85 | \item{figuresTag}{A string specifying an identifier that is appended to all figure names.}
86 |
87 | \item{figuresPath}{A string specifying file path for saved figures made in the analysis.}
88 |
89 | \item{plotBands}{An integer or vector specifying which band position (from the image representation) should be plotted in the visual results. If a vector, \code{plotBands} should have 3 (and only 3) dimensions (corresponding to the 3 dimensions to be used in RGB plotting).}
90 |
91 | \item{plotResults}{(default = \code{T}) Should analysis results be plotted?}
92 |
93 | \item{optimizeImageRep}{Boolean specifying whether to optimize over the image model representation (or only over downstream parameters).}
94 |
95 | \item{nWidth_ImageRep}{Integer specifying width of image model representation.}
96 |
97 | \item{nDepth_ImageRep}{Integer specifying depth of image model representation.}
98 |
99 | \item{kernelSize}{Dimensions used in spatial convolutions.}
100 |
101 | \item{nWidth_Dense}{Integer specifying width of image model representation.}
102 |
103 | \item{nDepth_Dense}{Integer specifying depth of dense model representation.}
104 |
105 | \item{strides}{(default = \code{2L}) Integer specifying the strides used in the convolutional layers.}
106 |
107 | \item{dropoutRate}{Dropout rate used in training to prevent overfitting (\code{dropoutRate = 0} corresponds to no dropout).}
108 |
109 | \item{droppathRate}{Droppath rate used in training to prevent overfitting (\code{droppathRate = 0} corresponds to no droppath).}
110 |
111 | \item{batchSize}{Batch size used in SGD optimization. Default = \code{50L}.}
112 |
113 | \item{nSGD}{Number of stochastic gradient descent (SGD) iterations. Default = \code{400L}}
114 |
115 | \item{testFrac}{Default = \code{0.1}. Fraction of observations held out as a test set to evaluate out-of-sample loss values.}
116 |
117 | \item{TfRecords_BufferScaler}{The buffer size used in \code{tfrecords} mode is \code{batchSize*TfRecords_BufferScaler}. Lower \code{TfRecords_BufferScaler} towards 1 if out-of-memory problems.}
118 |
119 | \item{dataType}{(default = \code{"image"}) String specifying whether to assume \code{"image"} or \code{"video"} data types.}
120 | }
121 | \value{
122 | Returns a list consisting of
123 | \itemize{
124 | \item \code{ATE_est} ATE estimate.
125 | \item \code{ATE_se} Standard error estimate for the ATE.
126 | \item \code{plotResults} If set to \code{TRUE}, causal salience plots are saved to disk, characterizing the image confounding structure. See references for details.
127 | }
128 | }
129 | \description{
130 | Perform causal estimation under image confounding
131 | }
132 | \section{References}{
133 |
134 | \itemize{
135 | \item Connor T. Jerzak, Fredrik Johansson, Adel Daoud. Integrating Earth Observation Data into Causal Inference: Challenges and Opportunities. \emph{ArXiv Preprint}, 2023.
136 | }
137 | }
138 |
139 | \examples{
140 | # For a tutorial, see
141 | # github.com/cjerzak/causalimages-software/
142 |
143 | }
144 |
--------------------------------------------------------------------------------
/causalimages/man/AnalyzeImageHeterogeneity.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/CI_Heterogeneity.R
3 | \name{AnalyzeImageHeterogeneity}
4 | \alias{AnalyzeImageHeterogeneity}
5 | \title{Decompose treatment effect heterogeneity by image or image sequence}
6 | \usage{
7 | AnalyzeImageHeterogeneity(
8 | obsW,
9 | obsY,
10 | X = NULL,
11 | orthogonalize = F,
12 | imageKeysOfUnits = 1:length(obsY),
13 | kClust_est = 2,
14 | file = NULL,
15 | transportabilityMat = NULL,
16 | lat = NULL,
17 | long = NULL,
18 | conda_env = "CausalImagesEnv",
19 | conda_env_required = T,
20 | figuresTag = "",
21 | figuresPath = "./",
22 | plotBands = 1L,
23 | heterogeneityModelType = "variational_minimal",
24 | plotResults = F,
25 | optimizeImageRep = T,
26 | nWidth_ImageRep = 64L,
27 | nDepth_ImageRep = 1L,
28 | nWidth_Dense = 64L,
29 | nDepth_Dense = 1L,
30 | nDepth_TemporalRep = 1L,
31 | useTrainingPertubations = T,
32 | strides = 2L,
33 | nonLinearScaler = NULL,
34 | pretrainedModel = NULL,
35 | testFrac = 0.1,
36 | kernelSize = 5L,
37 | learningRateMax = 0.001,
38 | TFRecordControl = NULL,
39 | patchEmbedDim = 16L,
40 | nSGD = 500L,
41 | batchSize = 16L,
42 | seed = NULL,
43 | Sys.setenv_text = NULL,
44 | imageModelClass = "VisionTransformer",
45 | nMonte_predictive = 10L,
46 | nMonte_salience = 10L,
47 | nMonte_variational = 2L,
48 | TfRecords_BufferScaler = 4L,
49 | temperature = 1,
50 | inputAvePoolingSize = 1L,
51 | dataType = "image"
52 | )
53 | }
54 | \arguments{
55 | \item{obsW}{A numeric vector where \code{0}'s correspond to control units and \code{1}'s to treated units.}
56 |
57 | \item{obsY}{A numeric vector containing observed outcomes.}
58 |
59 | \item{X}{Optional numeric matrix containing tabular information used if \code{orthogonalize = T}.}
60 |
61 | \item{orthogonalize}{A Boolean specifying whether to perform the image decomposition after orthogonalizing with respect to tabular covariates specified in \code{X}.}
62 |
63 | \item{imageKeysOfUnits}{A vector of length \code{length(obsY)} specifying the unique image ID associated with each unit. Samples of \code{imageKeysOfUnits} are fed into the package to call images into memory.}
64 |
65 | \item{kClust_est}{Integer specifying the number of clusters used in estimation. Default is \code{2L}.}
66 |
67 | \item{file}{Path to a tfrecord file generated by \code{WriteTfRecord}.}
68 |
69 | \item{transportabilityMat}{An optional matrix with a column named \code{key} specifying keys to be used for generating treatment effect predictions for out-of-sample points in earth observation data settings.}
70 |
71 | \item{long, lat}{Optional vectors specifying longitude and latitude coordinates for units. Used only for describing highest and lowest probability neighborhood units if specified.}
72 |
73 | \item{conda_env}{A \code{conda} environment where computational environment lives, usually created via \code{causalimages::BuildBackend()}. Default = \code{"CausalImagesEnv"}.}
74 |
75 | \item{conda_env_required}{A Boolean stating whether use of the specified conda environment is required.}
76 |
77 | \item{figuresTag}{A string specifying an identifier that is appended to all figure names.}
78 |
79 | \item{figuresPath}{A string specifying file path for saved figures made in the analysis.}
80 |
81 | \item{plotBands}{An integer or vector specifying which band position (from the acquired image representation) should be plotted in the visual results. If a vector, \code{plotBands} should have 3 (and only 3) dimensions (corresponding to the 3 dimensions to be used in RGB plotting).}
82 |
83 | \item{plotResults}{Should analysis results be plotted?}
84 |
85 | \item{optimizeImageRep}{Boolean specifying whether to optimize over the image model representation (or only over downstream parameters).}
86 |
87 | \item{nWidth_ImageRep}{Integer specifying width of image model representation.}
88 |
89 | \item{nDepth_ImageRep}{Integer specifying depth of image model representation.}
90 |
91 | \item{nWidth_Dense}{Integer specifying width of image model representation.}
92 |
93 | \item{nDepth_Dense}{Integer specifying depth of dense model representation.}
94 |
95 | \item{strides}{Integer specifying the strides used in the convolutional layers.}
96 |
97 | \item{kernelSize}{Dimensions used in spatial convolutions.}
98 |
99 | \item{nSGD}{Number of stochastic gradient descent (SGD) iterations.}
100 |
101 | \item{batchSize}{Batch size used in SGD optimization.}
102 |
103 | \item{nMonte_predictive}{An integer specifying how many Monte Carlo iterations to use in the calculation
104 | of posterior means (e.g., mean cluster probabilities).}
105 |
106 | \item{nMonte_salience}{An integer specifying how many Monte Carlo iterations to use in the calculation
107 | of the salience maps (e.g., image gradients of expected cluster probabilities).}
108 |
109 | \item{nMonte_variational}{An integer specifying how many Monte Carlo iterations to use in the
110 | calculation of the expected likelihood in each training step.}
111 |
112 | \item{TfRecords_BufferScaler}{The buffer size used in \code{tfrecords} mode is \code{batchSize*TfRecords_BufferScaler}. Lower \code{TfRecords_BufferScaler} towards 1 if out-of-memory problems.}
113 |
114 | \item{dataType}{String specifying whether to assume \code{"image"} or \code{"video"} data types.}
115 | }
116 | \value{
117 | Returns a list consisting of \itemize{
118 | \item \code{clusterTaus_mean} default
119 | \item \code{clusterProbs_mean}. Estimated mean image effect cluster probabilities.
120 | \item \code{clusterTaus_sigma}. Estimated cluster standard deviations.
121 | \item \code{clusterProbs_lowerConf}. Estimated lower confidence for effect cluster probabilities.
122 | \item \code{impliedATE}. Implied ATE.
123 | \item \code{individualTau_est}. Estimated individual-level image-based treatment effects.
124 | \item \code{transportabilityMat}. Transportability matrix with estimated cluster information.
125 | \item \code{plottedCoordinates}. List containing coordinates plotted in salience maps.
126 | \item \code{whichNA_dropped}. A vector containing observations dropped due to missingness.
127 | }
128 | }
129 | \description{
130 | Implements the image heterogeneity decomposition analysis of Jerzak, Johansson, and Daoud (2023). Users
131 | input in treatment and outcome data, along with a function specifying how to load in images
132 | using keys referenced to each unit (since loading in all image data will usually not be possible due to memory limitations).
133 | This function by default performs estimation, constructs salience maps, and can optionally perform
134 | estimation for new areas outside the original study sites in a transportability analysis.
135 | }
136 | \section{References}{
137 |
138 | \itemize{
139 | \item Connor T. Jerzak, Fredrik Johansson, Adel Daoud. Image-based Treatment Effect Heterogeneity. Forthcoming in \emph{Proceedings of the Second Conference on Causal Learning and Reasoning (CLeaR), Proceedings of Machine Learning Research (PMLR)}, 2023.
140 | }
141 | }
142 |
143 | \examples{
144 | # For a tutorial, see
145 | # github.com/cjerzak/causalimages-software/
146 |
147 | }
148 |
--------------------------------------------------------------------------------
/causalimages/man/BuildBackend.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/CI_BuildBackend.R
3 | \name{BuildBackend}
4 | \alias{BuildBackend}
5 | \title{Build the environment for CausalImages models. Builds a conda environment in which jax, tensorflow, tensorflow-probability optax, equinox, and jmp are installed.}
6 | \usage{
7 | BuildBackend(conda_env = "CausalImagesEnv", conda = "auto")
8 | }
9 | \arguments{
10 | \item{conda_env}{(default = \code{"CausalImagesEnv"}) Name of the conda environment in which to place the backends.}
11 |
12 | \item{conda}{(default = \code{auto}) The path to a conda executable. Using \code{"auto"} allows reticulate to attempt to automatically find an appropriate conda binary.}
13 | }
14 | \value{
15 | Builds the computational environment for \code{causalimages}. This function requires an Internet connection.
16 | You may find out a list of conda Python paths via: \code{system("which python")}
17 | }
18 | \description{
19 | Build the environment for CausalImages models. Builds a conda environment in which jax, tensorflow, tensorflow-probability optax, equinox, and jmp are installed.
20 | }
21 | \examples{
22 | # For a tutorial, see
23 | # github.com/cjerzak/causalimages-software/
24 |
25 | }
26 |
--------------------------------------------------------------------------------
/causalimages/man/GetAndSaveGeolocatedImages.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/CI_GetAndSaveGeolocatedImages.R
3 | \name{GetAndSaveGeolocatedImages}
4 | \alias{GetAndSaveGeolocatedImages}
5 | \title{Getting and saving geo-located images from a pool of .tif's}
6 | \usage{
7 | GetAndSaveGeolocatedImages(
8 | long,
9 | lat,
10 | keys,
11 | tif_pool,
12 | image_pixel_width = 256L,
13 | save_folder = ".",
14 | save_as = "csv",
15 | lyrs = NULL
16 | )
17 | }
18 | \arguments{
19 | \item{long}{Vector of numeric longitudes.}
20 |
21 | \item{lat}{Vector of numeric latitudes.}
22 |
23 | \item{keys}{The image keys associated with the long/lat coordinates.}
24 |
25 | \item{tif_pool}{A character vector specifying the fully qualified path to a corpus of .tif files.}
26 |
27 | \item{image_pixel_width}{An even integer specifying the pixel width (and height) of the saved images.}
28 |
29 | \item{save_folder}{(default = \code{"."}) What folder should be used to save the output? Example: \code{"~/Downloads"}}
30 |
31 | \item{save_as}{(default = \code{".csv"}) What format should the output be saved as? Only one option currently (\code{.csv})}
32 |
33 | \item{lyrs}{(default = NULL) Integer (vector) specifying the layers to be extracted. Default is for all layers to be extracted.}
34 | }
35 | \value{
36 | Finds the image slice associated with the \code{long} and \code{lat} values, saves images by band (if \code{save_as = "csv"}) in save_folder.
37 | The save format is: \code{sprintf("\%s/Key\%s_BAND\%s.csv", save_folder, keys[i], band_)}
38 | }
39 | \description{
40 | A function that finds the image slice associated with the \code{long} and \code{lat} values, saves images by band (if \code{save_as = "csv"}) in save_folder.
41 | }
42 | \examples{
43 |
44 | # Example use (not run):
45 | #MASTER_IMAGE_POOL_FULL_DIR <- c("./LargeTifs/tif1.tif","./LargeTifs/tif2.tif")
46 | #GetAndSaveGeolocatedImages(
47 | #long = GeoKeyMat$geo_long,
48 | #lat = GeoKeyMat$geo_lat,
49 | #image_pixel_width = 500L,
50 | #keys = row.names(GeoKeyMat),
51 | #tif_pool = MASTER_IMAGE_POOL_FULL_DIR,
52 | #save_folder = "./Data/Uganda2000_processed",
53 | #save_as = "csv",
54 | #lyrs = NULL)
55 |
56 | }
57 |
--------------------------------------------------------------------------------
/causalimages/man/GetElementFromTfRecordAtIndices.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/CI_TfRecordFxns.R
3 | \name{GetElementFromTfRecordAtIndices}
4 | \alias{GetElementFromTfRecordAtIndices}
5 | \title{Reads unique key indices from a \code{.tfrecord} file.}
6 | \usage{
7 | GetElementFromTfRecordAtIndices(uniqueKeyIndices, file,
8 | conda_env, conda_env_required)
9 | }
10 | \arguments{
11 | \item{uniqueKeyIndices}{(integer vector) Unique image indices to be retrieved from a \code{.tfrecord}}
12 |
13 | \item{conda_env}{(Default = \code{NULL}) A \code{conda} environment where tensorflow v2 lives. Used only if a version of tensorflow is not already active.}
14 |
15 | \item{conda_env_required}{(default = \code{F}) A Boolean stating whether use of the specified conda environment is required.}
16 |
17 | \item{file}{(character string) A character string stating the path to a \code{.tfrecord}}
18 | }
19 | \value{
20 | Returns content from a \code{.tfrecord} associated with \code{uniqueKeyIndices}
21 | }
22 | \description{
23 | Reads unique key indices from a \code{.tfrecord} file saved via a call to \code{causalimages::WriteTfRecord}.
24 | }
25 | \examples{
26 | # Example usage (not run):
27 | #GetElementFromTfRecordAtIndices(
28 | #uniqueKeyIndices = 1:10,
29 | #file = "./NigeriaConfoundApp.tfrecord")
30 |
31 | }
32 |
--------------------------------------------------------------------------------
/causalimages/man/GetImageRepresentations.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/CI_ImageModelBackbones.R
3 | \name{GetImageRepresentations}
4 | \alias{GetImageRepresentations}
5 | \title{Generates image and video representations useful in earth observation tasks for casual inference.}
6 | \usage{
7 | GetImageRepresentations(
8 | X = NULL,
9 | imageKeysOfUnits = NULL,
10 | file = NULL,
11 | conda_env = "CausalImagesEnv",
12 | conda_env_required = T,
13 | returnContents = T,
14 | getRepresentations = T,
15 | imageModelClass = "VisionTransformer",
16 | NORM_MEAN = NULL,
17 | NORM_SD = NULL,
18 | Sys.setenv_text = NULL,
19 | InitImageProcess = NULL,
20 | pretrainedModel = NULL,
21 | lat = NULL,
22 | long = NULL,
23 | image_dtype = NULL,
24 | image_dtype_tf = NULL,
25 | XCrossModal = T,
26 | XForceModal = F,
27 | nWidth_ImageRep = 64L,
28 | nDepth_ImageRep = 1L,
29 | nDepth_TemporalRep = 1L,
30 | batchSize = 16L,
31 | nonLinearScaler = NULL,
32 | optimizeImageRep = T,
33 | strides = 1L,
34 | kernelSize = 3L,
35 | patchEmbedDim = 16L,
36 | TfRecords_BufferScaler = 10L,
37 | dropoutRate,
38 | droppathRate,
39 | dataType = "image",
40 | bn_momentum = 0.99,
41 | inputAvePoolingSize = 1L,
42 | CleanupEnv = FALSE,
43 | initializingFxns = FALSE,
44 | seed = NULL
45 | )
46 | }
47 | \arguments{
48 | \item{imageKeysOfUnits}{A vector of length \code{length(imageKeysOfUnits)} specifying the unique image ID associated with each unit. Samples of \code{imageKeysOfUnits} are fed into the package to call images into memory.}
49 |
50 | \item{file}{Path to a tfrecord file generated by \code{causalimages::WriteTfRecord}.}
51 |
52 | \item{conda_env}{A \code{conda} environment where computational environment lives, usually created via \code{causalimages::BuildBackend()}. Default = \code{"CausalImagesEnv"}}
53 |
54 | \item{conda_env_required}{A Boolean stating whether use of the specified conda environment is required.}
55 |
56 | \item{InitImageProcess}{(default = \code{NULL}) Initial image processing function. Usually left \code{NULL}.}
57 |
58 | \item{nWidth_ImageRep}{Number of embedding features output.}
59 |
60 | \item{batchSize}{Integer specifying batch size in obtaining representations.}
61 |
62 | \item{strides}{Integer specifying the strides used in the convolutional layers.}
63 |
64 | \item{kernelSize}{Dimensions used in the convolution kernels.}
65 |
66 | \item{TfRecords_BufferScaler}{The buffer size used in \code{tfrecords} mode is \code{batchSize*TfRecords_BufferScaler}. Lower \code{TfRecords_BufferScaler} towards 1 if out-of-memory problems.}
67 |
68 | \item{dataType}{String specifying whether to assume \code{"image"} or \code{"video"} data types. Default is \code{"image"}.}
69 | }
70 | \value{
71 | A list containing two items:
72 | \itemize{
73 | \item \code{Representations} (matrix) A matrix containing image/video representations, with rows corresponding to observations.
74 | \item \verb{ImageRepArm_OneObs,ImageRepArm_batch_R, ImageRepArm_batch} (functions) Image modeling functions.
75 | \item \code{ImageModel_And_State_And_MPPolicy_List} List containing image model parameters fed into functions.
76 | }
77 | }
78 | \description{
79 | Generates image and video representations useful in earth observation tasks for casual inference.
80 | }
81 | \section{References}{
82 |
83 | \itemize{
84 | \item Rolf, Esther, et al. "A generalizable and accessible approach to machine learning with global satellite imagery." \emph{Nature Communications} 12.1 (2021): 4392.
85 | }
86 | }
87 |
88 | \examples{
89 | # For a tutorial, see
90 | # github.com/cjerzak/causalimages-software/
91 |
92 | }
93 |
--------------------------------------------------------------------------------
/causalimages/man/GetMoments.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/CI_GetMoments.R
3 | \name{GetMoments}
4 | \alias{GetMoments}
5 | \title{Get moments for normalization (internal function)}
6 | \usage{
7 | GetMoments(iterator, dataType, image_dtype, momentCalIters = 34L)
8 | }
9 | \arguments{
10 | \item{iterator}{An iterator}
11 |
12 | \item{dataType}{A string denoting data type}
13 |
14 | \item{momentCalIters}{Number of minibatches with which to estimate moments}
15 | }
16 | \value{
17 | Returns mean/sd arrays for normalization.
18 | }
19 | \description{
20 | An internal function function for obtaining moments for channel normalization.
21 | }
22 | \examples{
23 | # (Not run)
24 | # GetMoments(iterator, dataType, image_dtype, momentCalIters = 34L)
25 | }
26 |
--------------------------------------------------------------------------------
/causalimages/man/LongLat2CRS.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/CI_helperFxns.R
3 | \name{LongLat2CRS}
4 | \alias{LongLat2CRS}
5 | \title{Get the spatial point of long/lat coordinates}
6 | \usage{
7 | LongLat2CRS(long, lat, CRS_ref)
8 | }
9 | \arguments{
10 | \item{long}{Vector of numeric longitudes.}
11 |
12 | \item{lat}{Vector of numeric latitudes.}
13 |
14 | \item{CRS_ref}{A CRS into which the long-lat point should be projected.}
15 | }
16 | \value{
17 | Numeric vector of length two giving the coordinates of the supplied
18 | location in the CRS defined by \code{CRS_ref}.
19 | }
20 | \description{
21 | Convert longitude and latitude coordinates to a different coordinate reference
22 | system (CRS).
23 | }
24 | \examples{
25 | # (Not run)
26 | #spatialPt <- LongLat2CRS(long = 49.932,
27 | # lat = 35.432,
28 | # CRS_ref = sf::st_crs("+proj=lcc +lat_1=48 +lat_2=33 +lon_0=-100 +ellps=WGS84"))
29 | }
30 |
--------------------------------------------------------------------------------
/causalimages/man/PredictiveRun.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/CI_PredictiveRun.R
3 | \name{PredictiveRun}
4 | \alias{PredictiveRun}
5 | \title{Perform predictive modeling using images or videos}
6 | \usage{
7 | PredictiveRun(
8 | obsY,
9 | imageKeysOfUnits = NULL,
10 | file = NULL,
11 | fileTransport = NULL,
12 | imageKeysOfUnitsTransport = NULL,
13 | nBoot = 10L,
14 | inputAvePoolingSize = 1L,
15 | useTrainingPertubations = T,
16 | useScalePertubations = F,
17 | X = NULL,
18 | conda_env = "CausalImagesEnv",
19 | conda_env_required = T,
20 | Sys.setenv_text = NULL,
21 | figuresTag = NULL,
22 | figuresPath = "./",
23 | plotBands = 1L,
24 | plotResults = T,
25 | XCrossModal = T,
26 | XForceModal = F,
27 | optimizeImageRep = T,
28 | nWidth_ImageRep = 64L,
29 | nDepth_ImageRep = 1L,
30 | kernelSize = 5L,
31 | nWidth_Dense = 64L,
32 | nDepth_Dense = 1L,
33 | imageModelClass = "VisionTransformer",
34 | pretrainedModel = NULL,
35 | strides = 2L,
36 | nonLinearScaler = NULL,
37 | nDepth_TemporalRep = 3L,
38 | patchEmbedDim = 16L,
39 | dropoutRate = 0.1,
40 | droppathRate = 0.1,
41 | batchSize = 16L,
42 | nSGD = 400L,
43 | testFrac = 0.05,
44 | TfRecords_BufferScaler = 4L,
45 | learningRateMax = 0.001,
46 | TFRecordControl = NULL,
47 | dataType = "image",
48 | image_dtype = "float16",
49 | atError = "stop",
50 | seed = NULL,
51 | modelPath = "./trained_model.eqx",
52 | metricsPath = "./evaluation_metrics.rds"
53 | )
54 | }
55 | \arguments{
56 | \item{obsY}{A numeric vector containing observed outcomes to predict.}
57 |
58 | \item{imageKeysOfUnits}{A vector of length \code{length(obsY)} specifying the unique image ID associated with each unit. Samples of \code{imageKeysOfUnits} are fed into the package to call images into memory.}
59 |
60 | \item{file}{Path to a tfrecord file generated by \code{WriteTfRecord}.}
61 |
62 | \item{nBoot}{Number of bootstrap iterations for uncertainty estimation.}
63 |
64 | \item{useTrainingPertubations}{Boolean specifying whether to randomly perturb the image axes during training to reduce overfitting.}
65 |
66 | \item{X}{An optional numeric matrix containing tabular information. \code{X} is normalized internally.}
67 |
68 | \item{conda_env}{A \code{conda} environment where computational environment lives, usually created via \code{causalimages::BuildBackend()}. Default = \code{"CausalImagesEnv"}.}
69 |
70 | \item{conda_env_required}{A Boolean stating whether use of the specified conda environment is required.}
71 |
72 | \item{figuresTag}{A string specifying an identifier that is appended to all figure names.}
73 |
74 | \item{figuresPath}{A string specifying file path for saved figures made in the analysis.}
75 |
76 | \item{plotBands}{An integer or vector specifying which band position (from the image representation) should be plotted in the visual results. If a vector, \code{plotBands} should have 3 (and only 3) dimensions (corresponding to the 3 dimensions to be used in RGB plotting).}
77 |
78 | \item{plotResults}{(default = \code{T}) Should analysis results be plotted?}
79 |
80 | \item{optimizeImageRep}{Boolean specifying whether to optimize over the image model representation (or only over downstream parameters).}
81 |
82 | \item{nWidth_ImageRep}{Integer specifying width of image model representation.}
83 |
84 | \item{nDepth_ImageRep}{Integer specifying depth of image model representation.}
85 |
86 | \item{kernelSize}{Dimensions used in spatial convolutions.}
87 |
88 | \item{nWidth_Dense}{Integer specifying width of image model representation.}
89 |
90 | \item{nDepth_Dense}{Integer specifying depth of dense model representation.}
91 |
92 | \item{strides}{(default = \code{2L}) Integer specifying the strides used in the convolutional layers.}
93 |
94 | \item{dropoutRate}{Dropout rate used in training to prevent overfitting (\code{dropoutRate = 0} corresponds to no dropout).}
95 |
96 | \item{droppathRate}{Droppath rate used in training to prevent overfitting (\code{droppathRate = 0} corresponds to no droppath).}
97 |
98 | \item{batchSize}{Batch size used in SGD optimization. Default = \code{50L}.}
99 |
100 | \item{nSGD}{Number of stochastic gradient descent (SGD) iterations. Default = \code{400L}}
101 |
102 | \item{testFrac}{Default = \code{0.1}. Fraction of observations held out as a test set to evaluate out-of-sample loss values.}
103 |
104 | \item{TfRecords_BufferScaler}{The buffer size used in \code{tfrecords} mode is \code{batchSize*TfRecords_BufferScaler}. Lower \code{TfRecords_BufferScaler} towards 1 if out-of-memory problems.}
105 |
106 | \item{dataType}{(default = \code{"image"}) String specifying whether to assume \code{"image"} or \code{"video"} data types.}
107 |
108 | \item{modelPath}{Path to save the trained model. Default = \code{"./trained_model.eqx"}.}
109 |
110 | \item{metricsPath}{Path to save the evaluation metrics as a RDS file. Default = \code{"./evaluation_metrics.rds"}.}
111 |
112 | \item{transportabilityMat}{Optional matrix with a column named \code{imageKeysOfUnits} specifying keys to be used by the package for generating predictions for out-of-sample points.}
113 | }
114 | \value{
115 | Returns a list consisting of
116 | \itemize{
117 | \item \code{predictedY} Predicted values for all units.
118 | \item \code{ModelEvaluationMetrics} Rigorous evaluation metrics (e.g., MSE, R2 for continuous; AUC, accuracy for binary).
119 | }
120 | }
121 | \description{
122 | Perform predictive modeling using images or videos
123 | }
124 | \section{References}{
125 |
126 | \itemize{
127 | \item Connor T. Jerzak, Fredrik Johansson, Adel Daoud. Integrating Earth Observation Data into Causal Inference: Challenges and Opportunities. \emph{ArXiv Preprint}, 2023.
128 | }
129 | }
130 |
131 | \examples{
132 | # For a tutorial, see
133 | # github.com/cjerzak/causalimages-software/
134 |
135 | }
136 |
--------------------------------------------------------------------------------
/causalimages/man/TFRecordManagement.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/CI_TFRecordManagement.R
3 | \name{TFRecordManagement}
4 | \alias{TFRecordManagement}
5 | \title{Defines an internal TFRecord management routine (internal function)}
6 | \usage{
7 | TFRecordManagement()
8 | }
9 | \arguments{
10 | \item{.}{No parameters.}
11 | }
12 | \value{
13 | Internal function defining a tfrecord management sequence.
14 | }
15 | \description{
16 | Defines management defined in TFRecordManagement(). Internal function.
17 | }
18 |
--------------------------------------------------------------------------------
/causalimages/man/TrainDefine.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/CI_TrainDefine.R
3 | \name{TrainDefine}
4 | \alias{TrainDefine}
5 | \title{Defines an internal training routine (internal function)}
6 | \usage{
7 | TrainDefine()
8 | }
9 | \arguments{
10 | \item{.}{No parameters.}
11 | }
12 | \value{
13 | Internal function defining a training sequence.
14 | }
15 | \description{
16 | Defines trainers defined in TrainDefine(). Internal function.
17 | }
18 |
--------------------------------------------------------------------------------
/causalimages/man/TrainDo.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/CI_TrainDo.R
3 | \name{TrainDo}
4 | \alias{TrainDo}
5 | \title{Runs a training routine (internal function)}
6 | \usage{
7 | TrainDo()
8 | }
9 | \arguments{
10 | \item{.}{No parameters.}
11 | }
12 | \value{
13 | Internal function performing model training.
14 | }
15 | \description{
16 | Runs trainers defined in TrainDefine(). Internal function.
17 | }
18 |
--------------------------------------------------------------------------------
/causalimages/man/WriteTfRecord.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/CI_TfRecordFxns.R
3 | \name{WriteTfRecord}
4 | \alias{WriteTfRecord}
5 | \title{Write an image corpus as a .tfrecord file}
6 | \usage{
7 | WriteTfRecord(
8 | file,
9 | uniqueImageKeys,
10 | acquireImageFxn,
11 | writeVideo = F,
12 | image_dtype = "float16",
13 | conda_env = "CausalImagesEnv",
14 | conda_env_required = T,
15 | Sys.setenv_text = NULL
16 | )
17 | }
18 | \arguments{
19 | \item{file}{A character string naming a file for writing.}
20 |
21 | \item{uniqueImageKeys}{A vector specifying the unique image keys of the corpus.
22 | A key grabs an image/video array via \code{acquireImageFxn(key)}.}
23 |
24 | \item{acquireImageFxn}{A function whose input is an observation keys and whose output is an array with dimensions \verb{(length(keys), nSpatialDim1, nSpatialDim2, nChannels)} for images and \verb{(length(keys), nTimeSteps, nSpatialDim1, nSpatialDim2, nChannels)} for image sequence data.}
25 |
26 | \item{writeVideo}{(default = \code{FALSE}) Should we assume we're writing image sequence data of form batch by time by height by width by channels?}
27 |
28 | \item{conda_env}{(default = \code{"CausalImagesEnv"}) A \code{conda} environment where computational environment lives, usually created via \code{causalimages::BuildBackend()}}
29 |
30 | \item{conda_env_required}{(default = \code{T}) A Boolean stating whether use of the specified conda environment is required.}
31 | }
32 | \value{
33 | Writes a unique key-referenced \code{.tfrecord} from an image/video corpus for use in image-based causal inference training.
34 | }
35 | \description{
36 | Writes an image corpus to a \code{.tfrecord} file for rapid reading of images into memory for fast ML training.
37 | Specifically, this function serializes an image or video corpus into a \code{.tfrecord} file, enabling efficient data loading for machine learning tasks, particularly for image-based causal inference training.
38 | It requires that users define an \code{acquireImageFxn} function that accepts keys and returns the corresponding image or video as an array of dimensions \verb{(length(keys), nSpatialDim1, nSpatialDim2, nChannels)} for images or \verb{(length(keys), nTimeSteps, nSpatialDim1, nSpatialDim2, nChannels)} for video sequences.
39 | }
40 | \examples{
41 | # Example usage (not run):
42 | #WriteTfRecord(
43 | # file = "./NigeriaConfoundApp.tfrecord",
44 | # uniqueImageKeys = 1:n,
45 | # acquireImageFxn = acquireImageFxn)
46 |
47 | }
48 |
--------------------------------------------------------------------------------
/causalimages/man/image2.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/CI_image2.R
3 | \name{image2}
4 | \alias{image2}
5 | \title{Visualizing matrices as heatmaps with correct north-south-east-west orientation}
6 | \usage{
7 | image2(
8 | x,
9 | xaxt = NULL,
10 | yaxt = NULL,
11 | xlab = "",
12 | ylab = "",
13 | main = NULL,
14 | cex.main = NULL,
15 | col.lab = "black",
16 | col.main = "black",
17 | cex.lab = 1.5,
18 | box = F
19 | )
20 | }
21 | \arguments{
22 | \item{x}{The numeric matrix to be visualized.}
23 |
24 | \item{xaxt}{The x-axis tick labels.}
25 |
26 | \item{yaxt}{The y-axis tick labels.}
27 |
28 | \item{xlab}{The x-axis labels.}
29 |
30 | \item{ylab}{The y-axis labels.}
31 |
32 | \item{main}{The main figure label.}
33 |
34 | \item{cex.main}{The main figure label sizing factor.}
35 |
36 | \item{col.lab}{Axis label color.}
37 |
38 | \item{col.main}{Main label color.}
39 |
40 | \item{cex.lab}{Cex for the labels.}
41 |
42 | \item{box}{Draw a box around the image?}
43 | }
44 | \value{
45 | Returns a heatmap representation of the matrix, \code{x}, with correct north/south/east/west orientation.
46 | }
47 | \description{
48 | A function for generating a heatmap representation of a matrix with correct spatial orientation.
49 | }
50 | \examples{
51 | #set seed
52 | set.seed(1)
53 |
54 | #Geneate data
55 | x <- matrix(rnorm(50*50), ncol = 50)
56 | diag(x) <- 3
57 |
58 | # create plot
59 | image2(x, main = "Example Text", cex.main = 2)
60 |
61 | }
62 |
--------------------------------------------------------------------------------
/causalimages/man/message2.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/CI_helperFxns.R
3 | \name{message2}
4 | \alias{message2}
5 | \title{message2 message() with timestamps}
6 | \usage{
7 | message2(text, quiet = FALSE)
8 | }
9 | \arguments{
10 | \item{text}{Character string to be displayed as message, with date and time.}
11 |
12 | \item{quiet}{Logical. If TRUE, suppresses the message output. Default is FALSE.}
13 | }
14 | \value{
15 | Displays message with date and time to stderr.
16 | }
17 | \description{
18 | A function that displays a message with date and time.
19 | }
20 | \examples{
21 | message2("Hello world")
22 | message2("Process completed", quiet = FALSE)
23 | }
24 |
--------------------------------------------------------------------------------
/causalimages/man/print2.Rd:
--------------------------------------------------------------------------------
1 | % Generated by roxygen2: do not edit by hand
2 | % Please edit documentation in R/CI_helperFxns.R
3 | \name{print2}
4 | \alias{print2}
5 | \title{print2 print() with timestamps}
6 | \usage{
7 | print2(text, quiet = F)
8 | }
9 | \arguments{
10 | \item{x}{Character string to be printed, with date and time.}
11 | }
12 | \value{
13 | Prints with date and time.
14 | }
15 | \description{
16 | A function prints a string with date and time.
17 | }
18 | \examples{
19 | message("Hello world")
20 | }
21 |
--------------------------------------------------------------------------------
/causalimages/tests/Test_AAARunAllTutorialsSuite.R:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env Rscript
2 | {
3 | # remove all dim calls to arrays
4 | ##########################################
5 | # Code for testing most functionalities of CausalImage on your hardware.
6 | # Current tests failing on optimizing video representation on METAL Other tests succeed.
7 | ##########################################
8 | tryTests <- try({
9 | # remote install latest version of the package
10 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages")
11 |
12 | # local install for development team
13 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F)
14 |
15 | # Before running tests, it may be necessary to (re)build the backend:
16 | # causalimages::BuildBackend()
17 | # See ?causalimages::BuildBackend for help. Remember to re-start your R session after re-building.
18 |
19 | options(error = NULL)
20 |
21 | print("Starting image TfRecords"); setwd("~");
22 | TfRecordsTest <- try(source("~/Documents/causalimages-software/causalimages/tests/Test_UsingTFRecords.R"),T)
23 | if("try-error" %in% class(TfRecordsTest)){ stop("Failed at TfRecordsTest (1)") }; try(dev.off(), T)
24 |
25 | print("Starting ImageRepTest"); setwd("~");
26 | ImageRepTest <- try(source("~/Documents/causalimages-software/causalimages/tests/Test_ExtractImageRepresentations.R"),T)
27 | if("try-error" %in% class(ImageRepTest)){ stop("Failed at ImageRepTest (2)") }; try(dev.off(), T)
28 |
29 | print("Starting ImConfoundTest"); setwd("~");
30 | ImConfoundTest <- try(source("~/Documents/causalimages-software/causalimages/tests/Test_AnalyzeImageConfounding.R"),T)
31 | if("try-error" %in% class(ImConfoundTest)){ stop("Failed at ImConfoundTest (3)") }; try(dev.off(), T)
32 |
33 | #print("Starting HetTest"); setwd("~");
34 | #HetTest <- try(source("~/Documents/causalimages-software/causalimages/tests/Test_AnalyzeImageHeterogeneity.R"),T)
35 | #if("try-error" %in% class(HetTest)){ stop("Failed at HetTest") }; try(dev.off(), T)
36 | }, T)
37 |
38 | if('try-error' %in% class(tryTests)){ print("At least one test failed"); print( tryTests ); stop(tryTests) }
39 | if(!'try-error' %in% class(tryTests)){ print("All tests succeeded!") }
40 | }
41 |
--------------------------------------------------------------------------------
/causalimages/tests/Test_AnalyzeImageConfounding.R:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env Rscript
2 | {
3 | ################################
4 | # Image confounding tutorial using causalimages
5 | ################################
6 | setwd("~/Downloads/"); options( error = NULL )
7 | #setwd("./"); options( error = NULL )
8 |
9 | # remote install latest version of the package if needed
10 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages")
11 |
12 | # local install for development team
13 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F)
14 |
15 | # build backend you haven't ready:
16 | # causalimages::BuildBackend()
17 |
18 | # load in package
19 | library( causalimages )
20 |
21 | # resave TfRecords?
22 | reSaveTFRecord <- F
23 |
24 | # load in tutorial data
25 | data( CausalImagesTutorialData )
26 |
27 | # mean imputation for toy example in this tutorial
28 | X <- apply(X[,-1],2,function(zer){
29 | zer[is.na(zer)] <- mean( zer,na.rm = T ); return( zer )
30 | })
31 |
32 | # select observation subset to make the tutorial quick
33 | set.seed(4321L);take_indices <- unlist( tapply(1:length(obsW),obsW,function(zer){sample(zer, 100)}) )
34 |
35 | # perform causal inference with image and tabular confounding
36 | {
37 | # example acquire image function (loading from memory)
38 | # in general, you'll want to write a function that returns images
39 | # that saved disk associated with keys
40 | acquireImageFxn <- function(keys){
41 | # here, the function input keys
42 | # refers to the unit-associated image keys
43 | # we also tweak the image dimensions for testing purposes
44 | #m_ <- FullImageArray[match(keys, KeysOfImages),c(1:35,1:35),c(1:35,1:35),1:2] # test with two channels
45 | #m_ <- FullImageArray[match(keys, KeysOfImages),c(1:35,1:35),c(1:35,1:35),] # test with three channels
46 | m_ <- FullImageArray[match(keys, KeysOfImages),c(1:35,1:35),c(1:35,1:35),c(1:3,1:2)] # test with five channels
47 |
48 | # if keys == 1, add the batch dimension so output dims are always consistent
49 | # (here in image case, dims are batch by height by width by channel)
50 | if(length(keys) == 1){ m_ <- array(m_,dim = c(1L,dim(m_)[1],dim(m_)[2],dim(m_)[3])) }
51 |
52 | return( m_ )
53 | }
54 |
55 | # write tf record
56 | TFRecordName_im <- "./ImageTutorial/TutorialData_im.tfrecord"
57 | if( reSaveTFRecord ){
58 | causalimages::WriteTfRecord(
59 | file = TFRecordName_im,
60 | uniqueImageKeys = unique(KeysOfObservations[ take_indices ]),
61 | acquireImageFxn = acquireImageFxn)
62 | }
63 |
64 | for(ImageModelClass in (c("VisionTransformer","CNN"))){
65 | for(optimizeImageRep in c(T,F)){
66 | print(sprintf("Image confounding analysis & optimizeImageRep: %s & ImageModelClass: %s",optimizeImageRep, ImageModelClass))
67 | ImageConfoundingAnalysis <- causalimages::AnalyzeImageConfounding(
68 | obsW = obsW[ take_indices ],
69 | obsY = obsY[ take_indices ],
70 | X = X[ take_indices,apply(X[ take_indices,],2,sd)>0],
71 | long = LongLat$geo_long[ take_indices ], # optional argument
72 | lat = LongLat$geo_lat[ take_indices ], # optional argument
73 | imageKeysOfUnits = KeysOfObservations[ take_indices ],
74 | file = TFRecordName_im,
75 |
76 | batchSize = 16L,
77 | nBoot = 5L,
78 | optimizeImageRep = optimizeImageRep,
79 | imageModelClass = ImageModelClass,
80 | nDepth_ImageRep = ifelse(optimizeImageRep, yes = 1L, no = 1L),
81 | nWidth_ImageRep = as.integer(2L^6),
82 | learningRateMax = 0.001, nSGD = 10L, #
83 | dropoutRate = NULL, # 0.1,
84 | plotBands = c(1,2,3),
85 | plotResults = T, figuresTag = "ConfoundingImTutorial",
86 | figuresPath = "./ImageTutorial")
87 | try(dev.off(), T)
88 | #ImageConfoundingAnalysis$ModelEvaluationMetrics
89 | }
90 | }
91 |
92 | # ATE estimate (image confounder adjusted)
93 | ImageConfoundingAnalysis$tauHat_propensityHajek
94 |
95 | # ATE se estimate (image confounder adjusted)
96 | ImageConfoundingAnalysis$tauHat_propensityHajek_se
97 |
98 | # some out-of-sample evaluation metrics
99 | ImageConfoundingAnalysis$ModelEvaluationMetrics
100 |
101 | }
102 |
103 | # perform causal inference with image *sequence* and tabular confounding
104 | {
105 | acquireVideoRep <- function(keys) {
106 | # Note: this is a toy function generating image representations
107 | # that simply reuse a single temporal slice. In practice, we will
108 | # weant to read in images of different time periods.
109 |
110 | # Get image data as an array from disk
111 | tmp <- acquireImageFxn(keys)
112 |
113 | # Expand dimensions: we create a new dimension at the start
114 | tmp <- array(tmp, dim = c(1, dim(tmp)))
115 |
116 | # Transpose dimensions to get the target order
117 | tmp <- aperm(tmp, c(2, 1, 3, 4, 5))
118 |
119 | # Swap image dimensions to see variability across time
120 | tmp_ <- aperm(tmp, c(1, 2, 4, 3, 5))
121 |
122 | # Concatenate along the second axis
123 | tmp <- abind::abind(tmp, tmp, tmp_, tmp_, along = 2)
124 |
125 | return(tmp)
126 | }
127 |
128 | # write tf record
129 | TFRecordName_imSeq <- "./ImageTutorial/TutorialData_imSeq.tfrecord"
130 | if( reSaveTFRecord ){
131 | causalimages::WriteTfRecord(
132 | file = TFRecordName_imSeq,
133 | uniqueImageKeys = unique(KeysOfObservations[ take_indices ]),
134 | acquireImageFxn = acquireVideoRep,
135 | writeVideo = T)
136 | }
137 |
138 | for(ImageModelClass in c("VisionTransformer","CNN")){
139 | for(optimizeImageRep in c(T, F)){
140 | print(sprintf("Image seq confounding analysis & optimizeImageRep: %s & ImageModelClass: %s",optimizeImageRep, ImageModelClass))
141 | ImageSeqConfoundingAnalysis <- causalimages::AnalyzeImageConfounding(
142 | obsW = obsW[ take_indices ],
143 | obsY = obsY[ take_indices ],
144 | X = X[ take_indices,apply(X[ take_indices,],2,sd)>0],
145 | long = LongLat$geo_long[ take_indices ],
146 | lat = LongLat$geo_lat[ take_indices ],
147 | file = TFRecordName_imSeq, dataType = "video",
148 | imageKeysOfUnits = KeysOfObservations[ take_indices ],
149 |
150 | # model specifics
151 | batchSize = 16L,
152 | optimizeImageRep = optimizeImageRep,
153 | imageModelClass = ImageModelClass,
154 | nDepth_ImageRep = ifelse(optimizeImageRep, yes = 1L, no = 1L),
155 | learningRateMax = 0.001,
156 | nSGD = 50L, #
157 | nWidth_ImageRep = as.integer(2L^7),
158 | nBoot = 5L,
159 | plotBands = c(1,2,3),
160 | plotResults = T, figuresTag = "ConfoundingImSeqTutorial",
161 | figuresPath = "./ImageTutorial") # figures saved here
162 | try(dev.off(), T)
163 | }
164 | }
165 |
166 | # ATE estimate (image confounder adjusted)
167 | ImageSeqConfoundingAnalysis$tauHat_propensityHajek
168 |
169 | # ATE se estimate (image seq confounder adjusted)
170 | ImageSeqConfoundingAnalysis$tauHat_propensityHajek_se
171 |
172 | # some out-of-sample evaluation metrics
173 | ImageSeqConfoundingAnalysis$ModelEvaluationMetrics
174 | print("Done with confounding test!")
175 | }
176 | }
177 |
--------------------------------------------------------------------------------
/causalimages/tests/Test_AnalyzeImageHeterogeneity.R:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env Rscript
2 | {
3 | ################################
4 | # Image heterogeneity tutorial using causalimages
5 | ################################
6 |
7 | # remote install latest version of the package
8 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages")
9 |
10 | # local install for development team
11 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F)
12 |
13 | # build backend you haven't ready:
14 | # causalimages::BuildBackend()
15 |
16 | # run code if downloading data for the first time
17 | download_folder <- "~/Downloads/UgandaAnalysis"
18 | reSaveTfRecords <- F
19 | if( reDownloadRawData <- F ){
20 |
21 | # (1) Specify the Dataverse dataset DOI
22 | doi <- "doi:10.7910/DVN/O8XOSF"
23 |
24 | # (2) Construct the download URL for the entire dataset as a .zip
25 | # By convention on Harvard Dataverse, this API endpoint:
26 | # https://{server}/api/access/dataset/:persistentId?persistentId={doi}&format=original
27 | # downloads all files in a single zip.
28 | base_url <- "https://dataverse.harvard.edu"
29 | download_url <- paste0(base_url,
30 | "/api/access/dataset/:persistentId",
31 | "?persistentId=", doi,
32 | "&format=original")
33 |
34 | # (3) Download the ZIP file
35 | destfile <- "~/Downloads/UgandaAnalysis.zip"
36 | download.file(download_url, destfile = destfile, mode = "wb")
37 |
38 | # unzip and list files
39 | unzip(destfile, exdir = "~/Downloads/UgandaAnalysis")
40 | unzip('./UgandaAnalysis/Uganda2000_processed.zip', exdir = "~/Downloads/UgandaAnalysis")
41 | }
42 |
43 | # load in package
44 | library( causalimages ); options(error = NULL)
45 |
46 | # set new wd
47 | setwd( "~/Downloads" )
48 |
49 | # see directory contents
50 | list.files()
51 |
52 | # images saved here
53 | list.files( "./UgandaAnalysis/Uganda2000_processed" )
54 |
55 | # individual-level data
56 | UgandaDataProcessed <- read.csv( "./UgandaAnalysis/UgandaDataProcessed.csv" )
57 |
58 | # unit-level covariates (many covariates are subject to missingness!)
59 | dim( UgandaDataProcessed )
60 | table( UgandaDataProcessed$age )
61 |
62 | # approximate longitude + latitude for units
63 | head( cbind(UgandaDataProcessed$geo_long, UgandaDataProcessed$geo_lat) )
64 |
65 | # image keys of units (use for referencing satellite images)
66 | UgandaDataProcessed$geo_long_lat_key
67 |
68 | # an experimental outcome
69 | UgandaDataProcessed$Yobs
70 |
71 | # treatment variable
72 | UgandaDataProcessed$Wobs
73 |
74 | # information on keys linking to satellite images for all of Uganda
75 | # (not just experimental context, use for constructing transportability maps)
76 | UgandaGeoKeyMat <- read.csv( "./UgandaAnalysis/UgandaGeoKeyMat.csv" )
77 |
78 | # set outcome to an income index
79 | UgandaDataProcessed$Yobs <- UgandaDataProcessed$income_index_e_RECREATED
80 |
81 | # drop observations with NAs in key variables
82 | # (you can also use a multiple imputation strategy)
83 | UgandaDataProcessed <- UgandaDataProcessed[!is.na(UgandaDataProcessed$Yobs) &
84 | !is.na(UgandaDataProcessed$Wobs) &
85 | !is.na(UgandaDataProcessed$geo_lat) , ]
86 |
87 | # sanity checks
88 | {
89 | # write a function that reads in images as saved and process them into an array
90 | NBANDS <- 3L
91 | imageHeight <- imageWidth <- 351L # pixel height/width
92 | acquireImageRep <- function(keys){
93 | # initialize an array shell to hold image slices
94 | array_shell <- array(NA, dim = c(1L, imageHeight, imageWidth, NBANDS))
95 |
96 | # iterate over keys:
97 | # -- images are referenced to keys
98 | # -- keys are referenced to units (to allow for duplicate images uses)
99 | array_ <- sapply(keys, function(key_) {
100 | # iterate over all image bands (NBANDS = 3 for RBG images)
101 | for (band_ in 1:NBANDS) {
102 | # place the image in the correct place in the array
103 | array_shell[,,,band_] <-
104 | as.matrix(data.table::fread(
105 | input = sprintf("./UgandaAnalysis/Uganda2000_processed/GeoKey%s_BAND%s.csv", key_, band_), header = FALSE)[-1,])
106 | }
107 | return(array_shell)
108 | }, simplify = "array")
109 |
110 | # return the array in the format c(nBatch, imageWidth, imageHeight, nChannels)
111 | # ensure that the dimensions are correctly ordered for further processing
112 | if(length(keys) > 1){ array_ <- aperm(array_[1,,,,], c(4, 1, 2, 3) ) }
113 | if(length(keys) == 1){
114 | array_ <- aperm(array_, c(1,5, 2, 3, 4))
115 | array_ <- array(array_, dim(array_)[-1])
116 | }
117 |
118 | return(array_)
119 | }
120 |
121 | # try out the function
122 | # note: some units are co-located in same area (hence, multiple observations per image key)
123 | ImageBatch <- acquireImageRep( UgandaDataProcessed$geo_long_lat_key[ check_indices <- c(1, 20, 50, 101) ])
124 | acquireImageRep( UgandaDataProcessed$geo_long_lat_key[ check_indices[1] ] )
125 |
126 | # sanity checks in the analysis of earth observation data are essential
127 | # check that images are centered around correct location
128 | causalimages::image2( as.array(ImageBatch)[1,,,1] )
129 | UgandaDataProcessed$geo_long[check_indices[1]]
130 | UgandaDataProcessed$geo_lat[check_indices[1]]
131 | # check against google maps to confirm correctness
132 | # https://www.google.com/maps/place/1%C2%B018'16.4%22N+34%C2%B005'15.1%22E/@1.3111951,34.0518834,10145m/data=!3m1!1e3!4m4!3m3!8m2!3d1.3045556!4d34.0875278?entry=ttu
133 |
134 | # scramble data (important for reading into causalimages::WriteTfRecord
135 | # to ensure no systematic biases in data sequence with model training
136 | set.seed(144L); UgandaDataProcessed <- UgandaDataProcessed[sample(1:nrow(UgandaDataProcessed)),]
137 | }
138 |
139 | # Image heterogeneity example with tfrecords
140 | # write a tf records repository
141 | # whenever changes are made to the input data to AnalyzeImageHeterogeneity, WriteTfRecord() should be re-run
142 | # to ensure correct ordering of data
143 | tfrecord_loc <- "~/Downloads/UgandaExample.tfrecord"
144 | if( reSaveTfRecords ){
145 | causalimages::WriteTfRecord(
146 | file = tfrecord_loc,
147 | uniqueImageKeys = unique(UgandaDataProcessed$geo_long_lat_key),
148 | acquireImageFxn = acquireImageRep )
149 | }
150 |
151 | for(ImageModelClass in c("VisionTransformer","CNN")){
152 | for(optimizeImageRep in c(T, F)){
153 | print(sprintf("Image hetero analysis & optimizeImageRep: %s",optimizeImageRep))
154 | ImageHeterogeneityResults <- causalimages::AnalyzeImageHeterogeneity(
155 | # data inputs
156 | obsW = UgandaDataProcessed$Wobs,
157 | obsY = UgandaDataProcessed$Yobs,
158 | X = matrix(rnorm(length(UgandaDataProcessed$Yobs)*10),ncol=10),
159 | imageKeysOfUnits = UgandaDataProcessed$geo_long_lat_key,
160 | file = tfrecord_loc, # location of tf record (use absolute file paths)
161 | lat = UgandaDataProcessed$geo_lat, # not required but helpful for dealing with redundant locations in EO data
162 | long = UgandaDataProcessed$geo_long, # not required but helpful for dealing with redundant locations in EO data
163 |
164 | # inputs to control where visual results are saved as PDF or PNGs
165 | # (these image grids are large and difficult to display in RStudio's interactive mode)
166 | plotResults = T,
167 | figuresPath = "~/Downloads/HeteroTutorial", # where to write analysis figures
168 | figuresTag = "HeterogeneityImTutorial",plotBands = 1L:3L,
169 |
170 | # optional arguments for generating transportability maps
171 | # here, we leave those NULL for simplicity
172 | transportabilityMat = NULL, #
173 |
174 | # other modeling options
175 | imageModelClass = ImageModelClass,
176 | nSGD = 5L, # make this larger for real applications (e.g., 2000L)
177 | nDepth_ImageRep = ifelse(optimizeImageRep, yes = 1L, no = 1L),
178 | nWidth_ImageRep = as.integer(2L^6),
179 | optimizeImageRep = optimizeImageRep,
180 | batchSize = 8L, # make this larger for real application (e.g., 50L)
181 | kClust_est = 1 # vary depending on problem. Usually < 5
182 | )
183 | try(dev.off(), T)
184 | }
185 | }
186 |
187 | # video heterogeneity example
188 | {
189 | acquireVideoRep <- function(keys) {
190 | # Get image data as an array from disk
191 | tmp <- acquireImageRep(keys)
192 |
193 | # Expand dimensions: we create a new dimension at the start
194 | tmp <- array(tmp, dim = c(1, dim(tmp)))
195 |
196 | # Transpose dimensions to get the required order
197 | tmp <- aperm(tmp, c(2, 1, 3, 4, 5))
198 |
199 | # Swap image dimensions to see variability across time
200 | tmp_ <- aperm(tmp, c(1, 2, 4, 3, 5))
201 |
202 | # Concatenate along the second axis
203 | tmp <- abind::abind(tmp, tmp_, along = 2)
204 |
205 | return(tmp)
206 | }
207 |
208 | # write the tf records repository
209 | tfrecord_loc_imSeq <- "~/Downloads/UgandaExampleVideo.tfrecord"
210 | if(reSaveTfRecords){
211 | causalimages::WriteTfRecord( file = tfrecord_loc_imSeq,
212 | uniqueImageKeys = unique(UgandaDataProcessed$geo_long_lat_key),
213 | acquireImageFxn = acquireVideoRep, writeVideo = T )
214 | }
215 |
216 | for(ImageModelClass in (c("VisionTransformer","CNN"))){
217 | for(optimizeImageRep in c(T, F)){
218 | print(sprintf("Image seq hetero analysis & optimizeImageRep: %s",optimizeImageRep))
219 | # Note: optimizeImageRep = T breaks with video on METAL framework
220 | VideoHeterogeneityResults <- causalimages::AnalyzeImageHeterogeneity(
221 | # data inputs
222 | obsW = UgandaDataProcessed$Wobs,
223 | obsY = UgandaDataProcessed$Yobs,
224 | imageKeysOfUnits = UgandaDataProcessed$geo_long_lat_key,
225 | file = tfrecord_loc_imSeq, # location of tf record (absolute paths are safest)
226 | dataType = "video",
227 | lat = UgandaDataProcessed$geo_lat, # not required but helpful for dealing with redundant locations in EO data
228 | long = UgandaDataProcessed$geo_long, # not required but helpful for dealing with redundant locations in EO data
229 |
230 | # inputs to control where visual results are saved as PDF or PNGs
231 | # (these image grids are large and difficult to display in RStudio's interactive mode)
232 | plotResults = T,
233 | figuresPath = "~/Downloads/HeteroTutorial",
234 | plotBands = 1L:3L, figuresTag = "HeterogeneityImSeqTutorial",
235 |
236 | # optional arguments for generating transportability maps
237 | # here, we leave those NULL for simplicity
238 | transportabilityMat = NULL, #
239 |
240 | # other modeling options
241 | imageModelClass = ImageModelClass,
242 | nSGD = 5L, # make this larger for real applications (e.g., 2000L)
243 | nDepth_ImageRep = ifelse(optimizeImageRep, yes = 1L, no = 1L),
244 | nWidth_ImageRep = as.integer(2L^5),
245 | optimizeImageRep = optimizeImageRep,
246 | kClust_est = 2, # vary depending on problem. Usually < 5
247 | batchSize = 8L, # make this larger for real application (e.g., 50L)
248 | strides = 2L )
249 | try(dev.off(), T)
250 | }
251 | }
252 | }
253 | print("Done with image heterogeneity test!")
254 | }
255 |
--------------------------------------------------------------------------------
/causalimages/tests/Test_BuildBackend.R:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env Rscript
2 | {
3 | # remote install latest version of the package
4 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages")
5 |
6 | # local install for development team
7 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F)
8 |
9 | # setup backend, conda points to location of python in conda where packages will be downloaded
10 | # note: This function requires an Internet connection
11 | # you can find out a list of conda Python paths via:
12 | # system("which python")
13 | causalimages::BuildBackend(conda = "/Users/cjerzak/miniforge3/bin/python")
14 | print("Done with BuildBackend() test!")
15 | }
16 |
17 |
--------------------------------------------------------------------------------
/causalimages/tests/Test_ExtractImageRepresentations.R:
--------------------------------------------------------------------------------
1 | {
2 | ################################
3 | # Image and image-sequence embeddings tutorial using causalimages
4 | ################################
5 |
6 | # remote install latest version of the package if needed
7 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages")
8 |
9 | # local install for development team
10 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F)
11 |
12 | # build backend you haven't ready:
13 | # causalimages::BuildBackend()
14 |
15 | # load in package
16 | library( causalimages ); options(error = NULL)
17 |
18 | # load in tutorial data
19 | data( CausalImagesTutorialData )
20 |
21 | # example acquire image function (loading from memory)
22 | # in general, you'll want to write a function that returns images
23 | # that saved disk associated with keys
24 | acquireImageFromMemory <- function(keys){
25 | # here, the function input keys
26 | # refers to the unit-associated image keys
27 | m_ <- FullImageArray[match(keys, KeysOfImages),,,]
28 |
29 | # if keys == 1, add the batch dimension so output dims are always consistent
30 | # (here in image case, dims are batch by height by width by channel)
31 | if(length(keys) == 1){
32 | m_ <- array(m_,dim = c(1L,dim(m_)[1],dim(m_)[2],dim(m_)[3]))
33 | }
34 | return( m_ )
35 | }
36 |
37 | # drop first column
38 | X <- X[,-1]
39 |
40 | # mean imputation for simplicity
41 | X <- apply(X,2,function(zer){
42 | zer[is.na(zer)] <- mean( zer,na.rm = T )
43 | return( zer )
44 | })
45 |
46 | # select observation subset to make tutorial analyses run faster
47 | take_indices <- unlist( tapply(1:length(obsW),obsW,function(zer){ sample(zer, 50) }) )
48 |
49 | # write tf record
50 | TfRecord_name <- "~/Downloads/CausalImagesTutorialDat.tfrecord"
51 | causalimages::WriteTfRecord( file = TfRecord_name,
52 | uniqueImageKeys = unique( KeysOfObservations[ take_indices ] ),
53 | acquireImageFxn = acquireImageFromMemory )
54 |
55 | # obtain image representation
56 | MyImageEmbeddings <- causalimages::GetImageRepresentations(
57 | file = TfRecord_name,
58 | imageModelClass = "VisionTransformer",
59 | pretrainedModel = "clip-rsicd",
60 | imageKeysOfUnits = KeysOfObservations[ take_indices ]
61 | )
62 |
63 | # each row in MyImageEmbeddings$ImageRepresentations corresponds to an observation
64 | # each column represents an embedding dimension associated with the imagery for that location
65 | dim( MyImageEmbeddings$ImageRepresentations )
66 | plot( MyImageEmbeddings$ImageRepresentations )
67 |
68 | # other output quantities include the image model functions and model parameters
69 | names( MyImageEmbeddings )[-1]
70 |
71 | print("Done with image representations test!")
72 | }
73 |
--------------------------------------------------------------------------------
/causalimages/tests/Test_UsingTfRecords.R:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env Rscript
2 | {
3 | ################################
4 | # Image confounding tutorial using causalimages
5 | # and tfrecords for faster results
6 | ################################
7 |
8 | # remote install latest version of the package if needed
9 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages")
10 |
11 | # local install for development team
12 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F)
13 |
14 | # build backend you haven't ready:
15 | # causalimages::BuildBackend()
16 |
17 | # load in package
18 | library( causalimages )
19 |
20 | # load in tutorial data
21 | data( CausalImagesTutorialData )
22 |
23 | # example acquire image function (loading from memory)
24 | # in general, you'll want to write a function that returns images
25 | # that saved disk associated with keys
26 | acquireImageFromMemory <- function(keys){
27 | # here, the function input keys
28 | # refers to the unit-associated image keys
29 | m_ <- FullImageArray[match(keys, KeysOfImages),,,]
30 |
31 | # if keys == 1, add the batch dimension so output dims are always consistent
32 | # (here in image case, dims are batch by height by width by channel)
33 | if(length(keys) == 1){
34 | m_ <- array(m_,dim = c(1L,dim(m_)[1],dim(m_)[2],dim(m_)[3]))
35 | }
36 |
37 | # uncomment for a test with different image dimensions
38 | #if(length(keys) == 1){ m_ <- abind::abind(m_,m_,m_,along = 3L) }; if(length(keys) > 1){ m_ <- abind::abind(m_,m_,m_,.along = 4L) }
39 | return( m_ )
40 | }
41 |
42 | dim( acquireImageFromMemory(KeysOfImages[1]) )
43 | dim( acquireImageFromMemory(KeysOfImages[1:2]) )
44 |
45 | # drop first column
46 | X <- X[,-1]
47 |
48 | # mean imputation for simplicity
49 | X <- apply(X,2,function(zer){
50 | zer[is.na(zer)] <- mean( zer,na.rm = T )
51 | return( zer )
52 | })
53 |
54 | # select observation subset to make tutorial analyses run faster
55 | # select 50 treatment and 50 control observations
56 | set.seed(1.)
57 | take_indices <- unlist( tapply(1:length(obsW),obsW,function(zer){sample(zer, 50)}) )
58 |
59 | # !!! important note !!!
60 | # when using tf recordings, it is essential that the data inputs be pre-shuffled like is done here.
61 | # you can use a seed for reproducing the shuffle (so the tfrecord is correctly indexed and you don't need to re-make it)
62 | # tf records read data quasi-sequentially, so systematic patterns in the data ordering
63 | # reduce performance
64 |
65 | # uncomment for a larger n analysis
66 | #take_indices <- 1:length( obsY )
67 |
68 | # set tfrecord save location (safest using absolute path)
69 | tfrecord_loc <- "./ExampleRecord.tfrecord"
70 |
71 | # you may use relative paths like this:
72 | # tfrecord_loc <- "./Downloads/test1/test2/test3/ExampleRecord.tfrecord"
73 |
74 | # or absolute paths like this:
75 | # tfrecord_loc <- "~/Downloads/test1/test2/test3/ExampleRecord.tfrecord"
76 |
77 | # write a tf records repository
78 | causalimages::WriteTfRecord( file = tfrecord_loc,
79 | uniqueImageKeys = unique( KeysOfObservations[ take_indices ] ),
80 | acquireImageFxn = acquireImageFromMemory )
81 | }
82 |
83 |
--------------------------------------------------------------------------------
/documentPackage.R:
--------------------------------------------------------------------------------
1 | {
2 | rm(list=ls()); options(error = NULL)
3 | package_name <- "causalimages"
4 | setwd(sprintf("~/Documents/%s-software", package_name))
5 |
6 | package_path <- sprintf("~/Documents/%s-software/%s",package_name,package_name)
7 | tools::add_datalist(package_path, force = TRUE)
8 | devtools::document(package_path)
9 | try(file.remove(sprintf("./%s.pdf",package_name)),T)
10 | system(sprintf("R CMD Rd2pdf %s",package_path))
11 |
12 | # install.packages( sprintf("~/Documents/%s-software/%s",package_name,package_name),repos = NULL, type = "source")
13 | # library( causalimages ); data( CausalImagesTutorialData )
14 | log(sort( sapply(ls(),function(l_){ object.size(eval(parse(text=l_))) })))
15 |
16 | # Check package to ensure it meets CRAN standards.
17 | # devtools::check( package_path )
18 |
19 | # see https://github.com/RConsortium/S7
20 | }
21 |
--------------------------------------------------------------------------------
/misc/dataverse/DataverseReadme_confounding.md:
--------------------------------------------------------------------------------
1 | Replication Data for: Integrating Earth Observation Data into Causal Inference: Challenges and Opportunities
2 |
3 | Connor T. Jerzak, Fredrik Johansson, Adel Daoud. Integrating Earth Observation Data into Causal Inference: Challenges and Opportunities. ArXiv Preprint, 2023.
4 |
5 | YandW_mat.csv contains individual-level observational data. In the dataset, LONGITUDE and LATITUDE refer to the approximate geo-referenced long/lat of observational units. Experimental outcomes are stored in Yobs. Treatment variable is stored in Wobs. See the tutorial for more information. The unique image key for each observational unit is saved in UNIQUE_ID.
6 |
7 | Geo-referenced satellite images are saved in
8 | "./Nigeria2000_processed/%s_BAND%s.csv"", where the first "%s" refers to the the image key associated with each observation (saved in UNIQUE_ID in YandW_mat.csv) and BAND%s refers to one of 3 bands in the satellite imagery.
9 |
10 | After downloading the replication package, unzip `Nigeria2000_processed.zip` so
11 | that the `Nigeria2000_processed` folder containing the band CSV files resides in
12 | the same directory as `YandW_mat.csv`.
13 |
14 | For more information, see: https://github.com/cjerzak/causalimages-software/
15 |
--------------------------------------------------------------------------------
/misc/dataverse/DataverseReadme_heterogeneity.md:
--------------------------------------------------------------------------------
1 | Replication Data for: Image-based Treatment Effect Heterogeneity
2 |
3 | Connor Thomas Jerzak, Fredrik Daniel Johansson, Adel Daoud Proceedings of the Second Conference on Causal Learning and Reasoning, PMLR 213:531-552, 2023.
4 |
5 | UgandaDataProcessed.csv contains individual-level data from the YOP experiment. In the dataset, geo_long and geo_lat refer to the approximate geo-referenced long/lat of experimental units. The variable, geo_long_lat_key, refers to the image key associated with each location. Experimental outcomes are stored in Yobs. Treatment variable is stored in Wobs. See the tutorial for more information.
6 |
7 | UgandaGeoKeyMat.csv contains information on keys linking to satellite images for all of Uganda for the transportability analysis.
8 |
9 | Geo-referenced satellite images are saved in "./Uganda2000_processed/GeoKey%s_BAND%s.csv", where GeoKey%s denotes the image key associated with each observation and BAND%s refers to one of 3 bands in the satellite imagery.
10 |
11 | Unzip `Uganda2000_processed.zip` so the `Uganda2000_processed` directory
12 | containing the band CSVs sits alongside `UgandaDataProcessed.csv` before running
13 | the tutorial scripts.
14 |
15 | For more information, see: https://github.com/cjerzak/causalimages-software/blob/main/tutorials/AnalyzeImageHeterogeneity_FullTutorial.R
16 |
--------------------------------------------------------------------------------
/misc/dataverse/DataverseTutorial_confounding.R:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env Rscript
2 |
3 | ################################
4 | # For an up-to-date version of this tutorial, see
5 | #
6 | ################################
7 |
8 | # set new wd
9 |
--------------------------------------------------------------------------------
/misc/dataverse/DataverseTutorial_heterogeneity.R:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env Rscript
2 |
3 | ################################
4 | # For an up-to-date version of this tutorial, see
5 | # https://github.com/cjerzak/causalimages-software/blob/main/tutorials/AnalyzeImageHeterogeneity_FullTutorial.R
6 | ################################
7 |
8 | # set new wd
9 | setwd(sprintf('%s/Public Replication Data, YOP Experiment/',
10 | gsub(download_folder,pattern="\\.zip",replace="")))
11 |
12 | # see directory contents
13 | list.files()
14 |
15 | # images saved here
16 | list.files( "./Uganda2000_processed" )
17 |
18 | # individual-level data
19 | UgandaDataProcessed <- read.csv( "./UgandaDataProcessed.csv" )
20 |
21 | # unit-level covariates (many covariates are subject to missingness!)
22 | dim( UgandaDataProcessed )
23 | table( UgandaDataProcessed$female )
24 | table( UgandaDataProcessed$age )
25 |
26 | # approximate longitude + latitude for units
27 | UgandaDataProcessed$geo_long
28 | UgandaDataProcessed$geo_lat
29 |
30 | # image keys of units (use for referencing satellite images)
31 | UgandaDataProcessed$geo_long_lat_key
32 |
33 | # an experimental outcome
34 | UgandaDataProcessed$Yobs
35 |
36 | # treatment variable
37 | UgandaDataProcessed$Wobs
38 |
39 | # information on keys linking to satellite images for all of Uganda
40 | # (not just experimental context, use for constructing transportability maps)
41 | UgandaGeoKeyMat <- read.csv( "./UgandaGeoKeyMat.csv" )
42 | tail( UgandaGeoKeyMat )
43 |
44 | # Geo-referenced satellite images are saved in
45 | # "./Uganda2000_processed/GeoKey%s_BAND%s.csv",
46 | # where GeoKey%s denotes the image key associated with each observation and
47 | # BAND%s refers to one of 3 bands in the satellite imagery.
48 | # See https://github.com/cjerzak/causalimages-software/blob/main/tutorials/AnalyzeImageHeterogeneity_FullTutorial.R
49 | # for up-to-date useage information.
50 |
--------------------------------------------------------------------------------
/misc/docker/setup/CodexHelpers.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -euo pipefail
4 | # set -x
5 |
6 | #------------------------------------------------------------------------------
7 | # 1. Install R and unzip, other utilities
8 | #------------------------------------------------------------------------------
9 | apt-get update \
10 | && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
11 | r-base r-base-dev \
12 | libssl-dev libxml2-dev unzip \
13 | build-essential libcurl4-openssl-dev \
14 | libgdal-dev libgeos-dev libproj-dev libudunits2-dev ca-certificates \
15 | && apt-get clean \
16 | && rm -rf /var/lib/apt/lists/*
17 |
18 | #------------------------------------------------------------------------------
19 | # 2. Install Miniconda into /opt/conda
20 | #------------------------------------------------------------------------------
21 | readonly MINICONDA_SH="/tmp/Miniconda3-latest-Linux-x86_64.sh"
22 | wget --quiet "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" -O "$MINICONDA_SH"
23 | bash "$MINICONDA_SH" -b -p /opt/conda
24 | rm "$MINICONDA_SH"
25 |
26 | # Make sure conda is on PATH
27 | export PATH="/opt/conda/bin:${PATH}"
28 |
29 | # Initialize conda for this shell
30 | if [[ -f "/opt/conda/etc/profile.d/conda.sh" ]]; then
31 | source "/opt/conda/etc/profile.d/conda.sh"
32 | else
33 | echo "Error: /opt/conda/etc/profile.d/conda.sh not found" >&2
34 | exit 1
35 | fi
36 |
37 |
38 | # Define URLs for ZIP of tars and conda ZIP
39 | THE_ZIP_URL="https://www.dl.dropboxusercontent.com/scl/fi/ek76700s9o4p4grwfw5ko/Archive.zip?rlkey=w0vj0qh78i3eb3c5zt4vlmot9&st=f8kxeq6t&dl=1"
40 | readonly CONDA_ENV_ZIP_URL="https://www.dl.dropboxusercontent.com/scl/fi/k5vylxygjl4icm76drtsz/CausalImagesEnv.zip?rlkey=hmmwpma9bihoze25vee44dktz&st=0vz2vmbk&dl=1"
41 |
42 | #------------------------------------------------------------------------------
43 | # 3. Fetch and unpack the ZIP-of-tar's
44 | #------------------------------------------------------------------------------
45 | BUNDLE_ZIP="binaries.zip"
46 |
47 | # download the bundle
48 | curl -sSL -o "/tmp/${BUNDLE_ZIP}" "${THE_ZIP_URL}"
49 |
50 | # unpack into a temp directory
51 | mkdir -p /tmp/binaries
52 | unzip -q "/tmp/${BUNDLE_ZIP}" -d /tmp/binaries
53 |
54 | echo "Files extracted to /tmp/binaries:"
55 | find /tmp/binaries -type f -printf " ➜ %P\n"
56 |
57 | #------------------------------------------------------------------------------
58 | # 4. Install each inner ZIP into R's library path
59 | #------------------------------------------------------------------------------
60 | # find the first (user) library path and make sure it exists
61 | LIB="$(Rscript -e 'cat(.libPaths()[1])')"
62 | mkdir -p "$LIB"
63 |
64 | #echo "Binary ZIPs found under /tmp/binaries:"
65 | #for f in /tmp/binaries/*.tar.gz; do
66 | # echo " ➜ $f"
67 | # tar -xzf "$f" -C "$LIB"
68 | #done
69 |
70 | echo "Installing binary packages into $LIB:"
71 | for f in /tmp/binaries/*.tar.gz; do
72 | pkg=$(basename "$f" | sed -E 's/_.*//') # e.g. DBI_1.2.3... → DBI
73 | echo " ➜ $pkg"
74 | mkdir -p "$LIB/$pkg" # ensure lib/pkg exists
75 | #tar --strip-components=1 -xzf "$f" -C "$LIB/$pkg" # drop top‐level folder
76 | tar -xzf "$f" -C "$LIB"
77 |
78 | echo " Contents of $LIB/$pkg:"
79 | find "$LIB/$pkg" -maxdepth 1 -mindepth 1 -printf " ➜ %f\n"
80 | done
81 |
82 | # clean up temporary files
83 | rm -rf /tmp/${BUNDLE_ZIP} /tmp/binaries
84 |
85 | echo "Installed R packages in $LIB:"
86 | Rscript -e '
87 | #lib <- Sys.getenv("R_LIBS_USER", .libPaths()[1])
88 | #pkgs <- rownames(installed.packages(lib.loc = lib))
89 | pkgs <- rownames(installed.packages(lib.loc = .libPaths()))
90 | cat(paste0("Installed packages found ➜ ", pkgs, "\n"), sep = "")
91 | '
92 |
93 | echo "Installing causalimages backend..."
94 | readonly CONDA_ENV_ZIP="/tmp/conda_env.zip"
95 |
96 | echo "Downloading pre-built conda environment from $CONDA_ENV_ZIP_URL"
97 | curl -sSL -o "$CONDA_ENV_ZIP" "$CONDA_ENV_ZIP_URL"
98 |
99 | echo "Unpacking conda environment into /opt/conda/envs/"
100 | unzip -q "$CONDA_ENV_ZIP" -d "/opt/conda/envs/"
101 | rm "$CONDA_ENV_ZIP"
102 |
103 | echo "Available conda environments:"
104 | conda env list
105 |
106 | echo "Success: Done with download & unpacking script! Ready to experiment."
107 |
--------------------------------------------------------------------------------
/misc/docker/setup/Dockerfile:
--------------------------------------------------------------------------------
1 |
2 | docker run --rm -it \
3 | --platform=linux/amd64 \
4 | -e CODEX_ENV_PYTHON_VERSION=3.12 \
5 | -e CODEX_ENV_NODE_VERSION=20 \
6 | -e CODEX_ENV_RUST_VERSION=1.87.0 \
7 | -e CODEX_ENV_GO_VERSION=1.23.8 \
8 | -e CODEX_ENV_SWIFT_VERSION=6.1 \
9 | -v "$HOME/Documents/causalimages-software/misc/docker/binaries:/binaries" \
10 | ghcr.io/openai/codex-universal:latest \
11 | -exc "
12 | set -euo pipefail
13 |
14 | echo \"🔧 Installing system development libraries...\"
15 | # Prepare apt for CRAN
16 | apt-get update -qq && \
17 | apt-get install -y --no-install-recommends \
18 | software-properties-common dirmngr ca-certificates && \
19 |
20 | # Install R base + dev headers + all system libraries
21 | DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
22 | r-base r-base-dev \
23 | build-essential cmake pkg-config git zip \
24 | libcurl4-openssl-dev libssl-dev libxml2-dev \
25 | libgdal-dev libgeos-dev libproj-dev libudunits2-dev \
26 | ca-certificates && \
27 | apt-get clean && \
28 | rm -rf /var/lib/apt/lists/*
29 |
30 | # Clean previous artifacts and prepare directories
31 | rm -rf /binaries/* && mkdir -p /binaries/src /binaries/bin
32 |
33 | # List of CRAN packages in dependency-safe order
34 | pkgs=( \
35 | remotes yaml config ps R6 processx Rcpp RcppEigen RcppTOML rprojroot here jsonlite png rappdirs rlang withr reticulate base64enc magrittr whisker cli glue lifecycle vctrs tidyselect rstudioapi tfruns backports tfautograph tensorflow sp geosphere terra raster rrapply iterators foreach shape glmnet proxy e1071 classInt DBI wk s2 units sf data.table plyr pROC Matrix lattice survival codetools class KernSmooth MASS \
36 | )
37 |
38 | echo \"🔧 Installing Miniconda...\"
39 | MINI=Miniconda3-latest-Linux-x86_64.sh
40 | wget -q https://repo.anaconda.com/miniconda/\${MINI} -O /tmp/\${MINI} && \
41 | bash /tmp/\${MINI} -b -p /opt/miniconda && \
42 | rm /tmp/\${MINI} && \
43 | export PATH=\"/opt/miniconda/bin:\$PATH\" && \
44 | conda config --set always_yes yes --set changeps1 no && \
45 | conda update -q conda
46 |
47 |
48 | # Download CRAN package sources
49 | echo \"📥 Downloading CRAN package sources...\"
50 | cd /binaries/src
51 | for pkg in \"\${pkgs[@]}\"; do
52 | echo \"📦 Downloading \$pkg source...\"
53 | Rscript -e \"download.packages('\$pkg', destdir='/binaries/src', type='source', repos='https://cloud.r-project.org')\" \
54 | || { echo \"❌ Failed to download \$pkg\"; exit 1; }
55 | # verify that a tar.gz actually appeared
56 | if ! compgen -G \"\${pkg}_*.tar.gz\" >/dev/null; then
57 | echo \"❌ No source tarball found for \$pkg after download, likely installed in base-R (recommended package)?.\" >&2
58 | # exit 1
59 | continue
60 | fi
61 | done
62 |
63 | echo \"✅ Downloaded all CRAN sources ✅\"
64 |
65 | # Build CRAN package binaries in correct dependency order
66 | echo \"🛠️ Building CRAN package binaries...\"
67 | cd /binaries/bin
68 | for pkg in \"\${pkgs[@]}\"; do
69 | src_tar=(/binaries/src/\${pkg}_*.tar.gz)
70 | if [[ -f \"\${src_tar[0]}\" ]]; then
71 | echo \"⚙ Building \$pkg from \$(basename \"\${src_tar[0]}\")...\"
72 | R CMD INSTALL --build \"\${src_tar[0]}\"
73 | else
74 | echo \"⚠ Source for \$pkg not found, skipping (likely found in base R install).\"
75 | #exit 1 # <— stop the entire script here
76 | continue
77 | fi
78 | done
79 |
80 | # Build GitHub package binary
81 | cd /binaries/src
82 |
83 | # clone *and* sparsify GitHub repo in one go
84 | git clone \
85 | --depth 1 \
86 | --filter=blob:none \
87 | --sparse \
88 | https://github.com/cjerzak/causalimages-software.git \
89 | /binaries/src/repo
90 |
91 | cd /binaries/src/repo
92 |
93 | # then tell Git which folder you actually want
94 | git sparse-checkout set causalimages
95 |
96 | # now build
97 | R CMD INSTALL --build ./causalimages
98 | mv causalimages_*.tar.gz /binaries/bin/causalimages.tar.gz
99 |
100 | echo \"✅✅ Built binary tarballs ✅✅\"
101 |
102 | echo \"🔧 Building conda backend...\"
103 | export RETICULATE_MINICONDA_PATH=/opt/miniconda
104 | Rscript -e \"
105 | library(causalimages);
106 | causalimages::BuildBackend(conda_env='CausalImagesEnv',
107 | conda='/opt/miniconda/bin/conda')
108 | \"
109 |
110 | echo \"📦 Zipping conda env...\"
111 | zip -r /binaries/CausalImagesEnv.zip /opt/miniconda/envs/CausalImagesEnv
112 |
113 |
114 | echo \"✅✅✅ Built binary tarballs ✅✅✅\"
115 | "
116 |
117 |
118 |
--------------------------------------------------------------------------------
/misc/docker/setup/DockerfileNoCompile:
--------------------------------------------------------------------------------
1 | docker run --platform=linux/amd64 --rm -v "$HOME/Documents/causalimages-software/misc/docker/binaries:/binaries" rocker/r-ver:4.4.0 bash -exc "
2 | set -euo pipefail
3 |
4 | echo \"🔧 Installing system development libraries...\"
5 | apt-get update -qq && \
6 | DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
7 | build-essential cmake pkg-config git libcurl4-openssl-dev libssl-dev libxml2-dev \
8 | libgdal-dev libgeos-dev libproj-dev libudunits2-dev ca-certificates && \
9 | apt-get clean && \
10 | rm -rf /var/lib/apt/lists/*
11 |
12 | # Build GitHub package binary
13 | cd /binaries/src
14 |
15 | # Check and remove existing repo directory if it exists
16 | if [ -d \"/binaries/src/repo\" ]; then
17 | rm -rf /binaries/src/repo
18 | fi
19 |
20 | # clone *and* sparsify GitHub repo in one go
21 | git clone \
22 | --depth 1 \
23 | --filter=blob:none \
24 | --sparse \
25 | https://github.com/cjerzak/causalimages-software.git \
26 | /binaries/src/repo
27 |
28 | cd /binaries/src/repo
29 |
30 | # then tell Git which folder you actually want
31 | git sparse-checkout set causalimages
32 |
33 | # now build
34 | #R CMD build ./causalimages --no-manual --no-build-vignettes
35 | R CMD INSTALL --build ./causalimages
36 | mv causalimages_*.tar.gz /binaries/bin/causalimages.tar.gz
37 |
38 | echo \"✅✅✅ Built binary tarballs ✅✅✅\"
39 | "
40 |
41 |
--------------------------------------------------------------------------------
/misc/docker/setup/FindDependencies.sh:
--------------------------------------------------------------------------------
1 |
2 | docker run --platform=linux/amd64 --rm \
3 | -v "$(pwd)/binaries:/binaries" \
4 | rocker/r-ver:4.4.0 bash -exc "
5 | set -euo pipefail
6 |
7 | echo \"🔧 Installing system development libraries...\"
8 | apt-get update -qq && \
9 | DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
10 | build-essential libcurl4-openssl-dev libssl-dev libxml2-dev \
11 | libgdal-dev libgeos-dev libproj-dev libudunits2-dev ca-certificates \
12 | && apt-get clean && rm -rf /var/lib/apt/lists/*
13 |
14 | echo \"🛠️ Installing remotes from CRAN...\"
15 | Rscript -e \"options(repos = c(CRAN = 'https://cloud.r-project.org')); \
16 | install.packages('remotes', dependencies = TRUE)\"
17 |
18 | echo \"📂 Preparing build directory...\"
19 | mkdir -p /binaries/bin
20 | cd /binaries/bin
21 |
22 | echo \"📦 Building causalimages from GitHub...\"
23 | Rscript -e \"remotes::install_github('cjerzak/causalimages-software', \
24 | subdir='causalimages', \
25 | dependencies=FALSE)\"
26 |
27 | echo \"⚙️ Building causalimages binary...\"
28 | R CMD INSTALL --build causalimages_*.tar.gz
29 |
30 | echo \"✅ Built binary tarballs:\"
31 | ls -1 *.tar.gz
32 | "
33 |
--------------------------------------------------------------------------------
/misc/docker/setup/GenEnv.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # build_and_package.sh
3 | # Run this from the same directory as your Dockerfile.
4 |
5 | set -euo pipefail
6 |
7 | IMAGE_NAME="causalimages-env"
8 | TAR_NAME="${IMAGE_NAME}.tar"
9 | ZIP_NAME="${IMAGE_NAME}.zip"
10 |
11 | echo "1) Building Docker image..."
12 | docker build -t ${IMAGE_NAME} .
13 |
14 | echo "2) Saving image to tarball..."
15 | docker save ${IMAGE_NAME} -o ${TAR_NAME}
16 |
17 | echo "3) Zipping up Dockerfile + tarball..."
18 | zip -r ${ZIP_NAME} Dockerfile ${TAR_NAME}
19 |
20 | echo
21 | echo "Done! Your package is: ${ZIP_NAME}"
22 | echo "Push that ZIP to GitHub, then Codex can unzip and run:"
23 | echo " docker load -i ${TAR_NAME}"
24 | echo " docker run -it ${IMAGE_NAME} bash"
25 |
--------------------------------------------------------------------------------
/misc/docker/setup/GetRDependencyOrder.R:
--------------------------------------------------------------------------------
1 | # Script: GetRDependencyOrder.R
2 | # Purpose: - Get all dependencies for set of base packages
3 | # - Get safe dependency order for creation of compiled R packages
4 | {
5 | # Your initial target packages
6 | pkgs <- c("remotes", "tensorflow", "reticulate", "geosphere", "raster",
7 | "rrapply", "glmnet", "sf", "data.table", "pROC")
8 |
9 | # Fetch the CRAN package database
10 | cran_db <- available.packages(repos = "https://cloud.r-project.org")
11 |
12 | # 1. Get the full recursive set of Depends + Imports
13 | recursive_deps <- tools::package_dependencies(
14 | pkgs,
15 | db = cran_db,
16 | which = c("Depends", "Imports"),
17 | recursive = TRUE
18 | )
19 | all_pkgs <- unique(c(pkgs, unlist(recursive_deps)))
20 |
21 | # Identify base/recommended packages to drop
22 | base_pkgs <- rownames(installed.packages(priority = c("base", "recommended")))
23 |
24 | # 2. Build a direct-dependency map for every package
25 | direct_deps <- tools::package_dependencies(
26 | all_pkgs,
27 | db = cran_db,
28 | which = c("Depends", "Imports"),
29 | recursive = FALSE
30 | )
31 | # Remove any base packages from each dependency vector
32 | direct_deps <- lapply(direct_deps, setdiff, y = base_pkgs)
33 |
34 | # 3. Topological sort via DFS
35 | ordered <- character(0) # will hold the final install order
36 | visited <- character(0) # permanently marked nodes
37 | visiting <- character(0) # temporary marks for cycle detection
38 |
39 | visit <- function(pkg) {
40 | if (pkg %in% visited) return()
41 | if (pkg %in% visiting) {
42 | stop("Circular dependency detected involving: ", pkg)
43 | }
44 | visiting <<- c(visiting, pkg)
45 | for (dep in direct_deps[[pkg]]) {
46 | visit(dep)
47 | }
48 | visiting <<- setdiff(visiting, pkg)
49 | visited <<- c(visited, pkg)
50 | ordered <<- c(ordered, pkg)
51 | }
52 |
53 | # Run DFS on all packages
54 | for (p in all_pkgs) {
55 | visit(p)
56 | }
57 |
58 | # Filter to just the ones you want (if you only want your original pkgs + all their deps)
59 | dependency_safe_order <- intersect(ordered, all_pkgs)
60 | dependency_safe_order <- dependency_safe_order[!dependency_safe_order %in%
61 | rownames(installed.packages(priority = c("base")))]
62 |
63 | # manual addins
64 | dependency_safe_order <- c(
65 | dependency_safe_order[1:which(dependency_safe_order=="Rcpp")],
66 | "RcppEigen",
67 | dependency_safe_order[(which(dependency_safe_order=="Rcpp")+1):length(dependency_safe_order)])
68 |
69 | # Inspect
70 | cat("Install in this order:\n")
71 | cat(paste(dependency_safe_order, collapse = " "), "\n")
72 | # must add in: RcppEigen
73 | }
74 |
75 |
76 |
77 |
78 |
79 |
--------------------------------------------------------------------------------
/misc/notes/MaintainerNotes.txt:
--------------------------------------------------------------------------------
1 | # Maintainer notes for GPU support
2 | #
3 | # These notes summarize our experience configuring JAX with GPU acceleration.
4 | # The package relies heavily on JAX and TensorFlow, so version compatibility is
5 | # important.
6 | # Use Python 3.10 or above for jax-metal. Some JAX operations (e.g.
7 | # `jax$nn$softplus` or random sampling) can segfault when used with float16.
8 | # Sampling parameters with `seeds` (not `vseeds`) and wrapping functions with
9 | # `tf$function(jit_compile = TRUE)` works with jax-metal.
10 |
11 |
12 | # Conda environment setup on M1+ Mac:
13 | conda create -n CausalImagesEnv python==3.11
14 | conda activate CausalImagesEnv
15 | python3 -m pip install tensorflow tensorflow-metal optax equinox jmp tensorflow_probability
16 | python3 -m pip install jax-metal
17 |
18 | # After installation, verify that JAX sees your GPU by running in Python:
19 | # >>> import jax; jax.devices()
20 |
--------------------------------------------------------------------------------
/other/DataverseReadme_confounding.md:
--------------------------------------------------------------------------------
1 | Replication Data for: Integrating Earth Observation Data into Causal Inference: Challenges and Opportunities
2 |
3 | Connor T. Jerzak, Fredrik Johansson, Adel Daoud. Integrating Earth Observation Data into Causal Inference: Challenges and Opportunities. ArXiv Preprint, 2023.
4 |
5 | YandW_mat.csv contains individual-level observational data. In the dataset, LONGITUDE and LATITUDE refer to the approximate geo-referenced long/lat of observational units. Experimental outcomes are stored in Yobs. Treatment variable is stored in Wobs. See the tutorial for more information. The unique image key for each observational unit is saved in UNIQUE_ID.
6 |
7 | Geo-referenced satellite images are saved in
8 | "./Nigeria2000_processed/%s_BAND%s.csv"", where the first "%s" refers to the the image key associated with each observation (saved in UNIQUE_ID in YandW_mat.csv) and BAND%s refers to one of 3 bands in the satellite imagery.
9 |
10 | For more information, see: https://github.com/cjerzak/causalimages-software/
11 |
--------------------------------------------------------------------------------
/other/DataverseReadme_heterogeneity.md:
--------------------------------------------------------------------------------
1 | Replication Data for: Image-based Treatment Effect Heterogeneity
2 |
3 | Connor Thomas Jerzak, Fredrik Daniel Johansson, Adel Daoud Proceedings of the Second Conference on Causal Learning and Reasoning, PMLR 213:531-552, 2023.
4 |
5 | UgandaDataProcessed.csv contains individual-level data from the YOP experiment. In the dataset, geo_long and geo_lat refer to the approximate geo-referenced long/lat of experimental units. The variable, geo_long_lat_key, refers to the image key associated with each location. Experimental outcomes are stored in Yobs. Treatment variable is stored in Wobs. See the tutorial for more information.
6 |
7 | UgandaGeoKeyMat.csv contains information on keys linking to satellite images for all of Uganda for the transportability analysis.
8 |
9 | Geo-referenced satellite images are saved in "./Uganda2000_processed/GeoKey%s_BAND%s.csv", where GeoKey%s denotes the image key associated with each observation and BAND%s refers to one of 3 bands in the satellite imagery.
10 |
11 | For more information, see: https://github.com/cjerzak/causalimages-software/blob/main/tutorials/AnalyzeImageHeterogeneity_FullTutorial.R
12 |
--------------------------------------------------------------------------------
/other/DataverseTutorial_confounding.R:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env Rscript
2 |
3 | ################################
4 | # For an up-to-date version of this tutorial, see
5 | #
6 | ################################
7 |
8 | # set new wd
9 |
--------------------------------------------------------------------------------
/other/DataverseTutorial_heterogeneity.R:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env Rscript
2 |
3 | ################################
4 | # For an up-to-date version of this tutorial, see
5 | # https://github.com/cjerzak/causalimages-software/blob/main/tutorials/AnalyzeImageHeterogeneity_FullTutorial.R
6 | ################################
7 |
8 | # set new wd
9 | setwd(sprintf('%s/Public Replication Data, YOP Experiment/',
10 | gsub(download_folder,pattern="\\.zip",replace="")))
11 |
12 | # see directory contents
13 | list.files()
14 |
15 | # images saved here
16 | list.files( "./Uganda2000_processed" )
17 |
18 | # individual-level data
19 | UgandaDataProcessed <- read.csv( "./UgandaDataProcessed.csv" )
20 |
21 | # unit-level covariates (many covariates are subject to missingness!)
22 | dim( UgandaDataProcessed )
23 | table( UgandaDataProcessed$female )
24 | table( UgandaDataProcessed$age )
25 |
26 | # approximate longitude + latitude for units
27 | UgandaDataProcessed$geo_long
28 | UgandaDataProcessed$geo_lat
29 |
30 | # image keys of units (use for referencing satellite images)
31 | UgandaDataProcessed$geo_long_lat_key
32 |
33 | # an experimental outcome
34 | UgandaDataProcessed$Yobs
35 |
36 | # treatment variable
37 | UgandaDataProcessed$Wobs
38 |
39 | # information on keys linking to satellite images for all of Uganda
40 | # (not just experimental context, use for constructing transportability maps)
41 | UgandaGeoKeyMat <- read.csv( "./UgandaGeoKeyMat.csv" )
42 | tail( UgandaGeoKeyMat )
43 |
44 | # Geo-referenced satellite images are saved in
45 | # "./Uganda2000_processed/GeoKey%s_BAND%s.csv",
46 | # where GeoKey%s denotes the image key associated with each observation and
47 | # BAND%s refers to one of 3 bands in the satellite imagery.
48 | # See https://github.com/cjerzak/causalimages-software/blob/main/tutorials/AnalyzeImageHeterogeneity_FullTutorial.R
49 | # for up-to-date useage information.
50 |
--------------------------------------------------------------------------------
/other/PackageRunChecks.R:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env Rscript
2 |
3 | ################################
4 | # master test suite
5 | # note: this code is primarily for the use of the development team
6 | ################################
7 |
8 | # in process
9 |
--------------------------------------------------------------------------------
/tutorials/AnalyzeImageConfounding_Tutorial_Advanced.R:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env Rscript
2 | {
3 | ################################
4 | # Image confounding tutorial using causalimages
5 | ################################
6 |
7 | # clean workspace
8 | rm(list=ls()); options(error = NULL)
9 |
10 | # setup environment
11 | if(Sys.getenv()["RSTUDIO_USER_IDENTITY"] == "cjerzak"){
12 | setwd("~/Downloads/")
13 | }
14 | if(Sys.getenv()["RSTUDIO_USER_IDENTITY"] != "cjerzak"){
15 | setwd("./")
16 | # or set directory as desired
17 | }
18 |
19 | # remote install latest version of the package if needed
20 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages")
21 |
22 | # local install for development team
23 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F)
24 |
25 | # build backend you haven't ready (run this only once upon (re)installing causalimages!)
26 | # causalimages::BuildBackend()
27 |
28 | # load in package
29 | library( causalimages )
30 |
31 | # resave TfRecords?
32 | reSaveTFRecord <- FALSE
33 |
34 | # load in tutorial data
35 | data( CausalImagesTutorialData )
36 |
37 | # mean imputation for toy example in this tutorial
38 | X <- apply(X[,-1],2,function(zer){
39 | zer[is.na(zer)] <- mean( zer,na.rm = T ); return( zer )
40 | })
41 |
42 | # select observation subset to make the tutorial quick
43 | set.seed(4321L);take_indices <- unlist( tapply(1:length(obsW),obsW,function(zer){sample(zer, 100)}) )
44 |
45 | # example acquire image function (loading from memory)
46 | # in general, you'll want to write a function that returns images
47 | # that saved disk associated with keys
48 | acquireImageFxn <- function(keys){
49 | # here, the function input keys
50 | # refers to the unit-associated image keys
51 | # we also tweak the image dimensions for testing purposes
52 | #m_ <- FullImageArray[match(keys, KeysOfImages),c(1:35,1:35),c(1:35,1:35),1:2] # test with two channels
53 | m_ <- FullImageArray[match(keys, KeysOfImages),c(1:35,1:35),c(1:35,1:35),] # test with three channels
54 | #m_ <- FullImageArray[match(keys, KeysOfImages),c(1:35,1:35),c(1:35,1:35),c(1:3,1:2)] # test with five channels
55 |
56 | # if keys == 1, add the batch dimension so output dims are always consistent
57 | # (here in image case, dims are batch by height by width by channel)
58 | if(length(keys) == 1){ m_ <- array(m_,dim = c(1L,dim(m_)[1],dim(m_)[2],dim(m_)[3])) }
59 |
60 | return( m_ )
61 | }
62 |
63 | # perform causal inference with image and tabular confounding
64 | if(T == T){
65 | # write tf record
66 | TFRecordName_im <- "./TutorialData_im.tfrecord"
67 | if( reSaveTFRecord ){
68 | causalimages::WriteTfRecord(
69 | file = TFRecordName_im,
70 | uniqueImageKeys = unique(KeysOfObservations[ take_indices ]),
71 | acquireImageFxn = acquireImageFxn
72 | )
73 | }
74 |
75 | # perform causal inference with image-based and tabular confounding
76 | if(T == T){
77 | for(imageModelClass in (c("VisionTransformer","CNN"))){
78 | for(optimizeImageRep in c(T,F)){
79 | print(sprintf("Image confounding analysis & optimizeImageRep: %s & imageModelClass: %s",optimizeImageRep, imageModelClass))
80 | ImageConfoundingAnalysis <- causalimages::AnalyzeImageConfounding(
81 | obsW = obsW[ take_indices ],
82 | obsY = obsY[ take_indices ],
83 | X = X[ take_indices,apply(X[ take_indices,],2,sd)>0],
84 | long = LongLat$geo_long[ take_indices ], # optional argument
85 | lat = LongLat$geo_lat[ take_indices ], # optional argument
86 | imageKeysOfUnits = KeysOfObservations[ take_indices ],
87 | file = TFRecordName_im,
88 |
89 | batchSize = 16L,
90 | nBoot = 5L,
91 | optimizeImageRep = optimizeImageRep,
92 | imageModelClass = imageModelClass,
93 | nDepth_ImageRep = ifelse(optimizeImageRep, yes = 1L, no = 1L),
94 | nWidth_ImageRep = as.integer(2L^6),
95 | learningRateMax = 0.001, nSGD = 300L, #
96 | dropoutRate = NULL, # 0.1,
97 | plotBands = c(1,2,3),
98 | plotResults = T, figuresTag = "ConfoundingImTutorial",
99 | figuresPath = "./")
100 | try(dev.off(), T)
101 | #ImageConfoundingAnalysis$ModelEvaluationMetrics
102 | }
103 | }
104 | # ATE estimate (image confounder adjusted)
105 | ImageConfoundingAnalysis$tauHat_propensityHajek
106 |
107 | # ATE se estimate (image confounder adjusted)
108 | ImageConfoundingAnalysis$tauHat_propensityHajek_se
109 |
110 | # some out-of-sample evaluation metrics
111 | ImageConfoundingAnalysis$ModelEvaluationMetrics
112 | }
113 | }
114 |
115 | # perform causal inference with image sequence and tabular confounding
116 | gc()
117 | if(T == T){
118 | acquireVideoRep <- function(keys) {
119 | # Note: this is a toy function generating image representations
120 | # that simply reuse a single temporal slice. In practice, we will
121 | # weant to read in images of different time periods.
122 |
123 | # Get image data as an array from disk
124 | tmp <- acquireImageFxn(keys)
125 |
126 | # Expand dimensions: we create a new dimension at the start
127 | tmp <- array(tmp, dim = c(1, dim(tmp)))
128 |
129 | # Transpose dimensions to get the target order
130 | tmp <- aperm(tmp, c(2, 1, 3, 4, 5))
131 |
132 | # Swap image dimensions to see variability across time
133 | tmp_ <- aperm(tmp, c(1, 2, 4, 3, 5))
134 |
135 | # Concatenate along the second axis
136 | tmp <- abind::abind(tmp, tmp, tmp_, tmp_, along = 2)
137 |
138 | return(tmp)
139 | }
140 |
141 | # sanity check dimensions
142 | print(dim(acquireVideoRep(unique(KeysOfObservations[ take_indices ])[1:2])))
143 |
144 | # write tf record
145 | TFRecordName_imSeq <- "./TutorialData_imSeq.tfrecord"
146 | if( reSaveTFRecord ){
147 | causalimages::WriteTfRecord(
148 | file = TFRecordName_imSeq,
149 | uniqueImageKeys = unique(KeysOfObservations[ take_indices ]),
150 | acquireImageFxn = acquireVideoRep,
151 | writeVideo = TRUE)
152 | }
153 |
154 | for(imageModelClass in c("VisionTransformer","CNN")){
155 | for(optimizeImageRep in c(T, F)){
156 | print(sprintf("Image seq confounding analysis & optimizeImageRep: %s & imageModelClass: %s",optimizeImageRep, imageModelClass))
157 | ImageSeqConfoundingAnalysis <- causalimages::AnalyzeImageConfounding(
158 | obsW = obsW[ take_indices ],
159 | obsY = obsY[ take_indices ],
160 | X = X[ take_indices,apply(X[ take_indices,],2,sd)>0],
161 | long = LongLat$geo_long[ take_indices ],
162 | lat = LongLat$geo_lat[ take_indices ],
163 | file = TFRecordName_imSeq, dataType = "video",
164 | imageKeysOfUnits = KeysOfObservations[ take_indices ],
165 |
166 | # model specifics
167 | batchSize = 16L,
168 | optimizeImageRep = optimizeImageRep,
169 | imageModelClass = imageModelClass,
170 | nDepth_ImageRep = ifelse(optimizeImageRep, yes = 1L, no = 1L),
171 | nWidth_ImageRep = as.integer(2L^7),
172 | learningRateMax = 0.001, nSGD = 300L, #
173 | nBoot = 5L,
174 | plotBands = c(1,2,3),
175 | plotResults = T, figuresTag = "ConfoundingImSeqTutorial",
176 | figuresPath = "./") # figures saved here
177 | try(dev.off(), T)
178 | }
179 | }
180 |
181 | # ATE estimate (image confounder adjusted)
182 | ImageSeqConfoundingAnalysis$tauHat_propensityHajek
183 |
184 | # ATE se estimate (image seq confounder adjusted)
185 | ImageSeqConfoundingAnalysis$tauHat_propensityHajek_se
186 |
187 | # some out-of-sample evaluation metrics
188 | ImageSeqConfoundingAnalysis$ModelEvaluationMetrics
189 | print("Done with confounding tutorial!")
190 | }
191 | }
192 |
--------------------------------------------------------------------------------
/tutorials/AnalyzeImageConfounding_Tutorial_Base.R:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env Rscript
2 | {
3 | ################################
4 | # Image confounding tutorial using causalimages
5 | ################################
6 |
7 | # clean workspace
8 | rm(list=ls()); options(error = NULL)
9 |
10 | # setup environment
11 | if(Sys.getenv()["RSTUDIO_USER_IDENTITY"] == "cjerzak"){
12 | setwd("~/Downloads/")
13 | }
14 | if(Sys.getenv()["RSTUDIO_USER_IDENTITY"] != "cjerzak"){
15 | setwd("./")
16 | # or set directory as desired
17 | }
18 |
19 | # remote install latest version of the package if needed
20 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages")
21 |
22 | # local install for development team
23 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F)
24 |
25 | # build backend you haven't ready (run this only once upon (re)installing causalimages!)
26 | # causalimages::BuildBackend()
27 |
28 | # load in package
29 | library( causalimages )
30 |
31 | # resave TfRecords?
32 | reSaveTFRecord <- TRUE
33 |
34 | # load in tutorial data
35 | data( CausalImagesTutorialData )
36 |
37 | # mean imputation for toy example in this tutorial
38 | X <- apply(X[,-1],2,function(zer){
39 | zer[is.na(zer)] <- mean( zer,na.rm = T ); return( zer )
40 | })
41 |
42 | # select observation subset to make the tutorial quick
43 | set.seed(4321L);take_indices <-
44 | unlist( tapply(1:length(obsW),obsW,function(zer){sample(zer, 300)}) )
45 |
46 | # perform causal inference with image and tabular confounding
47 | {
48 | # example acquire image function (loading from memory)
49 | # in general, you'll want to write a function that returns images
50 | # that saved disk associated with keys
51 | acquireImageFxn <- function(keys){
52 | # here, the function input keys
53 | # refers to the unit-associated image keys
54 | # we also tweak the image dimensions for testing purposes
55 | #m_ <- FullImageArray[match(keys, KeysOfImages),c(1:35,1:35),c(1:35,1:35),1:2] # test with two channels
56 | m_ <- FullImageArray[match(keys, KeysOfImages),c(1:35,1:35),c(1:35,1:35),] # test with three channels
57 | #m_ <- FullImageArray[match(keys, KeysOfImages),c(1:35,1:35),c(1:35,1:35),c(1:3,1:2)] # test with five channels
58 |
59 | # if keys == 1, add the batch dimension so output dims are always consistent
60 | # (here in image case, dims are batch by height by width by channel)
61 | if(length(keys) == 1){ m_ <- array(m_,dim = c(1L,dim(m_)[1],dim(m_)[2],dim(m_)[3])) }
62 |
63 | return( m_ )
64 | }
65 |
66 | # look at one of the images
67 | causalimages::image2( FullImageArray[1,,,1] )
68 |
69 | # write tf record
70 | # AnalyzeImageConfounding can efficiently stream batched image data from disk
71 | # (avoiding repeated in-memory loads and speeding up I/O during model training)
72 | TFRecordName_im <- "./TutorialData_im.tfrecord"
73 | if( reSaveTFRecord ){
74 | causalimages::WriteTfRecord(
75 | file = TFRecordName_im,
76 | uniqueImageKeys = unique(KeysOfObservations[ take_indices ]),
77 | acquireImageFxn = acquireImageFxn
78 | )
79 | }
80 |
81 | # perform causal inference with image-based and tabular confounding
82 | imageModelClass <- "VisionTransformer"
83 | optimizeImageRep <- TRUE # train the model to predict treatment, for use in IPW
84 | print(sprintf("Image confounding analysis & optimizeImageRep: %s & imageModelClass: %s",optimizeImageRep, imageModelClass))
85 | ImageConfoundingAnalysis <- causalimages::AnalyzeImageConfounding(
86 | # input data
87 | obsW = obsW[ take_indices ],
88 | obsY = obsY[ take_indices ],
89 | X = X[ take_indices,apply(X[ take_indices,],2,sd)>0],
90 | long = LongLat$geo_long[ take_indices ], # optional argument
91 | lat = LongLat$geo_lat[ take_indices ], # optional argument
92 | imageKeysOfUnits = KeysOfObservations[ take_indices ],
93 | file = TFRecordName_im,
94 |
95 | # modeling parameters
96 | batchSize = 16L,
97 | nBoot = 5L,
98 | optimizeImageRep = optimizeImageRep,
99 | imageModelClass = imageModelClass,
100 | nDepth_ImageRep = ifelse(optimizeImageRep, yes = 1L, no = 1L),
101 | nWidth_ImageRep = as.integer(2L^6),
102 | learningRateMax = 0.001, nSGD = 300L, #
103 | dropoutRate = NULL, # 0.1,
104 | plotBands = c(1,2,3),
105 | plotResults = T, figuresTag = "ConfoundingImTutorial",
106 | figuresPath = "./")
107 | try(dev.off(), T)
108 |
109 | # Analyze in/out sample metrics
110 | ImageConfoundingAnalysis$ModelEvaluationMetrics
111 |
112 | # ATE estimate (image confounder adjusted)
113 | ImageConfoundingAnalysis$tauHat_propensityHajek
114 |
115 | # ATE se estimate (image confounder adjusted)
116 | ImageConfoundingAnalysis$tauHat_propensityHajek_se
117 |
118 | # some out-of-sample evaluation metrics
119 | ImageConfoundingAnalysis$ModelEvaluationMetrics
120 | }
121 | }
122 |
--------------------------------------------------------------------------------
/tutorials/AnalyzeImageConfounding_Tutorial_EmbeddingsOnly.R:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env Rscript
2 | {
3 | ################################
4 | # Image tutorial using embeddings
5 | ################################
6 |
7 | # clean workspace
8 | rm(list=ls()); options(error = NULL)
9 |
10 | # setup environment
11 | setwd("~/Downloads/")
12 |
13 | # fetch image embeddings (requires an Internet connection)
14 | m_embeddings <- read.csv("https://huggingface.co/datasets/cjerzak/PCI_TutorialMaterial/resolve/main/nDepthIS1_analysisTypeISheterogeneity_imageModelClassISVisionTransformer_optimizeImageRepISclip-rsicd_MaxImageDimsIS64_dataTypeISimage_monte_iIS1_applicationISUganda_perturbCenterISFALSE.csv")
15 |
16 | # load data
17 | library(causalimages)
18 | data( CausalImagesTutorialData )
19 |
20 | # image embedding dimensions - one image for each set of geo-located units
21 | # geo-location done by village name
22 | # pre-treatment image covariates via embeddings
23 | # obtained from an EO-fined tuned CLIP model
24 | # https://huggingface.co/flax-community/clip-rsicd-v2
25 | dim(m_embeddings)
26 |
27 | # correlation matrix
28 | cor(m_embeddings)
29 |
30 | # analyze via pca
31 | pca_anlaysis <- predict(prcomp(m_embeddings,scale = T, center = T))
32 |
33 | # first two principal components
34 | plot(pca_anlaysis[,1:2])
35 |
36 | # treatment indicator (recipient of YOP cash transfer)
37 | obsW
38 |
39 | # outcome - measure of human capital post intervention
40 | obsY
41 |
42 | # run double ML for image deconfounding
43 | # install.packages(c("DoubleML","mlr3","ranger"))
44 | library(DoubleML) # double/debiased ML framework
45 | library(mlr3) # core mlr3 infrastructure
46 | library(mlr3learners) # access a wide range of learners
47 | library(ranger)
48 |
49 | # Combine outcome, treatment, and image embeddings into a single data.frame
50 | df_dml <- data.frame(
51 | Y = obsY, # outcome vector
52 | W = obsW, # treatment indicator
53 | m_embeddings # precomputed image embeddings (one column per embedding dimension)
54 | )
55 |
56 | # Create a DoubleMLData object
57 | dml_data <- DoubleMLData$new(
58 | data = df_dml,
59 | y_col = "Y",
60 | d_cols = "W",
61 | x_cols = colnames(m_embeddings)
62 | )
63 |
64 | # Specify learners for the nuisance functions
65 | learner_g <- lrn("regr.ranger") # regression learner for E[Y|X,W]
66 | learner_m <- lrn("classif.ranger", predict_type = "prob") # classification learner for P[W=1|X]
67 |
68 | # Instantiate the partially linear regression (PLR) DML model
69 | dml_plr <- DoubleMLPLR$new(
70 | dml_data,
71 | ml_g = learner_g,
72 | ml_m = learner_m,
73 | n_folds = 5 # number of folds for cross-fitting
74 | )
75 |
76 | # fit model
77 | dml_plr$fit()
78 |
79 | # extract results
80 | dml_plr$summary() # prints estimated ATE and standard error
81 |
82 | }
83 |
--------------------------------------------------------------------------------
/tutorials/AnalyzeImageConfounding_Tutorial_Simulation.R:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env Rscript
2 | {
3 | ################################
4 | # Simulated causal data with images using causalimages
5 | ################################
6 |
7 | # clean workspace
8 | rm(list=ls()); options(error = NULL)
9 |
10 | # setup environment
11 | # setwd as needed
12 |
13 | # Install causalimages if not installed
14 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages")
15 |
16 | # load package
17 | library(causalimages)
18 |
19 | # Build backend if not already done
20 | # causalimages::BuildBackend()
21 |
22 | # Simulation parameters
23 | n <- 1000 # number of units
24 | tau <- 2.0 # true ATE (constant treatment effect)
25 | image_dim <- 64L # image size
26 |
27 | # Generate synthetic data
28 | set.seed(12321)
29 |
30 | # Latent confounder C that generates images
31 | C <- rnorm(n)
32 |
33 | # Tabular confounder X
34 | X <- matrix(rnorm(n * 2), ncol = 2)
35 | colnames(X) <- c("X1", "X2")
36 |
37 | # Note: we have a dual confounder setup here (tabular + image confounders)
38 |
39 | # Treatment assignment: depends on C and X
40 | logit_prob <- C + X[,1] + X[,2]
41 | prob_W <- 1 / (1 + exp(-logit_prob))
42 | obsW <- rbinom(n,
43 | size = 1,
44 | prob = prob_W)
45 |
46 | # Outcome: depends on both confounders and treatment and noise
47 | obsY <- tau * obsW + C + X[,1] + X[,2] + rnorm(n)
48 |
49 | # Generate synthetic images based on C
50 | # For simplicity, create images where intensity depends on C
51 | # We'll make a 32x32x3 array, with red channel proportional to C
52 | KeysOfImages <- paste0("img_", 1:n) # unique keys
53 | FullImageArray <- array(0, dim = c(n, image_dim, image_dim, 3))
54 | for(i in 1:n) {
55 | base_intensity <- (C[i] - min(C)) / (max(C) - min(C)) # normalize to [0,1]
56 | FullImageArray[i,,,1] <- base_intensity # red channel
57 | FullImageArray[i,,,2] <- runif(1) # green random
58 | FullImageArray[i,,,3] <- runif(1) # blue random
59 | # Add a non-causal pattern to image, e.g., a gradient
60 | for(row in 1:image_dim) {
61 | FullImageArray[i, row,,] <- FullImageArray[i, row,,] * log(1+row / image_dim)^runif(n=1,min=0,max=1/i)
62 | }
63 | }
64 |
65 | # Keys of observations
66 | KeysOfObservations <- KeysOfImages # one-to-one for simplicity
67 |
68 | # Define acquireImageFxn
69 | acquireImageFxn <- function(keys) {
70 | m_ <- FullImageArray[match(keys, KeysOfImages),,,]
71 | if(length(keys) == 1) {
72 | m_ <- array(m_, dim = c(1L, dim(m_)[1], dim(m_)[2], dim(m_)[3]))
73 | }
74 | return(m_)
75 | }
76 |
77 | # run once
78 | # causalimages::BuildBackend()
79 |
80 | # Look at one image
81 | causalimages::image2(FullImageArray[sample(1:1000,1),,,1])
82 |
83 | # Write TFRecord (optional, but as in tutorial)
84 | reSaveTFRecord <- TRUE
85 | TFRecordName_im <- "~/Downloads/SimulatedData_im.tfrecord"
86 | if(reSaveTFRecord) {
87 | causalimages::WriteTfRecord(
88 | file = TFRecordName_im,
89 | uniqueImageKeys = unique(KeysOfObservations),
90 | acquireImageFxn = acquireImageFxn
91 | )
92 | }
93 |
94 | # Perform causal inference with image and tabular confounding
95 | imageModelClass <- "VisionTransformer"
96 | optimizeImageRep <- TRUE
97 | print(sprintf("Image confounding analysis & optimizeImageRep: %s & imageModelClass: %s", optimizeImageRep, imageModelClass))
98 |
99 | ImageConfoundingAnalysis <- causalimages::AnalyzeImageConfounding(
100 | # input data
101 | obsW = obsW,
102 | obsY = obsY,
103 | X = X[, apply(X, 2, sd) > 0],
104 | imageKeysOfUnits = KeysOfObservations,
105 | file = TFRecordName_im,
106 |
107 | # modeling parameters
108 | batchSize = 16L,
109 | nBoot = 5L,
110 | optimizeImageRep = TRUE,
111 | imageModelClass = imageModelClass,
112 | nDepth_ImageRep = 4L,
113 | nWidth_ImageRep = as.integer(2^8),
114 | learningRateMax = 0.0001, nSGD = 300L,
115 | dropoutRate = 0.1,
116 | plotBands = c(1,2,3),
117 | plotResults = TRUE, figuresTag = "SimConfoundingIm",
118 | figuresPath = "./"
119 | )
120 | try(dev.off(), TRUE)
121 |
122 | # Output results
123 | print("True ATE:")
124 | print(tau)
125 |
126 | print("Estimated ATE (no confounder adjustment):")
127 | print(ImageConfoundingAnalysis$tauHat_diffInMeans)
128 |
129 | print("Estimated ATE (image+tabular confounder adjusted):")
130 | print(ImageConfoundingAnalysis$tauHat_propensityHajek)
131 |
132 |
133 | print("Estimated SE:")
134 | print(ImageConfoundingAnalysis$tauHat_propensityHajek_se)
135 |
136 | print("Model Evaluation Metrics:")
137 | print(ImageConfoundingAnalysis$ModelEvaluationMetrics)
138 |
139 | # Comparison
140 | bias_naive <- ImageConfoundingAnalysis$tauHat_diffInMeans - tau
141 | print("Bias (diff in means):")
142 | print(bias_naive) # around 2
143 |
144 | bias_adjusted <- ImageConfoundingAnalysis$tauHat_propensityHajek - tau
145 | print("Bias (image+tabular deconfounding):")
146 | print(bias_adjusted) # around 0.4
147 | }
148 |
--------------------------------------------------------------------------------
/tutorials/AnalyzeImageHeterogeneity_Tutorial.R:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env Rscript
2 | {
3 | ################################
4 | # Image heterogeneity tutorial using causalimages
5 | ################################
6 |
7 | # clean environment
8 | rm(list = ls()); options( error = NULL )
9 |
10 | # remote install latest version of the package
11 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages")
12 |
13 | # local install for development team
14 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F)
15 |
16 | # build backend you haven't ready:
17 | # causalimages::BuildBackend()
18 |
19 | # run code if downloading data for the first time
20 | download_folder <- "~/Downloads/UgandaAnalysis.zip"
21 | reSaveTfRecords <- T
22 | if( reDownloadRawData <- F ){
23 | # specify uganda data URL
24 | uganda_data_url <- "https://dl.dropboxusercontent.com/s/xy8xvva4i46di9d/Public%20Replication%20Data%2C%20YOP%20Experiment.zip?dl=0"
25 | download_folder <- "~/Downloads/UgandaAnalysis.zip"
26 |
27 | # download into new directory
28 | download.file( uganda_data_url, destfile = download_folder)
29 |
30 | # unzip and list files
31 | unzip(download_folder, exdir = "~/Downloads/UgandaAnalysis")
32 | }
33 |
34 | # load in package
35 | library( causalimages )
36 |
37 | # set new wd
38 | setwd(sprintf('%s/Public Replication Data, YOP Experiment/',
39 | gsub(download_folder,pattern="\\.zip",replace="")))
40 |
41 | # see directory contents
42 | list.files()
43 |
44 | # images saved here
45 | list.files( "./Uganda2000_processed" )
46 |
47 | # individual-level data
48 | UgandaDataProcessed <- read.csv( "./UgandaDataProcessed.csv" )
49 |
50 | # unit-level covariates (many covariates are subject to missingness!)
51 | dim( UgandaDataProcessed )
52 | table( UgandaDataProcessed$age )
53 |
54 | # approximate longitude + latitude for units
55 | head( cbind(UgandaDataProcessed$geo_long, UgandaDataProcessed$geo_lat) )
56 |
57 | # image keys of units (use for referencing satellite images)
58 | UgandaDataProcessed$geo_long_lat_key
59 |
60 | # an experimental outcome
61 | UgandaDataProcessed$Yobs
62 |
63 | # treatment variable
64 | UgandaDataProcessed$Wobs
65 |
66 | # information on keys linking to satellite images for all of Uganda
67 | # (not just experimental context, use for constructing transportability maps)
68 | UgandaGeoKeyMat <- read.csv( "./UgandaGeoKeyMat.csv" )
69 |
70 | # set outcome
71 | UgandaDataProcessed$Yobs <- UgandaDataProcessed$human_capital_index_e
72 |
73 | # drop observations with NAs in key variables
74 | # (you can also use a multiple imputation strategy)
75 | UgandaDataProcessed <- UgandaDataProcessed[!is.na(UgandaDataProcessed$Yobs) &
76 | !is.na(UgandaDataProcessed$Wobs) &
77 | !is.na(UgandaDataProcessed$geo_lat) , ]
78 |
79 | # sanity checks
80 | {
81 | # write a function that reads in images as saved and process them into an array
82 | NBANDS <- 3L
83 | imageHeight <- imageWidth <- 351L # pixel height/width
84 | acquireImageRep <- function( keys ){
85 | # keys <- unique(UgandaDataProcessed$geo_long_lat_key)[1:5]
86 | # initialize an array shell to hold image slices
87 | array_shell <- array(NA, dim = c(1L, imageHeight, imageWidth, NBANDS))
88 | # dim(array_shell)
89 |
90 | # iterate over keys:
91 | # -- images are referenced to keys
92 | # -- keys are referenced to units (to allow for duplicate images uses)
93 | array_ <- sapply(keys, function(key_) {
94 | # iterate over all image bands (NBANDS = 3 for RBG images)
95 | for (band_ in 1:NBANDS) {
96 | # place the image in the correct place in the array
97 | array_shell[,,,band_] <-
98 | as.matrix(data.table::fread(
99 | input = sprintf("./Uganda2000_processed/GeoKey%s_BAND%s.csv", key_, band_), header = FALSE)[-1,])
100 | }
101 | return(array_shell)
102 | }, simplify = "array")
103 |
104 | # return the array in the format c(nBatch, imageWidth, imageHeight, nChannels)
105 | # ensure that the dimensions are correctly ordered for further processing
106 | if(length(keys) > 1){ array_ <- aperm(array_[1,,,,], c(4, 1, 2, 3) ) }
107 | if(length(keys) == 1){
108 | array_ <- aperm(array_, c(1,5, 2, 3, 4))
109 | array_ <- array(array_, dim(array_)[-1])
110 | }
111 | return(array_)
112 | }
113 |
114 | # try out the function
115 | # note: some units are co-located in same area (hence, multiple observations per image key)
116 | ImageBatch <- acquireImageRep( UgandaDataProcessed$geo_long_lat_key[ check_indices <- c(1, 20, 50, 101) ])
117 | acquireImageRep( UgandaDataProcessed$geo_long_lat_key[ check_indices[1:2] ] )
118 |
119 | # sanity checks in the analysis of earth observation data are essential
120 | # check that images are centered around correct location
121 | causalimages::image2( as.array(ImageBatch)[1,,,1] )
122 | UgandaDataProcessed$geo_long[check_indices[1]]
123 | UgandaDataProcessed$geo_lat[check_indices[1]]
124 | # check against google maps to confirm correctness
125 | # https://www.google.com/maps/place/1%C2%B018'16.4%22N+34%C2%B005'15.1%22E/@1.3111951,34.0518834,10145m/data=!3m1!1e3!4m4!3m3!8m2!3d1.3045556!4d34.0875278?entry=ttu
126 |
127 | # scramble data (important for reading into causalimages::WriteTfRecord
128 | # to ensure no systematic biases in data sequence with model training
129 | set.seed(144L); UgandaDataProcessed <- UgandaDataProcessed[sample(1:nrow(UgandaDataProcessed)),]
130 | }
131 |
132 | # Image heterogeneity example
133 | if(T == ){
134 | # write a tf records repository
135 | # whenever changes are made to the input data to AnalyzeImageHeterogeneity, WriteTfRecord() should be re-run
136 | # to ensure correct ordering of data
137 | tfrecord_loc <- "~/Downloads/UgandaExample.tfrecord"
138 | if( reSaveTfRecords ){
139 | causalimages::WriteTfRecord(
140 | file = tfrecord_loc,
141 | uniqueImageKeys = unique(UgandaDataProcessed$geo_long_lat_key),
142 | acquireImageFxn = acquireImageRep )
143 | }
144 |
145 | for(ImageModelClass in c("VisionTransformer","CNN")){
146 | for(optimizeImageRep in c(T, F)){
147 | print(sprintf("Image hetero analysis & optimizeImageRep: %s",optimizeImageRep))
148 | # perform image heterogeneity analysis (toy example)
149 | ImageHeterogeneityResults <- causalimages::AnalyzeImageHeterogeneity(
150 | # data inputs
151 | obsW = UgandaDataProcessed$Wobs,
152 | obsY = UgandaDataProcessed$Yobs,
153 | X = matrix(rnorm(length(UgandaDataProcessed$Yobs)*10),ncol=10),
154 | imageKeysOfUnits = UgandaDataProcessed$geo_long_lat_key,
155 | file = tfrecord_loc, # location of tf record (use absolute file paths)
156 | lat = UgandaDataProcessed$geo_lat, # not required but helpful for dealing with redundant locations in EO data
157 | long = UgandaDataProcessed$geo_long, # not required but helpful for dealing with redundant locations in EO data
158 |
159 | # inputs to control where visual results are saved as PDF or PNGs
160 | # (these image grids are large and difficult to display in RStudio's interactive mode)
161 | plotResults = T,
162 | figuresPath = "~/Downloads/HeteroTutorial", # where to write analysis figures
163 | figuresTag = "HeterogeneityImTutorial",plotBands = 1L:3L,
164 |
165 | # optional arguments for generating transportability maps
166 | # here, we leave those NULL for simplicity
167 | transportabilityMat = NULL, #
168 |
169 | # other modeling options
170 | imageModelClass = ImageModelClass,
171 | nSGD = 5L, # make this larger for real applications (e.g., 2000L)
172 | nDepth_ImageRep = ifelse(optimizeImageRep, yes = 1L, no = 1L),
173 | nWidth_ImageRep = as.integer(2L^6),
174 | optimizeImageRep = optimizeImageRep,
175 | batchSize = 8L, # make this larger for real application (e.g., 50L)
176 | kClust_est = 2 # vary depending on problem. Usually < 5
177 | )
178 | try(dev.off(), T)
179 | }
180 | }
181 | }
182 |
183 | # image sequence example
184 | if(T == T){
185 | acquireVideoRep <- function(keys) {
186 | # Get image data as an array from disk
187 | tmp <- acquireImageRep(keys)
188 |
189 | # Expand dimensions: we create a new dimension at the start
190 | tmp <- array(tmp, dim = c(1, dim(tmp)))
191 |
192 | # Transpose dimensions to get the required order
193 | tmp <- aperm(tmp, c(2, 1, 3, 4, 5))
194 |
195 | # Swap image dimensions to see variability across time
196 | tmp_ <- aperm(tmp, c(1, 2, 4, 3, 5))
197 |
198 | # Concatenate along the second axis
199 | tmp <- abind::abind(tmp, tmp_, along = 2)
200 |
201 | return(tmp)
202 | }
203 |
204 | # write the tf records repository
205 | tfrecord_loc_imSeq <- "~/Downloads/UgandaExampleVideo.tfrecord"
206 | if(reSaveTfRecords){
207 | causalimages::WriteTfRecord( file = tfrecord_loc_imSeq,
208 | uniqueImageKeys = unique(UgandaDataProcessed$geo_long_lat_key),
209 | acquireImageFxn = acquireVideoRep, writeVideo = T )
210 | }
211 |
212 | for(ImageModelClass in (c("VisionTransformer","CNN"))){
213 | for(optimizeImageRep in c(T, F)){
214 | print(sprintf("Image seq hetero analysis & optimizeImageRep: %s",optimizeImageRep))
215 | # Note: optimizeImageRep = T breaks with video on METAL framework
216 | VideoHeterogeneityResults <- causalimages::AnalyzeImageHeterogeneity(
217 | # data inputs
218 | obsW = UgandaDataProcessed$Wobs,
219 | obsY = UgandaDataProcessed$Yobs,
220 | imageKeysOfUnits = UgandaDataProcessed$geo_long_lat_key,
221 | file = tfrecord_loc_imSeq, # location of tf record (absolute paths are safest)
222 | dataType = "video",
223 | lat = UgandaDataProcessed$geo_lat, # not required but helpful for dealing with redundant locations in EO data
224 | long = UgandaDataProcessed$geo_long, # not required but helpful for dealing with redundant locations in EO data
225 |
226 | # inputs to control where visual results are saved as PDF or PNGs
227 | # (these image grids are large and difficult to display in RStudio's interactive mode)
228 | plotResults = T,
229 | figuresPath = "~/Downloads/HeteroTutorial",
230 | plotBands = 1L:3L, figuresTag = "HeterogeneityImSeqTutorial",
231 |
232 | # optional arguments for generating transportability maps
233 | # here, we leave those NULL for simplicity
234 | transportabilityMat = NULL, #
235 |
236 | # other modeling options
237 | imageModelClass = ImageModelClass,
238 | nSGD = 5L, # make this larger for real applications (e.g., 2000L)
239 | nDepth_ImageRep = ifelse(optimizeImageRep, yes = 1L, no = 1L),
240 | nWidth_ImageRep = as.integer(2L^5),
241 | optimizeImageRep = optimizeImageRep,
242 | kClust_est = 2, # vary depending on problem. Usually < 5
243 | batchSize = 8L, # make this larger for real application (e.g., 50L)
244 | strides = 2L )
245 | try(dev.off(), T)
246 | }
247 | }
248 | }
249 |
250 | causalimages::print2("Done with image heterogeneity tutorial!")
251 | }
252 |
--------------------------------------------------------------------------------
/tutorials/BuildBackend_Tutorial.R:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env Rscript
2 | {
3 | # clear workspace
4 | rm(list = ls()); options(error = NULL)
5 |
6 | # remote install latest version of the package
7 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages")
8 |
9 | # local install for development team
10 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F)
11 |
12 | # in general, you will simply use:
13 | causalimages::BuildBackend()
14 |
15 | # Note: This function requires an Internet connection
16 | # Note: With default arguments, a conda environment called
17 | # "CausalImagesEnv" will be created with required packages saved within.
18 |
19 | # Advanced tip:
20 | # if you need to points to the location of a specific version of python
21 | # in conda where packages will be downloaded,
22 | # you can use:
23 | #causalimages::BuildBackend(conda = "/Users/cjerzak/miniforge3/bin/python")
24 | }
25 |
26 |
--------------------------------------------------------------------------------
/tutorials/ExtractImageRepresentations_Tutorial.R:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env Rscript
2 | {
3 | ################################
4 | # Image and image-sequence embeddings tutorial using causalimages
5 | ################################
6 |
7 | # start with a clean environment
8 | rm(list=ls()); options(error = NULL)
9 |
10 | # remote install latest version of the package if needed
11 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages")
12 |
13 | # local install for development team
14 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F)
15 |
16 | # build backend you haven't ready:
17 | # causalimages::BuildBackend()
18 |
19 | # load in package
20 | library( causalimages )
21 |
22 | # load in tutorial data
23 | data( CausalImagesTutorialData )
24 |
25 | # example acquire image function (loading from memory)
26 | # in general, you'll want to write a function that returns images
27 | # that saved disk associated with keys
28 | acquireImageFromMemory <- function(keys){
29 | # here, the function input keys
30 | # refers to the unit-associated image keys
31 | m_ <- FullImageArray[match(keys, KeysOfImages),,,]
32 |
33 | # if keys == 1, add the batch dimension so output dims are always consistent
34 | # (here in image case, dims are batch by height by width by channel)
35 | if(length(keys) == 1){
36 | m_ <- array(m_,dim = c(1L,dim(m_)[1],dim(m_)[2],dim(m_)[3]))
37 | }
38 | return( m_ )
39 | }
40 |
41 | # drop first column
42 | X <- X[,-1]
43 |
44 | # mean imputation for simplicity
45 | X <- apply(X,2,function(zer){
46 | zer[is.na(zer)] <- mean( zer,na.rm = T )
47 | return( zer )
48 | })
49 |
50 | # select observation subset to make tutorial analyses run fast
51 | take_indices <- unlist( tapply(1:length(obsW),obsW,function(zer){ sample(zer, 50) }) )
52 |
53 | # write tf record
54 | TfRecord_name <- "~/Downloads/CausalImagesTutorialDat.tfrecord"
55 | causalimages::WriteTfRecord( file = TfRecord_name,
56 | uniqueImageKeys = unique( KeysOfObservations[ take_indices ] ),
57 | acquireImageFxn = acquireImageFromMemory )
58 |
59 | # obtain image representation (random neural projection)
60 | MyImageEmbeddings_RandomProj <- causalimages::GetImageRepresentations(
61 | file = TfRecord_name,
62 | imageKeysOfUnits = KeysOfObservations[ take_indices ],
63 | nDepth_ImageRep = 1L,
64 | nWidth_ImageRep = 128L,
65 | CleanupEnv = T)
66 |
67 | # sanity check - # of rows in MyImageEmbeddings matches # of image keys
68 | nrow(MyImageEmbeddings_RandomProj$ImageRepresentations) == length(KeysOfObservations[ take_indices ])
69 |
70 | # each row in MyImageEmbeddings$ImageRepresentations corresponds to an observation
71 | # each column represents an embedding dimension associated with the imagery for that location
72 | plot( MyImageEmbeddings_RandomProj$ImageRepresentations )
73 |
74 | # obtain image representation (pre-trained ViT)
75 | MyImageEmbeddings_ViT <- causalimages::GetImageRepresentations(
76 | file = TfRecord_name,
77 | pretrainedModel = "vit-base",
78 | imageKeysOfUnits = KeysOfObservations[ take_indices ],
79 | CleanupEnv = T)
80 |
81 | # analyze ViT representations
82 | plot( MyImageEmbeddings_ViT$ImageRepresentations )
83 |
84 | # obtain image representation (pre-trained CLIP-RCSID)
85 | MyImageEmbeddings_Clip <- causalimages::GetImageRepresentations(
86 | file = TfRecord_name,
87 | pretrainedModel = "clip-rsicd",
88 | imageKeysOfUnits = KeysOfObservations[ take_indices ],
89 | CleanupEnv = T)
90 |
91 | # analyze Clip representations
92 | plot( MyImageEmbeddings_Clip$ImageRepresentations )
93 |
94 | # sanity check - # of rows in MyImageEmbeddings matches # of image keys
95 | nrow(MyImageEmbeddings_Clip$ImageRepresentations) == length(KeysOfObservations[ take_indices ])
96 |
97 | # other output quantities include the image model functions and model parameters
98 | names( MyImageEmbeddings_Clip )[-1]
99 |
100 | print("Done with image representations tutorial!")
101 | }
102 |
--------------------------------------------------------------------------------
/tutorials/UsingTfRecords_Tutorial.R:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env Rscript
2 | {
3 | # clear workspace
4 | rm(list=ls()); options(error = NULL)
5 |
6 | ################################
7 | # Image confounding tutorial using causalimages & tfrecords
8 |
9 | # remote install latest version of the package if needed
10 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages")
11 |
12 | # local install for development team
13 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F)
14 |
15 | # build backend you haven't ready:
16 | # causalimages::BuildBackend()
17 |
18 | # load in package
19 | library( causalimages )
20 |
21 | # load in tutorial data
22 | data( CausalImagesTutorialData )
23 |
24 | # example acquire image function (loading from memory)
25 | # in general, you'll want to write a function that returns images
26 | # that saved disk associated with keys
27 | acquireImageFromMemory <- function(keys){
28 | # here, the function input keys
29 | # refers to the unit-associated image keys
30 | m_ <- FullImageArray[match(keys, KeysOfImages),,,]
31 |
32 | # if keys == 1, add the batch dimension so output dims are always consistent
33 | # (here in image case, dims are batch by height by width by channel)
34 | if(length(keys) == 1){
35 | m_ <- array(m_,dim = c(1L,dim(m_)[1],dim(m_)[2],dim(m_)[3]))
36 | }
37 |
38 | # uncomment for a test with different image dimensions
39 | #if(length(keys) == 1){ m_ <- abind::abind(m_,m_,m_,along = 3L) }; if(length(keys) > 1){ m_ <- abind::abind(m_,m_,m_,.along = 4L) }
40 | return( m_ )
41 | }
42 |
43 | dim( acquireImageFromMemory(KeysOfImages[1]) )
44 | dim( acquireImageFromMemory(KeysOfImages[1:2]) )
45 |
46 | # drop first column
47 | X <- X[,-1]
48 |
49 | # mean imputation for simplicity
50 | X <- apply(X,2,function(zer){
51 | zer[is.na(zer)] <- mean( zer,na.rm = T )
52 | return( zer )
53 | })
54 |
55 | # select observation subset to make tutorial analyses run faster
56 | # select 50 treatment and 50 control observations
57 | set.seed(1.)
58 | take_indices <- unlist( tapply(1:length(obsW),obsW,function(zer){sample(zer, 50)}) )
59 |
60 | # !!! important note !!!
61 | # when using tf recordings, it is essential that the data inputs be pre-shuffled like is done here.
62 | # you can use a seed for reproducing the shuffle (so the tfrecord is correctly indexed and you don't need to re-make it)
63 | # tf records read data quasi-sequentially, so systematic patterns in the data ordering
64 | # reduce performance
65 |
66 | # uncomment for a larger n analysis
67 | #take_indices <- 1:length( obsY )
68 |
69 | # set tfrecord save location (best to use absolute paths, but relative paths should in general work too)
70 | tfrecord_loc <- "~/Downloads/ExampleRecord.tfrecord"
71 |
72 | # write a tf records repository
73 | causalimages::WriteTfRecord( file = tfrecord_loc,
74 | uniqueImageKeys = unique( as.character(KeysOfObservations)[ take_indices ] ),
75 | acquireImageFxn = acquireImageFromMemory )
76 |
77 | # perform causal inference with image and tabular confounding
78 | # toy example for illustration purposes where
79 | # treatment is truly randomized
80 | ImageConfoundingAnalysis <- causalimages::AnalyzeImageConfounding(
81 | obsW = obsW[ take_indices ],
82 | obsY = obsY[ take_indices ],
83 | X = X[ take_indices,apply(X[ take_indices,],2,sd)>0],
84 | long = LongLat$geo_long[ take_indices ],
85 | lat = LongLat$geo_lat[ take_indices ],
86 | batchSize = 16L,
87 | imageKeysOfUnits = as.character(KeysOfObservations)[ take_indices ],
88 | file = tfrecord_loc, # point to tfrecords file
89 |
90 | nSGD = 500L,
91 | imageModelClass = "CNN",
92 | #imageModelClass = "VisionTransformer",
93 | plotBands = c(1,2,3),
94 | figuresTag = "TutorialExample",
95 | figuresPath = "~/Downloads/TFRecordTutorial" # figures saved here (use absolute file paths)
96 | )
97 |
98 | # ATE estimate (image confounder adjusted)
99 | ImageConfoundingAnalysis$tauHat_propensityHajek
100 |
101 | # ATE se estimate (image confounder adjusted)
102 | ImageConfoundingAnalysis$tauHat_propensityHajek_se
103 |
104 | # see figuresPath for image analysis output
105 | causalimages::print2("Done with TfRecords tutorial!")
106 | }
107 |
--------------------------------------------------------------------------------