(undefined);
18 | React.useEffect(() => {
19 | if (data) {
20 | console.log('HERE');
21 | let filteredRequests: PromoteRequest[] = Object.assign([], data);
22 | if (modelFilter.length > 0) {
23 | filteredRequests = filteredRequests.filter((request) => modelFilter.includes(request.model_name));
24 | }
25 |
26 | if (statusFilter.length > 0) {
27 | filteredRequests = filteredRequests.filter((request) => {
28 | if (request.status === 'open') {
29 | return statusFilter.includes('Open');
30 | }
31 | if (request.status === 'closed' || request.status === 'approved') {
32 | return statusFilter.includes('Closed');
33 | }
34 | return false;
35 | });
36 | }
37 | setRequests(filteredRequests);
38 | } else {
39 | setRequests(undefined);
40 | }
41 | }, [modelFilter, statusFilter, data]);
42 |
43 | const modelOptions = Array.from(new Set(data?.map((request) => request.model_name) || []));
44 | return (
45 |
46 |
47 |
48 |
49 | ({ displayName: model_name, value: model_name }))}
52 | onChangeValue={setModelFilter}
53 | />
54 |
55 |
64 |
65 |
66 |
67 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 | );
78 | };
79 |
80 | export default Component;
81 |
--------------------------------------------------------------------------------
/checkpoint/frontend/src/pages/RequestList/RequestListVM.tsx:
--------------------------------------------------------------------------------
1 | import * as React from 'react';
2 | import { AppState } from '../../redux/state';
3 | import { fetchRequests } from '../../redux/actions/appActions';
4 | import { ThunkDispatch } from 'redux-thunk';
5 | import { AnyAction } from 'redux';
6 | import { connect, ConnectedProps } from 'react-redux';
7 | import RequestList from './RequestList';
8 |
9 | const mapState = (state: AppState) => {
10 | return {
11 | requests: state.app.requests,
12 | };
13 | };
14 |
15 | const mapDispatchToProps = (dispatch: ThunkDispatch) => {
16 | return {
17 | fetchRequests: () => dispatch(fetchRequests()),
18 | };
19 | };
20 |
21 | const connector = connect(mapState, mapDispatchToProps);
22 |
23 | type PropsFromRedux = ConnectedProps;
24 |
25 | type Props = PropsFromRedux;
26 |
27 | const RequestListVM: React.FC = ({ requests, fetchRequests }) => {
28 | React.useEffect(() => {
29 | console.log('Home set');
30 | fetchRequests();
31 | return () => {
32 | console.log('Home clear');
33 | };
34 | }, []);
35 | return ;
36 | };
37 |
38 | const ConnectedViewModel = connector(RequestListVM);
39 |
40 | export default ConnectedViewModel;
41 |
--------------------------------------------------------------------------------
/checkpoint/frontend/src/pages/RequestList/index.ts:
--------------------------------------------------------------------------------
1 | import RequestList from './RequestListVM';
2 |
3 | export default RequestList;
4 |
--------------------------------------------------------------------------------
/checkpoint/frontend/src/pages/ShowRequest/ShowRequest.module.scss:
--------------------------------------------------------------------------------
1 | .diffTable {
2 | :global {
3 | .ant-table-selection-column {
4 | display: none;
5 | }
6 | }
7 | }
--------------------------------------------------------------------------------
/checkpoint/frontend/src/pages/ShowRequest/ShowRequestVM.tsx:
--------------------------------------------------------------------------------
1 | import * as React from 'react';
2 | import { AppState } from '../../redux/state';
3 | import { ThunkDispatch } from 'redux-thunk';
4 | import { AnyAction } from 'redux';
5 | import { connect, ConnectedProps } from 'react-redux';
6 | import ShowRequest from './ShowRequest';
7 | import { useParams } from 'react-router-dom';
8 | import { fetchRequestDetails, clearRequestDetails, fetchRequests, submitReview } from 'redux/actions/appActions';
9 | import { CreateReview } from '../../utils/types';
10 | import { History } from 'history';
11 |
12 | const mapState = (state: AppState) => {
13 | return {
14 | requests: state.app.requests,
15 | details: state.app.details,
16 | };
17 | };
18 |
19 | const mapDispatchToProps = (dispatch: ThunkDispatch) => {
20 | return {
21 | fetchRequestDetails: (request_id: number) => dispatch(fetchRequestDetails(request_id)),
22 | clearRequestDetails: () => dispatch(clearRequestDetails()),
23 | fetchRequests: () => dispatch(fetchRequests()),
24 | submitReview: (history: History, request_id: string, request: CreateReview) =>
25 | dispatch(submitReview(history, request_id, request)),
26 | };
27 | };
28 |
29 | const connector = connect(mapState, mapDispatchToProps);
30 |
31 | type PropsFromRedux = ConnectedProps;
32 |
33 | type Props = PropsFromRedux;
34 |
35 | interface Params {
36 | request_id?: string;
37 | }
38 |
39 | const ShowRequestVM: React.FC = ({
40 | requests,
41 | details,
42 | fetchRequests,
43 | fetchRequestDetails,
44 | clearRequestDetails,
45 | submitReview,
46 | }) => {
47 | const { request_id } = useParams();
48 | let rid: number | undefined = undefined;
49 | if (request_id) {
50 | rid = parseInt(request_id);
51 | }
52 | React.useEffect(() => {
53 | clearRequestDetails();
54 | fetchRequests();
55 | if (rid) {
56 | fetchRequestDetails(rid);
57 | }
58 | }, []);
59 | const request = requests?.find((request) => request.id == rid);
60 |
61 | return ;
62 | };
63 |
64 | const ConnectedViewModel = connector(ShowRequestVM);
65 |
66 | export default ConnectedViewModel;
67 |
--------------------------------------------------------------------------------
/checkpoint/frontend/src/pages/ShowRequest/index.ts:
--------------------------------------------------------------------------------
1 | import ShowRequestVM from './ShowRequestVM';
2 |
3 | export default ShowRequestVM;
4 |
--------------------------------------------------------------------------------
/checkpoint/frontend/src/pages/index.ts:
--------------------------------------------------------------------------------
1 | export { default as RequestList } from './RequestList';
2 | export { default as ShowRequest } from './ShowRequest';
3 | export { default as RequestForm } from './RequestForm';
4 |
--------------------------------------------------------------------------------
/checkpoint/frontend/src/react-app-env.d.ts:
--------------------------------------------------------------------------------
1 | ///
2 |
--------------------------------------------------------------------------------
/checkpoint/frontend/src/redux/CONSTANTS.ts:
--------------------------------------------------------------------------------
1 | export const example = 3;
2 |
--------------------------------------------------------------------------------
/checkpoint/frontend/src/redux/actions/appActions.ts:
--------------------------------------------------------------------------------
1 | // You can use CONSTANTS.js file for below definitions of constants and import here.
2 | import { ThunkDispatch } from 'redux-thunk';
3 | import { AppState } from '../state';
4 | import { PromoteRequest, CreatePromoteRequest, RequestDetails, CreateReview } from '../../utils/types';
5 | import { AnyAction } from 'redux';
6 | import { History } from 'history';
7 |
8 | const root = process.env.NODE_ENV === 'development' ? 'http://localhost:5000' : '';
9 |
10 | export const API_ERROR = 'API_ERROR';
11 |
12 | export const FETCH_MODELS = 'FETCH_MODELS';
13 |
14 | export const GOT_MODELS = 'GOT_MODELS';
15 |
16 | export const gotModels = (models: string[]): AnyAction => ({
17 | type: GOT_MODELS,
18 | models: models,
19 | });
20 |
21 | export const fetchModels = () => {
22 | return async (dispatch: ThunkDispatch): Promise => {
23 | const response = await fetch(`${root}/checkpoint/api/models`);
24 | if (response.status === 200) {
25 | try {
26 | dispatch(gotModels(await response.json()));
27 | } catch {}
28 | } else {
29 | console.error(response);
30 | }
31 | };
32 | };
33 |
34 | export const FETCH_VERSIONS = 'FETCH_VERSIONS';
35 |
36 | export const GOT_VERSIONS = 'GOT_VERSIONS';
37 |
38 | export const gotVersions = (versions: string[]): AnyAction => ({
39 | type: GOT_VERSIONS,
40 | versions: versions,
41 | });
42 |
43 | export const fetchVersions = (model: string) => {
44 | return async (dispatch: ThunkDispatch): Promise => {
45 | const response = await fetch(`${root}/checkpoint/api/models/${model}/versions`);
46 | if (response.status === 200) {
47 | try {
48 | dispatch(gotVersions(await response.json()));
49 | } catch {}
50 | } else {
51 | console.error(response);
52 | }
53 | };
54 | };
55 |
56 | export const FETCH_REQUESTS = 'FETCH_REQUESTS';
57 |
58 | export const GOT_REQUESTS = 'GOT_REQUESTS';
59 |
60 | export const gotRequests = (requests: PromoteRequest[]): AnyAction => ({
61 | type: GOT_REQUESTS,
62 | requests: requests,
63 | });
64 |
65 | export const fetchRequests = () => {
66 | return async (dispatch: ThunkDispatch): Promise => {
67 | const response = await fetch(`${root}/checkpoint/api/requests`);
68 | if (response.status === 200) {
69 | try {
70 | dispatch(gotRequests(await response.json()));
71 | } catch {}
72 | } else {
73 | console.error(response);
74 | }
75 | };
76 | };
77 |
78 | export const FETCH_STAGES = 'FETCH_STAGES';
79 |
80 | export const GOT_STAGES = 'GOT_STAGES';
81 |
82 | export const gotStages = (stages: string[]): AnyAction => ({
83 | type: GOT_STAGES,
84 | stages: stages,
85 | });
86 |
87 | export const fetchStages = () => {
88 | return async (dispatch: ThunkDispatch): Promise => {
89 | const response = await fetch(`${root}/checkpoint/api/stages`);
90 | if (response.status === 200) {
91 | try {
92 | dispatch(gotStages(await response.json()));
93 | } catch {}
94 | } else {
95 | console.error(response);
96 | }
97 | };
98 | };
99 |
100 | export const SUBMIT_REQUEST = 'SUBMIT_REQUEST';
101 |
102 | export const SUBMIT_REQUEST_ERROR = 'SUBMIT_REQUEST_ERROR';
103 |
104 | export const gotSubmitRequestError = (error: string): AnyAction => ({
105 | type: SUBMIT_REQUEST_ERROR,
106 | error: error,
107 | });
108 |
109 | export const CLEAR_SUBMIT_REQUEST_ERROR = 'CLEAR_SUBMIT_REQUEST_ERROR';
110 |
111 | export const clearSubmitRequestError = (): AnyAction => ({
112 | type: CLEAR_SUBMIT_REQUEST_ERROR,
113 | });
114 |
115 | export const submitRequest = (history: History, request: CreatePromoteRequest) => {
116 | return async (dispatch: ThunkDispatch): Promise => {
117 | try {
118 | const response = await fetch(`${root}/checkpoint/api/requests`, {
119 | method: 'POST',
120 | headers: {
121 | 'Content-Type': 'application/json',
122 | },
123 | body: JSON.stringify(request),
124 | });
125 | if (response.status === 200) {
126 | const data = await response.json();
127 | console.log(data);
128 | history.push(`/checkpoint/requests/${data.id}`);
129 | console.log('SUCCESS');
130 | } else {
131 | dispatch(
132 | gotSubmitRequestError(
133 | `Error submitting request (${response.status}: ${response.statusText}): ${await response.text()}`,
134 | ),
135 | );
136 | }
137 | } catch (error) {
138 | console.error(error);
139 | }
140 | };
141 | };
142 |
143 | export const FETCH_REQUEST_DETAILS = 'FETCH_REQUEST_DETAILS';
144 |
145 | export const CLEAR_REQUEST_DETAILS = 'FETCH_REQUEST_DETAILS';
146 |
147 | export const clearRequestDetails = (): AnyAction => ({
148 | type: GOT_REQUEST_DETAILS,
149 | });
150 |
151 | export const GOT_REQUEST_DETAILS = 'GOT_REQUEST_DETAILS';
152 |
153 | export const gotRequestDetails = (details: RequestDetails): AnyAction => ({
154 | type: GOT_REQUEST_DETAILS,
155 | details: details,
156 | });
157 |
158 | export const fetchRequestDetails = (request_id: number) => {
159 | return async (dispatch: ThunkDispatch): Promise => {
160 | const response = await fetch(`${root}/checkpoint/api/requests/${request_id}/details`);
161 | if (response.status === 200) {
162 | try {
163 | dispatch(gotRequestDetails(await response.json()));
164 | } catch {}
165 | } else {
166 | console.error(response);
167 | }
168 | };
169 | };
170 |
171 | export const SUBMIT_REVIEW = 'SUBMIT_REVIEW';
172 |
173 | export const submitReview = (history: History, request_id: string, request: CreateReview) => {
174 | return async (dispatch: ThunkDispatch): Promise => {
175 | try {
176 | const response = await fetch(`${root}/checkpoint/api/requests/${request_id}`, {
177 | method: 'PUT',
178 | headers: {
179 | 'Content-Type': 'application/json',
180 | },
181 | body: JSON.stringify(request),
182 | });
183 | if (response.status === 200) {
184 | history.push(`/checkpoint/requests`);
185 | } else {
186 | dispatch(
187 | gotSubmitRequestError(
188 | `Error submitting request (${response.status}: ${response.statusText}): ${await response.text()}`,
189 | ),
190 | );
191 | }
192 | } catch (error) {
193 | console.error(error);
194 | }
195 | };
196 | };
197 |
--------------------------------------------------------------------------------
/checkpoint/frontend/src/redux/reducers/appReducer.ts:
--------------------------------------------------------------------------------
1 | import { createReducer } from 'typesafe-actions';
2 | import { RootState } from '../state';
3 |
4 | const initialState = {
5 | requests: undefined,
6 | models: undefined,
7 | versions: undefined,
8 | } as RootState;
9 |
10 | export const appReducer = createReducer(initialState, {
11 | API_ERROR: (state) => {
12 | return state;
13 | },
14 | GOT_MODELS: (state, action) => {
15 | return {
16 | ...state,
17 | models: action.models,
18 | };
19 | },
20 | GOT_VERSIONS: (state, action) => {
21 | return {
22 | ...state,
23 | versions: action.versions,
24 | };
25 | },
26 | GOT_STAGES: (state, action) => {
27 | return {
28 | ...state,
29 | stages: action.stages,
30 | };
31 | },
32 | GOT_REQUESTS: (state, action) => {
33 | return {
34 | ...state,
35 | requests: action.requests,
36 | };
37 | },
38 | SUBMIT_REQUEST_ERROR: (state, action) => {
39 | return {
40 | ...state,
41 | error: action.error,
42 | };
43 | },
44 | CLEAR_SUBMIT_REQUEST_ERROR: (state) => {
45 | return {
46 | ...state,
47 | error: undefined,
48 | };
49 | },
50 | GOT_REQUEST_DETAILS: (state, action) => {
51 | return {
52 | ...state,
53 | details: action.details,
54 | };
55 | },
56 | CLEAR_REQUEST_DETAILS: (state) => {
57 | return {
58 | ...state,
59 | details: undefined,
60 | };
61 | },
62 | });
63 |
--------------------------------------------------------------------------------
/checkpoint/frontend/src/redux/reducers/index.ts:
--------------------------------------------------------------------------------
1 | import { combineReducers } from 'redux';
2 | import { appReducer } from './appReducer';
3 |
4 | export const rootReducer = combineReducers({
5 | app: appReducer,
6 | });
7 |
--------------------------------------------------------------------------------
/checkpoint/frontend/src/redux/state.ts:
--------------------------------------------------------------------------------
1 | import { PromoteRequest, Model, ModelVersion, RequestDetails } from '../utils/types';
2 |
3 | export interface AppState {
4 | app: RootState;
5 | }
6 |
7 | export interface RootState {
8 | requests?: PromoteRequest[];
9 | models?: Model[];
10 | stages?: string[];
11 | versions?: ModelVersion[];
12 | error?: string;
13 | details?: RequestDetails;
14 | }
15 |
--------------------------------------------------------------------------------
/checkpoint/frontend/src/redux/store.ts:
--------------------------------------------------------------------------------
1 | import { Store, applyMiddleware, compose, createStore } from 'redux';
2 | import thunk from 'redux-thunk';
3 | import { rootReducer } from './reducers';
4 |
5 | declare global {
6 | interface Window {
7 | __REDUX_DEVTOOLS_EXTENSION_COMPOSE__?: typeof compose;
8 | }
9 | }
10 |
11 | const composeEnhancers = window.__REDUX_DEVTOOLS_EXTENSION_COMPOSE__ || compose;
12 |
13 | export const store: Store = createStore(rootReducer, composeEnhancers(applyMiddleware(thunk)));
14 |
--------------------------------------------------------------------------------
/checkpoint/frontend/src/setupTests.ts:
--------------------------------------------------------------------------------
1 | // jest-dom adds custom jest matchers for asserting on DOM nodes.
2 | // allows you to do things like:
3 | // expect(element).toHaveTextContent(/react/i)
4 | // learn more: https://github.com/testing-library/jest-dom
5 | import '@testing-library/jest-dom';
6 |
--------------------------------------------------------------------------------
/checkpoint/frontend/src/utils/index.ts:
--------------------------------------------------------------------------------
1 | export * from './types';
2 |
--------------------------------------------------------------------------------
/checkpoint/frontend/src/utils/types.ts:
--------------------------------------------------------------------------------
1 | export interface UserInfo {
2 | user: string;
3 | email: string;
4 | }
5 |
6 | export interface PromoteRequest {
7 | title: string;
8 | description?: string;
9 | model_name: string;
10 | model_version: string;
11 | target_stage: string;
12 | current_stage: string;
13 | id: number;
14 | author_username: string;
15 | reviewer_username?: string;
16 | review_comment?: string;
17 | status: 'open' | 'closed' | 'approved';
18 | closed_at_epoch: number;
19 | created_at_epoch: number;
20 | }
21 |
22 | export interface CreatePromoteRequest {
23 | title: string;
24 | description?: string;
25 | model_name: string;
26 | version_id: string;
27 | target_stage: string;
28 | }
29 |
30 | export interface Model {
31 | name: string;
32 | }
33 |
34 | export interface ModelVersion {
35 | model_name: string;
36 | id: string;
37 | }
38 |
39 | export interface RequestDetails {
40 | promote_request_id: number;
41 | challenger_version_details: VersionDetails;
42 | champion_version_details?: VersionDetails;
43 | }
44 |
45 | export interface VersionDetails {
46 | id: string;
47 | stage: string;
48 | tags: Record;
49 | // eslint-disable-next-line
50 | parameters: Record;
51 | metrics: Record;
52 | }
53 |
54 | export interface CreateReview {
55 | status: 'closed' | 'approved';
56 | review_comment: string;
57 | }
58 |
--------------------------------------------------------------------------------
/checkpoint/frontend/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "baseUrl": "./src",
4 | "target": "es5",
5 | "lib": [
6 | "dom",
7 | "dom.iterable",
8 | "esnext"
9 | ],
10 | "allowJs": true,
11 | "skipLibCheck": true,
12 | "esModuleInterop": true,
13 | "allowSyntheticDefaultImports": true,
14 | "strict": true,
15 | "forceConsistentCasingInFileNames": true,
16 | "noFallthroughCasesInSwitch": true,
17 | "module": "esnext",
18 | "moduleResolution": "node",
19 | "resolveJsonModule": true,
20 | "isolatedModules": true,
21 | "noEmit": true,
22 | "jsx": "react-jsx"
23 | },
24 | "include": [
25 | "src"
26 | ],
27 | "exclude": ["**/*.stories.*",]
28 | }
29 |
--------------------------------------------------------------------------------
/checkpoint/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 79
3 | target-version = ['py38']
4 |
--------------------------------------------------------------------------------
/checkpoint/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | black==20.8b1
2 | flake8==3.8.4
3 | mypy==0.800
4 | pre-commit==2.14.0
5 | pytest==6.2.2
6 |
--------------------------------------------------------------------------------
/checkpoint/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | setuptools.setup(
4 | name="checkpoint",
5 | version="0.1.0",
6 | author="Kevin Flansburg, Josh Broomberg",
7 | author_email="kevin.flansburg@dominodatalab.com, josh.broomberg@dominodatalab.com",
8 | description="Model approval layer for your model registry.",
9 | url="https://github.com/dominodatalab/domino-research/checkpoint",
10 | packages=setuptools.find_packages(),
11 | install_requires=[
12 | "beautifulsoup4==4.9",
13 | "Flask==2.0",
14 | "flask-sqlalchemy==2.5.1",
15 | "mixpanel==4.9.0",
16 | "mlflow==1.19",
17 | "requests==2.26",
18 | "psycopg2-binary==2.9",
19 | ],
20 | entry_points={"console_scripts": ["checkpoint = checkpoint.app:main"]},
21 | )
22 |
--------------------------------------------------------------------------------
/checkpoint/tests/test_app.py:
--------------------------------------------------------------------------------
1 | def test_hello_world():
2 | assert True
3 |
--------------------------------------------------------------------------------
/flare/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 |
140 | .DS_Store
141 | .brdg-models
142 |
143 | *.json
144 |
--------------------------------------------------------------------------------
/flare/examples/.gitignore:
--------------------------------------------------------------------------------
1 | *.json
2 | *.joblib
3 |
--------------------------------------------------------------------------------
/flare/examples/README.md:
--------------------------------------------------------------------------------
1 | # Flare Example
2 |
3 | This folder contains a small example of Flare's usage.
4 |
5 | ## Training
6 |
7 | In `train.py`, a simple linear regression model is trained
8 | on the wine quality dataset. Before training, Flare's
9 | `baseline` method is invoked using the training dataframe.
10 | This creates our baseline files (`statistics.json` and
11 | `constraints.json`). Finally, the training script saves
12 | the trained model for later use.
13 |
14 | ## Inference
15 |
16 | In `infer.py`, an example of using Flare at inference time
17 | is shown. First, the Slack alert target is configured, as
18 | shown the [quickstart](/flare#configure-alert-target).
19 | Next, the model is invoked twice using the Flare context.
20 | The second invocation purposfully introduces an outlier, which
21 | should trigger a Flare alert.
22 |
--------------------------------------------------------------------------------
/flare/examples/infer.py:
--------------------------------------------------------------------------------
1 | from flare.alerting import SlackAlertTarget
2 | from flare.runtime import Flare
3 | import pandas as pd
4 | from joblib import load
5 | import logging
6 |
7 | logging.basicConfig(level=logging.DEBUG)
8 |
9 | # Create alert client
10 | alert_target = SlackAlertTarget("/XXXXX/XXXXXX/XXXXXXXXXXXXXXXXXXXX")
11 |
12 | # Load sample data
13 | x = pd.read_csv("winequality-red.csv", sep=";").head(1)
14 | del x["quality"]
15 |
16 | # Load model
17 | model = load("model.joblib")
18 |
19 | # Valid Data; No Alerts
20 | with Flare("wine-quality", x, alert_target):
21 | output = model.predict(x)
22 |
23 |
24 | # Insert invalid (below minimum bound) value
25 | x["fixed acidity"] = 3.0
26 |
27 | # Generates an error notification
28 | with Flare("wine-quality", x, alert_target):
29 | output = model.predict(x)
30 |
--------------------------------------------------------------------------------
/flare/examples/train.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from flare.generators import baseline
3 | from sklearn.linear_model import LinearRegression
4 | from joblib import dump
5 |
6 | df = pd.read_csv("winequality-red.csv", sep=";")
7 |
8 | y = df["quality"]
9 | X = df.copy()
10 | del X["quality"]
11 |
12 | # Create Flare Baseline
13 | baseline(X)
14 |
15 | # Train Model
16 | model = LinearRegression()
17 | model.fit(X, y)
18 |
19 | dump(model, "model.joblib")
20 |
--------------------------------------------------------------------------------
/flare/flare/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dominodatalab/domino-research/c9f6285987bfd9337312a4d52ea497a39ff6e381/flare/flare/__init__.py
--------------------------------------------------------------------------------
/flare/flare/alerting.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from dataclasses import dataclass, asdict
3 | from abc import ABC, abstractmethod
4 | import requests # type: ignore
5 | from typing import Dict, List, Optional, Any
6 | import logging
7 |
8 |
9 | logger = logging.getLogger("flare")
10 |
11 |
12 | class FeatureAlertKind(Enum):
13 | # Sample was more than x standard deviations from mean
14 | OUTLIER = "Outlier"
15 | # Sample was outside of lower or upper bounds
16 | BOUND = "Bound"
17 | # Sample was not the correct data type
18 | TYPE = "Type"
19 | # Sample was null and feature is non-nullable
20 | NULL = "Null"
21 | # Sample was negative and feature was non-negative
22 | NEGATIVE = "Negative"
23 | # Sample was not a valid variant of a categorical feature
24 | CATEGORICAL = "Categorical"
25 |
26 |
27 | @dataclass
28 | class FeatureAlert:
29 | # Feature Name
30 | name: str
31 | # See FeatureAlertKind.value
32 | kind: str
33 |
34 |
35 | @dataclass
36 | class InferenceException:
37 | message: str
38 | traceback: str
39 |
40 |
41 | @dataclass
42 | class Alert:
43 | model_name: str
44 | features: List[FeatureAlert]
45 | exception: Optional[InferenceException]
46 |
47 |
48 | class AlertWebhookTarget(ABC):
49 | @abstractmethod
50 | def _alert_webhook_url(self) -> str:
51 | pass
52 |
53 | @abstractmethod
54 | def _format_alert(self, alert: Alert) -> Dict[Any, Any]:
55 | pass
56 |
57 | def send_alert(self, alert: Alert):
58 | if (alert.exception is None) and (len(alert.features) == 0):
59 | logger.error("Alert has no exception/feature alerts. Not sending")
60 | return
61 |
62 | formatted_alert = self._format_alert(alert)
63 | logger.debug(formatted_alert)
64 |
65 | resp = requests.post(self._alert_webhook_url(), json=formatted_alert)
66 |
67 | if resp.ok:
68 | logger.info(f"Sent alert to {type(self).__name__}")
69 | else:
70 | logger.error(
71 | f"Failed to send alert to {type(self).__name__}. "
72 | + f"Code: {resp.status_code}. Body: {resp.text}"
73 | )
74 |
75 |
76 | class SlackAlertTarget(AlertWebhookTarget):
77 | def __init__(self, slack_webhook_path: str):
78 | # Everything after https://hooks.slack.com/services
79 | # FORMAT: /XXXXX/XXXXXX/XXXXXXXXXXXXXXXXXXXX
80 | self.slack_webhook_path = slack_webhook_path
81 |
82 | def _alert_webhook_url(self) -> str:
83 | return f"https://hooks.slack.com/services{self.slack_webhook_path}"
84 |
85 | def _format_alert(self, alert: Alert) -> Dict[str, Any]:
86 | msg_structure: Dict[str, Any] = {
87 | "blocks": [
88 | {
89 | "type": "section",
90 | "text": {
91 | "type": "mrkdwn",
92 | "text": f"Inference in model *{alert.model_name}* "
93 | + "triggered Flare :sparkler: alerts:",
94 | },
95 | },
96 | ]
97 | }
98 |
99 | if alert.exception is not None:
100 | msg_structure["blocks"].extend(
101 | [
102 | {
103 | "type": "header",
104 | "text": {
105 | "type": "plain_text",
106 | "text": "Runtime Exception",
107 | },
108 | },
109 | {
110 | "type": "section",
111 | "text": {
112 | "type": "mrkdwn",
113 | "text": f"*Message:* {alert.exception.message}\n"
114 | + "*Full trace:*\n\n"
115 | + f"{alert.exception.traceback[:20*1000]}",
116 | },
117 | },
118 | {"type": "divider"},
119 | ]
120 | )
121 |
122 | if (num_alerts := len(alert.features)) > 0:
123 | msg_structure["blocks"].extend(
124 | [
125 | {
126 | "type": "header",
127 | "text": {
128 | "type": "plain_text",
129 | "text": f"Input feature alerts ({num_alerts})",
130 | },
131 | },
132 | {
133 | "type": "section",
134 | "text": {
135 | "type": "mrkdwn",
136 | "text": "\n".join(
137 | [
138 | f"- Feature: {fa.name}. "
139 | + f"Alert kind: {fa.kind}"
140 | for fa in alert.features
141 | ]
142 | ),
143 | },
144 | },
145 | ]
146 | )
147 |
148 | return msg_structure
149 |
150 |
151 | class ZapierAlertTarget(AlertWebhookTarget):
152 | def __init__(self, zapier_webhook_path: str):
153 | # this is everything after https://hooks.zapier.com/hooks/catch
154 | # FORMAT: /XXXXX/XXXXXX
155 | self.zapier_webhook_path = zapier_webhook_path
156 |
157 | def _alert_webhook_url(self) -> str:
158 | return f"https://hooks.zapier.com/services{self.zapier_webhook_path}"
159 |
160 | def _format_alert(self, alert: Alert) -> Dict[str, str]:
161 | return asdict(alert)
162 |
163 |
164 | class CustomAlertTarget(AlertWebhookTarget):
165 | def __init__(self, webhook_url: str):
166 | self.webhook_url = webhook_url
167 |
168 | def _alert_webhook_url(self) -> str:
169 | return self.webhook_url
170 |
171 | def _format_alert(self, alert: Alert) -> Dict[str, str]:
172 | return asdict(alert)
173 |
--------------------------------------------------------------------------------
/flare/flare/analytics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import hashlib
3 | from mixpanel import Mixpanel # type: ignore
4 | from typing import Dict, Union
5 | from flare.constants import MIXPANEL_API_KEY
6 | import logging
7 | import uuid
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | class AnalyticsClient:
13 | BASELINE_CREATED_EVENT_NAME = "baseline_created"
14 |
15 | def __init__(self):
16 | mp_api_key = MIXPANEL_API_KEY
17 |
18 | analytics_configured = bool(mp_api_key) # false if None (unset) or ""
19 |
20 | analytics_opted_out = (
21 | os.environ.get("FLARE_ANALYTICS_OPT_OUT") is not None
22 | )
23 |
24 | self.analytics_enabled = analytics_configured and (
25 | not analytics_opted_out
26 | )
27 |
28 | if self.analytics_enabled:
29 | logger.warning(
30 | "Flare Analytics ENABLED. "
31 | + "To opt out set FLARE_ANALYTICS_OPT_OUT=1 "
32 | + "in your environment."
33 | )
34 | else:
35 | logger.warning("Flare Analytics DISABLED.")
36 | logger.debug(
37 | "Analytics disabled diagnosis: "
38 | + f"API key present: {analytics_configured}. "
39 | + f"Opt out: {analytics_opted_out}"
40 | )
41 |
42 | self.mp_client = Mixpanel(mp_api_key)
43 | self.client_id = str(uuid.uuid4())
44 |
45 | def track_baseline_created(self):
46 | self._track_event(
47 | self.client_id,
48 | self.BASELINE_CREATED_EVENT_NAME,
49 | {},
50 | )
51 |
52 | def _track_event(
53 | self,
54 | distinct_id: str,
55 | event_name: str,
56 | event_data: Dict[str, Union[str, int, float]],
57 | ):
58 | if self.analytics_enabled:
59 | anonymized_id = self._anonymize_id(distinct_id)
60 | self.mp_client.track(anonymized_id, event_name, event_data)
61 | logger.debug(
62 | f"Reporting analytics event: {event_name} {event_data}"
63 | )
64 | else:
65 | logger.debug(
66 | "Analytics disabled. "
67 | + f"Not reporting analytics event: {event_name} {event_data}."
68 | )
69 |
70 | def _anonymize_id(self, distinct_id: str) -> str:
71 | return hashlib.sha256(distinct_id.encode()).hexdigest()
72 |
--------------------------------------------------------------------------------
/flare/flare/constants.py:
--------------------------------------------------------------------------------
1 | MIXPANEL_API_KEY = "fa624cf5c80af9b98fb383d4c3ea8dda"
2 |
--------------------------------------------------------------------------------
/flare/flare/constraints.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Optional
3 |
4 |
5 | @dataclass
6 | class NumericalConstraints:
7 | is_non_negative: bool
8 |
9 |
10 | @dataclass
11 | class StringConstraints:
12 | domains: List[str]
13 |
14 |
15 | @dataclass
16 | class DistributionConstraints:
17 | # 'Enabled' | 'Disabled'
18 | perform_comparison: str = "Enabled"
19 | comparison_threshold: float = 0.1
20 | # 'Simple' | 'Robust'
21 | comparison_method: str = "Simple"
22 |
23 |
24 | @dataclass
25 | class MonitoringConfig:
26 | distribution_constraints: DistributionConstraints
27 | # 'Enabled' | 'Disabled'
28 | evaluate_constraints: str = "Enabled"
29 | # 'Enabled' | 'Disabled'
30 | emit_metrics: str = "Enabled"
31 | datatype_check_threshold: float = 0.1
32 | domain_content_threshold: float = 0.1
33 |
34 |
35 | @dataclass
36 | class Feature:
37 | name: str
38 | inferred_type: str
39 | # denotes observed non-null value percentage
40 | completeness: float
41 | num_constraints: Optional[NumericalConstraints] = None
42 | string_constraints: Optional[StringConstraints] = None
43 | monitoringConfigOverrides: Optional[MonitoringConfig] = None
44 |
45 |
46 | @dataclass
47 | class Constraints:
48 | features: List[Feature]
49 | monitoring_config: MonitoringConfig
50 | version: int = 0
51 |
--------------------------------------------------------------------------------
/flare/flare/examples.py:
--------------------------------------------------------------------------------
1 | import pandas as pd # type: ignore
2 | import numpy as np # type: ignore
3 |
4 | EXAMPLE_DF_N_COLS_PER_TYPE = 2
5 | EXAMPLE_DF_N_PER_COL = 1000
6 | EXAMPLE_DF_MIN_INT = -100
7 | EXAMPLE_DF_MAX_INT = 100
8 | EXAMPLE_DF_STRING_DOMAINS = ["eeny", "meeny", "miny", "moe"]
9 |
10 | EXAMPLE_DF_NULL_PERCENTS = {"int_0": 0.1, "float_0": 0.2, "str_0": 0.3}
11 |
12 |
13 | def generate_example_dataframe() -> pd.DataFrame:
14 |
15 | ints_df = pd.DataFrame(
16 | np.random.randint(
17 | EXAMPLE_DF_MIN_INT,
18 | EXAMPLE_DF_MAX_INT,
19 | size=(EXAMPLE_DF_N_PER_COL, EXAMPLE_DF_N_COLS_PER_TYPE),
20 | ),
21 | columns=[f"int_{n}" for n in range(EXAMPLE_DF_N_COLS_PER_TYPE)],
22 | )
23 |
24 | positive_ints_df = pd.DataFrame(
25 | np.random.randint(
26 | 0,
27 | EXAMPLE_DF_MAX_INT,
28 | size=(EXAMPLE_DF_N_PER_COL, EXAMPLE_DF_N_COLS_PER_TYPE),
29 | ),
30 | columns=[
31 | f"positive_int_{n}" for n in range(EXAMPLE_DF_N_COLS_PER_TYPE)
32 | ],
33 | )
34 |
35 | floats_df = pd.DataFrame(
36 | np.random.random(
37 | size=(EXAMPLE_DF_N_PER_COL, EXAMPLE_DF_N_COLS_PER_TYPE)
38 | ),
39 | columns=[f"float_{n}" for n in range(EXAMPLE_DF_N_COLS_PER_TYPE)],
40 | )
41 |
42 | strings_df = pd.DataFrame(
43 | np.random.choice(
44 | EXAMPLE_DF_STRING_DOMAINS,
45 | size=(EXAMPLE_DF_N_PER_COL, EXAMPLE_DF_N_COLS_PER_TYPE),
46 | ),
47 | columns=[f"str_{n}" for n in range(EXAMPLE_DF_N_COLS_PER_TYPE)],
48 | )
49 |
50 | mixed_df = pd.DataFrame(
51 | np.random.choice(
52 | np.array([EXAMPLE_DF_STRING_DOMAINS[0], 1, 2.0], dtype="object"),
53 | size=(EXAMPLE_DF_N_PER_COL, EXAMPLE_DF_N_COLS_PER_TYPE),
54 | ),
55 | columns=[f"mixed_{n}" for n in range(EXAMPLE_DF_N_COLS_PER_TYPE)],
56 | )
57 |
58 | full_df = pd.concat(
59 | [ints_df, positive_ints_df, floats_df, strings_df, mixed_df], axis=1
60 | )
61 |
62 | for target_col, target_frac in EXAMPLE_DF_NULL_PERCENTS.items():
63 | indeces = np.random.choice(
64 | range(EXAMPLE_DF_N_PER_COL),
65 | size=int(EXAMPLE_DF_N_PER_COL * target_frac),
66 | replace=False,
67 | )
68 | full_df.loc[indeces, target_col] = None
69 |
70 | return full_df
71 |
--------------------------------------------------------------------------------
/flare/flare/generators.py:
--------------------------------------------------------------------------------
1 | from flare.constraints import (
2 | Constraints,
3 | NumericalConstraints,
4 | StringConstraints,
5 | MonitoringConfig,
6 | DistributionConstraints,
7 | )
8 | from flare.constraints import Feature as ConstraintFeature
9 | from flare.statistics import (
10 | Statistics,
11 | Dataset,
12 | NumericalStatistics,
13 | StringStatistics,
14 | CommonStatistics,
15 | )
16 | from flare.statistics import Feature as StatisticsFeature
17 | from flare.analytics import AnalyticsClient
18 |
19 | from flare.types import FeatureType
20 | import pandas as pd # type: ignore
21 | from dataclasses import asdict
22 | import json
23 | import numpy as np
24 |
25 | MAX_UNIQUES_THRESHOLD = 20
26 | MAX_ROWS_FOR_OBJECT_TYPE_INFERENCE = 10 ** 5
27 |
28 | analytics = AnalyticsClient()
29 |
30 |
31 | class NumpyEncoder(json.JSONEncoder):
32 | def default(self, obj):
33 | if isinstance(obj, np.integer):
34 | return int(obj)
35 | elif isinstance(obj, np.floating):
36 | return float(obj)
37 | elif isinstance(obj, np.ndarray):
38 | return obj.tolist()
39 | return json.JSONEncoder.default(self, obj)
40 |
41 |
42 | def baseline(df: pd.DataFrame):
43 | statistics = gen_statistics(df)
44 | constraints = gen_constraints(df)
45 |
46 | with open("constraints.json", "w") as f:
47 | json.dump(asdict(constraints), f, cls=NumpyEncoder)
48 |
49 | with open("statistics.json", "w") as f:
50 | json.dump(asdict(statistics), f, cls=NumpyEncoder)
51 |
52 | analytics.track_baseline_created()
53 |
54 |
55 | def gen_statistics(df: pd.DataFrame) -> Statistics:
56 | statistics = Statistics(
57 | dataset=Dataset(len(df)),
58 | features=[
59 | _create_statistics_feature(feature_series)
60 | for name, feature_series in df.iteritems()
61 | ],
62 | )
63 |
64 | return statistics
65 |
66 |
67 | def gen_constraints(df: pd.DataFrame) -> Constraints:
68 | features = [
69 | _create_constraints_feature(feature_series)
70 | for name, feature_series in df.iteritems()
71 | ]
72 |
73 | monitoring_config = MonitoringConfig(DistributionConstraints())
74 |
75 | constraints = Constraints(features, monitoring_config)
76 |
77 | return constraints
78 |
79 |
80 | def _create_statistics_feature(feature_series: pd.Series) -> StatisticsFeature:
81 | feature_name = feature_series.name
82 | feature_type = _infer_feature_type(feature_series)
83 |
84 | feature = StatisticsFeature(
85 | name=feature_name, inferred_type=feature_type.value
86 | )
87 |
88 | # NOTE (Josh): this is duplicative of the completeness
89 | # constraint and also seems internally redundant?
90 | n_missing = feature_series.isna().sum()
91 | n_present = len(feature_series) - n_missing
92 |
93 | common = CommonStatistics(n_present, n_missing)
94 |
95 | if feature_type in {FeatureType.INTEGRAL, FeatureType.FRACTIONAL}:
96 | feature.numerical_statistics = NumericalStatistics(
97 | common=common,
98 | mean=feature_series.mean(),
99 | sum=feature_series.sum(),
100 | std_dev=feature_series.std(),
101 | min=feature_series.min(),
102 | max=feature_series.max(),
103 | )
104 |
105 | elif feature_type == FeatureType.STRING:
106 | feature.string_statistics = StringStatistics(
107 | common=common, distinct_count=len(feature_series.dropna().unique())
108 | )
109 |
110 | return feature
111 |
112 |
113 | def _create_constraints_feature(
114 | feature_series: pd.Series,
115 | ) -> ConstraintFeature:
116 | feature_name = feature_series.name
117 |
118 | # 1. Infer type
119 | feature_type = _infer_feature_type(feature_series)
120 |
121 | # 2. Measure completeness
122 | n_missing = feature_series.isna().sum()
123 | feature_completeness = 1 - (n_missing / len(feature_series))
124 |
125 | feature = ConstraintFeature(
126 | name=feature_name,
127 | inferred_type=feature_type.value,
128 | completeness=feature_completeness,
129 | )
130 |
131 | # 3. Enrich with type-specific constraints
132 | if feature_type in {FeatureType.INTEGRAL, FeatureType.FRACTIONAL}:
133 | feature.num_constraints = NumericalConstraints(
134 | is_non_negative=bool(feature_series.min() >= 0)
135 | )
136 |
137 | elif feature_type == FeatureType.STRING:
138 | uniques = feature_series.dropna().unique()
139 | if len(uniques) <= MAX_UNIQUES_THRESHOLD:
140 | feature.string_constraints = StringConstraints(
141 | domains=list(uniques)
142 | )
143 |
144 | return feature
145 |
146 |
147 | def _infer_feature_type(feature_series: pd.Series) -> FeatureType:
148 | dtype_name = str(feature_series.dtype)
149 |
150 | # {"int8", "int16", "int32", "int64", "intp"}
151 | # {"uint8", "uint16", "uint32", "uint64", "uintp"}
152 | if dtype_name.startswith("int") or dtype_name.startswith("uint"):
153 | feature_type = FeatureType.INTEGRAL
154 |
155 | # {"float16", "float32", "float64", "float96", "float128"}:
156 | elif dtype_name.startswith("float"):
157 | feature_type = FeatureType.FRACTIONAL
158 |
159 | # {"string", "U16/32/...", "=U16/32/..."}
160 | elif (dtype_name == "string") or (dtype_name[:2] in {"U", "=U"}):
161 | feature_type = FeatureType.STRING
162 |
163 | elif dtype_name == "object":
164 | # Assume unknown type as object
165 | # dtype is assigned to mixed type
166 | # data
167 | feature_type = FeatureType.UNKNOWN
168 |
169 | # If dataset is small-ish, attempt to infer
170 | # if object dtype is actually strings
171 | if len(feature_series) <= MAX_ROWS_FOR_OBJECT_TYPE_INFERENCE:
172 | types = set(map(type, feature_series.dropna()))
173 | if types == {str}:
174 | feature_type = FeatureType.STRING
175 |
176 | else:
177 | # Bools, datetimes, etc are all treated as unknown
178 | feature_type = FeatureType.UNKNOWN
179 |
180 | return feature_type
181 |
--------------------------------------------------------------------------------
/flare/flare/statistics.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import List, Optional
3 |
4 |
5 | @dataclass
6 | class KLLSketchParameters:
7 | c: float
8 | k: float
9 |
10 |
11 | @dataclass
12 | class KLLBucket:
13 | lower_bound: float
14 | upper_bound: float
15 | count: int
16 |
17 |
18 | @dataclass
19 | class KLLSketch:
20 | parameters: KLLSketchParameters
21 | data: List[List[float]]
22 |
23 |
24 | @dataclass
25 | class CategoryBucket:
26 | value: str
27 | count: int
28 |
29 |
30 | @dataclass
31 | class CategoricalDistribution:
32 | buckets: List[CategoryBucket]
33 |
34 |
35 | @dataclass
36 | class KLLDistribution:
37 | buckets: List[KLLBucket]
38 | sketch: KLLSketch
39 |
40 |
41 | @dataclass
42 | class StringDistribution:
43 | categorical: CategoricalDistribution
44 |
45 |
46 | @dataclass
47 | class NumericalDistribution:
48 | kll: KLLDistribution
49 |
50 |
51 | @dataclass
52 | class CommonStatistics:
53 | num_present: int
54 | num_missing: int
55 |
56 |
57 | @dataclass
58 | class NumericalStatistics:
59 | common: CommonStatistics
60 | mean: float
61 | sum: float
62 | std_dev: float
63 | min: float
64 | max: float
65 |
66 | # TODO: make this non-optional when we
67 | # decide to tackle sketches
68 | distribution: Optional[NumericalDistribution] = None
69 |
70 |
71 | @dataclass
72 | class StringStatistics:
73 | common: CommonStatistics
74 | distinct_count: int
75 |
76 | # TODO: make this non-optional when we
77 | # decide to tackle string distros
78 | distribution: Optional[StringDistribution] = None
79 |
80 |
81 | @dataclass
82 | class Feature:
83 | name: str
84 | inferred_type: str
85 | numerical_statistics: Optional[NumericalStatistics] = None
86 | string_statistics: Optional[StringStatistics] = None
87 |
88 |
89 | @dataclass
90 | class Dataset:
91 | item_count: int
92 |
93 |
94 | @dataclass
95 | class Statistics:
96 | dataset: Dataset
97 | version: int = 0
98 | features: List[Feature] = field(default_factory=list)
99 |
--------------------------------------------------------------------------------
/flare/flare/types.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class FeatureType(Enum):
5 | FRACTIONAL = "Fractional"
6 | INTEGRAL = "Integral"
7 | STRING = "String"
8 | UNKNOWN = "Unknown"
9 |
--------------------------------------------------------------------------------
/flare/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 79
3 | target-version = ['py38']
4 |
--------------------------------------------------------------------------------
/flare/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | black==20.8b1
2 | flake8==3.8.4
3 | mypy==0.800
4 | pre-commit==2.14.0
5 | pytest==6.2.2
6 | requests==2.26.0
7 |
--------------------------------------------------------------------------------
/flare/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 | import os
3 |
4 | setuptools.setup(
5 | name="domino-flare",
6 | version=os.environ.get("RELEASE_VERSION", "SNAPSHOT"),
7 | author="Kevin Flansburg, Josh Broomberg",
8 | author_email="kevin.flansburg@dominodatalab.com,josh.broomberg@dominodatalab.com",
9 | description="Lightweight model monitoring framework.",
10 | url="https://github.com/dominodatalab/domino-research/flare",
11 | packages=setuptools.find_packages(),
12 | install_requires=[
13 | "dacite==1.6",
14 | "mixpanel==4.9.0",
15 | "numpy==1.21.1",
16 | "pandas==1.3.1",
17 | "requests==2.26.0",
18 | ],
19 | )
20 |
--------------------------------------------------------------------------------
/flare/tests/test_app.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from flare.examples import (
3 | generate_example_dataframe,
4 | EXAMPLE_DF_NULL_PERCENTS,
5 | EXAMPLE_DF_STRING_DOMAINS,
6 | )
7 |
8 | from flare.types import FeatureType
9 | from flare.generators import gen_constraints
10 |
11 | from flare.constraints import Constraints
12 | from flare.constraints import Feature as ConstraintsFeature
13 |
14 |
15 | def test_type_inference():
16 | test_df = generate_example_dataframe()
17 | constraints = gen_constraints(test_df)
18 |
19 | # Integer column with no missing values is
20 | # detected as Integral
21 | assert test_df["int_1"].isna().sum() == 0
22 | int_1_feature = fetch_constrain_feature("int_1", constraints)
23 | assert int_1_feature.inferred_type == FeatureType.INTEGRAL.value
24 |
25 | # Integer column with missing values is
26 | # detected as Fractional
27 | assert test_df["int_0"].isna().sum() > 0
28 | int_0_feature = fetch_constrain_feature("int_0", constraints)
29 | assert int_0_feature.inferred_type == FeatureType.FRACTIONAL.value
30 |
31 | # Fractional column with missing values is
32 | # detected as Fractional and has correct completeness
33 | assert test_df["float_0"].isna().sum() > 0
34 | float_0_feature = fetch_constrain_feature("float_0", constraints)
35 | assert float_0_feature.inferred_type == FeatureType.FRACTIONAL.value
36 |
37 | # Mixed-type column is detected as unknown
38 | assert len(set(map(type, test_df["mixed_0"].dropna()))) > 1
39 | mixed_0_feature = fetch_constrain_feature("mixed_0", constraints)
40 | assert mixed_0_feature.inferred_type == FeatureType.UNKNOWN.value
41 |
42 | # String column with missing is detected as String
43 | assert test_df["str_0"].isna().sum() > 0
44 | str_0_feature = fetch_constrain_feature("str_0", constraints)
45 | assert str_0_feature.inferred_type == FeatureType.STRING.value
46 |
47 |
48 | def test_completeness_inference():
49 | test_df = generate_example_dataframe()
50 | constraints = gen_constraints(test_df)
51 |
52 | float_0_feature = fetch_constrain_feature("float_0", constraints)
53 |
54 | assert (
55 | float_0_feature.completeness
56 | - (1 - EXAMPLE_DF_NULL_PERCENTS["float_0"])
57 | <= 0.1 # allow some room for rounding
58 | )
59 |
60 |
61 | def test_string_constraints_inference():
62 | test_df = generate_example_dataframe()
63 | constraints = gen_constraints(test_df)
64 |
65 | str_0_feature = fetch_constrain_feature("str_0", constraints)
66 |
67 | assert set(str_0_feature.string_constraints.domains) == set(
68 | EXAMPLE_DF_STRING_DOMAINS
69 | )
70 |
71 |
72 | def test_numerical_constraints():
73 | test_df = generate_example_dataframe()
74 | constraints = gen_constraints(test_df)
75 |
76 | int_0_feature = fetch_constrain_feature("int_0", constraints)
77 | pos_int_0_feature = fetch_constrain_feature("positive_int_0", constraints)
78 |
79 | assert test_df["int_0"].dropna().min() < 0
80 | assert not int_0_feature.num_constraints.is_non_negative
81 |
82 | assert test_df["positive_int_0"].dropna().min() >= 0
83 | assert pos_int_0_feature.num_constraints.is_non_negative
84 |
85 |
86 | def fetch_constrain_feature(
87 | feature_name: str, constraints: Constraints
88 | ) -> Optional[ConstraintsFeature]:
89 |
90 | feature: Optional[ConstraintsFeature] = next(
91 | filter(lambda x: x.name == feature_name, constraints.features)
92 | )
93 |
94 | return feature
95 |
--------------------------------------------------------------------------------
/flare/tests/test_runtime.py:
--------------------------------------------------------------------------------
1 | from flare.runtime import Flare
2 | import tempfile
3 | import json
4 | import pytest
5 | from dataclasses import asdict
6 | import os
7 | import pandas as pd # type: ignore
8 | from flare.runtime import FLARE_STATISTICS_PATH_VAR
9 | import logging
10 | from flare.alerting import FeatureAlert, CustomAlertTarget
11 |
12 |
13 | def generate_statistics():
14 | from flare.statistics import (
15 | Dataset,
16 | StringStatistics,
17 | CommonStatistics,
18 | StringDistribution,
19 | CategoricalDistribution,
20 | NumericalStatistics,
21 | NumericalDistribution,
22 | KLLDistribution,
23 | KLLSketch,
24 | KLLSketchParameters,
25 | Statistics,
26 | Feature as FeatureStatistics,
27 | )
28 |
29 | float_feature = FeatureStatistics(
30 | name="float",
31 | inferred_type="Fractional",
32 | numerical_statistics=NumericalStatistics(
33 | common=CommonStatistics(num_present=1, num_missing=0),
34 | mean=1.0,
35 | sum=1.0,
36 | min=0.0,
37 | max=2.0,
38 | std_dev=1.0,
39 | distribution=NumericalDistribution(
40 | kll=KLLDistribution(
41 | buckets=[],
42 | sketch=KLLSketch(
43 | parameters=KLLSketchParameters(
44 | c=0.0,
45 | k=0.0,
46 | ),
47 | data=[],
48 | ),
49 | )
50 | ),
51 | ),
52 | string_statistics=None,
53 | )
54 | int_feature = FeatureStatistics(
55 | name="int",
56 | inferred_type="Integral",
57 | numerical_statistics=NumericalStatistics(
58 | common=CommonStatistics(num_present=1, num_missing=0),
59 | mean=1.0,
60 | sum=1.0,
61 | min=0.0,
62 | max=3.0,
63 | std_dev=1.0,
64 | distribution=NumericalDistribution(
65 | kll=KLLDistribution(
66 | buckets=[],
67 | sketch=KLLSketch(
68 | parameters=KLLSketchParameters(
69 | c=0.0,
70 | k=0.0,
71 | ),
72 | data=[],
73 | ),
74 | )
75 | ),
76 | ),
77 | string_statistics=None,
78 | )
79 | string_feature = FeatureStatistics(
80 | name="string",
81 | inferred_type="String",
82 | numerical_statistics=None,
83 | string_statistics=StringStatistics(
84 | distinct_count=1,
85 | distribution=StringDistribution(
86 | categorical=CategoricalDistribution(buckets=[])
87 | ),
88 | common=CommonStatistics(num_present=1, num_missing=0),
89 | ),
90 | )
91 |
92 | stats = Statistics(
93 | version=0,
94 | dataset=Dataset(item_count=1),
95 | features=[float_feature, int_feature, string_feature],
96 | )
97 |
98 | return stats
99 |
100 |
101 | @pytest.fixture()
102 | def statistics():
103 | with tempfile.NamedTemporaryFile(suffix=".json", mode="w") as f:
104 | json.dump(asdict(generate_statistics()), f)
105 | f.flush()
106 | yield f.name
107 |
108 |
109 | def test_bound(statistics):
110 | os.environ[FLARE_STATISTICS_PATH_VAR] = statistics
111 | level = logging.getLevelName("TRACE")
112 | logging.basicConfig(level=level)
113 | x = pd.DataFrame([[-1.0, 4, "3"]], columns=["float", "int", "string"])
114 | target = CustomAlertTarget("")
115 | session = Flare("test-model", x, target)
116 | assert session.feature_alerts == [
117 | FeatureAlert(name="float", kind="Bound"),
118 | FeatureAlert(name="int", kind="Bound"),
119 | ]
120 |
--------------------------------------------------------------------------------
/guides/mlflow/.gitignore:
--------------------------------------------------------------------------------
1 | mlflowArtifactData
2 | mlflowDBData
3 | checkpoint.db
--------------------------------------------------------------------------------
/guides/mlflow/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM continuumio/miniconda3:latest
2 | RUN pip install mlflow boto3 pymysql
3 |
--------------------------------------------------------------------------------
/guides/mlflow/README.md:
--------------------------------------------------------------------------------
1 | ## MLflow Quickstart
2 |
3 | This is a quick guide to running MLflow locally. A full MLflow
4 | installation consists of 3 components:
5 |
6 | * MLflow tracking server / model registry
7 | * Database backend - stores run and model metadata
8 | * Storage backend - stores run and model artifacts
9 |
10 | This guide uses `docker-compose` to run MLflow, MySQL (as the database backend),
11 | and Minio as an S3-compatible artifact backend. You'll need Docker and `docker-compose`
12 | installed on your machine.
13 |
14 | ### 0. Clone this repo
15 |
16 | Clone the repo and change to the `guides/mlflow` directory.
17 |
18 | HTTPS: `git clone https://github.com/dominodatalab/domino-research.git && cd domino-research/guides/mlflow`
19 |
20 | SSH: `git clone git@github.com:dominodatalab/domino-research.git && cd domino-research/guides/mlflow`
21 |
22 | ### 1. Start MLflow
23 |
24 | Make sure you're in the `guides/mlflow` subdirectory then run the command below to start MLflow:
25 |
26 | ```bash
27 | docker-compose up -d
28 | ```
29 |
30 | MLflow will take about 15-30 seconds to start up and that's it!
31 | If you navigate to `http://localhost:5555`, you should see the MLflow UI.
32 | You can now proceed with the quick start for the project that brought your here.
33 |
34 | Note that we pre-seed the registry with a simple demo model called `ScikitElasticnetWineModel`.
35 | It has 3 versions, with one marked for Staging and one marked for Production. The training
36 | code for this model is in `seed_models/scikit_elasticnet_wine/train.py`.
37 |
38 |
39 |
40 | ### 2. [Optional] Using the MLflow registry with your own models
41 |
42 | If you'd like to add new models using your own training code,
43 | you can use the sample configuration below. Note that any model versions
44 | used with Bridge will need a valid `MLmodel` specification (one that allows you to call
45 | `mlflow models serve -m your_model`). This file will be created for you if you log
46 | with the `mlflow..log_model` sdk.
47 |
48 | ```python
49 | import os
50 | import mlflow
51 |
52 | SERVER_URI = "http://localhost:5000"
53 | S3_ENDPOINT_URL = "http://localhost:9000"
54 |
55 | os.environ["MLFLOW_S3_ENDPOINT_URL"] = S3_ENDPOINT_URL
56 | os.environ["MLFLOW_S3_IGNORE_TLS"] = true
57 | os.environ["AWS_ACCESS_KEY_ID"] = AKIAIfoobar
58 | os.environ["AWS_SECRET_ACCESS_KEY"] = deadbeef
59 |
60 | # For MLflow tracking:
61 | mlflow.set_tracking_uri(SERVER_URI)
62 | # ...
63 |
64 | # For the MLflow API client:
65 | MlflowClient(
66 | registry_uri=SERVER_URI, tracking_uri=SERVER_URI
67 | )
68 | # ...
69 | ```
70 |
71 | ### 3. [Optional] Stop MLflow
72 |
73 | Make sure you're in the `guides/mlflow` subdirectory then
74 | run the command below to stop MLflow:
75 |
76 | ```
77 | docker-compose down
78 | ```
79 |
80 | To wipe the artifact and metadata stored locally by MLflow, delete
81 | the `mlflowArtifactData` and `mlflowDBData` subdirectories in the `guides/mlflow` folder.
82 |
--------------------------------------------------------------------------------
/guides/mlflow/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | version: '3.2'
2 | services:
3 | minio:
4 | image: minio/minio
5 | container_name: mlflow-minio
6 | environment:
7 | - "MINIO_ROOT_USER=AKIAIfoobar"
8 | - "MINIO_ROOT_PASSWORD=deadbeef"
9 | healthcheck:
10 | test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
11 | interval: 30s
12 | timeout: 20s
13 | retries: 3
14 | ports:
15 | - 9000:9000
16 | command: "server /data"
17 | volumes:
18 | - ./mlflowArtifactData:/data
19 | networks:
20 | - mlflow
21 |
22 | createbuckets:
23 | container_name: mlflow-createbuckets
24 | image: minio/mc
25 | depends_on:
26 | - minio
27 | entrypoint: >
28 | /bin/sh -c "
29 | until (/usr/bin/mc config host add myminio http://minio:9000 AKIAIfoobar deadbeef) do echo '...waiting...' && sleep 1; done;
30 | /usr/bin/mc mb --ignore-existing --region us-east-2 myminio/local-mlflow-artifacts;
31 | /usr/bin/mc policy set public myminio/local-mlflow-artifacts;
32 | exit 0;
33 | "
34 | networks:
35 | - mlflow
36 |
37 | db:
38 | restart: always
39 | image: mysql/mysql-server:5.7.28
40 | container_name: mlflow-db
41 | expose:
42 | - "3306"
43 | environment:
44 | - MYSQL_DATABASE=mlflow
45 | - MYSQL_USER=user
46 | - MYSQL_PASSWORD=pass
47 | - MYSQL_ROOT_PASSWORD=rootpass
48 | volumes:
49 | - ./mlflowDBData:/var/lib/mysql
50 | networks:
51 | - mlflow
52 |
53 | mlflow:
54 | container_name: mlflow-server
55 | image: local-mlflow
56 | build:
57 | context: .
58 | dockerfile: Dockerfile
59 | healthcheck:
60 | test: ["CMD", "curl", "-f", "http://localhost:5555/health"]
61 | interval: 2s
62 | timeout: 60s
63 | ports:
64 | - "5555:5555"
65 | environment:
66 | - AWS_ACCESS_KEY_ID=AKIAIfoobar
67 | - AWS_SECRET_ACCESS_KEY=deadbeef
68 | - MLFLOW_S3_ENDPOINT_URL=http://minio:9000
69 | - MLFLOW_S3_IGNORE_TLS=true
70 | networks:
71 | - mlflow
72 | entrypoint: mlflow server --backend-store-uri mysql+pymysql://user:pass@db:3306/mlflow --default-artifact-root s3://local-mlflow-artifacts/ -h 0.0.0.0 -p 5555
73 | depends_on:
74 | - db
75 | - createbuckets
76 |
77 | seed-mlflow:
78 | image: seed-mlflow
79 | container_name: mlflow-seed-models
80 | build:
81 | context: seed_models
82 | dockerfile: Dockerfile
83 | environment:
84 | - AWS_ACCESS_KEY_ID=AKIAIfoobar
85 | - AWS_SECRET_ACCESS_KEY=deadbeef
86 | - MLFLOW_S3_ENDPOINT_URL=http://minio:9000
87 | - MLFLOW_S3_IGNORE_TLS=true
88 | networks:
89 | - mlflow
90 | depends_on:
91 | - mlflow
92 |
93 | networks:
94 | mlflow:
95 | name: mlflow
96 | driver: bridge
97 |
--------------------------------------------------------------------------------
/guides/mlflow/seed_models/Dockerfile:
--------------------------------------------------------------------------------
1 | # Don't really need conda but this image is pulled anyway to run MLflow.
2 | # So re-using it is efficient.
3 | FROM continuumio/miniconda3:latest
4 |
5 | COPY scikit_elasticnet_wine/requirements.txt /home/scikit_elasticnet_wine/requirements.txt
6 | RUN pip install -r /home/scikit_elasticnet_wine/requirements.txt
7 |
8 | COPY . /home
9 |
10 | ENTRYPOINT ["/bin/bash", "-c"]
11 |
12 | CMD ["./home/wait-for-it.sh -h mlflow -p 5555 -s -t 60 -- python ./home/seed-if-empty.py"]
--------------------------------------------------------------------------------
/guides/mlflow/seed_models/scikit_elasticnet_wine/requirements.txt:
--------------------------------------------------------------------------------
1 | scikit-learn==0.24.2
2 | mlflow==1.20.2
3 | boto3==1.18.39
--------------------------------------------------------------------------------
/guides/mlflow/seed_models/scikit_elasticnet_wine/train.py:
--------------------------------------------------------------------------------
1 | # The data set used in this example is from http://archive.ics.uci.edu/ml/datasets/Wine+Quality
2 | # P. Cortez, A. Cerdeira, F. Almeida, T. Matos and J. Reis.
3 | # Modeling wine preferences by data mining from physicochemical properties.
4 | # In Decision Support Systems, Elsevier, 47(4):547-553, 2009.
5 |
6 | import warnings
7 | import sys
8 | import os
9 | import pandas as pd
10 | import numpy as np
11 | from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
12 | from sklearn.model_selection import train_test_split
13 | from sklearn.linear_model import ElasticNet
14 | import mlflow
15 | import mlflow.sklearn
16 |
17 | import logging
18 |
19 | logging.basicConfig(level=logging.WARN)
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | REMOTE_SERVER_URI = "http://mlflow:5555"
24 | S3_ENDPOINT_URL = "http://minio:9000"
25 |
26 | mlflow.set_tracking_uri(REMOTE_SERVER_URI)
27 | os.environ["MLFLOW_S3_ENDPOINT_URL"] = S3_ENDPOINT_URL
28 |
29 | MODEL_NAME = "ScikitElasticnetWineModel"
30 |
31 |
32 | def eval_metrics(actual, pred):
33 | rmse = np.sqrt(mean_squared_error(actual, pred))
34 | mae = mean_absolute_error(actual, pred)
35 | r2 = r2_score(actual, pred)
36 | return rmse, mae, r2
37 |
38 |
39 | if __name__ == "__main__":
40 | warnings.filterwarnings("ignore")
41 | np.random.seed(40)
42 |
43 | # Read the wine-quality csv file from disk
44 | CUR_DIR = os.path.dirname(os.path.abspath(__file__))
45 | csv_path = os.path.join(CUR_DIR, "winequality-red.csv")
46 |
47 | try:
48 | data = pd.read_csv(csv_path, sep=";")
49 | except Exception as e:
50 | logger.exception(
51 | "Unable to read training & test CSV. Error: %s", e
52 | )
53 | raise e
54 |
55 | # Split the data into training and test sets. (0.75, 0.25) split.
56 | train, test = train_test_split(data)
57 |
58 | # The predicted column is "quality" which is a scalar from [3, 9]
59 | train_x = train.drop(["quality"], axis=1)
60 | test_x = test.drop(["quality"], axis=1)
61 | train_y = train[["quality"]]
62 | test_y = test[["quality"]]
63 |
64 | alpha = float(sys.argv[1]) if len(sys.argv) > 1 else 0.5
65 | l1_ratio = float(sys.argv[2]) if len(sys.argv) > 2 else 0.5
66 |
67 | with mlflow.start_run():
68 | lr = ElasticNet(alpha=alpha, l1_ratio=l1_ratio, random_state=42)
69 | lr.fit(train_x, train_y)
70 |
71 | predicted_qualities = lr.predict(test_x)
72 |
73 | (rmse, mae, r2) = eval_metrics(test_y, predicted_qualities)
74 |
75 | print("Elasticnet model (alpha=%f, l1_ratio=%f):" % (alpha, l1_ratio))
76 | print(" RMSE: %s" % rmse)
77 | print(" MAE: %s" % mae)
78 | print(" R2: %s" % r2)
79 |
80 | mlflow.log_param("alpha", alpha)
81 | mlflow.log_param("l1_ratio", l1_ratio)
82 | mlflow.log_metric("rmse", rmse)
83 | mlflow.log_metric("r2", r2)
84 | mlflow.log_metric("mae", mae)
85 |
86 |
87 | # Register the model
88 | mlflow.sklearn.log_model(lr, "model", registered_model_name=MODEL_NAME)
89 |
--------------------------------------------------------------------------------
/guides/mlflow/seed_models/seed-if-empty.py:
--------------------------------------------------------------------------------
1 | from mlflow.tracking import MlflowClient
2 | import subprocess
3 |
4 | client = MlflowClient(
5 | registry_uri="http://mlflow:5555",
6 | tracking_uri="http://mlflow:5555"
7 | )
8 |
9 | models = client.list_registered_models()
10 |
11 | if len(models) == 0:
12 | list_files = subprocess.run(["python", "/home/scikit_elasticnet_wine/train.py", "0.1"])
13 | list_files = subprocess.run(["python", "/home/scikit_elasticnet_wine/train.py", "0.5"])
14 | list_files = subprocess.run(["python", "/home/scikit_elasticnet_wine/train.py", "0.9"])
15 |
16 | client.transition_model_version_stage(
17 | name="ScikitElasticnetWineModel",
18 | version=1,
19 | stage="Production"
20 | )
21 |
22 | client.transition_model_version_stage(
23 | name="ScikitElasticnetWineModel",
24 | version=2,
25 | stage="Staging"
26 | )
27 |
--------------------------------------------------------------------------------
/guides/mlflow/seed_models/wait-for-it.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Use this script to test if a given TCP host/port are available
3 | # https://raw.githubusercontent.com/vishnubob/wait-for-it/master/wait-for-it.sh
4 |
5 | WAITFORIT_cmdname=${0##*/}
6 |
7 | echoerr() { if [[ $WAITFORIT_QUIET -ne 1 ]]; then echo "$@" 1>&2; fi }
8 |
9 | usage()
10 | {
11 | cat << USAGE >&2
12 | Usage:
13 | $WAITFORIT_cmdname host:port [-s] [-t timeout] [-- command args]
14 | -h HOST | --host=HOST Host or IP under test
15 | -p PORT | --port=PORT TCP port under test
16 | Alternatively, you specify the host and port as host:port
17 | -s | --strict Only execute subcommand if the test succeeds
18 | -q | --quiet Don't output any status messages
19 | -t TIMEOUT | --timeout=TIMEOUT
20 | Timeout in seconds, zero for no timeout
21 | -- COMMAND ARGS Execute command with args after the test finishes
22 | USAGE
23 | exit 1
24 | }
25 |
26 | wait_for()
27 | {
28 | if [[ $WAITFORIT_TIMEOUT -gt 0 ]]; then
29 | echoerr "$WAITFORIT_cmdname: waiting $WAITFORIT_TIMEOUT seconds for $WAITFORIT_HOST:$WAITFORIT_PORT"
30 | else
31 | echoerr "$WAITFORIT_cmdname: waiting for $WAITFORIT_HOST:$WAITFORIT_PORT without a timeout"
32 | fi
33 | WAITFORIT_start_ts=$(date +%s)
34 | while :
35 | do
36 | if [[ $WAITFORIT_ISBUSY -eq 1 ]]; then
37 | nc -z $WAITFORIT_HOST $WAITFORIT_PORT
38 | WAITFORIT_result=$?
39 | else
40 | (echo -n > /dev/tcp/$WAITFORIT_HOST/$WAITFORIT_PORT) >/dev/null 2>&1
41 | WAITFORIT_result=$?
42 | fi
43 | if [[ $WAITFORIT_result -eq 0 ]]; then
44 | WAITFORIT_end_ts=$(date +%s)
45 | echoerr "$WAITFORIT_cmdname: $WAITFORIT_HOST:$WAITFORIT_PORT is available after $((WAITFORIT_end_ts - WAITFORIT_start_ts)) seconds"
46 | break
47 | fi
48 | sleep 1
49 | done
50 | return $WAITFORIT_result
51 | }
52 |
53 | wait_for_wrapper()
54 | {
55 | # In order to support SIGINT during timeout: http://unix.stackexchange.com/a/57692
56 | if [[ $WAITFORIT_QUIET -eq 1 ]]; then
57 | timeout $WAITFORIT_BUSYTIMEFLAG $WAITFORIT_TIMEOUT $0 --quiet --child --host=$WAITFORIT_HOST --port=$WAITFORIT_PORT --timeout=$WAITFORIT_TIMEOUT &
58 | else
59 | timeout $WAITFORIT_BUSYTIMEFLAG $WAITFORIT_TIMEOUT $0 --child --host=$WAITFORIT_HOST --port=$WAITFORIT_PORT --timeout=$WAITFORIT_TIMEOUT &
60 | fi
61 | WAITFORIT_PID=$!
62 | trap "kill -INT -$WAITFORIT_PID" INT
63 | wait $WAITFORIT_PID
64 | WAITFORIT_RESULT=$?
65 | if [[ $WAITFORIT_RESULT -ne 0 ]]; then
66 | echoerr "$WAITFORIT_cmdname: timeout occurred after waiting $WAITFORIT_TIMEOUT seconds for $WAITFORIT_HOST:$WAITFORIT_PORT"
67 | fi
68 | return $WAITFORIT_RESULT
69 | }
70 |
71 | # process arguments
72 | while [[ $# -gt 0 ]]
73 | do
74 | case "$1" in
75 | *:* )
76 | WAITFORIT_hostport=(${1//:/ })
77 | WAITFORIT_HOST=${WAITFORIT_hostport[0]}
78 | WAITFORIT_PORT=${WAITFORIT_hostport[1]}
79 | shift 1
80 | ;;
81 | --child)
82 | WAITFORIT_CHILD=1
83 | shift 1
84 | ;;
85 | -q | --quiet)
86 | WAITFORIT_QUIET=1
87 | shift 1
88 | ;;
89 | -s | --strict)
90 | WAITFORIT_STRICT=1
91 | shift 1
92 | ;;
93 | -h)
94 | WAITFORIT_HOST="$2"
95 | if [[ $WAITFORIT_HOST == "" ]]; then break; fi
96 | shift 2
97 | ;;
98 | --host=*)
99 | WAITFORIT_HOST="${1#*=}"
100 | shift 1
101 | ;;
102 | -p)
103 | WAITFORIT_PORT="$2"
104 | if [[ $WAITFORIT_PORT == "" ]]; then break; fi
105 | shift 2
106 | ;;
107 | --port=*)
108 | WAITFORIT_PORT="${1#*=}"
109 | shift 1
110 | ;;
111 | -t)
112 | WAITFORIT_TIMEOUT="$2"
113 | if [[ $WAITFORIT_TIMEOUT == "" ]]; then break; fi
114 | shift 2
115 | ;;
116 | --timeout=*)
117 | WAITFORIT_TIMEOUT="${1#*=}"
118 | shift 1
119 | ;;
120 | --)
121 | shift
122 | WAITFORIT_CLI=("$@")
123 | break
124 | ;;
125 | --help)
126 | usage
127 | ;;
128 | *)
129 | echoerr "Unknown argument: $1"
130 | usage
131 | ;;
132 | esac
133 | done
134 |
135 | if [[ "$WAITFORIT_HOST" == "" || "$WAITFORIT_PORT" == "" ]]; then
136 | echoerr "Error: you need to provide a host and port to test."
137 | usage
138 | fi
139 |
140 | WAITFORIT_TIMEOUT=${WAITFORIT_TIMEOUT:-15}
141 | WAITFORIT_STRICT=${WAITFORIT_STRICT:-0}
142 | WAITFORIT_CHILD=${WAITFORIT_CHILD:-0}
143 | WAITFORIT_QUIET=${WAITFORIT_QUIET:-0}
144 |
145 | # Check to see if timeout is from busybox?
146 | WAITFORIT_TIMEOUT_PATH=$(type -p timeout)
147 | WAITFORIT_TIMEOUT_PATH=$(realpath $WAITFORIT_TIMEOUT_PATH 2>/dev/null || readlink -f $WAITFORIT_TIMEOUT_PATH)
148 |
149 | WAITFORIT_BUSYTIMEFLAG=""
150 | if [[ $WAITFORIT_TIMEOUT_PATH =~ "busybox" ]]; then
151 | WAITFORIT_ISBUSY=1
152 | # Check if busybox timeout uses -t flag
153 | # (recent Alpine versions don't support -t anymore)
154 | if timeout &>/dev/stdout | grep -q -e '-t '; then
155 | WAITFORIT_BUSYTIMEFLAG="-t"
156 | fi
157 | else
158 | WAITFORIT_ISBUSY=0
159 | fi
160 |
161 | if [[ $WAITFORIT_CHILD -gt 0 ]]; then
162 | wait_for
163 | WAITFORIT_RESULT=$?
164 | exit $WAITFORIT_RESULT
165 | else
166 | if [[ $WAITFORIT_TIMEOUT -gt 0 ]]; then
167 | wait_for_wrapper
168 | WAITFORIT_RESULT=$?
169 | else
170 | wait_for
171 | WAITFORIT_RESULT=$?
172 | fi
173 | fi
174 |
175 | if [[ $WAITFORIT_CLI != "" ]]; then
176 | if [[ $WAITFORIT_RESULT -ne 0 && $WAITFORIT_STRICT -eq 1 ]]; then
177 | echoerr "$WAITFORIT_cmdname: strict mode, refusing to execute subprocess"
178 | exit $WAITFORIT_RESULT
179 | fi
180 | exec "${WAITFORIT_CLI[@]}"
181 | else
182 | exit $WAITFORIT_RESULT
183 | fi
--------------------------------------------------------------------------------
/local-notebooks/.gitignore:
--------------------------------------------------------------------------------
1 | .aws
2 | .env
3 | .mypy_cache
--------------------------------------------------------------------------------
/local-notebooks/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM jupyter/base-notebook
2 |
3 | # Environment manager
4 | RUN mamba install -c conda-forge jupyterlab mamba_gator nb_conda_kernels
5 |
6 | # Install git extension and cell execution timer
7 | RUN pip install --upgrade jupyterlab-git jupyterlab_execute_time
8 |
9 | # Install S3FS and other deps
10 | USER root
11 | RUN apt-get update && apt-get install -y git s3fs jq cron && rm -rf /var/lib/apt/lists/*
12 |
13 | # Configure S3FS
14 | RUN echo "user_allow_other" >> /etc/fuse.conf
15 |
16 | # Configure cron
17 | COPY config/cron /usr/local/sbin/dom_cron/
18 | RUN crontab /usr/local/sbin/dom_cron/root
19 | RUN chmod +x /usr/local/sbin/dom_cron/dump_conda.sh
20 |
21 | USER jovyan
22 |
23 | # Install our custom notebook config
24 | COPY config/jupyter_lab_config.py /home/jovyan/.jupyter/jupyter_lab_config.py
25 |
26 | # Override settings
27 | COPY config/overrides.json /opt/conda/share/jupyter/lab/settings/overrides.json
28 |
29 | # Install our custom before-start scripts
30 | COPY config/start-notebook /usr/local/bin/start-notebook.d/
31 | COPY config/before-notebook /usr/local/bin/before-notebook.d/
32 |
33 | # Install the script for launching a tunnel.
34 | COPY --chmod=0755 config/start-tunnel /usr/local/bin/start-tunnel.d/
35 |
36 |
37 |
--------------------------------------------------------------------------------
/local-notebooks/README.md:
--------------------------------------------------------------------------------
1 | # Local Notebooks
2 |
3 | Locally run JupyterLab, with configuration for environment and data sharing,
4 | and real-time collaboration.
5 |
6 | # Installation
7 |
8 | ## 0. Requirements
9 |
10 | * `awscli` Version 2
11 | * Docker
12 |
13 | ## 1. Build Container Image
14 |
15 | ```bash
16 | docker build -t local-notebooks:latest .
17 | ```
18 |
19 | ## Provision AWS Resources
20 |
21 | Select a region, bucket name (must be globally unique), and user name (must be unique in account):
22 |
23 | ```bash
24 | # us-east-1 is recommended, s3fs has some known issues with other regions
25 | export AWS_REGION=us-east-1
26 | export AWS_BUCKET=
27 | export AWS_USER=local-notebooks
28 | ```
29 |
30 | Create S3 bucket:
31 |
32 | ```bash
33 | aws s3 mb ${AWS_BUCKET}
34 | ```
35 |
36 | Create IAM credentials for bucket access:
37 |
38 | ```bash
39 | # Create new IAM user
40 | aws iam create-user --user-name ${AWS_USER}
41 | # Create policy for access to S3 bucket created above
42 | aws iam create-policy --policy-name ${AWS_USER} --policy-document "$(envsubst < config/policy.json)"
43 | # Attach policy to user
44 | aws iam attach-user-policy --user-name ${AWS_USER} --policy-arn arn:aws:iam::$(aws sts get-caller-identity --query Account --output text):policy/${AWS_USER}
45 | # Generate API credentials for user
46 | mkdir .aws && aws iam create-access-key --user-name ${AWS_USER} > .aws/credentials
47 | ```
48 |
49 | ## Run Notebook
50 |
51 | ```bash
52 | docker run --name local-notebook \
53 | -it -p 8888:8888 \
54 | -e JUPYTER_ENABLE_LAB=yes \
55 | -e NB_USER=domino-research \
56 | --user root \
57 | --cap-add SYS_ADMIN \
58 | --security-opt apparmor:unconfined \
59 | --device /dev/fuse \
60 | -e AWS_BUCKET=${AWS_BUCKET} \
61 | -e AWS_REGION=${AWS_REGION} \
62 | -v $(pwd)/.aws:/etc/aws \
63 | local-notebooks
64 | ```
65 |
66 | # Usage
67 |
68 | Once the notebook starts, JupyterLab will print out a link to access the UI as usual, use the one with hostname `127.0.0.1` and port `8888`.
69 |
70 | ## Shared Data
71 |
72 | The S3 bucket is mounted at `/mnt/home`, and set as the default path for the notebook. Any files you store here will be persisted to S3, and can be used by other users with a similar configuration.
73 |
74 | ## Environments
75 |
76 | Environments are managed through Conda. We have included a number of customizations to make this easy:
77 |
78 | - To use an environment, simply select the kernel that corresponds to the environment. We use `nb_conda_kernels` to automatically make all Conda environments available for use in Jupyter as kernels. You will start with only a base environment, but can add as many as you like.
79 |
80 | - To create, modify and delete environments, navigate to settings > Conda Packages Manager. From here, you can create, delete and modify environments without touch the command line.
81 |
82 | - To use a shared an environment, start this tool pointing at the same S3 bucket as your collaborator. We will automatically load the latest version of all of the Conda environments that are backed up in the S3 bucket and make them available as kernels. Any changes you make or new environments you add will be saved back to the S3 bucket.
83 |
84 | Notes and limitations:
85 |
86 | - Environments are backed up every minute. Changes you make in the last minute before shut down may not be saved. We suggest waiting a minute before shutting down if you edit the environment.
87 |
88 | - Loading shared environments is done synchronously prior to server launch. We are looking into solutions to speed this up by saving the entire environment to S3 rather than only a YAML of the installed packages. We are also looking into making this process asynchronous.
89 |
90 | - You can also create and modify environments using the `conda` CLI. But, when you do this, you must add a `--name XXX` to the commands to target a specific environment. If you do `conda install` in a notebook without supplying a name, it will install into the base Conda environment and not the environment you have selected via the kernel.
91 |
92 | ## Real-Time Collaboration
93 |
94 | To start a network tunnel allowing other users to access your local notebook,
95 | open Terminal (inside the Notebook webapp) and launch
96 | `/usr/local/bin/start-tunnel.d/start-tunnel.sh` script.
97 | A public URL for your tunnel will appear in the console. For example,
98 |
99 | ```bash
100 | (base) domino-research@f6d754573428:~$ /usr/local/bin/start-tunnel.d/start-tunnel.sh
101 | % Total % Received % Xferd Average Speed Time Time Time Current
102 | Dload Upload Total Spent Left Speed
103 | 100 159 100 159 0 0 703 0 --:--:-- --:--:-- --:--:-- 703
104 | 100 631 100 631 0 0 1923 0 --:--:-- --:--:-- --:--:-- 1923
105 | 100 29.9M 100 29.9M 0 0 25.6M 0 0:00:01 0:00:01 --:--:-- 57.8M
106 | 2021-10-27T18:52:57Z INF Thank you for trying Cloudflare Tunnel. Doing so, without a Cloudflare account, is a quick way to experiment and try it out. However, be aware that these account-less Tunnels have no uptime guarantee. If you intend to use Tunnels in production you should use a pre-created named tunnel by following: https://developers.cloudflare.com/cloudflare-one/connections/connect-apps
107 | 2021-10-27T18:52:57Z INF Requesting new quick Tunnel on trycloudflare.com...
108 | 2021-10-27T18:52:59Z INF +--------------------------------------------------------------------------------------------+
109 | 2021-10-27T18:52:59Z INF | Your quick Tunnel has been created! Visit it at (it may take some time to be reachable): |
110 | 2021-10-27T18:52:59Z INF | https://afghanistan-null-ron-norwegian.trycloudflare.com |
111 | 2021-10-27T18:52:59Z INF +--------------------------------------------------------------------------------------------+
112 | 2021-10-27T18:52:59Z INF Cannot determine default configuration path. No file [config.yml config.yaml] in [~/.cloudflared ~/.cloudflare-warp ~/cloudflare-warp /etc/cloudflared /usr/local/etc/cloudflared]
113 | 2021-10-27T18:52:59Z INF Version 2021.10.5
114 | 2021-10-27T18:52:59Z INF GOOS: linux, GOVersion: devel +a84af465cb Mon Aug 9 10:31:00 2021 -0700, GoArch: amd64
115 | 2021-10-27T18:52:59Z INF Settings: map[protocol:quic url:http://127.0.0.1:8888]
116 | 2021-10-27T18:52:59Z INF cloudflared will not automatically update when run from the shell. To enable auto-updates, run cloudflared as a service: https://developers.cloudflare.com/cloudflare-one/connections/connect-apps/run-tunnel/run-as-service
117 | 2021-10-27T18:52:59Z INF Generated Connector ID: 1d0fca10-ba61-4620-902a-60ada75c622a
118 | 2021-10-27T18:52:59Z INF Initial protocol quic
119 | 2021-10-27T18:52:59Z INF Starting metrics server on 127.0.0.1:35577/metrics
120 | <...>
121 | ```
122 |
123 | Here, `https://afghanistan-null-ron-norwegian.trycloudflare.com` is the URL
124 | that can be shared with your collaborators. Note that the tunnel always uses
125 | a default port for the given protocol.
126 |
127 |
--------------------------------------------------------------------------------
/local-notebooks/config/before-notebook/cron.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | echo "Configuring cron"
5 | cron
6 | echo "Configured cron"
7 |
--------------------------------------------------------------------------------
/local-notebooks/config/before-notebook/load_conda.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Init Conda
4 | CONDA_BASE=$(/opt/conda/bin/conda info --base)
5 | CONDA_FUNCTION="etc/profile.d/conda.sh"
6 | CONDA="$CONDA_BASE/$CONDA_FUNCTION"
7 | source $CONDA
8 |
9 | # Find more recent backup file
10 | BACKUP_PREFIX="/mnt/home/conda_backups"
11 | BACKUP_FOLDER=$(ls $BACKUP_PREFIX | sort -r | head -n 1)
12 |
13 | echo "Using environment backups from $BACKUP_FOLDER"
14 |
15 | ENVS=$(ls $BACKUP_PREFIX/$BACKUP_FOLDER)
16 | for env in $ENVS; do
17 | NAME="$(basename $env .yml)"
18 | echo "Env loading: $NAME"
19 | conda env update --name $NAME --file $BACKUP_PREFIX/$BACKUP_FOLDER/$env --prune >>/dev/null
20 | echo "Env loaded: $NAME"
21 | done
22 |
--------------------------------------------------------------------------------
/local-notebooks/config/cron/dump_conda.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # modified from: https://github.com/conda/conda/issues/5165#issuecomment-665354035
4 |
5 | # Init Conda
6 | CONDA_BASE=$(/opt/conda/bin/conda info --base)
7 | CONDA_FUNCTION="etc/profile.d/conda.sh"
8 | CONDA="$CONDA_BASE/$CONDA_FUNCTION"
9 | source $CONDA
10 |
11 | # Create backup file
12 | NOW=$(date "+%Y%m%d%H%M%S")
13 | LOC=/mnt/home/conda_backups/$NOW
14 | mkdir -p $LOC
15 |
16 | ENVS=$(conda env list | grep '^\w' | cut -d' ' -f1)
17 | for env in $ENVS; do
18 | echo "Exporting $env"
19 | conda activate $env
20 | conda env export > $LOC/$env.yml
21 | done
22 |
--------------------------------------------------------------------------------
/local-notebooks/config/cron/root:
--------------------------------------------------------------------------------
1 | # To debug: install and start rsyslog and then
2 | # view logs with grep 'conda_dump' /var/log/syslog
3 | * * * * * /usr/local/sbin/dom_cron/dump_conda.sh 2>&1 | logger -t conda_dump
4 |
--------------------------------------------------------------------------------
/local-notebooks/config/jupyter_lab_config.py:
--------------------------------------------------------------------------------
1 | # Enable collaborative mode
2 | c.LabApp.collaborative = True
3 |
4 | # Default notebook path to S3 mount
5 | c.NotebookApp.notebook_dir = "/mnt/home"
6 |
7 |
--------------------------------------------------------------------------------
/local-notebooks/config/overrides.json:
--------------------------------------------------------------------------------
1 | {
2 | "@jupyterlab/notebook-extension:tracker": {
3 | "recordTiming": true
4 | },
5 | "@jupyterlab/extensionmanager-extension:plugin": {
6 | "enabled": true,
7 | "disclaimed": true
8 | }
9 | }
10 |
--------------------------------------------------------------------------------
/local-notebooks/config/policy.json:
--------------------------------------------------------------------------------
1 | {
2 | "Version": "2012-10-17",
3 | "Statement": [
4 | {
5 | "Sid": "VisualEditor0",
6 | "Effect": "Allow",
7 | "Action": [
8 | "s3:PutObject",
9 | "s3:GetObject",
10 | "s3:DeleteObject"
11 | ],
12 | "Resource": "arn:aws:s3:::${AWS_BUCKET}/*"
13 | },
14 | {
15 | "Sid": "VisualEditor1",
16 | "Effect": "Allow",
17 | "Action": "s3:ListBucket",
18 | "Resource": "arn:aws:s3:::${AWS_BUCKET}"
19 | }
20 | ]
21 | }
--------------------------------------------------------------------------------
/local-notebooks/config/start-notebook/s3.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | mkdir -p /mnt/home
5 |
6 | # Configure S3FS
7 | mkdir -p /tmp/s3fs
8 |
9 | echo "$(cat /etc/aws/credentials | jq -r .AccessKey.AccessKeyId):$(cat /etc/aws/credentials | jq -r .AccessKey.SecretAccessKey)" > /etc/passwd-s3fs
10 | chmod 600 /etc/passwd-s3fs
11 |
12 | s3fs ${AWS_BUCKET} /mnt/home -o allow_other -o use_cache=/tmp/s3fs
13 |
14 | echo "S3 Mounted"
15 |
16 | sleep 1
17 |
--------------------------------------------------------------------------------
/local-notebooks/config/start-tunnel/start-tunnel.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # This script downloads and installs an executable for an anonymous Cloudflare
4 | # tunnel, and then starts the tunnel pointed to a local endpoint specified below.
5 | # The public URL for the tunnel is printed in the logs -- it's different each time.
6 |
7 | LOCAL_ENDPOINT=http://127.0.0.1:8888
8 |
9 | DOWNLOAD_URL=https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64
10 | TUNNEL_EXE=~/opt/cloudflared
11 |
12 | if [ ! -e "$TUNNEL_EXE" ]
13 | then
14 | DIR=$(dirname $TUNNEL_EXE)
15 | mkdir -p "$DIR"
16 | curl -o $TUNNEL_EXE -L $DOWNLOAD_URL
17 | chmod u+x $TUNNEL_EXE
18 | fi
19 |
20 | $TUNNEL_EXE tunnel -url $LOCAL_ENDPOINT
21 |
22 |
--------------------------------------------------------------------------------
/terraform/mlflow-runtime-ecr/.gitignore:
--------------------------------------------------------------------------------
1 | # Local .terraform directories
2 | **/.terraform/*
3 |
4 | # .tfstate files
5 | *.tfstate
6 | *.tfstate.*
7 |
8 | # Crash log files
9 | crash.log
10 |
11 | # Exclude all .tfvars files, which are likely to contain sentitive data, such as
12 | # password, private keys, and other secrets. These should not be part of version
13 | # control as they are data points which are potentially sensitive and subject
14 | # to change depending on the environment.
15 | #
16 | *.tfvars
17 |
18 | # Ignore override files as they are usually used to override resources locally and so
19 | # are not checked in
20 | override.tf
21 | override.tf.json
22 | *_override.tf
23 | *_override.tf.json
24 |
25 | # Include override files you do wish to add to version control using negated pattern
26 | #
27 | # !example_override.tf
28 |
29 | # Include tfplan files to ignore the plan output of command: terraform plan -out=tfplan
30 | # example: *tfplan*
31 |
32 | # Ignore CLI configuration files
33 | .terraformrc
34 | terraform.rc
35 |
--------------------------------------------------------------------------------
/terraform/mlflow-runtime-ecr/.terraform.lock.hcl:
--------------------------------------------------------------------------------
1 | # This file is maintained automatically by "terraform init".
2 | # Manual edits may be lost in future updates.
3 |
4 | provider "registry.terraform.io/hashicorp/aws" {
5 | version = "3.53.0"
6 | constraints = "3.53.0"
7 | hashes = [
8 | "h1:kcda9YVaFUzBFVtKXNZrQB801i2XkH1Y5gbdOHNpB38=",
9 | "zh:35a77c79170b0cf3fb7eb835f3ce0b715aeeceda0a259e96e49fed5a30cf6646",
10 | "zh:519d5470a932b1ec9a0fe08876c5e0f0f84f8e506b652c051e4ab708be081e89",
11 | "zh:58cfa5b454602d57c47acd15c2ad166a012574742cdbcf950787ce79b6510218",
12 | "zh:5fc3c0162335a730701c0175809250233f45f1021da8fa52c73635e4c08372d8",
13 | "zh:6790f9d6261eb4bd5cdd7cd9125f103befce2ba127f9ba46eef83585b86e1d11",
14 | "zh:76e1776c3bf9568d520f78419ec143c081f653b8df4fb22577a8c4a35d3315f9",
15 | "zh:ca8ed88d0385e45c35223ace59b1bf77d81cd2154d5416e63a3dddaf0def30e6",
16 | "zh:d002562c4a89a9f1f6cd8d854fad3c66839626fc260e5dde5267f6d34dbd97a4",
17 | "zh:da5e47fb769e90a2f16c90fd0ba95d62da3d76eb006823664a5c6e96188731b0",
18 | "zh:dfe7f33ec252ea550e090975a5f10940c27302bebb5559957957937b069646ea",
19 | "zh:fa91574605ddce726e8a4e421297009a9dabe023106e139ac46da49c8285f2fe",
20 | ]
21 | }
22 |
--------------------------------------------------------------------------------
/terraform/mlflow-runtime-ecr/generate_vendor_plan.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | aws ec2 describe-regions | jq -rc '.[][].RegionName' | while read region; do echo -e "provider \"aws\" {\n region = \"${region}\"\n alias = \"${region}\"\n}\n\nmodule \"${region}\" {\n source = \"./vendor\"\n\n providers = {\n aws = aws.${region}\n }\n}\n"; done > vendor.tf
3 |
--------------------------------------------------------------------------------
/terraform/mlflow-runtime-ecr/provider.tf:
--------------------------------------------------------------------------------
1 | terraform {
2 | required_providers {
3 | aws = {
4 | source = "hashicorp/aws"
5 | version = "3.53.0"
6 | }
7 | }
8 | }
9 |
--------------------------------------------------------------------------------
/terraform/mlflow-runtime-ecr/push.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | aws ec2 describe-regions | jq -rc '.[][].RegionName' | while read region
4 | do
5 | echo ${region}
6 | aws ecr get-login-password --region ${region} | docker login --username AWS --password-stdin 667552661262.dkr.ecr.${region}.amazonaws.com
7 | docker tag mlflow-pyfunc:latest 667552661262.dkr.ecr.${region}.amazonaws.com/bridge-mlflow-runtime:latest
8 | docker push 667552661262.dkr.ecr.${region}.amazonaws.com/bridge-mlflow-runtime:latest
9 | done
10 |
--------------------------------------------------------------------------------
/terraform/mlflow-runtime-ecr/vendor.tf:
--------------------------------------------------------------------------------
1 | provider "aws" {
2 | region = "eu-north-1"
3 | alias = "eu-north-1"
4 | }
5 |
6 | module "eu-north-1" {
7 | source = "./vendor"
8 |
9 | providers = {
10 | aws = aws.eu-north-1
11 | }
12 | }
13 |
14 | provider "aws" {
15 | region = "ap-south-1"
16 | alias = "ap-south-1"
17 | }
18 |
19 | module "ap-south-1" {
20 | source = "./vendor"
21 |
22 | providers = {
23 | aws = aws.ap-south-1
24 | }
25 | }
26 |
27 | provider "aws" {
28 | region = "eu-west-3"
29 | alias = "eu-west-3"
30 | }
31 |
32 | module "eu-west-3" {
33 | source = "./vendor"
34 |
35 | providers = {
36 | aws = aws.eu-west-3
37 | }
38 | }
39 |
40 | provider "aws" {
41 | region = "eu-west-2"
42 | alias = "eu-west-2"
43 | }
44 |
45 | module "eu-west-2" {
46 | source = "./vendor"
47 |
48 | providers = {
49 | aws = aws.eu-west-2
50 | }
51 | }
52 |
53 | provider "aws" {
54 | region = "eu-west-1"
55 | alias = "eu-west-1"
56 | }
57 |
58 | module "eu-west-1" {
59 | source = "./vendor"
60 |
61 | providers = {
62 | aws = aws.eu-west-1
63 | }
64 | }
65 |
66 | provider "aws" {
67 | region = "ap-northeast-3"
68 | alias = "ap-northeast-3"
69 | }
70 |
71 | module "ap-northeast-3" {
72 | source = "./vendor"
73 |
74 | providers = {
75 | aws = aws.ap-northeast-3
76 | }
77 | }
78 |
79 | provider "aws" {
80 | region = "ap-northeast-2"
81 | alias = "ap-northeast-2"
82 | }
83 |
84 | module "ap-northeast-2" {
85 | source = "./vendor"
86 |
87 | providers = {
88 | aws = aws.ap-northeast-2
89 | }
90 | }
91 |
92 | provider "aws" {
93 | region = "ap-northeast-1"
94 | alias = "ap-northeast-1"
95 | }
96 |
97 | module "ap-northeast-1" {
98 | source = "./vendor"
99 |
100 | providers = {
101 | aws = aws.ap-northeast-1
102 | }
103 | }
104 |
105 | provider "aws" {
106 | region = "sa-east-1"
107 | alias = "sa-east-1"
108 | }
109 |
110 | module "sa-east-1" {
111 | source = "./vendor"
112 |
113 | providers = {
114 | aws = aws.sa-east-1
115 | }
116 | }
117 |
118 | provider "aws" {
119 | region = "ca-central-1"
120 | alias = "ca-central-1"
121 | }
122 |
123 | module "ca-central-1" {
124 | source = "./vendor"
125 |
126 | providers = {
127 | aws = aws.ca-central-1
128 | }
129 | }
130 |
131 | provider "aws" {
132 | region = "ap-southeast-1"
133 | alias = "ap-southeast-1"
134 | }
135 |
136 | module "ap-southeast-1" {
137 | source = "./vendor"
138 |
139 | providers = {
140 | aws = aws.ap-southeast-1
141 | }
142 | }
143 |
144 | provider "aws" {
145 | region = "ap-southeast-2"
146 | alias = "ap-southeast-2"
147 | }
148 |
149 | module "ap-southeast-2" {
150 | source = "./vendor"
151 |
152 | providers = {
153 | aws = aws.ap-southeast-2
154 | }
155 | }
156 |
157 | provider "aws" {
158 | region = "eu-central-1"
159 | alias = "eu-central-1"
160 | }
161 |
162 | module "eu-central-1" {
163 | source = "./vendor"
164 |
165 | providers = {
166 | aws = aws.eu-central-1
167 | }
168 | }
169 |
170 | provider "aws" {
171 | region = "us-east-1"
172 | alias = "us-east-1"
173 | }
174 |
175 | module "us-east-1" {
176 | source = "./vendor"
177 |
178 | providers = {
179 | aws = aws.us-east-1
180 | }
181 | }
182 |
183 | provider "aws" {
184 | region = "us-east-2"
185 | alias = "us-east-2"
186 | }
187 |
188 | module "us-east-2" {
189 | source = "./vendor"
190 |
191 | providers = {
192 | aws = aws.us-east-2
193 | }
194 | }
195 |
196 | provider "aws" {
197 | region = "us-west-1"
198 | alias = "us-west-1"
199 | }
200 |
201 | module "us-west-1" {
202 | source = "./vendor"
203 |
204 | providers = {
205 | aws = aws.us-west-1
206 | }
207 | }
208 |
209 | provider "aws" {
210 | region = "us-west-2"
211 | alias = "us-west-2"
212 | }
213 |
214 | module "us-west-2" {
215 | source = "./vendor"
216 |
217 | providers = {
218 | aws = aws.us-west-2
219 | }
220 | }
221 |
222 |
--------------------------------------------------------------------------------
/terraform/mlflow-runtime-ecr/vendor/ecr.tf:
--------------------------------------------------------------------------------
1 | resource "aws_ecr_repository" "this" {
2 | name = "bridge-mlflow-runtime"
3 | image_tag_mutability = "MUTABLE"
4 |
5 | tags = {
6 | ManagedBy = "Terraform"
7 | }
8 | }
9 |
10 | resource "aws_ecr_repository_policy" "this" {
11 | repository = aws_ecr_repository.this.name
12 |
13 | policy = <