├── .github └── workflows │ └── tutorial-tests.yml ├── .gitignore ├── LICENSE ├── README.md ├── bin └── micromamba ├── causalimages.pdf ├── causalimages ├── .DS_Store ├── DESCRIPTION ├── NAMESPACE ├── R │ ├── CI_BuildBackend.R │ ├── CI_Confounding.R │ ├── CI_GetAndSaveGeolocatedImages.R │ ├── CI_GetMoments.R │ ├── CI_Heterogeneity.R │ ├── CI_Heterogeneity_plot.R │ ├── CI_ImageModelBackbones.R │ ├── CI_InitializeJAX.R │ ├── CI_PredictiveRun.R │ ├── CI_TFRecordManagement.R │ ├── CI_TfRecordFxns.R │ ├── CI_TrainDefine.R │ ├── CI_TrainDo.R │ ├── CI_helperFxns.R │ └── CI_image2.R ├── data │ ├── CausalImagesTutorialData.RData │ └── datalist ├── man │ ├── AnalyzeImageConfounding.Rd │ ├── AnalyzeImageHeterogeneity.Rd │ ├── BuildBackend.Rd │ ├── GetAndSaveGeolocatedImages.Rd │ ├── GetElementFromTfRecordAtIndices.Rd │ ├── GetImageRepresentations.Rd │ ├── GetMoments.Rd │ ├── LongLat2CRS.Rd │ ├── PredictiveRun.Rd │ ├── TFRecordManagement.Rd │ ├── TrainDefine.Rd │ ├── TrainDo.Rd │ ├── WriteTfRecord.Rd │ ├── image2.Rd │ ├── message2.Rd │ └── print2.Rd └── tests │ ├── Test_AAARunAllTutorialsSuite.R │ ├── Test_AnalyzeImageConfounding.R │ ├── Test_AnalyzeImageHeterogeneity.R │ ├── Test_BuildBackend.R │ ├── Test_ExtractImageRepresentations.R │ └── Test_UsingTfRecords.R ├── documentPackage.R ├── misc ├── dataverse │ ├── DataverseReadme_confounding.md │ ├── DataverseReadme_heterogeneity.md │ ├── DataverseTutorial_confounding.R │ └── DataverseTutorial_heterogeneity.R ├── docker │ └── setup │ │ ├── CodexHelpers.sh │ │ ├── Dockerfile │ │ ├── DockerfileNoCompile │ │ ├── FindDependencies.sh │ │ ├── GenEnv.sh │ │ └── GetRDependencyOrder.R └── notes │ └── MaintainerNotes.txt ├── other ├── DataverseReadme_confounding.md ├── DataverseReadme_heterogeneity.md ├── DataverseTutorial_confounding.R ├── DataverseTutorial_heterogeneity.R └── PackageRunChecks.R └── tutorials ├── AnalyzeImageConfounding_Tutorial_Advanced.R ├── AnalyzeImageConfounding_Tutorial_Base.R ├── AnalyzeImageConfounding_Tutorial_EmbeddingsOnly.R ├── AnalyzeImageConfounding_Tutorial_Simulation.R ├── AnalyzeImageHeterogeneity_Tutorial.R ├── BuildBackend_Tutorial.R ├── ExtractImageRepresentations_Tutorial.R └── UsingTfRecords_Tutorial.R /.github/workflows/tutorial-tests.yml: -------------------------------------------------------------------------------- 1 | name: tutorial-tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | 8 | jobs: 9 | test: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: conda-incubator/setup-miniconda@v2 14 | with: 15 | auto-update-conda: true 16 | python-version: 3.10 17 | - uses: r-lib/actions/setup-r@v2 18 | with: 19 | r-version: '4.2' 20 | - name: Install devtools 21 | run: | 22 | Rscript -e 'install.packages("devtools", repos="https://cloud.r-project.org")' 23 | - name: Install package dependencies 24 | run: | 25 | Rscript -e 'devtools::install_deps(dependencies = TRUE)' 26 | - name: Install causalimages package 27 | run: R CMD INSTALL causalimages 28 | - name: Build backend 29 | run: Rscript -e 'causalimages::BuildBackend()' 30 | - name: Link repository for tests 31 | run: | 32 | mkdir -p "$HOME/Documents" 33 | ln -s "$GITHUB_WORKSPACE" "$HOME/Documents/causalimages-software" 34 | - name: Run tutorial tests 35 | run: Rscript tests/Test_AAARunAllTutorialsSuite.R 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .ipynb_checkpoints 3 | .DS_Store 4 | .history 5 | default.profraw 6 | 7 | build 8 | dist 9 | /misc/docker/binaries/ 10 | /misc/docker/binaries/bin/ 11 | /misc/docker/binaries/src/ 12 | */.DS_Store 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2025 Connor Jerzak 2 | 3 | This file is part of causalimages-software. 4 | 5 | causalimages-software is free software: you can redistribute it and/or modify 6 | it under the terms of the GNU General Public License as published by 7 | the Free Software Foundation, either version 3 of the License, or (at 8 | your option) any later version. 9 | 10 | causalimages-software is distributed in the hope that it will be useful, 11 | but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | GNU General Public License for more details. 14 | 15 | For more information, see . -------------------------------------------------------------------------------- /bin/micromamba: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjerzak/causalimages-software/a6ca8dd8aae433c5207dedaa89e2c099d5700b8f/bin/micromamba -------------------------------------------------------------------------------- /causalimages.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjerzak/causalimages-software/a6ca8dd8aae433c5207dedaa89e2c099d5700b8f/causalimages.pdf -------------------------------------------------------------------------------- /causalimages/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjerzak/causalimages-software/a6ca8dd8aae433c5207dedaa89e2c099d5700b8f/causalimages/.DS_Store -------------------------------------------------------------------------------- /causalimages/DESCRIPTION: -------------------------------------------------------------------------------- 1 | Package: causalimages 2 | Title: Causal Inference with Earth Observation, Bio-medical, 3 | and Social Science Images 4 | Version: 0.1 5 | Authors@R: 6 | c(person(given = "Connor", 7 | family = "Jerzak", 8 | role = c("aut", "cre"), 9 | email = "connor.jerzak@gmail.com", 10 | comment = c(ORCID = "0000-0003-1914-8905")), 11 | person(given = "Adel", 12 | family = "Daoud", 13 | role = "aut", 14 | comment = c(ORCID = "0000-0001-7478-8345"))) 15 | Description: Provides a system for performing causal inference with earth observation, 16 | bio-medical, and social science images and image sequences (videos). The package 17 | uses a 'JAX' backend for GPU/TPU acceleration. Key functionalities include building 18 | conda-based backends (e.g., via 'BuildBackend'), implementing image-based confounder 19 | and heterogeneity analyses (e.g., 'AnalyzeImageConfounding', 'AnalyzeImageHeterogeneity'), 20 | and writing/reading large image corpora as '.tfrecord' files for use in training 21 | (via 'WriteTfRecord' and 'GetElementFromTfRecordAtIndices'). This allows researchers 22 | to scale causal inference to modern large-scale imagery data, bridging R with 23 | hardware-accelerated Python libraries. The package is partly based on Jerzak 24 | and Daoud (2023) . 25 | URL: https://github.com/cjerzak/causalimages-software 26 | BugReports: https://github.com/cjerzak/causalimages-software/issues 27 | Depends: R (>= 3.3.3) 28 | License: GPL-3 29 | Encoding: UTF-8 30 | LazyData: false 31 | Imports: 32 | tensorflow, 33 | reticulate, 34 | geosphere, 35 | raster, 36 | rrapply, 37 | glmnet, 38 | sf, 39 | data.table, 40 | pROC 41 | Suggests: 42 | knitr, 43 | rmarkdown 44 | VignetteBuilder: knitr 45 | RoxygenNote: 7.3.2 46 | RemoteType: local 47 | RemotePkgRef: local::~/Documents/causalimages-software/causalimages 48 | RemoteUrl: /Users/cjerzak/Documents/causalimages-software/causalimages 49 | -------------------------------------------------------------------------------- /causalimages/NAMESPACE: -------------------------------------------------------------------------------- 1 | # Generated by roxygen2: do not edit by hand 2 | 3 | export(AnalyzeImageConfounding) 4 | export(AnalyzeImageHeterogeneity) 5 | export(BuildBackend) 6 | export(GetAndSaveGeolocatedImages) 7 | export(GetElementFromTfRecordAtIndices) 8 | export(GetImageRepresentations) 9 | export(GetMoments) 10 | export(LongLat2CRS) 11 | export(PredictiveRun) 12 | export(TFRecordManagement) 13 | export(TrainDefine) 14 | export(TrainDo) 15 | export(WriteTfRecord) 16 | export(image2) 17 | export(message2) 18 | export(print2) 19 | import(raster) 20 | import(reticulate) 21 | import(rrapply) 22 | -------------------------------------------------------------------------------- /causalimages/R/CI_BuildBackend.R: -------------------------------------------------------------------------------- 1 | #' Build the environment for CausalImages models. Builds a conda environment in which jax, tensorflow, tensorflow-probability optax, equinox, and jmp are installed. 2 | #' 3 | #' @param conda_env (default = `"CausalImagesEnv"`) Name of the conda environment in which to place the backends. 4 | #' @param conda (default = `auto`) The path to a conda executable. Using `"auto"` allows reticulate to attempt to automatically find an appropriate conda binary. 5 | 6 | #' @return Builds the computational environment for `causalimages`. This function requires an Internet connection. 7 | #' You may find out a list of conda Python paths via: `system("which python")` 8 | #' 9 | #' @examples 10 | #' # For a tutorial, see 11 | #' # github.com/cjerzak/causalimages-software/ 12 | #' 13 | #' @export 14 | #' @md 15 | 16 | BuildBackend <- function(conda_env = "CausalImagesEnv", conda = "auto") { 17 | # --- helpers --------------------------------------------------------------- 18 | os <- Sys.info()[["sysname"]] 19 | machine <- Sys.info()[["machine"]] 20 | msg <- function(...) message(sprintf(...)) 21 | 22 | pip_install <- function(pkgs, ...) { 23 | reticulate::py_install( 24 | packages = pkgs, 25 | envname = conda_env, 26 | conda = conda, 27 | pip = TRUE, 28 | ... 29 | ) 30 | TRUE 31 | } 32 | 33 | # Find the Python executable inside the target conda env (for manual pip calls) 34 | env_python_path <- function() { 35 | # try via conda_list 36 | cl <- try(reticulate::conda_list(), silent = TRUE) 37 | if (!inherits(cl, "try-error") && any(cl$name == conda_env)) { 38 | py <- cl$python[match(conda_env, cl$name)] 39 | if (length(py) == 1 && !is.na(py) && nzchar(py) && file.exists(py)) return(py) 40 | } 41 | # fallback via conda binary location 42 | cb <- try(reticulate::conda_binary(conda), silent = TRUE) 43 | prefix <- if (!inherits(cb, "try-error") && nzchar(cb)) dirname(dirname(cb)) else { 44 | # last-ditch default 45 | if (os == "Windows") "C:/Miniconda3" else file.path(Sys.getenv("HOME"), "miniconda3") 46 | } 47 | if (os == "Windows") 48 | file.path(prefix, "envs", conda_env, "python.exe") 49 | else 50 | file.path(prefix, "envs", conda_env, "bin", "python") 51 | } 52 | 53 | pip_install_from_findlinks <- function(spec, find_links) { 54 | py <- env_python_path() 55 | cmd <- sprintf( 56 | "%s -m pip install --upgrade --no-user -f %s %s", 57 | shQuote(py), shQuote(find_links), shQuote(spec) 58 | ) 59 | res <- try(system(cmd, intern = TRUE), silent = TRUE) 60 | !inherits(res, "try-error") 61 | } 62 | 63 | # --- conda env ------------------------------------------------------------- 64 | reticulate::conda_create( 65 | envname = conda_env, 66 | conda = conda, 67 | python_version = "3.13" 68 | ) 69 | 70 | # Install numpy early to stabilize BLAS/ABI choices if needed 71 | pip_install("numpy") 72 | 73 | # --- JAX first: hardware-aware selection ----------------------------------- 74 | install_jax <- function() { 75 | if (os == "Darwin" && machine %in% c("arm64", "aarch64")) { 76 | # Apple Silicon: Metal backend 77 | #msg("Apple Silicon detected: installing JAX (Metal).") 78 | #pip_install(c("jax==0.5.0", "jaxlib==0.5.0", "jax-metal==0.1.1")) 79 | pip_install(c("jax", "jaxlib")) 80 | return(invisible(TRUE)) 81 | } 82 | 83 | if (identical(os, "Linux")) { 84 | # Query NVIDIA driver major version (e.g., '535.171.04' -> 535) 85 | drv <- try(suppressWarnings( 86 | system("nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n1", 87 | intern = TRUE) 88 | ), silent = TRUE) 89 | drv_major <- suppressWarnings(as.integer(sub("^([0-9]+).*", "\\1", drv[1]))) 90 | msg("Detected NVIDIA driver: %s", ifelse(length(drv) > 0, drv[1], "none/unknown")) 91 | 92 | # Prefer CUDA 13 when possible, fall back to CUDA 12, then CPU 93 | if (!is.na(drv_major) && drv_major >= 580) { 94 | msg("Driver >= 580: trying JAX CUDA 13 wheels.") 95 | ok <- try(pip_install("jax[cuda13]"), silent = TRUE) 96 | ok <- isTRUE(ok) && !inherits(ok, "try-error") 97 | if (!ok) { 98 | msg("CUDA 13 wheels failed; falling back to CUDA 12 extras.") 99 | ok <- try(pip_install("jax[cuda12]"), silent = TRUE) 100 | ok <- isTRUE(ok) && !inherits(ok, "try-error") 101 | } 102 | if (!ok) { 103 | msg("CUDA 12 extras failed; trying legacy 'cuda12_pip' via find-links.") 104 | ok <- pip_install_from_findlinks( 105 | "jax[cuda12_pip]", 106 | "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" 107 | ) 108 | } 109 | if (!ok) { 110 | msg("All CUDA wheel attempts failed; installing CPU-only JAX.") 111 | pip_install("jax") 112 | } 113 | } else if (!is.na(drv_major) && drv_major >= 525) { 114 | msg("Driver >= 525 and < 580: installing JAX CUDA 12 wheels.") 115 | ok <- try(pip_install("jax[cuda12]"), silent = TRUE) 116 | ok <- isTRUE(ok) && !inherits(ok, "try-error") 117 | if (!ok) { 118 | msg("CUDA 12 extras failed; trying legacy 'cuda12_pip' via find-links.") 119 | ok <- pip_install_from_findlinks( 120 | "jax[cuda12_pip]", 121 | "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" 122 | ) 123 | } 124 | if (!ok) { 125 | msg("CUDA wheels failed; installing CPU-only JAX.") 126 | pip_install("jax") 127 | } 128 | } else { 129 | msg("No suitable NVIDIA driver found (or too old); installing CPU-only JAX.") 130 | pip_install("jax") 131 | } 132 | return(invisible(TRUE)) 133 | } 134 | 135 | # Other OSes: CPU-only JAX 136 | msg("Non-Linux or non-Apple-Silicon platform; installing CPU-only JAX.") 137 | pip_install("jax") 138 | } 139 | 140 | install_jax() 141 | 142 | # Optionally neutralize LD_LIBRARY_PATH within this env to avoid host overrides 143 | if (os == "Linux") { 144 | cb <- try(reticulate::conda_binary(conda), silent = TRUE) 145 | conda_prefix <- if (!inherits(cb, "try-error") && nzchar(cb)) dirname(dirname(cb)) else { 146 | file.path(Sys.getenv("HOME"), "miniconda3") 147 | } 148 | env_dir <- file.path(conda_prefix, "envs", conda_env) 149 | actdir <- file.path(env_dir, "etc", "conda", "activate.d") 150 | dir.create(actdir, recursive = TRUE, showWarnings = FALSE) 151 | try(writeLines("unset LD_LIBRARY_PATH", file.path(actdir, "00-unset-ld.sh")), silent = TRUE) 152 | } 153 | 154 | # --- Remaining packages (do NOT include 'jax' again to avoid downgrades) ---- 155 | pip_install(c( 156 | "tensorflow", 157 | "optax", 158 | "torch", 159 | "transformers", 160 | "pillow", 161 | "tf-keras", 162 | "equinox", 163 | "jmp" 164 | )) 165 | 166 | # reinstall JAX again on Mac after forced upgrade to breaking version of JAX 167 | #if(os == "Darwin"){ install_jax() } 168 | 169 | done_msg <- sprintf("Done building causalimages backend (env '%s').", conda_env) 170 | if (exists("message2", mode = "function")) message2(done_msg) else message(done_msg) 171 | } 172 | -------------------------------------------------------------------------------- /causalimages/R/CI_GetAndSaveGeolocatedImages.R: -------------------------------------------------------------------------------- 1 | #' Getting and saving geo-located images from a pool of .tif's 2 | #' 3 | #' A function that finds the image slice associated with the `long` and `lat` values, saves images by band (if `save_as = "csv"`) in save_folder. 4 | #' 5 | #' @param long Vector of numeric longitudes. 6 | #' @param lat Vector of numeric latitudes. 7 | #' @param keys The image keys associated with the long/lat coordinates. 8 | #' @param tif_pool A character vector specifying the fully qualified path to a corpus of .tif files. 9 | #' @param save_folder (default = `"."`) What folder should be used to save the output? Example: `"~/Downloads"` 10 | #' @param image_pixel_width An even integer specifying the pixel width (and height) of the saved images. 11 | #' @param save_as (default = `".csv"`) What format should the output be saved as? Only one option currently (`.csv`) 12 | #' @param lyrs (default = NULL) Integer (vector) specifying the layers to be extracted. Default is for all layers to be extracted. 13 | #' 14 | #' @return Finds the image slice associated with the `long` and `lat` values, saves images by band (if `save_as = "csv"`) in save_folder. 15 | #' The save format is: `sprintf("%s/Key%s_BAND%s.csv", save_folder, keys[i], band_)` 16 | #' 17 | #' @examples 18 | #' 19 | #' # Example use (not run): 20 | #' #MASTER_IMAGE_POOL_FULL_DIR <- c("./LargeTifs/tif1.tif","./LargeTifs/tif2.tif") 21 | #' #GetAndSaveGeolocatedImages( 22 | #' #long = GeoKeyMat$geo_long, 23 | #' #lat = GeoKeyMat$geo_lat, 24 | #' #image_pixel_width = 500L, 25 | #' #keys = row.names(GeoKeyMat), 26 | #' #tif_pool = MASTER_IMAGE_POOL_FULL_DIR, 27 | #' #save_folder = "./Data/Uganda2000_processed", 28 | #' #save_as = "csv", 29 | #' #lyrs = NULL) 30 | #' 31 | #' @import raster 32 | #' @export 33 | #' @md 34 | #' 35 | GetAndSaveGeolocatedImages <- function( 36 | long, 37 | lat, 38 | keys, 39 | tif_pool, 40 | image_pixel_width = 256L, 41 | save_folder = ".", 42 | save_as = "csv", 43 | lyrs = NULL){ 44 | 45 | library(raster); library(sf) 46 | RADIUS_CELLS <- (DIAMETER_CELLS <- image_pixel_width) / 2 47 | bad_indices <- c();observation_indices <- 1:length(long) 48 | counter_b <- 0; for(i in observation_indices){ 49 | counter_b <- counter_b + 1 50 | if(counter_b %% 10 == 0){message(sprintf("At image %s of %s",counter_b,length(observation_indices)))} 51 | SpatialTarget_longlat <- c(long[i],lat[i]) 52 | 53 | found_<-F;counter_ <- 0; while(found_ == F){ 54 | counter_ <- counter_ + 1 55 | if(is.na(tif_pool[counter_])){ found_ <- bad_ <- T } 56 | if(!is.na(tif_pool[counter_])){ 57 | MASTER_IMAGE_ <- try(raster::brick( tif_pool[counter_] ), T) 58 | SpatialTarget_utm <- LongLat2CRS( 59 | long = SpatialTarget_longlat[1], 60 | lat = SpatialTarget_longlat[2], 61 | CRS_ref = raster::crs(MASTER_IMAGE_)) 62 | 63 | SpatialTargetCellLoc <- raster::cellFromXY( 64 | object = MASTER_IMAGE_, 65 | xy = SpatialTarget_utm) 66 | SpatialTargetRowCol <- raster::rowColFromCell( 67 | object = MASTER_IMAGE_, 68 | cell = SpatialTargetCellLoc ) 69 | if(!is.na(sum(SpatialTargetRowCol))){found_<-T;bad_ <- F} 70 | if(counter_ > 1000000){stop("ERROR! Target not found anywhere in pool!")} 71 | } 72 | } 73 | if(bad_){ 74 | message(sprintf("Failure at %s. Apparently, no .tif contains the reference point",i)) 75 | bad_indices <- c(bad_indices,i) 76 | } 77 | if(!bad_){ 78 | message(sprintf("Success at %s - Extracting & saving image!", i)) 79 | # available rows/cols 80 | rows_available <- nrow( MASTER_IMAGE_ ) 81 | cols_available <- ncol( MASTER_IMAGE_ ) 82 | 83 | 84 | # define start row/col 85 | start_row <- SpatialTargetRowCol[1,"row"] - RADIUS_CELLS 86 | start_col <- SpatialTargetRowCol[1,"col"] - RADIUS_CELLS 87 | 88 | # find end row/col 89 | end_row <- start_row + DIAMETER_CELLS 90 | end_col <- start_col + DIAMETER_CELLS 91 | 92 | # perform checks to deal with spilling over image 93 | if(start_row <= 0){start_row <- 1} 94 | if(start_col <= 0){start_col <- 1} 95 | if(end_row > rows_available){ start_row <- rows_available - DIAMETER_CELLS } 96 | if(end_col > cols_available){ start_col <- cols_available - DIAMETER_CELLS } 97 | 98 | for(iof in 0:0){ 99 | if(is.null(lyrs)){lyrs <- 1:dim(MASTER_IMAGE_)[3] } 100 | band_iters <- ifelse(grepl(x = save_as, pattern ="csv"), 101 | yes = list(lyrs), no = list(1L) )[[1]] 102 | for(band_ in band_iters){ 103 | if(iof > 0){ 104 | start_row <- sample(1:(nrow(MASTER_IMAGE_)-DIAMETER_CELLS-1),1) 105 | start_col <- sample(1:(ncol(MASTER_IMAGE_)-DIAMETER_CELLS-1),1) 106 | } 107 | SpatialTargetImage_ <- getValuesBlock(MASTER_IMAGE_[[band_]], 108 | row = start_row, nrows = DIAMETER_CELLS, 109 | col = start_col, ncols = DIAMETER_CELLS, 110 | format = "matrix", lyrs = 1L) 111 | if(length(unique(c(SpatialTargetImage_)))<5){ bad_indices <- c(bad_indices,i) } 112 | check_ <- dim(SpatialTargetImage_) - c(DIAMETER_CELLS,DIAMETER_CELLS) 113 | if(any(check_ < 0)){print("WARNING: CHECKS FAILED"); browser()} 114 | if(grepl(x = save_as, pattern ="tif")){ 115 | # in progress 116 | } 117 | if(grepl(x = save_as, pattern ="csv")){ 118 | if(iof == 0){ 119 | data.table::fwrite(file = sprintf("%s/Key%s_BAND%s.csv", 120 | save_folder, keys[i], band_), 121 | data.table::as.data.table(SpatialTargetImage_)) 122 | } 123 | } 124 | } 125 | } 126 | } 127 | } 128 | message("Done with GetAndSaveGeolocatedImages()!") 129 | } 130 | -------------------------------------------------------------------------------- /causalimages/R/CI_GetMoments.R: -------------------------------------------------------------------------------- 1 | #' Get moments for normalization (internal function) 2 | #' 3 | #' An internal function function for obtaining moments for channel normalization. 4 | #' 5 | #' @param iterator An iterator 6 | #' @param dataType A string denoting data type 7 | #' @param momentCalIters Number of minibatches with which to estimate moments 8 | #' 9 | #' @return Returns mean/sd arrays for normalization. 10 | #' 11 | #' @examples 12 | #' # (Not run) 13 | #' # GetMoments(iterator, dataType, image_dtype, momentCalIters = 34L) 14 | #' @export 15 | #' @md 16 | #' 17 | GetMoments <- function(iterator, dataType, image_dtype, momentCalIters = 34L){ 18 | message2("Calibrating moments for input data normalization...") 19 | NORM_SD <- NORM_MEAN <- c(); for(momentCalIter in 1L:momentCalIters){ 20 | # get a data batch 21 | ds_next_ <- try(iterator$get_next(),T) 22 | 23 | if(!"try-error" %in% class(ds_next_)){ 24 | # setup normalizations 25 | ApplyAxis <- ifelse(dataType == "video", yes = 5, no = 4) 26 | 27 | # sanity check 28 | # causalimages::image2( cienv$np$array((ds_next_train[[1]])[2,,,1] ) 29 | 30 | # update normalizations 31 | NORM_SD <- rbind(NORM_SD, apply(cienv$np$array(ds_next_[[1]]),ApplyAxis,sd)) 32 | NORM_MEAN <- rbind(NORM_MEAN, apply(cienv$np$array(ds_next_[[1]]),ApplyAxis,mean)) 33 | } 34 | } 35 | 36 | # mean calc 37 | NORM_MEAN_mat <- NORM_MEAN # same shape 38 | NORM_MEAN <- apply(NORM_MEAN,2,mean) # overall mean across all batches 39 | NORM_MEAN_array <- cienv$jnp$array(array(NORM_MEAN,dim=c(1,1,1,length(NORM_MEAN))),image_dtype) 40 | 41 | # SD calc using Rubin’s rule: combine within‐ and between‐batch variances 42 | NORM_SD_mat <- NORM_SD # matrix: rows = batches, cols = features 43 | 44 | #combine information to get 45 | m <- nrow(NORM_SD_mat) 46 | W <- colMeans(NORM_SD_mat^2) # average within‐batch variance 47 | B <- apply(NORM_MEAN_mat,2,var) # variance of the batch means 48 | T_var <- W + (1 + 1/m) * B # total variance 49 | NORM_SD <- sqrt(T_var) # combined SD 50 | # plot(apply(NORM_SD_mat,2,median),NORM_SD);abline(a=0,b=1) 51 | if("try-error" %in% class(NORM_SD)){browser()} 52 | NORM_SD_array <- cienv$jnp$array(array(NORM_SD,dim=c(1,1,1,length(NORM_SD))),image_dtype) 53 | 54 | if(dataType == "video"){ 55 | NORM_MEAN_array <- cienv$jnp$expand_dims(NORM_MEAN_array, 0L) 56 | NORM_SD_array <- cienv$jnp$expand_dims(NORM_SD_array, 0L) 57 | } 58 | 59 | return(list("NORM_MEAN_array"=NORM_MEAN_array, 60 | "NORM_SD_array"=NORM_SD_array, 61 | "NORM_MEAN" = NORM_MEAN, 62 | "NORM_SD" = NORM_SD)) 63 | } 64 | -------------------------------------------------------------------------------- /causalimages/R/CI_InitializeJAX.R: -------------------------------------------------------------------------------- 1 | initialize_jax <- function(conda_env = "cienv", 2 | conda_env_required = TRUE, 3 | Sys.setenv_text = NULL) { 4 | message2("Establishing connection to computational environment (build via causalimages::BuildBackend())") 5 | 6 | library(reticulate) 7 | #library(tensorflow) 8 | 9 | # Load reticulate (Declared in Imports: in DESCRIPTION) 10 | reticulate::use_condaenv(condaenv = conda_env, required = conda_env_required) 11 | 12 | if(!is.null(Sys.setenv_text)){ 13 | eval(parse(text = Sys.setenv_text), envir = .GlobalEnv) 14 | } 15 | 16 | # Import Python packages once, storing them in cienv 17 | if (!exists("jax", envir = cienv, inherits = FALSE)) { 18 | cienv$jax <- reticulate::import("jax") 19 | cienv$jnp <- reticulate::import("jax.numpy") 20 | cienv$flash_mha <- try(import("flash_attn_jax.flash_mha"),TRUE) 21 | cienv$tf <- reticulate::import("tensorflow") 22 | cienv$np <- reticulate::import("numpy") 23 | cienv$jmp <- reticulate::import("jmp") 24 | cienv$optax <- reticulate::import("optax") 25 | #cienv$oryx <- reticulate::import("tensorflow_probability.substrates.jax") 26 | cienv$eq <- reticulate::import("equinox") 27 | cienv$py_gc <- reticulate::import("gc") 28 | } 29 | 30 | # set memory growth for tensorflow 31 | for(device_ in cienv$tf$config$list_physical_devices()){ try(cienv$tf$config$experimental$set_memory_growth(device_, T),T) } 32 | 33 | # ensure tensorflow doesn't use GPU 34 | try(cienv$tf$config$set_visible_devices(list(), "GPU"), silent = TRUE) 35 | 36 | # Disable 64-bit computations 37 | cienv$jax$config$update("jax_enable_x64", FALSE) 38 | cienv$jaxFloatType <- cienv$jnp$float32 39 | } 40 | 41 | initialize_torch <- function(conda_env = "cienv", 42 | conda_env_required = TRUE, 43 | Sys.setenv_text = NULL) { 44 | # Import Python packages once, storing them in cienv 45 | if (!exists("torch", envir = cienv, inherits = FALSE)) { 46 | cienv$torch <- reticulate::import("torch") 47 | cienv$transformers <- reticulate::import("transformers") 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /causalimages/R/CI_TFRecordManagement.R: -------------------------------------------------------------------------------- 1 | #' Defines an internal TFRecord management routine (internal function) 2 | #' 3 | #' Defines management defined in TFRecordManagement(). Internal function. 4 | #' 5 | #' @param . No parameters. 6 | #' 7 | #' @return Internal function defining a tfrecord management sequence. 8 | #' 9 | #' @import reticulate rrapply 10 | #' @export 11 | #' @md 12 | TFRecordManagement <- function(){ 13 | 14 | if(is.null(file)){stop("No file specified for tfrecord!")} 15 | changed_wd <- F; if( !is.null( file ) ){ 16 | message2("Establishing connection with tfrecord") 17 | tf_record_name <- file 18 | if( !grepl(tf_record_name, pattern = "/") ){ 19 | tf_record_name <- paste("./",tf_record_name, sep = "") 20 | } 21 | tf_record_name <- strsplit(tf_record_name,split="/")[[1]] 22 | new_wd <- paste(tf_record_name[-length(tf_record_name)], collapse = "/") 23 | message2(sprintf("Temporarily re-setting the wd to %s", new_wd ) ) 24 | changed_wd <- T; setwd( new_wd ) 25 | 26 | # define video indicator 27 | useVideoIndicator <- dataType == "video" 28 | 29 | # define tf record 30 | tf_dataset <- cienv$tf$data$TFRecordDataset( tf_record_name[length(tf_record_name)] ) 31 | 32 | # helper functions 33 | getParsed_tf_dataset_inference <- function(tf_dataset){ 34 | dataset <- tf_dataset$map( function(x){parse_tfr_element(x, 35 | readVideo = useVideoIndicator, 36 | image_dtype = image_dtype_tf)} ) 37 | return( dataset <- dataset$batch( ai(max(2L,round(batchSize/2L) ))) ) 38 | } 39 | 40 | message2("Setting up iterators...") # skip the first test_size observations 41 | getParsed_tf_dataset_train_Select <- function( tf_dataset ){ 42 | return( tf_dataset$map( function(x){ parse_tfr_element(x, 43 | readVideo = useVideoIndicator, 44 | image_dtype = image_dtype_tf)}, 45 | num_parallel_calls = cienv$tf$data$AUTOTUNE) ) 46 | } 47 | getParsed_tf_dataset_train_BatchAndShuffle <- function( tf_dataset ){ 48 | tf_dataset <- tf_dataset$shuffle(buffer_size = cienv$tf$constant(ai(TfRecords_BufferScaler*batchSize), 49 | dtype=cienv$tf$int64), 50 | reshuffle_each_iteration = T) 51 | tf_dataset <- tf_dataset$batch( ai(batchSize) ) 52 | tf_dataset <- tf_dataset$prefetch( cienv$tf$data$AUTOTUNE ) 53 | return( tf_dataset ) 54 | } 55 | if(!is.null(TFRecordControl)){ 56 | tf_dataset_train_control <- getParsed_tf_dataset_train_Select( 57 | tf_dataset$skip( test_size <- ai( TFRecordControl$nTest ) ) 58 | )$take( ai(TFRecordControl$nControl) )$`repeat`(-1L) 59 | 60 | tf_dataset_train_treated <- getParsed_tf_dataset_train_Select( 61 | tf_dataset$skip( test_size <- ai( TFRecordControl$nTest ) ) 62 | )$skip( ai(TFRecordControl$nControl)+1L)$`repeat`(-1L) 63 | 64 | tf_dataset_train_treated <- getParsed_tf_dataset_train_BatchAndShuffle( tf_dataset_train_treated ) 65 | tf_dataset_train_control <- getParsed_tf_dataset_train_BatchAndShuffle( tf_dataset_train_control ) 66 | 67 | ds_iterator_train_treated <- reticulate::as_iterator( tf_dataset_train_treated ) 68 | ds_iterator_train_control <- reticulate::as_iterator( tf_dataset_train_control ) 69 | ds_iterator_train <- reticulate::as_iterator( tf_dataset_train_control ) 70 | } 71 | if(is.null(TFRecordControl)){ 72 | getParsed_tf_dataset_train <- function( tf_dataset ){ 73 | dataset <- tf_dataset$map( function(x){ parse_tfr_element(x, readVideo = useVideoIndicator, image_dtype = image_dtype_tf)}, 74 | num_parallel_calls = cienv$tf$data$AUTOTUNE) 75 | dataset <- dataset$shuffle(buffer_size = cienv$tf$constant(ai(TfRecords_BufferScaler*batchSize), dtype=cienv$tf$int64), 76 | reshuffle_each_iteration = FALSE) # set FALSE so same train/test split each re-initialization 77 | dataset <- dataset$batch( ai(batchSize) ) 78 | dataset <- dataset$prefetch( cienv$tf$data$AUTOTUNE ) 79 | return( dataset ) 80 | } 81 | 82 | # shuffle (generating different train/test splits) 83 | tf_dataset <- cienv$tf$data$Dataset$shuffle( tf_dataset, 84 | buffer_size = cienv$tf$constant(ai(10L*TfRecords_BufferScaler*batchSize), 85 | dtype=cienv$tf$int64), reshuffle_each_iteration = F ) 86 | tf_dataset_train <- getParsed_tf_dataset_train( 87 | tf_dataset$skip(test_size <- as.integer(round(testFrac * length(unique(imageKeysOfUnits)) )) ) )$`repeat`( -1L ) 88 | ds_iterator_train <- reticulate::as_iterator( tf_dataset_train ) 89 | } 90 | 91 | # define inference iterator 92 | tf_dataset_inference <- getParsed_tf_dataset_inference( tf_dataset ) 93 | ds_iterator_inference <- reticulate::as_iterator( tf_dataset_inference ) 94 | 95 | # Other helper functions 96 | getParsed_tf_dataset_train_Shuffle <- function( tf_dataset ){ 97 | tf_dataset <- tf_dataset$shuffle(buffer_size = cienv$tf$constant(ai(TfRecords_BufferScaler*batchSize), 98 | dtype=cienv$tf$int64), 99 | reshuffle_each_iteration = FALSE ) 100 | return(tf_dataset) 101 | } 102 | } 103 | } -------------------------------------------------------------------------------- /causalimages/R/CI_TfRecordFxns.R: -------------------------------------------------------------------------------- 1 | #' Write an image corpus as a .tfrecord file 2 | #' 3 | #' Writes an image corpus to a `.tfrecord` file for rapid reading of images into memory for fast ML training. 4 | #' Specifically, this function serializes an image or video corpus into a `.tfrecord` file, enabling efficient data loading for machine learning tasks, particularly for image-based causal inference training. 5 | #' It requires that users define an `acquireImageFxn` function that accepts keys and returns the corresponding image or video as an array of dimensions `(length(keys), nSpatialDim1, nSpatialDim2, nChannels)` for images or `(length(keys), nTimeSteps, nSpatialDim1, nSpatialDim2, nChannels)` for video sequences. 6 | 7 | #' 8 | #' @param file A character string naming a file for writing. 9 | #' @param uniqueImageKeys A vector specifying the unique image keys of the corpus. 10 | #' A key grabs an image/video array via `acquireImageFxn(key)`. 11 | #' @param acquireImageFxn A function whose input is an observation keys and whose output is an array with dimensions `(length(keys), nSpatialDim1, nSpatialDim2, nChannels)` for images and `(length(keys), nTimeSteps, nSpatialDim1, nSpatialDim2, nChannels)` for image sequence data. 12 | #' @param conda_env (default = `"CausalImagesEnv"`) A `conda` environment where computational environment lives, usually created via `causalimages::BuildBackend()` 13 | #' @param conda_env_required (default = `T`) A Boolean stating whether use of the specified conda environment is required. 14 | #' @param writeVideo (default = `FALSE`) Should we assume we're writing image sequence data of form batch by time by height by width by channels? 15 | #' 16 | #' @return Writes a unique key-referenced `.tfrecord` from an image/video corpus for use in image-based causal inference training. 17 | #' 18 | #' @examples 19 | #' # Example usage (not run): 20 | #' #WriteTfRecord( 21 | #' # file = "./NigeriaConfoundApp.tfrecord", 22 | #' # uniqueImageKeys = 1:n, 23 | #' # acquireImageFxn = acquireImageFxn) 24 | #' 25 | #' @export 26 | #' @md 27 | WriteTfRecord <- function(file, 28 | uniqueImageKeys, 29 | acquireImageFxn, 30 | writeVideo = F, 31 | image_dtype = "float16", 32 | conda_env = "CausalImagesEnv", 33 | conda_env_required = T, 34 | Sys.setenv_text = NULL){ 35 | if(!"jax" %in% ls(envir = cienv)) { 36 | initialize_jax(conda_env = conda_env, 37 | conda_env_required = conda_env_required, 38 | Sys.setenv_text = Sys.setenv_text) 39 | } 40 | 41 | if(length(uniqueImageKeys) != length(unique(uniqueImageKeys))){ 42 | stop("Stopping because length(uniqueImageKeys) != length(unique(uniqueImageKeys)) \n 43 | Remember: Input to WriteTFRecord is uniqueImageKeys, not imageKeysOfUnits where redundancies may live") 44 | } 45 | 46 | # helper fxns 47 | message2("Initializing tfrecord helpers...") 48 | { 49 | # see https://towardsdatascience.com/a-practical-guide-to-tfrecords-584536bc786c 50 | my_bytes_feature <- function(value){ 51 | #"""Returns a bytes_list from a string / byte.""" 52 | #if(class(value) == class(cienv$tf$constant(0))){ # if value ist tensor 53 | value = value$numpy() # get value of tensor 54 | #} 55 | return( cienv$tf$train$Feature(bytes_list=cienv$tf$train$BytesList(value=list(value)))) 56 | } 57 | 58 | my_simple_bytes_feature <- function(value){ 59 | return( cienv$tf$train$Feature(bytes_list = cienv$tf$train$BytesList(value = list(value$numpy()))) ) 60 | } 61 | 62 | my_int_feature <- function(value){ 63 | #"""Returns an int64_list from a bool / enum / int / uint.""" 64 | return( cienv$tf$train$Feature(int64_list=cienv$tf$train$Int64List(value=list(value))) ) 65 | } 66 | 67 | my_serialize_array <- function(array){return( cienv$tf$io$serialize_tensor(array) )} 68 | 69 | parse_single_image <- function(image, index, key){ 70 | if(writeVideo == F){ 71 | data <- dict( 72 | "height" = my_int_feature( image$shape[[1]] ), # note: zero indexed 73 | "width" = my_int_feature( image$shape[[2]] ), 74 | "channels" = my_int_feature( image$shape[[3]] ), 75 | "raw_image" = my_bytes_feature( my_serialize_array( image ) ), 76 | "index" = my_int_feature( index ), 77 | "key" = my_bytes_feature( my_serialize_array(key) )) 78 | } 79 | if(writeVideo == T){ 80 | data <- dict( 81 | "time" = my_int_feature( image$shape[[1]] ), # note: zero indexed 82 | "height" = my_int_feature( image$shape[[2]] ), 83 | "width" = my_int_feature( image$shape[[3]] ), 84 | "channels" = my_int_feature( image$shape[[4]] ), 85 | "raw_image" = my_bytes_feature( my_serialize_array( image ) ), 86 | "index" = my_int_feature( index ), 87 | "key" = my_bytes_feature( my_serialize_array(key) ) ) 88 | } 89 | out <- cienv$tf$train$Example( features = cienv$tf$train$Features(feature = data) ) 90 | return( out ) 91 | } 92 | } 93 | 94 | # for clarity, set file to tf_record_name 95 | message2("Starting save run...") 96 | tf_record_name <- file 97 | if( !grepl(tf_record_name, pattern = "/") ){ 98 | tf_record_name <- paste("./",tf_record_name, sep = "") 99 | } 100 | 101 | orig_wd <- getwd() 102 | tf_record_name <- strsplit(tf_record_name,split="/")[[1]] 103 | new_wd <- paste(tf_record_name[- length(tf_record_name) ],collapse = "/") 104 | setwd( new_wd ) 105 | tf_record_writer = cienv$tf$io$TFRecordWriter( tf_record_name[ length(tf_record_name) ] ) #create a writer that'll store our data to disk 106 | setwd( orig_wd ) 107 | for(irz in 1:length(uniqueImageKeys)){ 108 | if(irz %% 10 == 0 | irz == 1){ print( sprintf("[%s] At index %s of %s [%.3f%%]", 109 | format(Sys.time(), "%Y-%m-%d %H:%M:%S"), 110 | irz, length(uniqueImageKeys), 111 | 100*irz/length(uniqueImageKeys)) ) } 112 | tf_record_write_output <- parse_single_image(image = r2const(acquireImageFxn( uniqueImageKeys[irz] ), 113 | eval(parse(text = sprintf("cienv$tf$%s",image_dtype)))), 114 | index = irz, 115 | key = as.character(uniqueImageKeys[irz] ) ) 116 | tf_record_writer$write( tf_record_write_output$SerializeToString() ) 117 | } 118 | print("Finalizing tfrecords....") 119 | tf_record_writer$close() 120 | print("Done writing tfrecord!") 121 | } 122 | 123 | #!/usr/bin/env Rscript 124 | #' Reads unique key indices from a `.tfrecord` file. 125 | #' 126 | #' Reads unique key indices from a `.tfrecord` file saved via a call to `causalimages::WriteTfRecord`. 127 | #' 128 | #' @usage 129 | #' 130 | #' GetElementFromTfRecordAtIndices(uniqueKeyIndices, file, 131 | #' conda_env, conda_env_required) 132 | #' 133 | #' @param uniqueKeyIndices (integer vector) Unique image indices to be retrieved from a `.tfrecord` 134 | #' @param file (character string) A character string stating the path to a `.tfrecord` 135 | #' @param conda_env (Default = `NULL`) A `conda` environment where tensorflow v2 lives. Used only if a version of tensorflow is not already active. 136 | #' @param conda_env_required (default = `F`) A Boolean stating whether use of the specified conda environment is required. 137 | #' 138 | #' @return Returns content from a `.tfrecord` associated with `uniqueKeyIndices` 139 | #' 140 | #' @examples 141 | #' # Example usage (not run): 142 | #' #GetElementFromTfRecordAtIndices( 143 | #' #uniqueKeyIndices = 1:10, 144 | #' #file = "./NigeriaConfoundApp.tfrecord") 145 | #' 146 | #' @export 147 | #' @md 148 | GetElementFromTfRecordAtIndices <- function(uniqueKeyIndices, filename, nObs, readVideo = F, 149 | conda_env = NULL, conda_env_required = F, image_dtype = "float16", 150 | iterator = NULL, return_iterator = F){ 151 | # consider passing iterator as input to function to speed up large-batch execution 152 | image_dtype_ <- try(eval(parse(text = sprintf("cienv$tf$%s",image_dtype))), T) 153 | if("try-error" %in% class(image_dtype_)){ 154 | image_dtype_ <- try(eval(parse(text = sprintf("cienv$tf$%s",image_dtype$name))), T) 155 | } 156 | image_dtype <- image_dtype_ 157 | 158 | if(is.null(iterator)){ 159 | orig_wd <- getwd() 160 | tf_record_name <- filename 161 | if( !grepl(tf_record_name, pattern = "/") ){ 162 | tf_record_name <- paste("./",tf_record_name, sep = "") 163 | } 164 | tf_record_name <- strsplit(tf_record_name,split="/")[[1]] 165 | new_wd <- paste(tf_record_name[-length(tf_record_name)],collapse = "/") 166 | setwd( new_wd ) 167 | 168 | # Load the TFRecord file 169 | dataset = cienv$tf$data$TFRecordDataset( tf_record_name[length(tf_record_name)] ) 170 | 171 | # Parse the tf.Example messages 172 | dataset <- dataset$map( function(x){ parse_tfr_element(x, 173 | readVideo = readVideo, 174 | image_dtype = image_dtype) }) # return 175 | 176 | index_counter <- last_in_ <- 0L 177 | return_list <- replicate(length( dataset$element_spec), {list(replicate(length(uniqueKeyIndices), list()))}) 178 | } 179 | 180 | if(!is.null(iterator)){ 181 | dataset_iterator <- iterator[[1]] 182 | last_in_ <- iterator[[2]] # note: last_in_ is 0 indexed 183 | index_counter <- 0L 184 | return_list <- replicate(length( dataset_iterator$element_spec), 185 | {list(replicate(length(uniqueKeyIndices), list()))}) 186 | } 187 | 188 | # uniqueKeyIndices made 0 indexed 189 | uniqueKeyIndices <- as.integer( uniqueKeyIndices - 1L ) 190 | 191 | for(in_ in (indices_sorted <- sort(uniqueKeyIndices))){ 192 | index_counter <- index_counter + 1 193 | 194 | # Skip the first `uniqueKeyIndices` elements, shifted by current loc thru data set 195 | if( index_counter == 1 & is.null(iterator) ){ 196 | dataset <- dataset$skip( as.integer(in_) )#$prefetch(buffer_size = 5L) 197 | dataset_iterator <- reticulate::as_iterator( dataset$take( as.integer(nObs - as.integer(in_) ) )) 198 | element <- dataset_iterator$`next`() 199 | } 200 | 201 | # Take the next element, then 202 | # Get the only element in the dataset (as a tuple of features) 203 | if(index_counter > 1 | !is.null(iterator)){ 204 | needThisManyUnsavedIters <- (in_ - last_in_ - 1L) 205 | if(length(needThisManyUnsavedIters) > 0){ if(needThisManyUnsavedIters > 0){ 206 | for(fari in 1:needThisManyUnsavedIters){ dataset_iterator$`next`() } 207 | } } 208 | element <- dataset_iterator$`next`() 209 | } 210 | last_in_ <- in_ 211 | 212 | # form final output 213 | if(length(uniqueKeyIndices) == 1){ return_list <- element } 214 | if(length(uniqueKeyIndices) > 1){ 215 | for(li_ in 1:length(element)){ 216 | return_list[[li_]][[index_counter]] <- cienv$tf$expand_dims(element[[li_]],0L) 217 | } 218 | } 219 | if(index_counter %% 5==0){ try(cienv$py_gc$collect(),T) } 220 | } 221 | 222 | if(index_counter > 1){ for(li_ in 1:length(element)){ 223 | return_list[[li_]] <- eval(parse(text = 224 | paste("cienv$tf$concat( list(", paste(paste("return_list[[li_]][[", 1:length(uniqueKeyIndices), "]]"), 225 | collapse = ","), "), 0L)", collapse = "") )) 226 | if( any(diff(uniqueKeyIndices)<0) ){ # re-order if needed 227 | return_list[[li_]] <- cienv$tf$gather(return_list[[li_]], 228 | indices = as.integer(match(uniqueKeyIndices,indices_sorted)-1L), 229 | axis = 0L) 230 | } 231 | }} 232 | 233 | if(is.null(iterator)){ setwd( orig_wd ) } 234 | 235 | if(return_iterator == T){ 236 | return_list <- list(return_list, list(dataset_iterator, last_in_)) 237 | } 238 | 239 | return( return_list ) 240 | } 241 | 242 | parse_tfr_element <- function(element, readVideo, image_dtype){ 243 | #use the same structure as above; it's kinda an outline of the structure we now want to create 244 | image_dtype_ <- try(eval(parse(text = sprintf("cienv$tf$%s",image_dtype))), T) 245 | if("try-error" %in% class(image_dtype_)){ 246 | image_dtype_ <- try(eval(parse(text = sprintf("cienv$tf$%s",image_dtype$name))), T) 247 | } 248 | image_dtype <- image_dtype_ 249 | 250 | dict_init_val <- list() 251 | if(!readVideo){ 252 | im_feature_description <- dict( 253 | 'height' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$int64), 254 | 'width' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$int64), 255 | 'channels' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$int64), 256 | 'raw_image' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$string), 257 | 'index' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$int64), 258 | 'key' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$string) 259 | ) 260 | } 261 | 262 | if(readVideo){ 263 | im_feature_description <- dict( 264 | 'time' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$int64), 265 | 'height' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$int64), 266 | 'width' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$int64), 267 | 'channels' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$int64), 268 | 'raw_image' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$string), 269 | 'index' = cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$int64), 270 | 'key'= cienv$tf$io$FixedLenFeature(dict_init_val, cienv$tf$string) 271 | ) 272 | } 273 | 274 | # parse tf record 275 | content <- cienv$tf$io$parse_single_example(element, 276 | im_feature_description) 277 | 278 | # get 'feature' (e.g., image/image sequence) 279 | feature <- cienv$tf$io$parse_tensor( content[['raw_image']], 280 | out_type = image_dtype ) 281 | 282 | # get the key 283 | key <- cienv$tf$io$parse_tensor( content[['key']], 284 | out_type = cienv$tf$string ) 285 | 286 | # and reshape it appropriately 287 | if(!readVideo){ 288 | feature = cienv$tf$reshape( feature, shape = c(content[['height']], 289 | content[['width']], 290 | content[['channels']]) ) 291 | } 292 | if(readVideo){ 293 | feature = cienv$tf$reshape( feature, shape = c(content[['time']], 294 | content[['height']], 295 | content[['width']], 296 | content[['channels']]) ) 297 | } 298 | 299 | return( list(feature, content[['index']], key) ) 300 | } 301 | -------------------------------------------------------------------------------- /causalimages/R/CI_TrainDefine.R: -------------------------------------------------------------------------------- 1 | #' Defines an internal training routine (internal function) 2 | #' 3 | #' Defines trainers defined in TrainDefine(). Internal function. 4 | #' 5 | #' @param . No parameters. 6 | #' 7 | #' @return Internal function defining a training sequence. 8 | #' 9 | #' @import reticulate rrapply 10 | #' @export 11 | #' @md 12 | TrainDefine <- function(){ 13 | message2("Define optimizer and training step...") 14 | { 15 | LR_schedule <- cienv$optax$warmup_cosine_decay_schedule( 16 | warmup_steps = (nWarmup <- (0.05*nSGD)), 17 | decay_steps = nSGD-nWarmup, 18 | init_value = learningRateMax/100, 19 | peak_value = learningRateMax, 20 | end_value = learningRateMax/100) 21 | plot(cienv$np$array(LR_schedule(cienv$jnp$array(1:nSGD))),xlab = "Iteration", ylab="Learning rate schedule") 22 | optax_optimizer <- cienv$optax$chain( 23 | cienv$optax$adaptive_grad_clip(clipping = 0.25, eps = 0.001), 24 | cienv$optax$adabelief( learning_rate = LR_schedule ) 25 | ) 26 | plot(cienv$np$array(LR_schedule(cienv$jnp$array(1:nSGD))), xlab = "Iteration", ylab = "Learning rate") 27 | 28 | # model partition, setup state, perform parameter count 29 | opt_state <- optax_optimizer$init( cienv$eq$partition(ModelList, cienv$eq$is_array)[[1]] ) 30 | message2(sprintf("Total trainable parameter count: %s", 31 | nParamsRep <- nTrainable <- 32 | sum(unlist(lapply(cienv$jax$tree$leaves(cienv$eq$partition(ModelList, 33 | cienv$eq$is_array)[[1]]), function(zer){zer$size}))))) 34 | 35 | # jit update fxns 36 | jit_apply_updates <- cienv$eq$filter_jit( cienv$optax$apply_updates ) 37 | jit_get_update <- cienv$eq$filter_jit( optax_optimizer$update ) 38 | } 39 | } 40 | 41 | -------------------------------------------------------------------------------- /causalimages/R/CI_TrainDo.R: -------------------------------------------------------------------------------- 1 | #' Runs a training routine (internal function) 2 | #' 3 | #' Runs trainers defined in TrainDefine(). Internal function. 4 | #' 5 | #' @param . No parameters. 6 | #' 7 | #' @return Internal function performing model training. 8 | #' 9 | #' @import reticulate rrapply 10 | #' @export 11 | #' @md 12 | TrainDo <- function(){ 13 | par(mfrow=c(1,2)) 14 | keys2indices_list <- tapply(1:length(imageKeysOfUnits), imageKeysOfUnits, c) 15 | GradNorm_vec <- loss_vec <- rep(NA,times=nSGD) 16 | keysUsedInTraining <- c();i_<-1L ; DoneUpdates <- 0L; for(i in i_:nSGD){ 17 | t0 <- Sys.time(); if(i %% 5 == 0 | i == 1){gc(); cienv$py_gc$collect()} 18 | 19 | if(is.null(TFRecordControl)){ 20 | # get next batch 21 | ds_next_train <- ds_iterator_train$`next`() 22 | 23 | # if we run out of observations, reset iterator 24 | RestartedIterator <- FALSE; if( is.null(ds_next_train) ){ 25 | message2("Re-setting iterator! (type 1)"); gc(); cienv$py_gc$collect() 26 | ds_iterator_train <- reticulate::as_iterator( tf_dataset_train ) 27 | ds_next_train <- ds_iterator_train$`next`(); gc();cienv$py_gc$collect() 28 | } 29 | 30 | # get a new batch if size mismatch - size mismatches generate new cached compiled fxns 31 | if(!RestartedIterator){ if(dim(ds_next_train[[1]])[1] != batchSize){ 32 | message2("Re-setting iterator! (type 2)"); gc(); cienv$py_gc$collect() 33 | ds_iterator_train <- reticulate::as_iterator( tf_dataset_train ) 34 | ds_next_train <- ds_iterator_train$`next`(); gc(); cienv$py_gc$collect() 35 | } } 36 | 37 | # select batch indices based on keys 38 | batch_keys <- unlist( lapply( p2l(ds_next_train[[3]]$numpy()), as.character) ) 39 | batch_indices <- sapply(batch_keys,function(key_){ f2n( sample(as.character( keys2indices_list[[key_]] ), 1) ) }) 40 | ds_next_train <- ds_next_train[[1]] 41 | } 42 | if(!is.null(TFRecordControl)){ 43 | # get next batch 44 | ds_next_train_control <- ds_iterator_train_control$`next`() 45 | 46 | # if we run out of observations, reset iterator 47 | RestartedIterator <- FALSE; if( is.null(ds_next_train_control) ){ 48 | message2("Re-setting iterator! (type 1)"); gc(); cienv$py_gc$collect() 49 | ds_iterator_train_control <- reticulate::as_iterator( tf_dataset_train_control ) 50 | ds_next_train_control <- ds_iterator_train_control$`next`(); gc();cienv$py_gc$collect() 51 | } 52 | 53 | # get a new batch if size mismatch - size mismatches generate new cached compiled fxns 54 | if(!RestartedIterator){ if(dim(ds_next_train_control[[1]])[1] != batchSize){ 55 | message2("Re-setting iterator! (type 2)"); gc(); cienv$py_gc$collect() 56 | ds_iterator_train_control <- reticulate::as_iterator( tf_dataset_train_control ) 57 | ds_next_train_control <- ds_iterator_train_control$`next`(); gc(); cienv$py_gc$collect() 58 | } } 59 | 60 | # get next batch 61 | ds_next_train_treated <- ds_iterator_train_treated$`next`() 62 | 63 | # if we run out of observations, reset iterator 64 | RestartedIterator <- F; if( is.null(ds_next_train_treated) ){ 65 | message2("Re-setting iterator! (type 1)"); gc(); cienv$py_gc$collect() 66 | ds_iterator_train_treated <- reticulate::as_iterator( tf_dataset_train_treated ) 67 | ds_next_train_treated <- ds_iterator_train_treated$`next`(); gc();cienv$py_gc$collect() 68 | } 69 | 70 | # get a new batch if size mismatch - size mismatches generate new cached compiled fxns 71 | if(!RestartedIterator){ if(dim(ds_next_train_treated[[1]])[1] != batchSize){ 72 | message2("Re-setting iterator! (type 2)"); gc(); cienv$py_gc$collect() 73 | ds_iterator_train_treated <- reticulate::as_iterator( tf_dataset_train_treated ) 74 | ds_next_train_treated <- ds_iterator_train_treated$`next`(); gc(); cienv$py_gc$collect() 75 | } } 76 | 77 | # select batch indices based on keys 78 | batch_keys <- c(unlist( lapply( p2l(ds_next_train_control[[3]]$numpy()), as.character) ), 79 | unlist( lapply( p2l(ds_next_train_treated[[3]]$numpy()), as.character) )) 80 | batch_indices <- sapply(batch_keys,function(key_){ f2n( sample(as.character( keys2indices_list[[key_]] ), 1) ) }) 81 | ds_next_train <- cienv$tf$concat(list(ds_next_train_control[[1]], 82 | ds_next_train_treated[[1]]), 0L) 83 | } 84 | if(any(!batch_indices %in% keysUsedInTraining)){ 85 | keysUsedInTraining <- c(keysUsedInTraining, batch_keys[!batch_keys %in% keysUsedInTraining]) 86 | } 87 | 88 | # if no treat, define it (unused in GetLoss) 89 | if(!"obsW" %in% ls()){ obsW <- obsY } 90 | 91 | # training step 92 | if(!justCheckIterators){ 93 | t1 <- Sys.time() 94 | if(i == 1){ 95 | message2("Initial forward pass...") 96 | GetLoss( 97 | MPList[[1]]$cast_to_compute(ModelList), # model list 98 | MPList[[1]]$cast_to_compute(ModelList_fixed), # model list 99 | InitImageProcessFn(cienv$jnp$array(ds_next_train), cienv$jax$random$key(600L+i), inference = F), # m 100 | #InitImageProcessFn(cienv$jnp$array(ds_next_train), cienv$jax$random$key(600L), inference = FALSE), # m 101 | cienv$jnp$array(ifelse( !is.null(X), yes = list(X[batch_indices,]), no = list(1.))[[1]] , dtype = cienv$jnp$float16), # x 102 | cienv$jnp$array(as.matrix(obsW[batch_indices]), dtype = cienv$jnp$float16), # treat 103 | cienv$jnp$array(as.matrix(obsY[batch_indices]), dtype = cienv$jnp$float16), # y 104 | cienv$jax$random$split(cienv$jax$random$key( 500L+i ),length(batch_indices)), # vseed for observations 105 | StateList, # StateList 106 | MPList, # MPlist 107 | FALSE)[[1]] 108 | } 109 | 110 | # sanity check 111 | if(FALSE){ 112 | test_index <- 2 113 | GetElementFromTfRecordAtIndices( 114 | uniqueKeyIndices = which(unique(imageKeysOfUnits)==unique(imageKeysOfUnits)[test_index]), 115 | filename = file, 116 | readVideo = useVideoIndicator, 117 | nObs = length(unique(imageKeysOfUnits) ) ) 118 | # unique(imageKeysOfUnits)[test_index] 119 | } 120 | 121 | # Sanity check for dimension swapping as i varies 122 | if(i == 1){ 123 | message2(sprintf("Training balance: %s", 124 | paste(paste(names(table(obsW[batch_indices])), 125 | table(obsW[batch_indices]), sep = " has "),collapse="; ") 126 | )) 127 | if(any(prop.table(table(obsW[batch_indices])) > 0.9) & 128 | !is.null(TFRecordControl)){ 129 | stop( "Stopping - Balanced training not satisfied despite TFRecordControl being defined!" ) 130 | } 131 | } 132 | # causalimages::image2(cienv$np$array(InitImageProcessFn(cienv$jnp$array(ds_next_train), cienv$jax$random$key(600L+sample(1:100,1)), inference = F)[2,,,1])) 133 | # causalimages::image2(cienv$np$array(InitImageProcessFn(cienv$jnp$array(ds_next_train), cienv$jax$random$key(600L+sample(1:100,1)), inference = F)[1,,,1])) 134 | 135 | # Get gradient update packages 136 | GradientUpdatePackage <- GradAndLossAndAux( 137 | MPList[[1]]$cast_to_compute(ModelList), MPList[[1]]$cast_to_compute(ModelList_fixed), # model lists 138 | InitImageProcessFn(cienv$jnp$array(ds_next_train), cienv$jax$random$key(600L+i), inference = FALSE), # m 139 | cienv$jnp$array(ifelse( !is.null(X), yes = list(X[batch_indices,]), no = list(1.))[[1]], dtype = ComputeDtype), # x 140 | cienv$jnp$array(as.matrix(obsW[batch_indices]), dtype = ComputeDtype), # treat (unused for prediction only runs) 141 | cienv$jnp$array(as.matrix(obsY[batch_indices]), dtype = ComputeDtype), # y 142 | cienv$jax$random$split(cienv$jax$random$key( 50L+i ),length(batch_indices)), # vseed for observations 143 | StateList, # StateList 144 | MPList, # MPlist 145 | FALSE) # inference 146 | 147 | # perform gradient updates 148 | { 149 | # get updated state 150 | StateList_tmp <- GradientUpdatePackage[[1]][[2]] # state 151 | 152 | # get loss + grad 153 | if(image_dtype_char == "float16"){ 154 | loss_vec[i] <- myLoss_fromGrad <- cienv$np$array( MPList[[2]]$unscale( GradientUpdatePackage[[1]][[1]] ) )# value 155 | } 156 | if(image_dtype_char != "float16"){ 157 | loss_vec[i] <- myLoss_fromGrad <- cienv$np$array( GradientUpdatePackage[[1]][[1]] )# value 158 | } 159 | GradientUpdatePackage <- GradientUpdatePackage[[2]] # grads 160 | GradientUpdatePackage <- cienv$eq$partition(GradientUpdatePackage, cienv$eq$is_inexact_array) 161 | GradientUpdatePackage_aux <- GradientUpdatePackage[[2]]; GradientUpdatePackage <- GradientUpdatePackage[[1]] 162 | 163 | # unscale + adjust loss scale is some non-finite or NA 164 | if(i == 1){ 165 | Map2Zero <- cienv$eq$filter_jit(function(input){ 166 | cienv$jax$tree$map(function(x){ cienv$jnp$where(cienv$jnp$isnan(x), cienv$jnp$array(0), x)}, input) }) 167 | GetGetNorms <- cienv$eq$filter_jit(function(input){ 168 | cienv$jax$tree$map(function(x){ cienv$jnp$mean(cienv$jnp$abs(x)) }, input) }) 169 | AllFinite <- cienv$jax$jit( cienv$jmp$all_finite ) 170 | } 171 | if(image_dtype_char == "float16"){ 172 | GradientUpdatePackage <- Map2Zero( MPList[[2]]$unscale( GradientUpdatePackage ) ) 173 | } 174 | AllFinite_DontAdjust <- AllFinite( GradientUpdatePackage ) & 175 | cienv$jnp$squeeze(cienv$jnp$array(!is.infinite(myLoss_fromGrad))) 176 | # MPList[[2]]$adjust( cienv$jnp$array(FALSE) ) 177 | # MPList[[2]]$adjust( cienv$jnp$array(TRUE) ) 178 | MPList[[2]] <- MPList[[2]]$adjust( AllFinite_DontAdjust ) 179 | # which(is.na( c(unlist(lapply(cienv$jax$tree$leaves(myGrad_jax), function(zer){cienv$np$array(zer)}))) ) ) 180 | # which(is.infinite( c(unlist(lapply(cienv$jax$tree$leaves(myGrad_jax), function(zer){cienv$np$array(zer)}))) ) ) 181 | 182 | # get update norm 183 | GradNorm_vec[i] <- mean( GradVec <- unlist( lapply(cienv$jax$tree$leaves(GradientUpdatePackage), 184 | function(zer){ cienv$np$array(cienv$jnp$mean(cienv$jnp$abs(zer) )) }) ) ) 185 | 186 | # update parameters if finite gradients 187 | DoUpdate <- !is.na(myLoss_fromGrad) & cienv$np$array(AllFinite_DontAdjust) & 188 | !is.infinite(myLoss_fromGrad) & ( GradNorm_vec[i] > 1e-10) 189 | if( !DoUpdate ){ 190 | message2("Warning: Not updating parameters due to NA, zero, or non-finite gradients in mixed-precision training...") 191 | } 192 | if( DoUpdate ){ 193 | DoneUpdates <- DoneUpdates + 1 194 | 195 | # cast updates to param 196 | GradientUpdatePackage <- MPList[[1]]$cast_to_param( GradientUpdatePackage ) 197 | 198 | # get gradient updates 199 | # GradientUpdatePackage$SpatialTransformer$ResidualWts # check non-zero gradients 200 | # GradientUpdatePackage$SpatialTransformer$TransformerRenormer # -> check non-zero gradients here (indicates problem with dropout) 201 | GradientUpdatePackage <- jit_get_update( 202 | updates = GradientUpdatePackage, 203 | state = opt_state, 204 | params = (cienv$eq$partition(ModelList, cienv$eq$is_inexact_array)[[1]] ) 205 | ) 206 | 207 | if(FALSE){ 208 | # Before the jit_get_update call, add: 209 | params_tree = cienv$eq$partition(ModelList, cienv$eq$is_array)[[1]] 210 | grads_tree = GradientUpdatePackage 211 | 212 | # Print tree structures 213 | message2("Params tree structure:") 214 | print(cienv$jax$tree$structure(params_tree)) 215 | message2("Grads tree structure:") 216 | print(cienv$jax$tree$structure(grads_tree)) 217 | 218 | # Check for None values 219 | params_leaves = cienv$jax$tree$leaves(params_tree) 220 | grads_leaves = cienv$jax$tree$leaves(grads_tree) 221 | message2(sprintf("Params leaves: %d, Grads leaves: %d", 222 | length(params_leaves), length(grads_leaves))) 223 | } 224 | 225 | # separate updates from state 226 | opt_state <- GradientUpdatePackage[[2]] 227 | GradientUpdatePackage <- cienv$eq$combine(GradientUpdatePackage[[1]], 228 | GradientUpdatePackage_aux) 229 | 230 | # perform updates 231 | #ModelList_tminus1 <- ModelList 232 | ModelList <- cienv$eq$combine( jit_apply_updates( 233 | params = cienv$eq$partition(ModelList, cienv$eq$is_inexact_array)[[1]], 234 | updates = GradientUpdatePackage), 235 | cienv$eq$partition(ModelList, cienv$eq$is_inexact_array)[[2]]) 236 | StateList <- StateList_tmp 237 | if(FALSE){ 238 | LayerWiseParamDiff <- function(params_new, params_old){ 239 | diff_fn <- function(new, old){ 240 | cienv$jnp$mean(cienv$jnp$abs(new - old)) 241 | } 242 | diff_list <- cienv$jax$tree$map(diff_fn, params_new, params_old) 243 | rrapply::rrapply(diff_list, how = "flatten") 244 | } 245 | 246 | # Use the function right after parameters update: 247 | param_diffs <- LayerWiseParamDiff( 248 | cienv$eq$partition(ModelList, cienv$eq$is_array)[[1]], 249 | cienv$eq$partition(ModelList_tminus1, cienv$eq$is_array)[[1]] 250 | ) 251 | GradientUpdatePackage 252 | } 253 | suppressWarnings( rm(StateList_tmp, GradientUpdatePackage,BNInfo) ) 254 | } 255 | i_ <- i ; if( (i %% 25 == 0 | i < 10) & 256 | (length(loss_vec[!is.na(loss_vec) & !is.infinite(loss_vec)]) > 5) ){ 257 | loss_vec_ <- loss_vec 258 | loss_vec_[is.infinite(loss_vec_)] <- NA 259 | message2(sprintf("SGD iteration %s of %s -- Loss: %.2f (%.1f%%) -- 260 | Total iter time (s): %.2f - Grad iter time (s): %.2f -- 261 | Grad norm: %.3f -- Grads zero %%: %.1f%% -- 262 | %.3f tstat on log(iter)", 263 | i, nSGD, loss_vec[i], 100*mean(loss_vec[i] <= loss_vec[1:i],na.rm=T), 264 | (Sys.time() - t0)[[1]], (Sys.time() - t1)[[1]], 265 | mean(GradVec,na.rm=T), 100*mean(GradVec==0,na.rm=T), 266 | ifelse("try-error" %in% class(tstat_ <- try(coef(summary(lm(loss_vec[1:i]~log(1:i))))[2,3], T)),yes = NA, no = tstat_) 267 | ) ) 268 | loss_vec <- f2n(loss_vec); loss_vec[is.infinite(loss_vec)] <- NA 269 | plot( (na.omit(loss_vec)), cex.main = 0.95,ylab = "Loss Function",xlab="SGD Iteration Number") 270 | if(length(na.omit(loss_vec)) > 10){ points(smooth.spline( (na.omit(loss_vec) ),spar=1,cv=TRUE), col="red",type = "l",lwd=5) } 271 | plot(GradNorm_vec[!is.infinite(GradNorm_vec) & !is.na(GradNorm_vec)], cex.main = 0.95,ylab = "GradNorm",xlab="SGD Iteration Number") 272 | } 273 | 274 | # Early stopping 275 | 276 | if( !is.null(earlyStopThreshold) ){ 277 | window <- 25 278 | patience_limit <- 25 279 | if(!"patience_counter" %in% ls()){ patience_counter <- 0 } 280 | if( i > 2*window & i > 100 ){ 281 | first_avg <- mean(loss_vec[1:10], na.rm = TRUE) 282 | prev_avg <- mean(loss_vec[(i-2*window):(i-window-1)], na.rm = TRUE) 283 | curr_avg <- mean(loss_vec[(i-window):i], na.rm = TRUE) 284 | 285 | se_diff <- sqrt( var(loss_vec[(i-2*window):(i-window-1)], na.rm=TRUE)/window + 286 | var(loss_vec[(i-window):i], na.rm=TRUE)/window ) 287 | prev_avg_upper <- curr_avg + (t_es<-2.528)*sqrt( var(loss_vec[(i-2*window):(i-window-1)], na.rm=TRUE)/window ) 288 | curr_avg_lower <- prev_avg - t_es*sqrt( var(loss_vec[(i-window):i], na.rm=TRUE)/window ) 289 | 290 | if( curr_avg >= prev_avg - t_es*se_diff & curr_avg < 0.8*first_avg ){ 291 | message2("We fail to detect evidence of improvement, early stopping being considered...") 292 | patience_counter <- patience_counter + 1 293 | if(patience_counter >= patience_limit){ 294 | message2("Early stopping triggered - No more meaningful improvement.") 295 | break 296 | } 297 | } else { 298 | patience_counter <- 0 # reset when any improvement detected 299 | } 300 | } 301 | } 302 | } 303 | } 304 | } # end for(i in i_:nSGD){ 305 | par(mfrow=c(1,1)) 306 | } 307 | 308 | -------------------------------------------------------------------------------- /causalimages/R/CI_helperFxns.R: -------------------------------------------------------------------------------- 1 | #' Get the spatial point of long/lat coordinates 2 | #' 3 | #' Convert longitude and latitude coordinates to a different coordinate reference 4 | #' system (CRS). 5 | #' 6 | #' @param long Vector of numeric longitudes. 7 | #' @param lat Vector of numeric latitudes. 8 | #' @param CRS_ref A CRS into which the long-lat point should be projected. 9 | #' 10 | #' @return Numeric vector of length two giving the coordinates of the supplied 11 | #' location in the CRS defined by `CRS_ref`. 12 | #' 13 | #' @examples 14 | #' # (Not run) 15 | #' #spatialPt <- LongLat2CRS(long = 49.932, 16 | #' # lat = 35.432, 17 | #' # CRS_ref = sf::st_crs("+proj=lcc +lat_1=48 +lat_2=33 +lon_0=-100 +ellps=WGS84")) 18 | #' @export 19 | #' @md 20 | #' 21 | LongLat2CRS <- function(long, lat, CRS_ref){ 22 | point_longlat <- sf::st_as_sf( 23 | data.frame(long = as.numeric(long), lat = as.numeric(lat)), 24 | coords = c("long", "lat"), 25 | crs = 4326 26 | ) 27 | point_longlat_ref <- sf::st_transform(point_longlat, crs = sf::st_crs(CRS_ref)) 28 | coords_ <- sf::st_coordinates(point_longlat_ref)[1, ] 29 | return(coords_) 30 | } 31 | 32 | LongLat2CRS_extent <- function(point_longlat, 33 | CRS_ref, 34 | target_km_diameter = 10){ 35 | target_km <- target_km_diameter 36 | offset <- 1/111 * (target_km/2) 37 | point_longlat1 <- c(long = as.numeric(point_longlat[1]) - offset, 38 | lat = as.numeric(point_longlat[2]) - offset) 39 | point_longlat2 <- c(long = as.numeric(point_longlat[1]) + offset, 40 | lat = as.numeric(point_longlat[2]) + offset) 41 | pts <- sf::st_as_sf(rbind(point_longlat1, point_longlat2), 42 | coords = c("long", "lat"), crs = 4326) 43 | pts_ref <- sf::st_transform(pts, crs = sf::st_crs(CRS_ref)) 44 | coords_ <- sf::st_coordinates(pts_ref) 45 | return(raster::extent(min(coords_[,1]), max(coords_[,1]), 46 | min(coords_[,2]), max(coords_[,2]))) 47 | } 48 | 49 | # converts python builtin to list 50 | p2l <- function(zer){ 51 | if("python.builtin.bytes" %in% class(zer)){ zer <- list(zer) } 52 | return( zer ) 53 | } 54 | 55 | # zips two lists 56 | rzip<-function(l1,l2){ fl<-list(); for(aia in 1:length(l1)){ fl[[aia]] <- list(l1[[aia]], l2[[aia]]) }; return( fl ) } 57 | 58 | # reshapes 59 | reshape_fxn_DEPRECIATED <- function(input_){ 60 | ## DEPRECIATED 61 | cienv$tf$reshape(input_, list(cienv$tf$shape(input_)[1], 62 | cienv$tf$reduce_prod(cienv$tf$shape(input_)[2:5]))) 63 | } 64 | 65 | fixZeroEndings <- function(zr,roundAt=2){ 66 | unlist( lapply(strsplit(as.character(zr),split="\\."),function(l_){ 67 | if(length(l_) == 1){ retl <- paste(l_, paste(rep("0",times=roundAt),collapse=""),sep=".") } 68 | if(length(l_) == 2){ 69 | retl <- paste(l_[1], paste(l_[2], paste(rep("0",times=roundAt-nchar(l_[2])),collapse=""),sep=""), 70 | sep = ".") } 71 | return( retl ) 72 | }) ) } 73 | 74 | r2const <- function(x, dtype){ 75 | if("tensorflow.tensor" %in% class( x )){ x <- cienv$tf$cast(x, dtype = dtype ) } 76 | if(!"tensorflow.tensor" %in% class( x )){ x <- cienv$tf$constant(x, dtype = dtype ) } 77 | return( x ) 78 | } 79 | 80 | #' print2 print() with timestamps 81 | #' 82 | #' A function prints a string with date and time. 83 | #' 84 | #' @param x Character string to be printed, with date and time. 85 | #' 86 | #' @return Prints with date and time. 87 | #' 88 | #' @examples 89 | #' message("Hello world") 90 | #' @export 91 | #' @md 92 | #' 93 | print2 <- function(text, quiet = F){ 94 | if(!quiet){ print( sprintf("[%s] %s" ,format(Sys.time(), "%Y-%m-%d %H:%M:%S"),text) ) } 95 | } 96 | 97 | #' message2 message() with timestamps 98 | #' 99 | #' A function that displays a message with date and time. 100 | #' 101 | #' @param text Character string to be displayed as message, with date and time. 102 | #' @param quiet Logical. If TRUE, suppresses the message output. Default is FALSE. 103 | #' 104 | #' @return Displays message with date and time to stderr. 105 | #' 106 | #' @examples 107 | #' message2("Hello world") 108 | #' message2("Process completed", quiet = FALSE) 109 | #' @export 110 | #' @md 111 | #' 112 | message2 <- function(text, quiet = FALSE){ 113 | if(!quiet){ 114 | message(sprintf("[%s] %s", format(Sys.time(), "%Y-%m-%d %H:%M:%S"), text)) 115 | } 116 | } 117 | 118 | # LE <- function(l_, name_){ return( unlist(l_)[[name_]] ) } 119 | # l_ <- DenseList;name <-"Tau_d1" 120 | LE <- function(l_, key) { 121 | # Recursive helper function 122 | search_recursive <- function(list_element, key) { 123 | # Check if the current element is a list 124 | if (is.list(list_element)) { 125 | # If it's a list, check if the key exists in this list 126 | if (key %in% names(list_element)) { 127 | return(list_element[[key]]) 128 | } 129 | # Otherwise, iterate over its elements 130 | for (item in list_element) { 131 | found <- search_recursive(item, key) 132 | if (!is.null(found)) { 133 | return(found) 134 | } 135 | } 136 | } 137 | return(NULL) 138 | } 139 | 140 | # Start the recursive search 141 | return(search_recursive(l_, key)) 142 | } 143 | 144 | LE_index <- function(l_, key) { 145 | # Recursive helper function 146 | search_recursive <- function(list_element, key, path) { 147 | # Check if the current element is a list 148 | if(is.list(list_element)){ 149 | 150 | # If it's a list, check if the key exists in this list 151 | if(key %in% names(list_element) & (length(names(list_element)) == 1)){ 152 | return( c(path, 153 | ifelse("list" %in% class(list_element), yes = 1, no = NULL)) ) 154 | } 155 | if(key %in% names(list_element) & (length(names(list_element)) > 1)){ 156 | return( c(path, 157 | which(names(list_element) == key), 158 | ifelse("list" %in% class(list_element), yes = 1, no = NULL) ) ) 159 | } 160 | 161 | # Otherwise, iterate over its elements 162 | for (i in seq_along(list_element)) { 163 | new_path <- c(path, i) 164 | found <- search_recursive(list_element[[i]], key, new_path) 165 | if (!is.null(found)) { return( found ) } 166 | } 167 | } 168 | return(NULL) 169 | } 170 | 171 | # Start the recursive search 172 | return(search_recursive(l_, key, c())) 173 | } 174 | 175 | GlobalPartition <- function(zer, eq_fxn){ 176 | yes_branches <- rrapply::rrapply(zer,f=function(zerr){ 177 | unlist(ifelse(eq_fxn(zerr),yes = list(zerr), no = list(NULL))[[1]]) 178 | },how="list") 179 | no_branches <- rrapply::rrapply(zer,f=function(zerrr){ 180 | unlist(ifelse(eq_fxn(zerrr),yes = list(NULL), no = list(zerrr))[[1]]) 181 | },how="list") 182 | list(yes_branches,no_branches) 183 | } 184 | PartFxn <- function(zerz){ !"first_time_index" %in% names(zerz)} 185 | 186 | AddQuotes <- function(text) { gsub("\\[\\[([A-Za-z]\\w*)", "\\[\\['\\1'", text) } 187 | LinearizeNestedList <- function (NList, 188 | LinearizeDataFrames = FALSE, 189 | NameSep = "/", 190 | ForceNames = FALSE){ 191 | stopifnot(is.character(NameSep), length(NameSep) == 1) 192 | stopifnot(is.logical(LinearizeDataFrames), length(LinearizeDataFrames) == 1) 193 | stopifnot(is.logical(ForceNames), length(ForceNames) == 1) 194 | if (!is.list(NList)) 195 | return(NList) 196 | if (is.null(names(NList)) | ForceNames == TRUE) 197 | names(NList) <- as.character(1:length(NList)) 198 | if (is.data.frame(NList) & LinearizeDataFrames == FALSE) 199 | return(NList) 200 | if (is.data.frame(NList) & LinearizeDataFrames == TRUE) 201 | return(as.list(NList)) 202 | A <- 1 203 | B <- length(NList) 204 | while (A <= B) { 205 | Element <- NList[[A]] 206 | EName <- names(NList)[A] 207 | if (is.list(Element)) { 208 | Before <- if (A == 1) 209 | NULL 210 | else NList[1:(A - 1)] 211 | After <- if (A == B) 212 | NULL 213 | else NList[(A + 1):B] 214 | if (is.data.frame(Element)) { 215 | if (LinearizeDataFrames == TRUE) { 216 | Jump <- length(Element) 217 | NList[[A]] <- NULL 218 | if(is.null(names(Element)) | ForceNames == TRUE) 219 | names(Element) <- as.character(1:length(Element)) 220 | Element <- as.list(Element) 221 | names(Element) <- paste(EName, names(Element), 222 | sep = NameSep) 223 | NList <- c(Before, Element, After) 224 | } 225 | Jump <- 1 226 | } 227 | else { 228 | NList[[A]] <- NULL 229 | if (is.null(names(Element)) | ForceNames == TRUE) 230 | names(Element) <- as.character(1:length(Element)) 231 | Element <- LinearizeNestedList(Element, LinearizeDataFrames, 232 | NameSep, ForceNames) 233 | names(Element) <- AddQuotes( paste(EName, 234 | names(Element), 235 | sep = NameSep) ) 236 | Jump <- length(Element) 237 | NList <- c(Before, Element, After) 238 | } 239 | } 240 | else { 241 | Jump <- 1 242 | } 243 | A <- A + Jump 244 | B <- length(NList) 245 | } 246 | return(NList) 247 | } 248 | 249 | 250 | ai <- as.integer 251 | 252 | se <- function(x){ x <- c(na.omit(x)); return(sqrt(var(x)/length(x)))} 253 | 254 | LocalFxnSource <- function(fxn, evaluation_environment){ 255 | fxn_text <- paste(deparse(fxn), collapse="\n") 256 | fxn_text <- gsub(fxn_text,pattern="function \\(\\)", replace="") 257 | eval( parse( text = fxn_text ), envir = evaluation_environment ) 258 | } 259 | 260 | FilterBN <- function(l_){ cienv$eq$partition(l_, function(l__){"first_time_index" %in% names(l__)}) } 261 | 262 | cienv <- new.env( parent = emptyenv() ) 263 | 264 | 265 | dropout_layer_init <- function(p) { 266 | if (p == 0) { return( function(x, key, inference) { return( x ) } ) } 267 | if (p != 0) { 268 | keep_prob <- (1 - p) 269 | return( 270 | function(x, key, inference) { 271 | # Efficient dynamic branch: skip dropout at inference time 272 | cienv$jax$lax$cond( 273 | pred = inference, 274 | true_fun = function(args){ return(args[[1]]) }, 275 | false_fun = function(args){ 276 | mask <- cienv$jax$lax$stop_gradient( 277 | #cienv$jax$random$bernoulli(key = args[[2]], p = keep_prob, shape = args[[1]]$shape)$astype(args[[1]]) 278 | cienv$jax$random$bernoulli(key = args[[2]], p = keep_prob, shape = args[[1]]$shape)$astype( args[[1]]$dtype ) 279 | ) 280 | return( args[[1]] * mask / keep_prob ) 281 | }, 282 | operand = list(x, key) # pack arguments 283 | ) } 284 | ) 285 | }} 286 | 287 | wt_init <- function(shape, seed_key){ 288 | init_std <- sqrt(2.0 / as.numeric(shape[[1]] + shape[[1]])) 289 | cienv$jax$random$normal( 290 | key = seed_key, 291 | shape = shape 292 | ) * cienv$jnp$array(init_std)$astype( cienv$jaxFloatType ) 293 | } 294 | 295 | -------------------------------------------------------------------------------- /causalimages/R/CI_image2.R: -------------------------------------------------------------------------------- 1 | #' Visualizing matrices as heatmaps with correct north-south-east-west orientation 2 | #' 3 | #' A function for generating a heatmap representation of a matrix with correct spatial orientation. 4 | #' 5 | #' @param x The numeric matrix to be visualized. 6 | #' @param xlab The x-axis labels. 7 | #' @param ylab The y-axis labels. 8 | #' @param xaxt The x-axis tick labels. 9 | #' @param yaxt The y-axis tick labels. 10 | #' @param main The main figure label. 11 | #' @param cex.main The main figure label sizing factor. 12 | #' @param col.lab Axis label color. 13 | #' @param col.main Main label color. 14 | #' @param cex.lab Cex for the labels. 15 | #' @param box Draw a box around the image? 16 | 17 | #' @return Returns a heatmap representation of the matrix, `x`, with correct north/south/east/west orientation. 18 | #' 19 | #' @examples 20 | #' #set seed 21 | #' set.seed(1) 22 | #' 23 | #' #Geneate data 24 | #' x <- matrix(rnorm(50*50), ncol = 50) 25 | #' diag(x) <- 3 26 | #' 27 | #' # create plot 28 | #' image2(x, main = "Example Text", cex.main = 2) 29 | #' 30 | #' @export 31 | #' @md 32 | 33 | image2 = function(x, 34 | xaxt=NULL, 35 | yaxt = NULL, 36 | xlab = "", 37 | ylab = "", 38 | main=NULL, 39 | cex.main = NULL, 40 | col.lab = "black", 41 | col.main = "black", 42 | cex.lab = 1.5, 43 | box=F){ 44 | image((t(x)[,nrow(x):1]), 45 | axes = F, 46 | main = main, 47 | xlab = xlab, 48 | ylab = ylab, 49 | xaxs = "i", 50 | cex.lab = cex.lab, 51 | col.lab = col.lab, 52 | col.main = col.main, 53 | cex.main = cex.main) 54 | if(box == T){box()} 55 | if(!is.null(xaxt)){ axis(1, at = 0:(nrow(x)-1)/nrow(x)*1.04, tick=F,labels = (xaxt),cex.axis = 1,las = 1) } 56 | if(!is.null(yaxt)){ axis(2, at = 0:(nrow(x)-1)/nrow(x)*1.04, tick=F,labels = rev(yaxt),cex.axis = 1,las = 2) } 57 | } 58 | 59 | f2n <- function(.){as.numeric(as.character(.))} 60 | -------------------------------------------------------------------------------- /causalimages/data/CausalImagesTutorialData.RData: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cjerzak/causalimages-software/a6ca8dd8aae433c5207dedaa89e2c099d5700b8f/causalimages/data/CausalImagesTutorialData.RData -------------------------------------------------------------------------------- /causalimages/data/datalist: -------------------------------------------------------------------------------- 1 | CausalImagesTutorialData: FullImageArray KeysOfImages KeysOfObservations LongLat obsW obsY X 2 | -------------------------------------------------------------------------------- /causalimages/man/AnalyzeImageConfounding.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CI_Confounding.R 3 | \name{AnalyzeImageConfounding} 4 | \alias{AnalyzeImageConfounding} 5 | \title{Perform causal estimation under image confounding} 6 | \usage{ 7 | AnalyzeImageConfounding( 8 | obsW, 9 | obsY, 10 | X = NULL, 11 | file = NULL, 12 | imageKeysOfUnits = NULL, 13 | fileTransport = NULL, 14 | imageKeysOfUnitsTransport = NULL, 15 | nBoot = 10L, 16 | inputAvePoolingSize = 1L, 17 | useTrainingPertubations = T, 18 | useScalePertubations = F, 19 | crossFit = FALSE, 20 | augmented = FALSE, 21 | orthogonalize = F, 22 | transportabilityMat = NULL, 23 | latTransport = NULL, 24 | longTransport = NULL, 25 | lat = NULL, 26 | long = NULL, 27 | conda_env = "CausalImagesEnv", 28 | conda_env_required = T, 29 | Sys.setenv_text = NULL, 30 | figuresTag = NULL, 31 | figuresPath = "./", 32 | plotBands = 1L, 33 | plotResults = T, 34 | XCrossModal = T, 35 | XForceModal = F, 36 | optimizeImageRep = T, 37 | nonLinearScaler = NULL, 38 | nWidth_ImageRep = 64L, 39 | nDepth_ImageRep = 1L, 40 | kernelSize = 5L, 41 | nWidth_Dense = 64L, 42 | nDepth_Dense = 1L, 43 | imageModelClass = "VisionTransformer", 44 | pretrainedModel = NULL, 45 | strides = 2L, 46 | nDepth_TemporalRep = 3L, 47 | patchEmbedDim = 16L, 48 | dropoutRate = 0.1, 49 | droppathRate = 0.1, 50 | batchSize = 16L, 51 | nSGD = 400L, 52 | testFrac = 0.05, 53 | TfRecords_BufferScaler = 4L, 54 | learningRateMax = 0.001, 55 | TFRecordControl = NULL, 56 | dataType = "image", 57 | image_dtype = "float16", 58 | atError = "stop", 59 | seed = NULL 60 | ) 61 | } 62 | \arguments{ 63 | \item{obsW}{A numeric vector where \code{0}'s correspond to control units and \code{1}'s to treated units.} 64 | 65 | \item{obsY}{A numeric vector containing observed outcomes.} 66 | 67 | \item{X}{An optional numeric matrix containing tabular information used if \code{orthogonalize = T}. \code{X} is normalized internally and salience maps with respect to \code{X} are transformed back to the original scale.} 68 | 69 | \item{file}{Path to a tfrecord file generated by \code{WriteTfRecord}.} 70 | 71 | \item{imageKeysOfUnits}{A vector of length \code{length(obsY)} specifying the unique image ID associated with each unit. Samples of \code{imageKeysOfUnits} are fed into the package to call images into memory.} 72 | 73 | \item{nBoot}{Number of bootstrap iterations for uncertainty estimation.} 74 | 75 | \item{useTrainingPertubations}{Boolean specifying whether to randomly perturb the image axes during training to reduce overfitting.} 76 | 77 | \item{transportabilityMat}{Optional matrix with a column named \code{imageKeysOfUnits} specifying keys to be used by the package for generating treatment effect predictions for out-of-sample points.} 78 | 79 | \item{long, lat}{Optional vectors specifying longitude and latitude coordinates for units. Used only for describing highest and lowest probability neighborhood units if specified.} 80 | 81 | \item{conda_env}{A \code{conda} environment where computational environment lives, usually created via \code{causalimages::BuildBackend()}. Default = \code{"CausalImagesEnv"}.} 82 | 83 | \item{conda_env_required}{A Boolean stating whether use of the specified conda environment is required.} 84 | 85 | \item{figuresTag}{A string specifying an identifier that is appended to all figure names.} 86 | 87 | \item{figuresPath}{A string specifying file path for saved figures made in the analysis.} 88 | 89 | \item{plotBands}{An integer or vector specifying which band position (from the image representation) should be plotted in the visual results. If a vector, \code{plotBands} should have 3 (and only 3) dimensions (corresponding to the 3 dimensions to be used in RGB plotting).} 90 | 91 | \item{plotResults}{(default = \code{T}) Should analysis results be plotted?} 92 | 93 | \item{optimizeImageRep}{Boolean specifying whether to optimize over the image model representation (or only over downstream parameters).} 94 | 95 | \item{nWidth_ImageRep}{Integer specifying width of image model representation.} 96 | 97 | \item{nDepth_ImageRep}{Integer specifying depth of image model representation.} 98 | 99 | \item{kernelSize}{Dimensions used in spatial convolutions.} 100 | 101 | \item{nWidth_Dense}{Integer specifying width of image model representation.} 102 | 103 | \item{nDepth_Dense}{Integer specifying depth of dense model representation.} 104 | 105 | \item{strides}{(default = \code{2L}) Integer specifying the strides used in the convolutional layers.} 106 | 107 | \item{dropoutRate}{Dropout rate used in training to prevent overfitting (\code{dropoutRate = 0} corresponds to no dropout).} 108 | 109 | \item{droppathRate}{Droppath rate used in training to prevent overfitting (\code{droppathRate = 0} corresponds to no droppath).} 110 | 111 | \item{batchSize}{Batch size used in SGD optimization. Default = \code{50L}.} 112 | 113 | \item{nSGD}{Number of stochastic gradient descent (SGD) iterations. Default = \code{400L}} 114 | 115 | \item{testFrac}{Default = \code{0.1}. Fraction of observations held out as a test set to evaluate out-of-sample loss values.} 116 | 117 | \item{TfRecords_BufferScaler}{The buffer size used in \code{tfrecords} mode is \code{batchSize*TfRecords_BufferScaler}. Lower \code{TfRecords_BufferScaler} towards 1 if out-of-memory problems.} 118 | 119 | \item{dataType}{(default = \code{"image"}) String specifying whether to assume \code{"image"} or \code{"video"} data types.} 120 | } 121 | \value{ 122 | Returns a list consisting of 123 | \itemize{ 124 | \item \code{ATE_est} ATE estimate. 125 | \item \code{ATE_se} Standard error estimate for the ATE. 126 | \item \code{plotResults} If set to \code{TRUE}, causal salience plots are saved to disk, characterizing the image confounding structure. See references for details. 127 | } 128 | } 129 | \description{ 130 | Perform causal estimation under image confounding 131 | } 132 | \section{References}{ 133 | 134 | \itemize{ 135 | \item Connor T. Jerzak, Fredrik Johansson, Adel Daoud. Integrating Earth Observation Data into Causal Inference: Challenges and Opportunities. \emph{ArXiv Preprint}, 2023. 136 | } 137 | } 138 | 139 | \examples{ 140 | # For a tutorial, see 141 | # github.com/cjerzak/causalimages-software/ 142 | 143 | } 144 | -------------------------------------------------------------------------------- /causalimages/man/AnalyzeImageHeterogeneity.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CI_Heterogeneity.R 3 | \name{AnalyzeImageHeterogeneity} 4 | \alias{AnalyzeImageHeterogeneity} 5 | \title{Decompose treatment effect heterogeneity by image or image sequence} 6 | \usage{ 7 | AnalyzeImageHeterogeneity( 8 | obsW, 9 | obsY, 10 | X = NULL, 11 | orthogonalize = F, 12 | imageKeysOfUnits = 1:length(obsY), 13 | kClust_est = 2, 14 | file = NULL, 15 | transportabilityMat = NULL, 16 | lat = NULL, 17 | long = NULL, 18 | conda_env = "CausalImagesEnv", 19 | conda_env_required = T, 20 | figuresTag = "", 21 | figuresPath = "./", 22 | plotBands = 1L, 23 | heterogeneityModelType = "variational_minimal", 24 | plotResults = F, 25 | optimizeImageRep = T, 26 | nWidth_ImageRep = 64L, 27 | nDepth_ImageRep = 1L, 28 | nWidth_Dense = 64L, 29 | nDepth_Dense = 1L, 30 | nDepth_TemporalRep = 1L, 31 | useTrainingPertubations = T, 32 | strides = 2L, 33 | nonLinearScaler = NULL, 34 | pretrainedModel = NULL, 35 | testFrac = 0.1, 36 | kernelSize = 5L, 37 | learningRateMax = 0.001, 38 | TFRecordControl = NULL, 39 | patchEmbedDim = 16L, 40 | nSGD = 500L, 41 | batchSize = 16L, 42 | seed = NULL, 43 | Sys.setenv_text = NULL, 44 | imageModelClass = "VisionTransformer", 45 | nMonte_predictive = 10L, 46 | nMonte_salience = 10L, 47 | nMonte_variational = 2L, 48 | TfRecords_BufferScaler = 4L, 49 | temperature = 1, 50 | inputAvePoolingSize = 1L, 51 | dataType = "image" 52 | ) 53 | } 54 | \arguments{ 55 | \item{obsW}{A numeric vector where \code{0}'s correspond to control units and \code{1}'s to treated units.} 56 | 57 | \item{obsY}{A numeric vector containing observed outcomes.} 58 | 59 | \item{X}{Optional numeric matrix containing tabular information used if \code{orthogonalize = T}.} 60 | 61 | \item{orthogonalize}{A Boolean specifying whether to perform the image decomposition after orthogonalizing with respect to tabular covariates specified in \code{X}.} 62 | 63 | \item{imageKeysOfUnits}{A vector of length \code{length(obsY)} specifying the unique image ID associated with each unit. Samples of \code{imageKeysOfUnits} are fed into the package to call images into memory.} 64 | 65 | \item{kClust_est}{Integer specifying the number of clusters used in estimation. Default is \code{2L}.} 66 | 67 | \item{file}{Path to a tfrecord file generated by \code{WriteTfRecord}.} 68 | 69 | \item{transportabilityMat}{An optional matrix with a column named \code{key} specifying keys to be used for generating treatment effect predictions for out-of-sample points in earth observation data settings.} 70 | 71 | \item{long, lat}{Optional vectors specifying longitude and latitude coordinates for units. Used only for describing highest and lowest probability neighborhood units if specified.} 72 | 73 | \item{conda_env}{A \code{conda} environment where computational environment lives, usually created via \code{causalimages::BuildBackend()}. Default = \code{"CausalImagesEnv"}.} 74 | 75 | \item{conda_env_required}{A Boolean stating whether use of the specified conda environment is required.} 76 | 77 | \item{figuresTag}{A string specifying an identifier that is appended to all figure names.} 78 | 79 | \item{figuresPath}{A string specifying file path for saved figures made in the analysis.} 80 | 81 | \item{plotBands}{An integer or vector specifying which band position (from the acquired image representation) should be plotted in the visual results. If a vector, \code{plotBands} should have 3 (and only 3) dimensions (corresponding to the 3 dimensions to be used in RGB plotting).} 82 | 83 | \item{plotResults}{Should analysis results be plotted?} 84 | 85 | \item{optimizeImageRep}{Boolean specifying whether to optimize over the image model representation (or only over downstream parameters).} 86 | 87 | \item{nWidth_ImageRep}{Integer specifying width of image model representation.} 88 | 89 | \item{nDepth_ImageRep}{Integer specifying depth of image model representation.} 90 | 91 | \item{nWidth_Dense}{Integer specifying width of image model representation.} 92 | 93 | \item{nDepth_Dense}{Integer specifying depth of dense model representation.} 94 | 95 | \item{strides}{Integer specifying the strides used in the convolutional layers.} 96 | 97 | \item{kernelSize}{Dimensions used in spatial convolutions.} 98 | 99 | \item{nSGD}{Number of stochastic gradient descent (SGD) iterations.} 100 | 101 | \item{batchSize}{Batch size used in SGD optimization.} 102 | 103 | \item{nMonte_predictive}{An integer specifying how many Monte Carlo iterations to use in the calculation 104 | of posterior means (e.g., mean cluster probabilities).} 105 | 106 | \item{nMonte_salience}{An integer specifying how many Monte Carlo iterations to use in the calculation 107 | of the salience maps (e.g., image gradients of expected cluster probabilities).} 108 | 109 | \item{nMonte_variational}{An integer specifying how many Monte Carlo iterations to use in the 110 | calculation of the expected likelihood in each training step.} 111 | 112 | \item{TfRecords_BufferScaler}{The buffer size used in \code{tfrecords} mode is \code{batchSize*TfRecords_BufferScaler}. Lower \code{TfRecords_BufferScaler} towards 1 if out-of-memory problems.} 113 | 114 | \item{dataType}{String specifying whether to assume \code{"image"} or \code{"video"} data types.} 115 | } 116 | \value{ 117 | Returns a list consisting of \itemize{ 118 | \item \code{clusterTaus_mean} default 119 | \item \code{clusterProbs_mean}. Estimated mean image effect cluster probabilities. 120 | \item \code{clusterTaus_sigma}. Estimated cluster standard deviations. 121 | \item \code{clusterProbs_lowerConf}. Estimated lower confidence for effect cluster probabilities. 122 | \item \code{impliedATE}. Implied ATE. 123 | \item \code{individualTau_est}. Estimated individual-level image-based treatment effects. 124 | \item \code{transportabilityMat}. Transportability matrix with estimated cluster information. 125 | \item \code{plottedCoordinates}. List containing coordinates plotted in salience maps. 126 | \item \code{whichNA_dropped}. A vector containing observations dropped due to missingness. 127 | } 128 | } 129 | \description{ 130 | Implements the image heterogeneity decomposition analysis of Jerzak, Johansson, and Daoud (2023). Users 131 | input in treatment and outcome data, along with a function specifying how to load in images 132 | using keys referenced to each unit (since loading in all image data will usually not be possible due to memory limitations). 133 | This function by default performs estimation, constructs salience maps, and can optionally perform 134 | estimation for new areas outside the original study sites in a transportability analysis. 135 | } 136 | \section{References}{ 137 | 138 | \itemize{ 139 | \item Connor T. Jerzak, Fredrik Johansson, Adel Daoud. Image-based Treatment Effect Heterogeneity. Forthcoming in \emph{Proceedings of the Second Conference on Causal Learning and Reasoning (CLeaR), Proceedings of Machine Learning Research (PMLR)}, 2023. 140 | } 141 | } 142 | 143 | \examples{ 144 | # For a tutorial, see 145 | # github.com/cjerzak/causalimages-software/ 146 | 147 | } 148 | -------------------------------------------------------------------------------- /causalimages/man/BuildBackend.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CI_BuildBackend.R 3 | \name{BuildBackend} 4 | \alias{BuildBackend} 5 | \title{Build the environment for CausalImages models. Builds a conda environment in which jax, tensorflow, tensorflow-probability optax, equinox, and jmp are installed.} 6 | \usage{ 7 | BuildBackend(conda_env = "CausalImagesEnv", conda = "auto") 8 | } 9 | \arguments{ 10 | \item{conda_env}{(default = \code{"CausalImagesEnv"}) Name of the conda environment in which to place the backends.} 11 | 12 | \item{conda}{(default = \code{auto}) The path to a conda executable. Using \code{"auto"} allows reticulate to attempt to automatically find an appropriate conda binary.} 13 | } 14 | \value{ 15 | Builds the computational environment for \code{causalimages}. This function requires an Internet connection. 16 | You may find out a list of conda Python paths via: \code{system("which python")} 17 | } 18 | \description{ 19 | Build the environment for CausalImages models. Builds a conda environment in which jax, tensorflow, tensorflow-probability optax, equinox, and jmp are installed. 20 | } 21 | \examples{ 22 | # For a tutorial, see 23 | # github.com/cjerzak/causalimages-software/ 24 | 25 | } 26 | -------------------------------------------------------------------------------- /causalimages/man/GetAndSaveGeolocatedImages.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CI_GetAndSaveGeolocatedImages.R 3 | \name{GetAndSaveGeolocatedImages} 4 | \alias{GetAndSaveGeolocatedImages} 5 | \title{Getting and saving geo-located images from a pool of .tif's} 6 | \usage{ 7 | GetAndSaveGeolocatedImages( 8 | long, 9 | lat, 10 | keys, 11 | tif_pool, 12 | image_pixel_width = 256L, 13 | save_folder = ".", 14 | save_as = "csv", 15 | lyrs = NULL 16 | ) 17 | } 18 | \arguments{ 19 | \item{long}{Vector of numeric longitudes.} 20 | 21 | \item{lat}{Vector of numeric latitudes.} 22 | 23 | \item{keys}{The image keys associated with the long/lat coordinates.} 24 | 25 | \item{tif_pool}{A character vector specifying the fully qualified path to a corpus of .tif files.} 26 | 27 | \item{image_pixel_width}{An even integer specifying the pixel width (and height) of the saved images.} 28 | 29 | \item{save_folder}{(default = \code{"."}) What folder should be used to save the output? Example: \code{"~/Downloads"}} 30 | 31 | \item{save_as}{(default = \code{".csv"}) What format should the output be saved as? Only one option currently (\code{.csv})} 32 | 33 | \item{lyrs}{(default = NULL) Integer (vector) specifying the layers to be extracted. Default is for all layers to be extracted.} 34 | } 35 | \value{ 36 | Finds the image slice associated with the \code{long} and \code{lat} values, saves images by band (if \code{save_as = "csv"}) in save_folder. 37 | The save format is: \code{sprintf("\%s/Key\%s_BAND\%s.csv", save_folder, keys[i], band_)} 38 | } 39 | \description{ 40 | A function that finds the image slice associated with the \code{long} and \code{lat} values, saves images by band (if \code{save_as = "csv"}) in save_folder. 41 | } 42 | \examples{ 43 | 44 | # Example use (not run): 45 | #MASTER_IMAGE_POOL_FULL_DIR <- c("./LargeTifs/tif1.tif","./LargeTifs/tif2.tif") 46 | #GetAndSaveGeolocatedImages( 47 | #long = GeoKeyMat$geo_long, 48 | #lat = GeoKeyMat$geo_lat, 49 | #image_pixel_width = 500L, 50 | #keys = row.names(GeoKeyMat), 51 | #tif_pool = MASTER_IMAGE_POOL_FULL_DIR, 52 | #save_folder = "./Data/Uganda2000_processed", 53 | #save_as = "csv", 54 | #lyrs = NULL) 55 | 56 | } 57 | -------------------------------------------------------------------------------- /causalimages/man/GetElementFromTfRecordAtIndices.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CI_TfRecordFxns.R 3 | \name{GetElementFromTfRecordAtIndices} 4 | \alias{GetElementFromTfRecordAtIndices} 5 | \title{Reads unique key indices from a \code{.tfrecord} file.} 6 | \usage{ 7 | GetElementFromTfRecordAtIndices(uniqueKeyIndices, file, 8 | conda_env, conda_env_required) 9 | } 10 | \arguments{ 11 | \item{uniqueKeyIndices}{(integer vector) Unique image indices to be retrieved from a \code{.tfrecord}} 12 | 13 | \item{conda_env}{(Default = \code{NULL}) A \code{conda} environment where tensorflow v2 lives. Used only if a version of tensorflow is not already active.} 14 | 15 | \item{conda_env_required}{(default = \code{F}) A Boolean stating whether use of the specified conda environment is required.} 16 | 17 | \item{file}{(character string) A character string stating the path to a \code{.tfrecord}} 18 | } 19 | \value{ 20 | Returns content from a \code{.tfrecord} associated with \code{uniqueKeyIndices} 21 | } 22 | \description{ 23 | Reads unique key indices from a \code{.tfrecord} file saved via a call to \code{causalimages::WriteTfRecord}. 24 | } 25 | \examples{ 26 | # Example usage (not run): 27 | #GetElementFromTfRecordAtIndices( 28 | #uniqueKeyIndices = 1:10, 29 | #file = "./NigeriaConfoundApp.tfrecord") 30 | 31 | } 32 | -------------------------------------------------------------------------------- /causalimages/man/GetImageRepresentations.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CI_ImageModelBackbones.R 3 | \name{GetImageRepresentations} 4 | \alias{GetImageRepresentations} 5 | \title{Generates image and video representations useful in earth observation tasks for casual inference.} 6 | \usage{ 7 | GetImageRepresentations( 8 | X = NULL, 9 | imageKeysOfUnits = NULL, 10 | file = NULL, 11 | conda_env = "CausalImagesEnv", 12 | conda_env_required = T, 13 | returnContents = T, 14 | getRepresentations = T, 15 | imageModelClass = "VisionTransformer", 16 | NORM_MEAN = NULL, 17 | NORM_SD = NULL, 18 | Sys.setenv_text = NULL, 19 | InitImageProcess = NULL, 20 | pretrainedModel = NULL, 21 | lat = NULL, 22 | long = NULL, 23 | image_dtype = NULL, 24 | image_dtype_tf = NULL, 25 | XCrossModal = T, 26 | XForceModal = F, 27 | nWidth_ImageRep = 64L, 28 | nDepth_ImageRep = 1L, 29 | nDepth_TemporalRep = 1L, 30 | batchSize = 16L, 31 | nonLinearScaler = NULL, 32 | optimizeImageRep = T, 33 | strides = 1L, 34 | kernelSize = 3L, 35 | patchEmbedDim = 16L, 36 | TfRecords_BufferScaler = 10L, 37 | dropoutRate, 38 | droppathRate, 39 | dataType = "image", 40 | bn_momentum = 0.99, 41 | inputAvePoolingSize = 1L, 42 | CleanupEnv = FALSE, 43 | initializingFxns = FALSE, 44 | seed = NULL 45 | ) 46 | } 47 | \arguments{ 48 | \item{imageKeysOfUnits}{A vector of length \code{length(imageKeysOfUnits)} specifying the unique image ID associated with each unit. Samples of \code{imageKeysOfUnits} are fed into the package to call images into memory.} 49 | 50 | \item{file}{Path to a tfrecord file generated by \code{causalimages::WriteTfRecord}.} 51 | 52 | \item{conda_env}{A \code{conda} environment where computational environment lives, usually created via \code{causalimages::BuildBackend()}. Default = \code{"CausalImagesEnv"}} 53 | 54 | \item{conda_env_required}{A Boolean stating whether use of the specified conda environment is required.} 55 | 56 | \item{InitImageProcess}{(default = \code{NULL}) Initial image processing function. Usually left \code{NULL}.} 57 | 58 | \item{nWidth_ImageRep}{Number of embedding features output.} 59 | 60 | \item{batchSize}{Integer specifying batch size in obtaining representations.} 61 | 62 | \item{strides}{Integer specifying the strides used in the convolutional layers.} 63 | 64 | \item{kernelSize}{Dimensions used in the convolution kernels.} 65 | 66 | \item{TfRecords_BufferScaler}{The buffer size used in \code{tfrecords} mode is \code{batchSize*TfRecords_BufferScaler}. Lower \code{TfRecords_BufferScaler} towards 1 if out-of-memory problems.} 67 | 68 | \item{dataType}{String specifying whether to assume \code{"image"} or \code{"video"} data types. Default is \code{"image"}.} 69 | } 70 | \value{ 71 | A list containing two items: 72 | \itemize{ 73 | \item \code{Representations} (matrix) A matrix containing image/video representations, with rows corresponding to observations. 74 | \item \verb{ImageRepArm_OneObs,ImageRepArm_batch_R, ImageRepArm_batch} (functions) Image modeling functions. 75 | \item \code{ImageModel_And_State_And_MPPolicy_List} List containing image model parameters fed into functions. 76 | } 77 | } 78 | \description{ 79 | Generates image and video representations useful in earth observation tasks for casual inference. 80 | } 81 | \section{References}{ 82 | 83 | \itemize{ 84 | \item Rolf, Esther, et al. "A generalizable and accessible approach to machine learning with global satellite imagery." \emph{Nature Communications} 12.1 (2021): 4392. 85 | } 86 | } 87 | 88 | \examples{ 89 | # For a tutorial, see 90 | # github.com/cjerzak/causalimages-software/ 91 | 92 | } 93 | -------------------------------------------------------------------------------- /causalimages/man/GetMoments.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CI_GetMoments.R 3 | \name{GetMoments} 4 | \alias{GetMoments} 5 | \title{Get moments for normalization (internal function)} 6 | \usage{ 7 | GetMoments(iterator, dataType, image_dtype, momentCalIters = 34L) 8 | } 9 | \arguments{ 10 | \item{iterator}{An iterator} 11 | 12 | \item{dataType}{A string denoting data type} 13 | 14 | \item{momentCalIters}{Number of minibatches with which to estimate moments} 15 | } 16 | \value{ 17 | Returns mean/sd arrays for normalization. 18 | } 19 | \description{ 20 | An internal function function for obtaining moments for channel normalization. 21 | } 22 | \examples{ 23 | # (Not run) 24 | # GetMoments(iterator, dataType, image_dtype, momentCalIters = 34L) 25 | } 26 | -------------------------------------------------------------------------------- /causalimages/man/LongLat2CRS.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CI_helperFxns.R 3 | \name{LongLat2CRS} 4 | \alias{LongLat2CRS} 5 | \title{Get the spatial point of long/lat coordinates} 6 | \usage{ 7 | LongLat2CRS(long, lat, CRS_ref) 8 | } 9 | \arguments{ 10 | \item{long}{Vector of numeric longitudes.} 11 | 12 | \item{lat}{Vector of numeric latitudes.} 13 | 14 | \item{CRS_ref}{A CRS into which the long-lat point should be projected.} 15 | } 16 | \value{ 17 | Numeric vector of length two giving the coordinates of the supplied 18 | location in the CRS defined by \code{CRS_ref}. 19 | } 20 | \description{ 21 | Convert longitude and latitude coordinates to a different coordinate reference 22 | system (CRS). 23 | } 24 | \examples{ 25 | # (Not run) 26 | #spatialPt <- LongLat2CRS(long = 49.932, 27 | # lat = 35.432, 28 | # CRS_ref = sf::st_crs("+proj=lcc +lat_1=48 +lat_2=33 +lon_0=-100 +ellps=WGS84")) 29 | } 30 | -------------------------------------------------------------------------------- /causalimages/man/PredictiveRun.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CI_PredictiveRun.R 3 | \name{PredictiveRun} 4 | \alias{PredictiveRun} 5 | \title{Perform predictive modeling using images or videos} 6 | \usage{ 7 | PredictiveRun( 8 | obsY, 9 | imageKeysOfUnits = NULL, 10 | file = NULL, 11 | fileTransport = NULL, 12 | imageKeysOfUnitsTransport = NULL, 13 | nBoot = 10L, 14 | inputAvePoolingSize = 1L, 15 | useTrainingPertubations = T, 16 | useScalePertubations = F, 17 | X = NULL, 18 | conda_env = "CausalImagesEnv", 19 | conda_env_required = T, 20 | Sys.setenv_text = NULL, 21 | figuresTag = NULL, 22 | figuresPath = "./", 23 | plotBands = 1L, 24 | plotResults = T, 25 | XCrossModal = T, 26 | XForceModal = F, 27 | optimizeImageRep = T, 28 | nWidth_ImageRep = 64L, 29 | nDepth_ImageRep = 1L, 30 | kernelSize = 5L, 31 | nWidth_Dense = 64L, 32 | nDepth_Dense = 1L, 33 | imageModelClass = "VisionTransformer", 34 | pretrainedModel = NULL, 35 | strides = 2L, 36 | nonLinearScaler = NULL, 37 | nDepth_TemporalRep = 3L, 38 | patchEmbedDim = 16L, 39 | dropoutRate = 0.1, 40 | droppathRate = 0.1, 41 | batchSize = 16L, 42 | nSGD = 400L, 43 | testFrac = 0.05, 44 | TfRecords_BufferScaler = 4L, 45 | learningRateMax = 0.001, 46 | TFRecordControl = NULL, 47 | dataType = "image", 48 | image_dtype = "float16", 49 | atError = "stop", 50 | seed = NULL, 51 | modelPath = "./trained_model.eqx", 52 | metricsPath = "./evaluation_metrics.rds" 53 | ) 54 | } 55 | \arguments{ 56 | \item{obsY}{A numeric vector containing observed outcomes to predict.} 57 | 58 | \item{imageKeysOfUnits}{A vector of length \code{length(obsY)} specifying the unique image ID associated with each unit. Samples of \code{imageKeysOfUnits} are fed into the package to call images into memory.} 59 | 60 | \item{file}{Path to a tfrecord file generated by \code{WriteTfRecord}.} 61 | 62 | \item{nBoot}{Number of bootstrap iterations for uncertainty estimation.} 63 | 64 | \item{useTrainingPertubations}{Boolean specifying whether to randomly perturb the image axes during training to reduce overfitting.} 65 | 66 | \item{X}{An optional numeric matrix containing tabular information. \code{X} is normalized internally.} 67 | 68 | \item{conda_env}{A \code{conda} environment where computational environment lives, usually created via \code{causalimages::BuildBackend()}. Default = \code{"CausalImagesEnv"}.} 69 | 70 | \item{conda_env_required}{A Boolean stating whether use of the specified conda environment is required.} 71 | 72 | \item{figuresTag}{A string specifying an identifier that is appended to all figure names.} 73 | 74 | \item{figuresPath}{A string specifying file path for saved figures made in the analysis.} 75 | 76 | \item{plotBands}{An integer or vector specifying which band position (from the image representation) should be plotted in the visual results. If a vector, \code{plotBands} should have 3 (and only 3) dimensions (corresponding to the 3 dimensions to be used in RGB plotting).} 77 | 78 | \item{plotResults}{(default = \code{T}) Should analysis results be plotted?} 79 | 80 | \item{optimizeImageRep}{Boolean specifying whether to optimize over the image model representation (or only over downstream parameters).} 81 | 82 | \item{nWidth_ImageRep}{Integer specifying width of image model representation.} 83 | 84 | \item{nDepth_ImageRep}{Integer specifying depth of image model representation.} 85 | 86 | \item{kernelSize}{Dimensions used in spatial convolutions.} 87 | 88 | \item{nWidth_Dense}{Integer specifying width of image model representation.} 89 | 90 | \item{nDepth_Dense}{Integer specifying depth of dense model representation.} 91 | 92 | \item{strides}{(default = \code{2L}) Integer specifying the strides used in the convolutional layers.} 93 | 94 | \item{dropoutRate}{Dropout rate used in training to prevent overfitting (\code{dropoutRate = 0} corresponds to no dropout).} 95 | 96 | \item{droppathRate}{Droppath rate used in training to prevent overfitting (\code{droppathRate = 0} corresponds to no droppath).} 97 | 98 | \item{batchSize}{Batch size used in SGD optimization. Default = \code{50L}.} 99 | 100 | \item{nSGD}{Number of stochastic gradient descent (SGD) iterations. Default = \code{400L}} 101 | 102 | \item{testFrac}{Default = \code{0.1}. Fraction of observations held out as a test set to evaluate out-of-sample loss values.} 103 | 104 | \item{TfRecords_BufferScaler}{The buffer size used in \code{tfrecords} mode is \code{batchSize*TfRecords_BufferScaler}. Lower \code{TfRecords_BufferScaler} towards 1 if out-of-memory problems.} 105 | 106 | \item{dataType}{(default = \code{"image"}) String specifying whether to assume \code{"image"} or \code{"video"} data types.} 107 | 108 | \item{modelPath}{Path to save the trained model. Default = \code{"./trained_model.eqx"}.} 109 | 110 | \item{metricsPath}{Path to save the evaluation metrics as a RDS file. Default = \code{"./evaluation_metrics.rds"}.} 111 | 112 | \item{transportabilityMat}{Optional matrix with a column named \code{imageKeysOfUnits} specifying keys to be used by the package for generating predictions for out-of-sample points.} 113 | } 114 | \value{ 115 | Returns a list consisting of 116 | \itemize{ 117 | \item \code{predictedY} Predicted values for all units. 118 | \item \code{ModelEvaluationMetrics} Rigorous evaluation metrics (e.g., MSE, R2 for continuous; AUC, accuracy for binary). 119 | } 120 | } 121 | \description{ 122 | Perform predictive modeling using images or videos 123 | } 124 | \section{References}{ 125 | 126 | \itemize{ 127 | \item Connor T. Jerzak, Fredrik Johansson, Adel Daoud. Integrating Earth Observation Data into Causal Inference: Challenges and Opportunities. \emph{ArXiv Preprint}, 2023. 128 | } 129 | } 130 | 131 | \examples{ 132 | # For a tutorial, see 133 | # github.com/cjerzak/causalimages-software/ 134 | 135 | } 136 | -------------------------------------------------------------------------------- /causalimages/man/TFRecordManagement.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CI_TFRecordManagement.R 3 | \name{TFRecordManagement} 4 | \alias{TFRecordManagement} 5 | \title{Defines an internal TFRecord management routine (internal function)} 6 | \usage{ 7 | TFRecordManagement() 8 | } 9 | \arguments{ 10 | \item{.}{No parameters.} 11 | } 12 | \value{ 13 | Internal function defining a tfrecord management sequence. 14 | } 15 | \description{ 16 | Defines management defined in TFRecordManagement(). Internal function. 17 | } 18 | -------------------------------------------------------------------------------- /causalimages/man/TrainDefine.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CI_TrainDefine.R 3 | \name{TrainDefine} 4 | \alias{TrainDefine} 5 | \title{Defines an internal training routine (internal function)} 6 | \usage{ 7 | TrainDefine() 8 | } 9 | \arguments{ 10 | \item{.}{No parameters.} 11 | } 12 | \value{ 13 | Internal function defining a training sequence. 14 | } 15 | \description{ 16 | Defines trainers defined in TrainDefine(). Internal function. 17 | } 18 | -------------------------------------------------------------------------------- /causalimages/man/TrainDo.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CI_TrainDo.R 3 | \name{TrainDo} 4 | \alias{TrainDo} 5 | \title{Runs a training routine (internal function)} 6 | \usage{ 7 | TrainDo() 8 | } 9 | \arguments{ 10 | \item{.}{No parameters.} 11 | } 12 | \value{ 13 | Internal function performing model training. 14 | } 15 | \description{ 16 | Runs trainers defined in TrainDefine(). Internal function. 17 | } 18 | -------------------------------------------------------------------------------- /causalimages/man/WriteTfRecord.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CI_TfRecordFxns.R 3 | \name{WriteTfRecord} 4 | \alias{WriteTfRecord} 5 | \title{Write an image corpus as a .tfrecord file} 6 | \usage{ 7 | WriteTfRecord( 8 | file, 9 | uniqueImageKeys, 10 | acquireImageFxn, 11 | writeVideo = F, 12 | image_dtype = "float16", 13 | conda_env = "CausalImagesEnv", 14 | conda_env_required = T, 15 | Sys.setenv_text = NULL 16 | ) 17 | } 18 | \arguments{ 19 | \item{file}{A character string naming a file for writing.} 20 | 21 | \item{uniqueImageKeys}{A vector specifying the unique image keys of the corpus. 22 | A key grabs an image/video array via \code{acquireImageFxn(key)}.} 23 | 24 | \item{acquireImageFxn}{A function whose input is an observation keys and whose output is an array with dimensions \verb{(length(keys), nSpatialDim1, nSpatialDim2, nChannels)} for images and \verb{(length(keys), nTimeSteps, nSpatialDim1, nSpatialDim2, nChannels)} for image sequence data.} 25 | 26 | \item{writeVideo}{(default = \code{FALSE}) Should we assume we're writing image sequence data of form batch by time by height by width by channels?} 27 | 28 | \item{conda_env}{(default = \code{"CausalImagesEnv"}) A \code{conda} environment where computational environment lives, usually created via \code{causalimages::BuildBackend()}} 29 | 30 | \item{conda_env_required}{(default = \code{T}) A Boolean stating whether use of the specified conda environment is required.} 31 | } 32 | \value{ 33 | Writes a unique key-referenced \code{.tfrecord} from an image/video corpus for use in image-based causal inference training. 34 | } 35 | \description{ 36 | Writes an image corpus to a \code{.tfrecord} file for rapid reading of images into memory for fast ML training. 37 | Specifically, this function serializes an image or video corpus into a \code{.tfrecord} file, enabling efficient data loading for machine learning tasks, particularly for image-based causal inference training. 38 | It requires that users define an \code{acquireImageFxn} function that accepts keys and returns the corresponding image or video as an array of dimensions \verb{(length(keys), nSpatialDim1, nSpatialDim2, nChannels)} for images or \verb{(length(keys), nTimeSteps, nSpatialDim1, nSpatialDim2, nChannels)} for video sequences. 39 | } 40 | \examples{ 41 | # Example usage (not run): 42 | #WriteTfRecord( 43 | # file = "./NigeriaConfoundApp.tfrecord", 44 | # uniqueImageKeys = 1:n, 45 | # acquireImageFxn = acquireImageFxn) 46 | 47 | } 48 | -------------------------------------------------------------------------------- /causalimages/man/image2.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CI_image2.R 3 | \name{image2} 4 | \alias{image2} 5 | \title{Visualizing matrices as heatmaps with correct north-south-east-west orientation} 6 | \usage{ 7 | image2( 8 | x, 9 | xaxt = NULL, 10 | yaxt = NULL, 11 | xlab = "", 12 | ylab = "", 13 | main = NULL, 14 | cex.main = NULL, 15 | col.lab = "black", 16 | col.main = "black", 17 | cex.lab = 1.5, 18 | box = F 19 | ) 20 | } 21 | \arguments{ 22 | \item{x}{The numeric matrix to be visualized.} 23 | 24 | \item{xaxt}{The x-axis tick labels.} 25 | 26 | \item{yaxt}{The y-axis tick labels.} 27 | 28 | \item{xlab}{The x-axis labels.} 29 | 30 | \item{ylab}{The y-axis labels.} 31 | 32 | \item{main}{The main figure label.} 33 | 34 | \item{cex.main}{The main figure label sizing factor.} 35 | 36 | \item{col.lab}{Axis label color.} 37 | 38 | \item{col.main}{Main label color.} 39 | 40 | \item{cex.lab}{Cex for the labels.} 41 | 42 | \item{box}{Draw a box around the image?} 43 | } 44 | \value{ 45 | Returns a heatmap representation of the matrix, \code{x}, with correct north/south/east/west orientation. 46 | } 47 | \description{ 48 | A function for generating a heatmap representation of a matrix with correct spatial orientation. 49 | } 50 | \examples{ 51 | #set seed 52 | set.seed(1) 53 | 54 | #Geneate data 55 | x <- matrix(rnorm(50*50), ncol = 50) 56 | diag(x) <- 3 57 | 58 | # create plot 59 | image2(x, main = "Example Text", cex.main = 2) 60 | 61 | } 62 | -------------------------------------------------------------------------------- /causalimages/man/message2.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CI_helperFxns.R 3 | \name{message2} 4 | \alias{message2} 5 | \title{message2 message() with timestamps} 6 | \usage{ 7 | message2(text, quiet = FALSE) 8 | } 9 | \arguments{ 10 | \item{text}{Character string to be displayed as message, with date and time.} 11 | 12 | \item{quiet}{Logical. If TRUE, suppresses the message output. Default is FALSE.} 13 | } 14 | \value{ 15 | Displays message with date and time to stderr. 16 | } 17 | \description{ 18 | A function that displays a message with date and time. 19 | } 20 | \examples{ 21 | message2("Hello world") 22 | message2("Process completed", quiet = FALSE) 23 | } 24 | -------------------------------------------------------------------------------- /causalimages/man/print2.Rd: -------------------------------------------------------------------------------- 1 | % Generated by roxygen2: do not edit by hand 2 | % Please edit documentation in R/CI_helperFxns.R 3 | \name{print2} 4 | \alias{print2} 5 | \title{print2 print() with timestamps} 6 | \usage{ 7 | print2(text, quiet = F) 8 | } 9 | \arguments{ 10 | \item{x}{Character string to be printed, with date and time.} 11 | } 12 | \value{ 13 | Prints with date and time. 14 | } 15 | \description{ 16 | A function prints a string with date and time. 17 | } 18 | \examples{ 19 | message("Hello world") 20 | } 21 | -------------------------------------------------------------------------------- /causalimages/tests/Test_AAARunAllTutorialsSuite.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | { 3 | # remove all dim calls to arrays 4 | ########################################## 5 | # Code for testing most functionalities of CausalImage on your hardware. 6 | # Current tests failing on optimizing video representation on METAL Other tests succeed. 7 | ########################################## 8 | tryTests <- try({ 9 | # remote install latest version of the package 10 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages") 11 | 12 | # local install for development team 13 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F) 14 | 15 | # Before running tests, it may be necessary to (re)build the backend: 16 | # causalimages::BuildBackend() 17 | # See ?causalimages::BuildBackend for help. Remember to re-start your R session after re-building. 18 | 19 | options(error = NULL) 20 | 21 | print("Starting image TfRecords"); setwd("~"); 22 | TfRecordsTest <- try(source("~/Documents/causalimages-software/causalimages/tests/Test_UsingTFRecords.R"),T) 23 | if("try-error" %in% class(TfRecordsTest)){ stop("Failed at TfRecordsTest (1)") }; try(dev.off(), T) 24 | 25 | print("Starting ImageRepTest"); setwd("~"); 26 | ImageRepTest <- try(source("~/Documents/causalimages-software/causalimages/tests/Test_ExtractImageRepresentations.R"),T) 27 | if("try-error" %in% class(ImageRepTest)){ stop("Failed at ImageRepTest (2)") }; try(dev.off(), T) 28 | 29 | print("Starting ImConfoundTest"); setwd("~"); 30 | ImConfoundTest <- try(source("~/Documents/causalimages-software/causalimages/tests/Test_AnalyzeImageConfounding.R"),T) 31 | if("try-error" %in% class(ImConfoundTest)){ stop("Failed at ImConfoundTest (3)") }; try(dev.off(), T) 32 | 33 | #print("Starting HetTest"); setwd("~"); 34 | #HetTest <- try(source("~/Documents/causalimages-software/causalimages/tests/Test_AnalyzeImageHeterogeneity.R"),T) 35 | #if("try-error" %in% class(HetTest)){ stop("Failed at HetTest") }; try(dev.off(), T) 36 | }, T) 37 | 38 | if('try-error' %in% class(tryTests)){ print("At least one test failed"); print( tryTests ); stop(tryTests) } 39 | if(!'try-error' %in% class(tryTests)){ print("All tests succeeded!") } 40 | } 41 | -------------------------------------------------------------------------------- /causalimages/tests/Test_AnalyzeImageConfounding.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | { 3 | ################################ 4 | # Image confounding tutorial using causalimages 5 | ################################ 6 | setwd("~/Downloads/"); options( error = NULL ) 7 | #setwd("./"); options( error = NULL ) 8 | 9 | # remote install latest version of the package if needed 10 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages") 11 | 12 | # local install for development team 13 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F) 14 | 15 | # build backend you haven't ready: 16 | # causalimages::BuildBackend() 17 | 18 | # load in package 19 | library( causalimages ) 20 | 21 | # resave TfRecords? 22 | reSaveTFRecord <- F 23 | 24 | # load in tutorial data 25 | data( CausalImagesTutorialData ) 26 | 27 | # mean imputation for toy example in this tutorial 28 | X <- apply(X[,-1],2,function(zer){ 29 | zer[is.na(zer)] <- mean( zer,na.rm = T ); return( zer ) 30 | }) 31 | 32 | # select observation subset to make the tutorial quick 33 | set.seed(4321L);take_indices <- unlist( tapply(1:length(obsW),obsW,function(zer){sample(zer, 100)}) ) 34 | 35 | # perform causal inference with image and tabular confounding 36 | { 37 | # example acquire image function (loading from memory) 38 | # in general, you'll want to write a function that returns images 39 | # that saved disk associated with keys 40 | acquireImageFxn <- function(keys){ 41 | # here, the function input keys 42 | # refers to the unit-associated image keys 43 | # we also tweak the image dimensions for testing purposes 44 | #m_ <- FullImageArray[match(keys, KeysOfImages),c(1:35,1:35),c(1:35,1:35),1:2] # test with two channels 45 | #m_ <- FullImageArray[match(keys, KeysOfImages),c(1:35,1:35),c(1:35,1:35),] # test with three channels 46 | m_ <- FullImageArray[match(keys, KeysOfImages),c(1:35,1:35),c(1:35,1:35),c(1:3,1:2)] # test with five channels 47 | 48 | # if keys == 1, add the batch dimension so output dims are always consistent 49 | # (here in image case, dims are batch by height by width by channel) 50 | if(length(keys) == 1){ m_ <- array(m_,dim = c(1L,dim(m_)[1],dim(m_)[2],dim(m_)[3])) } 51 | 52 | return( m_ ) 53 | } 54 | 55 | # write tf record 56 | TFRecordName_im <- "./ImageTutorial/TutorialData_im.tfrecord" 57 | if( reSaveTFRecord ){ 58 | causalimages::WriteTfRecord( 59 | file = TFRecordName_im, 60 | uniqueImageKeys = unique(KeysOfObservations[ take_indices ]), 61 | acquireImageFxn = acquireImageFxn) 62 | } 63 | 64 | for(ImageModelClass in (c("VisionTransformer","CNN"))){ 65 | for(optimizeImageRep in c(T,F)){ 66 | print(sprintf("Image confounding analysis & optimizeImageRep: %s & ImageModelClass: %s",optimizeImageRep, ImageModelClass)) 67 | ImageConfoundingAnalysis <- causalimages::AnalyzeImageConfounding( 68 | obsW = obsW[ take_indices ], 69 | obsY = obsY[ take_indices ], 70 | X = X[ take_indices,apply(X[ take_indices,],2,sd)>0], 71 | long = LongLat$geo_long[ take_indices ], # optional argument 72 | lat = LongLat$geo_lat[ take_indices ], # optional argument 73 | imageKeysOfUnits = KeysOfObservations[ take_indices ], 74 | file = TFRecordName_im, 75 | 76 | batchSize = 16L, 77 | nBoot = 5L, 78 | optimizeImageRep = optimizeImageRep, 79 | imageModelClass = ImageModelClass, 80 | nDepth_ImageRep = ifelse(optimizeImageRep, yes = 1L, no = 1L), 81 | nWidth_ImageRep = as.integer(2L^6), 82 | learningRateMax = 0.001, nSGD = 10L, # 83 | dropoutRate = NULL, # 0.1, 84 | plotBands = c(1,2,3), 85 | plotResults = T, figuresTag = "ConfoundingImTutorial", 86 | figuresPath = "./ImageTutorial") 87 | try(dev.off(), T) 88 | #ImageConfoundingAnalysis$ModelEvaluationMetrics 89 | } 90 | } 91 | 92 | # ATE estimate (image confounder adjusted) 93 | ImageConfoundingAnalysis$tauHat_propensityHajek 94 | 95 | # ATE se estimate (image confounder adjusted) 96 | ImageConfoundingAnalysis$tauHat_propensityHajek_se 97 | 98 | # some out-of-sample evaluation metrics 99 | ImageConfoundingAnalysis$ModelEvaluationMetrics 100 | 101 | } 102 | 103 | # perform causal inference with image *sequence* and tabular confounding 104 | { 105 | acquireVideoRep <- function(keys) { 106 | # Note: this is a toy function generating image representations 107 | # that simply reuse a single temporal slice. In practice, we will 108 | # weant to read in images of different time periods. 109 | 110 | # Get image data as an array from disk 111 | tmp <- acquireImageFxn(keys) 112 | 113 | # Expand dimensions: we create a new dimension at the start 114 | tmp <- array(tmp, dim = c(1, dim(tmp))) 115 | 116 | # Transpose dimensions to get the target order 117 | tmp <- aperm(tmp, c(2, 1, 3, 4, 5)) 118 | 119 | # Swap image dimensions to see variability across time 120 | tmp_ <- aperm(tmp, c(1, 2, 4, 3, 5)) 121 | 122 | # Concatenate along the second axis 123 | tmp <- abind::abind(tmp, tmp, tmp_, tmp_, along = 2) 124 | 125 | return(tmp) 126 | } 127 | 128 | # write tf record 129 | TFRecordName_imSeq <- "./ImageTutorial/TutorialData_imSeq.tfrecord" 130 | if( reSaveTFRecord ){ 131 | causalimages::WriteTfRecord( 132 | file = TFRecordName_imSeq, 133 | uniqueImageKeys = unique(KeysOfObservations[ take_indices ]), 134 | acquireImageFxn = acquireVideoRep, 135 | writeVideo = T) 136 | } 137 | 138 | for(ImageModelClass in c("VisionTransformer","CNN")){ 139 | for(optimizeImageRep in c(T, F)){ 140 | print(sprintf("Image seq confounding analysis & optimizeImageRep: %s & ImageModelClass: %s",optimizeImageRep, ImageModelClass)) 141 | ImageSeqConfoundingAnalysis <- causalimages::AnalyzeImageConfounding( 142 | obsW = obsW[ take_indices ], 143 | obsY = obsY[ take_indices ], 144 | X = X[ take_indices,apply(X[ take_indices,],2,sd)>0], 145 | long = LongLat$geo_long[ take_indices ], 146 | lat = LongLat$geo_lat[ take_indices ], 147 | file = TFRecordName_imSeq, dataType = "video", 148 | imageKeysOfUnits = KeysOfObservations[ take_indices ], 149 | 150 | # model specifics 151 | batchSize = 16L, 152 | optimizeImageRep = optimizeImageRep, 153 | imageModelClass = ImageModelClass, 154 | nDepth_ImageRep = ifelse(optimizeImageRep, yes = 1L, no = 1L), 155 | learningRateMax = 0.001, 156 | nSGD = 50L, # 157 | nWidth_ImageRep = as.integer(2L^7), 158 | nBoot = 5L, 159 | plotBands = c(1,2,3), 160 | plotResults = T, figuresTag = "ConfoundingImSeqTutorial", 161 | figuresPath = "./ImageTutorial") # figures saved here 162 | try(dev.off(), T) 163 | } 164 | } 165 | 166 | # ATE estimate (image confounder adjusted) 167 | ImageSeqConfoundingAnalysis$tauHat_propensityHajek 168 | 169 | # ATE se estimate (image seq confounder adjusted) 170 | ImageSeqConfoundingAnalysis$tauHat_propensityHajek_se 171 | 172 | # some out-of-sample evaluation metrics 173 | ImageSeqConfoundingAnalysis$ModelEvaluationMetrics 174 | print("Done with confounding test!") 175 | } 176 | } 177 | -------------------------------------------------------------------------------- /causalimages/tests/Test_AnalyzeImageHeterogeneity.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | { 3 | ################################ 4 | # Image heterogeneity tutorial using causalimages 5 | ################################ 6 | 7 | # remote install latest version of the package 8 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages") 9 | 10 | # local install for development team 11 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F) 12 | 13 | # build backend you haven't ready: 14 | # causalimages::BuildBackend() 15 | 16 | # run code if downloading data for the first time 17 | download_folder <- "~/Downloads/UgandaAnalysis" 18 | reSaveTfRecords <- F 19 | if( reDownloadRawData <- F ){ 20 | 21 | # (1) Specify the Dataverse dataset DOI 22 | doi <- "doi:10.7910/DVN/O8XOSF" 23 | 24 | # (2) Construct the download URL for the entire dataset as a .zip 25 | # By convention on Harvard Dataverse, this API endpoint: 26 | # https://{server}/api/access/dataset/:persistentId?persistentId={doi}&format=original 27 | # downloads all files in a single zip. 28 | base_url <- "https://dataverse.harvard.edu" 29 | download_url <- paste0(base_url, 30 | "/api/access/dataset/:persistentId", 31 | "?persistentId=", doi, 32 | "&format=original") 33 | 34 | # (3) Download the ZIP file 35 | destfile <- "~/Downloads/UgandaAnalysis.zip" 36 | download.file(download_url, destfile = destfile, mode = "wb") 37 | 38 | # unzip and list files 39 | unzip(destfile, exdir = "~/Downloads/UgandaAnalysis") 40 | unzip('./UgandaAnalysis/Uganda2000_processed.zip', exdir = "~/Downloads/UgandaAnalysis") 41 | } 42 | 43 | # load in package 44 | library( causalimages ); options(error = NULL) 45 | 46 | # set new wd 47 | setwd( "~/Downloads" ) 48 | 49 | # see directory contents 50 | list.files() 51 | 52 | # images saved here 53 | list.files( "./UgandaAnalysis/Uganda2000_processed" ) 54 | 55 | # individual-level data 56 | UgandaDataProcessed <- read.csv( "./UgandaAnalysis/UgandaDataProcessed.csv" ) 57 | 58 | # unit-level covariates (many covariates are subject to missingness!) 59 | dim( UgandaDataProcessed ) 60 | table( UgandaDataProcessed$age ) 61 | 62 | # approximate longitude + latitude for units 63 | head( cbind(UgandaDataProcessed$geo_long, UgandaDataProcessed$geo_lat) ) 64 | 65 | # image keys of units (use for referencing satellite images) 66 | UgandaDataProcessed$geo_long_lat_key 67 | 68 | # an experimental outcome 69 | UgandaDataProcessed$Yobs 70 | 71 | # treatment variable 72 | UgandaDataProcessed$Wobs 73 | 74 | # information on keys linking to satellite images for all of Uganda 75 | # (not just experimental context, use for constructing transportability maps) 76 | UgandaGeoKeyMat <- read.csv( "./UgandaAnalysis/UgandaGeoKeyMat.csv" ) 77 | 78 | # set outcome to an income index 79 | UgandaDataProcessed$Yobs <- UgandaDataProcessed$income_index_e_RECREATED 80 | 81 | # drop observations with NAs in key variables 82 | # (you can also use a multiple imputation strategy) 83 | UgandaDataProcessed <- UgandaDataProcessed[!is.na(UgandaDataProcessed$Yobs) & 84 | !is.na(UgandaDataProcessed$Wobs) & 85 | !is.na(UgandaDataProcessed$geo_lat) , ] 86 | 87 | # sanity checks 88 | { 89 | # write a function that reads in images as saved and process them into an array 90 | NBANDS <- 3L 91 | imageHeight <- imageWidth <- 351L # pixel height/width 92 | acquireImageRep <- function(keys){ 93 | # initialize an array shell to hold image slices 94 | array_shell <- array(NA, dim = c(1L, imageHeight, imageWidth, NBANDS)) 95 | 96 | # iterate over keys: 97 | # -- images are referenced to keys 98 | # -- keys are referenced to units (to allow for duplicate images uses) 99 | array_ <- sapply(keys, function(key_) { 100 | # iterate over all image bands (NBANDS = 3 for RBG images) 101 | for (band_ in 1:NBANDS) { 102 | # place the image in the correct place in the array 103 | array_shell[,,,band_] <- 104 | as.matrix(data.table::fread( 105 | input = sprintf("./UgandaAnalysis/Uganda2000_processed/GeoKey%s_BAND%s.csv", key_, band_), header = FALSE)[-1,]) 106 | } 107 | return(array_shell) 108 | }, simplify = "array") 109 | 110 | # return the array in the format c(nBatch, imageWidth, imageHeight, nChannels) 111 | # ensure that the dimensions are correctly ordered for further processing 112 | if(length(keys) > 1){ array_ <- aperm(array_[1,,,,], c(4, 1, 2, 3) ) } 113 | if(length(keys) == 1){ 114 | array_ <- aperm(array_, c(1,5, 2, 3, 4)) 115 | array_ <- array(array_, dim(array_)[-1]) 116 | } 117 | 118 | return(array_) 119 | } 120 | 121 | # try out the function 122 | # note: some units are co-located in same area (hence, multiple observations per image key) 123 | ImageBatch <- acquireImageRep( UgandaDataProcessed$geo_long_lat_key[ check_indices <- c(1, 20, 50, 101) ]) 124 | acquireImageRep( UgandaDataProcessed$geo_long_lat_key[ check_indices[1] ] ) 125 | 126 | # sanity checks in the analysis of earth observation data are essential 127 | # check that images are centered around correct location 128 | causalimages::image2( as.array(ImageBatch)[1,,,1] ) 129 | UgandaDataProcessed$geo_long[check_indices[1]] 130 | UgandaDataProcessed$geo_lat[check_indices[1]] 131 | # check against google maps to confirm correctness 132 | # https://www.google.com/maps/place/1%C2%B018'16.4%22N+34%C2%B005'15.1%22E/@1.3111951,34.0518834,10145m/data=!3m1!1e3!4m4!3m3!8m2!3d1.3045556!4d34.0875278?entry=ttu 133 | 134 | # scramble data (important for reading into causalimages::WriteTfRecord 135 | # to ensure no systematic biases in data sequence with model training 136 | set.seed(144L); UgandaDataProcessed <- UgandaDataProcessed[sample(1:nrow(UgandaDataProcessed)),] 137 | } 138 | 139 | # Image heterogeneity example with tfrecords 140 | # write a tf records repository 141 | # whenever changes are made to the input data to AnalyzeImageHeterogeneity, WriteTfRecord() should be re-run 142 | # to ensure correct ordering of data 143 | tfrecord_loc <- "~/Downloads/UgandaExample.tfrecord" 144 | if( reSaveTfRecords ){ 145 | causalimages::WriteTfRecord( 146 | file = tfrecord_loc, 147 | uniqueImageKeys = unique(UgandaDataProcessed$geo_long_lat_key), 148 | acquireImageFxn = acquireImageRep ) 149 | } 150 | 151 | for(ImageModelClass in c("VisionTransformer","CNN")){ 152 | for(optimizeImageRep in c(T, F)){ 153 | print(sprintf("Image hetero analysis & optimizeImageRep: %s",optimizeImageRep)) 154 | ImageHeterogeneityResults <- causalimages::AnalyzeImageHeterogeneity( 155 | # data inputs 156 | obsW = UgandaDataProcessed$Wobs, 157 | obsY = UgandaDataProcessed$Yobs, 158 | X = matrix(rnorm(length(UgandaDataProcessed$Yobs)*10),ncol=10), 159 | imageKeysOfUnits = UgandaDataProcessed$geo_long_lat_key, 160 | file = tfrecord_loc, # location of tf record (use absolute file paths) 161 | lat = UgandaDataProcessed$geo_lat, # not required but helpful for dealing with redundant locations in EO data 162 | long = UgandaDataProcessed$geo_long, # not required but helpful for dealing with redundant locations in EO data 163 | 164 | # inputs to control where visual results are saved as PDF or PNGs 165 | # (these image grids are large and difficult to display in RStudio's interactive mode) 166 | plotResults = T, 167 | figuresPath = "~/Downloads/HeteroTutorial", # where to write analysis figures 168 | figuresTag = "HeterogeneityImTutorial",plotBands = 1L:3L, 169 | 170 | # optional arguments for generating transportability maps 171 | # here, we leave those NULL for simplicity 172 | transportabilityMat = NULL, # 173 | 174 | # other modeling options 175 | imageModelClass = ImageModelClass, 176 | nSGD = 5L, # make this larger for real applications (e.g., 2000L) 177 | nDepth_ImageRep = ifelse(optimizeImageRep, yes = 1L, no = 1L), 178 | nWidth_ImageRep = as.integer(2L^6), 179 | optimizeImageRep = optimizeImageRep, 180 | batchSize = 8L, # make this larger for real application (e.g., 50L) 181 | kClust_est = 1 # vary depending on problem. Usually < 5 182 | ) 183 | try(dev.off(), T) 184 | } 185 | } 186 | 187 | # video heterogeneity example 188 | { 189 | acquireVideoRep <- function(keys) { 190 | # Get image data as an array from disk 191 | tmp <- acquireImageRep(keys) 192 | 193 | # Expand dimensions: we create a new dimension at the start 194 | tmp <- array(tmp, dim = c(1, dim(tmp))) 195 | 196 | # Transpose dimensions to get the required order 197 | tmp <- aperm(tmp, c(2, 1, 3, 4, 5)) 198 | 199 | # Swap image dimensions to see variability across time 200 | tmp_ <- aperm(tmp, c(1, 2, 4, 3, 5)) 201 | 202 | # Concatenate along the second axis 203 | tmp <- abind::abind(tmp, tmp_, along = 2) 204 | 205 | return(tmp) 206 | } 207 | 208 | # write the tf records repository 209 | tfrecord_loc_imSeq <- "~/Downloads/UgandaExampleVideo.tfrecord" 210 | if(reSaveTfRecords){ 211 | causalimages::WriteTfRecord( file = tfrecord_loc_imSeq, 212 | uniqueImageKeys = unique(UgandaDataProcessed$geo_long_lat_key), 213 | acquireImageFxn = acquireVideoRep, writeVideo = T ) 214 | } 215 | 216 | for(ImageModelClass in (c("VisionTransformer","CNN"))){ 217 | for(optimizeImageRep in c(T, F)){ 218 | print(sprintf("Image seq hetero analysis & optimizeImageRep: %s",optimizeImageRep)) 219 | # Note: optimizeImageRep = T breaks with video on METAL framework 220 | VideoHeterogeneityResults <- causalimages::AnalyzeImageHeterogeneity( 221 | # data inputs 222 | obsW = UgandaDataProcessed$Wobs, 223 | obsY = UgandaDataProcessed$Yobs, 224 | imageKeysOfUnits = UgandaDataProcessed$geo_long_lat_key, 225 | file = tfrecord_loc_imSeq, # location of tf record (absolute paths are safest) 226 | dataType = "video", 227 | lat = UgandaDataProcessed$geo_lat, # not required but helpful for dealing with redundant locations in EO data 228 | long = UgandaDataProcessed$geo_long, # not required but helpful for dealing with redundant locations in EO data 229 | 230 | # inputs to control where visual results are saved as PDF or PNGs 231 | # (these image grids are large and difficult to display in RStudio's interactive mode) 232 | plotResults = T, 233 | figuresPath = "~/Downloads/HeteroTutorial", 234 | plotBands = 1L:3L, figuresTag = "HeterogeneityImSeqTutorial", 235 | 236 | # optional arguments for generating transportability maps 237 | # here, we leave those NULL for simplicity 238 | transportabilityMat = NULL, # 239 | 240 | # other modeling options 241 | imageModelClass = ImageModelClass, 242 | nSGD = 5L, # make this larger for real applications (e.g., 2000L) 243 | nDepth_ImageRep = ifelse(optimizeImageRep, yes = 1L, no = 1L), 244 | nWidth_ImageRep = as.integer(2L^5), 245 | optimizeImageRep = optimizeImageRep, 246 | kClust_est = 2, # vary depending on problem. Usually < 5 247 | batchSize = 8L, # make this larger for real application (e.g., 50L) 248 | strides = 2L ) 249 | try(dev.off(), T) 250 | } 251 | } 252 | } 253 | print("Done with image heterogeneity test!") 254 | } 255 | -------------------------------------------------------------------------------- /causalimages/tests/Test_BuildBackend.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | { 3 | # remote install latest version of the package 4 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages") 5 | 6 | # local install for development team 7 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F) 8 | 9 | # setup backend, conda points to location of python in conda where packages will be downloaded 10 | # note: This function requires an Internet connection 11 | # you can find out a list of conda Python paths via: 12 | # system("which python") 13 | causalimages::BuildBackend(conda = "/Users/cjerzak/miniforge3/bin/python") 14 | print("Done with BuildBackend() test!") 15 | } 16 | 17 | -------------------------------------------------------------------------------- /causalimages/tests/Test_ExtractImageRepresentations.R: -------------------------------------------------------------------------------- 1 | { 2 | ################################ 3 | # Image and image-sequence embeddings tutorial using causalimages 4 | ################################ 5 | 6 | # remote install latest version of the package if needed 7 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages") 8 | 9 | # local install for development team 10 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F) 11 | 12 | # build backend you haven't ready: 13 | # causalimages::BuildBackend() 14 | 15 | # load in package 16 | library( causalimages ); options(error = NULL) 17 | 18 | # load in tutorial data 19 | data( CausalImagesTutorialData ) 20 | 21 | # example acquire image function (loading from memory) 22 | # in general, you'll want to write a function that returns images 23 | # that saved disk associated with keys 24 | acquireImageFromMemory <- function(keys){ 25 | # here, the function input keys 26 | # refers to the unit-associated image keys 27 | m_ <- FullImageArray[match(keys, KeysOfImages),,,] 28 | 29 | # if keys == 1, add the batch dimension so output dims are always consistent 30 | # (here in image case, dims are batch by height by width by channel) 31 | if(length(keys) == 1){ 32 | m_ <- array(m_,dim = c(1L,dim(m_)[1],dim(m_)[2],dim(m_)[3])) 33 | } 34 | return( m_ ) 35 | } 36 | 37 | # drop first column 38 | X <- X[,-1] 39 | 40 | # mean imputation for simplicity 41 | X <- apply(X,2,function(zer){ 42 | zer[is.na(zer)] <- mean( zer,na.rm = T ) 43 | return( zer ) 44 | }) 45 | 46 | # select observation subset to make tutorial analyses run faster 47 | take_indices <- unlist( tapply(1:length(obsW),obsW,function(zer){ sample(zer, 50) }) ) 48 | 49 | # write tf record 50 | TfRecord_name <- "~/Downloads/CausalImagesTutorialDat.tfrecord" 51 | causalimages::WriteTfRecord( file = TfRecord_name, 52 | uniqueImageKeys = unique( KeysOfObservations[ take_indices ] ), 53 | acquireImageFxn = acquireImageFromMemory ) 54 | 55 | # obtain image representation 56 | MyImageEmbeddings <- causalimages::GetImageRepresentations( 57 | file = TfRecord_name, 58 | imageModelClass = "VisionTransformer", 59 | pretrainedModel = "clip-rsicd", 60 | imageKeysOfUnits = KeysOfObservations[ take_indices ] 61 | ) 62 | 63 | # each row in MyImageEmbeddings$ImageRepresentations corresponds to an observation 64 | # each column represents an embedding dimension associated with the imagery for that location 65 | dim( MyImageEmbeddings$ImageRepresentations ) 66 | plot( MyImageEmbeddings$ImageRepresentations ) 67 | 68 | # other output quantities include the image model functions and model parameters 69 | names( MyImageEmbeddings )[-1] 70 | 71 | print("Done with image representations test!") 72 | } 73 | -------------------------------------------------------------------------------- /causalimages/tests/Test_UsingTfRecords.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | { 3 | ################################ 4 | # Image confounding tutorial using causalimages 5 | # and tfrecords for faster results 6 | ################################ 7 | 8 | # remote install latest version of the package if needed 9 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages") 10 | 11 | # local install for development team 12 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F) 13 | 14 | # build backend you haven't ready: 15 | # causalimages::BuildBackend() 16 | 17 | # load in package 18 | library( causalimages ) 19 | 20 | # load in tutorial data 21 | data( CausalImagesTutorialData ) 22 | 23 | # example acquire image function (loading from memory) 24 | # in general, you'll want to write a function that returns images 25 | # that saved disk associated with keys 26 | acquireImageFromMemory <- function(keys){ 27 | # here, the function input keys 28 | # refers to the unit-associated image keys 29 | m_ <- FullImageArray[match(keys, KeysOfImages),,,] 30 | 31 | # if keys == 1, add the batch dimension so output dims are always consistent 32 | # (here in image case, dims are batch by height by width by channel) 33 | if(length(keys) == 1){ 34 | m_ <- array(m_,dim = c(1L,dim(m_)[1],dim(m_)[2],dim(m_)[3])) 35 | } 36 | 37 | # uncomment for a test with different image dimensions 38 | #if(length(keys) == 1){ m_ <- abind::abind(m_,m_,m_,along = 3L) }; if(length(keys) > 1){ m_ <- abind::abind(m_,m_,m_,.along = 4L) } 39 | return( m_ ) 40 | } 41 | 42 | dim( acquireImageFromMemory(KeysOfImages[1]) ) 43 | dim( acquireImageFromMemory(KeysOfImages[1:2]) ) 44 | 45 | # drop first column 46 | X <- X[,-1] 47 | 48 | # mean imputation for simplicity 49 | X <- apply(X,2,function(zer){ 50 | zer[is.na(zer)] <- mean( zer,na.rm = T ) 51 | return( zer ) 52 | }) 53 | 54 | # select observation subset to make tutorial analyses run faster 55 | # select 50 treatment and 50 control observations 56 | set.seed(1.) 57 | take_indices <- unlist( tapply(1:length(obsW),obsW,function(zer){sample(zer, 50)}) ) 58 | 59 | # !!! important note !!! 60 | # when using tf recordings, it is essential that the data inputs be pre-shuffled like is done here. 61 | # you can use a seed for reproducing the shuffle (so the tfrecord is correctly indexed and you don't need to re-make it) 62 | # tf records read data quasi-sequentially, so systematic patterns in the data ordering 63 | # reduce performance 64 | 65 | # uncomment for a larger n analysis 66 | #take_indices <- 1:length( obsY ) 67 | 68 | # set tfrecord save location (safest using absolute path) 69 | tfrecord_loc <- "./ExampleRecord.tfrecord" 70 | 71 | # you may use relative paths like this: 72 | # tfrecord_loc <- "./Downloads/test1/test2/test3/ExampleRecord.tfrecord" 73 | 74 | # or absolute paths like this: 75 | # tfrecord_loc <- "~/Downloads/test1/test2/test3/ExampleRecord.tfrecord" 76 | 77 | # write a tf records repository 78 | causalimages::WriteTfRecord( file = tfrecord_loc, 79 | uniqueImageKeys = unique( KeysOfObservations[ take_indices ] ), 80 | acquireImageFxn = acquireImageFromMemory ) 81 | } 82 | 83 | -------------------------------------------------------------------------------- /documentPackage.R: -------------------------------------------------------------------------------- 1 | { 2 | rm(list=ls()); options(error = NULL) 3 | package_name <- "causalimages" 4 | setwd(sprintf("~/Documents/%s-software", package_name)) 5 | 6 | package_path <- sprintf("~/Documents/%s-software/%s",package_name,package_name) 7 | tools::add_datalist(package_path, force = TRUE) 8 | devtools::document(package_path) 9 | try(file.remove(sprintf("./%s.pdf",package_name)),T) 10 | system(sprintf("R CMD Rd2pdf %s",package_path)) 11 | 12 | # install.packages( sprintf("~/Documents/%s-software/%s",package_name,package_name),repos = NULL, type = "source") 13 | # library( causalimages ); data( CausalImagesTutorialData ) 14 | log(sort( sapply(ls(),function(l_){ object.size(eval(parse(text=l_))) }))) 15 | 16 | # Check package to ensure it meets CRAN standards. 17 | # devtools::check( package_path ) 18 | 19 | # see https://github.com/RConsortium/S7 20 | } 21 | -------------------------------------------------------------------------------- /misc/dataverse/DataverseReadme_confounding.md: -------------------------------------------------------------------------------- 1 | Replication Data for: Integrating Earth Observation Data into Causal Inference: Challenges and Opportunities 2 | 3 | Connor T. Jerzak, Fredrik Johansson, Adel Daoud. Integrating Earth Observation Data into Causal Inference: Challenges and Opportunities. ArXiv Preprint, 2023. 4 | 5 | YandW_mat.csv contains individual-level observational data. In the dataset, LONGITUDE and LATITUDE refer to the approximate geo-referenced long/lat of observational units. Experimental outcomes are stored in Yobs. Treatment variable is stored in Wobs. See the tutorial for more information. The unique image key for each observational unit is saved in UNIQUE_ID. 6 | 7 | Geo-referenced satellite images are saved in 8 | "./Nigeria2000_processed/%s_BAND%s.csv"", where the first "%s" refers to the the image key associated with each observation (saved in UNIQUE_ID in YandW_mat.csv) and BAND%s refers to one of 3 bands in the satellite imagery. 9 | 10 | After downloading the replication package, unzip `Nigeria2000_processed.zip` so 11 | that the `Nigeria2000_processed` folder containing the band CSV files resides in 12 | the same directory as `YandW_mat.csv`. 13 | 14 | For more information, see: https://github.com/cjerzak/causalimages-software/ 15 | -------------------------------------------------------------------------------- /misc/dataverse/DataverseReadme_heterogeneity.md: -------------------------------------------------------------------------------- 1 | Replication Data for: Image-based Treatment Effect Heterogeneity 2 | 3 | Connor Thomas Jerzak, Fredrik Daniel Johansson, Adel Daoud Proceedings of the Second Conference on Causal Learning and Reasoning, PMLR 213:531-552, 2023. 4 | 5 | UgandaDataProcessed.csv contains individual-level data from the YOP experiment. In the dataset, geo_long and geo_lat refer to the approximate geo-referenced long/lat of experimental units. The variable, geo_long_lat_key, refers to the image key associated with each location. Experimental outcomes are stored in Yobs. Treatment variable is stored in Wobs. See the tutorial for more information. 6 | 7 | UgandaGeoKeyMat.csv contains information on keys linking to satellite images for all of Uganda for the transportability analysis. 8 | 9 | Geo-referenced satellite images are saved in "./Uganda2000_processed/GeoKey%s_BAND%s.csv", where GeoKey%s denotes the image key associated with each observation and BAND%s refers to one of 3 bands in the satellite imagery. 10 | 11 | Unzip `Uganda2000_processed.zip` so the `Uganda2000_processed` directory 12 | containing the band CSVs sits alongside `UgandaDataProcessed.csv` before running 13 | the tutorial scripts. 14 | 15 | For more information, see: https://github.com/cjerzak/causalimages-software/blob/main/tutorials/AnalyzeImageHeterogeneity_FullTutorial.R 16 | -------------------------------------------------------------------------------- /misc/dataverse/DataverseTutorial_confounding.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | 3 | ################################ 4 | # For an up-to-date version of this tutorial, see 5 | # 6 | ################################ 7 | 8 | # set new wd 9 | -------------------------------------------------------------------------------- /misc/dataverse/DataverseTutorial_heterogeneity.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | 3 | ################################ 4 | # For an up-to-date version of this tutorial, see 5 | # https://github.com/cjerzak/causalimages-software/blob/main/tutorials/AnalyzeImageHeterogeneity_FullTutorial.R 6 | ################################ 7 | 8 | # set new wd 9 | setwd(sprintf('%s/Public Replication Data, YOP Experiment/', 10 | gsub(download_folder,pattern="\\.zip",replace=""))) 11 | 12 | # see directory contents 13 | list.files() 14 | 15 | # images saved here 16 | list.files( "./Uganda2000_processed" ) 17 | 18 | # individual-level data 19 | UgandaDataProcessed <- read.csv( "./UgandaDataProcessed.csv" ) 20 | 21 | # unit-level covariates (many covariates are subject to missingness!) 22 | dim( UgandaDataProcessed ) 23 | table( UgandaDataProcessed$female ) 24 | table( UgandaDataProcessed$age ) 25 | 26 | # approximate longitude + latitude for units 27 | UgandaDataProcessed$geo_long 28 | UgandaDataProcessed$geo_lat 29 | 30 | # image keys of units (use for referencing satellite images) 31 | UgandaDataProcessed$geo_long_lat_key 32 | 33 | # an experimental outcome 34 | UgandaDataProcessed$Yobs 35 | 36 | # treatment variable 37 | UgandaDataProcessed$Wobs 38 | 39 | # information on keys linking to satellite images for all of Uganda 40 | # (not just experimental context, use for constructing transportability maps) 41 | UgandaGeoKeyMat <- read.csv( "./UgandaGeoKeyMat.csv" ) 42 | tail( UgandaGeoKeyMat ) 43 | 44 | # Geo-referenced satellite images are saved in 45 | # "./Uganda2000_processed/GeoKey%s_BAND%s.csv", 46 | # where GeoKey%s denotes the image key associated with each observation and 47 | # BAND%s refers to one of 3 bands in the satellite imagery. 48 | # See https://github.com/cjerzak/causalimages-software/blob/main/tutorials/AnalyzeImageHeterogeneity_FullTutorial.R 49 | # for up-to-date useage information. 50 | -------------------------------------------------------------------------------- /misc/docker/setup/CodexHelpers.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -euo pipefail 4 | # set -x 5 | 6 | #------------------------------------------------------------------------------ 7 | # 1. Install R and unzip, other utilities 8 | #------------------------------------------------------------------------------ 9 | apt-get update \ 10 | && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 11 | r-base r-base-dev \ 12 | libssl-dev libxml2-dev unzip \ 13 | build-essential libcurl4-openssl-dev \ 14 | libgdal-dev libgeos-dev libproj-dev libudunits2-dev ca-certificates \ 15 | && apt-get clean \ 16 | && rm -rf /var/lib/apt/lists/* 17 | 18 | #------------------------------------------------------------------------------ 19 | # 2. Install Miniconda into /opt/conda 20 | #------------------------------------------------------------------------------ 21 | readonly MINICONDA_SH="/tmp/Miniconda3-latest-Linux-x86_64.sh" 22 | wget --quiet "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" -O "$MINICONDA_SH" 23 | bash "$MINICONDA_SH" -b -p /opt/conda 24 | rm "$MINICONDA_SH" 25 | 26 | # Make sure conda is on PATH 27 | export PATH="/opt/conda/bin:${PATH}" 28 | 29 | # Initialize conda for this shell 30 | if [[ -f "/opt/conda/etc/profile.d/conda.sh" ]]; then 31 | source "/opt/conda/etc/profile.d/conda.sh" 32 | else 33 | echo "Error: /opt/conda/etc/profile.d/conda.sh not found" >&2 34 | exit 1 35 | fi 36 | 37 | 38 | # Define URLs for ZIP of tars and conda ZIP 39 | THE_ZIP_URL="https://www.dl.dropboxusercontent.com/scl/fi/ek76700s9o4p4grwfw5ko/Archive.zip?rlkey=w0vj0qh78i3eb3c5zt4vlmot9&st=f8kxeq6t&dl=1" 40 | readonly CONDA_ENV_ZIP_URL="https://www.dl.dropboxusercontent.com/scl/fi/k5vylxygjl4icm76drtsz/CausalImagesEnv.zip?rlkey=hmmwpma9bihoze25vee44dktz&st=0vz2vmbk&dl=1" 41 | 42 | #------------------------------------------------------------------------------ 43 | # 3. Fetch and unpack the ZIP-of-tar's 44 | #------------------------------------------------------------------------------ 45 | BUNDLE_ZIP="binaries.zip" 46 | 47 | # download the bundle 48 | curl -sSL -o "/tmp/${BUNDLE_ZIP}" "${THE_ZIP_URL}" 49 | 50 | # unpack into a temp directory 51 | mkdir -p /tmp/binaries 52 | unzip -q "/tmp/${BUNDLE_ZIP}" -d /tmp/binaries 53 | 54 | echo "Files extracted to /tmp/binaries:" 55 | find /tmp/binaries -type f -printf " ➜ %P\n" 56 | 57 | #------------------------------------------------------------------------------ 58 | # 4. Install each inner ZIP into R's library path 59 | #------------------------------------------------------------------------------ 60 | # find the first (user) library path and make sure it exists 61 | LIB="$(Rscript -e 'cat(.libPaths()[1])')" 62 | mkdir -p "$LIB" 63 | 64 | #echo "Binary ZIPs found under /tmp/binaries:" 65 | #for f in /tmp/binaries/*.tar.gz; do 66 | # echo " ➜ $f" 67 | # tar -xzf "$f" -C "$LIB" 68 | #done 69 | 70 | echo "Installing binary packages into $LIB:" 71 | for f in /tmp/binaries/*.tar.gz; do 72 | pkg=$(basename "$f" | sed -E 's/_.*//') # e.g. DBI_1.2.3... → DBI 73 | echo " ➜ $pkg" 74 | mkdir -p "$LIB/$pkg" # ensure lib/pkg exists 75 | #tar --strip-components=1 -xzf "$f" -C "$LIB/$pkg" # drop top‐level folder 76 | tar -xzf "$f" -C "$LIB" 77 | 78 | echo " Contents of $LIB/$pkg:" 79 | find "$LIB/$pkg" -maxdepth 1 -mindepth 1 -printf " ➜ %f\n" 80 | done 81 | 82 | # clean up temporary files 83 | rm -rf /tmp/${BUNDLE_ZIP} /tmp/binaries 84 | 85 | echo "Installed R packages in $LIB:" 86 | Rscript -e ' 87 | #lib <- Sys.getenv("R_LIBS_USER", .libPaths()[1]) 88 | #pkgs <- rownames(installed.packages(lib.loc = lib)) 89 | pkgs <- rownames(installed.packages(lib.loc = .libPaths())) 90 | cat(paste0("Installed packages found ➜ ", pkgs, "\n"), sep = "") 91 | ' 92 | 93 | echo "Installing causalimages backend..." 94 | readonly CONDA_ENV_ZIP="/tmp/conda_env.zip" 95 | 96 | echo "Downloading pre-built conda environment from $CONDA_ENV_ZIP_URL" 97 | curl -sSL -o "$CONDA_ENV_ZIP" "$CONDA_ENV_ZIP_URL" 98 | 99 | echo "Unpacking conda environment into /opt/conda/envs/" 100 | unzip -q "$CONDA_ENV_ZIP" -d "/opt/conda/envs/" 101 | rm "$CONDA_ENV_ZIP" 102 | 103 | echo "Available conda environments:" 104 | conda env list 105 | 106 | echo "Success: Done with download & unpacking script! Ready to experiment." 107 | -------------------------------------------------------------------------------- /misc/docker/setup/Dockerfile: -------------------------------------------------------------------------------- 1 | 2 | docker run --rm -it \ 3 | --platform=linux/amd64 \ 4 | -e CODEX_ENV_PYTHON_VERSION=3.12 \ 5 | -e CODEX_ENV_NODE_VERSION=20 \ 6 | -e CODEX_ENV_RUST_VERSION=1.87.0 \ 7 | -e CODEX_ENV_GO_VERSION=1.23.8 \ 8 | -e CODEX_ENV_SWIFT_VERSION=6.1 \ 9 | -v "$HOME/Documents/causalimages-software/misc/docker/binaries:/binaries" \ 10 | ghcr.io/openai/codex-universal:latest \ 11 | -exc " 12 | set -euo pipefail 13 | 14 | echo \"🔧 Installing system development libraries...\" 15 | # Prepare apt for CRAN 16 | apt-get update -qq && \ 17 | apt-get install -y --no-install-recommends \ 18 | software-properties-common dirmngr ca-certificates && \ 19 | 20 | # Install R base + dev headers + all system libraries 21 | DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 22 | r-base r-base-dev \ 23 | build-essential cmake pkg-config git zip \ 24 | libcurl4-openssl-dev libssl-dev libxml2-dev \ 25 | libgdal-dev libgeos-dev libproj-dev libudunits2-dev \ 26 | ca-certificates && \ 27 | apt-get clean && \ 28 | rm -rf /var/lib/apt/lists/* 29 | 30 | # Clean previous artifacts and prepare directories 31 | rm -rf /binaries/* && mkdir -p /binaries/src /binaries/bin 32 | 33 | # List of CRAN packages in dependency-safe order 34 | pkgs=( \ 35 | remotes yaml config ps R6 processx Rcpp RcppEigen RcppTOML rprojroot here jsonlite png rappdirs rlang withr reticulate base64enc magrittr whisker cli glue lifecycle vctrs tidyselect rstudioapi tfruns backports tfautograph tensorflow sp geosphere terra raster rrapply iterators foreach shape glmnet proxy e1071 classInt DBI wk s2 units sf data.table plyr pROC Matrix lattice survival codetools class KernSmooth MASS \ 36 | ) 37 | 38 | echo \"🔧 Installing Miniconda...\" 39 | MINI=Miniconda3-latest-Linux-x86_64.sh 40 | wget -q https://repo.anaconda.com/miniconda/\${MINI} -O /tmp/\${MINI} && \ 41 | bash /tmp/\${MINI} -b -p /opt/miniconda && \ 42 | rm /tmp/\${MINI} && \ 43 | export PATH=\"/opt/miniconda/bin:\$PATH\" && \ 44 | conda config --set always_yes yes --set changeps1 no && \ 45 | conda update -q conda 46 | 47 | 48 | # Download CRAN package sources 49 | echo \"📥 Downloading CRAN package sources...\" 50 | cd /binaries/src 51 | for pkg in \"\${pkgs[@]}\"; do 52 | echo \"📦 Downloading \$pkg source...\" 53 | Rscript -e \"download.packages('\$pkg', destdir='/binaries/src', type='source', repos='https://cloud.r-project.org')\" \ 54 | || { echo \"❌ Failed to download \$pkg\"; exit 1; } 55 | # verify that a tar.gz actually appeared 56 | if ! compgen -G \"\${pkg}_*.tar.gz\" >/dev/null; then 57 | echo \"❌ No source tarball found for \$pkg after download, likely installed in base-R (recommended package)?.\" >&2 58 | # exit 1 59 | continue 60 | fi 61 | done 62 | 63 | echo \"✅ Downloaded all CRAN sources ✅\" 64 | 65 | # Build CRAN package binaries in correct dependency order 66 | echo \"🛠️ Building CRAN package binaries...\" 67 | cd /binaries/bin 68 | for pkg in \"\${pkgs[@]}\"; do 69 | src_tar=(/binaries/src/\${pkg}_*.tar.gz) 70 | if [[ -f \"\${src_tar[0]}\" ]]; then 71 | echo \"⚙ Building \$pkg from \$(basename \"\${src_tar[0]}\")...\" 72 | R CMD INSTALL --build \"\${src_tar[0]}\" 73 | else 74 | echo \"⚠ Source for \$pkg not found, skipping (likely found in base R install).\" 75 | #exit 1 # <— stop the entire script here 76 | continue 77 | fi 78 | done 79 | 80 | # Build GitHub package binary 81 | cd /binaries/src 82 | 83 | # clone *and* sparsify GitHub repo in one go 84 | git clone \ 85 | --depth 1 \ 86 | --filter=blob:none \ 87 | --sparse \ 88 | https://github.com/cjerzak/causalimages-software.git \ 89 | /binaries/src/repo 90 | 91 | cd /binaries/src/repo 92 | 93 | # then tell Git which folder you actually want 94 | git sparse-checkout set causalimages 95 | 96 | # now build 97 | R CMD INSTALL --build ./causalimages 98 | mv causalimages_*.tar.gz /binaries/bin/causalimages.tar.gz 99 | 100 | echo \"✅✅ Built binary tarballs ✅✅\" 101 | 102 | echo \"🔧 Building conda backend...\" 103 | export RETICULATE_MINICONDA_PATH=/opt/miniconda 104 | Rscript -e \" 105 | library(causalimages); 106 | causalimages::BuildBackend(conda_env='CausalImagesEnv', 107 | conda='/opt/miniconda/bin/conda') 108 | \" 109 | 110 | echo \"📦 Zipping conda env...\" 111 | zip -r /binaries/CausalImagesEnv.zip /opt/miniconda/envs/CausalImagesEnv 112 | 113 | 114 | echo \"✅✅✅ Built binary tarballs ✅✅✅\" 115 | " 116 | 117 | 118 | -------------------------------------------------------------------------------- /misc/docker/setup/DockerfileNoCompile: -------------------------------------------------------------------------------- 1 | docker run --platform=linux/amd64 --rm -v "$HOME/Documents/causalimages-software/misc/docker/binaries:/binaries" rocker/r-ver:4.4.0 bash -exc " 2 | set -euo pipefail 3 | 4 | echo \"🔧 Installing system development libraries...\" 5 | apt-get update -qq && \ 6 | DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 7 | build-essential cmake pkg-config git libcurl4-openssl-dev libssl-dev libxml2-dev \ 8 | libgdal-dev libgeos-dev libproj-dev libudunits2-dev ca-certificates && \ 9 | apt-get clean && \ 10 | rm -rf /var/lib/apt/lists/* 11 | 12 | # Build GitHub package binary 13 | cd /binaries/src 14 | 15 | # Check and remove existing repo directory if it exists 16 | if [ -d \"/binaries/src/repo\" ]; then 17 | rm -rf /binaries/src/repo 18 | fi 19 | 20 | # clone *and* sparsify GitHub repo in one go 21 | git clone \ 22 | --depth 1 \ 23 | --filter=blob:none \ 24 | --sparse \ 25 | https://github.com/cjerzak/causalimages-software.git \ 26 | /binaries/src/repo 27 | 28 | cd /binaries/src/repo 29 | 30 | # then tell Git which folder you actually want 31 | git sparse-checkout set causalimages 32 | 33 | # now build 34 | #R CMD build ./causalimages --no-manual --no-build-vignettes 35 | R CMD INSTALL --build ./causalimages 36 | mv causalimages_*.tar.gz /binaries/bin/causalimages.tar.gz 37 | 38 | echo \"✅✅✅ Built binary tarballs ✅✅✅\" 39 | " 40 | 41 | -------------------------------------------------------------------------------- /misc/docker/setup/FindDependencies.sh: -------------------------------------------------------------------------------- 1 | 2 | docker run --platform=linux/amd64 --rm \ 3 | -v "$(pwd)/binaries:/binaries" \ 4 | rocker/r-ver:4.4.0 bash -exc " 5 | set -euo pipefail 6 | 7 | echo \"🔧 Installing system development libraries...\" 8 | apt-get update -qq && \ 9 | DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 10 | build-essential libcurl4-openssl-dev libssl-dev libxml2-dev \ 11 | libgdal-dev libgeos-dev libproj-dev libudunits2-dev ca-certificates \ 12 | && apt-get clean && rm -rf /var/lib/apt/lists/* 13 | 14 | echo \"🛠️ Installing remotes from CRAN...\" 15 | Rscript -e \"options(repos = c(CRAN = 'https://cloud.r-project.org')); \ 16 | install.packages('remotes', dependencies = TRUE)\" 17 | 18 | echo \"📂 Preparing build directory...\" 19 | mkdir -p /binaries/bin 20 | cd /binaries/bin 21 | 22 | echo \"📦 Building causalimages from GitHub...\" 23 | Rscript -e \"remotes::install_github('cjerzak/causalimages-software', \ 24 | subdir='causalimages', \ 25 | dependencies=FALSE)\" 26 | 27 | echo \"⚙️ Building causalimages binary...\" 28 | R CMD INSTALL --build causalimages_*.tar.gz 29 | 30 | echo \"✅ Built binary tarballs:\" 31 | ls -1 *.tar.gz 32 | " 33 | -------------------------------------------------------------------------------- /misc/docker/setup/GenEnv.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # build_and_package.sh 3 | # Run this from the same directory as your Dockerfile. 4 | 5 | set -euo pipefail 6 | 7 | IMAGE_NAME="causalimages-env" 8 | TAR_NAME="${IMAGE_NAME}.tar" 9 | ZIP_NAME="${IMAGE_NAME}.zip" 10 | 11 | echo "1) Building Docker image..." 12 | docker build -t ${IMAGE_NAME} . 13 | 14 | echo "2) Saving image to tarball..." 15 | docker save ${IMAGE_NAME} -o ${TAR_NAME} 16 | 17 | echo "3) Zipping up Dockerfile + tarball..." 18 | zip -r ${ZIP_NAME} Dockerfile ${TAR_NAME} 19 | 20 | echo 21 | echo "Done! Your package is: ${ZIP_NAME}" 22 | echo "Push that ZIP to GitHub, then Codex can unzip and run:" 23 | echo " docker load -i ${TAR_NAME}" 24 | echo " docker run -it ${IMAGE_NAME} bash" 25 | -------------------------------------------------------------------------------- /misc/docker/setup/GetRDependencyOrder.R: -------------------------------------------------------------------------------- 1 | # Script: GetRDependencyOrder.R 2 | # Purpose: - Get all dependencies for set of base packages 3 | # - Get safe dependency order for creation of compiled R packages 4 | { 5 | # Your initial target packages 6 | pkgs <- c("remotes", "tensorflow", "reticulate", "geosphere", "raster", 7 | "rrapply", "glmnet", "sf", "data.table", "pROC") 8 | 9 | # Fetch the CRAN package database 10 | cran_db <- available.packages(repos = "https://cloud.r-project.org") 11 | 12 | # 1. Get the full recursive set of Depends + Imports 13 | recursive_deps <- tools::package_dependencies( 14 | pkgs, 15 | db = cran_db, 16 | which = c("Depends", "Imports"), 17 | recursive = TRUE 18 | ) 19 | all_pkgs <- unique(c(pkgs, unlist(recursive_deps))) 20 | 21 | # Identify base/recommended packages to drop 22 | base_pkgs <- rownames(installed.packages(priority = c("base", "recommended"))) 23 | 24 | # 2. Build a direct-dependency map for every package 25 | direct_deps <- tools::package_dependencies( 26 | all_pkgs, 27 | db = cran_db, 28 | which = c("Depends", "Imports"), 29 | recursive = FALSE 30 | ) 31 | # Remove any base packages from each dependency vector 32 | direct_deps <- lapply(direct_deps, setdiff, y = base_pkgs) 33 | 34 | # 3. Topological sort via DFS 35 | ordered <- character(0) # will hold the final install order 36 | visited <- character(0) # permanently marked nodes 37 | visiting <- character(0) # temporary marks for cycle detection 38 | 39 | visit <- function(pkg) { 40 | if (pkg %in% visited) return() 41 | if (pkg %in% visiting) { 42 | stop("Circular dependency detected involving: ", pkg) 43 | } 44 | visiting <<- c(visiting, pkg) 45 | for (dep in direct_deps[[pkg]]) { 46 | visit(dep) 47 | } 48 | visiting <<- setdiff(visiting, pkg) 49 | visited <<- c(visited, pkg) 50 | ordered <<- c(ordered, pkg) 51 | } 52 | 53 | # Run DFS on all packages 54 | for (p in all_pkgs) { 55 | visit(p) 56 | } 57 | 58 | # Filter to just the ones you want (if you only want your original pkgs + all their deps) 59 | dependency_safe_order <- intersect(ordered, all_pkgs) 60 | dependency_safe_order <- dependency_safe_order[!dependency_safe_order %in% 61 | rownames(installed.packages(priority = c("base")))] 62 | 63 | # manual addins 64 | dependency_safe_order <- c( 65 | dependency_safe_order[1:which(dependency_safe_order=="Rcpp")], 66 | "RcppEigen", 67 | dependency_safe_order[(which(dependency_safe_order=="Rcpp")+1):length(dependency_safe_order)]) 68 | 69 | # Inspect 70 | cat("Install in this order:\n") 71 | cat(paste(dependency_safe_order, collapse = " "), "\n") 72 | # must add in: RcppEigen 73 | } 74 | 75 | 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /misc/notes/MaintainerNotes.txt: -------------------------------------------------------------------------------- 1 | # Maintainer notes for GPU support 2 | # 3 | # These notes summarize our experience configuring JAX with GPU acceleration. 4 | # The package relies heavily on JAX and TensorFlow, so version compatibility is 5 | # important. 6 | # Use Python 3.10 or above for jax-metal. Some JAX operations (e.g. 7 | # `jax$nn$softplus` or random sampling) can segfault when used with float16. 8 | # Sampling parameters with `seeds` (not `vseeds`) and wrapping functions with 9 | # `tf$function(jit_compile = TRUE)` works with jax-metal. 10 | 11 | 12 | # Conda environment setup on M1+ Mac: 13 | conda create -n CausalImagesEnv python==3.11 14 | conda activate CausalImagesEnv 15 | python3 -m pip install tensorflow tensorflow-metal optax equinox jmp tensorflow_probability 16 | python3 -m pip install jax-metal 17 | 18 | # After installation, verify that JAX sees your GPU by running in Python: 19 | # >>> import jax; jax.devices() 20 | -------------------------------------------------------------------------------- /other/DataverseReadme_confounding.md: -------------------------------------------------------------------------------- 1 | Replication Data for: Integrating Earth Observation Data into Causal Inference: Challenges and Opportunities 2 | 3 | Connor T. Jerzak, Fredrik Johansson, Adel Daoud. Integrating Earth Observation Data into Causal Inference: Challenges and Opportunities. ArXiv Preprint, 2023. 4 | 5 | YandW_mat.csv contains individual-level observational data. In the dataset, LONGITUDE and LATITUDE refer to the approximate geo-referenced long/lat of observational units. Experimental outcomes are stored in Yobs. Treatment variable is stored in Wobs. See the tutorial for more information. The unique image key for each observational unit is saved in UNIQUE_ID. 6 | 7 | Geo-referenced satellite images are saved in 8 | "./Nigeria2000_processed/%s_BAND%s.csv"", where the first "%s" refers to the the image key associated with each observation (saved in UNIQUE_ID in YandW_mat.csv) and BAND%s refers to one of 3 bands in the satellite imagery. 9 | 10 | For more information, see: https://github.com/cjerzak/causalimages-software/ 11 | -------------------------------------------------------------------------------- /other/DataverseReadme_heterogeneity.md: -------------------------------------------------------------------------------- 1 | Replication Data for: Image-based Treatment Effect Heterogeneity 2 | 3 | Connor Thomas Jerzak, Fredrik Daniel Johansson, Adel Daoud Proceedings of the Second Conference on Causal Learning and Reasoning, PMLR 213:531-552, 2023. 4 | 5 | UgandaDataProcessed.csv contains individual-level data from the YOP experiment. In the dataset, geo_long and geo_lat refer to the approximate geo-referenced long/lat of experimental units. The variable, geo_long_lat_key, refers to the image key associated with each location. Experimental outcomes are stored in Yobs. Treatment variable is stored in Wobs. See the tutorial for more information. 6 | 7 | UgandaGeoKeyMat.csv contains information on keys linking to satellite images for all of Uganda for the transportability analysis. 8 | 9 | Geo-referenced satellite images are saved in "./Uganda2000_processed/GeoKey%s_BAND%s.csv", where GeoKey%s denotes the image key associated with each observation and BAND%s refers to one of 3 bands in the satellite imagery. 10 | 11 | For more information, see: https://github.com/cjerzak/causalimages-software/blob/main/tutorials/AnalyzeImageHeterogeneity_FullTutorial.R 12 | -------------------------------------------------------------------------------- /other/DataverseTutorial_confounding.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | 3 | ################################ 4 | # For an up-to-date version of this tutorial, see 5 | # 6 | ################################ 7 | 8 | # set new wd 9 | -------------------------------------------------------------------------------- /other/DataverseTutorial_heterogeneity.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | 3 | ################################ 4 | # For an up-to-date version of this tutorial, see 5 | # https://github.com/cjerzak/causalimages-software/blob/main/tutorials/AnalyzeImageHeterogeneity_FullTutorial.R 6 | ################################ 7 | 8 | # set new wd 9 | setwd(sprintf('%s/Public Replication Data, YOP Experiment/', 10 | gsub(download_folder,pattern="\\.zip",replace=""))) 11 | 12 | # see directory contents 13 | list.files() 14 | 15 | # images saved here 16 | list.files( "./Uganda2000_processed" ) 17 | 18 | # individual-level data 19 | UgandaDataProcessed <- read.csv( "./UgandaDataProcessed.csv" ) 20 | 21 | # unit-level covariates (many covariates are subject to missingness!) 22 | dim( UgandaDataProcessed ) 23 | table( UgandaDataProcessed$female ) 24 | table( UgandaDataProcessed$age ) 25 | 26 | # approximate longitude + latitude for units 27 | UgandaDataProcessed$geo_long 28 | UgandaDataProcessed$geo_lat 29 | 30 | # image keys of units (use for referencing satellite images) 31 | UgandaDataProcessed$geo_long_lat_key 32 | 33 | # an experimental outcome 34 | UgandaDataProcessed$Yobs 35 | 36 | # treatment variable 37 | UgandaDataProcessed$Wobs 38 | 39 | # information on keys linking to satellite images for all of Uganda 40 | # (not just experimental context, use for constructing transportability maps) 41 | UgandaGeoKeyMat <- read.csv( "./UgandaGeoKeyMat.csv" ) 42 | tail( UgandaGeoKeyMat ) 43 | 44 | # Geo-referenced satellite images are saved in 45 | # "./Uganda2000_processed/GeoKey%s_BAND%s.csv", 46 | # where GeoKey%s denotes the image key associated with each observation and 47 | # BAND%s refers to one of 3 bands in the satellite imagery. 48 | # See https://github.com/cjerzak/causalimages-software/blob/main/tutorials/AnalyzeImageHeterogeneity_FullTutorial.R 49 | # for up-to-date useage information. 50 | -------------------------------------------------------------------------------- /other/PackageRunChecks.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | 3 | ################################ 4 | # master test suite 5 | # note: this code is primarily for the use of the development team 6 | ################################ 7 | 8 | # in process 9 | -------------------------------------------------------------------------------- /tutorials/AnalyzeImageConfounding_Tutorial_Advanced.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | { 3 | ################################ 4 | # Image confounding tutorial using causalimages 5 | ################################ 6 | 7 | # clean workspace 8 | rm(list=ls()); options(error = NULL) 9 | 10 | # setup environment 11 | if(Sys.getenv()["RSTUDIO_USER_IDENTITY"] == "cjerzak"){ 12 | setwd("~/Downloads/") 13 | } 14 | if(Sys.getenv()["RSTUDIO_USER_IDENTITY"] != "cjerzak"){ 15 | setwd("./") 16 | # or set directory as desired 17 | } 18 | 19 | # remote install latest version of the package if needed 20 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages") 21 | 22 | # local install for development team 23 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F) 24 | 25 | # build backend you haven't ready (run this only once upon (re)installing causalimages!) 26 | # causalimages::BuildBackend() 27 | 28 | # load in package 29 | library( causalimages ) 30 | 31 | # resave TfRecords? 32 | reSaveTFRecord <- FALSE 33 | 34 | # load in tutorial data 35 | data( CausalImagesTutorialData ) 36 | 37 | # mean imputation for toy example in this tutorial 38 | X <- apply(X[,-1],2,function(zer){ 39 | zer[is.na(zer)] <- mean( zer,na.rm = T ); return( zer ) 40 | }) 41 | 42 | # select observation subset to make the tutorial quick 43 | set.seed(4321L);take_indices <- unlist( tapply(1:length(obsW),obsW,function(zer){sample(zer, 100)}) ) 44 | 45 | # example acquire image function (loading from memory) 46 | # in general, you'll want to write a function that returns images 47 | # that saved disk associated with keys 48 | acquireImageFxn <- function(keys){ 49 | # here, the function input keys 50 | # refers to the unit-associated image keys 51 | # we also tweak the image dimensions for testing purposes 52 | #m_ <- FullImageArray[match(keys, KeysOfImages),c(1:35,1:35),c(1:35,1:35),1:2] # test with two channels 53 | m_ <- FullImageArray[match(keys, KeysOfImages),c(1:35,1:35),c(1:35,1:35),] # test with three channels 54 | #m_ <- FullImageArray[match(keys, KeysOfImages),c(1:35,1:35),c(1:35,1:35),c(1:3,1:2)] # test with five channels 55 | 56 | # if keys == 1, add the batch dimension so output dims are always consistent 57 | # (here in image case, dims are batch by height by width by channel) 58 | if(length(keys) == 1){ m_ <- array(m_,dim = c(1L,dim(m_)[1],dim(m_)[2],dim(m_)[3])) } 59 | 60 | return( m_ ) 61 | } 62 | 63 | # perform causal inference with image and tabular confounding 64 | if(T == T){ 65 | # write tf record 66 | TFRecordName_im <- "./TutorialData_im.tfrecord" 67 | if( reSaveTFRecord ){ 68 | causalimages::WriteTfRecord( 69 | file = TFRecordName_im, 70 | uniqueImageKeys = unique(KeysOfObservations[ take_indices ]), 71 | acquireImageFxn = acquireImageFxn 72 | ) 73 | } 74 | 75 | # perform causal inference with image-based and tabular confounding 76 | if(T == T){ 77 | for(imageModelClass in (c("VisionTransformer","CNN"))){ 78 | for(optimizeImageRep in c(T,F)){ 79 | print(sprintf("Image confounding analysis & optimizeImageRep: %s & imageModelClass: %s",optimizeImageRep, imageModelClass)) 80 | ImageConfoundingAnalysis <- causalimages::AnalyzeImageConfounding( 81 | obsW = obsW[ take_indices ], 82 | obsY = obsY[ take_indices ], 83 | X = X[ take_indices,apply(X[ take_indices,],2,sd)>0], 84 | long = LongLat$geo_long[ take_indices ], # optional argument 85 | lat = LongLat$geo_lat[ take_indices ], # optional argument 86 | imageKeysOfUnits = KeysOfObservations[ take_indices ], 87 | file = TFRecordName_im, 88 | 89 | batchSize = 16L, 90 | nBoot = 5L, 91 | optimizeImageRep = optimizeImageRep, 92 | imageModelClass = imageModelClass, 93 | nDepth_ImageRep = ifelse(optimizeImageRep, yes = 1L, no = 1L), 94 | nWidth_ImageRep = as.integer(2L^6), 95 | learningRateMax = 0.001, nSGD = 300L, # 96 | dropoutRate = NULL, # 0.1, 97 | plotBands = c(1,2,3), 98 | plotResults = T, figuresTag = "ConfoundingImTutorial", 99 | figuresPath = "./") 100 | try(dev.off(), T) 101 | #ImageConfoundingAnalysis$ModelEvaluationMetrics 102 | } 103 | } 104 | # ATE estimate (image confounder adjusted) 105 | ImageConfoundingAnalysis$tauHat_propensityHajek 106 | 107 | # ATE se estimate (image confounder adjusted) 108 | ImageConfoundingAnalysis$tauHat_propensityHajek_se 109 | 110 | # some out-of-sample evaluation metrics 111 | ImageConfoundingAnalysis$ModelEvaluationMetrics 112 | } 113 | } 114 | 115 | # perform causal inference with image sequence and tabular confounding 116 | gc() 117 | if(T == T){ 118 | acquireVideoRep <- function(keys) { 119 | # Note: this is a toy function generating image representations 120 | # that simply reuse a single temporal slice. In practice, we will 121 | # weant to read in images of different time periods. 122 | 123 | # Get image data as an array from disk 124 | tmp <- acquireImageFxn(keys) 125 | 126 | # Expand dimensions: we create a new dimension at the start 127 | tmp <- array(tmp, dim = c(1, dim(tmp))) 128 | 129 | # Transpose dimensions to get the target order 130 | tmp <- aperm(tmp, c(2, 1, 3, 4, 5)) 131 | 132 | # Swap image dimensions to see variability across time 133 | tmp_ <- aperm(tmp, c(1, 2, 4, 3, 5)) 134 | 135 | # Concatenate along the second axis 136 | tmp <- abind::abind(tmp, tmp, tmp_, tmp_, along = 2) 137 | 138 | return(tmp) 139 | } 140 | 141 | # sanity check dimensions 142 | print(dim(acquireVideoRep(unique(KeysOfObservations[ take_indices ])[1:2]))) 143 | 144 | # write tf record 145 | TFRecordName_imSeq <- "./TutorialData_imSeq.tfrecord" 146 | if( reSaveTFRecord ){ 147 | causalimages::WriteTfRecord( 148 | file = TFRecordName_imSeq, 149 | uniqueImageKeys = unique(KeysOfObservations[ take_indices ]), 150 | acquireImageFxn = acquireVideoRep, 151 | writeVideo = TRUE) 152 | } 153 | 154 | for(imageModelClass in c("VisionTransformer","CNN")){ 155 | for(optimizeImageRep in c(T, F)){ 156 | print(sprintf("Image seq confounding analysis & optimizeImageRep: %s & imageModelClass: %s",optimizeImageRep, imageModelClass)) 157 | ImageSeqConfoundingAnalysis <- causalimages::AnalyzeImageConfounding( 158 | obsW = obsW[ take_indices ], 159 | obsY = obsY[ take_indices ], 160 | X = X[ take_indices,apply(X[ take_indices,],2,sd)>0], 161 | long = LongLat$geo_long[ take_indices ], 162 | lat = LongLat$geo_lat[ take_indices ], 163 | file = TFRecordName_imSeq, dataType = "video", 164 | imageKeysOfUnits = KeysOfObservations[ take_indices ], 165 | 166 | # model specifics 167 | batchSize = 16L, 168 | optimizeImageRep = optimizeImageRep, 169 | imageModelClass = imageModelClass, 170 | nDepth_ImageRep = ifelse(optimizeImageRep, yes = 1L, no = 1L), 171 | nWidth_ImageRep = as.integer(2L^7), 172 | learningRateMax = 0.001, nSGD = 300L, # 173 | nBoot = 5L, 174 | plotBands = c(1,2,3), 175 | plotResults = T, figuresTag = "ConfoundingImSeqTutorial", 176 | figuresPath = "./") # figures saved here 177 | try(dev.off(), T) 178 | } 179 | } 180 | 181 | # ATE estimate (image confounder adjusted) 182 | ImageSeqConfoundingAnalysis$tauHat_propensityHajek 183 | 184 | # ATE se estimate (image seq confounder adjusted) 185 | ImageSeqConfoundingAnalysis$tauHat_propensityHajek_se 186 | 187 | # some out-of-sample evaluation metrics 188 | ImageSeqConfoundingAnalysis$ModelEvaluationMetrics 189 | print("Done with confounding tutorial!") 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /tutorials/AnalyzeImageConfounding_Tutorial_Base.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | { 3 | ################################ 4 | # Image confounding tutorial using causalimages 5 | ################################ 6 | 7 | # clean workspace 8 | rm(list=ls()); options(error = NULL) 9 | 10 | # setup environment 11 | if(Sys.getenv()["RSTUDIO_USER_IDENTITY"] == "cjerzak"){ 12 | setwd("~/Downloads/") 13 | } 14 | if(Sys.getenv()["RSTUDIO_USER_IDENTITY"] != "cjerzak"){ 15 | setwd("./") 16 | # or set directory as desired 17 | } 18 | 19 | # remote install latest version of the package if needed 20 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages") 21 | 22 | # local install for development team 23 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F) 24 | 25 | # build backend you haven't ready (run this only once upon (re)installing causalimages!) 26 | # causalimages::BuildBackend() 27 | 28 | # load in package 29 | library( causalimages ) 30 | 31 | # resave TfRecords? 32 | reSaveTFRecord <- TRUE 33 | 34 | # load in tutorial data 35 | data( CausalImagesTutorialData ) 36 | 37 | # mean imputation for toy example in this tutorial 38 | X <- apply(X[,-1],2,function(zer){ 39 | zer[is.na(zer)] <- mean( zer,na.rm = T ); return( zer ) 40 | }) 41 | 42 | # select observation subset to make the tutorial quick 43 | set.seed(4321L);take_indices <- 44 | unlist( tapply(1:length(obsW),obsW,function(zer){sample(zer, 300)}) ) 45 | 46 | # perform causal inference with image and tabular confounding 47 | { 48 | # example acquire image function (loading from memory) 49 | # in general, you'll want to write a function that returns images 50 | # that saved disk associated with keys 51 | acquireImageFxn <- function(keys){ 52 | # here, the function input keys 53 | # refers to the unit-associated image keys 54 | # we also tweak the image dimensions for testing purposes 55 | #m_ <- FullImageArray[match(keys, KeysOfImages),c(1:35,1:35),c(1:35,1:35),1:2] # test with two channels 56 | m_ <- FullImageArray[match(keys, KeysOfImages),c(1:35,1:35),c(1:35,1:35),] # test with three channels 57 | #m_ <- FullImageArray[match(keys, KeysOfImages),c(1:35,1:35),c(1:35,1:35),c(1:3,1:2)] # test with five channels 58 | 59 | # if keys == 1, add the batch dimension so output dims are always consistent 60 | # (here in image case, dims are batch by height by width by channel) 61 | if(length(keys) == 1){ m_ <- array(m_,dim = c(1L,dim(m_)[1],dim(m_)[2],dim(m_)[3])) } 62 | 63 | return( m_ ) 64 | } 65 | 66 | # look at one of the images 67 | causalimages::image2( FullImageArray[1,,,1] ) 68 | 69 | # write tf record 70 | # AnalyzeImageConfounding can efficiently stream batched image data from disk 71 | # (avoiding repeated in-memory loads and speeding up I/O during model training) 72 | TFRecordName_im <- "./TutorialData_im.tfrecord" 73 | if( reSaveTFRecord ){ 74 | causalimages::WriteTfRecord( 75 | file = TFRecordName_im, 76 | uniqueImageKeys = unique(KeysOfObservations[ take_indices ]), 77 | acquireImageFxn = acquireImageFxn 78 | ) 79 | } 80 | 81 | # perform causal inference with image-based and tabular confounding 82 | imageModelClass <- "VisionTransformer" 83 | optimizeImageRep <- TRUE # train the model to predict treatment, for use in IPW 84 | print(sprintf("Image confounding analysis & optimizeImageRep: %s & imageModelClass: %s",optimizeImageRep, imageModelClass)) 85 | ImageConfoundingAnalysis <- causalimages::AnalyzeImageConfounding( 86 | # input data 87 | obsW = obsW[ take_indices ], 88 | obsY = obsY[ take_indices ], 89 | X = X[ take_indices,apply(X[ take_indices,],2,sd)>0], 90 | long = LongLat$geo_long[ take_indices ], # optional argument 91 | lat = LongLat$geo_lat[ take_indices ], # optional argument 92 | imageKeysOfUnits = KeysOfObservations[ take_indices ], 93 | file = TFRecordName_im, 94 | 95 | # modeling parameters 96 | batchSize = 16L, 97 | nBoot = 5L, 98 | optimizeImageRep = optimizeImageRep, 99 | imageModelClass = imageModelClass, 100 | nDepth_ImageRep = ifelse(optimizeImageRep, yes = 1L, no = 1L), 101 | nWidth_ImageRep = as.integer(2L^6), 102 | learningRateMax = 0.001, nSGD = 300L, # 103 | dropoutRate = NULL, # 0.1, 104 | plotBands = c(1,2,3), 105 | plotResults = T, figuresTag = "ConfoundingImTutorial", 106 | figuresPath = "./") 107 | try(dev.off(), T) 108 | 109 | # Analyze in/out sample metrics 110 | ImageConfoundingAnalysis$ModelEvaluationMetrics 111 | 112 | # ATE estimate (image confounder adjusted) 113 | ImageConfoundingAnalysis$tauHat_propensityHajek 114 | 115 | # ATE se estimate (image confounder adjusted) 116 | ImageConfoundingAnalysis$tauHat_propensityHajek_se 117 | 118 | # some out-of-sample evaluation metrics 119 | ImageConfoundingAnalysis$ModelEvaluationMetrics 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /tutorials/AnalyzeImageConfounding_Tutorial_EmbeddingsOnly.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | { 3 | ################################ 4 | # Image tutorial using embeddings 5 | ################################ 6 | 7 | # clean workspace 8 | rm(list=ls()); options(error = NULL) 9 | 10 | # setup environment 11 | setwd("~/Downloads/") 12 | 13 | # fetch image embeddings (requires an Internet connection) 14 | m_embeddings <- read.csv("https://huggingface.co/datasets/cjerzak/PCI_TutorialMaterial/resolve/main/nDepthIS1_analysisTypeISheterogeneity_imageModelClassISVisionTransformer_optimizeImageRepISclip-rsicd_MaxImageDimsIS64_dataTypeISimage_monte_iIS1_applicationISUganda_perturbCenterISFALSE.csv") 15 | 16 | # load data 17 | library(causalimages) 18 | data( CausalImagesTutorialData ) 19 | 20 | # image embedding dimensions - one image for each set of geo-located units 21 | # geo-location done by village name 22 | # pre-treatment image covariates via embeddings 23 | # obtained from an EO-fined tuned CLIP model 24 | # https://huggingface.co/flax-community/clip-rsicd-v2 25 | dim(m_embeddings) 26 | 27 | # correlation matrix 28 | cor(m_embeddings) 29 | 30 | # analyze via pca 31 | pca_anlaysis <- predict(prcomp(m_embeddings,scale = T, center = T)) 32 | 33 | # first two principal components 34 | plot(pca_anlaysis[,1:2]) 35 | 36 | # treatment indicator (recipient of YOP cash transfer) 37 | obsW 38 | 39 | # outcome - measure of human capital post intervention 40 | obsY 41 | 42 | # run double ML for image deconfounding 43 | # install.packages(c("DoubleML","mlr3","ranger")) 44 | library(DoubleML) # double/debiased ML framework 45 | library(mlr3) # core mlr3 infrastructure 46 | library(mlr3learners) # access a wide range of learners 47 | library(ranger) 48 | 49 | # Combine outcome, treatment, and image embeddings into a single data.frame 50 | df_dml <- data.frame( 51 | Y = obsY, # outcome vector 52 | W = obsW, # treatment indicator 53 | m_embeddings # precomputed image embeddings (one column per embedding dimension) 54 | ) 55 | 56 | # Create a DoubleMLData object 57 | dml_data <- DoubleMLData$new( 58 | data = df_dml, 59 | y_col = "Y", 60 | d_cols = "W", 61 | x_cols = colnames(m_embeddings) 62 | ) 63 | 64 | # Specify learners for the nuisance functions 65 | learner_g <- lrn("regr.ranger") # regression learner for E[Y|X,W] 66 | learner_m <- lrn("classif.ranger", predict_type = "prob") # classification learner for P[W=1|X] 67 | 68 | # Instantiate the partially linear regression (PLR) DML model 69 | dml_plr <- DoubleMLPLR$new( 70 | dml_data, 71 | ml_g = learner_g, 72 | ml_m = learner_m, 73 | n_folds = 5 # number of folds for cross-fitting 74 | ) 75 | 76 | # fit model 77 | dml_plr$fit() 78 | 79 | # extract results 80 | dml_plr$summary() # prints estimated ATE and standard error 81 | 82 | } 83 | -------------------------------------------------------------------------------- /tutorials/AnalyzeImageConfounding_Tutorial_Simulation.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | { 3 | ################################ 4 | # Simulated causal data with images using causalimages 5 | ################################ 6 | 7 | # clean workspace 8 | rm(list=ls()); options(error = NULL) 9 | 10 | # setup environment 11 | # setwd as needed 12 | 13 | # Install causalimages if not installed 14 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages") 15 | 16 | # load package 17 | library(causalimages) 18 | 19 | # Build backend if not already done 20 | # causalimages::BuildBackend() 21 | 22 | # Simulation parameters 23 | n <- 1000 # number of units 24 | tau <- 2.0 # true ATE (constant treatment effect) 25 | image_dim <- 64L # image size 26 | 27 | # Generate synthetic data 28 | set.seed(12321) 29 | 30 | # Latent confounder C that generates images 31 | C <- rnorm(n) 32 | 33 | # Tabular confounder X 34 | X <- matrix(rnorm(n * 2), ncol = 2) 35 | colnames(X) <- c("X1", "X2") 36 | 37 | # Note: we have a dual confounder setup here (tabular + image confounders) 38 | 39 | # Treatment assignment: depends on C and X 40 | logit_prob <- C + X[,1] + X[,2] 41 | prob_W <- 1 / (1 + exp(-logit_prob)) 42 | obsW <- rbinom(n, 43 | size = 1, 44 | prob = prob_W) 45 | 46 | # Outcome: depends on both confounders and treatment and noise 47 | obsY <- tau * obsW + C + X[,1] + X[,2] + rnorm(n) 48 | 49 | # Generate synthetic images based on C 50 | # For simplicity, create images where intensity depends on C 51 | # We'll make a 32x32x3 array, with red channel proportional to C 52 | KeysOfImages <- paste0("img_", 1:n) # unique keys 53 | FullImageArray <- array(0, dim = c(n, image_dim, image_dim, 3)) 54 | for(i in 1:n) { 55 | base_intensity <- (C[i] - min(C)) / (max(C) - min(C)) # normalize to [0,1] 56 | FullImageArray[i,,,1] <- base_intensity # red channel 57 | FullImageArray[i,,,2] <- runif(1) # green random 58 | FullImageArray[i,,,3] <- runif(1) # blue random 59 | # Add a non-causal pattern to image, e.g., a gradient 60 | for(row in 1:image_dim) { 61 | FullImageArray[i, row,,] <- FullImageArray[i, row,,] * log(1+row / image_dim)^runif(n=1,min=0,max=1/i) 62 | } 63 | } 64 | 65 | # Keys of observations 66 | KeysOfObservations <- KeysOfImages # one-to-one for simplicity 67 | 68 | # Define acquireImageFxn 69 | acquireImageFxn <- function(keys) { 70 | m_ <- FullImageArray[match(keys, KeysOfImages),,,] 71 | if(length(keys) == 1) { 72 | m_ <- array(m_, dim = c(1L, dim(m_)[1], dim(m_)[2], dim(m_)[3])) 73 | } 74 | return(m_) 75 | } 76 | 77 | # run once 78 | # causalimages::BuildBackend() 79 | 80 | # Look at one image 81 | causalimages::image2(FullImageArray[sample(1:1000,1),,,1]) 82 | 83 | # Write TFRecord (optional, but as in tutorial) 84 | reSaveTFRecord <- TRUE 85 | TFRecordName_im <- "~/Downloads/SimulatedData_im.tfrecord" 86 | if(reSaveTFRecord) { 87 | causalimages::WriteTfRecord( 88 | file = TFRecordName_im, 89 | uniqueImageKeys = unique(KeysOfObservations), 90 | acquireImageFxn = acquireImageFxn 91 | ) 92 | } 93 | 94 | # Perform causal inference with image and tabular confounding 95 | imageModelClass <- "VisionTransformer" 96 | optimizeImageRep <- TRUE 97 | print(sprintf("Image confounding analysis & optimizeImageRep: %s & imageModelClass: %s", optimizeImageRep, imageModelClass)) 98 | 99 | ImageConfoundingAnalysis <- causalimages::AnalyzeImageConfounding( 100 | # input data 101 | obsW = obsW, 102 | obsY = obsY, 103 | X = X[, apply(X, 2, sd) > 0], 104 | imageKeysOfUnits = KeysOfObservations, 105 | file = TFRecordName_im, 106 | 107 | # modeling parameters 108 | batchSize = 16L, 109 | nBoot = 5L, 110 | optimizeImageRep = TRUE, 111 | imageModelClass = imageModelClass, 112 | nDepth_ImageRep = 4L, 113 | nWidth_ImageRep = as.integer(2^8), 114 | learningRateMax = 0.0001, nSGD = 300L, 115 | dropoutRate = 0.1, 116 | plotBands = c(1,2,3), 117 | plotResults = TRUE, figuresTag = "SimConfoundingIm", 118 | figuresPath = "./" 119 | ) 120 | try(dev.off(), TRUE) 121 | 122 | # Output results 123 | print("True ATE:") 124 | print(tau) 125 | 126 | print("Estimated ATE (no confounder adjustment):") 127 | print(ImageConfoundingAnalysis$tauHat_diffInMeans) 128 | 129 | print("Estimated ATE (image+tabular confounder adjusted):") 130 | print(ImageConfoundingAnalysis$tauHat_propensityHajek) 131 | 132 | 133 | print("Estimated SE:") 134 | print(ImageConfoundingAnalysis$tauHat_propensityHajek_se) 135 | 136 | print("Model Evaluation Metrics:") 137 | print(ImageConfoundingAnalysis$ModelEvaluationMetrics) 138 | 139 | # Comparison 140 | bias_naive <- ImageConfoundingAnalysis$tauHat_diffInMeans - tau 141 | print("Bias (diff in means):") 142 | print(bias_naive) # around 2 143 | 144 | bias_adjusted <- ImageConfoundingAnalysis$tauHat_propensityHajek - tau 145 | print("Bias (image+tabular deconfounding):") 146 | print(bias_adjusted) # around 0.4 147 | } 148 | -------------------------------------------------------------------------------- /tutorials/AnalyzeImageHeterogeneity_Tutorial.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | { 3 | ################################ 4 | # Image heterogeneity tutorial using causalimages 5 | ################################ 6 | 7 | # clean environment 8 | rm(list = ls()); options( error = NULL ) 9 | 10 | # remote install latest version of the package 11 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages") 12 | 13 | # local install for development team 14 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F) 15 | 16 | # build backend you haven't ready: 17 | # causalimages::BuildBackend() 18 | 19 | # run code if downloading data for the first time 20 | download_folder <- "~/Downloads/UgandaAnalysis.zip" 21 | reSaveTfRecords <- T 22 | if( reDownloadRawData <- F ){ 23 | # specify uganda data URL 24 | uganda_data_url <- "https://dl.dropboxusercontent.com/s/xy8xvva4i46di9d/Public%20Replication%20Data%2C%20YOP%20Experiment.zip?dl=0" 25 | download_folder <- "~/Downloads/UgandaAnalysis.zip" 26 | 27 | # download into new directory 28 | download.file( uganda_data_url, destfile = download_folder) 29 | 30 | # unzip and list files 31 | unzip(download_folder, exdir = "~/Downloads/UgandaAnalysis") 32 | } 33 | 34 | # load in package 35 | library( causalimages ) 36 | 37 | # set new wd 38 | setwd(sprintf('%s/Public Replication Data, YOP Experiment/', 39 | gsub(download_folder,pattern="\\.zip",replace=""))) 40 | 41 | # see directory contents 42 | list.files() 43 | 44 | # images saved here 45 | list.files( "./Uganda2000_processed" ) 46 | 47 | # individual-level data 48 | UgandaDataProcessed <- read.csv( "./UgandaDataProcessed.csv" ) 49 | 50 | # unit-level covariates (many covariates are subject to missingness!) 51 | dim( UgandaDataProcessed ) 52 | table( UgandaDataProcessed$age ) 53 | 54 | # approximate longitude + latitude for units 55 | head( cbind(UgandaDataProcessed$geo_long, UgandaDataProcessed$geo_lat) ) 56 | 57 | # image keys of units (use for referencing satellite images) 58 | UgandaDataProcessed$geo_long_lat_key 59 | 60 | # an experimental outcome 61 | UgandaDataProcessed$Yobs 62 | 63 | # treatment variable 64 | UgandaDataProcessed$Wobs 65 | 66 | # information on keys linking to satellite images for all of Uganda 67 | # (not just experimental context, use for constructing transportability maps) 68 | UgandaGeoKeyMat <- read.csv( "./UgandaGeoKeyMat.csv" ) 69 | 70 | # set outcome 71 | UgandaDataProcessed$Yobs <- UgandaDataProcessed$human_capital_index_e 72 | 73 | # drop observations with NAs in key variables 74 | # (you can also use a multiple imputation strategy) 75 | UgandaDataProcessed <- UgandaDataProcessed[!is.na(UgandaDataProcessed$Yobs) & 76 | !is.na(UgandaDataProcessed$Wobs) & 77 | !is.na(UgandaDataProcessed$geo_lat) , ] 78 | 79 | # sanity checks 80 | { 81 | # write a function that reads in images as saved and process them into an array 82 | NBANDS <- 3L 83 | imageHeight <- imageWidth <- 351L # pixel height/width 84 | acquireImageRep <- function( keys ){ 85 | # keys <- unique(UgandaDataProcessed$geo_long_lat_key)[1:5] 86 | # initialize an array shell to hold image slices 87 | array_shell <- array(NA, dim = c(1L, imageHeight, imageWidth, NBANDS)) 88 | # dim(array_shell) 89 | 90 | # iterate over keys: 91 | # -- images are referenced to keys 92 | # -- keys are referenced to units (to allow for duplicate images uses) 93 | array_ <- sapply(keys, function(key_) { 94 | # iterate over all image bands (NBANDS = 3 for RBG images) 95 | for (band_ in 1:NBANDS) { 96 | # place the image in the correct place in the array 97 | array_shell[,,,band_] <- 98 | as.matrix(data.table::fread( 99 | input = sprintf("./Uganda2000_processed/GeoKey%s_BAND%s.csv", key_, band_), header = FALSE)[-1,]) 100 | } 101 | return(array_shell) 102 | }, simplify = "array") 103 | 104 | # return the array in the format c(nBatch, imageWidth, imageHeight, nChannels) 105 | # ensure that the dimensions are correctly ordered for further processing 106 | if(length(keys) > 1){ array_ <- aperm(array_[1,,,,], c(4, 1, 2, 3) ) } 107 | if(length(keys) == 1){ 108 | array_ <- aperm(array_, c(1,5, 2, 3, 4)) 109 | array_ <- array(array_, dim(array_)[-1]) 110 | } 111 | return(array_) 112 | } 113 | 114 | # try out the function 115 | # note: some units are co-located in same area (hence, multiple observations per image key) 116 | ImageBatch <- acquireImageRep( UgandaDataProcessed$geo_long_lat_key[ check_indices <- c(1, 20, 50, 101) ]) 117 | acquireImageRep( UgandaDataProcessed$geo_long_lat_key[ check_indices[1:2] ] ) 118 | 119 | # sanity checks in the analysis of earth observation data are essential 120 | # check that images are centered around correct location 121 | causalimages::image2( as.array(ImageBatch)[1,,,1] ) 122 | UgandaDataProcessed$geo_long[check_indices[1]] 123 | UgandaDataProcessed$geo_lat[check_indices[1]] 124 | # check against google maps to confirm correctness 125 | # https://www.google.com/maps/place/1%C2%B018'16.4%22N+34%C2%B005'15.1%22E/@1.3111951,34.0518834,10145m/data=!3m1!1e3!4m4!3m3!8m2!3d1.3045556!4d34.0875278?entry=ttu 126 | 127 | # scramble data (important for reading into causalimages::WriteTfRecord 128 | # to ensure no systematic biases in data sequence with model training 129 | set.seed(144L); UgandaDataProcessed <- UgandaDataProcessed[sample(1:nrow(UgandaDataProcessed)),] 130 | } 131 | 132 | # Image heterogeneity example 133 | if(T == ){ 134 | # write a tf records repository 135 | # whenever changes are made to the input data to AnalyzeImageHeterogeneity, WriteTfRecord() should be re-run 136 | # to ensure correct ordering of data 137 | tfrecord_loc <- "~/Downloads/UgandaExample.tfrecord" 138 | if( reSaveTfRecords ){ 139 | causalimages::WriteTfRecord( 140 | file = tfrecord_loc, 141 | uniqueImageKeys = unique(UgandaDataProcessed$geo_long_lat_key), 142 | acquireImageFxn = acquireImageRep ) 143 | } 144 | 145 | for(ImageModelClass in c("VisionTransformer","CNN")){ 146 | for(optimizeImageRep in c(T, F)){ 147 | print(sprintf("Image hetero analysis & optimizeImageRep: %s",optimizeImageRep)) 148 | # perform image heterogeneity analysis (toy example) 149 | ImageHeterogeneityResults <- causalimages::AnalyzeImageHeterogeneity( 150 | # data inputs 151 | obsW = UgandaDataProcessed$Wobs, 152 | obsY = UgandaDataProcessed$Yobs, 153 | X = matrix(rnorm(length(UgandaDataProcessed$Yobs)*10),ncol=10), 154 | imageKeysOfUnits = UgandaDataProcessed$geo_long_lat_key, 155 | file = tfrecord_loc, # location of tf record (use absolute file paths) 156 | lat = UgandaDataProcessed$geo_lat, # not required but helpful for dealing with redundant locations in EO data 157 | long = UgandaDataProcessed$geo_long, # not required but helpful for dealing with redundant locations in EO data 158 | 159 | # inputs to control where visual results are saved as PDF or PNGs 160 | # (these image grids are large and difficult to display in RStudio's interactive mode) 161 | plotResults = T, 162 | figuresPath = "~/Downloads/HeteroTutorial", # where to write analysis figures 163 | figuresTag = "HeterogeneityImTutorial",plotBands = 1L:3L, 164 | 165 | # optional arguments for generating transportability maps 166 | # here, we leave those NULL for simplicity 167 | transportabilityMat = NULL, # 168 | 169 | # other modeling options 170 | imageModelClass = ImageModelClass, 171 | nSGD = 5L, # make this larger for real applications (e.g., 2000L) 172 | nDepth_ImageRep = ifelse(optimizeImageRep, yes = 1L, no = 1L), 173 | nWidth_ImageRep = as.integer(2L^6), 174 | optimizeImageRep = optimizeImageRep, 175 | batchSize = 8L, # make this larger for real application (e.g., 50L) 176 | kClust_est = 2 # vary depending on problem. Usually < 5 177 | ) 178 | try(dev.off(), T) 179 | } 180 | } 181 | } 182 | 183 | # image sequence example 184 | if(T == T){ 185 | acquireVideoRep <- function(keys) { 186 | # Get image data as an array from disk 187 | tmp <- acquireImageRep(keys) 188 | 189 | # Expand dimensions: we create a new dimension at the start 190 | tmp <- array(tmp, dim = c(1, dim(tmp))) 191 | 192 | # Transpose dimensions to get the required order 193 | tmp <- aperm(tmp, c(2, 1, 3, 4, 5)) 194 | 195 | # Swap image dimensions to see variability across time 196 | tmp_ <- aperm(tmp, c(1, 2, 4, 3, 5)) 197 | 198 | # Concatenate along the second axis 199 | tmp <- abind::abind(tmp, tmp_, along = 2) 200 | 201 | return(tmp) 202 | } 203 | 204 | # write the tf records repository 205 | tfrecord_loc_imSeq <- "~/Downloads/UgandaExampleVideo.tfrecord" 206 | if(reSaveTfRecords){ 207 | causalimages::WriteTfRecord( file = tfrecord_loc_imSeq, 208 | uniqueImageKeys = unique(UgandaDataProcessed$geo_long_lat_key), 209 | acquireImageFxn = acquireVideoRep, writeVideo = T ) 210 | } 211 | 212 | for(ImageModelClass in (c("VisionTransformer","CNN"))){ 213 | for(optimizeImageRep in c(T, F)){ 214 | print(sprintf("Image seq hetero analysis & optimizeImageRep: %s",optimizeImageRep)) 215 | # Note: optimizeImageRep = T breaks with video on METAL framework 216 | VideoHeterogeneityResults <- causalimages::AnalyzeImageHeterogeneity( 217 | # data inputs 218 | obsW = UgandaDataProcessed$Wobs, 219 | obsY = UgandaDataProcessed$Yobs, 220 | imageKeysOfUnits = UgandaDataProcessed$geo_long_lat_key, 221 | file = tfrecord_loc_imSeq, # location of tf record (absolute paths are safest) 222 | dataType = "video", 223 | lat = UgandaDataProcessed$geo_lat, # not required but helpful for dealing with redundant locations in EO data 224 | long = UgandaDataProcessed$geo_long, # not required but helpful for dealing with redundant locations in EO data 225 | 226 | # inputs to control where visual results are saved as PDF or PNGs 227 | # (these image grids are large and difficult to display in RStudio's interactive mode) 228 | plotResults = T, 229 | figuresPath = "~/Downloads/HeteroTutorial", 230 | plotBands = 1L:3L, figuresTag = "HeterogeneityImSeqTutorial", 231 | 232 | # optional arguments for generating transportability maps 233 | # here, we leave those NULL for simplicity 234 | transportabilityMat = NULL, # 235 | 236 | # other modeling options 237 | imageModelClass = ImageModelClass, 238 | nSGD = 5L, # make this larger for real applications (e.g., 2000L) 239 | nDepth_ImageRep = ifelse(optimizeImageRep, yes = 1L, no = 1L), 240 | nWidth_ImageRep = as.integer(2L^5), 241 | optimizeImageRep = optimizeImageRep, 242 | kClust_est = 2, # vary depending on problem. Usually < 5 243 | batchSize = 8L, # make this larger for real application (e.g., 50L) 244 | strides = 2L ) 245 | try(dev.off(), T) 246 | } 247 | } 248 | } 249 | 250 | causalimages::print2("Done with image heterogeneity tutorial!") 251 | } 252 | -------------------------------------------------------------------------------- /tutorials/BuildBackend_Tutorial.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | { 3 | # clear workspace 4 | rm(list = ls()); options(error = NULL) 5 | 6 | # remote install latest version of the package 7 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages") 8 | 9 | # local install for development team 10 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F) 11 | 12 | # in general, you will simply use: 13 | causalimages::BuildBackend() 14 | 15 | # Note: This function requires an Internet connection 16 | # Note: With default arguments, a conda environment called 17 | # "CausalImagesEnv" will be created with required packages saved within. 18 | 19 | # Advanced tip: 20 | # if you need to points to the location of a specific version of python 21 | # in conda where packages will be downloaded, 22 | # you can use: 23 | #causalimages::BuildBackend(conda = "/Users/cjerzak/miniforge3/bin/python") 24 | } 25 | 26 | -------------------------------------------------------------------------------- /tutorials/ExtractImageRepresentations_Tutorial.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | { 3 | ################################ 4 | # Image and image-sequence embeddings tutorial using causalimages 5 | ################################ 6 | 7 | # start with a clean environment 8 | rm(list=ls()); options(error = NULL) 9 | 10 | # remote install latest version of the package if needed 11 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages") 12 | 13 | # local install for development team 14 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F) 15 | 16 | # build backend you haven't ready: 17 | # causalimages::BuildBackend() 18 | 19 | # load in package 20 | library( causalimages ) 21 | 22 | # load in tutorial data 23 | data( CausalImagesTutorialData ) 24 | 25 | # example acquire image function (loading from memory) 26 | # in general, you'll want to write a function that returns images 27 | # that saved disk associated with keys 28 | acquireImageFromMemory <- function(keys){ 29 | # here, the function input keys 30 | # refers to the unit-associated image keys 31 | m_ <- FullImageArray[match(keys, KeysOfImages),,,] 32 | 33 | # if keys == 1, add the batch dimension so output dims are always consistent 34 | # (here in image case, dims are batch by height by width by channel) 35 | if(length(keys) == 1){ 36 | m_ <- array(m_,dim = c(1L,dim(m_)[1],dim(m_)[2],dim(m_)[3])) 37 | } 38 | return( m_ ) 39 | } 40 | 41 | # drop first column 42 | X <- X[,-1] 43 | 44 | # mean imputation for simplicity 45 | X <- apply(X,2,function(zer){ 46 | zer[is.na(zer)] <- mean( zer,na.rm = T ) 47 | return( zer ) 48 | }) 49 | 50 | # select observation subset to make tutorial analyses run fast 51 | take_indices <- unlist( tapply(1:length(obsW),obsW,function(zer){ sample(zer, 50) }) ) 52 | 53 | # write tf record 54 | TfRecord_name <- "~/Downloads/CausalImagesTutorialDat.tfrecord" 55 | causalimages::WriteTfRecord( file = TfRecord_name, 56 | uniqueImageKeys = unique( KeysOfObservations[ take_indices ] ), 57 | acquireImageFxn = acquireImageFromMemory ) 58 | 59 | # obtain image representation (random neural projection) 60 | MyImageEmbeddings_RandomProj <- causalimages::GetImageRepresentations( 61 | file = TfRecord_name, 62 | imageKeysOfUnits = KeysOfObservations[ take_indices ], 63 | nDepth_ImageRep = 1L, 64 | nWidth_ImageRep = 128L, 65 | CleanupEnv = T) 66 | 67 | # sanity check - # of rows in MyImageEmbeddings matches # of image keys 68 | nrow(MyImageEmbeddings_RandomProj$ImageRepresentations) == length(KeysOfObservations[ take_indices ]) 69 | 70 | # each row in MyImageEmbeddings$ImageRepresentations corresponds to an observation 71 | # each column represents an embedding dimension associated with the imagery for that location 72 | plot( MyImageEmbeddings_RandomProj$ImageRepresentations ) 73 | 74 | # obtain image representation (pre-trained ViT) 75 | MyImageEmbeddings_ViT <- causalimages::GetImageRepresentations( 76 | file = TfRecord_name, 77 | pretrainedModel = "vit-base", 78 | imageKeysOfUnits = KeysOfObservations[ take_indices ], 79 | CleanupEnv = T) 80 | 81 | # analyze ViT representations 82 | plot( MyImageEmbeddings_ViT$ImageRepresentations ) 83 | 84 | # obtain image representation (pre-trained CLIP-RCSID) 85 | MyImageEmbeddings_Clip <- causalimages::GetImageRepresentations( 86 | file = TfRecord_name, 87 | pretrainedModel = "clip-rsicd", 88 | imageKeysOfUnits = KeysOfObservations[ take_indices ], 89 | CleanupEnv = T) 90 | 91 | # analyze Clip representations 92 | plot( MyImageEmbeddings_Clip$ImageRepresentations ) 93 | 94 | # sanity check - # of rows in MyImageEmbeddings matches # of image keys 95 | nrow(MyImageEmbeddings_Clip$ImageRepresentations) == length(KeysOfObservations[ take_indices ]) 96 | 97 | # other output quantities include the image model functions and model parameters 98 | names( MyImageEmbeddings_Clip )[-1] 99 | 100 | print("Done with image representations tutorial!") 101 | } 102 | -------------------------------------------------------------------------------- /tutorials/UsingTfRecords_Tutorial.R: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env Rscript 2 | { 3 | # clear workspace 4 | rm(list=ls()); options(error = NULL) 5 | 6 | ################################ 7 | # Image confounding tutorial using causalimages & tfrecords 8 | 9 | # remote install latest version of the package if needed 10 | # devtools::install_github(repo = "cjerzak/causalimages-software/causalimages") 11 | 12 | # local install for development team 13 | # install.packages("~/Documents/causalimages-software/causalimages",repos = NULL, type = "source",force = F) 14 | 15 | # build backend you haven't ready: 16 | # causalimages::BuildBackend() 17 | 18 | # load in package 19 | library( causalimages ) 20 | 21 | # load in tutorial data 22 | data( CausalImagesTutorialData ) 23 | 24 | # example acquire image function (loading from memory) 25 | # in general, you'll want to write a function that returns images 26 | # that saved disk associated with keys 27 | acquireImageFromMemory <- function(keys){ 28 | # here, the function input keys 29 | # refers to the unit-associated image keys 30 | m_ <- FullImageArray[match(keys, KeysOfImages),,,] 31 | 32 | # if keys == 1, add the batch dimension so output dims are always consistent 33 | # (here in image case, dims are batch by height by width by channel) 34 | if(length(keys) == 1){ 35 | m_ <- array(m_,dim = c(1L,dim(m_)[1],dim(m_)[2],dim(m_)[3])) 36 | } 37 | 38 | # uncomment for a test with different image dimensions 39 | #if(length(keys) == 1){ m_ <- abind::abind(m_,m_,m_,along = 3L) }; if(length(keys) > 1){ m_ <- abind::abind(m_,m_,m_,.along = 4L) } 40 | return( m_ ) 41 | } 42 | 43 | dim( acquireImageFromMemory(KeysOfImages[1]) ) 44 | dim( acquireImageFromMemory(KeysOfImages[1:2]) ) 45 | 46 | # drop first column 47 | X <- X[,-1] 48 | 49 | # mean imputation for simplicity 50 | X <- apply(X,2,function(zer){ 51 | zer[is.na(zer)] <- mean( zer,na.rm = T ) 52 | return( zer ) 53 | }) 54 | 55 | # select observation subset to make tutorial analyses run faster 56 | # select 50 treatment and 50 control observations 57 | set.seed(1.) 58 | take_indices <- unlist( tapply(1:length(obsW),obsW,function(zer){sample(zer, 50)}) ) 59 | 60 | # !!! important note !!! 61 | # when using tf recordings, it is essential that the data inputs be pre-shuffled like is done here. 62 | # you can use a seed for reproducing the shuffle (so the tfrecord is correctly indexed and you don't need to re-make it) 63 | # tf records read data quasi-sequentially, so systematic patterns in the data ordering 64 | # reduce performance 65 | 66 | # uncomment for a larger n analysis 67 | #take_indices <- 1:length( obsY ) 68 | 69 | # set tfrecord save location (best to use absolute paths, but relative paths should in general work too) 70 | tfrecord_loc <- "~/Downloads/ExampleRecord.tfrecord" 71 | 72 | # write a tf records repository 73 | causalimages::WriteTfRecord( file = tfrecord_loc, 74 | uniqueImageKeys = unique( as.character(KeysOfObservations)[ take_indices ] ), 75 | acquireImageFxn = acquireImageFromMemory ) 76 | 77 | # perform causal inference with image and tabular confounding 78 | # toy example for illustration purposes where 79 | # treatment is truly randomized 80 | ImageConfoundingAnalysis <- causalimages::AnalyzeImageConfounding( 81 | obsW = obsW[ take_indices ], 82 | obsY = obsY[ take_indices ], 83 | X = X[ take_indices,apply(X[ take_indices,],2,sd)>0], 84 | long = LongLat$geo_long[ take_indices ], 85 | lat = LongLat$geo_lat[ take_indices ], 86 | batchSize = 16L, 87 | imageKeysOfUnits = as.character(KeysOfObservations)[ take_indices ], 88 | file = tfrecord_loc, # point to tfrecords file 89 | 90 | nSGD = 500L, 91 | imageModelClass = "CNN", 92 | #imageModelClass = "VisionTransformer", 93 | plotBands = c(1,2,3), 94 | figuresTag = "TutorialExample", 95 | figuresPath = "~/Downloads/TFRecordTutorial" # figures saved here (use absolute file paths) 96 | ) 97 | 98 | # ATE estimate (image confounder adjusted) 99 | ImageConfoundingAnalysis$tauHat_propensityHajek 100 | 101 | # ATE se estimate (image confounder adjusted) 102 | ImageConfoundingAnalysis$tauHat_propensityHajek_se 103 | 104 | # see figuresPath for image analysis output 105 | causalimages::print2("Done with TfRecords tutorial!") 106 | } 107 | --------------------------------------------------------------------------------