├── .gitignore ├── LICENSE ├── README.md ├── notebooks ├── app.py ├── assets │ ├── cifar_frog.png │ ├── cifar_horse.png │ ├── cifar_truck.png │ ├── favicon.ico │ └── logo-no-text.png ├── cartpole.yml ├── ch_01_overview.ipynb ├── ch_02_ray_core.ipynb ├── ch_03_core_app.ipynb ├── ch_04_rllib.ipynb ├── ch_05_tune.ipynb ├── ch_06_data_processing.ipynb ├── ch_07_train.ipynb ├── ch_08_model_serving.ipynb ├── ch_09_example_aws.yaml ├── ch_09_example_azure.yaml ├── ch_09_example_gcp.yaml ├── ch_09_example_k8s.yaml ├── ch_09_ray_start_demo.txt ├── ch_09_script.ipynb ├── ch_10_air.ipynb ├── ch_11_ecosystem.ipynb ├── fare_predictor.py ├── gradio_demo.py ├── images │ ├── chapter_01 │ │ ├── AIR.png │ │ ├── Ecosystem.png │ │ ├── cartpole.png │ │ ├── ds_workflow.png │ │ ├── ray_layers.png │ │ ├── ray_layers_old.png │ │ └── simple_cluster.png │ ├── chapter_02 │ │ ├── architecture.png │ │ ├── map_reduce.png │ │ ├── task_dependency.png │ │ └── worker_node.png │ ├── chapter_03 │ │ └── train_policy.png │ ├── chapter_04 │ │ ├── mapping_envs.png │ │ ├── rllib_envs.png │ │ └── rllib_external.png │ ├── chapter_05 │ │ ├── Tune_model_training.png │ │ └── tune_flow.png │ ├── chapter_06 │ │ ├── AIR_data.png │ │ ├── data_pipeline_1.png │ │ ├── data_pipeline_2.png │ │ ├── data_positioning_1.png │ │ ├── data_positioning_2.png │ │ ├── datasets_arch.png │ │ ├── ml_workflow.png │ │ └── ml_workflow_no_logos.png │ ├── chapter_07 │ │ ├── data_model_parallel.png │ │ ├── torch_trainer.png │ │ ├── train_architecture.png │ │ ├── train_overview.png │ │ └── train_tune_execution.png │ ├── chapter_08 │ │ ├── nlp_api_arch.png │ │ ├── serve_arch.png │ │ └── serve_positioning.png │ ├── chapter_09 │ │ ├── kuberay_overview.png │ │ └── ray_kubernetes_operator.png │ ├── chapter_10 │ │ ├── AIR_deployment.png │ │ ├── AIR_predictor.png │ │ ├── AIR_trainer.png │ │ ├── AIR_tuner.png │ │ ├── AIR_workloads.png │ │ ├── Tune_stateful.png │ │ ├── air_overview.png │ │ ├── air_plan.png │ │ └── stateless_air_tasks.png │ ├── chapter_11 │ │ ├── AIR_ML_platform.png │ │ ├── Ray_extended_eco.png │ │ └── custom_integrations.png │ ├── learning_ray.png │ ├── marbled_electric_ray.png │ └── scaling_design_patterns.png ├── index.md ├── maze.py ├── maze.yml ├── maze_gym_env.py ├── nyc_tlc_data │ ├── yellow_tripdata_2020-01.parquet │ └── yellow_tripdata_2021-01.parquet ├── policy_client.py └── policy_server.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | .DS_Store 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | .idea/ 134 | 135 | mkdocs-material/ 136 | overrides/ 137 | js/ 138 | css/ 139 | 140 | mkdocs.yml 141 | 142 | data/ 143 | mlruns/ 144 | torch_checkpoint/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Max Pumperla 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Ray - Flexible Distributed Python for Machine Learning 2 | 3 | Jupyter notebooks and other resources for the upcoming book "Learning Ray" (O'Reilly). 4 | All code and diagrams used in the book are available here for free. 5 | The notebooks can be read online, as we add more and more explanations in the online version. 6 | If you want to support this project and buy the book itself, you can get it 7 | [directly from O'Reilly](https://www.oreilly.com/library/view/learning-ray/9781098117214/), 8 | or [from Amazon](https://www.amazon.com/Learning-Ray-Flexible-Distributed-Machine/dp/1098117220/). 9 | The book will be published in May 2023, but online versions should be available before that. 10 | 11 | ![Learning Ray](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/learning_ray.png) 12 | 13 | The book is organized to guide you chapter by chapter from core concepts of Ray to more sophisticated topics along the way. 14 | The first three chapters of the book teach the basics of Ray as a distributed Python framework with practical examples. 15 | Chapters four to ten introduce Ray's high-level libraries and show how to build applications with them. 16 | The last two chapters give you an overview of Ray's ecosystem and show you where to go next. 17 | Here's what you can expect from each chapter. 18 | 19 | * [_Chapter 1, An Overview of Ray_](https://github.com/maxpumperla/learning_ray/blob/main/notebooks/ch_01_overview.ipynb) 20 | Introduces you at a high level to all of Ray's components, how it can be used in 21 | machine learning and other tasks, what the Ray ecosystem currently looks like and how 22 | Ray as a whole fits into the landscape of distributed Python. 23 | * [_Chapter 2, Getting Started with Ray_](https://github.com/maxpumperla/learning_ray/blob/main/notebooks/ch_02_ray_core.ipynb) 24 | Walks you through the foundations of the Ray project, namely its low-level API. 25 | It also discussed how Ray Tasks and Actors naturally extend from Python functions and classes. 26 | You also learn about all of Ray's system components and how they work together. 27 | _* [_Chapter 3, Building Your First Distributed Application with Ray Core_](https://github.com/maxpumperla/learning_ray/blob/main/notebooks/ch_03_core_app.ipynb) 28 | Gives you an introduction to distributed systems and what makes them hard. 29 | We'll then build a first application together and discuss how to peak behind the scenes 30 | and get insights from the Ray toolbox. 31 | * [_Chapter 4, Reinforcement Learning with Ray RLlib_](https://github.com/maxpumperla/learning_ray/blob/main/notebooks/ch_04_rllib.ipynb) 32 | Gives you a quick introduction to reinforcement learning and shows how Ray implements 33 | important concepts in RLlib. After building some examples together, we'll also dive into 34 | more advanced topics like preprocessors, custom models, or working with offline data. 35 | * [_Chapter 5, Hyperparameter Optimization with Ray Tune_](https://github.com/maxpumperla/learning_ray/blob/main/notebooks/ch_05_tune.ipynb) 36 | Covers why efficiently tuning hyperparameters is hard, how Ray Tune works conceptually, 37 | and how you can use it in practice for your machine learning projects. 38 | * [_Chapter 6, Data Processing with Ray_](https://github.com/maxpumperla/learning_ray/blob/main/notebooks/ch_06_data_processing.ipynb) 39 | Introduces you to the Dataset abstraction of Ray and how it fits into the landscape 40 | of other data structures. You will also learn how to bring pandas data frames, Dask 41 | data structures and Apache Spark workloads to Ray. 42 | * [_Chapter 7, Distributed Training with Ray Train_](https://github.com/maxpumperla/learning_ray/blob/main/notebooks/ch_07_train.ipynb) 43 | Provides you with the basics of distributed model training and shows you how to use 44 | RaySGD with popular frameworks such as TensorFlow or PyTorch, and how to combine it 45 | with Ray Tune for hyperparameter optimization. 46 | * [_Chapter 9, Serving Models with Ray Serve_](https://github.com/maxpumperla/learning_ray/blob/main/notebooks/ch_08_model_serving.ipynb) 47 | Introduces you to model serving with Ray, why it works well within the framework, 48 | and how to do single-node and cluster deployment with it. 49 | * [_Chapter 9, Working with Ray Clusters_](https://github.com/maxpumperla/learning_ray/blob/main/notebooks/ch_09_script.ipynb) 50 | This chapter is all about how you configure, launch and scale Ray clusters for your applications. 51 | You'll learn about Ray's cluster launcher CLI and autoscaler, as well as how to set 52 | up clusters in the cloud and how to deploy on Kubernetes and other cluster managers. 53 | * [_Chapter 10, Getting Started with the Ray AI Runtime_](https://github.com/maxpumperla/learning_ray/blob/main/notebooks/ch_10_air.ipynb) 54 | Introduces you to Ray AIR, a unified toolkit for your ML workloads that offers many 55 | third party integrations for model training or accessing custom data sources. 56 | * [_Chapter 11, Ray's Ecosystem and Beyond_](https://github.com/maxpumperla/learning_ray/blob/main/notebooks/ch_11_ecosystem.ipynb) 57 | Gives you an overview of the many interesting extensions and 58 | integrations that Ray has attracted over the years._ -------------------------------------------------------------------------------- /notebooks/app.py: -------------------------------------------------------------------------------- 1 | # In this file we collect all the Serve deployments that get referenced in the chapter. 2 | # You can run any of these deployments by running `serve run app:`, 3 | # where is any of basic_deployment, scaled_deployment, 4 | # nlp_pipeline_driver, or batched_deployment. 5 | 6 | from fastapi import FastAPI 7 | from transformers import pipeline 8 | from ray import serve 9 | 10 | 11 | app = FastAPI() 12 | 13 | 14 | @serve.deployment 15 | class SentimentAnalysis: 16 | def __init__(self): 17 | self._classifier = pipeline("sentiment-analysis") 18 | 19 | def __call__(self, request) -> str: 20 | input_text = request.query_params["input_text"] 21 | return self._classifier(input_text)[0]["label"] 22 | 23 | 24 | basic_deployment = SentimentAnalysis.bind() 25 | 26 | 27 | @serve.deployment(num_replicas=2, ray_actor_options={"num_cpus": 2}) 28 | @serve.ingress(app) 29 | class SentimentAnalysis: 30 | def __init__(self): 31 | self._classifier = pipeline("sentiment-analysis") 32 | 33 | @app.get("/") 34 | def classify(self, input_text: str) -> str: 35 | import os 36 | print("from process:", os.getpid()) 37 | return self._classifier(input_text)[0]["label"] 38 | 39 | 40 | scaled_deployment = SentimentAnalysis.bind() 41 | 42 | 43 | @serve.deployment 44 | @serve.ingress(app) 45 | class SentimentAnalysis: 46 | def __init__(self): 47 | self._classifier = pipeline("sentiment-analysis") 48 | 49 | @serve.batch(max_batch_size=10, batch_wait_timeout_s=0.1) 50 | async def classify_batched(self, batched_inputs): 51 | print("Got batch size:", len(batched_inputs)) 52 | results = self._classifier(batched_inputs) 53 | return [result["label"] for result in results] 54 | 55 | @app.get("/") 56 | async def classify(self, input_text: str) -> str: 57 | return await self.classify_batched(input_text) 58 | 59 | 60 | batched_deployment = SentimentAnalysis.bind() 61 | 62 | 63 | from typing import Optional 64 | 65 | import wikipedia 66 | 67 | 68 | def fetch_wikipedia_page(search_term: str) -> Optional[str]: 69 | results = wikipedia.search(search_term) 70 | # If no results, return to caller. 71 | if len(results) == 0: 72 | return None 73 | 74 | # Get the page for the top result. 75 | return wikipedia.page(results[0]).content 76 | 77 | 78 | from ray import serve 79 | from transformers import pipeline 80 | from typing import List 81 | 82 | 83 | @serve.deployment 84 | class SentimentAnalysis: 85 | def __init__(self): 86 | self._classifier = pipeline("sentiment-analysis") 87 | 88 | @serve.batch(max_batch_size=10, batch_wait_timeout_s=0.1) 89 | async def is_positive_batched(self, inputs: List[str]) -> List[bool]: 90 | results = self._classifier(inputs, truncation=True) 91 | return [result["label"] == "POSITIVE" for result in results] 92 | 93 | async def __call__(self, input_text: str) -> bool: 94 | return await self.is_positive_batched(input_text) 95 | 96 | 97 | @serve.deployment(num_replicas=2) 98 | class Summarizer: 99 | def __init__(self, max_length: Optional[int] = None): 100 | self._summarizer = pipeline("summarization") 101 | self._max_length = max_length 102 | 103 | def __call__(self, input_text: str) -> str: 104 | result = self._summarizer( 105 | input_text, max_length=self._max_length, truncation=True) 106 | return result[0]["summary_text"] 107 | 108 | 109 | @serve.deployment 110 | class EntityRecognition: 111 | def __init__(self, threshold: float = 0.90, max_entities: int = 10): 112 | self._entity_recognition = pipeline("ner") 113 | self._threshold = threshold 114 | self._max_entities = max_entities 115 | 116 | def __call__(self, input_text: str) -> List[str]: 117 | final_results = [] 118 | for result in self._entity_recognition(input_text): 119 | if result["score"] > self._threshold: 120 | final_results.append(result["word"]) 121 | if len(final_results) == self._max_entities: 122 | break 123 | 124 | return final_results 125 | 126 | 127 | from pydantic import BaseModel 128 | 129 | 130 | class Response(BaseModel): 131 | success: bool 132 | message: str = "" 133 | summary: str = "" 134 | named_entities: List[str] = [] 135 | 136 | 137 | from fastapi import FastAPI 138 | 139 | app = FastAPI() 140 | 141 | 142 | @serve.deployment 143 | @serve.ingress(app) 144 | class NLPPipelineDriver: 145 | def __init__(self, sentiment_analysis, summarizer, entity_recognition): 146 | self._sentiment_analysis = sentiment_analysis 147 | self._summarizer = summarizer 148 | self._entity_recognition = entity_recognition 149 | 150 | @app.get("/", response_model=Response) 151 | async def summarize_article(self, search_term: str) -> Response: 152 | # Fetch the top page content for the search term if found. 153 | page_content = fetch_wikipedia_page(search_term) 154 | if page_content is None: 155 | return Response(success=False, message="No pages found.") 156 | 157 | # Conditionally continue based on the sentiment analysis. 158 | is_positive = await self._sentiment_analysis.remote(page_content) 159 | if not is_positive: 160 | return Response(success=False, message="Only positivitiy allowed!") 161 | 162 | # Query the summarizer and named entity recognition models in parallel. 163 | summary_result = self._summarizer.remote(page_content) 164 | entities_result = self._entity_recognition.remote(page_content) 165 | return Response( 166 | success=True, 167 | summary=await summary_result, 168 | named_entities=await entities_result 169 | ) 170 | 171 | sentiment_analysis = SentimentAnalysis.bind() 172 | summarizer = Summarizer.bind() 173 | entity_recognition = EntityRecognition.bind(threshold=0.95, max_entities=5) 174 | nlp_pipeline_driver = NLPPipelineDriver.bind( 175 | sentiment_analysis, summarizer, entity_recognition) 176 | 177 | -------------------------------------------------------------------------------- /notebooks/assets/cifar_frog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/assets/cifar_frog.png -------------------------------------------------------------------------------- /notebooks/assets/cifar_horse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/assets/cifar_horse.png -------------------------------------------------------------------------------- /notebooks/assets/cifar_truck.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/assets/cifar_truck.png -------------------------------------------------------------------------------- /notebooks/assets/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/assets/favicon.ico -------------------------------------------------------------------------------- /notebooks/assets/logo-no-text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/assets/logo-no-text.png -------------------------------------------------------------------------------- /notebooks/cartpole.yml: -------------------------------------------------------------------------------- 1 | cartpole-ppo: 2 | env: CartPole-v1 3 | run: PPO 4 | stop: 5 | episode_reward_mean: 150 6 | timesteps_total: 100000 7 | config: 8 | framework: tf 9 | gamma: 0.99 10 | lr: 0.0003 11 | num_workers: 1 12 | observation_filter: MeanStdFilter 13 | num_sgd_iter: 6 14 | vf_loss_coeff: 0.01 15 | model: 16 | fcnet_hiddens: [32] 17 | fcnet_activation: linear 18 | vf_share_layers: true 19 | enable_connectors: True 20 | -------------------------------------------------------------------------------- /notebooks/ch_03_core_app.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "1bb6bfd0", 6 | "metadata": {}, 7 | "source": [ 8 | "# Building Your First Distributed Application With Ray Core" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "b7219453", 14 | "metadata": {}, 15 | "source": [ 16 | "Having learned the fundamentals of the Ray API, we will now use it to create a more practical project. By the end of this chapter, you will have created a reinforcement learning problem from scratch, implemented an algorithm to solve it, and used Ray tasks and actors to parallelize the solution across a local cluster, all within 250 lines of code." 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "ab77bcd8", 22 | "metadata": {}, 23 | "source": [ 24 | "This chapter aims to assist readers who are new to reinforcement learning. We will work through a simple problem and learn the necessary skills to solve it through hands-on practice. Advanced topics and terminology related to reinforcement learning will not be covered in this chapter, as Chapter 4 is dedicated to those subjects. However, even those who are more experienced with reinforcement learning may find value in implementing a traditional algorithm in a distributed environment." 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "fe6d0897", 30 | "metadata": {}, 31 | "source": [ 32 | "The current chapter is focused on using Ray Core exclusively. It is hoped that the reader will come to understand the versatility and efficiency of Ray Core, particularly in regards to conducting distributed experiments that would otherwise require a significant amount of effort to set up. However, before moving on to implementation, it is worth briefly discussing the concept of reinforcement learning in greater detail. If you have prior experience with RL, feel free to skip this section." 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "id": "0e1c204e", 38 | "metadata": {}, 39 | "source": [ 40 | "\n", 41 | "You can run this notebook directly in\n", 42 | "[Colab](https://colab.research.google.com/github/maxpumperla/learning_ray/blob/main/notebooks/ch_03_core_app.ipynb). \n", 43 | "\"Open\n", 44 | "\n", 45 | "In any case, make sure you have Ray installed first:" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "id": "1540d697", 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "! pip install \"ray==2.2.0\"" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "id": "0939c69b", 61 | "metadata": {}, 62 | "source": [ 63 | "## Introducing Reinforcement Learning\n", 64 | "\n", 65 | "There is an app on my phone that is one of my favorites because it is able to accurately identify and label different plants in my garden just by being shown a picture of the plant. This is very useful for me because I am not good at telling them apart. In recent years, there have been a lot of impressive apps like this one that have been developed." 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "id": "611bdb13", 71 | "metadata": {}, 72 | "source": [ 73 | "Ultimately, the aim of AI is to create intelligent agents that are capable of much more than just classifying objects. Imagine an AI application that not only recognizes your plants, but is also able to take care of them. In order to do this, the AI would need to be able to:\n", 74 | "\n", 75 | "- Function in dynamic environments, such as changes in seasons\n", 76 | "- React to changes in the environment, like severe weather or pests affecting the plants\n", 77 | "- Take a series of actions, such as watering and fertilizing the plants\n", 78 | "- Achieve long-term goals, such as prioritizing the health of the plants." 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "id": "88ce5cd0", 84 | "metadata": {}, 85 | "source": [ 86 | "An AI that observes its environment would be able to learn to explore possible actions and improve its solutions over time. For instance, it could be used for managing and optimizing a supply chain, restocking a warehouse based on fluctuating demand, or coordinating the processing steps in an assembly line. The \"Coffee Test\" proposed by Stephen Wozniak is another example of what an AI could be capable of - finding and using a coffee machine and all necessary ingredients to brew a cup of coffee, although it may not be able to sit down and enjoy it. Can you think of any other examples that fit these criteria?" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "id": "fad0ad03", 92 | "metadata": {}, 93 | "source": [ 94 | "The above requirements can be understood as part of a machine learning subfield called reinforcement learning (RL). RL involves agents interacting with their environment by observing it and taking actions. In RL, agents evaluate their environment by assigning a reward value to certain outcomes (e.g., the health of a plant on a linear scale). The term \"reinforcement\" refers to the idea that agents will hopefully learn to engage in behaviors that lead to desirable outcomes (high reward) and avoid negative or punishing situations (low or negative reward)." 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "id": "17470bcc", 100 | "metadata": {}, 101 | "source": [ 102 | "To better understand how agents interact with their environment, it is common to create a computer simulation of it. However, it is not always possible to do so. To provide an example of this in practice, we will build a simulation where agents interact with their environment together." 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "id": "1ed1b5a9", 108 | "metadata": {}, 109 | "source": [ 110 | "## Setting Up a Simple Maze Problem\n", 111 | "\n", 112 | "The app we are developing consists of a 2D maze game in which a single player can move in four directions. The maze is set up as a 5x5 grid, with one of the 25 cells being the \"goal\" that the player, called the \"seeker,\" must reach. Instead of providing a pre-determined solution, we will use a reinforcement learning algorithm so that the seeker can learn how to find the goal through repeated simulations of the maze. The seeker will be rewarded for reaching the goal and the algorithm will track which decisions were successful and which were not. To make the process more efficient, we will use the Ray API to parallelize both the simulations and the training of the RL algorithm." 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "id": "b00ffc33", 118 | "metadata": {}, 119 | "source": [ 120 | "We will continue to use local clusters for now, rather than deploying the application on an actual Ray cluster made up of multiple nodes. If you want to learn about setting up Ray clusters and are interested in infrastructure topics, you can skip ahead to Chapter 9. However, make sure you have Ray installed by running the command `pip install ray`." 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "id": "4db17446", 126 | "metadata": {}, 127 | "source": [ 128 | "We will begin by creating the 2D maze that we previously discussed. This involves creating a 5x5 grid in Python, which starts at (0, 0) and ends at (4, 4). We need to also define how the player can move around the grid. To do this, we will use a class called Discrete to represent the four cardinal directions of movement: up, down, left, and right. This class will allow us to move in multiple directions, rather than just four. Don't worry, we will need this generalized Discrete class later on in the process." 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "id": "bb000a54", 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "import random\n", 139 | "\n", 140 | "\n", 141 | "class Discrete:\n", 142 | " def __init__(self, num_actions: int):\n", 143 | " \"\"\" Discrete action space for num_actions.\n", 144 | " Discrete(4) can be used as encoding moving in\n", 145 | " one of the cardinal directions.\n", 146 | " \"\"\"\n", 147 | " self.n = num_actions\n", 148 | "\n", 149 | " def sample(self):\n", 150 | " return random.randint(0, self.n - 1)\n", 151 | "\n", 152 | "\n", 153 | "space = Discrete(4)\n", 154 | "print(space.sample())" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "id": "94e83338", 160 | "metadata": {}, 161 | "source": [ 162 | "Sampling from a Discrete(4) distribution will randomly generate one of the numbers 0, 1, 2, or 3. These numbers can be interpreted in any way we choose, such as representing the directions \"down,\" \"left,\" \"right,\" and \"up,\" respectively. In order to create a maze and set the position of the player and the goal, we will create a Python class called Environment. This class will be named \"Environment\" because the maze serves as the environment in which the player exists. To make things easier, we will always place the player at the coordinates (0, 0) and the goal at (4, 4). In order to make the player move and attempt to reach the goal, we will initialize the Environment with an action space of Discrete(4)." 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "id": "9789ed9f", 168 | "metadata": {}, 169 | "source": [ 170 | "To complete the setup for our maze environment, we need to encode the seeker's position as a `Discrete(5*5)`. This will allow us to later implement an algorithm that keeps track of which actions lead to successful outcomes for different seeker positions. In reinforcement learning terminology, the information that is accessible to the player is known as an observation. Similarly, we can define an observation space for the seeker. The following code demonstrates this concept:" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "id": "e443af4f", 177 | "metadata": { 178 | "lines_to_next_cell": 1 179 | }, 180 | "outputs": [], 181 | "source": [ 182 | "import os\n", 183 | "\n", 184 | "\n", 185 | "class Environment:\n", 186 | " def __init__(self, *args, **kwargs):\n", 187 | " self.seeker, self.goal = (0, 0), (4, 4)\n", 188 | " self.info = {'seeker': self.seeker, 'goal': self.goal}\n", 189 | "\n", 190 | " self.action_space = Discrete(4)\n", 191 | " self.observation_space = Discrete(5*5)\n", 192 | "\n", 193 | " def reset(self):\n", 194 | " \"\"\"Reset seeker position and return observations.\"\"\"\n", 195 | " self.seeker = (0, 0)\n", 196 | "\n", 197 | " return self.get_observation()\n", 198 | "\n", 199 | " def get_observation(self):\n", 200 | " \"\"\"Encode the seeker position as integer\"\"\"\n", 201 | " return 5 * self.seeker[0] + self.seeker[1]\n", 202 | "\n", 203 | " def get_reward(self):\n", 204 | " \"\"\"Reward finding the goal\"\"\"\n", 205 | " return 1 if self.seeker == self.goal else 0\n", 206 | "\n", 207 | " def is_done(self):\n", 208 | " \"\"\"We're done if we found the goal\"\"\"\n", 209 | " return self.seeker == self.goal\n", 210 | "\n", 211 | " def step(self, action):\n", 212 | " \"\"\"Take a step in a direction and return all available information.\"\"\"\n", 213 | " if action == 0: # move down\n", 214 | " self.seeker = (min(self.seeker[0] + 1, 4), self.seeker[1])\n", 215 | " elif action == 1: # move left\n", 216 | " self.seeker = (self.seeker[0], max(self.seeker[1] - 1, 0))\n", 217 | " elif action == 2: # move up\n", 218 | " self.seeker = (max(self.seeker[0] - 1, 0), self.seeker[1])\n", 219 | " elif action == 3: # move right\n", 220 | " self.seeker = (self.seeker[0], min(self.seeker[1] + 1, 4))\n", 221 | " else:\n", 222 | " raise ValueError(\"Invalid action\")\n", 223 | "\n", 224 | " obs = self.get_observation()\n", 225 | " rew = self.get_reward()\n", 226 | " done = self.is_done()\n", 227 | " return obs, rew, done, self.info\n", 228 | "\n", 229 | " def render(self, *args, **kwargs):\n", 230 | " \"\"\"We override this method here so clear the output in Jupyter notebooks.\n", 231 | " The previous implementation works well in the terminal, but does not clear\n", 232 | " the screen in interactive environments.\n", 233 | " \"\"\"\n", 234 | " os.system('cls' if os.name == 'nt' else 'clear')\n", 235 | " try:\n", 236 | " from IPython.display import clear_output\n", 237 | " clear_output(wait=True)\n", 238 | " except Exception:\n", 239 | " pass\n", 240 | " grid = [['| ' for _ in range(5)] + [\"|\\n\"] for _ in range(5)]\n", 241 | " grid[self.goal[0]][self.goal[1]] = '|G'\n", 242 | " grid[self.seeker[0]][self.seeker[1]] = '|S'\n", 243 | " print(''.join([''.join(grid_row) for grid_row in grid]))" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "id": "c456be70", 249 | "metadata": {}, 250 | "source": [ 251 | "We have finished building the Environment class that is used in our 2D-maze game. This class allows us to move through the game, determine when it has ended, and reset it. The player, referred to as the seeker, can also view the game's environment and receive rewards for reaching the goal. Now, we can use this implementation to play a game of finding the goal using a seeker that randomly selects actions by creating a new Environment, applying actions to it, and displaying the environment until the game ends." 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "id": "0f8dc78b", 258 | "metadata": { 259 | "lines_to_next_cell": 2 260 | }, 261 | "outputs": [], 262 | "source": [ 263 | "import time\n", 264 | "\n", 265 | "environment = Environment()\n", 266 | "\n", 267 | "while not environment.is_done():\n", 268 | " random_action = environment.action_space.sample()\n", 269 | " environment.step(random_action)\n", 270 | " time.sleep(0.1)\n", 271 | " environment.render()" 272 | ] 273 | }, 274 | { 275 | "cell_type": "markdown", 276 | "id": "8696b78b", 277 | "metadata": {}, 278 | "source": [ 279 | "If you run the program on your computer, you will eventually see that the seeker has found the goal and the game is over. It may take some time if you are unlucky. While you may argue that this is a very simple problem that can be solved by simply taking 8 steps (4 right and 4 down in any order), the purpose of using machine learning in this situation is to be able to tackle more difficult problems in the future. The idea is to create an algorithm that can figure out how to play the game on its own by observing the game, making decisions about what to do next, and receiving rewards for its actions." 280 | ] 281 | }, 282 | { 283 | "cell_type": "markdown", 284 | "id": "3b35b9a7", 285 | "metadata": {}, 286 | "source": [ 287 | "If you are interested in doing so, now is an appropriate time to make the game more complex on your own. As long as you do not alter the interface established for the Environment class, you have the option to modify the game in numerous ways. Here are a few ideas:\n", 288 | "\n", 289 | "- Make the grid a 10x10 size or randomly determine the initial position of the seeker.\n", 290 | "- Make the outer walls of the grid hazardous. If you try to touch them, you will receive a penalty of -100.\n", 291 | "- Add obstacles in the grid that the seeker cannot pass through.\n", 292 | "\n", 293 | "If you are feeling particularly adventurous, you could also randomly determine the goal position. However, be mindful that the seeker currently has no information about the goal position through the get_observation method. You may want to consider tackling this last exercise after you have completed reading this chapter." 294 | ] 295 | }, 296 | { 297 | "cell_type": "markdown", 298 | "id": "288e6f49", 299 | "metadata": {}, 300 | "source": [ 301 | "## Building a Simulation" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "id": "0cb29ae0", 307 | "metadata": {}, 308 | "source": [ 309 | "Now that the Environment class has been implemented, what is required to help the seeker learn how to play the game effectively? How can it consistently find the goal in the minimum required number of eight steps? To assist with this, we have provided the maze environment with reward information so that the seeker can use this to learn how to play the game. In reinforcement learning, the player repeatedly plays the game and learns from their experiences. The player is often referred to as an agent that takes actions in the environment, observes its state, and receives a reward. The better the agent learns, the better it becomes at interpreting the current game state and finding actions that lead to more rewarding outcomes." 310 | ] 311 | }, 312 | { 313 | "cell_type": "markdown", 314 | "id": "25db3f64", 315 | "metadata": {}, 316 | "source": [ 317 | "In order to use any reinforcement learning algorithm, it is necessary to have a way to simulate the game repeatedly in order to gather experience data. Therefore, we will be creating a simple Simulation class soon. Additionally, we need the concept of a Policy, which is a way to determine the actions to take in a game. Currently, the only option we have is to randomly sample actions for our seeker. However, a `Policy` allows us to select better actions based on the current state of the game. A `Policy` is defined as a class with a `get_action` method that takes a game state and returns an action." 318 | ] 319 | }, 320 | { 321 | "cell_type": "markdown", 322 | "id": "db748def", 323 | "metadata": {}, 324 | "source": [ 325 | "In our game, the seeker has `25` possible states on the grid and can take 4 actions. One strategy is to assign high values to state-action pairs that will result in a high reward and low values to those that will not. For example, moving down or right is usually a good choice, while moving left or up is not. We can create a `25x4` table of all possible state-action pairs and store it in our `Policy`. Then, when given a state, our policy can return the highest value for any action. While figuring out the best values for these pairs is a challenge, we can start by implementing this `Policy` and worry about finding a suitable algorithm later." 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": null, 331 | "id": "0d8eecd1", 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "import numpy as np\n", 336 | "\n", 337 | "class Policy:\n", 338 | "\n", 339 | " def __init__(self, env):\n", 340 | " \"\"\"A Policy suggests actions based on the current state.\n", 341 | " We do this by tracking the value of each state-action pair.\n", 342 | " \"\"\"\n", 343 | " self.state_action_table = [\n", 344 | " [0 for _ in range(env.action_space.n)]\n", 345 | " for _ in range(env.observation_space.n)\n", 346 | " ]\n", 347 | " self.action_space = env.action_space\n", 348 | "\n", 349 | " def get_action(self, state, explore=True, epsilon=0.1):\n", 350 | " \"\"\"Explore randomly or exploit the best value currently available.\"\"\"\n", 351 | " if explore and random.uniform(0, 1) < epsilon:\n", 352 | " return self.action_space.sample()\n", 353 | " return np.argmax(self.state_action_table[state])" 354 | ] 355 | }, 356 | { 357 | "cell_type": "markdown", 358 | "id": "23604986", 359 | "metadata": {}, 360 | "source": [ 361 | "I've included a small detail in the Policy definition that could potentially be confusing. The `get_action` method has an explore parameter. The purpose of this is to allow for exploration in situations where the current policy is not effective, such as when it always instructs the player to move left. In other words, sometimes it is necessary to try new approaches instead of relying solely on the current understanding of the game. As previously mentioned, we have not yet discussed how to improve the values in the `state_action_table` for the policy. For now, just keep in mind that the policy provides the actions to follow when playing the maze game." 362 | ] 363 | }, 364 | { 365 | "cell_type": "markdown", 366 | "id": "9143c282", 367 | "metadata": {}, 368 | "source": [ 369 | "The Simulation class is responsible for running a simulation of a game by taking in an Environment and following a given Policy until the goal is achieved and the game is completed. This process, known as a \"rollout,\" generates a collection of observations and actions, which are referred to as the \"experience\" gained from the simulation. The Simulation class includes a rollout method that executes a full game and returns the resulting experience. The following code represents the implementation of the `Simulation` class:" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": null, 375 | "id": "e4de6020", 376 | "metadata": {}, 377 | "outputs": [], 378 | "source": [ 379 | "class Simulation(object):\n", 380 | " def __init__(self, env):\n", 381 | " \"\"\"Simulates rollouts of an environment, given a policy to follow.\"\"\"\n", 382 | " self.env = env\n", 383 | "\n", 384 | " def rollout(self, policy, render=False, explore=True, epsilon=0.1):\n", 385 | " \"\"\"Returns experiences for a policy rollout.\"\"\"\n", 386 | " experiences = []\n", 387 | " state = self.env.reset()\n", 388 | " done = False\n", 389 | " while not done:\n", 390 | " action = policy.get_action(state, explore, epsilon)\n", 391 | " next_state, reward, done, info = self.env.step(action)\n", 392 | " experiences.append([state, action, reward, next_state])\n", 393 | " state = next_state\n", 394 | " if render:\n", 395 | " time.sleep(0.05)\n", 396 | " self.env.render()\n", 397 | "\n", 398 | " return experiences" 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "id": "0ca73bfa", 404 | "metadata": {}, 405 | "source": [ 406 | "We will be using a specific algorithm to learn from the `experiences` we collect in a rollout, which are comprised of four values: the current state, the action taken, the reward received, and the next state. These values are necessary for our algorithm, but other algorithms may require different experience values. Although our policy has not yet learned anything, we can test its interface by creating a Simulation object, using the rollout method on the policy, and then printing out the `state_action_table` of it." 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": null, 412 | "id": "3503b079", 413 | "metadata": { 414 | "lines_to_next_cell": 2 415 | }, 416 | "outputs": [], 417 | "source": [ 418 | "untrained_policy = Policy(environment)\n", 419 | "sim = Simulation(environment)\n", 420 | "\n", 421 | "exp = sim.rollout(untrained_policy, render=True, epsilon=1.0)\n", 422 | "for row in untrained_policy.state_action_table:\n", 423 | " print(row)" 424 | ] 425 | }, 426 | { 427 | "cell_type": "markdown", 428 | "id": "7b067645", 429 | "metadata": {}, 430 | "source": [ 431 | "In order to accurately address the issue, both simulation and a policy were utilized. The only remaining task is to create a clever method for updating the policy's internal state based on the collected data so that it is able to effectively learn how to play the maze game." 432 | ] 433 | }, 434 | { 435 | "cell_type": "markdown", 436 | "id": "be4140bb", 437 | "metadata": {}, 438 | "source": [ 439 | "## Training a Reinforcement Learning Model\n", 440 | "\n", 441 | "Suppose we have a set of experiences from a few games. How can we effectively update the values in the `state_action_table` of our `Policy`? One way to do this is by considering a specific situation, such as being at position `(3,5)` and deciding to go right, which brings us to position `(4,5)` just one step away from the goal. In this case, continuing to go right would result in a reward of `1`, indicating that the current state `(3,5)` combined with the action of going right should have a high value. On the other hand, going left in the same situation does not lead to any reward, so the value of the state-action pair involving left movement should be low." 442 | ] 443 | }, 444 | { 445 | "cell_type": "markdown", 446 | "id": "2c55c7d5", 447 | "metadata": {}, 448 | "source": [ 449 | "The expected reward from taking an action from the next state by consulting our state_action_table for the policy. This allows us to evaluate the potential benefits of taking a certain action after reaching the next state. Essentially, this is how we define an experience – by being in a specific state, taking an action that leads to a reward, and then transitioning to the next state.\n", 450 | "\n", 451 | "```{python}\n", 452 | "next_max = np.max(policy.state_action_table[next_state])\n", 453 | "```" 454 | ] 455 | }, 456 | { 457 | "cell_type": "markdown", 458 | "id": "0e4811d3", 459 | "metadata": {}, 460 | "source": [ 461 | "There are several ways to compare the knowledge of this value to the current state-action value, which is the value stored in the policy's state-action table for the given state and action. One option is to calculate a weighted sum of the old value and the expected value, using a formula such as `new_value = 0.9 * value + 0.1 * next_max`. The weights of `0.9` and `0.1` have been chosen to reflect the preference for keeping the old value, and the important factor is that the weights sum to 1. However, this approach does not take into account the important information from the reward, which should be given more trust than the projected `next_max` value. To account for this, it may be beneficial to discount the `next_max` value by 10%. The updated state-action value would then be calculated as follows:\n", 462 | "\n", 463 | "```{python}\n", 464 | "new_value = 0.9 * value + 0.1 * (reward + 0.9 * next_max)\n", 465 | "```" 466 | ] 467 | }, 468 | { 469 | "cell_type": "markdown", 470 | "id": "48569a3f", 471 | "metadata": {}, 472 | "source": [ 473 | "If you have a lot of experience with reasoning like this, the previous paragraphs may be overwhelming. However, if you have understood the explanations until now, you will probably find the rest of this chapter easy to follow. Mathematically, this was the most difficult part of this example. If you have worked with RL before, you will have realized that this is an implementation of the Q-Learning algorithm. It is called this because the state-action table can be represented as a function `Q(state, action)` that returns values for these pairs.\n", 474 | "\n", 475 | "We’re almost there, so let’s formalize this procedure by implementing an `update_policy` function for a policy and collected experiences:" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": null, 481 | "id": "e45eab9e", 482 | "metadata": {}, 483 | "outputs": [], 484 | "source": [ 485 | "def update_policy(policy, experiences, weight=0.1, discount_factor=0.9):\n", 486 | " \"\"\"Updates a given policy with a list of (state, action, reward, state)\n", 487 | " experiences.\"\"\"\n", 488 | " for state, action, reward, next_state in experiences:\n", 489 | " next_max = np.max(policy.state_action_table[next_state])\n", 490 | " value = policy.state_action_table[state][action]\n", 491 | " new_value = (1 - weight) * value + weight * \\\n", 492 | " (reward + discount_factor * next_max)\n", 493 | " policy.state_action_table[state][action] = new_value" 494 | ] 495 | }, 496 | { 497 | "cell_type": "markdown", 498 | "id": "41e925da", 499 | "metadata": {}, 500 | "source": [ 501 | "With this we can easily define a function to train a policy as follows. The `train_policy` function follows the steps of initializing a policy and a simulation, running the simulation multiple times (in this case, 10000 times), collecting the experiences for each game through a rollout, and updating the policy using the update_policy function with the collected experiences." 502 | ] 503 | }, 504 | { 505 | "cell_type": "code", 506 | "execution_count": null, 507 | "id": "78d7fc03", 508 | "metadata": { 509 | "lines_to_next_cell": 2 510 | }, 511 | "outputs": [], 512 | "source": [ 513 | "def train_policy(env, num_episodes=10000, weight=0.1, discount_factor=0.9):\n", 514 | " \"\"\"Training a policy by updating it with rollout experiences.\"\"\"\n", 515 | " policy = Policy(env)\n", 516 | " sim = Simulation(env)\n", 517 | " for _ in range(num_episodes):\n", 518 | " experiences = sim.rollout(policy)\n", 519 | " update_policy(policy, experiences, weight, discount_factor)\n", 520 | "\n", 521 | " return policy\n", 522 | "\n", 523 | "\n", 524 | "trained_policy = train_policy(environment)" 525 | ] 526 | }, 527 | { 528 | "cell_type": "markdown", 529 | "id": "c0af2463", 530 | "metadata": {}, 531 | "source": [ 532 | "In the field of RL, a full play-through of the maze game is referred to as an episode. That is why the train_policy function has an argument called num_episodes, rather than num_games. Now that we have a trained policy, we want to see how well it performs. Previously in this chapter, we ran random policies a couple of times to get an idea of their effectiveness in the maze problem. However, we now want to properly evaluate our trained policy over multiple games and see how it performs on average. Specifically, we will run our simulation for several episodes and measure the number of steps it takes to reach the goal in each episode. To do this, we will implement an evaluate_policy function that accomplishes this task." 533 | ] 534 | }, 535 | { 536 | "cell_type": "code", 537 | "execution_count": null, 538 | "id": "d9329a87", 539 | "metadata": { 540 | "lines_to_next_cell": 2 541 | }, 542 | "outputs": [], 543 | "source": [ 544 | "def evaluate_policy(env, policy, num_episodes=10):\n", 545 | " \"\"\"Evaluate a trained policy through rollouts.\"\"\"\n", 546 | " simulation = Simulation(env)\n", 547 | " steps = 0\n", 548 | "\n", 549 | " for _ in range(num_episodes):\n", 550 | " experiences = simulation.rollout(policy, render=True, explore=False)\n", 551 | " steps += len(experiences)\n", 552 | "\n", 553 | " print(f\"{steps / num_episodes} steps on average \"\n", 554 | " f\"for a total of {num_episodes} episodes.\")\n", 555 | "\n", 556 | " return steps / num_episodes\n", 557 | "\n", 558 | "\n", 559 | "evaluate_policy(environment, trained_policy)" 560 | ] 561 | }, 562 | { 563 | "cell_type": "markdown", 564 | "id": "703b9258", 565 | "metadata": {}, 566 | "source": [ 567 | "To summarize, the policy that has been trained is capable of finding the best solutions for the maze game. This means that you have successfully implemented your first RL algorithm from scratch. Now, consider whether the evaluation function would still be effective if the seeker was placed in random starting positions. Try making the necessary adjustments to find out. Additionally, think about the assumptions that were made when designing the algorithm, such as the assumption that all state-action pairs can be listed. Would this algorithm still work effectively if there were millions of states and thousands of actions?" 568 | ] 569 | }, 570 | { 571 | "cell_type": "markdown", 572 | "id": "d8eac575", 573 | "metadata": {}, 574 | "source": [ 575 | "## Building a Distributed Ray App" 576 | ] 577 | }, 578 | { 579 | "cell_type": "markdown", 580 | "id": "66f902cc", 581 | "metadata": {}, 582 | "source": [ 583 | "You might be wondering how this example relates to Ray. To turn this RL experiment into a distributed Ray application, we only need to write three code snippets. First, we will make the Simulation a Ray actor with a few lines of code. Next, we will define a parallel version of train_policy that is similar in structure to the original, but only parallelizes the rollouts, not the policy updates. Finally, we will train and evaluate the policy as before, but using the parallel version of train_policy.\n", 584 | "\n", 585 | "The first step is to implement a Ray actor called `SimulationActor`:" 586 | ] 587 | }, 588 | { 589 | "cell_type": "code", 590 | "execution_count": null, 591 | "id": "91d688d9", 592 | "metadata": {}, 593 | "outputs": [], 594 | "source": [ 595 | "import ray\n", 596 | "\n", 597 | "ray.init()\n", 598 | "\n", 599 | "@ray.remote\n", 600 | "class SimulationActor(Simulation):\n", 601 | " \"\"\"Ray actor for a Simulation.\"\"\"\n", 602 | " def __init__(self):\n", 603 | " env = Environment()\n", 604 | " super().__init__(env)" 605 | ] 606 | }, 607 | { 608 | "cell_type": "markdown", 609 | "id": "c00550e7", 610 | "metadata": {}, 611 | "source": [ 612 | "You should be able to understand the code presented in this section thanks to the knowledge of Ray Core that you gained in Chapter 2. While it may take some time to become comfortable with writing this type of code yourself, the concepts should be familiar to you. In the following example, we will demonstrate how to use a local Ray cluster to distribute the workload of reinforcement learning (RL) by creating a policy on the driver and four `SimulationActor` instances to perform distributed rollouts. We will store the policy in the object store with ray.put and pass it as an argument to the remote rollout calls to gather experiences over a specified number of training episodes. The finished rollouts will be retrieved with ray.wait, taking into account that some may finish before others, and the policy will be updated with the collected experiences. Finally, the trained policy will be returned." 613 | ] 614 | }, 615 | { 616 | "cell_type": "code", 617 | "execution_count": null, 618 | "id": "dc254782", 619 | "metadata": {}, 620 | "outputs": [], 621 | "source": [ 622 | "def train_policy_parallel(env, num_episodes=1000, num_simulations=4):\n", 623 | " \"\"\"Parallel policy training function.\"\"\"\n", 624 | " policy = Policy(env)\n", 625 | " simulations = [SimulationActor.remote() for _ in range(num_simulations)]\n", 626 | "\n", 627 | " policy_ref = ray.put(policy)\n", 628 | " for _ in range(num_episodes):\n", 629 | " experiences = [sim.rollout.remote(policy_ref) for sim in simulations]\n", 630 | "\n", 631 | " while len(experiences) > 0:\n", 632 | " finished, experiences = ray.wait(experiences)\n", 633 | " for xp in ray.get(finished):\n", 634 | " update_policy(policy, xp)\n", 635 | "\n", 636 | " return policy" 637 | ] 638 | }, 639 | { 640 | "cell_type": "markdown", 641 | "id": "8a5acf70", 642 | "metadata": {}, 643 | "source": [ 644 | "This allows us to take the last step and run the training procedure in parallel and then\n", 645 | "evaluate the resulting as before:" 646 | ] 647 | }, 648 | { 649 | "cell_type": "code", 650 | "execution_count": null, 651 | "id": "43af8205", 652 | "metadata": {}, 653 | "outputs": [], 654 | "source": [ 655 | "parallel_policy = train_policy_parallel(environment)\n", 656 | "evaluate_policy(environment, parallel_policy)" 657 | ] 658 | }, 659 | { 660 | "cell_type": "markdown", 661 | "id": "97ebb498", 662 | "metadata": {}, 663 | "source": [ 664 | "The output of the two lines is the same as when we previously ran the single version of the RL training for the maze. It's helpful to compare `train_policy_parallel` and `train_policy` line by line because they have the same overall structure. All we had to do to parallelize the training process was to use the `ray.remote` decorator on a class appropriately and then make the correct remote calls. It's helpful to have some experience to do this correctly, but it's worth noting how little time we spent considering distributed computing and how much time we were able to focus on the actual application code. We didn't need to completely change our programming approach and were able to handle the problem in a natural way. This is what we want and Ray excels at providing this kind of flexibility. \n", 665 | "\n", 666 | "To conclude, let's briefly look at the execution graph of the Ray application we just created, as shown in the following figure." 667 | ] 668 | }, 669 | { 670 | "cell_type": "markdown", 671 | "id": "5eaaf094", 672 | "metadata": {}, 673 | "source": [ 674 | "![Task dependency](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_03/train_policy.png)" 675 | ] 676 | } 677 | ], 678 | "metadata": { 679 | "jupytext": { 680 | "cell_metadata_filter": "-all", 681 | "main_language": "python", 682 | "notebook_metadata_filter": "-all" 683 | }, 684 | "kernelspec": { 685 | "display_name": "Python 3 (ipykernel)", 686 | "language": "python", 687 | "name": "python3" 688 | }, 689 | "language_info": { 690 | "codemirror_mode": { 691 | "name": "ipython", 692 | "version": 3 693 | }, 694 | "file_extension": ".py", 695 | "mimetype": "text/x-python", 696 | "name": "python", 697 | "nbconvert_exporter": "python", 698 | "pygments_lexer": "ipython3", 699 | "version": "3.9.13" 700 | } 701 | }, 702 | "nbformat": 4, 703 | "nbformat_minor": 5 704 | } 705 | -------------------------------------------------------------------------------- /notebooks/ch_04_rllib.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "780507ca", 6 | "metadata": {}, 7 | "source": [ 8 | "# Reinforcement Learning with Ray RLlib" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "2856c7b0", 14 | "metadata": {}, 15 | "source": [ 16 | "\n", 17 | "You can run this notebook directly in\n", 18 | "[Colab](https://colab.research.google.com/github/maxpumperla/learning_ray/blob/main/notebooks/ch_04_rllib.ipynb).\n", 19 | "\n", 20 | "\"Open" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "3ea5e1f7", 26 | "metadata": {}, 27 | "source": [ 28 | "For this chapter you need to install the following dependencies:" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "a9d6d409", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "! pip install \"ray[rllib]==2.2.0\"" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "id": "8b17e3aa", 44 | "metadata": {}, 45 | "source": [ 46 | "\n", 47 | "To import utility files for this chapter, on Colab you will also have to clone\n", 48 | "the repo and copy the code files to the base path of the runtime:" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "id": "8206edba", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "!git clone https://github.com/maxpumperla/learning_ray\n", 59 | "%cp -r learning_ray/notebooks/* ." 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "4f8c77f4", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "import gym\n", 70 | "\n", 71 | "\n", 72 | "class Env:\n", 73 | "\n", 74 | " action_space: gym.spaces.Space\n", 75 | " observation_space: gym.spaces.Space\n", 76 | "\n", 77 | " def step(self, action):\n", 78 | " ...\n", 79 | "\n", 80 | " def reset(self):\n", 81 | " ...\n", 82 | "\n", 83 | " def render(self, mode=\"human\"):\n", 84 | " ..." 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "id": "f72525bf", 90 | "metadata": {}, 91 | "source": [ 92 | "In `maze.py` we set `num_rollout_workers=0` for this notebook, so that the code works in Colab. In the book itself we use 2 rollout workers to show that experience collection can be distributed by RLlib." 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "id": "83a6734a", 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "! rllib train file maze.py --stop '{\"timesteps_total\": 10000}'" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "id": "e8dfc658", 108 | "metadata": { 109 | "lines_to_next_cell": 2 110 | }, 111 | "source": [ 112 | "\n", 113 | "Try:\n", 114 | "rllib evaluate ~/ray_results/maze_env/\\\n", 115 | " --algo DQN\\\n", 116 | " --env maze_gym_env.Environment\\\n", 117 | " --steps 100" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "id": "965ee003", 124 | "metadata": { 125 | "lines_to_next_cell": 2 126 | }, 127 | "outputs": [], 128 | "source": [ 129 | "from ray.tune.logger import pretty_print\n", 130 | "from maze_gym_env import GymEnvironment\n", 131 | "from ray.rllib.algorithms.dqn import DQNConfig\n", 132 | "\n", 133 | "config = (DQNConfig().environment(GymEnvironment)\n", 134 | " .rollouts(num_rollout_workers=2, create_env_on_local_worker=True))\n", 135 | "\n", 136 | "pretty_print(config.to_dict())\n", 137 | "\n", 138 | "algo = config.build()\n", 139 | "\n", 140 | "for i in range(10):\n", 141 | " result = algo.train()\n", 142 | "\n", 143 | "print(pretty_print(result))" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "id": "283bf0f8", 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "from ray.rllib.algorithms.algorithm import Algorithm\n", 154 | "\n", 155 | "\n", 156 | "checkpoint = algo.save()\n", 157 | "print(checkpoint)\n", 158 | "\n", 159 | "evaluation = algo.evaluate()\n", 160 | "print(pretty_print(evaluation))\n", 161 | "\n", 162 | "algo.stop()\n", 163 | "restored_algo = Algorithm.from_checkpoint(checkpoint)\n", 164 | "\n", 165 | "algo = restored_algo" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "id": "e4e3c637", 172 | "metadata": { 173 | "lines_to_next_cell": 2 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "env = GymEnvironment()\n", 178 | "done = False\n", 179 | "total_reward = 0\n", 180 | "observations = env.reset()\n", 181 | "\n", 182 | "while not done:\n", 183 | " action = algo.compute_single_action(observations)\n", 184 | " observations, reward, done, info = env.step(action)\n", 185 | " total_reward += reward" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "id": "6230b49d", 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "action = algo.compute_actions(\n", 196 | " {\"obs_1\": observations, \"obs_2\": observations}\n", 197 | ")\n", 198 | "print(action)\n", 199 | "# {'obs_1': 0, 'obs_2': 1}" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "id": "51d2c31a", 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "policy = algo.get_policy()\n", 210 | "print(policy.get_weights())\n", 211 | "\n", 212 | "model = policy.model" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "id": "83366455", 219 | "metadata": {}, 220 | "outputs": [], 221 | "source": [ 222 | "workers = algo.workers\n", 223 | "workers.foreach_worker(\n", 224 | " lambda remote_trainer: remote_trainer.get_policy().get_weights()\n", 225 | ")" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "id": "371e167b", 232 | "metadata": {}, 233 | "outputs": [], 234 | "source": [ 235 | "model.base_model.summary()" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "id": "5a539340", 242 | "metadata": { 243 | "lines_to_next_cell": 2 244 | }, 245 | "outputs": [], 246 | "source": [ 247 | "from ray.rllib.models.preprocessors import get_preprocessor\n", 248 | "\n", 249 | "\n", 250 | "env = GymEnvironment()\n", 251 | "obs_space = env.observation_space\n", 252 | "preprocessor = get_preprocessor(obs_space)(obs_space)\n", 253 | "\n", 254 | "observations = env.reset()\n", 255 | "transformed = preprocessor.transform(observations).reshape(1, -1)\n", 256 | "\n", 257 | "model_output, _ = model({\"obs\": transformed})" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "id": "0f47a55f", 264 | "metadata": { 265 | "lines_to_next_cell": 2 266 | }, 267 | "outputs": [], 268 | "source": [ 269 | "q_values = model.get_q_value_distributions(model_output)\n", 270 | "print(q_values)\n", 271 | "\n", 272 | "action_distribution = policy.dist_class(model_output, model)\n", 273 | "sample = action_distribution.sample()\n", 274 | "print(sample)" 275 | ] 276 | }, 277 | { 278 | "cell_type": "markdown", 279 | "id": "94dc4ca5", 280 | "metadata": {}, 281 | "source": [ 282 | "\n", 283 | "![RLlib Environments](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_04/rllib_envs.png)" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "id": "9081e6a5", 290 | "metadata": { 291 | "lines_to_next_cell": 0 292 | }, 293 | "outputs": [], 294 | "source": [ 295 | "from ray.rllib.env.multi_agent_env import MultiAgentEnv\n", 296 | "from gym.spaces import Discrete\n", 297 | "import os\n", 298 | "\n", 299 | "\n", 300 | "class MultiAgentMaze(MultiAgentEnv):\n", 301 | "\n", 302 | " def __init__(self, *args, **kwargs):\n", 303 | " self.action_space = Discrete(4)\n", 304 | " self.observation_space = Discrete(5*5)\n", 305 | " self.agents = {1: (4, 0), 2: (0, 4)}\n", 306 | " self.goal = (4, 4)\n", 307 | " self.info = {1: {'obs': self.agents[1]}, 2: {'obs': self.agents[2]}}\n", 308 | "\n", 309 | " def reset(self):\n", 310 | " self.agents = {1: (4, 0), 2: (0, 4)}\n", 311 | "\n", 312 | " return {1: self.get_observation(1), 2: self.get_observation(2)}\n", 313 | "\n", 314 | " def get_observation(self, agent_id):\n", 315 | " seeker = self.agents[agent_id]\n", 316 | " return 5 * seeker[0] + seeker[1]\n", 317 | "\n", 318 | " def get_reward(self, agent_id):\n", 319 | " return 1 if self.agents[agent_id] == self.goal else 0\n", 320 | "\n", 321 | " def is_done(self, agent_id):\n", 322 | " return self.agents[agent_id] == self.goal\n", 323 | "\n", 324 | " def step(self, action):\n", 325 | " agent_ids = action.keys()\n", 326 | "\n", 327 | " for agent_id in agent_ids:\n", 328 | " seeker = self.agents[agent_id]\n", 329 | " if action[agent_id] == 0: # move down\n", 330 | " seeker = (min(seeker[0] + 1, 4), seeker[1])\n", 331 | " elif action[agent_id] == 1: # move left\n", 332 | " seeker = (seeker[0], max(seeker[1] - 1, 0))\n", 333 | " elif action[agent_id] == 2: # move up\n", 334 | " seeker = (max(seeker[0] - 1, 0), seeker[1])\n", 335 | " elif action[agent_id] == 3: # move right\n", 336 | " seeker = (seeker[0], min(seeker[1] + 1, 4))\n", 337 | " else:\n", 338 | " raise ValueError(\"Invalid action\")\n", 339 | " self.agents[agent_id] = seeker\n", 340 | "\n", 341 | " observations = {i: self.get_observation(i) for i in agent_ids}\n", 342 | " rewards = {i: self.get_reward(i) for i in agent_ids}\n", 343 | " done = {i: self.is_done(i) for i in agent_ids}\n", 344 | "\n", 345 | " done[\"__all__\"] = all(done.values())\n", 346 | "\n", 347 | " return observations, rewards, done, self.info\n", 348 | "\n", 349 | " def render(self, *args, **kwargs):\n", 350 | " \"\"\"We override this method here so clear the output in Jupyter notebooks.\n", 351 | " The previous implementation works well in the terminal, but does not clear\n", 352 | " the screen in interactive environments.\n", 353 | " \"\"\"\n", 354 | " os.system('cls' if os.name == 'nt' else 'clear')\n", 355 | " try:\n", 356 | " from IPython.display import clear_output\n", 357 | " clear_output(wait=True)\n", 358 | " except Exception:\n", 359 | " pass\n", 360 | " grid = [['| ' for _ in range(5)] + [\"|\\n\"] for _ in range(5)]\n", 361 | " grid[self.goal[0]][self.goal[1]] = '|G'\n", 362 | " grid[self.agents[1][0]][self.agents[1][1]] = '|1'\n", 363 | " grid[self.agents[2][0]][self.agents[2][1]] = '|2'\n", 364 | " grid[self.agents[2][0]][self.agents[2][1]] = '|2'\n", 365 | " print(''.join([''.join(grid_row) for grid_row in grid]))" 366 | ] 367 | }, 368 | { 369 | "cell_type": "markdown", 370 | "id": "7a3057fa", 371 | "metadata": {}, 372 | "source": [ 373 | "![RLlib Mapping Envs](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_04/mapping_envs.png)" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": null, 379 | "id": "3b74fd5b", 380 | "metadata": { 381 | "lines_to_next_cell": 2 382 | }, 383 | "outputs": [], 384 | "source": [ 385 | "import time\n", 386 | "\n", 387 | "env = MultiAgentMaze()\n", 388 | "\n", 389 | "while True:\n", 390 | " obs, rew, done, info = env.step(\n", 391 | " {1: env.action_space.sample(), 2: env.action_space.sample()}\n", 392 | " )\n", 393 | " time.sleep(0.1)\n", 394 | " env.render()\n", 395 | " if any(done.values()):\n", 396 | " break" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": null, 402 | "id": "417642b6", 403 | "metadata": { 404 | "lines_to_next_cell": 2 405 | }, 406 | "outputs": [], 407 | "source": [ 408 | "from ray.rllib.algorithms.dqn import DQNConfig\n", 409 | "\n", 410 | "simple_trainer = DQNConfig().environment(env=MultiAgentMaze).build()\n", 411 | "simple_trainer.train()\n", 412 | "\n", 413 | "algo = DQNConfig()\\\n", 414 | " .environment(env=MultiAgentMaze)\\\n", 415 | " .multi_agent(\n", 416 | " policies={\n", 417 | " \"policy_1\": (\n", 418 | " None, env.observation_space, env.action_space, {\"gamma\": 0.80}\n", 419 | " ),\n", 420 | " \"policy_2\": (\n", 421 | " None, env.observation_space, env.action_space, {\"gamma\": 0.95}\n", 422 | " ),\n", 423 | " },\n", 424 | " policy_mapping_fn = lambda agent_id: f\"policy_{agent_id}\",\n", 425 | " ).build()\n", 426 | "\n", 427 | "print(algo.train())" 428 | ] 429 | }, 430 | { 431 | "cell_type": "markdown", 432 | "id": "9b65341b", 433 | "metadata": {}, 434 | "source": [ 435 | "![RLlib External Envs](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_04/rllib_external.png)" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": null, 441 | "id": "897b4d78", 442 | "metadata": { 443 | "lines_to_next_cell": 1 444 | }, 445 | "outputs": [], 446 | "source": [ 447 | "from gym.spaces import Discrete\n", 448 | "import random\n", 449 | "import os\n", 450 | "\n", 451 | "\n", 452 | "class AdvancedEnv(GymEnvironment):\n", 453 | "\n", 454 | " def __init__(self, seeker=None, *args, **kwargs):\n", 455 | " super().__init__(*args, **kwargs)\n", 456 | " self.maze_len = 11\n", 457 | " self.action_space = Discrete(4)\n", 458 | " self.observation_space = Discrete(self.maze_len * self.maze_len)\n", 459 | "\n", 460 | " if seeker:\n", 461 | " assert 0 <= seeker[0] < self.maze_len and \\\n", 462 | " 0 <= seeker[1] < self.maze_len\n", 463 | " self.seeker = seeker\n", 464 | " else:\n", 465 | " self.reset()\n", 466 | "\n", 467 | " self.goal = (self.maze_len-1, self.maze_len-1)\n", 468 | " self.info = {'seeker': self.seeker, 'goal': self.goal}\n", 469 | "\n", 470 | " self.punish_states = [\n", 471 | " (i, j) for i in range(self.maze_len) for j in range(self.maze_len)\n", 472 | " if i % 2 == 1 and j % 2 == 0\n", 473 | " ]\n", 474 | "\n", 475 | " def reset(self):\n", 476 | " \"\"\"Reset seeker position randomly, return observations.\"\"\"\n", 477 | " self.seeker = (\n", 478 | " random.randint(0, self.maze_len - 1),\n", 479 | " random.randint(0, self.maze_len - 1)\n", 480 | " )\n", 481 | " return self.get_observation()\n", 482 | "\n", 483 | " def get_observation(self):\n", 484 | " \"\"\"Encode the seeker position as integer\"\"\"\n", 485 | " return self.maze_len * self.seeker[0] + self.seeker[1]\n", 486 | "\n", 487 | " def get_reward(self):\n", 488 | " \"\"\"Reward finding the goal and punish forbidden states\"\"\"\n", 489 | " reward = -1 if self.seeker in self.punish_states else 0\n", 490 | " reward += 5 if self.seeker == self.goal else 0\n", 491 | " return reward\n", 492 | "\n", 493 | " def render(self, *args, **kwargs):\n", 494 | " \"\"\"We override this method here so clear the output in Jupyter notebooks.\n", 495 | " The previous implementation works well in the terminal, but does not clear\n", 496 | " the screen in interactive environments.\n", 497 | " \"\"\"\n", 498 | " os.system('cls' if os.name == 'nt' else 'clear')\n", 499 | " try:\n", 500 | " from IPython.display import clear_output\n", 501 | " clear_output(wait=True)\n", 502 | " except Exception:\n", 503 | " pass\n", 504 | " grid = [['| ' for _ in range(self.maze_len)] +\n", 505 | " [\"|\\n\"] for _ in range(self.maze_len)]\n", 506 | " for punish in self.punish_states:\n", 507 | " grid[punish[0]][punish[1]] = '|X'\n", 508 | " grid[self.goal[0]][self.goal[1]] = '|G'\n", 509 | " grid[self.seeker[0]][self.seeker[1]] = '|S'\n", 510 | " print(''.join([''.join(grid_row) for grid_row in grid]))" 511 | ] 512 | }, 513 | { 514 | "cell_type": "code", 515 | "execution_count": null, 516 | "id": "eeec1d65", 517 | "metadata": {}, 518 | "outputs": [], 519 | "source": [ 520 | "from ray.rllib.env.apis.task_settable_env import TaskSettableEnv\n", 521 | "\n", 522 | "\n", 523 | "class CurriculumEnv(AdvancedEnv, TaskSettableEnv):\n", 524 | "\n", 525 | " def __init__(self, *args, **kwargs):\n", 526 | " AdvancedEnv.__init__(self)\n", 527 | "\n", 528 | " def difficulty(self):\n", 529 | " return abs(self.seeker[0] - self.goal[0]) + \\\n", 530 | " abs(self.seeker[1] - self.goal[1])\n", 531 | "\n", 532 | " def get_task(self):\n", 533 | " return self.difficulty()\n", 534 | "\n", 535 | " def set_task(self, task_difficulty):\n", 536 | " while not self.difficulty() <= task_difficulty:\n", 537 | " self.reset()" 538 | ] 539 | }, 540 | { 541 | "cell_type": "code", 542 | "execution_count": null, 543 | "id": "36dcd250", 544 | "metadata": {}, 545 | "outputs": [], 546 | "source": [ 547 | "def curriculum_fn(train_results, task_settable_env, env_ctx):\n", 548 | " time_steps = train_results.get(\"timesteps_total\")\n", 549 | " difficulty = time_steps // 1000\n", 550 | " print(f\"Current difficulty: {difficulty}\")\n", 551 | " return difficulty" 552 | ] 553 | }, 554 | { 555 | "cell_type": "code", 556 | "execution_count": null, 557 | "id": "4db20214", 558 | "metadata": { 559 | "lines_to_next_cell": 2 560 | }, 561 | "outputs": [], 562 | "source": [ 563 | "from ray.rllib.algorithms.dqn import DQNConfig\n", 564 | "import tempfile\n", 565 | "\n", 566 | "\n", 567 | "temp = tempfile.mkdtemp()\n", 568 | "\n", 569 | "trainer = (\n", 570 | " DQNConfig()\n", 571 | " .environment(env=CurriculumEnv, env_task_fn=curriculum_fn)\n", 572 | " .offline_data(output=temp)\n", 573 | " .build()\n", 574 | ")\n", 575 | "\n", 576 | "for i in range(15):\n", 577 | " trainer.train()" 578 | ] 579 | }, 580 | { 581 | "cell_type": "code", 582 | "execution_count": null, 583 | "id": "e878b6b2", 584 | "metadata": {}, 585 | "outputs": [], 586 | "source": [ 587 | "imitation_algo = (\n", 588 | " DQNConfig()\n", 589 | " .environment(env=AdvancedEnv)\n", 590 | " .evaluation(off_policy_estimation_methods={})\n", 591 | " .offline_data(input_=temp)\n", 592 | " .exploration(explore=False)\n", 593 | " .build())\n", 594 | "\n", 595 | "for i in range(10):\n", 596 | " imitation_algo.train()\n", 597 | "\n", 598 | "imitation_algo.evaluate()" 599 | ] 600 | } 601 | ], 602 | "metadata": { 603 | "jupytext": { 604 | "cell_metadata_filter": "-all", 605 | "main_language": "python", 606 | "notebook_metadata_filter": "-all" 607 | }, 608 | "kernelspec": { 609 | "display_name": "Python 3 (ipykernel)", 610 | "language": "python", 611 | "name": "python3" 612 | }, 613 | "language_info": { 614 | "codemirror_mode": { 615 | "name": "ipython", 616 | "version": 3 617 | }, 618 | "file_extension": ".py", 619 | "mimetype": "text/x-python", 620 | "name": "python", 621 | "nbconvert_exporter": "python", 622 | "pygments_lexer": "ipython3", 623 | "version": "3.9.13" 624 | } 625 | }, 626 | "nbformat": 4, 627 | "nbformat_minor": 5 628 | } 629 | -------------------------------------------------------------------------------- /notebooks/ch_05_tune.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "3ed82bf7", 6 | "metadata": {}, 7 | "source": [ 8 | "# Hyperparameter Optimization with Ray Tune" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "344a719e", 14 | "metadata": {}, 15 | "source": [ 16 | "\n", 17 | "You can run this notebook directly in\n", 18 | "[Colab](https://colab.research.google.com/github/maxpumperla/learning_ray/blob/main/notebooks/ch_05_tune.ipynb).\n", 19 | "\n", 20 | "\"Open" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "bfdc058b", 26 | "metadata": {}, 27 | "source": [ 28 | "For this chapter you need to install the following dependencies:" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "4c21e464", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "! pip install \"ray[tune]==2.2.0\"\n", 39 | "! pip install \"hyperopt==0.2.7\"\n", 40 | "! pip install \"bayesian-optimization==1.3.1\"\n", 41 | "! pip install \"tensorflow>=2.9.0\"" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "id": "5c0af473", 47 | "metadata": {}, 48 | "source": [ 49 | "\n", 50 | "To import utility files for this chapter, on Colab you will also have to clone\n", 51 | "the repo and copy the code files to the base path of the runtime:" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "3e740de8", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "!git clone https://github.com/maxpumperla/learning_ray\n", 62 | "%cp -r learning_ray/notebooks/* ." 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "id": "4fd95122", 69 | "metadata": { 70 | "lines_to_next_cell": 1 71 | }, 72 | "outputs": [], 73 | "source": [ 74 | "from maze_gym_env import Environment\n", 75 | "import time\n", 76 | "import numpy as np\n", 77 | "\n", 78 | "class Policy:\n", 79 | "\n", 80 | " def __init__(self, env):\n", 81 | " \"\"\"A Policy suggests actions based on the current state.\n", 82 | " We do this by tracking the value of each state-action pair.\n", 83 | " \"\"\"\n", 84 | " self.state_action_table = [\n", 85 | " [0 for _ in range(env.action_space.n)]\n", 86 | " for _ in range(env.observation_space.n)\n", 87 | " ]\n", 88 | " self.action_space = env.action_space\n", 89 | "\n", 90 | " def get_action(self, state, explore=True, epsilon=0.1):\n", 91 | " \"\"\"Explore randomly or exploit the best value currently available.\"\"\"\n", 92 | " if explore and random.uniform(0, 1) < epsilon:\n", 93 | " return self.action_space.sample()\n", 94 | " return np.argmax(self.state_action_table[state])\n", 95 | "\n", 96 | "\n", 97 | "class Simulation(object):\n", 98 | " def __init__(self, env):\n", 99 | " \"\"\"Simulates rollouts of an environment, given a policy to follow.\"\"\"\n", 100 | " self.env = env\n", 101 | "\n", 102 | " def rollout(self, policy, render=False, explore=True, epsilon=0.1):\n", 103 | " \"\"\"Returns experiences for a policy rollout.\"\"\"\n", 104 | " experiences = []\n", 105 | " state = self.env.reset()\n", 106 | " done = False\n", 107 | " while not done:\n", 108 | " action = policy.get_action(state, explore, epsilon)\n", 109 | " next_state, reward, done, info = self.env.step(action)\n", 110 | " experiences.append([state, action, reward, next_state])\n", 111 | " state = next_state\n", 112 | " if render:\n", 113 | " time.sleep(0.05)\n", 114 | " self.env.render()\n", 115 | "\n", 116 | " return experiences\n", 117 | "\n", 118 | "\n", 119 | "def update_policy(policy, experiences, weight=0.1, discount_factor=0.9):\n", 120 | " \"\"\"Updates a given policy with a list of (state, action, reward, state)\n", 121 | " experiences.\"\"\"\n", 122 | " for state, action, reward, next_state in experiences:\n", 123 | " next_max = np.max(policy.state_action_table[next_state])\n", 124 | " value = policy.state_action_table[state][action]\n", 125 | " new_value = (1 - weight) * value + weight * \\\n", 126 | " (reward + discount_factor * next_max)\n", 127 | " policy.state_action_table[state][action] = new_value\n", 128 | "\n", 129 | "\n", 130 | "def train_policy(env, num_episodes=10000, weight=0.1, discount_factor=0.9):\n", 131 | " \"\"\"Training a policy by updating it with rollout experiences.\"\"\"\n", 132 | " policy = Policy(env)\n", 133 | " sim = Simulation(env)\n", 134 | " for _ in range(num_episodes):\n", 135 | " experiences = sim.rollout(policy)\n", 136 | " update_policy(policy, experiences, weight, discount_factor)\n", 137 | "\n", 138 | " return policy\n", 139 | "\n", 140 | "\n", 141 | "def evaluate_policy(env, policy, num_episodes=10):\n", 142 | " \"\"\"Evaluate a trained policy through rollouts.\"\"\"\n", 143 | " simulation = Simulation(env)\n", 144 | " steps = 0\n", 145 | "\n", 146 | " for _ in range(num_episodes):\n", 147 | " experiences = simulation.rollout(policy, render=True, explore=False)\n", 148 | " steps += len(experiences)\n", 149 | "\n", 150 | " print(f\"{steps / num_episodes} steps on average \"\n", 151 | " f\"for a total of {num_episodes} episodes.\")\n", 152 | "\n", 153 | " return steps / num_episodes" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "source": [ 159 | "![Tune Flow](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_05/tune_flow.png)\n" 160 | ], 161 | "metadata": { 162 | "collapsed": false 163 | } 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "id": "68b77888", 169 | "metadata": { 170 | "lines_to_next_cell": 2 171 | }, 172 | "outputs": [], 173 | "source": [ 174 | "import random\n", 175 | "search_space = []\n", 176 | "for i in range(10):\n", 177 | " random_choice = {\n", 178 | " 'weight': random.uniform(0, 1),\n", 179 | " 'discount_factor': random.uniform(0, 1)\n", 180 | " }\n", 181 | " search_space.append(random_choice)" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "id": "53678cc4", 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "import ray\n", 192 | "\n", 193 | "\n", 194 | "@ray.remote\n", 195 | "def objective(config):\n", 196 | " environment = Environment()\n", 197 | " policy = train_policy(\n", 198 | " environment,\n", 199 | " weight=config[\"weight\"],\n", 200 | " discount_factor=config[\"discount_factor\"]\n", 201 | " )\n", 202 | " score = evaluate_policy(environment, policy)\n", 203 | " return [score, config]" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "id": "699de128", 210 | "metadata": { 211 | "lines_to_next_cell": 2 212 | }, 213 | "outputs": [], 214 | "source": [ 215 | "result_objects = [objective.remote(choice) for choice in search_space]\n", 216 | "results = ray.get(result_objects)\n", 217 | "\n", 218 | "results.sort(key=lambda x: x[0])\n", 219 | "print(results[-1])" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "id": "eb562ac8", 226 | "metadata": { 227 | "lines_to_next_cell": 2 228 | }, 229 | "outputs": [], 230 | "source": [ 231 | "from ray import tune\n", 232 | "\n", 233 | "\n", 234 | "search_space = {\n", 235 | " \"weight\": tune.uniform(0, 1),\n", 236 | " \"discount_factor\": tune.uniform(0, 1),\n", 237 | "}" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "id": "4c076bb6", 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "def tune_objective(config):\n", 248 | " environment = Environment()\n", 249 | " policy = train_policy(\n", 250 | " environment,\n", 251 | " weight=config[\"weight\"],\n", 252 | " discount_factor=config[\"discount_factor\"]\n", 253 | " )\n", 254 | " score = evaluate_policy(environment, policy)\n", 255 | "\n", 256 | " return {\"score\": score}" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "id": "cc2b8963", 263 | "metadata": { 264 | "lines_to_next_cell": 2 265 | }, 266 | "outputs": [], 267 | "source": [ 268 | "analysis = tune.run(tune_objective, config=search_space)\n", 269 | "print(analysis.get_best_config(metric=\"score\", mode=\"min\"))" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "id": "dd7a09df", 276 | "metadata": { 277 | "lines_to_next_cell": 2 278 | }, 279 | "outputs": [], 280 | "source": [ 281 | "from ray.tune.suggest.bayesopt import BayesOptSearch\n", 282 | "\n", 283 | "\n", 284 | "algo = BayesOptSearch(random_search_steps=4)\n", 285 | "\n", 286 | "tune.run(\n", 287 | " tune_objective,\n", 288 | " config=search_space,\n", 289 | " metric=\"score\",\n", 290 | " mode=\"min\",\n", 291 | " search_alg=algo,\n", 292 | " stop={\"training_iteration\": 10},\n", 293 | ")" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": null, 299 | "id": "588fccc6", 300 | "metadata": { 301 | "lines_to_next_cell": 2 302 | }, 303 | "outputs": [], 304 | "source": [ 305 | "def objective(config):\n", 306 | " for step in range(30):\n", 307 | " score = config[\"weight\"] * (step ** 0.5) + config[\"bias\"]\n", 308 | " tune.report(score=score)\n", 309 | "\n", 310 | "\n", 311 | "search_space = {\"weight\": tune.uniform(0, 1), \"bias\": tune.uniform(0, 1)}" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "id": "5de86454", 318 | "metadata": { 319 | "lines_to_next_cell": 2 320 | }, 321 | "outputs": [], 322 | "source": [ 323 | "from ray.tune.schedulers import HyperBandScheduler\n", 324 | "\n", 325 | "\n", 326 | "scheduler = HyperBandScheduler(metric=\"score\", mode=\"min\")\n", 327 | "\n", 328 | "\n", 329 | "analysis = tune.run(\n", 330 | " objective,\n", 331 | " config=search_space,\n", 332 | " scheduler=scheduler,\n", 333 | " num_samples=10,\n", 334 | ")\n", 335 | "\n", 336 | "print(analysis.get_best_config(metric=\"score\", mode=\"min\"))" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": null, 342 | "id": "61ddb245", 343 | "metadata": { 344 | "lines_to_next_cell": 2 345 | }, 346 | "outputs": [], 347 | "source": [ 348 | "# NOTE: in the book we have 0.5 GPUs, but set this to 0 here so that it runs on Colab.\n", 349 | "from ray import tune\n", 350 | "\n", 351 | "tune.run(\n", 352 | " objective,\n", 353 | " config=search_space,\n", 354 | " num_samples=10,\n", 355 | " resources_per_trial={\"cpu\": 2, \"gpu\": 0}\n", 356 | ")" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": null, 362 | "id": "453bd292", 363 | "metadata": {}, 364 | "outputs": [], 365 | "source": [ 366 | "from ray import tune\n", 367 | "from ray.tune import Callback\n", 368 | "from ray.tune.logger import pretty_print\n", 369 | "\n", 370 | "\n", 371 | "class PrintResultCallback(Callback):\n", 372 | " def on_trial_result(self, iteration, trials, trial, result, **info):\n", 373 | " print(f\"Trial {trial} in iteration {iteration}, \"\n", 374 | " f\"got result: {result['score']}\")\n", 375 | "\n", 376 | "\n", 377 | "def objective(config):\n", 378 | " for step in range(30):\n", 379 | " score = config[\"weight\"] * (step ** 0.5) + config[\"bias\"]\n", 380 | " tune.report(score=score, step=step, more_metrics={})" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": null, 386 | "id": "4fccbcf4", 387 | "metadata": { 388 | "lines_to_next_cell": 2 389 | }, 390 | "outputs": [], 391 | "source": [ 392 | "search_space = {\"weight\": tune.uniform(0, 1), \"bias\": tune.uniform(0, 1)}\n", 393 | "\n", 394 | "analysis = tune.run(\n", 395 | " objective,\n", 396 | " config=search_space,\n", 397 | " mode=\"min\",\n", 398 | " metric=\"score\",\n", 399 | " callbacks=[PrintResultCallback()])\n", 400 | "\n", 401 | "best = analysis.best_trial\n", 402 | "print(pretty_print(best.last_result))" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": null, 408 | "id": "aefdaab6", 409 | "metadata": { 410 | "lines_to_next_cell": 2 411 | }, 412 | "outputs": [], 413 | "source": [ 414 | "# NOTE: this will only run if you insert a correct logdir.\n", 415 | "analysis = tune.run(\n", 416 | " objective,\n", 417 | " name=\"\",\n", 418 | " resume=True,\n", 419 | " config=search_space)" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": null, 425 | "id": "9679008e", 426 | "metadata": { 427 | "lines_to_next_cell": 2 428 | }, 429 | "outputs": [], 430 | "source": [ 431 | "tune.run(\n", 432 | " objective,\n", 433 | " config=search_space,\n", 434 | " stop={\"training_iteration\": 10})" 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "execution_count": null, 440 | "id": "96462e39", 441 | "metadata": { 442 | "lines_to_next_cell": 2 443 | }, 444 | "outputs": [], 445 | "source": [ 446 | "def stopper(trial_id, result):\n", 447 | " return result[\"score\"] < 2\n", 448 | "\n", 449 | "\n", 450 | "tune.run(\n", 451 | " objective,\n", 452 | " config=search_space,\n", 453 | " stop=stopper)" 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": null, 459 | "id": "5d91fd6e", 460 | "metadata": { 461 | "lines_to_next_cell": 2 462 | }, 463 | "outputs": [], 464 | "source": [ 465 | "from ray import tune\n", 466 | "import numpy as np\n", 467 | "\n", 468 | "search_space = {\n", 469 | " \"weight\": tune.sample_from(\n", 470 | " lambda context: np.random.uniform(low=0.0, high=1.0)\n", 471 | " ),\n", 472 | " \"bias\": tune.sample_from(\n", 473 | " lambda context: context.config.weight * np.random.normal()\n", 474 | " )}\n", 475 | "\n", 476 | "tune.run(objective, config=search_space)" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": null, 482 | "id": "fe2dd055", 483 | "metadata": { 484 | "lines_to_next_cell": 2 485 | }, 486 | "outputs": [], 487 | "source": [ 488 | "# NOTE: this run will take incredibly long on Colab, be warned!\n", 489 | "from ray import tune\n", 490 | "\n", 491 | "analysis = tune.run(\n", 492 | " \"DQN\",\n", 493 | " metric=\"episode_reward_mean\",\n", 494 | " mode=\"max\",\n", 495 | " config={\n", 496 | " \"env\": \"CartPole-v1\",\n", 497 | " \"lr\": tune.uniform(1e-5, 1e-4),\n", 498 | " \"train_batch_size\": tune.choice([10000, 20000, 40000]),\n", 499 | " },\n", 500 | ")" 501 | ] 502 | }, 503 | { 504 | "cell_type": "markdown", 505 | "source": [ 506 | "![Tune Model Training](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_05/Tune_model_training.png)\n" 507 | ], 508 | "metadata": { 509 | "collapsed": false 510 | } 511 | }, 512 | { 513 | "cell_type": "code", 514 | "execution_count": null, 515 | "id": "79160b9f", 516 | "metadata": { 517 | "lines_to_next_cell": 2 518 | }, 519 | "outputs": [], 520 | "source": [ 521 | "from tensorflow.keras.datasets import mnist\n", 522 | "from tensorflow.keras.utils import to_categorical\n", 523 | "\n", 524 | "\n", 525 | "def load_data():\n", 526 | " (x_train, y_train), (x_test, y_test) = mnist.load_data()\n", 527 | " num_classes = 10\n", 528 | " x_train, x_test = x_train / 255.0, x_test / 255.0\n", 529 | " y_train = to_categorical(y_train, num_classes)\n", 530 | " y_test = to_categorical(y_test, num_classes)\n", 531 | " return (x_train, y_train), (x_test, y_test)\n", 532 | "\n", 533 | "\n", 534 | "load_data()" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": null, 540 | "id": "d2e71386", 541 | "metadata": {}, 542 | "outputs": [], 543 | "source": [ 544 | "from tensorflow.keras.models import Sequential\n", 545 | "from tensorflow.keras.layers import Flatten, Dense, Dropout\n", 546 | "from ray.tune.integration.keras import TuneReportCallback\n", 547 | "\n", 548 | "\n", 549 | "def objective(config):\n", 550 | " (x_train, y_train), (x_test, y_test) = load_data()\n", 551 | " model = Sequential()\n", 552 | " model.add(Flatten(input_shape=(28, 28)))\n", 553 | " model.add(Dense(config[\"hidden\"], activation=config[\"activation\"]))\n", 554 | " model.add(Dropout(config[\"rate\"]))\n", 555 | " model.add(Dense(10, activation=\"softmax\"))\n", 556 | "\n", 557 | " model.compile(loss=\"categorical_crossentropy\", metrics=[\"accuracy\"])\n", 558 | " model.fit(x_train, y_train, batch_size=128, epochs=10,\n", 559 | " validation_data=(x_test, y_test),\n", 560 | " callbacks=[TuneReportCallback({\"mean_accuracy\": \"accuracy\"})])" 561 | ] 562 | }, 563 | { 564 | "cell_type": "code", 565 | "execution_count": null, 566 | "id": "1826b95f", 567 | "metadata": {}, 568 | "outputs": [], 569 | "source": [ 570 | "from ray import tune\n", 571 | "from ray.tune.suggest.hyperopt import HyperOptSearch\n", 572 | "\n", 573 | "initial_params = [{\"rate\": 0.2, \"hidden\": 128, \"activation\": \"relu\"}]\n", 574 | "algo = HyperOptSearch(points_to_evaluate=initial_params)\n", 575 | "\n", 576 | "search_space = {\n", 577 | " \"rate\": tune.uniform(0.1, 0.5),\n", 578 | " \"hidden\": tune.randint(32, 512),\n", 579 | " \"activation\": tune.choice([\"relu\", \"tanh\"])\n", 580 | "}\n", 581 | "\n", 582 | "\n", 583 | "analysis = tune.run(\n", 584 | " objective,\n", 585 | " name=\"keras_hyperopt_exp\",\n", 586 | " search_alg=algo,\n", 587 | " metric=\"mean_accuracy\",\n", 588 | " mode=\"max\",\n", 589 | " stop={\"mean_accuracy\": 0.99},\n", 590 | " num_samples=10,\n", 591 | " config=search_space,\n", 592 | ")\n", 593 | "print(\"Best hyperparameters found were: \", analysis.best_config)" 594 | ] 595 | } 596 | ], 597 | "metadata": { 598 | "jupytext": { 599 | "cell_metadata_filter": "-all", 600 | "main_language": "python", 601 | "notebook_metadata_filter": "-all" 602 | } 603 | }, 604 | "nbformat": 4, 605 | "nbformat_minor": 5 606 | } 607 | -------------------------------------------------------------------------------- /notebooks/ch_06_data_processing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "372c8436", 6 | "metadata": {}, 7 | "source": [ 8 | "# Data Processing with Ray" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "f7365d60", 14 | "metadata": {}, 15 | "source": [ 16 | "\n", 17 | "You can run this notebook directly in\n", 18 | "[Colab](https://colab.research.google.com/github/maxpumperla/learning_ray/blob/main/notebooks/ch_06_data_processing.ipynb).\n", 19 | "\n", 20 | "\"Open" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "650c6212", 26 | "metadata": {}, 27 | "source": [ 28 | "For this chapter you need to install the following dependencies:" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "60a6140b", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "! pip install \"ray[data]==2.2.0\"\n", 39 | "! pip install \"scikit-learn==1.0.2\"\n", 40 | "! pip install \"dask==2022.2.0\"" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "id": "40becbf9", 46 | "metadata": {}, 47 | "source": [ 48 | "\n", 49 | "To import utility files for this chapter, on Colab you will also have to clone\n", 50 | "the repo and copy the code files to the base path of the runtime:" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "id": "17aca7a4", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "!git clone https://github.com/maxpumperla/learning_ray\n", 61 | "%cp -r learning_ray/notebooks/* ." 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "source": [ 67 | "![Simple Ray Data](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_06/AIR_data.png)\n" 68 | ], 69 | "metadata": { 70 | "collapsed": false 71 | } 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "source": [ 76 | "![Data Pipeline 1](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_06/data_pipeline_1.png)" 77 | ], 78 | "metadata": { 79 | "collapsed": false 80 | } 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "source": [ 85 | "![Data Pipeline 2](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_06/data_pipeline_2.png)" 86 | ], 87 | "metadata": { 88 | "collapsed": false 89 | } 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "source": [ 94 | "![Data Positioning 1](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_06/data_positioning_1.png)" 95 | ], 96 | "metadata": { 97 | "collapsed": false 98 | } 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "source": [ 103 | "![Data Positioning 2](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_06/data_positioning_2.png)" 104 | ], 105 | "metadata": { 106 | "collapsed": false 107 | } 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "source": [ 112 | "![Data Architecture](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_06/datasets_arch.png)" 113 | ], 114 | "metadata": { 115 | "collapsed": false 116 | } 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "source": [ 121 | "![Data ML Workflow](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_06/ml_workflow.png)" 122 | ], 123 | "metadata": { 124 | "collapsed": false 125 | } 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "id": "20447273", 131 | "metadata": { 132 | "lines_to_next_cell": 2 133 | }, 134 | "outputs": [], 135 | "source": [ 136 | "import ray\n", 137 | "\n", 138 | "# Create a dataset containing integers in the range [0, 10000).\n", 139 | "ds = ray.data.range(10000)\n", 140 | "\n", 141 | "# Basic operations: show the size of the dataset, get a few samples, print the schema.\n", 142 | "print(ds.count()) # -> 10000\n", 143 | "print(ds.take(5)) # -> [0, 1, 2, 3, 4]\n", 144 | "print(ds.schema()) # -> " 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "id": "665a2ff7", 151 | "metadata": { 152 | "lines_to_next_cell": 2 153 | }, 154 | "outputs": [], 155 | "source": [ 156 | "# Save the dataset to a local file and load it back.\n", 157 | "ray.data.range(10000).write_csv(\"local_dir\")\n", 158 | "ds = ray.data.read_csv(\"local_dir\")\n", 159 | "print(ds.count())" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "id": "5847b020", 166 | "metadata": { 167 | "lines_to_next_cell": 2 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "ds1 = ray.data.range(10000)\n", 172 | "ds2 = ray.data.range(10000)\n", 173 | "ds3 = ds1.union(ds2)\n", 174 | "print(ds3.count()) # -> 20000\n", 175 | "\n", 176 | "# Filter the combined dataset to only the even elements.\n", 177 | "ds3 = ds3.filter(lambda x: x % 2 == 0)\n", 178 | "print(ds3.count()) # -> 10000\n", 179 | "print(ds3.take(5)) # -> [0, 2, 4, 6, 8]\n", 180 | "\n", 181 | "# Sort the filtered dataset.\n", 182 | "ds3 = ds3.sort() # <3>\n", 183 | "print(ds3.take(5)) # -> [0, 0, 2, 2, 4]" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "id": "09c5ecc7", 190 | "metadata": { 191 | "lines_to_next_cell": 2 192 | }, 193 | "outputs": [], 194 | "source": [ 195 | "ds1 = ray.data.range(10000)\n", 196 | "print(ds1.num_blocks()) # -> 200\n", 197 | "ds2 = ray.data.range(10000)\n", 198 | "print(ds2.num_blocks()) # -> 200\n", 199 | "ds3 = ds1.union(ds2)\n", 200 | "print(ds3.num_blocks()) # -> 400\n", 201 | "\n", 202 | "print(ds3.repartition(200).num_blocks()) # -> 200" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "id": "08f6cbde", 209 | "metadata": { 210 | "lines_to_next_cell": 2 211 | }, 212 | "outputs": [], 213 | "source": [ 214 | "ds = ray.data.from_items([{\"id\": \"abc\", \"value\": 1}, {\"id\": \"def\", \"value\": 2}])\n", 215 | "print(ds.schema()) # -> id: string, value: int64" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "id": "8c4dcf23", 222 | "metadata": { 223 | "lines_to_next_cell": 2 224 | }, 225 | "outputs": [], 226 | "source": [ 227 | "pandas_df = ds.to_pandas() # pandas_df will inherit the schema from our Dataset." 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "id": "bdb1d60e", 234 | "metadata": { 235 | "lines_to_next_cell": 2 236 | }, 237 | "outputs": [], 238 | "source": [ 239 | "ds = ray.data.range(10000).map(lambda x: x ** 2)\n", 240 | "ds.take(5) # -> [0, 1, 4, 9, 16]" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "id": "51bd5b81", 247 | "metadata": { 248 | "lines_to_next_cell": 2 249 | }, 250 | "outputs": [], 251 | "source": [ 252 | "import numpy as np\n", 253 | "\n", 254 | "\n", 255 | "ds = ray.data.range(10000).map_batches(lambda batch: np.square(batch).tolist())\n", 256 | "ds.take(5) # -> [0, 1, 4, 9, 16]" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "id": "d5553d8b", 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "def load_model():\n", 267 | " # Returns a dummy model for this example.\n", 268 | " # In reality, this would likely load some model weights onto a GPU.\n", 269 | " class DummyModel:\n", 270 | " def __call__(self, batch):\n", 271 | " return batch\n", 272 | "\n", 273 | " return DummyModel()\n", 274 | "\n", 275 | "\n", 276 | "class MLModel:\n", 277 | " def __init__(self):\n", 278 | " # load_model() will only run once per actor that's started.\n", 279 | " self._model = load_model()\n", 280 | "\n", 281 | " def __call__(self, batch):\n", 282 | " return self._model(batch)\n", 283 | "\n", 284 | "\n", 285 | "ds.map_batches(MLModel, compute=\"actors\")\n", 286 | "\n", 287 | "\n", 288 | "cpu_intensive_preprocessing = lambda batch: batch\n", 289 | "gpu_intensive_inference = lambda batch: batch" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": null, 295 | "id": "1e2350d3", 296 | "metadata": { 297 | "lines_to_next_cell": 2 298 | }, 299 | "outputs": [], 300 | "source": [ 301 | "# NOTE: this only works if you create an S3 bucket and upload the data there.\n", 302 | "ds = (ray.data.read_parquet(\"s3://my_bucket/input_data\")\n", 303 | " .map(cpu_intensive_preprocessing)\n", 304 | " .map_batches(gpu_intensive_inference, compute=\"actors\", num_gpus=1)\n", 305 | " .repartition(10))\n", 306 | "\n", 307 | "ds.write_parquet(\"s3://my_bucket/output_predictions\")" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "id": "9f3c1e05", 314 | "metadata": { 315 | "lines_to_next_cell": 2 316 | }, 317 | "outputs": [], 318 | "source": [ 319 | "# NOTE: this only works if you create an S3 bucket and upload the data there.\n", 320 | "ds = (ray.data.read_parquet(\"s3://my_bucket/input_data\")\n", 321 | " .window(blocks_per_window=5)\n", 322 | " .map(cpu_intensive_preprocessing)\n", 323 | " .map_batches(gpu_intensive_inference, compute=\"actors\", num_gpus=1)\n", 324 | " .repartition(10))\n", 325 | "ds.write_parquet(\"s3://my_bucket/output_predictions\")" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": null, 331 | "id": "6cc33042", 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "from sklearn import datasets\n", 336 | "from sklearn.linear_model import SGDClassifier\n", 337 | "from sklearn.model_selection import train_test_split\n", 338 | "\n", 339 | "\n", 340 | "@ray.remote\n", 341 | "class TrainingWorker:\n", 342 | " def __init__(self, alpha: float):\n", 343 | " self._model = SGDClassifier(alpha=alpha)\n", 344 | "\n", 345 | " def train(self, train_shard: ray.data.Dataset):\n", 346 | " for i, epoch in enumerate(train_shard.iter_epochs()):\n", 347 | " X, Y = zip(*list(epoch.iter_rows()))\n", 348 | " self._model.partial_fit(X, Y, classes=[0, 1])\n", 349 | "\n", 350 | " return self._model\n", 351 | "\n", 352 | " def test(self, X_test: np.ndarray, Y_test: np.ndarray):\n", 353 | " return self._model.score(X_test, Y_test)" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": null, 359 | "id": "7437935a", 360 | "metadata": { 361 | "lines_to_next_cell": 2 362 | }, 363 | "outputs": [], 364 | "source": [ 365 | "ALPHA_VALS = [0.00008, 0.00009, 0.0001, 0.00011, 0.00012]\n", 366 | "\n", 367 | "print(f\"Starting {len(ALPHA_VALS)} training workers.\")\n", 368 | "workers = [TrainingWorker.remote(alpha) for alpha in ALPHA_VALS]" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": null, 374 | "id": "46198d41", 375 | "metadata": { 376 | "lines_to_next_cell": 2 377 | }, 378 | "outputs": [], 379 | "source": [ 380 | "X_train, X_test, Y_train, Y_test = train_test_split(\n", 381 | " *datasets.make_classification()\n", 382 | ")\n", 383 | "\n", 384 | "train_ds = ray.data.from_items(list(zip(X_train, Y_train)))\n", 385 | "shards = (train_ds.repeat(10)\n", 386 | " .random_shuffle_each_window()\n", 387 | " .split(len(workers), locality_hints=workers))\n", 388 | "\n", 389 | "ray.get([worker.train.remote(shard) for worker, shard in zip(workers, shards)])" 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": null, 395 | "id": "9969a4ef", 396 | "metadata": {}, 397 | "outputs": [], 398 | "source": [ 399 | "# Get validation results from each worker.\n", 400 | "print(ray.get([worker.test.remote(X_test, Y_test) for worker in workers]))\n", 401 | "\n", 402 | "ray.shutdown()" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": null, 408 | "id": "adf0b4b7", 409 | "metadata": { 410 | "lines_to_next_cell": 2 411 | }, 412 | "outputs": [], 413 | "source": [ 414 | "import ray\n", 415 | "from ray.util.dask import enable_dask_on_ray\n", 416 | "\n", 417 | "ray.init() # Start or connect to Ray.\n", 418 | "enable_dask_on_ray() # Enable the Ray scheduler backend for Dask." 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": null, 424 | "id": "ef41df5b", 425 | "metadata": { 426 | "lines_to_next_cell": 2 427 | }, 428 | "outputs": [], 429 | "source": [ 430 | "import dask\n", 431 | "\n", 432 | "df = dask.datasets.timeseries()\n", 433 | "df = df[df.y > 0].groupby(\"name\").x.std()\n", 434 | "df.compute() # Trigger the task graph to be evaluated." 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "execution_count": null, 440 | "id": "17bb6db1", 441 | "metadata": {}, 442 | "outputs": [], 443 | "source": [ 444 | "import ray\n", 445 | "ds = ray.data.range(10000)\n", 446 | "\n", 447 | "# Convert the Dataset to a Dask DataFrame.\n", 448 | "df = ds.to_dask()\n", 449 | "print(df.std().compute()) # -> 2886.89568\n", 450 | "\n", 451 | "# Convert the Dask DataFrame back to a Dataset.\n", 452 | "ds = ray.data.from_dask(df)\n", 453 | "print(ds.std()) # -> 2886.89568" 454 | ] 455 | } 456 | ], 457 | "metadata": { 458 | "jupytext": { 459 | "cell_metadata_filter": "-all", 460 | "main_language": "python", 461 | "notebook_metadata_filter": "-all" 462 | } 463 | }, 464 | "nbformat": 4, 465 | "nbformat_minor": 5 466 | } 467 | -------------------------------------------------------------------------------- /notebooks/ch_07_train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "89466d94", 6 | "metadata": {}, 7 | "source": [ 8 | "# Distributed Training with Ray Train" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "2c23b3af", 14 | "metadata": {}, 15 | "source": [ 16 | "\n", 17 | "You can run this notebook directly in\n", 18 | "[Colab](https://colab.research.google.com/github/maxpumperla/learning_ray/blob/main/notebooks/ch_07_train.ipynb).\n", 19 | "\n", 20 | "\"Open" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "94e05d4f", 26 | "metadata": {}, 27 | "source": [ 28 | "For this chapter you will need to install the following dependencies:" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "1765ed6f", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "! pip install \"ray[data,train]==2.2.0\" \"dask==2022.2.0\" \"torch==1.12.1\"\n", 39 | "! pip install \"xgboost==1.6.2\" \"xgboost-ray>=0.1.10\"" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "id": "28ef63a5", 45 | "metadata": {}, 46 | "source": [ 47 | "\n", 48 | "To import utility files for this chapter, on Colab you will also have to clone\n", 49 | "the repo and copy the code files to the base path of the runtime:" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "id": "4a179146", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "!git clone https://github.com/maxpumperla/learning_ray\n", 60 | "%cp -r learning_ray/notebooks/* ." 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "source": [ 66 | "![Data Model Parallel](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_07/data_model_parallel.png)" 67 | ], 68 | "metadata": { 69 | "collapsed": false 70 | } 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "source": [ 75 | "![Torch Trainer](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_07/torch_trainer.png)" 76 | ], 77 | "metadata": { 78 | "collapsed": false 79 | } 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "source": [ 84 | "![Train Architecture](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_07/train_architecture.png)" 85 | ], 86 | "metadata": { 87 | "collapsed": false 88 | } 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "source": [ 93 | "![Train Overview](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_07/train_overview.png)" 94 | ], 95 | "metadata": { 96 | "collapsed": false 97 | } 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "source": [ 102 | "![Train Tune Execution](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_07/train_tune_execution.png)" 103 | ], 104 | "metadata": { 105 | "collapsed": false 106 | } 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "id": "409cb5d3", 112 | "metadata": {}, 113 | "outputs": [], 114 | "source": [ 115 | "import ray\n", 116 | "from ray.util.dask import enable_dask_on_ray\n", 117 | "\n", 118 | "import dask.dataframe as dd\n", 119 | "\n", 120 | "LABEL_COLUMN = \"is_big_tip\"\n", 121 | "FEATURE_COLUMNS = [\"passenger_count\", \"trip_distance\", \"fare_amount\",\n", 122 | " \"trip_duration\", \"hour\", \"day_of_week\"]\n", 123 | "\n", 124 | "enable_dask_on_ray()\n", 125 | "\n", 126 | "\n", 127 | "def load_dataset(path: str, *, include_label=True):\n", 128 | " columns = [\"tpep_pickup_datetime\", \"tpep_dropoff_datetime\", \"tip_amount\",\n", 129 | " \"passenger_count\", \"trip_distance\", \"fare_amount\"]\n", 130 | " df = dd.read_parquet(path, columns=columns)\n", 131 | "\n", 132 | " df = df.dropna()\n", 133 | " df = df[(df[\"passenger_count\"] <= 4) &\n", 134 | " (df[\"trip_distance\"] < 100) &\n", 135 | " (df[\"fare_amount\"] < 1000)]\n", 136 | "\n", 137 | " df[\"tpep_pickup_datetime\"] = dd.to_datetime(df[\"tpep_pickup_datetime\"])\n", 138 | " df[\"tpep_dropoff_datetime\"] = dd.to_datetime(df[\"tpep_dropoff_datetime\"])\n", 139 | "\n", 140 | " df[\"trip_duration\"] = (df[\"tpep_dropoff_datetime\"] -\n", 141 | " df[\"tpep_pickup_datetime\"]).dt.seconds\n", 142 | " df = df[df[\"trip_duration\"] < 4 * 60 * 60] # 4 hours.\n", 143 | " df[\"hour\"] = df[\"tpep_pickup_datetime\"].dt.hour\n", 144 | " df[\"day_of_week\"] = df[\"tpep_pickup_datetime\"].dt.weekday\n", 145 | "\n", 146 | " if include_label:\n", 147 | " df[LABEL_COLUMN] = df[\"tip_amount\"] > 0.2 * df[\"fare_amount\"]\n", 148 | "\n", 149 | " df = df.drop(\n", 150 | " columns=[\"tpep_pickup_datetime\", \"tpep_dropoff_datetime\", \"tip_amount\"]\n", 151 | " )\n", 152 | "\n", 153 | " return ray.data.from_dask(df).repartition(100)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "id": "f91e9d1b", 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "import torch\n", 164 | "import torch.nn as nn\n", 165 | "import torch.nn.functional as F\n", 166 | "\n", 167 | "\n", 168 | "class FarePredictor(nn.Module):\n", 169 | " def __init__(self):\n", 170 | " super().__init__()\n", 171 | "\n", 172 | " self.fc1 = nn.Linear(6, 256)\n", 173 | " self.fc2 = nn.Linear(256, 16)\n", 174 | " self.fc3 = nn.Linear(16, 1)\n", 175 | "\n", 176 | " self.bn1 = nn.BatchNorm1d(256)\n", 177 | " self.bn2 = nn.BatchNorm1d(16)\n", 178 | "\n", 179 | " def forward(self, x):\n", 180 | " x = F.relu(self.fc1(x))\n", 181 | " x = self.bn1(x)\n", 182 | " x = F.relu(self.fc2(x))\n", 183 | " x = self.bn2(x)\n", 184 | " x = torch.sigmoid(self.fc3(x))\n", 185 | "\n", 186 | " return x" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "id": "7cd15858", 193 | "metadata": {}, 194 | "outputs": [], 195 | "source": [ 196 | "from ray.air import session\n", 197 | "from ray.air.config import ScalingConfig\n", 198 | "import ray.train as train\n", 199 | "from ray.train.torch import TorchCheckpoint, TorchTrainer\n", 200 | "\n", 201 | "\n", 202 | "def train_loop_per_worker(config: dict):\n", 203 | " batch_size = config.get(\"batch_size\", 32)\n", 204 | " lr = config.get(\"lr\", 1e-2)\n", 205 | " num_epochs = config.get(\"num_epochs\", 3)\n", 206 | "\n", 207 | " dataset_shard = session.get_dataset_shard(\"train\")\n", 208 | "\n", 209 | " model = FarePredictor()\n", 210 | " dist_model = train.torch.prepare_model(model)\n", 211 | "\n", 212 | " loss_function = nn.SmoothL1Loss()\n", 213 | " optimizer = torch.optim.Adam(dist_model.parameters(), lr=lr)\n", 214 | "\n", 215 | " for epoch in range(num_epochs):\n", 216 | " loss = 0\n", 217 | " num_batches = 0\n", 218 | " for batch in dataset_shard.iter_torch_batches(\n", 219 | " batch_size=batch_size, dtypes=torch.float\n", 220 | " ):\n", 221 | " labels = torch.unsqueeze(batch[LABEL_COLUMN], dim=1)\n", 222 | " inputs = torch.cat(\n", 223 | " [torch.unsqueeze(batch[f], dim=1) for f in FEATURE_COLUMNS], dim=1\n", 224 | " )\n", 225 | " output = dist_model(inputs)\n", 226 | " batch_loss = loss_function(output, labels)\n", 227 | " optimizer.zero_grad()\n", 228 | " batch_loss.backward()\n", 229 | " optimizer.step()\n", 230 | "\n", 231 | " num_batches += 1\n", 232 | " loss += batch_loss.item()\n", 233 | "\n", 234 | " session.report(\n", 235 | " {\"epoch\": epoch, \"loss\": loss},\n", 236 | " checkpoint=TorchCheckpoint.from_model(dist_model)\n", 237 | " )" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "id": "5209607a", 244 | "metadata": { 245 | "lines_to_next_cell": 2 246 | }, 247 | "outputs": [], 248 | "source": [ 249 | "# NOTE: In the book we use num_workers=2, but reduce this here, so that it runs on Colab.\n", 250 | "# In any case, this training loop will take considerable time to run.\n", 251 | "trainer = TorchTrainer(\n", 252 | " train_loop_per_worker=train_loop_per_worker,\n", 253 | " train_loop_config={\n", 254 | " \"lr\": 1e-2, \"num_epochs\": 3, \"batch_size\": 64\n", 255 | " },\n", 256 | " scaling_config=ScalingConfig(num_workers=1, resources_per_worker={\"CPU\": 1, \"GPU\": 0}),\n", 257 | " datasets={\n", 258 | " \"train\": load_dataset(\"nyc_tlc_data/yellow_tripdata_2020-01.parquet\")\n", 259 | " },\n", 260 | ")\n", 261 | "\n", 262 | "result = trainer.fit()\n", 263 | "trained_model = result.checkpoint" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "id": "88deaa0b", 270 | "metadata": { 271 | "lines_to_next_cell": 2 272 | }, 273 | "outputs": [], 274 | "source": [ 275 | "from ray.train.torch import TorchPredictor\n", 276 | "from ray.train.batch_predictor import BatchPredictor\n", 277 | "\n", 278 | "batch_predictor = BatchPredictor(trained_model, TorchPredictor)\n", 279 | "ds = load_dataset(\n", 280 | " \"nyc_tlc_data/yellow_tripdata_2021-01.parquet\", include_label=False)\n", 281 | "\n", 282 | "batch_predictor.predict_pipelined(ds, blocks_per_window=10)" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": null, 288 | "id": "4087d98b", 289 | "metadata": {}, 290 | "outputs": [], 291 | "source": [ 292 | "import torch\n", 293 | "import torch.nn as nn\n", 294 | "import torch.nn.functional as F\n", 295 | "from ray.data import from_torch\n", 296 | "\n", 297 | "num_samples = 20\n", 298 | "input_size = 10\n", 299 | "layer_size = 15\n", 300 | "output_size = 5\n", 301 | "num_epochs = 3\n", 302 | "\n", 303 | "\n", 304 | "class NeuralNetwork(nn.Module):\n", 305 | " def __init__(self):\n", 306 | " super().__init__()\n", 307 | " self.fc1 = nn.Linear(input_size, layer_size)\n", 308 | " self.relu = nn.ReLU()\n", 309 | " self.fc2 = nn.Linear(layer_size, output_size)\n", 310 | "\n", 311 | " def forward(self, x):\n", 312 | " x = F.relu(self.fc1(x))\n", 313 | " x = self.fc2(x)\n", 314 | " return x\n", 315 | "\n", 316 | "\n", 317 | "def train_data():\n", 318 | " return torch.randn(num_samples, input_size)\n", 319 | "\n", 320 | "\n", 321 | "input_data = train_data()\n", 322 | "label_data = torch.randn(num_samples, output_size)\n", 323 | "train_dataset = from_torch(input_data)\n", 324 | "\n", 325 | "\n", 326 | "def train_one_epoch(model, loss_fn, optimizer):\n", 327 | " output = model(input_data)\n", 328 | " loss = loss_fn(output, label_data)\n", 329 | " optimizer.zero_grad()\n", 330 | " loss.backward()\n", 331 | " optimizer.step()\n", 332 | "\n", 333 | "\n", 334 | "def training_loop():\n", 335 | " model = NeuralNetwork()\n", 336 | " loss_fn = nn.MSELoss()\n", 337 | " optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n", 338 | " for epoch in range(num_epochs):\n", 339 | " train_one_epoch(model, loss_fn, optimizer)" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": null, 345 | "id": "0cf11c52", 346 | "metadata": {}, 347 | "outputs": [], 348 | "source": [ 349 | "from ray.train.torch import prepare_model\n", 350 | "\n", 351 | "\n", 352 | "def distributed_training_loop():\n", 353 | " model = NeuralNetwork()\n", 354 | " model = prepare_model(model)\n", 355 | " loss_fn = nn.MSELoss()\n", 356 | " optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n", 357 | " for epoch in range(num_epochs):\n", 358 | " train_one_epoch(model, loss_fn, optimizer)" 359 | ] 360 | }, 361 | { 362 | "cell_type": "code", 363 | "execution_count": null, 364 | "id": "afbc5a35", 365 | "metadata": { 366 | "lines_to_next_cell": 2 367 | }, 368 | "outputs": [], 369 | "source": [ 370 | "from ray.air.config import ScalingConfig\n", 371 | "from ray.train.torch import TorchTrainer\n", 372 | "\n", 373 | "\n", 374 | "trainer = TorchTrainer(\n", 375 | " train_loop_per_worker=distributed_training_loop,\n", 376 | " scaling_config=ScalingConfig(\n", 377 | " num_workers=2,\n", 378 | " use_gpu=False\n", 379 | " ),\n", 380 | " datasets={\"train\": train_dataset}\n", 381 | ")\n", 382 | "\n", 383 | "result = trainer.fit()" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": null, 389 | "id": "3c049af7", 390 | "metadata": { 391 | "lines_to_next_cell": 2 392 | }, 393 | "outputs": [], 394 | "source": [ 395 | "import ray\n", 396 | "\n", 397 | "from ray.air.config import ScalingConfig\n", 398 | "from ray import tune\n", 399 | "from ray.data.preprocessors import StandardScaler, MinMaxScaler\n", 400 | "\n", 401 | "\n", 402 | "dataset = ray.data.from_items(\n", 403 | " [{\"X\": x, \"Y\": 1} for x in range(0, 100)] +\n", 404 | " [{\"X\": x, \"Y\": 0} for x in range(100, 200)]\n", 405 | ")\n", 406 | "prep_v1 = StandardScaler(columns=[\"X\"])\n", 407 | "prep_v2 = MinMaxScaler(columns=[\"X\"])\n", 408 | "\n", 409 | "param_space = {\n", 410 | " \"scaling_config\": ScalingConfig(\n", 411 | " num_workers=tune.grid_search([2, 4]),\n", 412 | " resources_per_worker={\n", 413 | " \"CPU\": 2,\n", 414 | " \"GPU\": 0,\n", 415 | " },\n", 416 | " ),\n", 417 | " \"preprocessor\": tune.grid_search([prep_v1, prep_v2]),\n", 418 | " \"params\": {\n", 419 | " \"objective\": \"binary:logistic\",\n", 420 | " \"tree_method\": \"hist\",\n", 421 | " \"eval_metric\": [\"logloss\", \"error\"],\n", 422 | " \"eta\": tune.loguniform(1e-4, 1e-1),\n", 423 | " \"subsample\": tune.uniform(0.5, 1.0),\n", 424 | " \"max_depth\": tune.randint(1, 9),\n", 425 | " },\n", 426 | "}" 427 | ] 428 | }, 429 | { 430 | "cell_type": "code", 431 | "execution_count": null, 432 | "id": "c9e305d2", 433 | "metadata": {}, 434 | "outputs": [], 435 | "source": [ 436 | "from ray.train.xgboost import XGBoostTrainer\n", 437 | "from ray.air.config import RunConfig\n", 438 | "from ray.tune import Tuner\n", 439 | "\n", 440 | "\n", 441 | "trainer = XGBoostTrainer(\n", 442 | " params={},\n", 443 | " run_config=RunConfig(verbose=2),\n", 444 | " preprocessor=None,\n", 445 | " scaling_config=None,\n", 446 | " label_column=\"Y\",\n", 447 | " datasets={\"train\": dataset}\n", 448 | ")\n", 449 | "\n", 450 | "tuner = Tuner(\n", 451 | " trainer,\n", 452 | " param_space=param_space,\n", 453 | ")\n", 454 | "\n", 455 | "results = tuner.fit()" 456 | ] 457 | } 458 | ], 459 | "metadata": { 460 | "jupytext": { 461 | "cell_metadata_filter": "-all", 462 | "main_language": "python", 463 | "notebook_metadata_filter": "-all" 464 | } 465 | }, 466 | "nbformat": 4, 467 | "nbformat_minor": 5 468 | } 469 | -------------------------------------------------------------------------------- /notebooks/ch_08_model_serving.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "36d84a44", 6 | "metadata": {}, 7 | "source": [ 8 | "# Online Inference with Ray Serve" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "dcc579d9", 14 | "metadata": {}, 15 | "source": [ 16 | "\n", 17 | "You can run this notebook directly in\n", 18 | "[Colab](https://colab.research.google.com/github/maxpumperla/learning_ray/blob/main/notebooks/ch_08_model_serving.ipynb).\n", 19 | "\n", 20 | "\"Open" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "1545f53d", 26 | "metadata": {}, 27 | "source": [ 28 | "For this chapter you need to install the following dependencies:" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "9633bd83", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "! pip install \"ray[serve]==2.2.0\" \"transformers==4.21.2\"\n", 39 | "! pip install \"requests==2.28.1\" \"wikipedia==1.4.0\"" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "id": "b9d2fed7", 45 | "metadata": {}, 46 | "source": [ 47 | "\n", 48 | "To import utility files for this chapter, on Colab you will also have to clone\n", 49 | "the repo and copy the code files to the base path of the runtime:" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "id": "34fe8c35", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "!git clone https://github.com/maxpumperla/learning_ray\n", 60 | "%cp -r learning_ray/notebooks/* ." 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "source": [ 66 | "![Serve Positioning](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_08/serve_positioning.png)" 67 | ], 68 | "metadata": { 69 | "collapsed": false 70 | } 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "source": [ 75 | "![Serve Architecture](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_08/serve_arch.png)" 76 | ], 77 | "metadata": { 78 | "collapsed": false 79 | } 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "source": [ 84 | "![NLP API Architecture](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_08/nlp_api_arch.png)" 85 | ], 86 | "metadata": { 87 | "collapsed": false 88 | } 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "id": "c18b1864", 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "from ray import serve\n", 98 | "\n", 99 | "from transformers import pipeline\n", 100 | "\n", 101 | "\n", 102 | "@serve.deployment\n", 103 | "class SentimentAnalysis:\n", 104 | " def __init__(self):\n", 105 | " self._classifier = pipeline(\"sentiment-analysis\")\n", 106 | "\n", 107 | " def __call__(self, request) -> str:\n", 108 | " input_text = request.query_params[\"input_text\"]\n", 109 | " return self._classifier(input_text)[0][\"label\"]" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "id": "60908300", 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "basic_deployment = SentimentAnalysis.bind()" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "b7659c67", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "# Run this in a separate process to avoid any blocking:\n", 130 | "! serve run --non-blocking app:basic_deployment" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "id": "042a810c", 137 | "metadata": { 138 | "lines_to_next_cell": 2 139 | }, 140 | "outputs": [], 141 | "source": [ 142 | "import requests\n", 143 | "\n", 144 | "print(requests.get(\n", 145 | " \"http://localhost:8000/\", params={\"input_text\": \"Hello friend!\"}\n", 146 | ").json())" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "id": "8452a0b4", 153 | "metadata": { 154 | "lines_to_next_cell": 2 155 | }, 156 | "outputs": [], 157 | "source": [ 158 | "from fastapi import FastAPI\n", 159 | "\n", 160 | "app = FastAPI()\n", 161 | "\n", 162 | "\n", 163 | "@serve.deployment\n", 164 | "@serve.ingress(app)\n", 165 | "class SentimentAnalysis:\n", 166 | " def __init__(self):\n", 167 | " self._classifier = pipeline(\"sentiment-analysis\")\n", 168 | "\n", 169 | " @app.get(\"/\")\n", 170 | " def classify(self, input_text: str) -> str:\n", 171 | " return self._classifier(input_text)[0][\"label\"]\n", 172 | "\n", 173 | "\n", 174 | "fastapi_deployment = SentimentAnalysis.bind()" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "id": "987200e3", 181 | "metadata": { 182 | "lines_to_next_cell": 2 183 | }, 184 | "outputs": [], 185 | "source": [ 186 | "app = FastAPI()\n", 187 | "\n", 188 | "\n", 189 | "@serve.deployment(num_replicas=2, ray_actor_options={\"num_cpus\": 2})\n", 190 | "@serve.ingress(app)\n", 191 | "class SentimentAnalysis:\n", 192 | " def __init__(self):\n", 193 | " self._classifier = pipeline(\"sentiment-analysis\")\n", 194 | "\n", 195 | " @app.get(\"/\")\n", 196 | " def classify(self, input_text: str) -> str:\n", 197 | " import os\n", 198 | " print(\"from process:\", os.getpid())\n", 199 | " return self._classifier(input_text)[0][\"label\"]\n", 200 | "\n", 201 | "\n", 202 | "scaled_deployment = SentimentAnalysis.bind()" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "id": "28b7996d", 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "app = FastAPI()\n", 213 | "\n", 214 | "\n", 215 | "@serve.deployment\n", 216 | "@serve.ingress(app)\n", 217 | "class SentimentAnalysis:\n", 218 | " def __init__(self):\n", 219 | " self._classifier = pipeline(\"sentiment-analysis\")\n", 220 | "\n", 221 | " @serve.batch(max_batch_size=10, batch_wait_timeout_s=0.1)\n", 222 | " async def classify_batched(self, batched_inputs):\n", 223 | " print(\"Got batch size:\", len(batched_inputs))\n", 224 | " results = self._classifier(batched_inputs)\n", 225 | " return [result[\"label\"] for result in results]\n", 226 | "\n", 227 | " @app.get(\"/\")\n", 228 | " async def classify(self, input_text: str) -> str:\n", 229 | " return await self.classify_batched(input_text)\n", 230 | "\n", 231 | "\n", 232 | "batched_deployment = SentimentAnalysis.bind()" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "id": "a167c10b", 239 | "metadata": { 240 | "lines_to_next_cell": 2 241 | }, 242 | "outputs": [], 243 | "source": [ 244 | "import ray\n", 245 | "from ray import serve\n", 246 | "from app import batched_deployment\n", 247 | "\n", 248 | "handle = serve.run(batched_deployment)\n", 249 | "ray.get([handle.classify.remote(\"sample text\") for _ in range(10)])" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "id": "c6eb34fe", 256 | "metadata": { 257 | "lines_to_next_cell": 2 258 | }, 259 | "outputs": [], 260 | "source": [ 261 | "@serve.deployment\n", 262 | "class DownstreamModel:\n", 263 | " def __call__(self, inp: str):\n", 264 | " return \"Hi from downstream model!\"\n", 265 | "\n", 266 | "\n", 267 | "@serve.deployment\n", 268 | "class Driver:\n", 269 | " def __init__(self, downstream):\n", 270 | " self._d = downstream\n", 271 | "\n", 272 | " async def __call__(self, *args) -> str:\n", 273 | " return await self._d.remote()\n", 274 | "\n", 275 | "\n", 276 | "downstream = DownstreamModel.bind()\n", 277 | "driver = Driver.bind(downstream)" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "id": "e56ac5a3", 284 | "metadata": { 285 | "lines_to_next_cell": 2 286 | }, 287 | "outputs": [], 288 | "source": [ 289 | "@serve.deployment\n", 290 | "class DownstreamModel:\n", 291 | " def __init__(self, my_val: str):\n", 292 | " self._my_val = my_val\n", 293 | "\n", 294 | " def __call__(self, inp: str):\n", 295 | " return inp + \"|\" + self._my_val\n", 296 | "\n", 297 | "\n", 298 | "@serve.deployment\n", 299 | "class PipelineDriver:\n", 300 | " def __init__(self, model1, model2):\n", 301 | " self._m1 = model1\n", 302 | " self._m2 = model2\n", 303 | "\n", 304 | " async def __call__(self, *args) -> str:\n", 305 | " intermediate = self._m1.remote(\"input\")\n", 306 | " final = self._m2.remote(intermediate)\n", 307 | " return await final\n", 308 | "\n", 309 | "\n", 310 | "m1 = DownstreamModel.bind(\"val1\")\n", 311 | "m2 = DownstreamModel.bind(\"val2\")\n", 312 | "pipeline_driver = PipelineDriver.bind(m1, m2)" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "id": "7a5cfea3", 319 | "metadata": { 320 | "lines_to_next_cell": 2 321 | }, 322 | "outputs": [], 323 | "source": [ 324 | "@serve.deployment\n", 325 | "class DownstreamModel:\n", 326 | " def __init__(self, my_val: str):\n", 327 | " self._my_val = my_val\n", 328 | "\n", 329 | " def __call__(self):\n", 330 | " return self._my_val\n", 331 | "\n", 332 | "\n", 333 | "@serve.deployment\n", 334 | "class BroadcastDriver:\n", 335 | " def __init__(self, model1, model2):\n", 336 | " self._m1 = model1\n", 337 | " self._m2 = model2\n", 338 | "\n", 339 | " async def __call__(self, *args) -> str:\n", 340 | " output1, output2 = self._m1.remote(), self._m2.remote()\n", 341 | " return [await output1, await output2]\n", 342 | "\n", 343 | "\n", 344 | "m1 = DownstreamModel.bind(\"val1\")\n", 345 | "m2 = DownstreamModel.bind(\"val2\")\n", 346 | "broadcast_driver = BroadcastDriver.bind(m1, m2)" 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": null, 352 | "id": "26b08ba4", 353 | "metadata": { 354 | "lines_to_next_cell": 2 355 | }, 356 | "outputs": [], 357 | "source": [ 358 | "@serve.deployment\n", 359 | "class DownstreamModel:\n", 360 | " def __init__(self, my_val: str):\n", 361 | " self._my_val = my_val\n", 362 | "\n", 363 | " def __call__(self):\n", 364 | " return self._my_val\n", 365 | "\n", 366 | "\n", 367 | "@serve.deployment\n", 368 | "class ConditionalDriver:\n", 369 | " def __init__(self, model1, model2):\n", 370 | " self._m1 = model1\n", 371 | " self._m2 = model2\n", 372 | "\n", 373 | " async def __call__(self, *args) -> str:\n", 374 | " import random\n", 375 | " if random.random() > 0.5:\n", 376 | " return await self._m1.remote()\n", 377 | " else:\n", 378 | " return await self._m2.remote()\n", 379 | "\n", 380 | "\n", 381 | "m1 = DownstreamModel.bind(\"val1\")\n", 382 | "m2 = DownstreamModel.bind(\"val2\")\n", 383 | "conditional_driver = ConditionalDriver.bind(m1, m2)" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": null, 389 | "id": "33ed39cb", 390 | "metadata": {}, 391 | "outputs": [], 392 | "source": [ 393 | "from typing import Optional\n", 394 | "\n", 395 | "import wikipedia\n", 396 | "\n", 397 | "\n", 398 | "def fetch_wikipedia_page(search_term: str) -> Optional[str]:\n", 399 | " results = wikipedia.search(search_term)\n", 400 | " # If no results, return to caller.\n", 401 | " if len(results) == 0:\n", 402 | " return None\n", 403 | "\n", 404 | " # Get the page for the top result.\n", 405 | " return wikipedia.page(results[0]).content" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": null, 411 | "id": "2646606f", 412 | "metadata": {}, 413 | "outputs": [], 414 | "source": [ 415 | "from ray import serve\n", 416 | "from transformers import pipeline\n", 417 | "from typing import List\n", 418 | "\n", 419 | "\n", 420 | "@serve.deployment\n", 421 | "class SentimentAnalysis:\n", 422 | " def __init__(self):\n", 423 | " self._classifier = pipeline(\"sentiment-analysis\")\n", 424 | "\n", 425 | " @serve.batch(max_batch_size=10, batch_wait_timeout_s=0.1)\n", 426 | " async def is_positive_batched(self, inputs: List[str]) -> List[bool]:\n", 427 | " results = self._classifier(inputs, truncation=True)\n", 428 | " return [result[\"label\"] == \"POSITIVE\" for result in results]\n", 429 | "\n", 430 | " async def __call__(self, input_text: str) -> bool:\n", 431 | " return await self.is_positive_batched(input_text)" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": null, 437 | "id": "5d37f206", 438 | "metadata": {}, 439 | "outputs": [], 440 | "source": [ 441 | "@serve.deployment(num_replicas=2)\n", 442 | "class Summarizer:\n", 443 | " def __init__(self, max_length: Optional[int] = None):\n", 444 | " self._summarizer = pipeline(\"summarization\")\n", 445 | " self._max_length = max_length\n", 446 | "\n", 447 | " def __call__(self, input_text: str) -> str:\n", 448 | " result = self._summarizer(\n", 449 | " input_text, max_length=self._max_length, truncation=True)\n", 450 | " return result[0][\"summary_text\"]" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": null, 456 | "id": "931685ff", 457 | "metadata": {}, 458 | "outputs": [], 459 | "source": [ 460 | "@serve.deployment\n", 461 | "class EntityRecognition:\n", 462 | " def __init__(self, threshold: float = 0.90, max_entities: int = 10):\n", 463 | " self._entity_recognition = pipeline(\"ner\")\n", 464 | " self._threshold = threshold\n", 465 | " self._max_entities = max_entities\n", 466 | "\n", 467 | " def __call__(self, input_text: str) -> List[str]:\n", 468 | " final_results = []\n", 469 | " for result in self._entity_recognition(input_text):\n", 470 | " if result[\"score\"] > self._threshold:\n", 471 | " final_results.append(result[\"word\"])\n", 472 | " if len(final_results) == self._max_entities:\n", 473 | " break\n", 474 | "\n", 475 | " return final_results" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": null, 481 | "id": "4b925bb1", 482 | "metadata": {}, 483 | "outputs": [], 484 | "source": [ 485 | "from pydantic import BaseModel\n", 486 | "\n", 487 | "\n", 488 | "class Response(BaseModel):\n", 489 | " success: bool\n", 490 | " message: str = \"\"\n", 491 | " summary: str = \"\"\n", 492 | " named_entities: List[str] = []" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": null, 498 | "id": "479767a8", 499 | "metadata": {}, 500 | "outputs": [], 501 | "source": [ 502 | "from fastapi import FastAPI\n", 503 | "\n", 504 | "app = FastAPI()\n", 505 | "\n", 506 | "\n", 507 | "@serve.deployment\n", 508 | "@serve.ingress(app)\n", 509 | "class NLPPipelineDriver:\n", 510 | " def __init__(self, sentiment_analysis, summarizer, entity_recognition):\n", 511 | " self._sentiment_analysis = sentiment_analysis\n", 512 | " self._summarizer = summarizer\n", 513 | " self._entity_recognition = entity_recognition\n", 514 | "\n", 515 | " @app.get(\"/\", response_model=Response)\n", 516 | " async def summarize_article(self, search_term: str) -> Response:\n", 517 | " # Fetch the top page content for the search term if found.\n", 518 | " page_content = fetch_wikipedia_page(search_term)\n", 519 | " if page_content is None:\n", 520 | " return Response(success=False, message=\"No pages found.\")\n", 521 | "\n", 522 | " # Conditionally continue based on the sentiment analysis.\n", 523 | " is_positive = await self._sentiment_analysis.remote(page_content)\n", 524 | " if not is_positive:\n", 525 | " return Response(success=False, message=\"Only positivitiy allowed!\")\n", 526 | "\n", 527 | " # Query the summarizer and named entity recognition models in parallel.\n", 528 | " summary_result = self._summarizer.remote(page_content)\n", 529 | " entities_result = self._entity_recognition.remote(page_content)\n", 530 | " return Response(\n", 531 | " success=True,\n", 532 | " summary=await summary_result,\n", 533 | " named_entities=await entities_result\n", 534 | " )" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": null, 540 | "id": "b7ec00f5", 541 | "metadata": {}, 542 | "outputs": [], 543 | "source": [ 544 | "sentiment_analysis = SentimentAnalysis.bind()\n", 545 | "summarizer = Summarizer.bind()\n", 546 | "entity_recognition = EntityRecognition.bind(threshold=0.95, max_entities=5)\n", 547 | "nlp_pipeline_driver = NLPPipelineDriver.bind(\n", 548 | " sentiment_analysis, summarizer, entity_recognition)" 549 | ] 550 | }, 551 | { 552 | "cell_type": "code", 553 | "execution_count": null, 554 | "id": "62408257", 555 | "metadata": {}, 556 | "outputs": [], 557 | "source": [ 558 | "# Run this in a separate process to avoid any blocking:\n", 559 | "! serve run --non-blocking app:nlp_pipeline_driver" 560 | ] 561 | }, 562 | { 563 | "cell_type": "code", 564 | "execution_count": null, 565 | "id": "f456cfc9", 566 | "metadata": { 567 | "lines_to_next_cell": 2 568 | }, 569 | "outputs": [], 570 | "source": [ 571 | "import requests\n", 572 | "\n", 573 | "\n", 574 | "print(requests.get(\n", 575 | " \"http://localhost:8000/\", params={\"search_term\": \"rayserve\"}\n", 576 | ").text)" 577 | ] 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": null, 582 | "id": "6448177a", 583 | "metadata": { 584 | "lines_to_next_cell": 2 585 | }, 586 | "outputs": [], 587 | "source": [ 588 | "print(requests.get(\n", 589 | " \"http://localhost:8000/\", params={\"search_term\": \"war\"}\n", 590 | ").text)" 591 | ] 592 | }, 593 | { 594 | "cell_type": "code", 595 | "execution_count": null, 596 | "id": "e0b6dccf", 597 | "metadata": { 598 | "lines_to_next_cell": 2 599 | }, 600 | "outputs": [], 601 | "source": [ 602 | "print(requests.get(\n", 603 | " \"http://localhost:8000/\", params={\"search_term\": \"physicist\"}\n", 604 | ").text)" 605 | ] 606 | } 607 | ], 608 | "metadata": { 609 | "jupytext": { 610 | "cell_metadata_filter": "-all", 611 | "main_language": "python", 612 | "notebook_metadata_filter": "-all" 613 | } 614 | }, 615 | "nbformat": 4, 616 | "nbformat_minor": 5 617 | } 618 | -------------------------------------------------------------------------------- /notebooks/ch_09_example_aws.yaml: -------------------------------------------------------------------------------- 1 | # An unique identifier for the head node and workers of this cluster. 2 | cluster_name: minimal 3 | 4 | # The maximum number of workers nodes to launch in addition to the head 5 | # node. min_workers default to 0. 6 | max_workers: 1 7 | 8 | # Cloud-provider specific configuration. 9 | provider: 10 | type: aws 11 | region: us-west-2 12 | availability_zone: us-west-2a 13 | 14 | # How Ray will authenticate with newly launched nodes. 15 | auth: 16 | ssh_user: ubuntu -------------------------------------------------------------------------------- /notebooks/ch_09_example_azure.yaml: -------------------------------------------------------------------------------- 1 | # An unique identifier for the head node and workers of this cluster. 2 | cluster_name: minimal 3 | 4 | # The maximum number of workers nodes to launch in addition to the head 5 | # node. min_workers default to 0. 6 | max_workers: 1 7 | 8 | # Cloud-provider specific configuration. 9 | provider: 10 | type: azure 11 | location: westus2 12 | resource_group: ray-cluster 13 | 14 | # How Ray will authenticate with newly launched nodes. 15 | auth: 16 | ssh_user: ubuntu 17 | # you must specify paths to matching private and public key pair files 18 | # use `ssh-keygen -t rsa -b 4096` to generate a new ssh key pair 19 | ssh_private_key: ~/.ssh/id_rsa 20 | # changes to this should match what is specified in file_mounts 21 | ssh_public_key: ~/.ssh/id_rsa.pub -------------------------------------------------------------------------------- /notebooks/ch_09_example_gcp.yaml: -------------------------------------------------------------------------------- 1 | # A unique identifier for the head node and workers of this cluster. 2 | cluster_name: minimal 3 | 4 | # The maximum number of worker nodes to launch in addition to the head 5 | # node. min_workers default to 0. 6 | max_workers: 1 7 | 8 | # Cloud-provider specific configuration. 9 | provider: 10 | type: gcp 11 | region: us-west1 12 | availability_zone: us-west1-a 13 | project_id: null # Globally unique project id 14 | 15 | # How Ray will authenticate with newly launched nodes. 16 | auth: 17 | ssh_user: ubuntu -------------------------------------------------------------------------------- /notebooks/ch_09_example_k8s.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: ray.io/v1alpha1 2 | kind: RayCluster 3 | metadata: 4 | name: raycluster-complete 5 | spec: 6 | headGroupSpec: 7 | rayStartParams: 8 | port: '6379' 9 | num-cpus: '1' 10 | ... 11 | template: # Pod template 12 | metadata: # Pod metadata 13 | spec: # Pod spec 14 | containers: 15 | - name: ray-head 16 | image: rayproject/ray:1.12.1 17 | resources: 18 | limits: 19 | cpu: "1" 20 | memory: "1024Mi" 21 | requests: 22 | cpu: "1" 23 | memory: "1024Mi" 24 | ports: 25 | - containerPort: 6379 26 | name: gcs 27 | - containerPort: 8265 28 | name: dashboard 29 | - containerPort: 10001 30 | name: client 31 | env: 32 | - name: "RAY_LOG_TO_STDERR" 33 | value: "1" 34 | volumeMounts: 35 | - mountPath: /tmp/ray 36 | name: ray-logs 37 | volumes: 38 | - name: ray-logs 39 | emptyDir: {} 40 | workerGroupSpecs: 41 | - groupName: small-group 42 | replicas: 2 43 | rayStartParams: 44 | ... 45 | template: # Pod template 46 | ... 47 | - groupName: medium-group 48 | ... -------------------------------------------------------------------------------- /notebooks/ch_09_ray_start_demo.txt: -------------------------------------------------------------------------------- 1 | # tag::start_head[] 2 | ray start --head --port=6379 3 | # end::start_head[] 4 | 5 | # tag::start_head_out[] 6 | ... 7 | Next steps 8 | To connect to this Ray runtime from another node, run 9 | ray start --address=':6379' 10 | 11 | If connection fails, check your firewall settings and network configuration. 12 | # end::start_head_out[] 13 | 14 | # tag::start_worker[] 15 | ray start --address= 16 | # end::start_worker[] 17 | 18 | # tag::start_worker_out[] 19 | -------------------- 20 | Ray runtime started. 21 | -------------------- 22 | 23 | To terminate the Ray runtime, run 24 | ray stop 25 | # end::start_worker_out[] 26 | 27 | # tag::job_submission[] 28 | Job submission server address: http://127.0.0.1:8265 29 | 2022-05-20 23:35:36,066 INFO dashboard_sdk.py:276 30 | -- Uploading package gcs://_ray_pkg_533a957683abeba8.zip. 31 | 2022-05-20 23:35:36,067 INFO packaging.py:416 32 | -- Creating a file package for local directory '.'. 33 | 34 | ------------------------------------------------------- 35 | Job 'raysubmit_U5hfr1rqJZWwJmLP' submitted successfully 36 | ------------------------------------------------------- 37 | 38 | Next steps 39 | Query the logs of the job: 40 | ray job logs raysubmit_U5hfr1rqJZWwJmLP 41 | Query the status of the job: 42 | ray job status raysubmit_U5hfr1rqJZWwJmLP 43 | Request the job to be stopped: 44 | ray job stop raysubmit_U5hfr1rqJZWwJmLP 45 | 46 | Tailing logs until the job exits (disable with --no-wait): 47 | {'memory': 47157884109.0, 'object_store_memory': 2147483648.0, 48 | 'CPU': 16.0, 'node:127.0.0.1': 1.0} 49 | 50 | ------------------------------------------ 51 | Job 'raysubmit_U5hfr1rqJZWwJmLP' succeeded 52 | ------------------------------------------ 53 | # end::job_submission[] -------------------------------------------------------------------------------- /notebooks/ch_09_script.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "ec21798b", 6 | "metadata": {}, 7 | "source": [ 8 | "# Ray Clusters" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "5cd0b7fe", 14 | "metadata": {}, 15 | "source": [ 16 | "\n", 17 | "You can run this notebook directly in\n", 18 | "[Colab](https://colab.research.google.com/github/maxpumperla/learning_ray/blob/main/notebooks/ch_09_script.ipynb).\n", 19 | "\n", 20 | "\"Open" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "feb7685c", 26 | "metadata": {}, 27 | "source": [ 28 | "For this chapter you need to install the following dependencies:" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "9eb880e7", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "! pip install \"ray==2.2.0\" boto3" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "id": "f74a1486", 44 | "metadata": {}, 45 | "source": [ 46 | "\n", 47 | "To import utility files for this chapter, on Colab you will also have to clone\n", 48 | "the repo and copy the code files to the base path of the runtime:" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "id": "7a22faf4", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "!git clone https://github.com/maxpumperla/learning_ray\n", 59 | "%cp -r learning_ray/notebooks/* ." 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "source": [ 65 | "![Kuberay Overview](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_09/kuberay_overview.png)" 66 | ], 67 | "metadata": { 68 | "collapsed": false 69 | } 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "source": [ 74 | "![Ray Kubernetes Operator](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_09/ray_kubernetes_operator.png)" 75 | ], 76 | "metadata": { 77 | "collapsed": false 78 | } 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "dd58b910", 84 | "metadata": { 85 | "lines_to_next_cell": 2 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "import ray\n", 90 | "ray.init(address=\"auto\")\n", 91 | "print(ray.cluster_resources())\n", 92 | "\n", 93 | "\n", 94 | "@ray.remote\n", 95 | "def test():\n", 96 | " return 12\n", 97 | "\n", 98 | "\n", 99 | "ray.get([test.remote() for i in range(12)])" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "id": "8d331f2e", 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "import ray\n", 110 | "ray.init(address=\"ray://localhost:10001\")\n", 111 | "print(ray.cluster_resources())\n", 112 | "\n", 113 | "\n", 114 | "@ray.remote\n", 115 | "def test():\n", 116 | " return 12\n", 117 | "\n", 118 | "\n", 119 | "ray.get([test.remote() for i in range(12)])" 120 | ] 121 | } 122 | ], 123 | "metadata": { 124 | "jupytext": { 125 | "cell_metadata_filter": "-all", 126 | "main_language": "python", 127 | "notebook_metadata_filter": "-all" 128 | } 129 | }, 130 | "nbformat": 4, 131 | "nbformat_minor": 5 132 | } 133 | -------------------------------------------------------------------------------- /notebooks/ch_10_air.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "ba89df43", 6 | "metadata": {}, 7 | "source": [ 8 | "# Getting Started with the Ray AI Runtime" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "1900de5b", 14 | "metadata": {}, 15 | "source": [ 16 | "\n", 17 | "You can run this notebook directly in\n", 18 | "[Colab](https://colab.research.google.com/github/maxpumperla/learning_ray/blob/main/notebooks/ch_10_air.ipynb).\n", 19 | "\n", 20 | "\"Open" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "ad4f0182", 26 | "metadata": {}, 27 | "source": [ 28 | "For this chapter you will also need to install the following dependencies:" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "0a8477db", 35 | "metadata": { 36 | "lines_to_next_cell": 2 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "! pip install \"ray[air]==2.2.0\" \"xgboost-ray>=0.1.10\" \"xgboost>=1.6.2\"\n", 41 | "! pip install \"numpy>=1.19.5\" \"pandas>=1.3.5\" \"pyarrow>=6.0.1\" \"aiorwlock==1.3.0\"" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "id": "72b8ef5a", 47 | "metadata": {}, 48 | "source": [ 49 | "![AIR Deployment](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_10/AIR_deployment.png)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "id": "d79efcf0", 55 | "metadata": {}, 56 | "source": [ 57 | "![AIR Overview](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_10/air_overview.png)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "id": "f9dfc551", 63 | "metadata": {}, 64 | "source": [ 65 | "![AIR Plan](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_10/air_plan.png)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "id": "97f3c29b", 71 | "metadata": {}, 72 | "source": [ 73 | "![AIR Predictor](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_10/AIR_predictor.png)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "id": "35907632", 79 | "metadata": {}, 80 | "source": [ 81 | "![AIR Trainer](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_10/AIR_trainer.png)" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "id": "24287bd5", 87 | "metadata": {}, 88 | "source": [ 89 | "![AIR Tuner](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_10/AIR_tuner.png)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "id": "a68c01dd", 95 | "metadata": {}, 96 | "source": [ 97 | "![AIR Workloads](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_10/AIR_workloads.png)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "id": "86b34d0e", 103 | "metadata": {}, 104 | "source": [ 105 | "![AIR Stateless Tasks](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_10/stateless_air_tasks.png)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "id": "214b0291", 111 | "metadata": {}, 112 | "source": [ 113 | "![Tune Stateful Computation](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_10/Tune_stateful.png)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "id": "020752a3", 120 | "metadata": { 121 | "lines_to_next_cell": 2 122 | }, 123 | "outputs": [], 124 | "source": [ 125 | "import ray\n", 126 | "from ray.data.preprocessors import StandardScaler\n", 127 | "\n", 128 | "\n", 129 | "dataset = ray.data.read_csv(\"s3://anonymous@air-example-data/breast_cancer.csv\")\n", 130 | "\n", 131 | "train_dataset, valid_dataset = dataset.train_test_split(test_size=0.2)\n", 132 | "test_dataset = valid_dataset.drop_columns(cols=[\"target\"])\n", 133 | "\n", 134 | "preprocessor = StandardScaler(columns=[\"mean radius\", \"mean texture\"])" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "id": "0867208c", 141 | "metadata": { 142 | "lines_to_next_cell": 2 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "# NOTE: Colab does not have enough resources to run this example.\n", 147 | "# try using num_workers=1, resources_per_worker={\"CPU\": 1, \"GPU\": 0} in your\n", 148 | "# ScalingConfig below.\n", 149 | "# In any case, this training loop will take considerable time to run.\n", 150 | "from ray.air.config import ScalingConfig\n", 151 | "from ray.train.xgboost import XGBoostTrainer\n", 152 | "\n", 153 | "\n", 154 | "trainer = XGBoostTrainer(\n", 155 | " scaling_config=ScalingConfig(\n", 156 | " num_workers=2,\n", 157 | " use_gpu=False,\n", 158 | " ),\n", 159 | " label_column=\"target\",\n", 160 | " num_boost_round=20,\n", 161 | " params={\n", 162 | " \"objective\": \"binary:logistic\",\n", 163 | " \"eval_metric\": [\"logloss\", \"error\"],\n", 164 | " },\n", 165 | " datasets={\"train\": train_dataset, \"valid\": valid_dataset},\n", 166 | " preprocessor=preprocessor,\n", 167 | ")\n", 168 | "result = trainer.fit()\n", 169 | "print(result.metrics)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "id": "215dd80f", 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "# NOTE: Colab does not have enough resources to run this example.\n", 180 | "from ray import tune\n", 181 | "\n", 182 | "param_space = {\"params\": {\"max_depth\": tune.randint(1, 9)}}\n", 183 | "metric = \"train-logloss\"\n", 184 | "\n", 185 | "from ray.tune.tuner import Tuner, TuneConfig\n", 186 | "from ray.air.config import RunConfig\n", 187 | "\n", 188 | "tuner = Tuner(\n", 189 | " trainer,\n", 190 | " param_space=param_space,\n", 191 | " run_config=RunConfig(verbose=1),\n", 192 | " tune_config=TuneConfig(num_samples=2, metric=metric, mode=\"min\"),\n", 193 | ")\n", 194 | "result_grid = tuner.fit()\n", 195 | "\n", 196 | "best_result = result_grid.get_best_result()\n", 197 | "print(\"Best Result:\", best_result)" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": null, 203 | "id": "598890b1", 204 | "metadata": { 205 | "lines_to_next_cell": 2 206 | }, 207 | "outputs": [], 208 | "source": [ 209 | "checkpoint = best_result.checkpoint\n", 210 | "print(checkpoint)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "id": "302102ed", 217 | "metadata": { 218 | "lines_to_next_cell": 2 219 | }, 220 | "outputs": [], 221 | "source": [ 222 | "from ray.train.tensorflow import TensorflowCheckpoint\n", 223 | "import tensorflow as tf\n", 224 | "\n", 225 | "model = tf.keras.Sequential([\n", 226 | " tf.keras.layers.InputLayer(input_shape=(1,)),\n", 227 | " tf.keras.layers.Dense(1)\n", 228 | "])\n", 229 | "\n", 230 | "keras_checkpoint = TensorflowCheckpoint.from_model(model)" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "id": "911c0364", 237 | "metadata": { 238 | "lines_to_next_cell": 2 239 | }, 240 | "outputs": [], 241 | "source": [ 242 | "from ray.train.batch_predictor import BatchPredictor\n", 243 | "from ray.train.xgboost import XGBoostPredictor\n", 244 | "\n", 245 | "checkpoint = best_result.checkpoint\n", 246 | "batch_predictor = BatchPredictor.from_checkpoint(checkpoint, XGBoostPredictor)\n", 247 | "\n", 248 | "predicted_probabilities = batch_predictor.predict(test_dataset)\n", 249 | "predicted_probabilities.show()" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "id": "1b781908", 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "from ray import serve\n", 260 | "from fastapi import Request\n", 261 | "import pandas as pd\n", 262 | "from ray.serve import PredictorDeployment\n", 263 | "\n", 264 | "\n", 265 | "async def adapter(request: Request):\n", 266 | " payload = await request.json()\n", 267 | " return pd.DataFrame.from_dict(payload)\n", 268 | "\n", 269 | "\n", 270 | "serve.start(detached=True)\n", 271 | "deployment = PredictorDeployment.options(name=\"XGBoostService\")\n", 272 | "\n", 273 | "deployment.deploy(\n", 274 | " XGBoostPredictor,\n", 275 | " checkpoint,\n", 276 | " http_adapter=adapter\n", 277 | ")\n", 278 | "\n", 279 | "print(deployment.url)" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": null, 285 | "id": "55052dd1", 286 | "metadata": { 287 | "lines_to_next_cell": 2 288 | }, 289 | "outputs": [], 290 | "source": [ 291 | "import requests\n", 292 | "\n", 293 | "first_item = test_dataset.take(1)\n", 294 | "sample_input = dict(first_item[0])\n", 295 | "\n", 296 | "result = requests.post(\n", 297 | " deployment.url,\n", 298 | " json=[sample_input]\n", 299 | ")\n", 300 | "print(result.json())\n", 301 | "\n", 302 | "serve.shutdown()" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "id": "9956dc58", 309 | "metadata": { 310 | "lines_to_next_cell": 2 311 | }, 312 | "outputs": [], 313 | "source": [ 314 | "from ray.tune.tuner import Tuner\n", 315 | "from ray.train.rl.rl_trainer import RLTrainer\n", 316 | "from ray.air.config import RunConfig, ScalingConfig\n", 317 | "\n", 318 | "\n", 319 | "trainer = RLTrainer(\n", 320 | " run_config=RunConfig(stop={\"training_iteration\": 5}),\n", 321 | " scaling_config=ScalingConfig(num_workers=2, use_gpu=False),\n", 322 | " algorithm=\"PPO\",\n", 323 | " config={\"env\": \"CartPole-v1\"},\n", 324 | ")\n", 325 | "\n", 326 | "tuner = Tuner(\n", 327 | " trainer,\n", 328 | " _tuner_kwargs={\"checkpoint_at_end\": True},\n", 329 | ")\n", 330 | "\n", 331 | "result = tuner.fit()[0]" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": null, 337 | "id": "e35d9922", 338 | "metadata": { 339 | "lines_to_next_cell": 2 340 | }, 341 | "outputs": [], 342 | "source": [ 343 | "from ray.train.rl.rl_predictor import RLPredictor\n", 344 | "from ray.serve import PredictorDeployment\n", 345 | "\n", 346 | "\n", 347 | "serve.start(detached=True)\n", 348 | "deployment = PredictorDeployment.options(name=\"RLDeployment\")\n", 349 | "deployment.deploy(RLPredictor, result.checkpoint)\n", 350 | "\n", 351 | "\n", 352 | "serve.run(\n", 353 | " PredictorDeployment.options(name=\"RLDeployment\").bind(RLPredictor, result.checkpoint)\n", 354 | ")" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": null, 360 | "id": "14fabd6d", 361 | "metadata": {}, 362 | "outputs": [], 363 | "source": [ 364 | "import gym\n", 365 | "import requests\n", 366 | "\n", 367 | "\n", 368 | "num_episodes = 5\n", 369 | "env = gym.make(\"CartPole-v1\")\n", 370 | "\n", 371 | "rewards = []\n", 372 | "for i in range(num_episodes):\n", 373 | " obs = env.reset()\n", 374 | " reward = 0.0\n", 375 | " done = False\n", 376 | " while not done:\n", 377 | " action = requests.post(\n", 378 | " deployment.url,\n", 379 | " json={\"array\": obs.tolist()}\n", 380 | " ).json()\n", 381 | " obs, rew, done, _ = env.step(action)\n", 382 | " reward += rew\n", 383 | " rewards.append(reward)\n", 384 | "\n", 385 | "print(\"Episode rewards:\", rewards)\n", 386 | "\n", 387 | "serve.shutdown()" 388 | ] 389 | } 390 | ], 391 | "metadata": { 392 | "jupytext": { 393 | "cell_metadata_filter": "-all", 394 | "main_language": "python", 395 | "notebook_metadata_filter": "-all" 396 | }, 397 | "kernelspec": { 398 | "display_name": "Python 3 (ipykernel)", 399 | "language": "python", 400 | "name": "python3" 401 | }, 402 | "language_info": { 403 | "codemirror_mode": { 404 | "name": "ipython", 405 | "version": 3 406 | }, 407 | "file_extension": ".py", 408 | "mimetype": "text/x-python", 409 | "name": "python", 410 | "nbconvert_exporter": "python", 411 | "pygments_lexer": "ipython3", 412 | "version": "3.9.13" 413 | } 414 | }, 415 | "nbformat": 4, 416 | "nbformat_minor": 5 417 | } 418 | -------------------------------------------------------------------------------- /notebooks/ch_11_ecosystem.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "bfd352a6", 6 | "metadata": {}, 7 | "source": [ 8 | "# Ray's Ecosystem and Beyond" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "bac1625d", 14 | "metadata": {}, 15 | "source": [ 16 | "\n", 17 | "You can run this notebook directly in\n", 18 | "[Colab](https://colab.research.google.com/github/maxpumperla/learning_ray/blob/main/notebooks/ch_11_ecosystem.ipynb).\n", 19 | "\n", 20 | "\"Open" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "id": "c18c1823", 26 | "metadata": {}, 27 | "source": [ 28 | "For this chapter you will also need to install the following dependencies:" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "e84c8b05", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "! pip install \"ray[air, serve]==2.2.0\" \"gradio==3.5.0\" \"requests==2.28.1\"\n", 39 | "! pip install \"mlflow==1.30.0\" \"torch==1.12.1\" \"torchvision==0.13.1\"" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "id": "92f35a22", 45 | "metadata": {}, 46 | "source": [ 47 | "\n", 48 | "To import utility files for this chapter, on Colab you will also have to clone\n", 49 | "the repo and copy the code files to the base path of the runtime:" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "id": "d1cb335b", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "!git clone https://github.com/maxpumperla/learning_ray\n", 60 | "%cp -r learning_ray/notebooks/* ." 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "id": "e21a9ab7", 66 | "metadata": {}, 67 | "source": [ 68 | "![AIR ML Platform](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_11/AIR_ML_platform.png)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "id": "8297767b", 74 | "metadata": {}, 75 | "source": [ 76 | "![Custom Integrations](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_11/custom_integrations.png)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "id": "cda14a9f", 82 | "metadata": {}, 83 | "source": [ 84 | "![Ray Extended Ecosystem](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/chapter_11/Ray_extended_eco.png)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "id": "4587afa5", 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "from torchvision import transforms, datasets\n", 95 | "\n", 96 | "\n", 97 | "def load_cifar(train: bool):\n", 98 | " transform = transforms.Compose([\n", 99 | " transforms.ToTensor(),\n", 100 | " transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n", 101 | " ])\n", 102 | "\n", 103 | " return datasets.CIFAR10(\n", 104 | " root=\"./data\",\n", 105 | " download=True,\n", 106 | " train=train,\n", 107 | " transform=transform\n", 108 | " )" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "id": "430187ce", 115 | "metadata": { 116 | "lines_to_next_cell": 2 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "from ray.data import from_torch\n", 121 | "\n", 122 | "\n", 123 | "train_dataset = from_torch(load_cifar(train=True))\n", 124 | "test_dataset = from_torch(load_cifar(train=False))" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "id": "9e593fac", 131 | "metadata": { 132 | "lines_to_next_cell": 2 133 | }, 134 | "outputs": [], 135 | "source": [ 136 | "import numpy as np\n", 137 | "\n", 138 | "\n", 139 | "def to_labeled_image(batch):\n", 140 | " return {\n", 141 | " \"image\": np.array([image.numpy() for image, _ in batch]),\n", 142 | " \"label\": np.array([label for _, label in batch]),\n", 143 | " }\n", 144 | "\n", 145 | "\n", 146 | "train_dataset = train_dataset.map_batches(to_labeled_image)\n", 147 | "test_dataset = test_dataset.map_batches(to_labeled_image)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "id": "b75b9f60", 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "import torch\n", 158 | "import torch.nn as nn\n", 159 | "import torch.nn.functional as F\n", 160 | "\n", 161 | "\n", 162 | "class Net(nn.Module):\n", 163 | " def __init__(self):\n", 164 | " super().__init__()\n", 165 | " self.conv1 = nn.Conv2d(3, 6, 5)\n", 166 | " self.pool = nn.MaxPool2d(2, 2)\n", 167 | " self.conv2 = nn.Conv2d(6, 16, 5)\n", 168 | " self.fc1 = nn.Linear(16 * 5 * 5, 120)\n", 169 | " self.fc2 = nn.Linear(120, 84)\n", 170 | " self.fc3 = nn.Linear(84, 10)\n", 171 | "\n", 172 | " def forward(self, x):\n", 173 | " x = self.pool(F.relu(self.conv1(x)))\n", 174 | " x = self.pool(F.relu(self.conv2(x)))\n", 175 | " x = torch.flatten(x, 1)\n", 176 | " x = F.relu(self.fc1(x))\n", 177 | " x = F.relu(self.fc2(x))\n", 178 | " x = self.fc3(x)\n", 179 | " return x" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "id": "c2d53f79", 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "from ray import train\n", 190 | "from ray.air import session, Checkpoint\n", 191 | "\n", 192 | "\n", 193 | "def train_loop(config):\n", 194 | " model = train.torch.prepare_model(Net())\n", 195 | " loss_fct = nn.CrossEntropyLoss()\n", 196 | " optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n", 197 | "\n", 198 | " train_batches = session.get_dataset_shard(\"train\").iter_torch_batches(\n", 199 | " batch_size=config[\"batch_size\"],\n", 200 | " )\n", 201 | "\n", 202 | " for epoch in range(config[\"epochs\"]):\n", 203 | " running_loss = 0.0\n", 204 | " for i, data in enumerate(train_batches):\n", 205 | " inputs, labels = data[\"image\"], data[\"label\"]\n", 206 | "\n", 207 | " optimizer.zero_grad()\n", 208 | " forward_outputs = model(inputs)\n", 209 | " loss = loss_fct(forward_outputs, labels)\n", 210 | " loss.backward()\n", 211 | " optimizer.step()\n", 212 | "\n", 213 | " running_loss += loss.item()\n", 214 | " if i % 1000 == 0:\n", 215 | " print(f\"[{epoch + 1}, {i + 1:4d}] loss: \"\n", 216 | " f\"{running_loss / 1000:.3f}\")\n", 217 | " running_loss = 0.0\n", 218 | "\n", 219 | " session.report(\n", 220 | " dict(running_loss=running_loss),\n", 221 | " checkpoint=Checkpoint.from_dict(\n", 222 | " dict(model=model.module.state_dict())\n", 223 | " ),\n", 224 | " )" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "id": "61ad9d9d", 231 | "metadata": { 232 | "lines_to_next_cell": 2 233 | }, 234 | "outputs": [], 235 | "source": [ 236 | "from ray.train.torch import TorchTrainer\n", 237 | "from ray.air.config import ScalingConfig, RunConfig\n", 238 | "from ray.air.callbacks.mlflow import MLflowLoggerCallback\n", 239 | "\n", 240 | "\n", 241 | "trainer = TorchTrainer(\n", 242 | " train_loop_per_worker=train_loop,\n", 243 | " train_loop_config={\"batch_size\": 10, \"epochs\": 5},\n", 244 | " datasets={\"train\": train_dataset},\n", 245 | " scaling_config=ScalingConfig(num_workers=2),\n", 246 | " run_config=RunConfig(callbacks=[\n", 247 | " MLflowLoggerCallback(experiment_name=\"torch_trainer\")\n", 248 | " ])\n", 249 | "\n", 250 | ")\n", 251 | "result = trainer.fit()" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "id": "95ca2514", 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "CHECKPOINT_PATH = \"torch_checkpoint\"\n", 262 | "result.checkpoint.to_directory(CHECKPOINT_PATH)" 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "id": "816b2ca4", 268 | "metadata": {}, 269 | "source": [ 270 | "If you run this notebook in Colab, please make sure the \"torch_checkpoint\" gets\n", 271 | "generated properly. The folder needs an \".is_checkpoint\" file in it, as well as\n", 272 | "\".tune_metadata\" and a \"dict_checkpoint.pkl\". The gradio demo will throw an error\n", 273 | "on faulty checkpoints." 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "id": "429b4d97", 280 | "metadata": { 281 | "lines_to_next_cell": 2 282 | }, 283 | "outputs": [], 284 | "source": [ 285 | "# Note: if the checkpoint didn't get generated properly, you will get a \"pickle\" error here.\n", 286 | "! serve run --non-blocking gradio_demo:app" 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": null, 292 | "id": "273e364b", 293 | "metadata": { 294 | "lines_to_next_cell": 2 295 | }, 296 | "outputs": [], 297 | "source": [ 298 | "from ray.data import read_datasource, datasource\n", 299 | "\n", 300 | "\n", 301 | "class SnowflakeDatasource(datasource.Datasource):\n", 302 | " pass\n", 303 | "\n", 304 | "\n", 305 | "dataset = read_datasource(SnowflakeDatasource(), ...)" 306 | ] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "execution_count": null, 311 | "id": "4a0516f9", 312 | "metadata": { 313 | "lines_to_next_cell": 2 314 | }, 315 | "outputs": [], 316 | "source": [ 317 | "from ray.train.data_parallel_trainer import DataParallelTrainer\n", 318 | "\n", 319 | "\n", 320 | "class JaxTrainer(DataParallelTrainer):\n", 321 | " pass\n", 322 | "\n", 323 | "\n", 324 | "trainer = JaxTrainer(\n", 325 | " ...,\n", 326 | " scaling_config=ScalingConfig(...),\n", 327 | " datasets=dict(train=dataset),\n", 328 | ")" 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": null, 334 | "id": "72e08a53", 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "from ray.tune import logger, tuner\n", 339 | "from ray.air.config import RunConfig\n", 340 | "\n", 341 | "\n", 342 | "class NeptuneCallback(logger.LoggerCallback):\n", 343 | " pass\n", 344 | "\n", 345 | "\n", 346 | "tuner = tuner.Tuner(\n", 347 | " trainer,\n", 348 | " run_config=RunConfig(callbacks=[NeptuneCallback()])\n", 349 | ")" 350 | ] 351 | } 352 | ], 353 | "metadata": { 354 | "jupytext": { 355 | "cell_metadata_filter": "-all", 356 | "main_language": "python", 357 | "notebook_metadata_filter": "-all" 358 | }, 359 | "kernelspec": { 360 | "display_name": "venv", 361 | "language": "python", 362 | "name": "venv" 363 | }, 364 | "language_info": { 365 | "codemirror_mode": { 366 | "name": "ipython", 367 | "version": 3 368 | }, 369 | "file_extension": ".py", 370 | "mimetype": "text/x-python", 371 | "name": "python", 372 | "nbconvert_exporter": "python", 373 | "pygments_lexer": "ipython3", 374 | "version": "3.7.13" 375 | } 376 | }, 377 | "nbformat": 4, 378 | "nbformat_minor": 5 379 | } 380 | -------------------------------------------------------------------------------- /notebooks/fare_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | NUM_FEATURES = 6 6 | 7 | 8 | class FarePredictor(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | self.fc1 = nn.Linear(NUM_FEATURES, 256) 13 | self.fc2 = nn.Linear(256, 16) 14 | self.fc3 = nn.Linear(16, 1) 15 | 16 | self.bn1 = nn.BatchNorm1d(256) 17 | self.bn2 = nn.BatchNorm1d(16) 18 | 19 | def forward(self, x): 20 | x = F.relu(self.fc1(x)) 21 | x = self.bn1(x) 22 | x = F.relu(self.fc2(x)) 23 | x = self.bn2(x) 24 | x = torch.sigmoid(self.fc3(x)) 25 | 26 | return x 27 | -------------------------------------------------------------------------------- /notebooks/gradio_demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Net(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | self.conv1 = nn.Conv2d(3, 6, 5) 10 | self.pool = nn.MaxPool2d(2, 2) 11 | self.conv2 = nn.Conv2d(6, 16, 5) 12 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 13 | self.fc2 = nn.Linear(120, 84) 14 | self.fc3 = nn.Linear(84, 10) 15 | 16 | def forward(self, x): 17 | x = self.pool(F.relu(self.conv1(x))) 18 | x = self.pool(F.relu(self.conv2(x))) 19 | x = torch.flatten(x, 1) # flatten all dimensions except batch 20 | x = F.relu(self.fc1(x)) 21 | x = F.relu(self.fc2(x)) 22 | x = self.fc3(x) 23 | return x 24 | 25 | 26 | from ray.train.torch import TorchCheckpoint, TorchPredictor 27 | 28 | CHECKPOINT_PATH = "torch_checkpoint" 29 | checkpoint = TorchCheckpoint.from_directory(CHECKPOINT_PATH) 30 | predictor = TorchPredictor.from_checkpoint( 31 | checkpoint=checkpoint, 32 | model=Net() 33 | ) 34 | 35 | 36 | from ray.serve.gradio_integrations import GradioServer 37 | import gradio as gr 38 | import numpy as np 39 | 40 | 41 | def predict(payload): 42 | payload = np.array(payload, dtype=np.float32) 43 | array = payload.reshape((1, 3, 32, 32)) 44 | return np.argmax(predictor.predict(array)) 45 | 46 | 47 | demo = gr.Interface( 48 | fn=predict, 49 | inputs=gr.Image(), 50 | outputs=gr.Label(num_top_classes=10) 51 | ) 52 | 53 | # To just run the Gradio demo, without Serve, simply uncomment the line below 54 | # and start the script with `python gradio_demo.py`: 55 | # demo.launch() 56 | 57 | app = GradioServer.options( 58 | num_replicas=2, 59 | ray_actor_options={"num_cpus": 2} 60 | ).bind(demo) 61 | -------------------------------------------------------------------------------- /notebooks/images/chapter_01/AIR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_01/AIR.png -------------------------------------------------------------------------------- /notebooks/images/chapter_01/Ecosystem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_01/Ecosystem.png -------------------------------------------------------------------------------- /notebooks/images/chapter_01/cartpole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_01/cartpole.png -------------------------------------------------------------------------------- /notebooks/images/chapter_01/ds_workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_01/ds_workflow.png -------------------------------------------------------------------------------- /notebooks/images/chapter_01/ray_layers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_01/ray_layers.png -------------------------------------------------------------------------------- /notebooks/images/chapter_01/ray_layers_old.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_01/ray_layers_old.png -------------------------------------------------------------------------------- /notebooks/images/chapter_01/simple_cluster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_01/simple_cluster.png -------------------------------------------------------------------------------- /notebooks/images/chapter_02/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_02/architecture.png -------------------------------------------------------------------------------- /notebooks/images/chapter_02/map_reduce.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_02/map_reduce.png -------------------------------------------------------------------------------- /notebooks/images/chapter_02/task_dependency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_02/task_dependency.png -------------------------------------------------------------------------------- /notebooks/images/chapter_02/worker_node.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_02/worker_node.png -------------------------------------------------------------------------------- /notebooks/images/chapter_03/train_policy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_03/train_policy.png -------------------------------------------------------------------------------- /notebooks/images/chapter_04/mapping_envs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_04/mapping_envs.png -------------------------------------------------------------------------------- /notebooks/images/chapter_04/rllib_envs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_04/rllib_envs.png -------------------------------------------------------------------------------- /notebooks/images/chapter_04/rllib_external.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_04/rllib_external.png -------------------------------------------------------------------------------- /notebooks/images/chapter_05/Tune_model_training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_05/Tune_model_training.png -------------------------------------------------------------------------------- /notebooks/images/chapter_05/tune_flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_05/tune_flow.png -------------------------------------------------------------------------------- /notebooks/images/chapter_06/AIR_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_06/AIR_data.png -------------------------------------------------------------------------------- /notebooks/images/chapter_06/data_pipeline_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_06/data_pipeline_1.png -------------------------------------------------------------------------------- /notebooks/images/chapter_06/data_pipeline_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_06/data_pipeline_2.png -------------------------------------------------------------------------------- /notebooks/images/chapter_06/data_positioning_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_06/data_positioning_1.png -------------------------------------------------------------------------------- /notebooks/images/chapter_06/data_positioning_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_06/data_positioning_2.png -------------------------------------------------------------------------------- /notebooks/images/chapter_06/datasets_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_06/datasets_arch.png -------------------------------------------------------------------------------- /notebooks/images/chapter_06/ml_workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_06/ml_workflow.png -------------------------------------------------------------------------------- /notebooks/images/chapter_06/ml_workflow_no_logos.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_06/ml_workflow_no_logos.png -------------------------------------------------------------------------------- /notebooks/images/chapter_07/data_model_parallel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_07/data_model_parallel.png -------------------------------------------------------------------------------- /notebooks/images/chapter_07/torch_trainer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_07/torch_trainer.png -------------------------------------------------------------------------------- /notebooks/images/chapter_07/train_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_07/train_architecture.png -------------------------------------------------------------------------------- /notebooks/images/chapter_07/train_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_07/train_overview.png -------------------------------------------------------------------------------- /notebooks/images/chapter_07/train_tune_execution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_07/train_tune_execution.png -------------------------------------------------------------------------------- /notebooks/images/chapter_08/nlp_api_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_08/nlp_api_arch.png -------------------------------------------------------------------------------- /notebooks/images/chapter_08/serve_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_08/serve_arch.png -------------------------------------------------------------------------------- /notebooks/images/chapter_08/serve_positioning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_08/serve_positioning.png -------------------------------------------------------------------------------- /notebooks/images/chapter_09/kuberay_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_09/kuberay_overview.png -------------------------------------------------------------------------------- /notebooks/images/chapter_09/ray_kubernetes_operator.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_09/ray_kubernetes_operator.png -------------------------------------------------------------------------------- /notebooks/images/chapter_10/AIR_deployment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_10/AIR_deployment.png -------------------------------------------------------------------------------- /notebooks/images/chapter_10/AIR_predictor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_10/AIR_predictor.png -------------------------------------------------------------------------------- /notebooks/images/chapter_10/AIR_trainer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_10/AIR_trainer.png -------------------------------------------------------------------------------- /notebooks/images/chapter_10/AIR_tuner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_10/AIR_tuner.png -------------------------------------------------------------------------------- /notebooks/images/chapter_10/AIR_workloads.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_10/AIR_workloads.png -------------------------------------------------------------------------------- /notebooks/images/chapter_10/Tune_stateful.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_10/Tune_stateful.png -------------------------------------------------------------------------------- /notebooks/images/chapter_10/air_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_10/air_overview.png -------------------------------------------------------------------------------- /notebooks/images/chapter_10/air_plan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_10/air_plan.png -------------------------------------------------------------------------------- /notebooks/images/chapter_10/stateless_air_tasks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_10/stateless_air_tasks.png -------------------------------------------------------------------------------- /notebooks/images/chapter_11/AIR_ML_platform.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_11/AIR_ML_platform.png -------------------------------------------------------------------------------- /notebooks/images/chapter_11/Ray_extended_eco.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_11/Ray_extended_eco.png -------------------------------------------------------------------------------- /notebooks/images/chapter_11/custom_integrations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/chapter_11/custom_integrations.png -------------------------------------------------------------------------------- /notebooks/images/learning_ray.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/learning_ray.png -------------------------------------------------------------------------------- /notebooks/images/marbled_electric_ray.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/marbled_electric_ray.png -------------------------------------------------------------------------------- /notebooks/images/scaling_design_patterns.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/images/scaling_design_patterns.png -------------------------------------------------------------------------------- /notebooks/index.md: -------------------------------------------------------------------------------- 1 | # Learning Ray - Flexible Distributed Python for Machine Learning 2 | 3 | -- _Max Pumperla, Edward Oakes, Richard Liaw_ 4 | 5 | Online version of "Learning Ray" (O'Reilly). 6 | All code and diagrams used in the book are fully open-sourced, and you can find self-contained notebooks accompanying the book here for free. 7 | You won't get the exact same reading experience as with the printed book, but you should get a good idea if the book is for you. 8 | If you want to support this project and buy the book, you can e.g. get it 9 | [directly from O'Reilly](https://www.oreilly.com/library/view/learning-ray/9781098117214/), 10 | or [from Amazon](https://www.amazon.com/Learning-Ray-Flexible-Distributed-Machine/dp/1098117220/). 11 | The book will be published in May 2023, but online formats should be available before that. 12 | 13 | 14 | ![Learning Ray](https://raw.githubusercontent.com/maxpumperla/learning_ray/main/notebooks/images/learning_ray.png) 15 | 16 | 17 | ## Overview 18 | 19 | The book is organized to guide you chapter by chapter from core concepts of Ray to more sophisticated topics along the way. 20 | The first three chapters of the book teach the basics of Ray as a distributed Python framework with practical examples. 21 | Chapters four to ten introduce Ray's high-level libraries and show how to build applications with them. 22 | The last two chapters give you an overview of Ray's ecosystem and show you where to go next. 23 | Here's what you can expect from each chapter. 24 | 25 | * [_Chapter 1, An Overview of Ray_](./ch_01_overview) 26 | Introduces you at a high level to all of Ray's components, how it can be used in 27 | machine learning and other tasks, what the Ray ecosystem currently looks like and how 28 | Ray as a whole fits into the landscape of distributed Python. 29 | * [_Chapter 2, Getting Started with Ray_](./ch_02_ray_core) 30 | Walks you through the foundations of the Ray project, namely its low-level API. 31 | It also discussed how Ray Tasks and Actors naturally extend from Python functions and classes. 32 | You also learn about all of Ray's system components and how they work together. 33 | * [_Chapter 3, Building Your First Distributed Application with Ray Core_](./ch_03_core_app) 34 | Gives you an introduction to distributed systems and what makes them hard. 35 | We'll then build a first application together and discuss how to peak behind the scenes 36 | and get insights from the Ray toolbox. 37 | * [_Chapter 4, Reinforcement Learning with Ray RLlib_](./ch_04_rllib) 38 | Gives you a quick introduction to reinforcement learning and shows how Ray implements 39 | important concepts in RLlib. After building some examples together, we'll also dive into 40 | more advanced topics like preprocessors, custom models, or working with offline data. 41 | * [_Chapter 5, Hyperparameter Optimization with Ray Tune_](./ch_05_tune) 42 | Covers why efficiently tuning hyperparameters is hard, how Ray Tune works conceptually, 43 | and how you can use it in practice for your machine learning projects. 44 | * [_Chapter 6, Data Processing with Ray_](./ch_06_data_processing) 45 | Introduces you to the Dataset abstraction of Ray and how it fits into the landscape 46 | of other data structures. You will also learn how to bring pandas data frames, Dask 47 | data structures and Apache Spark workloads to Ray. 48 | * [_Chapter 7, Distributed Training with Ray Train_](./ch_07_train) 49 | Provides you with the basics of distributed model training and shows you how to use 50 | RaySGD with popular frameworks such as TensorFlow or PyTorch, and how to combine it 51 | with Ray Tune for hyperparameter optimization. 52 | * [_Chapter 9, Serving Models with Ray Serve_](./ch_08_model_serving) 53 | Introduces you to model serving with Ray, why it works well within the framework, 54 | and how to do single-node and cluster deployment with it. 55 | * [_Chapter 9, Working with Ray Clusters_](./ch_09_script) 56 | This chapter is all about how you configure, launch and scale Ray clusters for your applications. 57 | You'll learn about Ray's cluster launcher CLI and autoscaler, as well as how to set 58 | up clusters in the cloud and how to deploy on Kubernetes and other cluster managers. 59 | * [_Chapter 10, Getting Started with the Ray AI Runtime_](./ch_10_air) 60 | Introduces you to Ray AIR, a unified toolkit for your ML workloads that offers many 61 | third party integrations for model training or accessing custom data sources. 62 | * [_Chapter 11, Ray's Ecosystem and Beyond_](./ch_11_ecosystem) 63 | Gives you an overview of the many interesting extensions and 64 | integrations that Ray has attracted over the years. 65 | 66 | ```python 67 | 68 | ``` 69 | -------------------------------------------------------------------------------- /notebooks/maze.py: -------------------------------------------------------------------------------- 1 | from ray.rllib.algorithms.dqn import DQNConfig 2 | 3 | config = DQNConfig().environment("maze_gym_env.GymEnvironment")\ 4 | .rollouts(num_rollout_workers=0) 5 | -------------------------------------------------------------------------------- /notebooks/maze.yml: -------------------------------------------------------------------------------- 1 | maze_env: 2 | env: maze_gym_env.GymEnvironment 3 | run: DQN 4 | checkpoint_freq: 1 5 | stop: 6 | timesteps_total: 10000 7 | -------------------------------------------------------------------------------- /notebooks/maze_gym_env.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import random 4 | 5 | 6 | class Discrete: 7 | def __init__(self, num_actions: int): 8 | """ Discrete action space for num_actions. 9 | Discrete(4) can be used as encoding moving in one of the cardinal directions. 10 | """ 11 | self.n = num_actions 12 | 13 | def sample(self): 14 | return random.randint(0, self.n - 1) 15 | 16 | 17 | class Environment: 18 | 19 | seeker, goal = (0, 0), (4, 4) 20 | info = {'seeker': seeker, 'goal': goal} 21 | 22 | def __init__(self, *args, **kwargs): 23 | self.action_space = Discrete(4) 24 | self.observation_space = Discrete(5*5) 25 | 26 | def reset(self): 27 | """Reset seeker and goal positions, return observations.""" 28 | self.seeker = (0, 0) 29 | self.goal = (4, 4) 30 | 31 | return self.get_observation() 32 | 33 | def get_observation(self): 34 | """Encode the seeker position as integer""" 35 | return 5 * self.seeker[0] + self.seeker[1] 36 | 37 | def get_reward(self): 38 | """Reward finding the goal""" 39 | return 1 if self.seeker == self.goal else 0 40 | 41 | def is_done(self): 42 | """We're done if we found the goal""" 43 | return self.seeker == self.goal 44 | 45 | def step(self, action): 46 | """Take a step in a direction and return all available information.""" 47 | if action == 0: # move down 48 | self.seeker = (min(self.seeker[0] + 1, 4), self.seeker[1]) 49 | elif action == 1: # move left 50 | self.seeker = (self.seeker[0], max(self.seeker[1] - 1, 0)) 51 | elif action == 2: # move up 52 | self.seeker = (max(self.seeker[0] - 1, 0), self.seeker[1]) 53 | elif action == 3: # move right 54 | self.seeker = (self.seeker[0], min(self.seeker[1] + 1, 4)) 55 | else: 56 | raise ValueError("Invalid action") 57 | 58 | return self.get_observation(), self.get_reward(), self.is_done(), self.info 59 | 60 | def render(self, *args, **kwargs): 61 | """Render the environment, e.g. by printing its representation.""" 62 | os.system('cls' if os.name == 'nt' else 'clear') 63 | try: 64 | from IPython.display import clear_output 65 | clear_output(wait=True) 66 | except Exception: 67 | pass 68 | grid = [['| ' for _ in range(5)] + ["|\n"] for _ in range(5)] 69 | grid[self.goal[0]][self.goal[1]] = '|G' 70 | grid[self.seeker[0]][self.seeker[1]] = '|S' 71 | print(''.join([''.join(grid_row) for grid_row in grid])) 72 | 73 | 74 | import gym 75 | from gym.spaces import Discrete 76 | 77 | 78 | class GymEnvironment(Environment, gym.Env): 79 | def __init__(self, *args, **kwargs): 80 | """Make our original `Environment` a gym `Env`.""" 81 | super().__init__(*args, **kwargs) 82 | 83 | 84 | gym_env = GymEnvironment() 85 | -------------------------------------------------------------------------------- /notebooks/nyc_tlc_data/yellow_tripdata_2020-01.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/nyc_tlc_data/yellow_tripdata_2020-01.parquet -------------------------------------------------------------------------------- /notebooks/nyc_tlc_data/yellow_tripdata_2021-01.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maxpumperla/learning_ray/321ebe5fdab451f75f2736683fd40921feffdf27/notebooks/nyc_tlc_data/yellow_tripdata_2021-01.parquet -------------------------------------------------------------------------------- /notebooks/policy_client.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import gym 3 | from ray.rllib.env.policy_client import PolicyClient 4 | from maze_gym_env import GymEnvironment 5 | 6 | if __name__ == "__main__": 7 | env = GymEnvironment() 8 | client = PolicyClient("http://localhost:9900", inference_mode="remote") 9 | 10 | obs = env.reset() 11 | episode_id = client.start_episode(training_enabled=True) 12 | 13 | while True: 14 | action = client.get_action(episode_id, obs) 15 | 16 | obs, reward, done, info = env.step(action) 17 | 18 | client.log_returns(episode_id, reward, info=info) 19 | 20 | if done: 21 | client.end_episode(episode_id, obs) 22 | exit(0) 23 | -------------------------------------------------------------------------------- /notebooks/policy_server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import ray 3 | from ray.rllib.agents.dqn import DQNConfig 4 | from ray.rllib.env.policy_server_input import PolicyServerInput 5 | import gym 6 | 7 | 8 | ray.init() 9 | 10 | 11 | def policy_input(context): 12 | return PolicyServerInput(context, "localhost", 9900) 13 | 14 | 15 | config = DQNConfig()\ 16 | .environment( 17 | env=None, 18 | action_space=gym.spaces.Discrete(4), 19 | observation_space=gym.spaces.Discrete(5*5))\ 20 | .debugging(log_level="INFO")\ 21 | .rollouts(num_rollout_workers=0)\ 22 | .offline_data( 23 | input=policy_input, 24 | input_evaluation=[])\ 25 | 26 | 27 | algo = config.build() 28 | 29 | 30 | if __name__ == "__main__": 31 | 32 | time_steps = 0 33 | for _ in range(100): 34 | results = algo.train() 35 | checkpoint = algo.save() 36 | if time_steps >= 1000: 37 | break 38 | time_steps += results["timesteps_total"] 39 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mkdocs 2 | pymdown-extensions 3 | mkdocs-jupyter==0.18.2 4 | mkdocs-material==7.3.6 5 | mkdocs-material-extensions==1.0.3 6 | mdx-include==1.4.2 --------------------------------------------------------------------------------