├── .coveragerc ├── .dockerignore ├── .flake8 ├── .github └── workflows │ └── ci.yaml ├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── Dockerfile ├── Makefile ├── README.md ├── conf └── app.conf.template ├── docker-compose.yml ├── funcx_web_service ├── __init__.py ├── application.py ├── authentication │ ├── __init__.py │ ├── auth.py │ ├── auth_state.py │ └── globus_auth.py ├── container_service_adapter.py ├── error_responses.py ├── models │ ├── __init__.py │ ├── auth_groups.py │ ├── container.py │ ├── endpoint.py │ ├── function.py │ ├── search.py │ ├── serializer.py │ ├── tasks.py │ ├── user.py │ └── utils.py ├── response.py ├── routes │ ├── __init__.py │ ├── container.py │ └── funcx.py └── version.py ├── integration_tests ├── Tutorial.ipynb ├── funcX.postman_collection.json ├── get_valid_token.py └── integration_test.py ├── migrations ├── README ├── alembic.ini ├── env.py ├── script.py.mako └── versions │ ├── v0.0.3_.py │ └── v0.2.0_.py ├── mypy.ini ├── requirements.in ├── requirements.txt ├── requirements_test.txt ├── scripts ├── store_endpoint_info.py └── store_usage.py ├── tests ├── __init__.py ├── conftest.py ├── integration │ └── test_endpoint_api.py ├── test_container_service_adapter.py └── unit │ ├── auth │ ├── test_auth_state.py │ └── test_authorization_functions.py │ ├── routes │ ├── conftest.py │ ├── test_auth.py │ ├── test_funcx.py │ ├── test_register_container.py │ ├── test_register_endpoint.py │ ├── test_register_function.py │ ├── test_status.py │ ├── test_submit_function.py │ └── test_task_groups.py │ ├── test_app_init.py │ └── test_task_behavior.py ├── tox.ini ├── uwsgi.ini └── web-entrypoint.sh /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | show_missing = true 3 | skip_covered = true 4 | 5 | exclude_lines = 6 | # the pragma to disable coverage 7 | pragma: no cover 8 | # don't complain if tests don't hit unimplemented methods/modes 9 | raise NotImplementedError 10 | # don't check on executable components of importable modules 11 | if __name__ == .__main__.: 12 | # don't check coverage on type checking conditionals 13 | if TYPE_CHECKING: 14 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | dbsetup/ 2 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] # black-compatible 2 | max-line-length = 88 3 | ignore = W503,W504 4 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI/CD 2 | 3 | on: 4 | push: 5 | branches: 6 | - "*" 7 | tags: 8 | - "*" 9 | pull_request: 10 | 11 | jobs: 12 | test: 13 | strategy: 14 | matrix: 15 | python-version: [3.7] 16 | runs-on: ubuntu-latest 17 | 18 | steps: 19 | - uses: actions/checkout@master 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v1 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | - name: Install tox 25 | run: python -m pip install tox 26 | - name: Lint 27 | run: python -m tox -e lint,mypy 28 | - name: Safety Check 29 | run: python -m tox -e safety 30 | - name: Test 31 | run: python -m tox -e py 32 | - name: Report coverage with Codecov 33 | run: python -m tox -e codecov -- --token=${{ secrets.CODECOV_TOKEN }} 34 | 35 | publish: 36 | # only trigger on pushes to the main repo (not forks, and not PRs) 37 | if: ${{ github.repository == 'funcx-faas/funcx-web-service' && github.event_name == 'push' }} 38 | needs: test 39 | runs-on: ubuntu-latest 40 | steps: 41 | - uses: actions/checkout@master 42 | - name: Extract tag name 43 | shell: bash 44 | run: echo "##[set-output name=imagetag;]$(echo ${GITHUB_REF##*/})" 45 | id: extract_tag_name 46 | 47 | - name: Build funcX-web-service Image 48 | uses: elgohr/Publish-Docker-Github-Action@master 49 | with: 50 | name: funcx/web-service:${{ steps.extract_tag_name.outputs.imagetag }} 51 | username: ${{ secrets.DOCKER_USERNAME }} 52 | password: ${{ secrets.DOCKER_PASSWORD }} 53 | tag: "${GITHUB_REF##*/}" 54 | 55 | # If this is a merge to main branch then we want to restart the web service 56 | # pod on dev cluster to pick up the changes 57 | deploy: 58 | needs: publish 59 | runs-on: ubuntu-latest 60 | if: github.ref == 'refs/heads/main' 61 | steps: 62 | - name: Configure AWS credentials 63 | uses: aws-actions/configure-aws-credentials@v1 64 | with: 65 | aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} 66 | aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 67 | aws-region: us-east-1 68 | 69 | - name: scale to webservice pods to zero 70 | uses: kodermax/kubectl-aws-eks@master 71 | env: 72 | KUBE_CONFIG_DATA: ${{ secrets.KUBE_CONFIG_DATA_STAGING }} 73 | with: 74 | args: scale deployment funcx-funcx-web-service --replicas=0 75 | 76 | - name: scale to webservice pods back up 77 | uses: kodermax/kubectl-aws-eks@master 78 | env: 79 | KUBE_CONFIG_DATA: ${{ secrets.KUBE_CONFIG_DATA_STAGING }} 80 | with: 81 | args: scale deployment funcx-funcx-web-service --replicas=1 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dockers/secrets 2 | dockers/certs 3 | conf/app.conf 4 | *~ 5 | *.pyc 6 | parsl.egg-info/* 7 | .scripts 8 | .*out 9 | .*err 10 | *log 11 | .ipynb_checkpoints 12 | tutorials/* 13 | venv/* 14 | web/* 15 | api/* 16 | .idea/* 17 | forwarder/.idea/* 18 | forwarder/.idea/forwarder.iml 19 | *.iml 20 | # TODO: remove this after we move to a tool for managing dependencies 21 | /frozen-requirements-tree.txt 22 | 23 | # Byte-compiled / optimized / DLL files 24 | __pycache__/ 25 | *.py[cod] 26 | *$py.class 27 | 28 | # C extensions 29 | *.so 30 | 31 | # Distribution / packaging 32 | .Python 33 | env/ 34 | build/ 35 | develop-eggs/ 36 | dist/ 37 | downloads/ 38 | eggs/ 39 | .eggs/ 40 | lib/ 41 | lib64/ 42 | parts/ 43 | sdist/ 44 | var/ 45 | wheels/ 46 | *.egg-info/ 47 | .installed.cfg 48 | *.egg 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *.cover 69 | .hypothesis/ 70 | 71 | # Translations 72 | *.mo 73 | *.pot 74 | 75 | # Django stuff: 76 | *.log 77 | local_settings.py 78 | 79 | # Flask stuff: 80 | instance/ 81 | .webassets-cache 82 | 83 | # Scrapy stuff: 84 | .scrapy 85 | 86 | # Sphinx documentation 87 | docs/_build/ 88 | 89 | # PyBuilder 90 | target/ 91 | 92 | # Jupyter Notebook 93 | .ipynb_checkpoints 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # celery beat schedule file 99 | celerybeat-schedule 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # dotenv 105 | .env 106 | 107 | # virtualenv 108 | .venv 109 | venv/ 110 | ENV/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | 125 | # emacs buffers 126 | \#* 127 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | profile = black 3 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: meta 3 | hooks: 4 | - id: check-hooks-apply 5 | - id: check-useless-excludes 6 | - repo: https://github.com/pre-commit/pre-commit-hooks.git 7 | rev: v4.0.1 8 | hooks: 9 | - id: check-merge-conflict 10 | - id: trailing-whitespace 11 | - repo: https://github.com/sirosen/check-jsonschema 12 | rev: 0.5.1 13 | hooks: 14 | - id: check-github-workflows 15 | - repo: https://gitlab.com/pycqa/flake8 16 | rev: 3.9.2 17 | hooks: 18 | - id: flake8 19 | # because pre-commit invokes flake8 directly on files, the exclude in 20 | # `.flake8` config does not apply and the files must be excluded from the hook 21 | exclude: ^migrations/.* 22 | additional_dependencies: ['flake8-bugbear==21.4.3'] 23 | - repo: https://github.com/python/black 24 | rev: 21.9b0 25 | hooks: 26 | - id: black 27 | - repo: https://github.com/timothycrosley/isort 28 | rev: 5.9.3 29 | hooks: 30 | - id: isort 31 | - repo: https://github.com/asottile/pyupgrade 32 | rev: v2.29.0 33 | hooks: 34 | - id: pyupgrade 35 | args: ["--py36-plus"] 36 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guide 2 | 3 | This doc covers dev setup and guidelines for contributing. 4 | 5 | FIXME: This doc is a stub. 6 | 7 | ## Requirements 8 | 9 | - python3.7+ (prefer 3.7), pip, virtualenv 10 | - docker 11 | 12 | ### Recommended 13 | 14 | - [pre-commit](https://pre-commit.com/) 15 | 16 | ## Linting & Testing 17 | 18 | Testing should be done in a virtualenv with pytest. Setup: 19 | 20 | pip install -r ./requirements.txt 21 | pip install -r ./requirements_test.txt 22 | 23 | Run tests with 24 | 25 | pytest 26 | 27 | Linting can be run via pre-commit. Run for all files in the repo: 28 | 29 | pre-commit run -a 30 | 31 | ### (Optional) Setup pre-commit Hooks 32 | 33 | For the best development experience, set up linting and autofixing pre-commit 34 | git hooks using the `pre-commit` tool. 35 | 36 | After installing `pre-commit`, run 37 | 38 | pre-commit install 39 | 40 | in the repo to configure hooks. 41 | 42 | > NOTE: If necessary, you can always skip hooks with `git commit --no-verify` 43 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.7 2 | 3 | # Create a group and user 4 | RUN addgroup uwsgi && useradd -g uwsgi uwsgi 5 | 6 | WORKDIR /opt/funcx-web-service 7 | 8 | COPY ./requirements.txt . 9 | 10 | RUN pip install -r requirements.txt 11 | RUN pip install --disable-pip-version-check uwsgi 12 | 13 | COPY uwsgi.ini . 14 | COPY ./funcx_web_service/ ./funcx_web_service/ 15 | COPY ./migrations/ ./migrations/ 16 | COPY web-entrypoint.sh . 17 | 18 | USER uwsgi 19 | EXPOSE 5000 20 | 21 | CMD sh web-entrypoint.sh 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: lint 2 | lint: 3 | tox -e lint,mypy 4 | 5 | .PHONY: test 6 | test: 7 | tox 8 | 9 | .PHONY: py-safety 10 | py-safety: 11 | tox -e safety 12 | 13 | # use sed to trim leading whitespace, and `sort -u` to remove duplicates 14 | # intermediate file allows the use of tee, to show output, and saves the frozen 15 | # deps 16 | # 17 | # the funcx git requirements are intentionally excluded from the deptree 18 | # because they look to pip like a conflict once frozen 19 | # they are pulled out of the `requirements.in` data to get a complete 20 | # requirement specification 21 | # 22 | # FIXME: 23 | # the funcx requirement munging should be possible to remove as soon as 24 | # we've removed the dependency on the forwarder and switch to a packaged version 25 | # of the SDK 26 | .PHONY: freezedeps 27 | freezedeps: 28 | echo "# frozen requirements; generate with 'make freezedeps'" > requirements.txt 29 | tox -qq -e freezedeps | tee frozen-requirements-tree.txt 30 | sed 's/ //g' frozen-requirements-tree.txt | sort -u >> requirements.txt 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FuncX Web Service 2 | This is the client interface to FuncX 3 | 4 | ## How to build docker image 5 | First build the docker image 6 | ```shell script 7 | docker build -t funcx_web_service:develop . 8 | ``` 9 | 10 | ## How to Test With Kubernetes 11 | You can create all of the required infrstructure for funcX web service and 12 | run it on your host for debugging. 13 | 14 | 1. Deploy the [helm chart](https://github.com/funcx-faas/helm-chart) 15 | 2. Set up your local app config with the following values: 16 | ```python 17 | DB_NAME = "public" 18 | DB_USER = "funcx" 19 | DB_PASSWORD = "leftfoot1" 20 | DB_HOST = "localhost" 21 | 22 | REDIS_PORT = 6379 23 | REDIS_HOST = "localhost" 24 | ``` 25 | 3. Forward the postgres pod ports to your host. This command will not return so 26 | start it in another shell. 27 | ```shell script 28 | kubectl port-forward funcx-postgresql-0 5432:5432 29 | ``` 30 | 4. Forward the Redis master pod ports to your host. This command will not 31 | return so start it in another shell. 32 | ```shell script 33 | kubectl port-forward funcx-redis-master-0 6379:6379 34 | ``` 35 | 5. Launch the flask app: 36 | ```shell script 37 | APP_CONFIG_FILE=../conf/app.conf PYTHONPATH=. python funcx_web_service/application.py 38 | ``` 39 | 6. Obtain a JWT to authenticate requests to the REST server 40 | ```shell script 41 | python integration_tests/get_valid_token.py 42 | ``` 43 | 7. Use the postman tests in `integration_tests/funcX.postman_collection.json` 44 | with the `host` variable set to `localhost:5000` and the `access_token` set 45 | to your JWT. 46 | 47 | -------------------------------------------------------------------------------- /conf/app.conf.template: -------------------------------------------------------------------------------- 1 | REDIS_PORT = 6379 2 | REDIS_HOST = "localhost" 3 | 4 | GLOBUS_CLIENT = <> 5 | GLOBUS_KEY = <> 6 | 7 | CONTAINER_SERVICE_ENABLED = False 8 | 9 | # URL of Container Service 10 | CONTAINER_SERVICE = http://localhost:5001 -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.8' 2 | services: 3 | funcx-web-service: 4 | build: funcx-web-service/ 5 | restart: always 6 | ports: 7 | - "8080:80" 8 | environment: 9 | - FLASK_APP=funcx-web-service 10 | - FLASK_DEBUG=1 11 | env_file: funcx-web-service/funcx-web-service.env 12 | volumes: 13 | - ./funcx-web-service:/opt/funcx-web-service 14 | secrets: 15 | - globus_client 16 | - globus_key 17 | - web_cert 18 | - web_key 19 | networks: 20 | default: 21 | aliases: 22 | - funcx.org 23 | cap_add: 24 | - NET_ADMIN 25 | - NET_RAW 26 | forwarder: 27 | build: funcx-web-service/forwarder 28 | ports: 29 | - "8081:8080" 30 | volumes: 31 | - ./funcx-web-service/forwarder:/opt/forwarder 32 | - funcx_install:/funcx 33 | depends_on: 34 | - funcx-web-service 35 | - mockredis 36 | serializer: 37 | build: funcx-web-service/serializer 38 | ports: 39 | - "8082:8080" 40 | environment: 41 | - FLASK_ENV=development 42 | volumes: 43 | - ./funcx-web-service/serializer/serializer:/opt/serializer 44 | - funcx_install:/funcx 45 | mockrds: 46 | build: funcx-web-service/dockers/rds 47 | ports: 48 | - "5432:5432" 49 | mockredis: 50 | build: funcx-web-service/dockers/redis 51 | ports: 52 | - "6379:6379" 53 | sysctls: 54 | net.core.somaxconn: 1024 55 | endpoints: 56 | build: funcx-web-service/dockers/endpoints 57 | ports: 58 | - "8888:8888" 59 | volumes: 60 | - ./funcx-web-service/dockers/endpoints:/data 61 | - funcx_install:/funcx 62 | secrets: 63 | - web_cert 64 | - funcx_sdk_tokens 65 | - funcx_config 66 | depends_on: 67 | - funcx-web-service 68 | - forwarder 69 | secrets: 70 | globus_client: 71 | file: funcx-web-service/dockers/secrets/globus_client.txt 72 | globus_key: 73 | file: funcx-web-service/dockers/secrets/globus_key.txt 74 | web_cert: 75 | file: funcx-web-service/dockers/secrets/web-cert.pem 76 | web_key: 77 | file: funcx-web-service/dockers/secrets/web-key.pem 78 | funcx_sdk_tokens: 79 | file: funcx-web-service/dockers/secrets/funcx-credentials/funcx_sdk_tokens.json 80 | funcx_config: 81 | file: funcx-web-service/dockers/secrets/funcx-config.py 82 | 83 | volumes: 84 | funcx_install: 85 | driver: local 86 | driver_opts: 87 | type: none 88 | device: $PWD/funcX 89 | o: bind 90 | -------------------------------------------------------------------------------- /funcx_web_service/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from distutils.util import strtobool 4 | 5 | from flask import Flask 6 | from flask.logging import default_handler 7 | from pythonjsonlogger import jsonlogger 8 | 9 | from funcx_web_service.container_service_adapter import ContainerServiceAdapter 10 | from funcx_web_service.error_responses import create_error_response 11 | from funcx_web_service.models import db, load_all_models 12 | from funcx_web_service.response import FuncxResponse 13 | from funcx_web_service.routes.container import container_api 14 | from funcx_web_service.routes.funcx import funcx_api 15 | 16 | 17 | def _override_config_with_environ(app): 18 | """ 19 | Use app.config as a guide to configuration settings that can be overridden from env 20 | vars. 21 | """ 22 | # Env vars will be strings. Convert boolean values 23 | def _convert_string(value): 24 | return value if value not in ["true", "false"] else strtobool(value) 25 | 26 | # Create a dictionary of environment vars that have keys that match keys from the 27 | # loaded config. These will override anything from the config file 28 | return { 29 | k: (lambda key, value: _convert_string(os.environ[k]))(k, v) 30 | for (k, v) in app.config.items() 31 | if k in os.environ 32 | } 33 | 34 | 35 | def create_app(test_config=None): 36 | application = Flask(__name__) 37 | 38 | level = os.environ.get("LOGLEVEL", "DEBUG").upper() 39 | logger = application.logger 40 | logger.setLevel(level) 41 | 42 | handler = logging.StreamHandler() 43 | handler.setLevel(level) 44 | formatter = jsonlogger.JsonFormatter( 45 | "%(asctime)s %(name)s %(levelname)s %(message)s" 46 | ) 47 | handler.setFormatter(formatter) 48 | logger.addHandler(handler) 49 | 50 | # This removes the default Flask handler. Since we have added a JSON 51 | # log formatter and handler above, we must disable the default handler 52 | # to prevent duplicate log messages (where one is the normal log format 53 | # and the other is JSON format). 54 | logger.removeHandler(default_handler) 55 | 56 | application.response_class = FuncxResponse 57 | 58 | if test_config: 59 | application.config.from_mapping(test_config) 60 | else: 61 | application.config.from_envvar("APP_CONFIG_FILE") 62 | application.config.update(_override_config_with_environ(application)) 63 | 64 | if not hasattr(application, "extensions"): 65 | application.extensions = {} 66 | 67 | if application.config.get("CONTAINER_SERVICE_ENABLED", False): 68 | container_service = ContainerServiceAdapter( 69 | application.config["CONTAINER_SERVICE"] 70 | ) 71 | application.extensions["ContainerService"] = container_service 72 | else: 73 | application.extensions["ContainerService"] = None 74 | 75 | load_all_models() 76 | db.init_app(application) 77 | 78 | @application.before_first_request 79 | def create_tables(): 80 | db.create_all() 81 | 82 | @application.errorhandler(Exception) 83 | def handle_exception(e): 84 | logger.exception(e) 85 | return create_error_response(e, jsonify_response=True) 86 | 87 | # Include the API blueprint 88 | application.register_blueprint(funcx_api, url_prefix="/v2") 89 | application.register_blueprint(container_api, url_prefix="/v2") 90 | # Keeping these routes for backwards compatibility on tests. 91 | application.register_blueprint(funcx_api, url_prefix="/v1") 92 | application.register_blueprint(funcx_api, url_prefix="/api/v1") 93 | return application 94 | -------------------------------------------------------------------------------- /funcx_web_service/application.py: -------------------------------------------------------------------------------- 1 | from flask_migrate import Migrate 2 | 3 | from funcx_web_service import create_app 4 | from funcx_web_service.models import db 5 | 6 | app = create_app() 7 | db.init_app(app) 8 | 9 | migrate = Migrate(app, db) 10 | 11 | if __name__ == "__main__": 12 | app.run("0.0.0.0", port=5000) 13 | -------------------------------------------------------------------------------- /funcx_web_service/authentication/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funcx-faas/funcx-web-service/5714d5ab889396a72d81f25cf17e9523a9ecc82b/funcx_web_service/authentication/__init__.py -------------------------------------------------------------------------------- /funcx_web_service/authentication/auth.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from functools import wraps 3 | 4 | from flask import current_app, make_response 5 | from funcx_common.response_errors import ( 6 | EndpointNotFound, 7 | FunctionNotFound, 8 | FunctionNotPermitted, 9 | ) 10 | from globus_nexus_client import NexusClient 11 | from globus_sdk import AccessTokenAuthorizer 12 | from globus_sdk.base import BaseClient 13 | 14 | from funcx_web_service.models.auth_groups import AuthGroup 15 | from funcx_web_service.models.endpoint import Endpoint 16 | from funcx_web_service.models.function import Function, FunctionAuthGroup 17 | 18 | from .auth_state import get_auth_state 19 | from .globus_auth import get_auth_client 20 | 21 | # Default scope if not provided in config 22 | FUNCX_SCOPE = "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all" 23 | 24 | 25 | def authenticated(f): 26 | """Decorator for globus auth.""" 27 | 28 | @wraps(f) 29 | def decorated_function(*args, **kwargs): 30 | auth_state = get_auth_state() 31 | auth_state.assert_is_authenticated() 32 | auth_state.assert_has_default_scope() 33 | 34 | # TODO: review, should this be getting logged here? 35 | # it's the raw introspect response and could be logged by the 36 | # AuthenticationState object if that's desirable 37 | introspect_detail = getattr( 38 | auth_state.introspect_data, "data", auth_state.introspect_data 39 | ) 40 | current_app.logger.debug( 41 | "auth_detail", 42 | extra={"log_type": "auth_detail", "auth_detail": introspect_detail}, 43 | ) 44 | 45 | response = make_response(f(auth_state.user_object, *args, **kwargs)) 46 | response._log_data.set_user(auth_state.user_object) 47 | return response 48 | 49 | return decorated_function 50 | 51 | 52 | def authenticated_w_uuid(f): 53 | """Decorator for globus auth.""" 54 | 55 | @wraps(f) 56 | def decorated_function(*args, **kwargs): 57 | auth_state = get_auth_state() 58 | auth_state.assert_is_authenticated() 59 | auth_state.assert_has_default_scope() 60 | 61 | # TODO: review, as above 62 | introspect_detail = getattr( 63 | auth_state.introspect_data, "data", auth_state.introspect_data 64 | ) 65 | current_app.logger.debug( 66 | "auth_detail", 67 | extra={"log_type": "auth_detail", "auth_detail": introspect_detail}, 68 | ) 69 | 70 | response = make_response( 71 | f(auth_state.user_object, auth_state.identity_id, *args, **kwargs) 72 | ) 73 | response._log_data.set_user(auth_state.user_object) 74 | return response 75 | 76 | return decorated_function 77 | 78 | 79 | def check_group_membership(token, endpoint_groups): 80 | """Determine whether or not the user is a member 81 | of any of the groups 82 | 83 | Parameters 84 | ---------- 85 | token : str 86 | The user's nexus token 87 | endpoint_groups : list 88 | A list of the group ids associated with the endpoint 89 | 90 | Returns 91 | ------- 92 | bool 93 | Whether or not the user is a member of any of the groups 94 | """ 95 | client = get_auth_client() 96 | dep_tokens = client.oauth2_get_dependent_tokens(token) 97 | 98 | if "groups.api.globus.org" in dep_tokens.by_resource_server: 99 | current_app.logger.debug("Using groups v2 api.") 100 | token = dep_tokens.by_resource_server["groups.api.globus.org"]["access_token"] 101 | user_group_ids = _get_group_ids_groups_api(token) 102 | else: 103 | current_app.logger.debug("Using legacy nexus api.") 104 | token = dep_tokens.by_resource_server["nexus.api.globus.org"]["access_token"] 105 | user_group_ids = _get_group_ids_nexus_api(token) 106 | 107 | # Check if any of the user's groups match 108 | if user_group_ids & set(endpoint_groups): 109 | return True 110 | return False 111 | 112 | 113 | def _get_group_ids_groups_api(token): 114 | # Create a nexus client to retrieve the user's groups 115 | groups_client = BaseClient( 116 | "groups", 117 | base_url="https://groups.api.globus.org", 118 | base_path="/v2/groups/", 119 | authorizer=AccessTokenAuthorizer(token), 120 | ) 121 | user_groups = groups_client.get("my_groups").data 122 | user_group_ids = {_["id"] for _ in user_groups} 123 | return user_group_ids 124 | 125 | 126 | def _get_group_ids_nexus_api(token): 127 | # Create a nexus client to retrieve the user's groups 128 | nexus_client = NexusClient() 129 | nexus_client.authorizer = AccessTokenAuthorizer(token) 130 | user_groups = nexus_client.list_groups( 131 | my_statuses="active", fields="id", for_all_identities=True 132 | ) 133 | user_group_ids = {_["id"] for _ in user_groups} 134 | return user_group_ids 135 | 136 | 137 | @functools.lru_cache() 138 | def authorize_endpoint(user_id, endpoint_uuid, function_uuid, token): 139 | """Determine whether or not the user is allowed to access this endpoint. 140 | This is done in two steps: first, check if the user owns the endpoint. If not, 141 | check if there are any groups associated with the endpoint and determine if the user 142 | is a member of any of them. 143 | 144 | Raises an Exception if the endpoint does not exist, or if the endpoint is 145 | restricted and the provided function is not whitelisted. 146 | 147 | Parameters 148 | ---------- 149 | user_id : str 150 | The primary identity of the user 151 | endpoint_uuid : str 152 | The uuid of the endpoint 153 | function_uuid : str 154 | The uuid of the function 155 | token : str 156 | The auth token 157 | 158 | Returns 159 | ------- 160 | bool 161 | Whether or not the user is allowed access to the endpoint 162 | """ 163 | 164 | authorized = False 165 | endpoint = Endpoint.find_by_uuid(endpoint_uuid) 166 | authorized = False 167 | 168 | if not endpoint: 169 | raise EndpointNotFound(endpoint_uuid) 170 | 171 | if endpoint.restricted: 172 | current_app.logger.debug("Restricted endpoint, checking function is allowed.") 173 | whitelisted_functions = [f.function_uuid for f in endpoint.restricted_functions] 174 | 175 | if function_uuid not in whitelisted_functions: 176 | raise FunctionNotPermitted(function_uuid, endpoint_uuid) 177 | 178 | if endpoint.public: 179 | authorized = True 180 | elif endpoint.user_id == user_id: 181 | authorized = True 182 | 183 | if not authorized: 184 | # Check if there are any groups associated with this endpoint 185 | groups = AuthGroup.find_by_endpoint_uuid(endpoint_uuid) 186 | endpoint_groups = [g.group_id for g in groups] 187 | if len(endpoint_groups) > 0: 188 | authorized = check_group_membership(token, endpoint_groups) 189 | 190 | return authorized 191 | 192 | 193 | @functools.lru_cache() 194 | def authorize_function(user_id, function_uuid, token): 195 | """Determine whether or not the user is allowed to access this function. 196 | This is done in two steps: first, check if the user owns the function. If not, 197 | check if there are any groups associated with the function and determine if the user 198 | is a member of any of them. 199 | 200 | Raises an Exception if the function does not exist. 201 | 202 | Parameters 203 | ---------- 204 | user_id : str 205 | The primary identity of the user 206 | function_uuid : str 207 | The uuid of the function 208 | token : str 209 | The auth token 210 | 211 | Returns 212 | ------- 213 | bool 214 | Whether or not the user is allowed access to the function 215 | """ 216 | 217 | authorized = False 218 | function = Function.find_by_uuid(function_uuid) 219 | 220 | if not function: 221 | raise FunctionNotFound(function_uuid) 222 | 223 | if function.user_id == user_id: 224 | authorized = True 225 | elif function.public: 226 | authorized = True 227 | 228 | if not authorized: 229 | # Check if there are any groups associated with this function 230 | groups = FunctionAuthGroup.find_by_function_id(function.id) 231 | function_groups = [g.group_id for g in groups] 232 | 233 | if len(function_groups) > 0: 234 | authorized = check_group_membership(token, function_groups) 235 | 236 | return authorized 237 | -------------------------------------------------------------------------------- /funcx_web_service/authentication/auth_state.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | 3 | import globus_sdk 4 | from flask import abort, current_app, g, request 5 | 6 | from funcx_web_service.models.user import User 7 | 8 | from .globus_auth import introspect_token 9 | 10 | 11 | class AuthenticationState: 12 | """ 13 | This is a dedicated object for handling authentication. 14 | 15 | It takes in a Globus Auth token and resolve it to a user and various data about that 16 | user. It is the "auth_state" object for the application within the context of a 17 | request, showing "who" is calling the application (e.g. identity_id) and some 18 | information about "how" the call is being made (e.g. scopes). 19 | 20 | For the most part, this should not handle authorization checks, to maintain 21 | separation of concerns. 22 | """ 23 | 24 | # Default scope if not provided in config 25 | DEFAULT_FUNCX_SCOPE = ( 26 | "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all" 27 | ) 28 | 29 | def __init__( 30 | self, token: t.Optional[str], *, assert_default_scope: bool = True 31 | ) -> None: 32 | self.funcx_all_scope: str = current_app.config.get( 33 | "FUNCX_SCOPE", self.DEFAULT_FUNCX_SCOPE 34 | ) 35 | self.token = token 36 | 37 | self.introspect_data: t.Optional[globus_sdk.GlobusHTTPResponse] = None 38 | self.identity_id: t.Optional[str] = None 39 | self.username: t.Optional[str] = None 40 | self._user_object: t.Optional[User] = None 41 | self.scopes: t.Set[str] = set() 42 | 43 | if token: 44 | self._handle_token() 45 | 46 | def _handle_token(self) -> None: 47 | """Given a token, flesh out the AuthenticationState.""" 48 | self.introspect_data = introspect_token(t.cast(str, self.token)) 49 | self.username = self.introspect_data["username"] 50 | self.identity_id = self.introspect_data["sub"] 51 | self.scopes = set(self.introspect_data["scope"].split(" ")) 52 | 53 | @property 54 | def user_object(self) -> User: 55 | if self._user_object is None: 56 | self._user_object = User.resolve_user(self.username) 57 | return self._user_object 58 | 59 | @property 60 | def is_authenticated(self): 61 | return self.identity_id is not None 62 | 63 | def assert_is_authenticated(self): 64 | """ 65 | This tests that is_authenticated=True, and raises an Unauthorized error 66 | (401) if it is not. 67 | """ 68 | if not self.is_authenticated: 69 | abort(401, "method requires token authenticated access") 70 | 71 | def assert_has_scope(self, scope: str) -> None: 72 | # the user may be thought of as "not properly authorized" if they do not have 73 | # required scopes 74 | # this leaves a question of whether this should be a 401 (Unauthorized) or 75 | # 403 (Forbidden) 76 | # 77 | # the answer is that it must be a 403 78 | # this requirement is set by the OAuth2 spec 79 | # 80 | # for more detail, see 81 | # https://datatracker.ietf.org/doc/html/rfc6750#section-3.1 82 | if scope not in self.scopes: 83 | abort(403, "Missing Scopes") 84 | 85 | def assert_has_default_scope(self) -> None: 86 | self.assert_has_scope(self.funcx_all_scope) 87 | 88 | 89 | def get_auth_state(): 90 | """ 91 | Get the current AuthenticationState. This may be called at any time in the 92 | application within a request context, but will always return the same state object. 93 | 94 | It is especially useful for tests, which can mock over the return value of this 95 | function by setting `g.auth_state` and be assured that the application will respect 96 | the fake authentication info. 97 | """ 98 | if not g.get("auth_state", None): 99 | cred = request.headers.get("Authorization", None) 100 | if cred and cred.startswith("Bearer "): 101 | cred = cred[7:] 102 | else: 103 | cred = None 104 | 105 | g.auth_state = AuthenticationState(cred) 106 | return g.auth_state 107 | -------------------------------------------------------------------------------- /funcx_web_service/authentication/globus_auth.py: -------------------------------------------------------------------------------- 1 | import globus_sdk 2 | from flask import abort, current_app 3 | 4 | 5 | def introspect_token( 6 | token: str, *, verify: bool = True 7 | ) -> globus_sdk.GlobusHTTPResponse: 8 | client = get_auth_client() 9 | data = client.oauth2_token_introspect(token) 10 | if verify: 11 | if not data.get("active", False): 12 | abort(401, "Credentials are inactive.") 13 | return data 14 | 15 | 16 | # FIXME: 17 | # this should be only creating a single client per web worker, not a new one per call 18 | def get_auth_client(): 19 | """Create an AuthClient for the service.""" 20 | return globus_sdk.ConfidentialAppAuthClient( 21 | current_app.config["GLOBUS_CLIENT"], current_app.config["GLOBUS_KEY"] 22 | ) 23 | -------------------------------------------------------------------------------- /funcx_web_service/container_service_adapter.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import urljoin 2 | 3 | import requests 4 | 5 | 6 | class ContainerServiceAdapter: 7 | def __init__(self, service_url): 8 | self.service_url = service_url 9 | 10 | def get_version(self): 11 | result = requests.get(urljoin(self.service_url, "version")) 12 | if result.status_code == 200: 13 | return result.json() 14 | else: 15 | return {"version": "Service Unavailable"} 16 | -------------------------------------------------------------------------------- /funcx_web_service/error_responses.py: -------------------------------------------------------------------------------- 1 | from flask import jsonify 2 | from funcx_common.response_errors import FuncxResponseError 3 | from werkzeug.exceptions import HTTPException 4 | 5 | 6 | def create_error_response(exception, jsonify_response=False): 7 | """Creates JSON object responses for errors that occur in the service. 8 | These responses can be sent back to the funcx SDK client to be decoded. 9 | They also have a "reason" property so that they are human-readable. 10 | Note that the returned JSON object will be a dict unless jsonify_response 11 | is enabled, in which case it will return JSON that Flask can respond with. 12 | This helper will not raise the exception passed in, it will only turn 13 | it into a JSON object. 14 | 15 | Parameters 16 | ========== 17 | 18 | exception : Exception 19 | Exception to convert to a JSON object 20 | jsonify_response : bool 21 | Whether or not to call 'jsonify' on the resulting dict 22 | 23 | Returns: 24 | JSON object, HTTP status code 25 | """ 26 | if isinstance(exception, FuncxResponseError): 27 | # the pack method turns a FuncxResponseError into a record 28 | # which will become json 29 | response = exception.pack() 30 | status_code = int(exception.http_status_code) 31 | else: 32 | status_code = None 33 | # if there is an HTTPException (e.g. due to calling of Flask abort()) 34 | # we can grab the status code from the exception 35 | if isinstance(exception, HTTPException): 36 | status_code = exception.code 37 | reason = str(exception) 38 | else: 39 | reason = f"An unknown error occurred: {exception}" 40 | 41 | if status_code is None: 42 | status_code = 500 43 | # if the error is not recognized as a FuncxResponseError, a generic 44 | # response of the same format will be sent back, indicating an 45 | # internal server error 46 | response = { 47 | "status": "Failed", 48 | "code": 0, 49 | "error_args": [], 50 | "reason": reason, 51 | "http_status_code": status_code, 52 | } 53 | 54 | if jsonify_response: 55 | response = jsonify(response) 56 | 57 | return response, status_code 58 | -------------------------------------------------------------------------------- /funcx_web_service/models/__init__.py: -------------------------------------------------------------------------------- 1 | from flask_sqlalchemy import SQLAlchemy 2 | 3 | db: SQLAlchemy = SQLAlchemy() 4 | 5 | 6 | def load_all_models(): 7 | # deferred import of the necessary model code 8 | from . import auth_groups, container, function, user # noqa: F401 9 | -------------------------------------------------------------------------------- /funcx_web_service/models/auth_groups.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Column, Integer, String 2 | from sqlalchemy.orm.exc import NoResultFound 3 | 4 | from funcx_web_service.models import db 5 | 6 | 7 | class AuthGroup(db.Model): 8 | __tablename__ = "auth_groups" 9 | id = Column(Integer, primary_key=True) 10 | group_id = Column(String(67)) 11 | endpoint_id = Column(String(67)) 12 | 13 | @classmethod 14 | def find_by_uuid(cls, uuid): 15 | try: 16 | return cls.query.filter_by(group_id=uuid).first() 17 | except NoResultFound: 18 | return None 19 | 20 | @classmethod 21 | def find_by_endpoint_uuid(cls, endpoint_uuid): 22 | return cls.query.filter_by(endpoint_id=endpoint_uuid).all() 23 | -------------------------------------------------------------------------------- /funcx_web_service/models/container.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from sqlalchemy import DateTime, ForeignKey, Integer 4 | from sqlalchemy.orm import relationship 5 | from sqlalchemy.orm.exc import NoResultFound 6 | 7 | from funcx_web_service.models import db 8 | 9 | 10 | class Container(db.Model): 11 | __tablename__ = "containers" 12 | id = db.Column(db.Integer, primary_key=True) 13 | author = db.Column(Integer, ForeignKey("users.id")) 14 | container_uuid = db.Column(db.String(67)) 15 | name = db.Column(db.String(1024)) 16 | description = db.Column(db.Text) 17 | created_at = db.Column(DateTime, default=datetime.utcnow) 18 | modified_at = db.Column(DateTime, default=datetime.utcnow) 19 | 20 | images = relationship("ContainerImage") 21 | functions = relationship("FunctionContainer") 22 | user = relationship("User", back_populates="containers") 23 | 24 | def save_to_db(self): 25 | db.session.add(self) 26 | db.session.commit() 27 | 28 | @classmethod 29 | def find_by_uuid(cls, uuid): 30 | try: 31 | return cls.query.filter_by(container_uuid=uuid).first() 32 | except NoResultFound: 33 | return None 34 | 35 | @classmethod 36 | def find_by_uuid_and_type(cls, uuid, type): 37 | try: 38 | return ( 39 | cls.query.filter_by(container_uuid=uuid) 40 | .join(ContainerImage) 41 | .filter_by(type=type) 42 | .first() 43 | ) 44 | except NoResultFound: 45 | return None 46 | 47 | def to_json(self): 48 | result = {"container_uuid": self.container_uuid, "name": self.name} 49 | 50 | if self.images and len(self.images) == 1: 51 | result["type"] = self.images[0].type 52 | result["location"] = self.images[0].location 53 | 54 | return result 55 | 56 | 57 | class ContainerImage(db.Model): 58 | __tablename__ = "container_images" 59 | id = db.Column(db.Integer, primary_key=True) 60 | container_id = db.Column(Integer, ForeignKey("containers.id")) 61 | type = db.Column(db.String(256)) 62 | location = db.Column(db.String(1024)) 63 | created_at = db.Column(DateTime, default=datetime.utcnow) 64 | modified_at = db.Column(DateTime, default=datetime.utcnow) 65 | -------------------------------------------------------------------------------- /funcx_web_service/models/endpoint.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from sqlalchemy import Boolean, DateTime, Float, ForeignKey, Integer, String, and_ 4 | from sqlalchemy.orm import relationship 5 | from sqlalchemy.orm.exc import NoResultFound 6 | 7 | from funcx_web_service.models import db 8 | from funcx_web_service.models.user import User 9 | 10 | restricted_endpoint_table = db.Table( 11 | "restricted_endpoint_functions", 12 | db.Column("id", Integer, primary_key=True), 13 | db.Column("endpoint_id", Integer, ForeignKey("sites.id")), 14 | db.Column("function_id", Integer, ForeignKey("functions.id")), 15 | ) 16 | 17 | 18 | class Endpoint(db.Model): 19 | __tablename__ = "sites" 20 | __table_args__ = ( 21 | db.UniqueConstraint("endpoint_uuid", name="unique_endpoint_uuid"), 22 | ) 23 | 24 | id = db.Column(Integer, primary_key=True) 25 | name = db.Column(String(256)) 26 | description = db.Column(String(256)) 27 | user_id = db.Column(Integer, ForeignKey("users.id")) 28 | status = db.Column(String(10)) 29 | endpoint_name = db.Column(String(256)) 30 | endpoint_uuid = db.Column(String(38)) 31 | public = db.Column(Boolean, default=False) 32 | deleted = db.Column(Boolean, default=False) 33 | ip_addr = db.Column(String(15)) 34 | city = db.Column(String(256)) 35 | region = db.Column(String(256)) 36 | country = db.Column(String(256)) 37 | zipcode = db.Column(String(10)) 38 | latitude = db.Column(Float) 39 | longitude = db.Column(Float) 40 | core_hours = db.Column(Float, default=0) 41 | hostname = db.Column(String(256)) 42 | org = db.Column(String(256)) 43 | restricted = db.Column(Boolean, default=False) 44 | created_at = db.Column(DateTime, default=datetime.utcnow) 45 | 46 | user = relationship("User", back_populates="endpoints") 47 | tasks = relationship("DBTask") 48 | 49 | restricted_functions = db.relationship( 50 | "Function", 51 | secondary=restricted_endpoint_table, 52 | back_populates="restricted_endpoints", 53 | ) 54 | 55 | def save_to_db(self): 56 | db.session.add(self) 57 | db.session.commit() 58 | 59 | def delete_whitelist_for_function(self, function): 60 | conn = db.engine.connect() 61 | 62 | s = restricted_endpoint_table.delete().where( 63 | and_( 64 | restricted_endpoint_table.c.endpoint_id == self.id, 65 | restricted_endpoint_table.c.function_id == function.id, 66 | ) 67 | ) 68 | 69 | conn.execute(s) 70 | 71 | @classmethod 72 | def find_by_uuid(cls, uuid): 73 | try: 74 | return cls.query.filter_by(endpoint_uuid=uuid).first() 75 | except NoResultFound: 76 | return None 77 | 78 | @classmethod 79 | def delete_endpoint(cls, user: User, endpoint_uuid): 80 | """Delete a function 81 | 82 | Parameters 83 | ---------- 84 | user : User 85 | The primary identity of the user 86 | endpoint_uuid : str 87 | The uuid of the endpoint 88 | 89 | Returns 90 | ------- 91 | str 92 | The result as a status code integer 93 | "302" for success and redirect 94 | "403" for unauthorized 95 | "404" for a non-existent or previously-deleted endpoint 96 | "500" for try statement error 97 | """ 98 | 99 | user_id = user.id 100 | 101 | try: 102 | existing_endpoint = Endpoint.find_by_uuid(endpoint_uuid) 103 | if existing_endpoint: 104 | if not existing_endpoint.deleted: 105 | if existing_endpoint.user_id == user_id: 106 | existing_endpoint.deleted = True 107 | existing_endpoint.save_to_db() 108 | return 302 109 | else: 110 | return 403 # Endpoint doesn't belong to user 111 | else: 112 | return 404 # Endpoint is already deleted 113 | else: 114 | return 404 # Endpoint not found 115 | 116 | except Exception as e: 117 | print(e) 118 | return 500 119 | -------------------------------------------------------------------------------- /funcx_web_service/models/function.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from sqlalchemy import DateTime, ForeignKey, Integer, String, Text 4 | from sqlalchemy.orm import relationship 5 | from sqlalchemy.orm.exc import NoResultFound 6 | 7 | from funcx_web_service.models import db 8 | from funcx_web_service.models.endpoint import restricted_endpoint_table 9 | 10 | 11 | class Function(db.Model): 12 | __tablename__ = "functions" 13 | __table_args__ = ( 14 | db.UniqueConstraint("function_uuid", name="unique_function_uuid"), 15 | ) 16 | 17 | id = db.Column(Integer, nullable=False, primary_key=True, autoincrement=True) 18 | user_id = db.Column(Integer, ForeignKey("users.id")) 19 | name = db.Column(String(1024)) 20 | description = db.Column(Text) 21 | status = db.Column(String(1024)) 22 | function_name = db.Column(String(1024)) 23 | function_uuid = db.Column(String(38)) 24 | function_source_code = db.Column(Text) 25 | timestamp = db.Column(DateTime, default=datetime.utcnow) 26 | entry_point = db.Column(String(38)) 27 | modified_at = db.Column(DateTime, default=datetime.utcnow) 28 | deleted = db.Column(db.Boolean, default=False) 29 | public = db.Column(db.Boolean, default=False) 30 | 31 | container = relationship( 32 | "FunctionContainer", uselist=False, back_populates="function" 33 | ) 34 | auth_groups = relationship("FunctionAuthGroup") 35 | 36 | tasks = relationship("DBTask") 37 | 38 | restricted_endpoints = relationship( 39 | "Endpoint", 40 | secondary=restricted_endpoint_table, 41 | back_populates="restricted_functions", 42 | ) 43 | 44 | user = relationship("User", back_populates="functions") 45 | 46 | def save_to_db(self): 47 | db.session.add(self) 48 | db.session.commit() 49 | 50 | @classmethod 51 | def find_by_uuid(cls, uuid): 52 | try: 53 | return cls.query.filter_by(function_uuid=uuid).first() 54 | except NoResultFound: 55 | return None 56 | 57 | 58 | class FunctionContainer(db.Model): 59 | __tablename__ = "function_containers" 60 | id = db.Column(Integer, primary_key=True) 61 | container_id = db.Column(Integer, ForeignKey("containers.id")) 62 | function_id = db.Column(Integer, ForeignKey("functions.id")) 63 | created_at = db.Column(DateTime, default=datetime.utcnow) 64 | modified_at = db.Column(DateTime, default=datetime.utcnow) 65 | 66 | function = relationship("Function", back_populates="container") 67 | container = relationship("Container", back_populates="functions") 68 | 69 | 70 | class FunctionAuthGroup(db.Model): 71 | __tablename__ = "function_auth_groups" 72 | id = db.Column(Integer, primary_key=True) 73 | group_id = db.Column(String(38)) 74 | function_id = db.Column(Integer, ForeignKey("functions.id")) 75 | function = relationship("Function", back_populates="auth_groups") 76 | 77 | @classmethod 78 | def find_by_function_id(cls, function_id): 79 | return cls.query.filter_by(function_id=function_id).all() 80 | -------------------------------------------------------------------------------- /funcx_web_service/models/search.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from flask import current_app as app 4 | from globus_sdk import AccessTokenAuthorizer, SearchAPIError, SearchClient 5 | 6 | import funcx_web_service.authentication.auth 7 | 8 | FUNCTION_SEARCH_INDEX_NAME = "funcx" 9 | FUNCTION_SEARCH_INDEX_ID = "673a4b58-3231-421d-9473-9df1b6fa3a9d" 10 | ENDPOINT_SEARCH_INDEX_NAME = "funcx_endpoints" 11 | ENDPOINT_SEARCH_INDEX_ID = "85bcc497-3ee9-4d73-afbb-2abf292e398b" 12 | SEARCH_SCOPE = "urn:globus:auth:scope:search.api.globus.org:all" 13 | 14 | # Search limit defined by the globus API 15 | SEARCH_LIMIT = 10000 16 | 17 | # By default we will return 10 functions at a time 18 | DEFAULT_SEARCH_LIMIT = 10 19 | 20 | 21 | def _sanitize_tokens(token_data: Dict[str, Any]) -> Dict[str, Any]: 22 | access_token = token_data.get("access_token", "") 23 | if access_token is not None: 24 | access_token = f"***{token_data['access_token'][-5:]}" 25 | 26 | refresh_token = token_data.get("refresh_token", "") 27 | if refresh_token is not None: 28 | refresh_token = f"***{token_data['refresh_token'][:-5:]}" 29 | 30 | return { 31 | "scope": token_data.get("scope"), 32 | "access_token": access_token, 33 | "refresh_token": refresh_token, 34 | } 35 | 36 | 37 | def get_search_client(): 38 | """Creates a Globus Search Client using FuncX's client token""" 39 | auth_client = funcx_web_service.authentication.auth.get_auth_client() 40 | tokens = auth_client.oauth2_client_credentials_tokens( 41 | requested_scopes=[SEARCH_SCOPE] 42 | ) 43 | search_token = tokens.by_scopes[SEARCH_SCOPE] 44 | app.logger.debug(f"Search token: {_sanitize_tokens(search_token)}") 45 | access_token = search_token["access_token"] 46 | authorizer = AccessTokenAuthorizer(access_token) 47 | app.logger.debug("Acquired AccessTokenAuthorizer for search") 48 | search_client = SearchClient(authorizer) 49 | app.logger.debug("Acquired SearchClient with that authorizer") 50 | return search_client 51 | 52 | 53 | def _trim_func_data(func_data): 54 | """Remove unnecessary fields from FuncX function metadata for ingest 55 | 56 | Parameters 57 | ---------- 58 | func_data : dict 59 | the data put into redis for a function 60 | 61 | Returns 62 | ------- 63 | dict 64 | a dict with the fields we want in search, notably including an author field 65 | """ 66 | return { 67 | "function_name": func_data["function_name"], 68 | "function_code": func_data["function_code"], 69 | "function_source": func_data["function_source"], 70 | "container_uuid": func_data.get("container_uuid", ""), 71 | "description": func_data["description"], 72 | "public": func_data["public"], 73 | "group": func_data["group"], 74 | "author": "", 75 | } 76 | 77 | 78 | def _exists(client, index, func_uuid): 79 | """Checks if a func_uuid exists in the search index 80 | 81 | Mainly used to determine whether we need a create or an update call to the 82 | search API 83 | 84 | Parameters 85 | ---------- 86 | func_uuid : str 87 | the uuid of the function 88 | 89 | Returns 90 | ------- 91 | bool 92 | True if `func_uuid` is a subject in Globus Search index 93 | """ 94 | try: 95 | res = client.get_entry(index, func_uuid) 96 | return len(res.data["entries"]) > 0 97 | except SearchAPIError as err: 98 | if err.http_status == 404: 99 | return False 100 | raise err 101 | 102 | 103 | def func_ingest_or_update(func_uuid, func_data, author="", author_urn=""): 104 | """Update or create a function in search index 105 | 106 | Parameters 107 | ---------- 108 | func_uuid : str 109 | func_data : dict 110 | author : str 111 | author_urn : str 112 | """ 113 | client = get_search_client() 114 | app.logger.debug("Acquired search client") 115 | acl = [] 116 | if func_data["public"]: 117 | acl.append("public") 118 | elif func_data["group"]: 119 | group: str = func_data["group"] 120 | if not group.startswith("urn"): 121 | group = f"urn:globus:groups:id:{group}" 122 | acl.append(group) 123 | 124 | # Ensure that the author of the function and the funcx search admin group have 125 | # access 126 | # TODO: do we want access to everything? 127 | # Is this the default since we control the index? 128 | acl.append(author_urn) 129 | acl.append("urn:globus:groups:id:69e12e30-b499-11ea-91c1-0a0ee5aecb35") 130 | 131 | content = _trim_func_data(func_data) 132 | content["author"] = author 133 | content["version"] = "0" 134 | 135 | ingest_data = {"subject": func_uuid, "visible_to": acl, "content": content} 136 | 137 | # Since we associate only 1 entry with each subject (func_uuid), there is basically 138 | # no difference between creating and updating, other than the method... 139 | app.logger.debug( 140 | f"Ingesting the following data to {FUNCTION_SEARCH_INDEX_ID}: {ingest_data}" 141 | ) 142 | 143 | if not _exists(client, FUNCTION_SEARCH_INDEX_ID, func_uuid): 144 | client.create_entry(FUNCTION_SEARCH_INDEX_ID, ingest_data) 145 | else: 146 | client.update_entry(FUNCTION_SEARCH_INDEX_ID, ingest_data) 147 | 148 | 149 | def endpoint_ingest_or_update(ep_uuid, data, owner="", owner_urn=""): 150 | """ 151 | 152 | Parameters 153 | ---------- 154 | ep_uuid 155 | data 156 | owner 157 | owner_urn 158 | 159 | Returns 160 | ------- 161 | 162 | """ 163 | client = get_search_client() 164 | acl = [] 165 | if data["public"]: 166 | acl.append("public") 167 | 168 | acl.extend(data["visible_to"]) 169 | del data["visible_to"] 170 | 171 | # Ensure that the author of the function and the funcx search admin group have 172 | # access 173 | # TODO: do we want access to everything? 174 | # Is this the default since we control the index? 175 | acl.append(owner_urn) 176 | acl.append("urn:globus:groups:id:69e12e30-b499-11ea-91c1-0a0ee5aecb35") 177 | 178 | content = data.copy() 179 | content["owner"] = owner_urn 180 | 181 | ingest_data = {"subject": ep_uuid, "visible_to": acl, "content": content} 182 | 183 | app.logger.debug( 184 | f"Ingesting endpoint data to {ENDPOINT_SEARCH_INDEX_NAME}: {ingest_data}" 185 | ) 186 | # TODO: security (if exists, dont allow updates if not owner) 187 | if not _exists(client, ENDPOINT_SEARCH_INDEX_ID, ep_uuid): 188 | res = client.create_entry(ENDPOINT_SEARCH_INDEX_ID, ingest_data) 189 | else: 190 | res = client.update_entry(ENDPOINT_SEARCH_INDEX_ID, ingest_data) 191 | 192 | app.logger.debug(f"received response from Search API: {res.text}") 193 | -------------------------------------------------------------------------------- /funcx_web_service/models/serializer.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from flask import current_app as app 3 | 4 | 5 | def serialize_inputs(input_data): 6 | """Use the serialization service to encode input data. 7 | 8 | Parameters 9 | ---------- 10 | input_data : str 11 | The input data to pass to ther serializer 12 | 13 | Returns 14 | ------- 15 | str : The encoded data 16 | """ 17 | ser_addr = app.config["SERIALIZATION_ADDR"] 18 | ser_port = app.config["SERIALIZATION_PORT"] 19 | 20 | res = requests.post(f"http://{ser_addr}:{ser_port}/serialize", json=input_data) 21 | if res.status_code == 200: 22 | return res.json() 23 | 24 | return None 25 | 26 | 27 | def deserialize_result(result): 28 | """Use the serialization service to decode result. 29 | 30 | Parameters 31 | ---------- 32 | result : str 33 | The data to pass to the deserializer 34 | 35 | Returns 36 | ------- 37 | str : The decoded data 38 | """ 39 | ser_addr = app.config["SERIALIZATION_ADDR"] 40 | ser_port = app.config["SERIALIZATION_PORT"] 41 | 42 | res = requests.post(f"http://{ser_addr}:{ser_port}/deserialize", json=result) 43 | if res.status_code == 200: 44 | return res.json() 45 | 46 | return None 47 | -------------------------------------------------------------------------------- /funcx_web_service/models/tasks.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | from datetime import datetime, timedelta 3 | from enum import Enum 4 | 5 | from funcx_common.redis import ( 6 | INT_SERDE, 7 | JSON_SERDE, 8 | FuncxRedisEnumSerde, 9 | HasRedisFieldsMeta, 10 | RedisField, 11 | ) 12 | from funcx_common.tasks import TaskProtocol, TaskState 13 | from redis import Redis 14 | from sqlalchemy import DateTime, ForeignKey, Integer, String 15 | from sqlalchemy.orm import relationship 16 | 17 | from funcx_web_service.models import db 18 | 19 | 20 | # This internal state is never shown to the user and is meant to track whether 21 | # or not the forwarder has succeeded in fully processing the task 22 | class InternalTaskState(str, Enum): 23 | INCOMPLETE = "incomplete" 24 | COMPLETE = "complete" 25 | 26 | 27 | class DBTask(db.Model): 28 | __tablename__ = "tasks" 29 | id = db.Column(Integer, primary_key=True) 30 | user_id = db.Column(Integer, ForeignKey("users.id")) 31 | task_uuid = db.Column(String(38)) 32 | status = db.Column(String(10), default="UNKNOWN") 33 | created_at = db.Column(DateTime, default=datetime.utcnow) 34 | modified_at = db.Column(DateTime, default=datetime.utcnow) 35 | endpoint_id = db.Column(String(38), ForeignKey("sites.endpoint_uuid")) 36 | function_id = db.Column(String(38), ForeignKey("functions.function_uuid")) 37 | 38 | function = relationship("Function", back_populates="tasks") 39 | endpoint = relationship("Endpoint", back_populates="tasks") 40 | 41 | user = relationship("User", back_populates="tasks") 42 | 43 | def save_to_db(self): 44 | db.session.add(self) 45 | db.session.commit() 46 | 47 | 48 | class RedisTask(TaskProtocol, metaclass=HasRedisFieldsMeta): 49 | """ 50 | ORM-esque class to wrap access to properties of tasks for better style and 51 | encapsulation 52 | """ 53 | 54 | status = t.cast(TaskState, RedisField(serde=FuncxRedisEnumSerde(TaskState))) 55 | internal_status = t.cast( 56 | InternalTaskState, RedisField(serde=FuncxRedisEnumSerde(InternalTaskState)) 57 | ) 58 | user_id = RedisField(serde=INT_SERDE) 59 | function_id = RedisField() 60 | endpoint = t.cast(str, RedisField()) 61 | container = RedisField() 62 | payload = RedisField(serde=JSON_SERDE) 63 | payload_reference = t.cast( 64 | t.Optional[t.Dict[str, t.Any]], RedisField(serde=JSON_SERDE) 65 | ) 66 | result = RedisField() 67 | result_reference = t.cast( 68 | t.Optional[t.Dict[str, t.Any]], RedisField(serde=JSON_SERDE) 69 | ) 70 | exception = RedisField() 71 | completion_time = RedisField() 72 | task_group_id = RedisField() 73 | 74 | # must keep ttl and _set_expire in merge 75 | # tasks expire in 1 week, we are giving some grace period for 76 | # long-lived clients, and we'll revise this if there are complaints 77 | TASK_TTL = timedelta(weeks=2) 78 | 79 | def __init__( 80 | self, 81 | redis_client: Redis, 82 | task_id: str, 83 | *, 84 | user_id: t.Optional[int] = None, 85 | function_id: t.Optional[str] = None, 86 | container: t.Optional[str] = None, 87 | payload: t.Any = None, 88 | task_group_id: t.Optional[str] = None, 89 | ): 90 | """ 91 | If optional values are passed, then they will be written. 92 | Otherwise, they will fetched from any existing task entry. 93 | 94 | :param redis_client: Redis client for properties to get/set 95 | :param task_id: UUID of the task, as str 96 | :param user_id: ID of user to whom this task belongs 97 | :param function_id: UUID of the function for this task, as str 98 | :param container: UUID of container in which to run, as str 99 | :param payload: serialized function + input data 100 | :param task_group_id: UUID of task group that this task belongs to 101 | """ 102 | self.redis_client = redis_client 103 | self.task_id = task_id 104 | self.hname = f"task_{task_id}" 105 | 106 | # if required attributes are not yet set, initialize them to their defaults 107 | if self.status is None: 108 | self.status = TaskState.WAITING_FOR_EP 109 | if self.internal_status is None: 110 | self.internal_status = InternalTaskState.INCOMPLETE 111 | 112 | if user_id is not None: 113 | self.user_id = user_id 114 | if function_id is not None: 115 | self.function_id = function_id 116 | if container is not None: 117 | self.container = container 118 | if payload is not None: 119 | self.payload = payload 120 | if task_group_id is not None: 121 | self.task_group_id = task_group_id 122 | 123 | # Used to pass bits of information to EP 124 | self.header = f"{self.task_id};{self.container};None" 125 | self._set_expire() 126 | 127 | def _set_expire(self): 128 | """Expires task after TASK_TTL, if not already set.""" 129 | ttl = self.redis_client.ttl(self.hname) 130 | if ttl < 0: 131 | # expire was not already set 132 | self.redis_client.expire(self.hname, RedisTask.TASK_TTL) 133 | 134 | def delete(self): 135 | """Removes this task from Redis, to be used after the result is gotten""" 136 | self.redis_client.delete(self.hname) 137 | 138 | @classmethod 139 | def exists(cls, redis_client: Redis, task_id: str) -> bool: 140 | """Check if a given task_id exists in Redis""" 141 | return bool(redis_client.exists(f"task_{task_id}")) 142 | 143 | 144 | class TaskGroup(metaclass=HasRedisFieldsMeta): 145 | """ 146 | ORM-esque class to wrap access to properties of batches for better style and 147 | encapsulation 148 | """ 149 | 150 | user_id = RedisField(serde=INT_SERDE) 151 | 152 | TASK_GROUP_TTL = timedelta(weeks=1) 153 | 154 | def __init__(self, redis_client: Redis, task_group_id: str, user_id: int = None): 155 | """ 156 | If the kwargs are passed, then they will be overwritten. Otherwise, they 157 | will gotten from existing task entry. 158 | 159 | :param redis_client: Redis client so that properties can get/set 160 | """ 161 | self.redis_client = redis_client 162 | self.task_group_id = task_group_id 163 | self.hname = f"task_group_{task_group_id}" 164 | 165 | if user_id is not None: 166 | self.user_id = user_id 167 | 168 | self.header = self.task_group_id 169 | self._set_expire() 170 | 171 | def _set_expire(self): 172 | """Expires task after TASK_TTL, if not already set.""" 173 | ttl = self.redis_client.ttl(self.hname) 174 | if ttl < 0: 175 | # expire was not already set 176 | self.redis_client.expire(self.hname, TaskGroup.TASK_GROUP_TTL) 177 | 178 | def delete(self): 179 | """Removes this task group from Redis, to be used after the result is gotten""" 180 | self.redis_client.delete(self.hname) 181 | 182 | @classmethod 183 | def exists(cls, redis_client: Redis, task_group_id: str) -> bool: 184 | """Check if a given task_group_id exists in Redis""" 185 | return bool(redis_client.exists(f"task_group_{task_group_id}")) 186 | -------------------------------------------------------------------------------- /funcx_web_service/models/user.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | from sqlalchemy import Boolean, Column, DateTime, Integer, String 4 | from sqlalchemy.orm import relationship 5 | from sqlalchemy.orm.exc import NoResultFound 6 | 7 | from funcx_web_service.models import db 8 | 9 | 10 | class User(db.Model): 11 | __tablename__ = "users" 12 | id = Column(Integer, primary_key=True) 13 | username = Column(String(256)) 14 | globus_identity = Column(String(256)) 15 | created_at = db.Column(DateTime, default=datetime.utcnow) 16 | namespace = Column(String(1024)) 17 | deleted = Column(Boolean, default=False) 18 | 19 | functions = relationship("Function") 20 | endpoints = relationship("Endpoint") 21 | containers = relationship("Container") 22 | tasks = relationship("DBTask") 23 | 24 | def save_to_db(self): 25 | db.session.add(self) 26 | db.session.commit() 27 | 28 | @classmethod 29 | def find_by_username(cls, username): 30 | try: 31 | return cls.query.filter_by(username=username).first() 32 | except NoResultFound: 33 | return None 34 | 35 | @classmethod 36 | def resolve_user(cls, username): 37 | existing_user = cls.find_by_username(username) 38 | 39 | if existing_user: 40 | return existing_user 41 | else: 42 | new_user = User(username=username) 43 | new_user.save_to_db() 44 | return new_user 45 | -------------------------------------------------------------------------------- /funcx_web_service/models/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import uuid 3 | 4 | import redis 5 | from flask import current_app as app 6 | from funcx_common.response_errors import EndpointAlreadyRegistered, FunctionNotFound 7 | 8 | from funcx_web_service.models import search 9 | from funcx_web_service.models.endpoint import Endpoint 10 | from funcx_web_service.models.function import Function 11 | from funcx_web_service.models.tasks import DBTask 12 | from funcx_web_service.models.user import User 13 | 14 | 15 | class db_invocation_logger: 16 | def log(self, user_id, task_id, function_id, endpoint_id, deferred=False): 17 | try: 18 | status = "CREATED" 19 | task_record = DBTask( 20 | user_id=user_id, 21 | function_id=function_id, 22 | endpoint_id=endpoint_id, 23 | status=status, 24 | ) 25 | task_record.save_to_db() 26 | except Exception: 27 | app.logger.exception("Caught error while writing log update to db") 28 | 29 | def commit(self): 30 | pass 31 | 32 | 33 | def add_ep_whitelist(user: User, endpoint_uuid, functions): 34 | """Add a list of function to the endpoint's whitelist. 35 | 36 | This function is only allowed by the owner of the endpoint. 37 | 38 | Parameters 39 | ---------- 40 | user : User 41 | The user making the request 42 | endpoint_uuid : str 43 | The uuid of the endpoint to add the whitelist entries for 44 | functions : list 45 | A list of the function ids to add to the whitelist. 46 | 47 | Returns 48 | ------- 49 | json 50 | The result of adding the functions to the whitelist 51 | """ 52 | 53 | user_id = user.id 54 | 55 | endpoint = Endpoint.find_by_uuid(endpoint_uuid) 56 | 57 | if not endpoint: 58 | return { 59 | "status": "Failed", 60 | "reason": f"Endpoint {endpoint_uuid} is not found in database", 61 | } 62 | 63 | if endpoint.user_id != user_id: 64 | return { 65 | "status": "Failed", 66 | "reason": f"Endpoint does not belong to User {user.username}", 67 | } 68 | 69 | try: 70 | restricted_functions = [Function.find_by_uuid(f) for f in functions] 71 | endpoint.restricted_functions.extend(restricted_functions) 72 | endpoint.save_to_db() 73 | except Exception as e: 74 | print(e) 75 | return { 76 | "status": "Failed", 77 | "reason": f"Unable to add functions {functions} " 78 | f"to endpoint {endpoint_uuid}, {e}", 79 | } 80 | 81 | return { 82 | "status": "Success", 83 | "reason": f"Added functions {functions} " 84 | f"to endpoint {endpoint_uuid} whitelist.", 85 | } 86 | 87 | 88 | def get_ep_whitelist(user: User, endpoint_id): 89 | """Get the list of functions in an endpoint's whitelist. 90 | 91 | This function is only allowed by the owner of the endpoint. 92 | 93 | Parameters 94 | ---------- 95 | user : User 96 | The name of the user making the request 97 | endpoint_id : str 98 | The uuid of the endpoint to add the whitelist entries for 99 | 100 | Returns 101 | ------- 102 | json 103 | The functions in the whitelist 104 | """ 105 | 106 | endpoint = Endpoint.find_by_uuid(endpoint_id) 107 | if not endpoint: 108 | return {"status": "Failed", "reason": f"Could not find endpoint {endpoint_id}"} 109 | 110 | if endpoint.user != user: 111 | return { 112 | "status": "Failed", 113 | "reason": f"User {user.username} is not authorized to perform this action " 114 | f"on endpoint {endpoint_id}", 115 | } 116 | 117 | functions = [f.function_uuid for f in endpoint.restricted_functions] 118 | return {"status": "Success", "result": functions} 119 | 120 | 121 | def delete_ep_whitelist(user: User, endpoint_id, function_id): 122 | """Delete the functions from an endpoint's whitelist. 123 | 124 | This function is only allowed by the owner of the endpoint. 125 | 126 | Parameters 127 | ---------- 128 | user : User 129 | The the user making the request 130 | endpoint_id : str 131 | The uuid of the endpoint to add the whitelist entries for 132 | function_id : str 133 | The uuid of the function to remove from the whitelist 134 | 135 | Returns 136 | ------- 137 | json 138 | A dict describing the success or failure of removing the function 139 | """ 140 | 141 | saved_endpoint = Endpoint.find_by_uuid(endpoint_id) 142 | if not saved_endpoint: 143 | return { 144 | "status": "Failed", 145 | "reason": f"Endpoint {endpoint_id} not found in database", 146 | } 147 | 148 | if saved_endpoint.user != user: 149 | return { 150 | "status": "Failed", 151 | "reason": f"User {user.username} is not authorized to perform this action " 152 | f"on endpoint {endpoint_id}", 153 | } 154 | 155 | saved_function = Function.find_by_uuid(function_id) 156 | 157 | if not saved_function: 158 | return { 159 | "status": "Failed", 160 | "reason": f"Function {function_id} not found in database", 161 | } 162 | 163 | saved_endpoint.delete_whitelist_for_function(saved_function) 164 | return {"status": "Success", "result": function_id} 165 | 166 | 167 | def ingest_function(function: Function, function_source, user_uuid): 168 | """Ingest a function into Globus Search 169 | 170 | Restructures data for ingest purposes. 171 | 172 | Parameters 173 | ---------- 174 | function : Function 175 | 176 | Returns 177 | ------- 178 | None 179 | """ 180 | selected_group = ( 181 | None if not function.auth_groups else function.auth_groups[0].group_id 182 | ) 183 | container_uuid = ( 184 | None if not function.container else function.container.container.container_uuid 185 | ) 186 | data = { 187 | "function_name": function.function_name, 188 | "function_code": function.function_source_code, 189 | "function_source": function_source, 190 | "container_uuid": container_uuid, 191 | "entry_point": function.entry_point, 192 | "description": function.description, 193 | "public": function.public, 194 | "group": selected_group, 195 | } 196 | user_urn = f"urn:globus:auth:identity:{user_uuid}" 197 | search.func_ingest_or_update( 198 | function.function_uuid, data, author=function.user.username, author_urn=user_urn 199 | ) 200 | 201 | 202 | def ingest_endpoint(user_name, user_uuid, ep_uuid, data): 203 | owner_urn = f"urn:globus:auth:identity:{user_uuid}" 204 | search.endpoint_ingest_or_update( 205 | ep_uuid, data, owner=user_name, owner_urn=owner_urn 206 | ) 207 | 208 | 209 | def register_endpoint(user: User, endpoint_name, description, endpoint_uuid=None): 210 | """Register the endpoint in the database. 211 | 212 | Parameters 213 | ---------- 214 | user : User 215 | The primary identity of the user 216 | endpoint_name : str 217 | The name of the endpoint 218 | description : str 219 | A description of the endpoint 220 | endpoint_uuid : str 221 | The uuid of the endpoint (if it exists) 222 | 223 | Returns 224 | ------- 225 | str 226 | The uuid of the endpoint 227 | """ 228 | user_id = user.id 229 | 230 | if endpoint_uuid: 231 | # Check it is a valid uuid 232 | uuid.UUID(endpoint_uuid) 233 | 234 | existing_endpoint = Endpoint.find_by_uuid(endpoint_uuid) 235 | 236 | if existing_endpoint: 237 | # Make sure user owns this endpoint 238 | if existing_endpoint.user_id == user_id: 239 | existing_endpoint.name = endpoint_name 240 | existing_endpoint.description = description 241 | existing_endpoint.save_to_db() 242 | return endpoint_uuid 243 | else: 244 | app.logger.debug( 245 | f"Endpoint {endpoint_uuid} was previously registered " 246 | f"with user {existing_endpoint.user_id} not {user_id}" 247 | ) 248 | raise EndpointAlreadyRegistered(endpoint_uuid) 249 | else: 250 | endpoint_uuid = str(uuid.uuid4()) 251 | try: 252 | new_endpoint = Endpoint( 253 | user=user, 254 | endpoint_name=endpoint_name, 255 | description=description, 256 | status="OFFLINE", 257 | endpoint_uuid=endpoint_uuid, 258 | ) 259 | new_endpoint.save_to_db() 260 | except Exception as e: 261 | app.logger.error(e) 262 | raise e 263 | return endpoint_uuid 264 | 265 | 266 | def resolve_function(user_id, function_uuid): 267 | """Get the function uuid from database 268 | 269 | Parameters 270 | ---------- 271 | user_id : str 272 | The uuid of the user 273 | function_uuid : str 274 | The uuid of the function 275 | 276 | Returns 277 | ------- 278 | str 279 | The function code 280 | str 281 | The function entry point 282 | str 283 | The uuid of the container image to use 284 | """ 285 | 286 | start = time.time() 287 | 288 | saved_function = Function.find_by_uuid(function_uuid) 289 | 290 | if not saved_function: 291 | raise FunctionNotFound(function_uuid) 292 | 293 | function_code = saved_function.function_source_code 294 | function_entry = saved_function.entry_point 295 | 296 | if saved_function.container: 297 | container_uuid = saved_function.container.container.container_uuid 298 | else: 299 | container_uuid = None 300 | 301 | delta = time.time() - start 302 | app.logger.info(f"Time to fetch function {delta * 1000:.1f}ms") 303 | return function_code, function_entry, container_uuid 304 | 305 | 306 | def get_redis_client(): 307 | """Return a redis client 308 | 309 | Returns 310 | ------- 311 | redis.StrictRedis 312 | A client for redis 313 | """ 314 | try: 315 | redis_client = redis.StrictRedis( 316 | host=app.config["REDIS_HOST"], 317 | port=app.config["REDIS_PORT"], 318 | decode_responses=True, 319 | ) 320 | return redis_client 321 | except Exception as e: 322 | print(e) 323 | 324 | 325 | def update_function( 326 | user_name, 327 | function_uuid, 328 | function_name, 329 | function_desc, 330 | function_entry_point, 331 | function_code, 332 | ): 333 | """Delete a function 334 | 335 | Parameters 336 | ---------- 337 | user_name : str 338 | The primary identity of the user 339 | function_uuid : str 340 | The uuid of the function 341 | function_name : str 342 | The name of the function 343 | function_desc : str 344 | The description of the function 345 | function_entry_point : str 346 | The entry point of the function 347 | function_code : str 348 | The code of the function 349 | 350 | Returns 351 | ------- 352 | str 353 | The result as a status code integer 354 | "302" for success and redirect 355 | "403" for unauthorized 356 | "404" for a non-existent or previously-deleted function 357 | "500" for try statement error 358 | """ 359 | 360 | saved_function = Function.find_by_uuid(function_uuid) 361 | if not saved_function or saved_function.deleted: 362 | return 404 363 | 364 | saved_user = User.resolve_user(user_name) 365 | 366 | if not saved_user or saved_function.user != saved_user: 367 | return 403 368 | 369 | saved_function.function_name = function_name 370 | saved_function.function_desc = function_desc 371 | saved_function.function_entry_point = function_entry_point 372 | saved_function.function_source_code = function_code 373 | saved_function.save_to_db() 374 | return 302 375 | 376 | 377 | def delete_function(user: User, function_uuid): 378 | """Delete a function 379 | 380 | Parameters 381 | ---------- 382 | user : User 383 | The primary identity of the user 384 | function_uuid : str 385 | The uuid of the function 386 | 387 | Returns 388 | ------- 389 | str 390 | The result as a status code integer 391 | "302" for success and redirect 392 | "403" for unauthorized 393 | "404" for a non-existent or previously-deleted function 394 | "500" for try statement error 395 | """ 396 | saved_function = Function.find_by_uuid(function_uuid) 397 | if not saved_function or saved_function.deleted: 398 | return 404 399 | 400 | if saved_function.user != user: 401 | return 403 402 | 403 | saved_function.deleted = True 404 | saved_function.save_to_db() 405 | return 302 406 | -------------------------------------------------------------------------------- /funcx_web_service/response.py: -------------------------------------------------------------------------------- 1 | from flask import Response 2 | 3 | 4 | class FuncxResponseLogData: 5 | def __init__(self): 6 | self.data = {} 7 | 8 | def set_user(self, user): 9 | self.data["user_id"] = user.id 10 | 11 | 12 | class FuncxResponse(Response): 13 | def __init__(self, *args, **kwargs): 14 | self._log_data = FuncxResponseLogData() 15 | super().__init__(*args, **kwargs) 16 | -------------------------------------------------------------------------------- /funcx_web_service/routes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funcx-faas/funcx-web-service/5714d5ab889396a72d81f25cf17e9523a9ecc82b/funcx_web_service/routes/__init__.py -------------------------------------------------------------------------------- /funcx_web_service/routes/container.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from flask import Blueprint 4 | from flask import current_app as app 5 | from flask import jsonify, request 6 | from funcx_common.response_errors import InternalError, RequestKeyError 7 | 8 | from funcx_web_service.authentication.auth import authenticated 9 | 10 | from ..models.container import Container, ContainerImage 11 | from ..models.user import User 12 | 13 | container_api = Blueprint("container_routes", __name__) 14 | 15 | 16 | @container_api.route("/containers//", methods=["GET"]) 17 | @authenticated 18 | def get_cont(user: User, container_id, container_type): 19 | """Get the details of a container. 20 | 21 | Parameters 22 | ---------- 23 | user : User 24 | The primary identity of the user 25 | container_id : str 26 | The id of the container 27 | container_type : str 28 | The type of containers to return: Docker, Singularity, Shifter, etc. 29 | 30 | Returns 31 | ------- 32 | dict 33 | A dictionary of container details 34 | """ 35 | 36 | app.logger.info(f"Getting container details: {container_id}") 37 | container = Container.find_by_uuid_and_type(container_id, container_type) 38 | app.logger.info(f"Got container: {container}") 39 | return jsonify({"container": container.to_json()}) 40 | 41 | 42 | @container_api.route("/containers", methods=["POST"]) 43 | @authenticated 44 | def reg_container(user: User): 45 | """Register a new container. 46 | 47 | Parameters 48 | ---------- 49 | user : User 50 | The primary identity of the user 51 | 52 | JSON Body 53 | --------- 54 | name: Str 55 | description: Str 56 | type: The type of containers that will be used (Singularity, Shifter, Docker) 57 | location: The location of the container (e.g., its docker url). 58 | 59 | Returns 60 | ------- 61 | dict 62 | A dictionary of container details including its uuid 63 | """ 64 | 65 | app.logger.debug("Creating container.") 66 | post_req = request.json 67 | 68 | try: 69 | container_rec = Container( 70 | author=user.id, 71 | name=post_req["name"], 72 | description=None 73 | if not post_req["description"] 74 | else post_req["description"], 75 | container_uuid=str(uuid.uuid4()), 76 | ) 77 | container_rec.images = [ 78 | ContainerImage(type=post_req["type"], location=post_req["location"]) 79 | ] 80 | 81 | container_rec.save_to_db() 82 | 83 | app.logger.info(f"Created container: {container_rec.container_uuid}") 84 | return jsonify({"container_id": container_rec.container_uuid}) 85 | except KeyError as e: 86 | raise RequestKeyError(str(e)) 87 | 88 | except Exception as e: 89 | raise InternalError(f"error adding container - {e}") 90 | -------------------------------------------------------------------------------- /funcx_web_service/routes/funcx.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import typing as t 4 | import uuid 5 | 6 | import requests 7 | from flask import Blueprint 8 | from flask import current_app as app 9 | from flask import g, jsonify, request 10 | from funcx_common.redis import FuncxRedisPubSub, default_redis_connection_factory 11 | from funcx_common.response_errors import ( 12 | ContainerNotFound, 13 | EndpointAccessForbidden, 14 | EndpointOutdated, 15 | EndpointStatsError, 16 | ForwarderRegistrationError, 17 | FunctionAccessForbidden, 18 | InternalError, 19 | InvalidUUID, 20 | RequestKeyError, 21 | RequestMalformed, 22 | TaskGroupAccessForbidden, 23 | TaskGroupNotFound, 24 | TaskNotFound, 25 | UserNotFound, 26 | ) 27 | from funcx_common.task_storage import TaskStorage, get_default_task_storage 28 | from redis.client import Redis 29 | 30 | from funcx_web_service.authentication.auth import ( 31 | authenticated, 32 | authenticated_w_uuid, 33 | authorize_endpoint, 34 | authorize_function, 35 | ) 36 | from funcx_web_service.error_responses import create_error_response 37 | from funcx_web_service.models.tasks import RedisTask, TaskGroup 38 | from funcx_web_service.models.utils import ( 39 | add_ep_whitelist, 40 | db_invocation_logger, 41 | delete_ep_whitelist, 42 | delete_function, 43 | get_ep_whitelist, 44 | get_redis_client, 45 | ingest_endpoint, 46 | ingest_function, 47 | register_endpoint, 48 | resolve_function, 49 | update_function, 50 | ) 51 | from funcx_web_service.version import MIN_SDK_VERSION, VERSION 52 | 53 | from ..models.container import Container 54 | from ..models.endpoint import Endpoint 55 | from ..models.function import Function, FunctionAuthGroup, FunctionContainer 56 | from ..models.serializer import deserialize_result, serialize_inputs 57 | from ..models.user import User 58 | 59 | funcx_api = Blueprint("routes", __name__) 60 | 61 | 62 | def get_db_logger(): 63 | if "db_logger" not in g: 64 | g.db_logger = db_invocation_logger() 65 | return g.db_logger 66 | 67 | 68 | def g_redis_client(): 69 | if "redis_client" not in g: 70 | g.redis_client = get_redis_client() 71 | return g.redis_client 72 | 73 | 74 | def g_redis_pubsub(): 75 | if "redis_pubsub" not in g: 76 | # TODO: remove support for REDIS_HOST and REDIS_PORT 77 | # instead, allow configuration via FUNCX_COMMON_REDIS_URL and instantiate 78 | # FuncxRedisPubSub with no arguments 79 | if "REDIS_HOST" in app.config and "REDIS_PORT" in app.config: 80 | redis_client = default_redis_connection_factory( 81 | f"redis://{app.config['REDIS_HOST']}:{app.config['REDIS_PORT']}" 82 | ) 83 | else: 84 | redis_client = None 85 | g.redis_pubsub = FuncxRedisPubSub(redis_client=redis_client) 86 | g.redis_pubsub.redis_client.ping() 87 | return g.redis_pubsub 88 | 89 | 90 | def get_task_storage() -> TaskStorage: 91 | if not hasattr(g, "task_storage"): 92 | g.task_storage = get_default_task_storage() 93 | return g.task_storage 94 | 95 | 96 | def auth_and_launch( 97 | user_id, 98 | function_uuid, 99 | endpoint_uuid, 100 | input_data, 101 | app, 102 | token, 103 | task_group_id, 104 | serialize=None, 105 | ): 106 | """Here we do basic auth for (user, fn, endpoint) and launch the function. 107 | 108 | Parameters 109 | ========== 110 | 111 | user_id : str 112 | user id 113 | function_uuid : str 114 | function uuid 115 | endpoint_uuid : str 116 | endpoint uuid 117 | input_data : string_buffer 118 | input payload data 119 | app : app object 120 | token : globus token 121 | serialize : bool 122 | Whether or not to serialize the input using the serialization service. This is 123 | used when the input is not already serialized by the SDK. 124 | 125 | Returns: 126 | JSON response object, containing task_uuid, http_status_code, and success or 127 | error info 128 | """ 129 | 130 | task_uuid = str(uuid.uuid4()) 131 | try: 132 | # Check if the user is allowed to access the function 133 | if not authorize_function(user_id, function_uuid, token): 134 | raise FunctionAccessForbidden(function_uuid) 135 | 136 | fn_code, fn_entry, container_uuid = resolve_function(user_id, function_uuid) 137 | 138 | # Make sure the user is allowed to use the function on this endpoint 139 | if not authorize_endpoint(user_id, endpoint_uuid, function_uuid, token): 140 | raise EndpointAccessForbidden(endpoint_uuid) 141 | 142 | app.logger.info(f"Got function container_uuid :{container_uuid}") 143 | 144 | # We should replace this with container_hdr = ";ctnr={container_uuid}" 145 | if not container_uuid: 146 | container_uuid = "RAW" 147 | 148 | rc = g_redis_client() 149 | task_channel = g_redis_pubsub() 150 | 151 | db_logger = get_db_logger() 152 | 153 | if serialize: 154 | serialize_res = serialize_inputs(input_data) 155 | if serialize_res: 156 | input_data = serialize_res 157 | 158 | # At this point the packed function body and the args are concatable strings 159 | payload = fn_code + input_data 160 | task = RedisTask( 161 | rc, 162 | task_uuid, 163 | user_id=user_id, 164 | function_id=function_uuid, 165 | container=container_uuid, 166 | task_group_id=task_group_id, 167 | ) 168 | get_task_storage().store_payload(task, payload) 169 | task_channel.put(endpoint_uuid, task) 170 | 171 | extra_logging = { 172 | "user_id": user_id, 173 | "task_id": task_uuid, 174 | "task_group_id": task_group_id, 175 | "function_id": function_uuid, 176 | "endpoint_id": endpoint_uuid, 177 | "container_id": container_uuid, 178 | "log_type": "task_transition", 179 | } 180 | app.logger.info("received", extra=extra_logging) 181 | 182 | # increment the counter 183 | rc.incr("funcx_invocation_counter") 184 | # add an invocation to the database 185 | # log_invocation(user_id, task_uuid, function_uuid, ep) 186 | db_logger.log(user_id, task_uuid, function_uuid, endpoint_uuid, deferred=True) 187 | 188 | db_logger.commit() 189 | 190 | return {"status": "Success", "task_uuid": task_uuid, "http_status_code": 200} 191 | except Exception as e: 192 | app.logger.exception(e) 193 | res = create_error_response(e)[0] 194 | res["task_uuid"] = task_uuid 195 | return res 196 | 197 | 198 | @funcx_api.route("/submit", methods=["POST"]) 199 | @authenticated 200 | def submit(user: User): 201 | """Puts the task request(s) into Redis and returns a list of task UUID(s) 202 | Parameters 203 | ---------- 204 | user : User 205 | The primary identity of the user 206 | 207 | POST payload 208 | ------------ 209 | { 210 | tasks: [] 211 | } 212 | Returns 213 | ------- 214 | json 215 | The task document 216 | """ 217 | 218 | app.logger.info(f"batch_run invoked by user:{user.username}") 219 | 220 | user_id = user.id 221 | 222 | # Extract the token for endpoint verification 223 | token_str = request.headers.get("Authorization") 224 | token = str.replace(str(token_str), "Bearer ", "") 225 | 226 | # Parse out the function info 227 | tasks = [] 228 | task_group_id = None 229 | try: 230 | post_req = request.json 231 | if "tasks" in post_req: 232 | # new client is being used 233 | # TODO: validate that this tasks list is formatted correctly so 234 | # that more useful errors can be sent back 235 | tasks = post_req["tasks"] 236 | task_group_id = post_req.get("task_group_id", str(uuid.uuid4())) 237 | try: 238 | # check that task_group_id is a valid UUID 239 | uuid.UUID(task_group_id) 240 | except Exception: 241 | raise InvalidUUID("Invalid task_group_id UUID provided") 242 | else: 243 | # old client was used and create a new task 244 | function_uuid = post_req["func"] 245 | endpoint = post_req["endpoint"] 246 | input_data = post_req["payload"] 247 | tasks.append([function_uuid, endpoint, input_data]) 248 | serialize = post_req.get("serialize", None) 249 | except KeyError as e: 250 | # this should raise a 500 because it prevented any tasks from launching 251 | raise RequestKeyError(str(e)) 252 | 253 | rc = g_redis_client() 254 | task_group = None 255 | if task_group_id and TaskGroup.exists(rc, task_group_id): 256 | app.logger.debug( 257 | f"Task Group {task_group_id} submitted to by user {user_id} " 258 | "already exists, checking if user is authorized" 259 | ) 260 | # TODO: This could be cached to minimize lookup cost. 261 | task_group = TaskGroup(rc, task_group_id) 262 | if task_group.user_id != user_id: 263 | raise TaskGroupAccessForbidden(task_group_id) 264 | 265 | # this is a breaking change for old funcx sdk versions 266 | results: t.Dict[str, t.Any] = { 267 | "response": "batch", 268 | "task_group_id": task_group_id, 269 | "results": [], 270 | } 271 | 272 | final_http_status = 200 273 | success_count = 0 274 | for task in tasks: 275 | res = auth_and_launch( 276 | user_id, 277 | function_uuid=task[0], 278 | endpoint_uuid=task[1], 279 | input_data=task[2], 280 | app=app, 281 | token=token, 282 | task_group_id=task_group_id, 283 | serialize=serialize, 284 | ) 285 | 286 | if res.get("status", "Failed") == "Success": 287 | success_count += 1 288 | else: 289 | # the response code is a 207 if some tasks failed to submit 290 | final_http_status = 207 291 | 292 | results["results"].append(res) 293 | 294 | # create a TaskGroup if there are actually tasks with results to wait on and 295 | # a TaskGroup with the provided ID doesn't already exist 296 | if success_count > 0 and task_group_id and not task_group: 297 | app.logger.debug(f"Creating new Task Group {task_group_id} for user {user_id}") 298 | TaskGroup(rc, task_group_id, user_id) 299 | 300 | return jsonify(results), final_http_status 301 | 302 | 303 | def get_tasks_from_redis(task_ids, user: User): 304 | all_tasks = {} 305 | 306 | rc = g_redis_client() 307 | for task_id in task_ids: 308 | # Get the task from redis 309 | if not RedisTask.exists(rc, task_id): 310 | all_tasks[task_id] = { 311 | "task_id": task_id, 312 | "status": "Failed", 313 | "reason": "Unknown task id", 314 | } 315 | continue 316 | 317 | task = RedisTask(rc, task_id) 318 | if task.user_id != user.id: 319 | all_tasks[task_id] = { 320 | "task_id": task_id, 321 | "status": "Failed", 322 | "reason": "Unknown task id", 323 | } 324 | continue 325 | 326 | task_status = task.status 327 | task_result = get_task_storage().get_result(task) 328 | task_exception = task.exception 329 | task_completion_t = task.completion_time 330 | if task_result or task_exception: 331 | task.delete() 332 | 333 | all_tasks[task_id] = { 334 | "task_id": task_id, 335 | "status": task_status, 336 | "result": task_result, 337 | "completion_t": task_completion_t, 338 | "exception": task_exception, 339 | } 340 | 341 | # Note: this is for backwards compat, when we can't include a None result and 342 | # have a non-complete status, we must forgo the result field if task not 343 | # complete. 344 | if task_result is None: 345 | del all_tasks[task_id]["result"] 346 | 347 | # Note: this is for backwards compat, when we can't include a None result and 348 | # have a non-complete status, we must forgo the result field if task not 349 | # complete. 350 | if task_exception is None: 351 | del all_tasks[task_id]["exception"] 352 | return all_tasks 353 | 354 | 355 | def get_task_or_404(rc: Redis, task_id: str) -> RedisTask: 356 | if not RedisTask.exists(rc, task_id): 357 | raise TaskNotFound(task_id) 358 | return RedisTask(rc, task_id) 359 | 360 | 361 | def authorize_task_or_404(task: RedisTask, user: User): 362 | if task.user_id != user.id: 363 | raise TaskNotFound(task.task_id) 364 | 365 | 366 | # TODO: Old APIs look at "//status" for status and result, when that's changed, 367 | # we should remove this route 368 | @funcx_api.route("//status", methods=["GET"]) 369 | @funcx_api.route("/tasks/", methods=["GET"]) 370 | @authenticated 371 | def status_and_result(user: User, task_id): 372 | """Check the status of a task. Return result if available. 373 | 374 | If the query param deserialize=True is passed, then we deserialize the result 375 | object. 376 | 377 | Parameters 378 | ---------- 379 | user : User 380 | The primary identity of the user 381 | task_id : str 382 | The task uuid to look up 383 | 384 | Returns 385 | ------- 386 | json 387 | The status of the task 388 | """ 389 | rc = g_redis_client() 390 | task = get_task_or_404(rc, task_id) 391 | authorize_task_or_404(task, user) 392 | 393 | task_status = task.status 394 | task_result = get_task_storage().get_result(task) 395 | task_exception = task.exception 396 | task_completion_t = task.completion_time 397 | if task_result or task_exception: 398 | extra_logging = { 399 | "user_id": task.user_id, 400 | "task_id": task_id, 401 | "task_group_id": task.task_group_id, 402 | "function_id": task.function_id, 403 | "endpoint_id": task.endpoint, 404 | "container_id": task.container, 405 | "log_type": "task_transition", 406 | } 407 | app.logger.info("user_fetched", extra=extra_logging) 408 | 409 | task.delete() 410 | 411 | deserialize = request.args.get("deserialize", False) 412 | if deserialize and task_result: 413 | task_result = deserialize_result(task_result) 414 | 415 | # TODO: change client to have better naming conventions 416 | # these fields like 'status' should be changed to 'task_status', because 'status' is 417 | # normally used for HTTP codes. 418 | response = { 419 | "task_id": task_id, 420 | "status": task_status, 421 | "result": task_result, 422 | "completion_t": task_completion_t, 423 | "exception": task_exception, 424 | } 425 | 426 | # Note: this is for backwards compat, when we can't include a None result and have a 427 | # non-complete status, we must forgo the result field if task not complete. 428 | if task_result is None: 429 | del response["result"] 430 | 431 | if task_exception is None: 432 | del response["exception"] 433 | 434 | return jsonify(response) 435 | 436 | 437 | @funcx_api.route("/batch_status", methods=["POST"]) 438 | @authenticated 439 | def batch_status(user: User): 440 | """Check the status of a task. 441 | 442 | Parameters 443 | ---------- 444 | user : User 445 | The primary identity of the user 446 | task_id : str 447 | The task uuid to look up 448 | 449 | Returns 450 | ------- 451 | json 452 | The status of the task 453 | """ 454 | app.logger.debug("batch_status_request", extra=request.json) 455 | results = get_tasks_from_redis(request.json["task_ids"], user) 456 | 457 | return jsonify({"response": "batch", "results": results}) 458 | 459 | 460 | def register_with_hub(address, endpoint_id, endpoint_address): 461 | """This registers with the Forwarder micro service. 462 | 463 | Can be used as an example of how to make calls this it, while the main API 464 | is updated to do this calling on behalf of the endpoint in the second iteration. 465 | 466 | Parameters 467 | ---------- 468 | address : str 469 | Address of the forwarder service of the form http://: 470 | 471 | """ 472 | try: 473 | r = requests.post( 474 | address + "/register", 475 | json={ 476 | "endpoint_id": endpoint_id, 477 | "redis_address": app.config["ADVERTISED_REDIS_HOST"], 478 | "endpoint_addr": endpoint_address, 479 | }, 480 | timeout=2, # timeout for the forwarder response 481 | ) 482 | except requests.Timeout: 483 | raise ForwarderRegistrationError( 484 | "Forwarder is un-responsive, unable to register endpoint within timeout:2s" 485 | ) 486 | except Exception as e: 487 | raise ForwarderRegistrationError( 488 | f"Request to Forwarder failed, unable to register endpoint: {e}" 489 | ) 490 | 491 | if r.status_code != 200: 492 | print(dir(r)) 493 | print(r) 494 | raise ForwarderRegistrationError(r.reason) 495 | 496 | return r.json() 497 | 498 | 499 | def get_forwarder_version(): 500 | forwarder_ip = app.config["FORWARDER_IP"] 501 | r = requests.get(f"http://{forwarder_ip}:8080/version", timeout=2) 502 | return r.json() 503 | 504 | 505 | @funcx_api.route("/version", methods=["GET"]) 506 | def get_version(): 507 | s = request.args.get("service") 508 | if s == "api" or s is None: 509 | return jsonify(VERSION) 510 | 511 | forwarder_v_info = get_forwarder_version() 512 | forwarder_version = forwarder_v_info["forwarder"] 513 | min_ep_version = forwarder_v_info["min_ep_version"] 514 | if s == "forwarder": 515 | return jsonify(forwarder_version) 516 | 517 | if s == "all": 518 | result = { 519 | "api": VERSION, 520 | "forwarder": forwarder_version, 521 | "min_sdk_version": MIN_SDK_VERSION, 522 | "min_ep_version": min_ep_version, 523 | } 524 | 525 | if app.extensions["ContainerService"]: 526 | result["container_service"] = app.extensions[ 527 | "ContainerService" 528 | ].get_version()["version"] 529 | return jsonify(result) 530 | 531 | raise RequestMalformed("unknown service type or other error.") 532 | 533 | 534 | # Endpoint routes 535 | @funcx_api.route("/endpoints", methods=["POST"]) 536 | @authenticated_w_uuid 537 | def reg_endpoint(user: User, user_uuid: str): 538 | """ 539 | Register an endpoint. Add this endpoint to the database and associate it with 540 | this user. 541 | 542 | Returns 543 | ------- 544 | json 545 | A dict containing the endpoint details 546 | """ 547 | app.logger.info("register_endpoint triggered", extra=request.json) 548 | 549 | v_info = get_forwarder_version() 550 | min_ep_version = v_info["min_ep_version"] 551 | if "version" not in request.json: 552 | raise RequestKeyError( 553 | "Endpoint funcx version must be passed in the 'version' field." 554 | ) 555 | 556 | if request.json["version"] < min_ep_version: 557 | raise EndpointOutdated(min_ep_version) 558 | 559 | # Cooley ALCF is the default used here. 560 | endpoint_ip_addr = "140.221.68.108" 561 | if request.environ.get("HTTP_X_FORWARDED_FOR") is None: 562 | endpoint_ip_addr = request.environ["REMOTE_ADDR"] 563 | else: 564 | endpoint_ip_addr = request.environ["HTTP_X_FORWARDED_FOR"] 565 | app.logger.info(f"Registering endpoint IP address as: {endpoint_ip_addr}") 566 | 567 | # always return the jsonified error response as soon as it is available below 568 | # to prevent further registration steps being taken after an error 569 | try: 570 | app.logger.info(f"requesting registration for {request.json}") 571 | endpoint_uuid = register_endpoint( 572 | user, 573 | request.json["endpoint_name"], 574 | "", # use description from meta? why store here at all 575 | endpoint_uuid=request.json["endpoint_uuid"], 576 | ) 577 | app.logger.info(f"Successfully registered {endpoint_uuid} in database") 578 | 579 | except KeyError as e: 580 | app.logger.exception("Missing keys in json request") 581 | raise RequestKeyError(str(e)) 582 | 583 | except UserNotFound as e: 584 | app.logger.exception("User not found") 585 | raise e 586 | 587 | except ValueError: 588 | app.logger.exception("Invalid UUID sent for endpoint") 589 | raise InvalidUUID("Invalid endpoint UUID provided") 590 | 591 | except Exception as e: 592 | app.logger.exception("Caught error while registering endpoint") 593 | raise e 594 | 595 | try: 596 | forwarder_ip = app.config["FORWARDER_IP"] 597 | response = register_with_hub( 598 | f"http://{forwarder_ip}:8080", endpoint_uuid, endpoint_ip_addr 599 | ) 600 | app.logger.info(f"Successfully registered {endpoint_uuid} with forwarder") 601 | 602 | except Exception as e: 603 | app.logger.exception("Caught error during forwarder initialization") 604 | raise e 605 | 606 | if "meta" in request.json and endpoint_uuid: 607 | ingest_endpoint(user.username, user_uuid, endpoint_uuid, request.json["meta"]) 608 | app.logger.info(f"Ingested endpoint {endpoint_uuid}") 609 | 610 | try: 611 | return jsonify(response) 612 | except NameError: 613 | return "oof" 614 | 615 | 616 | @funcx_api.route("/endpoints//status", methods=["GET"]) 617 | @authenticated 618 | def get_ep_stats(user: User, endpoint_id): 619 | """Retrieve the status updates from an endpoint. 620 | 621 | Parameters 622 | ---------- 623 | user : User 624 | The primary identity of the user 625 | endpoint_id : str 626 | The endpoint uuid to look up 627 | 628 | Returns 629 | ------- 630 | json 631 | The status of the endpoint 632 | """ 633 | alive_threshold = ( 634 | 2 * 60 635 | ) # time in seconds since last heartbeat to be counted as alive 636 | last = 10 637 | 638 | user_id = user.id 639 | 640 | # Extract the token for endpoint verification 641 | token_str = request.headers.get("Authorization") 642 | token = str.replace(str(token_str), "Bearer ", "") 643 | 644 | if not authorize_endpoint(user_id, endpoint_id, None, token): 645 | raise EndpointAccessForbidden(endpoint_id) 646 | 647 | rc = g_redis_client() 648 | 649 | status: str = "offline" 650 | status_logs: t.List[t.Dict[str, t.Any]] = [] 651 | try: 652 | end = min(rc.llen(f"ep_status_{endpoint_id}"), last) 653 | print("Total len :", end) 654 | items = rc.lrange(f"ep_status_{endpoint_id}", 0, end) 655 | if items: 656 | for i in items: 657 | dataitem = json.loads(i) 658 | if not isinstance(dataitem, dict): 659 | raise EndpointStatsError( 660 | endpoint_id, "endpoint stats failed to load, nondict data" 661 | ) 662 | status_logs.append(dataitem) 663 | 664 | # timestamp is an epoch timestamp 665 | newest_timestamp = status_logs[0].get("timestamp") 666 | if newest_timestamp is None or ( 667 | not isinstance(newest_timestamp, (int, float)) 668 | ): 669 | raise EndpointStatsError( 670 | endpoint_id, "could not load latest timestamp from ep status info" 671 | ) 672 | 673 | now = time.time() 674 | if now - newest_timestamp < alive_threshold: 675 | status = "online" 676 | 677 | # FIXME: identify error conditions and remove this blanket error capture 678 | except Exception as e: 679 | if isinstance(e, EndpointStatsError): 680 | raise 681 | raise EndpointStatsError(endpoint_id, str(e)) 682 | 683 | status_info = {"logs": status_logs, "status": status} 684 | return jsonify(status_info) 685 | 686 | 687 | @funcx_api.route("/endpoints/", methods=["DELETE"]) 688 | @authenticated 689 | def del_endpoint(user: User, endpoint_id): 690 | """Delete the endpoint. 691 | 692 | Parameters 693 | ---------- 694 | user : User 695 | The primary identity of the user 696 | endpoint_id : str 697 | The endpoint uuid to delete 698 | 699 | Returns 700 | ------- 701 | json 702 | Dict containing the result 703 | """ 704 | try: 705 | result = Endpoint.delete_endpoint(user, endpoint_id) 706 | return jsonify({"result": result}) 707 | except Exception as e: 708 | app.logger.error(e) 709 | 710 | 711 | # Whitelist routes 712 | @funcx_api.route("/endpoints//whitelist", methods=["POST", "GET"]) 713 | @authenticated 714 | def endpoint_whitelist(user: User, endpoint_id): 715 | """Get or insert into the endpoint's whitelist. 716 | If POST, insert the list of function ids into the whitelist. 717 | if GET, return the list of function ids in the whitelist 718 | 719 | Parameters 720 | ---------- 721 | user : User 722 | The primary identity of the user 723 | endpoint_id : str 724 | The id of the endpoint 725 | 726 | Returns 727 | ------- 728 | json 729 | A dict including a list of whitelisted functions for this endpoint 730 | """ 731 | 732 | app.logger.info( 733 | f"Adding to endpoint {endpoint_id} whitelist by user: {user.username}" 734 | ) 735 | 736 | if request.method == "GET": 737 | return get_ep_whitelist(user, endpoint_id) 738 | else: 739 | # Otherwise we need the list of functions passed in 740 | try: 741 | post_req = request.json 742 | functions = post_req["func"] 743 | except KeyError as e: 744 | raise RequestKeyError(str(e)) 745 | except Exception as e: 746 | raise RequestMalformed(str(e)) 747 | return add_ep_whitelist(user, endpoint_id, functions) 748 | 749 | 750 | @funcx_api.route("/endpoints//whitelist/", methods=["DELETE"]) 751 | @authenticated 752 | def del_endpoint_whitelist(user: User, endpoint_id, function_id): 753 | """Delete from an endpoint's whitelist. Return the success/failure of the delete. 754 | 755 | Parameters 756 | ---------- 757 | user : User 758 | The primary identity of the user 759 | endpoint_id : str 760 | The id of the endpoint 761 | function_id : str 762 | The id of the function to delete 763 | 764 | Returns 765 | ------- 766 | json 767 | A dict describing the result of deleting from the endpoint's whitelist 768 | """ 769 | 770 | app.logger.info( 771 | f"Deleting function {function_id} from endpoint {endpoint_id} whitelist by " 772 | f"user: {user.username}" 773 | ) 774 | 775 | return delete_ep_whitelist(user, endpoint_id, function_id) 776 | 777 | 778 | @funcx_api.route("/functions", methods=["POST"]) 779 | @authenticated_w_uuid 780 | def reg_function(user: User, user_uuid): 781 | """Register the function. 782 | 783 | Parameters 784 | ---------- 785 | user : str 786 | The primary identity of the user 787 | 788 | POST Payload 789 | ------------ 790 | { "function_name" : , 791 | "entry_point" : , 792 | "function_code" : , 793 | "function_source": , 795 | "description" : , 796 | "group": 797 | "public" : 798 | "searchable" : 799 | } 800 | 801 | Returns 802 | ------- 803 | json 804 | Dict containing the function details 805 | """ 806 | 807 | function_rec = None 808 | function_source = None 809 | try: 810 | function_source = request.json["function_source"] 811 | function_rec = Function( 812 | function_uuid=str(uuid.uuid4()), 813 | function_name=request.json["function_name"], 814 | entry_point=request.json["entry_point"], 815 | description=request.json["description"], 816 | function_source_code=request.json["function_code"], 817 | public=request.json.get("public", False), 818 | user_id=user.id, 819 | ) 820 | 821 | container_uuid = request.json.get("container_uuid", None) 822 | container = None 823 | if container_uuid: 824 | container = Container.find_by_uuid(container_uuid) 825 | if not container: 826 | raise ContainerNotFound(container_uuid) 827 | 828 | group_uuid = request.json.get("group", None) 829 | searchable = request.json.get("searchable", True) 830 | 831 | app.logger.info( 832 | f"Registering function {function_rec.function_name} " 833 | f"with container {container_uuid}" 834 | ) 835 | 836 | if container: 837 | function_rec.container = FunctionContainer( 838 | function=function_rec, container=container 839 | ) 840 | 841 | if group_uuid: 842 | function_rec.auth_groups = [ 843 | FunctionAuthGroup(group_id=group_uuid, function=function_rec) 844 | ] 845 | 846 | function_rec.save_to_db() 847 | 848 | response = jsonify({"function_uuid": function_rec.function_uuid}) 849 | 850 | if not searchable: 851 | return response 852 | 853 | except KeyError as key_error: 854 | app.logger.error(key_error) 855 | raise RequestKeyError(str(key_error)) 856 | 857 | except Exception as e: 858 | function_name = ( 859 | function_rec.function_name if function_rec is not None else "" 860 | ) 861 | message = ( 862 | f"Function registration failed for user={user.username} , " 863 | f"function_name={function_name} due to {e}" 864 | ) 865 | app.logger.error( 866 | "function_registration_fail", 867 | extra={"user_id": user_uuid, "function_name": function_name}, 868 | ) 869 | raise InternalError(message) 870 | 871 | try: 872 | ingest_function(function_rec, function_source, user_uuid) 873 | except Exception as e: 874 | message = ( 875 | f"Function ingest to search failed for user:{user.username} " 876 | f"function_name:{function_rec.function_name} due to {e}" 877 | ) 878 | app.logger.error( 879 | "function_ingest_failed", 880 | extra={"user_id": user_uuid, "function_name": function_rec.function_name}, 881 | ) 882 | raise InternalError(message) 883 | 884 | return response 885 | 886 | 887 | @funcx_api.route("/functions/", methods=["PUT"]) 888 | @authenticated 889 | def upd_function(user: User, function_id): 890 | """Update the function. 891 | 892 | Parameters 893 | ---------- 894 | user : User 895 | The primary identity of the user 896 | function_id : str 897 | The function to update 898 | 899 | Returns 900 | ------- 901 | json 902 | Dict containing the result as an integer 903 | """ 904 | try: 905 | function_name = request.json["name"] 906 | function_desc = request.json["desc"] 907 | function_entry_point = request.json["entry_point"] 908 | function_code = request.json["code"] 909 | result = update_function( 910 | user.username, 911 | function_id, 912 | function_name, 913 | function_desc, 914 | function_entry_point, 915 | function_code, 916 | ) 917 | if result == 302: 918 | return jsonify({"function_uuid": function_id}), 302 919 | elif result == 403: 920 | message = ( 921 | f"Unable to update function for user:{user.username} " 922 | f"function_id:{function_id}. 403 Unauthorized" 923 | ) 924 | app.logger.error(message) 925 | raise InternalError(message) 926 | elif result == 404: 927 | message = ( 928 | f"Unable to update function for user:{user.username} " 929 | f"function_id:{function_id}. 404 Function not found." 930 | ) 931 | app.logger.error(message) 932 | raise InternalError(message) 933 | except Exception as e: 934 | app.logger.exception(e) 935 | message = ( 936 | "Unable to update function for user:{} function_id:{} due to {}".format( 937 | user.username, function_id, e 938 | ) 939 | ) 940 | app.logger.error(message) 941 | raise InternalError(message) 942 | 943 | 944 | @funcx_api.route("/functions/", methods=["DELETE"]) 945 | @authenticated 946 | def del_function(user: User, function_id): 947 | """Delete the function. 948 | 949 | Parameters 950 | ---------- 951 | user : User 952 | The primary identity of the user 953 | function_id : str 954 | The function uuid to delete 955 | 956 | Returns 957 | ------- 958 | json 959 | Dict containing the result 960 | """ 961 | try: 962 | result = delete_function(user, function_id) 963 | return jsonify({"result": result}) 964 | except Exception as e: 965 | app.logger.error(e) 966 | 967 | 968 | @funcx_api.route("/stats", methods=["GET"]) 969 | def funcx_stats(): 970 | """Get various usage stats.""" 971 | app.logger.debug("Getting stats") 972 | try: 973 | rc = g_redis_client() 974 | result = int(rc.get("funcx_invocation_counter")) 975 | return jsonify({"total_function_invocations": result}), 200 976 | except Exception as e: 977 | app.logger.exception(e) 978 | message = f"Unable to get invocation count due to {e}" 979 | app.logger.error(message) 980 | raise InternalError(message) 981 | 982 | 983 | @funcx_api.route("/authenticate", methods=["GET"]) 984 | @authenticated 985 | def authenticate(user: User): 986 | return "OK" 987 | 988 | 989 | @funcx_api.route("/task_groups/", methods=["GET"]) 990 | @authenticated 991 | def get_batch_info(user: User, task_group_id): 992 | rc = g_redis_client() 993 | 994 | if not TaskGroup.exists(rc, task_group_id): 995 | raise TaskGroupNotFound(task_group_id) 996 | 997 | task_group = TaskGroup(rc, task_group_id) 998 | 999 | if task_group.user_id != user.id: 1000 | raise TaskGroupAccessForbidden(task_group_id) 1001 | 1002 | return jsonify({"authorized": True}) 1003 | -------------------------------------------------------------------------------- /funcx_web_service/version.py: -------------------------------------------------------------------------------- 1 | """Set module version. 2 | 3 | ..[alpha/beta/..] 4 | Alphas will be numbered like this -> 0.4.0a0 5 | """ 6 | VERSION = "0.3.7" 7 | MIN_SDK_VERSION = "0.2.1" 8 | -------------------------------------------------------------------------------- /integration_tests/Tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# funcX Tutorial\n", 8 | "\n", 9 | "funcX is a Function-as-a-Service (FaaS) platform for science that enables you to register functions in a cloud-hosted service and then reliably execute those functions on a remote funcX endpoint. This tutorial is configured to use a tutorial endpoint hosted by the funcX team. You can set up and use your own endpoint by following the [funcX documentation](https://funcx.readthedocs.io/en/latest/endpoints.html)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## funcX Python SDK\n", 17 | "\n", 18 | "The funcX Python SDK provides programming abstractions for interacting with the funcX service. Before running this tutorial locally, you should first install the funcX SDK as follows:\n", 19 | "\n", 20 | " $ pip install funcx\n", 21 | "\n", 22 | "(If you are running on binder, we've already done this for you in the binder environment.)\n", 23 | "The funcX SDK exposes a `FuncXClient` object for all interactions with the funcX service. In order to use the funcX service, you must first authenticate using one of hundreds of supported identity providers (e. g., your institution, ORCID, Google). As part of the authentication process, you must grant permission for funcX to access your identity information (to retrieve your email address), Globus Groups management access (to share functions and endpoints), and Globus Search (to discover functions and endpoints). " 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 3, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "from funcx.sdk.client import FuncXClient\n", 33 | "\n", 34 | "fxc = FuncXClient(funcx_service_address=\"http://localhost:5000/v1\")" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "# Basic usage\n", 42 | "\n", 43 | "The following example demonstrates how you can register and execute a function. \n", 44 | "\n", 45 | "## Registering a function\n", 46 | "\n", 47 | "funcX works like any other FaaS platform: you must first register a function with funcX before being able to execute it on a remote endpoint. The registration process will serialize the function body and store it securely in the funcX service. As we will see below, you may share functions with others and discover functions shared with you.\n", 48 | "\n", 49 | "When you register a function, funcX will return a universally unique identifier (UUID) for it. This UUID can then be used to manage and invoke the function." 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 4, 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "name": "stdout", 59 | "output_type": "stream", 60 | "text": [ 61 | "3ee63a1a-4cb0-4dc3-ba11-b6215e7f9e16\n" 62 | ] 63 | } 64 | ], 65 | "source": [ 66 | "def hello_world():\n", 67 | " return \"Hello World!\"\n", 68 | "\n", 69 | "func_uuid = fxc.register_function(hello_world)\n", 70 | "print(func_uuid)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "# Endpoints" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 5, 83 | "metadata": {}, 84 | "outputs": [ 85 | { 86 | "ename": "TypeError", 87 | "evalue": "search_endpoint() missing 1 required positional argument: 'q'", 88 | "output_type": "error", 89 | "traceback": [ 90 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 91 | "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", 92 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfxc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msearch_endpoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 93 | "\u001b[0;31mTypeError\u001b[0m: search_endpoint() missing 1 required positional argument: 'q'" 94 | ] 95 | } 96 | ], 97 | "source": [ 98 | "print(fxc.search_endpoint())" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "## Running a function \n", 106 | "\n", 107 | "To invoke a function, you must provide a) the function's UUID; and b) the `endpoint_id` of the endpoint on which you wish to execute that function. Note: here we use the public funcX tutorial endpoint; you may change the `endpoint_id` to the UUID of any endpoint on which you have permission to execute functions. \n", 108 | "\n", 109 | "funcX functions are designed to be executed remotely and asynchrously. To avoid synchronous invocation, the result of a function invocation (called a `task`) is a UUID, which may be introspected to monitor execution status and retrieve results.\n", 110 | "\n", 111 | "The funcX service will manage the reliable execution of a task, for example, by qeueing tasks when the endpoint is busy or offline and retrying tasks in case of node failures. " 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 15, 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "name": "stdout", 121 | "output_type": "stream", 122 | "text": [ 123 | "af466817-1fed-4904-9119-76cb4fb8595f\n" 124 | ] 125 | } 126 | ], 127 | "source": [ 128 | "tutorial_endpoint = '4b116d3c-1703-4f8f-9f6f-39921e5864df' # Public tutorial endpoint\n", 129 | "local_endpoint= '2c5cd24a-dc6b-4d8a-9024-0014e0f90846'\n", 130 | "res = fxc.run(endpoint_id=local_endpoint, function_id=func_uuid)\n", 131 | "print(res)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 16, 137 | "metadata": {}, 138 | "outputs": [ 139 | { 140 | "name": "stdout", 141 | "output_type": "stream", 142 | "text": [ 143 | "af466817-1fed-4904-9119-76cb4fb8595f {'pending': True}\n" 144 | ] 145 | } 146 | ], 147 | "source": [ 148 | "print(res, fxc.get_task(res))" 149 | ] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "metadata": {}, 154 | "source": [ 155 | "## Retrieving results\n", 156 | "\n", 157 | "When the task has completed executing, you can access the results via the funcX client as follows:" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 10, 163 | "metadata": {}, 164 | "outputs": [ 165 | { 166 | "ename": "Exception", 167 | "evalue": "Task pending", 168 | "output_type": "error", 169 | "traceback": [ 170 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 171 | "\u001b[0;31mException\u001b[0m Traceback (most recent call last)", 172 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mfxc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 173 | "\u001b[0;32m~/dev/funcx/funcX/funcx_sdk/funcx/sdk/client.py\u001b[0m in \u001b[0;36mget_result\u001b[0;34m(self, task_id)\u001b[0m\n\u001b[1;32m 211\u001b[0m \u001b[0mtask\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_task\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtask_id\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtask\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'pending'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 213\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Task pending\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 214\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 215\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m'result'\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtask\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 174 | "\u001b[0;31mException\u001b[0m: Task pending" 175 | ] 176 | } 177 | ], 178 | "source": [ 179 | "fxc.get_result(res)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": {}, 185 | "source": [ 186 | "## Functions with arguments\n", 187 | "\n", 188 | "funcX supports registration and invocation of functions with arbitrary arguments and returned parameters. funcX will serialize any \\*args and \\*\\*kwargs when invoking a function and it will serialize any return parameters or exceptions. Note: funcX uses standard Python serialization libraries (e. g., Pickle, Dill). It also limits the size of input arguments and returned parameters to 5 MB.\n", 189 | "\n", 190 | "The following example shows a function that computes the sum of a list of input arguments. First we register the function as above:" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": { 197 | "collapsed": true 198 | }, 199 | "outputs": [], 200 | "source": [ 201 | "def funcx_sum(items):\n", 202 | " return sum(items)\n", 203 | "\n", 204 | "sum_function = fxc.register_function(funcx_sum)" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "metadata": {}, 210 | "source": [ 211 | "When invoking the function, you can pass in arguments like any other function, either by position or with keyword arguments. " 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": { 218 | "collapsed": true 219 | }, 220 | "outputs": [], 221 | "source": [ 222 | "items = [1, 2, 3, 4, 5]\n", 223 | "\n", 224 | "res = fxc.run(items, endpoint_id=tutorial_endpoint, function_id=sum_function)\n", 225 | "\n", 226 | "print (fxc.get_result(res))" 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "metadata": {}, 232 | "source": [ 233 | "## Functions with dependencies\n", 234 | "\n", 235 | "funcX requires that functions explictly state all dependencies within the function body. It also assumes that the dependent libraries are available on the endpoint in which the function will execute. For example, in the following function we explictly import the time module. " 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "metadata": { 242 | "collapsed": true 243 | }, 244 | "outputs": [], 245 | "source": [ 246 | "def funcx_date():\n", 247 | " from datetime import date\n", 248 | " return date.today()\n", 249 | "\n", 250 | "date_function = fxc.register_function(funcx_date)\n", 251 | "\n", 252 | "res = fxc.run(endpoint_id=tutorial_endpoint, function_id=date_function)\n", 253 | "\n", 254 | "print (fxc.get_result(res))\n" 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": {}, 260 | "source": [ 261 | "## Calling external applications\n", 262 | "\n", 263 | "Depending on the configuration of the funcX endpoint, you can often invoke external applications that are avaialble in the endpoint environment. \n" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": { 270 | "collapsed": true 271 | }, 272 | "outputs": [], 273 | "source": [ 274 | "def funcx_echo(name):\n", 275 | " import os\n", 276 | " return os.popen(\"echo Hello %s\" % name).read()\n", 277 | "\n", 278 | "echo_function = fxc.register_function(funcx_echo)\n", 279 | "\n", 280 | "res = fxc.run(\"World\", endpoint_id=tutorial_endpoint, function_id=echo_function)\n", 281 | "\n", 282 | "print (fxc.get_result(res))" 283 | ] 284 | }, 285 | { 286 | "cell_type": "markdown", 287 | "metadata": {}, 288 | "source": [ 289 | "## Catching exceptions\n", 290 | "\n", 291 | "When functions fail, the exception is captured and serialized by the funcX endpoint, and is reraised when you try to get the result. In the following example, the 'deterministic failure' exception is raised when `fxc.get_result` is called on the failing function." 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "metadata": { 298 | "collapsed": true 299 | }, 300 | "outputs": [], 301 | "source": [ 302 | "def failing():\n", 303 | " raise Exception(\"deterministic failure\")\n", 304 | "\n", 305 | "failing_function = fxc.register_function(failing)\n", 306 | "\n", 307 | "res = fxc.run(endpoint_id=tutorial_endpoint, function_id=failing_function)\n", 308 | "\n", 309 | "fxc.get_result(res)" 310 | ] 311 | }, 312 | { 313 | "cell_type": "markdown", 314 | "metadata": {}, 315 | "source": [ 316 | "## Running functions many times\n", 317 | "\n", 318 | "After registering a function, you can invoke it repeatedly. The following example shows how the monte carlo method can be used to estimate pi. \n", 319 | "\n", 320 | "Specifically, if a circle with radius $r$ is inscribed inside a square with side length $2r$, the area of the circle is $\\pi r^2$ and the area of the square is $(2r)^2$. Thus, if $N$ uniformly-distributed random points are dropped within the square, approximately $N\\pi/4$ will be inside the circle.\n" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": null, 326 | "metadata": { 327 | "collapsed": true 328 | }, 329 | "outputs": [], 330 | "source": [ 331 | "import time\n", 332 | "\n", 333 | "# function that estimates pi by placing points in a box\n", 334 | "def pi(num_points):\n", 335 | " from random import random\n", 336 | " inside = 0 \n", 337 | " for i in range(num_points):\n", 338 | " x, y = random(), random() # Drop a random point in the box.\n", 339 | " if x**2 + y**2 < 1: # Count points within the circle.\n", 340 | " inside += 1 \n", 341 | " return (inside*4 / num_points)\n", 342 | "\n", 343 | "# register the function\n", 344 | "pi_function = fxc.register_function(pi)\n", 345 | "\n", 346 | "# execute the function 3 times \n", 347 | "estimates = []\n", 348 | "for i in range(3):\n", 349 | " estimates.append(fxc.run(10**5, endpoint_id=tutorial_endpoint, function_id=pi_function))\n", 350 | "\n", 351 | "# wait for tasks to complete\n", 352 | "time.sleep(5)\n", 353 | "\n", 354 | "# wait for all tasks to complete\n", 355 | "for e in estimates: \n", 356 | " while fxc.get_task(e)['pending'] == 'True':\n", 357 | " time.sleep(3)\n", 358 | "\n", 359 | "# get the results and calculate the total\n", 360 | "results = [fxc.get_result(i) for i in estimates]\n", 361 | "total = 0\n", 362 | "for r in results: \n", 363 | " total += r\n", 364 | "\n", 365 | "# print the results\n", 366 | "print(\"Estimates: %s\" % results)\n", 367 | "print(\"Average: {:.5f}\".format(total/len(results)))" 368 | ] 369 | }, 370 | { 371 | "cell_type": "markdown", 372 | "metadata": {}, 373 | "source": [ 374 | "# Describing and discovering functions \n", 375 | "\n", 376 | "funcX manages a registry of functions that can be shared, discovered and reused. \n", 377 | "\n", 378 | "When registering a function, you may choose to set a description to support discovery, as well as making it `public` (so that others can run it) and/or `searchable` (so that others can discover it). " 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": null, 384 | "metadata": { 385 | "collapsed": true 386 | }, 387 | "outputs": [], 388 | "source": [ 389 | "def hello_world():\n", 390 | " return \"Hello World!\"\n", 391 | "\n", 392 | "func_uuid = fxc.register_function(hello_world, description=\"hello world function\", public=True, searchable=True)\n", 393 | "print(func_uuid)" 394 | ] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "metadata": {}, 399 | "source": [ 400 | "You can search previously registered functions to which you have access using `search_function`. The first parameter is searched against all the fields, such as author, description, function name, and function source. You can navigate through pages of results with the `offset` and `limit` keyword args. \n", 401 | "\n", 402 | "The object returned is a simple wrapper on a list, so you can index into it, but also can have a pretty-printed table. " 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": null, 408 | "metadata": { 409 | "collapsed": true 410 | }, 411 | "outputs": [], 412 | "source": [ 413 | "search_results = fxc.search_function(\"hello\", offset=0, limit=5)\n", 414 | "print(search_results)" 415 | ] 416 | }, 417 | { 418 | "cell_type": "markdown", 419 | "metadata": {}, 420 | "source": [ 421 | "# Managing endpoints\n", 422 | "\n", 423 | "funcX endpoints advertise whether or not they are online as well as information about their available resources, queued tasks, and other information. If you are permitted to execute functions on an endpoint, you can also retrieve the status of the endpoint. The following example shows how to look up the status (online or offline) and the number of number of waiting tasks and workers connected to the endpoint. " 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": null, 429 | "metadata": { 430 | "collapsed": true 431 | }, 432 | "outputs": [], 433 | "source": [ 434 | "endpoint_status = fxc.get_endpoint_status(tutorial_endpoint)\n", 435 | "\n", 436 | "print(\"Status: %s\" % endpoint_status['status'])\n", 437 | "print(\"Workers: %s\" % endpoint_status['logs'][0]['total_workers'])\n", 438 | "print(\"Tasks: %s\" % endpoint_status['logs'][0]['outstanding_tasks'])" 439 | ] 440 | }, 441 | { 442 | "cell_type": "markdown", 443 | "metadata": {}, 444 | "source": [ 445 | "# Advanced features\n", 446 | "\n", 447 | "funcX provides several features that address more advanced use cases. \n", 448 | "\n", 449 | "## Running batches\n", 450 | "\n", 451 | "After registering a function, you might want to invoke that function many times without making individual calls to the funcX service. Such examples occur when running monte carlo simulations, ensembles, and parameter sweep applications. \n", 452 | "\n", 453 | "funcX provides a batch interface that enables specification of a range of function invocations. To use this interface, you must create a funcX batch object and then add each invocation to that object. You can then pass the constructed object to the `batch_run` interface. " 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": null, 459 | "metadata": { 460 | "collapsed": true 461 | }, 462 | "outputs": [], 463 | "source": [ 464 | "def squared(x):\n", 465 | " return x**2\n", 466 | "\n", 467 | "squared_function = fxc.register_function(squared)\n", 468 | "\n", 469 | "inputs = list(range(10))\n", 470 | "batch = fxc.create_batch()\n", 471 | "\n", 472 | "for x in inputs:\n", 473 | " batch.add(x, endpoint_id=tutorial_endpoint, function_id=squared_function)\n", 474 | " \n", 475 | "batch_res = fxc.batch_run(batch)" 476 | ] 477 | }, 478 | { 479 | "cell_type": "markdown", 480 | "metadata": {}, 481 | "source": [ 482 | "Similary, funcX provides an interface to retrieve the status of the entire batch of invocations. " 483 | ] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "execution_count": null, 488 | "metadata": { 489 | "collapsed": true 490 | }, 491 | "outputs": [], 492 | "source": [ 493 | "fxc.get_batch_status(batch_res)" 494 | ] 495 | } 496 | ], 497 | "metadata": { 498 | "kernelspec": { 499 | "display_name": "Python 3", 500 | "language": "python", 501 | "name": "python3" 502 | }, 503 | "language_info": { 504 | "codemirror_mode": { 505 | "name": "ipython", 506 | "version": 3 507 | }, 508 | "file_extension": ".py", 509 | "mimetype": "text/x-python", 510 | "name": "python", 511 | "nbconvert_exporter": "python", 512 | "pygments_lexer": "ipython3", 513 | "version": "3.7.6" 514 | } 515 | }, 516 | "nbformat": 4, 517 | "nbformat_minor": 4 518 | } 519 | -------------------------------------------------------------------------------- /integration_tests/funcX.postman_collection.json: -------------------------------------------------------------------------------- 1 | { 2 | "info": { 3 | "_postman_id": "ab5e0e17-f1c6-454b-a6b1-42129437a9b4", 4 | "name": "funcX", 5 | "schema": "https://schema.getpostman.com/json/collection/v2.1.0/collection.json" 6 | }, 7 | "item": [ 8 | { 9 | "name": "batch_status", 10 | "request": { 11 | "method": "GET", 12 | "header": [], 13 | "url": { 14 | "raw": "" 15 | } 16 | }, 17 | "response": [] 18 | } 19 | ], 20 | "protocolProfileBehavior": {} 21 | } -------------------------------------------------------------------------------- /integration_tests/get_valid_token.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from funcx.sdk.client import FuncXClient 5 | 6 | fxc = FuncXClient( 7 | funcx_service_address="http://localhost:5000/api/v1", force_login=False 8 | ) 9 | token_file = os.path.join(fxc.TOKEN_DIR, fxc.TOKEN_FILENAME) 10 | with open(token_file) as f: 11 | data = json.load(f) 12 | print(data["funcx_service"]["access_token"]) 13 | -------------------------------------------------------------------------------- /integration_tests/integration_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | 4 | from funcx.sdk.client import FuncXClient 5 | 6 | endpoint = "4a55d09b-a4c3-4f02-8e9e-4e5371d73b54" 7 | 8 | if len(sys.argv) > 0: 9 | endpoint = sys.argv[1] 10 | 11 | fxc = FuncXClient(funcx_service_address="http://localhost:5000/v1") 12 | container = fxc.register_container( 13 | "NCSA", "Docker", "test-image", "This is just a test" 14 | ) 15 | print("Registered container ", container) 16 | 17 | saved_container = fxc.get_container(container, "Docker") 18 | print("Saved container ", saved_container) 19 | 20 | 21 | def hello_world(): 22 | return "Hello World!" 23 | 24 | 25 | def hello_world2(): 26 | return "so it goes" 27 | 28 | 29 | func_uuid = fxc.register_function(hello_world) 30 | task_id = fxc.run(endpoint_id=endpoint, function_id=func_uuid) 31 | print(task_id) 32 | 33 | func_container_uuid = fxc.register_function(hello_world, container_uuid=container) 34 | print(func_uuid) 35 | 36 | 37 | print(fxc.add_to_whitelist(endpoint, [func_uuid, func_container_uuid])) 38 | print("White list", fxc.get_whitelist(endpoint)) 39 | 40 | print("Deleteing ", func_uuid, " from endpoint ", endpoint) 41 | fxc.delete_from_whitelist(endpoint, [func_uuid]) 42 | # This doesn't actually do anything! Empty implementation in SDK 43 | print("update function", fxc.update_function(func_uuid, hello_world2)) 44 | 45 | 46 | # Check batch requests are working 47 | def test_batch1(a, b, c=2, d=2): 48 | return a + b + c + d 49 | 50 | 51 | def test_batch2(a, b, c=2, d=2): 52 | return a * b * c * d 53 | 54 | 55 | def test_batch3(a, b, c=2, d=2): 56 | return a + 2 * b + 3 * c + 4 * d 57 | 58 | 59 | funcs = [test_batch1, test_batch2, test_batch3] 60 | func_ids = [] 61 | for func in funcs: 62 | func_ids.append(fxc.register_function(func, description="test")) 63 | 64 | start = time.time() 65 | task_count = 5 66 | batch = fxc.create_batch() 67 | for func_id in func_ids: 68 | for i in range(task_count): 69 | batch.add(i, i + 1, c=i + 2, d=i + 3, endpoint_id=endpoint, function_id=func_id) 70 | 71 | task_ids = fxc.batch_run(batch) 72 | 73 | delta = time.time() - start 74 | print(f"Time to launch {task_count * len(func_ids)} tasks: {delta:8.3f} s") 75 | print(f"Got {len(task_ids)} tasks_ids ") 76 | 77 | for _i in range(10): 78 | x = fxc.get_batch_status(task_ids) 79 | complete_count = sum( 80 | 1 for t in task_ids if t in x and not x[t].get("pending", True) 81 | ) 82 | print(f"Batch status : {complete_count}/{len(task_ids)} complete") 83 | if complete_count == len(task_ids): 84 | print(x) 85 | break 86 | time.sleep(5) 87 | 88 | 89 | # Verify exception deserialization 90 | def failing(): 91 | raise Exception("deterministic failure") 92 | 93 | 94 | failing_function = fxc.register_function(failing) 95 | 96 | res = fxc.run(endpoint_id=endpoint, function_id=failing_function) 97 | 98 | try: 99 | fxc.get_result(res) 100 | except Exception as e: 101 | print(e) 102 | 103 | 104 | # Check task status updates 105 | def funcx_sleep(val): 106 | import time 107 | 108 | time.sleep(int(val)) 109 | return "done" 110 | 111 | 112 | func_uuid = fxc.register_function(funcx_sleep, description="A sleep function") 113 | 114 | # check for pending status 115 | print("check pending") 116 | payload = 2 117 | res = fxc.run(payload, endpoint_id=endpoint, function_id=func_uuid) 118 | print(res) 119 | try: 120 | print(fxc.get_result(res)) 121 | except Exception as e: 122 | print(e) 123 | pass 124 | 125 | # Check for done 126 | print("check done") 127 | time.sleep(3) 128 | print(fxc.get_result(res)) 129 | 130 | # check for running 131 | print("check running") 132 | payload = 90 133 | res = fxc.run(payload, endpoint_id=endpoint, function_id=func_uuid) 134 | print(res) 135 | time.sleep(60) 136 | try: 137 | print(fxc.get_result(res)) 138 | except Exception as e: 139 | print(e) 140 | print("check still running") 141 | try: 142 | print(fxc.get_result(res)) 143 | except Exception as e: 144 | print(e) 145 | pass 146 | pass 147 | 148 | print("check done") 149 | time.sleep(32) 150 | print(fxc.get_result(res)) 151 | -------------------------------------------------------------------------------- /migrations/README: -------------------------------------------------------------------------------- 1 | Generic single-database configuration. -------------------------------------------------------------------------------- /migrations/alembic.ini: -------------------------------------------------------------------------------- 1 | # A generic, single database configuration. 2 | 3 | [alembic] 4 | # template used to generate migration files 5 | # file_template = %%(rev)s_%%(slug)s 6 | 7 | # set to 'true' to run the environment during 8 | # the 'revision' command, regardless of autogenerate 9 | # revision_environment = false 10 | 11 | 12 | # Logging configuration 13 | [loggers] 14 | keys = root,sqlalchemy,alembic 15 | 16 | [handlers] 17 | keys = console 18 | 19 | [formatters] 20 | keys = generic 21 | 22 | [logger_root] 23 | level = WARN 24 | handlers = console 25 | qualname = 26 | 27 | [logger_sqlalchemy] 28 | level = WARN 29 | handlers = 30 | qualname = sqlalchemy.engine 31 | 32 | [logger_alembic] 33 | level = INFO 34 | handlers = 35 | qualname = alembic 36 | 37 | [handler_console] 38 | class = StreamHandler 39 | args = (sys.stderr,) 40 | level = NOTSET 41 | formatter = generic 42 | 43 | [formatter_generic] 44 | format = %(levelname)-5.5s [%(name)s] %(message)s 45 | datefmt = %H:%M:%S 46 | -------------------------------------------------------------------------------- /migrations/env.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from logging.config import fileConfig 3 | 4 | from alembic import context 5 | from sqlalchemy import engine_from_config, pool 6 | 7 | # this is the Alembic Config object, which provides 8 | # access to the values within the .ini file in use. 9 | config = context.config 10 | 11 | # Interpret the config file for Python logging. 12 | # This line sets up loggers basically. 13 | fileConfig(config.config_file_name) 14 | logger = logging.getLogger("alembic.env") 15 | 16 | # add your model's MetaData object here 17 | # for 'autogenerate' support 18 | # from myapp import mymodel 19 | # target_metadata = mymodel.Base.metadata 20 | from flask import current_app 21 | 22 | config.set_main_option( 23 | "sqlalchemy.url", 24 | str(current_app.extensions["migrate"].db.engine.url).replace("%", "%%"), 25 | ) 26 | target_metadata = current_app.extensions["migrate"].db.metadata 27 | 28 | # other values from the config, defined by the needs of env.py, 29 | # can be acquired: 30 | # my_important_option = config.get_main_option("my_important_option") 31 | # ... etc. 32 | 33 | 34 | def run_migrations_offline(): 35 | """Run migrations in 'offline' mode. 36 | 37 | This configures the context with just a URL 38 | and not an Engine, though an Engine is acceptable 39 | here as well. By skipping the Engine creation 40 | we don't even need a DBAPI to be available. 41 | 42 | Calls to context.execute() here emit the given string to the 43 | script output. 44 | 45 | """ 46 | url = config.get_main_option("sqlalchemy.url") 47 | context.configure(url=url, target_metadata=target_metadata, literal_binds=True) 48 | 49 | with context.begin_transaction(): 50 | context.run_migrations() 51 | 52 | 53 | def run_migrations_online(): 54 | """Run migrations in 'online' mode. 55 | 56 | In this scenario we need to create an Engine 57 | and associate a connection with the context. 58 | 59 | """ 60 | 61 | # this callback is used to prevent an auto-migration from being generated 62 | # when there are no changes to the schema 63 | # reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html 64 | def process_revision_directives(context, revision, directives): 65 | if getattr(config.cmd_opts, "autogenerate", False): 66 | script = directives[0] 67 | if script.upgrade_ops.is_empty(): 68 | directives[:] = [] 69 | logger.info("No changes in schema detected.") 70 | 71 | connectable = engine_from_config( 72 | config.get_section(config.config_ini_section), 73 | prefix="sqlalchemy.", 74 | poolclass=pool.NullPool, 75 | pool_pre_ping=True, 76 | ) 77 | 78 | with connectable.connect() as connection: 79 | context.configure( 80 | connection=connection, 81 | target_metadata=target_metadata, 82 | process_revision_directives=process_revision_directives, 83 | **current_app.extensions["migrate"].configure_args 84 | ) 85 | 86 | with context.begin_transaction(): 87 | context.run_migrations() 88 | 89 | 90 | if context.is_offline_mode(): 91 | run_migrations_offline() 92 | else: 93 | run_migrations_online() 94 | -------------------------------------------------------------------------------- /migrations/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | Revision ID: ${up_revision} 4 | Revises: ${down_revision | comma,n} 5 | Create Date: ${create_date} 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | ${imports if imports else ""} 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = ${repr(up_revision)} 14 | down_revision = ${repr(down_revision)} 15 | branch_labels = ${repr(branch_labels)} 16 | depends_on = ${repr(depends_on)} 17 | 18 | 19 | def upgrade(): 20 | ${upgrades if upgrades else "pass"} 21 | 22 | 23 | def downgrade(): 24 | ${downgrades if downgrades else "pass"} 25 | -------------------------------------------------------------------------------- /migrations/versions/v0.0.3_.py: -------------------------------------------------------------------------------- 1 | """empty message 2 | 3 | Revision ID: v0.0.3 4 | Revises: 5 | Create Date: 2020-10-20 13:57:55.728947 6 | 7 | """ 8 | import sqlalchemy as sa 9 | from alembic import op 10 | 11 | # revision identifiers, used by Alembic. 12 | revision = "v0.0.3" 13 | down_revision = None 14 | branch_labels = None 15 | depends_on = None 16 | 17 | 18 | def upgrade(): 19 | # ### commands auto generated by Alembic - please adjust! ### 20 | op.create_table( 21 | "auth_groups", 22 | sa.Column("id", sa.Integer(), nullable=False), 23 | sa.Column("group_id", sa.String(length=67), nullable=True), 24 | sa.Column("endpoint_id", sa.String(length=67), nullable=True), 25 | sa.PrimaryKeyConstraint("id"), 26 | ) 27 | op.create_table( 28 | "users", 29 | sa.Column("id", sa.Integer(), nullable=False), 30 | sa.Column("username", sa.String(length=256), nullable=True), 31 | sa.Column("globus_identity", sa.String(length=256), nullable=True), 32 | sa.Column("created_at", sa.DateTime(), nullable=True), 33 | sa.Column("namespace", sa.String(length=1024), nullable=True), 34 | sa.Column("deleted", sa.Boolean(), nullable=True), 35 | sa.PrimaryKeyConstraint("id"), 36 | ) 37 | op.create_table( 38 | "containers", 39 | sa.Column("id", sa.Integer(), nullable=False), 40 | sa.Column("author", sa.Integer(), nullable=True), 41 | sa.Column("container_uuid", sa.String(length=67), nullable=True), 42 | sa.Column("name", sa.String(length=1024), nullable=True), 43 | sa.Column("description", sa.Text(), nullable=True), 44 | sa.Column("created_at", sa.DateTime(), nullable=True), 45 | sa.Column("modified_at", sa.DateTime(), nullable=True), 46 | sa.ForeignKeyConstraint( 47 | ["author"], 48 | ["users.id"], 49 | ), 50 | sa.PrimaryKeyConstraint("id"), 51 | ) 52 | op.create_table( 53 | "functions", 54 | sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), 55 | sa.Column("user_id", sa.Integer(), nullable=True), 56 | sa.Column("name", sa.String(length=1024), nullable=True), 57 | sa.Column("description", sa.Text(), nullable=True), 58 | sa.Column("status", sa.String(length=1024), nullable=True), 59 | sa.Column("function_name", sa.String(length=1024), nullable=True), 60 | sa.Column("function_uuid", sa.String(length=38), nullable=True), 61 | sa.Column("function_source_code", sa.Text(), nullable=True), 62 | sa.Column("timestamp", sa.DateTime(), nullable=True), 63 | sa.Column("entry_point", sa.String(length=38), nullable=True), 64 | sa.Column("modified_at", sa.DateTime(), nullable=True), 65 | sa.Column("deleted", sa.Boolean(), nullable=True), 66 | sa.Column("public", sa.Boolean(), nullable=True), 67 | sa.ForeignKeyConstraint( 68 | ["user_id"], 69 | ["users.id"], 70 | ), 71 | sa.PrimaryKeyConstraint("id"), 72 | sa.UniqueConstraint("function_uuid", name="unique_function_uuid"), 73 | ) 74 | op.create_table( 75 | "sites", 76 | sa.Column("id", sa.Integer(), nullable=False), 77 | sa.Column("name", sa.String(length=256), nullable=True), 78 | sa.Column("description", sa.String(length=256), nullable=True), 79 | sa.Column("user_id", sa.Integer(), nullable=True), 80 | sa.Column("status", sa.String(length=10), nullable=True), 81 | sa.Column("endpoint_name", sa.String(length=256), nullable=True), 82 | sa.Column("endpoint_uuid", sa.String(length=38), nullable=True), 83 | sa.Column("public", sa.Boolean(), nullable=True), 84 | sa.Column("deleted", sa.Boolean(), nullable=True), 85 | sa.Column("ip_addr", sa.String(length=15), nullable=True), 86 | sa.Column("city", sa.String(length=256), nullable=True), 87 | sa.Column("region", sa.String(length=256), nullable=True), 88 | sa.Column("country", sa.String(length=256), nullable=True), 89 | sa.Column("zipcode", sa.String(length=10), nullable=True), 90 | sa.Column("latitude", sa.Float(), nullable=True), 91 | sa.Column("longitude", sa.Float(), nullable=True), 92 | sa.Column("core_hours", sa.Float(), nullable=True), 93 | sa.Column("hostname", sa.String(length=256), nullable=True), 94 | sa.Column("org", sa.String(length=256), nullable=True), 95 | sa.Column("restricted", sa.Boolean(), nullable=True), 96 | sa.Column("created_at", sa.DateTime(), nullable=True), 97 | sa.ForeignKeyConstraint( 98 | ["user_id"], 99 | ["users.id"], 100 | ), 101 | sa.PrimaryKeyConstraint("id"), 102 | sa.UniqueConstraint("endpoint_uuid", name="unique_endpoint_uuid"), 103 | ) 104 | op.create_table( 105 | "container_images", 106 | sa.Column("id", sa.Integer(), nullable=False), 107 | sa.Column("container_id", sa.Integer(), nullable=True), 108 | sa.Column("type", sa.String(length=256), nullable=True), 109 | sa.Column("location", sa.String(length=1024), nullable=True), 110 | sa.Column("created_at", sa.DateTime(), nullable=True), 111 | sa.Column("modified_at", sa.DateTime(), nullable=True), 112 | sa.ForeignKeyConstraint( 113 | ["container_id"], 114 | ["containers.id"], 115 | ), 116 | sa.PrimaryKeyConstraint("id"), 117 | ) 118 | op.create_table( 119 | "function_auth_groups", 120 | sa.Column("id", sa.Integer(), nullable=False), 121 | sa.Column("group_id", sa.Integer(), nullable=True), 122 | sa.Column("function_id", sa.Integer(), nullable=True), 123 | sa.ForeignKeyConstraint( 124 | ["function_id"], 125 | ["functions.id"], 126 | ), 127 | sa.ForeignKeyConstraint( 128 | ["group_id"], 129 | ["auth_groups.id"], 130 | ), 131 | sa.PrimaryKeyConstraint("id"), 132 | ) 133 | op.create_table( 134 | "function_containers", 135 | sa.Column("id", sa.Integer(), nullable=False), 136 | sa.Column("container_id", sa.Integer(), nullable=True), 137 | sa.Column("function_id", sa.Integer(), nullable=True), 138 | sa.Column("created_at", sa.DateTime(), nullable=True), 139 | sa.Column("modified_at", sa.DateTime(), nullable=True), 140 | sa.ForeignKeyConstraint( 141 | ["container_id"], 142 | ["containers.id"], 143 | ), 144 | sa.ForeignKeyConstraint( 145 | ["function_id"], 146 | ["functions.id"], 147 | ), 148 | sa.PrimaryKeyConstraint("id"), 149 | ) 150 | op.create_table( 151 | "restricted_endpoint_functions", 152 | sa.Column("id", sa.Integer(), nullable=False), 153 | sa.Column("endpoint_id", sa.Integer(), nullable=True), 154 | sa.Column("function_id", sa.Integer(), nullable=True), 155 | sa.ForeignKeyConstraint( 156 | ["endpoint_id"], 157 | ["sites.id"], 158 | ), 159 | sa.ForeignKeyConstraint( 160 | ["function_id"], 161 | ["functions.id"], 162 | ), 163 | sa.PrimaryKeyConstraint("id"), 164 | ) 165 | op.create_table( 166 | "tasks", 167 | sa.Column("id", sa.Integer(), nullable=False), 168 | sa.Column("user_id", sa.Integer(), nullable=True), 169 | sa.Column("task_uuid", sa.String(length=38), nullable=True), 170 | sa.Column("status", sa.String(length=10), nullable=True), 171 | sa.Column("created_at", sa.DateTime(), nullable=True), 172 | sa.Column("modified_at", sa.DateTime(), nullable=True), 173 | sa.Column("endpoint_id", sa.String(length=38), nullable=True), 174 | sa.Column("function_id", sa.String(length=38), nullable=True), 175 | sa.ForeignKeyConstraint( 176 | ["endpoint_id"], 177 | ["sites.endpoint_uuid"], 178 | ), 179 | sa.ForeignKeyConstraint( 180 | ["function_id"], 181 | ["functions.function_uuid"], 182 | ), 183 | sa.ForeignKeyConstraint( 184 | ["user_id"], 185 | ["users.id"], 186 | ), 187 | sa.PrimaryKeyConstraint("id"), 188 | ) 189 | # ### end Alembic commands ### 190 | 191 | 192 | def downgrade(): 193 | # ### commands auto generated by Alembic - please adjust! ### 194 | op.drop_table("tasks") 195 | op.drop_table("restricted_endpoint_functions") 196 | op.drop_table("function_containers") 197 | op.drop_table("function_auth_groups") 198 | op.drop_table("container_images") 199 | op.drop_table("sites") 200 | op.drop_table("functions") 201 | op.drop_table("containers") 202 | op.drop_table("users") 203 | op.drop_table("auth_groups") 204 | # ### end Alembic commands ### 205 | -------------------------------------------------------------------------------- /migrations/versions/v0.2.0_.py: -------------------------------------------------------------------------------- 1 | """empty message 2 | 3 | Revision ID: v0.2.0 4 | Revises: v0.0.3 5 | Create Date: 2021-05-10 13:23:28.011569 6 | 7 | """ 8 | import sqlalchemy as sa 9 | from alembic import op 10 | 11 | # revision identifiers, used by Alembic. 12 | from sqlalchemy import Column, String 13 | 14 | revision = "v0.2.0" 15 | down_revision = "v0.0.3" 16 | branch_labels = None 17 | depends_on = None 18 | 19 | 20 | def upgrade(): 21 | # ### commands auto generated by Alembic - please adjust! ### 22 | op.drop_constraint( 23 | "function_auth_groups_group_id_fkey", "function_auth_groups", type_="foreignkey" 24 | ) 25 | op.drop_column("function_auth_groups", "group_id") 26 | op.add_column("function_auth_groups", Column("group_id", String(38))) 27 | # ### end Alembic commands ### 28 | 29 | 30 | def downgrade(): 31 | # ### commands auto generated by Alembic - please adjust! ### 32 | op.drop_column("function_auth_groups", "group_id") 33 | op.add_column( 34 | "function_auth_groups", Column("group_id", sa.Integer(), nullable=True) 35 | ) 36 | op.create_foreign_key( 37 | "function_auth_groups_group_id_fkey", 38 | "function_auth_groups", 39 | "auth_groups", 40 | ["group_id"], 41 | ["id"], 42 | ) 43 | # ### end Alembic commands ### 44 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = true 3 | # desired conf (do not set yet): 4 | # 5 | # strict = true 6 | # warn_unreachable = true 7 | # warn_no_return = true 8 | -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | # application requirements as a requirements.txt file, specifying only the 2 | # things which are needed by the app itself (first-order requirements) 3 | # 4 | # freeze with `pipdeptree -f` to get structured requirements.txt data for use 5 | # with `pip install` 6 | # 7 | # note that a requirement which is satisfied by a second-order requirement 8 | # (e.g. globus-sdk=>requests) should also be listed here, in case the upstream 9 | # libraries change and alter their dependencies 10 | # 11 | # tl,dr: If it's used by the application, it goes in this list. 12 | 13 | # flask 14 | Flask<3 15 | Werkzeug<3 16 | 17 | # flask plugins 18 | flask-sqlalchemy<3 19 | Flask-Migrate<3 20 | 21 | # sqlalchemy + psycopg2 for postgres connections 22 | sqlalchemy<1.5 23 | psycopg2-binary==2.8.5 24 | 25 | # redis connections 26 | redis==3.5.3 27 | 28 | # funcx tools 29 | funcx-common[redis,boto3]==0.0.11 30 | 31 | # globus clients 32 | globus-nexus-client==0.3.0 33 | globus-sdk<3 34 | 35 | 36 | 37 | requests>=2.24,<3 38 | python-json-logger<3 39 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # frozen requirements; generate with 'make freezedeps' 2 | alembic==1.7.5 3 | boto3==1.20.28 4 | botocore==1.23.28 5 | certifi==2021.10.8 6 | cffi==1.15.0 7 | charset-normalizer==2.0.10 8 | click==8.0.3 9 | cryptography==36.0.1 10 | Flask==2.0.2 11 | Flask-Migrate==2.7.0 12 | Flask-SQLAlchemy==2.5.1 13 | funcx-common==0.0.11 14 | globus-nexus-client==0.3.0 15 | globus-sdk==2.0.3 16 | greenlet==1.1.2 17 | idna==3.3 18 | importlib-metadata==4.10.0 19 | importlib-resources==5.4.0 20 | itsdangerous==2.0.1 21 | Jinja2==3.0.3 22 | jmespath==0.10.0 23 | Mako==1.1.6 24 | MarkupSafe==2.0.1 25 | psycopg2-binary==2.8.5 26 | pycparser==2.21 27 | PyJWT==1.7.1 28 | python-dateutil==2.8.2 29 | python-json-logger==2.0.2 30 | redis==3.5.3 31 | requests==2.27.0 32 | s3transfer==0.5.0 33 | six==1.16.0 34 | SQLAlchemy==1.4.29 35 | typing-extensions==4.0.1 36 | urllib3==1.26.7 37 | Werkzeug==2.0.2 38 | zipp==3.7.0 39 | -------------------------------------------------------------------------------- /requirements_test.txt: -------------------------------------------------------------------------------- 1 | flake8>=3.8 2 | pytest<7 3 | pytest-cov<3 4 | pytest-flask==1.0.0 5 | coverage>=5.2 6 | codecov==2.1.8 7 | pytest-mock==3.2.0 8 | responses==0.14.0 9 | fakeredis<2 10 | moto[s3]<3 11 | -------------------------------------------------------------------------------- /scripts/store_endpoint_info.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import uuid 4 | 5 | import psycopg2 6 | import psycopg2.extras 7 | import redis 8 | 9 | 10 | def get_endpoint_info(): 11 | """Connect to redis and get endpoint info""" 12 | REDIS_HOST = os.environ.get("redis_host") 13 | REDIS_PORT = os.environ.get("redis_port") 14 | rc = redis.StrictRedis(REDIS_HOST, port=REDIS_PORT, decode_responses=True) 15 | 16 | redis_data = [] 17 | 18 | for key in rc.keys("ep_status_*"): 19 | try: 20 | print(f"Getting key {key}") 21 | endpoint_id = key.split("ep_status_")[1] 22 | try: 23 | uuid.UUID(str(endpoint_id)) 24 | except ValueError: 25 | print("skipping ep:", key) 26 | continue 27 | 28 | items = rc.lrange(key, 0, 0) 29 | if items: 30 | last = json.loads(items[0]) 31 | else: 32 | continue 33 | ep_id = key.split("_")[2] 34 | ep_meta = rc.hgetall(f"endpoint:{ep_id}") 35 | lat, lon = ep_meta["loc"].split(",") 36 | current = ep_meta 37 | current["endpoint_id"] = endpoint_id 38 | current["core_hours"] = last["total_core_hrs"] 39 | current["latitude"] = lat 40 | current["longitude"] = lon 41 | if "hostname" not in current: 42 | current["hostname"] = None 43 | redis_data.append(current) 44 | 45 | except Exception as e: 46 | print(f"Failed to parse for key {key}") 47 | print(f"Error : {e}") 48 | 49 | return redis_data 50 | 51 | 52 | def store_data(ep_data, conn, cur): 53 | """Insert the info into the usage table 54 | 55 | data struct: 56 | {'ip': '140.221.68.107', 'hostname': 'cooleylogin1.cooley.pub.alcf.anl.gov', 57 | 'city': 'New York City', 'region': 'New York', 'country': 'US', 58 | 'loc': '40.7143,-74.0060', 'org': 'AS683 Argonne National Lab', 'postal': '10004', 59 | 'timezone': 'America/New_York', 'readme': 'https://ipinfo.io/missingauth', 60 | 'core_hours': 108.97, 'latitude': '40.7143', 'longitude': '-74.0060'} 61 | 62 | """ 63 | for data in ep_data: 64 | query = ( 65 | "update sites set latitude = %s, longitude = %s, ip_addr = %s, city = %s, " 66 | "region = %s, country = %s, " 67 | "zipcode = %s, hostname = %s, org = %s, core_hours = %s " 68 | "where endpoint_uuid = %s " 69 | ) 70 | cur.execute( 71 | query, 72 | ( 73 | data["latitude"], 74 | data["longitude"], 75 | data["ip"], 76 | data["city"], 77 | data["region"], 78 | data["country"], 79 | data["postal"], 80 | data["hostname"], 81 | data["org"], 82 | data["core_hours"], 83 | data["endpoint_id"], 84 | ), 85 | ) 86 | 87 | conn.commit() 88 | 89 | 90 | def get_info(): 91 | """Extract usage data from the database and redis then store it in the database""" 92 | DB_HOST = os.environ.get("db_host") 93 | DB_USER = os.environ.get("db_user") 94 | DB_NAME = os.environ.get("db_name") 95 | DB_PASSWORD = os.environ.get("db_password") 96 | 97 | con_str = f"dbname={DB_NAME} user={DB_USER} password={DB_PASSWORD} host={DB_HOST}" 98 | 99 | conn = psycopg2.connect(con_str) 100 | cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) 101 | 102 | data = get_endpoint_info() 103 | 104 | store_data(data, conn, cur) 105 | print("done") 106 | 107 | 108 | if __name__ == "__main__": 109 | get_info() 110 | -------------------------------------------------------------------------------- /scripts/store_usage.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import psycopg2 4 | import psycopg2.extras 5 | import redis 6 | 7 | 8 | def rds_usage(conn, cur): 9 | """Connect to the database and pull out usage info""" 10 | db_data = {} 11 | # Number of users 12 | query = "select count(*) from users" 13 | cur.execute(query) 14 | row = cur.fetchone() 15 | if row and "count" in row: 16 | db_data["users"] = row["count"] 17 | 18 | # Number of endpoints 19 | query = "select count(*) from sites" 20 | cur.execute(query) 21 | row = cur.fetchone() 22 | if row and "count" in row: 23 | db_data["endpoints"] = row["count"] 24 | 25 | # Number of functions 26 | query = "select count(*) from functions" 27 | cur.execute(query) 28 | row = cur.fetchone() 29 | if row and "count" in row: 30 | db_data["functions"] = row["count"] 31 | 32 | # Active endpoints, users, functions in last day 33 | query = ( 34 | "select count(distinct function_id) as functions, " 35 | "count(distinct user_id) as users, " 36 | "count(distinct endpoint_id) as endpoints " 37 | "from tasks " 38 | "WHERE created_at > current_date - interval '1' day; " 39 | ) 40 | cur.execute(query) 41 | row = cur.fetchone() 42 | if row: 43 | db_data["endpoints_day"] = row["endpoints"] 44 | db_data["functions_day"] = row["functions"] 45 | db_data["users_day"] = row["users"] 46 | 47 | # Active endpoints, users, functions in last week 48 | query = ( 49 | "select count(distinct function_id) as functions, " 50 | "count(distinct user_id) as users, count(distinct " 51 | "endpoint_id) as endpoints from tasks " 52 | "WHERE created_at > current_date - interval '7' day; " 53 | ) 54 | cur.execute(query) 55 | row = cur.fetchone() 56 | if row: 57 | db_data["endpoints_week"] = row["endpoints"] 58 | db_data["functions_week"] = row["functions"] 59 | db_data["users_week"] = row["users"] 60 | 61 | # Active things this month 62 | query = ( 63 | "select count(distinct function_id) as functions, " 64 | "count(distinct user_id) as users, count(distinct " 65 | "endpoint_id) as endpoints from tasks " 66 | "WHERE created_at >= date_trunc('month', CURRENT_DATE); " 67 | ) 68 | cur.execute(query) 69 | row = cur.fetchone() 70 | if row: 71 | db_data["endpoints_month"] = row["endpoints"] 72 | db_data["functions_month"] = row["functions"] 73 | db_data["users_month"] = row["users"] 74 | 75 | return db_data 76 | 77 | 78 | def redis_usage(): 79 | """Connect to redis and get counters""" 80 | REDIS_HOST = os.environ.get("redis_host") 81 | REDIS_PORT = os.environ.get("redis_port") 82 | rc = redis.StrictRedis(REDIS_HOST, port=REDIS_PORT, decode_responses=True) 83 | 84 | redis_data = {} 85 | # Total core hours 86 | redis_data["core_hours"] = rc.get("funcx_worldwide_counter") 87 | # Total function invocations 88 | redis_data["invocations"] = rc.get("funcx_invocation_counter") 89 | return redis_data 90 | 91 | 92 | def store_data(data, conn, cur): 93 | """Insert the info into the usage table 94 | 95 | DB STRUCTURE: 96 | total_functions int, 97 | total_endpoints int, 98 | total_users int, 99 | total_core_hours float, 100 | total_invocations int, 101 | functions_day int, 102 | functions_week int, 103 | functions_month int, 104 | endpoints_day int, 105 | endpoints_week int, 106 | endpoints_month int, 107 | users_day int, 108 | users_week int, 109 | users_month int 110 | """ 111 | query = ( 112 | "insert into usage_info (total_functions, total_endpoints, total_users, " 113 | "total_core_hours, total_invocations, functions_day, functions_week, " 114 | "functions_month, endpoints_day, endpoints_week, endpoints_month, users_day, " 115 | "users_week, users_month) values (%s, %s, %s, %s, " 116 | "%s, %s, %s, %s, %s, %s, %s, %s, %s, %s);" 117 | ) 118 | cur.execute( 119 | query, 120 | ( 121 | data["functions"], 122 | data["endpoints"], 123 | data["users"], 124 | float(data["core_hours"]), 125 | int(data["invocations"]), 126 | data["functions_day"], 127 | data["functions_week"], 128 | data["functions_month"], 129 | data["endpoints_day"], 130 | data["endpoints_week"], 131 | data["endpoints_month"], 132 | data["users_day"], 133 | data["users_week"], 134 | data["users_month"], 135 | ), 136 | ) 137 | conn.commit() 138 | 139 | 140 | def record_usage(): 141 | """Extract usage data from the database and redis then store it in the database""" 142 | DB_HOST = os.environ.get("db_host") 143 | DB_USER = os.environ.get("db_user") 144 | DB_NAME = os.environ.get("db_name") 145 | DB_PASSWORD = os.environ.get("db_password") 146 | 147 | con_str = f"dbname={DB_NAME} user={DB_USER} password={DB_PASSWORD} host={DB_HOST}" 148 | 149 | conn = psycopg2.connect(con_str) 150 | cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) 151 | 152 | data = rds_usage(conn, cur) 153 | redis_data = redis_usage() 154 | 155 | # Combine them together 156 | data.update(redis_data) 157 | print(data) 158 | 159 | store_data(data, conn, cur) 160 | print("done") 161 | 162 | 163 | if __name__ == "__main__": 164 | record_usage() 165 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/funcx-faas/funcx-web-service/5714d5ab889396a72d81f25cf17e9523a9ecc82b/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import typing as t 3 | import uuid 4 | 5 | import boto3 6 | import fakeredis 7 | import flask 8 | import moto 9 | import pytest 10 | import responses 11 | 12 | from funcx_web_service import create_app 13 | from funcx_web_service.models import db 14 | from funcx_web_service.models.user import User 15 | 16 | TEST_FORWARDER_IP = "192.162.3.5" 17 | DEFAULT_FUNCX_SCOPE = ( 18 | "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all" 19 | ) 20 | 21 | 22 | class FakeAuthState: 23 | """A fake object to replace the AuthenticationState during tests.""" 24 | 25 | def __init__( 26 | self, 27 | *, 28 | user: t.Optional[User], 29 | scope: t.Optional[str], 30 | introspect_data: t.Optional[dict], 31 | ): 32 | self.is_authenticated = user is not None 33 | self.user_object = user 34 | self.username = user.username if user is not None else None 35 | self.identity_id = user.globus_identity if user is not None else None 36 | self.scopes = {scope} 37 | 38 | if introspect_data is None: 39 | if self.is_authenticated: 40 | self.introspect_data = { 41 | "active": True, 42 | "username": self.username, 43 | "sub": self.identity_id, 44 | "scope": scope, 45 | } 46 | else: 47 | self.introspect_data = {"active": False} 48 | else: 49 | self.introspect_data = introspect_data 50 | 51 | def assert_is_authenticated(self): 52 | if not self.is_authenticated: 53 | flask.abort(401, "unauthenticated in FakeAuthState") 54 | 55 | def assert_has_default_scope(self): 56 | if DEFAULT_FUNCX_SCOPE not in self.scopes: 57 | flask.abort(403, "missing scope in FakeAuthState") 58 | 59 | 60 | @pytest.fixture 61 | def mocked_responses(): 62 | with responses.RequestsMock() as r: 63 | yield r 64 | 65 | 66 | @pytest.fixture 67 | def mock_redis_server(): 68 | return fakeredis.FakeServer() 69 | 70 | 71 | @pytest.fixture 72 | def mock_redis_pubsub(mocker): 73 | mock_pubsub = mocker.Mock() 74 | 75 | mocker.patch( 76 | "funcx_web_service.routes.funcx.g_redis_pubsub", return_value=mock_pubsub 77 | ) 78 | return mock_pubsub 79 | 80 | 81 | @pytest.fixture 82 | def mock_redis(mocker, mock_redis_server): 83 | mock_redis_client = fakeredis.FakeStrictRedis( 84 | server=mock_redis_server, decode_responses=True 85 | ) 86 | 87 | mocker.patch( 88 | "funcx_web_service.routes.funcx.get_redis_client", 89 | return_value=mock_redis_client, 90 | ) 91 | 92 | return mock_redis_client 93 | 94 | 95 | @pytest.fixture(scope="session") 96 | def flask_app(): 97 | app = create_app( 98 | test_config={ 99 | "GLOBUS_CLIENT": "TEST_GLOBUS_CLIENT_ID", 100 | "GLOBUS_KEY": "TEST_GLOBUS_CLIENT_SECRET", 101 | "REDIS_HOST": "localhost", 102 | "REDIS_PORT": 5000, 103 | "SQLALCHEMY_DATABASE_URI": "sqlite:///:memory:", 104 | "SQLALCHEMY_TRACK_MODIFICATIONS": False, 105 | "HOSTNAME": "http://testhost", 106 | "FORWARDER_IP": TEST_FORWARDER_IP, 107 | "ADVERTISED_REDIS_HOST": "my-redis.com", 108 | "CONTAINER_SERVICE_ENABLED": False, 109 | } 110 | ) 111 | app.secret_key = "Shhhhh" 112 | 113 | # do DB setup without waiting for first request, so that tests can use the DB 114 | # outside of a request context 115 | with app.test_request_context(): 116 | db.create_all() 117 | return app 118 | 119 | 120 | @pytest.fixture 121 | def flask_app_ctx(flask_app): 122 | with flask_app.app_context() as app_ctx: 123 | yield app_ctx 124 | 125 | 126 | @pytest.fixture 127 | def flask_request_ctx(flask_app, flask_app_ctx): 128 | with flask_app.test_request_context() as request_ctx: 129 | yield request_ctx 130 | 131 | 132 | @pytest.fixture 133 | def flask_test_client(flask_app, flask_app_ctx): 134 | return flask_app.test_client() 135 | 136 | 137 | @pytest.fixture 138 | def enable_mock_container_service(flask_app, mocker): 139 | mock_container_service = mocker.Mock() 140 | mock_container_service.get_version = mocker.Mock(return_value={"version": "3.14"}) 141 | 142 | @contextlib.contextmanager 143 | def func(): 144 | flask_app.extensions["ContainerService"] = mock_container_service 145 | flask_app.config["CONTAINER_SERVICE_ENABLED"] = True 146 | 147 | yield 148 | 149 | flask_app.extensions["ContainerService"] = None 150 | flask_app.config["CONTAINER_SERVICE_ENABLED"] = False 151 | 152 | return func 153 | 154 | 155 | @pytest.fixture 156 | def mock_s3_bucket(monkeypatch): 157 | bucket = "funcx-web-service-test-bucket" 158 | monkeypatch.setenv("FUNCX_S3_BUCKET_NAME", bucket) 159 | with moto.mock_s3(): 160 | client = boto3.client("s3") 161 | client.create_bucket(Bucket=bucket) 162 | yield bucket 163 | 164 | 165 | @pytest.fixture 166 | def mock_user_identity_id(): 167 | return str(uuid.uuid1()) 168 | 169 | 170 | @pytest.fixture 171 | def mock_user(flask_app_ctx, mock_user_identity_id): 172 | return User(username="foo-user", globus_identity=mock_user_identity_id, id=22) 173 | 174 | 175 | @pytest.fixture 176 | def mock_auth_state(flask_request_ctx, mock_user, mock_user_identity_id): 177 | # this fixture returns a context manager which can be used to set a mocked state 178 | # by default, that context manager will use the mock_user fixture data 179 | 180 | @contextlib.contextmanager 181 | def mock_ctx(*, user=mock_user, scope=DEFAULT_FUNCX_SCOPE, introspect_data=None): 182 | fake_auth_state = FakeAuthState( 183 | user=user, 184 | scope=scope, 185 | introspect_data=introspect_data, 186 | ) 187 | 188 | sentinel = object() 189 | oldstate = getattr(flask.g, "auth_state", sentinel) 190 | flask.g.auth_state = fake_auth_state 191 | yield 192 | if oldstate is sentinel: 193 | delattr(flask.g, "auth_state") 194 | else: 195 | flask.g.auth_state = oldstate 196 | 197 | return mock_ctx 198 | 199 | 200 | @pytest.fixture 201 | def in_mock_auth_state(mock_auth_state): 202 | """ 203 | A slightly different fixture from the mock_auth_state, this enters the 204 | context manager provided by mock_auth_state automatically. 205 | """ 206 | with mock_auth_state(): 207 | yield 208 | 209 | 210 | @pytest.fixture 211 | def default_forwarder_responses(mocked_responses): 212 | mocked_responses.add( 213 | responses.GET, 214 | f"http://{TEST_FORWARDER_IP}:8080/version", 215 | json={ 216 | "forwarder": "0.3.5", 217 | "min_ep_version": "0.0.1", 218 | }, 219 | status=200, 220 | ) 221 | mocked_responses.add( 222 | responses.POST, 223 | f"http://{TEST_FORWARDER_IP}:8080/register", 224 | body="{}", 225 | status=200, 226 | ) 227 | -------------------------------------------------------------------------------- /tests/integration/test_endpoint_api.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from funcx_web_service.models.user import User 4 | 5 | 6 | def test_unauthorized_get_ep_status( 7 | flask_test_client, mock_auth_state, default_forwarder_responses 8 | ): 9 | """ 10 | Create (register) an endpoint as one user, then switch to a second user and attempt 11 | to get the endpoint status. 12 | 13 | The result should be an error. 14 | """ 15 | id1 = str(uuid.uuid1()) 16 | id2 = str(uuid.uuid1()) 17 | epid = str(uuid.uuid1()) 18 | epinfo = { 19 | "version": "100.0.0", 20 | "endpoint_name": "foo-ep-1", 21 | "endpoint_uuid": epid, 22 | } 23 | 24 | user1 = User(username="foo", globus_identity=id1, id=100) 25 | user2 = User(username="bar", globus_identity=id2, id=101) 26 | 27 | with mock_auth_state(user=user1): 28 | result = flask_test_client.post("/api/v1/endpoints", json=epinfo) 29 | assert result.status_code == 200 30 | with mock_auth_state(user=user2): 31 | result2 = flask_test_client.get(f"/api/v1/endpoints/{epid}/status") 32 | assert result2.status_code == 403 33 | -------------------------------------------------------------------------------- /tests/test_container_service_adapter.py: -------------------------------------------------------------------------------- 1 | import responses 2 | 3 | from funcx_web_service import ContainerServiceAdapter 4 | 5 | 6 | @responses.activate 7 | def test_version(): 8 | responses.add( 9 | responses.GET, 10 | "http://container-service:5000/version", 11 | json={"version": "3.14"}, 12 | status=200, 13 | ) 14 | container_service = ContainerServiceAdapter("http://container-service:5000") 15 | assert container_service.get_version() == {"version": "3.14"} 16 | 17 | 18 | @responses.activate 19 | def test_version_server_error(): 20 | responses.add(responses.GET, "http://container-service:5000/version", status=500) 21 | container_service = ContainerServiceAdapter("http://container-service:5000") 22 | assert container_service.get_version() == {"version": "Service Unavailable"} 23 | -------------------------------------------------------------------------------- /tests/unit/auth/test_auth_state.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import responses 3 | from werkzeug.exceptions import Forbidden, Unauthorized 4 | 5 | from funcx_web_service.authentication.auth_state import ( 6 | AuthenticationState, 7 | get_auth_state, 8 | ) 9 | from funcx_web_service.models.user import User 10 | 11 | INTROSPECT_RESPONSE = { 12 | "active": True, 13 | "scope": "https://auth.globus.org/scopes/facd7ccc-c5f4-42aa-916b-a0e270e2c2a9/all", 14 | "sub": "79cb54bb-2296-424a-9ab2-8dabcf1457ff", 15 | "username": "example@globus.org", 16 | "name": None, 17 | "email": None, 18 | "client_id": "facd7ccc-c5f4-42aa-916b-a0e270e2c2a9", 19 | "aud": ["facd7ccc-c5f4-42aa-916b-a0e270e2c2a9"], 20 | "iss": "https://auth.globus.org", 21 | "exp": 1915236549, 22 | "iat": 1599703671, 23 | "nbf": 1599703671, 24 | "identity_set": ["79cb54bb-2296-424a-9ab2-8dabcf1457ff"], 25 | } 26 | 27 | 28 | @pytest.fixture 29 | def good_introspect(mocked_responses): 30 | mocked_responses.add( 31 | responses.POST, 32 | "https://auth.globus.org/v2/oauth2/token/introspect", 33 | json=INTROSPECT_RESPONSE, 34 | status=200, 35 | ) 36 | 37 | 38 | @pytest.fixture 39 | def badscope_introspect(mocked_responses): 40 | data = {**INTROSPECT_RESPONSE} 41 | data["scope"] = "" 42 | mocked_responses.add( 43 | responses.POST, 44 | "https://auth.globus.org/v2/oauth2/token/introspect", 45 | json=data, 46 | status=200, 47 | ) 48 | 49 | 50 | def test_get_auth_state_no_authz_header(flask_app, flask_app_ctx): 51 | # this test uses a customized flask request context in order to test get_auth_state 52 | # codepaths 53 | # this generally is not necessary to imitate: simply construct an 54 | # AuthenticationState in tests instead 55 | with flask_app.test_request_context(headers={}): 56 | state = get_auth_state() 57 | assert isinstance(state, AuthenticationState) 58 | assert state.is_authenticated is False 59 | 60 | with pytest.raises(Unauthorized): 61 | state.assert_is_authenticated() 62 | 63 | 64 | def test_get_auth_state_good_token(flask_app, flask_app_ctx, good_introspect): 65 | # this test uses a customized flask request context in order to test get_auth_state 66 | # codepaths 67 | # this generally is not necessary to imitate: simply construct an 68 | # AuthenticationState in tests instead 69 | with flask_app.test_request_context(headers={"Authorization": "Bearer foo"}): 70 | state = get_auth_state() 71 | assert isinstance(state, AuthenticationState) 72 | assert state.is_authenticated is True 73 | 74 | assert state.username == INTROSPECT_RESPONSE["username"] 75 | assert state.identity_id == INTROSPECT_RESPONSE["sub"] 76 | 77 | state.assert_is_authenticated() 78 | state.assert_has_default_scope() 79 | 80 | 81 | def test_auth_state_bad_scope(flask_request_ctx, badscope_introspect): 82 | state = AuthenticationState("foo") 83 | assert isinstance(state, AuthenticationState) 84 | assert state.is_authenticated is True 85 | assert state.username == INTROSPECT_RESPONSE["username"] 86 | assert state.identity_id == INTROSPECT_RESPONSE["sub"] 87 | 88 | # TODO: is this right? seems like this should be a class of 401 89 | with pytest.raises(Forbidden): 90 | state.assert_has_default_scope() 91 | 92 | 93 | def test_auth_state_user_object(flask_request_ctx, good_introspect): 94 | # check that fetching a user object from the AuthenticationState works 95 | state = AuthenticationState("foo") 96 | 97 | userobj = state.user_object 98 | assert userobj is not None 99 | assert isinstance(userobj, User) 100 | assert userobj.username == state.username 101 | -------------------------------------------------------------------------------- /tests/unit/auth/test_authorization_functions.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import funcx_web_service.authentication 4 | from funcx_web_service.authentication.auth import authorize_endpoint, authorize_function 5 | from funcx_web_service.models.auth_groups import AuthGroup 6 | from funcx_web_service.models.endpoint import Endpoint 7 | from funcx_web_service.models.function import Function, FunctionAuthGroup 8 | 9 | 10 | @pytest.fixture(autouse=True) 11 | def _auto_app_context(flask_app_ctx): 12 | """This is an autouse fixture which loads the flask app context. 13 | Ensures that all tests in this module execute within a flask app context.""" 14 | 15 | 16 | def test_authorize_endpoint_restricted_whitelist(mocker): 17 | """ 18 | Test to see that we are authorized if the endpoint is restricted, but the 19 | requested function is in the whitelist 20 | """ 21 | 22 | authorize_endpoint.cache_clear() 23 | mock_endpoint_find = mocker.patch.object( 24 | Endpoint, 25 | "find_by_uuid", 26 | return_value=Endpoint( 27 | public=True, 28 | restricted=True, 29 | restricted_functions=[Function(function_uuid="123")], 30 | ), 31 | ) 32 | 33 | result = authorize_endpoint( 34 | user_id="test_user", 35 | endpoint_uuid="123-45-566", 36 | function_uuid="123", 37 | token="ttttt", 38 | ) 39 | assert result 40 | mock_endpoint_find.assert_called_with("123-45-566") 41 | 42 | 43 | def test_authorize_endpoint_restricted_not_whitelist(mocker): 44 | """ 45 | Test to see that we are authorized if the endpoint is restricted, and the 46 | requested function is not in the whitelist 47 | """ 48 | authorize_endpoint.cache_clear() 49 | 50 | mock_endpoint_find = mocker.patch.object( 51 | Endpoint, 52 | "find_by_uuid", 53 | return_value=Endpoint( 54 | public=True, 55 | restricted=True, 56 | restricted_functions=[Function(function_uuid="456")], 57 | ), 58 | ) 59 | 60 | with pytest.raises(Exception) as excinfo: 61 | authorize_endpoint( 62 | user_id="test_user", 63 | endpoint_uuid="123-45-566", 64 | function_uuid="123", 65 | token="ttttt", 66 | ) 67 | print(excinfo) 68 | mock_endpoint_find.assert_called_with("123-45-566") 69 | 70 | 71 | def test_authorize_endpoint_public(mocker): 72 | authorize_endpoint.cache_clear() 73 | 74 | mock_endpoint_find = mocker.patch.object( 75 | Endpoint, 76 | "find_by_uuid", 77 | return_value=Endpoint(public=True, restricted=False), 78 | ) 79 | result = authorize_endpoint( 80 | user_id="test_user", 81 | endpoint_uuid="123-45-566", 82 | function_uuid="123", 83 | token="ttttt", 84 | ) 85 | assert result 86 | mock_endpoint_find.assert_called_with("123-45-566") 87 | 88 | 89 | def test_authorize_endpoint_user(mocker): 90 | authorize_endpoint.cache_clear() 91 | 92 | mock_endpoint_find = mocker.patch.object( 93 | Endpoint, 94 | "find_by_uuid", 95 | return_value=Endpoint(public=False, restricted=False, user_id=42), 96 | ) 97 | result = authorize_endpoint( 98 | user_id=42, endpoint_uuid="123-45-566", function_uuid="123", token="ttttt" 99 | ) 100 | assert result 101 | mock_endpoint_find.assert_called_with("123-45-566") 102 | 103 | 104 | def test_authorize_endpoint_group(mocker): 105 | 106 | authorize_endpoint.cache_clear() 107 | 108 | mock_endpoint_find = mocker.patch.object( 109 | Endpoint, 110 | "find_by_uuid", 111 | return_value=Endpoint(public=False, restricted=False, user_id=1), 112 | ) 113 | mock_auth_group_find = mocker.patch.object( 114 | AuthGroup, 115 | "find_by_endpoint_uuid", 116 | return_value=[AuthGroup(group_id="my-group", endpoint_id="123-45-566")], 117 | ) 118 | 119 | mock_check_group_membership = mocker.patch.object( 120 | funcx_web_service.authentication.auth, 121 | "check_group_membership", 122 | return_value=True, 123 | ) 124 | 125 | result = authorize_endpoint( 126 | user_id=42, endpoint_uuid="123-45-566", function_uuid="123", token="ttttt" 127 | ) 128 | assert result 129 | mock_endpoint_find.assert_called_with("123-45-566") 130 | mock_auth_group_find.assert_called_with("123-45-566") 131 | mock_check_group_membership.assert_called_with("ttttt", ["my-group"]) 132 | 133 | 134 | def test_authorize_endpoint_no_group(mocker): 135 | authorize_endpoint.cache_clear() 136 | 137 | mock_endpoint_find = mocker.patch.object( 138 | Endpoint, 139 | "find_by_uuid", 140 | return_value=Endpoint(public=False, restricted=False, user_id=1), 141 | ) 142 | mock_auth_group_find = mocker.patch.object( 143 | AuthGroup, "find_by_endpoint_uuid", return_value=[] 144 | ) 145 | 146 | mock_check_group_membership = mocker.patch.object( 147 | funcx_web_service.authentication.auth, 148 | "check_group_membership", 149 | return_value=True, 150 | ) 151 | 152 | result = authorize_endpoint( 153 | user_id=42, endpoint_uuid="123-45-566", function_uuid="123", token="ttttt" 154 | ) 155 | assert not result 156 | mock_endpoint_find.assert_called_with("123-45-566") 157 | mock_auth_group_find.assert_called_with("123-45-566") 158 | mock_check_group_membership.assert_not_called() 159 | 160 | 161 | def test_authorize_function_user_owns(mocker): 162 | authorize_function.cache_clear() 163 | 164 | mock_function_find = mocker.patch.object( 165 | Function, "find_by_uuid", return_value=Function(public=False, user_id=44) 166 | ) 167 | result = authorize_function(user_id=44, function_uuid="123", token="ttttt") 168 | assert result 169 | mock_function_find.assert_called_with("123") 170 | 171 | 172 | def test_authorize_function_public(mocker): 173 | authorize_function.cache_clear() 174 | 175 | mock_function_find = mocker.patch.object( 176 | Function, "find_by_uuid", return_value=Function(public=True, user_id=1) 177 | ) 178 | result = authorize_function(user_id=44, function_uuid="123", token="ttttt") 179 | assert result 180 | mock_function_find.assert_called_with("123") 181 | 182 | 183 | def test_authorize_function_auth_group(mocker): 184 | authorize_function.cache_clear() 185 | 186 | mock_function_find = mocker.patch.object( 187 | Function, "find_by_uuid", return_value=Function(public=False, user_id=1) 188 | ) 189 | 190 | mock_check_group_membership = mocker.patch.object( 191 | funcx_web_service.authentication.auth, 192 | "check_group_membership", 193 | return_value=True, 194 | ) 195 | 196 | mocker.patch.object( 197 | FunctionAuthGroup, 198 | "find_by_function_id", 199 | return_value=[FunctionAuthGroup(group_id="my-group", function_id="123-45-566")], 200 | ) 201 | 202 | result = authorize_function(user_id=44, function_uuid="123", token="ttttt") 203 | assert result 204 | mock_function_find.assert_called_with("123") 205 | mock_check_group_membership.assert_called_with("ttttt", ["my-group"]) 206 | -------------------------------------------------------------------------------- /tests/unit/routes/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from funcx_common.tasks import TaskState 3 | 4 | from funcx_web_service.models.endpoint import Endpoint 5 | from funcx_web_service.models.tasks import InternalTaskState, RedisTask 6 | from funcx_web_service.models.user import User 7 | 8 | 9 | @pytest.fixture 10 | def mock_user(): 11 | return User(username="bob", globus_identity="123-456", id=22) 12 | 13 | 14 | @pytest.fixture 15 | def mock_endpoint(): 16 | return Endpoint( 17 | user_id=1, 18 | endpoint_uuid="11111111-2222-3333-4444-555555555555", 19 | restricted=True, 20 | restricted_functions=[], 21 | ) 22 | 23 | 24 | @pytest.fixture 25 | def mock_redis_task_factory(mock_redis, mock_user, mocker): 26 | def func(task_id, user_id=mock_user.id, status=None, internal_status=None): 27 | t = RedisTask(mock_redis, task_id=task_id) 28 | t.user_id = user_id 29 | t.status = TaskState.WAITING_FOR_EP 30 | t.internal_status = InternalTaskState.INCOMPLETE 31 | return t 32 | 33 | return func 34 | -------------------------------------------------------------------------------- /tests/unit/routes/test_auth.py: -------------------------------------------------------------------------------- 1 | def test_authenticate(flask_test_client, in_mock_auth_state): 2 | result = flask_test_client.get("/api/v1/authenticate") 3 | assert result.status_code == 200 4 | -------------------------------------------------------------------------------- /tests/unit/routes/test_funcx.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import responses 3 | 4 | from funcx_web_service.version import VERSION 5 | 6 | 7 | @responses.activate 8 | @pytest.mark.parametrize("container_service_enabled", (True, False)) 9 | def test_version( 10 | flask_test_client, enable_mock_container_service, container_service_enabled 11 | ): 12 | responses.add( 13 | responses.GET, 14 | "http://192.162.3.5:8080/version", 15 | json={"forwarder": "1.2.3", "min_ep_version": "3.2.1"}, 16 | status=200, 17 | ) 18 | 19 | if container_service_enabled: 20 | with enable_mock_container_service(): 21 | result = flask_test_client.get( 22 | "/api/v1/version", query_string={"service": "all"} 23 | ) 24 | else: 25 | result = flask_test_client.get( 26 | "/api/v1/version", query_string={"service": "all"} 27 | ) 28 | version_result = result.json 29 | 30 | assert len(responses.calls) == 1 31 | assert responses.calls[0].request.url == "http://192.162.3.5:8080/version" 32 | 33 | assert version_result["api"] == VERSION 34 | assert version_result["forwarder"] == "1.2.3" 35 | assert version_result["min_ep_version"] == "3.2.1" 36 | if container_service_enabled: 37 | assert version_result["container_service"] == "3.14" 38 | else: 39 | assert "container_service" not in version_result 40 | 41 | assert "min_sdk_version" in version_result 42 | 43 | 44 | def test_stats(flask_test_client, mocker, mock_redis): 45 | mock_redis.set("funcx_invocation_counter", 1024) 46 | spy = mocker.spy(mock_redis, "get") 47 | 48 | result = flask_test_client.get("/api/v1/stats") 49 | assert result.status_code == 200 50 | assert result.json["total_function_invocations"] == 1024 51 | 52 | spy.assert_called_once_with("funcx_invocation_counter") 53 | 54 | 55 | def test_stats_malformed_underlying_data(flask_test_client, mocker, mock_redis): 56 | mock_redis.set("funcx_invocation_counter", "foo") 57 | spy = mocker.spy(mock_redis, "get") 58 | 59 | result = flask_test_client.get("/api/v1/stats") 60 | assert result.status_code == 500 61 | assert b"Unable to get invocation count" in result.data 62 | 63 | spy.assert_called_once_with("funcx_invocation_counter") 64 | -------------------------------------------------------------------------------- /tests/unit/routes/test_register_container.py: -------------------------------------------------------------------------------- 1 | from funcx_common.response_errors import ResponseErrorCode 2 | 3 | from funcx_web_service.models.container import Container 4 | 5 | 6 | def test_register_container(flask_test_client, mocker, in_mock_auth_state): 7 | result = flask_test_client.post( 8 | "v2/containers", 9 | json={ 10 | "name": "myContainer", 11 | "function_name": "test fun", 12 | "description": "this is a test", 13 | "type": "docker", 14 | "location": "http://hub.docker.com/myContainer", 15 | }, 16 | headers={"Authorization": "my_token"}, 17 | ) 18 | assert result.status_code == 200 19 | assert "container_id" in result.json 20 | container_uuid = result.json["container_id"] 21 | 22 | saved_container = Container.find_by_uuid(container_uuid) 23 | assert saved_container 24 | assert saved_container.name == "myContainer" 25 | assert saved_container.container_uuid == container_uuid 26 | assert saved_container.description == "this is a test" 27 | 28 | assert saved_container.images 29 | assert len(saved_container.images) == 1 30 | assert saved_container.images[0].type == "docker" 31 | assert saved_container.images[0].location == "http://hub.docker.com/myContainer" 32 | 33 | 34 | def test_register_container_invalid_spec(flask_test_client, mocker, in_mock_auth_state): 35 | result = flask_test_client.post( 36 | "v2/containers", 37 | json={ 38 | "type": "docker", 39 | "location": "http://hub.docker.com/myContainer", 40 | }, 41 | headers={"Authorization": "my_token"}, 42 | ) 43 | assert result.status_code == 400 44 | assert result.json["status"] == "Failed" 45 | assert result.json["code"] == int(ResponseErrorCode.REQUEST_KEY_ERROR) 46 | assert result.json["reason"] == "Missing key in JSON request - 'name'" 47 | 48 | 49 | def test_get_container(flask_test_client, mocker, in_mock_auth_state): 50 | container = Container() 51 | container.container_uuid = "123-45-678" 52 | container.name = "Docky" 53 | find_container_mock = mocker.patch.object( 54 | Container, "find_by_uuid_and_type", return_value=container 55 | ) 56 | 57 | result = flask_test_client.get( 58 | "v2/containers/1/docker", headers={"Authorization": "my_token"} 59 | ) 60 | 61 | result_container = result.json["container"] 62 | assert result_container["container_uuid"] == "123-45-678" 63 | find_container_mock.assert_called_with("1", "docker") 64 | -------------------------------------------------------------------------------- /tests/unit/routes/test_register_endpoint.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | 4 | import responses 5 | from funcx_common.response_errors import ResponseErrorCode 6 | 7 | from funcx_web_service.models.endpoint import Endpoint 8 | 9 | 10 | @responses.activate 11 | def test_register_endpoint(flask_test_client, mocker, in_mock_auth_state, mock_user): 12 | responses.add( 13 | responses.GET, 14 | "http://192.162.3.5:8080/version", 15 | json={"forwarder": "1.2.3", "min_ep_version": "3.2.1"}, 16 | status=200, 17 | ) 18 | 19 | responses.add( 20 | responses.POST, 21 | "http://192.162.3.5:8080/register", 22 | json={"forwarder": "1.1", "min_ep_version": "1.2"}, 23 | status=200, 24 | ) 25 | 26 | mock_register_endpoint = mocker.patch( 27 | "funcx_web_service.routes.funcx.register_endpoint", 28 | return_value="123-45-6789-1011", 29 | ) 30 | 31 | result = flask_test_client.post( 32 | "api/v1/endpoints", 33 | json={ 34 | "version": "3.2.2", 35 | "endpoint_name": "my-endpoint", 36 | "endpoint_uuid": None, 37 | }, 38 | headers={"Authorization": "my_token"}, 39 | ) 40 | assert len(responses.calls) == 2 41 | assert responses.calls[0].request.url == "http://192.162.3.5:8080/version" 42 | 43 | mock_register_endpoint.assert_called_with( 44 | mock_user, "my-endpoint", "", endpoint_uuid=None 45 | ) 46 | 47 | assert responses.calls[1].request.url == "http://192.162.3.5:8080/register" 48 | assert json.loads(responses.calls[1].request.body) == { 49 | "endpoint_id": "123-45-6789-1011", 50 | "redis_address": "my-redis.com", 51 | "endpoint_addr": "127.0.0.1", 52 | } 53 | 54 | assert result.status_code == 200 55 | 56 | 57 | def test_register_endpoint_version_mismatch( 58 | flask_test_client, mocker, in_mock_auth_state 59 | ): 60 | get_forwarder_version = mocker.patch( 61 | "funcx_web_service.routes.funcx.get_forwarder_version", 62 | return_value={"min_ep_version": "1.42.0"}, 63 | ) 64 | 65 | result = flask_test_client.post( 66 | "api/v1/endpoints", 67 | json={"version": "1.0.0"}, 68 | headers={"Authorization": "my_token"}, 69 | ) 70 | get_forwarder_version.assert_called() 71 | assert result.status_code == 400 72 | assert result.json["status"] == "Failed" 73 | assert result.json["code"] == int(ResponseErrorCode.ENDPOINT_OUTDATED) 74 | assert "Endpoint is out of date." in result.json["reason"] 75 | 76 | 77 | def test_register_endpoint_no_version(flask_test_client, mocker, in_mock_auth_state): 78 | get_forwarder_version = mocker.patch( 79 | "funcx_web_service.routes.funcx.get_forwarder_version", 80 | return_value={"min_ep_version": "1.42.0"}, 81 | ) 82 | 83 | result = flask_test_client.post( 84 | "api/v1/endpoints", json={}, headers={"Authorization": "my_token"} 85 | ) 86 | get_forwarder_version.assert_called() 87 | assert result.status_code == 400 88 | assert result.json["status"] == "Failed" 89 | assert result.json["code"] == int(ResponseErrorCode.REQUEST_KEY_ERROR) 90 | assert "version must be passed in" in result.json["reason"] 91 | 92 | 93 | def test_register_endpoint_missing_keys(flask_test_client, mocker, in_mock_auth_state): 94 | get_forwarder_version = mocker.patch( 95 | "funcx_web_service.routes.funcx.get_forwarder_version", 96 | return_value={"min_ep_version": "1.0.0"}, 97 | ) 98 | 99 | result = flask_test_client.post( 100 | "api/v1/endpoints", 101 | json={"version": "1.0.0", "endpoint_uuid": None}, 102 | headers={"Authorization": "my_token"}, 103 | ) 104 | get_forwarder_version.assert_called() 105 | 106 | assert result.status_code == 400 107 | assert result.json["status"] == "Failed" 108 | assert result.json["code"] == int(ResponseErrorCode.REQUEST_KEY_ERROR) 109 | assert result.json["reason"] == "Missing key in JSON request - 'endpoint_name'" 110 | 111 | 112 | def test_register_endpoint_already_registered( 113 | flask_test_client, mocker, in_mock_auth_state, mock_endpoint 114 | ): 115 | get_forwarder_version = mocker.patch( 116 | "funcx_web_service.routes.funcx.get_forwarder_version", 117 | return_value={"min_ep_version": "1.0.0"}, 118 | ) 119 | 120 | mocker.patch.object(Endpoint, "find_by_uuid", return_value=mock_endpoint) 121 | 122 | result = flask_test_client.post( 123 | "api/v1/endpoints", 124 | json={ 125 | "version": "1.0.0", 126 | "endpoint_name": "my-endpoint", 127 | "endpoint_uuid": "11111111-2222-3333-4444-555555555555", 128 | }, 129 | headers={"Authorization": "my_token"}, 130 | ) 131 | get_forwarder_version.assert_called() 132 | 133 | assert result.status_code == 400 134 | assert result.json["status"] == "Failed" 135 | assert result.json["code"] == int(ResponseErrorCode.ENDPOINT_ALREADY_REGISTERED) 136 | assert ( 137 | result.json["reason"] == "Endpoint 11111111-2222-3333-4444-555555555555 " 138 | "was already registered by a different user" 139 | ) 140 | 141 | 142 | def test_register_endpoint_unknown_error(flask_test_client, mocker, in_mock_auth_state): 143 | get_forwarder_version = mocker.patch( 144 | "funcx_web_service.routes.funcx.get_forwarder_version", 145 | return_value={"min_ep_version": "1.0.0"}, 146 | ) 147 | 148 | mock_register_endpoint = mocker.patch( 149 | "funcx_web_service.routes.funcx.register_endpoint", 150 | return_value="123-45-6789-1011", 151 | ) 152 | mock_register_endpoint.side_effect = Exception("hello") 153 | 154 | result = flask_test_client.post( 155 | "api/v1/endpoints", 156 | json={ 157 | "version": "1.0.0", 158 | "endpoint_name": "my-endpoint", 159 | "endpoint_uuid": None, 160 | }, 161 | headers={"Authorization": "my_token"}, 162 | ) 163 | get_forwarder_version.assert_called() 164 | 165 | assert result.status_code == 500 166 | assert result.json["status"] == "Failed" 167 | assert result.json["code"] == int(ResponseErrorCode.UNKNOWN_ERROR) 168 | assert result.json["reason"] == "An unknown error occurred: hello" 169 | 170 | 171 | def test_endpoint_status(flask_test_client, mocker, in_mock_auth_state, mock_redis): 172 | mocker.patch("funcx_web_service.routes.funcx.authorize_endpoint", return_value=True) 173 | 174 | epid = 123 175 | mock_redis.lpush(f"ep_status_{epid}", json.dumps({"timestamp": time.time()})) 176 | lrange_spy = mocker.spy(mock_redis, "lrange") 177 | 178 | result = flask_test_client.get( 179 | f"api/v1/endpoints/{epid}/status", headers={"Authorization": "my_token"} 180 | ) 181 | status_result = result.json 182 | assert len(status_result["logs"]) == 1 183 | assert status_result["status"] == "online" 184 | lrange_spy.assert_called_with("ep_status_123", 0, 1) 185 | 186 | 187 | def test_endpoint_delete(flask_test_client, mocker, in_mock_auth_state, mock_user): 188 | mock_delete_endpoint = mocker.patch.object( 189 | Endpoint, "delete_endpoint", return_value="Ok" 190 | ) 191 | 192 | result = flask_test_client.delete( 193 | "api/v1/endpoints/123", headers={"Authorization": "my_token"} 194 | ) 195 | assert result.json["result"] == "Ok" 196 | mock_delete_endpoint.assert_called_with(mock_user, "123") 197 | 198 | 199 | def test_get_whitelist(flask_test_client, mocker, in_mock_auth_state, mock_user): 200 | get_ep_whitelist = mocker.patch( 201 | "funcx_web_service.routes.funcx.get_ep_whitelist", 202 | return_value={"status": "success", "functions": ["1", "2", "3"]}, 203 | ) 204 | 205 | result = flask_test_client.get( 206 | "api/v1/endpoints/123/whitelist", headers={"Authorization": "my_token"} 207 | ) 208 | whitelist_result = result.json 209 | assert whitelist_result["status"] == "success" 210 | assert whitelist_result["functions"] == ["1", "2", "3"] 211 | get_ep_whitelist.assert_called_with(mock_user, "123") 212 | 213 | 214 | def test_add_whitelist(flask_test_client, mocker, in_mock_auth_state, mock_user): 215 | add_ep_whitelist = mocker.patch( 216 | "funcx_web_service.routes.funcx.add_ep_whitelist", 217 | return_value={"status": "success"}, 218 | ) 219 | 220 | result = flask_test_client.post( 221 | "api/v1/endpoints/123/whitelist", 222 | json={"func": ["1", "2", "3"]}, 223 | headers={"Authorization": "my_token"}, 224 | ) 225 | whitelist_result = result.json 226 | assert whitelist_result["status"] == "success" 227 | add_ep_whitelist.assert_called_with(mock_user, "123", ["1", "2", "3"]) 228 | 229 | 230 | def test_delete_whitelisted(flask_test_client, mocker, in_mock_auth_state, mock_user): 231 | delete_ep_whitelist = mocker.patch( 232 | "funcx_web_service.routes.funcx.delete_ep_whitelist", 233 | return_value={"status": "success"}, 234 | ) 235 | 236 | result = flask_test_client.delete( 237 | "api/v1/endpoints/123/whitelist/678-9", 238 | headers={"Authorization": "my_token"}, 239 | ) 240 | whitelist_result = result.json 241 | assert whitelist_result["status"] == "success" 242 | delete_ep_whitelist.assert_called_with(mock_user, "123", "678-9") 243 | -------------------------------------------------------------------------------- /tests/unit/routes/test_register_function.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import ANY 2 | 3 | from funcx_web_service.models.container import Container 4 | from funcx_web_service.models.function import Function, FunctionAuthGroup 5 | from funcx_web_service.models.user import User 6 | 7 | 8 | def test_register_function(flask_test_client, in_mock_auth_state, mocker): 9 | mock_ingest = mocker.patch("funcx_web_service.routes.funcx.ingest_function") 10 | result = flask_test_client.post( 11 | "api/v1/functions", 12 | json={ 13 | "function_source": "def fun(x): return x+1", 14 | "function_name": "test fun", 15 | "entry_point": "func()", 16 | "description": "this is a test", 17 | "function_code": "flksdjfldkjdlkfjslk", 18 | "public": True, 19 | }, 20 | headers={"Authorization": "my_token"}, 21 | ) 22 | assert result.status_code == 200 23 | assert "function_uuid" in result.json 24 | 25 | saved_function = Function.find_by_uuid(result.json["function_uuid"]) 26 | assert saved_function.function_uuid == result.json["function_uuid"] 27 | assert saved_function.function_name == "test fun" 28 | assert saved_function.entry_point == "func()" 29 | assert saved_function.description == "this is a test" 30 | assert saved_function.function_source_code == "flksdjfldkjdlkfjslk" 31 | assert saved_function.public 32 | 33 | assert mock_ingest.call_args[0][0].function_uuid == result.json["function_uuid"] 34 | assert mock_ingest.call_args[0][1] == "def fun(x): return x+1" 35 | assert mock_ingest.call_args[0][2] == "123-456" 36 | 37 | 38 | def test_register_function_no_search(flask_test_client, in_mock_auth_state, mocker): 39 | mock_ingest = mocker.patch("funcx_web_service.routes.funcx.ingest_function") 40 | mock_user = User(id=42, username="bob") 41 | 42 | mocker.patch.object(User, "find_by_username", return_value=mock_user) 43 | result = flask_test_client.post( 44 | "api/v1/functions", 45 | json={ 46 | "function_source": "def fun(x): return x+1", 47 | "function_name": "test fun", 48 | "entry_point": "func()", 49 | "description": "this is a test", 50 | "function_code": "flksdjfldkjdlkfjslk", 51 | "public": True, 52 | "searchable": False, 53 | }, 54 | headers={"Authorization": "my_token"}, 55 | ) 56 | assert result.status_code == 200 57 | assert "function_uuid" in result.json 58 | assert mock_ingest.not_called 59 | 60 | 61 | def test_register_function_with_container( 62 | flask_test_client, in_mock_auth_state, mocker 63 | ): 64 | mock_ingest = mocker.patch("funcx_web_service.routes.funcx.ingest_function") 65 | mock_user = User(id=42, username="bob") 66 | mocker.patch.object(User, "find_by_username", return_value=mock_user) 67 | 68 | mock_container = Container(id=44) 69 | 70 | mock_container_read = mocker.patch.object( 71 | Container, "find_by_uuid", return_value=mock_container 72 | ) 73 | result = flask_test_client.post( 74 | "api/v1/functions", 75 | json={ 76 | "function_source": "def fun(x): return x+1", 77 | "function_name": "test fun", 78 | "entry_point": "func()", 79 | "description": "this is a test", 80 | "function_code": "flksdjfldkjdlkfjslk", 81 | "public": True, 82 | "searchable": False, 83 | "container_uuid": "11122-22111", 84 | }, 85 | headers={"Authorization": "my_token"}, 86 | ) 87 | assert result.status_code == 200 88 | assert "function_uuid" in result.json 89 | assert mock_ingest.not_called 90 | mock_container_read.assert_called_with("11122-22111") 91 | 92 | saved_function = Function.find_by_uuid(result.json["function_uuid"]) 93 | assert saved_function.container.container_id == 44 94 | 95 | 96 | def test_register_function_with_group_auth( 97 | flask_test_client, in_mock_auth_state, mocker 98 | ): 99 | mock_ingest = mocker.patch("funcx_web_service.routes.funcx.ingest_function") 100 | mock_user = User(id=42, username="bob") 101 | mocker.patch.object(User, "find_by_username", return_value=mock_user) 102 | 103 | mock_auth_group = FunctionAuthGroup(id=45) 104 | 105 | mock_authgroup_read = mocker.patch( 106 | "funcx_web_service.routes.funcx.FunctionAuthGroup" 107 | ) 108 | mock_authgroup_read.return_value = mock_auth_group 109 | 110 | result = flask_test_client.post( 111 | "api/v1/functions", 112 | json={ 113 | "function_source": "def fun(x): return x+1", 114 | "function_name": "test fun", 115 | "entry_point": "func()", 116 | "description": "this is a test", 117 | "function_code": "flksdjfldkjdlkfjslk", 118 | "public": True, 119 | "searchable": False, 120 | "group": "222-111", 121 | }, 122 | headers={"Authorization": "my_token"}, 123 | ) 124 | assert result.status_code == 200 125 | assert "function_uuid" in result.json 126 | assert mock_ingest.not_called 127 | mock_authgroup_read.assert_called_with(function=ANY, group_id="222-111") 128 | 129 | saved_function = Function.find_by_uuid(result.json["function_uuid"]) 130 | assert len(saved_function.auth_groups) == 1 131 | assert saved_function.auth_groups[0].id == 45 132 | 133 | 134 | def test_update_function(flask_test_client, in_mock_auth_state, mocker): 135 | mock_update = mocker.patch( 136 | "funcx_web_service.routes.funcx.update_function", return_value=302 137 | ) 138 | result = flask_test_client.put( 139 | "api/v1/functions/123-45", 140 | json={ 141 | "function_source": "def fun(x): return x+1", 142 | "name": "test fun", 143 | "desc": "this is a test", 144 | "entry_point": "func()", 145 | "code": "flksdjfldkjdlkfjslk", 146 | "public": True, 147 | }, 148 | headers={"Authorization": "my_token"}, 149 | ) 150 | assert result.status_code == 302 151 | mock_update.assert_called_with( 152 | "bob", 153 | "123-45", 154 | "test fun", 155 | "this is a test", 156 | "func()", 157 | "flksdjfldkjdlkfjslk", 158 | ) 159 | 160 | 161 | def test_delete_function(flask_test_client, in_mock_auth_state, mock_user, mocker): 162 | mock_delete = mocker.patch( 163 | "funcx_web_service.routes.funcx.delete_function", return_value=302 164 | ) 165 | result = flask_test_client.delete( 166 | "api/v1/functions/123-45", headers={"Authorization": "my_token"} 167 | ) 168 | assert result.status_code == 200 169 | assert mock_delete.called_with("bob", "123-45") 170 | -------------------------------------------------------------------------------- /tests/unit/routes/test_status.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | 3 | from funcx_web_service.models.tasks import RedisTask 4 | from funcx_web_service.models.user import User 5 | 6 | 7 | def test_get_status( 8 | flask_test_client, in_mock_auth_state, mock_redis_task_factory, mock_user: User 9 | ): 10 | mock_redis_task_factory("42") 11 | result = flask_test_client.get( 12 | "/api/v1/tasks/42", headers={"Authorization": "my_token"} 13 | ) 14 | assert result.status_code == 200 15 | assert result.json["task_id"] == "42", result.json 16 | 17 | 18 | def test_unauthorized_get_status( 19 | flask_test_client, in_mock_auth_state, mock_redis_task_factory, mock_user: User 20 | ): 21 | """ 22 | Verify that a user cannot retrieve a Task status which is not theirs 23 | """ 24 | mock_redis_task_factory("42", user_id=123) 25 | result = flask_test_client.get( 26 | "/api/v1/tasks/42", headers={"Authorization": "my_token"} 27 | ) 28 | assert result.status_code == 404 29 | assert result.json["status"] == "Failed", result.json 30 | 31 | 32 | def test_get_batch_status( 33 | flask_test_client, 34 | in_mock_auth_state, 35 | mock_redis, 36 | mock_redis_task_factory, 37 | mock_user: User, 38 | ): 39 | mock_redis_task_factory("1") 40 | mock_redis_task_factory("2") 41 | mock_redis.hset("task_2", "result", "foo-some-result") 42 | 43 | response = flask_test_client.post( 44 | "/api/v1/batch_status", 45 | headers={"Authorization": "my_token"}, 46 | json={"task_ids": ["1", "2"]}, 47 | ) 48 | assert response 49 | data = response.json 50 | assert data["response"] == "batch" 51 | results = data["results"] 52 | assert len(results) == 2 53 | assert {"1", "2"} == set(results.keys()) 54 | for task_data in results.values(): 55 | assert "completion_t" in task_data, data 56 | assert "status" in task_data, data 57 | assert results["2"]["result"] == "foo-some-result" 58 | 59 | 60 | def test_unauthorized_get_batch_status( 61 | flask_test_client, mocker, in_mock_auth_state, mock_redis_task_factory, mock_redis 62 | ): 63 | """ 64 | Verify that a user cannot retrieve a status for a Batch which is not theirs 65 | """ 66 | mock_redis_task_factory("1", user_id=123) 67 | mock_redis_task_factory("2", user_id=123) 68 | 69 | exists_spy = mocker.spy(RedisTask, "exists") 70 | 71 | result = flask_test_client.post( 72 | "/api/v1/batch_status", 73 | headers={"Authorization": "my_token"}, 74 | json={"task_ids": ["1", "2"]}, 75 | ) 76 | 77 | assert result.status_code == 200 78 | assert "results" in result.json, result.json 79 | assert "1" in result.json["results"] 80 | assert "2" in result.json["results"] 81 | for task_id in result.json["results"]: 82 | assert result.json["results"][task_id]["reason"] == "Unknown task id" 83 | assert result.json["results"][task_id]["status"] == "Failed" 84 | 85 | exists_spy.assert_has_calls( 86 | [mock.call(mock_redis, "1"), mock.call(mock_redis, "2")] 87 | ) 88 | -------------------------------------------------------------------------------- /tests/unit/routes/test_submit_function.py: -------------------------------------------------------------------------------- 1 | from funcx_common.response_errors import ResponseErrorCode 2 | 3 | from funcx_web_service.models.tasks import TaskGroup 4 | 5 | 6 | def test_submit_function_access_forbidden( 7 | flask_test_client, mocker, in_mock_auth_state 8 | ): 9 | mock_authorize_function = mocker.patch( 10 | "funcx_web_service.routes.funcx.authorize_function", return_value=False 11 | ) 12 | 13 | mocker.patch.object(TaskGroup, attribute="exists", return_value=False) 14 | 15 | result = flask_test_client.post( 16 | "api/v1/submit", 17 | json={"tasks": [("1111", "2222", "")]}, 18 | headers={"Authorization": "my_token"}, 19 | ) 20 | 21 | mock_authorize_function.assert_called() 22 | 23 | assert result.status_code == 207 24 | res = result.json["results"][0] 25 | assert res["http_status_code"] == 403 26 | assert res["status"] == "Failed" 27 | assert res["code"] == int(ResponseErrorCode.FUNCTION_ACCESS_FORBIDDEN) 28 | assert res["reason"] == "Unauthorized access to function 1111" 29 | 30 | 31 | def test_submit_function_not_found(flask_test_client, mocker, in_mock_auth_state): 32 | 33 | mock_find_function = mocker.patch( 34 | "funcx_web_service.authentication.auth.Function.find_by_uuid", 35 | return_value=None, 36 | ) 37 | 38 | mocker.patch.object(TaskGroup, attribute="exists", return_value=False) 39 | 40 | result = flask_test_client.post( 41 | "api/v1/submit", 42 | json={"tasks": [("1111", "2222", "")]}, 43 | headers={"Authorization": "my_token"}, 44 | ) 45 | 46 | mock_find_function.assert_called() 47 | 48 | assert result.status_code == 207 49 | res = result.json["results"][0] 50 | assert res["http_status_code"] == 404 51 | assert res["status"] == "Failed" 52 | assert res["code"] == int(ResponseErrorCode.FUNCTION_NOT_FOUND) 53 | assert res["reason"] == "Function 1111 could not be resolved" 54 | 55 | 56 | def test_submit_endpoint_access_forbidden( 57 | flask_test_client, mocker, in_mock_auth_state 58 | ): 59 | mock_authorize_function = mocker.patch( 60 | "funcx_web_service.routes.funcx.authorize_function", return_value=True 61 | ) 62 | 63 | mock_resolve_function = mocker.patch( 64 | "funcx_web_service.routes.funcx.resolve_function", 65 | return_value=("1", "2", "3"), 66 | ) 67 | 68 | mock_authorize_endpoint = mocker.patch( 69 | "funcx_web_service.routes.funcx.authorize_endpoint", return_value=False 70 | ) 71 | 72 | mocker.patch.object(TaskGroup, attribute="exists", return_value=False) 73 | 74 | result = flask_test_client.post( 75 | "api/v1/submit", 76 | json={"tasks": [("1111", "2222", "")]}, 77 | headers={"Authorization": "my_token"}, 78 | ) 79 | 80 | mock_authorize_function.assert_called() 81 | mock_resolve_function.assert_called() 82 | mock_authorize_endpoint.assert_called() 83 | 84 | assert result.status_code == 207 85 | res = result.json["results"][0] 86 | assert res["http_status_code"] == 403 87 | assert res["status"] == "Failed" 88 | assert res["code"] == int(ResponseErrorCode.ENDPOINT_ACCESS_FORBIDDEN) 89 | assert res["reason"] == "Unauthorized access to endpoint 2222" 90 | 91 | 92 | def test_submit_endpoint_not_found(flask_test_client, mocker, in_mock_auth_state): 93 | mock_authorize_function = mocker.patch( 94 | "funcx_web_service.routes.funcx.authorize_function", return_value=True 95 | ) 96 | 97 | mock_resolve_function = mocker.patch( 98 | "funcx_web_service.routes.funcx.resolve_function", 99 | return_value=("1", "2", "3"), 100 | ) 101 | 102 | mock_find_endpoint = mocker.patch( 103 | "funcx_web_service.authentication.auth.Endpoint.find_by_uuid", 104 | return_value=None, 105 | ) 106 | 107 | mocker.patch.object(TaskGroup, attribute="exists", return_value=False) 108 | 109 | result = flask_test_client.post( 110 | "api/v1/submit", 111 | json={"tasks": [("1111", "2222", "")]}, 112 | headers={"Authorization": "my_token"}, 113 | ) 114 | 115 | mock_authorize_function.assert_called() 116 | mock_resolve_function.assert_called() 117 | mock_find_endpoint.assert_called() 118 | 119 | assert result.status_code == 207 120 | res = result.json["results"][0] 121 | assert res["http_status_code"] == 404 122 | assert res["status"] == "Failed" 123 | assert res["code"] == int(ResponseErrorCode.ENDPOINT_NOT_FOUND) 124 | assert res["reason"] == "Endpoint 2222 could not be resolved" 125 | 126 | 127 | def test_submit_function_not_permitted( 128 | flask_test_client, mocker, in_mock_auth_state, mock_endpoint 129 | ): 130 | mock_authorize_function = mocker.patch( 131 | "funcx_web_service.routes.funcx.authorize_function", return_value=True 132 | ) 133 | 134 | mock_resolve_function = mocker.patch( 135 | "funcx_web_service.routes.funcx.resolve_function", 136 | return_value=("1", "2", "3"), 137 | ) 138 | 139 | mock_find_endpoint = mocker.patch( 140 | "funcx_web_service.authentication.auth.Endpoint.find_by_uuid", 141 | return_value=mock_endpoint, 142 | ) 143 | 144 | mocker.patch.object(TaskGroup, attribute="exists", return_value=False) 145 | 146 | result = flask_test_client.post( 147 | "api/v1/submit", 148 | json={"tasks": [("1111", "2222", "")]}, 149 | headers={"Authorization": "my_token"}, 150 | ) 151 | 152 | mock_authorize_function.assert_called() 153 | mock_resolve_function.assert_called() 154 | mock_find_endpoint.assert_called() 155 | 156 | assert result.status_code == 207 157 | res = result.json["results"][0] 158 | assert res["http_status_code"] == 403 159 | assert res["status"] == "Failed" 160 | assert res["code"] == int(ResponseErrorCode.FUNCTION_NOT_PERMITTED) 161 | assert res["reason"] == "Function 1111 not permitted on endpoint 2222" 162 | 163 | 164 | def test_submit_function( 165 | flask_test_client, mocker, in_mock_auth_state, mock_redis_pubsub, mock_redis 166 | ): 167 | mock_function_auth = mocker.patch( 168 | "funcx_web_service.routes.funcx.authorize_function", return_value=True 169 | ) 170 | mock_endpoint_auth = mocker.patch( 171 | "funcx_web_service.routes.funcx.authorize_endpoint", return_value=True 172 | ) 173 | mock_resolve = mocker.patch( 174 | "funcx_web_service.routes.funcx.resolve_function", 175 | return_value=("codecode", "entry", "123-45"), 176 | ) 177 | 178 | result = flask_test_client.post( 179 | "/api/v1/submit", 180 | json={"tasks": [["12", "13", "my_data"]]}, 181 | headers={"Authorization": "my_token"}, 182 | ) 183 | 184 | assert result.status_code == 200 185 | submit_result = result.json 186 | assert submit_result["response"] == "batch" 187 | assert len(submit_result["results"]) == 1 188 | assert submit_result["results"][0]["status"] == "Success" 189 | 190 | mock_function_auth.assert_called_with(22, "12", "my_token") 191 | mock_endpoint_auth.assert_called_with(22, "13", "12", "my_token") 192 | mock_resolve.assert_called_with(22, "12") 193 | 194 | put_call = mock_redis_pubsub.put.call_args 195 | assert put_call[0][0] == "13" 196 | -------------------------------------------------------------------------------- /tests/unit/routes/test_task_groups.py: -------------------------------------------------------------------------------- 1 | from funcx_web_service.models.tasks import TaskGroup 2 | 3 | 4 | def test_get_batch_info( 5 | flask_test_client, mocker, mock_user, in_mock_auth_state, mock_redis 6 | ): 7 | TaskGroup(mock_redis, "123", user_id=mock_user.id) 8 | exists_spy = mocker.spy(TaskGroup, "exists") 9 | 10 | result = flask_test_client.get( 11 | "/api/v1/task_groups/123", headers={"Authorization": "my_token"} 12 | ) 13 | assert result.json["authorized"] 14 | exists_spy.assert_called_with(mock_redis, "123") 15 | -------------------------------------------------------------------------------- /tests/unit/test_app_init.py: -------------------------------------------------------------------------------- 1 | import funcx_web_service 2 | 3 | TEST_CONFIG = """\ 4 | FOO = "bar" 5 | SECRET_VALUE = "blah" 6 | BOOL_VALUE = False 7 | CONTAINER_SERVICE_ENABLED = False 8 | """ 9 | 10 | 11 | def test_read_from_config(tmp_path, monkeypatch): 12 | conf_file = tmp_path / "test.config" 13 | conf_file.write_text(TEST_CONFIG) 14 | monkeypatch.setenv("APP_CONFIG_FILE", str(conf_file)) 15 | 16 | app = funcx_web_service.create_app() 17 | assert app.config["FOO"] == "bar" 18 | assert app.config["SECRET_VALUE"] == "blah" 19 | assert not app.config["BOOL_VALUE"] 20 | 21 | monkeypatch.setenv("SECRET_VALUE", "shhh") 22 | monkeypatch.setenv("BOOL_VALUE", "true") 23 | app_from_env = funcx_web_service.create_app() 24 | assert app_from_env.config["SECRET_VALUE"] == "shhh" 25 | assert app_from_env.config["BOOL_VALUE"] 26 | -------------------------------------------------------------------------------- /tests/unit/test_task_behavior.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from funcx_common.tasks import TaskState 4 | 5 | from funcx_web_service.models.tasks import InternalTaskState, RedisTask, TaskGroup 6 | 7 | 8 | def test_redis_task_creation(mock_redis): 9 | task_id = str(uuid.uuid1()) 10 | 11 | assert not RedisTask.exists(mock_redis, task_id) 12 | new_task = RedisTask(mock_redis, task_id) 13 | 14 | assert RedisTask.exists(mock_redis, task_id) 15 | assert new_task.status == TaskState.WAITING_FOR_EP 16 | assert new_task.internal_status == InternalTaskState.INCOMPLETE 17 | 18 | 19 | # this is a regression test for a bug in which newly created RedisTask objects would 20 | # "reset" the status and internal_status fields of a task 21 | def test_redis_task_double_lookup(mock_redis): 22 | task_id = str(uuid.uuid1()) 23 | 24 | assert not RedisTask.exists(mock_redis, task_id) 25 | new_task = RedisTask(mock_redis, task_id) 26 | new_task.status = TaskState.SUCCESS 27 | new_task.internal_status = InternalTaskState.COMPLETE 28 | 29 | assert RedisTask.exists(mock_redis, task_id) 30 | second_task = RedisTask(mock_redis, task_id) 31 | assert second_task.status == TaskState.SUCCESS 32 | assert second_task.internal_status == InternalTaskState.COMPLETE 33 | 34 | 35 | def test_redis_task_delete(mock_redis): 36 | task_id = str(uuid.uuid1()) 37 | 38 | assert not RedisTask.exists(mock_redis, task_id) 39 | task = RedisTask(mock_redis, task_id) 40 | 41 | assert RedisTask.exists(mock_redis, task_id) 42 | task.delete() 43 | 44 | assert not RedisTask.exists(mock_redis, task_id) 45 | 46 | 47 | def test_task_group_create(mock_redis): 48 | task_group_id = str(uuid.uuid1()) 49 | user_id = 101 50 | 51 | assert not TaskGroup.exists(mock_redis, task_group_id) 52 | 53 | tg = TaskGroup(mock_redis, task_group_id, user_id=user_id) 54 | assert tg.user_id == user_id 55 | 56 | assert TaskGroup.exists(mock_redis, task_group_id) 57 | 58 | 59 | def test_task_group_delete(mock_redis): 60 | task_group_id = str(uuid.uuid1()) 61 | user_id = 101 62 | 63 | assert not TaskGroup.exists(mock_redis, task_group_id) 64 | tg = TaskGroup(mock_redis, task_group_id, user_id=user_id) 65 | assert TaskGroup.exists(mock_redis, task_group_id) 66 | tg.delete() 67 | assert not TaskGroup.exists(mock_redis, task_group_id) 68 | 69 | 70 | def test_task_group_no_user_id(mock_redis): 71 | # creating a TaskGroup without setting user_id results in no creation within redis, 72 | # and the TaskGroup does not exist 73 | # TODO: determine if this is the desired behavior 74 | # maybe __init__ should raise an exception if `user_id` is not set? 75 | task_group_id = str(uuid.uuid1()) 76 | 77 | assert not TaskGroup.exists(mock_redis, task_group_id) 78 | 79 | tg = TaskGroup(mock_redis, task_group_id) 80 | assert tg.user_id is None 81 | 82 | assert not TaskGroup.exists(mock_redis, task_group_id) 83 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py37 3 | 4 | [testenv] 5 | skip_install = true 6 | deps = 7 | -r requirements.txt 8 | -r requirements_test.txt 9 | commands = pytest tests --cov=funcx_web_service {posargs} 10 | 11 | [testenv:lint] 12 | deps = pre-commit<3 13 | skip_install = true 14 | commands = pre-commit run --all-files 15 | 16 | [testenv:mypy] 17 | deps = 18 | mypy 19 | types-requests 20 | types-redis 21 | commands = mypy funcx_web_service/ 22 | 23 | [testenv:safety] 24 | skip_install = true 25 | deps = 26 | safety 27 | -r requirements.txt 28 | commands = safety check 29 | 30 | [testenv:freezedeps] 31 | skip_install = true 32 | recreate = true 33 | deps = 34 | pipdeptree 35 | -r requirements.in 36 | commands = pipdeptree -f --exclude pip,pipdeptree,setuptools,wheel 37 | 38 | [testenv:codecov] 39 | skip_install = true 40 | deps = codecov 41 | commands = codecov {posargs} 42 | -------------------------------------------------------------------------------- /uwsgi.ini: -------------------------------------------------------------------------------- 1 | [uwsgi] 2 | uid = uwsgi 3 | binary-path = /usr/local/bin/uwsgi 4 | plugin = python37,http 5 | http = 0.0.0.0:5000 6 | manage-script-name = true 7 | http-keepalive = 1 8 | log-master=true 9 | module = funcx_web_service.application:app 10 | -------------------------------------------------------------------------------- /web-entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | FLASK_APP=funcx_web_service/application.py flask db upgrade 3 | uwsgi --ini uwsgi.ini --------------------------------------------------------------------------------