├── .gitignore ├── 1-deploy-esm3-inference-endpoint.ipynb ├── 2-basic-patterns.ipynb ├── 3-enzyme-scaffold-modification.ipynb ├── 4-cleanup.ipynb ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── esm3-sagemaker-sample-notebook.ipynb ├── images ├── all.png ├── cot.png ├── esm3-architecture.png ├── model-from-marketplace.png ├── part_seq-seq.png ├── seq-func.png ├── seq-str.png ├── seq_str_out.png └── str-seq.png ├── requirements.txt └── src ├── __init__.py └── esmhelpers.py /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | .venv 3 | -------------------------------------------------------------------------------- /1-deploy-esm3-inference-endpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Deploy ESM3-open Model Package from AWS Marketplace \n", 8 | "\n", 9 | "---\n", 10 | "## 1. Overview\n", 11 | "\n", 12 | "### 1.1. Important Note:\n", 13 | "\n", 14 | "Please visit model detail page in https://aws.amazon.com/marketplace/pp/prodview-xbvra5ylcu4xq to learn more. If you do not have access to the link, please contact account admin for the help.\n", 15 | "\n", 16 | "You will find details about the model including pricing, supported region, and end user license agreement. To use the model, please click “Continue to Subscribe” from the detail page, come back here and learn how to deploy and inference.\n", 17 | "\n", 18 | "ESM3 is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks.\n", 19 | "ESM3 is a generative masked language model. You can prompt it with partial sequence, structure, and function keywords, and iteratively sample masked positions until all positions are unmasked.\n", 20 | "\n", 21 | "\"ESM3\n", 22 | "\n", 23 | "\n", 24 | "The ESM3 architecture is highly scalable due to its transformer backbone and all-to-all reasoning over discrete token sequences. At its largest scale, ESM3 was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens, and has 98 billion parameters.\n", 25 | "Here we present esm3-open-small. With 1.4B parameters it is the smallest and fastest model in the family, trained specifically to be open sourced. ESM3-open is available under a non-commercial license.\n", 26 | "\n", 27 | "This sample notebook shows you how to deploy [EvolutionaryScale - ESM3](https://aws.amazon.com/marketplace/pp/prodview-xbvra5ylcu4xq) using Amazon SageMaker.\n", 28 | "\n", 29 | "> **Note**: This is a reference notebook and it cannot run unless you make changes suggested in the notebook.\n", 30 | "\n", 31 | "> ESM3 model package support SageMaker Realtime Inference but not SageMaker Batch Transform.\n", 32 | "\n", 33 | "### 1.2. Prerequisites\n", 34 | "- This notebook contains elements which render correctly in Jupyter interface. Open this notebook from an Amazon SageMaker Notebook Instance or Amazon SageMaker Studio.\n", 35 | "- Ensure that IAM role used has **AmazonSageMakerFullAccess** and a trust policy for `sagemaker.amazonaws.com`, as described in the SageMaker documentation.\n", 36 | "- To deploy this ML model successfully, ensure that you meet one of the following conditions:\n", 37 | " 1. Your IAM role has these three permissions and you have authority to make AWS Marketplace subscriptions in the AWS account used: \n", 38 | " - **aws-marketplace:ViewSubscriptions**\n", 39 | " - **aws-marketplace:Unsubscribe**\n", 40 | " - **aws-marketplace:Subscribe** \n", 41 | " 2. Your AWS account has a subscription to [ESM3](https://aws.amazon.com/marketplace/pp/prodview-xbvra5ylcu4xq). If so, skip step: [Subscribe to the model package](#1.-Subscribe-to-the-model-package)\n", 42 | "\n", 43 | "### 1.3. Contents\n", 44 | "1. [Overview](#1.-Overview)\n", 45 | "2. [Subscribe to the model package](#2.-Subscribe-to-the-model-package)\n", 46 | "3. [Create a real-time inference endpoint ](#3.-Create-a-real-time-inference-endpoint)\n", 47 | "4. [Test endpoint](#4.-Test-endpoint)\n", 48 | "5. [Clean up](#5.-Clean-up)\n", 49 | "\n", 50 | "\n", 51 | "### 1.4. Usage instructions\n", 52 | "You can run this notebook one cell at a time by pressing the Shift+Enter keys." 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "---\n", 60 | "## 2. Subscribe to the model package" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "1. Open the model package listing page [EvolutionaryScale ESM3 Model](https://aws.amazon.com/marketplace/pp/prodview-xbvra5ylcu4xq)\n", 68 | "1. On the AWS Marketplace listing, click on the **Continue to subscribe** button.\n", 69 | "1. On the **Subscribe to this software** page, review and click on **\"Accept Offer\"** if you and your organization agrees with EULA, pricing, and support terms. \n", 70 | "1. Once you click on **Continue to configuration button** and then choose a **region**, you will see a **Product Arn** displayed. This is the model package ARN that you need to specify while creating a deployable model using Boto3. Copy the ARN corresponding to your region and specify the same in the following cell." 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "---\n", 78 | "## 3. Create a real-time inference endpoint" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "To learn more about real-time inference on Amazon SageMaker, please visit the [Documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-hosting.html)." 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "### 3.1. Setup" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "Install dependencies. For inference capabilities we will use EvolutionaryScale's `esm` package. The order of installation is important so that dependencies don't get overriden." 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": { 106 | "tags": [] 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "from IPython.display import clear_output\n", 111 | "\n", 112 | "%pip install -U esm --no-deps\n", 113 | "%pip install -U -r requirements.txt\n", 114 | "%pip install -U sagemaker\n", 115 | "\n", 116 | "clear_output()" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "Define ESM3 model package and instance information. You can find the `MODEL_NAME` and `ESM3_PACKAGE_ID` from the SageMaker console. Go to the SageMaker Console > Inference > Marketplace model packages. Then go to the tab that says AWS Marketplace Subscriptions.\n", 124 | "\n", 125 | "This notebook is designed to work different models from EvolutionaryScale. This notebook shows how you can use the ESM Open model." 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "ESM3_PACKAGE_ID = \"esm3-sm-open-v1-e218175afc0b3c8d959cb2702a2d1097\" \n", 135 | "MODEL_NAME = \"esm3-sm-open-v1\" # This is the open model version\n", 136 | "INSTANCE_TYPE = \"ml.g5.2xlarge\"\n", 137 | "INITIAL_INSTANCE_COUNT = 1" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": { 144 | "tags": [] 145 | }, 146 | "outputs": [], 147 | "source": [ 148 | "import boto3\n", 149 | "import sagemaker\n", 150 | "\n", 151 | "# Create SageMaker clients\n", 152 | "sagemaker_session = sagemaker.Session()\n", 153 | "region = sagemaker_session.boto_region_name\n", 154 | "sagemaker_client = boto3.client(\"sagemaker\", region_name=region)\n", 155 | "sagemaker_runtime_client = boto3.client(\"sagemaker-runtime\", region_name=region)\n", 156 | "\n", 157 | "# Get SageMaker execution role\n", 158 | "try:\n", 159 | " role = sagemaker.get_execution_role()\n", 160 | " print(f\"Default SageMaker execution role: {role}\")\n", 161 | "except ValueError as e:\n", 162 | " print(f\"Error getting default execution role: {e}\")\n", 163 | " print(\n", 164 | " \"You may need to specify a role explicitly or create one if not running in a SageMaker environment.\"\n", 165 | " )\n", 166 | "\n", 167 | "# Identify model package arm\n", 168 | "model_package_map = {\n", 169 | " \"ap-northeast-1\": f\"arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/{ESM3_PACKAGE_ID}\",\n", 170 | " \"ap-northeast-2\": f\"arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/{ESM3_PACKAGE_ID}\",\n", 171 | " \"ap-south-1\": f\"arn:aws:sagemaker:ap-south-1:077584701553:model-package/{ESM3_PACKAGE_ID}\",\n", 172 | " \"ap-southeast-1\": f\"arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/{ESM3_PACKAGE_ID}\",\n", 173 | " \"ap-southeast-2\": f\"arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/{ESM3_PACKAGE_ID}\",\n", 174 | " \"ca-central-1\": f\"arn:aws:sagemaker:ca-central-1:470592106596:model-package/{ESM3_PACKAGE_ID}\",\n", 175 | " \"eu-central-1\": f\"arn:aws:sagemaker:eu-central-1:446921602837:model-package/{ESM3_PACKAGE_ID}\",\n", 176 | " \"eu-north-1\": f\"arn:aws:sagemaker:eu-north-1:136758871317:model-package/{ESM3_PACKAGE_ID}\",\n", 177 | " \"eu-west-1\": f\"arn:aws:sagemaker:eu-west-1:985815980388:model-package/{ESM3_PACKAGE_ID}\",\n", 178 | " \"eu-west-2\": f\"arn:aws:sagemaker:eu-west-2:856760150666:model-package/{ESM3_PACKAGE_ID}\",\n", 179 | " \"eu-west-3\": f\"arn:aws:sagemaker:eu-west-3:843114510376:model-package/{ESM3_PACKAGE_ID}\",\n", 180 | " \"sa-east-1\": f\"arn:aws:sagemaker:sa-east-1:270155090741:model-package/{ESM3_PACKAGE_ID}\",\n", 181 | " \"us-east-1\": f\"arn:aws:sagemaker:us-east-1:865070037744:model-package/{ESM3_PACKAGE_ID}\",\n", 182 | " \"us-east-2\": f\"arn:aws:sagemaker:us-east-2:057799348421:model-package/{ESM3_PACKAGE_ID}\",\n", 183 | " \"us-west-1\": f\"arn:aws:sagemaker:us-west-1:382657785993:model-package/{ESM3_PACKAGE_ID}\",\n", 184 | " \"us-west-2\": f\"arn:aws:sagemaker:us-west-2:594846645681:model-package/{ESM3_PACKAGE_ID}\",\n", 185 | "}\n", 186 | "\n", 187 | "if region not in model_package_map.keys():\n", 188 | " raise Exception(f\"Current boto3 session region {region} is not supported.\")\n", 189 | "\n", 190 | "model_package_arn = model_package_map[region]\n", 191 | "print(f\"Model package ARN: {model_package_arn}\")" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": {}, 197 | "source": [ 198 | "### 3.2. Create a model from the subscribed model package" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "model = sagemaker.model.ModelPackage(\n", 208 | " role=role,\n", 209 | " model_package_arn=model_package_arn,\n", 210 | " sagemaker_session=sagemaker_session,\n", 211 | " enable_network_isolation=True,\n", 212 | " predictor_cls=sagemaker.predictor.Predictor,\n", 213 | ")" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "### 3.3. Create a real-time endpoint" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": {}, 226 | "source": [ 227 | "Note: This step will take 10-20 minutes." 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": { 234 | "tags": [] 235 | }, 236 | "outputs": [], 237 | "source": [ 238 | "predictor = model.deploy(\n", 239 | " initial_instance_count=INITIAL_INSTANCE_COUNT,\n", 240 | " instance_type=INSTANCE_TYPE,\n", 241 | " sagemaker_session=sagemaker_session,\n", 242 | " serializer=sagemaker.base_serializers.JSONSerializer(),\n", 243 | " deserializer=sagemaker.base_deserializers.JSONDeserializer(),\n", 244 | ")\n", 245 | "\n", 246 | "print(f\"Deployed endpoint name is {predictor.endpoint_name}\")\n", 247 | "print(f\"Model name is {MODEL_NAME}\")" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": null, 253 | "metadata": {}, 254 | "outputs": [], 255 | "source": [ 256 | "print(predictor.endpoint_name)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "ENDPOINT_NAME = predictor.endpoint_name" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": {}, 271 | "source": [ 272 | "It will require several minutes to deploy the model to an endpoint" 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": {}, 278 | "source": [ 279 | "---\n", 280 | "## 4. Test endpoint" 281 | ] 282 | }, 283 | { 284 | "cell_type": "markdown", 285 | "metadata": {}, 286 | "source": [ 287 | "### 4.1. Let's create a simple new protein sequence as a test for our endpoint" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": null, 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [ 296 | "from esm.sdk.api import ESMProtein, GenerationConfig\n", 297 | "from esm.sdk.sagemaker import ESM3SageMakerClient\n", 298 | "from src.esmhelpers import format_seq, quick_pdb_plot, quick_aligment_plot\n", 299 | "\n", 300 | "model = ESM3SageMakerClient(endpoint_name=ENDPOINT_NAME, model=MODEL_NAME)" 301 | ] 302 | }, 303 | { 304 | "cell_type": "markdown", 305 | "metadata": {}, 306 | "source": [ 307 | "ESM3 is a generative model, so the most basic task it can accomplish is to create the sequence and structure of a new protein. All ESM3 inference requests must include sequence information, so in this case we will pass a string of \"_\" symbols. This is the \"mask\" token that indicates where we want ESM3 to fill in the blanks.\n", 308 | "\n", 309 | "We start by generating a new protein sequence." 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": null, 315 | "metadata": {}, 316 | "outputs": [], 317 | "source": [ 318 | "%%time\n", 319 | "\n", 320 | "n_masked = 64\n", 321 | "\n", 322 | "masked_sequence = \"_\" * n_masked\n", 323 | "\n", 324 | "prompt = ESMProtein(sequence=masked_sequence)\n", 325 | "sequence_generation_config = GenerationConfig(\n", 326 | " track=\"sequence\", # We want ESM3 to generate tokens for the sequence track\n", 327 | " num_steps=prompt.sequence.count(\"_\") // 4, # We'll use num(mask tokens) // 4 steps to decode the sequence\n", 328 | " temperature=0.7, # We'll use a temperature of 0.7 to increase the randomness of the decoding process\n", 329 | ")\n", 330 | "\n", 331 | "# Call the ESM3 inference endpoint\n", 332 | "generated_protein = model.generate(\n", 333 | " prompt,\n", 334 | " sequence_generation_config,\n", 335 | ")\n", 336 | "\n", 337 | "# View the generated sequence\n", 338 | "print(f\"Sequence length: {len(generated_protein.sequence)}\")\n", 339 | "print(format_seq(generated_protein.sequence))\n" 340 | ] 341 | }, 342 | { 343 | "cell_type": "markdown", 344 | "metadata": {}, 345 | "source": [ 346 | "## Next steps\n", 347 | "\n", 348 | "Voila! You have a new sequence of proteins generated using ESM3 Open model. Head to the next notebook for more basic patterns." 349 | ] 350 | } 351 | ], 352 | "metadata": { 353 | "instance_type": "ml.t3.medium", 354 | "kernelspec": { 355 | "display_name": "Python 3 (ipykernel)", 356 | "language": "python", 357 | "name": "python3" 358 | }, 359 | "language_info": { 360 | "codemirror_mode": { 361 | "name": "ipython", 362 | "version": 3 363 | }, 364 | "file_extension": ".py", 365 | "mimetype": "text/x-python", 366 | "name": "python", 367 | "nbconvert_exporter": "python", 368 | "pygments_lexer": "ipython3", 369 | "version": "3.10.12" 370 | }, 371 | "vscode": { 372 | "interpreter": { 373 | "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e" 374 | } 375 | } 376 | }, 377 | "nbformat": 4, 378 | "nbformat_minor": 4 379 | } 380 | -------------------------------------------------------------------------------- /2-basic-patterns.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "ec6f0e6c-e514-496c-ba0c-0e5f7184ad9e", 6 | "metadata": { 7 | "id": "GZeZDsBYTe6z" 8 | }, 9 | "source": [ 10 | "# ESM3 on SageMaker JumpStart\n", 11 | "\n", 12 | "The demo will showcase ESM3's ability to perform several protein design tasks.\n", 13 | "\n", 14 | "![1](images/all.png)" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "id": "147044cf", 20 | "metadata": {}, 21 | "source": [ 22 | "## 1. Setup" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "id": "d0963b84-e393-4a61-8927-afd48540d299", 28 | "metadata": {}, 29 | "source": [ 30 | "Note: you'll need to run the first notebook `1-deploy-esm3-inference-endpoint.ipynb` before running this one to get the `ENDPOINT_NAME` used below. Optionally if you deploy the model via the SageMaker console you can find the endpoint in the inference section of the console." 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "0e1ff740", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "ENDPOINT_NAME = \"\"\n", 41 | "MODEL_NAME = \"esm3-sm-open-v1\"" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "3fc6f04b", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "from esm.sdk.api import ESMProtein, GenerationConfig\n", 52 | "from esm.sdk.sagemaker import ESM3SageMakerClient\n", 53 | "from src.esmhelpers import format_seq, quick_pdb_plot, quick_aligment_plot\n", 54 | "\n", 55 | "model = ESM3SageMakerClient(endpoint_name=ENDPOINT_NAME, model=MODEL_NAME)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "id": "ab1db171-8f75-41c9-b59c-beb9ae6733b1", 61 | "metadata": {}, 62 | "source": [ 63 | "---\n", 64 | "## 2. Sequence + Structure Generation\n", 65 | "\n", 66 | "![Sequence and Structure Generation](images/seq_str_out.png)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "id": "c4f2f3d2", 72 | "metadata": {}, 73 | "source": [ 74 | "ESM3 is a generative model, so the most basic task it can accomplish is to create the sequence and structure of a new protein. All ESM3 inference requests must include sequence information, so in this case we will pass a string of \"_\" symbols. This is the \"mask\" token that indicates where we want ESM3 to fill in the blanks.\n", 75 | "\n", 76 | "We start by generating a new protein sequence." 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "3e39093e", 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "%%time\n", 87 | "\n", 88 | "n_masked = 64\n", 89 | "\n", 90 | "masked_sequence = \"_\" * n_masked\n", 91 | "\n", 92 | "prompt = ESMProtein(sequence=masked_sequence)\n", 93 | "sequence_generation_config = GenerationConfig(\n", 94 | " track=\"sequence\", # We want ESM3 to generate tokens for the sequence track\n", 95 | " num_steps=prompt.sequence.count(\"_\") // 4, # We'll use num(mask tokens) // 4 steps to decode the sequence\n", 96 | " temperature=0.7, # We'll use a temperature of 0.7 to increase the randomness of the decoding process\n", 97 | ")\n", 98 | "\n", 99 | "# Call the ESM3 inference endpoint\n", 100 | "generated_protein = model.generate(\n", 101 | " prompt,\n", 102 | " sequence_generation_config,\n", 103 | ")\n", 104 | "\n", 105 | "# View the generated sequence\n", 106 | "print(f\"Sequence length: {len(generated_protein.sequence)}\")\n", 107 | "print(format_seq(generated_protein.sequence))\n" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "id": "2bca5a11", 113 | "metadata": {}, 114 | "source": [ 115 | "Next, we predict the structure of the generated sequence and display the results." 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "id": "d0873929", 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "%%time\n", 126 | "\n", 127 | "import py3Dmol\n", 128 | "\n", 129 | "prompt = generated_protein\n", 130 | "\n", 131 | "structure_generation_config = GenerationConfig(\n", 132 | " track=\"structure\", # We want ESM3 to generate tokens for the structure track\n", 133 | " num_steps=len(generated_protein.sequence) // 8,\n", 134 | " temperature=0.0, \n", 135 | ")\n", 136 | "\n", 137 | "generated_protein = model.generate(\n", 138 | " prompt,\n", 139 | " structure_generation_config,\n", 140 | ")\n", 141 | "print(f\"Structure coordinates dimensions: {tuple(generated_protein.coordinates.shape)}\")\n", 142 | "\n", 143 | "quick_pdb_plot(generated_protein.to_protein_chain().infer_oxygen().to_pdb_string(), color=\"spectrum\")\n" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "id": "b0b3fbc3", 149 | "metadata": {}, 150 | "source": [ 151 | "Let's repeat the sequence + structure generation a few more times. In this case we'll generate all of the tokens in a single step. This makes the inference much faster, but will reduced accuracy." 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": null, 157 | "id": "a7702a29", 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "# Generate sequence\n", 162 | "for i in range(3):\n", 163 | " print(f\"Iteration {i+1}\")\n", 164 | " sequence_prompt = ESMProtein(sequence=\"_\" * n_masked)\n", 165 | " sequence_generation_config = GenerationConfig(\n", 166 | " track=\"sequence\",\n", 167 | " num_steps=1,\n", 168 | " temperature=0.7,\n", 169 | " )\n", 170 | " generated_protein = model.generate(\n", 171 | " sequence_prompt,\n", 172 | " sequence_generation_config,\n", 173 | " )\n", 174 | " print(format_seq(generated_protein.sequence))\n", 175 | "\n", 176 | " # Generate structure\n", 177 | " structure_prompt = generated_protein\n", 178 | " structure_generation_config = GenerationConfig(\n", 179 | " track=\"structure\",\n", 180 | " num_steps=1,\n", 181 | " temperature=0.0,\n", 182 | " )\n", 183 | "\n", 184 | " generated_protein = model.generate(\n", 185 | " generated_protein,\n", 186 | " structure_generation_config,\n", 187 | " )\n", 188 | "\n", 189 | " quick_pdb_plot(\n", 190 | " generated_protein.to_protein_chain().infer_oxygen().to_pdb_string(),\n", 191 | " width=400,\n", 192 | " height=300,\n", 193 | " color=\"spectrum\",\n", 194 | " )" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "id": "9a2891a9", 200 | "metadata": {}, 201 | "source": [ 202 | "---\n", 203 | "## 3. Sequence to Function Prediction\n", 204 | "\n", 205 | "![Sequence In - Function Out](images/seq-func.png)\n", 206 | "\n", 207 | "Another common task is function prediction. Given an unknown amino acid sequence, can we predict the function of its domains? Let's try an example.\n", 208 | "\n", 209 | "For this example, we'll look at pyruvate kinase (PDB ID: [1PKN](https://www.rcsb.org/structure/1PKN)), a key enzyme involved in the breakdown of sugar into energy. It is composed of two different domains, or functional units, the “Barrel Domain” (colored in green below) and the “C-Terminal Domain” (colored in orange)." 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "id": "483516da", 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "from esm.utils.structure.protein_chain import ProteinChain\n", 220 | "import py3Dmol\n", 221 | "\n", 222 | "pdb_id = \"1PKN\"\n", 223 | "chain_id = \"A\"\n", 224 | "\n", 225 | "# Download the mmCIF file for 1PKN from PDB\n", 226 | "pyruvate_kinase_chain = ProteinChain.from_rcsb(pdb_id, chain_id)\n", 227 | "\n", 228 | "# Display the sequence\n", 229 | "print(format_seq(pyruvate_kinase_chain.sequence))\n", 230 | "\n", 231 | "# Display the structure\n", 232 | "view = py3Dmol.view(width=400, height=300)\n", 233 | "view.addModel(pyruvate_kinase_chain.to_pdb_string(), \"pdb\")\n", 234 | "view.setStyle({\"cartoon\": {\"color\": \"lightgrey\"}})\n", 235 | "view.addStyle({\"resi\": list(range(40, 373))}, {\"cartoon\": {\"color\": \"#38EF7D\"}})\n", 236 | "view.addStyle({\"resi\": list(range(408, 526))}, {\"cartoon\": {\"color\": \"#FF9900\"}})\n", 237 | "view.rotate(150, \"x\")\n", 238 | "view.rotate(45, \"y\")\n", 239 | "view.rotate(45, \"z\")\n", 240 | "view.zoomTo()\n", 241 | "view.show()" 242 | ] 243 | }, 244 | { 245 | "cell_type": "markdown", 246 | "id": "22f66f7d", 247 | "metadata": {}, 248 | "source": [ 249 | "Let's submit the pyruvate kinase sequence to ESM3 and request functional annotations by setting the `track` parameter to `function'." 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "id": "a2a96bf8", 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "prompt = ESMProtein.from_protein_chain(pyruvate_kinase_chain)\n", 260 | "\n", 261 | "function_prediction_config = GenerationConfig(\n", 262 | " track=\"function\",\n", 263 | " num_steps=len(prompt.sequence)\n", 264 | " // 8,\n", 265 | ")\n", 266 | "\n", 267 | "generated_protein = model.generate(\n", 268 | " prompt,\n", 269 | " function_prediction_config,\n", 270 | ")\n", 271 | "\n", 272 | "for annotation in generated_protein.function_annotations:\n", 273 | " print(annotation.to_tuple())" 274 | ] 275 | }, 276 | { 277 | "cell_type": "markdown", 278 | "id": "4f38fee2-3c53-45f9-b33a-ecfd96a8bd0e", 279 | "metadata": {}, 280 | "source": [ 281 | "### Note:\n", 282 | "**Now choose a specific annotation label. In this case we're choosing `Pyruvate kinase, barrel (IPR015793)` to parse the annotation and see the sequence appear in a bigger sequence.**" 283 | ] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "execution_count": null, 288 | "id": "fc8fa87f", 289 | "metadata": {}, 290 | "outputs": [], 291 | "source": [ 292 | "from src.esmhelpers import parse_annotations_by_label, format_annotations\n", 293 | "\n", 294 | "parsed_annotations = parse_annotations_by_label(generated_protein.function_annotations)\n", 295 | "\n", 296 | "print(\n", 297 | " \" \".ljust(25),\n", 298 | " format_seq(\n", 299 | " generated_protein.sequence,\n", 300 | " width=len(generated_protein.sequence) + 1,\n", 301 | " line_numbers=False,\n", 302 | " ),\n", 303 | ")\n", 304 | "\n", 305 | "for label, flags in format_annotations(\n", 306 | " parsed_annotations,\n", 307 | " len(generated_protein.sequence),\n", 308 | " [\n", 309 | " \"Pyruvate kinase-like, insert domain superfamily (IPR011037)\",\n", 310 | " ],\n", 311 | ").items():\n", 312 | " print(\n", 313 | " label[:24].ljust(25),\n", 314 | " format_seq(\n", 315 | " flags,\n", 316 | " width=len(generated_protein.sequence) + 1,\n", 317 | " line_numbers=False,\n", 318 | " ),\n", 319 | " )" 320 | ] 321 | }, 322 | { 323 | "cell_type": "markdown", 324 | "id": "b6e2ac49", 325 | "metadata": {}, 326 | "source": [ 327 | "ESM3 was able to correctly identify the barrel and C-terminal domains, as well as some additional sequence annotations." 328 | ] 329 | }, 330 | { 331 | "cell_type": "markdown", 332 | "id": "d1902bdf", 333 | "metadata": {}, 334 | "source": [ 335 | "---\n", 336 | "## 4. Sequence to Structure Prediction\n", 337 | "\n", 338 | "![Sequence In - Structure Out](images/seq-str.png)\n", 339 | "\n" 340 | ] 341 | }, 342 | { 343 | "cell_type": "markdown", 344 | "id": "53224f0f", 345 | "metadata": {}, 346 | "source": [ 347 | "Another common task for bioFMs is to translate between sequence and struture (protein folding). Let's try to predict the structure of human beta 3 alchohol dehydrogenase, the enzyme responsible for breaking down alcohol in the liver." 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": null, 353 | "id": "953431a5", 354 | "metadata": {}, 355 | "outputs": [], 356 | "source": [ 357 | "from esm.utils.structure.protein_chain import ProteinChain\n", 358 | "import py3Dmol\n", 359 | "\n", 360 | "pdb_id = \"1HTB\"\n", 361 | "chain_id = \"A\"\n", 362 | "\n", 363 | "# Download the mmCIF file for 1JB0 from PDB\n", 364 | "adh_ref_chain = ProteinChain.from_rcsb(pdb_id, chain_id)\n", 365 | "\n", 366 | "# Display the sequence\n", 367 | "print(format_seq(adh_ref_chain.sequence))\n", 368 | "\n", 369 | "# Display the structure\n", 370 | "quick_pdb_plot(adh_ref_chain.to_pdb_string(), color=\"#007FAA\", width=400, height=300)" 371 | ] 372 | }, 373 | { 374 | "cell_type": "markdown", 375 | "id": "b0183412", 376 | "metadata": {}, 377 | "source": [ 378 | "Now we use ESM3 to predict the structure, conditioned on the sequence" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": null, 384 | "id": "4de7b9fc", 385 | "metadata": {}, 386 | "outputs": [], 387 | "source": [ 388 | "prompt = ESMProtein.from_protein_chain(adh_ref_chain)\n", 389 | "\n", 390 | "structure_generation_config = GenerationConfig(\n", 391 | " track=\"structure\",\n", 392 | " num_steps=len(prompt.sequence) // 8,\n", 393 | " temperature=0.0, # Lower temperature means more deterministic predictions.\n", 394 | ")\n", 395 | "\n", 396 | "generated_protein = model.generate(\n", 397 | " prompt,\n", 398 | " structure_generation_config,\n", 399 | ")\n", 400 | "\n", 401 | "generated_chain = generated_protein.to_protein_chain()\n", 402 | "generated_chain = generated_chain.align(adh_ref_chain)\n", 403 | "\n", 404 | "quick_pdb_plot(\n", 405 | " generated_protein.to_pdb_string(), color=\"#00f174\", width=400, height=300\n", 406 | ")" 407 | ] 408 | }, 409 | { 410 | "cell_type": "markdown", 411 | "id": "82ce5143", 412 | "metadata": {}, 413 | "source": [ 414 | "Finally we align the generated and reference structures and view the results." 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": null, 420 | "id": "00b6db39", 421 | "metadata": {}, 422 | "outputs": [], 423 | "source": [ 424 | "# Calculate the cRMSD\n", 425 | "crmsd = generated_chain.rmsd(adh_ref_chain)\n", 426 | "print(\n", 427 | " \"cRMSD of the motif in the generated structure vs the original structure: \", crmsd\n", 428 | ")\n", 429 | "\n", 430 | "view = py3Dmol.view(width=800, height=600)\n", 431 | "view.addModel(adh_ref_chain.to_pdb_string(), \"pdb\")\n", 432 | "view.addModel(generated_chain.to_pdb_string(), \"pdb\")\n", 433 | "view.setStyle({\"model\": 0}, {\"cartoon\": {\"color\": \"#007FAA\"}})\n", 434 | "view.setStyle({\"model\": 1}, {\"cartoon\": {\"color\": \"#00f174\"}})\n", 435 | "view.zoomTo()\n", 436 | "view.show()" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "id": "8f5e9108", 442 | "metadata": {}, 443 | "source": [ 444 | "The structure prediction is quite good, with a cRMSD of less than 1. The reference structure was generated using X-ray diffraction at a resolution of 2.4 angstroms, so this prediction matches the experimental accuracy." 445 | ] 446 | }, 447 | { 448 | "cell_type": "markdown", 449 | "id": "952007c7", 450 | "metadata": {}, 451 | "source": [ 452 | "---\n", 453 | "## 4. Structure to Sequence Prediction\n", 454 | "\n", 455 | "We can also translate the other direction, from structure to sequence.\n", 456 | "\n", 457 | "![Structure In - Sequence Out](images/str-seq.png)" 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": null, 463 | "id": "33e50f5b", 464 | "metadata": {}, 465 | "outputs": [], 466 | "source": [ 467 | "masked_sequence = \"_\" * len(adh_ref_chain.sequence)\n", 468 | "\n", 469 | "prompt = ESMProtein(\n", 470 | " sequence=masked_sequence,\n", 471 | " coordinates=generated_protein.coordinates,\n", 472 | ")\n", 473 | "sequence_generation_config = GenerationConfig(\n", 474 | " track=\"sequence\",\n", 475 | " num_steps=prompt.sequence.count(\"_\") // 4,\n", 476 | " temperature=0.0,\n", 477 | ")\n", 478 | "generated_protein = model.generate(\n", 479 | " prompt,\n", 480 | " sequence_generation_config,\n", 481 | ")\n", 482 | "print(format_seq(generated_protein.sequence))" 483 | ] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "execution_count": null, 488 | "id": "202656e1", 489 | "metadata": {}, 490 | "outputs": [], 491 | "source": [ 492 | "quick_aligment_plot(adh_ref_chain.sequence, generated_protein.sequence)" 493 | ] 494 | }, 495 | { 496 | "cell_type": "markdown", 497 | "id": "4e8269b7", 498 | "metadata": {}, 499 | "source": [ 500 | "Given only the predicted 3D structure of ADH, ESM3 was able to recover more than 85% of the actual sequence." 501 | ] 502 | }, 503 | { 504 | "cell_type": "markdown", 505 | "id": "8d67e718-a94e-44de-b6dd-325f58c0824d", 506 | "metadata": {}, 507 | "source": [ 508 | "## Next steps" 509 | ] 510 | }, 511 | { 512 | "cell_type": "markdown", 513 | "id": "a712992e-9e41-4f82-a2df-192fcbce5545", 514 | "metadata": {}, 515 | "source": [ 516 | "In our next notebook you're going to work on some enzyme engineering tasks" 517 | ] 518 | } 519 | ], 520 | "metadata": { 521 | "kernelspec": { 522 | "display_name": "Python 3 (ipykernel)", 523 | "language": "python", 524 | "name": "python3" 525 | }, 526 | "language_info": { 527 | "codemirror_mode": { 528 | "name": "ipython", 529 | "version": 3 530 | }, 531 | "file_extension": ".py", 532 | "mimetype": "text/x-python", 533 | "name": "python", 534 | "nbconvert_exporter": "python", 535 | "pygments_lexer": "ipython3", 536 | "version": "3.10.14" 537 | } 538 | }, 539 | "nbformat": 4, 540 | "nbformat_minor": 5 541 | } 542 | -------------------------------------------------------------------------------- /3-enzyme-scaffold-modification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "ec6f0e6c-e514-496c-ba0c-0e5f7184ad9e", 6 | "metadata": { 7 | "id": "GZeZDsBYTe6z" 8 | }, 9 | "source": [ 10 | "# Enzyme Engineering with ESM3 using Amazon SageMaker Realtime inference endpoints\n", 11 | "\n", 12 | "The demo will showcase ESM3's ability to modify enzyme sequences and structures.\n", 13 | "\n", 14 | "Rather than generating new sequences from scratch, it can be much more interesting to modify an existing protein sequence. You might do this to potentially increase the binding to a ligand, for example. Or to design a new protein that incorporates a known active site.\n", 15 | "\n", 16 | "![Protein Modification](images/part_seq-seq.png)" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "147044cf", 22 | "metadata": {}, 23 | "source": [ 24 | "## 1. Setup" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "a82fa5d4-fe25-4288-8b62-35243d5a4794", 30 | "metadata": {}, 31 | "source": [ 32 | "Note: you'll need to run the first notebook `1-deploy-esm3-inference-endpoint.ipynb` before running this one to get the `ENDPOINT_NAME` used below. Optionally if you deploy the model via the SageMaker console you can find the endpoint in the inference section of the console." 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "e6f35c7e", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "ENDPOINT_NAME = \"\"\n", 43 | "MODEL_NAME = \"esm3-sm-open-v1\"" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "id": "100abaa5", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "from esm.sdk.api import ESMProtein, GenerationConfig\n", 54 | "from esm.sdk.sagemaker import ESM3SageMakerClient\n", 55 | "from src.esmhelpers import format_seq\n", 56 | "\n", 57 | "model = ESM3SageMakerClient(endpoint_name=ENDPOINT_NAME, model=MODEL_NAME)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "id": "8bf50659", 63 | "metadata": {}, 64 | "source": [ 65 | "---\n", 66 | "## 2. Download enzyme structure\n", 67 | "\n", 68 | "Ornithine transcarbamylase (OTC) deficiency is a rare genetic disorder that affects the liver's ability to process ammonia, a waste product produced during the breakdown of proteins. It is the most common urea cycle disorder. Treatment involves a low-protein diet, ammonia-lowering medications, and sometimes liver transplantation for severe cases. Early diagnosis and management are crucial to prevent brain damage and other complications.\n", 69 | "\n", 70 | "One treatment approach for certain genetic diseases like OTCD is enzyme replacement therapy, where patients receive an intravenous infusion of the missing or deficient enzyme on a regular basis. This can be effective, but expensive. Instead, scientists have proposed using modified versions of these enzymes that require lower or less-frequent dosing. \n", 71 | "\n", 72 | "Let's see how ESM3 can improve protein engineering projects like this. First, we download the OTC reference structure from PDB and visualize the active sites necessary for its function." 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "id": "eb5bab72", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "from esm.utils.structure.protein_chain import ProteinChain\n", 83 | "import py3Dmol\n", 84 | "\n", 85 | "pdb_id = \"1OTH\"\n", 86 | "chain_id = \"A\"\n", 87 | "\n", 88 | "# Download the mmCIF file for 1PKN from PDB\n", 89 | "otc_reference_chain = ProteinChain.from_rcsb(pdb_id, chain_id)\n", 90 | "otc_reference_chain.residue_index = (\n", 91 | " otc_reference_chain.residue_index - otc_reference_chain.residue_index[0] + 1\n", 92 | ")\n", 93 | "otc_reference_protein = ESMProtein.from_protein_chain(otc_reference_chain)\n", 94 | "\n", 95 | "# Display the sequence\n", 96 | "print(format_seq(otc_reference_chain.sequence))\n", 97 | "\n", 98 | "active_site_residues = [\n", 99 | " 56,\n", 100 | " 57,\n", 101 | " 58,\n", 102 | " 59,\n", 103 | " 60,\n", 104 | " 61,\n", 105 | " 108,\n", 106 | " 130,\n", 107 | " 135,\n", 108 | " 138,\n", 109 | " 165,\n", 110 | " 166,\n", 111 | " 167,\n", 112 | " 230,\n", 113 | " 231,\n", 114 | " 234,\n", 115 | " 235,\n", 116 | " 270,\n", 117 | " 271,\n", 118 | " 272,\n", 119 | " 297,\n", 120 | "]\n", 121 | "\n", 122 | "# Display the structure\n", 123 | "view = py3Dmol.view(width=800, height=600)\n", 124 | "view.addModel(otc_reference_chain.infer_oxygen().to_pdb_string(), \"pdb\")\n", 125 | "view.setStyle({\"cartoon\": {\"color\": \"#007FAA\"}})\n", 126 | "view.addStyle({\"resi\": active_site_residues}, {\"cartoon\": {\"color\": \"#eb982c\"}})\n", 127 | "\n", 128 | "view.zoomTo()\n", 129 | "view.show()" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "id": "6ec7050a", 135 | "metadata": {}, 136 | "source": [ 137 | "---\n", 138 | "## 3. Prepare masked prompt" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "id": "1fcb964b", 144 | "metadata": {}, 145 | "source": [ 146 | "Next, we encode the reference sequence and structure into tokens. This will make it easier to select specific portions of the protein for redesign, especially for the structure." 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "id": "caef3ba8", 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "otc_reference_tokens = model.encode(otc_reference_protein)\n", 157 | "print(f\"Encoded sequence:\\n{otc_reference_tokens.sequence}\")\n", 158 | "print(f\"Encoded structure:\\n{otc_reference_tokens.structure}\")" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "id": "aaade3fc", 164 | "metadata": {}, 165 | "source": [ 166 | "Next, we create a prompt that masks all of the protein except for the binding pocket highlighted above. First, we can construct a sequence prompt of all masks and then fill in the active site residues." 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "id": "c9e79e61", 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "import torch\n", 177 | "from esm.utils.constants import esm3 as esm3_constants\n", 178 | "\n", 179 | "prompt_token_length = len(otc_reference_tokens.sequence)\n", 180 | "print(f\"Sequence token count: {prompt_token_length}\")\n", 181 | "masked_sequence_tokens = torch.full(\n", 182 | " [prompt_token_length], esm3_constants.SEQUENCE_MASK_TOKEN\n", 183 | ")\n", 184 | "masked_sequence_tokens[0] = esm3_constants.SEQUENCE_BOS_TOKEN\n", 185 | "masked_sequence_tokens[-1] = esm3_constants.SEQUENCE_EOS_TOKEN\n", 186 | "\n", 187 | "for idx in active_site_residues:\n", 188 | " masked_sequence_tokens[idx - 1] = otc_reference_tokens.sequence[idx - 1]\n", 189 | "\n", 190 | "masked_sequence_token_count = (\n", 191 | " (masked_sequence_tokens == esm3_constants.SEQUENCE_MASK_TOKEN).sum().item()\n", 192 | ")\n", 193 | "\n", 194 | "print(f\"Masked sequence token count: {masked_sequence_token_count}\")" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "id": "1bf11347", 200 | "metadata": {}, 201 | "source": [ 202 | "Next, we do something similar for the structure. Rather than dealing with 3D coordinates, we instead work with the encoded structure tokens. We construct an empty structure track like | ... |... and then fill in structure tokens for the active site.\n" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "id": "63a5a030", 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "import torch\n", 213 | "\n", 214 | "masked_structure_tokens = torch.full(\n", 215 | " [prompt_token_length], esm3_constants.STRUCTURE_MASK_TOKEN\n", 216 | ")\n", 217 | "\n", 218 | "masked_structure_tokens[0] = esm3_constants.STRUCTURE_BOS_TOKEN\n", 219 | "masked_structure_tokens[-1] = esm3_constants.STRUCTURE_EOS_TOKEN\n", 220 | "\n", 221 | "otc_reference_tokens = model.encode(otc_reference_protein)\n", 222 | "for idx in active_site_residues:\n", 223 | " masked_structure_tokens[idx - 1] = otc_reference_tokens.structure[idx - 1]\n", 224 | "\n", 225 | "masked_structure_token_count = (\n", 226 | " (masked_structure_tokens == esm3_constants.STRUCTURE_MASK_TOKEN).sum().item()\n", 227 | ")\n", 228 | "\n", 229 | "print(f\"Masked structure token count: {masked_structure_token_count}\")\n", 230 | "\n", 231 | "assert masked_sequence_token_count == masked_structure_token_count\n", 232 | "masked_token_count = masked_sequence_token_count" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "id": "0a073861", 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "from esm.sdk.api import ESMProteinTensor\n", 243 | "\n", 244 | "encoded_prompt = ESMProteinTensor(\n", 245 | " sequence=masked_sequence_tokens, \n", 246 | " structure=masked_structure_tokens\n", 247 | ")" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": null, 253 | "id": "02f0a717", 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "print(\"Reference sequence:\")\n", 258 | "print(\n", 259 | " format_seq(\n", 260 | " otc_reference_chain.sequence, width=prompt_token_length + 1, line_numbers=False\n", 261 | " )\n", 262 | ")\n", 263 | "print(\"Masked sequence:\")\n", 264 | "print(\n", 265 | " format_seq(\n", 266 | " model.decode(encoded_prompt).sequence,\n", 267 | " width=prompt_token_length + 1,\n", 268 | " line_numbers=False,\n", 269 | " )\n", 270 | ")\n", 271 | "print(\"Masked structure:\")\n", 272 | "print(\n", 273 | " format_seq(\n", 274 | " \"\".join([\"✔\" if st < 4096 else \"_\" for st in encoded_prompt.structure][1:-1]),\n", 275 | " width=prompt_token_length + 1,\n", 276 | " line_numbers=False,\n", 277 | " )\n", 278 | ")" 279 | ] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "id": "fdfb9516", 284 | "metadata": {}, 285 | "source": [ 286 | "---\n", 287 | "## 4. Generate structure" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": null, 293 | "id": "7d3e418b", 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "structure_generation_config = GenerationConfig(\n", 298 | " track=\"structure\", num_steps=masked_token_count // 8, temperature=1.0\n", 299 | ")\n", 300 | "\n", 301 | "generated_protein_1 = model.generate(encoded_prompt, structure_generation_config)\n", 302 | "\n", 303 | "decoded_protein_chain = model.decode(generated_protein_1).to_protein_chain()\n", 304 | "\n", 305 | "view = py3Dmol.view(width=600, height=400)\n", 306 | "view.addModel(decoded_protein_chain.infer_oxygen().to_pdb_string(), \"pdb\")\n", 307 | "view.setStyle({\"cartoon\": {\"color\": \"#007FAA\"}})\n", 308 | "view.addStyle({\"resi\": active_site_residues}, {\"cartoon\": {\"color\": \"#eb982c\"}})\n", 309 | "view.zoomTo()\n", 310 | "view.show()" 311 | ] 312 | }, 313 | { 314 | "cell_type": "markdown", 315 | "id": "ab5cf6e5", 316 | "metadata": {}, 317 | "source": [ 318 | "Verfiy that:\n", 319 | " 1. The new structure has a very similar active site as the reference\n", 320 | " 2. The new struture has a very DISSIMLAR backbone structure" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": null, 326 | "id": "384481d9", 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "constrained_site_rmsd = otc_reference_chain[active_site_residues].rmsd(\n", 331 | " decoded_protein_chain[active_site_residues]\n", 332 | ")\n", 333 | "backbone_rmsd = otc_reference_chain.rmsd(decoded_protein_chain)\n", 334 | "\n", 335 | "c_pass = \"✅\" if constrained_site_rmsd < 1.5 else \"❌\"\n", 336 | "b_pass = \"✅\" if backbone_rmsd > 1.5 else \"❌\"\n", 337 | "\n", 338 | "print(f\"Constrained site RMSD: {constrained_site_rmsd:.2f} Ang {c_pass}\")\n", 339 | "print(f\"Backbone RMSD: {backbone_rmsd:.2f} Ang {b_pass}\")" 340 | ] 341 | }, 342 | { 343 | "cell_type": "markdown", 344 | "id": "595372f1", 345 | "metadata": {}, 346 | "source": [ 347 | "---\n", 348 | "## 5. Generate sequence\n", 349 | "\n", 350 | "Next, we use the generated structure as conditioning to generate a new sequence" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": null, 356 | "id": "b648d8c9", 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "sequence_generation_config = GenerationConfig(\n", 361 | " track=\"sequence\",\n", 362 | " num_steps=masked_token_count // 4,\n", 363 | " temperature=1.0,\n", 364 | ")\n", 365 | "generated_protein_2 = model.generate(generated_protein_1, sequence_generation_config)\n", 366 | "\n", 367 | "print(\"Reference sequence:\")\n", 368 | "print(\n", 369 | " format_seq(\n", 370 | " otc_reference_chain.sequence, width=prompt_token_length + 1, line_numbers=False\n", 371 | " )\n", 372 | ")\n", 373 | "print(\"Masked sequence:\")\n", 374 | "print(\n", 375 | " format_seq(\n", 376 | " model.decode(encoded_prompt).sequence,\n", 377 | " width=prompt_token_length + 1,\n", 378 | " line_numbers=False,\n", 379 | " )\n", 380 | ")\n", 381 | "print(\"Generated sequence:\")\n", 382 | "print(\n", 383 | " format_seq(\n", 384 | " model.decode(generated_protein_2).sequence, width=prompt_token_length + 1, line_numbers=False\n", 385 | " )\n", 386 | ")" 387 | ] 388 | }, 389 | { 390 | "cell_type": "markdown", 391 | "id": "1c4dd846", 392 | "metadata": {}, 393 | "source": [ 394 | "Finally, refold the generated sequence without any other conditioning." 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": null, 400 | "id": "c8e2958a", 401 | "metadata": {}, 402 | "outputs": [], 403 | "source": [ 404 | "prompt = ESMProteinTensor(sequence=generated_protein_2.sequence, structure=None)\n", 405 | "\n", 406 | "structure_generation_config = GenerationConfig(\n", 407 | " track=\"structure\", num_steps=masked_token_count // 8, temperature=0.0\n", 408 | ")\n", 409 | "\n", 410 | "generated_protein_3 = model.generate(prompt, structure_generation_config)\n", 411 | "final_protein = model.decode(generated_protein_3)\n", 412 | "print(format_seq(final_protein.sequence))" 413 | ] 414 | }, 415 | { 416 | "cell_type": "markdown", 417 | "id": "f0e74d33", 418 | "metadata": {}, 419 | "source": [ 420 | "---\n", 421 | "## 6. Validation" 422 | ] 423 | }, 424 | { 425 | "cell_type": "markdown", 426 | "id": "c3224c60", 427 | "metadata": {}, 428 | "source": [ 429 | "Compare the generated sequence to the reference." 430 | ] 431 | }, 432 | { 433 | "cell_type": "code", 434 | "execution_count": null, 435 | "id": "fd5163c7", 436 | "metadata": {}, 437 | "outputs": [], 438 | "source": [ 439 | "import biotite.sequence as seq\n", 440 | "import biotite.sequence.align as align\n", 441 | "import biotite.sequence.graphics as graphics\n", 442 | "import matplotlib.pyplot as pl\n", 443 | "\n", 444 | "\n", 445 | "seq1 = seq.ProteinSequence(otc_reference_protein.sequence)\n", 446 | "seq2 = seq.ProteinSequence(final_protein.sequence)\n", 447 | "\n", 448 | "alignments = align.align_optimal(\n", 449 | " seq1,\n", 450 | " seq2,\n", 451 | " align.SubstitutionMatrix.std_protein_matrix(),\n", 452 | " gap_penalty=(-10, -1),\n", 453 | ")\n", 454 | "\n", 455 | "alignment = alignments[0]\n", 456 | "\n", 457 | "identity = align.get_sequence_identity(alignment)\n", 458 | "print(f\"Sequence identity: {100*identity:.2f}%\")\n", 459 | "\n", 460 | "print(\"\\nSequence alignment:\")\n", 461 | "fig = pl.figure(figsize=(8.0, 4.0))\n", 462 | "ax = fig.add_subplot(111)\n", 463 | "graphics.plot_alignment_similarity_based(\n", 464 | " ax, alignment, symbols_per_line=45, spacing=2,\n", 465 | " show_numbers=True,\n", 466 | ")\n", 467 | "fig.tight_layout()\n", 468 | "pl.show()" 469 | ] 470 | }, 471 | { 472 | "cell_type": "markdown", 473 | "id": "3b6d4e07", 474 | "metadata": {}, 475 | "source": [ 476 | "Compare the generated structure to the reference." 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": null, 482 | "id": "c805619a", 483 | "metadata": {}, 484 | "outputs": [], 485 | "source": [ 486 | "generated_chain = final_protein.to_protein_chain()\n", 487 | "generated_chain = generated_chain.align(otc_reference_chain)\n", 488 | "\n", 489 | "constrained_site_rmsd = otc_reference_chain[active_site_residues].rmsd(\n", 490 | " generated_chain[active_site_residues]\n", 491 | ")\n", 492 | "backbone_rmsd = otc_reference_chain.rmsd(generated_chain)\n", 493 | "\n", 494 | "c_pass = \"✅\" if constrained_site_rmsd < 1.5 else \"❌\"\n", 495 | "b_pass = \"🤷‍♂️\"\n", 496 | "\n", 497 | "print(f\"Constrained site RMSD: {constrained_site_rmsd:.2f} Ang {c_pass}\")\n", 498 | "print(f\"Backbone RMSD: {backbone_rmsd:.2f} Ang {b_pass}\")\n" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": null, 504 | "id": "c244824d", 505 | "metadata": {}, 506 | "outputs": [], 507 | "source": [ 508 | "view = py3Dmol.view(width=600, height=600)\n", 509 | "view.addModel(otc_reference_chain.infer_oxygen().to_pdb_string(), \"pdb\")\n", 510 | "view.addModel(generated_chain.infer_oxygen().to_pdb_string(), \"pdb\")\n", 511 | "view.setStyle({\"model\":0},{\"cartoon\": {\"color\": \"#007FAA\"}})\n", 512 | "view.setStyle({\"model\":1},{\"cartoon\": {\"color\": \"lightgreen\"}})\n", 513 | "view.addStyle({\"resi\": active_site_residues}, {\"cartoon\": {\"color\": \"#eb982c\"}})\n", 514 | "view.zoomTo()\n", 515 | "view.show()" 516 | ] 517 | }, 518 | { 519 | "cell_type": "markdown", 520 | "id": "f8c2cda2", 521 | "metadata": {}, 522 | "source": [ 523 | "We have successfully generated a new protein with a similar active site as the referemce but different backbone structure and sequence. Repeating this process many times will give us a good library of candidates for lab testing." 524 | ] 525 | }, 526 | { 527 | "cell_type": "markdown", 528 | "id": "859c1316-c3cb-4d72-b227-64798906f2df", 529 | "metadata": {}, 530 | "source": [ 531 | "## Congratulations\n", 532 | "\n", 533 | "You've gone through all the example notebooks. Please feel free to experiment further or go to the next notebook where you can learn how to delete the inference endpoints." 534 | ] 535 | } 536 | ], 537 | "metadata": { 538 | "kernelspec": { 539 | "display_name": "Python 3 (ipykernel)", 540 | "language": "python", 541 | "name": "python3" 542 | }, 543 | "language_info": { 544 | "codemirror_mode": { 545 | "name": "ipython", 546 | "version": 3 547 | }, 548 | "file_extension": ".py", 549 | "mimetype": "text/x-python", 550 | "name": "python", 551 | "nbconvert_exporter": "python", 552 | "pygments_lexer": "ipython3", 553 | "version": "3.10.14" 554 | } 555 | }, 556 | "nbformat": 4, 557 | "nbformat_minor": 5 558 | } 559 | -------------------------------------------------------------------------------- /4-cleanup.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "6c11a19d-1445-495c-a21c-5af9d220664e", 6 | "metadata": {}, 7 | "source": [ 8 | "# Clean up" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "69cc0126-3f7a-4da5-80a3-4d9e73eee39c", 14 | "metadata": {}, 15 | "source": [ 16 | "This notebook will create an instance of the predictor and delete the endpoint. Alternatively you can delete the endpoint directly through the SageMaker Console " 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "fb0392bc-264e-4a47-80d0-9830d6ef6093", 22 | "metadata": {}, 23 | "source": [ 24 | "## 1. Delete endpoint" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 3, 30 | "id": "498b4143-624c-4207-b462-6eebdb4905ec", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "ENDPOINT_NAME = \"\"" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 4, 40 | "id": "f033d199-9712-496f-8236-e9b5c83c1b8b", 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "from sagemaker.predictor import Predictor\n", 45 | "from sagemaker import Session\n", 46 | "\n", 47 | "session = Session()\n", 48 | "\n", 49 | "predictor = Predictor(ENDPOINT_NAME, sagemaker_session=session)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 5, 55 | "id": "159c7ff9-76a7-49b3-bbc4-0a3289225984", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "predictor.delete_endpoint()" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "id": "f72ff07c-abf1-4482-b0a0-30d68daf6d50", 65 | "metadata": {}, 66 | "source": [ 67 | "## 2. Unsubscribe to the listing (Optional)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "id": "34bcac8a-0304-44c6-ab1d-111fb034ccc6", 73 | "metadata": {}, 74 | "source": [ 75 | "If you would like to unsubscribe to the model package, follow these steps. Before you cancel the subscription, ensure that you do not have any [deployable model](https://console.aws.amazon.com/sagemaker/home#/models) created from the model package or using the algorithm. Note - You can find this information by looking at the container name associated with the model. \n", 76 | "\n", 77 | "**Steps to unsubscribe to product from AWS Marketplace**:\n", 78 | "1. Navigate to __Machine Learning__ tab on [__Your Software subscriptions page__](https://aws.amazon.com/marketplace/ai/library?productType=ml&ref_=mlmp_gitdemo_indust)\n", 79 | "2. Locate the listing that you want to cancel the subscription for, and then choose __Cancel Subscription__ to cancel the subscription.\n", 80 | "\n" 81 | ] 82 | } 83 | ], 84 | "metadata": { 85 | "kernelspec": { 86 | "display_name": "Python 3 (ipykernel)", 87 | "language": "python", 88 | "name": "python3" 89 | }, 90 | "language_info": { 91 | "codemirror_mode": { 92 | "name": "ipython", 93 | "version": 3 94 | }, 95 | "file_extension": ".py", 96 | "mimetype": "text/x-python", 97 | "name": "python", 98 | "nbconvert_exporter": "python", 99 | "pygments_lexer": "ipython3", 100 | "version": "3.10.14" 101 | } 102 | }, 103 | "nbformat": 4, 104 | "nbformat_minor": 5 105 | } 106 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT No Attribution 2 | 3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 17 | 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## EvolutionaryScale ESM3 on Amazon SageMaker 2 | 3 | This repository contains the sample notebooks to run EvolutionaryScale's ESM3 biological model on Amazon SageMaker. 4 | 5 | 6 | ### What is ESM3 7 | ESM3 is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks. ESM3 is a generative masked language model. You can prompt it with partial sequence, structure, and function keywords, and iteratively sample masked positions until all positions are unmasked. 8 | 9 | ![ESM3 Architecture](images/esm3-architecture.png) 10 | 11 | The ESM3 architecture is highly scalable due to its transformer backbone and all-to-all reasoning over discrete token sequences. At its largest scale, ESM3 was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens, and has 98 billion parameters. Here we present `esm3-open-small`. With 1.4B parameters it is the smallest and fastest model in the family, trained specifically to be open sourced. 12 | 13 | ### Notebooks in this repository 14 | * [Deploy ESM3-open Model Package from AWS Marketplace ](1-deploy-esm3-inference-endpoint.ipynb) 15 | * [Basic protein generation tasks](2-basic-patterns.ipynb) 16 | * [Enzyme engineering](3-enzyme-scaffold-modification.ipynb) 17 | * [Clean up](4-cleanup.ipynb) 18 | 19 | 20 | ### About EvolutionaryScale 21 | 22 | EvolutionaryScale is a frontier AI research lab and Public Benefit Corporation dedicated to developing artificial intelligence for the life sciences. EvolutionaryScale’s models support groundbreaking research and development in health, environmental science, and beyond. The company was founded in July 2023, with a founding team widely recognized for its pioneering work in transformer protein language models. For more information, visit https://evolutionaryscale.ai 23 | 24 | 25 | ## Security 26 | 27 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 28 | 29 | ## License 30 | 31 | This library is licensed under the MIT-0 License. See the LICENSE file. 32 | 33 | -------------------------------------------------------------------------------- /esm3-sagemaker-sample-notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Deploy ESM3-open Model Package from AWS Marketplace \n", 8 | "\n", 9 | "---\n", 10 | "## 1. Overview\n", 11 | "\n", 12 | "### 1.1. Important Note:\n", 13 | "\n", 14 | "Please visit model detail page in https://aws.amazon.com/marketplace/pp/prodview-xbvra5ylcu4xq to learn more. If you do not have access to the link, please contact account admin for the help.\n", 15 | "\n", 16 | "You will find details about the model including pricing, supported region, and end user license agreement. To use the model, please click “Continue to Subscribe” from the detail page, come back here and learn how to deploy and inference.\n", 17 | "\n", 18 | "ESM3 is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks.\n", 19 | "ESM3 is a generative masked language model. You can prompt it with partial sequence, structure, and function keywords, and iteratively sample masked positions until all positions are unmasked.\n", 20 | "\n", 21 | "\"ESM3\n", 22 | "\n", 23 | "\n", 24 | "The ESM3 architecture is highly scalable due to its transformer backbone and all-to-all reasoning over discrete token sequences. At its largest scale, ESM3 was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens, and has 98 billion parameters.\n", 25 | "Here we present esm3-open-small. With 1.4B parameters it is the smallest and fastest model in the family, trained specifically to be open sourced. ESM3-open is available under a non-commercial license.\n", 26 | "\n", 27 | "This sample notebook shows you how to deploy [EvolutionaryScale - ESM3](https://aws.amazon.com/marketplace/pp/prodview-xbvra5ylcu4xq) using Amazon SageMaker.\n", 28 | "\n", 29 | "> **Note**: This is a reference notebook and it cannot run unless you make changes suggested in the notebook.\n", 30 | "\n", 31 | "> ESM3 model package support SageMaker Realtime Inference but not SageMaker Batch Transform.\n", 32 | "\n", 33 | "### 1.2. Prerequisites\n", 34 | "- This notebook contains elements which render correctly in Jupyter interface. Open this notebook from an Amazon SageMaker Notebook Instance or Amazon SageMaker Studio.\n", 35 | "- Ensure that IAM role used has **AmazonSageMakerFullAccess** and a trust policy for `sagemaker.amazonaws.com`, as described in the SageMaker documentation.\n", 36 | "- To deploy this ML model successfully, ensure that you meet one of the following conditions:\n", 37 | " 1. Your IAM role has these three permissions and you have authority to make AWS Marketplace subscriptions in the AWS account used: \n", 38 | " - **aws-marketplace:ViewSubscriptions**\n", 39 | " - **aws-marketplace:Unsubscribe**\n", 40 | " - **aws-marketplace:Subscribe** \n", 41 | " 2. Your AWS account has a subscription to [ESM3](https://aws.amazon.com/marketplace/pp/prodview-xbvra5ylcu4xq). If so, skip step: [Subscribe to the model package](#1.-Subscribe-to-the-model-package)\n", 42 | "\n", 43 | "### 1.3. Contents\n", 44 | "1. [Overview](#1.-Overview)\n", 45 | "2. [Subscribe to the model package](#2.-Subscribe-to-the-model-package)\n", 46 | "3. [Create a real-time inference endpoint ](#3.-Create-a-real-time-inference-endpoint)\n", 47 | "4. [Test endpoint](#4.-Test-endpoint)\n", 48 | "5. [Clean up](#5.-Clean-up)\n", 49 | "\n", 50 | "\n", 51 | "### 1.4. Usage instructions\n", 52 | "You can run this notebook one cell at a time by pressing the Shift+Enter keys." 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "---\n", 60 | "## 2. Subscribe to the model package" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "1. Open the model package listing page [EvolutionaryScale ESM3 Model](https://aws.amazon.com/marketplace/pp/prodview-xbvra5ylcu4xq)\n", 68 | "1. On the AWS Marketplace listing, click on the **Continue to subscribe** button.\n", 69 | "1. On the **Subscribe to this software** page, review and click on **\"Accept Offer\"** if you and your organization agrees with EULA, pricing, and support terms. \n", 70 | "1. Once you click on **Continue to configuration button** and then choose a **region**, you will see a **Product Arn** displayed. This is the model package ARN that you need to specify while creating a deployable model using Boto3. Copy the ARN corresponding to your region and specify the same in the following cell." 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "---\n", 78 | "## 3. Create a real-time inference endpoint" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "To learn more about real-time inference on Amazon SageMaker, please visit the [Documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-hosting.html)." 86 | ] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": {}, 91 | "source": [ 92 | "### 3.1. Setup" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "Install dependencies. For inference capabilities we will use EvolutionaryScale's `esm` package. The order of installation is important so that dependencies don't get overriden." 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": { 106 | "tags": [] 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "from IPython.display import clear_output\n", 111 | "\n", 112 | "%pip install -U esm --no-deps\n", 113 | "%pip install -U -r requirements.txt\n", 114 | "%pip install -U sagemaker\n", 115 | "\n", 116 | "clear_output()" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "Define ESM3 model package and instance information. You can find the `MODEL_NAME` and `ESM3_PACKAGE_ID` from the SageMaker console. Go to the SageMaker Console > Inference > Marketplace model packages. Then go to the tab that says AWS Marketplace Subscriptions.\n", 124 | "\n", 125 | "This notebook is designed to work different models from EvolutionaryScale. This notebook shows how you can use the ESM Open model." 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "ESM3_PACKAGE_ID = \"esm3-sm-open-v1-e218175afc0b3c8d959cb2702a2d1097\" \n", 135 | "MODEL_NAME = \"esm3-sm-open-v1\" # This is the open model version\n", 136 | "INSTANCE_TYPE = \"ml.g5.2xlarge\"\n", 137 | "INITIAL_INSTANCE_COUNT = 1" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "metadata": { 144 | "tags": [] 145 | }, 146 | "outputs": [], 147 | "source": [ 148 | "import boto3\n", 149 | "import sagemaker\n", 150 | "\n", 151 | "# Create SageMaker clients\n", 152 | "sagemaker_session = sagemaker.Session()\n", 153 | "region = sagemaker_session.boto_region_name\n", 154 | "sagemaker_client = boto3.client(\"sagemaker\", region_name=region)\n", 155 | "sagemaker_runtime_client = boto3.client(\"sagemaker-runtime\", region_name=region)\n", 156 | "\n", 157 | "# Get SageMaker execution role\n", 158 | "try:\n", 159 | " role = sagemaker.get_execution_role()\n", 160 | " print(f\"Default SageMaker execution role: {role}\")\n", 161 | "except ValueError as e:\n", 162 | " print(f\"Error getting default execution role: {e}\")\n", 163 | " print(\n", 164 | " \"You may need to specify a role explicitly or create one if not running in a SageMaker environment.\"\n", 165 | " )\n", 166 | "\n", 167 | "# Identify model package arm\n", 168 | "model_package_map = {\n", 169 | " \"ap-northeast-1\": f\"arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/{ESM3_PACKAGE_ID}\",\n", 170 | " \"ap-northeast-2\": f\"arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/{ESM3_PACKAGE_ID}\",\n", 171 | " \"ap-south-1\": f\"arn:aws:sagemaker:ap-south-1:077584701553:model-package/{ESM3_PACKAGE_ID}\",\n", 172 | " \"ap-southeast-1\": f\"arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/{ESM3_PACKAGE_ID}\",\n", 173 | " \"ap-southeast-2\": f\"arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/{ESM3_PACKAGE_ID}\",\n", 174 | " \"ca-central-1\": f\"arn:aws:sagemaker:ca-central-1:470592106596:model-package/{ESM3_PACKAGE_ID}\",\n", 175 | " \"eu-central-1\": f\"arn:aws:sagemaker:eu-central-1:446921602837:model-package/{ESM3_PACKAGE_ID}\",\n", 176 | " \"eu-north-1\": f\"arn:aws:sagemaker:eu-north-1:136758871317:model-package/{ESM3_PACKAGE_ID}\",\n", 177 | " \"eu-west-1\": f\"arn:aws:sagemaker:eu-west-1:985815980388:model-package/{ESM3_PACKAGE_ID}\",\n", 178 | " \"eu-west-2\": f\"arn:aws:sagemaker:eu-west-2:856760150666:model-package/{ESM3_PACKAGE_ID}\",\n", 179 | " \"eu-west-3\": f\"arn:aws:sagemaker:eu-west-3:843114510376:model-package/{ESM3_PACKAGE_ID}\",\n", 180 | " \"sa-east-1\": f\"arn:aws:sagemaker:sa-east-1:270155090741:model-package/{ESM3_PACKAGE_ID}\",\n", 181 | " \"us-east-1\": f\"arn:aws:sagemaker:us-east-1:865070037744:model-package/{ESM3_PACKAGE_ID}\",\n", 182 | " \"us-east-2\": f\"arn:aws:sagemaker:us-east-2:057799348421:model-package/{ESM3_PACKAGE_ID}\",\n", 183 | " \"us-west-1\": f\"arn:aws:sagemaker:us-west-1:382657785993:model-package/{ESM3_PACKAGE_ID}\",\n", 184 | " \"us-west-2\": f\"arn:aws:sagemaker:us-west-2:594846645681:model-package/{ESM3_PACKAGE_ID}\",\n", 185 | "}\n", 186 | "\n", 187 | "if region not in model_package_map.keys():\n", 188 | " raise Exception(f\"Current boto3 session region {region} is not supported.\")\n", 189 | "\n", 190 | "model_package_arn = model_package_map[region]\n", 191 | "print(f\"Model package ARN: {model_package_arn}\")" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "metadata": {}, 197 | "source": [ 198 | "### 3.2. Create a model from the subscribed model package" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "model = sagemaker.model.ModelPackage(\n", 208 | " role=role,\n", 209 | " model_package_arn=model_package_arn,\n", 210 | " sagemaker_session=sagemaker_session,\n", 211 | " enable_network_isolation=True,\n", 212 | " predictor_cls=sagemaker.predictor.Predictor,\n", 213 | ")" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "### 3.3. Create a real-time endpoint" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "metadata": {}, 226 | "source": [ 227 | "Note: This step will take 10-20 minutes." 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": { 234 | "tags": [] 235 | }, 236 | "outputs": [], 237 | "source": [ 238 | "predictor = model.deploy(\n", 239 | " initial_instance_count=INITIAL_INSTANCE_COUNT,\n", 240 | " instance_type=INSTANCE_TYPE,\n", 241 | " sagemaker_session=sagemaker_session,\n", 242 | " serializer=sagemaker.base_serializers.JSONSerializer(),\n", 243 | " deserializer=sagemaker.base_deserializers.JSONDeserializer(),\n", 244 | ")\n", 245 | "\n", 246 | "print(f\"Deployed endpoint name is {predictor.endpoint_name}\")\n", 247 | "print(f\"Model name is {MODEL_NAME}\")" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": null, 253 | "metadata": {}, 254 | "outputs": [], 255 | "source": [ 256 | "print(predictor.endpoint_name)" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "ENDPOINT_NAME = predictor.endpoint_name" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": {}, 271 | "source": [ 272 | "It will require several minutes to deploy the model to an endpoint" 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": {}, 278 | "source": [ 279 | "---\n", 280 | "## 4. Test endpoint" 281 | ] 282 | }, 283 | { 284 | "cell_type": "markdown", 285 | "metadata": {}, 286 | "source": [ 287 | "### 4.1. Let's create a simple new protein sequence as a test for our endpoint" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": null, 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [ 296 | "from esm.sdk.api import ESMProtein, GenerationConfig\n", 297 | "from esm.sdk.sagemaker import ESM3SageMakerClient\n", 298 | "from src.esmhelpers import format_seq, quick_pdb_plot, quick_aligment_plot\n", 299 | "\n", 300 | "model = ESM3SageMakerClient(endpoint_name=ENDPOINT_NAME, model=MODEL_NAME)" 301 | ] 302 | }, 303 | { 304 | "cell_type": "markdown", 305 | "metadata": {}, 306 | "source": [ 307 | "ESM3 is a generative model, so the most basic task it can accomplish is to create the sequence and structure of a new protein. All ESM3 inference requests must include sequence information, so in this case we will pass a string of \"_\" symbols. This is the \"mask\" token that indicates where we want ESM3 to fill in the blanks.\n", 308 | "\n", 309 | "We start by generating a new protein sequence." 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": null, 315 | "metadata": {}, 316 | "outputs": [], 317 | "source": [ 318 | "%%time\n", 319 | "\n", 320 | "n_masked = 64\n", 321 | "\n", 322 | "masked_sequence = \"_\" * n_masked\n", 323 | "\n", 324 | "prompt = ESMProtein(sequence=masked_sequence)\n", 325 | "sequence_generation_config = GenerationConfig(\n", 326 | " track=\"sequence\", # We want ESM3 to generate tokens for the sequence track\n", 327 | " num_steps=prompt.sequence.count(\"_\") // 4, # We'll use num(mask tokens) // 4 steps to decode the sequence\n", 328 | " temperature=0.7, # We'll use a temperature of 0.7 to increase the randomness of the decoding process\n", 329 | ")\n", 330 | "\n", 331 | "# Call the ESM3 inference endpoint\n", 332 | "generated_protein = model.generate(\n", 333 | " prompt,\n", 334 | " sequence_generation_config,\n", 335 | ")\n", 336 | "\n", 337 | "# View the generated sequence\n", 338 | "print(f\"Sequence length: {len(generated_protein.sequence)}\")\n", 339 | "print(format_seq(generated_protein.sequence))\n" 340 | ] 341 | }, 342 | { 343 | "cell_type": "markdown", 344 | "metadata": {}, 345 | "source": [ 346 | "## Next steps\n", 347 | "\n", 348 | "Voila! You have a new sequence of proteins generated using ESM3 Open model. Head to the next notebook for more basic patterns." 349 | ] 350 | } 351 | ], 352 | "metadata": { 353 | "instance_type": "ml.t3.medium", 354 | "kernelspec": { 355 | "display_name": "Python 3 (ipykernel)", 356 | "language": "python", 357 | "name": "python3" 358 | }, 359 | "language_info": { 360 | "codemirror_mode": { 361 | "name": "ipython", 362 | "version": 3 363 | }, 364 | "file_extension": ".py", 365 | "mimetype": "text/x-python", 366 | "name": "python", 367 | "nbconvert_exporter": "python", 368 | "pygments_lexer": "ipython3", 369 | "version": "3.10.12" 370 | }, 371 | "vscode": { 372 | "interpreter": { 373 | "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e" 374 | } 375 | } 376 | }, 377 | "nbformat": 4, 378 | "nbformat_minor": 4 379 | } 380 | -------------------------------------------------------------------------------- /images/all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/esm3-on-amazon-sagemaker/cca42a44d472dfbbac7b6a30ef730d2b219baf0c/images/all.png -------------------------------------------------------------------------------- /images/cot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/esm3-on-amazon-sagemaker/cca42a44d472dfbbac7b6a30ef730d2b219baf0c/images/cot.png -------------------------------------------------------------------------------- /images/esm3-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/esm3-on-amazon-sagemaker/cca42a44d472dfbbac7b6a30ef730d2b219baf0c/images/esm3-architecture.png -------------------------------------------------------------------------------- /images/model-from-marketplace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/esm3-on-amazon-sagemaker/cca42a44d472dfbbac7b6a30ef730d2b219baf0c/images/model-from-marketplace.png -------------------------------------------------------------------------------- /images/part_seq-seq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/esm3-on-amazon-sagemaker/cca42a44d472dfbbac7b6a30ef730d2b219baf0c/images/part_seq-seq.png -------------------------------------------------------------------------------- /images/seq-func.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/esm3-on-amazon-sagemaker/cca42a44d472dfbbac7b6a30ef730d2b219baf0c/images/seq-func.png -------------------------------------------------------------------------------- /images/seq-str.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/esm3-on-amazon-sagemaker/cca42a44d472dfbbac7b6a30ef730d2b219baf0c/images/seq-str.png -------------------------------------------------------------------------------- /images/seq_str_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/esm3-on-amazon-sagemaker/cca42a44d472dfbbac7b6a30ef730d2b219baf0c/images/seq_str_out.png -------------------------------------------------------------------------------- /images/str-seq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/esm3-on-amazon-sagemaker/cca42a44d472dfbbac7b6a30ef730d2b219baf0c/images/str-seq.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrs 2 | biotite==0.41.2 3 | biopython 4 | boto3 5 | brotli 6 | cloudpathlib 7 | einops 8 | ipykernel 9 | ipython 10 | ipywidgets 11 | matplotlib 12 | msgpack-numpy 13 | pandas 14 | pillow 15 | py3dmol 16 | scikit-learn 17 | torch --extra-index-url https://download.pytorch.org/whl/cpu 18 | transformers 19 | tqdm -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/esm3-on-amazon-sagemaker/cca42a44d472dfbbac7b6a30ef730d2b219baf0c/src/__init__.py -------------------------------------------------------------------------------- /src/esmhelpers.py: -------------------------------------------------------------------------------- 1 | import biotite.sequence as seq 2 | import biotite.sequence.align as align 3 | import biotite.sequence.graphics as graphics 4 | import matplotlib.pyplot as plt 5 | from PIL import ImageColor 6 | import py3Dmol 7 | 8 | 9 | def format_seq( 10 | seq: str, 11 | width: int = 80, 12 | block_size: int = 10, 13 | gap: str = " ", 14 | line_numbers: bool = True, 15 | color_scheme_name: str = "flower", 16 | ) -> str: 17 | """ 18 | Format a biological sequence into pretty blocks with (optional) line numbers. 19 | """ 20 | 21 | output = "" 22 | output += f"{1:<4}" + " " if line_numbers else "" 23 | if type(seq) != str: 24 | seq = str(seq) 25 | 26 | for i, res in enumerate(seq, start=1): 27 | output += color_amino_acid(res, color_scheme_name) 28 | if i % width == 0: 29 | output += "\n" 30 | output += f"{i+1:<4}" + " " if line_numbers else "" 31 | elif i % block_size == 0: 32 | output += gap 33 | 34 | return output 35 | 36 | 37 | def quick_pdb_plot( 38 | pdb_str: str, width: int = 800, height: int = 600, color: str = "#007FAA" 39 | ) -> None: 40 | """ 41 | Plot a PDB structure using py3dmol 42 | """ 43 | view = py3Dmol.view(width=width, height=height) 44 | view.addModel(pdb_str, "pdb") 45 | view.setStyle({"cartoon": {"color": color}}) 46 | view.zoomTo() 47 | view.show() 48 | return None 49 | 50 | 51 | def quick_aligment_plot(seq_1: str, seq_2: str) -> None: 52 | seq1 = seq.ProteinSequence(seq_1) 53 | seq2 = seq.ProteinSequence(seq_2) 54 | # Get BLOSUM62 matrix 55 | matrix = align.SubstitutionMatrix.std_protein_matrix() 56 | # Perform pairwise sequence alignment with affine gap penalty 57 | # Terminal gaps are not penalized 58 | alignments = align.align_optimal( 59 | seq1, seq2, matrix, gap_penalty=(-10, -1), terminal_penalty=False 60 | ) 61 | 62 | print("Alignment Score: ", alignments[0].score) 63 | print("Sequence identity:", align.get_sequence_identity(alignments[0])) 64 | 65 | # Draw first and only alignment 66 | # The color intensity indicates the similiarity 67 | fig = plt.figure() 68 | ax = fig.add_subplot(111) 69 | graphics.plot_alignment_similarity_based( 70 | ax, 71 | alignments[0], 72 | matrix=matrix, 73 | labels=["Reference", "Prediction"], 74 | show_numbers=False, 75 | show_line_position=True, 76 | color=(0.0, 127 / 255, 170 / 255), 77 | ) 78 | fig.tight_layout() 79 | plt.show() 80 | return None 81 | 82 | 83 | def color_text(text: str, hex: str): 84 | """ 85 | Color text 86 | """ 87 | rgb = ImageColor.getrgb(hex) 88 | color_code = f"\033[38;2;{rgb[0]};{rgb[1]};{rgb[2]}m" 89 | return color_code + text + "\033[0m" 90 | 91 | 92 | def color_amino_acid(res, color_scheme_name="flower"): 93 | colors = graphics.get_color_scheme(color_scheme_name, seq.ProteinSequence.alphabet) 94 | color_map = dict(zip(seq.ProteinSequence.alphabet, colors)) 95 | color_map.update( 96 | { 97 | "B": "#FFFFFF", 98 | "U": "#FFFFFF", 99 | "Z": "#FFFFFF", 100 | "O": "#FFFFFF", 101 | ".": "#FFFFFF", 102 | "-": "#FFFFFF", 103 | "|": "#FFFFFF", 104 | "_": "#000000", 105 | "✔": "#FF9900", 106 | } 107 | ) 108 | return color_text(res, color_map[res]) 109 | 110 | 111 | def color_protein_sequence(protein_sequence: str, color_scheme_name="flower"): 112 | return "".join( 113 | [color_amino_acid(res, color_scheme_name) for res in protein_sequence] 114 | ) 115 | 116 | 117 | def parse_annotations_by_label(annotations) -> dict: 118 | """ 119 | Generate a dictionary of annotation labels and their corresponding sequence positions 120 | """ 121 | 122 | parsed_annotations = {} 123 | 124 | for annotation in annotations: 125 | annotation_idx = list(range(annotation.start, annotation.end + 1)) 126 | if annotation.label in parsed_annotations: 127 | parsed_annotations[annotation.label].extend(annotation_idx) 128 | parsed_annotations[annotation.label] = list( 129 | set(parsed_annotations[annotation.label]) 130 | ) 131 | else: 132 | parsed_annotations[annotation.label] = annotation_idx 133 | return parsed_annotations 134 | 135 | 136 | def parse_annotations_by_index(annotations, sequence_length) -> dict: 137 | """ 138 | Generate a list of sequence positions and their corresponding annotation labels 139 | """ 140 | 141 | parsed_annotations = [] 142 | annotations_by_label = parse_annotations_by_label(annotations) 143 | 144 | for idx in range(1, sequence_length): 145 | idx_annotations = [] 146 | for k, v in annotations_by_label.items(): 147 | if idx in v: 148 | idx_annotations.append(k) 149 | parsed_annotations.append(idx_annotations) 150 | 151 | return parsed_annotations 152 | 153 | def format_annotations(parsed_annotations, length, filter = None): 154 | output = {} 155 | parsed_annotations = {k: parsed_annotations[k] for k in filter} if filter else parsed_annotations 156 | for k, v in parsed_annotations.items(): 157 | tmp = ["_"] * length 158 | for i in v: 159 | tmp[i - 1] = "✔" 160 | output[k] = "".join(tmp) 161 | return output --------------------------------------------------------------------------------