├── .editorconfig ├── .gitignore ├── .pylintrc ├── CHANGELOG.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SparseSC.sln ├── _config.yml ├── docs ├── api_ref.md ├── azure_batch.md ├── conf.py ├── dev_notes.md ├── estimate-effects.md ├── examples │ ├── estimate_effects.md │ ├── prospective_anomaly_detection.md │ ├── retrospecitve_restricted_synthetic_controls.md │ └── retrospecitve_synthetic_controls.md ├── fit.md ├── index.rst ├── model-types.md ├── overview.md ├── performance-notes.md ├── rtd-base.txt └── rtd-requirements.txt ├── example-code.py ├── examples ├── AATest.ipynb ├── DifferentialTrends.ipynb ├── __init__.py ├── example_graphs.py └── strip_magic.py ├── makefile ├── replication ├── .gitignore ├── figs │ ├── sparsesc_fast_pe.pdf │ ├── sparsesc_fast_xf2_pe.pdf │ ├── sparsesc_fast_xf_pe.pdf │ ├── sparsesc_full_pe.pdf │ ├── sparsesc_full_xf2_pe.pdf │ ├── sparsesc_full_xf_pe.pdf │ ├── standard_flat55_pe.pdf │ ├── standard_nested_pe.pdf │ ├── standard_spfast_pe.pdf │ ├── standard_spfast_xf2_pe.pdf │ ├── standard_spfast_xf_pe.pdf │ ├── standard_spfull_pe.pdf │ ├── standard_spfull_xf2_pe.pdf │ └── standard_spfull_xf_pe.pdf ├── repl2010.do ├── repl2010.log ├── sc2010.ipynb ├── smoking.dta ├── vmats │ ├── fast_fit.txt │ ├── full_fit.txt │ ├── xf_fits_fast.txt │ ├── xf_fits_fast2.txt │ ├── xf_fits_full.txt │ └── xf_fits_full2.txt └── w_pen_switch.ipynb ├── setup.py ├── src └── SparseSC │ ├── .pylintrc │ ├── SparseSC.pyproj │ ├── __init__.py │ ├── cli │ ├── __init__.py │ ├── daemon.py │ ├── daemon_process.py │ ├── scgrad.py │ └── stt.py │ ├── cross_validation.py │ ├── estimate_effects.py │ ├── fit.py │ ├── fit_ct.py │ ├── fit_fast.py │ ├── fit_fold.py │ ├── fit_loo.py │ ├── optimizers │ ├── __init__.py │ ├── cd_line_search.py │ └── simplex_step.py │ ├── tensor.py │ ├── utils │ ├── AzureBatch │ │ ├── __init__.py │ │ ├── aggregate_results.py │ │ ├── azure_batch_client.py │ │ ├── build_batch_job.py │ │ ├── constants.py │ │ └── gradient_batch_client.py │ ├── __init__.py │ ├── batch_gradient.py │ ├── descr_sets.py │ ├── dist_summary.py │ ├── local_grad_daemon.py │ ├── match_space.py │ ├── metrics_utils.py │ ├── misc.py │ ├── ols_model.py │ ├── penalty_utils.py │ ├── print_progress.py │ ├── sub_matrix_inverse.py │ └── warnings.py │ └── weights.py ├── static ├── controls-only-pre-and-post.png └── pre-only-controls-and-treated.png └── test ├── AzureBatch ├── README.md ├── test_batch_aggregate.py ├── test_batch_build.py └── test_batch_run.py ├── CausalImpact_test.R ├── SparseSC_36.yml ├── __init__.py ├── dgp ├── __init__.py ├── factor_model.py └── group_effects.py ├── test.pyproj ├── test_batchFile.py ├── test_estimation.py ├── test_fit.py ├── test_fit_batch.py ├── test_normal.py └── test_simulation.py /.editorconfig: -------------------------------------------------------------------------------- 1 | # EditorConfig helps developers define and maintain consistent coding styles between different editors and IDEs. http://EditorConfig.org 2 | 3 | # top-most EditorConfig file 4 | root = true 5 | 6 | # 4 space indentation 7 | [*.py] 8 | indent_style = space 9 | indent_size = 4 10 | insert_final_newline = true 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ignore testing data 2 | **/data/ 3 | replication/*.pkl 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Sphinx documentation 55 | docs/_build/ 56 | 57 | # Jupyter Notebook 58 | .ipynb_checkpoints 59 | 60 | # pyenv 61 | .python-version 62 | 63 | # Environments 64 | .env 65 | .venv 66 | env/ 67 | venv/ 68 | ENV/ 69 | env.bak/ 70 | venv.bak/ 71 | 72 | # Spyder project settings 73 | .spyderproject 74 | .spyproject 75 | 76 | # mypy 77 | .mypy_cache/ 78 | 79 | #Visual Studio 80 | .vs/ 81 | 82 | .vscode/ 83 | 84 | # own-generated content 85 | docs/SyntheticControlsReadme.* 86 | examples/*.pdf 87 | examples/*.html 88 | examples/DifferentialTrends.py 89 | 90 | #Jupyter Windows 91 | jupyter_launch.log 92 | 93 | tmp/ 94 | 95 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to this project will be documented in this file. 3 | 4 | The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) 5 | and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). 6 | 7 | 8 | ## [Unreleased](https://github.com/Microsoft/SparseSC/compare/v0.2.0...master) 9 | ### Added 10 | - Added `fit_args` to `estimate_effect()` that can be used with `fit_fast()` with the default variable weight algorithm (`sklearn`'s [MTLassoCV](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.MultiTaskLassoCV.html)). Common usages would include increasing `max_iter` (if you want to improve convergence) or `n_jobs` (if you want to run in parallel). 11 | - Separated CV folds from Cross-fitting folds in `estimate_effects()`. `cv_folds` controls the amount of additional estimation done on the control units (this used to be controled by `max_n_pl`, but that parameter now only governs the amount of post-estimation processing that is done). This will do `cv_folds` extra estimations per treatment time period, though if `=1` then no extra work will be done (but control residuals might be biased toward 0). 12 | - Added additional option `Y_col_block_size` to `MTLassoCV_MatchSpace_factory` to estimate `V` on block-averages of `Y` (e.g. taking a 150 cols down to 5 by doing averages over 30 cols at a time). 13 | - Added `se_factor` to `MTLassoCV_MatchSpace_factory` to use a different penalty than the MSE min. 14 | - For large data, approximate the outcomes using a normal distribution (`DescrSet`), and allow for calculating estimates. 15 | 16 | ## 0.2.0 - 2020-05-06 17 | ### Added 18 | - Added tools too to use `fit_fast()` with large datasets. This includes the `sample_frac` option to `MTLassoCV_MatchSpace_factory()` to estimate the match space on a subset of the observations. It also includes `fit_fast()` option `avoid_NxN_mats` which will avoid making large matrices (at the expense of only returning the Synthetic control `targets` and `targets_aux` and not the full weight matrix) 19 | - Added logging in `fit_fast` via the `verbose` numerical option. This can help identify out-of-memory errors. 20 | - Added a pseudo-Doubly robust match space maker `D_LassoCV_MatchSpace_factory`. It apporptions some of the normalized variable V weight to those variables that are good predictors of treatment. This should only be done if there are many treated units so that one can reasonably model this relationship. 21 | - Switched using standardized Azure Batch config library 22 | 23 | 24 | ## 0.1.0 - 2019-07-25 25 | Initial release. 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | MIT License 3 | 4 | Copyright (c) Microsoft Corporation. All rights reserved. 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Sparse Synthetic Controls](https://sparsesc.readthedocs.io/en/latest/) 2 | 3 | SparseSC is a package that implements an ML-enhanced version of Synthetic Control Methodologies, which introduces penalties on the both feature and unit weights. Specifically it optimizes these two equations: 4 | 5 | $$\gamma_0 = \left\Vert Y_T - W\cdot Y_C \right\Vert_F^2 + \lambda_V \left\Vert V \right\Vert_1 $$ 6 | 7 | $$\gamma_1 = \left\Vert X_T - W\cdot X_C \right\Vert_V + \lambda_W \left\Vert W - J/c \right\Vert_F^2 $$ 8 | 9 | by optimizing $\lambda_V$ and $\lambda_W$ using cross validation within the control units in the usual way, and where: 10 | 11 | - $X_T$ and $X_C$ are matrices of features (covariates and/or pre-treatement outcomes) for the treated and control units, respectively 12 | 13 | - $Y_T$ and $Y_C$ are matrices of post-treatement outcomes on the treated and control units, respectively 14 | 15 | - $W$ 16 | is a matrix of weights such that $W \cdot (X_C \left |Y_C \right)$ forms a matrix of synthetic controls for all units 17 | 18 | - $V$ is a diagnoal matrix of weights applied to the covariates / pre-treatment outcomes. 19 | 20 | Note that all matrices are formated with one row per unit and one column per feature / outcome. Breaking down the two main equations, we have: 21 | 22 | - $\left\Vert Y_T - W\cdot Y_C \right\Vert_F^2$ is the out-of-sample squared prediction error (i.e. the squared Frobenius Norm), measured within the control units under cross validation 23 | 24 | - $\left\Vert V \right\Vert_1$ represents how much the model depends on the features 25 | 26 | - $\left\Vert X_T - W\cdot X_C \right\Vert_V$ is the difference between synthetic and observed units in the the feature space weighted by $V$. Specifically, $\left\Vert A \right\Vert_V = AVA^T$ 27 | 28 | - $\left\Vert W - J/c \right\Vert_F^2$ is the difference between optimistic weights and simple ($1/N_c$) weighted averages of the control units. 29 | 30 | 31 | Typically this is used to estimate causal effects from binary treatments on observational panel (longitudinal) data. The functions `fit()` and `fit_fast()` provide basic fitting of the model. If you are estimating treatment effects, fitting and diagnostic information can be done via `estimate_effects()`. 32 | 33 | Though the fitting methods do not require such structure, the typical setup is where we have panel data of an outcome variable `Y` for `T` time periods for `N` observation units (customer, computers, etc.). We may additionally have some baseline characteristics `X` about the units. In the treatment effect setting, we will also have a discrete change in treatment status (e.g. some policy change) at time, `T0`, for a select group of units. When there is treatment, we can think of the pre-treatment data as [`X`, `Y_pre`] and post-treatment data as [`Y_post`]. 34 | 35 | ```py 36 | import SparseSC 37 | 38 | # Fit the model: 39 | treatment_unit_size = np.full((N), np.NaN) 40 | treatment_unit_size[treated_unit_idx] = T0 41 | fitted_estimates = SparseSC.estimate_effects(Y,unit_treatment_periods,...) 42 | 43 | # Print summary of the model including effect size estimates, 44 | # p-values, and confidendence intervals: 45 | print(fitted_estimates) 46 | 47 | # Extract model attributes: 48 | fitted_estimates.pl_res_post.avg_joint_effect.p_value 49 | fitted_estimates.pl_res_post.avg_joint_effect.CI 50 | 51 | # access the fitted Synthetic Controls model: 52 | fitted_model = fitted_estimates.fit 53 | ``` 54 | 55 | See [the docs](https://sparsesc.readthedocs.io/en/latest/) for more details 56 | 57 | ## Overview 58 | 59 | See [here](https://en.wikipedia.org/wiki/Synthetic_control_method) for more info on Synthetic Controls. In essence, it is a type of matching estimator. For each unit it will find a weighted average of untreated units that is similar on key pre-treatment data. The goal of Synthetic controls is find out which variables are important to match on (the `V` matrix) and then, given those, to find a vector of per-unit weights that combine the control units into its synthetic control. The synthetic control acts as the counterfactual for a unit, and the estimate of a treatment effect is the difference between the observed outcome in the post-treatment period and the synthetic control's outcome. 60 | 61 | SparseSC makes a number of changes to Synthetic Controls. It uses regularization and feature learning to avoid overfitting, ensure uniqueness of the solution, automate researcher decisions, and allow for estimation on large datasets. See the docs for more details. 62 | 63 | The main choices to make are: 64 | 1. The solution structure 65 | 2. The model-type 66 | 67 | ### SparseSC Solution Structure 68 | The first choice is whether to calculate all of the high-level parameters (`V`, its regularization parameter, and the regularization parameters for the weights) on the main matching objective or whether to get approximate/fast estimates of them using non-matching formulations. The options are: 69 | * Full joint (done by `fit()`): We optimize over `v_pen`, `w_pen` and `V`, so that the resulting SC for controls have smallest squared prediction error on `Y_post`. 70 | * Separate (done by `fit_fast()`): We note that we can efficiently estimate `w_pen` on main matching objective, since, given `V`, we can reformulate the finding problem into a Ridge Regression and use efficient LOO cross-validation (e.g. `RidgeCV`) to estimate `w_pen`. We will estimate `V` using an alternative, non-matching objective (such as a `MultiTaskLasso` of using `X,Y_pre` to predict `Y_post`). This setup also allows for feature generation to select the match space. There are two variants depending on how we handle `v_pen`: 71 | * Mixed. Choose `v_pen` based on the resulting down-stream main matching objective. 72 | * Full separate: Choose `v_pen` base on approximate objective (e.g., `MultiTaskLassoCV`). 73 | 74 | The Fully Separate solution is fast and often quite good so we recommend starting there, and if need be, advancing to the Mixed and then Fully Joint optimizations. 75 | 76 | ### Model types 77 | There are two main model-types (corresponding to different cuts of the data) that can be used to estimate treatment effects. 78 | 1. Retrospective: The goal is to minimize squared prediction error of the control units on `Y_post` and the full-pre history of the outcome is used as features in fitting. This is the default and was used in the descriptive elements above. 79 | 2. Prospective: We make an artificial split in time before any treatment actually happens (`Y_pre=[Y_train,Y_test]`). The goal is to minimize squared prediction error of all units on `Y_test` and `Y_train` for all units is used as features in fitting. 80 | 81 | Given the same amount of features, the two will only differ when there are a non-trivial number of treated units. In this case the prospective model may provide lower prediction error for the treated units, though at the cost of less pre-history data used for fitting. When there are a trivial number of units, the retrospective design will be the most efficient. 82 | 83 | See more details about these and two additional model types (Prospective-restrictive, and full) at the docs. 84 | 85 | ## Fitting a synthetic control model 86 | 87 | ### Documentation 88 | 89 | You can read these online at [Read the 90 | Docs](https://sparsesc.readthedocs.io/en/latest/). See there for: 91 | * Custom Donor Pools 92 | Parallelization 93 | * Constraining the `V` matrix to be in the unit simplex 94 | * Performance Notes for `fit()` 95 | * Additional Performance Considerations for `fit()` 96 | * Full parameter listings 97 | 98 | To build the documentation see `docs/dev_notes.md`. 99 | 100 | ## Citation 101 | Brian Quistorff, Matt Goldman, and Jason Thorpe (2020) [Sparse Synthetic Controls: Unit-Level Counterfactuals from High-Dimensional Data](https://drive.google.com/file/d/1lfH1CK_JZpc0ou7hP60FhQpkeoXhR6fC/view?usp=sharing), Microsoft Journal of Applied Research, 14, pp.155-170. 102 | 103 | ## Installation 104 | 105 | The easiest way to install `SparseSC` is with: 106 | 107 | ```sh 108 | pip install git+https://github.com/microsoft/SparseSC.git 109 | ``` 110 | 111 | Additional commands to run tests and examples are in the makefile. 112 | 113 | ## Contributing 114 | 115 | This project welcomes contributions and suggestions. Most contributions 116 | require you to agree to a Contributor License Agreement (CLA) declaring 117 | that you have the right to, and actually do, grant us the rights to use 118 | your contribution. For details, visit https://cla.microsoft.com. 119 | 120 | When you submit a pull request, a CLA-bot will automatically determine 121 | whether you need to provide a CLA and decorate the PR appropriately (e.g., 122 | label, comment). Simply follow the instructions provided by the bot. You 123 | will only need to do this once across all repos using our CLA. 124 | 125 | This project has adopted the [Microsoft Open Source Code of 126 | Conduct](https://opensource.microsoft.com/codeofconduct/). For more 127 | information see the [Code of Conduct 128 | FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact 129 | [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional 130 | questions or comments. 131 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /SparseSC.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio Version 16 4 | VisualStudioVersion = 16.0.29613.14 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{888888A0-9F3D-457C-B088-3A5042F75D52}") = "SparseSC", "src\SparseSC\SparseSC.pyproj", "{3FDC664A-C1C8-47F0-8D77-8E0679E53C82}" 7 | EndProject 8 | Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Misc", "Misc", "{AA17F266-A75A-4FDE-A774-168EADE9A32E}" 9 | ProjectSection(SolutionItems) = preProject 10 | docs\api_ref.md = docs\api_ref.md 11 | docs\azure_batch.md = docs\azure_batch.md 12 | CHANGELOG.md = CHANGELOG.md 13 | docs\dev_notes.md = docs\dev_notes.md 14 | docs\estimate-effects.md = docs\estimate-effects.md 15 | example-code.py = example-code.py 16 | examples\example_graphs.py = examples\example_graphs.py 17 | docs\fit.md = docs\fit.md 18 | examples\fit_poc.py = examples\fit_poc.py 19 | docs\model-types.md = docs\model-types.md 20 | docs\overview.md = docs\overview.md 21 | docs\performance-notes.md = docs\performance-notes.md 22 | README.md = README.md 23 | EndProjectSection 24 | EndProject 25 | Project("{888888A0-9F3D-457C-B088-3A5042F75D52}") = "test", "test\test.pyproj", "{01446F94-F552-4EE1-90DC-E93A0DB22B4C}" 26 | EndProject 27 | Global 28 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 29 | Debug|Any CPU = Debug|Any CPU 30 | Release|Any CPU = Release|Any CPU 31 | EndGlobalSection 32 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 33 | {3FDC664A-C1C8-47F0-8D77-8E0679E53C82}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 34 | {3FDC664A-C1C8-47F0-8D77-8E0679E53C82}.Release|Any CPU.ActiveCfg = Release|Any CPU 35 | {01446F94-F552-4EE1-90DC-E93A0DB22B4C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU 36 | {01446F94-F552-4EE1-90DC-E93A0DB22B4C}.Release|Any CPU.ActiveCfg = Release|Any CPU 37 | EndGlobalSection 38 | GlobalSection(SolutionProperties) = preSolution 39 | HideSolutionNode = FALSE 40 | EndGlobalSection 41 | GlobalSection(ExtensibilityGlobals) = postSolution 42 | SolutionGuid = {9EE1A574-A448-4146-B4A8-998CAEAFDD85} 43 | EndGlobalSection 44 | EndGlobal 45 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-slate -------------------------------------------------------------------------------- /docs/api_ref.md: -------------------------------------------------------------------------------- 1 | # API Reference 2 | 3 | 4 | ## Estimate Treatment Effects 5 | ```eval_rst 6 | .. autofunction:: SparseSC.estimate_effects.estimate_effects 7 | :noindex: 8 | 9 | ``` 10 | 11 | ```eval_rst 12 | .. autoclass:: SparseSC.estimate_effects.SparseSCEstResults 13 | :members: 14 | :show-inheritance: 15 | :noindex: 16 | 17 | ``` 18 | 19 | ```eval_rst 20 | .. autoclass:: SparseSC.utils.metrics_utils.PlaceboResults 21 | :members: 22 | :show-inheritance: 23 | :noindex: 24 | 25 | ``` 26 | 27 | ```eval_rst 28 | .. autoclass:: SparseSC.utils.metrics_utils.EstResultCI 29 | :members: 30 | :show-inheritance: 31 | :noindex: 32 | 33 | ``` 34 | 35 | ```eval_rst 36 | .. autoclass:: SparseSC.utils.metrics_utils.CI_int 37 | :members: 38 | :show-inheritance: 39 | :noindex: 40 | 41 | ``` 42 | 43 | ```eval_rst 44 | .. autoclass:: SparseSC.utils.metrics_utils.AA_results 45 | :members: 46 | :show-inheritance: 47 | :noindex: 48 | 49 | ``` 50 | 51 | 52 | ## Fit a Synthetic Controls Model (Slow, Joint) 53 | ```eval_rst 54 | .. autofunction:: SparseSC.fit.fit 55 | :noindex: 56 | 57 | ``` 58 | 59 | ```eval_rst 60 | .. autoclass:: SparseSC.fit.SparseSCFit 61 | :members: 62 | :show-inheritance: 63 | :noindex: 64 | 65 | ``` 66 | 67 | ## Fit a Synthetic Controls Model (Fast, Separate) 68 | ```eval_rst 69 | .. autofunction:: SparseSC.fit_fast.fit_fast 70 | :noindex: 71 | 72 | ``` 73 | 74 | ```eval_rst 75 | .. autofunction:: SparseSC.utils.match_space.Fixed_V_factory 76 | :noindex: 77 | 78 | ``` 79 | 80 | ```eval_rst 81 | .. autofunction:: SparseSC.utils.match_space.MTLassoCV_MatchSpace_factory 82 | :noindex: 83 | 84 | ``` 85 | 86 | ```eval_rst 87 | .. autofunction:: SparseSC.utils.match_space.MTLSTMMixed_MatchSpace_factory 88 | :noindex: 89 | 90 | ``` 91 | 92 | ```eval_rst 93 | .. autofunction:: SparseSC.utils.match_space.MTLassoMixed_MatchSpace_factory 94 | :noindex: 95 | 96 | ``` 97 | -------------------------------------------------------------------------------- /docs/azure_batch.md: -------------------------------------------------------------------------------- 1 | # Running Jobs in Parallel with Azure Batch 2 | 3 | Fitting a Sparse Synthetic Controls model can result in a very long running 4 | time. Fortunately much of the work can be done in parallel and executed in 5 | the cloud, and the SparseSC package comes with an Azure Batch utility which 6 | can be used to fit a Synthetic controls model using Azure Batch. There are 7 | code examples for Windows CMD, Bash and Powershell. 8 | 9 | ## Setup 10 | 11 | Running SparseSC with Azure Batch requires the `super_batch` library which 12 | can be installed with: 13 | 14 | ```bash 15 | pip install git+https://github.com/jdthorpe/batch-config.git 16 | ``` 17 | 18 | Also note that this module has only been tested with Python 3.7. 19 | You will also need `pyyaml` and `psutil`. 20 | 21 | ### Create the Required Azure resources 22 | 23 | Running SparseSC with Azure Batch requires a an Azure account and handful of 24 | resources and credentials. These can be set up by following along with 25 | [section 4 of the super-batch README](https://github.com/jdthorpe/batch-config#step-4-create-the-required-azure-resources). 26 | 27 | ### Prepare parameters for the Batch Job 28 | 29 | The parameters required to run a batch job can be created using `fit()` by 30 | providing a directory where the parameters files should be stored: 31 | 32 | ```python 33 | from SparseSC import fit 34 | batch_dir = "/path/to/my/batch/data/" 35 | 36 | # initialize the batch parameters in the directory `batch_dir` 37 | fit(x, y, ... , batchDir = batch_dir) 38 | ``` 39 | 40 | ### Executing the Batch Job 41 | 42 | In the following Python script, a Batch configuration is created and the 43 | batch job is executed with Azure Batch. Note that in the following script, 44 | the various Batch Account and Storage Account credentials are taken from 45 | system environment varables, as in the [super-batch readme](https://github.com/jdthorpe/batch-config#step-4-create-the-required-azure-resources). 46 | 47 | ```python 48 | import os 49 | from datetime import datetime 50 | from super_batch import Client 51 | from SparseSC.utils.AzureBatch import ( 52 | DOCKER_IMAGE_NAME, 53 | create_job, 54 | ) 55 | # Batch job names must be unique, and a timestamp is one way to keep it uniquie across runs 56 | timestamp = datetime.utcnow().strftime("%H%M%S") 57 | batch_dir = "/path/to/my/batch/data/" 58 | 59 | batch_client = Client( 60 | # Name of the VM pool 61 | POOL_ID= name, 62 | # number of standard nodes 63 | POOL_NODE_COUNT=5, 64 | # number of low priority nodes 65 | POOL_LOW_PRIORITY_NODE_COUNT=5, 66 | # VM type 67 | POOL_VM_SIZE= "STANDARD_A1_v2", 68 | # Job ID. Note that this must be unique. 69 | JOB_ID= name + timestamp, 70 | # Name of the storage container for storing parameters and results 71 | CONTAINER_NAME= name, 72 | # local directory with the parameters, and where the results will go 73 | BATCH_DIRECTORY= batch_dir, 74 | # Keep the pool around after the run, which saves time when doing 75 | # multiple batch jobs, as it typically takes a few minutes to spin up a 76 | # pool of VMs. (Optional. Default = False) 77 | DELETE_POOL_WHEN_DONE=False, 78 | # Keeping the job details can be useful for debugging: 79 | # (Optional. Default = False) 80 | DELETE_JOB_WHEN_DONE=False 81 | ) 82 | 83 | create_job(batch_client, batch_dir) 84 | # run the batch job 85 | batch_client.run() 86 | 87 | # aggregate the results into a fitted model instance 88 | fitted_model = aggregate_batch_results(batch_dir) 89 | ``` 90 | 91 | ## Cleaning Up 92 | 93 | When you are done fitting your model with Azure Batch be sure to 94 | [clean up your Azure Resources](https://github.com/jdthorpe/batch-config#step-6-clean-up) 95 | in order to prevent unexpected charges on your Azure account. 96 | 97 | ## Solving 98 | 99 | The Azure batch will just vary one of the penalty parameters. You should therefore not specify the 100 | simplex constraint for the V matrix as then it will be missing one degree of freedom. 101 | 102 | ## FAQ 103 | 104 | 1. What if I get disconnected while the batch job is running? 105 | 106 | Once the pool and the job are created, they will keep running until the 107 | job completes, or your delete the resources. You can reconnect create the 108 | `batch_client` as in the example above and then reconnect to the job and 109 | download the results with: 110 | 111 | ```python 112 | batch_client.load_results() 113 | fitted_model = aggregate_batch_results(batch_dir) 114 | ``` 115 | 116 | In fact, if you'd rather not wait for the job to compelte, you can 117 | add the parameter `batch_client.run(... ,wait=False)` and the 118 | `run_batch_job` will return as soon as the job and pool configuration 119 | have been createdn in Azure. 120 | 121 | 1. `batch_client.run()` or `batch_client.load_results()` complain that the 122 | results are in complete. What happened? 123 | 124 | Typically this means that one or more of the jobs failed, and a common 125 | reason for the job to fail is that the VM runs out of memory while 126 | running the batch job. Failed Jobs can be viewed in either the Azure 127 | Batch Explorer or the Azure Portal. The `POOL_VM_SIZE` use above 128 | ("STANDARD_A1_v2") is one of the smallest (and cheapest) VMs available 129 | on Azure. Upgrading to a VM with more memory can help in this 130 | situation. 131 | 132 | 1. Why does `aggregate_batch_results()` take so long? 133 | 134 | Each batch job runs a single gradient descent in V space using a subset 135 | (Cross Validation fold) of the data and with a single pair of penalty 136 | parameters, and return the out of sample error for the held out samples. 137 | `aggregate_batch_results()` very quickly aggregates these out of sample 138 | errors and chooses the optimal penalty parameters given the `choice` 139 | parameter provided to `fit()` or `aggregate_batch_results()`. Finally, 140 | with the selected parameters, a final gradient descent is run using the 141 | full dataset which will be larger than the and take longer as the rate 142 | limiting step 143 | ( [scipy.linalg.solve](https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.solve.html) ) 144 | has a running time of 145 | [`O(N^3)`](https://stackoverflow.com/a/12665483/1519199). While it is 146 | possible to run this step in parallel as well, it hasn't yet been 147 | implemented. 148 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # SparseSC documentation build configuration file, created by 5 | # sphinx-quickstart on Thu Sep 27 14:53:55 2018. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | # 20 | import os 21 | import sys 22 | ##Allow MarkDown. 23 | ##Prerequisite. pip install recommonmark 24 | import recommonmark 25 | from recommonmark.transform import AutoStructify 26 | 27 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) 28 | 29 | #Allow capturing output 30 | #from SparseSC.utils.misc import capture 31 | 32 | 33 | # -- General configuration ------------------------------------------------ 34 | 35 | # If your documentation needs a minimal Sphinx version, state it here. 36 | # 37 | # needs_sphinx = '1.0' 38 | 39 | # Add any Sphinx extension module names here, as strings. They can be 40 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 41 | # ones. 42 | extensions = ['sphinx.ext.autodoc', 43 | 'sphinx.ext.mathjax', 44 | 'sphinx_markdown_tables'] 45 | 46 | # Add any paths that contain templates here, relative to this directory. 47 | templates_path = ['_templates'] 48 | 49 | # The suffix(es) of source filenames. 50 | # You can specify multiple suffix as a list of string: 51 | # 52 | source_suffix = ['.rst', '.md'] 53 | 54 | # The master toctree document. 55 | master_doc = 'index' 56 | 57 | # General information about the project. 58 | project = 'SparseSC' 59 | copyright = '2018, Jason Thorpe, Brian Quistorff, Matt Goldman' 60 | author = 'Jason Thorpe, Brian Quistorff, Matt Goldman' 61 | 62 | # The version info for the project you're documenting, acts as replacement for 63 | # |version| and |release|, also used in various other places throughout the 64 | # built documents. 65 | # 66 | from SparseSC import __version__ 67 | version = __version__ 68 | # The full version, including alpha/beta/rc tags. (For now, keep the same) 69 | release = version 70 | 71 | # The language for content autogenerated by Sphinx. Refer to documentation 72 | # for a list of supported languages. 73 | # 74 | # This is also used if you do content translation via gettext catalogs. 75 | # Usually you set "language" from the command line for these cases. 76 | language = None 77 | 78 | html_copy_source=False 79 | 80 | # List of patterns, relative to source directory, that match files and 81 | # directories to ignore when looking for source files. 82 | # This patterns also effect to html_static_path and html_extra_path 83 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 84 | 85 | # The name of the Pygments (syntax highlighting) style to use. 86 | pygments_style = 'sphinx' 87 | 88 | # If true, `todo` and `todoList` produce output, else they produce nothing. 89 | todo_include_todos = False 90 | 91 | #Run apidoc from here rather than separate process (so that we can do Read the Docs easily) 92 | #https://github.com/rtfd/readthedocs.org/issues/1139 93 | #Used to exclude some files, but this caused warnings (they weren't removed from the auto-generated index) 94 | # and instead we just made a currated api doc. Leave for now as we might go back. 95 | def run_apidoc(app): 96 | from sphinx.apidoc import main as apidoc_main 97 | cur_dir = os.path.abspath(os.path.dirname(__file__)) 98 | buildapidocdir = os.path.join(app.outdir, "apidoc","SparseSC") 99 | module = os.path.join(cur_dir,"..","src","SparseSC") 100 | to_excl = [] #"cross_validation","fit_ct","fit_fold", "fit_loo","optimizers","optimizers/cd_line_search","tensor","utils/ols_model","utils/penalty_utils","utils/print_progress","utils/sub_matrix_inverse","weights"] 101 | #Locally could wrap each to_excl with "*" "*" and put in the apidoc cmd and end and works as exclude patterns, but doesn't work on RTD 102 | #with capture() as out: #doesn't have quiet option 103 | apidoc_main([None, '-f', '-e', '-o', buildapidocdir, module]) 104 | #rm module file because we don't link to it directly and this silences the warning 105 | os.remove(os.path.join(buildapidocdir, "modules.rst")) 106 | for excl in to_excl: 107 | path = os.path.join(buildapidocdir, "SparseSC."+excl.replace("/",".")+".rst") 108 | print("removing: "+path) 109 | os.remove(path) 110 | 111 | def skip(app, what, name, obj, skip, options): 112 | #force showing __init__()'s 113 | if name == "__init__": 114 | return False 115 | 116 | skip_fns = [] 117 | if what=="class" and '__qualname__' in dir(obj) and obj.__qualname__ in skip_fns: 118 | return True 119 | 120 | # Can't figure out how to get the properties class to skip more targettedly 121 | skip_prs = [] 122 | if what=="class" and name in skip_prs: 123 | return True 124 | 125 | skip_mds = [] 126 | if what=="module" and name in skip_mds: 127 | return True 128 | #helpful debugging line 129 | #print what, name, obj, dir(obj) 130 | 131 | return skip 132 | 133 | def setup(app): 134 | app.connect('builder-inited', run_apidoc) 135 | 136 | app.connect("autodoc-skip-member", skip) 137 | #Allow MarkDown 138 | app.add_config_value('recommonmark_config', { 139 | 'url_resolver': lambda url: "build/apidoc/" + url, 140 | 'auto_toc_tree_section': ['Contents','Examples'], 141 | 'enable_eval_rst': True, 142 | #'enable_auto_doc_ref': True, 143 | 'enable_math': True, 144 | 'enable_inline_math': True 145 | }, True) 146 | app.add_transform(AutoStructify) 147 | 148 | 149 | # -- Options for HTML output ---------------------------------------------- 150 | 151 | # The theme to use for HTML and HTML Help pages. See the documentation for 152 | # a list of builtin themes. 153 | # 154 | html_theme = 'sphinx_rtd_theme' 155 | 156 | # Theme options are theme-specific and customize the look and feel of a theme 157 | # further. For a list of options available for each theme, see the 158 | # documentation. 159 | # 160 | # html_theme_options = {} 161 | 162 | # Add any paths that contain custom static files (such as style sheets) here, 163 | # relative to this directory. They are copied after the builtin static files, 164 | # so a file named "default.css" will overwrite the builtin "default.css". 165 | #html_static_path = ['_static'] 166 | 167 | # Custom sidebar templates, must be a dictionary that maps document names 168 | # to template names. 169 | # 170 | # This is required for the alabaster theme 171 | # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars 172 | html_sidebars = { 173 | '**': [ 174 | 'about.html', 175 | 'navigation.html', 176 | 'relations.html', # needs 'show_related': True theme option to display 177 | 'searchbox.html', 178 | 'donate.html', 179 | ] 180 | } 181 | 182 | ##Allow MarkDown. 183 | source_parsers = {'.md': 'recommonmark.parser.CommonMarkParser', } 184 | 185 | 186 | # -- Options for HTMLHelp output ------------------------------------------ 187 | 188 | # Output file base name for HTML help builder. 189 | htmlhelp_basename = 'SparseSCdoc' 190 | 191 | 192 | # -- Options for LaTeX output --------------------------------------------- 193 | 194 | latex_elements = { 195 | # The paper size ('letterpaper' or 'a4paper'). 196 | # 197 | # 'papersize': 'letterpaper', 198 | 199 | # The font size ('10pt', '11pt' or '12pt'). 200 | # 201 | # 'pointsize': '10pt', 202 | 203 | # Additional stuff for the LaTeX preamble. 204 | # 205 | # 'preamble': '', 206 | 207 | # Latex figure (float) alignment 208 | # 209 | # 'figure_align': 'htbp', 210 | } 211 | 212 | # Grouping the document tree into LaTeX files. List of tuples 213 | # (source start file, target name, title, 214 | # author, documentclass [howto, manual, or own class]). 215 | latex_documents = [ 216 | (master_doc, 'SparseSC.tex', 'SparseSC Documentation', 217 | 'Jason Thorpe, Brian Quistorff, Matt Goldman', 'manual'), 218 | ] 219 | 220 | 221 | # -- Options for manual page output --------------------------------------- 222 | 223 | # One entry per manual page. List of tuples 224 | # (source start file, name, description, authors, manual section). 225 | man_pages = [ 226 | (master_doc, 'sparsesc', 'SparseSC Documentation', 227 | [author], 1) 228 | ] 229 | 230 | 231 | # -- Options for Texinfo output ------------------------------------------- 232 | 233 | # Grouping the document tree into Texinfo files. List of tuples 234 | # (source start file, target name, title, author, 235 | # dir menu entry, description, category) 236 | texinfo_documents = [ 237 | (master_doc, 'SparseSC', 'SparseSC Documentation', 238 | author, 'SparseSC', 'One line description of project.', 239 | 'Miscellaneous'), 240 | ] 241 | 242 | 243 | 244 | -------------------------------------------------------------------------------- /docs/dev_notes.md: -------------------------------------------------------------------------------- 1 | # Developer Notes 2 | 3 | ## Python environments 4 | 5 | You can create Anaconda environments using 6 | ```bash 7 | conda env create -f test/SparseSC_36.yml 8 | ``` 9 | You can can do `update` rather than `create` to update existing ones (to avoid [potential bugs](https://stackoverflow.com/a/46114295/3429373) make sure the env isn't currently active). 10 | 11 | Note: When regenerating these files (`conda env export > test/SparseSC_*.yml`) make sure to remove the final `prefix` line since that's computer specific. You can do this automatically on Linux by inserting `| grep -v "prefix"` and on Windows by inserting `| findstr -v "prefix"`. 12 | 13 | ## Building the docs 14 | Requires Python >=3.6 and packages: `sphinx`, `recommonmark`, `sphinx-markdown-tables`. 15 | Use `(n)make htmldocs` and an index HTML file is madeat `docs/build/html/index.html`. 16 | 17 | To build a mini-RTD environment to test building docs: 18 | 1) You can make a new environment with Python 3.7 (`conda create -n SparseSC_37_rtd python=3.7`) 19 | 2) update `pip` (likely fine). 20 | 3) `pip install --upgrade --no-cache-dir -r docs/rtd-base.txt` . This file is loosely kept in sync by looking at the install commands on the rtd run. 21 | 4) `pip install --exists-action=w --no-cache-dir -r docs/rtd-requirements.txt` . This file doesn't list the full environment versions because that causes headaches when the rtd base environment got updated. It downgrades Sphinx to a known good version that allows markdown files to have math in code quotes ([GH Issues](https://github.com/readthedocs/recommonmark/issues/133)) (there might be higher ones that also work, didn't try). 22 | 23 | ## Running examples 24 | The Jupyter notebooks require `matplotlib`, `jupyter`, and `notebook`. 25 | 26 | ## Testing 27 | We use the built-in `unittest`. Can run from makefile using the `tests` target or you can run python directly from the repo root using the following types of commands: 28 | 29 | ```bash 30 | python -m unittest test/test_fit.py #file (only Python >=3.5) 31 | python -m unittest test.test_fit #module 32 | python -m unittest test.test_fit.TestFit #class 33 | python -m unittest test.test_fit.TestFit.test_retrospective #function 34 | ``` 35 | 36 | 37 | ## Release Process 38 | * Ensure the makefile target `check` (which does pylint, tests, doc building, and packaging) runs clean 39 | * If new version, check that it's been updated in `SparseSC/src/__init__.py` 40 | * Updated `Changelog.md` 41 | * Tag/Release in version control 42 | 43 | -------------------------------------------------------------------------------- /docs/estimate-effects.md: -------------------------------------------------------------------------------- 1 | # Treatment Effects 2 | 3 | The `estimate_effects()` function can be used to conduct 4 | [DID](https://en.wikipedia.org/wiki/Difference_in_differences) style 5 | analyses where counter-factual observations are constructed using Sparse 6 | Synthetic Controls. 7 | 8 | ```py 9 | import SparseSC 10 | 11 | # Fit the model: 12 | fitted_estimates = SparseSC.estimate_effects(outcomes,unit_treatment_periods,covariates=X,fast=True,...) 13 | 14 | # Print summary of the model including effect size estimates, 15 | # p-values, and confidendence intervals: 16 | print(fitted_estimates) 17 | 18 | # Extract model attributes: 19 | fitted_estimates.pl_res_post.avg_joint_effect.p_value 20 | fitted_estimates.pl_res_post.avg_joint_effect.CI 21 | 22 | # access the fitted Synthetic Controls model: 23 | fitted_model = fitted_estimates.fit 24 | ``` 25 | 26 | The returned object is of class `SparseSCEstResults`. 27 | 28 | #### Feature and Target Data 29 | 30 | When estimating synthetic controls, units of observation are divided into 31 | control and treated units. Data collected on these units may include 32 | observations of the outcome of interest, as well as other characteristics 33 | of the units (termed "covariates", herein). Outcomes may be observed both 34 | before and after an intervention on the treated units. 35 | 36 | To maintain independence of the fitted synthetic controls and the 37 | post-intervention outcomes of interest of treated units, the 38 | post-intervention outcomes from treated units are not used in the fitting 39 | process. There are two cuts from the remaining data that may be used to 40 | fit synthetic controls, and each has it's advantages and disadvantages. 41 | 42 | In the call to `estimate_effects()`, `outcomes` should 43 | be numeric matrices containing data on the target variables collected prior 44 | to (after) the treatment / intervention ( respectively), and the optional 45 | parameter `covariates` may be a matrix of additional features. All matrices 46 | should have one row per unit and one column per observation. 47 | 48 | In addition, the rows in `covariates` and `outcomes` which contain units that were affected 49 | by the intervention ("treated units") should be indicated using the 50 | `treated_units` parameter, which may be a vector of booleans or integers 51 | indicating the rows which belong to treat units. 52 | 53 | #### Statistical parameters 54 | 55 | The confidence level may be specified with the `level` parameter, and the 56 | maximum number of simulations used to produce the placebo distribution may 57 | be set with the `max_n_pl` parameter. 58 | 59 | #### Additional parameters 60 | 61 | Additional keyword arguments are passed on to the call to `fit()`, which is 62 | responsible for fitting the Synthetic Controls used to create the 63 | counterfactuals. 64 | -------------------------------------------------------------------------------- /docs/examples/estimate_effects.md: -------------------------------------------------------------------------------- 1 | # Treatment effects 2 | 3 | [ Coming soon ] 4 | -------------------------------------------------------------------------------- /docs/examples/prospective_anomaly_detection.md: -------------------------------------------------------------------------------- 1 | # Anomaly Detection 2 | 3 | ### Overview 4 | 5 | In this scenario the goal is to identify irregular values in an outcome 6 | variable prospectively in a homogeneous population (i.e. when no 7 | treatment / intervention is planned). As an example, we may wish to detect 8 | failure of any one machine in a cluster, and to do so, we wish to create a 9 | synthetic unit for each machine which is composed of a weighted average of 10 | other machines in the cluster. In particular, there may be variation of 11 | the workload across the cluster and where workload may vary across the 12 | cluster by (possibly unobserved) differences in machine hardware, cluster 13 | architecture, scheduler versions, networking architecture, job type, etc. 14 | 15 | Like the Prospective Treatment Effects scenario, *Feature* data consist of 16 | of unit attributes (covariates) and a subset of the pre-intervention values 17 | from the outcome of interest, and **target** data consist of the remaining 18 | pre-intervention values for the outcome of interest, and Cross fold 19 | validation is conducted using the entire dataset, and Cross validation and 20 | gradient folds are determined randomly. 21 | 22 | ### Example 23 | 24 | In this scenario, we'll need a matrix with past observations of the outcome 25 | (target) of interest (`targets`), with one row per unit of observation, and 26 | one column per time period, ordered from left to right. Additionally we 27 | may have another matrix of additional features with one row per unit and 28 | one column per feature (`features`). Armed with this we may wish to construct a 29 | synthetic control model to help decide weather future observations 30 | (`additional_observations`) deviate from their synthetic predictions. 31 | 32 | The strategy will be to divide the `targets` matrix into two parts (before 33 | and after column `t`), one of which will be used as features, and other 34 | which will be treated as outcomes for the purpose of fitting the weights 35 | which make up the synthetic controls model. 36 | 37 | ```python 38 | from numpy import hstack 39 | from SparseSC import fit 40 | 41 | # Let X be the features plus some of the targets 42 | X = hstack([features, targets[:,:t]) 43 | 44 | # And let Y be the remaining targets 45 | Y = targets[:,t:] 46 | 47 | # fit the model: 48 | fitted_model = fit(X=X, 49 | Y=Y, 50 | model_type="full") 51 | ``` 52 | 53 | The `model_type="full"` allows produces a model in which every unit can 54 | serve as a control for every other unit, unless of course the parameter 55 | `custom_donor_pool` is specified. 56 | 57 | Now with our fitted synthetic control model, as soon as new set of targets 58 | outcomes are observed for each unit, we can create synthetic outcomes using 59 | our fitted model using the `predict()` method: 60 | 61 | ```python 62 | synthetic_controls = fitted_model.predict(additional_observations) 63 | ``` 64 | 65 | Note that while the call to `fit()` is computationally intensive, the call 66 | to `model.predict()` is fast and can be used for real time anomaly 67 | detection. 68 | 69 | ### Model Details: 70 | 71 | This model yields a synthetic unit for every unit in the dataset, and 72 | synthetic units are composted of the remaining units not included in the 73 | same gradient fold. 74 | 75 | | Type | Units used to fit V & penalties | Donor pool for W | 76 | |---|---|---| 77 | |(prospective) full|All units|All units| 78 | -------------------------------------------------------------------------------- /docs/examples/retrospecitve_restricted_synthetic_controls.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/docs/examples/retrospecitve_restricted_synthetic_controls.md -------------------------------------------------------------------------------- /docs/examples/retrospecitve_synthetic_controls.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/docs/examples/retrospecitve_synthetic_controls.md -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. SparseSC documentation master file, created by 2 | sphinx-quickstart on Thu Sep 27 14:53:55 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to SparseSC's |version| documentation! 7 | ============================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | overview 14 | fit 15 | estimate-effects 16 | model-types 17 | api_ref 18 | azure_batch 19 | dev_notes 20 | 21 | 22 | .. the complete docs (build/apidoc/SparseSC/SparseSC) are a bit chaotic* and 23 | redundant to :ref:`modindex`. Replacing with the tamer api_ref.md 24 | 25 | .. toctree:: 26 | :maxdepth: 2 27 | :caption: Examples: 28 | 29 | examples/prospective_anomaly_detection 30 | 31 | 32 | Indices and tables 33 | ================== 34 | 35 | * :ref:`genindex` 36 | * :ref:`modindex` 37 | 38 | .. not needed: * :ref:`search` 39 | -------------------------------------------------------------------------------- /docs/model-types.md: -------------------------------------------------------------------------------- 1 | # Model Types 2 | 3 | There are three distinct types of model fitting with respect choosing 4 | optimal values of the penalty parameters and the collection of units 5 | that are used for cross validation. 6 | 7 | Recall that a synthetic treatment unit is defined as a weighted average of 8 | control units where weights are determined from the targets and are chosen 9 | to minimize the difference between the each the treated unit and its 10 | synthetic unit in the absence of an intervention. This methodology can be 11 | combined with cross-fold validation a in several ways in order to 12 | accommodate a variety of use cases. 13 | 14 | ## Retrospective Treatment Effects 15 | 16 | `model_type = "retrospective"` 17 | 18 | In a retrospective analysis, a subset of units have received a treatment 19 | which possibly correlates with features of the units and values for the 20 | target variable have been collected both before and after an intervention. 21 | For example, a software update may have been applied to machines in a 22 | cluster which were experiences unusually latency, and retrospectively an 23 | analyst wishes to understand the effect of the update on another outcome 24 | such as memory utilization. 25 | 26 | In this scenario, for each treated unit we wish to create a synthetic unit 27 | composed only of untreated units. The units are divided into a training 28 | set consisting of just the control units, and a test set consisting of the 29 | treated units. Within the training set, *feature* data consist of anything 30 | known prior to a treatment or intervention, such as pre-intervention values 31 | from the outcome of interest as well as unit level attributes, sometimes 32 | called "covariatess". Likewise, **target** data consist of observations 33 | from the outcome of interest collected after the treatment or intervention 34 | is initiated. 35 | 36 | Cross-fold validation is done within the training set, holding out a single 37 | fold, identifying feature weights within the remaining folds, creating 38 | synthetic units for each held out unit defined as a weighted average of the 39 | non-held-out units. Out-of-sample prediction errors are calculated for each 40 | training fold and summed across folds. The set of penalty parameters that 41 | minimizes the Cross-Fold out-of-sample error are chosen. Feature weights 42 | are calculated within the training set for the chosen penalties, and 43 | finally individual synthetic units are calculated for each treated unit 44 | which is a weighted average of the control units. 45 | 46 | This model yields a synthetic unit for each treated unit composed of 47 | control units. 48 | 49 | ## Prospective Treatment Effects 50 | 51 | `model_type = "prospective"` 52 | 53 | In a prospective analysis, a subset of units have been designated to 54 | receive a treatment but the treatment has not yet occurred and the 55 | designation of the treatment may be correlated with a (possibly unobserved) 56 | feature of the treatment units. For example, a software update may have 57 | been planned for machines in a cluster which are experiencing unusually 58 | latency, and there is a desire to understand the impact of the software on 59 | memory utilization. 60 | 61 | Like the retrospective scenario, for each treated unit we wish to create a 62 | synthetic unit composed only of untreated units. 63 | 64 | *Feature* data consist of of unit attributes (covariates) and a subset 65 | of the pre-intervention values from the outcome of interest, and **target** 66 | data consist of the remaining pre-intervention values for the outcome of 67 | interest 68 | 69 | Cross fold validation is conducted using the entire dataset without regard 70 | to intent to treat. However, treated units allocated to a single gradient 71 | fold, ensuring that synthetic treated units are composed of only the 72 | control units. 73 | 74 | Cross-fold validation is done within the *test* set, holding out a single 75 | fold, identifying feature weights within the remaining treatment folds 76 | combined with the control units, synthetic units for each held out unit 77 | defined as a weighted average of the full set of control units. 78 | 79 | Out-of-sample prediction errors are calculated for each treatment fold and 80 | the sum of these defines the Cross-Fold out-of-sample error. The set of 81 | penalty parameters that minimizes the Cross-Fold out-of-sample error are 82 | chosen. Feature weights are calculated within the training set for the 83 | chosen penalties, and finally individual synthetic units are calculated for 84 | each full unit which is a weighted average of the control units. 85 | 86 | This model yields a synthetic unit for each treated unit composed of 87 | control units. 88 | 89 | ## Prospective Treatment Effects training 90 | 91 | `model_type = "prospective-restricted"` 92 | 93 | This is motivated by the same example as the previous sample. It requires 94 | a larger set of treated units for similar levels of precision, with the 95 | benefit of substantially faster running time. 96 | 97 | The units are divided into a training set consisting of just the control 98 | units, and a test set consisting of the unit which will be treated. 99 | *feature* data will consist of of unit attributes (covariates) and a subset 100 | of the pre-intervention values from the outcome of interest, and **target** 101 | data consist of the remaining pre-intervention values for the outcome of 102 | interest 103 | 104 | Cross-fold validation is done within the *test* set, holding out a single 105 | fold, identifying feature weights within the remaining treatment folds 106 | combined with the control units, synthetic units for each held out unit 107 | defined as a weighted average of the full set of control units. 108 | 109 | Out-of-sample prediction errors are calculated for each treatment fold and 110 | the sum of these defines the Cross-Fold out-of-sample error. The set of 111 | penalty parameters that minimizes the Cross-Fold out-of-sample error are 112 | chosen. Feature weights are calculated within the training set for the 113 | chosen penalties, and finally individual synthetic units are calculated for 114 | each full unit which is a weighted average of the control units. 115 | 116 | This model yields a synthetic unit for each treated unit composed of 117 | control units. 118 | 119 | Not that this model will tend to have wider confidence intervals and small estimated treatments given the sample it is fit on. 120 | 121 | ## Prospective Failure Detection 122 | 123 | `model_type = "full"` 124 | 125 | In this scenario the goal is to identify irregular values in an outcome 126 | variable prospectively in a homogeneous population (i.e. when no 127 | treatment / intervention is planned). As an example, we may wish to detect 128 | failure of any one machine in a cluster, and to do so, we wish to create a 129 | synthetic unit for each machine which is composed of a weighted average of 130 | other machines in the cluster. In particular, there may be variation of 131 | the workload across the cluster and where workload may vary across the 132 | cluster by (possibly unobserved) differences in machine hardware, cluster 133 | architecture, scheduler versions, networking architecture, job type, etc. 134 | 135 | Like the Prospective Treatment Effects scenario, *Feature* data consist of 136 | of unit attributes (covariates) and a subset of the pre-intervention values 137 | from the outcome of interest, and **target** data consist of the remaining 138 | pre-intervention values for the outcome of interest, and Cross fold 139 | validation is conducted using the entire dataset, and Cross validation and 140 | gradient folds are determined randomly. 141 | 142 | This model yields a synthetic unit for every unit in the dataset, and 143 | synthetic units are composted of the remaining units not included in the 144 | same gradient fold. 145 | 146 | ## Summary 147 | 148 | Here is a summary of the main differences between the model types. 149 | 150 | | Type | Units used to fit V & penalties | Donor pool for W | 151 | |---|---|---| 152 | |retrospective|Controls|Controls| 153 | |prospective|All|Controls| 154 | |prospective-restricted|Treated|Controls| 155 | |(prospective) full|All|All| 156 | 157 | A tree view of differences: 158 | * Treatment date: The *prospective* studies differ from the *retrospective* study in that they can use all units for fitting. 159 | * (Prospective studies) Treated units: The intended-to-treat (*ITT*) studies differ from the *full* in that the "treated" units can't be used for donors. 160 | * (Prospective-ITT studies): The *restrictive* model differs in that it tries to maximize predictive power for just the treated units. 161 | -------------------------------------------------------------------------------- /docs/performance-notes.md: -------------------------------------------------------------------------------- 1 | # Performance Notes 2 | 3 | ## Running time 4 | 5 | The function `get_max_lambda()` requires a single calculation of the 6 | gradient using all of the available data. In contrast, ` SC.CV_score()` 7 | performs gradient descent within each validation-fold of the data. 8 | Furthermore, in the 'pre-only' scenario the gradient is calculated once for 9 | each iteration of the gradient descent, whereas in the 'controls-only' 10 | scenario the gradient is calculated once for each control unit. 11 | Specifically, each control unit is excluded from the set of units that can 12 | be used to predict it's own post-intervention outcomes, resulting in 13 | leave-one-out gradient descent. 14 | 15 | For large sample sizes in the 'controls-only' scenario, it may be 16 | sufficient to divide the non-held out control units into "gradient folds", such 17 | that controls within the same gradient-fold are not used to predict the 18 | post-intervention outcomes of other control units in the same fold. This 19 | result's in K-fold gradient descent, which improves the speed of 20 | calculating the overall gradient by a factor slightly greater than `c/k` 21 | (where `c` is the number of control units) with an even greater reduction 22 | in memory usage. 23 | 24 | K-fold gradient descent is enabled by passing the parameter `grad_splits` 25 | to `CV_score()`, and for consistency across calls to `CV_score()` it is 26 | recommended to also pass a value to the parameter `random_state`, which is 27 | used in selecting the gradient folds. 28 | 29 | ## Additional Considerations 30 | 31 | If you have the BLAS/LAPACK libraries installed and available to Python, 32 | you should not need to do any further optimization to ensure that maximum 33 | number of processors are used during the execution of `CV_score()`. If 34 | not, you may wish to set the parameter `parallel=True` when you call 35 | `CV_score()` which will split the work across N - 2 sub-processes where N 36 | is the [number of cores in your 37 | machine](https://docs.python.org/2/library/multiprocessing.html#miscellaneous). 38 | (Note that setting `parallel=True` when the BLAS/LAPACK are available will 39 | tend to increase running times.) 40 | 41 | -------------------------------------------------------------------------------- /docs/rtd-base.txt: -------------------------------------------------------------------------------- 1 | setuptools==54.0.0 2 | mock==1.0.1 3 | pillow==5.4.1 4 | alabaster>=0.7,<0.8,!=0.7.5 5 | commonmark==0.8.1 6 | recommonmark==0.5.0 7 | sphinx<2 8 | sphinx-rtd-theme<0.5 9 | readthedocs-sphinx-ext<2.2 10 | -------------------------------------------------------------------------------- /docs/rtd-requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.13.1 2 | scikit-learn>=0.19.1 3 | scipy>=0.19.1 4 | Sphinx==1.6.3 5 | sphinx-markdown-tables==0.0.15 6 | statsmodels>=0.8.0 7 | pyyaml>=5.4.1 8 | psutil>=5.8.0 9 | git+git://github.com/jdthorpe/batch-config@7eae164#egg=batch-config 10 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/examples/__init__.py -------------------------------------------------------------------------------- /examples/strip_magic.py: -------------------------------------------------------------------------------- 1 | with open("DifferentialTrends.py","r+") as f: 2 | new_f = f.readlines() 3 | f.seek(0) 4 | for line in new_f: 5 | if "get_ipython().magic" not in line: 6 | f.write(line) 7 | f.truncate() 8 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | #On Windows can use 'nmake' 2 | # Incl w/ VS (w/ "Desktop development with C++" components) 3 | # Incl in path or use the "Developer Command Prompt for VS...." 4 | # If manually including the folder in path, this can change when VS's MSVC version gets bumped. 5 | # So to always get it you can run (or change for your system): 6 | # "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat" 7 | # Can run 'nmake /NOLOGO ...' to remove logo output. 8 | #nmake can't do automatic (pattern) rules like make (has inference rules which aren't cross-platform) 9 | 10 | # For linux, can use Anaconda or if using virtualenv, install virtualenvwrapper 11 | # and alias activate->workon. 12 | 13 | help: 14 | @echo "Use one of the common targets: pylint, package, readmedocs, htmldocs" 15 | 16 | #Allow for slightly different commands for nmake and make 17 | #NB: Don't always need the different DIR_SEP 18 | # \ 19 | !ifndef 0 # \ 20 | # nmake specific code here \ 21 | RMDIR_CMD = rmdir /S /Q # \ 22 | RM_CMD = del # \ 23 | DIR_SEP = \ # \ 24 | !else 25 | # make specific code here 26 | RMDIR_CMD = rm -rf 27 | RM_CMD = rm 28 | DIR_SEP = /# \ 29 | # \ 30 | !endif 31 | 32 | #Creates a "Source Distribution" and a "Pure Python Wheel" (which is a bit easier for user) 33 | package: package_both 34 | 35 | package_both: 36 | python setup.py sdist bdist_wheel 37 | 38 | package_sdist: 39 | python setup.py sdist 40 | 41 | package_bdist_wheel: 42 | python setup.py bdist_wheel 43 | 44 | pypi_upload: 45 | twine upload dist/* 46 | #python -m twine upload --verbose --repository-url https://test.pypi.org/legacy/ dist/* 47 | 48 | pylint: 49 | -mkdir build 50 | -pylint SparseSC > build$(DIR_SEP)pylint_msgs.txt 51 | 52 | SOURCEDIR = docs/ 53 | BUILDDIR = docs/build 54 | BUILDDIRHTML = docs$(DIR_SEP)build$(DIR_SEP)html 55 | BUILDAPIDOCDIR= docs$(DIR_SEP)build$(DIR_SEP)apidoc 56 | 57 | #$(O) is meant as a shortcut for $(SPHINXOPTS). 58 | htmldocs: 59 | -$(RMDIR_CMD) $(BUILDDIRHTML) 60 | -$(RMDIR_CMD) $(BUILDAPIDOCDIR) 61 | # sphinx-apidoc -f -o $(BUILDAPIDOCDIR)/SparseSC SparseSC 62 | # $(RM_CMD) $(BUILDAPIDOCDIR)$(DIR_SEP)SparseSC$(DIR_SEP)modules.rst 63 | @python -msphinx -b html -T -E "$(SOURCEDIR)" "$(BUILDDIR)" $(O) 64 | 65 | examples: 66 | python example-code.py 67 | python examples/fit_poc.py 68 | 69 | tests: 70 | python -m unittest test.test_fit.TestFitForErrors test.test_fit.TestFitFastForErrors test.test_normal.TestNormalForErrors test.test_estimation.TestEstimationForErrors 71 | 72 | #tests_both: 73 | # activate SparseSC_36 && python -m unittest test.test_fit 74 | 75 | #add examples here when working 76 | check: pylint package_bdist_wheel tests_both 77 | 78 | #Have to strip because of unfixed https://github.com/jupyter/nbconvert/issues/503 79 | examples/DifferentialTrends.py: examples/DifferentialTrends.ipynb 80 | jupyter nbconvert examples/DifferentialTrends.ipynb --to script 81 | cd examples && python strip_magic.py 82 | 83 | examples/DifferentialTrends.html: examples/DifferentialTrends.ipynb 84 | jupyter nbconvert examples/DifferentialTrends.ipynb --to html 85 | 86 | examples/DifferentialTrends.pdf: examples/DifferentialTrends.ipynb 87 | cd examples && jupyter nbconvert DifferentialTrends.ipynb --to pdf 88 | 89 | clear_ipynb_output: 90 | jupyter nbconvert --ClearOutputPreprocessor.enabled=True --inplace examples/DifferentialTrends.ipynb 91 | gen_ipynb_output: 92 | jupyter nbconvert --to notebook --execute examples/DifferentialTrends.ipynb 93 | 94 | #Have to cd into subfulder otherwise will pick up potential SparseSC pkg in build/ 95 | #TODO: Make the prefix filter automatic 96 | #TODO: check if this way of doing phony targets for nmake works with make 97 | test/SparseSC_36.yml: .phony 98 | activate SparseSC_36 && cd test && conda env export > SparseSC_36.yml 99 | echo Make sure to remove the last prefix line and the pip sparsesc line, as user does pip install -e for that 100 | .phony: 101 | 102 | #Old: 103 | # Don't generate requirements-rtd.txt from conda environments (e.g. pip freeze > rtd-requirements.txt) 104 | # 1) Can be finicky to get working since using pip and docker images and don't need lots of packages (e.g. for Jupyter) 105 | # 2) Github compliains about requests<=2.19.1. Conda can't install 2.20 w/ Python <3.6. Our env is 3.5, but RTD uses Python3.7 106 | # Could switch to using conda 107 | #doc/rtd-requirements.txt: 108 | 109 | conda_env_upate: 110 | deactivate && conda env update -f test/SparseSC_36.yml 111 | 112 | #Just needs to be done once 113 | conda_env_create: 114 | conda env create -f test/SparseSC_36.yml 115 | 116 | jupyter_DifferentialTrends: 117 | START jupyter notebook examples/DifferentialTrends.ipynb > jupyter_launch.log 118 | -------------------------------------------------------------------------------- /replication/.gitignore: -------------------------------------------------------------------------------- 1 | dta_dir/ 2 | 3 | -------------------------------------------------------------------------------- /replication/figs/sparsesc_fast_pe.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/sparsesc_fast_pe.pdf -------------------------------------------------------------------------------- /replication/figs/sparsesc_fast_xf2_pe.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/sparsesc_fast_xf2_pe.pdf -------------------------------------------------------------------------------- /replication/figs/sparsesc_fast_xf_pe.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/sparsesc_fast_xf_pe.pdf -------------------------------------------------------------------------------- /replication/figs/sparsesc_full_pe.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/sparsesc_full_pe.pdf -------------------------------------------------------------------------------- /replication/figs/sparsesc_full_xf2_pe.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/sparsesc_full_xf2_pe.pdf -------------------------------------------------------------------------------- /replication/figs/sparsesc_full_xf_pe.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/sparsesc_full_xf_pe.pdf -------------------------------------------------------------------------------- /replication/figs/standard_flat55_pe.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/standard_flat55_pe.pdf -------------------------------------------------------------------------------- /replication/figs/standard_nested_pe.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/standard_nested_pe.pdf -------------------------------------------------------------------------------- /replication/figs/standard_spfast_pe.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/standard_spfast_pe.pdf -------------------------------------------------------------------------------- /replication/figs/standard_spfast_xf2_pe.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/standard_spfast_xf2_pe.pdf -------------------------------------------------------------------------------- /replication/figs/standard_spfast_xf_pe.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/standard_spfast_xf_pe.pdf -------------------------------------------------------------------------------- /replication/figs/standard_spfull_pe.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/standard_spfull_pe.pdf -------------------------------------------------------------------------------- /replication/figs/standard_spfull_xf2_pe.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/standard_spfull_xf2_pe.pdf -------------------------------------------------------------------------------- /replication/figs/standard_spfull_xf_pe.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/standard_spfull_xf_pe.pdf -------------------------------------------------------------------------------- /replication/smoking.dta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/smoking.dta -------------------------------------------------------------------------------- /replication/vmats/fast_fit.txt: -------------------------------------------------------------------------------- 1 | beer(1986) lnincome(1985) lnincome(1987) lnincome(1988) retprice(1985) age15to24(1988) cigsale(1986) cigsale(1987) cigsale(1988) 2 | 0.5130071259983364 22.968210890267105 66.11994103276801 20.372467610421825 0.3953643337713626 146.42793943085758 0.40076349898165653 1.0914176928934336 1.6124490971940708 3 | -------------------------------------------------------------------------------- /replication/vmats/full_fit.txt: -------------------------------------------------------------------------------- 1 | beer beer(1984) beer(1985) beer(1986) beer(1987) beer(1988) cigsale(1970) cigsale(1971) cigsale(1972) cigsale(1973) cigsale(1974) cigsale(1975) cigsale(1976) cigsale(1977) cigsale(1978) cigsale(1979) cigsale(1980) cigsale(1981) cigsale(1982) cigsale(1983) cigsale(1984) cigsale(1985) cigsale(1986) cigsale(1987) cigsale(1988) 2 | 0.0001931675090059052 0.00011673718456481037 0.00017442025581167862 0.00019492525551082786 0.00023112626647203255 0.0002485464241715504 0.04316485032629001 0.04855303157528119 0.05649902721742458 0.05969122641855661 0.0605678608204393 0.0626985287374379 0.0689133161012072 0.06649196460099964 0.06167250333634058 0.05650907357046557 0.05510538799749318 0.052133342737103204 0.05262319599234059 0.048749408792219784 0.04201710074463961 0.04108083109059892 0.04125358244144677 0.040033436688557196 0.04108340791562128 3 | -------------------------------------------------------------------------------- /replication/vmats/xf_fits_fast2.txt: -------------------------------------------------------------------------------- 1 | 23 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 2 | 23 52.9138366517419 9.288186511419811 0.8048612299517014 0.20594545615924512 0.4714023965271641 0.9417936107673718 1.3465241057644206 3 | 9 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 4 | 9 32.650879763821514 37.02197484931736 0.542540877404274 0.014584754143649414 0.4798681872289629 0.9368634841278982 1.371249208763717 5 | 27 lnincome(1985) lnincome(1987) retprice(1985) cigsale(1986) cigsale(1987) cigsale(1988) 6 | 27 32.004396105062035 30.46557116615753 0.6970722297526234 0.38688970430837605 0.8419576015155998 1.5282063874976552 7 | 14 lnincome(1985) lnincome(1987) retprice(1985) cigsale(1984) cigsale(1987) cigsale(1988) 8 | 14 41.19800297067463 15.510595853489166 0.722493021351017 0.0205302688781237 0.9289118032490858 1.7545389645850142 9 | 25 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 10 | 25 37.216813057528164 25.221489830962824 0.657228392750966 0.02454777512573974 0.43500171319846226 0.8125446272254442 1.5057689475284177 11 | 4 lnincome(1985) lnincome(1987) retprice(1985) cigsale(1986) cigsale(1987) cigsale(1988) 12 | 4 22.033380508626728 37.95607243445006 0.7392112404836514 0.2679160799269573 0.8907623648242043 1.5853183181096129 13 | 2 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 14 | 2 23.75657592484401 41.23483140470908 0.6910275048224122 0.04477245768082374 0.32116721757583316 0.916959715791747 1.5374270606732956 15 | 28 beer(1988) lnincome(1985) lnincome(1987) retprice(1985) cigsale(1986) cigsale(1988) 16 | 28 0.4278699614050563 18.808995663124076 51.15028075063535 0.3855812959197057 0.42320896120863377 2.5603609090753428 17 | 3 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 18 | 3 35.774461358952315 24.42133906654064 0.70260293128289 0.03315716245442422 0.4498448389541703 0.8010006448441934 1.4886819652342287 19 | 6 beer(1986) lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 20 | 6 0.09165879470209615 21.0404041711949 44.783387436919206 0.6866341032510951 0.15078832914078785 0.12407411039828849 1.1316974441477015 1.542310302014068 21 | 5 beer(1988) lnincome(1983) lnincome(1988) retprice(1980) retprice(1985) retprice(1988) age15to24(1988) cigsale(1982) cigsale(1984) cigsale(1986) cigsale(1987) cigsale(1988) 22 | 5 1.441267670831488 16.17283654516924 104.89648163273378 0.09551657763150784 0.3641232956059621 0.33445253937605773 851.0159554272302 0.10935326306525987 0.028820136107687223 0.003648718605989471 1.416359875540445 1.6223994793627619 23 | 16 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 24 | 16 37.614732176886335 27.89280012177113 0.620584501883219 0.08323182778087279 0.4319512962017977 0.8848872277804385 1.4643646690933532 25 | 32 lnincome(1987) retprice(1985) cigsale(1986) cigsale(1987) cigsale(1988) 26 | 32 70.46915458836764 0.5141370149716578 0.7971683142142422 0.8480369389565701 1.2087497747969596 27 | 31 lnincome(1985) lnincome(1987) retprice(1985) cigsale(1986) cigsale(1987) cigsale(1988) 28 | 31 19.951200540111767 38.488804338029084 0.6758079493890541 0.48036987327851277 0.7061985896568476 1.5592920543203923 29 | 34 lnincome(1985) lnincome(1987) retprice(1985) cigsale(1986) cigsale(1987) cigsale(1988) 30 | 34 39.40376183490209 22.354550796059744 0.6888351098369737 0.41524836151985645 0.7825350950415934 1.5554797074795266 31 | 15 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 32 | 15 30.132849920047796 30.779908003822648 0.5972116627569648 0.11990144050863925 0.35550657037139605 0.8863180945932602 1.5035821890610244 33 | 11 beer(1986) lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 34 | 11 0.002505453803009291 26.818904888838087 38.1254682990119 0.6921247832162963 0.025809909197581832 0.3455231672678883 0.8743694551676355 1.5570267410097178 35 | 21 beer(1986) lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 36 | 21 0.4589963686470479 10.137739000768292 63.320284106226666 0.49840518043861753 0.028262507662936767 0.39409111305372047 1.2395339989892469 1.2661476752987006 37 | 17 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 38 | 17 64.5033184509774 9.574472312436525 0.5627863776716775 0.005630207000854121 0.8150684830960727 0.701261488167543 1.2831255659946963 39 | 26 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 40 | 26 53.81482286497893 4.8721745069110325 0.6285881994463606 0.11099082230607583 0.27774834257279696 0.7975338506147123 1.644516378562484 41 | 1 lnincome(1985) lnincome(1987) retprice(1985) cigsale(1986) cigsale(1987) cigsale(1988) 42 | 1 35.321178164904296 18.753395130096887 0.7248123993386513 0.4455704580670278 0.7709306503211723 1.4955447491570548 43 | 10 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 44 | 10 26.172315514404612 40.5066411648697 0.6210816258045354 0.08574858166221508 0.3894929472153092 0.9582031390511888 1.4532274644713794 45 | 33 beer(1986) lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 46 | 33 0.478314791014334 28.231179947143886 41.24021289768206 0.45429005823127383 0.13228216707157708 0.31092784094021314 1.1063592582571 1.357207450058452 47 | 7 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 48 | 7 38.44525779541849 22.241683904867678 0.6672100263119967 0.00020003168746482707 0.4358017018642251 0.790338563123268 1.5019443042961558 49 | 19 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 50 | 19 33.33300260657349 28.87629219649373 0.646379658811407 0.03824224699002728 0.3994839059986097 0.7982194807630935 1.5626017823817469 51 | 22 lnincome(1983) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 52 | 22 9.828161509868865 57.55044358391577 0.48095435298865974 0.2830639710490712 0.33786416421453 0.8822333421833521 1.4843815147743358 53 | 24 lnincome(1988) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 54 | 24 61.8693982343369 0.6118933951108044 0.07169034659486957 0.24410578816292353 0.9136417430117386 1.616141985352236 55 | 36 lnincome(1985) lnincome(1987) lnincome(1988) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 56 | 36 14.475861078773029 26.943137832098724 15.92435359735392 0.9875744643350647 0.041163824055681746 0.16873475155440118 1.21535634363386 1.4074242537938177 57 | 30 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 58 | 30 33.17457684604746 29.073985497061805 0.5162346705854688 0.11486578771775793 0.4982760481771537 0.8095874955348279 1.4697053562688922 59 | 12 lnincome(1984) lnincome(1987) retprice(1985) cigsale(1987) cigsale(1988) 60 | 12 26.92911058554163 6.155518834714997 0.6285961030571781 1.0346295966649277 1.4340086985377358 61 | 18 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 62 | 18 19.45223698594143 41.78134781747795 0.6792763549942265 0.003319460377707231 0.24435409463401464 0.9326189526484684 1.5990422694981496 63 | 38 lnincome(1985) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 64 | 38 72.4455632990079 0.4853074391447769 0.10797260029294552 0.3127912565447208 1.149499721947342 1.3325911487315836 65 | 35 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 66 | 35 21.367246762453107 25.873179715901948 0.9101086883278595 0.17639448181510284 0.1591127187569252 0.9462375093430284 1.5790071062927429 67 | 13 lnincome(1985) lnincome(1987) retprice(1985) cigsale(1986) cigsale(1987) cigsale(1988) 68 | 13 45.663501999478804 10.455542220073793 0.8321049818628562 0.34649744106868596 0.7978118480212745 1.5693247671585517 69 | 37 beer(1986) lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 70 | 37 0.20369210499860185 28.583256638938593 35.23310804285026 0.7322974820605626 0.013874845991732802 0.44732514640064847 0.8581754884927489 1.4884211448589109 71 | 29 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988) 72 | 29 35.18711223047968 26.12276075274429 0.699080302144707 0.06147429869908371 0.3675090246616547 0.8458766224462938 1.5266674252392296 73 | 20 lnincome(1987) lnincome(1988) retprice(1985) retprice(1988) cigsale(1987) cigsale(1988) 74 | 20 59.399857194938505 10.616584838603158 0.24505149601082157 0.0521000199791986 1.3759206906091328 1.580541231282191 75 | 8 lnincome(1985) retprice(1985) cigsale(1986) cigsale(1987) cigsale(1988) 76 | 8 48.170218174935776 0.7191122125984405 0.7690613403578862 0.44355310614471416 1.4764919171905777 77 | -------------------------------------------------------------------------------- /replication/vmats/xf_fits_full.txt: -------------------------------------------------------------------------------- 1 | 23 cigsale(1977) cigsale(1986) cigsale(1987) cigsale(1988) 2 | 23 3.220031805241539e-05 0.0004541363359538803 0.0003971751428356052 0.0016313917361968722 3 | 9 cigsale(1986) cigsale(1987) cigsale(1988) 4 | 9 0.0003952631426078369 0.0003601154537986567 0.0011483353677890413 5 | 27 cigsale(1977) cigsale(1986) cigsale(1987) cigsale(1988) 6 | 27 1.8871490715416953e-05 0.0004399069826426927 0.00043629534490380686 0.0010540606594254798 7 | 14 cigsale(1986) cigsale(1987) cigsale(1988) 8 | 14 0.0003229640528205937 0.0003688476045066298 0.000956740279593809 9 | 25 cigsale(1986) cigsale(1987) cigsale(1988) 10 | 25 0.0004202607334967443 0.0004268848440616784 0.0010671839324497162 11 | 4 cigsale(1986) cigsale(1987) cigsale(1988) 12 | 4 0.0003196300849811584 0.00032930779514542914 0.0016052914489879876 13 | 2 cigsale(1986) cigsale(1987) cigsale(1988) 14 | 2 0.00038734543047050905 0.00037124265455011347 0.0010588787288313994 15 | 28 cigsale(1982) cigsale(1986) cigsale(1988) 16 | 28 1.024800187075679e-05 0.00024196009018830406 0.0025578912517229798 17 | 3 cigsale(1986) cigsale(1987) cigsale(1988) 18 | 3 0.00040220420600810497 0.0003571963463967961 0.000890850495876067 19 | 6 cigsale(1977) cigsale(1986) cigsale(1987) cigsale(1988) 20 | 6 1.2390431604427915e-07 0.00046409781938842903 0.0002743527155968748 0.0017239069267463766 21 | 5 cigsale(1977) cigsale(1986) cigsale(1987) cigsale(1988) 22 | 5 2.7787923373098215e-05 0.0003434544188081404 0.0003422614209224295 0.0008778496833327705 23 | 16 cigsale(1986) cigsale(1987) cigsale(1988) 24 | 16 0.0003784240793891911 0.0003674397411472507 0.0009145390885686904 25 | 32 cigsale(1982) cigsale(1983) cigsale(1986) cigsale(1987) cigsale(1988) 26 | 32 2.7245185321026544e-05 2.655913644964166e-05 0.0005845624034860476 0.00046112657110766037 0.0010158811024341598 27 | 31 cigsale(1977) cigsale(1982) cigsale(1986) cigsale(1987) cigsale(1988) 28 | 31 0.00018248850627347452 1.4800388877421827e-05 0.0004229485947024648 0.0003769731202171168 0.0010430100065317626 29 | 34 cigsale(1977) cigsale(1982) cigsale(1986) cigsale(1987) cigsale(1988) 30 | 34 0.00013728124270578521 3.721020977056891e-05 0.0004237767444671827 0.00044270188843288644 0.0012123166239067715 31 | 15 cigsale(1985) cigsale(1986) cigsale(1987) cigsale(1988) 32 | 15 1.494300101581489e-05 0.0003595505300900124 0.00041894917648444024 0.0008757901102167003 33 | 11 cigsale(1977) cigsale(1986) cigsale(1987) cigsale(1988) 34 | 11 9.333161499217062e-06 0.0003719157586207366 0.00033404170597104007 0.0009539006346480628 35 | 21 cigsale(1986) cigsale(1987) cigsale(1988) 36 | 21 0.00045315631293933257 0.00038923440996003735 0.0013524841267793333 37 | 17 cigsale(1986) cigsale(1987) cigsale(1988) 38 | 17 0.0003852429561616505 0.00033760087204227126 0.000882231757549232 39 | 26 cigsale(1977) cigsale(1982) cigsale(1983) cigsale(1986) cigsale(1987) cigsale(1988) 40 | 26 0.00011592293176551638 7.637050088258385e-06 3.2884018372150726e-06 0.00042086462533162314 0.00042901457399833744 0.00114115134113562 41 | 1 cigsale(1986) cigsale(1987) cigsale(1988) 42 | 1 0.0003869479555100085 0.00036510350377175503 0.0010168374051120169 43 | 10 cigsale(1986) cigsale(1987) cigsale(1988) 44 | 10 0.0003878740099791977 0.00035772101520275027 0.0009611957205379999 45 | 33 cigsale(1986) cigsale(1987) cigsale(1988) 46 | 33 0.00035909168155274026 0.0004524634847694947 0.0009077421449587082 47 | 7 cigsale(1986) cigsale(1987) cigsale(1988) 48 | 7 0.00036534733256354617 0.00036294902200991505 0.0009171959812324014 49 | 19 cigsale(1986) cigsale(1987) cigsale(1988) 50 | 19 0.00037962900054159084 0.00035386606609022855 0.0009644690512068232 51 | 22 cigsale(1986) cigsale(1987) cigsale(1988) 52 | 22 0.0004390091607431598 0.0003705713962368925 0.0010729243300373226 53 | 24 cigsale(1986) cigsale(1987) cigsale(1988) 54 | 24 0.00043586384804969745 0.0004063534938212729 0.0011187371006597577 55 | 36 cigsale(1985) cigsale(1986) cigsale(1987) cigsale(1988) 56 | 36 1.2467606071326915e-05 0.000441784626549294 0.0005516730720741426 0.0011395891829565716 57 | 30 cigsale(1986) cigsale(1987) cigsale(1988) 58 | 30 0.0004292474615245855 0.00043906072446289914 0.0009789791169024593 59 | 12 cigsale(1986) cigsale(1987) cigsale(1988) 60 | 12 0.00015634756678617497 0.00046524598371104576 0.0010862701532716264 61 | 18 cigsale(1986) cigsale(1987) cigsale(1988) 62 | 18 0.0004345321289109919 0.00021028873110012252 0.0018020174936725004 63 | 38 cigsale(1986) cigsale(1987) cigsale(1988) 64 | 38 0.00040178596776896587 0.0005333706872139673 0.001150960005824483 65 | 35 cigsale(1977) cigsale(1986) cigsale(1987) cigsale(1988) 66 | 35 6.88134845640188e-06 0.0004183524461092219 0.00045105353956890554 0.0010166728003982875 67 | 13 cigsale(1986) cigsale(1987) cigsale(1988) 68 | 13 0.0003543644684414958 0.0003600256236189256 0.0009416488053131033 69 | 37 cigsale(1986) cigsale(1987) cigsale(1988) 70 | 37 0.00042587151170159475 0.0004661787252079279 0.0011675865092952683 71 | 29 cigsale(1986) cigsale(1987) cigsale(1988) 72 | 29 0.000481981960140239 0.0004486368134434901 0.0011979810228173145 73 | 20 cigsale(1977) cigsale(1986) cigsale(1987) cigsale(1988) 74 | 20 2.008220371182916e-05 0.0002263114136367005 0.0004515698795582304 0.001129838472899092 75 | 8 cigsale(1986) cigsale(1987) cigsale(1988) 76 | 8 0.0004636874855871064 0.0002868862186762217 0.0010817068639995256 77 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | from __future__ import absolute_import 4 | from __future__ import print_function 5 | import re 6 | from glob import glob 7 | from os.path import basename, dirname, join, splitext, abspath 8 | from setuptools import find_packages, setup 9 | import codecs 10 | 11 | # Allow single version in source file to be used here 12 | # From https://packaging.python.org/guides/single-sourcing-package-version/ 13 | def read(*parts): 14 | # intentionally *not* adding an encoding option to open 15 | # see here: https://github.com/pypa/virtualenv/issues/201#issuecomment-3145690 16 | here = abspath(dirname(__file__)) 17 | return codecs.open(join(here, *parts), "r").read() 18 | 19 | 20 | def find_version(*file_paths): 21 | version_file = read(*file_paths) 22 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) 23 | if version_match: 24 | return version_match.group(1) 25 | raise RuntimeError("Unable to find version string.") 26 | 27 | 28 | setup( 29 | name="SparseSC", 30 | version=find_version("src", "SparseSC", "__init__.py"), 31 | description="Sparse Synthetic Controls", 32 | license="MIT", 33 | long_description="%s\n%s" 34 | % ( 35 | re.compile("^.. start-badges.*^.. end-badges", re.M | re.S).sub( 36 | "", read("README.md") 37 | ), 38 | re.sub(":[a-z]+:`~?(.*?)`", r"``\1``", read("CHANGELOG.md")), 39 | ), 40 | long_description_content_type="text/markdown", 41 | author="Microsoft Research", 42 | url="https://github.com/Microsoft/SparseSyntheticControls", 43 | packages=find_packages("src"), 44 | package_dir={"": "src"}, 45 | py_modules=[splitext(basename(path))[0] for path in glob("src/*.py")], 46 | include_package_data=True, 47 | zip_safe=False, 48 | classifiers=[ 49 | # complete classifier list: http://pypi.python.org/pypi?%3Aaction=list_classifiers 50 | "Development Status :: 5 - Production/Stable", 51 | "Intended Audience :: Developers", 52 | "License :: OSI Approved :: BSD License", 53 | "Operating System :: Unix", 54 | "Operating System :: POSIX", 55 | "Operating System :: Microsoft :: Windows", 56 | "Programming Language :: Python", 57 | "Programming Language :: Python :: 3", 58 | "Programming Language :: Python :: 3.4", 59 | "Programming Language :: Python :: 3.5", 60 | "Programming Language :: Python :: 3.6", 61 | "Programming Language :: Python :: 3.7", 62 | "Programming Language :: Python :: Implementation :: CPython", 63 | "Programming Language :: Python :: Implementation :: PyPy", 64 | "Topic :: Utilities", 65 | ], 66 | keywords=["Sparse", "Synthetic", "Controls"], 67 | install_requires=["numpy", "Scipy", "scikit-learn", "pandas", "pyyaml"], 68 | entry_points={ 69 | "console_scripts": [ 70 | "scgrad=SparseSC.cli.scgrad:main", 71 | "daemon=SparseSC.cli.daemon_process:main", 72 | "stt=SparseSC.cli.stt:main", 73 | ] 74 | }, 75 | ) 76 | -------------------------------------------------------------------------------- /src/SparseSC/SparseSC.pyproj: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | Debug 5 | 2.0 6 | {3fdc664a-c1c8-47f0-8d77-8e0679e53c82} 7 | 8 | 9 | 10 | ..\ 11 | . 12 | . 13 | {888888a0-9f3d-457c-b088-3a5042f75d52} 14 | Standard Python launcher 15 | 16 | 17 | 18 | 19 | 20 | 10.0 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | Code 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | Code 52 | 53 | 54 | Code 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /src/SparseSC/__init__.py: -------------------------------------------------------------------------------- 1 | """ Public API for SparseSC 2 | """ 3 | 4 | # PRIMARY FITTING FUNCTIONS 5 | from SparseSC.fit import fit, TrivialUnitsWarning 6 | from SparseSC.fit_fast import fit_fast, _fit_fast_inner, _fit_fast_match 7 | from SparseSC.utils.match_space import ( 8 | keras_reproducible, MTLassoCV_MatchSpace_factory, MTLasso_MatchSpace_factory, MTLassoMixed_MatchSpace_factory, MTLSTMMixed_MatchSpace_factory, 9 | Fixed_V_factory, D_LassoCV_MatchSpace_factory 10 | ) 11 | from SparseSC.utils.penalty_utils import RidgeCVSolution 12 | from SparseSC.fit_loo import loo_v_matrix, loo_weights, loo_score 13 | from SparseSC.fit_ct import ct_v_matrix, ct_weights, ct_score 14 | 15 | # ESTIMATION FUNCTIONS 16 | from SparseSC.estimate_effects import estimate_effects, get_c_predictions_honest 17 | from SparseSC.utils.dist_summary import SSC_DescrStat, Estimate 18 | from SparseSC.utils.descr_sets import DescrSet, MatchingEstimate 19 | 20 | # Public API 21 | from SparseSC.cross_validation import ( 22 | score_train_test, 23 | score_train_test_sorted_v_pens, 24 | CV_score, 25 | ) 26 | from SparseSC.tensor import tensor 27 | from SparseSC.weights import weights 28 | from SparseSC.utils.penalty_utils import get_max_w_pen, get_max_v_pen, w_pen_guestimate 29 | 30 | # The version as used in the setup.py 31 | __version__ = "0.2.0" 32 | -------------------------------------------------------------------------------- /src/SparseSC/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/src/SparseSC/cli/__init__.py -------------------------------------------------------------------------------- /src/SparseSC/cli/daemon.py: -------------------------------------------------------------------------------- 1 | """Generic linux daemon base class for python 3.x.""" 2 | # pylint: disable=multiple-imports 3 | 4 | import sys, os, time, atexit, signal 5 | 6 | 7 | class Daemon: 8 | """A generic daemon class. 9 | 10 | Usage: subclass the daemon class and override the run() method. 11 | """ 12 | 13 | def __init__(self, pidfile, fifofile): 14 | self.pidfile = pidfile 15 | self.fifofile = fifofile 16 | 17 | def daemonize(self): 18 | """Deamonize class. UNIX double fork mechanism.""" 19 | 20 | try: 21 | pid = os.fork() 22 | if pid > 0: 23 | # exit first parent 24 | sys.exit(0) 25 | except OSError as err: 26 | sys.stderr.write("fork #1 failed: {0}\n".format(err)) 27 | sys.exit(1) 28 | 29 | # decouple from parent environment 30 | os.chdir("/") 31 | os.setsid() 32 | os.umask(0) 33 | 34 | # do second fork 35 | try: 36 | pid = os.fork() 37 | if pid > 0: 38 | 39 | # exit from second parent 40 | sys.exit(0) 41 | except OSError as err: 42 | sys.stderr.write("fork #2 failed: {0}\n".format(err)) 43 | sys.exit(1) 44 | 45 | # redirect standard file descriptors 46 | sys.stdout.flush() 47 | sys.stderr.flush() 48 | si = open(os.devnull, "r") 49 | so = open(os.devnull, "a+") 50 | se = open(os.devnull, "a+") 51 | 52 | os.dup2(si.fileno(), sys.stdin.fileno()) 53 | os.dup2(so.fileno(), sys.stdout.fileno()) 54 | os.dup2(se.fileno(), sys.stderr.fileno()) 55 | 56 | # write pidfile 57 | atexit.register(self.delpid) 58 | 59 | os.mkfifo(self.fifofile) # pylint: disable=no-member 60 | 61 | pid = str(os.getpid()) 62 | with open(self.pidfile, "w+") as f: 63 | f.write(pid + "\n") 64 | 65 | workdir = os.getenv("AZ_BATCH_TASK_WORKING_DIR","/tmp") 66 | sys.stdout = open(os.path.join(workdir,"sc-out.txt"),"a+") 67 | sys.stderr = open(os.path.join(workdir,"sc-err.txt"),"a+") 68 | print("daemonized >>>>>>>>>>>"); sys.stdout.flush() 69 | 70 | def delpid(self): 71 | " clean up" 72 | try: 73 | os.remove(self.pidfile) 74 | except: 75 | print("failed to remove pidfile"); sys.stdout.flush() 76 | raise RuntimeError("failed to remove pidfile") 77 | else: 78 | print("removed pidfile"); sys.stdout.flush() 79 | try: 80 | os.remove(self.fifofile) 81 | except: 82 | print("failed to remove fifofile"); sys.stdout.flush() 83 | raise RuntimeError("failed to remove pidfile") 84 | else: 85 | print("removed fifofile"); sys.stdout.flush() 86 | 87 | 88 | 89 | def start(self): 90 | """Start the daemon.""" 91 | 92 | # Check for a pidfile to see if the daemon already runs 93 | try: 94 | with open(self.pidfile, "r") as pf: 95 | 96 | pid = int(pf.read().strip()) 97 | except IOError: 98 | pid = None 99 | 100 | if pid: 101 | message = "pidfile {0} already exist. " + "Daemon already running?\n" 102 | sys.stderr.write(message.format(self.pidfile)) 103 | sys.exit(1) 104 | 105 | # Start the daemon 106 | self.daemonize() 107 | self.run() 108 | 109 | def stop(self): 110 | """Stop the daemon.""" 111 | 112 | # Get the pid from the pidfile 113 | try: 114 | with open(self.pidfile, "r") as pf: 115 | pid = int(pf.read().strip()) 116 | except IOError: 117 | pid = None 118 | 119 | if not pid: 120 | message = "pidfile {0} does not exist. " + "Daemon not running?\n" 121 | sys.stderr.write(message.format(self.pidfile)) 122 | return # not an error in a restart 123 | 124 | # Try killing the daemon process 125 | try: 126 | while 1: 127 | os.kill(pid, signal.SIGTERM) 128 | time.sleep(0.1) 129 | except OSError as err: 130 | e = str(err.args) 131 | if e.find("No such process") > 0: 132 | if os.path.exists(self.pidfile): 133 | os.remove(self.pidfile) 134 | if os.path.exists(self.fifofile): 135 | os.remove(self.fifofile) 136 | else: 137 | print(str(err.args)) 138 | sys.exit(1) 139 | 140 | def restart(self): 141 | """Restart the daemon.""" 142 | self.stop() 143 | self.start() 144 | 145 | def run(self): 146 | """You should override this method when you subclass Daemon. 147 | 148 | It will be called after the process has been daemonized by 149 | start() or restart().""" 150 | -------------------------------------------------------------------------------- /src/SparseSC/cli/daemon_process.py: -------------------------------------------------------------------------------- 1 | """ 2 | A modeul to run as a background process 3 | 4 | pip install psutil 5 | 6 | usage: 7 | python -m SparseSC.cli.daemon_process start 8 | python -m SparseSC.cli.daemon_process stop 9 | python -m SparseSC.cli.daemon_process status 10 | """ 11 | # pylint: disable=multiple-imports 12 | 13 | import sys, os, time, atexit, signal, json,psutil 14 | from yaml import load, dump 15 | try: 16 | from yaml import CLoader as Loader, CDumper as Dumper 17 | except ImportError: 18 | from yaml import Loader, Dumper 19 | 20 | from .scgrad import grad_part, DAEMON_FIFO, DAEMON_PID, _BASENAMES 21 | 22 | #-- DAEMON_FIFO = "/tmp/sc-daemon.fifo" 23 | #-- DAEMON_PID = "/tmp/sc-gradient-daemon.pid" 24 | #-- 25 | #-- _CONTAINER_OUTPUT_FILE = "output.yaml" # Standard Output file 26 | #-- _GRAD_COMMON_FILE = "common.yaml" 27 | #-- _GRAD_PART_FILE = "part.yaml" 28 | #-- 29 | #-- _BASENAMES = [_GRAD_COMMON_FILE, _GRAD_PART_FILE, _CONTAINER_OUTPUT_FILE] 30 | 31 | 32 | 33 | pidfile, fifofile = DAEMON_PID, DAEMON_FIFO 34 | 35 | 36 | def stop(): 37 | """Stop the daemon.""" 38 | 39 | # Get the pid from the pidfile 40 | try: 41 | with open(pidfile, "r") as pf: 42 | pid = int(pf.read().strip()) 43 | except IOError: 44 | pid = None 45 | 46 | if not pid: 47 | message = "pidfile {0} does not exist. " + "Daemon not running?\n" 48 | sys.stderr.write(message.format(pidfile)) 49 | return # not an error in a restart 50 | 51 | # Try killing the daemon process 52 | try: 53 | while 1: 54 | os.kill(pid, signal.SIGTERM) 55 | time.sleep(0.1) 56 | except OSError as err: 57 | e = str(err.args) 58 | if e.find("No such process") > 0: 59 | if os.path.exists(pidfile): 60 | os.remove(pidfile) 61 | if os.path.exists(fifofile): 62 | os.remove(fifofile) 63 | else: 64 | print(str(err.args)) 65 | sys.exit(1) 66 | 67 | 68 | def run(): 69 | """ 70 | do work 71 | """ 72 | while True: 73 | with open(DAEMON_FIFO, "r") as fifo: 74 | try: 75 | params = fifo.read() 76 | print("params: " + params) 77 | sys.stdout.flush() 78 | tmpdirname, return_fifo, k = json.loads(params) 79 | print(_BASENAMES) 80 | for file in os.listdir(tmpdirname): 81 | print(file) 82 | common_file, part_file, out_file = [ 83 | os.path.join(tmpdirname, name) for name in _BASENAMES 84 | ] 85 | print([common_file, part_file, out_file, return_fifo, k]) 86 | sys.stdout.flush() 87 | 88 | # LOAD IN THE INPUT FILES 89 | with open(common_file, "r") as fp: 90 | common = load(fp, Loader=Loader) 91 | with open(part_file, "r") as fp: 92 | part = load(fp, Loader=Loader) 93 | 94 | # DO THE WORK 95 | print("about to do work: ") 96 | sys.stdout.flush() 97 | grad = grad_part(common, part, int(k)) 98 | print("did work: ") 99 | sys.stdout.flush() 100 | 101 | # DUMP THE RESULT TO THE OUTPUT FILE 102 | with open(out_file, "w") as fp: 103 | fp.write(dump(grad, Dumper=Dumper)) 104 | 105 | except Exception as err: # pylint: disable=broad-except 106 | 107 | # SOMETHING WENT WRONG, RESPOND WITH A NON-ZERO 108 | try: 109 | with open(return_fifo, "w") as rf: 110 | rf.write("1") 111 | except: # pylint: disable=bare-except 112 | print("double failed...: ") 113 | sys.stdout.flush() 114 | else: 115 | print( 116 | "failed with {}: {}", 117 | err.__class__.__name__, 118 | getattr(err, "message", "<>"), 119 | ) 120 | sys.stdout.flush() 121 | 122 | else: 123 | # SEND THE SUCCESS RESPONSE 124 | print("success...: ") 125 | sys.stdout.flush() 126 | with open(return_fifo, "w") as rf: 127 | rf.write("0") 128 | print("and wrote about it...: ") 129 | sys.stdout.flush() 130 | 131 | 132 | def start(): 133 | """ 134 | start the process_deamon if it's not already started 135 | """ 136 | 137 | # IS THE DAEMON ALREADY RUNNING? 138 | try: 139 | with open(pidfile, "r") as pf: 140 | pid = int(pf.read().strip()) 141 | except IOError: 142 | pid = None 143 | 144 | if pid: 145 | message = "pidfile {0} already exist. " + "Daemon already running?\n" 146 | sys.stderr.write(message.format(pidfile)) 147 | return 148 | 149 | def delpid(): 150 | " clean up" 151 | try: 152 | os.remove(pidfile) 153 | except: 154 | print("failed to remove pidfile"); sys.stdout.flush() 155 | raise RuntimeError("failed to remove pidfile") 156 | else: 157 | print("removed pidfile"); sys.stdout.flush() 158 | try: 159 | os.remove(fifofile) 160 | except: 161 | print("failed to remove fifofile"); sys.stdout.flush() 162 | raise RuntimeError("failed to remove pidfile") 163 | else: 164 | print("removed fifofile"); sys.stdout.flush() 165 | atexit.register(delpid) 166 | 167 | os.mkfifo(fifofile) # pylint: disable=no-member 168 | 169 | pid = str(os.getpid()) 170 | with open(pidfile, "w+") as f: 171 | f.write(pid + "\n") 172 | 173 | workdir = os.getenv("AZ_BATCH_TASK_WORKING_DIR","/tmp") 174 | 175 | 176 | # """Start the daemon.""" 177 | sys.stdout = open(os.path.join(workdir,"sc-ps-out.txt"),"a+") 178 | sys.stderr = open(os.path.join(workdir,"sc-ps-err.txt"),"a+") 179 | print("process started >>>>>>>>>>>"); sys.stdout.flush() 180 | run() 181 | 182 | 183 | def status(): 184 | """ 185 | check the process_deamon status 186 | """ 187 | if not os.path.exists(DAEMON_PID): 188 | print("Daemon not running") 189 | return 190 | with open(DAEMON_PID,'r') as fh: 191 | _pid = int(fh.read()) 192 | 193 | if _pid in psutil.pids(): 194 | print("daemon process (pid {}) is running".format(_pid)) 195 | else: 196 | print("daemon process (pid {}) NOT is running".format(_pid)) 197 | 198 | def main(): 199 | ARGS = sys.argv[1:] 200 | 201 | if not ARGS: 202 | print("no args provided") 203 | elif ARGS[0] == "trystart": 204 | start() 205 | elif ARGS[0] == "start": 206 | start() 207 | elif ARGS[0] == "stop": 208 | stop() 209 | elif ARGS[0] == "status": 210 | status() 211 | else: 212 | print("unknown command '{}'".format(ARGS[0])) 213 | 214 | 215 | if __name__ == "__main__": 216 | main() 217 | -------------------------------------------------------------------------------- /src/SparseSC/cli/stt.py: -------------------------------------------------------------------------------- 1 | """ 2 | somethings something 3 | """ 4 | # pylint: disable=invalid-name, unused-import 5 | import sys 6 | import numpy 7 | from yaml import load, dump 8 | 9 | try: 10 | from yaml import CLoader as Loader, CDumper as Dumper 11 | except ImportError: 12 | from yaml import Loader, Dumper 13 | 14 | from SparseSC.cross_validation import score_train_test 15 | 16 | 17 | def get_config(infile): 18 | """ 19 | read in the contents of the inputs yaml file 20 | """ 21 | 22 | with open(infile, "r") as fp: 23 | config = load(fp, Loader=Loader) 24 | try: 25 | v_pen = tuple(config["v_pen"]) 26 | except TypeError: 27 | v_pen = (config["v_pen"],) 28 | 29 | try: 30 | w_pen = tuple(config["w_pen"]) 31 | except TypeError: 32 | w_pen = (config["w_pen"],) 33 | 34 | return v_pen, w_pen, config 35 | 36 | 37 | def main(): 38 | # GET THE COMMAND LINE ARGS 39 | ARGS = sys.argv[1:] 40 | if ARGS[0] == "ssc.py": 41 | ARGS.pop(0) 42 | assert ( 43 | len(ARGS) == 3 44 | ), "ssc.py expects 2 parameters, including a file name and a batch number" 45 | infile, outfile, batchNumber = ARGS 46 | batchNumber = int(batchNumber) 47 | 48 | v_pen, w_pen, config = get_config(infile) 49 | n_folds = len(config["folds"]) * len(v_pen) * len(w_pen) 50 | 51 | assert 0 <= batchNumber < n_folds, "Batch number out of range" 52 | i_fold = batchNumber % len(config["folds"]) 53 | i_v = (batchNumber // len(config["folds"])) % len(v_pen) 54 | i_w = (batchNumber // len(config["folds"])) // len(v_pen) 55 | 56 | params = config.copy() 57 | del params["folds"] 58 | del params["v_pen"] 59 | del params["w_pen"] 60 | 61 | train, test = config["folds"][i_fold] 62 | out = score_train_test( 63 | train=train, test=test, v_pen=v_pen[i_v], w_pen=w_pen[i_w], **params 64 | ) 65 | 66 | with open(outfile, "w") as fp: 67 | fp.write( 68 | dump( 69 | { 70 | "batch": batchNumber, 71 | "i_fold": i_fold, 72 | "i_v": i_v, 73 | "i_w": i_w, 74 | "results": out, 75 | }, 76 | Dumper=Dumper, 77 | ) 78 | ) 79 | 80 | 81 | if __name__ == "__main__": 82 | import sys 83 | try: 84 | main() 85 | except: # catch *all* exceptions 86 | e = sys.exc_info()[0] 87 | print( "STT Error: %s" % e ) 88 | 89 | -------------------------------------------------------------------------------- /src/SparseSC/optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/src/SparseSC/optimizers/__init__.py -------------------------------------------------------------------------------- /src/SparseSC/optimizers/simplex_step.py: -------------------------------------------------------------------------------- 1 | """ 2 | Gradient descent within the simplex 3 | 4 | Method inspired by the qustion "what would a water droplet stuckin insise the 5 | positive simplex go when pulled in the direction of the gradient which would 6 | otherwise take the droplet outside of the simplex 7 | """ 8 | # pylint: disable=invalid-name 9 | import numpy as np 10 | from numpy import ( 11 | array, 12 | append, 13 | arange, 14 | cumsum, 15 | random, 16 | zeros, 17 | ones, 18 | logical_or, 19 | logical_not, 20 | logical_and, 21 | maximum, 22 | argmin, 23 | where, 24 | sort, 25 | ) 26 | 27 | 28 | def _sub_simplex_project(d_hat, indx): 29 | """ 30 | A utility function which projects the gradient into the subspace of the 31 | simplex which intersects the plane x[_index] == 0 32 | """ 33 | # now project the gradient perpendicular to the edge we just came up against 34 | n = len(d_hat) 35 | _n = float(n) 36 | a_dot_a = (n - 1) / _n 37 | a_tilde = -ones(n) / _n 38 | a_tilde[indx] += 1 # plus a' 39 | proj_a_d = (d_hat.dot(a_tilde) / a_dot_a) * a_tilde 40 | d_tilde = d_hat - proj_a_d 41 | return d_tilde 42 | 43 | 44 | def simplex_step(x, g, verbose=False): 45 | """ 46 | follow the gradint as far as you can within the positive simplex 47 | """ 48 | i = 0 49 | x, g = x.copy(), g.copy() 50 | # project the gradient into the simplex 51 | g = g - (g.sum() / len(x)) * ones(len(g)) 52 | _g = g.copy() 53 | while True: 54 | if verbose: 55 | print("iter: %s, g: %s" % (i, g)) 56 | # we can move in the direction of the gradient if either 57 | # (a) the gradient points away from the axis 58 | # (b) we're not yet touching the axis 59 | valid_directions = logical_or(g < 0, x > 0) 60 | if verbose: 61 | print( 62 | " valid_directions(%s, %s, %s): %s " 63 | % ( 64 | valid_directions.sum(), 65 | (g < 0).sum(), 66 | (x > 0).sum(), 67 | ", ".join(str(x) for x in valid_directions), 68 | ) 69 | ) 70 | if not valid_directions.any(): 71 | break 72 | if any(g[logical_not(valid_directions)] != 0): 73 | # TODO: make sure is is invariant on the order of operations 74 | n_valid = where(valid_directions)[0] 75 | W = where(logical_not(valid_directions))[0] 76 | for i, _w in enumerate(W): 77 | # TODO: Project the invalid directions into the current (valid) subspace of 78 | # the simplex 79 | mask = append(array(_w), n_valid) 80 | # print("work in progress") 81 | g[mask] = _sub_simplex_project(g[mask], 0) 82 | g[_w] = 0 # may not be exactly zero due to rounding error 83 | # rounding error can take us out of the simplex positive orthant: 84 | g = maximum(0, g) 85 | if (g == zeros(len(g))).all(): 86 | # we've arrived at a corner and the gradient points outside the constrained simplex 87 | break 88 | # HOW FAR CAN WE GO? 89 | limit_directions = logical_and(valid_directions, g > 0) 90 | xl = x[limit_directions] 91 | gl = g[limit_directions] 92 | ratios = xl / gl 93 | try: 94 | c = ratios.min() 95 | except: 96 | import pdb 97 | 98 | pdb.set_trace() 99 | if c > 1: 100 | x = x - g 101 | # pdb.set_trace() 102 | break 103 | arange(len(g)) 104 | indx = argmin(ratios) 105 | # MOVE 106 | # there's gotta be a better way... 107 | _indx = where(limit_directions)[0][indx] 108 | tmp = -ones(len(x)) 109 | tmp[valid_directions] = arange(valid_directions.sum()) 110 | __indx = int(tmp[_indx]) 111 | # get the index 112 | del xl, gl, ratios 113 | x = x - c * g 114 | # PROJECT THE GRADIENT 115 | d_tilde = _sub_simplex_project(g[valid_directions] * (1 - c), __indx) 116 | if verbose: 117 | print( 118 | "i: %s, which: %s, g.sum(): %f, x.sum(): %f, x[i]: %f, g[i]: %f, d_tilde[i]: %f" 119 | % ( 120 | i, 121 | indx, 122 | g.sum(), 123 | x.sum(), 124 | x[valid_directions][__indx], 125 | g[valid_directions][__indx], 126 | d_tilde[__indx], 127 | ) 128 | ) 129 | g[valid_directions] = d_tilde 130 | # handle rounding error... 131 | x[_indx] = 0 132 | g[_indx] = 0 133 | # INCREMENT THE COUNTER 134 | i += 1 135 | if i > len(x): 136 | raise RuntimeError("something went wrong") 137 | return x 138 | 139 | def simplex_step_proj_sort(x, g, verbose=False): 140 | x_new = simplex_proj_sort(x-g) 141 | return x_new 142 | 143 | #There's a fast version which uses the median finding algorithm rather than full sorting, but more complicated 144 | #See https://en.wikipedia.org/wiki/Simplex#Projection_onto_the_standard_simplex 145 | # and https://gist.github.com/mblondel/6f3b7aaad90606b98f71 146 | def simplex_proj_sort(v, verbose=False): 147 | k = v.shape[0] 148 | if k == 1: 149 | return np.array([1]) 150 | 151 | u = sort(v)[::-1] #switches the order 152 | ind = arange(1, k+1) #shift to 1-indexing 153 | pis = (cumsum(u) - 1) / ind 154 | rho = ind[(u - pis) > 0][-1] 155 | theta = pis[rho-1] #shift back to 0-indexing 156 | v_new = maximum(v - theta, 0) 157 | 158 | return v_new 159 | -------------------------------------------------------------------------------- /src/SparseSC/tensor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Calculates the tensor (V) matrix which puts the metric on the covariate space 3 | """ 4 | from SparseSC.fit_fold import fold_v_matrix 5 | from SparseSC.fit_loo import loo_v_matrix 6 | from SparseSC.fit_ct import ct_v_matrix 7 | import numpy as np 8 | 9 | 10 | def tensor(X, Y, X_treat=None, Y_treat=None, grad_splits=None, **kwargs): 11 | """ Presents a unified api for ct_v_matrix and loo_v_matrix 12 | """ 13 | # PARAMETER QC 14 | try: 15 | X = np.float64(X) 16 | except ValueError: 17 | raise ValueError("X is not coercible to float64") 18 | try: 19 | Y = np.float64(Y) 20 | except ValueError: 21 | raise ValueError("Y is not coercible to float64") 22 | 23 | Y = np.asmatrix(Y) # this needs to be deprecated properly -- bc Array.dot(Array) != matrix(Array).dot(matrix(Array)) -- not even close !!! 24 | X = np.asmatrix(X) 25 | 26 | if X.shape[1] == 0: 27 | raise ValueError("X.shape[1] == 0") 28 | if Y.shape[1] == 0: 29 | raise ValueError("Y.shape[1] == 0") 30 | if X.shape[0] != Y.shape[0]: 31 | raise ValueError( 32 | "X and Y have different number of rows (%s and %s)" 33 | % (X.shape[0], Y.shape[0]) 34 | ) 35 | 36 | if (X_treat is None) != (Y_treat is None): 37 | raise ValueError( 38 | "parameters `X_treat` and `Y_treat` must both be Matrices or None" 39 | ) 40 | 41 | if X_treat is not None: 42 | # Fit the Treated units to the control units; assuming that Y contains 43 | # pre-intervention outcomes: 44 | 45 | # PARAMETER QC 46 | try: 47 | X_treat = np.float64(X_treat) 48 | except ValueError: 49 | raise ValueError("X_treat is not coercible to float64") 50 | try: 51 | Y_treat = np.float64(Y_treat) 52 | except ValueError: 53 | raise ValueError("Y_treat is not coercible to float64") 54 | 55 | Y_treat = np.asmatrix(Y_treat) # this needs to be deprecated properly -- bc Array.dot(Array) != matrix(Array).dot(matrix(Array)) -- not even close !!! 56 | X_treat = np.asmatrix(X_treat) 57 | 58 | if X_treat.shape[1] == 0: 59 | raise ValueError("X_treat.shape[1] == 0") 60 | if Y_treat.shape[1] == 0: 61 | raise ValueError("Y_treat.shape[1] == 0") 62 | if X_treat.shape[0] != Y_treat.shape[0]: 63 | raise ValueError( 64 | "X_treat and Y_treat have different number of rows (%s and %s)" 65 | % (X_treat.shape[0], Y_treat.shape[0]) 66 | ) 67 | 68 | # FIT THE V-MATRIX AND POSSIBLY CALCULATE THE w_pen 69 | # note that the weights, score, and loss function value returned here 70 | # are for the in-sample predictions 71 | _, v_mat, _, _, _, _ = ct_v_matrix( 72 | X=np.vstack((X, X_treat)), 73 | Y=np.vstack((Y, Y_treat)), 74 | control_units=np.arange(X.shape[0]), 75 | treated_units=np.arange(X_treat.shape[0]) + X.shape[0], 76 | **kwargs 77 | ) 78 | 79 | else: 80 | # Fit the control units to themselves; Y may contain post-intervention outcomes: 81 | 82 | adjusted=False 83 | if kwargs["w_pen"] < 1: 84 | adjusted=True 85 | 86 | if grad_splits is not None: 87 | _, v_mat, _, _, _, _ = fold_v_matrix( 88 | X=X, 89 | Y=Y, 90 | control_units=np.arange(X.shape[0]), 91 | treated_units=np.arange(X.shape[0]), 92 | grad_splits=grad_splits, 93 | **kwargs 94 | ) 95 | 96 | #if adjusted: 97 | # print("vmat: %s" % (np.diag(v_mat))) 98 | 99 | else: 100 | _, v_mat, _, _, _, _ = loo_v_matrix( 101 | X=X, 102 | Y=Y, 103 | control_units=np.arange(X.shape[0]), 104 | treated_units=np.arange(X.shape[0]), 105 | **kwargs 106 | ) 107 | return v_mat 108 | -------------------------------------------------------------------------------- /src/SparseSC/utils/AzureBatch/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | AzureBatch module 3 | """ 4 | # from .gradient_batch_client import gradient_batch_client # abandon for now 5 | DOCKER_IMAGE_NAME = "jdthorpe/sparsesc:latest" 6 | from .aggregate_results import aggregate_batch_results 7 | from .build_batch_job import create_job 8 | -------------------------------------------------------------------------------- /src/SparseSC/utils/AzureBatch/aggregate_results.py: -------------------------------------------------------------------------------- 1 | """ 2 | aggregate batch results, and optionally use batch to compute the final gradient descent. 3 | """ 4 | import numpy as np 5 | import os 6 | 7 | # $ from ...cross_validation import _score_from_batch 8 | from ...fit import _which, SparseSCFit 9 | from ...weights import weights 10 | from ...tensor import tensor 11 | from .constants import _BATCH_CV_FILE_NAME, _BATCH_FIT_FILE_NAME 12 | 13 | 14 | def aggregate_batch_results(batchDir, batch_client_config=None, choice=None): 15 | """ 16 | Aggregate results from a batch run 17 | """ 18 | 19 | from yaml import load 20 | 21 | try: 22 | from yaml import CLoader as Loader 23 | except ImportError: 24 | from yaml import Loader 25 | 26 | with open(os.path.join(batchDir, _BATCH_CV_FILE_NAME), "r") as fp: 27 | _cv_params = load(fp, Loader=Loader) 28 | 29 | with open(os.path.join(batchDir, _BATCH_FIT_FILE_NAME), "r") as fp: 30 | _fit_params = load(fp, Loader=Loader) 31 | 32 | # https://stackoverflow.com/a/17074606/1519199 33 | pluck = lambda d, *args: (d[arg] for arg in args) 34 | 35 | X_cv, Y_cv, grad_splits, random_state, v_pen, w_pen = pluck( 36 | _cv_params, "X", "Y", "grad_splits", "random_state", "v_pen", "w_pen" 37 | ) 38 | 39 | choice = choice if choice is not None else _fit_params["choice"] 40 | X, Y, treated_units, custom_donor_pool, model_type, kwargs = pluck( 41 | _fit_params, 42 | "X", 43 | "Y", 44 | "treated_units", 45 | "custom_donor_pool", 46 | "model_type", 47 | "kwargs", 48 | ) 49 | 50 | # this is on purpose (allows for debugging remote sessions at no cost to the local console user) 51 | kwargs["print_path"] = 1 52 | 53 | scores, scores_se = _score_from_batch(batchDir, _cv_params) 54 | 55 | try: 56 | iter(w_pen) 57 | except TypeError: 58 | w_pen_is_iterable = False 59 | else: 60 | w_pen_is_iterable = True 61 | 62 | try: 63 | iter(v_pen) 64 | except TypeError: 65 | v_pen_is_iterable = False 66 | else: 67 | v_pen_is_iterable = True 68 | 69 | # GET THE INDEX OF THE BEST SCORE 70 | def _choose(scores, scores_se): 71 | """ helper function which implements the choice of covariate weights penalty parameter 72 | 73 | Nested here for access to v_pen, w_pe,n w_pen_is_iterable and 74 | v_pen_is_iterable, and choice, via Lexical Scoping 75 | """ 76 | # GET THE INDEX OF THE BEST SCORE 77 | if w_pen_is_iterable: 78 | indx = _which(scores, scores_se, choice) 79 | return v_pen, w_pen[indx], scores[indx], indx 80 | if v_pen_is_iterable: 81 | indx = _which(scores, scores_se, choice) 82 | return v_pen[indx], w_pen, scores[indx], indx 83 | return v_pen, w_pen, scores, None 84 | 85 | best_v_pen, best_w_pen, score, which = _choose(scores, scores_se) 86 | 87 | # -------------------------------------------------- 88 | # Phase 2: extract V and weights: slow ( tens of seconds to minutes ) 89 | # -------------------------------------------------- 90 | if treated_units is not None: 91 | control_units = [u for u in range(Y.shape[0]) if u not in treated_units] 92 | Xtrain = X[control_units, :] 93 | Xtest = X[treated_units, :] 94 | Ytrain = Y[control_units, :] 95 | Ytest = Y[treated_units, :] 96 | else: 97 | control_units = None 98 | 99 | if model_type == "prospective-restricted": 100 | best_V = tensor( 101 | X=X_cv, 102 | Y=Y_cv, 103 | w_pen=best_w_pen, 104 | v_pen=best_v_pen, 105 | # 106 | X_treat=Xtest, 107 | Y_treat=Ytest, 108 | # 109 | batch_client_config=batch_client_config, # TODO: not sure if this makes sense... 110 | **_fit_params["kwargs"] 111 | ) 112 | else: 113 | best_V = tensor( 114 | X=X_cv, 115 | Y=Y_cv, 116 | w_pen=best_w_pen, 117 | v_pen=best_v_pen, 118 | # 119 | grad_splits=grad_splits, 120 | random_state=random_state, 121 | # 122 | batch_client_config=batch_client_config, 123 | **_fit_params["kwargs"] 124 | ) 125 | 126 | if treated_units is not None: 127 | 128 | # GET THE BEST SET OF WEIGHTS 129 | sc_weights = np.empty((X.shape[0], Ytrain.shape[0])) 130 | if custom_donor_pool is None: 131 | custom_donor_pool_t = None 132 | custom_donor_pool_c = None 133 | else: 134 | custom_donor_pool_t = custom_donor_pool[treated_units, :] 135 | custom_donor_pool_c = custom_donor_pool[control_units, :] 136 | sc_weights[treated_units, :] = weights( 137 | Xtrain, 138 | Xtest, 139 | V=best_V, 140 | w_pen=best_w_pen, 141 | custom_donor_pool=custom_donor_pool_t, 142 | ) 143 | sc_weights[control_units, :] = weights( 144 | Xtrain, V=best_V, w_pen=best_w_pen, custom_donor_pool=custom_donor_pool_c 145 | ) 146 | 147 | else: 148 | # GET THE BEST SET OF WEIGHTS 149 | sc_weights = weights( 150 | X, V=best_V, w_pen=best_w_pen, custom_donor_pool=custom_donor_pool 151 | ) 152 | 153 | return SparseSCFit( 154 | features=X, 155 | targets=Y, 156 | control_units=control_units, 157 | treated_units=treated_units, 158 | model_type=model_type, 159 | # fitting parameters 160 | fitted_v_pen=best_v_pen, 161 | fitted_w_pen=best_w_pen, 162 | initial_w_pen=w_pen, 163 | initial_v_pen=v_pen, 164 | V=best_V, 165 | # Fitted Synthetic Controls 166 | sc_weights=sc_weights, 167 | score=score, 168 | scores=scores, 169 | selected_score=which, 170 | ) 171 | 172 | 173 | def _score_from_batch(batchDir, config): 174 | """ 175 | read in the results from a batch run 176 | """ 177 | from yaml import load 178 | 179 | try: 180 | from yaml import CLoader as Loader 181 | except ImportError: 182 | from yaml import Loader 183 | 184 | try: 185 | v_pen = tuple(config["v_pen"]) 186 | except TypeError: 187 | v_pen = (config["v_pen"],) 188 | 189 | try: 190 | w_pen = tuple(config["w_pen"]) 191 | except TypeError: 192 | w_pen = (config["w_pen"],) 193 | 194 | n_folds = len(config["folds"]) * len(v_pen) * len(w_pen) 195 | n_pens = np.max((len(v_pen), len(w_pen))) 196 | n_cv_folds = n_folds // n_pens 197 | 198 | scores = np.empty((n_pens, n_cv_folds)) 199 | for i in range(n_folds): 200 | # i_fold, i_v, i_w = pluck(res, "i_fold", "i_v", "i_w", ) 201 | i_fold = i % len(config["folds"]) 202 | i_pen = i // len(config["folds"]) 203 | with open(os.path.join(batchDir, "fold_{}.yaml".format(i)), "r") as fp: 204 | res = load(fp, Loader=Loader) 205 | assert ( 206 | res["batch"] == i 207 | ), "Batch File Import Error Inconsistent batch identifiers" 208 | scores[i_pen, i_fold] = res["results"][2] 209 | 210 | # TODO: np.sqrt(len(scores)) * np.std(scores) is a quick and dirty hack for 211 | # calculating the standard error of the sum from the partial sums. It's 212 | # assumes the samples are equal size and randomly allocated (which is true 213 | # in the default settings). However, it could be made more formal with a 214 | # fixed effects framework, and leveraging the individual errors. 215 | # https://stats.stackexchange.com/a/271223/67839 216 | 217 | if len(v_pen) > 0 or len(w_pen): 218 | n_pens = np.max((len(v_pen), len(w_pen))) 219 | n_cv_folds = n_folds // n_pens 220 | total_score = scores.sum(axis=1) 221 | se = np.sqrt(n_cv_folds) * scores.std(axis=1) 222 | else: 223 | total_score = sum(scores) 224 | se = np.sqrt(len(scores)) * np.std(scores) 225 | 226 | return total_score, se 227 | 228 | -------------------------------------------------------------------------------- /src/SparseSC/utils/AzureBatch/build_batch_job.py: -------------------------------------------------------------------------------- 1 | """ 2 | USAGE: 3 | 4 | python 5 | create_job() 6 | """ 7 | from __future__ import print_function 8 | from os.path import join 9 | import super_batch 10 | from SparseSC.cli.stt import get_config 11 | from .constants import ( 12 | _CONTAINER_OUTPUT_FILE, 13 | _CONTAINER_INPUT_FILE, 14 | _BATCH_CV_FILE_NAME, 15 | ) 16 | 17 | LOCAL_OUTPUTS_PATTERN = "fold_{}.yaml" 18 | 19 | 20 | def create_job(client: super_batch.Client, batch_dir: str) -> None: 21 | r""" 22 | :param client: A :class:`super_batch.Client` instance with the Azure Batch run parameters 23 | :type client: :class:super_batch.Client 24 | 25 | :param str batch_dir: path of the local batch temp directory 26 | """ 27 | _LOCAL_INPUT_FILE = join(batch_dir, _BATCH_CV_FILE_NAME) 28 | 29 | v_pen, w_pen, model_data = get_config(_LOCAL_INPUT_FILE) 30 | n_folds = len(model_data["folds"]) * len(v_pen) * len(w_pen) 31 | 32 | # CREATE THE COMMON IMPUT FILE RESOURCE 33 | input_resource = client.build_resource_file( 34 | _LOCAL_INPUT_FILE, _CONTAINER_INPUT_FILE 35 | ) 36 | 37 | for fold_number in range(n_folds): 38 | 39 | # BUILD THE COMMAND LINE 40 | command_line = "/bin/bash -c 'stt {} {} {}'".format( 41 | _CONTAINER_INPUT_FILE, _CONTAINER_OUTPUT_FILE, fold_number 42 | ) 43 | 44 | # CREATE AN OUTPUT RESOURCE: 45 | output_resource = client.build_output_file( 46 | _CONTAINER_OUTPUT_FILE, LOCAL_OUTPUTS_PATTERN.format(fold_number) 47 | ) 48 | 49 | # CREATE A TASK 50 | client.add_task([input_resource], [output_resource], command_line=command_line) 51 | -------------------------------------------------------------------------------- /src/SparseSC/utils/AzureBatch/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | CONSTANTS USED BY THIS MODULE 3 | """ 4 | 5 | _DOCKER_IMAGE = "jdthorpe/sparsesc" 6 | _STANDARD_OUT_FILE_NAME = "stdout.txt" # Standard Output file 7 | _CONTAINER_OUTPUT_FILE = "output.yaml" # Standard Output file 8 | _CONTAINER_INPUT_FILE = "input.yaml" # Standard Output file 9 | _BATCH_CV_FILE_NAME = "cv_parameters.yaml" 10 | _BATCH_FIT_FILE_NAME = "fit_parameters.yaml" 11 | _GRAD_COMMON_FILE = "common.yaml" 12 | _GRAD_PART_FILE = "part.yaml" 13 | 14 | -------------------------------------------------------------------------------- /src/SparseSC/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/src/SparseSC/utils/__init__.py -------------------------------------------------------------------------------- /src/SparseSC/utils/batch_gradient.py: -------------------------------------------------------------------------------- 1 | """ 2 | utilities for TESTING gradient descent in batches 3 | """ 4 | import os 5 | import numpy as np 6 | import subprocess 7 | 8 | 9 | def single_grad_cli( 10 | tmpdir, N0, N1, in_controls, splits, b_i, w_pen, treated_units, Y_treated, Y_control 11 | ): 12 | """ 13 | wrapper for the real function 14 | """ 15 | from yaml import load, dump 16 | 17 | try: 18 | from yaml import CLoader as Loader, CDumper as Dumper 19 | except ImportError: 20 | from yaml import Loader, Dumper 21 | 22 | _common_params = { 23 | "N0": N0, 24 | "N1": N1, 25 | "in_controls": in_controls, 26 | "splits": splits, 27 | "b_i": b_i, 28 | "w_pen": w_pen, 29 | "treated_units": treated_units, 30 | "Y_treated": Y_treated, 31 | "Y_control": Y_control, 32 | } 33 | 34 | COMMONFILE = os.path.join(tmpdir, "commonfile.yaml") 35 | PARTFILE = os.path.join(tmpdir, "partfile.yaml") 36 | OUTFILE = os.path.join(tmpdir, "outfile.yaml") 37 | 38 | with open(COMMONFILE, "w") as fp: 39 | fp.write(dump(_common_params, Dumper=Dumper)) 40 | 41 | def inner(A, weights, dA_dV_ki_k, dB_dV_ki_k): 42 | """ 43 | Calculate a single component of the gradient 44 | """ 45 | _local_params = { 46 | "A": A, 47 | "weights": weights, 48 | "dA_dV_ki_k": dA_dV_ki_k, 49 | "dB_dV_ki_k": dB_dV_ki_k, 50 | } 51 | with open(PARTFILE, "w") as fp: 52 | fp.write(dump(_local_params, Dumper=Dumper)) 53 | 54 | subprocess.run(["scgrad", COMMONFILE, PARTFILE, OUTFILE]) 55 | 56 | with open(OUTFILE, "r") as fp: 57 | val = load(fp, Loader=Loader) 58 | return val 59 | 60 | return inner 61 | 62 | 63 | def single_grad( 64 | N0, N1, in_controls, splits, b_i, w_pen, treated_units, Y_treated, Y_control 65 | ): 66 | """ 67 | wrapper for the real function 68 | """ 69 | 70 | in_controls2 = [np.ix_(i, i) for i in in_controls] 71 | 72 | def inner(A, weights, dA_dV_ki_k, dB_dV_ki_k): 73 | """ 74 | Calculate a single component of the gradient 75 | """ 76 | dPI_dV = np.zeros((N0, N1)) # stupid notation: PI = W.T 77 | for i, (_, (_, test)) in enumerate(zip(in_controls, splits)): 78 | dA = dA_dV_ki_k[i] 79 | dB = dB_dV_ki_k[i] 80 | try: 81 | b = np.linalg.solve(A[in_controls2[i]], dB - dA.dot(b_i[i])) 82 | except np.linalg.LinAlgError as exc: 83 | print("Unique weights not possible.") 84 | if w_pen == 0: 85 | print("Try specifying a very small w_pen rather than 0.") 86 | raise exc 87 | dPI_dV[np.ix_(in_controls[i], treated_units[test])] = b 88 | # einsum is faster than the equivalent (Ey * Y_control.T.dot(dPI_dV).T.getA()).sum() 89 | return 2 * np.einsum( 90 | "ij,kj,ki->", (weights.T.dot(Y_control) - Y_treated), Y_control, dPI_dV 91 | ) 92 | 93 | return inner 94 | -------------------------------------------------------------------------------- /src/SparseSC/utils/descr_sets.py: -------------------------------------------------------------------------------- 1 | """Store the typical information for a match. 2 | """ 3 | from collections import namedtuple 4 | 5 | from SparseSC.utils.dist_summary import SSC_DescrStat 6 | 7 | MatchingEstimate = namedtuple( 8 | 'MatchingEstimate', 'att_est att_debiased_est atut_est atut_debiased_est ate_est ate_debiased_est aa_est naive_est') 9 | #""" 10 | #aa = The difference between control counterfactuals for controls and controls (Y_c_cf_c - Y_c) 11 | #Debiased means we subtract from the estimate the aa estimate 12 | #att = Average Treatment effect on the Treated (Y_t - Y_t_cf_c). 13 | #atut = Average Treatment on the UnTreated (Y_c_cf_t - Y_c) 14 | #ate = Average Treatment Effect (pooling att and atut samples) 15 | #naive = Just comparing Y_t and Y_c (no matching). Helpful for comparison (gauging selection size) 16 | #""" 17 | 18 | class DescrSet: 19 | """Holds potential distribution summaries for the various data used for matching 20 | """ 21 | 22 | def __init__(self, descr_Y_t=None, descr_Y_t_cf_c=None, descr_Y_diff_t_cf_c=None, 23 | descr_Y_c=None, descr_Y_diff_t_c=None, descr_Y_c_cf_c=None, descr_Y_diff_c_cf_c=None, 24 | descr_Y_c_cf_t=None, descr_Y_diff_c_cf_t=None): 25 | """Generate the common descriptive stats from data and differences 26 | :param descr_Y_t: SSC_DescrStat for Y_t (outcomes for treated units) 27 | :param descr_Y_t_cf_c: SSC_DescrStat for Y_t_cf_c (outcomes for (control) counterfactuals of treated units) 28 | :param descr_Y_diff_t_cf_c: SSC_DescrStat of Y_t - Y_t_cf_c 29 | :param descr_Y_c: SSC_DescrStat of Y_c (outcomes for control units) 30 | :param descr_Y_diff_t_c: SSC_DescrStat for Y_t-Y_c 31 | :param descr_Y_c_cf_c: SSC_DescrStat for Y_c_cf_c (outcomes for (control) counterfactuals of control units) 32 | :param descr_Y_diff_c_cf_c: SSC_DescrStat for Y_c - Y_c_cf_c 33 | :param descr_Y_c_cf_t: SSC_DescrStat for Y_c_cf_t (outcomes for (TREATED) counterfactuals of control units) 34 | :param descr_Y_diff_c_cf_t: SSC_DescrStat for Y_c_cf_t - Y_c 35 | """ 36 | self.descr_Y_t = descr_Y_t 37 | self.descr_Y_t_cf_c = descr_Y_t_cf_c 38 | self.descr_Y_diff_t_cf_c = descr_Y_diff_t_cf_c 39 | 40 | self.descr_Y_c = descr_Y_c 41 | self.descr_Y_diff_t_c = descr_Y_diff_t_c 42 | self.descr_Y_c_cf_c = descr_Y_c_cf_c 43 | self.descr_Y_diff_c_cf_c = descr_Y_diff_c_cf_c 44 | 45 | self.descr_Y_c_cf_t = descr_Y_c_cf_t 46 | self.descr_Y_diff_c_cf_t = descr_Y_diff_c_cf_t 47 | 48 | def __repr__(self): 49 | return ("%s(descr_Y_t=%s, descr_Y_t_cf_c=%s, descr_Y_diff_t_cf_c=%s, descr_Y_c=%s" + 50 | "descr_Y_diff_t_c=%s, descr_Y_c_cf_c=%s, descr_Y_diff_c_cf_c=%s, descr_Y_c_cf_t=%s," + 51 | " descr_Y_diff_c_cf_t=%s)") % (self.__class__.__name__, self.descr_Y_t, 52 | self.descr_Y_t_cf_c, self.descr_Y_diff_t_cf_c, self.descr_Y_c, 53 | self.descr_Y_diff_t_c, self.descr_Y_c_cf_c, self.descr_Y_diff_c_cf_c, 54 | self.descr_Y_c_cf_t, self.descr_Y_diff_c_cf_t) 55 | 56 | @staticmethod 57 | def from_data(Y_t=None, Y_t_cf_c=None, Y_c=None, Y_c_cf_c=None, Y_c_cf_t=None): 58 | """Generate the common descriptive stats from data and differences 59 | :param Y_t: np.array of dim=(N_t, T) for outcomes for treated units 60 | :param Y_t_cf_c: np.array of dim=(N_t, T) for outcomes for (control) counterfactuals of treated units (used to get the average treatment effect on the treated (ATT)) 61 | :param Y_c: np.array of dim=(N_c, T) for outcomes for control units 62 | :param Y_c_cf_c: np.array of dim=(N_c, T) for outcomes for (control) counterfactuals of control units (used for AA test) 63 | :param Y_c_cf_t: np.array of dim=(N_c, T) for outcomes for (TREATED) counterfactuals of control units (used to calculate average treatment effect on the untreated (ATUT), or pooled with ATT to get the average treatment effect (ATE)) 64 | :returns: DescrSet 65 | """ 66 | # Note: While possible, there's no real use for treated matched to other treateds. 67 | def _gen_if_valid(Y): 68 | return SSC_DescrStat.from_data(Y) if Y is not None else None 69 | 70 | def _gen_diff_if_valid(Y1, Y2): 71 | return SSC_DescrStat.from_data(Y1-Y2) if (Y1 is not None and Y2 is not None) else None 72 | 73 | descr_Y_t = _gen_if_valid(Y_t) 74 | descr_Y_t_cf_c = _gen_if_valid(Y_t_cf_c) 75 | descr_Y_diff_t_cf_c = _gen_diff_if_valid(Y_t, Y_t_cf_c) 76 | 77 | descr_Y_c = _gen_if_valid(Y_c) 78 | descr_Y_diff_t_c = _gen_diff_if_valid(Y_t, Y_c) 79 | descr_Y_c_cf_c = _gen_if_valid(Y_c_cf_c) 80 | descr_Y_diff_c_cf_c = _gen_diff_if_valid(Y_c_cf_c, Y_c) 81 | descr_Y_c_cf_t = _gen_if_valid(Y_c_cf_t) 82 | descr_Y_diff_c_cf_t = _gen_diff_if_valid(Y_c_cf_t, Y_c) 83 | 84 | return DescrSet(descr_Y_t, descr_Y_t_cf_c, descr_Y_diff_t_cf_c, 85 | descr_Y_c, descr_Y_diff_t_c, descr_Y_c_cf_c, descr_Y_diff_c_cf_c, 86 | descr_Y_c_cf_t, descr_Y_diff_c_cf_t) 87 | 88 | def __add__(self, other): 89 | def _add_if_valid(a, b): 90 | return a+b if (a is not None and b is not None) else None 91 | return DescrSet(descr_Y_t=_add_if_valid(self.descr_Y_t, other.descr_Y_t), 92 | descr_Y_t_cf_c=_add_if_valid(self.descr_Y_t_cf_c, other.descr_Y_t_cf_c), 93 | descr_Y_diff_t_cf_c=_add_if_valid(self.descr_Y_diff_t_cf_c, other.descr_Y_diff_t_cf_c), 94 | descr_Y_c=_add_if_valid(self.descr_Y_c, other.descr_Y_c), 95 | descr_Y_diff_t_c=_add_if_valid(self.descr_Y_diff_t_c, other.descr_Y_diff_t_c), 96 | descr_Y_c_cf_c=_add_if_valid(self.descr_Y_c_cf_c, other.descr_Y_c_cf_c), 97 | descr_Y_diff_c_cf_c=_add_if_valid(self.descr_Y_diff_c_cf_c, other.descr_Y_diff_c_cf_c), 98 | descr_Y_c_cf_t=_add_if_valid(self.descr_Y_c_cf_t, other.descr_Y_c_cf_t), 99 | descr_Y_diff_c_cf_t=_add_if_valid(self.descr_Y_diff_c_cf_t, other.descr_Y_diff_c_cf_t)) 100 | 101 | def calc_estimates(self): 102 | """ Takes matrices of effects for multiple events and return averaged results 103 | """ 104 | def _calc_estimate(descr_stat1, descr_stat2): 105 | if descr_stat1 is None or descr_stat2 is None: 106 | return None 107 | return SSC_DescrStat.lcl_comp_means(descr_stat1, descr_stat2) 108 | 109 | att_est = _calc_estimate(self.descr_Y_t, self.descr_Y_t_cf_c) 110 | att_debiased_est = _calc_estimate(self.descr_Y_diff_t_cf_c, self.descr_Y_diff_c_cf_c) 111 | 112 | atut_est = _calc_estimate(self.descr_Y_c_cf_t, self.descr_Y_c) 113 | atut_debiased_est = _calc_estimate(self.descr_Y_diff_c_cf_t, self.descr_Y_diff_c_cf_c) 114 | 115 | ate_est, ate_debiased_est = None, None 116 | if all(d is not None for d in [self.descr_Y_t, self.descr_Y_c_cf_t, self.descr_Y_t_cf_c, self.descr_Y_c]): 117 | ate_est = _calc_estimate(self.descr_Y_t + self.descr_Y_c_cf_t, 118 | self.descr_Y_t_cf_c + self.descr_Y_c) 119 | if all(d is not None for d in [self.descr_Y_diff_t_cf_c, self.descr_Y_diff_c_cf_t, self.descr_Y_diff_c_cf_c]): 120 | ate_debiased_est = _calc_estimate(self.descr_Y_diff_t_cf_c + 121 | self.descr_Y_diff_c_cf_t, self.descr_Y_diff_c_cf_c) 122 | 123 | aa_est = _calc_estimate(self.descr_Y_c_cf_c, self.descr_Y_c) 124 | # descr_Y_diff_c_cf_c #used for the double comparisons 125 | 126 | naive_est = _calc_estimate(self.descr_Y_t, self.descr_Y_c) 127 | # descr_Y_diff_t_c #Don't think this could be useful, but just in case. 128 | 129 | return MatchingEstimate(att_est, att_debiased_est, 130 | atut_est, atut_debiased_est, 131 | ate_est, ate_debiased_est, 132 | aa_est, naive_est) 133 | -------------------------------------------------------------------------------- /src/SparseSC/utils/dist_summary.py: -------------------------------------------------------------------------------- 1 | """This is a way to summarize (using normal approximations) the distributions of real and 2 | synthetic controls so that all data doesn't have to be stored. 3 | """ 4 | import math 5 | from collections import namedtuple 6 | 7 | import statsmodels 8 | import numpy as np 9 | 10 | # TODO: 11 | # - Allow passed in vector and return scalar 12 | 13 | def tstat_generic(mean1, mean2, stdm, dof): 14 | """Vectorized version of statsmodels' _tstat_generic 15 | :param mean1: int or np.array 16 | :param mean2: int or np.array 17 | :param stdm: int or np.array of the standard deviation of the pooled sample 18 | :param dof: int 19 | """ 20 | from statsmodels.stats.weightstats import _tstat_generic 21 | l = len(mean1) 22 | if l == 1: 23 | tstat, pval = _tstat_generic(mean1, mean2, stdm, dof, 'two-sided', diff=0) 24 | else: 25 | tstat, pval = zip(*[_tstat_generic(mean1[i], mean2[i], stdm[i], dof, 'two-sided', diff=0) 26 | for i in range(l)]) 27 | #tstat = (mean1 - mean2) / stdm # 28 | #pvalue = stats.t.sf(np.abs(tstat), dof)*2 29 | # cohen's d: diff/samplt std dev 30 | return tstat, pval 31 | 32 | # def pooled_variances_scalar(sample_variances, sample_sizes): 33 | # """Estimate pooled variance from a set of samples. Assumes same variance but allows different means.""" 34 | # return np.average(sample_variances, weights=(sample_sizes-1)) 35 | 36 | def pooled_variances(sample_variances, sample_sizes): 37 | """Estimate pooled variance from a set of samples. Assumes same variance but allows different means. 38 | If inputs are nxl then return l 39 | """ 40 | # https://en.wikipedia.org/wiki/Pooled_variance 41 | return np.average(sample_variances, weights=(sample_sizes-1), axis=0) 42 | 43 | 44 | class Estimate: # can hold scalars or vectors 45 | def __init__(self, effect, pval, baseline=None): 46 | self.effect = effect 47 | self.pval = pval 48 | self.baseline=baseline 49 | 50 | class SSC_DescrStat(object): 51 | """Stores mean and variance for a sample in a way that can be updated 52 | with new observations or adding together summaries. Similar to statsmodel's DescrStatW 53 | except we don't keep the raw data, and we use 'online' algorithm's to allow for incremental approach.""" 54 | # Similar to https://github.com/grantjenks/python-runstats but uses numpy and doesn't do higher order stats and has multiple columns 55 | # Ref: https://www.statsmodels.org/stable/generated/statsmodels.stats.weightstats.DescrStatsW.html#statsmodels.stats.weightstats.DescrStatsW 56 | 57 | def __init__(self, nobs, mean, M2): 58 | """ 59 | :param nobs: scalar 60 | :param mean: vector 61 | :param M2: vector of same length as mean. Sum of squared deviations (sum_i (x_i-mean)^2). 62 | Sometimes called 'S' (capital; though 's' is often sample variance) 63 | :raises: ValueError 64 | """ 65 | import numbers 66 | if not isinstance(nobs, numbers.Number): 67 | raise ValueError('mean should be np vector') 68 | self.nobs = nobs 69 | if len(mean.shape)!=1: 70 | raise ValueError('mean should be np vector') 71 | if len(M2.shape)!=1: 72 | raise ValueError('M2 should be np vector') 73 | if len(M2.shape)!=len(mean.shape): 74 | raise ValueError('M2 and mean should be the same length') 75 | self.mean = mean 76 | self.M2 = M2 # sometimes called S (though s is often sample variance) 77 | 78 | def __repr__(self): 79 | return "%s(nobs=%s, mean=%s, M2=%s)" % (self.__class__.__name__, self.nobs, self.mean, self.M2) 80 | 81 | def __eq__(self, obj): 82 | return isinstance(obj, SSC_DescrStat) and np.array_equal(obj.nobs,self.nobs) and np.array_equal(obj.mean, self.mean) and np.array_equal(obj.M2, self.M2) 83 | 84 | @staticmethod 85 | def from_data(data): 86 | """ 87 | :param data: 2D np.array. We compute stats per column. 88 | :returns: SSC_DescrStat object 89 | """ 90 | N = data.shape[0] 91 | mean = np.average(data, axis=0) 92 | M2 = np.var(data, axis=0)*N 93 | return SSC_DescrStat(N, mean, M2) 94 | 95 | def __add__(self, other, alt=False): 96 | """ 97 | Chan's parallel algorithm 98 | :param other: Other SSC_DescrStat object 99 | :param alt: Use when roughly similar sizes and both are large. Avoid catastrophic cancellation 100 | :returns: new SSC_DescrStat object 101 | """ 102 | # See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance 103 | # TODO: could make alt auto (look at, e.g., abs(self.nobs - other.nobs)/new_n) 104 | new_n = self.nobs + other.nobs 105 | delta = other.mean - self.mean 106 | if not alt: 107 | new_mean = self.mean+delta*(other.nobs/new_n) 108 | else: 109 | new_mean = (self.nobs*self.mean + other.nobs*other.mean)/new_n 110 | new_M2 = self.M2 + other.M2 + np.square(delta) * self.nobs * other.nobs / new_n 111 | return SSC_DescrStat(new_n, new_mean, new_M2) 112 | 113 | def update(self, obs): 114 | """Welford's online algorithm 115 | :param obs: new observation vector 116 | """ 117 | # See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm 118 | self.nobs += 1 119 | if len(obs.shape)==2 and obs.shape[0]==1: 120 | obs = obs[0,:] 121 | delta = obs - self.mean 122 | self.mean += delta*1./self.nobs 123 | self.M2 += delta*(obs - self.mean) 124 | 125 | def variance(self, ddof=1): 126 | """ 127 | :param ddof: delta degree of difference. 1 will give sample variance (s). 128 | :returns: Variance 129 | """ 130 | if self.nobs == 1: 131 | return np.zeros(self.mean.shape) 132 | return self.M2/(self.nobs-ddof) 133 | 134 | @property 135 | def var(self): 136 | return self.variance() 137 | 138 | def stddev(self, ddof=1): 139 | """Standard Deviation 140 | :param ddof: delta degree of difference. 1 will give sample Standard Deviation 141 | :returns: Standard Deviation 142 | """ 143 | return np.sqrt(self.variance(ddof)) 144 | 145 | @property 146 | def std(self): 147 | return self.stddev() 148 | 149 | def std_mean(self, ddof=1): 150 | """Standard Deviation/Error of the mean 151 | :param ddof: delta degree of difference. 1 will give sample variance (s). 152 | :returns: Standard Error 153 | """ 154 | return self.stddev(ddof)/math.sqrt(self.nobs-1) 155 | 156 | @property 157 | def sumsquares(self): 158 | return self.M2 159 | 160 | @property 161 | def sum(self): 162 | return self.mean*self.nobs 163 | 164 | @property 165 | def sum_weights(self): 166 | return self.sum 167 | 168 | 169 | @staticmethod 170 | def lcl_comp_means(descr1, descr2): 171 | """ Calclulates the t-test of the difference in means. Local version of statsmodels.stats.weightstats import CompareMeans 172 | :param descr1: DescrStatW-type object of sample statistics 173 | :param descr2: DescrStatW-type object of sample statistics 174 | """ 175 | #from statsmodels.stats.weightstats import CompareMeans 176 | # Do statsmodels.CompareMeans with just summary stats. 177 | var_pooled = pooled_variances(np.array([descr1.var, descr2.var]), np.array([descr1.nobs, descr2.nobs])) 178 | stdm = np.sqrt(var_pooled * (1. / descr1.nobs + 1. / descr2.nobs)) # ~samplt std dev/sqrt(N) 179 | dof = descr1.nobs - 1 + descr2.nobs - 1 180 | tstat, pval = tstat_generic(descr1.mean, descr2.mean, stdm, dof) 181 | effect = descr1.mean - descr2.mean 182 | return Estimate(effect, pval) 183 | -------------------------------------------------------------------------------- /src/SparseSC/utils/local_grad_daemon.py: -------------------------------------------------------------------------------- 1 | """ 2 | for local testing of the daemon 3 | """ 4 | # pylint: disable=differing-type-doc, differing-param-doc, missing-param-doc, missing-raises-doc, missing-return-doc 5 | import datetime 6 | import os 7 | import json 8 | import tempfile 9 | import atexit 10 | import subprocess 11 | import itertools 12 | import tarfile 13 | import io 14 | import sys 15 | 16 | import numpy as np 17 | from ..cli.scgrad import GradientDaemon, DAEMON_PID, DAEMON_FIFO 18 | 19 | from yaml import load, dump 20 | 21 | try: 22 | from yaml import CLoader as Loader, CDumper as Dumper 23 | except ImportError: 24 | from yaml import Loader, Dumper 25 | 26 | # pylint: disable=fixme, too-few-public-methods 27 | # pylint: disable=bad-continuation, invalid-name, protected-access, line-too-long 28 | 29 | try: 30 | input = raw_input # pylint: disable=redefined-builtin 31 | except NameError: 32 | pass 33 | 34 | 35 | _CONTAINER_OUTPUT_FILE = "output.yaml" # Standard Output file 36 | _GRAD_COMMON_FILE = "common.yaml" 37 | _GRAD_PART_FILE = "part.yaml" 38 | 39 | RETURN_FIFO = "/tmp/sc-return.fifo" 40 | if sys.platform!="win32": #mkfifo not available on Windows 41 | os.mkfifo(RETURN_FIFO) # pylint: disable=no-member 42 | 43 | 44 | def cleanup(): 45 | """ clean up""" 46 | if sys.platform!="win32": #allow sphinx to build on windows w/o error 47 | os.remove(RETURN_FIFO) 48 | 49 | 50 | atexit.register(cleanup) 51 | 52 | 53 | class local_batch_daemon: 54 | """ 55 | Client object for performing gradient calculations with azure batch 56 | """ 57 | 58 | def __init__(self, common_data, K): 59 | subprocess.call(["python", "-m", "SparseSC.cli.scgrad", "start"]) 60 | # CREATE THE RESPONSE FIFO 61 | # replace any missing values with environment variables 62 | self.common_data = common_data 63 | self.K = K 64 | 65 | # BUILT THE TEMPORARY FILE NAMES 66 | self.tmpDirManager = tempfile.TemporaryDirectory() 67 | self.tmpdirname = self.tmpDirManager.name 68 | print("Created temporary directory:", self.tmpdirname) 69 | self.GRAD_PART_FILE = os.path.join(self.tmpdirname, _GRAD_PART_FILE) 70 | self.CONTAINER_OUTPUT_FILE = os.path.join(self.tmpdirname, _CONTAINER_OUTPUT_FILE) 71 | 72 | # WRITE THE COMMON DATA TO FILE: 73 | with open(os.path.join(self.tmpdirname, _GRAD_COMMON_FILE), "w") as fh: 74 | fh.write(dump(self.common_data, Dumper=Dumper)) 75 | 76 | 77 | #-- # A UTILITY FUNCTION 78 | #-- def tarify(x,name): 79 | #-- with tarfile.open(os.path.join(self.tmpdirname, '{}.tar.gz'.format(name)), mode='w:gz') as dest_file: 80 | #-- for i, k in itertools.product( range(len(x)), range(len(x[0]))): 81 | #-- fname = 'arr_{}_{}'.format(i,k) 82 | #-- array_bytes = x[i][k].tobytes() 83 | #-- info = tarfile.TarInfo(fname) 84 | #-- info.size = len(array_bytes) 85 | #-- dest_file.addfile(info,io.BytesIO(array_bytes) ) 86 | #-- 87 | #-- tarify(part_data["dA_dV_ki"],"dA_dV_ki") 88 | #-- tarify(part_data["dB_dV_ki"],"dB_dV_ki") 89 | #-- import pdb; pdb.set_trace() 90 | 91 | 92 | def stop(self): 93 | """ 94 | stop the daemon 95 | """ 96 | # pylint: disable=no-self-use 97 | subprocess.call(["python", "-m", "SparseSC.cli.scgrad", "stop"]) 98 | 99 | def do_grad(self, part_data): 100 | """ 101 | calculate the gradient 102 | """ 103 | start_time = datetime.datetime.now().replace(microsecond=0) 104 | print("Gradient start time: {}".format(start_time)) 105 | 106 | # WRITE THE PART DATA TO FILE 107 | with open(self.GRAD_PART_FILE, "w") as fh: 108 | fh.write(dump(part_data, Dumper=Dumper)) 109 | 110 | print("Gradient A") 111 | #-- for key in part_data.keys(): 112 | #-- with open(os.path.join(self.tmpdirname, key), "w") as fh: 113 | #-- fh.write(dump(part_data[key], Dumper=Dumper)) 114 | #-- with open(os.path.join(self.tmpdirname, "item0.yaml"), "w") as fh: fh.write(dump(part_data["dA_dV_ki"][0][0], Dumper=Dumper)) 115 | #-- with open(os.path.join(self.tmpdirname, "item0.bytes"), "wb") as fh: fh.write(part_data["dA_dV_ki"][0][0].tobytes()) 116 | #-- import gzip 117 | #-- 118 | #-- with gzip.open(os.path.join(self.tmpdirname, "item0_{}.gz".format(i)), "rb") as fh: matbytes = fh.read() 119 | 120 | dGamma0_dV_term2 = np.zeros(self.K) 121 | for k in range(self.K): 122 | print(k, end=" ") 123 | # SEND THE ARGS TO THE DAEMON 124 | with open(DAEMON_FIFO, "w") as df: 125 | df.write( 126 | json.dumps( 127 | [ 128 | self.tmpdirname, 129 | RETURN_FIFO, 130 | k, 131 | ] 132 | ) 133 | ) 134 | 135 | # LISTEN FOR THE RESPONSE 136 | with open(RETURN_FIFO, "r") as rf: 137 | response = rf.read() 138 | if response != "0": 139 | raise RuntimeError("Something went wrong in the daemon: {}".format( response)) 140 | 141 | with open(self.CONTAINER_OUTPUT_FILE, "r") as fh: 142 | dGamma0_dV_term2[k] = load(fh.read(), Loader=Loader) 143 | 144 | return dGamma0_dV_term2 145 | 146 | -------------------------------------------------------------------------------- /src/SparseSC/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Allow capturing output 2 | # Modified (to not capture stderr too) from https://stackoverflow.com/questions/5136611/ 3 | import contextlib 4 | import sys 5 | 6 | from .print_progress import it_progressbar, it_progressmsg 7 | 8 | @contextlib.contextmanager 9 | def capture(): 10 | STDOUT = sys.stdout 11 | try: 12 | sys.stdout = DummyFile() 13 | yield 14 | finally: 15 | sys.stdout = STDOUT 16 | 17 | 18 | class DummyFile(object): 19 | def write(self, x): 20 | pass 21 | 22 | 23 | @contextlib.contextmanager 24 | def capture_all(): 25 | STDOUT, STDERR = sys.stdout, sys.stderr 26 | try: 27 | sys.stdout, sys.stderr = DummyFile(), DummyFile() 28 | yield 29 | finally: 30 | sys.stdout, sys.stderr = STDOUT, STDERR 31 | 32 | 33 | def par_map(part_fn, it, F, loop_verbose, n_multi=0, header="LOOP"): 34 | if n_multi>0: 35 | from multiprocessing import Pool 36 | 37 | with Pool(n_multi) as p: 38 | #p.map evals the it so can't use it_progressbar(it) 39 | if loop_verbose==1: 40 | rets = [] 41 | print(header + ":") 42 | for ret in it_progressbar(p.imap(part_fn, it), count=F): 43 | rets.append(ret) 44 | elif loop_verbose==2: 45 | rets = [] 46 | for ret in it_progressmsg(p.imap(part_fn, it), prefix=header, count=F): 47 | rets.append(ret) 48 | else: 49 | rets = p.map(part_fn, it) 50 | else: 51 | if loop_verbose==1: 52 | print(header + ":") 53 | it = it_progressbar(it, count=F) 54 | elif loop_verbose==2: 55 | it = it_progressmsg(it, prefix=header, count=F) 56 | rets = list(map(part_fn, it)) 57 | return rets 58 | 59 | 60 | class PreDemeanScaler: 61 | """ 62 | Units are defined by rows and cols are "pre" and "post" separated. 63 | Demeans each row by the "pre" mean. 64 | """ 65 | 66 | # maybe fit should just take Y and T0 (in init())? 67 | # Try in sklearn.pipeline with fit() for that and predict (on default Y_post) 68 | # might want wrappers around fit to make that work fine with pipeline (given its standard arguments). 69 | # maybe call the vars X rather than Y? 70 | def __init__(self): 71 | self.means = None 72 | # self.T0 = T0 73 | 74 | def fit(self, Y): 75 | """ 76 | Ex. fit(Y.iloc[:,0:T0]) 77 | """ 78 | import numpy as np 79 | 80 | self.means = np.mean(Y, axis=1) 81 | 82 | def transform(self, Y): 83 | return (Y.T - self.means).T 84 | 85 | def inverse_transform(self, Y): 86 | return (Y.T + self.means).T 87 | 88 | 89 | def _ensure_good_donor_pool(custom_donor_pool, control_units): 90 | N0 = custom_donor_pool.shape[1] 91 | custom_donor_pool_c = custom_donor_pool[control_units, :] 92 | for i in range(N0): 93 | custom_donor_pool_c[i, i] = False 94 | custom_donor_pool[control_units, :] = custom_donor_pool_c 95 | return custom_donor_pool 96 | 97 | 98 | def _get_fit_units(model_type, control_units, treated_units, N): 99 | if model_type == "retrospective": 100 | return control_units 101 | elif model_type == "prospective": 102 | return range(N) 103 | elif model_type == "prospective-restricted:": 104 | return treated_units 105 | # model_type=="full": 106 | return range(N) # same as control_units 107 | 108 | -------------------------------------------------------------------------------- /src/SparseSC/utils/ols_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import statsmodels.api as sm 3 | 4 | import SparseSC as SC 5 | 6 | #All data in the long format (rows mark id-time) 7 | def OLS_avg_AA_simple(N=100, T=10, K=0, treat_ratio=.1, T0=None): 8 | #0-based index for i and t. each units panel is together (fine index is time, gross is unit) 9 | Y = np.random.normal(0,1,(N*T)) 10 | id = np.tile(np.array(range(T)), N) 11 | time = np.repeat(np.array(range(N)), T) 12 | if T0 is None: 13 | T0 = int(T/2) 14 | return OLS_avg_AA(Y, id, time, T0, N, T, treat_ratio) 15 | 16 | def OLS_avg_AA(Y, id, time, post_start, N, T, treat_ratio, num_sim=1000, X=None, level=0.95): 17 | Const = np.ones((Y.shape[0], 1)) 18 | Post = np.expand_dims((time>=post_start).astype(int), axis=1) 19 | X_base = np.hstack((Const, Post)) 20 | #X_base = sm.add_constant(X_base) 21 | if X is not None: 22 | X_base = np.hstack(X_base, X) 23 | alpha = 1-level 24 | 25 | tes = np.empty((num_sim)) 26 | ci_ls = np.empty((num_sim)) 27 | ci_us = np.empty((num_sim)) 28 | N1 = int(N*treat_ratio) 29 | sel_idx = np.concatenate((np.repeat(1,N1), np.repeat(0,N-N1))) 30 | for s in range(num_sim): 31 | np.random.shuffle(sel_idx) 32 | Treat = np.expand_dims(np.repeat(sel_idx, T), axis=1) 33 | D = Treat * Post 34 | X = np.hstack((X_base,Treat, D)) 35 | model = sm.OLS(Y,X, hasconst=True) 36 | results = model.fit() 37 | tes[s] = results.params[3] 38 | [ci_ls[s], ci_us[s]] = results.conf_int(alpha, cols=[3])[0] 39 | 40 | stats = SC.utils.metrics_utils.simulation_eval(tes, ci_ls, ci_us, true_effect=0) 41 | print(stats) 42 | 43 | 44 | #Do separate effects for each post treatment period? 45 | #def OLS_AA_vec(Y, id, time, treat_ratio, post_times, X=None): 46 | # for post_time in post_times: 47 | # Post_ind_t = time==post_time 48 | # X_base = np.hstack((X_base, Post_ind_t)) 49 | # 50 | #def OLS_AA_vec_specific(Y, X_base, sel_idx): 51 | # base_init_len = X_base.shape[1] 52 | # Treat = np.vstack(id_pre==sel_id, id_post==sel_id) 53 | # X_base = np.hstack(X_base, Treat) 54 | # for post_idx in 2:base_init_len: 55 | # D_t = Treat and X_base[:,post_idx] 56 | # X = np.hstack((X,D_t)) 57 | 58 | # model = sm.OLS(Y,X, hasconst=True) 59 | # results = model.fit() 60 | # results.params[base_init_len+1:] 61 | -------------------------------------------------------------------------------- /src/SparseSC/utils/print_progress.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ A utility for displaying a progress bar 3 | """ 4 | # https://gist.github.com/aubricus/f91fb55dc6ba5557fbab06119420dd6a 5 | import sys 6 | import datetime 7 | SparseSC_prev_iteration = 0 8 | def print_progress(iteration, total=100, prefix='', suffix='', decimals=1, bar_length=100, file=sys.stdout): 9 | """ 10 | Call in a loop to create progress bar. If this isn't a tty-like (re-writable) output then 11 | no percent-complete number will be outputted. 12 | @params: 13 | iteration - Required : current iteration (Int) (typically 1 is first possible) 14 | total - Required : total iterations (Int) 15 | prefix - Optional : prefix string (Str) 16 | suffix - Optional : suffix string (Str) 17 | decimals - Optional : positive number of decimals in percent complete (Int) 18 | bar_length - Optional : character length of bar (Int) 19 | file - Optional : file descriptor for output 20 | """ 21 | fill_char = '>' # Note that the "█" character is not compatible with every platform... 22 | empty_char = '-' 23 | filled_length = int(round(bar_length * iteration / float(total))) 24 | if file.isatty(): 25 | str_format = "{0:." + str(decimals) + "f}" 26 | percents = str_format.format(100 * (iteration / float(total))) 27 | 28 | progress_bar = fill_char * filled_length + empty_char * (bar_length - filled_length) 29 | 30 | file.write('\r%s |%s| %s%s %s' % (prefix, progress_bar, percents, '%', suffix)) 31 | 32 | if iteration == total: 33 | file.write('\n') 34 | file.flush() 35 | else: # Can't do interactive re-writing (e.g. w/ the /r character) 36 | global SparseSC_prev_iteration 37 | if iteration == 1: 38 | file.write('%s |' % (prefix)) 39 | else: 40 | if iteration <= SparseSC_prev_iteration: 41 | SparseSC_prev_iteration = 0 42 | prev_fill_length = int(round(bar_length * SparseSC_prev_iteration / float(total))) 43 | progress_bar = fill_char * (filled_length-prev_fill_length) 44 | 45 | file.write('%s' % (progress_bar)) 46 | 47 | if iteration == total: 48 | file.write('| %s\n' % (suffix)) 49 | SparseSC_prev_iteration = iteration 50 | file.flush() 51 | 52 | def it_progressmsg(it, prefix="Loop", file=sys.stdout, count=None): 53 | for i, item in enumerate(it): 54 | if count is None: 55 | file.write(f"{prefix}: {i}\n") 56 | else: 57 | file.write(f"{prefix}: {i} of {count}\n") 58 | file.flush() 59 | yield item 60 | file.write(prefix + ": FINISHED\n") 61 | file.flush() 62 | 63 | #Similar to above, but you wrap an iterator 64 | def it_progressbar(it, prefix="", suffix='', decimals=1, bar_length=100, file=sys.stdout, count=None): 65 | if count is None: 66 | count = len(it) 67 | def show(j): 68 | fill_char = '>' 69 | empty_char = '-' 70 | x = int(bar_length*j/count) 71 | str_format = "{0:." + str(decimals) + "f}" 72 | percents = str_format.format(100 * (j / float(count))) 73 | progress_bar = fill_char * x + empty_char * (bar_length - x) 74 | if file.isatty(): 75 | file.write("%s |%s| %s%s %s\r" % (prefix, progress_bar, percents, '%', suffix)) 76 | file.flush() 77 | else: # Can't do interactive re-writing (e.g. w/ the /r character) 78 | if j == 0: 79 | file.write('%s |' % (prefix)) 80 | else: 81 | prev_x = int(bar_length*(j-1)/count) 82 | progress_bar = fill_char * (x-prev_x) 83 | file.write('%s' % (progress_bar)) 84 | if j == count: 85 | file.write('| %s' % (suffix)) 86 | show(0) 87 | for i, item in enumerate(it): 88 | yield item 89 | show(i+1) 90 | file.write("\n") 91 | file.flush() 92 | 93 | def log_if_necessary(str_note_start, verbose): 94 | str_note = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + ": " + str_note_start 95 | if verbose>0: 96 | print(str_note) 97 | if verbose>1: 98 | print_memory_snapshot(extra_str=str_note) 99 | 100 | 101 | def print_memory_snapshot(extra_str=None): 102 | import os 103 | import tracemalloc 104 | log_file = os.getenv("SparseSC_log_file") #None if non-existant 105 | snapshot = tracemalloc.take_snapshot() 106 | top_stats = snapshot.statistics('lineno') 107 | if log_file is not None: 108 | log_file = open(log_file, "a") 109 | if extra_str is not None: 110 | print(extra_str, file=log_file) 111 | limit=10 112 | print("[ Top 10 ] ", file=log_file) 113 | for stat in top_stats[:limit]: 114 | print(stat, file=log_file) 115 | other = top_stats[limit:] 116 | if other: 117 | size = sum(stat.size for stat in other) 118 | print("%s other: %.1f KiB" % (len(other), size / 1024), file=log_file) 119 | total = sum(stat.size for stat in top_stats) 120 | print("Total allocated size: %.1f KiB" % (total / 1024), file=log_file) 121 | #if old_snapshot is not None: 122 | # diff_stats = snapshot.compare_to(old_snapshot, 'lineno') 123 | if log_file is not None: 124 | log_file.close() 125 | -------------------------------------------------------------------------------- /src/SparseSC/utils/sub_matrix_inverse.py: -------------------------------------------------------------------------------- 1 | """ In the leave-one-out method with larger sample sizes most of the time is spent in calculating 2 | A.I.dot(B) using np.linalg.solve (which is much faster than the brute force method i.e. `A.I.dot(B)`). 3 | For example, with 200 units, about about 95% of the time is spent in this line. 4 | 5 | However we can take advantage of the fact that we're calculating the 6 | inverse of N matrices which are all sub-matrices of a common matrix by inverting the common matrix 7 | and calculating each subset inverse from there. 8 | 9 | https://math.stackexchange.com/a/208021/252693 10 | 11 | 12 | TODO: this could be made a bit faster by passing in the indexes (k_rng, 13 | k_rng2) instead of re-building them 14 | """ 15 | # pylint: skip-file 16 | import numpy as np 17 | 18 | def subinv(x,eps=None): 19 | """ Given an matrix (x), calculate all the inverses of leave-one-out sub-matrices. 20 | 21 | :param x: a square matrix for which to find the inverses of all it's leave one out sub-matrices. 22 | :param eps: If not None, used to assert that the each calculated 23 | sub-matrix-inverse is within eps of the brute force calculation. 24 | Testing only, this slows the process way down since the inverse of 25 | each sub-matrix is calculated by the brute force method. Typically 26 | set to a multiple of `np.finfo(float).eps` 27 | """ 28 | # handy constant for indexing 29 | xi = x.I 30 | N = x.shape[0] 31 | rng = np.arange(N) 32 | out = [None,] * N 33 | for k in range(N): 34 | k_rng = rng[rng != k] 35 | out[k] = xi[np.ix_(k_rng,k_rng)] - xi[k_rng,k].dot(xi[k,k_rng])/xi[k,k] 36 | if eps is not None: 37 | if not (abs(out[k] - x[np.ix_(k_rng,k_rng)].I) < eps).all(): 38 | raise RuntimeError("Fast and brute force methods were not within epsilon (%s) for sub-matrix k = %s; max difference = %s" % 39 | (eps, k, abs(out[k] - x[np.ix_(k_rng,k_rng)].I).max(), ) ) 40 | return out 41 | 42 | def subinv_k(xi,k,eps=None): 43 | """ Given an matrix (x), calculate all the inverses of leave-one-out sub-matrices. 44 | 45 | :param x: a square matrix for which to find the inverses of all it's leave one out sub-matrices. 46 | :param k: the column and row to leave out 47 | :param eps: If not None, used to assert that the each calculated 48 | sub-matrix-inverse is within eps of the brute force calculation. 49 | Testing only, this slows the process way down since the inverse of 50 | each sub-matrix is calculated by the brute force method. Typically 51 | set to a multiple of `np.finfo(float).eps` 52 | """ 53 | # handy constant for indexing 54 | N = xi.shape[0] 55 | rng = np.arange(N) 56 | k_rng = rng[rng != k] 57 | out = xi[np.ix_(k_rng,k_rng)] - xi[k_rng,k].dot(xi[k,k_rng])/xi[k,k] 58 | if eps is not None: 59 | if not (abs(out[k] - x[np.ix_(k_rng,k_rng)].I) < eps).all(): 60 | raise RuntimeError("Fast and brute force methods were not within epsilon (%s) for sub-matrix k = %s; max difference = %s" % (eps, k, abs(out[k] - x[np.ix_(k_rng,k_rng)].I).max(), ) ) 61 | return out 62 | 63 | 64 | 65 | # --------------------------------------------- 66 | # single sub-matrix 67 | # --------------------------------------------- 68 | 69 | if __name__ == "__main__": 70 | 71 | import time 72 | 73 | n = 200 74 | B = np.matrix(np.random.random((n,n,))) 75 | 76 | 77 | n = 5 78 | p = 3 79 | a = np.matrix(np.random.random((n,p,))) 80 | v = np.diag(np.random.random((p,))) 81 | x = a.dot(v).dot(a.T) 82 | x.dot(x.I) 83 | 84 | x = np.matrix(np.random.random((n,n,))) 85 | x.dot(x.I) 86 | 87 | xi = x.I 88 | 89 | 90 | B = np.matrix(np.random.random((n,n,))) 91 | 92 | k = np.arange(2) 93 | N = xi.shape[0] 94 | rng = np.arange(N) 95 | k_rng = rng[np.logical_not(np.isin(rng,k))] 96 | 97 | out = xi[np.ix_(k_rng,k_rng)] - xi[np.ix_(k_rng,k)].dot(xi[np.ix_(k,k_rng)])/np.linalg.det(xi[np.ix_(k,k)]) 98 | 99 | for i in range(100): 100 | # create a sub-matrix that meets the matching criteria 101 | x = np.matrix(np.random.random((n,n,))) 102 | try: 103 | zz = subinv(x,10e-10) 104 | break 105 | except: 106 | pass 107 | else: 108 | print("Failed to generate a %sx%s matrix whose inverses are all within %s of the quick method") 109 | 110 | 111 | k = 5 112 | n_tests = 1000 113 | 114 | # ======================= 115 | t0 = time.time() 116 | for i in range(n_tests): 117 | _N = xi.shape[0] 118 | rng = np.arange(_N) 119 | k_rng = rng[rng != k] 120 | k_rng2 = np.ix_(k_rng,k_rng) 121 | zz = x[k_rng2].I.dot(B[k_rng2]) 122 | 123 | t1 = time.time() 124 | slow_time = t1 - t0 125 | print("A.I.dot(B): brute force time (N = %s): %s"% (n,t1 - t0)) 126 | 127 | # ======================= 128 | t0 = time.time() 129 | for i in range(n_tests): 130 | # make the comparison fair 131 | if i % n == 0: 132 | xi = x.I 133 | zz = subinv_k(xi,k).dot(B[k_rng2]) 134 | 135 | t1 = time.time() 136 | fast_time = t1 - t0 137 | print("A.I.dot(B): quick time (N = %s): %s"% (n,t1 - t0)) 138 | 139 | # ======================= 140 | t0 = time.time() 141 | for i in range(n_tests): 142 | _N = xi.shape[0] 143 | rng = np.arange(_N) 144 | k_rng = rng[rng != k] 145 | k_rng2 = np.ix_(k_rng,k_rng) 146 | zz = np.linalg.solve(x[k_rng2],B[k_rng2]) 147 | 148 | t1 = time.time() 149 | fast_time = t1 - t0 150 | print("A.I.dot(B): np.linalg.solve time (N = %s): %s"% (n,t1 - t0)) 151 | 152 | # --------------------------------------------- 153 | # --------------------------------------------- 154 | 155 | t0 = time.time() 156 | for i in range(100): 157 | zz = subinv(x,10e-10) 158 | 159 | t1 = time.time() 160 | slow_time = t1 - t0 161 | print("Full set of inverses: brute force time (N = %s): %s", (n,t1 - t0)) 162 | 163 | t0 = time.time() 164 | for i in range(100): 165 | zz = subinv(x) 166 | 167 | t1 = time.time() 168 | fast_time = t1 - t0 169 | print("Full set of inverses: quick time (N = %s): %s", (n,t1 - t0)) 170 | 171 | # --------------------------------------------- 172 | # --------------------------------------------- 173 | 174 | -------------------------------------------------------------------------------- /src/SparseSC/utils/warnings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Warnings used throughout the module. Exported and inhereited from a commmon 3 | warning class so as to facilitate the filtering of warnings 4 | """ 5 | # pylint: disable=too-few-public-methods, missing-docstring 6 | class SparseSCWarning(RuntimeWarning):pass 7 | class UnpenalizedRecords(SparseSCWarning):pass 8 | 9 | 10 | -------------------------------------------------------------------------------- /src/SparseSC/weights.py: -------------------------------------------------------------------------------- 1 | """ 2 | Presents a unified API for the various weights methods 3 | """ 4 | from SparseSC.fit_loo import loo_weights 5 | from SparseSC.fit_ct import ct_weights 6 | from SparseSC.fit_fold import fold_weights 7 | import numpy as np 8 | 9 | 10 | def weights(X, X_treat=None, grad_splits=None, custom_donor_pool=None, **kwargs): 11 | """ Calculate synthetic control weights 12 | """ 13 | 14 | # PARAMETER QC 15 | try: 16 | X = np.float64(X) 17 | except ValueError: 18 | raise TypeError("X is not coercible to float64") 19 | 20 | X = np.asmatrix(X) # this needs to be deprecated properly -- bc Array.dot(Array) != matrix(Array).dot(matrix(Array)) -- not even close !!! 21 | if X_treat is not None: 22 | # weight for the control units against the remaining controls: 23 | 24 | if X_treat.shape[1] == 0: 25 | raise ValueError("X_treat.shape[1] == 0") 26 | 27 | # PARAMETER QC 28 | try: 29 | X_treat = np.float64(X_treat) 30 | except ValueError: 31 | raise ValueError("X_treat is not coercible to float64") 32 | 33 | # this needs to be deprecated properly -- bc Array.dot(Array) != matrix(Array).dot(matrix(Array)) -- not even close !!! 34 | X_treat = np.asmatrix(X_treat) 35 | 36 | if X_treat.shape[1] == 0: 37 | raise ValueError("X_treat.shape[1] == 0") 38 | 39 | # FIT THE V-MATRIX AND POSSIBLY CALCULATE THE w_pen 40 | # note that the weights, score, and loss function value returned here 41 | # are for the in-sample predictions 42 | return ct_weights( 43 | X=np.vstack((X, X_treat)), 44 | control_units=np.arange(X.shape[0]), 45 | treated_units=np.arange(X_treat.shape[0]) + X.shape[0], 46 | custom_donor_pool=custom_donor_pool, 47 | **kwargs 48 | ) 49 | 50 | # === X_treat is None: === 51 | 52 | if grad_splits is not None: 53 | return fold_weights(X=X, grad_splits=grad_splits, **kwargs) 54 | 55 | # === X_treat is None and grad_splits is None: === 56 | 57 | # weight for the control units against the remaining controls 58 | return loo_weights( 59 | X=X, 60 | control_units=np.arange(X.shape[0]), 61 | treated_units=np.arange(X.shape[0]), 62 | custom_donor_pool=custom_donor_pool, 63 | **kwargs 64 | ) 65 | -------------------------------------------------------------------------------- /static/controls-only-pre-and-post.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/static/controls-only-pre-and-post.png -------------------------------------------------------------------------------- /static/pre-only-controls-and-treated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/static/pre-only-controls-and-treated.png -------------------------------------------------------------------------------- /test/AzureBatch/README.md: -------------------------------------------------------------------------------- 1 | # Batch config testing 2 | 3 | ## Install Batch Config 4 | 5 | ```shell 6 | pip install git+https://github.com/jdthorpe/batch-configo 7 | ``` 8 | 9 | ## Gather required credentials 10 | 11 | These commands assume the CMD terminal. For other terminals, [see here](https://jdthorpe.github.io/super-batch-docs/create-resources). 12 | 13 | ```bat 14 | az login 15 | # optionally: az account set -s xxxxxxx-xxxx-xxxx-xxxx-xxxxxxxx 16 | 17 | set name=sparsescbatchtesting 18 | set rgname=SparseSC-batch-testing 19 | set BATCH_ACCOUNT_NAME=%name% 20 | for /f %i in ('az batch account keys list -n %name% -g %rgname% --query primary') do @set BATCH_ACCOUNT_KEY=%i 21 | for /f %i in ('az batch account show -n %name% -g %rgname% --query accountEndpoint') do @set BATCH_ACCOUNT_ENDPOINT=%i 22 | for /f %i in ('az storage account keys list -n %name% --query [0].value') do @set STORAGE_ACCOUNT_KEY=%i 23 | for /f %i in ('az storage account show-connection-string --name %name% --query connectionString') do @set STORAGE_ACCOUNT_CONNECTION_STRING=%i 24 | 25 | # clean up the quotes 26 | set BATCH_ACCOUNT_KEY=%BATCH_ACCOUNT_KEY:"=% 27 | set BATCH_ACCOUNT_ENDPOINT=%BATCH_ACCOUNT_ENDPOINT:"=% 28 | set STORAGE_ACCOUNT_KEY=%STORAGE_ACCOUNT_KEY:"=% 29 | set STORAGE_ACCOUNT_CONNECTION_STRING=%STORAGE_ACCOUNT_CONNECTION_STRING:"=% 30 | ``` 31 | 32 | ## run the tests 33 | 34 | ```bat 35 | cd test\AzureBatch 36 | rm -rf data 37 | python test_batch_build.py 38 | python test_batch_run.py 39 | python test_batch_aggregate.py 40 | ``` 41 | -------------------------------------------------------------------------------- /test/AzureBatch/test_batch_aggregate.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- 2 | # Programmer: Jason Thorpe 3 | # Date 1/25/2019 3:34:02 PM 4 | # Language: Python (.py) Version 2.7 or 3.5 5 | # Usage: 6 | # 7 | # Test all model types 8 | # 9 | # \SpasrseSC > python -m unittest test/test_fit.py 10 | # 11 | # Test a specific model type (e.g. "prospective-restricted"): 12 | # 13 | # \SpasrseSC > python -m unittest test.test_fit.TestFit.test_retrospective 14 | # 15 | # -------------------------------------------------------------------------------- 16 | # pylint: disable=multiple-imports, missing-docstring 17 | """ 18 | usage 19 | 20 | az login 21 | 22 | name="sparsesctest" 23 | location="westus2" 24 | 25 | export BATCH_ACCOUNT_NAME=$name 26 | export BATCH_ACCOUNT_KEY=$(az batch account keys list -n $name -g $name --query primary) 27 | export BATCH_ACCOUNT_URL="https://$name.$location.batch.azure.com" 28 | export STORAGE_ACCOUNT_NAME=$name 29 | export STORAGE_ACCOUNT_KEY=$(az storage account keys list -n $name --query [0].value) 30 | 31 | """ 32 | 33 | 34 | from __future__ import print_function # for compatibility with python 2.7 35 | import numpy as np 36 | import sys, os, random 37 | import unittest 38 | import warnings 39 | 40 | try: 41 | from SparseSC import fit 42 | except ImportError: 43 | raise RuntimeError("SparseSC is not installed. use 'pip install -e .' to install") 44 | from scipy.optimize.linesearch import LineSearchWarning 45 | from SparseSC.utils.AzureBatch import aggregate_batch_results 46 | 47 | 48 | class TestFit(unittest.TestCase): 49 | def setUp(self): 50 | 51 | random.seed(12345) 52 | np.random.seed(101101001) 53 | control_units = 50 54 | treated_units = 20 55 | features = 10 56 | targets = 5 57 | 58 | self.X = np.random.rand(control_units + treated_units, features) 59 | self.Y = np.random.rand(control_units + treated_units, targets) 60 | self.treated_units = np.arange(treated_units) 61 | 62 | @classmethod 63 | def run_test(cls, obj, model_type, verbose=False): 64 | if verbose: 65 | print("Calling fit with `model_type = '%s'`..." % (model_type,), end="") 66 | sys.stdout.flush() 67 | 68 | batchdir = os.path.join( 69 | os.path.dirname(os.path.realpath(__file__)), "data", "batchTest" 70 | ) 71 | assert os.path.exists(batchdir), "Batch Directory '{}' does not exist".format( 72 | batchdir 73 | ) 74 | 75 | with warnings.catch_warnings(): 76 | warnings.filterwarnings("ignore", category=PendingDeprecationWarning) 77 | warnings.filterwarnings("ignore", category=LineSearchWarning) 78 | try: 79 | verbose = 0 80 | model_a = fit( 81 | features=obj.X, 82 | targets=obj.Y, 83 | model_type=model_type, 84 | treated_units=obj.treated_units 85 | if model_type 86 | in ("retrospective", "prospective", "prospective-restricted") 87 | else None, 88 | # KWARGS: 89 | print_path=verbose, 90 | stopping_rule=1, 91 | progress=0, 92 | grid_length=5, 93 | min_iter=-1, 94 | tol=1, 95 | verbose=0, 96 | ) 97 | 98 | model_b = aggregate_batch_results( 99 | batchDir=batchdir 100 | ) # , batch_client_config="sg_daemon" 101 | 102 | assert np.all( 103 | np.abs(model_a.scores - model_b.scores) < 1e-14 104 | ), "model scores are not within rounding error" 105 | 106 | if verbose: 107 | print("DONE") 108 | except LineSearchWarning: 109 | pass 110 | except PendingDeprecationWarning: 111 | pass 112 | except Exception as exc: # pylint: disable=broad-except 113 | print( 114 | "Failed with %s(%s)" 115 | % (exc.__class__.__name__, getattr(exc, "message", "")) 116 | ) 117 | raise exc 118 | 119 | def test_retrospective(self): 120 | TestFit.run_test(self, "retrospective") 121 | 122 | 123 | # -- def test_prospective(self): 124 | # -- TestFit.run_test(self, "prospective") 125 | # -- 126 | # -- def test_prospective_restrictive(self): 127 | # -- # Catch the LineSearchWarning silently, but allow others 128 | # -- 129 | # -- TestFit.run_test(self, "prospective-restricted") 130 | # -- 131 | # -- def test_full(self): 132 | # -- TestFit.run_test(self, "full") 133 | 134 | 135 | if __name__ == "__main__": 136 | unittest.main() 137 | -------------------------------------------------------------------------------- /test/AzureBatch/test_batch_build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- 2 | # Programmer: Jason Thorpe 3 | # Date 1/25/2019 3:34:02 PM 4 | # Language: Python (.py) Version 2.7 or 3.5 5 | # Usage: 6 | # 7 | # Test all model types 8 | # 9 | # \SpasrseSC > python -m unittest test/test_fit.py 10 | # 11 | # Test a specific model type (e.g. "prospective-restricted"): 12 | # 13 | # \SpasrseSC > python -m unittest test.test_fit.TestFit.test_retrospective 14 | # 15 | # -------------------------------------------------------------------------------- 16 | 17 | from __future__ import print_function # for compatibility with python 2.7 18 | import numpy as np 19 | import sys, os, random 20 | import unittest 21 | import warnings 22 | from scipy.optimize.linesearch import LineSearchWarning 23 | 24 | try: 25 | from SparseSC import fit 26 | except ImportError: 27 | raise RuntimeError("SparseSC is not installed. use 'pip install -e .' to install") 28 | 29 | 30 | class TestFit(unittest.TestCase): 31 | def setUp(self): 32 | 33 | random.seed(12345) 34 | np.random.seed(101101001) 35 | control_units = 50 36 | treated_units = 20 37 | features = 10 38 | targets = 5 39 | 40 | self.X = np.random.rand(control_units + treated_units, features) 41 | self.Y = np.random.rand(control_units + treated_units, targets) 42 | self.treated_units = np.arange(treated_units) 43 | 44 | @classmethod 45 | def run_test(cls, obj, model_type, verbose=False): 46 | if verbose: 47 | print("Calling fit with `model_type = '%s'`..." % (model_type,), end="") 48 | sys.stdout.flush() 49 | 50 | batchdir = os.path.join( 51 | os.path.dirname(os.path.realpath(__file__)), "data", "batchTest" 52 | ) 53 | print("dumping batch artifacts to: {}'".format(batchdir)) 54 | 55 | with warnings.catch_warnings(): 56 | warnings.filterwarnings("ignore", category=PendingDeprecationWarning) 57 | warnings.filterwarnings("ignore", category=LineSearchWarning) 58 | try: 59 | fit( 60 | features=obj.X, 61 | targets=obj.Y, 62 | model_type=model_type, 63 | treated_units=obj.treated_units 64 | if model_type 65 | in ("retrospective", "prospective", "prospective-restricted") 66 | else None, 67 | # KWARGS: 68 | print_path=False, 69 | stopping_rule=1, 70 | progress=verbose, 71 | grid_length=5, 72 | min_iter=-1, 73 | tol=1, 74 | verbose=0, 75 | batchDir=batchdir, 76 | ) 77 | if verbose: 78 | print("DONE") 79 | except LineSearchWarning: 80 | pass 81 | except PendingDeprecationWarning: 82 | pass 83 | except Exception as exc: # pylint: disable=broad-except 84 | print( 85 | "Failed with %s(%s)" 86 | % (exc.__class__.__name__, getattr(exc, "message", "")) 87 | ) 88 | raise exc 89 | 90 | def test_retrospective(self): 91 | TestFit.run_test(self, "retrospective") 92 | 93 | 94 | # -- def test_prospective(self): 95 | # -- TestFit.run_test(self, "prospective") 96 | # -- 97 | # -- def test_prospective_restrictive(self): 98 | # -- # Catch the LineSearchWarning silently, but allow others 99 | # -- 100 | # -- TestFit.run_test(self, "prospective-restricted") 101 | # -- 102 | # -- def test_full(self): 103 | # -- TestFit.run_test(self, "full") 104 | 105 | 106 | if __name__ == "__main__": 107 | # t = TestFit() 108 | # t.setUp() 109 | # t.test_retrospective() 110 | unittest.main() 111 | -------------------------------------------------------------------------------- /test/AzureBatch/test_batch_run.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- 2 | # Programmer: Jason Thorpe 3 | # Date 1/25/2019 3:34:02 PM 4 | # Language: Python (.py) Version 2.7 or 3.5 5 | # Usage: 6 | # 7 | # Test all model types 8 | # 9 | # \SpasrseSC > python -m unittest test/test_fit.py 10 | # 11 | # Test a specific model type (e.g. "prospective-restricted"): 12 | # 13 | # \SpasrseSC > python -m unittest test.test_fit.TestFit.test_retrospective 14 | # 15 | # -------------------------------------------------------------------------------- 16 | # pylint: disable=multiple-imports, missing-docstring, no-self-use 17 | """ 18 | USAGE (CMD): 19 | 20 | az login 21 | 22 | set name=sparsescbatchtesting 23 | set rgname=SparseSC-batch-testing 24 | set BATCH_ACCOUNT_NAME=%name% 25 | for /f %i in ('az batch account keys list -n %name% -g %rgname% --query primary') do @set BATCH_ACCOUNT_KEY=%i 26 | for /f %i in ('az batch account show -n %name% -g %rgname% --query accountEndpoint') do @set BATCH_ACCOUNT_ENDPOINT=%i 27 | for /f %i in ('az storage account keys list -n %name% --query [0].value') do @set STORAGE_ACCOUNT_KEY=%i 28 | for /f %i in ('az storage account show-connection-string --name %name% --query connectionString') do @set STORAGE_ACCOUNT_CONNECTION_STRING=%i 29 | 30 | # clean up the quotes 31 | set BATCH_ACCOUNT_KEY=%BATCH_ACCOUNT_KEY:"=% 32 | set BATCH_ACCOUNT_ENDPOINT=%BATCH_ACCOUNT_ENDPOINT:"=% 33 | set STORAGE_ACCOUNT_KEY=%STORAGE_ACCOUNT_KEY:"=% 34 | set STORAGE_ACCOUNT_CONNECTION_STRING=%STORAGE_ACCOUNT_CONNECTION_STRING:"=% 35 | 36 | cd test\AzureBatch 37 | rm -rf data 38 | python test_batch_build.py 39 | python test_batch_run.py 40 | python test_batch_aggregate.py 41 | """ 42 | 43 | from __future__ import print_function # for compatibility with python 2.7 44 | import os, unittest, datetime 45 | from os.path import join, realpath, dirname, exists 46 | from super_batch import Client 47 | 48 | from SparseSC.utils.AzureBatch import ( 49 | DOCKER_IMAGE_NAME, 50 | create_job, 51 | ) 52 | 53 | 54 | class TestFit(unittest.TestCase): 55 | def test_retrospective_no_wait(self): 56 | """ 57 | test the no-wait and load_results API 58 | """ 59 | 60 | name = os.getenv("name") 61 | if name is None: 62 | raise RuntimeError( 63 | "Please create an environment variable called 'name' as en the example docs" 64 | ) 65 | batch_dir = join(dirname(realpath(__file__)), "data", "batchTest") 66 | assert exists(batch_dir), "Batch Directory '{}' does not exist".format( 67 | batch_dir 68 | ) 69 | 70 | timestamp = datetime.datetime.utcnow().strftime("%H%M%S") 71 | 72 | batch_client = Client( 73 | POOL_ID=name, 74 | POOL_LOW_PRIORITY_NODE_COUNT=5, 75 | POOL_VM_SIZE="STANDARD_A1_v2", 76 | JOB_ID=name + timestamp, 77 | BLOB_CONTAINER_NAME=name, 78 | BATCH_DIRECTORY=batch_dir, 79 | DOCKER_IMAGE=DOCKER_IMAGE_NAME, 80 | ) 81 | create_job(batch_client, batch_dir) 82 | 83 | batch_client.run(wait=False) 84 | batch_client.load_results() 85 | 86 | def test_retrospective(self): 87 | 88 | name = os.getenv("name") 89 | if name is None: 90 | raise RuntimeError( 91 | "Please create an environment variable called 'name' as en the example docs" 92 | ) 93 | batch_dir = join(dirname(realpath(__file__)), "data", "batchTest") 94 | assert exists(batch_dir), "Batch Directory '{}' does not exist".format( 95 | batch_dir 96 | ) 97 | 98 | timestamp = datetime.datetime.utcnow().strftime("%H%M%S") 99 | 100 | batch_client = Client( 101 | POOL_ID=name, 102 | POOL_LOW_PRIORITY_NODE_COUNT=5, 103 | POOL_VM_SIZE="STANDARD_A1_v2", 104 | JOB_ID=name + timestamp, 105 | BLOB_CONTAINER_NAME=name, 106 | BATCH_DIRECTORY=batch_dir, 107 | DOCKER_IMAGE=DOCKER_IMAGE_NAME, 108 | ) 109 | 110 | create_job(batch_client, batch_dir) 111 | batch_client.run() 112 | 113 | 114 | if __name__ == "__main__": 115 | unittest.main() 116 | -------------------------------------------------------------------------------- /test/CausalImpact_test.R: -------------------------------------------------------------------------------- 1 | library(CausalImpact) 2 | 3 | set.seed(1) 4 | x1 <- 100 + arima.sim(model = list(ar = 0.999), n = 100) 5 | x2 <- 100 + arima.sim(model = list(ar = 0.999), n = 100) 6 | x3 <- 100 + arima.sim(model = list(ar = 0.999), n = 100) 7 | x4 <- 100 + arima.sim(model = list(ar = 0.999), n = 100) 8 | y <- 1.2 * x1 + rnorm(100) + .8*x2 + .7*x3 + .6*x4 9 | y[71:100] <- y[71:100] + 10 10 | data <- cbind(y, x1, x2, x3, x4) 11 | 12 | pre.period <- c(1, 3) #70 13 | post.period <- c(4, 100) 14 | 15 | #impact <- CausalImpact(data, pre.period, post.period, alpha = 0.05) 16 | impact <- CausalImpact(data, pre.period, post.period) 17 | 18 | plot(impact) 19 | summary(impact) 20 | 21 | plot(impact$model$bsts.model, "coefficients") 22 | impact$summary$AbsEffect 23 | impact$summary$AbsEffect.lower 24 | impact$summary$AbsEffect.upper 25 | s = impact$series 26 | 27 | summary(lm(s$point.pred~x1+x2)) 28 | -------------------------------------------------------------------------------- /test/SparseSC_36.yml: -------------------------------------------------------------------------------- 1 | name: SparseSC_36 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _r-mutex=1.0.0=anacondar_1 6 | - _tflow_select=2.3.0=mkl 7 | - absl-py=0.7.1=py36_0 8 | - alabaster=0.7.12=py36_0 9 | - asn1crypto=0.24.0=py36_0 10 | - astor=0.7.1=py36_0 11 | - astroid=1.6.5=py36_0 12 | - attrs=19.1.0=py36_1 13 | - autopep8=1.4.4=py_0 14 | - babel=2.6.0=py36_0 15 | - backcall=0.1.0=py36_0 16 | - blas=1.0=mkl 17 | - bleach=3.1.0=py36_0 18 | - ca-certificates=2020.1.1=0 19 | - certifi=2020.4.5.1=py36_0 20 | - cffi=1.12.3=py36h7a1dbc1_0 21 | - chardet=3.0.4=py36_1 22 | - colorama=0.4.1=py36_0 23 | - cryptography=2.6.1=py36h7a1dbc1_0 24 | - cycler=0.10.0=py36h009560c_0 25 | - decorator=4.4.0=py36_1 26 | - defusedxml=0.6.0=py_0 27 | - docutils=0.14=py36h6012d8f_0 28 | - entrypoints=0.3=py36_0 29 | - freetype=2.9.1=ha9979f8_1 30 | - gast=0.2.2=py36_0 31 | - grpcio=1.16.1=py36h351948d_1 32 | - h5py=2.9.0=py36h5e291fa_0 33 | - hdf5=1.10.4=h7ebc959_0 34 | - icc_rt=2019.0.0=h0cc432a_1 35 | - icu=58.2=ha66f8fd_1 36 | - idna=2.8=py36_0 37 | - imagesize=1.1.0=py36_0 38 | - intel-openmp=2019.3=203 39 | - ipykernel=5.1.0=py36h39e3cac_0 40 | - ipython=7.5.0=py36h39e3cac_0 41 | - ipython_genutils=0.2.0=py36h3c5d0ee_0 42 | - isort=4.3.19=py36_0 43 | - jedi=0.13.3=py36_0 44 | - jinja2=2.10.1=py36_0 45 | - jpeg=9b=hb83a4c4_2 46 | - jsonschema=3.0.1=py36_0 47 | - jupyter_client=5.2.4=py36_0 48 | - jupyter_core=4.4.0=py36_0 49 | - keras=2.2.4=0 50 | - keras-applications=1.0.8=py_0 51 | - keras-base=2.2.4=py36_0 52 | - keras-preprocessing=1.1.0=py_1 53 | - kiwisolver=1.1.0=py36ha925a31_0 54 | - lazy-object-proxy=1.4.1=py36he774522_0 55 | - libmklml=2019.0.5=0 56 | - libpng=1.6.37=h2a8f88b_0 57 | - libprotobuf=3.8.0=h7bd577a_0 58 | - libsodium=1.0.16=h9d3ae62_0 59 | - m2w64-bwidget=1.9.10=2 60 | - m2w64-bzip2=1.0.6=6 61 | - m2w64-expat=2.1.1=2 62 | - m2w64-fftw=3.3.4=6 63 | - m2w64-flac=1.3.1=3 64 | - m2w64-gcc-libgfortran=5.3.0=6 65 | - m2w64-gcc-libs=5.3.0=7 66 | - m2w64-gcc-libs-core=5.3.0=7 67 | - m2w64-gettext=0.19.7=2 68 | - m2w64-gmp=6.1.0=2 69 | - m2w64-libiconv=1.14=6 70 | - m2w64-libjpeg-turbo=1.4.2=3 71 | - m2w64-libogg=1.3.2=3 72 | - m2w64-libpng=1.6.21=2 73 | - m2w64-libsndfile=1.0.26=2 74 | - m2w64-libtiff=4.0.6=2 75 | - m2w64-libvorbis=1.3.5=2 76 | - m2w64-libwinpthread-git=5.0.0.4634.697f757=2 77 | - m2w64-mpfr=3.1.4=4 78 | - m2w64-openblas=0.2.19=1 79 | - m2w64-pcre=8.38=2 80 | - m2w64-speex=1.2rc2=3 81 | - m2w64-speexdsp=1.2rc3=3 82 | - m2w64-tcl=8.6.5=3 83 | - m2w64-tk=8.6.5=3 84 | - m2w64-tktable=2.10=5 85 | - m2w64-wineditline=2.101=5 86 | - m2w64-xz=5.2.2=2 87 | - m2w64-zlib=1.2.8=10 88 | - markupsafe=1.1.1=py36he774522_0 89 | - matplotlib=3.1.0=py36hc8f65d3_0 90 | - mccabe=0.6.1=py36_1 91 | - mistune=0.8.4=py36he774522_0 92 | - mkl=2018.0.3=1 93 | - mkl_fft=1.0.6=py36hdbbee80_0 94 | - mkl_random=1.0.1=py36h77b88f5_1 95 | - mock=3.0.5=py36_0 96 | - msys2-conda-epoch=20160418=1 97 | - nbconvert=5.5.0=py_0 98 | - nbformat=4.4.0=py36h3a5bc1b_0 99 | - notebook=5.6.0=py36_0 100 | - numpy=1.15.1=py36hc27ee41_0 101 | - numpy-base=1.15.1=py36h8128ebf_0 102 | - openssl=1.1.1c=he774522_1 103 | - pandas=0.24.2=py36ha925a31_0 104 | - pandoc=2.2.3.2=0 105 | - pandocfilters=1.4.2=py36_1 106 | - parso=0.4.0=py_0 107 | - patsy=0.5.1=py36_0 108 | - pickleshare=0.7.5=py36_0 109 | - pip=19.1.1=py36_0 110 | - prometheus_client=0.6.0=py36_0 111 | - prompt_toolkit=2.0.9=py36_0 112 | - protobuf=3.8.0=py36h33f27b4_0 113 | - psutil=5.6.3=py36he774522_0 114 | - pycodestyle=2.5.0=py36_0 115 | - pycparser=2.19=py36_0 116 | - pygments=2.4.0=py_0 117 | - pylint=1.9.2=py36_0 118 | - pyopenssl=19.0.0=py36_0 119 | - pyparsing=2.4.0=py_0 120 | - pyqt=5.9.2=py36h6538335_2 121 | - pyreadline=2.1=py36_1 122 | - pyrsistent=0.14.11=py36he774522_0 123 | - pysocks=1.7.0=py36_0 124 | - python=3.6.8=h9f7ef89_7 125 | - python-dateutil=2.8.0=py36_0 126 | - pytz=2019.1=py_0 127 | - pywinpty=0.5.5=py36_1000 128 | - pyyaml=5.1=py36he774522_0 129 | - pyzmq=18.0.0=py36ha925a31_0 130 | - qt=5.9.7=vc14h73c81de_0 131 | - r-assertthat=0.2.1=r36h6115d3f_0 132 | - r-base=3.6.0=hf18239d_0 133 | - r-bh=1.69.0_1=r36h6115d3f_0 134 | - r-bit=1.1_14=r36h6115d3f_0 135 | - r-bit64=0.9_7=r36h6115d3f_0 136 | - r-blob=1.1.1=r36h6115d3f_0 137 | - r-cli=1.1.0=r36h6115d3f_0 138 | - r-crayon=1.3.4=r36h6115d3f_0 139 | - r-dbi=1.0.0=r36h6115d3f_0 140 | - r-dbplyr=1.4.0=r36h6115d3f_0 141 | - r-digest=0.6.18=r36h6115d3f_0 142 | - r-dplyr=0.8.0.1=r36h6115d3f_0 143 | - r-fansi=0.4.0=r36h6115d3f_0 144 | - r-glue=1.3.1=r36h6115d3f_0 145 | - r-magrittr=1.5=r36h6115d3f_4 146 | - r-memoise=1.1.0=r36h6115d3f_0 147 | - r-pillar=1.3.1=r36h6115d3f_0 148 | - r-pkgconfig=2.0.2=r36h6115d3f_0 149 | - r-plogr=0.2.0=r36h6115d3f_0 150 | - r-prettyunits=1.0.2=r36h6115d3f_0 151 | - r-purrr=0.3.2=r36h6115d3f_0 152 | - r-r6=2.4.0=r36h6115d3f_0 153 | - r-rcpp=1.0.1=r36h6115d3f_0 154 | - r-rlang=0.3.4=r36h6115d3f_0 155 | - r-rsqlite=2.1.1=r36h6115d3f_0 156 | - r-tibble=2.1.1=r36h6115d3f_0 157 | - r-tidyselect=0.2.5=r36h6115d3f_0 158 | - r-utf8=1.1.4=r36h6115d3f_0 159 | - requests=2.21.0=py36_0 160 | - rope=0.14.0=py_0 161 | - rpy2=2.9.4=py36r36h39e3cac_0 162 | - scikit-learn=0.19.1=py36hae9bb9f_0 163 | - scipy=1.0.1=py36hce232c7_0 164 | - send2trash=1.5.0=py36_0 165 | - setuptools=41.0.1=py36_0 166 | - sip=4.19.8=py36h6538335_0 167 | - six=1.12.0=py36_0 168 | - snowballstemmer=1.2.1=py36h763602f_0 169 | - sphinx=1.6.3=py36h9bb690b_0 170 | - sphinxcontrib=1.0=py36_1 171 | - sphinxcontrib-websupport=1.1.0=py36_1 172 | - sqlite=3.28.0=he774522_0 173 | - statsmodels=0.10.1=py36h8c2d366_0 174 | - tbb=2019.4=h74a9793_0 175 | - tbb4py=2019.4=py36h74a9793_0 176 | - tensorboard=1.13.1=py36h33f27b4_0 177 | - tensorflow=1.13.1=mkl_py36hd212fbe_0 178 | - tensorflow-base=1.13.1=mkl_py36hcaf7020_0 179 | - tensorflow-estimator=1.13.0=py_0 180 | - termcolor=1.1.0=py36_1 181 | - terminado=0.8.2=py36_0 182 | - testpath=0.4.2=py36_0 183 | - tornado=5.1.1=py36hfa6e2cd_0 184 | - traitlets=4.3.2=py36h096827d_0 185 | - typing=3.6.4=py36_0 186 | - urllib3=1.24.2=py36_0 187 | - vc=14.1=h0510ff6_4 188 | - vs2015_runtime=14.15.26706=h3a45250_4 189 | - wcwidth=0.1.7=py36h3d5aa90_0 190 | - webencodings=0.5.1=py36_1 191 | - werkzeug=0.15.4=py_0 192 | - wheel=0.33.4=py36_0 193 | - win_inet_pton=1.1.0=py36_0 194 | - wincertstore=0.2=py36h7fe50ca_0 195 | - winpty=0.4.3=4 196 | - wrapt=1.11.1=py36he774522_0 197 | - yaml=0.1.7=hc54c509_2 198 | - zeromq=4.3.1=h33f27b4_3 199 | - zlib=1.2.11=h62dcd97_3 200 | - pip: 201 | - adal==1.2.1 202 | - azure-batch==6.0.1 203 | - azure-common==1.1.20 204 | - azure-storage-blob==2.0.1 205 | - azure-storage-common==2.0.0 206 | - commonmark==0.9.0 207 | - future==0.17.1 208 | - isodate==0.6.0 209 | - markdown==2.6.11 210 | - msrest==0.6.6 211 | - msrestazure==0.6.0 212 | - oauthlib==3.0.1 213 | - pyjwt==1.7.1 214 | - recommonmark==0.5.0 215 | - requests-oauthlib==1.2.0 216 | - sphinx-markdown-tables==0.0.9 217 | - sphinx-rtd-theme==0.4.3 218 | 219 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/test/__init__.py -------------------------------------------------------------------------------- /test/dgp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/test/dgp/__init__.py -------------------------------------------------------------------------------- /test/dgp/factor_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Factor DGP 3 | """ 4 | 5 | import numpy as np 6 | 7 | def factor_dgp(N0,N1,T0,T1,K,R,F): 8 | ''' 9 | Factor DGP. Treatment effect is 0. 10 | 11 | Covariates: Values are drawn from N(0,1) and coefficients are drawn from an 12 | exponential(1) and then scaled by the max. 13 | Factors and Loadings are from N(0,1) 14 | Errors are from N(0,1) 15 | 16 | :param N0: 17 | :param N1: 18 | :param T0: 19 | :param T1: 20 | :param K: Number of covariates that affect outcome 21 | :param R: Number of (noise) covariates to do not affect outcome 22 | :param F: Number of factors 23 | :returns: (X_control, X_treated, Y_pre_control, Y_pre_treated, Y_post_control, Y_post_treated, Loadings_control, Loadings_treated) 24 | ''' 25 | 26 | if K>0: 27 | # COVARIATE EFFECTS 28 | X_control = np.matrix(np.random.normal(0,1,((N0), K+R))) 29 | X_treated = np.matrix(np.random.normal(0,1,((N1), K+R))) 30 | 31 | b_cause = np.random.exponential(1,K) 32 | b_cause *= 1 / b_cause.max() 33 | 34 | beta = np.matrix(np.concatenate( ( b_cause , np.zeros(R)) ) ).T 35 | Xbeta_C = X_control.dot(beta) 36 | Xbeta_T = X_treated.dot(beta) 37 | else: 38 | X_control = np.empty((N0,0)) 39 | X_treated = np.empty((N1,0)) 40 | Xbeta_C = np.zeros((N0,1)) 41 | Xbeta_T = np.zeros((N1,1)) 42 | 43 | # FACTORS 44 | Loadings_control = np.matrix(np.random.normal(0,1,((N0), F))) 45 | Loadings_treated = np.matrix(np.random.normal(0,1,((N1), F))) 46 | 47 | Factors_pre = np.matrix(np.random.normal(0,1,((F), T0))) 48 | Factors_post = np.matrix(np.random.normal(0,1,((F), T1))) 49 | 50 | # RANDOM ERRORS 51 | Y_pre_err_control = np.matrix(np.random.normal(0, 1, ( N0, T0, ) )) 52 | Y_pre_err_treated = np.matrix(np.random.normal(0, 1, ( N1, T0, ) )) 53 | Y_post_err_control = np.matrix(np.random.normal(0, 1, ( N0, T1, ) )) 54 | Y_post_err_treated = np.matrix(np.random.normal(0, 1, ( N1, T1, ) )) 55 | 56 | # OUTCOMES 57 | Y_pre_control = np.tile(Xbeta_C, (1,T0)) + Loadings_control.dot(Factors_pre) + Y_pre_err_control 58 | Y_pre_treated = np.tile(Xbeta_T, (1,T0)) + Loadings_treated.dot(Factors_pre) + Y_pre_err_treated 59 | 60 | Y_post_control = np.tile(Xbeta_C, (1,T1)) + Loadings_control.dot(Factors_post) + Y_post_err_control 61 | Y_post_treated = np.tile(Xbeta_T, (1,T1)) + Loadings_treated.dot(Factors_post) + Y_post_err_treated 62 | 63 | return X_control, X_treated, Y_pre_control, Y_pre_treated, Y_post_control, Y_post_treated, Loadings_control, Loadings_treated 64 | 65 | -------------------------------------------------------------------------------- /test/dgp/group_effects.py: -------------------------------------------------------------------------------- 1 | """ 2 | Factor DGP 3 | """ 4 | 5 | import numpy as np 6 | 7 | def ge_dgp(N0,N1,T0,T1,K,S,R,groups,group_scale,beta_scale,confounders_scale,model= "full"): 8 | """ 9 | From example-code.py 10 | """ 11 | 12 | # COVARIATE EFFECTS 13 | X_control = np.matrix(np.random.normal(0,1,((N0)*groups, K+S+R))) 14 | X_treated = np.matrix(np.random.normal(0,1,((N1)*groups, K+S+R))) 15 | 16 | # CAUSAL 17 | b_cause = np.random.exponential(1,K) 18 | b_cause *= beta_scale / b_cause.max() 19 | 20 | # CONFOUNDERS 21 | b_confound = np.random.exponential(1,S) 22 | b_confound *= confounders_scale / b_confound.max() 23 | 24 | beta_control = np.matrix(np.concatenate( ( b_cause ,b_confound, np.zeros(R)) ) ).T 25 | beta_treated = np.matrix(np.concatenate( ( b_cause ,np.zeros(S), np.zeros(R)) ) ).T 26 | 27 | # GROUP EFFECTS (hidden) 28 | 29 | Y_pre_group_effects = np.random.normal(0,group_scale,(groups,T0)) 30 | Y_pre_ge_control = Y_pre_group_effects[np.repeat(np.arange(groups),N0)] 31 | Y_pre_ge_treated = Y_pre_group_effects[np.repeat(np.arange(groups),N1)] 32 | 33 | Y_post_group_effects = np.random.normal(0,group_scale,(groups,T1)) 34 | Y_post_ge_control = Y_post_group_effects[np.repeat(np.arange(groups),N0)] 35 | Y_post_ge_treated = Y_post_group_effects[np.repeat(np.arange(groups),N1)] 36 | 37 | # RANDOM ERRORS 38 | Y_pre_err_control = np.matrix(np.random.random( ( N0*groups, T0, ) )) 39 | Y_pre_err_treated = np.matrix(np.random.random( ( N1*groups, T0, ) )) 40 | 41 | Y_post_err_control = np.matrix(np.random.random( ( N0*groups, T1, ) )) 42 | Y_post_err_treated = np.matrix(np.random.random( ( N1*groups, T1, ) )) 43 | 44 | # THE DATA GENERATING PROCESS 45 | 46 | if model == "full": 47 | """ 48 | In the full model, covariates (X) are correlated with pre and post 49 | outcomes, and variance of the outcomes pre- and post- outcomes is lower 50 | within groups which span both treated and control units. 51 | """ 52 | Y_pre_control = X_control.dot(beta_control) + Y_pre_ge_control + Y_pre_err_control 53 | Y_pre_treated = X_treated.dot(beta_treated) + Y_pre_ge_treated + Y_pre_err_treated 54 | 55 | Y_post_control = X_control.dot(beta_control) + Y_post_ge_control + Y_post_err_control 56 | Y_post_treated = X_treated.dot(beta_treated) + Y_post_ge_treated + Y_post_err_treated 57 | 58 | elif model == "hidden": 59 | """ 60 | In the hidden model outcomes are independent of the covariates, but 61 | variance of the outcomes pre- and post- outcomes is lower within groups 62 | which span both treated and control units. 63 | """ 64 | Y_pre_control = Y_pre_ge_control + Y_pre_err_control 65 | Y_pre_treated = Y_pre_ge_treated + Y_pre_err_treated 66 | 67 | Y_post_control = Y_post_ge_control + Y_post_err_control 68 | Y_post_treated = Y_post_ge_treated + Y_post_err_treated 69 | 70 | elif model == "null": 71 | """ 72 | Purely random data 73 | """ 74 | Y_pre_control = Y_pre_err_control 75 | Y_pre_treated = Y_pre_err_treated 76 | 77 | Y_post_control = Y_post_err_control 78 | Y_post_treated = Y_post_err_treated 79 | 80 | else: 81 | raise ValueError("Unknown model type: "+model) 82 | 83 | return X_control, X_treated, Y_pre_control, Y_pre_treated, Y_post_control, Y_post_treated 84 | 85 | -------------------------------------------------------------------------------- /test/test.pyproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | Debug 4 | 2.0 5 | 01446f94-f552-4ee1-90dc-e93a0db22b4c 6 | . 7 | test_estimation.py 8 | 9 | 10 | . 11 | . 12 | test 13 | test 14 | CondaEnv|CondaEnv|SparseSC_36 15 | true 16 | 17 | 18 | true 19 | false 20 | 21 | 22 | true 23 | false 24 | 25 | 26 | 27 | Code 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 47 | 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /test/test_batchFile.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- 2 | # Programmer: Jason Thorpe 3 | # Date 1/25/2019 3:34:02 PM 4 | # Language: Python (.py) Version 2.7 or 3.5 5 | # Usage: 6 | # 7 | # Test all model types 8 | # 9 | # \SpasrseSC > python -m unittest test/test_fit.py 10 | # 11 | # Test a specific model type (e.g. "prospective-restricted"): 12 | # 13 | # \SpasrseSC > python -m unittest test.test_fit.TestFit.test_retrospective 14 | # 15 | # -------------------------------------------------------------------------------- 16 | 17 | from __future__ import print_function # for compatibility with python 2.7 18 | import numpy as np 19 | import sys 20 | import random 21 | import unittest 22 | import warnings 23 | from os.path import expanduser, join 24 | from scipy.optimize.linesearch import LineSearchWarning 25 | 26 | try: 27 | from SparseSC import fit 28 | except ImportError: 29 | raise RuntimeError("SparseSC is not installed. use 'pip install -e .' to install") 30 | 31 | 32 | class TestFit(unittest.TestCase): 33 | def setUp(self): 34 | 35 | random.seed(12345) 36 | np.random.seed(101101001) 37 | control_units = 50 38 | treated_units = 20 39 | features = 10 40 | targets = 5 41 | 42 | self.X = np.random.rand(control_units + treated_units, features) 43 | self.Y = np.random.rand(control_units + treated_units, targets) 44 | self.treated_units = np.arange(treated_units) 45 | 46 | @classmethod 47 | def run_test(cls, obj, model_type, verbose=False): 48 | if verbose: 49 | print("Calling fit with `model_type = '%s'`..." % (model_type,), end="") 50 | sys.stdout.flush() 51 | 52 | with warnings.catch_warnings(): 53 | warnings.filterwarnings("ignore", category=PendingDeprecationWarning) 54 | warnings.filterwarnings("ignore", category=LineSearchWarning) 55 | try: 56 | fit( 57 | X=obj.X, 58 | Y=obj.Y, 59 | model_type=model_type, 60 | treated_units=obj.treated_units 61 | if model_type 62 | in ("retrospective", "prospective", "prospective-restricted") 63 | else None, 64 | # KWARGS: 65 | print_path=False, 66 | progress=verbose, 67 | grid_length=5, 68 | min_iter=-1, 69 | tol=1, 70 | verbose=0, 71 | batchFile=join( 72 | expanduser("~"), "temp", "%s_batch_params.py" % model_type 73 | ), 74 | ) 75 | import pdb 76 | 77 | pdb.set_trace() 78 | if verbose: 79 | print("DONE") 80 | except LineSearchWarning: 81 | pass 82 | except PendingDeprecationWarning: 83 | pass 84 | except Exception as exc: 85 | print("Failed with %s: %s" % (exc.__class__.__name__, str(exc))) 86 | raise exc 87 | 88 | def test_retrospective(self): 89 | TestFit.run_test(self, "retrospective") 90 | 91 | def test_prospective(self): 92 | TestFit.run_test(self, "prospective") 93 | 94 | def test_prospective_restrictive(self): 95 | # Catch the LineSearchWarning silently, but allow others 96 | 97 | TestFit.run_test(self, "prospective-restricted") 98 | 99 | def test_full(self): 100 | TestFit.run_test(self, "full") 101 | 102 | 103 | if __name__ == "__main__": 104 | # t = TestFit() 105 | # t.setUp() 106 | # t.test_retrospective() 107 | unittest.main() 108 | -------------------------------------------------------------------------------- /test/test_estimation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests for model fitness 3 | """ 4 | 5 | import unittest 6 | import random 7 | import numpy as np 8 | import pandas as pd 9 | 10 | #import warnings 11 | #warnings.simplefilter("error") 12 | 13 | try: 14 | import SparseSC as SC 15 | except ImportError: 16 | raise RuntimeError("SparseSC is not installed. Use 'pip install -e .' or 'conda develop .' from repo root to install in dev mode") 17 | from os.path import join, abspath, dirname 18 | from test.dgp.factor_model import factor_dgp 19 | 20 | # import matplotlib.pyplot as plt 21 | 22 | import sys 23 | sys.path.insert(0, join(dirname(abspath(__file__)), "..", "examples")) 24 | from example_graphs import * 25 | 26 | 27 | # pylint: disable=no-self-use, missing-docstring 28 | 29 | # here for lexical scoping 30 | command_line_options = {} 31 | 32 | class TestEstimationForErrors(unittest.TestCase): 33 | def setUp(self): 34 | 35 | random.seed(12345) 36 | np.random.seed(101101001) 37 | control_units = 50 38 | treated_units = 2 39 | N = control_units + treated_units 40 | T = 15 41 | K_X = 2 42 | 43 | self.Y = np.random.rand(N, T) 44 | self.X = np.random.rand(N, K_X) 45 | self.treated_units = np.arange(treated_units) 46 | self.unit_treatment_periods = np.full((N), np.nan) 47 | self.unit_treatment_periods[0] = 7 48 | self.unit_treatment_periods[1] = 8 49 | #self. 50 | #self.unit_treatment_periods[treated_name] = treatment_date_ms 51 | 52 | @classmethod 53 | def run_test(cls, obj, model_type="retrospective", frame_type="ndarray"): #"NDFrame", "pandas_timeindex", NDFrame 54 | X = obj.X 55 | Y = obj.Y 56 | unit_treatment_periods = obj.unit_treatment_periods 57 | if frame_type=="NDFrame" or frame_type=="timeindex": 58 | X = pd.DataFrame(X) 59 | Y = pd.DataFrame(Y) 60 | if frame_type=="timeindex": 61 | t_index = pd.Index(np.datetime64('2000-01-01','D') + range(Y.shape[1])) 62 | unit_treatment_periods = pd.Series(np.datetime64('NaT'), index=Y.index) 63 | unit_treatment_periods[0] = t_index[7] 64 | unit_treatment_periods[1] = t_index[8] 65 | Y.columns = t_index 66 | 67 | SC.estimate_effects(covariates=X, outcomes=Y, model_type=model_type, unit_treatment_periods=unit_treatment_periods) 68 | 69 | def test_all(self): #RidgeCV returns: RuntimeWarning: invalid value encountered in true_divide \n return (c / G_diag) ** 2, c 70 | for model_type in ["retrospective", "prospective", "prospective-restricted"]: 71 | for frame_type in ["ndarray", "NDFrame", "timeindex"]: 72 | TestEstimationForErrors.run_test(self, model_type, frame_type) 73 | 74 | class TestDGPs(unittest.TestCase): 75 | """ 76 | testing fixture 77 | """ 78 | @staticmethod 79 | def simple_summ(fit, Y): 80 | #print("V_pen=%s, W_pen=%s" % (fit.fitted_v_pen, fit.fitted_w_pen)) 81 | if fit.match_space_desc is not None: 82 | print(fit.match_space_desc) 83 | else: 84 | print("V=%s" % np.diag(fit.V)) 85 | print("Treated weights: sim=%s, uns=%s, sum=%s" % ( fit.sc_weights[0, 49], fit.sc_weights[0, 99], sum(fit.sc_weights[0, :]),)) 86 | print("Sim Con weights: sim=%s, uns=%s, sum=%s" % ( fit.sc_weights[1, 49], fit.sc_weights[1, 99], sum(fit.sc_weights[1, :]),)) 87 | print("Uns Con weights: sim=%s, uns=%s, sum=%s" % ( fit.sc_weights[51, 49], fit.sc_weights[51, 99], sum(fit.sc_weights[51, :]),)) 88 | Y_sc = fit.predict(Y) 89 | print("Treated diff: %s" % (Y - Y_sc)[0, :]) 90 | 91 | 92 | def testSimpleTrendDGP(self): 93 | """ 94 | No X, just Y; half the donors are great, other half are bad 95 | """ 96 | N1, N0_sim, N0_not = 1, 50, 50 97 | N0 = N0_sim + N0_not 98 | N = N1 + N0 99 | treated_units, control_units = range(N1), range(N1, N) 100 | T0, T1 = 5, 2 101 | T = T0 + T1 # unused 102 | proto_sim = np.array([1, 2, 3, 4, 5] + [6,7], ndmin=2) 103 | proto_not = np.array([0, 2, 4, 6, 8] + [10, 12], ndmin=2) 104 | te = 2 105 | proto_tr = proto_sim + np.hstack((np.zeros((1, T0)), np.full((1, T1), te))) 106 | Y1 = np.matmul(np.ones((N1, 1)), proto_tr) 107 | Y0_sim = np.matmul(np.ones((N0_sim, 1)), proto_sim) 108 | Y0_sim = Y0_sim + np.random.normal(0,0.1,Y0_sim.shape) 109 | #Y0_sim = Y0_sim + np.hstack((np.zeros((N0_sim,1)), 110 | # np.random.normal(0,0.1,(N0_sim,1)), 111 | # np.zeros((N0_sim,T-2)))) 112 | Y0_not = np.matmul(np.ones((N0_not, 1)), proto_not) 113 | Y0_not = Y0_not + np.random.normal(0,0.1,Y0_not.shape) 114 | Y = np.vstack((Y1, Y0_sim, Y0_not)) 115 | 116 | unit_treatment_periods = np.full((N), -1) 117 | unit_treatment_periods[0] = T0 118 | 119 | # Y += np.random.normal(0, 0.01, Y.shape) 120 | 121 | # OPTIMIZE OVER THE V_PEN'S 122 | # for v_pen, w_pen in [(1,1), (1,1e-10), (1e-10,1e-10), (1e-10,1), (None, None)]: # 123 | # print("\nv_pen=%s, w_pen=%s" % (v_pen, w_pen)) 124 | ret = SC.estimate_effects( 125 | Y, 126 | unit_treatment_periods, 127 | ret_CI=True, 128 | max_n_pl=200, 129 | fast = True, 130 | #stopping_rule=4, 131 | **command_line_options, 132 | ) 133 | TestDGPs.simple_summ(ret.fits[T0], Y) 134 | V_penalty = ret.fits[T0].fitted_v_pen 135 | 136 | Y_sc = ret.fits[T0].predict(Y)# [control_units, :] 137 | te_vec_est = (Y - Y_sc)[0:T0:] 138 | # weight_sums = np.sum(ret.fit.sc_weights, axis=1) 139 | 140 | # print(ret.fit.scores) 141 | p_value = ret.p_value 142 | #print("p-value: %s" % p_value) 143 | #print( ret.CI) 144 | #print(np.diag(ret.fit.V)) 145 | #import pdb; pdb.set_trace() 146 | # print(ret) 147 | assert te in ret.CI, "Confidence interval does not include the true effect" 148 | assert p_value is not None 149 | assert p_value < 0.1, "P-value is larger than expected" 150 | 151 | # [sc_raw, sc_diff] = ind_sc_plots(Y[0, :], Y_sc[0, :], T0, ind_ci=ret.ind_CI) 152 | # plt.figure("sc_raw") 153 | # plt.title("Unit 0") 154 | # ### SHOW() blocks!!!! 155 | # # plt.show() 156 | # plt.figure("sc_diff") 157 | # plt.title("Unit 0") 158 | # # plt.show() 159 | # [te] = te_plot(ret) 160 | # plt.figure("te") 161 | # plt.title("Average Treatment Effect") 162 | # # plt.show() 163 | 164 | def testFactorDGP(self): 165 | """ 166 | factor dbp based test 167 | """ 168 | N1, N0 = 2, 100 169 | treated_units = [0, 1] 170 | T0, T1 = 20, 10 171 | K, R, F = 5, 5, 5 172 | ( 173 | Cov_control, 174 | Cov_treated, 175 | Out_pre_control, 176 | Out_pre_treated, 177 | Out_post_control, 178 | Out_post_treated, 179 | _, 180 | _, 181 | ) = factor_dgp(N0, N1, T0, T1, K, R, F) 182 | 183 | Cov = np.vstack((Cov_treated, Cov_control)) 184 | Out_pre = np.vstack((Out_pre_treated, Out_pre_control)) 185 | Out_post = np.vstack((Out_post_treated, Out_post_control)) 186 | 187 | SC.estimate_effects( 188 | Out_pre, 189 | Out_post, 190 | treated_units, 191 | Cov, 192 | **command_line_options, 193 | ) 194 | 195 | # print(fit_res) 196 | # est_res = SC.estimate_effects( 197 | # Cov, Out_pre, Out_post, treated_units, V_penalty=0, W_penalty=0.001 198 | # ) 199 | # print(est_res) 200 | 201 | 202 | if __name__ == "__main__": 203 | import argparse 204 | 205 | parser = argparse.ArgumentParser(prog="PROG", allow_abbrev=False) 206 | parser.add_argument( 207 | "--constrain", choices=["orthant", "simplex"], default="simplex" 208 | ) 209 | args = parser.parse_args() 210 | command_line_options.update(vars(args)) 211 | 212 | random.seed(12345) 213 | np.random.seed(10101) 214 | 215 | t = TestEstimationForErrors() 216 | t.setUp() 217 | t.test_all() 218 | # unittest.main() 219 | -------------------------------------------------------------------------------- /test/test_fit_batch.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- 2 | # Programmer: Jason Thorpe 3 | # Date 1/25/2019 3:34:02 PM 4 | # Language: Python (.py) Version 2.7 or 3.5 5 | # Usage: 6 | # 7 | # Test all model types 8 | # 9 | # \SpasrseSC > python -m unittest test/test_fit.py 10 | # 11 | # Test a specific model type (e.g. "prospective-restricted"): 12 | # 13 | # \SpasrseSC > python -m unittest test.test_fit.TestFit.test_retrospective 14 | # 15 | # -------------------------------------------------------------------------------- 16 | 17 | from __future__ import print_function # for compatibility with python 2.7 18 | import numpy as np 19 | import sys, random 20 | import unittest 21 | import warnings 22 | from scipy.optimize.linesearch import LineSearchWarning 23 | 24 | try: 25 | from SparseSC import fit 26 | except ImportError: 27 | raise RuntimeError("SparseSC is not installed. use 'pip install -e .' to install") 28 | 29 | 30 | class TestFit(unittest.TestCase): 31 | def setUp(self): 32 | 33 | random.seed(12345) 34 | np.random.seed(101101001) 35 | control_units = 50 36 | treated_units = 20 37 | features = 10 38 | targets = 5 39 | 40 | self.X = np.random.rand(control_units + treated_units, features) 41 | self.Y = np.random.rand(control_units + treated_units, targets) 42 | self.treated_units = np.arange(treated_units) 43 | 44 | @classmethod 45 | def run_test(cls, obj, model_type, verbose=False): 46 | if verbose: 47 | print("Calling fit with `model_type = '%s'`..." % (model_type,), end="") 48 | sys.stdout.flush() 49 | 50 | with warnings.catch_warnings(): 51 | warnings.filterwarnings("ignore", category=PendingDeprecationWarning) 52 | warnings.filterwarnings("ignore", category=LineSearchWarning) 53 | try: 54 | Model1 = fit( 55 | features=obj.X, 56 | targets=obj.Y, 57 | model_type=model_type, 58 | treated_units=obj.treated_units 59 | if model_type 60 | in ("retrospective", "prospective", "prospective-restricted") 61 | else None, 62 | # KWARGS: 63 | print_path=False, 64 | stopping_rule=1, 65 | progress=verbose, 66 | grid_length=5, 67 | min_iter=-1, 68 | tol=1, 69 | verbose=0, 70 | ) 71 | Model2 = fit( 72 | features=obj.X, 73 | targets=obj.Y, 74 | model_type=model_type, 75 | treated_units=obj.treated_units 76 | if model_type 77 | in ("retrospective", "prospective", "prospective-restricted") 78 | else None, 79 | # KWARGS: 80 | print_path=False, 81 | stopping_rule=1, 82 | progress=verbose, 83 | grid_length=5, 84 | min_iter=-1, 85 | tol=1, 86 | verbose=0, 87 | batch_client_config="sg_daemon", 88 | ) 89 | if verbose: 90 | print("DONE") 91 | except LineSearchWarning: 92 | pass 93 | except PendingDeprecationWarning: 94 | pass 95 | except Exception as exc: 96 | print( 97 | "Failed with %s: %s" 98 | % (exc.__class__.__name__, getattr(exc, "message", "<>")) 99 | ) 100 | raise exc 101 | 102 | def test_retrospective(self): 103 | TestFit.run_test(self, "retrospective") 104 | 105 | 106 | # -- def test_prospective(self): 107 | # -- TestFit.run_test(self, "prospective") 108 | # -- 109 | # -- def test_prospective_restrictive(self): 110 | # -- # Catch the LineSearchWarning silently, but allow others 111 | # -- 112 | # -- TestFit.run_test(self, "prospective-restricted") 113 | # -- 114 | # -- def test_full(self): 115 | # -- TestFit.run_test(self, "full") 116 | 117 | 118 | if __name__ == "__main__": 119 | # t = TestFit() 120 | # t.setUp() 121 | # t.test_retrospective() 122 | unittest.main() 123 | -------------------------------------------------------------------------------- /test/test_normal.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | 4 | try: 5 | import SparseSC 6 | except ImportError: 7 | raise RuntimeError("SparseSC is not installed. Use 'pip install -e .' or 'conda develop .' from repo root to install in dev mode") 8 | 9 | 10 | class TestNormalForErrors(unittest.TestCase): 11 | def test_SSC_DescrStat(self): 12 | mat = np.arange(20).reshape(4,5) 13 | ds_top = SparseSC.SSC_DescrStat.from_data(mat[:3,:]) 14 | ds_bottom = SparseSC.SSC_DescrStat.from_data(mat[3:,:]) 15 | ds_add = ds_top + ds_bottom 16 | ds_top.update(mat[3:,:]) 17 | ds_whole = SparseSC.SSC_DescrStat.from_data(mat) 18 | assert ds_add == ds_top, "ds_add == ds_top" 19 | assert ds_add == ds_whole, "ds_add == ds_whole" 20 | 21 | def test_DescrSet(self): 22 | Y_t = (np.arange(20)+1).reshape(4,5) 23 | Y_c = np.arange(20).reshape(4,5) 24 | Y_t_cf_c = Y_c 25 | ds = SparseSC.DescrSet.from_data(Y_t=Y_t, Y_t_cf_c=Y_t_cf_c) 26 | est = ds.calc_estimates() 27 | assert np.array_equal(est.att_est.effect, np.full((5), 1.0)), "estimations not right" 28 | 29 | if __name__ == "__main__": 30 | t = TestNormalForErrors() 31 | t.test_SSC_DescrStat() 32 | t.test_DescrSet() 33 | 34 | #unittest.main() 35 | --------------------------------------------------------------------------------