├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── backend ├── lambda │ ├── app.py │ └── requirements.txt └── template.yaml ├── cfn └── initial_template.yaml ├── code ├── inference.py └── requirements.txt ├── frontend ├── package-lock.json ├── package.json ├── public │ ├── favicon.ico │ ├── index.html │ ├── logo192.png │ ├── logo512.png │ ├── manifest.json │ └── robots.txt └── src │ ├── App.css │ ├── App.js │ ├── config │ └── index.js │ ├── images │ └── header.jpg │ ├── index.css │ ├── index.js │ ├── logo.svg │ └── serviceWorker.js ├── frontendExample.png ├── inference.py ├── nlu-based-item-search.ipynb ├── query.png ├── ref.png └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/node,python,react,sam,windows,zsh,serverless,amplify,visualstudiocode 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=node,python,react,sam,windows,zsh,serverless,amplify,visualstudiocode 3 | 4 | ### Amplify ### 5 | # AWS Amplify 6 | amplify/\#current-cloud-backend 7 | amplify/.config/local-* 8 | amplify/mock-data 9 | amplify/backend/amplify-meta.json 10 | amplify/backend/awscloudformation 11 | build/ 12 | dist/ 13 | node_modules/ 14 | aws-exports.js 15 | awsconfiguration.json 16 | amplifyconfiguration.json 17 | amplify-build-config.json 18 | amplify-gradle-config.json 19 | amplifyxc.config 20 | 21 | ### Node ### 22 | # Logs 23 | logs 24 | *.log 25 | npm-debug.log* 26 | yarn-debug.log* 27 | yarn-error.log* 28 | lerna-debug.log* 29 | .pnpm-debug.log* 30 | 31 | # Diagnostic reports (https://nodejs.org/api/report.html) 32 | report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json 33 | 34 | # Runtime data 35 | pids 36 | *.pid 37 | *.seed 38 | *.pid.lock 39 | 40 | # Directory for instrumented libs generated by jscoverage/JSCover 41 | lib-cov 42 | 43 | # Coverage directory used by tools like istanbul 44 | coverage 45 | *.lcov 46 | 47 | # nyc test coverage 48 | .nyc_output 49 | 50 | # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) 51 | .grunt 52 | 53 | # Bower dependency directory (https://bower.io/) 54 | bower_components 55 | 56 | # node-waf configuration 57 | .lock-wscript 58 | 59 | # Compiled binary addons (https://nodejs.org/api/addons.html) 60 | build/Release 61 | 62 | # Dependency directories 63 | jspm_packages/ 64 | 65 | # Snowpack dependency directory (https://snowpack.dev/) 66 | web_modules/ 67 | 68 | # TypeScript cache 69 | *.tsbuildinfo 70 | 71 | # Optional npm cache directory 72 | .npm 73 | 74 | # Optional eslint cache 75 | .eslintcache 76 | 77 | # Optional stylelint cache 78 | .stylelintcache 79 | 80 | # Microbundle cache 81 | .rpt2_cache/ 82 | .rts2_cache_cjs/ 83 | .rts2_cache_es/ 84 | .rts2_cache_umd/ 85 | 86 | # Optional REPL history 87 | .node_repl_history 88 | 89 | # Output of 'npm pack' 90 | *.tgz 91 | 92 | # Yarn Integrity file 93 | .yarn-integrity 94 | 95 | # dotenv environment variable files 96 | .env 97 | .env.development.local 98 | .env.test.local 99 | .env.production.local 100 | .env.local 101 | 102 | # parcel-bundler cache (https://parceljs.org/) 103 | .cache 104 | .parcel-cache 105 | 106 | # Next.js build output 107 | .next 108 | out 109 | 110 | # Nuxt.js build / generate output 111 | .nuxt 112 | dist 113 | 114 | # Gatsby files 115 | .cache/ 116 | # Comment in the public line in if your project uses Gatsby and not Next.js 117 | # https://nextjs.org/blog/next-9-1#public-directory-support 118 | # public 119 | 120 | # vuepress build output 121 | .vuepress/dist 122 | 123 | # vuepress v2.x temp and cache directory 124 | .temp 125 | 126 | # Docusaurus cache and generated files 127 | .docusaurus 128 | 129 | # Serverless directories 130 | .serverless/ 131 | 132 | # FuseBox cache 133 | .fusebox/ 134 | 135 | # DynamoDB Local files 136 | .dynamodb/ 137 | 138 | # TernJS port file 139 | .tern-port 140 | 141 | # Stores VSCode versions used for testing VSCode extensions 142 | .vscode-test 143 | 144 | # yarn v2 145 | .yarn/cache 146 | .yarn/unplugged 147 | .yarn/build-state.yml 148 | .yarn/install-state.gz 149 | .pnp.* 150 | 151 | ### Node Patch ### 152 | # Serverless Webpack directories 153 | .webpack/ 154 | 155 | # Optional stylelint cache 156 | 157 | # SvelteKit build / generate output 158 | .svelte-kit 159 | 160 | ### Python ### 161 | # Byte-compiled / optimized / DLL files 162 | __pycache__/ 163 | *.py[cod] 164 | *$py.class 165 | 166 | # C extensions 167 | *.so 168 | 169 | # Distribution / packaging 170 | .Python 171 | develop-eggs/ 172 | downloads/ 173 | eggs/ 174 | .eggs/ 175 | lib/ 176 | lib64/ 177 | parts/ 178 | sdist/ 179 | var/ 180 | wheels/ 181 | share/python-wheels/ 182 | *.egg-info/ 183 | .installed.cfg 184 | *.egg 185 | MANIFEST 186 | 187 | # PyInstaller 188 | # Usually these files are written by a python script from a template 189 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 190 | *.manifest 191 | *.spec 192 | 193 | # Installer logs 194 | pip-log.txt 195 | pip-delete-this-directory.txt 196 | 197 | # Unit test / coverage reports 198 | htmlcov/ 199 | .tox/ 200 | .nox/ 201 | .coverage 202 | .coverage.* 203 | nosetests.xml 204 | coverage.xml 205 | *.cover 206 | *.py,cover 207 | .hypothesis/ 208 | .pytest_cache/ 209 | cover/ 210 | 211 | # Translations 212 | *.mo 213 | *.pot 214 | 215 | # Django stuff: 216 | local_settings.py 217 | db.sqlite3 218 | db.sqlite3-journal 219 | 220 | # Flask stuff: 221 | instance/ 222 | .webassets-cache 223 | 224 | # Scrapy stuff: 225 | .scrapy 226 | 227 | # Sphinx documentation 228 | docs/_build/ 229 | 230 | # PyBuilder 231 | .pybuilder/ 232 | target/ 233 | 234 | # Jupyter Notebook 235 | .ipynb_checkpoints 236 | 237 | # IPython 238 | profile_default/ 239 | ipython_config.py 240 | 241 | # pyenv 242 | # For a library or package, you might want to ignore these files since the code is 243 | # intended to run in multiple environments; otherwise, check them in: 244 | # .python-version 245 | 246 | # pipenv 247 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 248 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 249 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 250 | # install all needed dependencies. 251 | #Pipfile.lock 252 | 253 | # poetry 254 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 255 | # This is especially recommended for binary packages to ensure reproducibility, and is more 256 | # commonly ignored for libraries. 257 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 258 | #poetry.lock 259 | 260 | # pdm 261 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 262 | #pdm.lock 263 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 264 | # in version control. 265 | # https://pdm.fming.dev/#use-with-ide 266 | .pdm.toml 267 | 268 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 269 | __pypackages__/ 270 | 271 | # Celery stuff 272 | celerybeat-schedule 273 | celerybeat.pid 274 | 275 | # SageMath parsed files 276 | *.sage.py 277 | 278 | # Environments 279 | .venv 280 | env/ 281 | venv/ 282 | ENV/ 283 | env.bak/ 284 | venv.bak/ 285 | 286 | # Spyder project settings 287 | .spyderproject 288 | .spyproject 289 | 290 | # Rope project settings 291 | .ropeproject 292 | 293 | # mkdocs documentation 294 | /site 295 | 296 | # mypy 297 | .mypy_cache/ 298 | .dmypy.json 299 | dmypy.json 300 | 301 | # Pyre type checker 302 | .pyre/ 303 | 304 | # pytype static type analyzer 305 | .pytype/ 306 | 307 | # Cython debug symbols 308 | cython_debug/ 309 | 310 | # PyCharm 311 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 312 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 313 | # and can be added to the global gitignore or merged into this file. For a more nuclear 314 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 315 | #.idea/ 316 | 317 | ### react ### 318 | .DS_* 319 | **/*.backup.* 320 | **/*.back.* 321 | 322 | node_modules 323 | 324 | *.sublime* 325 | 326 | psd 327 | thumb 328 | sketch 329 | 330 | ### SAM ### 331 | # Ignore build directories for the AWS Serverless Application Model (SAM) 332 | # Info: https://aws.amazon.com/serverless/sam/ 333 | # Docs: https://docs.aws.amazon.com/serverless-application-model/latest/developerguide/serverless-sam-reference.html 334 | 335 | **/.aws-sam 336 | 337 | ### Serverless ### 338 | # Ignore build directory 339 | .serverless 340 | 341 | ### VisualStudioCode ### 342 | .vscode/* 343 | !.vscode/settings.json 344 | !.vscode/tasks.json 345 | !.vscode/launch.json 346 | !.vscode/extensions.json 347 | !.vscode/*.code-snippets 348 | 349 | # Local History for Visual Studio Code 350 | .history/ 351 | 352 | # Built Visual Studio Code Extensions 353 | *.vsix 354 | 355 | ### VisualStudioCode Patch ### 356 | # Ignore all local history of files 357 | .history 358 | .ionide 359 | 360 | # Support for Project snippet scope 361 | .vscode/*.code-snippets 362 | 363 | # Ignore code-workspaces 364 | *.code-workspace 365 | 366 | ### Windows ### 367 | # Windows thumbnail cache files 368 | Thumbs.db 369 | Thumbs.db:encryptable 370 | ehthumbs.db 371 | ehthumbs_vista.db 372 | 373 | # Dump file 374 | *.stackdump 375 | 376 | # Folder config file 377 | [Dd]esktop.ini 378 | 379 | # Recycle Bin used on file shares 380 | $RECYCLE.BIN/ 381 | 382 | # Windows Installer files 383 | *.cab 384 | *.msi 385 | *.msix 386 | *.msm 387 | *.msp 388 | 389 | # Windows shortcuts 390 | *.lnk 391 | 392 | ### Zsh ### 393 | # Zsh compiled script + zrecompile backup 394 | *.zwc 395 | *.zwc.old 396 | 397 | # Zsh completion-optimization dumpfile 398 | *zcompdump* 399 | 400 | # Zsh zcalc history 401 | .zcalc_history 402 | 403 | # A popular plugin manager's files 404 | ._zinit 405 | .zinit_lstupd 406 | 407 | # zdharma/zshelldoc tool's files 408 | zsdoc/data 409 | 410 | # robbyrussell/oh-my-zsh/plugins/per-directory-history plugin's files 411 | # (when set-up to store the history in the local directory) 412 | .directory_history 413 | 414 | # MichaelAquilina/zsh-autoswitch-virtualenv plugin's files 415 | # (for Zsh plugins using Python) 416 | 417 | # Zunit tests' output 418 | /tests/_output/* 419 | !/tests/_output/.gitkeep 420 | 421 | # End of https://www.toptal.com/developers/gitignore/api/node,python,react,sam,windows,zsh,serverless,amplify,visualstudiocode 422 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *master* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | 61 | We may ask you to sign a [Contributor License Agreement (CLA)](http://en.wikipedia.org/wiki/Contributor_License_Agreement) for larger changes. 62 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software is furnished to do so. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 10 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 11 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 12 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 13 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 14 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 15 | 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Amazon SageMaker NLU search 2 | 3 | This repository guides users through creating a NLU based product search using Amazon SageMaker and Amazon Elasticsearch service 4 | 5 | 6 | ## How does it work? 7 | 8 | we have used pre-trained BERT model(distilbert-base-nli-stsb-mean-tokens) from sentence-transformers to generate fixed 768 length sentence embedding on Multi-modal Corpus of Fashion Images from *__feidegger__*, a *__zalandoresearch__* dataset. Then those feature vectors is imported in Amazon ES KNN Index as a reference. 9 | 10 | ![diagram](../master/ref.png) 11 | 12 | 13 | When we present a new query text/sentence, it's computing the related embedding from Amazon SageMaker hosted BERT model and query Amazon ES KNN index to find similar text/sentence and corresponds to the actual product image which is stored in Amazon S3 14 | 15 | ![diagram](../master/query.png) 16 | 17 | ## License 18 | 19 | This library is licensed under the MIT-0 License. See the LICENSE file. 20 | -------------------------------------------------------------------------------- /backend/lambda/app.py: -------------------------------------------------------------------------------- 1 | import json 2 | from os import environ 3 | 4 | import boto3 5 | from urllib.parse import urlparse 6 | 7 | from elasticsearch import Elasticsearch, RequestsHttpConnection 8 | from requests_aws4auth import AWS4Auth 9 | 10 | # Global variables that are reused 11 | sm_runtime_client = boto3.client('sagemaker-runtime') 12 | s3_client = boto3.client('s3') 13 | 14 | 15 | def get_features(sm_runtime_client, sagemaker_endpoint, payload): 16 | response = sm_runtime_client.invoke_endpoint( 17 | EndpointName=sagemaker_endpoint, 18 | ContentType='text/plain', 19 | Body=payload) 20 | response_body = json.loads((response['Body'].read())) 21 | features = response_body 22 | 23 | return features 24 | 25 | 26 | def get_neighbors(features, es, k_neighbors=3): 27 | idx_name = 'idx_zalando' 28 | res = es.search( 29 | request_timeout=30, index=idx_name, 30 | body={ 31 | 'size': k_neighbors, 32 | 'query': {'knn': {'zalando_nlu_vector': {'vector': features, 'k': k_neighbors}}}} 33 | ) 34 | s3_uris = [res['hits']['hits'][x]['_source']['image'] for x in range(k_neighbors)] 35 | 36 | return s3_uris 37 | 38 | 39 | def es_match_query(payload, es, k=3): 40 | idx_name = 'idx_zalando' 41 | search_body = { 42 | "_source": { 43 | "excludes": ["zalando_nlu_vector"] 44 | }, 45 | "highlight": { 46 | "fields": { 47 | "description": {} 48 | } 49 | }, 50 | "query": { 51 | "match": { 52 | "description": { 53 | "query": payload 54 | } 55 | } 56 | } 57 | } 58 | 59 | search_response = es.search(request_timeout=30, index=idx_name, 60 | body=search_body)['hits']['hits'][:k] 61 | 62 | response = [{'image': x['_source']['image'], 'description': x['highlight']['description']} for x in search_response] 63 | 64 | return response 65 | 66 | 67 | def generate_presigned_urls(s3_uris): 68 | presigned_urls = [s3_client.generate_presigned_url( 69 | 'get_object', 70 | Params={ 71 | 'Bucket': urlparse(x).netloc, 72 | 'Key': urlparse(x).path.lstrip('/')}, 73 | ExpiresIn=300 74 | ) for x in s3_uris] 75 | 76 | return presigned_urls 77 | 78 | 79 | def lambda_handler(event, context): 80 | 81 | # elasticsearch variables 82 | service = 'es' 83 | region = environ['AWS_REGION'] 84 | elasticsearch_endpoint = environ['ES_ENDPOINT'] 85 | 86 | session = boto3.session.Session() 87 | credentials = session.get_credentials() 88 | awsauth = AWS4Auth( 89 | credentials.access_key, 90 | credentials.secret_key, 91 | region, 92 | service, 93 | session_token=credentials.token 94 | ) 95 | 96 | es = Elasticsearch( 97 | hosts=[{'host': elasticsearch_endpoint, 'port': 443}], 98 | http_auth=awsauth, 99 | use_ssl=True, 100 | verify_certs=True, 101 | connection_class=RequestsHttpConnection 102 | ) 103 | 104 | # sagemaker variables 105 | sagemaker_endpoint = environ['SM_ENDPOINT'] 106 | 107 | api_payload = json.loads(event['body']) 108 | k = api_payload['k'] 109 | payload = api_payload['searchString'] 110 | 111 | if event['path'] == '/postText': 112 | features = get_features(sm_runtime_client, sagemaker_endpoint, payload) 113 | s3_uris_neighbors = get_neighbors(features, es, k_neighbors=k) 114 | s3_presigned_urls = generate_presigned_urls(s3_uris_neighbors) 115 | return { 116 | "statusCode": 200, 117 | "headers": { 118 | "Access-Control-Allow-Origin": "*", 119 | "Access-Control-Allow-Headers": "*", 120 | "Access-Control-Allow-Methods": "*" 121 | }, 122 | "body": json.dumps({ 123 | "images": s3_presigned_urls, 124 | }), 125 | } 126 | else: 127 | search = es_match_query(payload, es, k) 128 | 129 | for i in range(len(search)): 130 | search[i]['presigned_url'] = generate_presigned_urls([search[i]['image']])[0] 131 | search[i]['description'] = " ".join(search[i]['description']) 132 | search[i]['description'] = search[i]['description'].replace("",'') 133 | return { 134 | "statusCode": 200, 135 | "headers": { 136 | "Access-Control-Allow-Origin": "*", 137 | "Access-Control-Allow-Headers": "*", 138 | "Access-Control-Allow-Methods": "*" 139 | }, 140 | "body": json.dumps(search), 141 | } 142 | -------------------------------------------------------------------------------- /backend/lambda/requirements.txt: -------------------------------------------------------------------------------- 1 | requests 2 | boto3 3 | elasticsearch 4 | requests-aws4auth -------------------------------------------------------------------------------- /backend/template.yaml: -------------------------------------------------------------------------------- 1 | AWSTemplateFormatVersion: '2010-09-09' 2 | Transform: AWS::Serverless-2016-10-31 3 | Description: 'backend 4 | 5 | Sample SAM Template for backend 6 | 7 | ' 8 | Parameters: 9 | BucketName: 10 | Type: String 11 | DomainName: 12 | Type: String 13 | ElasticSearchURL: 14 | Type: String 15 | SagemakerEndpoint: 16 | Type: String 17 | Globals: 18 | Function: 19 | Timeout: 60 20 | MemorySize: 512 21 | Api: 22 | Cors: 23 | AllowMethods: '''*''' 24 | AllowHeaders: '''*''' 25 | AllowOrigin: '''*''' 26 | Resources: 27 | PostGetSimilarTextFunction: 28 | Type: AWS::Serverless::Function 29 | Properties: 30 | CodeUri: s3://aws-ml-blog/artifacts/nlu-search/lambda.zip 31 | Handler: app.lambda_handler 32 | Runtime: python3.7 33 | Environment: 34 | Variables: 35 | ES_ENDPOINT: 36 | Ref: ElasticSearchURL 37 | SM_ENDPOINT: 38 | Ref: SagemakerEndpoint 39 | Policies: 40 | - Version: '2012-10-17' 41 | Statement: 42 | - Sid: AllowSagemakerInvokeEndpoint 43 | Effect: Allow 44 | Action: 45 | - sagemaker:InvokeEndpoint 46 | Resource: 47 | - Fn::Sub: arn:aws:sagemaker:${AWS::Region}:${AWS::AccountId}:endpoint/${SagemakerEndpoint} 48 | - Version: '2012-10-17' 49 | Statement: 50 | - Sid: AllowESS 51 | Effect: Allow 52 | Action: 53 | - es:* 54 | Resource: 55 | - Fn::Sub: arn:aws:es:${AWS::Region}:${AWS::AccountId}:domain/${DomainName}/* 56 | - S3ReadPolicy: 57 | BucketName: 58 | Ref: BucketName 59 | Events: 60 | PostText: 61 | Type: Api 62 | Properties: 63 | Path: /postText 64 | Method: post 65 | PostMatch: 66 | Type: Api 67 | Properties: 68 | Method: post 69 | Path: /postMatch 70 | Outputs: 71 | TextSimilarityApi: 72 | Description: API Gateway endpoint URL for Prod stage for GetSimilarText function 73 | Value: 74 | Fn::Sub: https://${ServerlessRestApi}.execute-api.${AWS::Region}.amazonaws.com/Prod/ 75 | PostGetSimilarTextFunctionArn: 76 | Description: GetSimilarText Lambda Function ARN 77 | Value: 78 | Fn::GetAtt: 79 | - PostGetSimilarTextFunction 80 | - Arn 81 | PostGetSimilarTextLambdaIamRole: 82 | Description: Implicit IAM Role created for GetSimilarText function 83 | Value: 84 | Fn::GetAtt: 85 | - PostGetSimilarTextFunction 86 | - Arn 87 | -------------------------------------------------------------------------------- /cfn/initial_template.yaml: -------------------------------------------------------------------------------- 1 | AWSTemplateFormatVersion: 2010-09-09 2 | Description: Template to start the nlu search blog 3 | 4 | Resources: 5 | CodeRepository: 6 | Type: AWS::SageMaker::CodeRepository 7 | Properties: 8 | GitConfig: 9 | RepositoryUrl: https://github.com/aws-samples/amazon-sagemaker-nlu-search 10 | 11 | NotebookInstance: 12 | Type: AWS::SageMaker::NotebookInstance 13 | Properties: 14 | InstanceType: ml.t3.medium 15 | RoleArn: !GetAtt Role.Arn 16 | DefaultCodeRepository: !GetAtt CodeRepository.CodeRepositoryName 17 | 18 | Role: 19 | Type: AWS::IAM::Role 20 | Properties: 21 | Policies: 22 | - PolicyName: CustomNotebookAccess 23 | PolicyDocument: 24 | Version: 2012-10-17 25 | Statement: 26 | - Effect: Allow 27 | Action: 28 | - "es:ESHttp*" 29 | Resource: 30 | - !Sub arn:aws:es:${AWS::Region}:${AWS::AccountId}:domain/* 31 | - Effect: Allow 32 | Action: 33 | - "s3:GetObject" 34 | - "s3:PutObject" 35 | - "s3:DeleteObject" 36 | - "s3:PutObjectAcl" 37 | Resource: 38 | - !Sub arn:aws:s3:::${s3BucketTraining}/* 39 | - !Sub arn:aws:s3:::${s3BucketHosting}/* 40 | ManagedPolicyArns: 41 | - arn:aws:iam::aws:policy/AmazonSageMakerFullAccess 42 | - arn:aws:iam::aws:policy/AWSCloudFormationReadOnlyAccess 43 | - arn:aws:iam::aws:policy/TranslateReadOnly 44 | AssumeRolePolicyDocument: 45 | Version: 2012-10-17 46 | Statement: 47 | - Effect: Allow 48 | Principal: 49 | Service: 50 | - sagemaker.amazonaws.com 51 | Action: 52 | - 'sts:AssumeRole' 53 | 54 | s3BucketTraining: 55 | Type: AWS::S3::Bucket 56 | Properties: 57 | BucketEncryption: 58 | ServerSideEncryptionConfiguration: 59 | - ServerSideEncryptionByDefault: 60 | SSEAlgorithm: "AES256" 61 | VersioningConfiguration: 62 | Status: Enabled 63 | 64 | s3BucketHosting: 65 | Type: AWS::S3::Bucket 66 | Properties: 67 | BucketEncryption: 68 | ServerSideEncryptionConfiguration: 69 | - ServerSideEncryptionByDefault: 70 | SSEAlgorithm: "AES256" 71 | VersioningConfiguration: 72 | Status: Enabled 73 | WebsiteConfiguration: 74 | IndexDocument: index.html 75 | ErrorDocument: error.html 76 | 77 | Domain: 78 | Type: AWS::Elasticsearch::Domain 79 | Properties: 80 | AccessPolicies: 81 | Version: 2012-10-17 82 | Statement: 83 | - Effect: Allow 84 | Principal: 85 | AWS: !Ref AWS::AccountId 86 | Action: 'es:*' 87 | Resource: !Sub arn:aws:es:${AWS::Region}:${AWS::AccountId}:domain/*/* 88 | ElasticsearchVersion: 7.7 89 | ElasticsearchClusterConfig: 90 | InstanceType: "t2.small.elasticsearch" 91 | EBSOptions: 92 | EBSEnabled: True 93 | VolumeSize: 10 94 | VolumeType: gp2 95 | 96 | 97 | Outputs: 98 | esHostName: 99 | Description: Elasticsearch hostname 100 | Value: !GetAtt Domain.DomainEndpoint 101 | 102 | esDomainName: 103 | Description: Elasticsearch domain name 104 | Value: !Ref Domain 105 | 106 | s3BucketTraining: 107 | Description: S3 bucket name for training 108 | Value: !Ref s3BucketTraining 109 | 110 | s3BucketHostingBucketName: 111 | Description: S3 bucket name for frontend hosting 112 | Value: !Ref s3BucketHosting 113 | 114 | S3BucketSecureURL: 115 | Value: !Join 116 | - '' 117 | - - 'https://' 118 | - !GetAtt 119 | - s3BucketHosting 120 | - DomainName 121 | Description: Name of S3 bucket to hold website content 122 | 123 | SageMakerNotebookURL: 124 | Description: SageMaker Notebook Instance 125 | Value: !Join 126 | - '' 127 | - - 'https://console.aws.amazon.com/sagemaker/home?region=us-east-1#/notebook-instances/openNotebook/' 128 | - !GetAtt NotebookInstance.NotebookInstanceName 129 | - '?view=classic' 130 | -------------------------------------------------------------------------------- /code/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import sagemaker_containers 4 | import requests 5 | 6 | import os 7 | import json 8 | import io 9 | import time 10 | import torch 11 | from transformers import AutoTokenizer, AutoModel 12 | # from sentence_transformers import models, losses, SentenceTransformer 13 | 14 | logger = logging.getLogger(__name__) 15 | logger.setLevel(logging.DEBUG) 16 | 17 | #Mean Pooling - Take attention mask into account for correct averaging 18 | def mean_pooling(model_output, attention_mask): 19 | token_embeddings = model_output[0] #First element of model_output contains all token embeddings 20 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 21 | sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) 22 | sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) 23 | return sum_embeddings / sum_mask 24 | 25 | def embed_tformer(model, tokenizer, sentences): 26 | encoded_input = tokenizer(sentences, padding=True, truncation=True, max_length=256, return_tensors='pt') 27 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 | encoded_input.to(device) 29 | 30 | #Compute token embeddings 31 | with torch.no_grad(): 32 | model_output = model(**encoded_input) 33 | 34 | sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) 35 | return sentence_embeddings 36 | 37 | def model_fn(model_dir): 38 | logger.info('model_fn') 39 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 40 | logger.info(model_dir) 41 | tokenizer = AutoTokenizer.from_pretrained(model_dir) 42 | nlp_model = AutoModel.from_pretrained(model_dir) 43 | nlp_model.to(device) 44 | model = {'model':nlp_model, 'tokenizer':tokenizer} 45 | 46 | return model 47 | 48 | # Deserialize the Invoke request body into an object we can perform prediction on 49 | def input_fn(serialized_input_data, content_type='text/plain'): 50 | logger.info('Deserializing the input data.') 51 | try: 52 | data = [serialized_input_data.decode('utf-8')] 53 | return data 54 | except: 55 | raise Exception('Requested unsupported ContentType in content_type: {}'.format(content_type)) 56 | 57 | # Perform prediction on the deserialized object, with the loaded model 58 | def predict_fn(input_object, model): 59 | logger.info("Calling model") 60 | start_time = time.time() 61 | sentence_embeddings = embed_tformer(model['model'], model['tokenizer'], input_object) 62 | print("--- Inference time: %s seconds ---" % (time.time() - start_time)) 63 | response = sentence_embeddings[0].tolist() 64 | return response 65 | 66 | # Serialize the prediction result into the desired response content type 67 | def output_fn(prediction, accept): 68 | logger.info('Serializing the generated output.') 69 | if accept == 'application/json': 70 | output = json.dumps(prediction) 71 | return output 72 | raise Exception('Requested unsupported ContentType in Accept: {}'.format(content_type)) 73 | -------------------------------------------------------------------------------- /code/requirements.txt: -------------------------------------------------------------------------------- 1 | sentence-transformers 2 | sagemaker-containers 3 | numpy>=1.17 4 | -------------------------------------------------------------------------------- /frontend/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "frontend", 3 | "version": "0.1.0", 4 | "private": true, 5 | "dependencies": { 6 | "@material-ui/core": "^4.12.4", 7 | "@material-ui/icons": "^4.11.3", 8 | "aws-amplify": "^4.3.34", 9 | "react": "^17.0.2", 10 | "react-dom": "^17.0.2", 11 | "typeface-roboto": "^1.1.13" 12 | }, 13 | "devDependencies": { 14 | "react-scripts": "^5.0.1" 15 | }, 16 | "scripts": { 17 | "start": "react-scripts start", 18 | "build": "react-scripts build", 19 | "eject": "react-scripts eject" 20 | }, 21 | "eslintConfig": { 22 | "extends": "react-app" 23 | }, 24 | "browserslist": { 25 | "production": [ 26 | ">0.2%", 27 | "not dead", 28 | "not op_mini all" 29 | ], 30 | "development": [ 31 | "last 1 chrome version", 32 | "last 1 firefox version", 33 | "last 1 safari version" 34 | ] 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /frontend/public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-nlu-search/f0c32c17d2a58e8fee0b91096891601002a1bed7/frontend/public/favicon.ico -------------------------------------------------------------------------------- /frontend/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 9 | 10 | 11 | 12 | 16 | 17 | 21 | 22 | 31 | AWS Natural Language Search - AWS Samples 32 | 33 | 34 | 35 |
36 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /frontend/public/logo192.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-nlu-search/f0c32c17d2a58e8fee0b91096891601002a1bed7/frontend/public/logo192.png -------------------------------------------------------------------------------- /frontend/public/logo512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-nlu-search/f0c32c17d2a58e8fee0b91096891601002a1bed7/frontend/public/logo512.png -------------------------------------------------------------------------------- /frontend/public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "short_name": "React App", 3 | "name": "Create React App Sample", 4 | "icons": [ 5 | { 6 | "src": "favicon.ico", 7 | "sizes": "64x64 32x32 24x24 16x16", 8 | "type": "image/x-icon" 9 | }, 10 | { 11 | "src": "logo192.png", 12 | "type": "image/png", 13 | "sizes": "192x192" 14 | }, 15 | { 16 | "src": "logo512.png", 17 | "type": "image/png", 18 | "sizes": "512x512" 19 | } 20 | ], 21 | "start_url": ".", 22 | "display": "standalone", 23 | "theme_color": "#000000", 24 | "background_color": "#ffffff" 25 | } 26 | -------------------------------------------------------------------------------- /frontend/public/robots.txt: -------------------------------------------------------------------------------- 1 | # https://www.robotstxt.org/robotstxt.html 2 | User-agent: * 3 | Disallow: 4 | -------------------------------------------------------------------------------- /frontend/src/App.css: -------------------------------------------------------------------------------- 1 | .App { 2 | text-align: center; 3 | } 4 | 5 | .App-logo { 6 | height: 40vmin; 7 | pointer-events: none; 8 | } 9 | 10 | @media (prefers-reduced-motion: no-preference) { 11 | .App-logo { 12 | animation: App-logo-spin infinite 20s linear; 13 | } 14 | } 15 | 16 | .App-header { 17 | background-color: #282c34; 18 | min-height: 100vh; 19 | display: flex; 20 | flex-direction: column; 21 | align-items: center; 22 | justify-content: center; 23 | font-size: calc(10px + 2vmin); 24 | color: white; 25 | } 26 | 27 | .App-link { 28 | color: #61dafb; 29 | } 30 | 31 | @keyframes App-logo-spin { 32 | from { 33 | transform: rotate(0deg); 34 | } 35 | to { 36 | transform: rotate(360deg); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /frontend/src/App.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import './App.css'; 3 | import 'typeface-roboto'; 4 | import { Button, Input, FormControl, Select, MenuItem } from '@material-ui/core'; 5 | import { withStyles, lighten } from "@material-ui/core/styles"; 6 | import InputAdornment from '@material-ui/core/InputAdornment'; 7 | import SearchIcon from '@material-ui/icons/Search'; 8 | import Paper from '@material-ui/core/Paper'; 9 | import Grid from '@material-ui/core/Grid'; 10 | import GridList from '@material-ui/core/GridList'; 11 | import GridListTile from '@material-ui/core/GridListTile'; 12 | import GridListTileBar from '@material-ui/core/GridListTileBar'; 13 | import LinearProgress from '@material-ui/core/LinearProgress'; 14 | import Typography from '@material-ui/core/Typography'; 15 | import Amplify, { API } from "aws-amplify"; 16 | import '@aws-amplify/ui/dist/style.css'; 17 | import Config from './config'; 18 | 19 | 20 | Amplify.configure({ 21 | API: { 22 | endpoints: [ 23 | { 24 | name: "NluSearch", 25 | endpoint: Config.apiEndpoint 26 | } 27 | ] 28 | } 29 | }); 30 | 31 | const styles = theme => ({ 32 | root: { 33 | flexGrow: 1, 34 | }, 35 | paper: { 36 | padding: theme.spacing(2), 37 | textAlign: 'center', 38 | height: "100%", 39 | color: theme.palette.text.secondary 40 | }, 41 | paper2: { 42 | padding: theme.spacing(2), 43 | textAlign: 'center', 44 | color: theme.palette.text.secondary 45 | }, 46 | em: { 47 | backgroundColor: "#f18973" 48 | } 49 | }); 50 | 51 | const BorderLinearProgress = withStyles({ 52 | root: { 53 | height: 10, 54 | backgroundColor: lighten('#ff6c5c', 0.5), 55 | }, 56 | bar: { 57 | borderRadius: 20, 58 | backgroundColor: '#ff6c5c', 59 | }, 60 | })(LinearProgress); 61 | 62 | // const classes = useStyles(); 63 | 64 | class App extends React.Component { 65 | 66 | constructor(props) { 67 | super(props); 68 | this.state = { 69 | pictures: [], 70 | results: [], 71 | completed:0, 72 | k:3 73 | }; 74 | this.handleSearchSubmit = this.handleSearchSubmit.bind(this); 75 | this.handleFormChange = this.handleFormChange.bind(this); 76 | this.handleKChange = this.handleKChange.bind(this); 77 | } 78 | 79 | handleSearchSubmit(event) { 80 | // function for when a use submits a URL 81 | // if the URL bar is empty, it will remove similar photos from state 82 | console.log(this.state.searchText); 83 | if (this.state.searchText === undefined || this.state.searchText === "") { 84 | console.log("Empty Text field"); 85 | this.setState({pictures: [], completed: 0, results: []}); 86 | } else { 87 | const myInit = { 88 | body: {"searchString": this.state.searchText, "k": this.state.k} 89 | }; 90 | 91 | this.setState({completed:66}); 92 | 93 | API.post('NluSearch', '/postText', myInit) 94 | .then(response => { 95 | this.setState({pictures: response.images.map(function(elem) { 96 | let picture = {}; 97 | picture.img = elem; 98 | picture.cols = 1; 99 | return picture; 100 | }) 101 | }); 102 | console.log(this.state.pictures); 103 | }) 104 | .catch(error => { 105 | console.log(error); 106 | }); 107 | 108 | this.setState({completed:85}); 109 | 110 | console.log(this.state.results); 111 | API.post('NluSearch', '/postMatch', myInit) 112 | .then(response => { 113 | // this.setState({results: []}); 114 | this.setState({results: response.map(function(elem) { 115 | let result = {}; 116 | result.img = elem.presigned_url; 117 | result.cols = 1; 118 | result.description = elem.description; 119 | return result; 120 | }) 121 | }); 122 | console.log(this.state.results); 123 | this.setState({completed:100}); 124 | }) 125 | .catch(error => { 126 | console.log(error); 127 | }); 128 | 129 | 130 | }; 131 | event.preventDefault(); 132 | } 133 | 134 | handleFormChange(event) { 135 | this.setState({searchText: event.target.value}); 136 | } 137 | 138 | handleKChange(event) { 139 | this.setState({k: event.target.value}); 140 | } 141 | 142 | render() { 143 | const { classes } = this.props; 144 | const createMarkup = htmlString => ({ __html: htmlString }); 145 | 146 | return ( 147 |
148 | 149 | 150 | {/* 151 | Header 152 | */} 153 | 154 | 155 | AWS Natural Language Search 156 | 157 | 158 | 159 | 160 | Step 1: Select the number of similar products (K neighbors): 161 |

