├── .editorconfig
├── .gitignore
├── .pylintrc
├── CHANGELOG.md
├── LICENSE
├── README.md
├── SECURITY.md
├── SparseSC.sln
├── _config.yml
├── docs
├── api_ref.md
├── azure_batch.md
├── conf.py
├── dev_notes.md
├── estimate-effects.md
├── examples
│ ├── estimate_effects.md
│ ├── prospective_anomaly_detection.md
│ ├── retrospecitve_restricted_synthetic_controls.md
│ └── retrospecitve_synthetic_controls.md
├── fit.md
├── index.rst
├── model-types.md
├── overview.md
├── performance-notes.md
├── rtd-base.txt
└── rtd-requirements.txt
├── example-code.py
├── examples
├── AATest.ipynb
├── DifferentialTrends.ipynb
├── __init__.py
├── example_graphs.py
└── strip_magic.py
├── makefile
├── replication
├── .gitignore
├── figs
│ ├── sparsesc_fast_pe.pdf
│ ├── sparsesc_fast_xf2_pe.pdf
│ ├── sparsesc_fast_xf_pe.pdf
│ ├── sparsesc_full_pe.pdf
│ ├── sparsesc_full_xf2_pe.pdf
│ ├── sparsesc_full_xf_pe.pdf
│ ├── standard_flat55_pe.pdf
│ ├── standard_nested_pe.pdf
│ ├── standard_spfast_pe.pdf
│ ├── standard_spfast_xf2_pe.pdf
│ ├── standard_spfast_xf_pe.pdf
│ ├── standard_spfull_pe.pdf
│ ├── standard_spfull_xf2_pe.pdf
│ └── standard_spfull_xf_pe.pdf
├── repl2010.do
├── repl2010.log
├── sc2010.ipynb
├── smoking.dta
├── vmats
│ ├── fast_fit.txt
│ ├── full_fit.txt
│ ├── xf_fits_fast.txt
│ ├── xf_fits_fast2.txt
│ ├── xf_fits_full.txt
│ └── xf_fits_full2.txt
└── w_pen_switch.ipynb
├── setup.py
├── src
└── SparseSC
│ ├── .pylintrc
│ ├── SparseSC.pyproj
│ ├── __init__.py
│ ├── cli
│ ├── __init__.py
│ ├── daemon.py
│ ├── daemon_process.py
│ ├── scgrad.py
│ └── stt.py
│ ├── cross_validation.py
│ ├── estimate_effects.py
│ ├── fit.py
│ ├── fit_ct.py
│ ├── fit_fast.py
│ ├── fit_fold.py
│ ├── fit_loo.py
│ ├── optimizers
│ ├── __init__.py
│ ├── cd_line_search.py
│ └── simplex_step.py
│ ├── tensor.py
│ ├── utils
│ ├── AzureBatch
│ │ ├── __init__.py
│ │ ├── aggregate_results.py
│ │ ├── azure_batch_client.py
│ │ ├── build_batch_job.py
│ │ ├── constants.py
│ │ └── gradient_batch_client.py
│ ├── __init__.py
│ ├── batch_gradient.py
│ ├── descr_sets.py
│ ├── dist_summary.py
│ ├── local_grad_daemon.py
│ ├── match_space.py
│ ├── metrics_utils.py
│ ├── misc.py
│ ├── ols_model.py
│ ├── penalty_utils.py
│ ├── print_progress.py
│ ├── sub_matrix_inverse.py
│ └── warnings.py
│ └── weights.py
├── static
├── controls-only-pre-and-post.png
└── pre-only-controls-and-treated.png
└── test
├── AzureBatch
├── README.md
├── test_batch_aggregate.py
├── test_batch_build.py
└── test_batch_run.py
├── CausalImpact_test.R
├── SparseSC_36.yml
├── __init__.py
├── dgp
├── __init__.py
├── factor_model.py
└── group_effects.py
├── test.pyproj
├── test_batchFile.py
├── test_estimation.py
├── test_fit.py
├── test_fit_batch.py
├── test_normal.py
└── test_simulation.py
/.editorconfig:
--------------------------------------------------------------------------------
1 | # EditorConfig helps developers define and maintain consistent coding styles between different editors and IDEs. http://EditorConfig.org
2 |
3 | # top-most EditorConfig file
4 | root = true
5 |
6 | # 4 space indentation
7 | [*.py]
8 | indent_style = space
9 | indent_size = 4
10 | insert_final_newline = true
11 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # ignore testing data
2 | **/data/
3 | replication/*.pkl
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Sphinx documentation
55 | docs/_build/
56 |
57 | # Jupyter Notebook
58 | .ipynb_checkpoints
59 |
60 | # pyenv
61 | .python-version
62 |
63 | # Environments
64 | .env
65 | .venv
66 | env/
67 | venv/
68 | ENV/
69 | env.bak/
70 | venv.bak/
71 |
72 | # Spyder project settings
73 | .spyderproject
74 | .spyproject
75 |
76 | # mypy
77 | .mypy_cache/
78 |
79 | #Visual Studio
80 | .vs/
81 |
82 | .vscode/
83 |
84 | # own-generated content
85 | docs/SyntheticControlsReadme.*
86 | examples/*.pdf
87 | examples/*.html
88 | examples/DifferentialTrends.py
89 |
90 | #Jupyter Windows
91 | jupyter_launch.log
92 |
93 | tmp/
94 |
95 |
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | # Changelog
2 | All notable changes to this project will be documented in this file.
3 |
4 | The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
5 | and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
6 |
7 |
8 | ## [Unreleased](https://github.com/Microsoft/SparseSC/compare/v0.2.0...master)
9 | ### Added
10 | - Added `fit_args` to `estimate_effect()` that can be used with `fit_fast()` with the default variable weight algorithm (`sklearn`'s [MTLassoCV](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.MultiTaskLassoCV.html)). Common usages would include increasing `max_iter` (if you want to improve convergence) or `n_jobs` (if you want to run in parallel).
11 | - Separated CV folds from Cross-fitting folds in `estimate_effects()`. `cv_folds` controls the amount of additional estimation done on the control units (this used to be controled by `max_n_pl`, but that parameter now only governs the amount of post-estimation processing that is done). This will do `cv_folds` extra estimations per treatment time period, though if `=1` then no extra work will be done (but control residuals might be biased toward 0).
12 | - Added additional option `Y_col_block_size` to `MTLassoCV_MatchSpace_factory` to estimate `V` on block-averages of `Y` (e.g. taking a 150 cols down to 5 by doing averages over 30 cols at a time).
13 | - Added `se_factor` to `MTLassoCV_MatchSpace_factory` to use a different penalty than the MSE min.
14 | - For large data, approximate the outcomes using a normal distribution (`DescrSet`), and allow for calculating estimates.
15 |
16 | ## 0.2.0 - 2020-05-06
17 | ### Added
18 | - Added tools too to use `fit_fast()` with large datasets. This includes the `sample_frac` option to `MTLassoCV_MatchSpace_factory()` to estimate the match space on a subset of the observations. It also includes `fit_fast()` option `avoid_NxN_mats` which will avoid making large matrices (at the expense of only returning the Synthetic control `targets` and `targets_aux` and not the full weight matrix)
19 | - Added logging in `fit_fast` via the `verbose` numerical option. This can help identify out-of-memory errors.
20 | - Added a pseudo-Doubly robust match space maker `D_LassoCV_MatchSpace_factory`. It apporptions some of the normalized variable V weight to those variables that are good predictors of treatment. This should only be done if there are many treated units so that one can reasonably model this relationship.
21 | - Switched using standardized Azure Batch config library
22 |
23 |
24 | ## 0.1.0 - 2019-07-25
25 | Initial release.
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | MIT License
3 |
4 | Copyright (c) Microsoft Corporation. All rights reserved.
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [Sparse Synthetic Controls](https://sparsesc.readthedocs.io/en/latest/)
2 |
3 | SparseSC is a package that implements an ML-enhanced version of Synthetic Control Methodologies, which introduces penalties on the both feature and unit weights. Specifically it optimizes these two equations:
4 |
5 | $$\gamma_0 = \left\Vert Y_T - W\cdot Y_C \right\Vert_F^2 + \lambda_V \left\Vert V \right\Vert_1 $$
6 |
7 | $$\gamma_1 = \left\Vert X_T - W\cdot X_C \right\Vert_V + \lambda_W \left\Vert W - J/c \right\Vert_F^2 $$
8 |
9 | by optimizing $\lambda_V$ and $\lambda_W$ using cross validation within the control units in the usual way, and where:
10 |
11 | - $X_T$ and $X_C$ are matrices of features (covariates and/or pre-treatement outcomes) for the treated and control units, respectively
12 |
13 | - $Y_T$ and $Y_C$ are matrices of post-treatement outcomes on the treated and control units, respectively
14 |
15 | - $W$
16 | is a matrix of weights such that $W \cdot (X_C \left |Y_C \right)$ forms a matrix of synthetic controls for all units
17 |
18 | - $V$ is a diagnoal matrix of weights applied to the covariates / pre-treatment outcomes.
19 |
20 | Note that all matrices are formated with one row per unit and one column per feature / outcome. Breaking down the two main equations, we have:
21 |
22 | - $\left\Vert Y_T - W\cdot Y_C \right\Vert_F^2$ is the out-of-sample squared prediction error (i.e. the squared Frobenius Norm), measured within the control units under cross validation
23 |
24 | - $\left\Vert V \right\Vert_1$ represents how much the model depends on the features
25 |
26 | - $\left\Vert X_T - W\cdot X_C \right\Vert_V$ is the difference between synthetic and observed units in the the feature space weighted by $V$. Specifically, $\left\Vert A \right\Vert_V = AVA^T$
27 |
28 | - $\left\Vert W - J/c \right\Vert_F^2$ is the difference between optimistic weights and simple ($1/N_c$) weighted averages of the control units.
29 |
30 |
31 | Typically this is used to estimate causal effects from binary treatments on observational panel (longitudinal) data. The functions `fit()` and `fit_fast()` provide basic fitting of the model. If you are estimating treatment effects, fitting and diagnostic information can be done via `estimate_effects()`.
32 |
33 | Though the fitting methods do not require such structure, the typical setup is where we have panel data of an outcome variable `Y` for `T` time periods for `N` observation units (customer, computers, etc.). We may additionally have some baseline characteristics `X` about the units. In the treatment effect setting, we will also have a discrete change in treatment status (e.g. some policy change) at time, `T0`, for a select group of units. When there is treatment, we can think of the pre-treatment data as [`X`, `Y_pre`] and post-treatment data as [`Y_post`].
34 |
35 | ```py
36 | import SparseSC
37 |
38 | # Fit the model:
39 | treatment_unit_size = np.full((N), np.NaN)
40 | treatment_unit_size[treated_unit_idx] = T0
41 | fitted_estimates = SparseSC.estimate_effects(Y,unit_treatment_periods,...)
42 |
43 | # Print summary of the model including effect size estimates,
44 | # p-values, and confidendence intervals:
45 | print(fitted_estimates)
46 |
47 | # Extract model attributes:
48 | fitted_estimates.pl_res_post.avg_joint_effect.p_value
49 | fitted_estimates.pl_res_post.avg_joint_effect.CI
50 |
51 | # access the fitted Synthetic Controls model:
52 | fitted_model = fitted_estimates.fit
53 | ```
54 |
55 | See [the docs](https://sparsesc.readthedocs.io/en/latest/) for more details
56 |
57 | ## Overview
58 |
59 | See [here](https://en.wikipedia.org/wiki/Synthetic_control_method) for more info on Synthetic Controls. In essence, it is a type of matching estimator. For each unit it will find a weighted average of untreated units that is similar on key pre-treatment data. The goal of Synthetic controls is find out which variables are important to match on (the `V` matrix) and then, given those, to find a vector of per-unit weights that combine the control units into its synthetic control. The synthetic control acts as the counterfactual for a unit, and the estimate of a treatment effect is the difference between the observed outcome in the post-treatment period and the synthetic control's outcome.
60 |
61 | SparseSC makes a number of changes to Synthetic Controls. It uses regularization and feature learning to avoid overfitting, ensure uniqueness of the solution, automate researcher decisions, and allow for estimation on large datasets. See the docs for more details.
62 |
63 | The main choices to make are:
64 | 1. The solution structure
65 | 2. The model-type
66 |
67 | ### SparseSC Solution Structure
68 | The first choice is whether to calculate all of the high-level parameters (`V`, its regularization parameter, and the regularization parameters for the weights) on the main matching objective or whether to get approximate/fast estimates of them using non-matching formulations. The options are:
69 | * Full joint (done by `fit()`): We optimize over `v_pen`, `w_pen` and `V`, so that the resulting SC for controls have smallest squared prediction error on `Y_post`.
70 | * Separate (done by `fit_fast()`): We note that we can efficiently estimate `w_pen` on main matching objective, since, given `V`, we can reformulate the finding problem into a Ridge Regression and use efficient LOO cross-validation (e.g. `RidgeCV`) to estimate `w_pen`. We will estimate `V` using an alternative, non-matching objective (such as a `MultiTaskLasso` of using `X,Y_pre` to predict `Y_post`). This setup also allows for feature generation to select the match space. There are two variants depending on how we handle `v_pen`:
71 | * Mixed. Choose `v_pen` based on the resulting down-stream main matching objective.
72 | * Full separate: Choose `v_pen` base on approximate objective (e.g., `MultiTaskLassoCV`).
73 |
74 | The Fully Separate solution is fast and often quite good so we recommend starting there, and if need be, advancing to the Mixed and then Fully Joint optimizations.
75 |
76 | ### Model types
77 | There are two main model-types (corresponding to different cuts of the data) that can be used to estimate treatment effects.
78 | 1. Retrospective: The goal is to minimize squared prediction error of the control units on `Y_post` and the full-pre history of the outcome is used as features in fitting. This is the default and was used in the descriptive elements above.
79 | 2. Prospective: We make an artificial split in time before any treatment actually happens (`Y_pre=[Y_train,Y_test]`). The goal is to minimize squared prediction error of all units on `Y_test` and `Y_train` for all units is used as features in fitting.
80 |
81 | Given the same amount of features, the two will only differ when there are a non-trivial number of treated units. In this case the prospective model may provide lower prediction error for the treated units, though at the cost of less pre-history data used for fitting. When there are a trivial number of units, the retrospective design will be the most efficient.
82 |
83 | See more details about these and two additional model types (Prospective-restrictive, and full) at the docs.
84 |
85 | ## Fitting a synthetic control model
86 |
87 | ### Documentation
88 |
89 | You can read these online at [Read the
90 | Docs](https://sparsesc.readthedocs.io/en/latest/). See there for:
91 | * Custom Donor Pools
92 | Parallelization
93 | * Constraining the `V` matrix to be in the unit simplex
94 | * Performance Notes for `fit()`
95 | * Additional Performance Considerations for `fit()`
96 | * Full parameter listings
97 |
98 | To build the documentation see `docs/dev_notes.md`.
99 |
100 | ## Citation
101 | Brian Quistorff, Matt Goldman, and Jason Thorpe (2020) [Sparse Synthetic Controls: Unit-Level Counterfactuals from High-Dimensional Data](https://drive.google.com/file/d/1lfH1CK_JZpc0ou7hP60FhQpkeoXhR6fC/view?usp=sharing), Microsoft Journal of Applied Research, 14, pp.155-170.
102 |
103 | ## Installation
104 |
105 | The easiest way to install `SparseSC` is with:
106 |
107 | ```sh
108 | pip install git+https://github.com/microsoft/SparseSC.git
109 | ```
110 |
111 | Additional commands to run tests and examples are in the makefile.
112 |
113 | ## Contributing
114 |
115 | This project welcomes contributions and suggestions. Most contributions
116 | require you to agree to a Contributor License Agreement (CLA) declaring
117 | that you have the right to, and actually do, grant us the rights to use
118 | your contribution. For details, visit https://cla.microsoft.com.
119 |
120 | When you submit a pull request, a CLA-bot will automatically determine
121 | whether you need to provide a CLA and decorate the PR appropriately (e.g.,
122 | label, comment). Simply follow the instructions provided by the bot. You
123 | will only need to do this once across all repos using our CLA.
124 |
125 | This project has adopted the [Microsoft Open Source Code of
126 | Conduct](https://opensource.microsoft.com/codeofconduct/). For more
127 | information see the [Code of Conduct
128 | FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact
129 | [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional
130 | questions or comments.
131 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/SparseSC.sln:
--------------------------------------------------------------------------------
1 |
2 | Microsoft Visual Studio Solution File, Format Version 12.00
3 | # Visual Studio Version 16
4 | VisualStudioVersion = 16.0.29613.14
5 | MinimumVisualStudioVersion = 10.0.40219.1
6 | Project("{888888A0-9F3D-457C-B088-3A5042F75D52}") = "SparseSC", "src\SparseSC\SparseSC.pyproj", "{3FDC664A-C1C8-47F0-8D77-8E0679E53C82}"
7 | EndProject
8 | Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Misc", "Misc", "{AA17F266-A75A-4FDE-A774-168EADE9A32E}"
9 | ProjectSection(SolutionItems) = preProject
10 | docs\api_ref.md = docs\api_ref.md
11 | docs\azure_batch.md = docs\azure_batch.md
12 | CHANGELOG.md = CHANGELOG.md
13 | docs\dev_notes.md = docs\dev_notes.md
14 | docs\estimate-effects.md = docs\estimate-effects.md
15 | example-code.py = example-code.py
16 | examples\example_graphs.py = examples\example_graphs.py
17 | docs\fit.md = docs\fit.md
18 | examples\fit_poc.py = examples\fit_poc.py
19 | docs\model-types.md = docs\model-types.md
20 | docs\overview.md = docs\overview.md
21 | docs\performance-notes.md = docs\performance-notes.md
22 | README.md = README.md
23 | EndProjectSection
24 | EndProject
25 | Project("{888888A0-9F3D-457C-B088-3A5042F75D52}") = "test", "test\test.pyproj", "{01446F94-F552-4EE1-90DC-E93A0DB22B4C}"
26 | EndProject
27 | Global
28 | GlobalSection(SolutionConfigurationPlatforms) = preSolution
29 | Debug|Any CPU = Debug|Any CPU
30 | Release|Any CPU = Release|Any CPU
31 | EndGlobalSection
32 | GlobalSection(ProjectConfigurationPlatforms) = postSolution
33 | {3FDC664A-C1C8-47F0-8D77-8E0679E53C82}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
34 | {3FDC664A-C1C8-47F0-8D77-8E0679E53C82}.Release|Any CPU.ActiveCfg = Release|Any CPU
35 | {01446F94-F552-4EE1-90DC-E93A0DB22B4C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
36 | {01446F94-F552-4EE1-90DC-E93A0DB22B4C}.Release|Any CPU.ActiveCfg = Release|Any CPU
37 | EndGlobalSection
38 | GlobalSection(SolutionProperties) = preSolution
39 | HideSolutionNode = FALSE
40 | EndGlobalSection
41 | GlobalSection(ExtensibilityGlobals) = postSolution
42 | SolutionGuid = {9EE1A574-A448-4146-B4A8-998CAEAFDD85}
43 | EndGlobalSection
44 | EndGlobal
45 |
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-slate
--------------------------------------------------------------------------------
/docs/api_ref.md:
--------------------------------------------------------------------------------
1 | # API Reference
2 |
3 |
4 | ## Estimate Treatment Effects
5 | ```eval_rst
6 | .. autofunction:: SparseSC.estimate_effects.estimate_effects
7 | :noindex:
8 |
9 | ```
10 |
11 | ```eval_rst
12 | .. autoclass:: SparseSC.estimate_effects.SparseSCEstResults
13 | :members:
14 | :show-inheritance:
15 | :noindex:
16 |
17 | ```
18 |
19 | ```eval_rst
20 | .. autoclass:: SparseSC.utils.metrics_utils.PlaceboResults
21 | :members:
22 | :show-inheritance:
23 | :noindex:
24 |
25 | ```
26 |
27 | ```eval_rst
28 | .. autoclass:: SparseSC.utils.metrics_utils.EstResultCI
29 | :members:
30 | :show-inheritance:
31 | :noindex:
32 |
33 | ```
34 |
35 | ```eval_rst
36 | .. autoclass:: SparseSC.utils.metrics_utils.CI_int
37 | :members:
38 | :show-inheritance:
39 | :noindex:
40 |
41 | ```
42 |
43 | ```eval_rst
44 | .. autoclass:: SparseSC.utils.metrics_utils.AA_results
45 | :members:
46 | :show-inheritance:
47 | :noindex:
48 |
49 | ```
50 |
51 |
52 | ## Fit a Synthetic Controls Model (Slow, Joint)
53 | ```eval_rst
54 | .. autofunction:: SparseSC.fit.fit
55 | :noindex:
56 |
57 | ```
58 |
59 | ```eval_rst
60 | .. autoclass:: SparseSC.fit.SparseSCFit
61 | :members:
62 | :show-inheritance:
63 | :noindex:
64 |
65 | ```
66 |
67 | ## Fit a Synthetic Controls Model (Fast, Separate)
68 | ```eval_rst
69 | .. autofunction:: SparseSC.fit_fast.fit_fast
70 | :noindex:
71 |
72 | ```
73 |
74 | ```eval_rst
75 | .. autofunction:: SparseSC.utils.match_space.Fixed_V_factory
76 | :noindex:
77 |
78 | ```
79 |
80 | ```eval_rst
81 | .. autofunction:: SparseSC.utils.match_space.MTLassoCV_MatchSpace_factory
82 | :noindex:
83 |
84 | ```
85 |
86 | ```eval_rst
87 | .. autofunction:: SparseSC.utils.match_space.MTLSTMMixed_MatchSpace_factory
88 | :noindex:
89 |
90 | ```
91 |
92 | ```eval_rst
93 | .. autofunction:: SparseSC.utils.match_space.MTLassoMixed_MatchSpace_factory
94 | :noindex:
95 |
96 | ```
97 |
--------------------------------------------------------------------------------
/docs/azure_batch.md:
--------------------------------------------------------------------------------
1 | # Running Jobs in Parallel with Azure Batch
2 |
3 | Fitting a Sparse Synthetic Controls model can result in a very long running
4 | time. Fortunately much of the work can be done in parallel and executed in
5 | the cloud, and the SparseSC package comes with an Azure Batch utility which
6 | can be used to fit a Synthetic controls model using Azure Batch. There are
7 | code examples for Windows CMD, Bash and Powershell.
8 |
9 | ## Setup
10 |
11 | Running SparseSC with Azure Batch requires the `super_batch` library which
12 | can be installed with:
13 |
14 | ```bash
15 | pip install git+https://github.com/jdthorpe/batch-config.git
16 | ```
17 |
18 | Also note that this module has only been tested with Python 3.7.
19 | You will also need `pyyaml` and `psutil`.
20 |
21 | ### Create the Required Azure resources
22 |
23 | Running SparseSC with Azure Batch requires a an Azure account and handful of
24 | resources and credentials. These can be set up by following along with
25 | [section 4 of the super-batch README](https://github.com/jdthorpe/batch-config#step-4-create-the-required-azure-resources).
26 |
27 | ### Prepare parameters for the Batch Job
28 |
29 | The parameters required to run a batch job can be created using `fit()` by
30 | providing a directory where the parameters files should be stored:
31 |
32 | ```python
33 | from SparseSC import fit
34 | batch_dir = "/path/to/my/batch/data/"
35 |
36 | # initialize the batch parameters in the directory `batch_dir`
37 | fit(x, y, ... , batchDir = batch_dir)
38 | ```
39 |
40 | ### Executing the Batch Job
41 |
42 | In the following Python script, a Batch configuration is created and the
43 | batch job is executed with Azure Batch. Note that in the following script,
44 | the various Batch Account and Storage Account credentials are taken from
45 | system environment varables, as in the [super-batch readme](https://github.com/jdthorpe/batch-config#step-4-create-the-required-azure-resources).
46 |
47 | ```python
48 | import os
49 | from datetime import datetime
50 | from super_batch import Client
51 | from SparseSC.utils.AzureBatch import (
52 | DOCKER_IMAGE_NAME,
53 | create_job,
54 | )
55 | # Batch job names must be unique, and a timestamp is one way to keep it uniquie across runs
56 | timestamp = datetime.utcnow().strftime("%H%M%S")
57 | batch_dir = "/path/to/my/batch/data/"
58 |
59 | batch_client = Client(
60 | # Name of the VM pool
61 | POOL_ID= name,
62 | # number of standard nodes
63 | POOL_NODE_COUNT=5,
64 | # number of low priority nodes
65 | POOL_LOW_PRIORITY_NODE_COUNT=5,
66 | # VM type
67 | POOL_VM_SIZE= "STANDARD_A1_v2",
68 | # Job ID. Note that this must be unique.
69 | JOB_ID= name + timestamp,
70 | # Name of the storage container for storing parameters and results
71 | CONTAINER_NAME= name,
72 | # local directory with the parameters, and where the results will go
73 | BATCH_DIRECTORY= batch_dir,
74 | # Keep the pool around after the run, which saves time when doing
75 | # multiple batch jobs, as it typically takes a few minutes to spin up a
76 | # pool of VMs. (Optional. Default = False)
77 | DELETE_POOL_WHEN_DONE=False,
78 | # Keeping the job details can be useful for debugging:
79 | # (Optional. Default = False)
80 | DELETE_JOB_WHEN_DONE=False
81 | )
82 |
83 | create_job(batch_client, batch_dir)
84 | # run the batch job
85 | batch_client.run()
86 |
87 | # aggregate the results into a fitted model instance
88 | fitted_model = aggregate_batch_results(batch_dir)
89 | ```
90 |
91 | ## Cleaning Up
92 |
93 | When you are done fitting your model with Azure Batch be sure to
94 | [clean up your Azure Resources](https://github.com/jdthorpe/batch-config#step-6-clean-up)
95 | in order to prevent unexpected charges on your Azure account.
96 |
97 | ## Solving
98 |
99 | The Azure batch will just vary one of the penalty parameters. You should therefore not specify the
100 | simplex constraint for the V matrix as then it will be missing one degree of freedom.
101 |
102 | ## FAQ
103 |
104 | 1. What if I get disconnected while the batch job is running?
105 |
106 | Once the pool and the job are created, they will keep running until the
107 | job completes, or your delete the resources. You can reconnect create the
108 | `batch_client` as in the example above and then reconnect to the job and
109 | download the results with:
110 |
111 | ```python
112 | batch_client.load_results()
113 | fitted_model = aggregate_batch_results(batch_dir)
114 | ```
115 |
116 | In fact, if you'd rather not wait for the job to compelte, you can
117 | add the parameter `batch_client.run(... ,wait=False)` and the
118 | `run_batch_job` will return as soon as the job and pool configuration
119 | have been createdn in Azure.
120 |
121 | 1. `batch_client.run()` or `batch_client.load_results()` complain that the
122 | results are in complete. What happened?
123 |
124 | Typically this means that one or more of the jobs failed, and a common
125 | reason for the job to fail is that the VM runs out of memory while
126 | running the batch job. Failed Jobs can be viewed in either the Azure
127 | Batch Explorer or the Azure Portal. The `POOL_VM_SIZE` use above
128 | ("STANDARD_A1_v2") is one of the smallest (and cheapest) VMs available
129 | on Azure. Upgrading to a VM with more memory can help in this
130 | situation.
131 |
132 | 1. Why does `aggregate_batch_results()` take so long?
133 |
134 | Each batch job runs a single gradient descent in V space using a subset
135 | (Cross Validation fold) of the data and with a single pair of penalty
136 | parameters, and return the out of sample error for the held out samples.
137 | `aggregate_batch_results()` very quickly aggregates these out of sample
138 | errors and chooses the optimal penalty parameters given the `choice`
139 | parameter provided to `fit()` or `aggregate_batch_results()`. Finally,
140 | with the selected parameters, a final gradient descent is run using the
141 | full dataset which will be larger than the and take longer as the rate
142 | limiting step
143 | ( [scipy.linalg.solve](https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.solve.html) )
144 | has a running time of
145 | [`O(N^3)`](https://stackoverflow.com/a/12665483/1519199). While it is
146 | possible to run this step in parallel as well, it hasn't yet been
147 | implemented.
148 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | #
4 | # SparseSC documentation build configuration file, created by
5 | # sphinx-quickstart on Thu Sep 27 14:53:55 2018.
6 | #
7 | # This file is execfile()d with the current directory set to its
8 | # containing dir.
9 | #
10 | # Note that not all possible configuration values are present in this
11 | # autogenerated file.
12 | #
13 | # All configuration values have a default; values that are commented out
14 | # serve to show the default.
15 |
16 | # If extensions (or modules to document with autodoc) are in another directory,
17 | # add these directories to sys.path here. If the directory is relative to the
18 | # documentation root, use os.path.abspath to make it absolute, like shown here.
19 | #
20 | import os
21 | import sys
22 | ##Allow MarkDown.
23 | ##Prerequisite. pip install recommonmark
24 | import recommonmark
25 | from recommonmark.transform import AutoStructify
26 |
27 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
28 |
29 | #Allow capturing output
30 | #from SparseSC.utils.misc import capture
31 |
32 |
33 | # -- General configuration ------------------------------------------------
34 |
35 | # If your documentation needs a minimal Sphinx version, state it here.
36 | #
37 | # needs_sphinx = '1.0'
38 |
39 | # Add any Sphinx extension module names here, as strings. They can be
40 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
41 | # ones.
42 | extensions = ['sphinx.ext.autodoc',
43 | 'sphinx.ext.mathjax',
44 | 'sphinx_markdown_tables']
45 |
46 | # Add any paths that contain templates here, relative to this directory.
47 | templates_path = ['_templates']
48 |
49 | # The suffix(es) of source filenames.
50 | # You can specify multiple suffix as a list of string:
51 | #
52 | source_suffix = ['.rst', '.md']
53 |
54 | # The master toctree document.
55 | master_doc = 'index'
56 |
57 | # General information about the project.
58 | project = 'SparseSC'
59 | copyright = '2018, Jason Thorpe, Brian Quistorff, Matt Goldman'
60 | author = 'Jason Thorpe, Brian Quistorff, Matt Goldman'
61 |
62 | # The version info for the project you're documenting, acts as replacement for
63 | # |version| and |release|, also used in various other places throughout the
64 | # built documents.
65 | #
66 | from SparseSC import __version__
67 | version = __version__
68 | # The full version, including alpha/beta/rc tags. (For now, keep the same)
69 | release = version
70 |
71 | # The language for content autogenerated by Sphinx. Refer to documentation
72 | # for a list of supported languages.
73 | #
74 | # This is also used if you do content translation via gettext catalogs.
75 | # Usually you set "language" from the command line for these cases.
76 | language = None
77 |
78 | html_copy_source=False
79 |
80 | # List of patterns, relative to source directory, that match files and
81 | # directories to ignore when looking for source files.
82 | # This patterns also effect to html_static_path and html_extra_path
83 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
84 |
85 | # The name of the Pygments (syntax highlighting) style to use.
86 | pygments_style = 'sphinx'
87 |
88 | # If true, `todo` and `todoList` produce output, else they produce nothing.
89 | todo_include_todos = False
90 |
91 | #Run apidoc from here rather than separate process (so that we can do Read the Docs easily)
92 | #https://github.com/rtfd/readthedocs.org/issues/1139
93 | #Used to exclude some files, but this caused warnings (they weren't removed from the auto-generated index)
94 | # and instead we just made a currated api doc. Leave for now as we might go back.
95 | def run_apidoc(app):
96 | from sphinx.apidoc import main as apidoc_main
97 | cur_dir = os.path.abspath(os.path.dirname(__file__))
98 | buildapidocdir = os.path.join(app.outdir, "apidoc","SparseSC")
99 | module = os.path.join(cur_dir,"..","src","SparseSC")
100 | to_excl = [] #"cross_validation","fit_ct","fit_fold", "fit_loo","optimizers","optimizers/cd_line_search","tensor","utils/ols_model","utils/penalty_utils","utils/print_progress","utils/sub_matrix_inverse","weights"]
101 | #Locally could wrap each to_excl with "*" "*" and put in the apidoc cmd and end and works as exclude patterns, but doesn't work on RTD
102 | #with capture() as out: #doesn't have quiet option
103 | apidoc_main([None, '-f', '-e', '-o', buildapidocdir, module])
104 | #rm module file because we don't link to it directly and this silences the warning
105 | os.remove(os.path.join(buildapidocdir, "modules.rst"))
106 | for excl in to_excl:
107 | path = os.path.join(buildapidocdir, "SparseSC."+excl.replace("/",".")+".rst")
108 | print("removing: "+path)
109 | os.remove(path)
110 |
111 | def skip(app, what, name, obj, skip, options):
112 | #force showing __init__()'s
113 | if name == "__init__":
114 | return False
115 |
116 | skip_fns = []
117 | if what=="class" and '__qualname__' in dir(obj) and obj.__qualname__ in skip_fns:
118 | return True
119 |
120 | # Can't figure out how to get the properties class to skip more targettedly
121 | skip_prs = []
122 | if what=="class" and name in skip_prs:
123 | return True
124 |
125 | skip_mds = []
126 | if what=="module" and name in skip_mds:
127 | return True
128 | #helpful debugging line
129 | #print what, name, obj, dir(obj)
130 |
131 | return skip
132 |
133 | def setup(app):
134 | app.connect('builder-inited', run_apidoc)
135 |
136 | app.connect("autodoc-skip-member", skip)
137 | #Allow MarkDown
138 | app.add_config_value('recommonmark_config', {
139 | 'url_resolver': lambda url: "build/apidoc/" + url,
140 | 'auto_toc_tree_section': ['Contents','Examples'],
141 | 'enable_eval_rst': True,
142 | #'enable_auto_doc_ref': True,
143 | 'enable_math': True,
144 | 'enable_inline_math': True
145 | }, True)
146 | app.add_transform(AutoStructify)
147 |
148 |
149 | # -- Options for HTML output ----------------------------------------------
150 |
151 | # The theme to use for HTML and HTML Help pages. See the documentation for
152 | # a list of builtin themes.
153 | #
154 | html_theme = 'sphinx_rtd_theme'
155 |
156 | # Theme options are theme-specific and customize the look and feel of a theme
157 | # further. For a list of options available for each theme, see the
158 | # documentation.
159 | #
160 | # html_theme_options = {}
161 |
162 | # Add any paths that contain custom static files (such as style sheets) here,
163 | # relative to this directory. They are copied after the builtin static files,
164 | # so a file named "default.css" will overwrite the builtin "default.css".
165 | #html_static_path = ['_static']
166 |
167 | # Custom sidebar templates, must be a dictionary that maps document names
168 | # to template names.
169 | #
170 | # This is required for the alabaster theme
171 | # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
172 | html_sidebars = {
173 | '**': [
174 | 'about.html',
175 | 'navigation.html',
176 | 'relations.html', # needs 'show_related': True theme option to display
177 | 'searchbox.html',
178 | 'donate.html',
179 | ]
180 | }
181 |
182 | ##Allow MarkDown.
183 | source_parsers = {'.md': 'recommonmark.parser.CommonMarkParser', }
184 |
185 |
186 | # -- Options for HTMLHelp output ------------------------------------------
187 |
188 | # Output file base name for HTML help builder.
189 | htmlhelp_basename = 'SparseSCdoc'
190 |
191 |
192 | # -- Options for LaTeX output ---------------------------------------------
193 |
194 | latex_elements = {
195 | # The paper size ('letterpaper' or 'a4paper').
196 | #
197 | # 'papersize': 'letterpaper',
198 |
199 | # The font size ('10pt', '11pt' or '12pt').
200 | #
201 | # 'pointsize': '10pt',
202 |
203 | # Additional stuff for the LaTeX preamble.
204 | #
205 | # 'preamble': '',
206 |
207 | # Latex figure (float) alignment
208 | #
209 | # 'figure_align': 'htbp',
210 | }
211 |
212 | # Grouping the document tree into LaTeX files. List of tuples
213 | # (source start file, target name, title,
214 | # author, documentclass [howto, manual, or own class]).
215 | latex_documents = [
216 | (master_doc, 'SparseSC.tex', 'SparseSC Documentation',
217 | 'Jason Thorpe, Brian Quistorff, Matt Goldman', 'manual'),
218 | ]
219 |
220 |
221 | # -- Options for manual page output ---------------------------------------
222 |
223 | # One entry per manual page. List of tuples
224 | # (source start file, name, description, authors, manual section).
225 | man_pages = [
226 | (master_doc, 'sparsesc', 'SparseSC Documentation',
227 | [author], 1)
228 | ]
229 |
230 |
231 | # -- Options for Texinfo output -------------------------------------------
232 |
233 | # Grouping the document tree into Texinfo files. List of tuples
234 | # (source start file, target name, title, author,
235 | # dir menu entry, description, category)
236 | texinfo_documents = [
237 | (master_doc, 'SparseSC', 'SparseSC Documentation',
238 | author, 'SparseSC', 'One line description of project.',
239 | 'Miscellaneous'),
240 | ]
241 |
242 |
243 |
244 |
--------------------------------------------------------------------------------
/docs/dev_notes.md:
--------------------------------------------------------------------------------
1 | # Developer Notes
2 |
3 | ## Python environments
4 |
5 | You can create Anaconda environments using
6 | ```bash
7 | conda env create -f test/SparseSC_36.yml
8 | ```
9 | You can can do `update` rather than `create` to update existing ones (to avoid [potential bugs](https://stackoverflow.com/a/46114295/3429373) make sure the env isn't currently active).
10 |
11 | Note: When regenerating these files (`conda env export > test/SparseSC_*.yml`) make sure to remove the final `prefix` line since that's computer specific. You can do this automatically on Linux by inserting `| grep -v "prefix"` and on Windows by inserting `| findstr -v "prefix"`.
12 |
13 | ## Building the docs
14 | Requires Python >=3.6 and packages: `sphinx`, `recommonmark`, `sphinx-markdown-tables`.
15 | Use `(n)make htmldocs` and an index HTML file is madeat `docs/build/html/index.html`.
16 |
17 | To build a mini-RTD environment to test building docs:
18 | 1) You can make a new environment with Python 3.7 (`conda create -n SparseSC_37_rtd python=3.7`)
19 | 2) update `pip` (likely fine).
20 | 3) `pip install --upgrade --no-cache-dir -r docs/rtd-base.txt` . This file is loosely kept in sync by looking at the install commands on the rtd run.
21 | 4) `pip install --exists-action=w --no-cache-dir -r docs/rtd-requirements.txt` . This file doesn't list the full environment versions because that causes headaches when the rtd base environment got updated. It downgrades Sphinx to a known good version that allows markdown files to have math in code quotes ([GH Issues](https://github.com/readthedocs/recommonmark/issues/133)) (there might be higher ones that also work, didn't try).
22 |
23 | ## Running examples
24 | The Jupyter notebooks require `matplotlib`, `jupyter`, and `notebook`.
25 |
26 | ## Testing
27 | We use the built-in `unittest`. Can run from makefile using the `tests` target or you can run python directly from the repo root using the following types of commands:
28 |
29 | ```bash
30 | python -m unittest test/test_fit.py #file (only Python >=3.5)
31 | python -m unittest test.test_fit #module
32 | python -m unittest test.test_fit.TestFit #class
33 | python -m unittest test.test_fit.TestFit.test_retrospective #function
34 | ```
35 |
36 |
37 | ## Release Process
38 | * Ensure the makefile target `check` (which does pylint, tests, doc building, and packaging) runs clean
39 | * If new version, check that it's been updated in `SparseSC/src/__init__.py`
40 | * Updated `Changelog.md`
41 | * Tag/Release in version control
42 |
43 |
--------------------------------------------------------------------------------
/docs/estimate-effects.md:
--------------------------------------------------------------------------------
1 | # Treatment Effects
2 |
3 | The `estimate_effects()` function can be used to conduct
4 | [DID](https://en.wikipedia.org/wiki/Difference_in_differences) style
5 | analyses where counter-factual observations are constructed using Sparse
6 | Synthetic Controls.
7 |
8 | ```py
9 | import SparseSC
10 |
11 | # Fit the model:
12 | fitted_estimates = SparseSC.estimate_effects(outcomes,unit_treatment_periods,covariates=X,fast=True,...)
13 |
14 | # Print summary of the model including effect size estimates,
15 | # p-values, and confidendence intervals:
16 | print(fitted_estimates)
17 |
18 | # Extract model attributes:
19 | fitted_estimates.pl_res_post.avg_joint_effect.p_value
20 | fitted_estimates.pl_res_post.avg_joint_effect.CI
21 |
22 | # access the fitted Synthetic Controls model:
23 | fitted_model = fitted_estimates.fit
24 | ```
25 |
26 | The returned object is of class `SparseSCEstResults`.
27 |
28 | #### Feature and Target Data
29 |
30 | When estimating synthetic controls, units of observation are divided into
31 | control and treated units. Data collected on these units may include
32 | observations of the outcome of interest, as well as other characteristics
33 | of the units (termed "covariates", herein). Outcomes may be observed both
34 | before and after an intervention on the treated units.
35 |
36 | To maintain independence of the fitted synthetic controls and the
37 | post-intervention outcomes of interest of treated units, the
38 | post-intervention outcomes from treated units are not used in the fitting
39 | process. There are two cuts from the remaining data that may be used to
40 | fit synthetic controls, and each has it's advantages and disadvantages.
41 |
42 | In the call to `estimate_effects()`, `outcomes` should
43 | be numeric matrices containing data on the target variables collected prior
44 | to (after) the treatment / intervention ( respectively), and the optional
45 | parameter `covariates` may be a matrix of additional features. All matrices
46 | should have one row per unit and one column per observation.
47 |
48 | In addition, the rows in `covariates` and `outcomes` which contain units that were affected
49 | by the intervention ("treated units") should be indicated using the
50 | `treated_units` parameter, which may be a vector of booleans or integers
51 | indicating the rows which belong to treat units.
52 |
53 | #### Statistical parameters
54 |
55 | The confidence level may be specified with the `level` parameter, and the
56 | maximum number of simulations used to produce the placebo distribution may
57 | be set with the `max_n_pl` parameter.
58 |
59 | #### Additional parameters
60 |
61 | Additional keyword arguments are passed on to the call to `fit()`, which is
62 | responsible for fitting the Synthetic Controls used to create the
63 | counterfactuals.
64 |
--------------------------------------------------------------------------------
/docs/examples/estimate_effects.md:
--------------------------------------------------------------------------------
1 | # Treatment effects
2 |
3 | [ Coming soon ]
4 |
--------------------------------------------------------------------------------
/docs/examples/prospective_anomaly_detection.md:
--------------------------------------------------------------------------------
1 | # Anomaly Detection
2 |
3 | ### Overview
4 |
5 | In this scenario the goal is to identify irregular values in an outcome
6 | variable prospectively in a homogeneous population (i.e. when no
7 | treatment / intervention is planned). As an example, we may wish to detect
8 | failure of any one machine in a cluster, and to do so, we wish to create a
9 | synthetic unit for each machine which is composed of a weighted average of
10 | other machines in the cluster. In particular, there may be variation of
11 | the workload across the cluster and where workload may vary across the
12 | cluster by (possibly unobserved) differences in machine hardware, cluster
13 | architecture, scheduler versions, networking architecture, job type, etc.
14 |
15 | Like the Prospective Treatment Effects scenario, *Feature* data consist of
16 | of unit attributes (covariates) and a subset of the pre-intervention values
17 | from the outcome of interest, and **target** data consist of the remaining
18 | pre-intervention values for the outcome of interest, and Cross fold
19 | validation is conducted using the entire dataset, and Cross validation and
20 | gradient folds are determined randomly.
21 |
22 | ### Example
23 |
24 | In this scenario, we'll need a matrix with past observations of the outcome
25 | (target) of interest (`targets`), with one row per unit of observation, and
26 | one column per time period, ordered from left to right. Additionally we
27 | may have another matrix of additional features with one row per unit and
28 | one column per feature (`features`). Armed with this we may wish to construct a
29 | synthetic control model to help decide weather future observations
30 | (`additional_observations`) deviate from their synthetic predictions.
31 |
32 | The strategy will be to divide the `targets` matrix into two parts (before
33 | and after column `t`), one of which will be used as features, and other
34 | which will be treated as outcomes for the purpose of fitting the weights
35 | which make up the synthetic controls model.
36 |
37 | ```python
38 | from numpy import hstack
39 | from SparseSC import fit
40 |
41 | # Let X be the features plus some of the targets
42 | X = hstack([features, targets[:,:t])
43 |
44 | # And let Y be the remaining targets
45 | Y = targets[:,t:]
46 |
47 | # fit the model:
48 | fitted_model = fit(X=X,
49 | Y=Y,
50 | model_type="full")
51 | ```
52 |
53 | The `model_type="full"` allows produces a model in which every unit can
54 | serve as a control for every other unit, unless of course the parameter
55 | `custom_donor_pool` is specified.
56 |
57 | Now with our fitted synthetic control model, as soon as new set of targets
58 | outcomes are observed for each unit, we can create synthetic outcomes using
59 | our fitted model using the `predict()` method:
60 |
61 | ```python
62 | synthetic_controls = fitted_model.predict(additional_observations)
63 | ```
64 |
65 | Note that while the call to `fit()` is computationally intensive, the call
66 | to `model.predict()` is fast and can be used for real time anomaly
67 | detection.
68 |
69 | ### Model Details:
70 |
71 | This model yields a synthetic unit for every unit in the dataset, and
72 | synthetic units are composted of the remaining units not included in the
73 | same gradient fold.
74 |
75 | | Type | Units used to fit V & penalties | Donor pool for W |
76 | |---|---|---|
77 | |(prospective) full|All units|All units|
78 |
--------------------------------------------------------------------------------
/docs/examples/retrospecitve_restricted_synthetic_controls.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/docs/examples/retrospecitve_restricted_synthetic_controls.md
--------------------------------------------------------------------------------
/docs/examples/retrospecitve_synthetic_controls.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/docs/examples/retrospecitve_synthetic_controls.md
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. SparseSC documentation master file, created by
2 | sphinx-quickstart on Thu Sep 27 14:53:55 2018.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to SparseSC's |version| documentation!
7 | ==============================================
8 |
9 | .. toctree::
10 | :maxdepth: 2
11 | :caption: Contents:
12 |
13 | overview
14 | fit
15 | estimate-effects
16 | model-types
17 | api_ref
18 | azure_batch
19 | dev_notes
20 |
21 |
22 | .. the complete docs (build/apidoc/SparseSC/SparseSC) are a bit chaotic* and
23 | redundant to :ref:`modindex`. Replacing with the tamer api_ref.md
24 |
25 | .. toctree::
26 | :maxdepth: 2
27 | :caption: Examples:
28 |
29 | examples/prospective_anomaly_detection
30 |
31 |
32 | Indices and tables
33 | ==================
34 |
35 | * :ref:`genindex`
36 | * :ref:`modindex`
37 |
38 | .. not needed: * :ref:`search`
39 |
--------------------------------------------------------------------------------
/docs/model-types.md:
--------------------------------------------------------------------------------
1 | # Model Types
2 |
3 | There are three distinct types of model fitting with respect choosing
4 | optimal values of the penalty parameters and the collection of units
5 | that are used for cross validation.
6 |
7 | Recall that a synthetic treatment unit is defined as a weighted average of
8 | control units where weights are determined from the targets and are chosen
9 | to minimize the difference between the each the treated unit and its
10 | synthetic unit in the absence of an intervention. This methodology can be
11 | combined with cross-fold validation a in several ways in order to
12 | accommodate a variety of use cases.
13 |
14 | ## Retrospective Treatment Effects
15 |
16 | `model_type = "retrospective"`
17 |
18 | In a retrospective analysis, a subset of units have received a treatment
19 | which possibly correlates with features of the units and values for the
20 | target variable have been collected both before and after an intervention.
21 | For example, a software update may have been applied to machines in a
22 | cluster which were experiences unusually latency, and retrospectively an
23 | analyst wishes to understand the effect of the update on another outcome
24 | such as memory utilization.
25 |
26 | In this scenario, for each treated unit we wish to create a synthetic unit
27 | composed only of untreated units. The units are divided into a training
28 | set consisting of just the control units, and a test set consisting of the
29 | treated units. Within the training set, *feature* data consist of anything
30 | known prior to a treatment or intervention, such as pre-intervention values
31 | from the outcome of interest as well as unit level attributes, sometimes
32 | called "covariatess". Likewise, **target** data consist of observations
33 | from the outcome of interest collected after the treatment or intervention
34 | is initiated.
35 |
36 | Cross-fold validation is done within the training set, holding out a single
37 | fold, identifying feature weights within the remaining folds, creating
38 | synthetic units for each held out unit defined as a weighted average of the
39 | non-held-out units. Out-of-sample prediction errors are calculated for each
40 | training fold and summed across folds. The set of penalty parameters that
41 | minimizes the Cross-Fold out-of-sample error are chosen. Feature weights
42 | are calculated within the training set for the chosen penalties, and
43 | finally individual synthetic units are calculated for each treated unit
44 | which is a weighted average of the control units.
45 |
46 | This model yields a synthetic unit for each treated unit composed of
47 | control units.
48 |
49 | ## Prospective Treatment Effects
50 |
51 | `model_type = "prospective"`
52 |
53 | In a prospective analysis, a subset of units have been designated to
54 | receive a treatment but the treatment has not yet occurred and the
55 | designation of the treatment may be correlated with a (possibly unobserved)
56 | feature of the treatment units. For example, a software update may have
57 | been planned for machines in a cluster which are experiencing unusually
58 | latency, and there is a desire to understand the impact of the software on
59 | memory utilization.
60 |
61 | Like the retrospective scenario, for each treated unit we wish to create a
62 | synthetic unit composed only of untreated units.
63 |
64 | *Feature* data consist of of unit attributes (covariates) and a subset
65 | of the pre-intervention values from the outcome of interest, and **target**
66 | data consist of the remaining pre-intervention values for the outcome of
67 | interest
68 |
69 | Cross fold validation is conducted using the entire dataset without regard
70 | to intent to treat. However, treated units allocated to a single gradient
71 | fold, ensuring that synthetic treated units are composed of only the
72 | control units.
73 |
74 | Cross-fold validation is done within the *test* set, holding out a single
75 | fold, identifying feature weights within the remaining treatment folds
76 | combined with the control units, synthetic units for each held out unit
77 | defined as a weighted average of the full set of control units.
78 |
79 | Out-of-sample prediction errors are calculated for each treatment fold and
80 | the sum of these defines the Cross-Fold out-of-sample error. The set of
81 | penalty parameters that minimizes the Cross-Fold out-of-sample error are
82 | chosen. Feature weights are calculated within the training set for the
83 | chosen penalties, and finally individual synthetic units are calculated for
84 | each full unit which is a weighted average of the control units.
85 |
86 | This model yields a synthetic unit for each treated unit composed of
87 | control units.
88 |
89 | ## Prospective Treatment Effects training
90 |
91 | `model_type = "prospective-restricted"`
92 |
93 | This is motivated by the same example as the previous sample. It requires
94 | a larger set of treated units for similar levels of precision, with the
95 | benefit of substantially faster running time.
96 |
97 | The units are divided into a training set consisting of just the control
98 | units, and a test set consisting of the unit which will be treated.
99 | *feature* data will consist of of unit attributes (covariates) and a subset
100 | of the pre-intervention values from the outcome of interest, and **target**
101 | data consist of the remaining pre-intervention values for the outcome of
102 | interest
103 |
104 | Cross-fold validation is done within the *test* set, holding out a single
105 | fold, identifying feature weights within the remaining treatment folds
106 | combined with the control units, synthetic units for each held out unit
107 | defined as a weighted average of the full set of control units.
108 |
109 | Out-of-sample prediction errors are calculated for each treatment fold and
110 | the sum of these defines the Cross-Fold out-of-sample error. The set of
111 | penalty parameters that minimizes the Cross-Fold out-of-sample error are
112 | chosen. Feature weights are calculated within the training set for the
113 | chosen penalties, and finally individual synthetic units are calculated for
114 | each full unit which is a weighted average of the control units.
115 |
116 | This model yields a synthetic unit for each treated unit composed of
117 | control units.
118 |
119 | Not that this model will tend to have wider confidence intervals and small estimated treatments given the sample it is fit on.
120 |
121 | ## Prospective Failure Detection
122 |
123 | `model_type = "full"`
124 |
125 | In this scenario the goal is to identify irregular values in an outcome
126 | variable prospectively in a homogeneous population (i.e. when no
127 | treatment / intervention is planned). As an example, we may wish to detect
128 | failure of any one machine in a cluster, and to do so, we wish to create a
129 | synthetic unit for each machine which is composed of a weighted average of
130 | other machines in the cluster. In particular, there may be variation of
131 | the workload across the cluster and where workload may vary across the
132 | cluster by (possibly unobserved) differences in machine hardware, cluster
133 | architecture, scheduler versions, networking architecture, job type, etc.
134 |
135 | Like the Prospective Treatment Effects scenario, *Feature* data consist of
136 | of unit attributes (covariates) and a subset of the pre-intervention values
137 | from the outcome of interest, and **target** data consist of the remaining
138 | pre-intervention values for the outcome of interest, and Cross fold
139 | validation is conducted using the entire dataset, and Cross validation and
140 | gradient folds are determined randomly.
141 |
142 | This model yields a synthetic unit for every unit in the dataset, and
143 | synthetic units are composted of the remaining units not included in the
144 | same gradient fold.
145 |
146 | ## Summary
147 |
148 | Here is a summary of the main differences between the model types.
149 |
150 | | Type | Units used to fit V & penalties | Donor pool for W |
151 | |---|---|---|
152 | |retrospective|Controls|Controls|
153 | |prospective|All|Controls|
154 | |prospective-restricted|Treated|Controls|
155 | |(prospective) full|All|All|
156 |
157 | A tree view of differences:
158 | * Treatment date: The *prospective* studies differ from the *retrospective* study in that they can use all units for fitting.
159 | * (Prospective studies) Treated units: The intended-to-treat (*ITT*) studies differ from the *full* in that the "treated" units can't be used for donors.
160 | * (Prospective-ITT studies): The *restrictive* model differs in that it tries to maximize predictive power for just the treated units.
161 |
--------------------------------------------------------------------------------
/docs/performance-notes.md:
--------------------------------------------------------------------------------
1 | # Performance Notes
2 |
3 | ## Running time
4 |
5 | The function `get_max_lambda()` requires a single calculation of the
6 | gradient using all of the available data. In contrast, ` SC.CV_score()`
7 | performs gradient descent within each validation-fold of the data.
8 | Furthermore, in the 'pre-only' scenario the gradient is calculated once for
9 | each iteration of the gradient descent, whereas in the 'controls-only'
10 | scenario the gradient is calculated once for each control unit.
11 | Specifically, each control unit is excluded from the set of units that can
12 | be used to predict it's own post-intervention outcomes, resulting in
13 | leave-one-out gradient descent.
14 |
15 | For large sample sizes in the 'controls-only' scenario, it may be
16 | sufficient to divide the non-held out control units into "gradient folds", such
17 | that controls within the same gradient-fold are not used to predict the
18 | post-intervention outcomes of other control units in the same fold. This
19 | result's in K-fold gradient descent, which improves the speed of
20 | calculating the overall gradient by a factor slightly greater than `c/k`
21 | (where `c` is the number of control units) with an even greater reduction
22 | in memory usage.
23 |
24 | K-fold gradient descent is enabled by passing the parameter `grad_splits`
25 | to `CV_score()`, and for consistency across calls to `CV_score()` it is
26 | recommended to also pass a value to the parameter `random_state`, which is
27 | used in selecting the gradient folds.
28 |
29 | ## Additional Considerations
30 |
31 | If you have the BLAS/LAPACK libraries installed and available to Python,
32 | you should not need to do any further optimization to ensure that maximum
33 | number of processors are used during the execution of `CV_score()`. If
34 | not, you may wish to set the parameter `parallel=True` when you call
35 | `CV_score()` which will split the work across N - 2 sub-processes where N
36 | is the [number of cores in your
37 | machine](https://docs.python.org/2/library/multiprocessing.html#miscellaneous).
38 | (Note that setting `parallel=True` when the BLAS/LAPACK are available will
39 | tend to increase running times.)
40 |
41 |
--------------------------------------------------------------------------------
/docs/rtd-base.txt:
--------------------------------------------------------------------------------
1 | setuptools==54.0.0
2 | mock==1.0.1
3 | pillow==5.4.1
4 | alabaster>=0.7,<0.8,!=0.7.5
5 | commonmark==0.8.1
6 | recommonmark==0.5.0
7 | sphinx<2
8 | sphinx-rtd-theme<0.5
9 | readthedocs-sphinx-ext<2.2
10 |
--------------------------------------------------------------------------------
/docs/rtd-requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.13.1
2 | scikit-learn>=0.19.1
3 | scipy>=0.19.1
4 | Sphinx==1.6.3
5 | sphinx-markdown-tables==0.0.15
6 | statsmodels>=0.8.0
7 | pyyaml>=5.4.1
8 | psutil>=5.8.0
9 | git+git://github.com/jdthorpe/batch-config@7eae164#egg=batch-config
10 |
--------------------------------------------------------------------------------
/examples/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/examples/__init__.py
--------------------------------------------------------------------------------
/examples/strip_magic.py:
--------------------------------------------------------------------------------
1 | with open("DifferentialTrends.py","r+") as f:
2 | new_f = f.readlines()
3 | f.seek(0)
4 | for line in new_f:
5 | if "get_ipython().magic" not in line:
6 | f.write(line)
7 | f.truncate()
8 |
--------------------------------------------------------------------------------
/makefile:
--------------------------------------------------------------------------------
1 | #On Windows can use 'nmake'
2 | # Incl w/ VS (w/ "Desktop development with C++" components)
3 | # Incl in path or use the "Developer Command Prompt for VS...."
4 | # If manually including the folder in path, this can change when VS's MSVC version gets bumped.
5 | # So to always get it you can run (or change for your system):
6 | # "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Auxiliary\Build\vcvars64.bat"
7 | # Can run 'nmake /NOLOGO ...' to remove logo output.
8 | #nmake can't do automatic (pattern) rules like make (has inference rules which aren't cross-platform)
9 |
10 | # For linux, can use Anaconda or if using virtualenv, install virtualenvwrapper
11 | # and alias activate->workon.
12 |
13 | help:
14 | @echo "Use one of the common targets: pylint, package, readmedocs, htmldocs"
15 |
16 | #Allow for slightly different commands for nmake and make
17 | #NB: Don't always need the different DIR_SEP
18 | # \
19 | !ifndef 0 # \
20 | # nmake specific code here \
21 | RMDIR_CMD = rmdir /S /Q # \
22 | RM_CMD = del # \
23 | DIR_SEP = \ # \
24 | !else
25 | # make specific code here
26 | RMDIR_CMD = rm -rf
27 | RM_CMD = rm
28 | DIR_SEP = /# \
29 | # \
30 | !endif
31 |
32 | #Creates a "Source Distribution" and a "Pure Python Wheel" (which is a bit easier for user)
33 | package: package_both
34 |
35 | package_both:
36 | python setup.py sdist bdist_wheel
37 |
38 | package_sdist:
39 | python setup.py sdist
40 |
41 | package_bdist_wheel:
42 | python setup.py bdist_wheel
43 |
44 | pypi_upload:
45 | twine upload dist/*
46 | #python -m twine upload --verbose --repository-url https://test.pypi.org/legacy/ dist/*
47 |
48 | pylint:
49 | -mkdir build
50 | -pylint SparseSC > build$(DIR_SEP)pylint_msgs.txt
51 |
52 | SOURCEDIR = docs/
53 | BUILDDIR = docs/build
54 | BUILDDIRHTML = docs$(DIR_SEP)build$(DIR_SEP)html
55 | BUILDAPIDOCDIR= docs$(DIR_SEP)build$(DIR_SEP)apidoc
56 |
57 | #$(O) is meant as a shortcut for $(SPHINXOPTS).
58 | htmldocs:
59 | -$(RMDIR_CMD) $(BUILDDIRHTML)
60 | -$(RMDIR_CMD) $(BUILDAPIDOCDIR)
61 | # sphinx-apidoc -f -o $(BUILDAPIDOCDIR)/SparseSC SparseSC
62 | # $(RM_CMD) $(BUILDAPIDOCDIR)$(DIR_SEP)SparseSC$(DIR_SEP)modules.rst
63 | @python -msphinx -b html -T -E "$(SOURCEDIR)" "$(BUILDDIR)" $(O)
64 |
65 | examples:
66 | python example-code.py
67 | python examples/fit_poc.py
68 |
69 | tests:
70 | python -m unittest test.test_fit.TestFitForErrors test.test_fit.TestFitFastForErrors test.test_normal.TestNormalForErrors test.test_estimation.TestEstimationForErrors
71 |
72 | #tests_both:
73 | # activate SparseSC_36 && python -m unittest test.test_fit
74 |
75 | #add examples here when working
76 | check: pylint package_bdist_wheel tests_both
77 |
78 | #Have to strip because of unfixed https://github.com/jupyter/nbconvert/issues/503
79 | examples/DifferentialTrends.py: examples/DifferentialTrends.ipynb
80 | jupyter nbconvert examples/DifferentialTrends.ipynb --to script
81 | cd examples && python strip_magic.py
82 |
83 | examples/DifferentialTrends.html: examples/DifferentialTrends.ipynb
84 | jupyter nbconvert examples/DifferentialTrends.ipynb --to html
85 |
86 | examples/DifferentialTrends.pdf: examples/DifferentialTrends.ipynb
87 | cd examples && jupyter nbconvert DifferentialTrends.ipynb --to pdf
88 |
89 | clear_ipynb_output:
90 | jupyter nbconvert --ClearOutputPreprocessor.enabled=True --inplace examples/DifferentialTrends.ipynb
91 | gen_ipynb_output:
92 | jupyter nbconvert --to notebook --execute examples/DifferentialTrends.ipynb
93 |
94 | #Have to cd into subfulder otherwise will pick up potential SparseSC pkg in build/
95 | #TODO: Make the prefix filter automatic
96 | #TODO: check if this way of doing phony targets for nmake works with make
97 | test/SparseSC_36.yml: .phony
98 | activate SparseSC_36 && cd test && conda env export > SparseSC_36.yml
99 | echo Make sure to remove the last prefix line and the pip sparsesc line, as user does pip install -e for that
100 | .phony:
101 |
102 | #Old:
103 | # Don't generate requirements-rtd.txt from conda environments (e.g. pip freeze > rtd-requirements.txt)
104 | # 1) Can be finicky to get working since using pip and docker images and don't need lots of packages (e.g. for Jupyter)
105 | # 2) Github compliains about requests<=2.19.1. Conda can't install 2.20 w/ Python <3.6. Our env is 3.5, but RTD uses Python3.7
106 | # Could switch to using conda
107 | #doc/rtd-requirements.txt:
108 |
109 | conda_env_upate:
110 | deactivate && conda env update -f test/SparseSC_36.yml
111 |
112 | #Just needs to be done once
113 | conda_env_create:
114 | conda env create -f test/SparseSC_36.yml
115 |
116 | jupyter_DifferentialTrends:
117 | START jupyter notebook examples/DifferentialTrends.ipynb > jupyter_launch.log
118 |
--------------------------------------------------------------------------------
/replication/.gitignore:
--------------------------------------------------------------------------------
1 | dta_dir/
2 |
3 |
--------------------------------------------------------------------------------
/replication/figs/sparsesc_fast_pe.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/sparsesc_fast_pe.pdf
--------------------------------------------------------------------------------
/replication/figs/sparsesc_fast_xf2_pe.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/sparsesc_fast_xf2_pe.pdf
--------------------------------------------------------------------------------
/replication/figs/sparsesc_fast_xf_pe.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/sparsesc_fast_xf_pe.pdf
--------------------------------------------------------------------------------
/replication/figs/sparsesc_full_pe.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/sparsesc_full_pe.pdf
--------------------------------------------------------------------------------
/replication/figs/sparsesc_full_xf2_pe.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/sparsesc_full_xf2_pe.pdf
--------------------------------------------------------------------------------
/replication/figs/sparsesc_full_xf_pe.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/sparsesc_full_xf_pe.pdf
--------------------------------------------------------------------------------
/replication/figs/standard_flat55_pe.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/standard_flat55_pe.pdf
--------------------------------------------------------------------------------
/replication/figs/standard_nested_pe.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/standard_nested_pe.pdf
--------------------------------------------------------------------------------
/replication/figs/standard_spfast_pe.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/standard_spfast_pe.pdf
--------------------------------------------------------------------------------
/replication/figs/standard_spfast_xf2_pe.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/standard_spfast_xf2_pe.pdf
--------------------------------------------------------------------------------
/replication/figs/standard_spfast_xf_pe.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/standard_spfast_xf_pe.pdf
--------------------------------------------------------------------------------
/replication/figs/standard_spfull_pe.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/standard_spfull_pe.pdf
--------------------------------------------------------------------------------
/replication/figs/standard_spfull_xf2_pe.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/standard_spfull_xf2_pe.pdf
--------------------------------------------------------------------------------
/replication/figs/standard_spfull_xf_pe.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/figs/standard_spfull_xf_pe.pdf
--------------------------------------------------------------------------------
/replication/smoking.dta:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/replication/smoking.dta
--------------------------------------------------------------------------------
/replication/vmats/fast_fit.txt:
--------------------------------------------------------------------------------
1 | beer(1986) lnincome(1985) lnincome(1987) lnincome(1988) retprice(1985) age15to24(1988) cigsale(1986) cigsale(1987) cigsale(1988)
2 | 0.5130071259983364 22.968210890267105 66.11994103276801 20.372467610421825 0.3953643337713626 146.42793943085758 0.40076349898165653 1.0914176928934336 1.6124490971940708
3 |
--------------------------------------------------------------------------------
/replication/vmats/full_fit.txt:
--------------------------------------------------------------------------------
1 | beer beer(1984) beer(1985) beer(1986) beer(1987) beer(1988) cigsale(1970) cigsale(1971) cigsale(1972) cigsale(1973) cigsale(1974) cigsale(1975) cigsale(1976) cigsale(1977) cigsale(1978) cigsale(1979) cigsale(1980) cigsale(1981) cigsale(1982) cigsale(1983) cigsale(1984) cigsale(1985) cigsale(1986) cigsale(1987) cigsale(1988)
2 | 0.0001931675090059052 0.00011673718456481037 0.00017442025581167862 0.00019492525551082786 0.00023112626647203255 0.0002485464241715504 0.04316485032629001 0.04855303157528119 0.05649902721742458 0.05969122641855661 0.0605678608204393 0.0626985287374379 0.0689133161012072 0.06649196460099964 0.06167250333634058 0.05650907357046557 0.05510538799749318 0.052133342737103204 0.05262319599234059 0.048749408792219784 0.04201710074463961 0.04108083109059892 0.04125358244144677 0.040033436688557196 0.04108340791562128
3 |
--------------------------------------------------------------------------------
/replication/vmats/xf_fits_fast2.txt:
--------------------------------------------------------------------------------
1 | 23 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
2 | 23 52.9138366517419 9.288186511419811 0.8048612299517014 0.20594545615924512 0.4714023965271641 0.9417936107673718 1.3465241057644206
3 | 9 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
4 | 9 32.650879763821514 37.02197484931736 0.542540877404274 0.014584754143649414 0.4798681872289629 0.9368634841278982 1.371249208763717
5 | 27 lnincome(1985) lnincome(1987) retprice(1985) cigsale(1986) cigsale(1987) cigsale(1988)
6 | 27 32.004396105062035 30.46557116615753 0.6970722297526234 0.38688970430837605 0.8419576015155998 1.5282063874976552
7 | 14 lnincome(1985) lnincome(1987) retprice(1985) cigsale(1984) cigsale(1987) cigsale(1988)
8 | 14 41.19800297067463 15.510595853489166 0.722493021351017 0.0205302688781237 0.9289118032490858 1.7545389645850142
9 | 25 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
10 | 25 37.216813057528164 25.221489830962824 0.657228392750966 0.02454777512573974 0.43500171319846226 0.8125446272254442 1.5057689475284177
11 | 4 lnincome(1985) lnincome(1987) retprice(1985) cigsale(1986) cigsale(1987) cigsale(1988)
12 | 4 22.033380508626728 37.95607243445006 0.7392112404836514 0.2679160799269573 0.8907623648242043 1.5853183181096129
13 | 2 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
14 | 2 23.75657592484401 41.23483140470908 0.6910275048224122 0.04477245768082374 0.32116721757583316 0.916959715791747 1.5374270606732956
15 | 28 beer(1988) lnincome(1985) lnincome(1987) retprice(1985) cigsale(1986) cigsale(1988)
16 | 28 0.4278699614050563 18.808995663124076 51.15028075063535 0.3855812959197057 0.42320896120863377 2.5603609090753428
17 | 3 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
18 | 3 35.774461358952315 24.42133906654064 0.70260293128289 0.03315716245442422 0.4498448389541703 0.8010006448441934 1.4886819652342287
19 | 6 beer(1986) lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
20 | 6 0.09165879470209615 21.0404041711949 44.783387436919206 0.6866341032510951 0.15078832914078785 0.12407411039828849 1.1316974441477015 1.542310302014068
21 | 5 beer(1988) lnincome(1983) lnincome(1988) retprice(1980) retprice(1985) retprice(1988) age15to24(1988) cigsale(1982) cigsale(1984) cigsale(1986) cigsale(1987) cigsale(1988)
22 | 5 1.441267670831488 16.17283654516924 104.89648163273378 0.09551657763150784 0.3641232956059621 0.33445253937605773 851.0159554272302 0.10935326306525987 0.028820136107687223 0.003648718605989471 1.416359875540445 1.6223994793627619
23 | 16 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
24 | 16 37.614732176886335 27.89280012177113 0.620584501883219 0.08323182778087279 0.4319512962017977 0.8848872277804385 1.4643646690933532
25 | 32 lnincome(1987) retprice(1985) cigsale(1986) cigsale(1987) cigsale(1988)
26 | 32 70.46915458836764 0.5141370149716578 0.7971683142142422 0.8480369389565701 1.2087497747969596
27 | 31 lnincome(1985) lnincome(1987) retprice(1985) cigsale(1986) cigsale(1987) cigsale(1988)
28 | 31 19.951200540111767 38.488804338029084 0.6758079493890541 0.48036987327851277 0.7061985896568476 1.5592920543203923
29 | 34 lnincome(1985) lnincome(1987) retprice(1985) cigsale(1986) cigsale(1987) cigsale(1988)
30 | 34 39.40376183490209 22.354550796059744 0.6888351098369737 0.41524836151985645 0.7825350950415934 1.5554797074795266
31 | 15 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
32 | 15 30.132849920047796 30.779908003822648 0.5972116627569648 0.11990144050863925 0.35550657037139605 0.8863180945932602 1.5035821890610244
33 | 11 beer(1986) lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
34 | 11 0.002505453803009291 26.818904888838087 38.1254682990119 0.6921247832162963 0.025809909197581832 0.3455231672678883 0.8743694551676355 1.5570267410097178
35 | 21 beer(1986) lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
36 | 21 0.4589963686470479 10.137739000768292 63.320284106226666 0.49840518043861753 0.028262507662936767 0.39409111305372047 1.2395339989892469 1.2661476752987006
37 | 17 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
38 | 17 64.5033184509774 9.574472312436525 0.5627863776716775 0.005630207000854121 0.8150684830960727 0.701261488167543 1.2831255659946963
39 | 26 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
40 | 26 53.81482286497893 4.8721745069110325 0.6285881994463606 0.11099082230607583 0.27774834257279696 0.7975338506147123 1.644516378562484
41 | 1 lnincome(1985) lnincome(1987) retprice(1985) cigsale(1986) cigsale(1987) cigsale(1988)
42 | 1 35.321178164904296 18.753395130096887 0.7248123993386513 0.4455704580670278 0.7709306503211723 1.4955447491570548
43 | 10 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
44 | 10 26.172315514404612 40.5066411648697 0.6210816258045354 0.08574858166221508 0.3894929472153092 0.9582031390511888 1.4532274644713794
45 | 33 beer(1986) lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
46 | 33 0.478314791014334 28.231179947143886 41.24021289768206 0.45429005823127383 0.13228216707157708 0.31092784094021314 1.1063592582571 1.357207450058452
47 | 7 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
48 | 7 38.44525779541849 22.241683904867678 0.6672100263119967 0.00020003168746482707 0.4358017018642251 0.790338563123268 1.5019443042961558
49 | 19 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
50 | 19 33.33300260657349 28.87629219649373 0.646379658811407 0.03824224699002728 0.3994839059986097 0.7982194807630935 1.5626017823817469
51 | 22 lnincome(1983) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
52 | 22 9.828161509868865 57.55044358391577 0.48095435298865974 0.2830639710490712 0.33786416421453 0.8822333421833521 1.4843815147743358
53 | 24 lnincome(1988) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
54 | 24 61.8693982343369 0.6118933951108044 0.07169034659486957 0.24410578816292353 0.9136417430117386 1.616141985352236
55 | 36 lnincome(1985) lnincome(1987) lnincome(1988) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
56 | 36 14.475861078773029 26.943137832098724 15.92435359735392 0.9875744643350647 0.041163824055681746 0.16873475155440118 1.21535634363386 1.4074242537938177
57 | 30 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
58 | 30 33.17457684604746 29.073985497061805 0.5162346705854688 0.11486578771775793 0.4982760481771537 0.8095874955348279 1.4697053562688922
59 | 12 lnincome(1984) lnincome(1987) retprice(1985) cigsale(1987) cigsale(1988)
60 | 12 26.92911058554163 6.155518834714997 0.6285961030571781 1.0346295966649277 1.4340086985377358
61 | 18 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
62 | 18 19.45223698594143 41.78134781747795 0.6792763549942265 0.003319460377707231 0.24435409463401464 0.9326189526484684 1.5990422694981496
63 | 38 lnincome(1985) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
64 | 38 72.4455632990079 0.4853074391447769 0.10797260029294552 0.3127912565447208 1.149499721947342 1.3325911487315836
65 | 35 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
66 | 35 21.367246762453107 25.873179715901948 0.9101086883278595 0.17639448181510284 0.1591127187569252 0.9462375093430284 1.5790071062927429
67 | 13 lnincome(1985) lnincome(1987) retprice(1985) cigsale(1986) cigsale(1987) cigsale(1988)
68 | 13 45.663501999478804 10.455542220073793 0.8321049818628562 0.34649744106868596 0.7978118480212745 1.5693247671585517
69 | 37 beer(1986) lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
70 | 37 0.20369210499860185 28.583256638938593 35.23310804285026 0.7322974820605626 0.013874845991732802 0.44732514640064847 0.8581754884927489 1.4884211448589109
71 | 29 lnincome(1985) lnincome(1987) retprice(1985) retprice(1988) cigsale(1986) cigsale(1987) cigsale(1988)
72 | 29 35.18711223047968 26.12276075274429 0.699080302144707 0.06147429869908371 0.3675090246616547 0.8458766224462938 1.5266674252392296
73 | 20 lnincome(1987) lnincome(1988) retprice(1985) retprice(1988) cigsale(1987) cigsale(1988)
74 | 20 59.399857194938505 10.616584838603158 0.24505149601082157 0.0521000199791986 1.3759206906091328 1.580541231282191
75 | 8 lnincome(1985) retprice(1985) cigsale(1986) cigsale(1987) cigsale(1988)
76 | 8 48.170218174935776 0.7191122125984405 0.7690613403578862 0.44355310614471416 1.4764919171905777
77 |
--------------------------------------------------------------------------------
/replication/vmats/xf_fits_full.txt:
--------------------------------------------------------------------------------
1 | 23 cigsale(1977) cigsale(1986) cigsale(1987) cigsale(1988)
2 | 23 3.220031805241539e-05 0.0004541363359538803 0.0003971751428356052 0.0016313917361968722
3 | 9 cigsale(1986) cigsale(1987) cigsale(1988)
4 | 9 0.0003952631426078369 0.0003601154537986567 0.0011483353677890413
5 | 27 cigsale(1977) cigsale(1986) cigsale(1987) cigsale(1988)
6 | 27 1.8871490715416953e-05 0.0004399069826426927 0.00043629534490380686 0.0010540606594254798
7 | 14 cigsale(1986) cigsale(1987) cigsale(1988)
8 | 14 0.0003229640528205937 0.0003688476045066298 0.000956740279593809
9 | 25 cigsale(1986) cigsale(1987) cigsale(1988)
10 | 25 0.0004202607334967443 0.0004268848440616784 0.0010671839324497162
11 | 4 cigsale(1986) cigsale(1987) cigsale(1988)
12 | 4 0.0003196300849811584 0.00032930779514542914 0.0016052914489879876
13 | 2 cigsale(1986) cigsale(1987) cigsale(1988)
14 | 2 0.00038734543047050905 0.00037124265455011347 0.0010588787288313994
15 | 28 cigsale(1982) cigsale(1986) cigsale(1988)
16 | 28 1.024800187075679e-05 0.00024196009018830406 0.0025578912517229798
17 | 3 cigsale(1986) cigsale(1987) cigsale(1988)
18 | 3 0.00040220420600810497 0.0003571963463967961 0.000890850495876067
19 | 6 cigsale(1977) cigsale(1986) cigsale(1987) cigsale(1988)
20 | 6 1.2390431604427915e-07 0.00046409781938842903 0.0002743527155968748 0.0017239069267463766
21 | 5 cigsale(1977) cigsale(1986) cigsale(1987) cigsale(1988)
22 | 5 2.7787923373098215e-05 0.0003434544188081404 0.0003422614209224295 0.0008778496833327705
23 | 16 cigsale(1986) cigsale(1987) cigsale(1988)
24 | 16 0.0003784240793891911 0.0003674397411472507 0.0009145390885686904
25 | 32 cigsale(1982) cigsale(1983) cigsale(1986) cigsale(1987) cigsale(1988)
26 | 32 2.7245185321026544e-05 2.655913644964166e-05 0.0005845624034860476 0.00046112657110766037 0.0010158811024341598
27 | 31 cigsale(1977) cigsale(1982) cigsale(1986) cigsale(1987) cigsale(1988)
28 | 31 0.00018248850627347452 1.4800388877421827e-05 0.0004229485947024648 0.0003769731202171168 0.0010430100065317626
29 | 34 cigsale(1977) cigsale(1982) cigsale(1986) cigsale(1987) cigsale(1988)
30 | 34 0.00013728124270578521 3.721020977056891e-05 0.0004237767444671827 0.00044270188843288644 0.0012123166239067715
31 | 15 cigsale(1985) cigsale(1986) cigsale(1987) cigsale(1988)
32 | 15 1.494300101581489e-05 0.0003595505300900124 0.00041894917648444024 0.0008757901102167003
33 | 11 cigsale(1977) cigsale(1986) cigsale(1987) cigsale(1988)
34 | 11 9.333161499217062e-06 0.0003719157586207366 0.00033404170597104007 0.0009539006346480628
35 | 21 cigsale(1986) cigsale(1987) cigsale(1988)
36 | 21 0.00045315631293933257 0.00038923440996003735 0.0013524841267793333
37 | 17 cigsale(1986) cigsale(1987) cigsale(1988)
38 | 17 0.0003852429561616505 0.00033760087204227126 0.000882231757549232
39 | 26 cigsale(1977) cigsale(1982) cigsale(1983) cigsale(1986) cigsale(1987) cigsale(1988)
40 | 26 0.00011592293176551638 7.637050088258385e-06 3.2884018372150726e-06 0.00042086462533162314 0.00042901457399833744 0.00114115134113562
41 | 1 cigsale(1986) cigsale(1987) cigsale(1988)
42 | 1 0.0003869479555100085 0.00036510350377175503 0.0010168374051120169
43 | 10 cigsale(1986) cigsale(1987) cigsale(1988)
44 | 10 0.0003878740099791977 0.00035772101520275027 0.0009611957205379999
45 | 33 cigsale(1986) cigsale(1987) cigsale(1988)
46 | 33 0.00035909168155274026 0.0004524634847694947 0.0009077421449587082
47 | 7 cigsale(1986) cigsale(1987) cigsale(1988)
48 | 7 0.00036534733256354617 0.00036294902200991505 0.0009171959812324014
49 | 19 cigsale(1986) cigsale(1987) cigsale(1988)
50 | 19 0.00037962900054159084 0.00035386606609022855 0.0009644690512068232
51 | 22 cigsale(1986) cigsale(1987) cigsale(1988)
52 | 22 0.0004390091607431598 0.0003705713962368925 0.0010729243300373226
53 | 24 cigsale(1986) cigsale(1987) cigsale(1988)
54 | 24 0.00043586384804969745 0.0004063534938212729 0.0011187371006597577
55 | 36 cigsale(1985) cigsale(1986) cigsale(1987) cigsale(1988)
56 | 36 1.2467606071326915e-05 0.000441784626549294 0.0005516730720741426 0.0011395891829565716
57 | 30 cigsale(1986) cigsale(1987) cigsale(1988)
58 | 30 0.0004292474615245855 0.00043906072446289914 0.0009789791169024593
59 | 12 cigsale(1986) cigsale(1987) cigsale(1988)
60 | 12 0.00015634756678617497 0.00046524598371104576 0.0010862701532716264
61 | 18 cigsale(1986) cigsale(1987) cigsale(1988)
62 | 18 0.0004345321289109919 0.00021028873110012252 0.0018020174936725004
63 | 38 cigsale(1986) cigsale(1987) cigsale(1988)
64 | 38 0.00040178596776896587 0.0005333706872139673 0.001150960005824483
65 | 35 cigsale(1977) cigsale(1986) cigsale(1987) cigsale(1988)
66 | 35 6.88134845640188e-06 0.0004183524461092219 0.00045105353956890554 0.0010166728003982875
67 | 13 cigsale(1986) cigsale(1987) cigsale(1988)
68 | 13 0.0003543644684414958 0.0003600256236189256 0.0009416488053131033
69 | 37 cigsale(1986) cigsale(1987) cigsale(1988)
70 | 37 0.00042587151170159475 0.0004661787252079279 0.0011675865092952683
71 | 29 cigsale(1986) cigsale(1987) cigsale(1988)
72 | 29 0.000481981960140239 0.0004486368134434901 0.0011979810228173145
73 | 20 cigsale(1977) cigsale(1986) cigsale(1987) cigsale(1988)
74 | 20 2.008220371182916e-05 0.0002263114136367005 0.0004515698795582304 0.001129838472899092
75 | 8 cigsale(1986) cigsale(1987) cigsale(1988)
76 | 8 0.0004636874855871064 0.0002868862186762217 0.0010817068639995256
77 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- encoding: utf-8 -*-
3 | from __future__ import absolute_import
4 | from __future__ import print_function
5 | import re
6 | from glob import glob
7 | from os.path import basename, dirname, join, splitext, abspath
8 | from setuptools import find_packages, setup
9 | import codecs
10 |
11 | # Allow single version in source file to be used here
12 | # From https://packaging.python.org/guides/single-sourcing-package-version/
13 | def read(*parts):
14 | # intentionally *not* adding an encoding option to open
15 | # see here: https://github.com/pypa/virtualenv/issues/201#issuecomment-3145690
16 | here = abspath(dirname(__file__))
17 | return codecs.open(join(here, *parts), "r").read()
18 |
19 |
20 | def find_version(*file_paths):
21 | version_file = read(*file_paths)
22 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M)
23 | if version_match:
24 | return version_match.group(1)
25 | raise RuntimeError("Unable to find version string.")
26 |
27 |
28 | setup(
29 | name="SparseSC",
30 | version=find_version("src", "SparseSC", "__init__.py"),
31 | description="Sparse Synthetic Controls",
32 | license="MIT",
33 | long_description="%s\n%s"
34 | % (
35 | re.compile("^.. start-badges.*^.. end-badges", re.M | re.S).sub(
36 | "", read("README.md")
37 | ),
38 | re.sub(":[a-z]+:`~?(.*?)`", r"``\1``", read("CHANGELOG.md")),
39 | ),
40 | long_description_content_type="text/markdown",
41 | author="Microsoft Research",
42 | url="https://github.com/Microsoft/SparseSyntheticControls",
43 | packages=find_packages("src"),
44 | package_dir={"": "src"},
45 | py_modules=[splitext(basename(path))[0] for path in glob("src/*.py")],
46 | include_package_data=True,
47 | zip_safe=False,
48 | classifiers=[
49 | # complete classifier list: http://pypi.python.org/pypi?%3Aaction=list_classifiers
50 | "Development Status :: 5 - Production/Stable",
51 | "Intended Audience :: Developers",
52 | "License :: OSI Approved :: BSD License",
53 | "Operating System :: Unix",
54 | "Operating System :: POSIX",
55 | "Operating System :: Microsoft :: Windows",
56 | "Programming Language :: Python",
57 | "Programming Language :: Python :: 3",
58 | "Programming Language :: Python :: 3.4",
59 | "Programming Language :: Python :: 3.5",
60 | "Programming Language :: Python :: 3.6",
61 | "Programming Language :: Python :: 3.7",
62 | "Programming Language :: Python :: Implementation :: CPython",
63 | "Programming Language :: Python :: Implementation :: PyPy",
64 | "Topic :: Utilities",
65 | ],
66 | keywords=["Sparse", "Synthetic", "Controls"],
67 | install_requires=["numpy", "Scipy", "scikit-learn", "pandas", "pyyaml"],
68 | entry_points={
69 | "console_scripts": [
70 | "scgrad=SparseSC.cli.scgrad:main",
71 | "daemon=SparseSC.cli.daemon_process:main",
72 | "stt=SparseSC.cli.stt:main",
73 | ]
74 | },
75 | )
76 |
--------------------------------------------------------------------------------
/src/SparseSC/SparseSC.pyproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Debug
5 | 2.0
6 | {3fdc664a-c1c8-47f0-8d77-8e0679e53c82}
7 |
8 |
9 |
10 | ..\
11 | .
12 | .
13 | {888888a0-9f3d-457c-b088-3a5042f75d52}
14 | Standard Python launcher
15 |
16 |
17 |
18 |
19 |
20 | 10.0
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 | Code
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 | Code
52 |
53 |
54 | Code
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/src/SparseSC/__init__.py:
--------------------------------------------------------------------------------
1 | """ Public API for SparseSC
2 | """
3 |
4 | # PRIMARY FITTING FUNCTIONS
5 | from SparseSC.fit import fit, TrivialUnitsWarning
6 | from SparseSC.fit_fast import fit_fast, _fit_fast_inner, _fit_fast_match
7 | from SparseSC.utils.match_space import (
8 | keras_reproducible, MTLassoCV_MatchSpace_factory, MTLasso_MatchSpace_factory, MTLassoMixed_MatchSpace_factory, MTLSTMMixed_MatchSpace_factory,
9 | Fixed_V_factory, D_LassoCV_MatchSpace_factory
10 | )
11 | from SparseSC.utils.penalty_utils import RidgeCVSolution
12 | from SparseSC.fit_loo import loo_v_matrix, loo_weights, loo_score
13 | from SparseSC.fit_ct import ct_v_matrix, ct_weights, ct_score
14 |
15 | # ESTIMATION FUNCTIONS
16 | from SparseSC.estimate_effects import estimate_effects, get_c_predictions_honest
17 | from SparseSC.utils.dist_summary import SSC_DescrStat, Estimate
18 | from SparseSC.utils.descr_sets import DescrSet, MatchingEstimate
19 |
20 | # Public API
21 | from SparseSC.cross_validation import (
22 | score_train_test,
23 | score_train_test_sorted_v_pens,
24 | CV_score,
25 | )
26 | from SparseSC.tensor import tensor
27 | from SparseSC.weights import weights
28 | from SparseSC.utils.penalty_utils import get_max_w_pen, get_max_v_pen, w_pen_guestimate
29 |
30 | # The version as used in the setup.py
31 | __version__ = "0.2.0"
32 |
--------------------------------------------------------------------------------
/src/SparseSC/cli/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/src/SparseSC/cli/__init__.py
--------------------------------------------------------------------------------
/src/SparseSC/cli/daemon.py:
--------------------------------------------------------------------------------
1 | """Generic linux daemon base class for python 3.x."""
2 | # pylint: disable=multiple-imports
3 |
4 | import sys, os, time, atexit, signal
5 |
6 |
7 | class Daemon:
8 | """A generic daemon class.
9 |
10 | Usage: subclass the daemon class and override the run() method.
11 | """
12 |
13 | def __init__(self, pidfile, fifofile):
14 | self.pidfile = pidfile
15 | self.fifofile = fifofile
16 |
17 | def daemonize(self):
18 | """Deamonize class. UNIX double fork mechanism."""
19 |
20 | try:
21 | pid = os.fork()
22 | if pid > 0:
23 | # exit first parent
24 | sys.exit(0)
25 | except OSError as err:
26 | sys.stderr.write("fork #1 failed: {0}\n".format(err))
27 | sys.exit(1)
28 |
29 | # decouple from parent environment
30 | os.chdir("/")
31 | os.setsid()
32 | os.umask(0)
33 |
34 | # do second fork
35 | try:
36 | pid = os.fork()
37 | if pid > 0:
38 |
39 | # exit from second parent
40 | sys.exit(0)
41 | except OSError as err:
42 | sys.stderr.write("fork #2 failed: {0}\n".format(err))
43 | sys.exit(1)
44 |
45 | # redirect standard file descriptors
46 | sys.stdout.flush()
47 | sys.stderr.flush()
48 | si = open(os.devnull, "r")
49 | so = open(os.devnull, "a+")
50 | se = open(os.devnull, "a+")
51 |
52 | os.dup2(si.fileno(), sys.stdin.fileno())
53 | os.dup2(so.fileno(), sys.stdout.fileno())
54 | os.dup2(se.fileno(), sys.stderr.fileno())
55 |
56 | # write pidfile
57 | atexit.register(self.delpid)
58 |
59 | os.mkfifo(self.fifofile) # pylint: disable=no-member
60 |
61 | pid = str(os.getpid())
62 | with open(self.pidfile, "w+") as f:
63 | f.write(pid + "\n")
64 |
65 | workdir = os.getenv("AZ_BATCH_TASK_WORKING_DIR","/tmp")
66 | sys.stdout = open(os.path.join(workdir,"sc-out.txt"),"a+")
67 | sys.stderr = open(os.path.join(workdir,"sc-err.txt"),"a+")
68 | print("daemonized >>>>>>>>>>>"); sys.stdout.flush()
69 |
70 | def delpid(self):
71 | " clean up"
72 | try:
73 | os.remove(self.pidfile)
74 | except:
75 | print("failed to remove pidfile"); sys.stdout.flush()
76 | raise RuntimeError("failed to remove pidfile")
77 | else:
78 | print("removed pidfile"); sys.stdout.flush()
79 | try:
80 | os.remove(self.fifofile)
81 | except:
82 | print("failed to remove fifofile"); sys.stdout.flush()
83 | raise RuntimeError("failed to remove pidfile")
84 | else:
85 | print("removed fifofile"); sys.stdout.flush()
86 |
87 |
88 |
89 | def start(self):
90 | """Start the daemon."""
91 |
92 | # Check for a pidfile to see if the daemon already runs
93 | try:
94 | with open(self.pidfile, "r") as pf:
95 |
96 | pid = int(pf.read().strip())
97 | except IOError:
98 | pid = None
99 |
100 | if pid:
101 | message = "pidfile {0} already exist. " + "Daemon already running?\n"
102 | sys.stderr.write(message.format(self.pidfile))
103 | sys.exit(1)
104 |
105 | # Start the daemon
106 | self.daemonize()
107 | self.run()
108 |
109 | def stop(self):
110 | """Stop the daemon."""
111 |
112 | # Get the pid from the pidfile
113 | try:
114 | with open(self.pidfile, "r") as pf:
115 | pid = int(pf.read().strip())
116 | except IOError:
117 | pid = None
118 |
119 | if not pid:
120 | message = "pidfile {0} does not exist. " + "Daemon not running?\n"
121 | sys.stderr.write(message.format(self.pidfile))
122 | return # not an error in a restart
123 |
124 | # Try killing the daemon process
125 | try:
126 | while 1:
127 | os.kill(pid, signal.SIGTERM)
128 | time.sleep(0.1)
129 | except OSError as err:
130 | e = str(err.args)
131 | if e.find("No such process") > 0:
132 | if os.path.exists(self.pidfile):
133 | os.remove(self.pidfile)
134 | if os.path.exists(self.fifofile):
135 | os.remove(self.fifofile)
136 | else:
137 | print(str(err.args))
138 | sys.exit(1)
139 |
140 | def restart(self):
141 | """Restart the daemon."""
142 | self.stop()
143 | self.start()
144 |
145 | def run(self):
146 | """You should override this method when you subclass Daemon.
147 |
148 | It will be called after the process has been daemonized by
149 | start() or restart()."""
150 |
--------------------------------------------------------------------------------
/src/SparseSC/cli/daemon_process.py:
--------------------------------------------------------------------------------
1 | """
2 | A modeul to run as a background process
3 |
4 | pip install psutil
5 |
6 | usage:
7 | python -m SparseSC.cli.daemon_process start
8 | python -m SparseSC.cli.daemon_process stop
9 | python -m SparseSC.cli.daemon_process status
10 | """
11 | # pylint: disable=multiple-imports
12 |
13 | import sys, os, time, atexit, signal, json,psutil
14 | from yaml import load, dump
15 | try:
16 | from yaml import CLoader as Loader, CDumper as Dumper
17 | except ImportError:
18 | from yaml import Loader, Dumper
19 |
20 | from .scgrad import grad_part, DAEMON_FIFO, DAEMON_PID, _BASENAMES
21 |
22 | #-- DAEMON_FIFO = "/tmp/sc-daemon.fifo"
23 | #-- DAEMON_PID = "/tmp/sc-gradient-daemon.pid"
24 | #--
25 | #-- _CONTAINER_OUTPUT_FILE = "output.yaml" # Standard Output file
26 | #-- _GRAD_COMMON_FILE = "common.yaml"
27 | #-- _GRAD_PART_FILE = "part.yaml"
28 | #--
29 | #-- _BASENAMES = [_GRAD_COMMON_FILE, _GRAD_PART_FILE, _CONTAINER_OUTPUT_FILE]
30 |
31 |
32 |
33 | pidfile, fifofile = DAEMON_PID, DAEMON_FIFO
34 |
35 |
36 | def stop():
37 | """Stop the daemon."""
38 |
39 | # Get the pid from the pidfile
40 | try:
41 | with open(pidfile, "r") as pf:
42 | pid = int(pf.read().strip())
43 | except IOError:
44 | pid = None
45 |
46 | if not pid:
47 | message = "pidfile {0} does not exist. " + "Daemon not running?\n"
48 | sys.stderr.write(message.format(pidfile))
49 | return # not an error in a restart
50 |
51 | # Try killing the daemon process
52 | try:
53 | while 1:
54 | os.kill(pid, signal.SIGTERM)
55 | time.sleep(0.1)
56 | except OSError as err:
57 | e = str(err.args)
58 | if e.find("No such process") > 0:
59 | if os.path.exists(pidfile):
60 | os.remove(pidfile)
61 | if os.path.exists(fifofile):
62 | os.remove(fifofile)
63 | else:
64 | print(str(err.args))
65 | sys.exit(1)
66 |
67 |
68 | def run():
69 | """
70 | do work
71 | """
72 | while True:
73 | with open(DAEMON_FIFO, "r") as fifo:
74 | try:
75 | params = fifo.read()
76 | print("params: " + params)
77 | sys.stdout.flush()
78 | tmpdirname, return_fifo, k = json.loads(params)
79 | print(_BASENAMES)
80 | for file in os.listdir(tmpdirname):
81 | print(file)
82 | common_file, part_file, out_file = [
83 | os.path.join(tmpdirname, name) for name in _BASENAMES
84 | ]
85 | print([common_file, part_file, out_file, return_fifo, k])
86 | sys.stdout.flush()
87 |
88 | # LOAD IN THE INPUT FILES
89 | with open(common_file, "r") as fp:
90 | common = load(fp, Loader=Loader)
91 | with open(part_file, "r") as fp:
92 | part = load(fp, Loader=Loader)
93 |
94 | # DO THE WORK
95 | print("about to do work: ")
96 | sys.stdout.flush()
97 | grad = grad_part(common, part, int(k))
98 | print("did work: ")
99 | sys.stdout.flush()
100 |
101 | # DUMP THE RESULT TO THE OUTPUT FILE
102 | with open(out_file, "w") as fp:
103 | fp.write(dump(grad, Dumper=Dumper))
104 |
105 | except Exception as err: # pylint: disable=broad-except
106 |
107 | # SOMETHING WENT WRONG, RESPOND WITH A NON-ZERO
108 | try:
109 | with open(return_fifo, "w") as rf:
110 | rf.write("1")
111 | except: # pylint: disable=bare-except
112 | print("double failed...: ")
113 | sys.stdout.flush()
114 | else:
115 | print(
116 | "failed with {}: {}",
117 | err.__class__.__name__,
118 | getattr(err, "message", "<>"),
119 | )
120 | sys.stdout.flush()
121 |
122 | else:
123 | # SEND THE SUCCESS RESPONSE
124 | print("success...: ")
125 | sys.stdout.flush()
126 | with open(return_fifo, "w") as rf:
127 | rf.write("0")
128 | print("and wrote about it...: ")
129 | sys.stdout.flush()
130 |
131 |
132 | def start():
133 | """
134 | start the process_deamon if it's not already started
135 | """
136 |
137 | # IS THE DAEMON ALREADY RUNNING?
138 | try:
139 | with open(pidfile, "r") as pf:
140 | pid = int(pf.read().strip())
141 | except IOError:
142 | pid = None
143 |
144 | if pid:
145 | message = "pidfile {0} already exist. " + "Daemon already running?\n"
146 | sys.stderr.write(message.format(pidfile))
147 | return
148 |
149 | def delpid():
150 | " clean up"
151 | try:
152 | os.remove(pidfile)
153 | except:
154 | print("failed to remove pidfile"); sys.stdout.flush()
155 | raise RuntimeError("failed to remove pidfile")
156 | else:
157 | print("removed pidfile"); sys.stdout.flush()
158 | try:
159 | os.remove(fifofile)
160 | except:
161 | print("failed to remove fifofile"); sys.stdout.flush()
162 | raise RuntimeError("failed to remove pidfile")
163 | else:
164 | print("removed fifofile"); sys.stdout.flush()
165 | atexit.register(delpid)
166 |
167 | os.mkfifo(fifofile) # pylint: disable=no-member
168 |
169 | pid = str(os.getpid())
170 | with open(pidfile, "w+") as f:
171 | f.write(pid + "\n")
172 |
173 | workdir = os.getenv("AZ_BATCH_TASK_WORKING_DIR","/tmp")
174 |
175 |
176 | # """Start the daemon."""
177 | sys.stdout = open(os.path.join(workdir,"sc-ps-out.txt"),"a+")
178 | sys.stderr = open(os.path.join(workdir,"sc-ps-err.txt"),"a+")
179 | print("process started >>>>>>>>>>>"); sys.stdout.flush()
180 | run()
181 |
182 |
183 | def status():
184 | """
185 | check the process_deamon status
186 | """
187 | if not os.path.exists(DAEMON_PID):
188 | print("Daemon not running")
189 | return
190 | with open(DAEMON_PID,'r') as fh:
191 | _pid = int(fh.read())
192 |
193 | if _pid in psutil.pids():
194 | print("daemon process (pid {}) is running".format(_pid))
195 | else:
196 | print("daemon process (pid {}) NOT is running".format(_pid))
197 |
198 | def main():
199 | ARGS = sys.argv[1:]
200 |
201 | if not ARGS:
202 | print("no args provided")
203 | elif ARGS[0] == "trystart":
204 | start()
205 | elif ARGS[0] == "start":
206 | start()
207 | elif ARGS[0] == "stop":
208 | stop()
209 | elif ARGS[0] == "status":
210 | status()
211 | else:
212 | print("unknown command '{}'".format(ARGS[0]))
213 |
214 |
215 | if __name__ == "__main__":
216 | main()
217 |
--------------------------------------------------------------------------------
/src/SparseSC/cli/stt.py:
--------------------------------------------------------------------------------
1 | """
2 | somethings something
3 | """
4 | # pylint: disable=invalid-name, unused-import
5 | import sys
6 | import numpy
7 | from yaml import load, dump
8 |
9 | try:
10 | from yaml import CLoader as Loader, CDumper as Dumper
11 | except ImportError:
12 | from yaml import Loader, Dumper
13 |
14 | from SparseSC.cross_validation import score_train_test
15 |
16 |
17 | def get_config(infile):
18 | """
19 | read in the contents of the inputs yaml file
20 | """
21 |
22 | with open(infile, "r") as fp:
23 | config = load(fp, Loader=Loader)
24 | try:
25 | v_pen = tuple(config["v_pen"])
26 | except TypeError:
27 | v_pen = (config["v_pen"],)
28 |
29 | try:
30 | w_pen = tuple(config["w_pen"])
31 | except TypeError:
32 | w_pen = (config["w_pen"],)
33 |
34 | return v_pen, w_pen, config
35 |
36 |
37 | def main():
38 | # GET THE COMMAND LINE ARGS
39 | ARGS = sys.argv[1:]
40 | if ARGS[0] == "ssc.py":
41 | ARGS.pop(0)
42 | assert (
43 | len(ARGS) == 3
44 | ), "ssc.py expects 2 parameters, including a file name and a batch number"
45 | infile, outfile, batchNumber = ARGS
46 | batchNumber = int(batchNumber)
47 |
48 | v_pen, w_pen, config = get_config(infile)
49 | n_folds = len(config["folds"]) * len(v_pen) * len(w_pen)
50 |
51 | assert 0 <= batchNumber < n_folds, "Batch number out of range"
52 | i_fold = batchNumber % len(config["folds"])
53 | i_v = (batchNumber // len(config["folds"])) % len(v_pen)
54 | i_w = (batchNumber // len(config["folds"])) // len(v_pen)
55 |
56 | params = config.copy()
57 | del params["folds"]
58 | del params["v_pen"]
59 | del params["w_pen"]
60 |
61 | train, test = config["folds"][i_fold]
62 | out = score_train_test(
63 | train=train, test=test, v_pen=v_pen[i_v], w_pen=w_pen[i_w], **params
64 | )
65 |
66 | with open(outfile, "w") as fp:
67 | fp.write(
68 | dump(
69 | {
70 | "batch": batchNumber,
71 | "i_fold": i_fold,
72 | "i_v": i_v,
73 | "i_w": i_w,
74 | "results": out,
75 | },
76 | Dumper=Dumper,
77 | )
78 | )
79 |
80 |
81 | if __name__ == "__main__":
82 | import sys
83 | try:
84 | main()
85 | except: # catch *all* exceptions
86 | e = sys.exc_info()[0]
87 | print( "STT Error: %s" % e )
88 |
89 |
--------------------------------------------------------------------------------
/src/SparseSC/optimizers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/src/SparseSC/optimizers/__init__.py
--------------------------------------------------------------------------------
/src/SparseSC/optimizers/simplex_step.py:
--------------------------------------------------------------------------------
1 | """
2 | Gradient descent within the simplex
3 |
4 | Method inspired by the qustion "what would a water droplet stuckin insise the
5 | positive simplex go when pulled in the direction of the gradient which would
6 | otherwise take the droplet outside of the simplex
7 | """
8 | # pylint: disable=invalid-name
9 | import numpy as np
10 | from numpy import (
11 | array,
12 | append,
13 | arange,
14 | cumsum,
15 | random,
16 | zeros,
17 | ones,
18 | logical_or,
19 | logical_not,
20 | logical_and,
21 | maximum,
22 | argmin,
23 | where,
24 | sort,
25 | )
26 |
27 |
28 | def _sub_simplex_project(d_hat, indx):
29 | """
30 | A utility function which projects the gradient into the subspace of the
31 | simplex which intersects the plane x[_index] == 0
32 | """
33 | # now project the gradient perpendicular to the edge we just came up against
34 | n = len(d_hat)
35 | _n = float(n)
36 | a_dot_a = (n - 1) / _n
37 | a_tilde = -ones(n) / _n
38 | a_tilde[indx] += 1 # plus a'
39 | proj_a_d = (d_hat.dot(a_tilde) / a_dot_a) * a_tilde
40 | d_tilde = d_hat - proj_a_d
41 | return d_tilde
42 |
43 |
44 | def simplex_step(x, g, verbose=False):
45 | """
46 | follow the gradint as far as you can within the positive simplex
47 | """
48 | i = 0
49 | x, g = x.copy(), g.copy()
50 | # project the gradient into the simplex
51 | g = g - (g.sum() / len(x)) * ones(len(g))
52 | _g = g.copy()
53 | while True:
54 | if verbose:
55 | print("iter: %s, g: %s" % (i, g))
56 | # we can move in the direction of the gradient if either
57 | # (a) the gradient points away from the axis
58 | # (b) we're not yet touching the axis
59 | valid_directions = logical_or(g < 0, x > 0)
60 | if verbose:
61 | print(
62 | " valid_directions(%s, %s, %s): %s "
63 | % (
64 | valid_directions.sum(),
65 | (g < 0).sum(),
66 | (x > 0).sum(),
67 | ", ".join(str(x) for x in valid_directions),
68 | )
69 | )
70 | if not valid_directions.any():
71 | break
72 | if any(g[logical_not(valid_directions)] != 0):
73 | # TODO: make sure is is invariant on the order of operations
74 | n_valid = where(valid_directions)[0]
75 | W = where(logical_not(valid_directions))[0]
76 | for i, _w in enumerate(W):
77 | # TODO: Project the invalid directions into the current (valid) subspace of
78 | # the simplex
79 | mask = append(array(_w), n_valid)
80 | # print("work in progress")
81 | g[mask] = _sub_simplex_project(g[mask], 0)
82 | g[_w] = 0 # may not be exactly zero due to rounding error
83 | # rounding error can take us out of the simplex positive orthant:
84 | g = maximum(0, g)
85 | if (g == zeros(len(g))).all():
86 | # we've arrived at a corner and the gradient points outside the constrained simplex
87 | break
88 | # HOW FAR CAN WE GO?
89 | limit_directions = logical_and(valid_directions, g > 0)
90 | xl = x[limit_directions]
91 | gl = g[limit_directions]
92 | ratios = xl / gl
93 | try:
94 | c = ratios.min()
95 | except:
96 | import pdb
97 |
98 | pdb.set_trace()
99 | if c > 1:
100 | x = x - g
101 | # pdb.set_trace()
102 | break
103 | arange(len(g))
104 | indx = argmin(ratios)
105 | # MOVE
106 | # there's gotta be a better way...
107 | _indx = where(limit_directions)[0][indx]
108 | tmp = -ones(len(x))
109 | tmp[valid_directions] = arange(valid_directions.sum())
110 | __indx = int(tmp[_indx])
111 | # get the index
112 | del xl, gl, ratios
113 | x = x - c * g
114 | # PROJECT THE GRADIENT
115 | d_tilde = _sub_simplex_project(g[valid_directions] * (1 - c), __indx)
116 | if verbose:
117 | print(
118 | "i: %s, which: %s, g.sum(): %f, x.sum(): %f, x[i]: %f, g[i]: %f, d_tilde[i]: %f"
119 | % (
120 | i,
121 | indx,
122 | g.sum(),
123 | x.sum(),
124 | x[valid_directions][__indx],
125 | g[valid_directions][__indx],
126 | d_tilde[__indx],
127 | )
128 | )
129 | g[valid_directions] = d_tilde
130 | # handle rounding error...
131 | x[_indx] = 0
132 | g[_indx] = 0
133 | # INCREMENT THE COUNTER
134 | i += 1
135 | if i > len(x):
136 | raise RuntimeError("something went wrong")
137 | return x
138 |
139 | def simplex_step_proj_sort(x, g, verbose=False):
140 | x_new = simplex_proj_sort(x-g)
141 | return x_new
142 |
143 | #There's a fast version which uses the median finding algorithm rather than full sorting, but more complicated
144 | #See https://en.wikipedia.org/wiki/Simplex#Projection_onto_the_standard_simplex
145 | # and https://gist.github.com/mblondel/6f3b7aaad90606b98f71
146 | def simplex_proj_sort(v, verbose=False):
147 | k = v.shape[0]
148 | if k == 1:
149 | return np.array([1])
150 |
151 | u = sort(v)[::-1] #switches the order
152 | ind = arange(1, k+1) #shift to 1-indexing
153 | pis = (cumsum(u) - 1) / ind
154 | rho = ind[(u - pis) > 0][-1]
155 | theta = pis[rho-1] #shift back to 0-indexing
156 | v_new = maximum(v - theta, 0)
157 |
158 | return v_new
159 |
--------------------------------------------------------------------------------
/src/SparseSC/tensor.py:
--------------------------------------------------------------------------------
1 | """
2 | Calculates the tensor (V) matrix which puts the metric on the covariate space
3 | """
4 | from SparseSC.fit_fold import fold_v_matrix
5 | from SparseSC.fit_loo import loo_v_matrix
6 | from SparseSC.fit_ct import ct_v_matrix
7 | import numpy as np
8 |
9 |
10 | def tensor(X, Y, X_treat=None, Y_treat=None, grad_splits=None, **kwargs):
11 | """ Presents a unified api for ct_v_matrix and loo_v_matrix
12 | """
13 | # PARAMETER QC
14 | try:
15 | X = np.float64(X)
16 | except ValueError:
17 | raise ValueError("X is not coercible to float64")
18 | try:
19 | Y = np.float64(Y)
20 | except ValueError:
21 | raise ValueError("Y is not coercible to float64")
22 |
23 | Y = np.asmatrix(Y) # this needs to be deprecated properly -- bc Array.dot(Array) != matrix(Array).dot(matrix(Array)) -- not even close !!!
24 | X = np.asmatrix(X)
25 |
26 | if X.shape[1] == 0:
27 | raise ValueError("X.shape[1] == 0")
28 | if Y.shape[1] == 0:
29 | raise ValueError("Y.shape[1] == 0")
30 | if X.shape[0] != Y.shape[0]:
31 | raise ValueError(
32 | "X and Y have different number of rows (%s and %s)"
33 | % (X.shape[0], Y.shape[0])
34 | )
35 |
36 | if (X_treat is None) != (Y_treat is None):
37 | raise ValueError(
38 | "parameters `X_treat` and `Y_treat` must both be Matrices or None"
39 | )
40 |
41 | if X_treat is not None:
42 | # Fit the Treated units to the control units; assuming that Y contains
43 | # pre-intervention outcomes:
44 |
45 | # PARAMETER QC
46 | try:
47 | X_treat = np.float64(X_treat)
48 | except ValueError:
49 | raise ValueError("X_treat is not coercible to float64")
50 | try:
51 | Y_treat = np.float64(Y_treat)
52 | except ValueError:
53 | raise ValueError("Y_treat is not coercible to float64")
54 |
55 | Y_treat = np.asmatrix(Y_treat) # this needs to be deprecated properly -- bc Array.dot(Array) != matrix(Array).dot(matrix(Array)) -- not even close !!!
56 | X_treat = np.asmatrix(X_treat)
57 |
58 | if X_treat.shape[1] == 0:
59 | raise ValueError("X_treat.shape[1] == 0")
60 | if Y_treat.shape[1] == 0:
61 | raise ValueError("Y_treat.shape[1] == 0")
62 | if X_treat.shape[0] != Y_treat.shape[0]:
63 | raise ValueError(
64 | "X_treat and Y_treat have different number of rows (%s and %s)"
65 | % (X_treat.shape[0], Y_treat.shape[0])
66 | )
67 |
68 | # FIT THE V-MATRIX AND POSSIBLY CALCULATE THE w_pen
69 | # note that the weights, score, and loss function value returned here
70 | # are for the in-sample predictions
71 | _, v_mat, _, _, _, _ = ct_v_matrix(
72 | X=np.vstack((X, X_treat)),
73 | Y=np.vstack((Y, Y_treat)),
74 | control_units=np.arange(X.shape[0]),
75 | treated_units=np.arange(X_treat.shape[0]) + X.shape[0],
76 | **kwargs
77 | )
78 |
79 | else:
80 | # Fit the control units to themselves; Y may contain post-intervention outcomes:
81 |
82 | adjusted=False
83 | if kwargs["w_pen"] < 1:
84 | adjusted=True
85 |
86 | if grad_splits is not None:
87 | _, v_mat, _, _, _, _ = fold_v_matrix(
88 | X=X,
89 | Y=Y,
90 | control_units=np.arange(X.shape[0]),
91 | treated_units=np.arange(X.shape[0]),
92 | grad_splits=grad_splits,
93 | **kwargs
94 | )
95 |
96 | #if adjusted:
97 | # print("vmat: %s" % (np.diag(v_mat)))
98 |
99 | else:
100 | _, v_mat, _, _, _, _ = loo_v_matrix(
101 | X=X,
102 | Y=Y,
103 | control_units=np.arange(X.shape[0]),
104 | treated_units=np.arange(X.shape[0]),
105 | **kwargs
106 | )
107 | return v_mat
108 |
--------------------------------------------------------------------------------
/src/SparseSC/utils/AzureBatch/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | AzureBatch module
3 | """
4 | # from .gradient_batch_client import gradient_batch_client # abandon for now
5 | DOCKER_IMAGE_NAME = "jdthorpe/sparsesc:latest"
6 | from .aggregate_results import aggregate_batch_results
7 | from .build_batch_job import create_job
8 |
--------------------------------------------------------------------------------
/src/SparseSC/utils/AzureBatch/aggregate_results.py:
--------------------------------------------------------------------------------
1 | """
2 | aggregate batch results, and optionally use batch to compute the final gradient descent.
3 | """
4 | import numpy as np
5 | import os
6 |
7 | # $ from ...cross_validation import _score_from_batch
8 | from ...fit import _which, SparseSCFit
9 | from ...weights import weights
10 | from ...tensor import tensor
11 | from .constants import _BATCH_CV_FILE_NAME, _BATCH_FIT_FILE_NAME
12 |
13 |
14 | def aggregate_batch_results(batchDir, batch_client_config=None, choice=None):
15 | """
16 | Aggregate results from a batch run
17 | """
18 |
19 | from yaml import load
20 |
21 | try:
22 | from yaml import CLoader as Loader
23 | except ImportError:
24 | from yaml import Loader
25 |
26 | with open(os.path.join(batchDir, _BATCH_CV_FILE_NAME), "r") as fp:
27 | _cv_params = load(fp, Loader=Loader)
28 |
29 | with open(os.path.join(batchDir, _BATCH_FIT_FILE_NAME), "r") as fp:
30 | _fit_params = load(fp, Loader=Loader)
31 |
32 | # https://stackoverflow.com/a/17074606/1519199
33 | pluck = lambda d, *args: (d[arg] for arg in args)
34 |
35 | X_cv, Y_cv, grad_splits, random_state, v_pen, w_pen = pluck(
36 | _cv_params, "X", "Y", "grad_splits", "random_state", "v_pen", "w_pen"
37 | )
38 |
39 | choice = choice if choice is not None else _fit_params["choice"]
40 | X, Y, treated_units, custom_donor_pool, model_type, kwargs = pluck(
41 | _fit_params,
42 | "X",
43 | "Y",
44 | "treated_units",
45 | "custom_donor_pool",
46 | "model_type",
47 | "kwargs",
48 | )
49 |
50 | # this is on purpose (allows for debugging remote sessions at no cost to the local console user)
51 | kwargs["print_path"] = 1
52 |
53 | scores, scores_se = _score_from_batch(batchDir, _cv_params)
54 |
55 | try:
56 | iter(w_pen)
57 | except TypeError:
58 | w_pen_is_iterable = False
59 | else:
60 | w_pen_is_iterable = True
61 |
62 | try:
63 | iter(v_pen)
64 | except TypeError:
65 | v_pen_is_iterable = False
66 | else:
67 | v_pen_is_iterable = True
68 |
69 | # GET THE INDEX OF THE BEST SCORE
70 | def _choose(scores, scores_se):
71 | """ helper function which implements the choice of covariate weights penalty parameter
72 |
73 | Nested here for access to v_pen, w_pe,n w_pen_is_iterable and
74 | v_pen_is_iterable, and choice, via Lexical Scoping
75 | """
76 | # GET THE INDEX OF THE BEST SCORE
77 | if w_pen_is_iterable:
78 | indx = _which(scores, scores_se, choice)
79 | return v_pen, w_pen[indx], scores[indx], indx
80 | if v_pen_is_iterable:
81 | indx = _which(scores, scores_se, choice)
82 | return v_pen[indx], w_pen, scores[indx], indx
83 | return v_pen, w_pen, scores, None
84 |
85 | best_v_pen, best_w_pen, score, which = _choose(scores, scores_se)
86 |
87 | # --------------------------------------------------
88 | # Phase 2: extract V and weights: slow ( tens of seconds to minutes )
89 | # --------------------------------------------------
90 | if treated_units is not None:
91 | control_units = [u for u in range(Y.shape[0]) if u not in treated_units]
92 | Xtrain = X[control_units, :]
93 | Xtest = X[treated_units, :]
94 | Ytrain = Y[control_units, :]
95 | Ytest = Y[treated_units, :]
96 | else:
97 | control_units = None
98 |
99 | if model_type == "prospective-restricted":
100 | best_V = tensor(
101 | X=X_cv,
102 | Y=Y_cv,
103 | w_pen=best_w_pen,
104 | v_pen=best_v_pen,
105 | #
106 | X_treat=Xtest,
107 | Y_treat=Ytest,
108 | #
109 | batch_client_config=batch_client_config, # TODO: not sure if this makes sense...
110 | **_fit_params["kwargs"]
111 | )
112 | else:
113 | best_V = tensor(
114 | X=X_cv,
115 | Y=Y_cv,
116 | w_pen=best_w_pen,
117 | v_pen=best_v_pen,
118 | #
119 | grad_splits=grad_splits,
120 | random_state=random_state,
121 | #
122 | batch_client_config=batch_client_config,
123 | **_fit_params["kwargs"]
124 | )
125 |
126 | if treated_units is not None:
127 |
128 | # GET THE BEST SET OF WEIGHTS
129 | sc_weights = np.empty((X.shape[0], Ytrain.shape[0]))
130 | if custom_donor_pool is None:
131 | custom_donor_pool_t = None
132 | custom_donor_pool_c = None
133 | else:
134 | custom_donor_pool_t = custom_donor_pool[treated_units, :]
135 | custom_donor_pool_c = custom_donor_pool[control_units, :]
136 | sc_weights[treated_units, :] = weights(
137 | Xtrain,
138 | Xtest,
139 | V=best_V,
140 | w_pen=best_w_pen,
141 | custom_donor_pool=custom_donor_pool_t,
142 | )
143 | sc_weights[control_units, :] = weights(
144 | Xtrain, V=best_V, w_pen=best_w_pen, custom_donor_pool=custom_donor_pool_c
145 | )
146 |
147 | else:
148 | # GET THE BEST SET OF WEIGHTS
149 | sc_weights = weights(
150 | X, V=best_V, w_pen=best_w_pen, custom_donor_pool=custom_donor_pool
151 | )
152 |
153 | return SparseSCFit(
154 | features=X,
155 | targets=Y,
156 | control_units=control_units,
157 | treated_units=treated_units,
158 | model_type=model_type,
159 | # fitting parameters
160 | fitted_v_pen=best_v_pen,
161 | fitted_w_pen=best_w_pen,
162 | initial_w_pen=w_pen,
163 | initial_v_pen=v_pen,
164 | V=best_V,
165 | # Fitted Synthetic Controls
166 | sc_weights=sc_weights,
167 | score=score,
168 | scores=scores,
169 | selected_score=which,
170 | )
171 |
172 |
173 | def _score_from_batch(batchDir, config):
174 | """
175 | read in the results from a batch run
176 | """
177 | from yaml import load
178 |
179 | try:
180 | from yaml import CLoader as Loader
181 | except ImportError:
182 | from yaml import Loader
183 |
184 | try:
185 | v_pen = tuple(config["v_pen"])
186 | except TypeError:
187 | v_pen = (config["v_pen"],)
188 |
189 | try:
190 | w_pen = tuple(config["w_pen"])
191 | except TypeError:
192 | w_pen = (config["w_pen"],)
193 |
194 | n_folds = len(config["folds"]) * len(v_pen) * len(w_pen)
195 | n_pens = np.max((len(v_pen), len(w_pen)))
196 | n_cv_folds = n_folds // n_pens
197 |
198 | scores = np.empty((n_pens, n_cv_folds))
199 | for i in range(n_folds):
200 | # i_fold, i_v, i_w = pluck(res, "i_fold", "i_v", "i_w", )
201 | i_fold = i % len(config["folds"])
202 | i_pen = i // len(config["folds"])
203 | with open(os.path.join(batchDir, "fold_{}.yaml".format(i)), "r") as fp:
204 | res = load(fp, Loader=Loader)
205 | assert (
206 | res["batch"] == i
207 | ), "Batch File Import Error Inconsistent batch identifiers"
208 | scores[i_pen, i_fold] = res["results"][2]
209 |
210 | # TODO: np.sqrt(len(scores)) * np.std(scores) is a quick and dirty hack for
211 | # calculating the standard error of the sum from the partial sums. It's
212 | # assumes the samples are equal size and randomly allocated (which is true
213 | # in the default settings). However, it could be made more formal with a
214 | # fixed effects framework, and leveraging the individual errors.
215 | # https://stats.stackexchange.com/a/271223/67839
216 |
217 | if len(v_pen) > 0 or len(w_pen):
218 | n_pens = np.max((len(v_pen), len(w_pen)))
219 | n_cv_folds = n_folds // n_pens
220 | total_score = scores.sum(axis=1)
221 | se = np.sqrt(n_cv_folds) * scores.std(axis=1)
222 | else:
223 | total_score = sum(scores)
224 | se = np.sqrt(len(scores)) * np.std(scores)
225 |
226 | return total_score, se
227 |
228 |
--------------------------------------------------------------------------------
/src/SparseSC/utils/AzureBatch/build_batch_job.py:
--------------------------------------------------------------------------------
1 | """
2 | USAGE:
3 |
4 | python
5 | create_job()
6 | """
7 | from __future__ import print_function
8 | from os.path import join
9 | import super_batch
10 | from SparseSC.cli.stt import get_config
11 | from .constants import (
12 | _CONTAINER_OUTPUT_FILE,
13 | _CONTAINER_INPUT_FILE,
14 | _BATCH_CV_FILE_NAME,
15 | )
16 |
17 | LOCAL_OUTPUTS_PATTERN = "fold_{}.yaml"
18 |
19 |
20 | def create_job(client: super_batch.Client, batch_dir: str) -> None:
21 | r"""
22 | :param client: A :class:`super_batch.Client` instance with the Azure Batch run parameters
23 | :type client: :class:super_batch.Client
24 |
25 | :param str batch_dir: path of the local batch temp directory
26 | """
27 | _LOCAL_INPUT_FILE = join(batch_dir, _BATCH_CV_FILE_NAME)
28 |
29 | v_pen, w_pen, model_data = get_config(_LOCAL_INPUT_FILE)
30 | n_folds = len(model_data["folds"]) * len(v_pen) * len(w_pen)
31 |
32 | # CREATE THE COMMON IMPUT FILE RESOURCE
33 | input_resource = client.build_resource_file(
34 | _LOCAL_INPUT_FILE, _CONTAINER_INPUT_FILE
35 | )
36 |
37 | for fold_number in range(n_folds):
38 |
39 | # BUILD THE COMMAND LINE
40 | command_line = "/bin/bash -c 'stt {} {} {}'".format(
41 | _CONTAINER_INPUT_FILE, _CONTAINER_OUTPUT_FILE, fold_number
42 | )
43 |
44 | # CREATE AN OUTPUT RESOURCE:
45 | output_resource = client.build_output_file(
46 | _CONTAINER_OUTPUT_FILE, LOCAL_OUTPUTS_PATTERN.format(fold_number)
47 | )
48 |
49 | # CREATE A TASK
50 | client.add_task([input_resource], [output_resource], command_line=command_line)
51 |
--------------------------------------------------------------------------------
/src/SparseSC/utils/AzureBatch/constants.py:
--------------------------------------------------------------------------------
1 | """
2 | CONSTANTS USED BY THIS MODULE
3 | """
4 |
5 | _DOCKER_IMAGE = "jdthorpe/sparsesc"
6 | _STANDARD_OUT_FILE_NAME = "stdout.txt" # Standard Output file
7 | _CONTAINER_OUTPUT_FILE = "output.yaml" # Standard Output file
8 | _CONTAINER_INPUT_FILE = "input.yaml" # Standard Output file
9 | _BATCH_CV_FILE_NAME = "cv_parameters.yaml"
10 | _BATCH_FIT_FILE_NAME = "fit_parameters.yaml"
11 | _GRAD_COMMON_FILE = "common.yaml"
12 | _GRAD_PART_FILE = "part.yaml"
13 |
14 |
--------------------------------------------------------------------------------
/src/SparseSC/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/src/SparseSC/utils/__init__.py
--------------------------------------------------------------------------------
/src/SparseSC/utils/batch_gradient.py:
--------------------------------------------------------------------------------
1 | """
2 | utilities for TESTING gradient descent in batches
3 | """
4 | import os
5 | import numpy as np
6 | import subprocess
7 |
8 |
9 | def single_grad_cli(
10 | tmpdir, N0, N1, in_controls, splits, b_i, w_pen, treated_units, Y_treated, Y_control
11 | ):
12 | """
13 | wrapper for the real function
14 | """
15 | from yaml import load, dump
16 |
17 | try:
18 | from yaml import CLoader as Loader, CDumper as Dumper
19 | except ImportError:
20 | from yaml import Loader, Dumper
21 |
22 | _common_params = {
23 | "N0": N0,
24 | "N1": N1,
25 | "in_controls": in_controls,
26 | "splits": splits,
27 | "b_i": b_i,
28 | "w_pen": w_pen,
29 | "treated_units": treated_units,
30 | "Y_treated": Y_treated,
31 | "Y_control": Y_control,
32 | }
33 |
34 | COMMONFILE = os.path.join(tmpdir, "commonfile.yaml")
35 | PARTFILE = os.path.join(tmpdir, "partfile.yaml")
36 | OUTFILE = os.path.join(tmpdir, "outfile.yaml")
37 |
38 | with open(COMMONFILE, "w") as fp:
39 | fp.write(dump(_common_params, Dumper=Dumper))
40 |
41 | def inner(A, weights, dA_dV_ki_k, dB_dV_ki_k):
42 | """
43 | Calculate a single component of the gradient
44 | """
45 | _local_params = {
46 | "A": A,
47 | "weights": weights,
48 | "dA_dV_ki_k": dA_dV_ki_k,
49 | "dB_dV_ki_k": dB_dV_ki_k,
50 | }
51 | with open(PARTFILE, "w") as fp:
52 | fp.write(dump(_local_params, Dumper=Dumper))
53 |
54 | subprocess.run(["scgrad", COMMONFILE, PARTFILE, OUTFILE])
55 |
56 | with open(OUTFILE, "r") as fp:
57 | val = load(fp, Loader=Loader)
58 | return val
59 |
60 | return inner
61 |
62 |
63 | def single_grad(
64 | N0, N1, in_controls, splits, b_i, w_pen, treated_units, Y_treated, Y_control
65 | ):
66 | """
67 | wrapper for the real function
68 | """
69 |
70 | in_controls2 = [np.ix_(i, i) for i in in_controls]
71 |
72 | def inner(A, weights, dA_dV_ki_k, dB_dV_ki_k):
73 | """
74 | Calculate a single component of the gradient
75 | """
76 | dPI_dV = np.zeros((N0, N1)) # stupid notation: PI = W.T
77 | for i, (_, (_, test)) in enumerate(zip(in_controls, splits)):
78 | dA = dA_dV_ki_k[i]
79 | dB = dB_dV_ki_k[i]
80 | try:
81 | b = np.linalg.solve(A[in_controls2[i]], dB - dA.dot(b_i[i]))
82 | except np.linalg.LinAlgError as exc:
83 | print("Unique weights not possible.")
84 | if w_pen == 0:
85 | print("Try specifying a very small w_pen rather than 0.")
86 | raise exc
87 | dPI_dV[np.ix_(in_controls[i], treated_units[test])] = b
88 | # einsum is faster than the equivalent (Ey * Y_control.T.dot(dPI_dV).T.getA()).sum()
89 | return 2 * np.einsum(
90 | "ij,kj,ki->", (weights.T.dot(Y_control) - Y_treated), Y_control, dPI_dV
91 | )
92 |
93 | return inner
94 |
--------------------------------------------------------------------------------
/src/SparseSC/utils/descr_sets.py:
--------------------------------------------------------------------------------
1 | """Store the typical information for a match.
2 | """
3 | from collections import namedtuple
4 |
5 | from SparseSC.utils.dist_summary import SSC_DescrStat
6 |
7 | MatchingEstimate = namedtuple(
8 | 'MatchingEstimate', 'att_est att_debiased_est atut_est atut_debiased_est ate_est ate_debiased_est aa_est naive_est')
9 | #"""
10 | #aa = The difference between control counterfactuals for controls and controls (Y_c_cf_c - Y_c)
11 | #Debiased means we subtract from the estimate the aa estimate
12 | #att = Average Treatment effect on the Treated (Y_t - Y_t_cf_c).
13 | #atut = Average Treatment on the UnTreated (Y_c_cf_t - Y_c)
14 | #ate = Average Treatment Effect (pooling att and atut samples)
15 | #naive = Just comparing Y_t and Y_c (no matching). Helpful for comparison (gauging selection size)
16 | #"""
17 |
18 | class DescrSet:
19 | """Holds potential distribution summaries for the various data used for matching
20 | """
21 |
22 | def __init__(self, descr_Y_t=None, descr_Y_t_cf_c=None, descr_Y_diff_t_cf_c=None,
23 | descr_Y_c=None, descr_Y_diff_t_c=None, descr_Y_c_cf_c=None, descr_Y_diff_c_cf_c=None,
24 | descr_Y_c_cf_t=None, descr_Y_diff_c_cf_t=None):
25 | """Generate the common descriptive stats from data and differences
26 | :param descr_Y_t: SSC_DescrStat for Y_t (outcomes for treated units)
27 | :param descr_Y_t_cf_c: SSC_DescrStat for Y_t_cf_c (outcomes for (control) counterfactuals of treated units)
28 | :param descr_Y_diff_t_cf_c: SSC_DescrStat of Y_t - Y_t_cf_c
29 | :param descr_Y_c: SSC_DescrStat of Y_c (outcomes for control units)
30 | :param descr_Y_diff_t_c: SSC_DescrStat for Y_t-Y_c
31 | :param descr_Y_c_cf_c: SSC_DescrStat for Y_c_cf_c (outcomes for (control) counterfactuals of control units)
32 | :param descr_Y_diff_c_cf_c: SSC_DescrStat for Y_c - Y_c_cf_c
33 | :param descr_Y_c_cf_t: SSC_DescrStat for Y_c_cf_t (outcomes for (TREATED) counterfactuals of control units)
34 | :param descr_Y_diff_c_cf_t: SSC_DescrStat for Y_c_cf_t - Y_c
35 | """
36 | self.descr_Y_t = descr_Y_t
37 | self.descr_Y_t_cf_c = descr_Y_t_cf_c
38 | self.descr_Y_diff_t_cf_c = descr_Y_diff_t_cf_c
39 |
40 | self.descr_Y_c = descr_Y_c
41 | self.descr_Y_diff_t_c = descr_Y_diff_t_c
42 | self.descr_Y_c_cf_c = descr_Y_c_cf_c
43 | self.descr_Y_diff_c_cf_c = descr_Y_diff_c_cf_c
44 |
45 | self.descr_Y_c_cf_t = descr_Y_c_cf_t
46 | self.descr_Y_diff_c_cf_t = descr_Y_diff_c_cf_t
47 |
48 | def __repr__(self):
49 | return ("%s(descr_Y_t=%s, descr_Y_t_cf_c=%s, descr_Y_diff_t_cf_c=%s, descr_Y_c=%s" +
50 | "descr_Y_diff_t_c=%s, descr_Y_c_cf_c=%s, descr_Y_diff_c_cf_c=%s, descr_Y_c_cf_t=%s," +
51 | " descr_Y_diff_c_cf_t=%s)") % (self.__class__.__name__, self.descr_Y_t,
52 | self.descr_Y_t_cf_c, self.descr_Y_diff_t_cf_c, self.descr_Y_c,
53 | self.descr_Y_diff_t_c, self.descr_Y_c_cf_c, self.descr_Y_diff_c_cf_c,
54 | self.descr_Y_c_cf_t, self.descr_Y_diff_c_cf_t)
55 |
56 | @staticmethod
57 | def from_data(Y_t=None, Y_t_cf_c=None, Y_c=None, Y_c_cf_c=None, Y_c_cf_t=None):
58 | """Generate the common descriptive stats from data and differences
59 | :param Y_t: np.array of dim=(N_t, T) for outcomes for treated units
60 | :param Y_t_cf_c: np.array of dim=(N_t, T) for outcomes for (control) counterfactuals of treated units (used to get the average treatment effect on the treated (ATT))
61 | :param Y_c: np.array of dim=(N_c, T) for outcomes for control units
62 | :param Y_c_cf_c: np.array of dim=(N_c, T) for outcomes for (control) counterfactuals of control units (used for AA test)
63 | :param Y_c_cf_t: np.array of dim=(N_c, T) for outcomes for (TREATED) counterfactuals of control units (used to calculate average treatment effect on the untreated (ATUT), or pooled with ATT to get the average treatment effect (ATE))
64 | :returns: DescrSet
65 | """
66 | # Note: While possible, there's no real use for treated matched to other treateds.
67 | def _gen_if_valid(Y):
68 | return SSC_DescrStat.from_data(Y) if Y is not None else None
69 |
70 | def _gen_diff_if_valid(Y1, Y2):
71 | return SSC_DescrStat.from_data(Y1-Y2) if (Y1 is not None and Y2 is not None) else None
72 |
73 | descr_Y_t = _gen_if_valid(Y_t)
74 | descr_Y_t_cf_c = _gen_if_valid(Y_t_cf_c)
75 | descr_Y_diff_t_cf_c = _gen_diff_if_valid(Y_t, Y_t_cf_c)
76 |
77 | descr_Y_c = _gen_if_valid(Y_c)
78 | descr_Y_diff_t_c = _gen_diff_if_valid(Y_t, Y_c)
79 | descr_Y_c_cf_c = _gen_if_valid(Y_c_cf_c)
80 | descr_Y_diff_c_cf_c = _gen_diff_if_valid(Y_c_cf_c, Y_c)
81 | descr_Y_c_cf_t = _gen_if_valid(Y_c_cf_t)
82 | descr_Y_diff_c_cf_t = _gen_diff_if_valid(Y_c_cf_t, Y_c)
83 |
84 | return DescrSet(descr_Y_t, descr_Y_t_cf_c, descr_Y_diff_t_cf_c,
85 | descr_Y_c, descr_Y_diff_t_c, descr_Y_c_cf_c, descr_Y_diff_c_cf_c,
86 | descr_Y_c_cf_t, descr_Y_diff_c_cf_t)
87 |
88 | def __add__(self, other):
89 | def _add_if_valid(a, b):
90 | return a+b if (a is not None and b is not None) else None
91 | return DescrSet(descr_Y_t=_add_if_valid(self.descr_Y_t, other.descr_Y_t),
92 | descr_Y_t_cf_c=_add_if_valid(self.descr_Y_t_cf_c, other.descr_Y_t_cf_c),
93 | descr_Y_diff_t_cf_c=_add_if_valid(self.descr_Y_diff_t_cf_c, other.descr_Y_diff_t_cf_c),
94 | descr_Y_c=_add_if_valid(self.descr_Y_c, other.descr_Y_c),
95 | descr_Y_diff_t_c=_add_if_valid(self.descr_Y_diff_t_c, other.descr_Y_diff_t_c),
96 | descr_Y_c_cf_c=_add_if_valid(self.descr_Y_c_cf_c, other.descr_Y_c_cf_c),
97 | descr_Y_diff_c_cf_c=_add_if_valid(self.descr_Y_diff_c_cf_c, other.descr_Y_diff_c_cf_c),
98 | descr_Y_c_cf_t=_add_if_valid(self.descr_Y_c_cf_t, other.descr_Y_c_cf_t),
99 | descr_Y_diff_c_cf_t=_add_if_valid(self.descr_Y_diff_c_cf_t, other.descr_Y_diff_c_cf_t))
100 |
101 | def calc_estimates(self):
102 | """ Takes matrices of effects for multiple events and return averaged results
103 | """
104 | def _calc_estimate(descr_stat1, descr_stat2):
105 | if descr_stat1 is None or descr_stat2 is None:
106 | return None
107 | return SSC_DescrStat.lcl_comp_means(descr_stat1, descr_stat2)
108 |
109 | att_est = _calc_estimate(self.descr_Y_t, self.descr_Y_t_cf_c)
110 | att_debiased_est = _calc_estimate(self.descr_Y_diff_t_cf_c, self.descr_Y_diff_c_cf_c)
111 |
112 | atut_est = _calc_estimate(self.descr_Y_c_cf_t, self.descr_Y_c)
113 | atut_debiased_est = _calc_estimate(self.descr_Y_diff_c_cf_t, self.descr_Y_diff_c_cf_c)
114 |
115 | ate_est, ate_debiased_est = None, None
116 | if all(d is not None for d in [self.descr_Y_t, self.descr_Y_c_cf_t, self.descr_Y_t_cf_c, self.descr_Y_c]):
117 | ate_est = _calc_estimate(self.descr_Y_t + self.descr_Y_c_cf_t,
118 | self.descr_Y_t_cf_c + self.descr_Y_c)
119 | if all(d is not None for d in [self.descr_Y_diff_t_cf_c, self.descr_Y_diff_c_cf_t, self.descr_Y_diff_c_cf_c]):
120 | ate_debiased_est = _calc_estimate(self.descr_Y_diff_t_cf_c +
121 | self.descr_Y_diff_c_cf_t, self.descr_Y_diff_c_cf_c)
122 |
123 | aa_est = _calc_estimate(self.descr_Y_c_cf_c, self.descr_Y_c)
124 | # descr_Y_diff_c_cf_c #used for the double comparisons
125 |
126 | naive_est = _calc_estimate(self.descr_Y_t, self.descr_Y_c)
127 | # descr_Y_diff_t_c #Don't think this could be useful, but just in case.
128 |
129 | return MatchingEstimate(att_est, att_debiased_est,
130 | atut_est, atut_debiased_est,
131 | ate_est, ate_debiased_est,
132 | aa_est, naive_est)
133 |
--------------------------------------------------------------------------------
/src/SparseSC/utils/dist_summary.py:
--------------------------------------------------------------------------------
1 | """This is a way to summarize (using normal approximations) the distributions of real and
2 | synthetic controls so that all data doesn't have to be stored.
3 | """
4 | import math
5 | from collections import namedtuple
6 |
7 | import statsmodels
8 | import numpy as np
9 |
10 | # TODO:
11 | # - Allow passed in vector and return scalar
12 |
13 | def tstat_generic(mean1, mean2, stdm, dof):
14 | """Vectorized version of statsmodels' _tstat_generic
15 | :param mean1: int or np.array
16 | :param mean2: int or np.array
17 | :param stdm: int or np.array of the standard deviation of the pooled sample
18 | :param dof: int
19 | """
20 | from statsmodels.stats.weightstats import _tstat_generic
21 | l = len(mean1)
22 | if l == 1:
23 | tstat, pval = _tstat_generic(mean1, mean2, stdm, dof, 'two-sided', diff=0)
24 | else:
25 | tstat, pval = zip(*[_tstat_generic(mean1[i], mean2[i], stdm[i], dof, 'two-sided', diff=0)
26 | for i in range(l)])
27 | #tstat = (mean1 - mean2) / stdm #
28 | #pvalue = stats.t.sf(np.abs(tstat), dof)*2
29 | # cohen's d: diff/samplt std dev
30 | return tstat, pval
31 |
32 | # def pooled_variances_scalar(sample_variances, sample_sizes):
33 | # """Estimate pooled variance from a set of samples. Assumes same variance but allows different means."""
34 | # return np.average(sample_variances, weights=(sample_sizes-1))
35 |
36 | def pooled_variances(sample_variances, sample_sizes):
37 | """Estimate pooled variance from a set of samples. Assumes same variance but allows different means.
38 | If inputs are nxl then return l
39 | """
40 | # https://en.wikipedia.org/wiki/Pooled_variance
41 | return np.average(sample_variances, weights=(sample_sizes-1), axis=0)
42 |
43 |
44 | class Estimate: # can hold scalars or vectors
45 | def __init__(self, effect, pval, baseline=None):
46 | self.effect = effect
47 | self.pval = pval
48 | self.baseline=baseline
49 |
50 | class SSC_DescrStat(object):
51 | """Stores mean and variance for a sample in a way that can be updated
52 | with new observations or adding together summaries. Similar to statsmodel's DescrStatW
53 | except we don't keep the raw data, and we use 'online' algorithm's to allow for incremental approach."""
54 | # Similar to https://github.com/grantjenks/python-runstats but uses numpy and doesn't do higher order stats and has multiple columns
55 | # Ref: https://www.statsmodels.org/stable/generated/statsmodels.stats.weightstats.DescrStatsW.html#statsmodels.stats.weightstats.DescrStatsW
56 |
57 | def __init__(self, nobs, mean, M2):
58 | """
59 | :param nobs: scalar
60 | :param mean: vector
61 | :param M2: vector of same length as mean. Sum of squared deviations (sum_i (x_i-mean)^2).
62 | Sometimes called 'S' (capital; though 's' is often sample variance)
63 | :raises: ValueError
64 | """
65 | import numbers
66 | if not isinstance(nobs, numbers.Number):
67 | raise ValueError('mean should be np vector')
68 | self.nobs = nobs
69 | if len(mean.shape)!=1:
70 | raise ValueError('mean should be np vector')
71 | if len(M2.shape)!=1:
72 | raise ValueError('M2 should be np vector')
73 | if len(M2.shape)!=len(mean.shape):
74 | raise ValueError('M2 and mean should be the same length')
75 | self.mean = mean
76 | self.M2 = M2 # sometimes called S (though s is often sample variance)
77 |
78 | def __repr__(self):
79 | return "%s(nobs=%s, mean=%s, M2=%s)" % (self.__class__.__name__, self.nobs, self.mean, self.M2)
80 |
81 | def __eq__(self, obj):
82 | return isinstance(obj, SSC_DescrStat) and np.array_equal(obj.nobs,self.nobs) and np.array_equal(obj.mean, self.mean) and np.array_equal(obj.M2, self.M2)
83 |
84 | @staticmethod
85 | def from_data(data):
86 | """
87 | :param data: 2D np.array. We compute stats per column.
88 | :returns: SSC_DescrStat object
89 | """
90 | N = data.shape[0]
91 | mean = np.average(data, axis=0)
92 | M2 = np.var(data, axis=0)*N
93 | return SSC_DescrStat(N, mean, M2)
94 |
95 | def __add__(self, other, alt=False):
96 | """
97 | Chan's parallel algorithm
98 | :param other: Other SSC_DescrStat object
99 | :param alt: Use when roughly similar sizes and both are large. Avoid catastrophic cancellation
100 | :returns: new SSC_DescrStat object
101 | """
102 | # See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
103 | # TODO: could make alt auto (look at, e.g., abs(self.nobs - other.nobs)/new_n)
104 | new_n = self.nobs + other.nobs
105 | delta = other.mean - self.mean
106 | if not alt:
107 | new_mean = self.mean+delta*(other.nobs/new_n)
108 | else:
109 | new_mean = (self.nobs*self.mean + other.nobs*other.mean)/new_n
110 | new_M2 = self.M2 + other.M2 + np.square(delta) * self.nobs * other.nobs / new_n
111 | return SSC_DescrStat(new_n, new_mean, new_M2)
112 |
113 | def update(self, obs):
114 | """Welford's online algorithm
115 | :param obs: new observation vector
116 | """
117 | # See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
118 | self.nobs += 1
119 | if len(obs.shape)==2 and obs.shape[0]==1:
120 | obs = obs[0,:]
121 | delta = obs - self.mean
122 | self.mean += delta*1./self.nobs
123 | self.M2 += delta*(obs - self.mean)
124 |
125 | def variance(self, ddof=1):
126 | """
127 | :param ddof: delta degree of difference. 1 will give sample variance (s).
128 | :returns: Variance
129 | """
130 | if self.nobs == 1:
131 | return np.zeros(self.mean.shape)
132 | return self.M2/(self.nobs-ddof)
133 |
134 | @property
135 | def var(self):
136 | return self.variance()
137 |
138 | def stddev(self, ddof=1):
139 | """Standard Deviation
140 | :param ddof: delta degree of difference. 1 will give sample Standard Deviation
141 | :returns: Standard Deviation
142 | """
143 | return np.sqrt(self.variance(ddof))
144 |
145 | @property
146 | def std(self):
147 | return self.stddev()
148 |
149 | def std_mean(self, ddof=1):
150 | """Standard Deviation/Error of the mean
151 | :param ddof: delta degree of difference. 1 will give sample variance (s).
152 | :returns: Standard Error
153 | """
154 | return self.stddev(ddof)/math.sqrt(self.nobs-1)
155 |
156 | @property
157 | def sumsquares(self):
158 | return self.M2
159 |
160 | @property
161 | def sum(self):
162 | return self.mean*self.nobs
163 |
164 | @property
165 | def sum_weights(self):
166 | return self.sum
167 |
168 |
169 | @staticmethod
170 | def lcl_comp_means(descr1, descr2):
171 | """ Calclulates the t-test of the difference in means. Local version of statsmodels.stats.weightstats import CompareMeans
172 | :param descr1: DescrStatW-type object of sample statistics
173 | :param descr2: DescrStatW-type object of sample statistics
174 | """
175 | #from statsmodels.stats.weightstats import CompareMeans
176 | # Do statsmodels.CompareMeans with just summary stats.
177 | var_pooled = pooled_variances(np.array([descr1.var, descr2.var]), np.array([descr1.nobs, descr2.nobs]))
178 | stdm = np.sqrt(var_pooled * (1. / descr1.nobs + 1. / descr2.nobs)) # ~samplt std dev/sqrt(N)
179 | dof = descr1.nobs - 1 + descr2.nobs - 1
180 | tstat, pval = tstat_generic(descr1.mean, descr2.mean, stdm, dof)
181 | effect = descr1.mean - descr2.mean
182 | return Estimate(effect, pval)
183 |
--------------------------------------------------------------------------------
/src/SparseSC/utils/local_grad_daemon.py:
--------------------------------------------------------------------------------
1 | """
2 | for local testing of the daemon
3 | """
4 | # pylint: disable=differing-type-doc, differing-param-doc, missing-param-doc, missing-raises-doc, missing-return-doc
5 | import datetime
6 | import os
7 | import json
8 | import tempfile
9 | import atexit
10 | import subprocess
11 | import itertools
12 | import tarfile
13 | import io
14 | import sys
15 |
16 | import numpy as np
17 | from ..cli.scgrad import GradientDaemon, DAEMON_PID, DAEMON_FIFO
18 |
19 | from yaml import load, dump
20 |
21 | try:
22 | from yaml import CLoader as Loader, CDumper as Dumper
23 | except ImportError:
24 | from yaml import Loader, Dumper
25 |
26 | # pylint: disable=fixme, too-few-public-methods
27 | # pylint: disable=bad-continuation, invalid-name, protected-access, line-too-long
28 |
29 | try:
30 | input = raw_input # pylint: disable=redefined-builtin
31 | except NameError:
32 | pass
33 |
34 |
35 | _CONTAINER_OUTPUT_FILE = "output.yaml" # Standard Output file
36 | _GRAD_COMMON_FILE = "common.yaml"
37 | _GRAD_PART_FILE = "part.yaml"
38 |
39 | RETURN_FIFO = "/tmp/sc-return.fifo"
40 | if sys.platform!="win32": #mkfifo not available on Windows
41 | os.mkfifo(RETURN_FIFO) # pylint: disable=no-member
42 |
43 |
44 | def cleanup():
45 | """ clean up"""
46 | if sys.platform!="win32": #allow sphinx to build on windows w/o error
47 | os.remove(RETURN_FIFO)
48 |
49 |
50 | atexit.register(cleanup)
51 |
52 |
53 | class local_batch_daemon:
54 | """
55 | Client object for performing gradient calculations with azure batch
56 | """
57 |
58 | def __init__(self, common_data, K):
59 | subprocess.call(["python", "-m", "SparseSC.cli.scgrad", "start"])
60 | # CREATE THE RESPONSE FIFO
61 | # replace any missing values with environment variables
62 | self.common_data = common_data
63 | self.K = K
64 |
65 | # BUILT THE TEMPORARY FILE NAMES
66 | self.tmpDirManager = tempfile.TemporaryDirectory()
67 | self.tmpdirname = self.tmpDirManager.name
68 | print("Created temporary directory:", self.tmpdirname)
69 | self.GRAD_PART_FILE = os.path.join(self.tmpdirname, _GRAD_PART_FILE)
70 | self.CONTAINER_OUTPUT_FILE = os.path.join(self.tmpdirname, _CONTAINER_OUTPUT_FILE)
71 |
72 | # WRITE THE COMMON DATA TO FILE:
73 | with open(os.path.join(self.tmpdirname, _GRAD_COMMON_FILE), "w") as fh:
74 | fh.write(dump(self.common_data, Dumper=Dumper))
75 |
76 |
77 | #-- # A UTILITY FUNCTION
78 | #-- def tarify(x,name):
79 | #-- with tarfile.open(os.path.join(self.tmpdirname, '{}.tar.gz'.format(name)), mode='w:gz') as dest_file:
80 | #-- for i, k in itertools.product( range(len(x)), range(len(x[0]))):
81 | #-- fname = 'arr_{}_{}'.format(i,k)
82 | #-- array_bytes = x[i][k].tobytes()
83 | #-- info = tarfile.TarInfo(fname)
84 | #-- info.size = len(array_bytes)
85 | #-- dest_file.addfile(info,io.BytesIO(array_bytes) )
86 | #--
87 | #-- tarify(part_data["dA_dV_ki"],"dA_dV_ki")
88 | #-- tarify(part_data["dB_dV_ki"],"dB_dV_ki")
89 | #-- import pdb; pdb.set_trace()
90 |
91 |
92 | def stop(self):
93 | """
94 | stop the daemon
95 | """
96 | # pylint: disable=no-self-use
97 | subprocess.call(["python", "-m", "SparseSC.cli.scgrad", "stop"])
98 |
99 | def do_grad(self, part_data):
100 | """
101 | calculate the gradient
102 | """
103 | start_time = datetime.datetime.now().replace(microsecond=0)
104 | print("Gradient start time: {}".format(start_time))
105 |
106 | # WRITE THE PART DATA TO FILE
107 | with open(self.GRAD_PART_FILE, "w") as fh:
108 | fh.write(dump(part_data, Dumper=Dumper))
109 |
110 | print("Gradient A")
111 | #-- for key in part_data.keys():
112 | #-- with open(os.path.join(self.tmpdirname, key), "w") as fh:
113 | #-- fh.write(dump(part_data[key], Dumper=Dumper))
114 | #-- with open(os.path.join(self.tmpdirname, "item0.yaml"), "w") as fh: fh.write(dump(part_data["dA_dV_ki"][0][0], Dumper=Dumper))
115 | #-- with open(os.path.join(self.tmpdirname, "item0.bytes"), "wb") as fh: fh.write(part_data["dA_dV_ki"][0][0].tobytes())
116 | #-- import gzip
117 | #--
118 | #-- with gzip.open(os.path.join(self.tmpdirname, "item0_{}.gz".format(i)), "rb") as fh: matbytes = fh.read()
119 |
120 | dGamma0_dV_term2 = np.zeros(self.K)
121 | for k in range(self.K):
122 | print(k, end=" ")
123 | # SEND THE ARGS TO THE DAEMON
124 | with open(DAEMON_FIFO, "w") as df:
125 | df.write(
126 | json.dumps(
127 | [
128 | self.tmpdirname,
129 | RETURN_FIFO,
130 | k,
131 | ]
132 | )
133 | )
134 |
135 | # LISTEN FOR THE RESPONSE
136 | with open(RETURN_FIFO, "r") as rf:
137 | response = rf.read()
138 | if response != "0":
139 | raise RuntimeError("Something went wrong in the daemon: {}".format( response))
140 |
141 | with open(self.CONTAINER_OUTPUT_FILE, "r") as fh:
142 | dGamma0_dV_term2[k] = load(fh.read(), Loader=Loader)
143 |
144 | return dGamma0_dV_term2
145 |
146 |
--------------------------------------------------------------------------------
/src/SparseSC/utils/misc.py:
--------------------------------------------------------------------------------
1 | # Allow capturing output
2 | # Modified (to not capture stderr too) from https://stackoverflow.com/questions/5136611/
3 | import contextlib
4 | import sys
5 |
6 | from .print_progress import it_progressbar, it_progressmsg
7 |
8 | @contextlib.contextmanager
9 | def capture():
10 | STDOUT = sys.stdout
11 | try:
12 | sys.stdout = DummyFile()
13 | yield
14 | finally:
15 | sys.stdout = STDOUT
16 |
17 |
18 | class DummyFile(object):
19 | def write(self, x):
20 | pass
21 |
22 |
23 | @contextlib.contextmanager
24 | def capture_all():
25 | STDOUT, STDERR = sys.stdout, sys.stderr
26 | try:
27 | sys.stdout, sys.stderr = DummyFile(), DummyFile()
28 | yield
29 | finally:
30 | sys.stdout, sys.stderr = STDOUT, STDERR
31 |
32 |
33 | def par_map(part_fn, it, F, loop_verbose, n_multi=0, header="LOOP"):
34 | if n_multi>0:
35 | from multiprocessing import Pool
36 |
37 | with Pool(n_multi) as p:
38 | #p.map evals the it so can't use it_progressbar(it)
39 | if loop_verbose==1:
40 | rets = []
41 | print(header + ":")
42 | for ret in it_progressbar(p.imap(part_fn, it), count=F):
43 | rets.append(ret)
44 | elif loop_verbose==2:
45 | rets = []
46 | for ret in it_progressmsg(p.imap(part_fn, it), prefix=header, count=F):
47 | rets.append(ret)
48 | else:
49 | rets = p.map(part_fn, it)
50 | else:
51 | if loop_verbose==1:
52 | print(header + ":")
53 | it = it_progressbar(it, count=F)
54 | elif loop_verbose==2:
55 | it = it_progressmsg(it, prefix=header, count=F)
56 | rets = list(map(part_fn, it))
57 | return rets
58 |
59 |
60 | class PreDemeanScaler:
61 | """
62 | Units are defined by rows and cols are "pre" and "post" separated.
63 | Demeans each row by the "pre" mean.
64 | """
65 |
66 | # maybe fit should just take Y and T0 (in init())?
67 | # Try in sklearn.pipeline with fit() for that and predict (on default Y_post)
68 | # might want wrappers around fit to make that work fine with pipeline (given its standard arguments).
69 | # maybe call the vars X rather than Y?
70 | def __init__(self):
71 | self.means = None
72 | # self.T0 = T0
73 |
74 | def fit(self, Y):
75 | """
76 | Ex. fit(Y.iloc[:,0:T0])
77 | """
78 | import numpy as np
79 |
80 | self.means = np.mean(Y, axis=1)
81 |
82 | def transform(self, Y):
83 | return (Y.T - self.means).T
84 |
85 | def inverse_transform(self, Y):
86 | return (Y.T + self.means).T
87 |
88 |
89 | def _ensure_good_donor_pool(custom_donor_pool, control_units):
90 | N0 = custom_donor_pool.shape[1]
91 | custom_donor_pool_c = custom_donor_pool[control_units, :]
92 | for i in range(N0):
93 | custom_donor_pool_c[i, i] = False
94 | custom_donor_pool[control_units, :] = custom_donor_pool_c
95 | return custom_donor_pool
96 |
97 |
98 | def _get_fit_units(model_type, control_units, treated_units, N):
99 | if model_type == "retrospective":
100 | return control_units
101 | elif model_type == "prospective":
102 | return range(N)
103 | elif model_type == "prospective-restricted:":
104 | return treated_units
105 | # model_type=="full":
106 | return range(N) # same as control_units
107 |
108 |
--------------------------------------------------------------------------------
/src/SparseSC/utils/ols_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import statsmodels.api as sm
3 |
4 | import SparseSC as SC
5 |
6 | #All data in the long format (rows mark id-time)
7 | def OLS_avg_AA_simple(N=100, T=10, K=0, treat_ratio=.1, T0=None):
8 | #0-based index for i and t. each units panel is together (fine index is time, gross is unit)
9 | Y = np.random.normal(0,1,(N*T))
10 | id = np.tile(np.array(range(T)), N)
11 | time = np.repeat(np.array(range(N)), T)
12 | if T0 is None:
13 | T0 = int(T/2)
14 | return OLS_avg_AA(Y, id, time, T0, N, T, treat_ratio)
15 |
16 | def OLS_avg_AA(Y, id, time, post_start, N, T, treat_ratio, num_sim=1000, X=None, level=0.95):
17 | Const = np.ones((Y.shape[0], 1))
18 | Post = np.expand_dims((time>=post_start).astype(int), axis=1)
19 | X_base = np.hstack((Const, Post))
20 | #X_base = sm.add_constant(X_base)
21 | if X is not None:
22 | X_base = np.hstack(X_base, X)
23 | alpha = 1-level
24 |
25 | tes = np.empty((num_sim))
26 | ci_ls = np.empty((num_sim))
27 | ci_us = np.empty((num_sim))
28 | N1 = int(N*treat_ratio)
29 | sel_idx = np.concatenate((np.repeat(1,N1), np.repeat(0,N-N1)))
30 | for s in range(num_sim):
31 | np.random.shuffle(sel_idx)
32 | Treat = np.expand_dims(np.repeat(sel_idx, T), axis=1)
33 | D = Treat * Post
34 | X = np.hstack((X_base,Treat, D))
35 | model = sm.OLS(Y,X, hasconst=True)
36 | results = model.fit()
37 | tes[s] = results.params[3]
38 | [ci_ls[s], ci_us[s]] = results.conf_int(alpha, cols=[3])[0]
39 |
40 | stats = SC.utils.metrics_utils.simulation_eval(tes, ci_ls, ci_us, true_effect=0)
41 | print(stats)
42 |
43 |
44 | #Do separate effects for each post treatment period?
45 | #def OLS_AA_vec(Y, id, time, treat_ratio, post_times, X=None):
46 | # for post_time in post_times:
47 | # Post_ind_t = time==post_time
48 | # X_base = np.hstack((X_base, Post_ind_t))
49 | #
50 | #def OLS_AA_vec_specific(Y, X_base, sel_idx):
51 | # base_init_len = X_base.shape[1]
52 | # Treat = np.vstack(id_pre==sel_id, id_post==sel_id)
53 | # X_base = np.hstack(X_base, Treat)
54 | # for post_idx in 2:base_init_len:
55 | # D_t = Treat and X_base[:,post_idx]
56 | # X = np.hstack((X,D_t))
57 |
58 | # model = sm.OLS(Y,X, hasconst=True)
59 | # results = model.fit()
60 | # results.params[base_init_len+1:]
61 |
--------------------------------------------------------------------------------
/src/SparseSC/utils/print_progress.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """ A utility for displaying a progress bar
3 | """
4 | # https://gist.github.com/aubricus/f91fb55dc6ba5557fbab06119420dd6a
5 | import sys
6 | import datetime
7 | SparseSC_prev_iteration = 0
8 | def print_progress(iteration, total=100, prefix='', suffix='', decimals=1, bar_length=100, file=sys.stdout):
9 | """
10 | Call in a loop to create progress bar. If this isn't a tty-like (re-writable) output then
11 | no percent-complete number will be outputted.
12 | @params:
13 | iteration - Required : current iteration (Int) (typically 1 is first possible)
14 | total - Required : total iterations (Int)
15 | prefix - Optional : prefix string (Str)
16 | suffix - Optional : suffix string (Str)
17 | decimals - Optional : positive number of decimals in percent complete (Int)
18 | bar_length - Optional : character length of bar (Int)
19 | file - Optional : file descriptor for output
20 | """
21 | fill_char = '>' # Note that the "█" character is not compatible with every platform...
22 | empty_char = '-'
23 | filled_length = int(round(bar_length * iteration / float(total)))
24 | if file.isatty():
25 | str_format = "{0:." + str(decimals) + "f}"
26 | percents = str_format.format(100 * (iteration / float(total)))
27 |
28 | progress_bar = fill_char * filled_length + empty_char * (bar_length - filled_length)
29 |
30 | file.write('\r%s |%s| %s%s %s' % (prefix, progress_bar, percents, '%', suffix))
31 |
32 | if iteration == total:
33 | file.write('\n')
34 | file.flush()
35 | else: # Can't do interactive re-writing (e.g. w/ the /r character)
36 | global SparseSC_prev_iteration
37 | if iteration == 1:
38 | file.write('%s |' % (prefix))
39 | else:
40 | if iteration <= SparseSC_prev_iteration:
41 | SparseSC_prev_iteration = 0
42 | prev_fill_length = int(round(bar_length * SparseSC_prev_iteration / float(total)))
43 | progress_bar = fill_char * (filled_length-prev_fill_length)
44 |
45 | file.write('%s' % (progress_bar))
46 |
47 | if iteration == total:
48 | file.write('| %s\n' % (suffix))
49 | SparseSC_prev_iteration = iteration
50 | file.flush()
51 |
52 | def it_progressmsg(it, prefix="Loop", file=sys.stdout, count=None):
53 | for i, item in enumerate(it):
54 | if count is None:
55 | file.write(f"{prefix}: {i}\n")
56 | else:
57 | file.write(f"{prefix}: {i} of {count}\n")
58 | file.flush()
59 | yield item
60 | file.write(prefix + ": FINISHED\n")
61 | file.flush()
62 |
63 | #Similar to above, but you wrap an iterator
64 | def it_progressbar(it, prefix="", suffix='', decimals=1, bar_length=100, file=sys.stdout, count=None):
65 | if count is None:
66 | count = len(it)
67 | def show(j):
68 | fill_char = '>'
69 | empty_char = '-'
70 | x = int(bar_length*j/count)
71 | str_format = "{0:." + str(decimals) + "f}"
72 | percents = str_format.format(100 * (j / float(count)))
73 | progress_bar = fill_char * x + empty_char * (bar_length - x)
74 | if file.isatty():
75 | file.write("%s |%s| %s%s %s\r" % (prefix, progress_bar, percents, '%', suffix))
76 | file.flush()
77 | else: # Can't do interactive re-writing (e.g. w/ the /r character)
78 | if j == 0:
79 | file.write('%s |' % (prefix))
80 | else:
81 | prev_x = int(bar_length*(j-1)/count)
82 | progress_bar = fill_char * (x-prev_x)
83 | file.write('%s' % (progress_bar))
84 | if j == count:
85 | file.write('| %s' % (suffix))
86 | show(0)
87 | for i, item in enumerate(it):
88 | yield item
89 | show(i+1)
90 | file.write("\n")
91 | file.flush()
92 |
93 | def log_if_necessary(str_note_start, verbose):
94 | str_note = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + ": " + str_note_start
95 | if verbose>0:
96 | print(str_note)
97 | if verbose>1:
98 | print_memory_snapshot(extra_str=str_note)
99 |
100 |
101 | def print_memory_snapshot(extra_str=None):
102 | import os
103 | import tracemalloc
104 | log_file = os.getenv("SparseSC_log_file") #None if non-existant
105 | snapshot = tracemalloc.take_snapshot()
106 | top_stats = snapshot.statistics('lineno')
107 | if log_file is not None:
108 | log_file = open(log_file, "a")
109 | if extra_str is not None:
110 | print(extra_str, file=log_file)
111 | limit=10
112 | print("[ Top 10 ] ", file=log_file)
113 | for stat in top_stats[:limit]:
114 | print(stat, file=log_file)
115 | other = top_stats[limit:]
116 | if other:
117 | size = sum(stat.size for stat in other)
118 | print("%s other: %.1f KiB" % (len(other), size / 1024), file=log_file)
119 | total = sum(stat.size for stat in top_stats)
120 | print("Total allocated size: %.1f KiB" % (total / 1024), file=log_file)
121 | #if old_snapshot is not None:
122 | # diff_stats = snapshot.compare_to(old_snapshot, 'lineno')
123 | if log_file is not None:
124 | log_file.close()
125 |
--------------------------------------------------------------------------------
/src/SparseSC/utils/sub_matrix_inverse.py:
--------------------------------------------------------------------------------
1 | """ In the leave-one-out method with larger sample sizes most of the time is spent in calculating
2 | A.I.dot(B) using np.linalg.solve (which is much faster than the brute force method i.e. `A.I.dot(B)`).
3 | For example, with 200 units, about about 95% of the time is spent in this line.
4 |
5 | However we can take advantage of the fact that we're calculating the
6 | inverse of N matrices which are all sub-matrices of a common matrix by inverting the common matrix
7 | and calculating each subset inverse from there.
8 |
9 | https://math.stackexchange.com/a/208021/252693
10 |
11 |
12 | TODO: this could be made a bit faster by passing in the indexes (k_rng,
13 | k_rng2) instead of re-building them
14 | """
15 | # pylint: skip-file
16 | import numpy as np
17 |
18 | def subinv(x,eps=None):
19 | """ Given an matrix (x), calculate all the inverses of leave-one-out sub-matrices.
20 |
21 | :param x: a square matrix for which to find the inverses of all it's leave one out sub-matrices.
22 | :param eps: If not None, used to assert that the each calculated
23 | sub-matrix-inverse is within eps of the brute force calculation.
24 | Testing only, this slows the process way down since the inverse of
25 | each sub-matrix is calculated by the brute force method. Typically
26 | set to a multiple of `np.finfo(float).eps`
27 | """
28 | # handy constant for indexing
29 | xi = x.I
30 | N = x.shape[0]
31 | rng = np.arange(N)
32 | out = [None,] * N
33 | for k in range(N):
34 | k_rng = rng[rng != k]
35 | out[k] = xi[np.ix_(k_rng,k_rng)] - xi[k_rng,k].dot(xi[k,k_rng])/xi[k,k]
36 | if eps is not None:
37 | if not (abs(out[k] - x[np.ix_(k_rng,k_rng)].I) < eps).all():
38 | raise RuntimeError("Fast and brute force methods were not within epsilon (%s) for sub-matrix k = %s; max difference = %s" %
39 | (eps, k, abs(out[k] - x[np.ix_(k_rng,k_rng)].I).max(), ) )
40 | return out
41 |
42 | def subinv_k(xi,k,eps=None):
43 | """ Given an matrix (x), calculate all the inverses of leave-one-out sub-matrices.
44 |
45 | :param x: a square matrix for which to find the inverses of all it's leave one out sub-matrices.
46 | :param k: the column and row to leave out
47 | :param eps: If not None, used to assert that the each calculated
48 | sub-matrix-inverse is within eps of the brute force calculation.
49 | Testing only, this slows the process way down since the inverse of
50 | each sub-matrix is calculated by the brute force method. Typically
51 | set to a multiple of `np.finfo(float).eps`
52 | """
53 | # handy constant for indexing
54 | N = xi.shape[0]
55 | rng = np.arange(N)
56 | k_rng = rng[rng != k]
57 | out = xi[np.ix_(k_rng,k_rng)] - xi[k_rng,k].dot(xi[k,k_rng])/xi[k,k]
58 | if eps is not None:
59 | if not (abs(out[k] - x[np.ix_(k_rng,k_rng)].I) < eps).all():
60 | raise RuntimeError("Fast and brute force methods were not within epsilon (%s) for sub-matrix k = %s; max difference = %s" % (eps, k, abs(out[k] - x[np.ix_(k_rng,k_rng)].I).max(), ) )
61 | return out
62 |
63 |
64 |
65 | # ---------------------------------------------
66 | # single sub-matrix
67 | # ---------------------------------------------
68 |
69 | if __name__ == "__main__":
70 |
71 | import time
72 |
73 | n = 200
74 | B = np.matrix(np.random.random((n,n,)))
75 |
76 |
77 | n = 5
78 | p = 3
79 | a = np.matrix(np.random.random((n,p,)))
80 | v = np.diag(np.random.random((p,)))
81 | x = a.dot(v).dot(a.T)
82 | x.dot(x.I)
83 |
84 | x = np.matrix(np.random.random((n,n,)))
85 | x.dot(x.I)
86 |
87 | xi = x.I
88 |
89 |
90 | B = np.matrix(np.random.random((n,n,)))
91 |
92 | k = np.arange(2)
93 | N = xi.shape[0]
94 | rng = np.arange(N)
95 | k_rng = rng[np.logical_not(np.isin(rng,k))]
96 |
97 | out = xi[np.ix_(k_rng,k_rng)] - xi[np.ix_(k_rng,k)].dot(xi[np.ix_(k,k_rng)])/np.linalg.det(xi[np.ix_(k,k)])
98 |
99 | for i in range(100):
100 | # create a sub-matrix that meets the matching criteria
101 | x = np.matrix(np.random.random((n,n,)))
102 | try:
103 | zz = subinv(x,10e-10)
104 | break
105 | except:
106 | pass
107 | else:
108 | print("Failed to generate a %sx%s matrix whose inverses are all within %s of the quick method")
109 |
110 |
111 | k = 5
112 | n_tests = 1000
113 |
114 | # =======================
115 | t0 = time.time()
116 | for i in range(n_tests):
117 | _N = xi.shape[0]
118 | rng = np.arange(_N)
119 | k_rng = rng[rng != k]
120 | k_rng2 = np.ix_(k_rng,k_rng)
121 | zz = x[k_rng2].I.dot(B[k_rng2])
122 |
123 | t1 = time.time()
124 | slow_time = t1 - t0
125 | print("A.I.dot(B): brute force time (N = %s): %s"% (n,t1 - t0))
126 |
127 | # =======================
128 | t0 = time.time()
129 | for i in range(n_tests):
130 | # make the comparison fair
131 | if i % n == 0:
132 | xi = x.I
133 | zz = subinv_k(xi,k).dot(B[k_rng2])
134 |
135 | t1 = time.time()
136 | fast_time = t1 - t0
137 | print("A.I.dot(B): quick time (N = %s): %s"% (n,t1 - t0))
138 |
139 | # =======================
140 | t0 = time.time()
141 | for i in range(n_tests):
142 | _N = xi.shape[0]
143 | rng = np.arange(_N)
144 | k_rng = rng[rng != k]
145 | k_rng2 = np.ix_(k_rng,k_rng)
146 | zz = np.linalg.solve(x[k_rng2],B[k_rng2])
147 |
148 | t1 = time.time()
149 | fast_time = t1 - t0
150 | print("A.I.dot(B): np.linalg.solve time (N = %s): %s"% (n,t1 - t0))
151 |
152 | # ---------------------------------------------
153 | # ---------------------------------------------
154 |
155 | t0 = time.time()
156 | for i in range(100):
157 | zz = subinv(x,10e-10)
158 |
159 | t1 = time.time()
160 | slow_time = t1 - t0
161 | print("Full set of inverses: brute force time (N = %s): %s", (n,t1 - t0))
162 |
163 | t0 = time.time()
164 | for i in range(100):
165 | zz = subinv(x)
166 |
167 | t1 = time.time()
168 | fast_time = t1 - t0
169 | print("Full set of inverses: quick time (N = %s): %s", (n,t1 - t0))
170 |
171 | # ---------------------------------------------
172 | # ---------------------------------------------
173 |
174 |
--------------------------------------------------------------------------------
/src/SparseSC/utils/warnings.py:
--------------------------------------------------------------------------------
1 | """
2 | Warnings used throughout the module. Exported and inhereited from a commmon
3 | warning class so as to facilitate the filtering of warnings
4 | """
5 | # pylint: disable=too-few-public-methods, missing-docstring
6 | class SparseSCWarning(RuntimeWarning):pass
7 | class UnpenalizedRecords(SparseSCWarning):pass
8 |
9 |
10 |
--------------------------------------------------------------------------------
/src/SparseSC/weights.py:
--------------------------------------------------------------------------------
1 | """
2 | Presents a unified API for the various weights methods
3 | """
4 | from SparseSC.fit_loo import loo_weights
5 | from SparseSC.fit_ct import ct_weights
6 | from SparseSC.fit_fold import fold_weights
7 | import numpy as np
8 |
9 |
10 | def weights(X, X_treat=None, grad_splits=None, custom_donor_pool=None, **kwargs):
11 | """ Calculate synthetic control weights
12 | """
13 |
14 | # PARAMETER QC
15 | try:
16 | X = np.float64(X)
17 | except ValueError:
18 | raise TypeError("X is not coercible to float64")
19 |
20 | X = np.asmatrix(X) # this needs to be deprecated properly -- bc Array.dot(Array) != matrix(Array).dot(matrix(Array)) -- not even close !!!
21 | if X_treat is not None:
22 | # weight for the control units against the remaining controls:
23 |
24 | if X_treat.shape[1] == 0:
25 | raise ValueError("X_treat.shape[1] == 0")
26 |
27 | # PARAMETER QC
28 | try:
29 | X_treat = np.float64(X_treat)
30 | except ValueError:
31 | raise ValueError("X_treat is not coercible to float64")
32 |
33 | # this needs to be deprecated properly -- bc Array.dot(Array) != matrix(Array).dot(matrix(Array)) -- not even close !!!
34 | X_treat = np.asmatrix(X_treat)
35 |
36 | if X_treat.shape[1] == 0:
37 | raise ValueError("X_treat.shape[1] == 0")
38 |
39 | # FIT THE V-MATRIX AND POSSIBLY CALCULATE THE w_pen
40 | # note that the weights, score, and loss function value returned here
41 | # are for the in-sample predictions
42 | return ct_weights(
43 | X=np.vstack((X, X_treat)),
44 | control_units=np.arange(X.shape[0]),
45 | treated_units=np.arange(X_treat.shape[0]) + X.shape[0],
46 | custom_donor_pool=custom_donor_pool,
47 | **kwargs
48 | )
49 |
50 | # === X_treat is None: ===
51 |
52 | if grad_splits is not None:
53 | return fold_weights(X=X, grad_splits=grad_splits, **kwargs)
54 |
55 | # === X_treat is None and grad_splits is None: ===
56 |
57 | # weight for the control units against the remaining controls
58 | return loo_weights(
59 | X=X,
60 | control_units=np.arange(X.shape[0]),
61 | treated_units=np.arange(X.shape[0]),
62 | custom_donor_pool=custom_donor_pool,
63 | **kwargs
64 | )
65 |
--------------------------------------------------------------------------------
/static/controls-only-pre-and-post.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/static/controls-only-pre-and-post.png
--------------------------------------------------------------------------------
/static/pre-only-controls-and-treated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/static/pre-only-controls-and-treated.png
--------------------------------------------------------------------------------
/test/AzureBatch/README.md:
--------------------------------------------------------------------------------
1 | # Batch config testing
2 |
3 | ## Install Batch Config
4 |
5 | ```shell
6 | pip install git+https://github.com/jdthorpe/batch-configo
7 | ```
8 |
9 | ## Gather required credentials
10 |
11 | These commands assume the CMD terminal. For other terminals, [see here](https://jdthorpe.github.io/super-batch-docs/create-resources).
12 |
13 | ```bat
14 | az login
15 | # optionally: az account set -s xxxxxxx-xxxx-xxxx-xxxx-xxxxxxxx
16 |
17 | set name=sparsescbatchtesting
18 | set rgname=SparseSC-batch-testing
19 | set BATCH_ACCOUNT_NAME=%name%
20 | for /f %i in ('az batch account keys list -n %name% -g %rgname% --query primary') do @set BATCH_ACCOUNT_KEY=%i
21 | for /f %i in ('az batch account show -n %name% -g %rgname% --query accountEndpoint') do @set BATCH_ACCOUNT_ENDPOINT=%i
22 | for /f %i in ('az storage account keys list -n %name% --query [0].value') do @set STORAGE_ACCOUNT_KEY=%i
23 | for /f %i in ('az storage account show-connection-string --name %name% --query connectionString') do @set STORAGE_ACCOUNT_CONNECTION_STRING=%i
24 |
25 | # clean up the quotes
26 | set BATCH_ACCOUNT_KEY=%BATCH_ACCOUNT_KEY:"=%
27 | set BATCH_ACCOUNT_ENDPOINT=%BATCH_ACCOUNT_ENDPOINT:"=%
28 | set STORAGE_ACCOUNT_KEY=%STORAGE_ACCOUNT_KEY:"=%
29 | set STORAGE_ACCOUNT_CONNECTION_STRING=%STORAGE_ACCOUNT_CONNECTION_STRING:"=%
30 | ```
31 |
32 | ## run the tests
33 |
34 | ```bat
35 | cd test\AzureBatch
36 | rm -rf data
37 | python test_batch_build.py
38 | python test_batch_run.py
39 | python test_batch_aggregate.py
40 | ```
41 |
--------------------------------------------------------------------------------
/test/AzureBatch/test_batch_aggregate.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------------------------------
2 | # Programmer: Jason Thorpe
3 | # Date 1/25/2019 3:34:02 PM
4 | # Language: Python (.py) Version 2.7 or 3.5
5 | # Usage:
6 | #
7 | # Test all model types
8 | #
9 | # \SpasrseSC > python -m unittest test/test_fit.py
10 | #
11 | # Test a specific model type (e.g. "prospective-restricted"):
12 | #
13 | # \SpasrseSC > python -m unittest test.test_fit.TestFit.test_retrospective
14 | #
15 | # --------------------------------------------------------------------------------
16 | # pylint: disable=multiple-imports, missing-docstring
17 | """
18 | usage
19 |
20 | az login
21 |
22 | name="sparsesctest"
23 | location="westus2"
24 |
25 | export BATCH_ACCOUNT_NAME=$name
26 | export BATCH_ACCOUNT_KEY=$(az batch account keys list -n $name -g $name --query primary)
27 | export BATCH_ACCOUNT_URL="https://$name.$location.batch.azure.com"
28 | export STORAGE_ACCOUNT_NAME=$name
29 | export STORAGE_ACCOUNT_KEY=$(az storage account keys list -n $name --query [0].value)
30 |
31 | """
32 |
33 |
34 | from __future__ import print_function # for compatibility with python 2.7
35 | import numpy as np
36 | import sys, os, random
37 | import unittest
38 | import warnings
39 |
40 | try:
41 | from SparseSC import fit
42 | except ImportError:
43 | raise RuntimeError("SparseSC is not installed. use 'pip install -e .' to install")
44 | from scipy.optimize.linesearch import LineSearchWarning
45 | from SparseSC.utils.AzureBatch import aggregate_batch_results
46 |
47 |
48 | class TestFit(unittest.TestCase):
49 | def setUp(self):
50 |
51 | random.seed(12345)
52 | np.random.seed(101101001)
53 | control_units = 50
54 | treated_units = 20
55 | features = 10
56 | targets = 5
57 |
58 | self.X = np.random.rand(control_units + treated_units, features)
59 | self.Y = np.random.rand(control_units + treated_units, targets)
60 | self.treated_units = np.arange(treated_units)
61 |
62 | @classmethod
63 | def run_test(cls, obj, model_type, verbose=False):
64 | if verbose:
65 | print("Calling fit with `model_type = '%s'`..." % (model_type,), end="")
66 | sys.stdout.flush()
67 |
68 | batchdir = os.path.join(
69 | os.path.dirname(os.path.realpath(__file__)), "data", "batchTest"
70 | )
71 | assert os.path.exists(batchdir), "Batch Directory '{}' does not exist".format(
72 | batchdir
73 | )
74 |
75 | with warnings.catch_warnings():
76 | warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
77 | warnings.filterwarnings("ignore", category=LineSearchWarning)
78 | try:
79 | verbose = 0
80 | model_a = fit(
81 | features=obj.X,
82 | targets=obj.Y,
83 | model_type=model_type,
84 | treated_units=obj.treated_units
85 | if model_type
86 | in ("retrospective", "prospective", "prospective-restricted")
87 | else None,
88 | # KWARGS:
89 | print_path=verbose,
90 | stopping_rule=1,
91 | progress=0,
92 | grid_length=5,
93 | min_iter=-1,
94 | tol=1,
95 | verbose=0,
96 | )
97 |
98 | model_b = aggregate_batch_results(
99 | batchDir=batchdir
100 | ) # , batch_client_config="sg_daemon"
101 |
102 | assert np.all(
103 | np.abs(model_a.scores - model_b.scores) < 1e-14
104 | ), "model scores are not within rounding error"
105 |
106 | if verbose:
107 | print("DONE")
108 | except LineSearchWarning:
109 | pass
110 | except PendingDeprecationWarning:
111 | pass
112 | except Exception as exc: # pylint: disable=broad-except
113 | print(
114 | "Failed with %s(%s)"
115 | % (exc.__class__.__name__, getattr(exc, "message", ""))
116 | )
117 | raise exc
118 |
119 | def test_retrospective(self):
120 | TestFit.run_test(self, "retrospective")
121 |
122 |
123 | # -- def test_prospective(self):
124 | # -- TestFit.run_test(self, "prospective")
125 | # --
126 | # -- def test_prospective_restrictive(self):
127 | # -- # Catch the LineSearchWarning silently, but allow others
128 | # --
129 | # -- TestFit.run_test(self, "prospective-restricted")
130 | # --
131 | # -- def test_full(self):
132 | # -- TestFit.run_test(self, "full")
133 |
134 |
135 | if __name__ == "__main__":
136 | unittest.main()
137 |
--------------------------------------------------------------------------------
/test/AzureBatch/test_batch_build.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------------------------------
2 | # Programmer: Jason Thorpe
3 | # Date 1/25/2019 3:34:02 PM
4 | # Language: Python (.py) Version 2.7 or 3.5
5 | # Usage:
6 | #
7 | # Test all model types
8 | #
9 | # \SpasrseSC > python -m unittest test/test_fit.py
10 | #
11 | # Test a specific model type (e.g. "prospective-restricted"):
12 | #
13 | # \SpasrseSC > python -m unittest test.test_fit.TestFit.test_retrospective
14 | #
15 | # --------------------------------------------------------------------------------
16 |
17 | from __future__ import print_function # for compatibility with python 2.7
18 | import numpy as np
19 | import sys, os, random
20 | import unittest
21 | import warnings
22 | from scipy.optimize.linesearch import LineSearchWarning
23 |
24 | try:
25 | from SparseSC import fit
26 | except ImportError:
27 | raise RuntimeError("SparseSC is not installed. use 'pip install -e .' to install")
28 |
29 |
30 | class TestFit(unittest.TestCase):
31 | def setUp(self):
32 |
33 | random.seed(12345)
34 | np.random.seed(101101001)
35 | control_units = 50
36 | treated_units = 20
37 | features = 10
38 | targets = 5
39 |
40 | self.X = np.random.rand(control_units + treated_units, features)
41 | self.Y = np.random.rand(control_units + treated_units, targets)
42 | self.treated_units = np.arange(treated_units)
43 |
44 | @classmethod
45 | def run_test(cls, obj, model_type, verbose=False):
46 | if verbose:
47 | print("Calling fit with `model_type = '%s'`..." % (model_type,), end="")
48 | sys.stdout.flush()
49 |
50 | batchdir = os.path.join(
51 | os.path.dirname(os.path.realpath(__file__)), "data", "batchTest"
52 | )
53 | print("dumping batch artifacts to: {}'".format(batchdir))
54 |
55 | with warnings.catch_warnings():
56 | warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
57 | warnings.filterwarnings("ignore", category=LineSearchWarning)
58 | try:
59 | fit(
60 | features=obj.X,
61 | targets=obj.Y,
62 | model_type=model_type,
63 | treated_units=obj.treated_units
64 | if model_type
65 | in ("retrospective", "prospective", "prospective-restricted")
66 | else None,
67 | # KWARGS:
68 | print_path=False,
69 | stopping_rule=1,
70 | progress=verbose,
71 | grid_length=5,
72 | min_iter=-1,
73 | tol=1,
74 | verbose=0,
75 | batchDir=batchdir,
76 | )
77 | if verbose:
78 | print("DONE")
79 | except LineSearchWarning:
80 | pass
81 | except PendingDeprecationWarning:
82 | pass
83 | except Exception as exc: # pylint: disable=broad-except
84 | print(
85 | "Failed with %s(%s)"
86 | % (exc.__class__.__name__, getattr(exc, "message", ""))
87 | )
88 | raise exc
89 |
90 | def test_retrospective(self):
91 | TestFit.run_test(self, "retrospective")
92 |
93 |
94 | # -- def test_prospective(self):
95 | # -- TestFit.run_test(self, "prospective")
96 | # --
97 | # -- def test_prospective_restrictive(self):
98 | # -- # Catch the LineSearchWarning silently, but allow others
99 | # --
100 | # -- TestFit.run_test(self, "prospective-restricted")
101 | # --
102 | # -- def test_full(self):
103 | # -- TestFit.run_test(self, "full")
104 |
105 |
106 | if __name__ == "__main__":
107 | # t = TestFit()
108 | # t.setUp()
109 | # t.test_retrospective()
110 | unittest.main()
111 |
--------------------------------------------------------------------------------
/test/AzureBatch/test_batch_run.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------------------------------
2 | # Programmer: Jason Thorpe
3 | # Date 1/25/2019 3:34:02 PM
4 | # Language: Python (.py) Version 2.7 or 3.5
5 | # Usage:
6 | #
7 | # Test all model types
8 | #
9 | # \SpasrseSC > python -m unittest test/test_fit.py
10 | #
11 | # Test a specific model type (e.g. "prospective-restricted"):
12 | #
13 | # \SpasrseSC > python -m unittest test.test_fit.TestFit.test_retrospective
14 | #
15 | # --------------------------------------------------------------------------------
16 | # pylint: disable=multiple-imports, missing-docstring, no-self-use
17 | """
18 | USAGE (CMD):
19 |
20 | az login
21 |
22 | set name=sparsescbatchtesting
23 | set rgname=SparseSC-batch-testing
24 | set BATCH_ACCOUNT_NAME=%name%
25 | for /f %i in ('az batch account keys list -n %name% -g %rgname% --query primary') do @set BATCH_ACCOUNT_KEY=%i
26 | for /f %i in ('az batch account show -n %name% -g %rgname% --query accountEndpoint') do @set BATCH_ACCOUNT_ENDPOINT=%i
27 | for /f %i in ('az storage account keys list -n %name% --query [0].value') do @set STORAGE_ACCOUNT_KEY=%i
28 | for /f %i in ('az storage account show-connection-string --name %name% --query connectionString') do @set STORAGE_ACCOUNT_CONNECTION_STRING=%i
29 |
30 | # clean up the quotes
31 | set BATCH_ACCOUNT_KEY=%BATCH_ACCOUNT_KEY:"=%
32 | set BATCH_ACCOUNT_ENDPOINT=%BATCH_ACCOUNT_ENDPOINT:"=%
33 | set STORAGE_ACCOUNT_KEY=%STORAGE_ACCOUNT_KEY:"=%
34 | set STORAGE_ACCOUNT_CONNECTION_STRING=%STORAGE_ACCOUNT_CONNECTION_STRING:"=%
35 |
36 | cd test\AzureBatch
37 | rm -rf data
38 | python test_batch_build.py
39 | python test_batch_run.py
40 | python test_batch_aggregate.py
41 | """
42 |
43 | from __future__ import print_function # for compatibility with python 2.7
44 | import os, unittest, datetime
45 | from os.path import join, realpath, dirname, exists
46 | from super_batch import Client
47 |
48 | from SparseSC.utils.AzureBatch import (
49 | DOCKER_IMAGE_NAME,
50 | create_job,
51 | )
52 |
53 |
54 | class TestFit(unittest.TestCase):
55 | def test_retrospective_no_wait(self):
56 | """
57 | test the no-wait and load_results API
58 | """
59 |
60 | name = os.getenv("name")
61 | if name is None:
62 | raise RuntimeError(
63 | "Please create an environment variable called 'name' as en the example docs"
64 | )
65 | batch_dir = join(dirname(realpath(__file__)), "data", "batchTest")
66 | assert exists(batch_dir), "Batch Directory '{}' does not exist".format(
67 | batch_dir
68 | )
69 |
70 | timestamp = datetime.datetime.utcnow().strftime("%H%M%S")
71 |
72 | batch_client = Client(
73 | POOL_ID=name,
74 | POOL_LOW_PRIORITY_NODE_COUNT=5,
75 | POOL_VM_SIZE="STANDARD_A1_v2",
76 | JOB_ID=name + timestamp,
77 | BLOB_CONTAINER_NAME=name,
78 | BATCH_DIRECTORY=batch_dir,
79 | DOCKER_IMAGE=DOCKER_IMAGE_NAME,
80 | )
81 | create_job(batch_client, batch_dir)
82 |
83 | batch_client.run(wait=False)
84 | batch_client.load_results()
85 |
86 | def test_retrospective(self):
87 |
88 | name = os.getenv("name")
89 | if name is None:
90 | raise RuntimeError(
91 | "Please create an environment variable called 'name' as en the example docs"
92 | )
93 | batch_dir = join(dirname(realpath(__file__)), "data", "batchTest")
94 | assert exists(batch_dir), "Batch Directory '{}' does not exist".format(
95 | batch_dir
96 | )
97 |
98 | timestamp = datetime.datetime.utcnow().strftime("%H%M%S")
99 |
100 | batch_client = Client(
101 | POOL_ID=name,
102 | POOL_LOW_PRIORITY_NODE_COUNT=5,
103 | POOL_VM_SIZE="STANDARD_A1_v2",
104 | JOB_ID=name + timestamp,
105 | BLOB_CONTAINER_NAME=name,
106 | BATCH_DIRECTORY=batch_dir,
107 | DOCKER_IMAGE=DOCKER_IMAGE_NAME,
108 | )
109 |
110 | create_job(batch_client, batch_dir)
111 | batch_client.run()
112 |
113 |
114 | if __name__ == "__main__":
115 | unittest.main()
116 |
--------------------------------------------------------------------------------
/test/CausalImpact_test.R:
--------------------------------------------------------------------------------
1 | library(CausalImpact)
2 |
3 | set.seed(1)
4 | x1 <- 100 + arima.sim(model = list(ar = 0.999), n = 100)
5 | x2 <- 100 + arima.sim(model = list(ar = 0.999), n = 100)
6 | x3 <- 100 + arima.sim(model = list(ar = 0.999), n = 100)
7 | x4 <- 100 + arima.sim(model = list(ar = 0.999), n = 100)
8 | y <- 1.2 * x1 + rnorm(100) + .8*x2 + .7*x3 + .6*x4
9 | y[71:100] <- y[71:100] + 10
10 | data <- cbind(y, x1, x2, x3, x4)
11 |
12 | pre.period <- c(1, 3) #70
13 | post.period <- c(4, 100)
14 |
15 | #impact <- CausalImpact(data, pre.period, post.period, alpha = 0.05)
16 | impact <- CausalImpact(data, pre.period, post.period)
17 |
18 | plot(impact)
19 | summary(impact)
20 |
21 | plot(impact$model$bsts.model, "coefficients")
22 | impact$summary$AbsEffect
23 | impact$summary$AbsEffect.lower
24 | impact$summary$AbsEffect.upper
25 | s = impact$series
26 |
27 | summary(lm(s$point.pred~x1+x2))
28 |
--------------------------------------------------------------------------------
/test/SparseSC_36.yml:
--------------------------------------------------------------------------------
1 | name: SparseSC_36
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _r-mutex=1.0.0=anacondar_1
6 | - _tflow_select=2.3.0=mkl
7 | - absl-py=0.7.1=py36_0
8 | - alabaster=0.7.12=py36_0
9 | - asn1crypto=0.24.0=py36_0
10 | - astor=0.7.1=py36_0
11 | - astroid=1.6.5=py36_0
12 | - attrs=19.1.0=py36_1
13 | - autopep8=1.4.4=py_0
14 | - babel=2.6.0=py36_0
15 | - backcall=0.1.0=py36_0
16 | - blas=1.0=mkl
17 | - bleach=3.1.0=py36_0
18 | - ca-certificates=2020.1.1=0
19 | - certifi=2020.4.5.1=py36_0
20 | - cffi=1.12.3=py36h7a1dbc1_0
21 | - chardet=3.0.4=py36_1
22 | - colorama=0.4.1=py36_0
23 | - cryptography=2.6.1=py36h7a1dbc1_0
24 | - cycler=0.10.0=py36h009560c_0
25 | - decorator=4.4.0=py36_1
26 | - defusedxml=0.6.0=py_0
27 | - docutils=0.14=py36h6012d8f_0
28 | - entrypoints=0.3=py36_0
29 | - freetype=2.9.1=ha9979f8_1
30 | - gast=0.2.2=py36_0
31 | - grpcio=1.16.1=py36h351948d_1
32 | - h5py=2.9.0=py36h5e291fa_0
33 | - hdf5=1.10.4=h7ebc959_0
34 | - icc_rt=2019.0.0=h0cc432a_1
35 | - icu=58.2=ha66f8fd_1
36 | - idna=2.8=py36_0
37 | - imagesize=1.1.0=py36_0
38 | - intel-openmp=2019.3=203
39 | - ipykernel=5.1.0=py36h39e3cac_0
40 | - ipython=7.5.0=py36h39e3cac_0
41 | - ipython_genutils=0.2.0=py36h3c5d0ee_0
42 | - isort=4.3.19=py36_0
43 | - jedi=0.13.3=py36_0
44 | - jinja2=2.10.1=py36_0
45 | - jpeg=9b=hb83a4c4_2
46 | - jsonschema=3.0.1=py36_0
47 | - jupyter_client=5.2.4=py36_0
48 | - jupyter_core=4.4.0=py36_0
49 | - keras=2.2.4=0
50 | - keras-applications=1.0.8=py_0
51 | - keras-base=2.2.4=py36_0
52 | - keras-preprocessing=1.1.0=py_1
53 | - kiwisolver=1.1.0=py36ha925a31_0
54 | - lazy-object-proxy=1.4.1=py36he774522_0
55 | - libmklml=2019.0.5=0
56 | - libpng=1.6.37=h2a8f88b_0
57 | - libprotobuf=3.8.0=h7bd577a_0
58 | - libsodium=1.0.16=h9d3ae62_0
59 | - m2w64-bwidget=1.9.10=2
60 | - m2w64-bzip2=1.0.6=6
61 | - m2w64-expat=2.1.1=2
62 | - m2w64-fftw=3.3.4=6
63 | - m2w64-flac=1.3.1=3
64 | - m2w64-gcc-libgfortran=5.3.0=6
65 | - m2w64-gcc-libs=5.3.0=7
66 | - m2w64-gcc-libs-core=5.3.0=7
67 | - m2w64-gettext=0.19.7=2
68 | - m2w64-gmp=6.1.0=2
69 | - m2w64-libiconv=1.14=6
70 | - m2w64-libjpeg-turbo=1.4.2=3
71 | - m2w64-libogg=1.3.2=3
72 | - m2w64-libpng=1.6.21=2
73 | - m2w64-libsndfile=1.0.26=2
74 | - m2w64-libtiff=4.0.6=2
75 | - m2w64-libvorbis=1.3.5=2
76 | - m2w64-libwinpthread-git=5.0.0.4634.697f757=2
77 | - m2w64-mpfr=3.1.4=4
78 | - m2w64-openblas=0.2.19=1
79 | - m2w64-pcre=8.38=2
80 | - m2w64-speex=1.2rc2=3
81 | - m2w64-speexdsp=1.2rc3=3
82 | - m2w64-tcl=8.6.5=3
83 | - m2w64-tk=8.6.5=3
84 | - m2w64-tktable=2.10=5
85 | - m2w64-wineditline=2.101=5
86 | - m2w64-xz=5.2.2=2
87 | - m2w64-zlib=1.2.8=10
88 | - markupsafe=1.1.1=py36he774522_0
89 | - matplotlib=3.1.0=py36hc8f65d3_0
90 | - mccabe=0.6.1=py36_1
91 | - mistune=0.8.4=py36he774522_0
92 | - mkl=2018.0.3=1
93 | - mkl_fft=1.0.6=py36hdbbee80_0
94 | - mkl_random=1.0.1=py36h77b88f5_1
95 | - mock=3.0.5=py36_0
96 | - msys2-conda-epoch=20160418=1
97 | - nbconvert=5.5.0=py_0
98 | - nbformat=4.4.0=py36h3a5bc1b_0
99 | - notebook=5.6.0=py36_0
100 | - numpy=1.15.1=py36hc27ee41_0
101 | - numpy-base=1.15.1=py36h8128ebf_0
102 | - openssl=1.1.1c=he774522_1
103 | - pandas=0.24.2=py36ha925a31_0
104 | - pandoc=2.2.3.2=0
105 | - pandocfilters=1.4.2=py36_1
106 | - parso=0.4.0=py_0
107 | - patsy=0.5.1=py36_0
108 | - pickleshare=0.7.5=py36_0
109 | - pip=19.1.1=py36_0
110 | - prometheus_client=0.6.0=py36_0
111 | - prompt_toolkit=2.0.9=py36_0
112 | - protobuf=3.8.0=py36h33f27b4_0
113 | - psutil=5.6.3=py36he774522_0
114 | - pycodestyle=2.5.0=py36_0
115 | - pycparser=2.19=py36_0
116 | - pygments=2.4.0=py_0
117 | - pylint=1.9.2=py36_0
118 | - pyopenssl=19.0.0=py36_0
119 | - pyparsing=2.4.0=py_0
120 | - pyqt=5.9.2=py36h6538335_2
121 | - pyreadline=2.1=py36_1
122 | - pyrsistent=0.14.11=py36he774522_0
123 | - pysocks=1.7.0=py36_0
124 | - python=3.6.8=h9f7ef89_7
125 | - python-dateutil=2.8.0=py36_0
126 | - pytz=2019.1=py_0
127 | - pywinpty=0.5.5=py36_1000
128 | - pyyaml=5.1=py36he774522_0
129 | - pyzmq=18.0.0=py36ha925a31_0
130 | - qt=5.9.7=vc14h73c81de_0
131 | - r-assertthat=0.2.1=r36h6115d3f_0
132 | - r-base=3.6.0=hf18239d_0
133 | - r-bh=1.69.0_1=r36h6115d3f_0
134 | - r-bit=1.1_14=r36h6115d3f_0
135 | - r-bit64=0.9_7=r36h6115d3f_0
136 | - r-blob=1.1.1=r36h6115d3f_0
137 | - r-cli=1.1.0=r36h6115d3f_0
138 | - r-crayon=1.3.4=r36h6115d3f_0
139 | - r-dbi=1.0.0=r36h6115d3f_0
140 | - r-dbplyr=1.4.0=r36h6115d3f_0
141 | - r-digest=0.6.18=r36h6115d3f_0
142 | - r-dplyr=0.8.0.1=r36h6115d3f_0
143 | - r-fansi=0.4.0=r36h6115d3f_0
144 | - r-glue=1.3.1=r36h6115d3f_0
145 | - r-magrittr=1.5=r36h6115d3f_4
146 | - r-memoise=1.1.0=r36h6115d3f_0
147 | - r-pillar=1.3.1=r36h6115d3f_0
148 | - r-pkgconfig=2.0.2=r36h6115d3f_0
149 | - r-plogr=0.2.0=r36h6115d3f_0
150 | - r-prettyunits=1.0.2=r36h6115d3f_0
151 | - r-purrr=0.3.2=r36h6115d3f_0
152 | - r-r6=2.4.0=r36h6115d3f_0
153 | - r-rcpp=1.0.1=r36h6115d3f_0
154 | - r-rlang=0.3.4=r36h6115d3f_0
155 | - r-rsqlite=2.1.1=r36h6115d3f_0
156 | - r-tibble=2.1.1=r36h6115d3f_0
157 | - r-tidyselect=0.2.5=r36h6115d3f_0
158 | - r-utf8=1.1.4=r36h6115d3f_0
159 | - requests=2.21.0=py36_0
160 | - rope=0.14.0=py_0
161 | - rpy2=2.9.4=py36r36h39e3cac_0
162 | - scikit-learn=0.19.1=py36hae9bb9f_0
163 | - scipy=1.0.1=py36hce232c7_0
164 | - send2trash=1.5.0=py36_0
165 | - setuptools=41.0.1=py36_0
166 | - sip=4.19.8=py36h6538335_0
167 | - six=1.12.0=py36_0
168 | - snowballstemmer=1.2.1=py36h763602f_0
169 | - sphinx=1.6.3=py36h9bb690b_0
170 | - sphinxcontrib=1.0=py36_1
171 | - sphinxcontrib-websupport=1.1.0=py36_1
172 | - sqlite=3.28.0=he774522_0
173 | - statsmodels=0.10.1=py36h8c2d366_0
174 | - tbb=2019.4=h74a9793_0
175 | - tbb4py=2019.4=py36h74a9793_0
176 | - tensorboard=1.13.1=py36h33f27b4_0
177 | - tensorflow=1.13.1=mkl_py36hd212fbe_0
178 | - tensorflow-base=1.13.1=mkl_py36hcaf7020_0
179 | - tensorflow-estimator=1.13.0=py_0
180 | - termcolor=1.1.0=py36_1
181 | - terminado=0.8.2=py36_0
182 | - testpath=0.4.2=py36_0
183 | - tornado=5.1.1=py36hfa6e2cd_0
184 | - traitlets=4.3.2=py36h096827d_0
185 | - typing=3.6.4=py36_0
186 | - urllib3=1.24.2=py36_0
187 | - vc=14.1=h0510ff6_4
188 | - vs2015_runtime=14.15.26706=h3a45250_4
189 | - wcwidth=0.1.7=py36h3d5aa90_0
190 | - webencodings=0.5.1=py36_1
191 | - werkzeug=0.15.4=py_0
192 | - wheel=0.33.4=py36_0
193 | - win_inet_pton=1.1.0=py36_0
194 | - wincertstore=0.2=py36h7fe50ca_0
195 | - winpty=0.4.3=4
196 | - wrapt=1.11.1=py36he774522_0
197 | - yaml=0.1.7=hc54c509_2
198 | - zeromq=4.3.1=h33f27b4_3
199 | - zlib=1.2.11=h62dcd97_3
200 | - pip:
201 | - adal==1.2.1
202 | - azure-batch==6.0.1
203 | - azure-common==1.1.20
204 | - azure-storage-blob==2.0.1
205 | - azure-storage-common==2.0.0
206 | - commonmark==0.9.0
207 | - future==0.17.1
208 | - isodate==0.6.0
209 | - markdown==2.6.11
210 | - msrest==0.6.6
211 | - msrestazure==0.6.0
212 | - oauthlib==3.0.1
213 | - pyjwt==1.7.1
214 | - recommonmark==0.5.0
215 | - requests-oauthlib==1.2.0
216 | - sphinx-markdown-tables==0.0.9
217 | - sphinx-rtd-theme==0.4.3
218 |
219 |
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/test/__init__.py
--------------------------------------------------------------------------------
/test/dgp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SparseSC/3b4d2ebd87b41fc8e0ec3e97fcb2be1b689275c0/test/dgp/__init__.py
--------------------------------------------------------------------------------
/test/dgp/factor_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Factor DGP
3 | """
4 |
5 | import numpy as np
6 |
7 | def factor_dgp(N0,N1,T0,T1,K,R,F):
8 | '''
9 | Factor DGP. Treatment effect is 0.
10 |
11 | Covariates: Values are drawn from N(0,1) and coefficients are drawn from an
12 | exponential(1) and then scaled by the max.
13 | Factors and Loadings are from N(0,1)
14 | Errors are from N(0,1)
15 |
16 | :param N0:
17 | :param N1:
18 | :param T0:
19 | :param T1:
20 | :param K: Number of covariates that affect outcome
21 | :param R: Number of (noise) covariates to do not affect outcome
22 | :param F: Number of factors
23 | :returns: (X_control, X_treated, Y_pre_control, Y_pre_treated, Y_post_control, Y_post_treated, Loadings_control, Loadings_treated)
24 | '''
25 |
26 | if K>0:
27 | # COVARIATE EFFECTS
28 | X_control = np.matrix(np.random.normal(0,1,((N0), K+R)))
29 | X_treated = np.matrix(np.random.normal(0,1,((N1), K+R)))
30 |
31 | b_cause = np.random.exponential(1,K)
32 | b_cause *= 1 / b_cause.max()
33 |
34 | beta = np.matrix(np.concatenate( ( b_cause , np.zeros(R)) ) ).T
35 | Xbeta_C = X_control.dot(beta)
36 | Xbeta_T = X_treated.dot(beta)
37 | else:
38 | X_control = np.empty((N0,0))
39 | X_treated = np.empty((N1,0))
40 | Xbeta_C = np.zeros((N0,1))
41 | Xbeta_T = np.zeros((N1,1))
42 |
43 | # FACTORS
44 | Loadings_control = np.matrix(np.random.normal(0,1,((N0), F)))
45 | Loadings_treated = np.matrix(np.random.normal(0,1,((N1), F)))
46 |
47 | Factors_pre = np.matrix(np.random.normal(0,1,((F), T0)))
48 | Factors_post = np.matrix(np.random.normal(0,1,((F), T1)))
49 |
50 | # RANDOM ERRORS
51 | Y_pre_err_control = np.matrix(np.random.normal(0, 1, ( N0, T0, ) ))
52 | Y_pre_err_treated = np.matrix(np.random.normal(0, 1, ( N1, T0, ) ))
53 | Y_post_err_control = np.matrix(np.random.normal(0, 1, ( N0, T1, ) ))
54 | Y_post_err_treated = np.matrix(np.random.normal(0, 1, ( N1, T1, ) ))
55 |
56 | # OUTCOMES
57 | Y_pre_control = np.tile(Xbeta_C, (1,T0)) + Loadings_control.dot(Factors_pre) + Y_pre_err_control
58 | Y_pre_treated = np.tile(Xbeta_T, (1,T0)) + Loadings_treated.dot(Factors_pre) + Y_pre_err_treated
59 |
60 | Y_post_control = np.tile(Xbeta_C, (1,T1)) + Loadings_control.dot(Factors_post) + Y_post_err_control
61 | Y_post_treated = np.tile(Xbeta_T, (1,T1)) + Loadings_treated.dot(Factors_post) + Y_post_err_treated
62 |
63 | return X_control, X_treated, Y_pre_control, Y_pre_treated, Y_post_control, Y_post_treated, Loadings_control, Loadings_treated
64 |
65 |
--------------------------------------------------------------------------------
/test/dgp/group_effects.py:
--------------------------------------------------------------------------------
1 | """
2 | Factor DGP
3 | """
4 |
5 | import numpy as np
6 |
7 | def ge_dgp(N0,N1,T0,T1,K,S,R,groups,group_scale,beta_scale,confounders_scale,model= "full"):
8 | """
9 | From example-code.py
10 | """
11 |
12 | # COVARIATE EFFECTS
13 | X_control = np.matrix(np.random.normal(0,1,((N0)*groups, K+S+R)))
14 | X_treated = np.matrix(np.random.normal(0,1,((N1)*groups, K+S+R)))
15 |
16 | # CAUSAL
17 | b_cause = np.random.exponential(1,K)
18 | b_cause *= beta_scale / b_cause.max()
19 |
20 | # CONFOUNDERS
21 | b_confound = np.random.exponential(1,S)
22 | b_confound *= confounders_scale / b_confound.max()
23 |
24 | beta_control = np.matrix(np.concatenate( ( b_cause ,b_confound, np.zeros(R)) ) ).T
25 | beta_treated = np.matrix(np.concatenate( ( b_cause ,np.zeros(S), np.zeros(R)) ) ).T
26 |
27 | # GROUP EFFECTS (hidden)
28 |
29 | Y_pre_group_effects = np.random.normal(0,group_scale,(groups,T0))
30 | Y_pre_ge_control = Y_pre_group_effects[np.repeat(np.arange(groups),N0)]
31 | Y_pre_ge_treated = Y_pre_group_effects[np.repeat(np.arange(groups),N1)]
32 |
33 | Y_post_group_effects = np.random.normal(0,group_scale,(groups,T1))
34 | Y_post_ge_control = Y_post_group_effects[np.repeat(np.arange(groups),N0)]
35 | Y_post_ge_treated = Y_post_group_effects[np.repeat(np.arange(groups),N1)]
36 |
37 | # RANDOM ERRORS
38 | Y_pre_err_control = np.matrix(np.random.random( ( N0*groups, T0, ) ))
39 | Y_pre_err_treated = np.matrix(np.random.random( ( N1*groups, T0, ) ))
40 |
41 | Y_post_err_control = np.matrix(np.random.random( ( N0*groups, T1, ) ))
42 | Y_post_err_treated = np.matrix(np.random.random( ( N1*groups, T1, ) ))
43 |
44 | # THE DATA GENERATING PROCESS
45 |
46 | if model == "full":
47 | """
48 | In the full model, covariates (X) are correlated with pre and post
49 | outcomes, and variance of the outcomes pre- and post- outcomes is lower
50 | within groups which span both treated and control units.
51 | """
52 | Y_pre_control = X_control.dot(beta_control) + Y_pre_ge_control + Y_pre_err_control
53 | Y_pre_treated = X_treated.dot(beta_treated) + Y_pre_ge_treated + Y_pre_err_treated
54 |
55 | Y_post_control = X_control.dot(beta_control) + Y_post_ge_control + Y_post_err_control
56 | Y_post_treated = X_treated.dot(beta_treated) + Y_post_ge_treated + Y_post_err_treated
57 |
58 | elif model == "hidden":
59 | """
60 | In the hidden model outcomes are independent of the covariates, but
61 | variance of the outcomes pre- and post- outcomes is lower within groups
62 | which span both treated and control units.
63 | """
64 | Y_pre_control = Y_pre_ge_control + Y_pre_err_control
65 | Y_pre_treated = Y_pre_ge_treated + Y_pre_err_treated
66 |
67 | Y_post_control = Y_post_ge_control + Y_post_err_control
68 | Y_post_treated = Y_post_ge_treated + Y_post_err_treated
69 |
70 | elif model == "null":
71 | """
72 | Purely random data
73 | """
74 | Y_pre_control = Y_pre_err_control
75 | Y_pre_treated = Y_pre_err_treated
76 |
77 | Y_post_control = Y_post_err_control
78 | Y_post_treated = Y_post_err_treated
79 |
80 | else:
81 | raise ValueError("Unknown model type: "+model)
82 |
83 | return X_control, X_treated, Y_pre_control, Y_pre_treated, Y_post_control, Y_post_treated
84 |
85 |
--------------------------------------------------------------------------------
/test/test.pyproj:
--------------------------------------------------------------------------------
1 |
2 |
3 | Debug
4 | 2.0
5 | 01446f94-f552-4ee1-90dc-e93a0db22b4c
6 | .
7 | test_estimation.py
8 |
9 |
10 | .
11 | .
12 | test
13 | test
14 | CondaEnv|CondaEnv|SparseSC_36
15 | true
16 |
17 |
18 | true
19 | false
20 |
21 |
22 | true
23 | false
24 |
25 |
26 |
27 | Code
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
47 |
48 |
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/test/test_batchFile.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------------------------------
2 | # Programmer: Jason Thorpe
3 | # Date 1/25/2019 3:34:02 PM
4 | # Language: Python (.py) Version 2.7 or 3.5
5 | # Usage:
6 | #
7 | # Test all model types
8 | #
9 | # \SpasrseSC > python -m unittest test/test_fit.py
10 | #
11 | # Test a specific model type (e.g. "prospective-restricted"):
12 | #
13 | # \SpasrseSC > python -m unittest test.test_fit.TestFit.test_retrospective
14 | #
15 | # --------------------------------------------------------------------------------
16 |
17 | from __future__ import print_function # for compatibility with python 2.7
18 | import numpy as np
19 | import sys
20 | import random
21 | import unittest
22 | import warnings
23 | from os.path import expanduser, join
24 | from scipy.optimize.linesearch import LineSearchWarning
25 |
26 | try:
27 | from SparseSC import fit
28 | except ImportError:
29 | raise RuntimeError("SparseSC is not installed. use 'pip install -e .' to install")
30 |
31 |
32 | class TestFit(unittest.TestCase):
33 | def setUp(self):
34 |
35 | random.seed(12345)
36 | np.random.seed(101101001)
37 | control_units = 50
38 | treated_units = 20
39 | features = 10
40 | targets = 5
41 |
42 | self.X = np.random.rand(control_units + treated_units, features)
43 | self.Y = np.random.rand(control_units + treated_units, targets)
44 | self.treated_units = np.arange(treated_units)
45 |
46 | @classmethod
47 | def run_test(cls, obj, model_type, verbose=False):
48 | if verbose:
49 | print("Calling fit with `model_type = '%s'`..." % (model_type,), end="")
50 | sys.stdout.flush()
51 |
52 | with warnings.catch_warnings():
53 | warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
54 | warnings.filterwarnings("ignore", category=LineSearchWarning)
55 | try:
56 | fit(
57 | X=obj.X,
58 | Y=obj.Y,
59 | model_type=model_type,
60 | treated_units=obj.treated_units
61 | if model_type
62 | in ("retrospective", "prospective", "prospective-restricted")
63 | else None,
64 | # KWARGS:
65 | print_path=False,
66 | progress=verbose,
67 | grid_length=5,
68 | min_iter=-1,
69 | tol=1,
70 | verbose=0,
71 | batchFile=join(
72 | expanduser("~"), "temp", "%s_batch_params.py" % model_type
73 | ),
74 | )
75 | import pdb
76 |
77 | pdb.set_trace()
78 | if verbose:
79 | print("DONE")
80 | except LineSearchWarning:
81 | pass
82 | except PendingDeprecationWarning:
83 | pass
84 | except Exception as exc:
85 | print("Failed with %s: %s" % (exc.__class__.__name__, str(exc)))
86 | raise exc
87 |
88 | def test_retrospective(self):
89 | TestFit.run_test(self, "retrospective")
90 |
91 | def test_prospective(self):
92 | TestFit.run_test(self, "prospective")
93 |
94 | def test_prospective_restrictive(self):
95 | # Catch the LineSearchWarning silently, but allow others
96 |
97 | TestFit.run_test(self, "prospective-restricted")
98 |
99 | def test_full(self):
100 | TestFit.run_test(self, "full")
101 |
102 |
103 | if __name__ == "__main__":
104 | # t = TestFit()
105 | # t.setUp()
106 | # t.test_retrospective()
107 | unittest.main()
108 |
--------------------------------------------------------------------------------
/test/test_estimation.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests for model fitness
3 | """
4 |
5 | import unittest
6 | import random
7 | import numpy as np
8 | import pandas as pd
9 |
10 | #import warnings
11 | #warnings.simplefilter("error")
12 |
13 | try:
14 | import SparseSC as SC
15 | except ImportError:
16 | raise RuntimeError("SparseSC is not installed. Use 'pip install -e .' or 'conda develop .' from repo root to install in dev mode")
17 | from os.path import join, abspath, dirname
18 | from test.dgp.factor_model import factor_dgp
19 |
20 | # import matplotlib.pyplot as plt
21 |
22 | import sys
23 | sys.path.insert(0, join(dirname(abspath(__file__)), "..", "examples"))
24 | from example_graphs import *
25 |
26 |
27 | # pylint: disable=no-self-use, missing-docstring
28 |
29 | # here for lexical scoping
30 | command_line_options = {}
31 |
32 | class TestEstimationForErrors(unittest.TestCase):
33 | def setUp(self):
34 |
35 | random.seed(12345)
36 | np.random.seed(101101001)
37 | control_units = 50
38 | treated_units = 2
39 | N = control_units + treated_units
40 | T = 15
41 | K_X = 2
42 |
43 | self.Y = np.random.rand(N, T)
44 | self.X = np.random.rand(N, K_X)
45 | self.treated_units = np.arange(treated_units)
46 | self.unit_treatment_periods = np.full((N), np.nan)
47 | self.unit_treatment_periods[0] = 7
48 | self.unit_treatment_periods[1] = 8
49 | #self.
50 | #self.unit_treatment_periods[treated_name] = treatment_date_ms
51 |
52 | @classmethod
53 | def run_test(cls, obj, model_type="retrospective", frame_type="ndarray"): #"NDFrame", "pandas_timeindex", NDFrame
54 | X = obj.X
55 | Y = obj.Y
56 | unit_treatment_periods = obj.unit_treatment_periods
57 | if frame_type=="NDFrame" or frame_type=="timeindex":
58 | X = pd.DataFrame(X)
59 | Y = pd.DataFrame(Y)
60 | if frame_type=="timeindex":
61 | t_index = pd.Index(np.datetime64('2000-01-01','D') + range(Y.shape[1]))
62 | unit_treatment_periods = pd.Series(np.datetime64('NaT'), index=Y.index)
63 | unit_treatment_periods[0] = t_index[7]
64 | unit_treatment_periods[1] = t_index[8]
65 | Y.columns = t_index
66 |
67 | SC.estimate_effects(covariates=X, outcomes=Y, model_type=model_type, unit_treatment_periods=unit_treatment_periods)
68 |
69 | def test_all(self): #RidgeCV returns: RuntimeWarning: invalid value encountered in true_divide \n return (c / G_diag) ** 2, c
70 | for model_type in ["retrospective", "prospective", "prospective-restricted"]:
71 | for frame_type in ["ndarray", "NDFrame", "timeindex"]:
72 | TestEstimationForErrors.run_test(self, model_type, frame_type)
73 |
74 | class TestDGPs(unittest.TestCase):
75 | """
76 | testing fixture
77 | """
78 | @staticmethod
79 | def simple_summ(fit, Y):
80 | #print("V_pen=%s, W_pen=%s" % (fit.fitted_v_pen, fit.fitted_w_pen))
81 | if fit.match_space_desc is not None:
82 | print(fit.match_space_desc)
83 | else:
84 | print("V=%s" % np.diag(fit.V))
85 | print("Treated weights: sim=%s, uns=%s, sum=%s" % ( fit.sc_weights[0, 49], fit.sc_weights[0, 99], sum(fit.sc_weights[0, :]),))
86 | print("Sim Con weights: sim=%s, uns=%s, sum=%s" % ( fit.sc_weights[1, 49], fit.sc_weights[1, 99], sum(fit.sc_weights[1, :]),))
87 | print("Uns Con weights: sim=%s, uns=%s, sum=%s" % ( fit.sc_weights[51, 49], fit.sc_weights[51, 99], sum(fit.sc_weights[51, :]),))
88 | Y_sc = fit.predict(Y)
89 | print("Treated diff: %s" % (Y - Y_sc)[0, :])
90 |
91 |
92 | def testSimpleTrendDGP(self):
93 | """
94 | No X, just Y; half the donors are great, other half are bad
95 | """
96 | N1, N0_sim, N0_not = 1, 50, 50
97 | N0 = N0_sim + N0_not
98 | N = N1 + N0
99 | treated_units, control_units = range(N1), range(N1, N)
100 | T0, T1 = 5, 2
101 | T = T0 + T1 # unused
102 | proto_sim = np.array([1, 2, 3, 4, 5] + [6,7], ndmin=2)
103 | proto_not = np.array([0, 2, 4, 6, 8] + [10, 12], ndmin=2)
104 | te = 2
105 | proto_tr = proto_sim + np.hstack((np.zeros((1, T0)), np.full((1, T1), te)))
106 | Y1 = np.matmul(np.ones((N1, 1)), proto_tr)
107 | Y0_sim = np.matmul(np.ones((N0_sim, 1)), proto_sim)
108 | Y0_sim = Y0_sim + np.random.normal(0,0.1,Y0_sim.shape)
109 | #Y0_sim = Y0_sim + np.hstack((np.zeros((N0_sim,1)),
110 | # np.random.normal(0,0.1,(N0_sim,1)),
111 | # np.zeros((N0_sim,T-2))))
112 | Y0_not = np.matmul(np.ones((N0_not, 1)), proto_not)
113 | Y0_not = Y0_not + np.random.normal(0,0.1,Y0_not.shape)
114 | Y = np.vstack((Y1, Y0_sim, Y0_not))
115 |
116 | unit_treatment_periods = np.full((N), -1)
117 | unit_treatment_periods[0] = T0
118 |
119 | # Y += np.random.normal(0, 0.01, Y.shape)
120 |
121 | # OPTIMIZE OVER THE V_PEN'S
122 | # for v_pen, w_pen in [(1,1), (1,1e-10), (1e-10,1e-10), (1e-10,1), (None, None)]: #
123 | # print("\nv_pen=%s, w_pen=%s" % (v_pen, w_pen))
124 | ret = SC.estimate_effects(
125 | Y,
126 | unit_treatment_periods,
127 | ret_CI=True,
128 | max_n_pl=200,
129 | fast = True,
130 | #stopping_rule=4,
131 | **command_line_options,
132 | )
133 | TestDGPs.simple_summ(ret.fits[T0], Y)
134 | V_penalty = ret.fits[T0].fitted_v_pen
135 |
136 | Y_sc = ret.fits[T0].predict(Y)# [control_units, :]
137 | te_vec_est = (Y - Y_sc)[0:T0:]
138 | # weight_sums = np.sum(ret.fit.sc_weights, axis=1)
139 |
140 | # print(ret.fit.scores)
141 | p_value = ret.p_value
142 | #print("p-value: %s" % p_value)
143 | #print( ret.CI)
144 | #print(np.diag(ret.fit.V))
145 | #import pdb; pdb.set_trace()
146 | # print(ret)
147 | assert te in ret.CI, "Confidence interval does not include the true effect"
148 | assert p_value is not None
149 | assert p_value < 0.1, "P-value is larger than expected"
150 |
151 | # [sc_raw, sc_diff] = ind_sc_plots(Y[0, :], Y_sc[0, :], T0, ind_ci=ret.ind_CI)
152 | # plt.figure("sc_raw")
153 | # plt.title("Unit 0")
154 | # ### SHOW() blocks!!!!
155 | # # plt.show()
156 | # plt.figure("sc_diff")
157 | # plt.title("Unit 0")
158 | # # plt.show()
159 | # [te] = te_plot(ret)
160 | # plt.figure("te")
161 | # plt.title("Average Treatment Effect")
162 | # # plt.show()
163 |
164 | def testFactorDGP(self):
165 | """
166 | factor dbp based test
167 | """
168 | N1, N0 = 2, 100
169 | treated_units = [0, 1]
170 | T0, T1 = 20, 10
171 | K, R, F = 5, 5, 5
172 | (
173 | Cov_control,
174 | Cov_treated,
175 | Out_pre_control,
176 | Out_pre_treated,
177 | Out_post_control,
178 | Out_post_treated,
179 | _,
180 | _,
181 | ) = factor_dgp(N0, N1, T0, T1, K, R, F)
182 |
183 | Cov = np.vstack((Cov_treated, Cov_control))
184 | Out_pre = np.vstack((Out_pre_treated, Out_pre_control))
185 | Out_post = np.vstack((Out_post_treated, Out_post_control))
186 |
187 | SC.estimate_effects(
188 | Out_pre,
189 | Out_post,
190 | treated_units,
191 | Cov,
192 | **command_line_options,
193 | )
194 |
195 | # print(fit_res)
196 | # est_res = SC.estimate_effects(
197 | # Cov, Out_pre, Out_post, treated_units, V_penalty=0, W_penalty=0.001
198 | # )
199 | # print(est_res)
200 |
201 |
202 | if __name__ == "__main__":
203 | import argparse
204 |
205 | parser = argparse.ArgumentParser(prog="PROG", allow_abbrev=False)
206 | parser.add_argument(
207 | "--constrain", choices=["orthant", "simplex"], default="simplex"
208 | )
209 | args = parser.parse_args()
210 | command_line_options.update(vars(args))
211 |
212 | random.seed(12345)
213 | np.random.seed(10101)
214 |
215 | t = TestEstimationForErrors()
216 | t.setUp()
217 | t.test_all()
218 | # unittest.main()
219 |
--------------------------------------------------------------------------------
/test/test_fit_batch.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------------------------------
2 | # Programmer: Jason Thorpe
3 | # Date 1/25/2019 3:34:02 PM
4 | # Language: Python (.py) Version 2.7 or 3.5
5 | # Usage:
6 | #
7 | # Test all model types
8 | #
9 | # \SpasrseSC > python -m unittest test/test_fit.py
10 | #
11 | # Test a specific model type (e.g. "prospective-restricted"):
12 | #
13 | # \SpasrseSC > python -m unittest test.test_fit.TestFit.test_retrospective
14 | #
15 | # --------------------------------------------------------------------------------
16 |
17 | from __future__ import print_function # for compatibility with python 2.7
18 | import numpy as np
19 | import sys, random
20 | import unittest
21 | import warnings
22 | from scipy.optimize.linesearch import LineSearchWarning
23 |
24 | try:
25 | from SparseSC import fit
26 | except ImportError:
27 | raise RuntimeError("SparseSC is not installed. use 'pip install -e .' to install")
28 |
29 |
30 | class TestFit(unittest.TestCase):
31 | def setUp(self):
32 |
33 | random.seed(12345)
34 | np.random.seed(101101001)
35 | control_units = 50
36 | treated_units = 20
37 | features = 10
38 | targets = 5
39 |
40 | self.X = np.random.rand(control_units + treated_units, features)
41 | self.Y = np.random.rand(control_units + treated_units, targets)
42 | self.treated_units = np.arange(treated_units)
43 |
44 | @classmethod
45 | def run_test(cls, obj, model_type, verbose=False):
46 | if verbose:
47 | print("Calling fit with `model_type = '%s'`..." % (model_type,), end="")
48 | sys.stdout.flush()
49 |
50 | with warnings.catch_warnings():
51 | warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
52 | warnings.filterwarnings("ignore", category=LineSearchWarning)
53 | try:
54 | Model1 = fit(
55 | features=obj.X,
56 | targets=obj.Y,
57 | model_type=model_type,
58 | treated_units=obj.treated_units
59 | if model_type
60 | in ("retrospective", "prospective", "prospective-restricted")
61 | else None,
62 | # KWARGS:
63 | print_path=False,
64 | stopping_rule=1,
65 | progress=verbose,
66 | grid_length=5,
67 | min_iter=-1,
68 | tol=1,
69 | verbose=0,
70 | )
71 | Model2 = fit(
72 | features=obj.X,
73 | targets=obj.Y,
74 | model_type=model_type,
75 | treated_units=obj.treated_units
76 | if model_type
77 | in ("retrospective", "prospective", "prospective-restricted")
78 | else None,
79 | # KWARGS:
80 | print_path=False,
81 | stopping_rule=1,
82 | progress=verbose,
83 | grid_length=5,
84 | min_iter=-1,
85 | tol=1,
86 | verbose=0,
87 | batch_client_config="sg_daemon",
88 | )
89 | if verbose:
90 | print("DONE")
91 | except LineSearchWarning:
92 | pass
93 | except PendingDeprecationWarning:
94 | pass
95 | except Exception as exc:
96 | print(
97 | "Failed with %s: %s"
98 | % (exc.__class__.__name__, getattr(exc, "message", "<>"))
99 | )
100 | raise exc
101 |
102 | def test_retrospective(self):
103 | TestFit.run_test(self, "retrospective")
104 |
105 |
106 | # -- def test_prospective(self):
107 | # -- TestFit.run_test(self, "prospective")
108 | # --
109 | # -- def test_prospective_restrictive(self):
110 | # -- # Catch the LineSearchWarning silently, but allow others
111 | # --
112 | # -- TestFit.run_test(self, "prospective-restricted")
113 | # --
114 | # -- def test_full(self):
115 | # -- TestFit.run_test(self, "full")
116 |
117 |
118 | if __name__ == "__main__":
119 | # t = TestFit()
120 | # t.setUp()
121 | # t.test_retrospective()
122 | unittest.main()
123 |
--------------------------------------------------------------------------------
/test/test_normal.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 |
4 | try:
5 | import SparseSC
6 | except ImportError:
7 | raise RuntimeError("SparseSC is not installed. Use 'pip install -e .' or 'conda develop .' from repo root to install in dev mode")
8 |
9 |
10 | class TestNormalForErrors(unittest.TestCase):
11 | def test_SSC_DescrStat(self):
12 | mat = np.arange(20).reshape(4,5)
13 | ds_top = SparseSC.SSC_DescrStat.from_data(mat[:3,:])
14 | ds_bottom = SparseSC.SSC_DescrStat.from_data(mat[3:,:])
15 | ds_add = ds_top + ds_bottom
16 | ds_top.update(mat[3:,:])
17 | ds_whole = SparseSC.SSC_DescrStat.from_data(mat)
18 | assert ds_add == ds_top, "ds_add == ds_top"
19 | assert ds_add == ds_whole, "ds_add == ds_whole"
20 |
21 | def test_DescrSet(self):
22 | Y_t = (np.arange(20)+1).reshape(4,5)
23 | Y_c = np.arange(20).reshape(4,5)
24 | Y_t_cf_c = Y_c
25 | ds = SparseSC.DescrSet.from_data(Y_t=Y_t, Y_t_cf_c=Y_t_cf_c)
26 | est = ds.calc_estimates()
27 | assert np.array_equal(est.att_est.effect, np.full((5), 1.0)), "estimations not right"
28 |
29 | if __name__ == "__main__":
30 | t = TestNormalForErrors()
31 | t.test_SSC_DescrStat()
32 | t.test_DescrSet()
33 |
34 | #unittest.main()
35 |
--------------------------------------------------------------------------------