162 | 163 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | Step 2: Enter a natural language search query about dresses. Try entering "I want a summery dress that's flowery and yellow": 181 |

182 |

183 | 194 | 195 | 196 | } 197 | /> 198 | 204 | 205 |
206 |
207 | 208 | 209 | 210 | 211 | Step 3: Results! 212 |

213 | 218 |

219 | 220 |

221 | 222 | KNN Search 223 | 224 | {this.state.pictures.map((tile) => ( 225 | 226 | Similar products... 227 | 228 | ))} 229 | 230 | 231 |

232 | 233 | Elasticsearch Match Search 234 | 235 | {this.state.results.map((tile) => ( 236 | 237 | Similar products... 238 | 239 |

240 | 241 | 242 | ))} 243 | 244 | 245 | 246 | 247 |

248 | );} 249 | } 250 | 251 | export default withStyles(styles, { withTheme: true })(App); 252 | 253 | -------------------------------------------------------------------------------- /frontend/src/config/index.js: -------------------------------------------------------------------------------- 1 | import Config from './config.json' 2 | 3 | export default { 4 | apiEndpoint: Config.apiEndpoint 5 | }; -------------------------------------------------------------------------------- /frontend/src/images/header.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-nlu-search/f0c32c17d2a58e8fee0b91096891601002a1bed7/frontend/src/images/header.jpg -------------------------------------------------------------------------------- /frontend/src/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | margin: 0; 3 | font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', 4 | 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', 5 | sans-serif; 6 | -webkit-font-smoothing: antialiased; 7 | -moz-osx-font-smoothing: grayscale; 8 | } 9 | 10 | code { 11 | font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', 12 | monospace; 13 | } 14 | -------------------------------------------------------------------------------- /frontend/src/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom'; 3 | import './index.css'; 4 | import App from './App'; 5 | import * as serviceWorker from './serviceWorker'; 6 | 7 | ReactDOM.render( 8 | 9 | 10 | , 11 | document.getElementById('root') 12 | ); 13 | 14 | // If you want your app to work offline and load faster, you can change 15 | // unregister() to register() below. Note this comes with some pitfalls. 16 | // Learn more about service workers: https://bit.ly/CRA-PWA 17 | serviceWorker.unregister(); 18 | -------------------------------------------------------------------------------- /frontend/src/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /frontend/src/serviceWorker.js: -------------------------------------------------------------------------------- 1 | // This optional code is used to register a service worker. 2 | // register() is not called by default. 3 | 4 | // This lets the app load faster on subsequent visits in production, and gives 5 | // it offline capabilities. However, it also means that developers (and users) 6 | // will only see deployed updates on subsequent visits to a page, after all the 7 | // existing tabs open on the page have been closed, since previously cached 8 | // resources are updated in the background. 9 | 10 | // To learn more about the benefits of this model and instructions on how to 11 | // opt-in, read https://bit.ly/CRA-PWA 12 | 13 | const isLocalhost = Boolean( 14 | window.location.hostname === 'localhost' || 15 | // [::1] is the IPv6 localhost address. 16 | window.location.hostname === '[::1]' || 17 | // 127.0.0.0/8 are considered localhost for IPv4. 18 | window.location.hostname.match( 19 | /^127(?:\.(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)){3}$/ 20 | ) 21 | ); 22 | 23 | export function register(config) { 24 | if (process.env.NODE_ENV === 'production' && 'serviceWorker' in navigator) { 25 | // The URL constructor is available in all browsers that support SW. 26 | const publicUrl = new URL(process.env.PUBLIC_URL, window.location.href); 27 | if (publicUrl.origin !== window.location.origin) { 28 | // Our service worker won't work if PUBLIC_URL is on a different origin 29 | // from what our page is served on. This might happen if a CDN is used to 30 | // serve assets; see https://github.com/facebook/create-react-app/issues/2374 31 | return; 32 | } 33 | 34 | window.addEventListener('load', () => { 35 | const swUrl = `${process.env.PUBLIC_URL}/service-worker.js`; 36 | 37 | if (isLocalhost) { 38 | // This is running on localhost. Let's check if a service worker still exists or not. 39 | checkValidServiceWorker(swUrl, config); 40 | 41 | // Add some additional logging to localhost, pointing developers to the 42 | // service worker/PWA documentation. 43 | navigator.serviceWorker.ready.then(() => { 44 | console.log( 45 | 'This web app is being served cache-first by a service ' + 46 | 'worker. To learn more, visit https://bit.ly/CRA-PWA' 47 | ); 48 | }); 49 | } else { 50 | // Is not localhost. Just register service worker 51 | registerValidSW(swUrl, config); 52 | } 53 | }); 54 | } 55 | } 56 | 57 | function registerValidSW(swUrl, config) { 58 | navigator.serviceWorker 59 | .register(swUrl) 60 | .then(registration => { 61 | registration.onupdatefound = () => { 62 | const installingWorker = registration.installing; 63 | if (installingWorker == null) { 64 | return; 65 | } 66 | installingWorker.onstatechange = () => { 67 | if (installingWorker.state === 'installed') { 68 | if (navigator.serviceWorker.controller) { 69 | // At this point, the updated precached content has been fetched, 70 | // but the previous service worker will still serve the older 71 | // content until all client tabs are closed. 72 | console.log( 73 | 'New content is available and will be used when all ' + 74 | 'tabs for this page are closed. See https://bit.ly/CRA-PWA.' 75 | ); 76 | 77 | // Execute callback 78 | if (config && config.onUpdate) { 79 | config.onUpdate(registration); 80 | } 81 | } else { 82 | // At this point, everything has been precached. 83 | // It's the perfect time to display a 84 | // "Content is cached for offline use." message. 85 | console.log('Content is cached for offline use.'); 86 | 87 | // Execute callback 88 | if (config && config.onSuccess) { 89 | config.onSuccess(registration); 90 | } 91 | } 92 | } 93 | }; 94 | }; 95 | }) 96 | .catch(error => { 97 | console.error('Error during service worker registration:', error); 98 | }); 99 | } 100 | 101 | function checkValidServiceWorker(swUrl, config) { 102 | // Check if the service worker can be found. If it can't reload the page. 103 | fetch(swUrl, { 104 | headers: { 'Service-Worker': 'script' }, 105 | }) 106 | .then(response => { 107 | // Ensure service worker exists, and that we really are getting a JS file. 108 | const contentType = response.headers.get('content-type'); 109 | if ( 110 | response.status === 404 || 111 | (contentType != null && contentType.indexOf('javascript') === -1) 112 | ) { 113 | // No service worker found. Probably a different app. Reload the page. 114 | navigator.serviceWorker.ready.then(registration => { 115 | registration.unregister().then(() => { 116 | window.location.reload(); 117 | }); 118 | }); 119 | } else { 120 | // Service worker found. Proceed as normal. 121 | registerValidSW(swUrl, config); 122 | } 123 | }) 124 | .catch(() => { 125 | console.log( 126 | 'No internet connection found. App is running in offline mode.' 127 | ); 128 | }); 129 | } 130 | 131 | export function unregister() { 132 | if ('serviceWorker' in navigator) { 133 | navigator.serviceWorker.ready 134 | .then(registration => { 135 | registration.unregister(); 136 | }) 137 | .catch(error => { 138 | console.error(error.message); 139 | }); 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /frontendExample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-nlu-search/f0c32c17d2a58e8fee0b91096891601002a1bed7/frontendExample.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import sagemaker_containers 4 | import requests 5 | 6 | import os 7 | import json 8 | import io 9 | import time 10 | import torch 11 | from transformers import AutoTokenizer, AutoModel 12 | # from sentence_transformers import models, losses, SentenceTransformer 13 | 14 | logger = logging.getLogger(__name__) 15 | logger.setLevel(logging.DEBUG) 16 | 17 | #Mean Pooling - Take attention mask into account for correct averaging 18 | def mean_pooling(model_output, attention_mask): 19 | token_embeddings = model_output[0] #First element of model_output contains all token embeddings 20 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 21 | sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) 22 | sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) 23 | return sum_embeddings / sum_mask 24 | 25 | def embed_tformer(model, tokenizer, sentences): 26 | encoded_input = tokenizer(sentences, padding=True, truncation=True, max_length=256, return_tensors='pt') 27 | 28 | #Compute token embeddings 29 | with torch.no_grad(): 30 | model_output = model(**encoded_input) 31 | 32 | sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) 33 | return sentence_embeddings 34 | 35 | def model_fn(model_dir): 36 | logger.info('model_fn') 37 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 38 | logger.info(model_dir) 39 | tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/bert-base-nli-mean-tokens") 40 | nlp_model = AutoModel.from_pretrained("sentence-transformers/bert-base-nli-mean-tokens") 41 | nlp_model.to(device) 42 | model = {'model':nlp_model, 'tokenizer':tokenizer} 43 | 44 | # model = SentenceTransformer(model_dir + '/transformer/') 45 | # logger.info(model) 46 | return model 47 | 48 | # Deserialize the Invoke request body into an object we can perform prediction on 49 | def input_fn(serialized_input_data, content_type='text/plain'): 50 | logger.info('Deserializing the input data.') 51 | try: 52 | data = [serialized_input_data.decode('utf-8')] 53 | return data 54 | except: 55 | raise Exception('Requested unsupported ContentType in content_type: {}'.format(content_type)) 56 | 57 | # Perform prediction on the deserialized object, with the loaded model 58 | def predict_fn(input_object, model): 59 | logger.info("Calling model") 60 | start_time = time.time() 61 | sentence_embeddings = embed_tformer(model['model'], model['tokenizer'], input_object) 62 | print("--- Inference time: %s seconds ---" % (time.time() - start_time)) 63 | response = sentence_embeddings[0].tolist() 64 | return response 65 | 66 | # Serialize the prediction result into the desired response content type 67 | def output_fn(prediction, accept): 68 | logger.info('Serializing the generated output.') 69 | if accept == 'application/json': 70 | output = json.dumps(prediction) 71 | return output 72 | raise Exception('Requested unsupported ContentType in Accept: {}'.format(content_type)) 73 | -------------------------------------------------------------------------------- /nlu-based-item-search.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# NLU based item search\n", 8 | "_**Using a pretrained BERT and Elasticsearch KNN to search textually similar items**_\n", 9 | "\n", 10 | "---\n", 11 | "\n", 12 | "---\n", 13 | "\n", 14 | "## Contents\n", 15 | "\n", 16 | "\n", 17 | "1. [Background](#Background)\n", 18 | "1. [Setup](#Setup)\n", 19 | "1. [Lauange Translation](#Trnslate)\n", 20 | "1. [SageMaker Model Hosting](#Hosting-Model)\n", 21 | "1. [Build a KNN Index in Elasticsearch](#ES-KNN)\n", 22 | "1. [Evaluate Index Search Results](#Searching-with-ES-k-NN)\n", 23 | "1. [Extensions](#Extensions)\n", 24 | "\n", 25 | "## Background\n", 26 | "\n", 27 | "In this notebook, we'll build the core components of a textually similar item search. Often people don't know what exactly they are looking for and in that case they just type an item description and hope it will retrieve similar items.\n", 28 | "\n", 29 | "One of the core components of textually similar items search is a fixed length sentence/word embedding i.e. a “feature vector” that corresponds to that text. The reference word/sentence embedding typically are generated offline and must be stored so they can be efficiently searched. Generating word/sentence embeddings can be achieved using pretrained language models such as BERT (Bidirectional Encoder Representations from Transformers). In our use case we are using a pretrained BERT model from [HuggingFace Transformers](https://huggingface.co/sentence-transformers/distilbert-base-nli-stsb-mean-tokens).\n", 30 | "\n", 31 | "To enable efficient searches for textually similar items, we'll use Amazon SageMaker to generate fixed length sentence embeddings i.e “feature vectors” and use the K-Nearest Neighbor (KNN) algorithm in Amazon Elasticsearch service. KNN for Amazon Elasticsearch Service7.7 lets you search for points in vector space and find the \"nearest neighbors\" for those points by cosine similarity (Default is Euclidean distance). Use cases include recommendations (for example, an \"other songs you might like\" feature in a music application), image recognition, and fraud detection.\n", 32 | "\n", 33 | "Here are the steps we'll follow to build textually similar items: After some initial setup, we'll host the pretrained BERT language model in SageMaker PyTorch model server. Then generate feature vectors for Multi-modal Corpus of Fashion Images from *__feidegger__*, a *__zalandoresearch__* dataset. Those feature vectors will be imported in Amazon Elasticsearch KNN Index. Next, we'll explore some sample text queries, and visualize the results." 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": { 40 | "scrolled": true 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "#Install tqdm to have progress bar\n", 45 | "!pip install tqdm\n", 46 | "\n", 47 | "#install necessary pkg to make connection with elasticsearch domain\n", 48 | "!pip install \"elasticsearch<7.14.0\"\n", 49 | "!pip install requests\n", 50 | "!pip install requests-aws4auth\n", 51 | "!pip install \"sagemaker>=2.0.0<3.0.0\"" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "If you run this notebook in SageMaker Studio, you need to make sure `ipywidgets` is installed and restart the kernel, so please uncomment the code in the next cell, and run it.\n" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "# %%capture\n", 68 | "# import IPython\n", 69 | "# import sys\n", 70 | "\n", 71 | "# !{sys.executable} -m pip install ipywidgets\n", 72 | "# IPython.Application.instance().kernel.do_shutdown(True) # has to restart kernel so changes are used" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "import boto3\n", 82 | "import re\n", 83 | "import time\n", 84 | "import sagemaker\n", 85 | "from sagemaker import get_execution_role\n", 86 | "\n", 87 | "role = get_execution_role()\n", 88 | "\n", 89 | "s3_resource = boto3.resource(\"s3\")\n", 90 | "s3 = boto3.client('s3')\n", 91 | "\n", 92 | "print(f'SageMaker SDK Version: {sagemaker.__version__}')" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "cfn = boto3.client('cloudformation')\n", 102 | "\n", 103 | "def get_cfn_outputs(stackname):\n", 104 | " outputs = {}\n", 105 | " for output in cfn.describe_stacks(StackName=stackname)['Stacks'][0]['Outputs']:\n", 106 | " outputs[output['OutputKey']] = output['OutputValue']\n", 107 | " return outputs\n", 108 | "\n", 109 | "## Setup variables to use for the rest of the demo\n", 110 | "cloudformation_stack_name = \"nlu-search\"\n", 111 | "\n", 112 | "outputs = get_cfn_outputs(cloudformation_stack_name)\n", 113 | "\n", 114 | "bucket = outputs['s3BucketTraining']\n", 115 | "es_host = outputs['esHostName']\n", 116 | "\n", 117 | "outputs" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "### Downloading Zalando Research data\n", 125 | "\n", 126 | "The dataset itself consists of 8732 high-resolution images, each depicting a dress from the available on the Zalando shop against a white-background. Each of the images has five textual annotations in German, each of which has been generated by a separate user. \n", 127 | "\n", 128 | "**Downloading Zalando Research data**: Data originally from here: https://github.com/zalandoresearch/feidegger \n", 129 | "\n", 130 | " **Citation:**
\n", 131 | " *@inproceedings{lefakis2018feidegger,*
\n", 132 | " *title={FEIDEGGER: A Multi-modal Corpus of Fashion Images and Descriptions in German},*
\n", 133 | " *author={Lefakis, Leonidas and Akbik, Alan and Vollgraf, Roland},*
\n", 134 | " *booktitle = {{LREC} 2018, 11th Language Resources and Evaluation Conference},*
\n", 135 | " *year = {2018}*
\n", 136 | " *}*" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "metadata": { 143 | "scrolled": true 144 | }, 145 | "outputs": [], 146 | "source": [ 147 | "## Data Preparation\n", 148 | "\n", 149 | "import os \n", 150 | "import shutil\n", 151 | "import json\n", 152 | "import tqdm\n", 153 | "import urllib.request\n", 154 | "from tqdm import notebook\n", 155 | "from multiprocessing import cpu_count\n", 156 | "from tqdm.contrib.concurrent import process_map\n", 157 | "\n", 158 | "images_path = 'data/feidegger/fashion'\n", 159 | "filename = 'metadata.json'\n", 160 | "\n", 161 | "my_bucket = s3_resource.Bucket(bucket)\n", 162 | "\n", 163 | "os.makedirs(images_path, exist_ok=True)\n", 164 | "\n", 165 | "def download_metadata(url):\n", 166 | " if not os.path.exists(filename):\n", 167 | " urllib.request.urlretrieve(url, filename)\n", 168 | " \n", 169 | "#download metadata.json to local notebook\n", 170 | "download_metadata('https://raw.githubusercontent.com/zalandoresearch/feidegger/master/data/FEIDEGGER_release_1.1.json')\n", 171 | "\n", 172 | "def generate_image_list(filename):\n", 173 | " metadata = open(filename,'r')\n", 174 | " data = json.load(metadata)\n", 175 | " url_lst = []\n", 176 | " for i in range(len(data)):\n", 177 | " url_lst.append(data[i]['url'])\n", 178 | " return url_lst\n", 179 | "\n", 180 | "\n", 181 | "def download_image(url):\n", 182 | " urllib.request.urlretrieve(url, images_path + '/' + url.split(\"/\")[-1])\n", 183 | " \n", 184 | "#generate image list \n", 185 | "url_lst = generate_image_list(filename) \n", 186 | "\n", 187 | "workers = 2 * cpu_count()\n", 188 | "\n", 189 | "#downloading images to local disk\n", 190 | "process_map(download_image, url_lst, max_workers=workers)\n" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": { 197 | "scrolled": true 198 | }, 199 | "outputs": [], 200 | "source": [ 201 | "# Uploading dataset to S3\n", 202 | "\n", 203 | "files_to_upload = []\n", 204 | "dirName = 'data'\n", 205 | "for path, subdirs, files in os.walk('./' + dirName):\n", 206 | " path = path.replace(\"\\\\\",\"/\")\n", 207 | " directory_name = path.replace('./',\"\")\n", 208 | " for file in files:\n", 209 | " files_to_upload.append({\n", 210 | " \"filename\": os.path.join(path, file),\n", 211 | " \"key\": directory_name+'/'+file\n", 212 | " })\n", 213 | " \n", 214 | "\n", 215 | "def upload_to_s3(file):\n", 216 | " my_bucket.upload_file(file['filename'], file['key'])\n", 217 | " \n", 218 | "#uploading images to s3\n", 219 | "process_map(upload_to_s3, files_to_upload, max_workers=workers)" 220 | ] 221 | }, 222 | { 223 | "cell_type": "markdown", 224 | "metadata": {}, 225 | "source": [ 226 | "## Lauange Translation\n", 227 | "\n", 228 | "This dataset has product descriptions in German, So we will use Amazon Translate for English translation for each German sentence." 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "with open(filename) as json_file:\n", 238 | " data = json.load(json_file)" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "#Define translator function\n", 248 | "def translate_txt(data):\n", 249 | " results = {}\n", 250 | " results['filename'] = f's3://{bucket}/data/feidegger/fashion/' + data['url'].split(\"/\")[-1]\n", 251 | " results['descriptions'] = []\n", 252 | " translate = boto3.client(service_name='translate', use_ssl=True)\n", 253 | " for i in data['descriptions']:\n", 254 | " result = translate.translate_text(Text=str(i), \n", 255 | " SourceLanguageCode=\"de\", TargetLanguageCode=\"en\")\n", 256 | " results['descriptions'].append(result['TranslatedText'])\n", 257 | " return results" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "# we are using realtime traslation which will take around ~35 min. \n", 267 | "workers = 1 * cpu_count()\n", 268 | "\n", 269 | "#downloading images to local disk\n", 270 | "results = process_map(translate_txt, data, max_workers=workers)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "# Saving the translated text in json format in case you need later time\n", 280 | "with open('zalando-translated-data.json', 'w', encoding='utf-8') as f:\n", 281 | " json.dump(results, f, ensure_ascii=False, indent=4)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "markdown", 286 | "metadata": {}, 287 | "source": [ 288 | "## SageMaker Model Hosting\n", 289 | "\n", 290 | "In this section will host the pretrained BERT model into SageMaker Pytorch model server to generate 768x1 dimension fixed length sentence embedding from [sentence-transformers](https://github.com/UKPLab/sentence-transformers) using [HuggingFace Transformers](https://huggingface.co/sentence-transformers/distilbert-base-nli-stsb-mean-tokens). \n", 291 | "\n", 292 | "**Citation:**
\n", 293 | " @inproceedings{reimers-2019-sentence-bert,
\n", 294 | " title = \"Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks\",
\n", 295 | " author = \"Reimers, Nils and Gurevych, Iryna\",
\n", 296 | " booktitle = \"Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing\",
\n", 297 | " month = \"11\",
\n", 298 | " year = \"2019\",
\n", 299 | " publisher = \"Association for Computational Linguistics\",
\n", 300 | " url = \"http://arxiv.org/abs/1908.10084\",
\n", 301 | "}" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": null, 307 | "metadata": { 308 | "scrolled": true 309 | }, 310 | "outputs": [], 311 | "source": [ 312 | "!pip install install transformers[torch]" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": null, 318 | "metadata": { 319 | "scrolled": true 320 | }, 321 | "outputs": [], 322 | "source": [ 323 | "#Save the model to disk which we will host at sagemaker\n", 324 | "from transformers import AutoTokenizer, AutoModel\n", 325 | "saved_model_dir = 'transformer'\n", 326 | "os.makedirs(saved_model_dir, exist_ok=True)\n", 327 | "\n", 328 | "tokenizer = AutoTokenizer.from_pretrained(\"sentence-transformers/distilbert-base-nli-stsb-mean-tokens\")\n", 329 | "model = AutoModel.from_pretrained(\"sentence-transformers/distilbert-base-nli-stsb-mean-tokens\") \n", 330 | "\n", 331 | "tokenizer.save_pretrained(saved_model_dir)\n", 332 | "model.save_pretrained(saved_model_dir)" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": null, 338 | "metadata": {}, 339 | "outputs": [], 340 | "source": [ 341 | "#Defining default bucket for SageMaker pretrained model hosting\n", 342 | "sagemaker_session = sagemaker.Session()\n", 343 | "role = sagemaker.get_execution_role()" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": null, 349 | "metadata": {}, 350 | "outputs": [], 351 | "source": [ 352 | "#zip the model in tar.gz format\n", 353 | "!cd transformer && tar czvf ../model.tar.gz *" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": null, 359 | "metadata": {}, 360 | "outputs": [], 361 | "source": [ 362 | "#Upload the model to S3\n", 363 | "\n", 364 | "inputs = sagemaker_session.upload_data(path='model.tar.gz', key_prefix='sentence-transformers-model')\n", 365 | "inputs" 366 | ] 367 | }, 368 | { 369 | "cell_type": "markdown", 370 | "metadata": {}, 371 | "source": [ 372 | "First we need to create a PyTorchModel object. The deploy() method on the model object creates an endpoint which serves prediction requests in real-time. If the instance_type is set to a SageMaker instance type (e.g. ml.m5.large) then the model will be deployed on SageMaker. If the instance_type parameter is set to *__local__* then it will be deployed locally as a Docker container and ready for testing locally.\n", 373 | "\n", 374 | "First we need to create a [`Predictor`](https://sagemaker.readthedocs.io/en/stable/api/inference/predictors.html) class to accept TEXT as input and output JSON. The default behaviour is to accept a numpy array." 375 | ] 376 | }, 377 | { 378 | "cell_type": "code", 379 | "execution_count": null, 380 | "metadata": {}, 381 | "outputs": [], 382 | "source": [ 383 | "from sagemaker.pytorch import PyTorch, PyTorchModel\n", 384 | "from sagemaker.predictor import Predictor\n", 385 | "from sagemaker import get_execution_role\n", 386 | "\n", 387 | "class StringPredictor(Predictor):\n", 388 | " def __init__(self, endpoint_name, sagemaker_session):\n", 389 | " super(StringPredictor, self).__init__(endpoint_name, sagemaker_session, content_type='text/plain')\n", 390 | " " 391 | ] 392 | }, 393 | { 394 | "cell_type": "code", 395 | "execution_count": null, 396 | "metadata": {}, 397 | "outputs": [], 398 | "source": [ 399 | "pytorch_model = PyTorchModel(model_data = inputs, \n", 400 | " role=role, \n", 401 | " entry_point ='inference.py',\n", 402 | " source_dir = './code',\n", 403 | " py_version = 'py3', \n", 404 | " framework_version = '1.7.1',\n", 405 | " predictor_cls=StringPredictor)\n", 406 | "\n", 407 | "predictor = pytorch_model.deploy(instance_type='ml.g4dn.xlarge', \n", 408 | " initial_instance_count=1, \n", 409 | " endpoint_name = f'nlu-search-model-{int(time.time())}')\n" 410 | ] 411 | }, 412 | { 413 | "cell_type": "markdown", 414 | "metadata": {}, 415 | "source": [ 416 | "HuggingFace Transformers uses BERT pretrained model so it will generate 768 dimension for the given text. we will quickly validate the same in next cell." 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": null, 422 | "metadata": {}, 423 | "outputs": [], 424 | "source": [ 425 | "# Doing a quick test to make sure model is generating the embeddings\n", 426 | "import json\n", 427 | "payload = 'a yellow dress that comes about to the knees'\n", 428 | "features = predictor.predict(payload)\n", 429 | "embedding = json.loads(features)\n", 430 | "\n", 431 | "embedding" 432 | ] 433 | }, 434 | { 435 | "cell_type": "markdown", 436 | "metadata": {}, 437 | "source": [ 438 | "## Build a KNN Index in Elasticsearch\n", 439 | "\n", 440 | "KNN for Amazon Elasticsearch Service lets you search for points in a vector space and find the \"nearest neighbors\" for those points by cosine similarity. Use cases include recommendations (for example, an \"other songs you might like\" feature in a music application), image recognition, and fraud detection.\n", 441 | "\n", 442 | "KNN cosine similarity requires Elasticsearch 7.7 or later (Default is Euclidean distance). Full documentation for the Elasticsearch feature, including descriptions of settings and statistics, is available in the Open Distro for Elasticsearch documentation. For background information about the k-nearest neighbors algorithm\n", 443 | "\n", 444 | "In this step we'll get all the translated product descriptions of *__zalandoresearch__* dataset and import those embadings into Elastichseach 7.7 domain." 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "execution_count": null, 450 | "metadata": {}, 451 | "outputs": [], 452 | "source": [ 453 | "# setting up the Elasticsearch connection\n", 454 | "from elasticsearch import Elasticsearch, RequestsHttpConnection\n", 455 | "from requests_aws4auth import AWS4Auth\n", 456 | "region = 'us-east-1' # e.g. us-east-1\n", 457 | "service = 'es'\n", 458 | "credentials = boto3.Session().get_credentials()\n", 459 | "awsauth = AWS4Auth(credentials.access_key, credentials.secret_key, region, service, session_token=credentials.token)\n", 460 | "\n", 461 | "es = Elasticsearch(\n", 462 | " hosts = [{'host': es_host, 'port': 443}],\n", 463 | " http_auth = awsauth,\n", 464 | " use_ssl = True,\n", 465 | " verify_certs = True,\n", 466 | " connection_class = RequestsHttpConnection\n", 467 | ")" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": null, 473 | "metadata": {}, 474 | "outputs": [], 475 | "source": [ 476 | "#KNN index maping\n", 477 | "knn_index = {\n", 478 | " \"settings\": {\n", 479 | " \"index.knn\": True,\n", 480 | " \"index.knn.space_type\": \"cosinesimil\",\n", 481 | " \"analysis\": {\n", 482 | " \"analyzer\": {\n", 483 | " \"default\": {\n", 484 | " \"type\": \"standard\",\n", 485 | " \"stopwords\": \"_english_\"\n", 486 | " }\n", 487 | " }\n", 488 | " }\n", 489 | " },\n", 490 | " \"mappings\": {\n", 491 | " \"properties\": {\n", 492 | " \"zalando_nlu_vector\": { \n", 493 | " \"type\": \"knn_vector\",\n", 494 | " \"dimension\": 768\n", 495 | " } \n", 496 | " }\n", 497 | " }\n", 498 | "}" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "execution_count": null, 504 | "metadata": {}, 505 | "outputs": [], 506 | "source": [ 507 | "#Creating the Elasticsearch index\n", 508 | "es.indices.create(index=\"idx_zalando\",body=knn_index,ignore=400) # change the index name " 509 | ] 510 | }, 511 | { 512 | "cell_type": "code", 513 | "execution_count": null, 514 | "metadata": {}, 515 | "outputs": [], 516 | "source": [ 517 | "# If you need to load zalando-translated-data.json un-comment and execute the following code\n", 518 | "# with open('zalando-translated-data.json') as json_file:\n", 519 | "# results = json.load(json_file)" 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "execution_count": null, 525 | "metadata": {}, 526 | "outputs": [], 527 | "source": [ 528 | "results[0]['descriptions']" 529 | ] 530 | }, 531 | { 532 | "cell_type": "code", 533 | "execution_count": null, 534 | "metadata": {}, 535 | "outputs": [], 536 | "source": [ 537 | "# For each product, we are concatenating all the \n", 538 | "# product descriptions into a single sentence,\n", 539 | "# so that we will have one embedding for each product\n", 540 | "\n", 541 | "def concat_desc(results):\n", 542 | " obj = {\n", 543 | " 'filename': results['filename'],\n", 544 | " }\n", 545 | " obj['descriptions'] = ' '.join(results['descriptions'])\n", 546 | " return obj\n", 547 | "\n", 548 | "concat_results = map(concat_desc, results)\n", 549 | "concat_results = list(concat_results)\n", 550 | "concat_results[0]" 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": null, 556 | "metadata": {}, 557 | "outputs": [], 558 | "source": [ 559 | "# defining a function to import the feature vectors corrosponds to each S3 URI into Elasticsearch KNN index\n", 560 | "# This process will take around ~2 min.\n", 561 | "\n", 562 | "def es_import(concat_result):\n", 563 | " vector = json.loads(predictor.predict(concat_result['descriptions']))\n", 564 | " es.index(index='idx_zalando',\n", 565 | " body={\"zalando_nlu_vector\": vector,\n", 566 | " \"image\": concat_result['filename'],\n", 567 | " \"description\": concat_result['descriptions']}\n", 568 | " )\n", 569 | " \n", 570 | "workers = 4 * cpu_count()\n", 571 | " \n", 572 | "process_map(es_import, concat_results, max_workers=workers)" 573 | ] 574 | }, 575 | { 576 | "cell_type": "markdown", 577 | "metadata": {}, 578 | "source": [ 579 | "## Evaluate Index Search Results\n", 580 | "\n", 581 | "In this step we will use SageMaker SDK as well as Boto3 SDK to query the Elasticsearch to retrive the nearest neighbours and retrive the relevent product images from Amazon S3 to display in the notebook." 582 | ] 583 | }, 584 | { 585 | "cell_type": "code", 586 | "execution_count": null, 587 | "metadata": {}, 588 | "outputs": [], 589 | "source": [ 590 | "#define display_image function\n", 591 | "from PIL import Image\n", 592 | "import io\n", 593 | "def display_image(bucket, key):\n", 594 | " s3_object = my_bucket.Object(key)\n", 595 | " response = s3_object.get()\n", 596 | " file_stream = response['Body']\n", 597 | " img = Image.open(file_stream)\n", 598 | " return display(img)" 599 | ] 600 | }, 601 | { 602 | "cell_type": "markdown", 603 | "metadata": {}, 604 | "source": [ 605 | "## SageMaker SDK Method" 606 | ] 607 | }, 608 | { 609 | "cell_type": "code", 610 | "execution_count": null, 611 | "metadata": {}, 612 | "outputs": [], 613 | "source": [ 614 | "#SageMaker SDK approach\n", 615 | "import json\n", 616 | "payload = 'I want a dress that is yellow and flowery'\n", 617 | "features = predictor.predict(payload)\n", 618 | "embedding = json.loads(features)" 619 | ] 620 | }, 621 | { 622 | "cell_type": "code", 623 | "execution_count": null, 624 | "metadata": {}, 625 | "outputs": [], 626 | "source": [ 627 | "#ES index search\n", 628 | "import json\n", 629 | "k = 3\n", 630 | "idx_name = 'idx_zalando'\n", 631 | "res = es.search(request_timeout=30, index=idx_name,\n", 632 | " body={'size': k, \n", 633 | " 'query': {'knn': {'zalando_nlu_vector': {'vector': embedding, 'k': k}}}})\n" 634 | ] 635 | }, 636 | { 637 | "cell_type": "code", 638 | "execution_count": null, 639 | "metadata": {}, 640 | "outputs": [], 641 | "source": [ 642 | "#Display the image\n", 643 | "\n", 644 | "for x in res['hits']['hits']:\n", 645 | " key = x['_source']['image']\n", 646 | " key = key.replace(f's3://{bucket}/','')\n", 647 | " img = display_image(bucket,key)\n" 648 | ] 649 | }, 650 | { 651 | "cell_type": "markdown", 652 | "metadata": {}, 653 | "source": [ 654 | "## Boto3 Method" 655 | ] 656 | }, 657 | { 658 | "cell_type": "code", 659 | "execution_count": null, 660 | "metadata": {}, 661 | "outputs": [], 662 | "source": [ 663 | "#calling SageMaker Endpoint\n", 664 | "client = boto3.client('sagemaker-runtime')\n", 665 | "ENDPOINT_NAME = predictor.endpoint\n", 666 | "response = client.invoke_endpoint(EndpointName=ENDPOINT_NAME,\n", 667 | " ContentType='text/plain',\n", 668 | " Body=payload)\n", 669 | "\n", 670 | "response_body = json.loads((response['Body'].read()))\n" 671 | ] 672 | }, 673 | { 674 | "cell_type": "code", 675 | "execution_count": null, 676 | "metadata": {}, 677 | "outputs": [], 678 | "source": [ 679 | "#ES index search\n", 680 | "import json\n", 681 | "k = 3\n", 682 | "idx_name = 'idx_zalando'\n", 683 | "res = es.search(request_timeout=30, index=idx_name,\n", 684 | " body={'size': k, \n", 685 | " 'query': {'knn': {'zalando_nlu_vector': {'vector': response_body, 'k': k}}}})" 686 | ] 687 | }, 688 | { 689 | "cell_type": "code", 690 | "execution_count": null, 691 | "metadata": {}, 692 | "outputs": [], 693 | "source": [ 694 | "#Display the image\n", 695 | "\n", 696 | "for i in range(k):\n", 697 | " key = res['hits']['hits'][i]['_source']['image']\n", 698 | " key = key.replace(f's3://{bucket}/','')\n", 699 | " img = display_image(bucket,key)" 700 | ] 701 | }, 702 | { 703 | "cell_type": "markdown", 704 | "metadata": {}, 705 | "source": [ 706 | "## Standard Full-Text Amazon ES Search" 707 | ] 708 | }, 709 | { 710 | "cell_type": "code", 711 | "execution_count": null, 712 | "metadata": {}, 713 | "outputs": [], 714 | "source": [ 715 | "fuzzy_search_body = {\n", 716 | " \"_source\": {\n", 717 | " \"excludes\": [ \"zalando_nlu_vector\" ]\n", 718 | " },\n", 719 | " \"query\": {\n", 720 | " \"match\" : {\n", 721 | " \"description\" : {\n", 722 | " \"query\" : payload\n", 723 | " }\n", 724 | " }\n", 725 | " }\n", 726 | "}\n", 727 | "\n", 728 | "result_fuzzy = es.search(request_timeout=30, index=idx_name,\n", 729 | " body=fuzzy_search_body)\n", 730 | " \n", 731 | "for x in result_fuzzy['hits']['hits'][:3]:\n", 732 | " key = x['_source']['image']\n", 733 | " key = key.replace(f's3://{bucket}/','')\n", 734 | " img = display_image(bucket,key)" 735 | ] 736 | }, 737 | { 738 | "cell_type": "markdown", 739 | "metadata": {}, 740 | "source": [ 741 | "## Deploying a full-stack NLU search application" 742 | ] 743 | }, 744 | { 745 | "cell_type": "code", 746 | "execution_count": null, 747 | "metadata": {}, 748 | "outputs": [], 749 | "source": [ 750 | "s3_resource.Object(bucket, 'backend/template.yaml').upload_file('./backend/template.yaml', ExtraArgs={'ACL':'public-read'})\n", 751 | "\n", 752 | "\n", 753 | "sam_template_url = f'https://{bucket}.s3.amazonaws.com/backend/template.yaml'\n", 754 | "\n", 755 | "# Generate the CloudFormation Quick Create Link\n", 756 | "\n", 757 | "print(\"Click the URL below to create the backend API for NLU search:\\n\")\n", 758 | "print((\n", 759 | " 'https://console.aws.amazon.com/cloudformation/home?region=us-east-1#/stacks/create/review'\n", 760 | " f'?templateURL={sam_template_url}'\n", 761 | " '&stackName=nlu-search-api'\n", 762 | " f'¶m_BucketName={outputs[\"s3BucketTraining\"]}'\n", 763 | " f'¶m_DomainName={outputs[\"esDomainName\"]}'\n", 764 | " f'¶m_ElasticSearchURL={outputs[\"esHostName\"]}'\n", 765 | " f'¶m_SagemakerEndpoint={predictor.endpoint}'\n", 766 | "))" 767 | ] 768 | }, 769 | { 770 | "cell_type": "markdown", 771 | "metadata": {}, 772 | "source": [ 773 | "Now that you have a working Amazon SageMaker endpoint for extracting image features and a KNN index on Elasticsearch, you are ready to build a real-world full-stack ML-powered web app. The SAM template you just created will deploy an Amazon API Gateway and AWS Lambda function. The Lambda function runs your code in response to HTTP requests that are sent to the API Gateway." 774 | ] 775 | }, 776 | { 777 | "cell_type": "code", 778 | "execution_count": null, 779 | "metadata": {}, 780 | "outputs": [], 781 | "source": [ 782 | "# Review the content of the Lambda function code.\n", 783 | "!pygmentize backend/lambda/app.py" 784 | ] 785 | }, 786 | { 787 | "cell_type": "markdown", 788 | "metadata": {}, 789 | "source": [ 790 | "## Once the CloudFormation Stack shows CREATE_COMPLETE, proceed to this cell below:" 791 | ] 792 | }, 793 | { 794 | "cell_type": "code", 795 | "execution_count": null, 796 | "metadata": {}, 797 | "outputs": [], 798 | "source": [ 799 | "import json\n", 800 | "api_endpoint = get_cfn_outputs('nlu-search-api')['TextSimilarityApi']\n", 801 | "\n", 802 | "with open('./frontend/src/config/config.json', 'w') as outfile:\n", 803 | " json.dump({'apiEndpoint': api_endpoint}, outfile)" 804 | ] 805 | }, 806 | { 807 | "cell_type": "markdown", 808 | "metadata": {}, 809 | "source": [ 810 | "## Step 2: Deploy frontend services" 811 | ] 812 | }, 813 | { 814 | "cell_type": "code", 815 | "execution_count": null, 816 | "metadata": {}, 817 | "outputs": [], 818 | "source": [ 819 | "# add NPM to the path so we can assemble the web frontend from our notebook code\n", 820 | "\n", 821 | "from os import environ\n", 822 | "\n", 823 | "npm_path = ':/home/ec2-user/anaconda3/envs/JupyterSystemEnv/bin'\n", 824 | "\n", 825 | "if npm_path not in environ['PATH']:\n", 826 | " ADD_NPM_PATH = environ['PATH']\n", 827 | " ADD_NPM_PATH = ADD_NPM_PATH + npm_path\n", 828 | "else:\n", 829 | " ADD_NPM_PATH = environ['PATH']\n", 830 | " \n", 831 | "%set_env PATH=$ADD_NPM_PATH" 832 | ] 833 | }, 834 | { 835 | "cell_type": "code", 836 | "execution_count": null, 837 | "metadata": { 838 | "scrolled": true 839 | }, 840 | "outputs": [], 841 | "source": [ 842 | "%cd ./frontend/\n", 843 | "\n", 844 | "!npm install" 845 | ] 846 | }, 847 | { 848 | "cell_type": "code", 849 | "execution_count": null, 850 | "metadata": {}, 851 | "outputs": [], 852 | "source": [ 853 | "!npm run-script build" 854 | ] 855 | }, 856 | { 857 | "cell_type": "code", 858 | "execution_count": null, 859 | "metadata": {}, 860 | "outputs": [], 861 | "source": [ 862 | "hosting_bucket = f\"s3://{outputs['s3BucketHostingBucketName']}\"\n", 863 | "\n", 864 | "!aws s3 sync ./build/ $hosting_bucket --acl public-read" 865 | ] 866 | }, 867 | { 868 | "cell_type": "markdown", 869 | "metadata": {}, 870 | "source": [ 871 | "## Step 3: Browse your frontend service" 872 | ] 873 | }, 874 | { 875 | "cell_type": "code", 876 | "execution_count": null, 877 | "metadata": {}, 878 | "outputs": [], 879 | "source": [ 880 | "print('Click the URL below:\\n')\n", 881 | "print(outputs['S3BucketSecureURL'] + '/index.html')" 882 | ] 883 | }, 884 | { 885 | "cell_type": "markdown", 886 | "metadata": {}, 887 | "source": [ 888 | "You should see the following page:" 889 | ] 890 | }, 891 | { 892 | "cell_type": "markdown", 893 | "metadata": {}, 894 | "source": [ 895 | "![Website](./frontendExample.png)" 896 | ] 897 | }, 898 | { 899 | "cell_type": "markdown", 900 | "metadata": {}, 901 | "source": [ 902 | "In the search bar, try typing *__\"I want a summery dress that's flowery and yellow\"__*. Notice how the results are more relevant in the KNN Search method in contrast with the conventional full-text match search query. The KNN Search method uses features generated from the search text and finds similar product descriptions, which allows it to find similar products, even if the description doesn't have the exact same phrase." 903 | ] 904 | }, 905 | { 906 | "cell_type": "markdown", 907 | "metadata": {}, 908 | "source": [ 909 | "## Extensions\n", 910 | "\n", 911 | "We have used pretrained BERT model from sentence-transformers. You can fine tune BERT model based on your own usecase with your own data." 912 | ] 913 | }, 914 | { 915 | "cell_type": "markdown", 916 | "metadata": {}, 917 | "source": [ 918 | "## Cleanup\n", 919 | "\n", 920 | "Make sure that you stop the notebook instance, delete the Amazon SageMaker endpoint and delete the Elasticsearch domain to prevent any additional charges." 921 | ] 922 | }, 923 | { 924 | "cell_type": "code", 925 | "execution_count": null, 926 | "metadata": {}, 927 | "outputs": [], 928 | "source": [ 929 | "# Delete the endpoint\n", 930 | "predictor.delete_endpoint()\n", 931 | "\n", 932 | "# Empty S3 Contents\n", 933 | "training_bucket_resource = s3_resource.Bucket(bucket)\n", 934 | "training_bucket_resource.objects.all().delete()\n", 935 | "\n", 936 | "hosting_bucket_resource = s3_resource.Bucket(outputs['s3BucketHostingBucketName'])\n", 937 | "hosting_bucket_resource.objects.all().delete()" 938 | ] 939 | } 940 | ], 941 | "metadata": { 942 | "kernelspec": { 943 | "display_name": "conda_python3", 944 | "language": "python", 945 | "name": "conda_python3" 946 | }, 947 | "language_info": { 948 | "codemirror_mode": { 949 | "name": "ipython", 950 | "version": 3 951 | }, 952 | "file_extension": ".py", 953 | "mimetype": "text/x-python", 954 | "name": "python", 955 | "nbconvert_exporter": "python", 956 | "pygments_lexer": "ipython3", 957 | "version": "3.6.13" 958 | } 959 | }, 960 | "nbformat": 4, 961 | "nbformat_minor": 4 962 | } 963 | -------------------------------------------------------------------------------- /query.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-nlu-search/f0c32c17d2a58e8fee0b91096891601002a1bed7/query.png -------------------------------------------------------------------------------- /ref.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/amazon-sagemaker-nlu-search/f0c32c17d2a58e8fee0b91096891601002a1bed7/ref.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sentence-transformers 2 | --------------------------------------------------------------------------------