├── .all-contributorsrc ├── .config.ini ├── .coveragerc ├── .editorconfig ├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── question.md ├── stale.yml └── workflows │ └── test.yml ├── .gitignore ├── .readthedocs.yml ├── .test_durations ├── LICENSE ├── README.md ├── docs ├── _static │ ├── css │ │ ├── extra.css │ │ └── modify.css │ ├── fonts │ │ └── RobotoSlab │ │ │ └── roboto-slab.eot │ ├── images │ │ ├── multi_feature_model.png │ │ ├── ner_f1_scores.png │ │ └── smp2018ecdtcorpus_f1_score.png │ └── js │ │ └── baidu-static.js ├── about │ ├── contributing.md │ └── release-notes.md ├── advance-use │ └── tensorflow-serving.md ├── apis │ ├── classification.rst │ ├── corpus.rst │ ├── embeddings.rst │ ├── generators.rst │ ├── labeling.rst │ └── processors.rst ├── conf.py ├── embeddings │ ├── bare-embedding.rst │ ├── bert-embedding.rst │ ├── index.rst │ ├── transformer-embedding.rst │ └── word-embedding.rst ├── index.rst ├── requirements.txt └── tutorial │ ├── seq2seq.md │ ├── text-classification.md │ └── text-labeling.md ├── examples ├── benchmarks │ ├── benchmark_utils.py │ ├── classification.ipynb │ ├── multi_label_classificaiton.ipynb │ └── named_entity_recognition.ipynb ├── custom_generator.py ├── k_fold_evaluation.py ├── tools.py ├── train_with_generator.ipynb ├── translate_with_seq2seq.ipynb └── web_qa_reading_comprehence.py ├── kashgari ├── __init__.py ├── __version__.py ├── callbacks │ ├── __init__.py │ ├── eval_callBack.py │ └── save_callback.py ├── corpus.py ├── embeddings │ ├── __init__.py │ ├── abc_embedding.py │ ├── bare_embedding.py │ ├── bert_embedding.py │ ├── transformer_embedding.py │ └── word_embedding.py ├── generators.py ├── layers │ ├── __init__.py │ ├── behdanau_attention.py │ └── conditional_random_field.py ├── logger.py ├── macros.py ├── metrics │ ├── __init__.py │ ├── multi_label_classification.py │ └── sequence_labeling.py ├── processors │ ├── __init__.py │ ├── abc_processor.py │ ├── class_processor.py │ ├── sequence_processor.py │ └── tools.py ├── tasks │ ├── __init__.py │ ├── abs_task_model.py │ ├── classification │ │ ├── __init__.py │ │ ├── abc_model.py │ │ ├── bi_gru_model.py │ │ ├── bi_lstm_model.py │ │ ├── cnn_attention_model.py │ │ ├── cnn_gru_model.py │ │ ├── cnn_lstm_model.py │ │ └── cnn_model.py │ ├── labeling │ │ ├── __init__.py │ │ ├── abc_model.py │ │ ├── bi_gru_crf_model.py │ │ ├── bi_gru_model.py │ │ ├── bi_lstm_crf_model.py │ │ ├── bi_lstm_model.py │ │ └── cnn_lstm_model.py │ └── seq2seq │ │ ├── __init__.py │ │ ├── decoder │ │ ├── __init__.py │ │ ├── att_gru_decoder.py │ │ └── gru_decoder.py │ │ ├── encoder │ │ ├── __init__.py │ │ └── gru_encoder.py │ │ └── model.py ├── tokenizers │ ├── __init__.py │ ├── base_tokenizer.py │ ├── bert_tokenizer.py │ └── jieba_tokenizer.py ├── types.py └── utils │ ├── __init__.py │ ├── data.py │ ├── model.py │ ├── multi_label.py │ └── serialize.py ├── legacy_docs ├── docs │ ├── CNAME │ ├── index.md │ └── version_selection.jpg ├── mkdocs.yml └── readme.md ├── requirements.dev.txt ├── requirements.txt ├── scripts ├── clean.sh ├── docs-generate.sh ├── docs-lint.sh ├── docs-live.sh ├── install_addons.py ├── install_tf.py ├── lint.sh ├── markdown2rst.py └── tests.sh ├── setup.py ├── sonar-project.properties ├── test_performance ├── classifications.py ├── labeling.py └── readme.md └── tests ├── __init__.py ├── test_classification ├── __init__.py ├── test_bi_gru_model.py ├── test_bi_lstm_model.py ├── test_cnn_attention_model.py ├── test_cnn_gru_model.py ├── test_cnn_lstm_model.py ├── test_cnn_model.py └── test_custom_model.py ├── test_corpus.py ├── test_embeddings ├── __init__.py ├── test_bare_embedding.py ├── test_transformer_embedding.py └── test_word_embedding.py ├── test_generator.py ├── test_labeling ├── __init__.py ├── test_bi_gru_crf_model.py ├── test_bi_gru_model.py ├── test_bi_lstm_crf_model.py ├── test_bi_lstm_model.py └── test_cnn_lstm_model.py ├── test_macros.py ├── test_processor ├── __init__.py ├── test_class_processor.py └── test_sequence_processor.py ├── test_seq2seq ├── __init__.py └── test_seq2seq.py ├── test_tokenizers.py └── test_utils.py /.all-contributorsrc: -------------------------------------------------------------------------------- 1 | { 2 | "files": [ 3 | "README.md" 4 | ], 5 | "imageSize": 100, 6 | "commit": false, 7 | "contributors": [ 8 | { 9 | "login": "BrikerMan", 10 | "name": "Eliyar Eziz", 11 | "avatar_url": "https://avatars1.githubusercontent.com/u/9368907?v=4", 12 | "profile": "https://developers.google.com/community/experts/directory/profile/profile-eliyar_eziz", 13 | "contributions": [ 14 | "doc", 15 | "test", 16 | "code" 17 | ] 18 | }, 19 | { 20 | "login": "alexwwang", 21 | "name": "Alex Wang", 22 | "avatar_url": "https://avatars3.githubusercontent.com/u/856746?v=4", 23 | "profile": "http://www.chuanxilu.com", 24 | "contributions": [ 25 | "code" 26 | ] 27 | }, 28 | { 29 | "login": "lsgrep", 30 | "name": "Yusup", 31 | "avatar_url": "https://avatars3.githubusercontent.com/u/3893940?v=4", 32 | "profile": "https://github.com/lsgrep", 33 | "contributions": [ 34 | "code" 35 | ] 36 | }, 37 | { 38 | "login": "Adline125", 39 | "name": "Adline", 40 | "avatar_url": "https://avatars1.githubusercontent.com/u/5442229?v=4", 41 | "profile": "https://github.com/adlinex", 42 | "contributions": [ 43 | "code" 44 | ] 45 | } 46 | ], 47 | "contributorsPerLine": 7, 48 | "projectName": "Kashgari", 49 | "projectOwner": "BrikerMan", 50 | "repoType": "github", 51 | "repoHost": "https://github.com", 52 | "skipCi": true 53 | } 54 | -------------------------------------------------------------------------------- /.config.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | 3 | ignore = A003, 4 | E402, E701, E722, 5 | F401, F541 6 | T400, T484, 7 | W291, W292, W391 8 | 9 | exclude = 10 | *migrations*, 11 | # python related 12 | *.pyc, 13 | .git, 14 | __pycache__, 15 | 16 | max-line-length=160 17 | max-complexity=12 18 | format=pylint 19 | show_source = True 20 | statistics = True 21 | count = True 22 | 23 | builtins = 24 | ignore 25 | override 26 | 27 | [mypy] 28 | 29 | disallow_untyped_defs = True 30 | ignore_missing_imports = True 31 | allow_redefinition = True 32 | strict_optional = False 33 | no_implicit_optional = True 34 | 35 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | # Regexes for lines to exclude from consideration 3 | exclude_lines = 4 | # Have to re-enable the standard pragma 5 | pragma: no cover 6 | 7 | # Don't complain about missing debug-only code: 8 | def __repr__ 9 | if self\.debug 10 | if debug_info: 11 | 12 | # Don't complain if tests don't hit defensive assertion code: 13 | raise AssertionError 14 | raise NotImplementedError 15 | raise ValueError 16 | except Exception as e: 17 | except ModuleNotFoundError: 18 | logging.debug 19 | 20 | # Don't complain if non-runnable code isn't run: 21 | if __name__ == .__main__.: 22 | 23 | ignore_errors = True 24 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # EditorConfig is awesome: http://EditorConfig.org 2 | 3 | # top-most EditorConfig file 4 | root = true 5 | 6 | # Unix-style newlines with a newline ending every file 7 | [*] 8 | end_of_line = lf 9 | insert_final_newline = true 10 | trim_trailing_whitespace = true 11 | max_line_length = 200 12 | 13 | # Matches multiple files with brace expansion notation 14 | # Set default charset 15 | [*.{js,py}] 16 | charset = utf-8 17 | 18 | # 4 space indentation 19 | [*.py] 20 | indent_style = space 21 | indent_size = 4 22 | max_line_length = 120 23 | 24 | # 2 space indentation 25 | [*.{html, yml, json, rst, ini, md}] 26 | indent_style = space 27 | indent_size = 2 28 | max_line_length = 120 29 | 30 | # Tab indentation (no size specified) 31 | [Makefile] 32 | indent_style = tab 33 | tab_width = 4 34 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | patreon: brikerman 4 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[BUG] " 5 | labels: bug 6 | assignees: BrikerMan 7 | 8 | --- 9 | 10 | **You must follow the issue template and provide as much information as possible. otherwise, this issue will be closed. 11 | 请按照 issue 模板要求填写信息。如果没有按照 issue 模板填写,将会忽略并关闭这个 issue** 12 | ## Check List 13 | 14 | Thanks for considering to open an issue. Before you submit your issue, please confirm these boxes are checked. 15 | 16 | **You can post pictures, but if specific text or code is required to reproduce the issue, please provide the text in a plain text format for easy copy/paste.** 17 | 18 | - [ ] I have searched in [existing issues](https://github.com/BrikerMan/Kashgari/issues?utf8=%E2%9C%93&q=is%3Aissue+) but did not find the same one. 19 | - [ ] I have read the [documents](https://kashgari.bmio.net) 20 | 21 | ## Environment 22 | 23 | - OS [e.g. Mac OS, Linux]: 24 | - Python Version: 25 | - requirements.txt: 26 | 27 | ```txt 28 | [Paste requirements.txt file here] 29 | ``` 30 | 31 | ## Issue Description 32 | ### What 33 | 34 | [Tell us about the issue] 35 | 36 | ### Reproduce 37 | 38 | [The steps to reproduce this issue. What is the URL you were trying to play, where did you put your code, etc.] 39 | 40 | ### Other Comment 41 | 42 | [Add anything else here] 43 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[Feature request] " 5 | labels: enhancement 6 | assignees: BrikerMan 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question 3 | about: Any question related to the usage of Kashgari 4 | title: "[Question] " 5 | labels: question 6 | assignees: BrikerMan 7 | 8 | --- 9 | 10 | **You must follow the issue template and provide as much information as possible. otherwise, this issue will be closed. 11 | 请按照 issue 模板要求填写信息。如果没有按照 issue 模板填写,将会忽略并关闭这个 issue** 12 | 13 | ## Check List 14 | Thanks for considering to open an issue. Before you submit your issue, please confirm these boxes are checked. 15 | 16 | **You can post pictures, but if specific text or code is required to reproduce the issue, please provide the text in a plain text format for easy copy/paste.** 17 | 18 | - [ ] I have searched in [existing issues](https://github.com/BrikerMan/Kashgari/issues?utf8=%E2%9C%93&q=is%3Aissue+) but did not find the same one. 19 | - [ ] I have read the [documents](https://kashgari.bmio.net) 20 | 21 | ## Environment 22 | 23 | - OS [e.g. Mac OS, Linux]: 24 | - Python Version: 25 | - requirements.txt: 26 | 27 | ```txt 28 | [Paste requirements.txt file here] 29 | ``` 30 | 31 | ## Question 32 | 33 | [A clear and concise description of what you want to know.] 34 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 180 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 30 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | # Label to use when marking an issue as stale 10 | staleLabel: wontfix 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. Thank you 15 | for your contributions. 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: false 18 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | on: 3 | push: 4 | branches: 5 | - v2-main 6 | - v2-dev 7 | - v2/github-actions 8 | pull_request: 9 | types: [opened, synchronize, reopened] 10 | jobs: 11 | lint: 12 | name: Lint 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Set up Python 3.8 17 | uses: actions/setup-python@v1 18 | with: 19 | python-version: 3.8 20 | - name: Install deps 21 | run: | 22 | python -m pip install --upgrade pip 23 | python scripts/install_tf.py 2.2 24 | python scripts/install_addons.py 2.2 25 | pip install -r requirements.dev.txt 26 | pip install -r requirements.txt 27 | - name: Run lint script 28 | run: sh ./scripts/lint.sh 29 | 30 | test: 31 | if: always() 32 | name: "Test with TF ${{ matrix.tensorflow_version }} - ${{ matrix.group }}" 33 | needs: lint 34 | runs-on: ubuntu-latest 35 | strategy: 36 | matrix: 37 | group: [ 1, 2, 3 ] 38 | tensorflow_version: [2.2, 2.3, 2.4, 2.5] 39 | steps: 40 | - uses: actions/checkout@v2 41 | - name: Set up Python 3.8 42 | uses: actions/setup-python@v1 43 | with: 44 | python-version: 3.8 45 | - name: Install deps 46 | run: | 47 | python -m pip install --upgrade pip 48 | python scripts/install_tf.py '${{ matrix.tensorflow_version }}' 49 | python scripts/install_addons.py '${{ matrix.tensorflow_version }}' 50 | pip install -r requirements.dev.txt 51 | pip install -r requirements.txt 52 | 53 | - name: Run pytest 54 | run: 'pytest 55 | --doctest-modules 56 | --junitxml=test-reports/junit-${{ matrix.tensorflow_version }}-${{ matrix.group }}.xml 57 | --cov=kashgari 58 | --cov-report=xml:cov-reports/coverage-${{ matrix.tensorflow_version }}-${{ matrix.group }}.xml 59 | --cov-report term 60 | --cov-config .coveragerc 61 | --cov 62 | --splits 3 63 | --group ${{ matrix.group }} 64 | tests/' 65 | 66 | - name: Upload unit test 67 | uses: actions/upload-artifact@v2 68 | with: 69 | name: junitxml-${{ matrix.tensorflow_version }}-${{ matrix.group }} 70 | path: test-reports 71 | 72 | - name: Upload coverage 73 | uses: actions/upload-artifact@v2 74 | with: 75 | name: coverage-${{ matrix.tensorflow_version }}-${{ matrix.group }} 76 | path: cov-reports 77 | 78 | sonarcloud: 79 | if: "!contains(github.event.head_commit.message, 'skip ci')" 80 | name: SonarCloud 81 | runs-on: ubuntu-latest 82 | needs: test 83 | steps: 84 | - uses: actions/checkout@v2 85 | with: 86 | fetch-depth: 0 # Shallow clones should be disabled for a better relevancy of analysis 87 | - uses: actions/download-artifact@v2 88 | with: 89 | path: artifacts 90 | - name: Display structure of downloaded files 91 | run: ls -R 92 | - name: Copy Artifacts to target file 93 | run: | 94 | mkdir -p test-reports && cp artifacts/junit*/* test-reports 95 | mkdir -p cov-reports && cp artifacts/cov*/* cov-reports 96 | - name: Display structure of downloaded files 97 | run: ls -R 98 | - name: SonarCloud Scan 99 | uses: SonarSource/sonarcloud-github-action@master 100 | env: 101 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # Needed to get PR information, if any 102 | SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} 103 | # - name: Publish Unit Test Results 104 | # uses: EnricoMi/publish-unit-test-result-action@v1.3 105 | # if: always() 106 | # with: 107 | # github_token: ${{ secrets.GITHUB_TOKEN }} 108 | # check_name: Unit Test Results 109 | # files: test-results/*.xml 110 | # report_individual_runs: true 111 | # deduplicate_classes_by_file_name: false 112 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | _build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | pip-venv/* 89 | env/ 90 | venv/ 91 | venv2/ 92 | ENV/ 93 | env.bak/ 94 | venv.bak/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | 109 | docs/README.rst 110 | .idea 111 | .vscode 112 | venv-tf/* 113 | .pytype/ 114 | mkdocs/site 115 | node_modules 116 | test-reports 117 | 118 | 119 | _site 120 | _site_src 121 | examples/benchmarks/tf_dir 122 | examples/tf_dir/ 123 | tf_dir/ 124 | 125 | venv-doc 126 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Build documentation in the docs/ directory with Sphinx 9 | sphinx: 10 | configuration: docs/conf.py 11 | builder: dirhtml 12 | fail_on_warning: false 13 | 14 | # Build documentation with MkDocs 15 | #mkdocs: 16 | # configuration: mkdocs.yml 17 | 18 | # Optionally build your docs in additional formats such as PDF and ePub 19 | formats: 20 | - htmlzip 21 | 22 | # Optionally set the version of Python and requirements required to build your docs 23 | python: 24 | version: 3.8 25 | install: 26 | - requirements: ./requirements.dev.txt 27 | - requirements: ./requirements.txt 28 | 29 | -------------------------------------------------------------------------------- /docs/_static/css/extra.css: -------------------------------------------------------------------------------- 1 | body, 2 | h1, 3 | h2, 4 | h3, 5 | h4, 6 | h5, 7 | h6, 8 | input { 9 | font-family: "Roboto", "Helvetica Neue", Helvetica, Arial, "Microsoft YaHei", "微软正黑体", "Microsoft JhengHei", sans-serif; 10 | } 11 | 12 | .wy-side-nav-search > div.version { 13 | margin-top: 5px; 14 | } 15 | 16 | .wy-nav-content { 17 | max-width: 1000px; 18 | } 19 | 20 | .pre { 21 | font-family: "JetBrains Mono", SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", Courier, monospace 22 | } 23 | -------------------------------------------------------------------------------- /docs/_static/css/modify.css: -------------------------------------------------------------------------------- 1 | :root { 2 | /* Colors */ 3 | --color--white: #fff; 4 | --color--lightwash: #f7fbfb; 5 | --color--mediumwash: #eff7f8; 6 | --color--darkwash: #e6f3f3; 7 | --color--warmgraylight: #eeedee; 8 | --color--warmgraydark: #a3acb0; 9 | --color--coolgray1: #c5c5d2; 10 | --color--coolgray2: #8e8ea0; 11 | --color--coolgray3: #6e6e80; 12 | --color--coolgray4: #404452; 13 | --color--black: #050505; 14 | --color--pink: #e6a2e4; 15 | --color--magenta: #dd5ce5; 16 | --color--red: #bd1c5f; 17 | --color--kashgarired: #ef5350; 18 | --color--brightred: #ef4146; 19 | --color--orange: #e86c09; 20 | --color--golden: #f4ac36; 21 | --color--yellow: #ebe93d; 22 | --color--lightgreen: #68de7a; 23 | --color--darkgreen: #10a37f; 24 | --color--teal: #2ff3ce; 25 | --color--lightblue: #27b5ea; 26 | --color--mediumblue: #2e95d3; 27 | --color--darkblue: #5436da; 28 | --color--navyblue: #1d0d4c; 29 | --color--lightpurple: #6b40d8; 30 | --color--darkpurple: #412991; 31 | --color--lightgrayishpurple: #cdc3cf; 32 | --color--mediumgrayishpurple: #9c88a3; 33 | --color--darkgrayishpurple: #562f5f; 34 | } 35 | 36 | body { 37 | color: var(--color--darkgray) !important; 38 | fill: var(--color--darkgray) !important; 39 | } 40 | 41 | h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend { 42 | /* font-weight: 500; 43 | font-family: Colfax, sans-serif !important; */ 44 | font-family: "Lato", "proxima-nova", "Helvetica Neue", Arial, sans-serif !important; 45 | } 46 | 47 | .wy-nav-top { 48 | background-color: var(--color--coolgray4) !important; 49 | } 50 | 51 | .rst-content .toc-backref { 52 | color: #404040 !important; 53 | } 54 | 55 | .footnote { 56 | padding-left: 0.75rem; 57 | background-color: var(--color--warmgraylight) !important; 58 | } 59 | 60 | .wy-nav-top a, .wy-nav-top a:visited { 61 | color: var(--color--white) !important; 62 | } 63 | 64 | .wy-menu-vertical header, .wy-menu-vertical p.caption { 65 | font-weight: 500 !important; 66 | letter-spacing: 1px; 67 | margin-top: 1.25rem; 68 | } 69 | 70 | .wy-side-nav-search { 71 | background-color: var(--color--warmgraylight) !important; 72 | } 73 | 74 | .wy-body-for-nav { 75 | background-color: var(--color--coolgray4) !important; 76 | } 77 | 78 | .wy-menu-vertical li span.toctree-expand { 79 | color: var(--color--coolgray2) !important; 80 | } 81 | 82 | .wy-nav-side { 83 | color: var(--color--coolgray1) !important; 84 | background-color: var(--color--coolgray4) !important; 85 | } 86 | 87 | .wy-side-nav-search input[type=text] { 88 | border-color: var(--color--warmgraydark) !important; 89 | } 90 | 91 | a { 92 | color: var(--color--kashgarired) !important; 93 | } 94 | 95 | a:visited { 96 | color: #ff1744 !important; 97 | } 98 | 99 | .wy-menu-vertical a { 100 | color: var(--color--coolgray2) !important; 101 | } 102 | 103 | .wy-menu-vertical li.current a { 104 | border-right: none !important; 105 | color: var(--color--coolgray4) !important; 106 | } 107 | 108 | .wy-menu-vertical li.current { 109 | background-color: var(--color--warmgraylight) !important; 110 | } 111 | 112 | .wy-menu-vertical li.toctree-l2.current > a { 113 | background-color: var(--color--coolgray1) !important; 114 | } 115 | 116 | .wy-menu-vertical a:hover, .wy-menu-vertical li.current a:hover, .wy-menu-vertical li.toctree-l2.current > a:hover { 117 | color: var(--color--warmgraylight) !important; 118 | background-color: var(--color--coolgray3) !important; 119 | } 120 | 121 | .wy-alert-title, .rst-content .admonition-title { 122 | background-color: var(--color--kashgarired) !important; 123 | } 124 | 125 | .wy-alert, .rst-content .note, .rst-content .attention, .rst-content .caution, .rst-content .danger, .rst-content .error, .rst-content .hint, .rst-content .important, .rst-content .tip, .rst-content .warning, .rst-content .seealso, .rst-content .admonition-todo, .rst-content .admonition { 126 | background-color: var(--color--warmgraylight) !important; 127 | } 128 | 129 | .rst-content dl:not(.docutils) dt { 130 | border-color: var(--color--kashgarired) !important; 131 | background-color: var(--color--warmgraylight) !important; 132 | } 133 | 134 | /* .rst-content pre.literal-block, .rst-content div[class^='highlight'] { 135 | background-color: var(--color--warmgraylight) !important; 136 | } */ 137 | 138 | .wy-table-odd td, .wy-table-striped tr:nth-child(2n-1) td, .rst-content table.docutils:not(.field-list) tr:nth-child(2n-1) td { 139 | background-color: var(--color--warmgraylight) !important; 140 | } 141 | 142 | @media screen and (min-width: 1100px) { 143 | .wy-nav-content-wrap { 144 | background-color: var(--color--warmgraylight) !important; 145 | } 146 | } 147 | 148 | .wy-side-nav-search img { 149 | height: auto !important; 150 | width: 100% !important; 151 | padding: 0 !important; 152 | background-color: inherit !important; 153 | border-radius: 0 !important; 154 | margin: 0 !important 155 | } 156 | 157 | .wy-side-nav-search > a, .wy-side-nav-search .wy-dropdown > a { 158 | margin-bottom: 0 !important; 159 | } 160 | 161 | .wy-menu-vertical li.toctree-l1.current > a { 162 | border: none !important; 163 | } 164 | 165 | .wy-side-nav-search > div.version { 166 | color: var(--color--coolgray2) !important; 167 | } 168 | -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrikerMan/Kashgari/ffe730d33f894e99a6fd7aa17ca67d161bf70359/docs/_static/fonts/RobotoSlab/roboto-slab.eot -------------------------------------------------------------------------------- /docs/_static/images/multi_feature_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrikerMan/Kashgari/ffe730d33f894e99a6fd7aa17ca67d161bf70359/docs/_static/images/multi_feature_model.png -------------------------------------------------------------------------------- /docs/_static/images/ner_f1_scores.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrikerMan/Kashgari/ffe730d33f894e99a6fd7aa17ca67d161bf70359/docs/_static/images/ner_f1_scores.png -------------------------------------------------------------------------------- /docs/_static/images/smp2018ecdtcorpus_f1_score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrikerMan/Kashgari/ffe730d33f894e99a6fd7aa17ca67d161bf70359/docs/_static/images/smp2018ecdtcorpus_f1_score.png -------------------------------------------------------------------------------- /docs/_static/js/baidu-static.js: -------------------------------------------------------------------------------- 1 | var _hmt = _hmt || []; 2 | (function() { 3 | var hm = document.createElement("script"); 4 | hm.src = "https://hm.baidu.com/hm.js?5bbafbb1b5e47d6c8da68e6875345893"; 5 | var s = document.getElementsByTagName("script")[0]; 6 | s.parentNode.insertBefore(hm, s); 7 | })(); 8 | -------------------------------------------------------------------------------- /docs/about/contributing.md: -------------------------------------------------------------------------------- 1 | # Contributing & Support 2 | 3 | We are happy to accept contributions that make `Kashgari` better and more awesome! You could contribute in various ways: 4 | 5 | ## Bug Reports 6 | 7 | 1. Please **read the documentation** and **search the issue tracker** to try and find the answer to your question **before** posting an issue. 8 | 9 | 2. When creating an issue on the repository, please provide as much info as possible: 10 | 11 | - Version being used. 12 | - Operating system. 13 | - Version of Python. 14 | - Errors in console. 15 | - Detailed description of the problem. 16 | - Examples for reproducing the error. You can post pictures, but if specific text or code is required to reproduce the issue, please provide the text in a plain text format for easy copy/paste. 17 | 18 | The more info provided the greater the chance someone will take the time to answer, implement, or fix the issue. 19 | 20 | 3. Be prepared to answer questions and provide additional information if required. Issues in which the creator refuses to respond to follow up questions will be marked as stale and closed. 21 | 22 | ## Reviewing Code 23 | 24 | Take part in reviewing pull requests and/or reviewing direct commits. Make suggestions to improve the code and discuss solutions to overcome weakness in the algorithm. 25 | 26 | ## Answer Questions in Issues 27 | 28 | Take time and answer questions and offer suggestions to people who've created issues in the issue tracker. Often people will have questions that you might have an answer for. Or maybe you know how to help them accomplish a specific task they are asking about. Feel free to share your experience with others to help them out. 29 | 30 | ## Pull Requests 31 | 32 | Pull requests are welcome, and a great way to help fix bugs and add new features. 33 | 34 | ### Accuracy Benchmarks 35 | 36 | Use Kashgari your own data, and report the F-1 score. 37 | 38 | ### Adding New Models 39 | 40 | New models can be of two basic types: 41 | 42 | ### Adding New Tasks 43 | 44 | Currently, Kashgari can handle text-classification and sequence-labeling tasks. If you want to apply Kashgari for a new task, please submit a request issue and explain why we would consider adding the new task to Kashgari 45 | 46 | ## Documentation Improvements 47 | 48 | A ton of time has been spent not only creating and supporting this tool, but also spent making this documentation. If you feel it is still lacking, show your appreciation for the tool by helping to improve/translate the documentation. 49 | 50 | You can build the docs by running this commands in project root folder. Source files are in the `docs` folder. 51 | 52 | ```bash 53 | pip install -r docs/requirements.txt 54 | python setup.py install 55 | sh ./scripts/docs-live.sh 56 | ``` 57 | -------------------------------------------------------------------------------- /docs/advance-use/tensorflow-serving.md: -------------------------------------------------------------------------------- 1 | # Tensorflow Serving 2 | 3 | ```python 4 | from kashgari.tasks.classification import BiGRU_Model 5 | from kashgari.corpus import SMP2018ECDTCorpus 6 | from kashgari import utils 7 | 8 | train_x, train_y = SMP2018ECDTCorpus.load_data() 9 | 10 | model = BiGRU_Model() 11 | model.fit(train_x, train_y) 12 | 13 | # Save model 14 | utils.convert_to_saved_model(model, 15 | model_path="saved_model/bgru", 16 | version=1) 17 | ``` 18 | 19 | Then run tensorflow-serving. 20 | 21 | ```bash 22 | docker run -t --rm -p 8501:8501 -v "/saved_model:/models/" -e MODEL_NAME=bgru tensorflow/serving 23 | ``` 24 | 25 | Load processor from model, then predict with serving. 26 | 27 | We need to check model input keys first. 28 | 29 | ```python 30 | import requests 31 | res = requests.get("http://localhost:8501/v1/models/bgru/metadata") 32 | inputs = res.json()['metadata']['signature_def']['signature_def']['serving_default']['inputs'] 33 | input_sample_keys = list(inputs.keys()) 34 | print(input_sample_keys) 35 | # ['Input-Token', 'Input-Segment'] 36 | ``` 37 | 38 | If we have only one input key, aka we are not using BERT like embedding, 39 | we need to pass json in this format to predict endpoint. 40 | 41 | ```json 42 | { 43 | "instances": [ 44 | [2, 1, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 45 | [2, 9, 41, 459, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0] 46 | ] 47 | } 48 | ``` 49 | 50 | Here is the code. 51 | 52 | ```python 53 | import requests 54 | import numpy as np 55 | from kashgari.processors import load_processors_from_model 56 | 57 | text_processor, label_processor = load_processors_from_model('/Users/brikerman/Desktop/tf-serving/1603683152') 58 | 59 | samples = [ 60 | ['hello', 'world'], 61 | ['你', '好', '世', '界'] 62 | ] 63 | tensor = text_processor.transform(samples) 64 | 65 | instances = [i.tolist() for i in tensor] 66 | 67 | # predict 68 | r = requests.post("http://localhost:8501/v1/models/bgru:predict", json={"instances": instances}) 69 | predictions = r.json()['predictions'] 70 | 71 | # Convert result back to labels 72 | labels = label_processor.inverse_transform(np.array(predictions).argmax(-1)) 73 | print(labels) 74 | ``` 75 | 76 | If we are using Bert, then we need to handle multi input fields, 77 | for example we get this two keys `['Input-Token', 'Input-Segment']` from metadata endpoint. 78 | Then we need to pass a json in this format. 79 | 80 | ```json 81 | [ 82 | { 83 | "Input-Token": [2, 1, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 84 | "Input-Segment": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 85 | }, 86 | { 87 | "Input-Token": [2, 9, 41, 459, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0], 88 | "Input-Segment": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 89 | } 90 | ] 91 | 92 | ``` 93 | 94 | Here is the code. 95 | 96 | ```python 97 | import requests 98 | import numpy as np 99 | from kashgari.processors import load_processors_from_model 100 | 101 | text_processor, label_processor = load_processors_from_model('/Users/brikerman/Desktop/tf-serving/1603683152') 102 | 103 | samples = [ 104 | ['hello', 'world'], 105 | ['你', '好', '世', '界'] 106 | ] 107 | tensor = text_processor.transform(samples) 108 | 109 | instances = [{ 110 | "Input-Token": i.tolist(), 111 | "Input-Segment": np.zeros(i.shape).tolist() 112 | } for i in tensor] 113 | 114 | # predict 115 | r = requests.post("http://localhost:8501/v1/models/bgru:predict", json={"instances": instances}) 116 | predictions = r.json()['predictions'] 117 | 118 | # Convert result back to labels 119 | labels = label_processor.inverse_transform(np.array(predictions).argmax(-1)) 120 | print(labels) 121 | ``` 122 | -------------------------------------------------------------------------------- /docs/apis/classification.rst: -------------------------------------------------------------------------------- 1 | ===================== 2 | Classification Models 3 | ===================== 4 | 5 | .. contents:: Table of Contents 6 | 7 | Bidirectional LSTM Model 8 | ------------------------ 9 | 10 | .. autoclass:: kashgari.tasks.classification.BiLSTM_Model 11 | :members: 12 | :undoc-members: 13 | 14 | 15 | Bidirectional GRU Model 16 | ----------------------- 17 | 18 | .. autoclass:: kashgari.tasks.classification.BiGRU_Model 19 | :members: 20 | :undoc-members: 21 | 22 | 23 | CNN Model 24 | --------- 25 | 26 | .. autoclass:: kashgari.tasks.classification.CNN_Model 27 | -------------------------------------------------------------------------------- /docs/apis/corpus.rst: -------------------------------------------------------------------------------- 1 | ====== 2 | Corpus 3 | ====== 4 | 5 | .. contents:: Table of Contents 6 | 7 | ChineseDailyNerCorpus 8 | ===================== 9 | 10 | .. autoclass:: kashgari.corpus.ChineseDailyNerCorpus 11 | :members: 12 | 13 | SMP2018ECDTCorpus 14 | ================= 15 | .. autoclass:: kashgari.corpus.SMP2018ECDTCorpus 16 | :members: 17 | 18 | 19 | JigsawToxicCommentCorpus 20 | ======================== 21 | .. autoclass:: kashgari.corpus.JigsawToxicCommentCorpus 22 | :members: 23 | -------------------------------------------------------------------------------- /docs/apis/embeddings.rst: -------------------------------------------------------------------------------- 1 | Embeddings 2 | ========== 3 | 4 | BareEmbedding 5 | ------------- 6 | 7 | .. autoclass:: kashgari.embeddings.BareEmbedding 8 | 9 | WordEmbedding 10 | ------------- 11 | 12 | .. autoclass:: kashgari.embeddings.WordEmbedding 13 | 14 | TransformerEmbedding 15 | -------------------- 16 | 17 | .. autoclass:: kashgari.embeddings.TransformerEmbedding 18 | 19 | BertEmbedding 20 | ------------- 21 | .. autoclass:: kashgari.embeddings.BertEmbedding 22 | -------------------------------------------------------------------------------- /docs/apis/generators.rst: -------------------------------------------------------------------------------- 1 | ========== 2 | Generators 3 | ========== 4 | 5 | .. contents:: Table of Contents 6 | 7 | CorpusGenerator 8 | =============== 9 | 10 | .. autoclass:: kashgari.generators.CorpusGenerator 11 | :members: 12 | 13 | BatchDataSet 14 | ============ 15 | .. autoclass:: kashgari.generators.BatchDataSet 16 | :members: 17 | -------------------------------------------------------------------------------- /docs/apis/labeling.rst: -------------------------------------------------------------------------------- 1 | =============== 2 | Labeling Models 3 | =============== 4 | 5 | .. contents:: Table of Contents 6 | 7 | Bidirectional LSTM Model 8 | ------------------------ 9 | 10 | .. autoclass:: kashgari.tasks.labeling.BiLSTM_Model 11 | :members: 12 | :undoc-members: 13 | 14 | 15 | Bidirectional GRU Model 16 | ------------------------ 17 | 18 | .. autoclass:: kashgari.tasks.labeling.BiGRU_Model 19 | :members: 20 | :undoc-members: 21 | 22 | Bidirectional LSTM CRF Model 23 | ---------------------------- 24 | 25 | .. autoclass:: kashgari.tasks.labeling.BiLSTM_CRF_Model 26 | :members: 27 | :undoc-members: 28 | 29 | Bidirectional GRU CRF Model 30 | --------------------------- 31 | 32 | .. autoclass:: kashgari.tasks.labeling.BiGRU_CRF_Model 33 | :members: 34 | :undoc-members: 35 | 36 | 37 | Bidirectional CNN LSTM Model 38 | ---------------------------- 39 | 40 | .. autoclass:: kashgari.tasks.labeling.CNN_LSTM_Model 41 | :members: 42 | :undoc-members: 43 | -------------------------------------------------------------------------------- /docs/apis/processors.rst: -------------------------------------------------------------------------------- 1 | =============== 2 | Data Processors 3 | =============== 4 | 5 | .. contents:: Table of Contents 6 | 7 | SequenceProcessor 8 | ================= 9 | 10 | .. autoclass:: kashgari.processors.SequenceProcessor 11 | 12 | ClassificationProcessor 13 | ======================= 14 | .. autoclass:: kashgari.processors.ClassificationProcessor 15 | -------------------------------------------------------------------------------- /docs/embeddings/bare-embedding.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _bare-embedding: 3 | 4 | Bare Embedding 5 | ============== 6 | 7 | BareEmbedding is a random init ``tf.keras.layers.Embedding`` layer for text sequence embedding, which is the defualt embedding class for kashgari models. 8 | 9 | .. autofunction:: kashgari.embeddings.BareEmbedding.__init__ 10 | 11 | Here is the sample how to use embedding class. The key difference here is that must call ``analyze_corpus`` function before using the embed function. This is because the embedding layer is not pre-trained and do not contain any word-list. We need to build word-list from the corpus. 12 | 13 | .. code-block:: python 14 | 15 | import kashgari 16 | from kashgari.embeddings import BareEmbedding 17 | 18 | embedding = BareEmbedding(embedding_size=100) 19 | 20 | embedding.analyze_corpus(x_data, y_data) 21 | 22 | embed_tensor = embedding.embed_one(['语', '言', '模', '型']) 23 | -------------------------------------------------------------------------------- /docs/embeddings/bert-embedding.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _bert-embedding: 3 | 4 | Bert Embedding 5 | ============== 6 | 7 | 8 | 9 | BertEmbedding is a simple wrapped class of `Transformer Embedding <../transformer-embedding>`_. If you need load other kind of transformer based language model, please use the `Transformer Embedding <../transformer-embedding>`_. 10 | 11 | .. note:: 12 | When using pre-trained embedding, remember to use same tokenize tool with the embedding model, this will allow to access the full power of the embedding 13 | 14 | .. autofunction:: kashgari.embeddings.BertEmbedding.__init__ 15 | 16 | Example Usage - Text Classification 17 | ----------------------------------- 18 | 19 | Let's run a text classification model with BERT. 20 | 21 | .. code-block:: python 22 | 23 | sentences = [ 24 | "Jim Henson was a puppeteer.", 25 | "This here's an example of using the BERT tokenizer.", 26 | "Why did the chicken cross the road?" 27 | ] 28 | labels = [ 29 | "class1", 30 | "class2", 31 | "class1" 32 | ] 33 | ########## Load Bert Embedding ########## 34 | import os 35 | from kashgari.embeddings import BertEmbedding 36 | from kashgari.tokenizers import BertTokenizer 37 | 38 | bert_embedding = BertEmbedding('') 39 | 40 | tokenizer = BertTokenizer.load_from_vocab_file(os.path.join('', 'vocab_chinese.txt')) 41 | sentences_tokenized = [tokenizer.tokenize(s) for s in sentences] 42 | 43 | """ 44 | The sentences will become tokenized into: 45 | [ 46 | ['jim', 'henson', 'was', 'a', 'puppet', '##eer', '.'], 47 | ['this', 'here', "'", 's', 'an', 'example', 'of', 'using', 'the', 'bert', 'token', '##izer', '.'], 48 | ['why', 'did', 'the', 'chicken', 'cross', 'the', 'road', '?'] 49 | ] 50 | """ 51 | 52 | train_x, train_y = sentences_tokenized[:2], labels[:2] 53 | validate_x, validate_y = sentences_tokenized[2:], labels[2:] 54 | 55 | ########## build model ########## 56 | from kashgari.tasks.classification import CNN_LSTM_Model 57 | model = CNN_LSTM_Model(bert_embedding) 58 | 59 | ########## /build model ########## 60 | model.fit( 61 | train_x, train_y, 62 | validate_x, validate_y, 63 | epochs=3, 64 | batch_size=32 65 | ) 66 | # save model 67 | model.save('path/to/save/model/to') 68 | 69 | Use sentence pairs for input 70 | ---------------------------- 71 | 72 | let's assume input pair sample is ``"First do it" "then do it right"``\ , Then first tokenize the sentences using bert tokenizer. Then 73 | 74 | .. code-block:: python 75 | 76 | sentence1 = ['First', 'do', 'it'] 77 | sentence2 = ['then', 'do', 'it', 'right'] 78 | 79 | sample = sentence1 + ["[SEP]"] + sentence2 80 | # Add a special separation token `[SEP]` between two sentences tokens 81 | # Generate a new token list 82 | # ['First', 'do', 'it', '[SEP]', 'then', 'do', 'it', 'right'] 83 | 84 | train_x = [sample] 85 | -------------------------------------------------------------------------------- /docs/embeddings/index.rst: -------------------------------------------------------------------------------- 1 | 2 | Language Embeddings 3 | =================== 4 | 5 | Kashgari provides several embeddings for language representation. Embedding layers will convert input sequence to tensor for downstream task. Availabel embeddings list: 6 | 7 | .. list-table:: 8 | :header-rows: 1 9 | 10 | * - class name 11 | - description 12 | * - `BareEmbedding <./bare-embedding>`_ 13 | - random init ``tf.keras.layers.Embedding`` layer for text sequence embedding 14 | * - `WordEmbedding <./word-embedding>`_ 15 | - pre-trained Word2Vec embedding 16 | * - `BERTEmbedding <./bert-embedding>`_ 17 | - pre-trained BERT embedding 18 | * - `TransformerEmbedding <./transformer-embedding>`_ 19 | - pre-trained TransferEmbedding embedding (BERT, ALBERT, RoBERTa, NEZHA) 20 | 21 | 22 | All embedding classes inherit from the ``Embedding`` class and implement the ``embed()`` to embed your input sequence and ``embed_model`` property which you need to build you own Model. By providing the ``embed()`` function and ``embed_model`` property, Kashgari hides the the complexity of different language embedding from users, all you need to care is which language embedding you need. 23 | 24 | You could check out the Embedding API document `here <../apis/embeddings/>`_ 25 | 26 | Quick start 27 | ----------- 28 | 29 | Feature Extract From Pre-trained Embedding 30 | ------------------------------------------ 31 | 32 | Feature Extraction is one of the major way to use pre-trained language embedding. 33 | Kashgari provides simple API for this task. 34 | All you need to is init a embedding object and setup it's pre-processor, then call ``embed`` function. 35 | Here is the example. All embedding shares same embed API. 36 | 37 | .. code-block:: python 38 | 39 | from kashgari.embeddings import BertEmbedding 40 | from kashgari.processors import SequenceProcessor 41 | 42 | bert = BertEmbedding('') 43 | processor = SequenceProcessor() 44 | bert.setup_text_processor(processor) 45 | # call for embed 46 | embed_tensor = bert.embed([['语', '言', '模', '型']]) 47 | 48 | print(embed_tensor) 49 | # array([[-0.5001117 , 0.9344998 , -0.55165815, ..., 0.49122602, 50 | # -0.2049343 , 0.25752577], 51 | # [-1.05762 , -0.43353617, 0.54398274, ..., -0.61096823, 52 | # 0.04312163, 0.03881482], 53 | # [ 0.14332692, -0.42566583, 0.68867105, ..., 0.42449307, 54 | # 0.41105768, 0.08222893], 55 | # ..., 56 | # [-0.86124015, 0.08591427, -0.34404194, ..., 0.19915134, 57 | # -0.34176797, 0.06111742], 58 | # [-0.73940575, -0.02692179, -0.5826528 , ..., 0.26934686, 59 | # -0.29708537, 0.01855129], 60 | # [-0.85489404, 0.007399 , -0.26482674, ..., 0.16851354, 61 | # -0.36805922, -0.0052386 ]], dtype=float32) 62 | 63 | Classification and Labeling 64 | --------------------------- 65 | 66 | See details at classification and labeling tutorial. 67 | 68 | Customized model 69 | ---------------- 70 | 71 | You can access the tf.keras model of embedding and add your own layers or any kind customization. Just need to access the ``embed_model`` property of the embedding object. 72 | -------------------------------------------------------------------------------- /docs/embeddings/transformer-embedding.rst: -------------------------------------------------------------------------------- 1 | 2 | .. transformer-embedding: 3 | 4 | Transformer Embedding 5 | ===================== 6 | 7 | TransformerEmbedding is based on `bert4keras `_. The embeddings itself are wrapped into our simple embedding interface so that they can be used like any other embedding. 8 | 9 | TransformerEmbedding support models: 10 | 11 | .. list-table:: 12 | :header-rows: 1 13 | 14 | * - Model 15 | - Author 16 | - Link 17 | * - BERT 18 | - Google 19 | - https://github.com/google-research/bert 20 | * - ALBERT 21 | - Google 22 | - https://github.com/google-research/ALBERT 23 | * - ALBERT 24 | - brightmart 25 | - https://github.com/brightmart/albert_zh 26 | * - RoBERTa 27 | - brightmart 28 | - https://github.com/brightmart/roberta_zh 29 | * - RoBERTa 30 | - 哈工大 31 | - https://github.com/ymcui/Chinese-BERT-wwm 32 | * - RoBERTa 33 | - 苏剑林 34 | - https://github.com/ZhuiyiTechnology/pretrained-models 35 | * - NEZHA 36 | - Huawei 37 | - https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/NEZHA 38 | 39 | .. note:: 40 | When using pre-trained embedding, remember to use same tokenize tool with the embedding model, this will allow to access the full power of the embedding 41 | 42 | .. autofunction:: kashgari.embeddings.TransformerEmbedding.__init__ 43 | 44 | Example Usage - Text Classification 45 | ----------------------------------- 46 | 47 | Let's run a text classification model with BERT. 48 | 49 | .. code-block:: python 50 | 51 | sentences = [ 52 | "Jim Henson was a puppeteer.", 53 | "This here's an example of using the BERT tokenizer.", 54 | "Why did the chicken cross the road?" 55 | ] 56 | labels = [ 57 | "class1", 58 | "class2", 59 | "class1" 60 | ] 61 | # ------------ Load Bert Embedding ------------ 62 | import os 63 | from kashgari.embeddings import TransformerEmbedding 64 | from kashgari.tokenizers import BertTokenizer 65 | 66 | # Setup paths 67 | model_folder = '/xxx/xxx/albert_base' 68 | checkpoint_path = os.path.join(model_folder, 'model.ckpt-best') 69 | config_path = os.path.join(model_folder, 'albert_config.json') 70 | vocab_path = os.path.join(model_folder, 'vocab_chinese.txt') 71 | 72 | tokenizer = BertTokenizer.load_from_vocab_file(vocab_path) 73 | embed = TransformerEmbedding(vocab_path, config_path, checkpoint_path, 74 | bert_type='albert') 75 | 76 | sentences_tokenized = [tokenizer.tokenize(s) for s in sentences] 77 | """ 78 | The sentences will become tokenized into: 79 | [ 80 | ['jim', 'henson', 'was', 'a', 'puppet', '##eer', '.'], 81 | ['this', 'here', "'", 's', 'an', 'example', 'of', 'using', 'the', 'bert', 'token', '##izer', '.'], 82 | ['why', 'did', 'the', 'chicken', 'cross', 'the', 'road', '?'] 83 | ] 84 | """ 85 | 86 | train_x, train_y = sentences_tokenized[:2], labels[:2] 87 | validate_x, validate_y = sentences_tokenized[2:], labels[2:] 88 | 89 | # ------------ Build Model Start ------------ 90 | from kashgari.tasks.classification import CNN_LSTM_Model 91 | model = CNN_LSTM_Model(embed) 92 | 93 | # ------------ Build Model End ------------ 94 | 95 | model.fit( 96 | train_x, train_y, 97 | validate_x, validate_y, 98 | epochs=3, 99 | batch_size=32 100 | ) 101 | # save model 102 | model.save('path/to/save/model/to') 103 | -------------------------------------------------------------------------------- /docs/embeddings/word-embedding.rst: -------------------------------------------------------------------------------- 1 | 2 | .. word-embedding: 3 | 4 | Word Embedding 5 | ============== 6 | 7 | WordEmbedding is a ``tf.keras.layers.Embedding`` layer with pre-trained Word2Vec/GloVe Emedding weights. 8 | 9 | .. autofunction:: kashgari.embeddings.WordEmbedding.__init__ 10 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. include:: README.rst 2 | 3 | .. toctree:: 4 | :maxdepth: 1 5 | :caption: Overview 6 | 7 | README.rst 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Tutorials 12 | 13 | tutorial/text-classification 14 | tutorial/text-labeling 15 | tutorial/seq2seq 16 | 17 | .. toctree:: 18 | :maxdepth: 2 19 | :caption: Embeddings 20 | 21 | embeddings/index.rst 22 | embeddings/bare-embedding.rst 23 | embeddings/word-embedding.rst 24 | embeddings/bert-embedding.rst 25 | embeddings/transformer-embedding.rst 26 | 27 | .. toctree:: 28 | :maxdepth: 2 29 | :caption: Advanced Use Cases 30 | 31 | advance-use/tensorflow-serving.md 32 | 33 | .. toctree:: 34 | :maxdepth: 3 35 | :caption: API 36 | 37 | apis/corpus 38 | apis/embeddings 39 | apis/classification 40 | apis/labeling 41 | apis/generators 42 | apis/processors 43 | 44 | .. toctree:: 45 | :maxdepth: 2 46 | :caption: About 47 | 48 | about/contributing.md 49 | about/release-notes.md 50 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # documents 2 | sphinx 3 | sphinx-autobuild 4 | sphinx-rtd-theme 5 | sphinx-markdown-tables 6 | sphinx-autodoc-typehints 7 | recommonmark 8 | m2r 9 | 10 | # dependence 11 | numpy>=1.18.1 12 | gensim>=3.8.1 13 | pandas>=1.0.1 14 | tqdm 15 | bert4keras>=0.7.9 16 | scikit-learn>=0.21.1 17 | tensorflow>=2.1.0 18 | -------------------------------------------------------------------------------- /docs/tutorial/seq2seq.md: -------------------------------------------------------------------------------- 1 | # Seq2Seq Model 2 | 3 | ## Train a translate model 4 | 5 | ```python 6 | # Original Corpus 7 | x_original = [ 8 | 'Who am I?', 9 | 'I am sick.', 10 | 'I like you.', 11 | 'I need help.', 12 | 'It may hurt.', 13 | 'Good morning.'] 14 | 15 | y_original = [ 16 | 'مەن كىم ؟', 17 | 'مەن كېسەل.', 18 | 'مەن سىزنى ياخشى كۆرمەن', 19 | 'ماڭا ياردەم كېرەك.', 20 | 'ئاغىرىشى مۇمكىن.', 21 | 'خەيىرلىك ئەتىگەن.'] 22 | 23 | # Tokenize sentence with custom tokenizing function 24 | # Tokenize sentence with custom tokenizing function 25 | # We use Bert Tokenizer for this demo 26 | from kashgari.tokenizers import BertTokenizer 27 | tokenizer = BertTokenizer() 28 | x_tokenized = [tokenizer.tokenize(sample) for sample in x_original] 29 | y_tokenized = [tokenizer.tokenize(sample) for sample in y_original] 30 | ``` 31 | 32 | After tokenizing the corpus, we can build a seq2seq Model. 33 | 34 | ```python 35 | from kashgari.tasks.seq2seq import Seq2Seq 36 | 37 | model = Seq2Seq() 38 | model.fit(x_tokenized, y_tokenized) 39 | 40 | # predict with model 41 | preds, attention = model.predict(x_tokenized) 42 | print(preds) 43 | ``` 44 | 45 | ## Train with custom embedding 46 | 47 | You can define both encoder's and decoder's embedding. This is how to use [Bert Embedding](./../embeddings/bert-embedding) as encoder's embedding layer. 48 | 49 | ```python 50 | from kashgari.tasks.seq2seq import Seq2Seq 51 | from kashgari.embeddings import BertEmbedding 52 | 53 | bert = BertEmbedding('') 54 | model = Seq2Seq(encoder_embedding=bert, hidden_size=512) 55 | 56 | model.fit(x_tokenized, y_tokenized) 57 | ``` 58 | -------------------------------------------------------------------------------- /examples/benchmarks/benchmark_utils.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: benchmark_utils.py 8 | # time: 11:05 下午 9 | 10 | import os 11 | import json 12 | from typing import List, Dict 13 | 14 | 15 | class BenchMarkHelper: 16 | 17 | @classmethod 18 | def save_training_logs(cls, 19 | log_file: str, 20 | embedding_name: str, 21 | model_name: str, 22 | logs: List, 23 | **kwargs: Dict): 24 | 25 | if not os.path.exists(log_file): 26 | data = {} 27 | else: 28 | data = json.loads(open(log_file, 'r').read()) 29 | 30 | if embedding_name not in data: 31 | data[embedding_name] = {} 32 | 33 | data[embedding_name][model_name] = { 34 | 'logs': logs, 35 | **kwargs 36 | } 37 | 38 | with open(log_file, 'w') as f: 39 | f.write(json.dumps(data, indent=2)) 40 | 41 | 42 | if __name__ == "__main__": 43 | BenchMarkHelper.save_training_logs('./training.json', 44 | embedding_name='embed_name', 45 | model_name='model_name', 46 | logs={}, 47 | training_duration=321) 48 | -------------------------------------------------------------------------------- /examples/custom_generator.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: custom_generator.py 8 | # time: 4:13 下午 9 | 10 | import os 11 | import linecache 12 | from tensorflow.keras.utils import get_file 13 | from kashgari.generators import ABCGenerator 14 | 15 | 16 | def download_data(duplicate=1000): 17 | url_list = [ 18 | 'https://raw.githubusercontent.com/BrikerMan/JointSLU/master/data/atis-2.train.w-intent.iob', 19 | 'https://raw.githubusercontent.com/BrikerMan/JointSLU/master/data/atis-2.dev.w-intent.iob', 20 | 'https://raw.githubusercontent.com/BrikerMan/JointSLU/master/data/atis.test.w-intent.iob', 21 | 'https://raw.githubusercontent.com/BrikerMan/JointSLU/master/data/atis.train.w-intent.iob' 22 | ] 23 | files = [] 24 | for url in url_list: 25 | files.append(get_file(url.split('/')[-1], url)) 26 | 27 | return files * duplicate 28 | 29 | 30 | class ClassificationGenerator: 31 | def __init__(self, files): 32 | self.files = files 33 | self._line_count = sum(sum(1 for line in open(file, 'r')) for file in files) 34 | 35 | @property 36 | def steps(self) -> int: 37 | return self._line_count 38 | 39 | def __iter__(self): 40 | for file in self.files: 41 | with open(file, 'r') as f: 42 | for line in f: 43 | rows = line.split('\t') 44 | x = rows[0].strip().split(' ')[1:-1] 45 | y = rows[1].strip().split(' ')[-1] 46 | yield x, y 47 | 48 | 49 | class LabelingGenerator(ABCGenerator): 50 | def __init__(self, files): 51 | self.files = files 52 | self._line_count = sum(sum(1 for line in open(file, 'r')) for file in files) 53 | 54 | @property 55 | def steps(self) -> int: 56 | return self._line_count 57 | 58 | def __iter__(self): 59 | for file in self.files: 60 | with open(file, 'r') as f: 61 | for line in f: 62 | rows = line.split('\t') 63 | x = rows[0].strip().split(' ')[1:-1] 64 | y = rows[1].strip().split(' ')[1:-1] 65 | yield x, y 66 | 67 | 68 | def run_classification_model(): 69 | from kashgari.tasks.classification import BiGRU_Model 70 | files = download_data() 71 | gen = ClassificationGenerator(files) 72 | 73 | model = BiGRU_Model() 74 | model.fit_generator(gen) 75 | 76 | 77 | def run_labeling_model(): 78 | from kashgari.tasks.labeling import BiGRU_Model 79 | files = download_data() 80 | gen = LabelingGenerator(files) 81 | 82 | model = BiGRU_Model() 83 | model.fit_generator(gen) 84 | 85 | 86 | if __name__ == "__main__": 87 | run_classification_model() 88 | -------------------------------------------------------------------------------- /examples/k_fold_evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Author : BrikerMan 4 | # Site : https://eliyar.biz 5 | 6 | # Time : 2020/9/3 7:23 下午 7 | # File : k_fold_evaluation.py 8 | # Project : Kashgari 9 | 10 | 11 | from sklearn.model_selection import StratifiedKFold 12 | import numpy as np 13 | from kashgari.corpus import SMP2018ECDTCorpus 14 | from kashgari.tasks.classification import BiLSTM_Model 15 | 16 | # fix random seed for reproducibility 17 | seed = 7 18 | np.random.seed(seed) 19 | 20 | # Combine all data for k-folding 21 | 22 | train_x, train_y = SMP2018ECDTCorpus.load_data('train') 23 | valid_x, valid_y = SMP2018ECDTCorpus.load_data('valid') 24 | test_x, test_y = SMP2018ECDTCorpus.load_data('test') 25 | 26 | X = train_x + valid_x + test_x 27 | Y = train_y + valid_y + test_y 28 | 29 | # define 10-fold cross validation test harness 30 | k_fold = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed) 31 | scores = [] 32 | 33 | for train_indexs, test_indexs in k_fold.split(X, Y): 34 | train_x, train_y = [], [] 35 | test_x, test_y = [], [] 36 | 37 | for i in train_indexs: 38 | train_x.append(X[i]) 39 | train_y.append(Y[i]) 40 | 41 | assert len(train_x) == len(train_y) 42 | for i in test_indexs: 43 | test_x.append(X[i]) 44 | test_y.append(Y[i]) 45 | 46 | assert len(test_x) == len(test_y) 47 | model = BiLSTM_Model() 48 | model.fit(train_x, train_y, epochs=10) 49 | 50 | report = model.evaluate(test_x, test_y) 51 | # extract your target metric from report, for example f1 52 | scores.append(report['f1-score']) 53 | 54 | print(f"{np.mean(scores):.2f} (+/- {np.std(scores):.2f})") 55 | -------------------------------------------------------------------------------- /examples/tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Author : BrikerMan 4 | # Site : https://eliyar.biz 5 | 6 | # Time : 2020/8/29 11:11 上午 7 | # File : tools.py 8 | # Project : Kashgari 9 | 10 | import os 11 | import zipfile 12 | import pathlib 13 | from tensorflow.keras.utils import get_file 14 | from kashgari import macros as K 15 | 16 | 17 | def get_bert_path() -> str: 18 | url = "https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip" 19 | bert_path = os.path.join(K.DATA_PATH, 'datasets', 'bert') 20 | model_path = os.path.join(bert_path, 'chinese_L-12_H-768_A-12') 21 | pathlib.Path(bert_path).mkdir(parents=True, exist_ok=True) 22 | if not os.path.exists(model_path): 23 | zip_file_path = get_file("bert/chinese_L-12_H-768_A-12.zip", 24 | url, 25 | cache_dir=K.DATA_PATH, ) 26 | 27 | with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: 28 | zip_ref.extractall(bert_path) 29 | return model_path 30 | 31 | 32 | if __name__ == '__main__': 33 | for k, v in os.environ.items(): 34 | print(f'{k:20}: {v}') 35 | get_bert_path() 36 | -------------------------------------------------------------------------------- /examples/web_qa_reading_comprehence.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Author : BrikerMan 4 | # Site : https://eliyar.biz 5 | 6 | # Time : 2020/9/2 5:51 下午 7 | # File : web_qa_reading_comprehence.py 8 | # Project : Kashgari 9 | 10 | import os 11 | import json 12 | from kashgari.tokenizers.bert_tokenizer import BertTokenizer 13 | from kashgari.embeddings import BertEmbedding 14 | from examples.tools import get_bert_path 15 | from kashgari.tasks.seq2seq.model import Seq2Seq 16 | import tensorflow as tf 17 | 18 | gpus = tf.config.experimental.list_physical_devices('GPU') 19 | if gpus: 20 | try: 21 | for gpu in gpus: 22 | tf.config.experimental.set_memory_growth(gpu, True) 23 | except RuntimeError as e: 24 | print(e) 25 | 26 | WEB_QA_PATH = '/home/brikerman/Downloads/SogouQA.json' 27 | Sogou_QA_PATH = '/home/brikerman/Downloads/SogouQA.json' 28 | 29 | with open(Sogou_QA_PATH, 'r') as f: 30 | corpus_data = json.loads(f.read()) 31 | 32 | bert_path = get_bert_path() 33 | tokenizer = BertTokenizer.load_from_vocab_file(os.path.join(bert_path, 'vocab.txt')) 34 | 35 | # 筛选数据 36 | seps, strips = u'\n。!?!?;;,, ', u';;,, ' 37 | x_data = [] 38 | y_data = [] 39 | 40 | for d in corpus_data: 41 | for p in d['passages']: 42 | if p['answer']: 43 | x = tokenizer.tokenize(d['question']) + ['[SEP]'] + tokenizer.tokenize(p['passage']) 44 | x_data.append(x) 45 | y_data.append(tokenizer.tokenize(p['answer'])) 46 | 47 | print(x_data[:3]) 48 | print(y_data[:3]) 49 | 50 | bert = BertEmbedding(bert_path) 51 | model = Seq2Seq(encoder_seq_length=256) 52 | 53 | 54 | class CustomCallback(tf.keras.callbacks.Callback): 55 | def __init__(self, model): 56 | self.model = model 57 | self.sample_count = 5 58 | 59 | def on_epoch_end(self, epoch, logs=None): 60 | if epoch % 4 != 0: 61 | return 62 | import random 63 | samples = random.sample(x_data, self.sample_count) 64 | translates, _ = self.model.predict(samples) 65 | print() 66 | for index in range(len(samples)): 67 | print(f"X: {''.join(samples[index])}") 68 | print(f"Y: {''.join(translates[index])}") 69 | print('------------------------------') 70 | 71 | 72 | his_callback = CustomCallback(model) 73 | history = model.fit(x_data, 74 | y_data, 75 | callbacks=[his_callback], 76 | epochs=50, 77 | batch_size=16) 78 | 79 | if __name__ == '__main__': 80 | pass 81 | -------------------------------------------------------------------------------- /kashgari/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: BrikerMan 4 | @contact: eliyar917@gmail.com 5 | @blog: https://eliyar.biz 6 | 7 | @version: 1.0 8 | @license: Apache Licence 9 | @file: __init__.py 10 | @time: 2019-05-17 11:15 11 | 12 | """ 13 | 14 | import os 15 | from distutils.version import LooseVersion 16 | from typing import Any, Dict 17 | 18 | os.environ["TF_KERAS"] = "1" 19 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 20 | 21 | custom_objects: Dict[str, Any] = {} 22 | 23 | 24 | def check_tfa_version(tf_version: str) -> str: 25 | if LooseVersion(tf_version) < "2.2.0": 26 | return "0.9.1" 27 | elif LooseVersion(tf_version) < "2.3.0": 28 | return "0.11.2" 29 | else: 30 | return "0.13.0" 31 | 32 | 33 | def dependency_check() -> None: 34 | import tensorflow as tf 35 | 36 | tfa_version = check_tfa_version(tf_version=tf.__version__) 37 | try: 38 | import tensorflow_addons as tfa 39 | except: 40 | raise ImportError( 41 | "Kashgari request tensorflow_addons, please install via the " 42 | f"`$pip install tensorflow_addons=={tfa_version}`" 43 | ) 44 | 45 | 46 | dependency_check() 47 | 48 | from kashgari import corpus, embeddings, layers, macros, processors, tasks, utils 49 | from kashgari.__version__ import __version__ 50 | from kashgari.macros import config 51 | 52 | custom_objects = layers.resigter_custom_layers(custom_objects) 53 | -------------------------------------------------------------------------------- /kashgari/__version__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __version__.py.py 8 | # time: 2019-05-20 16:32 9 | 10 | __version__ = '2.0.2' 11 | -------------------------------------------------------------------------------- /kashgari/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __init__.py 8 | # time: 8:08 下午 9 | 10 | 11 | from kashgari.callbacks.eval_callBack import EvalCallBack 12 | 13 | if __name__ == "__main__": 14 | pass 15 | -------------------------------------------------------------------------------- /kashgari/callbacks/eval_callBack.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: eval_callBack.py 8 | # time: 6:53 下午 9 | 10 | from typing import List, Any, Dict 11 | 12 | import tensorflow as tf 13 | from tensorflow import keras 14 | 15 | from kashgari.tasks.abs_task_model import ABCTaskModel 16 | 17 | 18 | class EvalCallBack(keras.callbacks.Callback): 19 | 20 | def __init__(self, 21 | kash_model: ABCTaskModel, 22 | x_data: List[Any], 23 | y_data: List[Any], 24 | *, 25 | step: int = 5, 26 | truncating: bool = False, 27 | batch_size: int = 256) -> None: 28 | """ 29 | Evaluate callback, calculate precision, recall and f1 30 | Args: 31 | kash_model: the kashgari task model to evaluate 32 | x_data: feature data for evaluation 33 | y_data: label data for evaluation 34 | step: step, default 5 35 | truncating: truncating: remove values from sequences larger than `model.embedding.sequence_length` 36 | batch_size: batch size, default 256 37 | """ 38 | super(EvalCallBack, self).__init__() 39 | self.kash_model: ABCTaskModel = kash_model 40 | self.x_data = x_data 41 | self.y_data = y_data 42 | self.step = step 43 | self.truncating = truncating 44 | self.batch_size = batch_size 45 | self.logs: List[Dict] = [] 46 | 47 | def on_epoch_end(self, epoch: int, logs: Any = None) -> None: 48 | if (epoch + 1) % self.step == 0: 49 | report = self.kash_model.evaluate(self.x_data, # type: ignore 50 | self.y_data, 51 | truncating=self.truncating, 52 | batch_size=self.batch_size) 53 | 54 | self.logs.append({ 55 | 'precision': report['precision'], 56 | 'recall': report['recall'], 57 | 'f1-score': report['f1-score'] 58 | }) 59 | 60 | tf.summary.scalar('eval f1-score', data=report['f1-score'], step=epoch) 61 | tf.summary.scalar('eval recall', data=report['recall'], step=epoch) 62 | tf.summary.scalar('eval precision', data=report['precision'], step=epoch) 63 | print(f"\nepoch: {epoch} precision: {report['precision']:.6f}," 64 | f" recall: {report['recall']:.6f}, f1-score: {report['f1-score']:.6f}") 65 | -------------------------------------------------------------------------------- /kashgari/callbacks/save_callback.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from typing import Union, Any, AnyStr 4 | 5 | import tensorflow as tf 6 | from kashgari.tasks.abs_task_model import ABCTaskModel 7 | from kashgari.logger import logger 8 | 9 | 10 | class KashgariModelCheckpoint(tf.keras.callbacks.ModelCheckpoint): 11 | """Save the model after every epoch. 12 | Arguments: 13 | filepath: string, path to save the model file. 14 | monitor: quantity to monitor. 15 | verbose: verbosity mode, 0 or 1. 16 | save_best_only: if `save_best_only=True`, the latest best model according 17 | to the quantity monitored will not be overwritten. 18 | mode: one of {auto, min, max}. If `save_best_only=True`, the decision to 19 | overwrite the current save file is made based on either the maximization 20 | or the minimization of the monitored quantity. For `val_acc`, this 21 | should be `max`, for `val_loss` this should be `min`, etc. In `auto` 22 | mode, the direction is automatically inferred from the name of the 23 | monitored quantity. 24 | save_weights_only: if True, then only the model's weights will be saved 25 | (`model.save_weights(filepath)`), else the full model is saved 26 | (`model.save(filepath)`). 27 | save_freq: `'epoch'` or integer. When using `'epoch'`, the callback saves 28 | the model after each epoch. When using integer, the callback saves the 29 | model at end of a batch at which this many samples have been seen since 30 | last saving. Note that if the saving isn't aligned to epochs, the 31 | monitored metric may potentially be less reliable (it could reflect as 32 | little as 1 batch, since the metrics get reset every epoch). Defaults to 33 | `'epoch'` 34 | **kwargs: Additional arguments for backwards compatibility. Possible key 35 | is `period`. 36 | """ 37 | 38 | def __init__(self, 39 | filepath: AnyStr, 40 | monitor: str = 'val_loss', 41 | verbose: int = 1, 42 | save_best_only: bool = False, 43 | save_weights_only: bool = False, 44 | mode: str = 'auto', 45 | save_freq: Union[str, int] = 'epoch', 46 | kash_model: ABCTaskModel = None, 47 | **kwargs: Any) -> None: 48 | super(KashgariModelCheckpoint, self).__init__( 49 | filepath=filepath, 50 | monitor=monitor, 51 | verbose=verbose, 52 | save_best_only=save_best_only, 53 | save_weights_only=save_weights_only, 54 | mode=mode, 55 | save_freq=save_freq, 56 | **kwargs) 57 | self.kash_model = kash_model 58 | 59 | def _save_model(self, epoch: int, logs: dict) -> None: 60 | """Saves the model. 61 | Arguments: 62 | epoch: the epoch this iteration is in. 63 | logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`. 64 | """ 65 | logs = logs or {} 66 | 67 | if isinstance(self.save_freq, 68 | int) or self.epochs_since_last_save >= self.period: 69 | self.epochs_since_last_save: int = 0 70 | filepath = self._get_file_path(epoch, logs) 71 | 72 | if self.save_best_only: 73 | current = logs.get(self.monitor) 74 | if current is None: 75 | logger.warning('Can save best model only with %s available, skipping.', self.monitor) 76 | else: 77 | if self.monitor_op(current, self.best): 78 | if self.verbose > 0: 79 | print('\nEpoch %d: %s improved from %0.5f to %0.5f,' 80 | ' saving model to %s' % (epoch + 1, self.monitor, self.best, 81 | current, filepath)) 82 | self.best: float = current 83 | if self.save_weights_only: 84 | filepath = os.path.join(filepath, 'cp') 85 | self.model.save_weights(filepath, overwrite=True) 86 | logger.info(f'checkpoint saved to {filepath}') 87 | else: 88 | self.kash_model.save(filepath) 89 | else: 90 | if self.verbose > 0: 91 | print('\nEpoch %d: %s did not improve from %0.5f' % 92 | (epoch + 1, self.monitor, self.best)) 93 | else: 94 | if self.verbose > 0: 95 | print('\nEpoch %d: saving model to %s' % (epoch + 1, filepath)) 96 | if self.save_weights_only: 97 | filepath = os.path.join(filepath, 'cp') 98 | self.model.save_weights(filepath, overwrite=True) 99 | logger.info(f'checkpoint saved to {filepath}') 100 | else: 101 | self.kash_model.save(filepath) 102 | 103 | self._maybe_remove_file() 104 | -------------------------------------------------------------------------------- /kashgari/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __init__.py 8 | # time: 3:06 下午 9 | 10 | from .abc_embedding import ABCEmbedding 11 | from .bare_embedding import BareEmbedding 12 | from .bert_embedding import BertEmbedding 13 | from .transformer_embedding import TransformerEmbedding 14 | from .word_embedding import WordEmbedding 15 | 16 | if __name__ == "__main__": 17 | pass 18 | -------------------------------------------------------------------------------- /kashgari/embeddings/abc_embedding.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: abc_embedding.py 8 | # time: 2:43 下午 9 | 10 | import json 11 | from typing import Dict, List, Any, Optional, Union 12 | 13 | import numpy as np 14 | import tensorflow as tf 15 | import tqdm 16 | 17 | import kashgari 18 | from kashgari.generators import CorpusGenerator 19 | from kashgari.logger import logger 20 | from kashgari.processors import ABCProcessor 21 | 22 | L = tf.keras.layers 23 | 24 | 25 | class ABCEmbedding: 26 | def to_dict(self) -> Dict[str, Any]: 27 | config: Dict[str, Any] = { 28 | 'segment': self.segment, 29 | 'embedding_size': self.embedding_size, 30 | 'max_position': self.max_position, 31 | **self.kwargs 32 | } 33 | return { 34 | '__class_name__': self.__class__.__name__, 35 | '__module__': self.__class__.__module__, 36 | 'config': config, 37 | 'embed_model': json.loads(self.embed_model.to_json()) 38 | } 39 | 40 | def __init__(self, 41 | segment: bool = False, 42 | embedding_size: int = 100, 43 | max_position: int = None, 44 | **kwargs: Any): 45 | 46 | self.embed_model: tf.keras.Model = None 47 | 48 | self.segment: bool = segment # type: ignore 49 | self.kwargs = kwargs 50 | 51 | self.embedding_size: int = embedding_size # type: ignore 52 | self.max_position: int = max_position # type: ignore 53 | self.vocab2idx = self.load_embed_vocab() 54 | self._text_processor: Optional[ABCProcessor] = None 55 | 56 | def _override_load_model(self, config: Dict) -> None: 57 | embed_model_json_str = json.dumps(config['embed_model']) 58 | self.embed_model = tf.keras.models.model_from_json(embed_model_json_str, 59 | custom_objects=kashgari.custom_objects) 60 | 61 | def setup_text_processor(self, processor: ABCProcessor) -> None: 62 | self._text_processor = processor 63 | self.build_embedding_model(vocab_size=processor.vocab_size) 64 | self._text_processor.segment = self.segment 65 | if self.vocab2idx: 66 | self._text_processor.vocab2idx = self.vocab2idx 67 | self._text_processor.idx2vocab = dict([(v, k) for k, v in self.vocab2idx.items()]) 68 | 69 | def get_seq_length_from_corpus(self, 70 | generators: List[CorpusGenerator], 71 | *, 72 | use_label: bool = False, 73 | cover_rate: float = 0.95) -> int: 74 | """ 75 | Calculate proper sequence length according to the corpus 76 | 77 | Args: 78 | generators: 79 | use_label: 80 | cover_rate: 81 | 82 | Returns: 83 | 84 | """ 85 | seq_lens = [] 86 | for gen in generators: 87 | for sentence, label in tqdm.tqdm(gen, desc="Calculating sequence length"): 88 | if use_label: 89 | seq_lens.append(len(label)) 90 | else: 91 | seq_lens.append(len(sentence)) 92 | if cover_rate == 1.0: 93 | target_index = -1 94 | else: 95 | target_index = int(cover_rate * len(seq_lens)) 96 | sequence_length = sorted(seq_lens)[target_index] 97 | logger.debug(f'Calculated sequence length = {sequence_length}') 98 | return sequence_length 99 | 100 | def load_embed_vocab(self) -> Optional[Dict[str, int]]: 101 | """ 102 | Load vocab dict from embedding layer 103 | 104 | Returns: 105 | vocab dict or None 106 | """ 107 | raise NotImplementedError 108 | 109 | def build_embedding_model(self, 110 | *, 111 | vocab_size: int = None, 112 | force: bool = False, 113 | **kwargs: Dict) -> None: 114 | raise NotImplementedError 115 | 116 | def embed(self, 117 | sentences: List[List[str]], 118 | *, 119 | debug: bool = False) -> np.ndarray: 120 | """ 121 | batch embed sentences 122 | 123 | Args: 124 | sentences: Sentence list to embed 125 | debug: show debug info 126 | Returns: 127 | vectorized sentence list 128 | """ 129 | if self._text_processor is None: 130 | raise ValueError('Need to setup the `embedding.setup_text_processor` before calling the embed function.') 131 | 132 | tensor_x = self._text_processor.transform(sentences, 133 | segment=self.segment, 134 | seq_length=self.max_position) 135 | if debug: 136 | logger.debug(f'sentence tensor: {tensor_x}') 137 | embed_results = self.embed_model.predict(tensor_x) 138 | return embed_results 139 | 140 | 141 | if __name__ == "__main__": 142 | pass 143 | -------------------------------------------------------------------------------- /kashgari/embeddings/bare_embedding.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: bare_embedding.py 8 | # time: 2:17 下午 9 | 10 | from typing import Dict, Any, Optional 11 | 12 | from tensorflow import keras 13 | 14 | from kashgari.embeddings.abc_embedding import ABCEmbedding 15 | 16 | L = keras.layers 17 | 18 | 19 | class BareEmbedding(ABCEmbedding): 20 | """ 21 | BareEmbedding is a random init `tf.keras.layers.Embedding` layer for text sequence embedding, 22 | which is the defualt embedding class for kashgari models. 23 | """ 24 | 25 | def __init__(self, 26 | embedding_size: int = 100, 27 | **kwargs: Any): 28 | """ 29 | 30 | Args: 31 | embedding_size: Dimension of the dense embedding. 32 | kwargs: additional params 33 | """ 34 | self.embedding_size: int = embedding_size 35 | super(BareEmbedding, self).__init__(embedding_size=embedding_size, 36 | **kwargs) 37 | 38 | def load_embed_vocab(self) -> Optional[Dict[str, int]]: 39 | return None 40 | 41 | def build_embedding_model(self, 42 | *, 43 | vocab_size: int = None, 44 | force: bool = False, 45 | **kwargs: Dict) -> None: 46 | if self.embed_model is None or force: 47 | input_tensor = L.Input(shape=(None,), 48 | name=f'input') 49 | layer_embedding = L.Embedding(vocab_size, 50 | self.embedding_size, 51 | mask_zero=True, 52 | name=f'layer_embedding') 53 | 54 | embedded_tensor = layer_embedding(input_tensor) 55 | self.embed_model = keras.Model(input_tensor, embedded_tensor) 56 | 57 | 58 | if __name__ == "__main__": 59 | pass 60 | -------------------------------------------------------------------------------- /kashgari/embeddings/bert_embedding.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: bert_embedding.py 8 | # time: 2:49 下午 9 | 10 | import os 11 | from typing import Dict, Any 12 | 13 | from kashgari.embeddings.transformer_embedding import TransformerEmbedding 14 | 15 | 16 | class BertEmbedding(TransformerEmbedding): 17 | """ 18 | BertEmbedding is a simple wrapped class of TransformerEmbedding. 19 | If you need load other kind of transformer based language model, please use the TransformerEmbedding. 20 | """ 21 | 22 | def to_dict(self) -> Dict[str, Any]: 23 | info_dic = super(BertEmbedding, self).to_dict() 24 | info_dic['config']['model_folder'] = self.model_folder 25 | return info_dic 26 | 27 | def __init__(self, 28 | model_folder: str, 29 | **kwargs: Any): 30 | """ 31 | 32 | Args: 33 | model_folder: path of checkpoint folder. 34 | kwargs: additional params 35 | """ 36 | self.model_folder = model_folder 37 | vocab_path = os.path.join(self.model_folder, 'vocab.txt') 38 | config_path = os.path.join(self.model_folder, 'bert_config.json') 39 | checkpoint_path = os.path.join(self.model_folder, 'bert_model.ckpt') 40 | kwargs['vocab_path'] = vocab_path 41 | kwargs['config_path'] = config_path 42 | kwargs['checkpoint_path'] = checkpoint_path 43 | kwargs['model_type'] = 'bert' 44 | super(BertEmbedding, self).__init__(**kwargs) 45 | 46 | 47 | if __name__ == "__main__": 48 | pass 49 | -------------------------------------------------------------------------------- /kashgari/embeddings/transformer_embedding.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: transformer_embedding.py 8 | # time: 11:41 上午 9 | 10 | import codecs 11 | import json 12 | from typing import Dict, List, Any, Optional 13 | 14 | from bert4keras.models import build_transformer_model 15 | 16 | from kashgari.embeddings.abc_embedding import ABCEmbedding 17 | from kashgari.logger import logger 18 | 19 | 20 | class TransformerEmbedding(ABCEmbedding): 21 | """ 22 | TransformerEmbedding is based on bert4keras. 23 | The embeddings itself are wrapped into our simple embedding interface so that they can be used like any other embedding. 24 | """ 25 | def to_dict(self) -> Dict[str, Any]: 26 | info_dic = super(TransformerEmbedding, self).to_dict() 27 | info_dic['config']['vocab_path'] = self.vocab_path 28 | info_dic['config']['config_path'] = self.config_path 29 | info_dic['config']['checkpoint_path'] = self.checkpoint_path 30 | info_dic['config']['model_type'] = self.model_type 31 | return info_dic 32 | 33 | def __init__(self, 34 | vocab_path: str, 35 | config_path: str, 36 | checkpoint_path: str, 37 | model_type: str = 'bert', 38 | **kwargs: Any): 39 | """ 40 | 41 | Args: 42 | vocab_path: vocab file path, example `vocab.txt` 43 | config_path: model config path, example `config.json` 44 | checkpoint_path: model weight path, example `model.ckpt-100000` 45 | model_type: transfer model type, {bert, albert, nezha, gpt2_ml, t5} 46 | kwargs: additional params 47 | """ 48 | self.vocab_path = vocab_path 49 | self.config_path = config_path 50 | self.checkpoint_path = checkpoint_path 51 | self.model_type = model_type 52 | self.vocab_list: List[str] = [] 53 | kwargs['segment'] = True 54 | super(TransformerEmbedding, self).__init__(**kwargs) 55 | 56 | def load_embed_vocab(self) -> Optional[Dict[str, int]]: 57 | token2idx: Dict[str, int] = {} 58 | with codecs.open(self.vocab_path, 'r', 'utf8') as reader: 59 | for line in reader: 60 | token = line.strip() 61 | self.vocab_list.append(token) 62 | token2idx[token] = len(token2idx) 63 | top_words = [k for k, v in list(token2idx.items())[:50]] 64 | logger.debug('------------------------------------------------') 65 | logger.debug("Loaded transformer model's vocab") 66 | logger.debug(f'config_path : {self.config_path}') 67 | logger.debug(f'vocab_path : {self.vocab_path}') 68 | logger.debug(f'checkpoint_path : {self.checkpoint_path}') 69 | logger.debug(f'Top 50 words : {top_words}') 70 | logger.debug('------------------------------------------------') 71 | 72 | return token2idx 73 | 74 | def build_embedding_model(self, 75 | *, 76 | vocab_size: int = None, 77 | force: bool = False, 78 | **kwargs: Dict) -> None: 79 | if self.embed_model is None: 80 | config_path = self.config_path 81 | with open(config_path, 'r') as f: 82 | config = json.loads(f.read()) 83 | if 'max_position' in config: 84 | self.max_position = config['max_position'] 85 | else: 86 | self.max_position = config.get('max_position_embeddings') 87 | 88 | bert_model = build_transformer_model(config_path=self.config_path, 89 | checkpoint_path=self.checkpoint_path, 90 | model=self.model_type, 91 | application='encoder', 92 | return_keras_model=True) 93 | for layer in bert_model.layers: 94 | layer.trainable = False 95 | self.embed_model = bert_model 96 | self.embedding_size = bert_model.output.shape[-1] 97 | 98 | 99 | if __name__ == "__main__": 100 | pass 101 | -------------------------------------------------------------------------------- /kashgari/embeddings/word_embedding.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: word_embedding.py 8 | # time: 3:06 下午 9 | 10 | from typing import Dict, Any, Optional 11 | 12 | import numpy as np 13 | from gensim.models import KeyedVectors 14 | from tensorflow import keras 15 | 16 | from kashgari.embeddings.abc_embedding import ABCEmbedding 17 | from kashgari.logger import logger 18 | 19 | L = keras.layers 20 | 21 | 22 | class WordEmbedding(ABCEmbedding): 23 | def to_dict(self) -> Dict[str, Any]: 24 | info_dic = super(WordEmbedding, self).to_dict() 25 | info_dic['config']['w2v_path'] = self.w2v_path 26 | info_dic['config']['w2v_kwargs'] = self.w2v_kwargs 27 | return info_dic 28 | 29 | def __init__(self, 30 | w2v_path: str, 31 | *, 32 | w2v_kwargs: Dict[str, Any] = None, 33 | **kwargs: Any): 34 | """ 35 | Args: 36 | w2v_path: Word2Vec file path. 37 | w2v_kwargs: params pass to the ``load_word2vec_format()`` function 38 | of `gensim.models.KeyedVectors `_ 39 | kwargs: additional params 40 | """ 41 | if w2v_kwargs is None: 42 | w2v_kwargs = {} 43 | 44 | self.w2v_path = w2v_path 45 | self.w2v_kwargs = w2v_kwargs 46 | 47 | self.embedding_size = None 48 | self.w2v_matrix: np.ndarray = None 49 | 50 | super(WordEmbedding, self).__init__(**kwargs) 51 | 52 | def load_embed_vocab(self) -> Optional[Dict[str, int]]: 53 | w2v = KeyedVectors.load_word2vec_format(self.w2v_path, **self.w2v_kwargs) 54 | 55 | token2idx = { 56 | '[PAD]': 0, 57 | '[UNK]': 1, 58 | '[BOS]': 2, 59 | '[EOS]': 3 60 | } 61 | 62 | for token in w2v.index2word: 63 | token2idx[token] = len(token2idx) 64 | 65 | vector_matrix = np.zeros((len(token2idx), w2v.vector_size)) 66 | vector_matrix[1] = np.random.rand(w2v.vector_size) 67 | vector_matrix[4:] = w2v.vectors 68 | 69 | self.embedding_size = w2v.vector_size 70 | self.w2v_matrix = vector_matrix 71 | w2v_top_words = w2v.index2entity[:50] 72 | 73 | logger.debug('------------------------------------------------') 74 | logger.debug("Loaded gensim word2vec model's vocab") 75 | logger.debug('model : {}'.format(self.w2v_path)) 76 | logger.debug('word count : {}'.format(len(self.w2v_matrix))) 77 | logger.debug('Top 50 words : {}'.format(w2v_top_words)) 78 | logger.debug('------------------------------------------------') 79 | 80 | return token2idx 81 | 82 | def build_embedding_model(self, 83 | *, 84 | vocab_size: int = None, 85 | force: bool = False, 86 | **kwargs: Dict) -> None: 87 | if self.embed_model is None: 88 | input_tensor = L.Input(shape=(None,), 89 | name=f'input') 90 | layer_embedding = L.Embedding(len(self.vocab2idx), 91 | self.embedding_size, 92 | weights=[self.w2v_matrix], 93 | trainable=False, 94 | mask_zero=True, 95 | name=f'layer_embedding') 96 | 97 | embedded_tensor = layer_embedding(input_tensor) 98 | self.embed_model = keras.Model(input_tensor, embedded_tensor) 99 | 100 | 101 | if __name__ == "__main__": 102 | pass 103 | -------------------------------------------------------------------------------- /kashgari/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __init__.py 8 | # time: 7:39 下午 9 | 10 | from typing import Dict, Any 11 | from tensorflow import keras 12 | 13 | from .conditional_random_field import KConditionalRandomField 14 | from .behdanau_attention import BahdanauAttention # type: ignore 15 | 16 | L = keras.layers 17 | L.BahdanauAttention = BahdanauAttention 18 | L.KConditionalRandomField = KConditionalRandomField 19 | 20 | 21 | def resigter_custom_layers(custom_objects: Dict[str, Any]) -> Dict[str, Any]: 22 | custom_objects['KConditionalRandomField'] = KConditionalRandomField 23 | custom_objects['BahdanauAttention'] = BahdanauAttention 24 | return custom_objects 25 | 26 | 27 | if __name__ == "__main__": 28 | pass 29 | -------------------------------------------------------------------------------- /kashgari/layers/behdanau_attention.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: behdanau_attention.py 8 | # time: 2:31 下午 9 | 10 | # type: ignore 11 | 12 | import tensorflow as tf 13 | from tensorflow.python.util.tf_export import keras_export 14 | 15 | 16 | @keras_export('keras.layers.BahdanauAttention') 17 | class BahdanauAttention(tf.keras.layers.Layer): 18 | def __init__(self, units): 19 | super(BahdanauAttention, self).__init__() 20 | self.W1 = tf.keras.layers.Dense(units) 21 | self.W2 = tf.keras.layers.Dense(units) 22 | self.V = tf.keras.layers.Dense(1) 23 | 24 | def call(self, query, values): 25 | # query hidden state shape == (batch_size, hidden size) 26 | # query_with_time_axis shape == (batch_size, 1, hidden size) 27 | # values shape == (batch_size, max_len, hidden size) 28 | # we are doing this to broadcast addition along the time axis to calculate the score 29 | query_with_time_axis = tf.expand_dims(query, 1) 30 | 31 | # score shape == (batch_size, max_length, 1) 32 | # we get 1 at the last axis because we are applying score to self.V 33 | # the shape of the tensor before applying self.V is (batch_size, max_length, units) 34 | score = self.V(tf.nn.tanh( 35 | self.W1(query_with_time_axis) + self.W2(values))) 36 | 37 | # attention_weights shape == (batch_size, max_length, 1) 38 | attention_weights = tf.nn.softmax(score, axis=1) 39 | 40 | # context_vector shape after sum == (batch_size, hidden_size) 41 | context_vector = attention_weights * values 42 | context_vector = tf.reduce_sum(context_vector, axis=1) 43 | 44 | return context_vector, attention_weights 45 | 46 | 47 | if __name__ == "__main__": 48 | pass 49 | -------------------------------------------------------------------------------- /kashgari/layers/conditional_random_field.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Author : BrikerMan 4 | # Site : https://eliyar.biz 5 | 6 | # Time : 2020/9/2 9:19 下午 7 | # File : conditional_random_field.py 8 | # Project : Kashgari 9 | 10 | # mypy: ignore-errors 11 | 12 | from distutils.version import LooseVersion 13 | 14 | import tensorflow as tf 15 | import tensorflow.keras.backend as K 16 | import tensorflow_addons as tfa 17 | 18 | 19 | class KConditionalRandomField(tf.keras.layers.Layer): 20 | """ 21 | K is to mark Kashgari version of CRF 22 | Conditional Random Field layer (tf.keras) 23 | `CRF` can be used as the last layer in a network (as a classifier). Input shape (features) 24 | must be equal to the number of classes the CRF can predict (a linear layer is recommended). 25 | 26 | Args: 27 | num_labels (int): the number of labels to tag each temporal input. 28 | 29 | Input shape: 30 | nD tensor with shape `(batch_size, sentence length, num_classes)`. 31 | 32 | Output shape: 33 | nD tensor with shape: `(batch_size, sentence length, num_classes)`. 34 | 35 | Masking 36 | This layer supports keras masking for input data with a variable number 37 | of timesteps. To introduce masks to your data, 38 | use an embedding layer with the `mask_zero` parameter 39 | set to `True` or add a Masking Layer before this Layer 40 | """ 41 | 42 | def __init__(self, 43 | sparse_target=True, 44 | **kwargs): 45 | if LooseVersion(tf.__version__) < '2.2.0': 46 | raise ImportError("The KConditionalRandomField requires TensorFlow 2.2.x version or higher.") 47 | 48 | super().__init__() 49 | self.transitions = kwargs.pop('transitions', None) 50 | self.output_dim = kwargs.pop('output_dim', None) 51 | self.sparse_target = sparse_target 52 | self.sequence_lengths = None 53 | self.mask = None 54 | 55 | def get_config(self): 56 | config = { 57 | "output_dim": self.output_dim, 58 | "transitions": K.eval(self.transitions), 59 | } 60 | base_config = super().get_config() 61 | return dict(**base_config, **config) 62 | 63 | def build(self, input_shape): 64 | self.output_dim = input_shape[-1] 65 | assert len(input_shape) == 3 66 | self.transitions = self.add_weight( 67 | name="transitions", 68 | shape=[input_shape[-1], input_shape[-1]], 69 | initializer="glorot_uniform", 70 | trainable=True 71 | ) 72 | 73 | def call(self, inputs, mask=None, **kwargs): 74 | if mask is not None: 75 | self.sequence_lengths = K.sum(K.cast(mask, 'int32'), axis=-1) 76 | self.mask = mask 77 | else: 78 | self.sequence_lengths = K.sum(K.ones_like(inputs[:, :, 0], dtype='int32'), axis=-1) 79 | viterbi_sequence, _ = tfa.text.crf_decode( 80 | inputs, self.transitions, self.sequence_lengths 81 | ) 82 | output = K.cast(K.one_hot(viterbi_sequence, inputs.shape[-1]), inputs.dtype) 83 | return K.in_train_phase(inputs, output) 84 | 85 | def loss(self, y_true, y_pred): 86 | if len(K.int_shape(y_true)) == 3: 87 | y_true = K.argmax(y_true, axis=-1) 88 | log_likelihood, self.transitions = tfa.text.crf_log_likelihood( 89 | y_pred, 90 | y_true, 91 | self.sequence_lengths, 92 | transition_params=self.transitions, 93 | ) 94 | return tf.reduce_mean(-log_likelihood) 95 | 96 | def compute_output_shape(self, input_shape): 97 | return input_shape[:2] + (self.out_dim,) 98 | 99 | # use crf decode to estimate accuracy 100 | def accuracy(self, y_true, y_pred): 101 | mask = self.mask 102 | if len(K.int_shape(y_true)) == 3: 103 | y_true = K.argmax(y_true, axis=-1) 104 | 105 | y_pred, _ = tfa.text.crf_decode( 106 | y_pred, self.transitions, self.sequence_lengths 107 | ) 108 | y_true = K.cast(y_true, y_pred.dtype) 109 | is_equal = K.equal(y_true, y_pred) 110 | is_equal = K.cast(is_equal, y_pred.dtype) 111 | if mask is None: 112 | return K.mean(is_equal) 113 | else: 114 | mask = K.cast(mask, y_pred.dtype) 115 | return K.sum(is_equal * mask) / K.sum(mask) 116 | 117 | # Use argmax to estimate accuracy 118 | def fast_accuracy(self, y_true, y_pred): 119 | mask = self.mask 120 | if len(K.int_shape(y_true)) == 3: 121 | y_true = K.argmax(y_true, axis=-1) 122 | y_pred = K.argmax(y_pred, -1) 123 | y_true = K.cast(y_true, y_pred.dtype) 124 | # 逐标签取最大来粗略评测训练效果 125 | isequal = K.equal(y_true, y_pred) 126 | isequal = K.cast(isequal, y_pred.dtype) 127 | if mask is None: 128 | return K.mean(isequal) 129 | else: 130 | mask = K.cast(mask, y_pred.dtype) 131 | return K.sum(isequal * mask) / K.sum(mask) 132 | -------------------------------------------------------------------------------- /kashgari/logger.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: logger.py 8 | # time: 11:43 下午 9 | 10 | import os 11 | import logging 12 | 13 | logger = logging.Logger('kashgari', level='DEBUG') 14 | stream_handler = logging.StreamHandler() 15 | 16 | if os.environ.get('KASHGARI_DEV') == 'True': 17 | log_format = '%(asctime)s [%(levelname)s] %(name)s:%(filename)s:%(lineno)d - %(message)s' 18 | else: 19 | log_format = '%(asctime)s [%(levelname)s] %(name)s - %(message)s' 20 | 21 | stream_handler.setFormatter(logging.Formatter(log_format)) 22 | logger.addHandler(stream_handler) 23 | 24 | if __name__ == "__main__": 25 | pass 26 | -------------------------------------------------------------------------------- /kashgari/macros.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: macros.py 8 | # time: 12:37 下午 9 | 10 | import os 11 | from pathlib import Path 12 | from typing import Dict 13 | 14 | DATA_PATH = os.path.join(str(Path.home()), '.kashgari') 15 | 16 | Path(DATA_PATH).mkdir(exist_ok=True, parents=True) 17 | 18 | 19 | class Config: 20 | 21 | def __init__(self) -> None: 22 | self.verbose = False 23 | 24 | def to_dict(self) -> Dict: 25 | return { 26 | 'verbose': self.verbose 27 | } 28 | 29 | 30 | config = Config() 31 | 32 | if __name__ == "__main__": 33 | pass 34 | -------------------------------------------------------------------------------- /kashgari/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __init__.py 8 | # time: 10:44 下午 9 | 10 | from kashgari.metrics.multi_label_classification import multi_label_classification_report 11 | from kashgari.metrics.sequence_labeling import sequence_labeling_report 12 | 13 | if __name__ == "__main__": 14 | pass 15 | -------------------------------------------------------------------------------- /kashgari/metrics/multi_label_classification.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: multi_label_classification.py 8 | # time: 6:33 下午 9 | 10 | from typing import Dict, Any, TYPE_CHECKING 11 | 12 | import numpy as np 13 | from sklearn import metrics 14 | 15 | from kashgari.types import MultiLabelClassificationLabelVar 16 | 17 | if TYPE_CHECKING: 18 | from kashgari.utils import MultiLabelBinarizer 19 | 20 | 21 | def multi_label_classification_report(y_true: MultiLabelClassificationLabelVar, 22 | y_pred: MultiLabelClassificationLabelVar, 23 | *, 24 | binarizer: 'MultiLabelBinarizer', 25 | digits: int = 4, 26 | verbose: int = 1) -> Dict[str, Any]: 27 | y_pred_b = binarizer.transform(y_pred) 28 | y_true_b = binarizer.transform(y_true) 29 | 30 | report_dic: Dict = {} 31 | details: Dict = {} 32 | 33 | rows = [] 34 | ps, rs, f1s, s = [], [], [], [] 35 | for c_index, c in enumerate(binarizer.classes): 36 | precision = metrics.precision_score(y_true_b[:, c_index], y_pred_b[:, c_index]) 37 | recall = metrics.recall_score(y_true_b[:, c_index], y_pred_b[:, c_index]) 38 | f1 = metrics.f1_score(y_true_b[:, c_index], y_pred_b[:, c_index]) 39 | support = len(np.where(y_true_b[:, c_index] == 1)[0]) 40 | details[c] = { 41 | 'precision': precision, 42 | 'recall': recall, 43 | 'f1': f1, 44 | 'support': support 45 | } 46 | 47 | rows.append((c, precision, recall, f1, support)) 48 | ps.append(precision) 49 | rs.append(recall) 50 | f1s.append(f1) 51 | s.append(support) 52 | 53 | report_dic['precision'] = np.average(ps, weights=s) 54 | report_dic['recall'] = np.average(rs, weights=s) 55 | report_dic['f1-score'] = np.average(f1s, weights=s) 56 | report_dic['support'] = np.sum(s) 57 | 58 | headers = ["precision", "recall", "f1-score", "support"] 59 | head_fmt = '{:>{width}s} ' + ' {:>9}' * len(headers) + '\n' 60 | 61 | report = head_fmt.format('', *headers, width=20) 62 | 63 | row_fmt = '{:>{width}s} {:>9.{digits}f} {:>9.{digits}f} {:>9.{digits}f} {:>9}\n' 64 | 65 | for row in rows: 66 | report += row_fmt.format(*row, width=20, digits=digits) 67 | 68 | # compute averages 69 | report += row_fmt.format('macro avg', 70 | np.average(ps, weights=s), 71 | np.average(rs, weights=s), 72 | np.average(f1s, weights=s), 73 | np.sum(s), 74 | width=20, digits=digits) 75 | 76 | report_dic['detail'] = details 77 | print(report) 78 | 79 | return report_dic 80 | 81 | 82 | if __name__ == "__main__": 83 | pass 84 | -------------------------------------------------------------------------------- /kashgari/processors/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __init__.py 8 | # time: 12:27 下午 9 | 10 | from .abc_processor import ABCProcessor 11 | from .class_processor import ClassificationProcessor 12 | from .sequence_processor import SequenceProcessor 13 | 14 | from .tools import load_processors_from_model 15 | 16 | if __name__ == "__main__": 17 | pass 18 | -------------------------------------------------------------------------------- /kashgari/processors/abc_processor.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: abs_processor.py 8 | # time: 2:53 下午 9 | 10 | from abc import ABC 11 | from typing import Dict, List, Optional, Any, Tuple 12 | 13 | import numpy as np 14 | 15 | from kashgari.generators import CorpusGenerator 16 | from kashgari.types import TextSamplesVar 17 | 18 | 19 | class ABCProcessor(ABC): 20 | def to_dict(self) -> Dict[str, Any]: 21 | return { 22 | 'config': { 23 | 'token_pad': self.token_pad, 24 | 'token_unk': self.token_unk, 25 | 'token_bos': self.token_bos, 26 | 'token_eos': self.token_eos, 27 | 'vocab2idx': self.vocab2idx, 28 | 'segment': self.segment 29 | }, 30 | '__class_name__': self.__class__.__name__, 31 | '__module__': self.__class__.__module__, 32 | } 33 | 34 | def __init__(self, **kwargs: Any) -> None: 35 | self.vocab2idx = kwargs.get('vocab2idx', {}) 36 | self.idx2vocab = dict([(v, k) for k, v in self.vocab2idx.items()]) 37 | 38 | self.segment = False 39 | 40 | self.token_pad: str = kwargs.get('token_pad', '[PAD]') # type: ignore 41 | self.token_unk: str = kwargs.get('token_unk', '[UNK]') # type: ignore 42 | self.token_bos: str = kwargs.get('token_bos', '[CLS]') # type: ignore 43 | self.token_eos: str = kwargs.get('token_eos', '[SEP]') # type: ignore 44 | 45 | self._sequence_length_from_saved_model: Optional[int] = None 46 | 47 | @property 48 | def vocab_size(self) -> int: 49 | return len(self.vocab2idx) 50 | 51 | @property 52 | def is_vocab_build(self) -> bool: 53 | return self.vocab_size != 0 54 | 55 | def build_vocab(self, 56 | x_data: TextSamplesVar, 57 | y_data: TextSamplesVar) -> None: 58 | corpus_gen = CorpusGenerator(x_data, y_data) 59 | self.build_vocab_generator([corpus_gen]) 60 | 61 | def build_vocab_generator(self, 62 | generators: List[CorpusGenerator]) -> None: 63 | raise NotImplementedError 64 | 65 | def get_tensor_shape(self, batch_size: int, seq_length: int) -> Tuple: 66 | if self.segment: 67 | return 2, batch_size, seq_length 68 | else: 69 | return batch_size, seq_length 70 | 71 | def transform(self, 72 | samples: TextSamplesVar, 73 | *, 74 | seq_length: int = None, 75 | max_position: int = None, 76 | segment: bool = False) -> np.ndarray: 77 | raise NotImplementedError 78 | 79 | def inverse_transform(self, 80 | labels: List[int], 81 | *, 82 | lengths: List[int] = None, 83 | threshold: float = 0.5, 84 | **kwargs: Any) -> List[str]: 85 | raise NotImplementedError 86 | 87 | 88 | if __name__ == "__main__": 89 | pass 90 | -------------------------------------------------------------------------------- /kashgari/processors/class_processor.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: label_processor.py 8 | # time: 2:53 下午 9 | 10 | import collections 11 | import operator 12 | from typing import List, Union, Dict, Optional, Any, Tuple 13 | 14 | import numpy as np 15 | import tqdm 16 | 17 | from kashgari.generators import CorpusGenerator 18 | from kashgari.processors.abc_processor import ABCProcessor 19 | from kashgari.types import TextSamplesVar 20 | 21 | 22 | class ClassificationProcessor(ABCProcessor): 23 | 24 | def to_dict(self) -> Dict[str, Any]: 25 | data = super(ClassificationProcessor, self).to_dict() 26 | data['config']['multi_label'] = self.multi_label 27 | return data 28 | 29 | def __init__(self, 30 | multi_label: bool = False, 31 | **kwargs: Any) -> None: 32 | from kashgari.utils import MultiLabelBinarizer 33 | super(ClassificationProcessor, self).__init__(**kwargs) 34 | self.multi_label = multi_label 35 | self.multi_label_binarizer = MultiLabelBinarizer(self.vocab2idx) 36 | 37 | def build_vocab_generator(self, 38 | generators: List[CorpusGenerator]) -> None: 39 | from kashgari.utils import MultiLabelBinarizer 40 | if self.vocab2idx: 41 | return 42 | 43 | vocab2idx: Dict[str, int] = {} 44 | token2count: Dict[str, int] = {} 45 | for generator in generators: 46 | if self.multi_label: 47 | for _, label in tqdm.tqdm(generator, desc="Preparing classification label vocab dict"): 48 | for token in label: 49 | count = token2count.get(token, 0) 50 | token2count[token] = count + 1 51 | else: 52 | for _, label in tqdm.tqdm(generator, desc="Preparing classification label vocab dict"): 53 | count = token2count.get(label, 0) 54 | token2count[label] = count + 1 55 | 56 | sorted_token2count = sorted(token2count.items(), 57 | key=operator.itemgetter(1), 58 | reverse=True) 59 | token2count = collections.OrderedDict(sorted_token2count) 60 | 61 | for token, token_count in token2count.items(): 62 | if token not in vocab2idx: 63 | vocab2idx[token] = len(vocab2idx) 64 | self.vocab2idx = vocab2idx 65 | self.idx2vocab = dict([(v, k) for k, v in self.vocab2idx.items()]) 66 | self.multi_label_binarizer = MultiLabelBinarizer(self.vocab2idx) 67 | 68 | def get_tensor_shape(self, batch_size: int, seq_length: int) -> Tuple: 69 | if self.multi_label: 70 | return batch_size, len(self.vocab2idx) 71 | else: 72 | return (batch_size,) 73 | 74 | def transform(self, 75 | samples: TextSamplesVar, 76 | *, 77 | seq_length: int = None, 78 | max_position: int = None, 79 | segment: bool = False) -> np.ndarray: 80 | if self.multi_label: 81 | sample_tensor = self.multi_label_binarizer.transform(samples) 82 | return sample_tensor 83 | 84 | sample_tensor = [self.vocab2idx[i] for i in samples] 85 | return np.array(sample_tensor) 86 | 87 | def inverse_transform(self, # type: ignore[override] 88 | labels: Union[List[int], np.ndarray], 89 | *, 90 | lengths: List[int] = None, 91 | threshold: float = 0.5, 92 | **kwargs: Any) -> Union[List[List[str]], List[str]]: 93 | if self.multi_label: 94 | return self.multi_label_binarizer.inverse_transform(labels, 95 | threshold=threshold) 96 | else: 97 | return [self.idx2vocab[i] for i in labels] 98 | 99 | 100 | if __name__ == "__main__": 101 | pass 102 | -------------------------------------------------------------------------------- /kashgari/processors/tools.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: tools.py 8 | # time: 11:24 上午 9 | 10 | import json 11 | import os 12 | from typing import Tuple 13 | 14 | from kashgari.processors.abc_processor import ABCProcessor 15 | from kashgari.utils.serialize import load_data_object 16 | 17 | 18 | def load_processors_from_model(model_path: str) -> Tuple[ABCProcessor, ABCProcessor]: 19 | with open(os.path.join(model_path, 'model_config.json'), 'r') as f: 20 | model_config = json.loads(f.read()) 21 | text_processor: ABCProcessor = load_data_object(model_config['text_processor']) 22 | label_processor: ABCProcessor = load_data_object(model_config['label_processor']) 23 | 24 | sequence_length_from_saved_model = model_config['config'].get('sequence_length', None) 25 | text_processor._sequence_length_from_saved_model = sequence_length_from_saved_model 26 | label_processor._sequence_length_from_saved_model = sequence_length_from_saved_model 27 | 28 | return text_processor, label_processor 29 | 30 | 31 | if __name__ == "__main__": 32 | text_processor, label_processor = load_processors_from_model('/Users/brikerman/Desktop/tf-serving/1603683152') 33 | x = text_processor.transform([list('我想你了')]) 34 | print(x.tolist()) 35 | -------------------------------------------------------------------------------- /kashgari/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __init__.py 8 | # time: 4:04 下午 9 | 10 | from kashgari.tasks import classification 11 | 12 | if __name__ == "__main__": 13 | pass 14 | -------------------------------------------------------------------------------- /kashgari/tasks/abs_task_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: abs_task_model.py 8 | # time: 1:43 下午 9 | 10 | import json 11 | import os 12 | import pathlib 13 | from abc import ABC, abstractmethod 14 | from typing import TYPE_CHECKING, Any, Dict, Union 15 | 16 | import tensorflow as tf 17 | 18 | import kashgari 19 | from kashgari.embeddings import ABCEmbedding 20 | from kashgari.layers import KConditionalRandomField 21 | from kashgari.logger import logger 22 | from kashgari.processors.abc_processor import ABCProcessor 23 | from kashgari.utils import load_data_object 24 | 25 | if TYPE_CHECKING: 26 | from kashgari.tasks.classification import ABCClassificationModel 27 | from kashgari.tasks.labeling import ABCLabelingModel 28 | 29 | 30 | class ABCTaskModel(ABC): 31 | 32 | def __init__(self) -> None: 33 | self.tf_model: tf.keras.Model = None 34 | self.embedding: ABCEmbedding = None 35 | self.hyper_parameters: Dict[str, Any] 36 | self.sequence_length: int 37 | self.text_processor: ABCProcessor 38 | self.label_processor: ABCProcessor 39 | 40 | def to_dict(self) -> Dict[str, Any]: 41 | model_json_str = self.tf_model.to_json() 42 | 43 | return { 44 | 'tf_version': tf.__version__, # type: ignore 45 | 'kashgari_version': kashgari.__version__, 46 | '__class_name__': self.__class__.__name__, 47 | '__module__': self.__class__.__module__, 48 | 'config': { 49 | 'hyper_parameters': self.hyper_parameters, # type: ignore 50 | 'sequence_length': self.sequence_length # type: ignore 51 | }, 52 | 'embedding': self.embedding.to_dict(), # type: ignore 53 | 'text_processor': self.text_processor.to_dict(), 54 | 'label_processor': self.label_processor.to_dict(), 55 | 'tf_model': json.loads(model_json_str) 56 | } 57 | 58 | @classmethod 59 | def default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]: 60 | """ 61 | The default hyper parameters of the model dict, **all models must implement this function.** 62 | 63 | You could easily change model's hyper-parameters. 64 | 65 | For example, change the LSTM unit in BiLSTM_Model from 128 to 32. 66 | 67 | >>> from kashgari.tasks.classification import BiLSTM_Model 68 | >>> hyper = BiLSTM_Model.default_hyper_parameters() 69 | >>> print(hyper) 70 | {'layer_bi_lstm': {'units': 128, 'return_sequences': False}, 'layer_output': {}} 71 | >>> hyper['layer_bi_lstm']['units'] = 32 72 | >>> model = BiLSTM_Model(hyper_parameters=hyper) 73 | 74 | Returns: 75 | hyper params dict 76 | """ 77 | raise NotImplementedError 78 | 79 | def save(self, model_path: str, encoding: str = 'utf-8') -> str: 80 | pathlib.Path(model_path).mkdir(exist_ok=True, parents=True) 81 | model_path = os.path.abspath(model_path) 82 | 83 | with open(os.path.join(model_path, 'model_config.json'), 'w', encoding=encoding) as f: 84 | f.write(json.dumps(self.to_dict(), indent=2, ensure_ascii=False)) 85 | f.close() 86 | 87 | self.embedding.embed_model.save_weights(os.path.join(model_path, 'embed_model_weights.h5')) 88 | self.tf_model.save_weights(os.path.join(model_path, 'model_weights.h5')) # type: ignore 89 | logger.info('model saved to {}'.format(os.path.abspath(model_path))) 90 | return model_path 91 | 92 | @classmethod 93 | def load_model(cls, model_path: str, 94 | custom_objects: Dict = None, 95 | encoding: str = 'utf-8') -> Union["ABCLabelingModel", "ABCClassificationModel"]: 96 | if custom_objects is None: 97 | custom_objects = {} 98 | 99 | if cls.__name__ not in custom_objects: 100 | custom_objects[cls.__name__] = cls 101 | 102 | model_config_path = os.path.join(model_path, 'model_config.json') 103 | model_config = json.loads(open(model_config_path, 'r', encoding=encoding).read()) 104 | model = load_data_object(model_config, custom_objects) 105 | 106 | model.embedding = load_data_object(model_config['embedding'], custom_objects) 107 | model.text_processor = load_data_object(model_config['text_processor'], custom_objects) 108 | model.label_processor = load_data_object(model_config['label_processor'], custom_objects) 109 | 110 | tf_model_str = json.dumps(model_config['tf_model']) 111 | 112 | model.tf_model = tf.keras.models.model_from_json(tf_model_str, 113 | custom_objects=kashgari.custom_objects) 114 | 115 | if isinstance(model.tf_model.layers[-1], KConditionalRandomField): 116 | model.crf_layer = model.tf_model.layers[-1] 117 | 118 | model.tf_model.load_weights(os.path.join(model_path, 'model_weights.h5')) 119 | model.embedding.embed_model.load_weights(os.path.join(model_path, 'embed_model_weights.h5')) 120 | return model 121 | 122 | @abstractmethod 123 | def build_model(self, 124 | x_data: Any, 125 | y_data: Any) -> None: 126 | raise NotImplementedError 127 | -------------------------------------------------------------------------------- /kashgari/tasks/classification/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __init__.py 8 | # time: 4:05 下午 9 | 10 | from .abc_model import ABCClassificationModel 11 | from .bi_gru_model import BiGRU_Model 12 | from .bi_lstm_model import BiLSTM_Model 13 | from .cnn_attention_model import CNN_Attention_Model 14 | from .cnn_gru_model import CNN_GRU_Model 15 | from .cnn_lstm_model import CNN_LSTM_Model 16 | from .cnn_model import CNN_Model 17 | 18 | ALL_MODELS = [ 19 | BiGRU_Model, 20 | BiLSTM_Model, 21 | CNN_Attention_Model, 22 | CNN_GRU_Model, 23 | CNN_LSTM_Model, 24 | CNN_Model 25 | ] 26 | 27 | if __name__ == "__main__": 28 | pass 29 | -------------------------------------------------------------------------------- /kashgari/tasks/classification/bi_gru_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: bi_gru_model.py 8 | # time: 4:37 下午 9 | 10 | from typing import Dict, Any 11 | 12 | from tensorflow import keras 13 | 14 | from kashgari.layers import L 15 | from kashgari.tasks.classification.abc_model import ABCClassificationModel 16 | 17 | 18 | class BiGRU_Model(ABCClassificationModel): 19 | 20 | @classmethod 21 | def default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]: 22 | return { 23 | 'layer_bi_gru': { 24 | 'units': 128, 25 | 'return_sequences': False 26 | }, 27 | 'layer_output': { 28 | } 29 | } 30 | 31 | def build_model_arc(self) -> None: 32 | output_dim = self.label_processor.vocab_size 33 | config = self.hyper_parameters 34 | embed_model = self.embedding.embed_model 35 | 36 | layer_stack = [ 37 | L.Bidirectional(L.GRU(**config['layer_bi_gru'])), 38 | L.Dense(output_dim, **config['layer_output']), 39 | self._activation_layer() 40 | ] 41 | 42 | tensor = embed_model.output 43 | for layer in layer_stack: 44 | tensor = layer(tensor) 45 | 46 | self.tf_model = keras.Model(embed_model.inputs, tensor) 47 | -------------------------------------------------------------------------------- /kashgari/tasks/classification/bi_lstm_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: cnn_lstm_model.py 8 | # time: 4:06 下午 9 | 10 | from typing import Dict, Any 11 | 12 | from tensorflow import keras 13 | 14 | from kashgari.layers import L 15 | from kashgari.tasks.classification.abc_model import ABCClassificationModel 16 | 17 | 18 | class BiLSTM_Model(ABCClassificationModel): 19 | @classmethod 20 | def default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]: 21 | return { 22 | 'layer_bi_lstm': { 23 | 'units': 128, 24 | 'return_sequences': False 25 | }, 26 | 'layer_output': { 27 | 28 | } 29 | } 30 | 31 | def build_model_arc(self) -> None: 32 | output_dim = self.label_processor.vocab_size 33 | 34 | config = self.hyper_parameters 35 | embed_model = self.embedding.embed_model 36 | 37 | # build model structure in sequent way 38 | layer_stack = [ 39 | L.Bidirectional(L.LSTM(**config['layer_bi_lstm'])), 40 | L.Dense(output_dim, **config['layer_output']), 41 | self._activation_layer() 42 | ] 43 | 44 | tensor = embed_model.output 45 | for layer in layer_stack: 46 | tensor = layer(tensor) 47 | 48 | self.tf_model: keras.Model = keras.Model(embed_model.inputs, tensor) 49 | -------------------------------------------------------------------------------- /kashgari/tasks/classification/cnn_attention_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: Adline 4 | # contact: gglfxsld@gmail.com 5 | # blog: https://medium.com/@Adline125 6 | 7 | # file: cnn_attention_model.py 8 | # time: 3:05 下午 9 | 10 | from abc import ABC 11 | from typing import Dict, Any 12 | import tensorflow as tf 13 | import tensorflow.keras.layers as L 14 | from tensorflow import keras 15 | from kashgari.logger import logger 16 | 17 | from kashgari.tasks.classification.abc_model import ABCClassificationModel 18 | 19 | 20 | class CNN_Attention_Model(ABCClassificationModel, ABC): 21 | 22 | @classmethod 23 | def default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]: 24 | return { 25 | 'conv_layer1': { 26 | 'filters': 264, 27 | 'kernel_size': 3, 28 | 'padding': 'same', 29 | 'activation': 'relu' 30 | }, 31 | 'conv_layer2': { 32 | 'filters': 128, 33 | 'kernel_size': 3, 34 | 'padding': 'same', 35 | 'activation': 'relu' 36 | }, 37 | 'conv_layer3': { 38 | 'filters': 64, 39 | 'kernel_size': 3, 40 | 'padding': 'same', 41 | 'activation': 'relu' 42 | }, 43 | 'layer_output': { 44 | }, 45 | } 46 | 47 | def build_model_arc(self) -> None: 48 | if tuple(tf.__version__.split('.')) < tuple('2.1.0'.split('.')): 49 | logger.warning("Attention layer not serializable because it takes init args " 50 | "but doesn't implement get_config. " 51 | "Please try Attention layer with tf versions >= 2.1.0. " 52 | "Issue: https://github.com/tensorflow/tensorflow/issues/32662") 53 | output_dim = self.label_processor.vocab_size 54 | config = self.hyper_parameters 55 | 56 | embed_model = self.embedding.embed_model 57 | # Query embeddings of shape [batch_size, Tq, dimension]. 58 | query_embeddings = embed_model.output 59 | # Value embeddings of shape [batch_size, Tv, dimension]. 60 | value_embeddings = embed_model.output 61 | 62 | # CNN layer. 63 | cnn_layer_1 = L.Conv1D(**config['conv_layer1']) 64 | # Query encoding of shape [batch_size, Tq, filters]. 65 | query_seq_encoding = cnn_layer_1(query_embeddings) 66 | # Value encoding of shape [batch_size, Tv, filters]. 67 | value_seq_encoding = cnn_layer_1(value_embeddings) 68 | 69 | cnn_layer_2 = L.Conv1D(**config['conv_layer2']) 70 | query_seq_encoding = cnn_layer_2(query_seq_encoding) 71 | value_seq_encoding = cnn_layer_2(value_seq_encoding) 72 | 73 | cnn_layer_3 = L.Conv1D(**config['conv_layer3']) 74 | query_seq_encoding = cnn_layer_3(query_seq_encoding) 75 | value_seq_encoding = cnn_layer_3(value_seq_encoding) 76 | 77 | # Query-value attention of shape [batch_size, Tq, filters]. 78 | query_value_attention_seq = L.Attention()( 79 | [query_seq_encoding, value_seq_encoding]) 80 | 81 | # Reduce over the sequence axis to produce encodings of shape 82 | # [batch_size, filters]. 83 | query_encoding = L.GlobalMaxPool1D()(query_seq_encoding) 84 | query_value_attention = L.GlobalMaxPool1D()(query_value_attention_seq) 85 | 86 | # Concatenate query and document encodings to produce a DNN input layer. 87 | input_layer = L.Concatenate(axis=-1)([query_encoding, query_value_attention]) 88 | 89 | output = L.Dense(output_dim, **config['layer_output'])(input_layer) 90 | output = self._activation_layer()(output) 91 | 92 | self.tf_model = keras.Model(embed_model.input, output) 93 | 94 | 95 | if __name__ == "__main__": 96 | pass 97 | -------------------------------------------------------------------------------- /kashgari/tasks/classification/cnn_gru_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: cnn_gru_model.py 8 | # time: 5:08 下午 9 | 10 | from typing import Dict, Any 11 | 12 | from tensorflow import keras 13 | 14 | from kashgari.layers import L 15 | from kashgari.tasks.classification.abc_model import ABCClassificationModel 16 | 17 | 18 | class CNN_GRU_Model(ABCClassificationModel): 19 | @classmethod 20 | def default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]: 21 | return { 22 | 'conv1d_layer': { 23 | 'filters': 128, 24 | 'kernel_size': 5, 25 | 'activation': 'relu' 26 | }, 27 | 'max_pool_layer': {}, 28 | 'gru_layer': { 29 | 'units': 100 30 | }, 31 | 'layer_output': { 32 | 33 | }, 34 | } 35 | 36 | def build_model_arc(self) -> None: 37 | output_dim = self.label_processor.vocab_size 38 | 39 | config = self.hyper_parameters 40 | embed_model = self.embedding.embed_model 41 | 42 | # build model structure in sequent way 43 | layer_stack = [ 44 | L.Conv1D(**config['conv1d_layer']), 45 | L.MaxPooling1D(**config['max_pool_layer']), 46 | L.GRU(**config['gru_layer']), 47 | L.Dense(output_dim, **config['layer_output']), 48 | self._activation_layer() 49 | ] 50 | 51 | tensor = embed_model.output 52 | for layer in layer_stack: 53 | tensor = layer(tensor) 54 | 55 | self.tf_model = keras.Model(embed_model.inputs, tensor) 56 | 57 | 58 | if __name__ == "__main__": 59 | pass 60 | -------------------------------------------------------------------------------- /kashgari/tasks/classification/cnn_lstm_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: cnn_lstm_model.py 8 | # time: 5:07 下午 9 | 10 | from typing import Dict, Any 11 | 12 | from tensorflow import keras 13 | 14 | from kashgari.layers import L 15 | from kashgari.tasks.classification.abc_model import ABCClassificationModel 16 | 17 | 18 | class CNN_LSTM_Model(ABCClassificationModel): 19 | @classmethod 20 | def default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]: 21 | return { 22 | 'conv1d_layer': { 23 | 'filters': 128, 24 | 'kernel_size': 5, 25 | 'activation': 'relu' 26 | }, 27 | 'max_pool_layer': {}, 28 | 'lstm_layer': { 29 | 'units': 100 30 | }, 31 | 'layer_output': { 32 | 33 | }, 34 | } 35 | 36 | def build_model_arc(self) -> None: 37 | output_dim = self.label_processor.vocab_size 38 | 39 | config = self.hyper_parameters 40 | embed_model = self.embedding.embed_model 41 | 42 | # build model structure in sequent way 43 | layer_stack = [ 44 | L.Conv1D(**config['conv1d_layer']), 45 | L.MaxPooling1D(**config['max_pool_layer']), 46 | L.LSTM(**config['lstm_layer']), 47 | L.Dense(output_dim, **config['layer_output']), 48 | self._activation_layer() 49 | ] 50 | 51 | tensor = embed_model.output 52 | for layer in layer_stack: 53 | tensor = layer(tensor) 54 | 55 | self.tf_model = keras.Model(embed_model.inputs, tensor) 56 | 57 | 58 | if __name__ == "__main__": 59 | pass 60 | -------------------------------------------------------------------------------- /kashgari/tasks/classification/cnn_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: cnn_model.py 8 | # time: 3:31 下午 9 | 10 | from typing import Dict, Any 11 | 12 | from tensorflow import keras 13 | 14 | from kashgari.layers import L 15 | from kashgari.tasks.classification.abc_model import ABCClassificationModel 16 | 17 | 18 | class CNN_Model(ABCClassificationModel): 19 | @classmethod 20 | def default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]: 21 | return { 22 | 'conv1d_layer': { 23 | 'filters': 128, 24 | 'kernel_size': 5, 25 | 'activation': 'relu' 26 | }, 27 | 'max_pool_layer': {}, 28 | 'dense_layer': { 29 | 'units': 64, 30 | 'activation': 'relu' 31 | }, 32 | 'layer_output': { 33 | 34 | }, 35 | } 36 | 37 | def build_model_arc(self) -> None: 38 | output_dim = self.label_processor.vocab_size 39 | 40 | config = self.hyper_parameters 41 | embed_model = self.embedding.embed_model 42 | 43 | # build model structure in sequent way 44 | layer_stack = [ 45 | L.Conv1D(**config['conv1d_layer']), 46 | L.GlobalMaxPooling1D(**config['max_pool_layer']), 47 | L.Dense(**config['dense_layer']), 48 | L.Dense(output_dim, **config['layer_output']), 49 | self._activation_layer() 50 | ] 51 | 52 | tensor = embed_model.output 53 | for layer in layer_stack: 54 | tensor = layer(tensor) 55 | 56 | self.tf_model = keras.Model(embed_model.inputs, tensor) 57 | 58 | 59 | if __name__ == "__main__": 60 | pass 61 | -------------------------------------------------------------------------------- /kashgari/tasks/labeling/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __init__.py 8 | # time: 4:30 下午 9 | 10 | from .abc_model import ABCLabelingModel 11 | from .bi_gru_model import BiGRU_Model 12 | from .bi_gru_crf_model import BiGRU_CRF_Model 13 | from .bi_lstm_model import BiLSTM_Model 14 | from .bi_lstm_crf_model import BiLSTM_CRF_Model 15 | from .cnn_lstm_model import CNN_LSTM_Model 16 | 17 | ALL_MODELS = [ 18 | BiGRU_Model, 19 | BiGRU_CRF_Model, 20 | BiLSTM_Model, 21 | BiLSTM_CRF_Model, 22 | CNN_LSTM_Model, 23 | ] 24 | 25 | if __name__ == "__main__": 26 | pass 27 | -------------------------------------------------------------------------------- /kashgari/tasks/labeling/bi_gru_crf_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Author : BrikerMan 4 | # Site : https://eliyar.biz 5 | 6 | # Time : 2020/9/1 10:26 下午 7 | # File : bi_gru_crf_model.py 8 | # Project : Kashgari 9 | 10 | from typing import Dict, Any 11 | 12 | from tensorflow import keras 13 | 14 | from kashgari.layers import L, KConditionalRandomField 15 | from kashgari.tasks.labeling.abc_model import ABCLabelingModel 16 | 17 | 18 | class BiGRU_CRF_Model(ABCLabelingModel): 19 | 20 | @classmethod 21 | def default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]: 22 | return { 23 | 'layer_bgru': { 24 | 'units': 128, 25 | 'return_sequences': True 26 | }, 27 | 'layer_dropout': { 28 | 'rate': 0.4 29 | }, 30 | 'layer_time_distributed': {}, 31 | 'layer_activation': { 32 | 'activation': 'softmax' 33 | } 34 | } 35 | 36 | def build_model_arc(self) -> None: 37 | output_dim = self.label_processor.vocab_size 38 | 39 | config = self.hyper_parameters 40 | embed_model = self.embedding.embed_model 41 | 42 | crf = KConditionalRandomField() 43 | 44 | layer_stack = [ 45 | L.Bidirectional(L.GRU(**config['layer_bgru']), name='layer_bgru'), 46 | L.Dropout(**config['layer_dropout'], name='layer_dropout'), 47 | L.Dense(output_dim, **config['layer_time_distributed']), 48 | crf 49 | ] 50 | 51 | tensor = embed_model.output 52 | for layer in layer_stack: 53 | tensor = layer(tensor) 54 | 55 | self.tf_model = keras.Model(embed_model.inputs, tensor) 56 | self.crf_layer = crf 57 | 58 | def compile_model(self, 59 | loss: Any = None, 60 | optimizer: Any = None, 61 | metrics: Any = None, 62 | **kwargs: Any) -> None: 63 | if loss is None: 64 | loss = self.crf_layer.loss 65 | if metrics is None: 66 | metrics = [self.crf_layer.accuracy] 67 | super(BiGRU_CRF_Model, self).compile_model(loss=loss, 68 | optimizer=optimizer, 69 | metrics=metrics, 70 | **kwargs) 71 | 72 | 73 | if __name__ == "__main__": 74 | from kashgari.corpus import ChineseDailyNerCorpus 75 | from kashgari.callbacks import EvalCallBack 76 | 77 | train_x, train_y = ChineseDailyNerCorpus.load_data('train') 78 | valid_x, valid_y = ChineseDailyNerCorpus.load_data('valid') 79 | test_x, test_y = ChineseDailyNerCorpus.load_data('test') 80 | 81 | model = BiGRU_CRF_Model(sequence_length=10) 82 | 83 | eval_callback = EvalCallBack(kash_model=model, 84 | x_data=valid_x, 85 | y_data=valid_y, 86 | truncating=True, 87 | step=1) 88 | 89 | model.fit(train_x, train_y, valid_x, valid_y, epochs=1, 90 | callbacks=[]) 91 | y = model.predict(test_x[:200]) 92 | model.tf_model.summary() 93 | model.evaluate(test_x, test_y) 94 | -------------------------------------------------------------------------------- /kashgari/tasks/labeling/bi_gru_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: bi_gru_model.py 8 | # time: 5:01 下午 9 | 10 | from typing import Dict, Any 11 | 12 | from tensorflow import keras 13 | 14 | from kashgari.layers import L 15 | from kashgari.tasks.labeling.abc_model import ABCLabelingModel 16 | 17 | 18 | class BiGRU_Model(ABCLabelingModel): 19 | 20 | @classmethod 21 | def default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]: 22 | return { 23 | 'layer_bgru': { 24 | 'units': 128, 25 | 'return_sequences': True 26 | }, 27 | 'layer_dropout': { 28 | 'rate': 0.4 29 | }, 30 | 'layer_time_distributed': {}, 31 | 'layer_activation': { 32 | 'activation': 'softmax' 33 | } 34 | } 35 | 36 | def build_model_arc(self) -> None: 37 | output_dim = self.label_processor.vocab_size 38 | 39 | config = self.hyper_parameters 40 | embed_model = self.embedding.embed_model 41 | 42 | layer_stack = [ 43 | L.Bidirectional(L.GRU(**config['layer_bgru']), name='layer_bgru'), 44 | L.Dropout(**config['layer_dropout'], name='layer_dropout'), 45 | L.TimeDistributed(L.Dense(output_dim, **config['layer_time_distributed']), name='layer_time_distributed'), 46 | L.Activation(**config['layer_activation']) 47 | ] 48 | 49 | tensor = embed_model.output 50 | for layer in layer_stack: 51 | tensor = layer(tensor) 52 | 53 | self.tf_model = keras.Model(embed_model.inputs, tensor) 54 | 55 | 56 | if __name__ == "__main__": 57 | from kashgari.corpus import ChineseDailyNerCorpus 58 | from kashgari.callbacks import EvalCallBack 59 | 60 | train_x, train_y = ChineseDailyNerCorpus.load_data('train') 61 | valid_x, valid_y = ChineseDailyNerCorpus.load_data('valid') 62 | test_x, test_y = ChineseDailyNerCorpus.load_data('test') 63 | 64 | model = BiGRU_Model(sequence_length=10) 65 | 66 | eval_callback = EvalCallBack(kash_model=model, 67 | x_data=valid_x, 68 | y_data=valid_y, 69 | truncating=True, 70 | step=1) 71 | 72 | model.fit(train_x[:300], train_y[:300], valid_x, valid_y, epochs=1, 73 | callbacks=[eval_callback]) 74 | y = model.predict(train_x[:200]) 75 | model.tf_model.summary() 76 | model.evaluate(test_x, test_y) 77 | -------------------------------------------------------------------------------- /kashgari/tasks/labeling/bi_lstm_crf_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Author : BrikerMan 4 | # Site : https://eliyar.biz 5 | 6 | # Time : 2020/9/1 11:51 下午 7 | # File : bi_lstm_crf_model.py 8 | # Project : Kashgari 9 | 10 | from typing import Dict, Any 11 | 12 | from tensorflow import keras 13 | from kashgari.layers import L, KConditionalRandomField 14 | from kashgari.tasks.labeling.abc_model import ABCLabelingModel 15 | 16 | 17 | class BiLSTM_CRF_Model(ABCLabelingModel): 18 | 19 | @classmethod 20 | def default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]: 21 | return { 22 | 'layer_blstm': { 23 | 'units': 128, 24 | 'return_sequences': True 25 | }, 26 | 'layer_dropout': { 27 | 'rate': 0.4 28 | }, 29 | 'layer_time_distributed': {}, 30 | 'layer_activation': { 31 | 'activation': 'softmax' 32 | } 33 | } 34 | 35 | def build_model_arc(self) -> None: 36 | output_dim = self.label_processor.vocab_size 37 | 38 | config = self.hyper_parameters 39 | embed_model = self.embedding.embed_model 40 | 41 | crf = KConditionalRandomField() 42 | 43 | layer_stack = [ 44 | L.Bidirectional(L.LSTM(**config['layer_blstm']), name='layer_blstm'), 45 | L.Dropout(**config['layer_dropout'], name='layer_dropout'), 46 | L.Dense(output_dim, **config['layer_time_distributed']), 47 | crf 48 | ] 49 | 50 | tensor = embed_model.output 51 | for layer in layer_stack: 52 | tensor = layer(tensor) 53 | 54 | self.tf_model = keras.Model(embed_model.inputs, tensor) 55 | self.crf_layer = crf 56 | 57 | def compile_model(self, 58 | loss: Any = None, 59 | optimizer: Any = None, 60 | metrics: Any = None, 61 | **kwargs: Any) -> None: 62 | if loss is None: 63 | loss = self.crf_layer.loss 64 | if metrics is None: 65 | metrics = [self.crf_layer.accuracy] 66 | super(BiLSTM_CRF_Model, self).compile_model(loss=loss, 67 | optimizer=optimizer, 68 | metrics=metrics, 69 | **kwargs) 70 | -------------------------------------------------------------------------------- /kashgari/tasks/labeling/bi_lstm_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: bi_lstm_model.py 8 | # time: 4:36 下午 9 | 10 | from typing import Dict, Any 11 | 12 | from tensorflow import keras 13 | 14 | from kashgari.layers import L 15 | from kashgari.tasks.labeling.abc_model import ABCLabelingModel 16 | 17 | 18 | class BiLSTM_Model(ABCLabelingModel): 19 | @classmethod 20 | def default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]: 21 | return { 22 | 'layer_blstm': { 23 | 'units': 128, 24 | 'return_sequences': True 25 | }, 26 | 'layer_dropout': { 27 | 'rate': 0.4 28 | }, 29 | 'layer_time_distributed': {}, 30 | 'layer_activation': { 31 | 'activation': 'softmax' 32 | } 33 | } 34 | 35 | def build_model_arc(self) -> None: 36 | output_dim = self.label_processor.vocab_size 37 | 38 | config = self.hyper_parameters 39 | embed_model = self.embedding.embed_model 40 | 41 | layer_stack = [ 42 | L.Bidirectional(L.LSTM(**config['layer_blstm']), name='layer_blstm'), 43 | L.Dropout(**config['layer_dropout'], name='layer_dropout'), 44 | L.Dense(output_dim, **config['layer_time_distributed']), 45 | L.Activation(**config['layer_activation']) 46 | ] 47 | tensor = embed_model.output 48 | for layer in layer_stack: 49 | tensor = layer(tensor) 50 | 51 | self.tf_model = keras.Model(embed_model.inputs, tensor) 52 | 53 | 54 | if __name__ == "__main__": 55 | from kashgari.corpus import ChineseDailyNerCorpus 56 | 57 | x, y = ChineseDailyNerCorpus.load_data() 58 | x_valid, y_valid = ChineseDailyNerCorpus.load_data('valid') 59 | model = BiLSTM_Model() 60 | model.fit(x, y, x_valid, y_valid, epochs=2) 61 | model.evaluate(*ChineseDailyNerCorpus.load_data('test')) 62 | -------------------------------------------------------------------------------- /kashgari/tasks/labeling/cnn_lstm_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: cnn_lstm_model.py 8 | # time: 5:28 下午 9 | 10 | from typing import Dict, Any 11 | 12 | from tensorflow import keras 13 | 14 | from kashgari.layers import L 15 | from kashgari.tasks.labeling.abc_model import ABCLabelingModel 16 | 17 | 18 | class CNN_LSTM_Model(ABCLabelingModel): 19 | 20 | @classmethod 21 | def default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]: 22 | return { 23 | 'layer_bgru': { 24 | 'units': 128, 25 | 'return_sequences': True 26 | }, 27 | 'layer_dropout': { 28 | 'rate': 0.4 29 | }, 30 | 'layer_time_distributed': {}, 31 | 'layer_activation': { 32 | 'activation': 'softmax' 33 | } 34 | } 35 | 36 | def build_model_arc(self) -> None: 37 | output_dim = self.label_processor.vocab_size 38 | 39 | config = self.hyper_parameters 40 | embed_model = self.embedding.embed_model 41 | 42 | layer_stack = [ 43 | L.Bidirectional(L.GRU(**config['layer_bgru']), name='layer_bgru'), 44 | L.Dropout(**config['layer_dropout'], name='layer_dropout'), 45 | L.TimeDistributed(L.Dense(output_dim, **config['layer_time_distributed']), name='layer_time_distributed'), 46 | L.Activation(**config['layer_activation']) 47 | ] 48 | 49 | tensor = embed_model.output 50 | for layer in layer_stack: 51 | tensor = layer(tensor) 52 | 53 | self.tf_model = keras.Model(embed_model.inputs, tensor) 54 | 55 | 56 | if __name__ == "__main__": 57 | pass 58 | -------------------------------------------------------------------------------- /kashgari/tasks/seq2seq/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __init__.py 8 | # time: 2:30 下午 9 | 10 | from .model import Seq2Seq 11 | 12 | if __name__ == "__main__": 13 | pass 14 | -------------------------------------------------------------------------------- /kashgari/tasks/seq2seq/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __init__.py 8 | # time: 2:33 下午 9 | 10 | from .att_gru_decoder import AttGRUDecoder # type: ignore 11 | from .gru_decoder import GRUDecoder # type: ignore 12 | 13 | if __name__ == "__main__": 14 | pass 15 | -------------------------------------------------------------------------------- /kashgari/tasks/seq2seq/decoder/att_gru_decoder.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: gru_att_decoder.py 8 | # time: 9:42 下午 9 | 10 | # type: ignore 11 | 12 | import tensorflow as tf 13 | 14 | from kashgari.embeddings.abc_embedding import ABCEmbedding 15 | from kashgari.layers import L 16 | 17 | 18 | class AttGRUDecoder(tf.keras.Model): 19 | def __init__(self, 20 | embedding: ABCEmbedding, 21 | vocab_size: int, 22 | hidden_size: int = 1024): 23 | super(AttGRUDecoder, self).__init__() 24 | self.embedding = embedding 25 | self.hidden_size = hidden_size 26 | self.gru = tf.keras.layers.GRU(hidden_size, 27 | return_sequences=True, 28 | return_state=True, 29 | recurrent_initializer='glorot_uniform') 30 | self.fc = tf.keras.layers.Dense(vocab_size) 31 | 32 | # 用于注意力 33 | self.attention = L.BahdanauAttention(hidden_size) 34 | 35 | def call(self, x, hidden, enc_output): 36 | # enc_output shape == (batch_size, max_length, hidden_size) 37 | context_vector, attention_weights = self.attention(hidden, enc_output) 38 | 39 | if self.embedding.segment: 40 | x = x, tf.zeros(x.shape) 41 | 42 | # x shape after passing through embedding == (batch_size, 1, embedding_dim) 43 | x = self.embedding.embed_model(x) 44 | 45 | # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size) 46 | x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1) 47 | 48 | # passing the concatenated vector to the GRU 49 | output, state = self.gru(x) 50 | 51 | # output shape == (batch_size * 1, hidden_size) 52 | output = tf.reshape(output, (-1, output.shape[2])) 53 | 54 | # output shape == (batch_size, vocab) 55 | x = self.fc(output) 56 | 57 | return x, state, attention_weights 58 | 59 | def model(self): 60 | x1 = L.Input(shape=(None,)) 61 | x2 = L.Input(shape=(self.hidden_size,)) 62 | x3 = L.Input(shape=(self.hidden_size,)) 63 | return tf.keras.Model(inputs=[x1, x2, x3], 64 | outputs=self.call(x1, x2, x3), 65 | name='AttGRUDecoder') 66 | 67 | 68 | if __name__ == "__main__": 69 | pass 70 | -------------------------------------------------------------------------------- /kashgari/tasks/seq2seq/decoder/gru_decoder.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: gru_decoder.py 8 | # time: 9:41 下午 9 | 10 | # type: ignore 11 | 12 | import tensorflow as tf 13 | 14 | from kashgari.embeddings.abc_embedding import ABCEmbedding 15 | 16 | 17 | class GRUDecoder(tf.keras.Model): 18 | def __init__(self, 19 | embedding: ABCEmbedding, 20 | hidden_size: int, 21 | vocab_size: int): 22 | super(GRUDecoder, self).__init__() 23 | self.embedding = embedding 24 | 25 | self.gru = tf.keras.layers.GRU(hidden_size, 26 | return_sequences=True, 27 | return_state=True, 28 | recurrent_initializer='glorot_uniform') 29 | self.fc = tf.keras.layers.Dense(vocab_size) 30 | 31 | def call(self, dec_input, dec_hidden, enc_output): 32 | # x 在通过嵌入层后的形状 == (批大小,1,嵌入维度) 33 | decoder_embedding = self.embedding.embed_model(dec_input) 34 | 35 | s = self.gru(decoder_embedding, initial_state=dec_hidden) 36 | decoder_outputs, decoder_state = s 37 | 38 | # 输出的形状 == (批大小 * 1,隐藏层大小) 39 | output = tf.reshape(decoder_outputs, (-1, decoder_outputs.shape[2])) 40 | 41 | # 输出的形状 == (批大小,vocab) 42 | x = self.fc(output) 43 | return x, decoder_state, None 44 | 45 | 46 | if __name__ == "__main__": 47 | pass 48 | -------------------------------------------------------------------------------- /kashgari/tasks/seq2seq/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __init__.py 8 | # time: 2:30 下午 9 | 10 | from .gru_encoder import GRUEncoder 11 | 12 | if __name__ == "__main__": 13 | pass 14 | -------------------------------------------------------------------------------- /kashgari/tasks/seq2seq/encoder/gru_encoder.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: encoder.py 8 | # time: 2:31 下午 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | from kashgari.embeddings.abc_embedding import ABCEmbedding 14 | from kashgari.layers import L 15 | 16 | 17 | class GRUEncoder(tf.keras.Model): 18 | def __init__(self, embedding: ABCEmbedding, hidden_size: int = 1024): 19 | super(GRUEncoder, self).__init__() 20 | self.embedding = embedding 21 | self.hidden_size = hidden_size 22 | self.gru = tf.keras.layers.GRU(hidden_size, 23 | return_sequences=True, 24 | return_state=True, 25 | recurrent_initializer='glorot_uniform') 26 | 27 | def call(self, x: np.ndarray, hidden: np.ndarray) -> np.ndarray: 28 | if self.embedding.segment: 29 | x = (x, tf.zeros(x.shape)) 30 | x = self.embedding.embed_model(x) 31 | output, state = self.gru(x, initial_state=hidden) 32 | return output, state 33 | 34 | def model(self) -> tf.keras.Model: 35 | x1 = L.Input(shape=(None,)) 36 | x2 = L.Input(shape=(self.hidden_size,)) 37 | return tf.keras.Model(inputs=[x1, x2], 38 | outputs=self.call(x1, x2), 39 | name='GRUEncoder') 40 | 41 | 42 | if __name__ == "__main__": 43 | pass 44 | -------------------------------------------------------------------------------- /kashgari/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __init__.py 8 | # time: 11:34 上午 9 | 10 | from kashgari.tokenizers.base_tokenizer import Tokenizer 11 | from kashgari.tokenizers.bert_tokenizer import BertTokenizer 12 | from kashgari.tokenizers.jieba_tokenizer import JiebaTokenizer 13 | -------------------------------------------------------------------------------- /kashgari/tokenizers/base_tokenizer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: base_tokenizer.py 8 | # time: 11:24 上午 9 | 10 | from typing import List 11 | 12 | 13 | class Tokenizer: 14 | """ 15 | Abstract base class for all implemented tokenizer. 16 | """ 17 | 18 | def tokenize(self, text: str) -> List[str]: 19 | """ 20 | Tokenize text into token sequence 21 | Args: 22 | text: target text sample 23 | 24 | Returns: 25 | List of tokens in this sample 26 | """ 27 | return text.split(' ') 28 | -------------------------------------------------------------------------------- /kashgari/tokenizers/bert_tokenizer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: bert_tokenizer.py 8 | # time: 11:33 上午 9 | 10 | # flake8: noqa: E127 11 | 12 | import codecs 13 | import unicodedata 14 | from typing import List, Optional, Dict 15 | 16 | from kashgari.tokenizers.base_tokenizer import Tokenizer 17 | 18 | TOKEN_PAD = '' # Token for padding 19 | TOKEN_UNK = '[UNK]' # Token for unknown words 20 | TOKEN_CLS = '[CLS]' # Token for classification 21 | TOKEN_SEP = '[SEP]' # Token for separation 22 | TOKEN_MASK = '[MASK]' # Token for masking 23 | 24 | 25 | class BertTokenizer(Tokenizer): 26 | """ 27 | Bert Like Tokenizer, ref: https://github.com/CyberZHG/keras-bert/blob/master/keras_bert/tokenizer.py 28 | 29 | """ 30 | 31 | def __init__(self, 32 | *, 33 | token_dict: Optional[Dict[str, int]] = None, 34 | token_cls: str = TOKEN_CLS, 35 | token_sep: str = TOKEN_SEP, 36 | token_unk: str = TOKEN_UNK, 37 | pad_index: int = 0, 38 | cased: bool = False) -> None: 39 | """ 40 | Initialize tokenizer. 41 | Args: 42 | token_dict: A dict maps tokens to indices. 43 | token_cls: The token represents classification. 44 | token_sep: The token represents separator. 45 | token_unk: The token represents unknown token. 46 | pad_index: The index to pad. 47 | cased: Whether to keep the case. 48 | """ 49 | self._token_dict: Dict[str, int] 50 | 51 | if token_dict: 52 | self._token_dict = token_dict 53 | else: 54 | self._token_dict = {} 55 | 56 | self._token_dict_inv: Dict[int, str] = {v: k for k, v in self._token_dict.items()} 57 | self._token_cls: str = token_cls 58 | self._token_sep: str = token_sep 59 | self._token_unk: str = token_unk 60 | self._pad_index: int = pad_index 61 | self._cased: bool = cased 62 | 63 | @classmethod 64 | def load_from_vocab_file(cls, vocab_path: str) -> 'BertTokenizer': 65 | token2idx: Dict[str, int] = {} 66 | with codecs.open(vocab_path, 'r', 'utf8') as reader: 67 | for line in reader: 68 | token = line.strip() 69 | token2idx[token] = len(token2idx) 70 | return BertTokenizer(token_dict=token2idx) 71 | 72 | def tokenize(self, text: str) -> List[str]: 73 | """ 74 | Split text to tokens. 75 | Args: 76 | text: text to tokenize. 77 | 78 | Returns: 79 | A list of strings. 80 | """ 81 | tokens = self._tokenize(text) 82 | return tokens 83 | 84 | def _tokenize(self, text: str) -> List[str]: 85 | if not self._cased: 86 | text = unicodedata.normalize('NFD', text) 87 | text = ''.join([ch for ch in text if unicodedata.category(ch) != 'Mn']) 88 | text = text.lower() 89 | spaced = '' 90 | for ch in text: 91 | if self._is_punctuation(ch) or self._is_cjk_character(ch): 92 | spaced += ' ' + ch + ' ' 93 | elif self._is_space(ch): 94 | spaced += ' ' 95 | elif ord(ch) == 0 or ord(ch) == 0xfffd or self._is_control(ch): 96 | continue 97 | else: 98 | spaced += ch 99 | 100 | if len(self._token_dict) > 0: 101 | tokens: List[str] = [] 102 | for word in spaced.strip().split(): 103 | tokens += self._word_piece_tokenize(word) 104 | return tokens 105 | else: 106 | return spaced.strip().split() 107 | 108 | def _word_piece_tokenize(self, word: str) -> List[str]: 109 | if word in self._token_dict: 110 | return [word] 111 | tokens = [] 112 | start, stop = 0, 0 113 | while start < len(word): 114 | stop = len(word) 115 | while stop > start: 116 | sub = word[start:stop] 117 | if start > 0: 118 | sub = '##' + sub 119 | if sub in self._token_dict: 120 | break 121 | stop -= 1 122 | if start == stop: 123 | stop += 1 124 | tokens.append(sub) 125 | start = stop 126 | return tokens 127 | 128 | @staticmethod 129 | def _is_punctuation(ch: str) -> bool: # noqa: E127 130 | code = ord(ch) 131 | return 33 <= code <= 47 or \ 132 | 58 <= code <= 64 or \ 133 | 91 <= code <= 96 or \ 134 | 123 <= code <= 126 or \ 135 | unicodedata.category(ch).startswith('P') 136 | 137 | @staticmethod 138 | def _is_cjk_character(ch: str) -> bool: 139 | code = ord(ch) 140 | return 0x4E00 <= code <= 0x9FFF or \ 141 | 0x3400 <= code <= 0x4DBF or \ 142 | 0x20000 <= code <= 0x2A6DF or \ 143 | 0x2A700 <= code <= 0x2B73F or \ 144 | 0x2B740 <= code <= 0x2B81F or \ 145 | 0x2B820 <= code <= 0x2CEAF or \ 146 | 0xF900 <= code <= 0xFAFF or \ 147 | 0x2F800 <= code <= 0x2FA1F 148 | 149 | @staticmethod 150 | def _is_space(ch: str) -> bool: 151 | return ch == ' ' or ch == '\n' or ch == '\r' or ch == '\t' or unicodedata.category(ch) == 'Zs' 152 | 153 | @staticmethod 154 | def _is_control(ch: str) -> bool: 155 | return unicodedata.category(ch) in ('Cc', 'Cf') 156 | -------------------------------------------------------------------------------- /kashgari/tokenizers/jieba_tokenizer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: jieba_tokenizer.py 8 | # time: 11:54 上午 9 | 10 | from typing import List, Any 11 | 12 | from kashgari.tokenizers.base_tokenizer import Tokenizer 13 | 14 | 15 | class JiebaTokenizer(Tokenizer): 16 | """ 17 | Jieba tokenizer 18 | """ 19 | 20 | def __init__(self) -> None: 21 | try: 22 | import jieba 23 | self._jieba = jieba 24 | except ModuleNotFoundError: 25 | raise ModuleNotFoundError("Jieba module not found, please install use `pip install jieba`") 26 | 27 | def tokenize(self, text: str, **kwargs: Any) -> List[str]: 28 | """ 29 | Tokenize text into token sequence 30 | Args: 31 | text: target text sample 32 | 33 | Returns: 34 | List of tokens in this sample 35 | """ 36 | 37 | return list(self._jieba.cut(text, **kwargs)) 38 | -------------------------------------------------------------------------------- /kashgari/types.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: types.py 8 | # time: 3:54 下午 9 | 10 | from typing import List, Union, Tuple 11 | 12 | TextSamplesVar = List[List[str]] 13 | NumSamplesListVar = List[List[int]] 14 | LabelSamplesVar = Union[TextSamplesVar, List[str]] 15 | 16 | ClassificationLabelVar = List[str] 17 | MultiLabelClassificationLabelVar = Union[List[List[str]], List[Tuple[str]]] 18 | 19 | if __name__ == "__main__": 20 | pass 21 | -------------------------------------------------------------------------------- /kashgari/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __init__.py 8 | # time: 11:22 上午 9 | 10 | import warnings 11 | import tensorflow as tf 12 | from typing import TYPE_CHECKING, Union 13 | from tensorflow.keras.utils import CustomObjectScope 14 | 15 | from kashgari import custom_objects 16 | from .data import get_list_subset 17 | from .data import unison_shuffled_copies 18 | from .multi_label import MultiLabelBinarizer 19 | from .serialize import load_data_object 20 | from .model import convert_to_saved_model 21 | 22 | if TYPE_CHECKING: 23 | from kashgari.tasks.labeling import ABCLabelingModel 24 | from kashgari.tasks.classification import ABCClassificationModel 25 | 26 | 27 | def custom_object_scope() -> CustomObjectScope: 28 | return tf.keras.utils.custom_object_scope(custom_objects) 29 | 30 | 31 | def load_model(model_path: str) -> Union["ABCLabelingModel", "ABCClassificationModel"]: 32 | warnings.warn("The 'load_model' function is deprecated, " 33 | "use 'XX_Model.load_model' instead", DeprecationWarning, 2) 34 | from kashgari.tasks.abs_task_model import ABCTaskModel 35 | return ABCTaskModel.load_model(model_path=model_path) 36 | 37 | 38 | if __name__ == "__main__": 39 | pass 40 | -------------------------------------------------------------------------------- /kashgari/utils/data.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: data.py 8 | # time: 11:22 上午 9 | 10 | import random 11 | from typing import List, Union, TypeVar, Tuple 12 | 13 | import numpy as np 14 | 15 | T = TypeVar("T") 16 | 17 | 18 | def get_list_subset(target: List[T], index_list: List[int]) -> List[T]: 19 | """ 20 | Get the subset of the target list 21 | Args: 22 | target: target list 23 | index_list: subset items index 24 | 25 | Returns: 26 | subset of the original list 27 | """ 28 | return [target[i] for i in index_list if i < len(target)] 29 | 30 | 31 | def unison_shuffled_copies(a: List[T], 32 | b: List[T]) -> Union[Tuple[List[T], ...], Tuple[np.ndarray, ...]]: 33 | """ 34 | Union shuffle two arrays 35 | Args: 36 | a: 37 | b: 38 | 39 | Returns: 40 | 41 | """ 42 | data_type = type(a) 43 | assert len(a) == len(b) 44 | c = list(zip(a, b)) 45 | random.shuffle(c) 46 | a, b = zip(*c) 47 | if data_type == np.ndarray: 48 | return np.array(a), np.array(b) 49 | return list(a), list(b) 50 | -------------------------------------------------------------------------------- /kashgari/utils/model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: model.py 8 | # time: 10:57 上午 9 | 10 | import json 11 | import os 12 | import pathlib 13 | import time 14 | from typing import Union, Any 15 | 16 | from kashgari.tasks.abs_task_model import ABCTaskModel 17 | 18 | 19 | def convert_to_saved_model(model: ABCTaskModel, 20 | model_path: str, 21 | version: Union[str, int] = None, 22 | signatures: Any = None, 23 | options: Any = None) -> None: 24 | """ 25 | Export model for tensorflow serving 26 | Args: 27 | model: Target model. 28 | model_path: The path to which the SavedModel will be stored. 29 | version: The model version code, default timestamp 30 | signatures: Signatures to save with the SavedModel. Applicable to the 31 | 'tf' format only. Please see the `signatures` argument in 32 | `tf.saved_model.save` for details. 33 | options: Optional `tf.saved_model.SaveOptions` object that specifies 34 | options for saving to SavedModel. 35 | 36 | """ 37 | if not isinstance(model, ABCTaskModel): 38 | raise ValueError("Only supports the classification model and labeling model") 39 | if version is None: 40 | version = round(time.time()) 41 | export_path = os.path.join(model_path, str(version)) 42 | 43 | pathlib.Path(export_path).mkdir(exist_ok=True, parents=True) 44 | model.tf_model.save(export_path, save_format='tf', signatures=signatures, options=options) 45 | 46 | with open(os.path.join(export_path, 'model_config.json'), 'w') as f: 47 | f.write(json.dumps(model.to_dict(), indent=2, ensure_ascii=True)) 48 | f.close() 49 | 50 | 51 | if __name__ == "__main__": 52 | pass 53 | -------------------------------------------------------------------------------- /kashgari/utils/multi_label.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: multi_label.py 8 | # time: 11:23 上午 9 | 10 | 11 | from typing import List, Dict 12 | 13 | import numpy as np 14 | 15 | from kashgari.types import MultiLabelClassificationLabelVar 16 | 17 | 18 | class MultiLabelBinarizer: 19 | def __init__(self, vocab2idx: Dict[str, int]): 20 | self.vocab2idx = vocab2idx 21 | self.idx2vocab = dict([(v, k) for k, v in vocab2idx.items()]) 22 | 23 | @property 24 | def classes(self) -> List[str]: 25 | return list(self.idx2vocab.values()) 26 | 27 | def transform(self, samples: MultiLabelClassificationLabelVar) -> np.ndarray: 28 | data = np.zeros((len(samples), len(self.vocab2idx))) 29 | for sample_index, sample in enumerate(samples): 30 | for label in sample: 31 | data[sample_index][self.vocab2idx[label]] = 1 32 | return data 33 | 34 | def inverse_transform(self, preds: np.ndarray, threshold: float = 0.5) -> List[List[str]]: 35 | data = [] 36 | for sample in preds: 37 | x = [] 38 | for label_x in np.where(sample >= threshold)[0]: 39 | x.append(self.idx2vocab[label_x]) 40 | data.append(x) 41 | return data 42 | 43 | 44 | if __name__ == "__main__": 45 | pass 46 | -------------------------------------------------------------------------------- /kashgari/utils/serialize.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: serialize.py 8 | # time: 11:23 上午 9 | 10 | import pydoc 11 | from typing import Dict, Any 12 | 13 | 14 | def load_data_object(data: Dict, 15 | custom_objects: Dict = None, 16 | **kwargs: Dict) -> Any: 17 | """ 18 | Load Object From Dict 19 | Args: 20 | data: 21 | custom_objects: 22 | **kwargs: 23 | 24 | Returns: 25 | 26 | """ 27 | if custom_objects is None: 28 | custom_objects = {} 29 | 30 | if data['__class_name__'] in custom_objects: 31 | obj: Any = custom_objects[data['__class_name__']](**data['config'], **kwargs) 32 | else: 33 | module_name = f"{data['__module__']}.{data['__class_name__']}" 34 | obj: Any = pydoc.locate(module_name)(**data['config'], **kwargs) # type: ignore 35 | if hasattr(obj, '_override_load_model'): 36 | obj._override_load_model(data) 37 | 38 | return obj 39 | 40 | 41 | if __name__ == "__main__": 42 | pass 43 | -------------------------------------------------------------------------------- /legacy_docs/docs/CNAME: -------------------------------------------------------------------------------- 1 | kashgari.bmio.net 2 | -------------------------------------------------------------------------------- /legacy_docs/docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome to Kashgari Documents 2 | 3 | Documents migrated to ReadTheDoc. 4 | 5 | 文档已迁移至 ReadTheDoc, 请按照版本查阅。 6 | 7 | - [Kashgari V1.x docs](https://kashgari.readthedocs.io/en/v1.1.5/) 8 | - [Kashgari V2.x docs](https://kashgari.readthedocs.io/en/latest/) 9 | 10 | ![](./version_selection.jpg) 11 | -------------------------------------------------------------------------------- /legacy_docs/docs/version_selection.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BrikerMan/Kashgari/ffe730d33f894e99a6fd7aa17ca67d161bf70359/legacy_docs/docs/version_selection.jpg -------------------------------------------------------------------------------- /legacy_docs/mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Kashgari 2 | theme: 3 | name: material 4 | feature: 5 | tabs: true 6 | palette: 7 | primary: red 8 | accent: red 9 | language: "en" 10 | repo_url: https://github.com/BrikerMan/Kashgari 11 | 12 | copyright: "Copyright © 2018 - 2018 BrikerMan" 13 | 14 | google_analytics: 15 | - "UA-78705624-3" 16 | - "auto" 17 | 18 | #markdown_extensions: 19 | # - markdown.extensions.admonition 20 | # - markdown.extensions.def_list 21 | # - markdown.extensions.footnotes 22 | # - markdown.extensions.meta 23 | # - markdown.extensions.toc: 24 | # permalink: "#" 25 | # # pymdown-extensions 26 | # - pymdownx.arithmatex 27 | # - pymdownx.betterem: 28 | # smart_enable: none 29 | # - pymdownx.caret 30 | # - pymdownx.critic 31 | # - pymdownx.details 32 | # - pymdownx.emoji: 33 | # emoji_generator: !!python/name:pymdownx.emoji.to_svg 34 | # - pymdownx.inlinehilite 35 | # - pymdownx.keys 36 | # - pymdownx.magiclink 37 | # - pymdownx.mark 38 | # - pymdownx.smartsymbols 39 | # - pymdownx.superfences 40 | # - pymdownx.highlight: 41 | # # css_class: codehilite # https://github.com/facelessuser/pymdown-extensions/blob/master/mkdocs.yml 42 | # extend_pygments_lang: 43 | # - name: python 44 | # lang: python 45 | # options: 46 | # startinline: true 47 | # - pymdownx.tasklist: 48 | # custom_checkbox: true 49 | # - pymdownx.tilde 50 | 51 | #extra_css: 52 | # - static/css/extra.css 53 | # 54 | extra_javascript: 55 | - static/js/baidu-static.js 56 | 57 | #extra: 58 | # social: 59 | # - type: github-alt 60 | # link: https://github.com/BrikerMan 61 | 62 | #nav: 63 | # - Introduction: index.md 64 | # - 中文文档: https://kashgari-zh.bmio.net 65 | # - Tutorials: 66 | # - tutorial/text-classification.md 67 | # - tutorial/text-labeling.md 68 | # - tutorial/text-scoring.md 69 | # 70 | # - Embeddings: 71 | # - embeddings/index.md 72 | # - embeddings/bare-embedding.md 73 | # - embeddings/word-embedding.md 74 | # - embeddings/bert-embedding.md 75 | # - embeddings/gpt2-embedding.md 76 | # - embeddings/numeric-features-embedding.md 77 | # - embeddings/stacked-embedding.md 78 | # 79 | # - Advanced: 80 | # - advance-use/multi-output-model.md 81 | # - advance-use/handle-numeric-features.md 82 | # - advance-use/tensorflow-serving.md 83 | # 84 | # - API: 85 | # - api/corpus.md 86 | # - api/embeddings.md 87 | # - api/tasks.classification.md 88 | # - api/tasks.labeling.md 89 | # - api/utils.md 90 | # - api/callbacks.md 91 | # 92 | # - FAQ: 93 | # - FAQ.md 94 | # 95 | # - About: 96 | # - about/contributing.md 97 | # - about/release-notes.md 98 | -------------------------------------------------------------------------------- /legacy_docs/readme.md: -------------------------------------------------------------------------------- 1 | # Legacy docs 2 | 3 | This is the original docs on https://kashgari.bmio.net/. 4 | 5 | I had to switch to ReadTheDoc for versioning. 6 | -------------------------------------------------------------------------------- /requirements.dev.txt: -------------------------------------------------------------------------------- 1 | # test & coverage 2 | flake8==3.8.4 3 | flake8-builtins 4 | mypy==0.790 5 | pytest==5.4.3 6 | pytest-cov 7 | pytest-split 8 | coveralls 9 | 10 | # documents 11 | sphinx 12 | sphinx-autobuild 13 | sphinx-rtd-theme 14 | sphinx-markdown-tables 15 | sphinx-autodoc-typehints 16 | recommonmark 17 | m2r 18 | 19 | # develop 20 | jupyterlab 21 | tabulate 22 | pandas 23 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.18.1 2 | gensim>=3.8.1,<4.0.0 3 | pandas>=1.0.1 4 | tqdm 5 | # Limit this version to avoid json serilization issue. 6 | # See https://github.com/bojone/bert4keras/issues/241 7 | bert4keras>=0.9.1 8 | scikit-learn>=0.21.1 9 | # tensorflow>=2.1.0 10 | # tensorflow_addons 11 | -------------------------------------------------------------------------------- /scripts/clean.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | echo "delete caches" 6 | 7 | if [ -d '_site' ] ; then 8 | rm -r _site 9 | fi 10 | 11 | if [ -d 'kashgari.egg-info' ] ; then 12 | rm -r kashgari.egg-info 13 | fi 14 | 15 | if [ -d '.coverage.imac.46767.552596' ] ; then 16 | rm -r .coverage.imac.46767.552596 17 | fi 18 | 19 | if [ -d '.pytest_cache' ] ; then 20 | rm -r .pytest_cache 21 | fi 22 | 23 | if [ -d 'tf_dir' ] ; then 24 | rm -r tf_dir 25 | fi 26 | 27 | if [ -d '_site_src' ] ; then 28 | rm -r _site_src 29 | fi 30 | 31 | if [ -d 'dist' ] ; then 32 | rm -r dist 33 | fi 34 | 35 | if [ -d 'build' ] ; then 36 | rm -r build 37 | fi 38 | 39 | if [ -d 'htmlcov' ] ; then 40 | rm -r htmlcov 41 | fi 42 | 43 | if [ -d 'test_report' ] ; then 44 | rm -r test_report 45 | fi 46 | -------------------------------------------------------------------------------- /scripts/docs-generate.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | echo "Build API documents to _site folder" 6 | 7 | if [ -d '_site_src' ] ; then 8 | rm -r _site_src 9 | fi 10 | 11 | cp -r docs _site_src 12 | 13 | if [ -d 'site' ] ; then 14 | rm -r site 15 | fi 16 | 17 | sphinx-build _site_src _site -n -a -T -b dirhtml 18 | -------------------------------------------------------------------------------- /scripts/docs-lint.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | docker run --rm -it -u $(id -u):$(id -g) -v "$PWD/_site":/mnt linkchecker/linkchecker index.html 6 | -------------------------------------------------------------------------------- /scripts/docs-live.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | echo "Build and Run API documents" 4 | 5 | sh scripts/docs-generate.sh 6 | 7 | python3 -m http.server --directory _site 8808 8 | -------------------------------------------------------------------------------- /scripts/install_addons.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | tf_version = str(sys.argv[1]) 5 | 6 | # TF 2.0, 2.1 7 | if tf_version in ['2.0', '2.1']: 8 | addons_version = '0.9.1' 9 | # TF 2.2 10 | elif tf_version == '2.2': 11 | addons_version = '0.11.2' 12 | # TF 2.3+ 13 | if tf_version in ['2.3', '2.4', '2.5']: 14 | addons_version = '0.13.0' 15 | 16 | if addons_version: 17 | print(f'Should Install tensorflow-addons=={addons_version}') 18 | os.system(f"pip install 'tensorflow-addons=={addons_version}'") 19 | -------------------------------------------------------------------------------- /scripts/install_tf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | tf_args = str(sys.argv[1]) 5 | major_version, minor_version = tf_args.split('.') 6 | 7 | command = ( 8 | f"pip install 'tensorflow>={major_version}.{minor_version}.0," 9 | f"<{major_version}.{int(minor_version)+1}.0'" 10 | ) 11 | print(command) 12 | os.system(command) 13 | -------------------------------------------------------------------------------- /scripts/lint.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | set -x 5 | 6 | mypy kashgari --config-file=.config.ini 7 | flake8 kashgari --config=.config.ini 8 | -------------------------------------------------------------------------------- /scripts/markdown2rst.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: markdown2rst.py 8 | # time: 12:08 上午 9 | 10 | import sys 11 | from m2r import convert 12 | 13 | 14 | def convert_file(file_path: str, target_path: str = None): 15 | if target_path is None: 16 | target_path = file_path.replace('.md', '.rst') 17 | 18 | with open(file_path, 'r') as f: 19 | md_content = f.read() 20 | 21 | with open(target_path, 'w') as f: 22 | f.write(convert(md_content)) 23 | print(f'Saved RST file to {target_path}') 24 | 25 | 26 | if __name__ == "__main__": 27 | convert_file(*sys.argv[1:]) 28 | -------------------------------------------------------------------------------- /scripts/tests.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | set -x 5 | 6 | pytest --doctest-modules --junitxml=test-reports/junit.xml \ 7 | --cov=kashgari --cov-report=xml:coverage.xml --cov-report=html:htmlcov --cov-config .config.ini tests 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | """ 4 | @author: BrikerMan 5 | @contact: eliyar917@gmail.com 6 | @blog: https://eliyar.biz 7 | 8 | @version: 1.0 9 | @license: Apache Licence 10 | @file: setup.py 11 | @time: 2019-01-24 16:42 12 | 13 | """ 14 | import codecs 15 | import os 16 | import pathlib 17 | import re 18 | 19 | from setuptools import find_packages, setup 20 | 21 | HERE = pathlib.Path(__file__).parent 22 | 23 | 24 | def read(*parts): 25 | with codecs.open(os.path.join(HERE, *parts), 'r') as fp: 26 | return fp.read() 27 | 28 | 29 | def find_version(*file_paths): 30 | version_file = read(*file_paths) 31 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", 32 | version_file, re.M) 33 | if version_match: 34 | return version_match.group(1) 35 | raise RuntimeError("Unable to find version string.") 36 | 37 | 38 | __name__ = 'kashgari' 39 | __author__ = "BrikerMan" 40 | __copyright__ = "Copyright 2018, BrikerMan" 41 | __credits__ = [] 42 | __license__ = "Apache License 2.0" 43 | __maintainer__ = "BrikerMan" 44 | __email__ = "eliyar917@gmail.com" 45 | 46 | __url__ = 'https://github.com/BrikerMan/Kashgari' 47 | __description__ = 'Simple, Keras-powered multilingual NLP framework,' \ 48 | ' allows you to build your models in 5 minutes for named entity recognition (NER),' \ 49 | ' part-of-speech tagging (PoS) and text classification tasks. ' \ 50 | 'Includes BERT, GPT-2 and word2vec embedding.' 51 | 52 | __version__ = find_version('kashgari', '__version__.py') 53 | README = (HERE / "README.md").read_text(encoding='utf-8') 54 | 55 | with codecs.open('requirements.txt', 'r', 'utf8') as reader: 56 | install_requires = list(map(lambda x: x.strip(), reader.readlines())) 57 | 58 | setup( 59 | name=__name__, 60 | version=__version__, 61 | description=__description__, 62 | python_requires='>3.6', 63 | long_description=README, 64 | long_description_content_type="text/markdown", 65 | author=__author__, 66 | author_email=__email__, 67 | url=__url__, 68 | packages=find_packages(exclude=('tests',)), 69 | install_requires=install_requires, 70 | include_package_data=True, 71 | license=__license__, 72 | classifiers=[ 73 | 'License :: OSI Approved :: Apache Software License', 74 | # 'Programming Language :: Python', 75 | 'Programming Language :: Python :: 3.6', 76 | 'Programming Language :: Python :: Implementation :: CPython', 77 | 'Programming Language :: Python :: Implementation :: PyPy' 78 | ], 79 | ) 80 | 81 | if __name__ == "__main__": 82 | print("Hello world") 83 | -------------------------------------------------------------------------------- /sonar-project.properties: -------------------------------------------------------------------------------- 1 | sonar.projectKey=BrikerMan_Kashgari 2 | sonar.organization=brikerman-github 3 | 4 | # This is the name and version displayed in the SonarCloud UI. 5 | #sonar.projectName=Kashgari 6 | #sonar.projectVersion=1.0 7 | 8 | # Path is relative to the sonar-project.properties file. Replace "\" by "/" on Windows. 9 | #sonar.sources=. 10 | 11 | # Encoding of the source code. Default is default system encoding 12 | #sonar.sourceEncoding=UTF-8 13 | 14 | sonar.python.coverage.reportPaths=artifacts/coverage*/coverage*.xml 15 | sonar.python.xunit.reportPath=artifacts/junit*/junit-*.xml 16 | -------------------------------------------------------------------------------- /test_performance/classifications.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Author : BrikerMan 4 | # Site : https://eliyar.biz 5 | 6 | # Time : 2020/8/29 11:16 上午 7 | # File : classifications.py 8 | # Project : Kashgari 9 | 10 | import logging 11 | import os 12 | import time 13 | from datetime import datetime 14 | from typing import Type 15 | 16 | import pandas as pd 17 | import tensorflow as tf 18 | 19 | from kashgari.callbacks import EvalCallBack 20 | from kashgari.corpus import SMP2018ECDTCorpus 21 | from kashgari.embeddings import BertEmbedding 22 | from kashgari.tasks.classification import ABCClassificationModel 23 | from kashgari.tasks.classification import ALL_MODELS 24 | from examples.tools import get_bert_path 25 | 26 | log_root = "tf_dir/classification/" + datetime.now().strftime("%m%d-%H:%M") 27 | 28 | 29 | class ClassificationPerformance: 30 | MODELS = ALL_MODELS 31 | 32 | def run_with_model_class(self, model_class: Type[ABCClassificationModel], epochs: int): 33 | bert_path = get_bert_path() 34 | 35 | train_x, train_y = SMP2018ECDTCorpus.load_data('train') 36 | valid_x, valid_y = SMP2018ECDTCorpus.load_data('valid') 37 | test_x, test_y = SMP2018ECDTCorpus.load_data('test') 38 | 39 | bert_embed = BertEmbedding(bert_path) 40 | model = model_class(bert_embed) 41 | 42 | log_path = os.path.join(log_root, model_class.__name__) 43 | file_writer = tf.summary.create_file_writer(log_path + "/metrics") 44 | file_writer.set_as_default() 45 | callbacks = [EvalCallBack(model, test_x, test_y, step=1)] 46 | 47 | model.fit(train_x, train_y, valid_x, valid_y, epochs=epochs, callbacks=callbacks) 48 | 49 | report = model.evaluate(test_x, test_y) 50 | del model 51 | del bert_embed 52 | return report 53 | 54 | def run(self, epochs=10): 55 | logging.basicConfig(level='DEBUG') 56 | reports = [] 57 | for model_class in self.MODELS: 58 | logging.info("=" * 80) 59 | logging.info("") 60 | logging.info("") 61 | logging.info(f" Start Training {model_class.__name__}") 62 | logging.info("") 63 | logging.info("") 64 | logging.info("=" * 80) 65 | start = time.time() 66 | report = self.run_with_model_class(model_class, epochs=epochs) 67 | time_cost = time.time() - start 68 | reports.append({ 69 | 'model_name': model_class.__name__, 70 | "epoch": epochs, 71 | 'f1-score': report['f1-score'], 72 | 'precision': report['precision'], 73 | 'recall': report['recall'], 74 | 'time': f"{int(time_cost // 60):02}:{int(time_cost % 60):02}" 75 | }) 76 | 77 | df = pd.DataFrame(reports) 78 | print(df.to_markdown()) 79 | 80 | 81 | if __name__ == '__main__': 82 | p = ClassificationPerformance() 83 | p.run() 84 | -------------------------------------------------------------------------------- /test_performance/labeling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Author : BrikerMan 4 | # Site : https://eliyar.biz 5 | 6 | # Time : 2020/8/29 11:47 上午 7 | # File : labeling.py 8 | # Project : Kashgari 9 | 10 | import os 11 | from datetime import datetime 12 | from typing import Type 13 | 14 | import tensorflow as tf 15 | 16 | from kashgari.callbacks import EvalCallBack 17 | from kashgari.corpus import ChineseDailyNerCorpus 18 | from kashgari.embeddings import BertEmbedding 19 | from kashgari.tasks.labeling import ABCLabelingModel 20 | from kashgari.tasks.labeling import ALL_MODELS 21 | from test_performance.classifications import ClassificationPerformance 22 | from examples.tools import get_bert_path 23 | 24 | log_root = "tf_dir/labeling/" + datetime.now().strftime("%m%d-%H:%M") 25 | 26 | gpus = tf.config.experimental.list_physical_devices('GPU') 27 | if gpus: 28 | try: 29 | for gpu in gpus: 30 | tf.config.experimental.set_memory_growth(gpu, True) 31 | except RuntimeError as e: 32 | print(e) 33 | 34 | 35 | class LabelingPerformance(ClassificationPerformance): 36 | MODELS = ALL_MODELS 37 | 38 | def run_with_model_class(self, model_class: Type[ABCLabelingModel], epochs: int): 39 | bert_path = get_bert_path() 40 | 41 | train_x, train_y = ChineseDailyNerCorpus.load_data('train') 42 | valid_x, valid_y = ChineseDailyNerCorpus.load_data('valid') 43 | test_x, test_y = ChineseDailyNerCorpus.load_data('test') 44 | 45 | bert_embed = BertEmbedding(bert_path) 46 | model = model_class(bert_embed) 47 | 48 | log_path = os.path.join(log_root, model_class.__name__) 49 | file_writer = tf.summary.create_file_writer(log_path + "/metrics") 50 | file_writer.set_as_default() 51 | callbacks = [EvalCallBack(model, test_x, test_y, step=1, truncating=True)] 52 | # callbacks = [] 53 | model.fit(train_x, train_y, valid_x, valid_y, epochs=epochs, callbacks=callbacks) 54 | 55 | report = model.evaluate(test_x, test_y) 56 | del model 57 | del bert_embed 58 | return report 59 | 60 | 61 | if __name__ == '__main__': 62 | p = LabelingPerformance() 63 | p.run(epochs=10) 64 | -------------------------------------------------------------------------------- /test_performance/readme.md: -------------------------------------------------------------------------------- 1 | # Performance 2 | 3 | This is for run performance report on models with bert-embedding. 4 | 5 | 6 | ## Classification 7 | 8 | ```python 9 | from kashgari.corpus import SMP2018ECDTCorpus 10 | 11 | train_x, train_y = SMP2018ECDTCorpus.load_data('train') 12 | valid_x, valid_y = SMP2018ECDTCorpus.load_data('valid') 13 | test_x, test_y = SMP2018ECDTCorpus.load_data('test') 14 | ``` 15 | 16 | | | model_name | epoch | f1-score | precision | recall | time | 17 | |---:|:--------------------|--------:|-----------:|------------:|---------:|:-------| 18 | | 0 | BiGRU_Model | 10 | 0.9335 | 0.937795 | 0.935065 | 00:33 | 19 | | 1 | BiLSTM_Model | 10 | 0.929075 | 0.930548 | 0.92987 | 00:33 | 20 | | 2 | CNN_Attention_Model | 10 | 0.862197 | 0.888507 | 0.866234 | 00:27 | 21 | | 3 | CNN_GRU_Model | 10 | 0.840024 | 0.886519 | 0.850649 | 00:28 | 22 | | 4 | CNN_LSTM_Model | 10 | 0.424649 | 0.551247 | 0.511688 | 00:27 | 23 | | 5 | CNN_Model | 10 | 0.930336 | 0.938373 | 0.931169 | 00:26 | 24 | 25 | ## NER 26 | 27 | ```python 28 | from kashgari.corpus import ChineseDailyNerCorpus 29 | 30 | train_x, train_y = ChineseDailyNerCorpus.load_data('train') 31 | valid_x, valid_y = ChineseDailyNerCorpus.load_data('valid') 32 | test_x, test_y = ChineseDailyNerCorpus.load_data('test') 33 | ``` 34 | 35 | | | model_name | epoch | f1-score | precision | recall | time | 36 | |---:|:-----------------|--------:|-----------:|------------:|---------:|:-------| 37 | | 0 | BiGRU_Model | 10 | 0.921583 | 0.913184 | 0.930532 | 19:10 | 38 | | 1 | BiGRU_CRF_Model | 10 | 0.935163 | 0.931246 | 0.939118 | 24:30 | 39 | | 2 | BiLSTM_Model | 10 | 0.915363 | 0.906566 | 0.924418 | 19:12 | 40 | | 3 | BiLSTM_CRF_Model | 10 | 0.940539 | 0.944549 | 0.936646 | 24:31 | 41 | | 4 | CNN_LSTM_Model | 10 | 0.919783 | 0.909695 | 0.930272 | 19:07 | 42 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __init__.py 8 | # time: 2019-05-31 19:09 9 | 10 | 11 | if __name__ == "__main__": 12 | print("Hello world") 13 | -------------------------------------------------------------------------------- /tests/test_classification/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __init__.py 8 | # time: 1:57 下午 9 | 10 | 11 | if __name__ == "__main__": 12 | pass 13 | -------------------------------------------------------------------------------- /tests/test_classification/test_bi_gru_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: test_bi_gru_model.py 8 | # time: 3:28 下午 9 | 10 | import unittest 11 | 12 | import tests.test_classification.test_bi_lstm_model as base 13 | from kashgari.embeddings import WordEmbedding 14 | from kashgari.tasks.classification import BiGRU_Model 15 | from tests.test_macros import TestMacros 16 | 17 | 18 | class TestBiGRU_Model(base.TestBiLSTM_Model): 19 | @classmethod 20 | def setUpClass(cls): 21 | cls.EPOCH_COUNT = 2 22 | cls.TASK_MODEL_CLASS = BiGRU_Model 23 | cls.w2v_embedding = WordEmbedding(TestMacros.w2v_path) 24 | 25 | 26 | if __name__ == "__main__": 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /tests/test_classification/test_bi_lstm_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __init__.py 8 | # time: 1:57 下午 9 | 10 | import os 11 | import tempfile 12 | import time 13 | import unittest 14 | 15 | from kashgari.corpus import SMP2018ECDTCorpus 16 | from kashgari.embeddings import WordEmbedding 17 | from kashgari.tasks.classification import BiLSTM_Model 18 | from tests.test_macros import TestMacros 19 | 20 | 21 | class TestBiLSTM_Model(unittest.TestCase): 22 | 23 | @classmethod 24 | def setUpClass(cls): 25 | cls.EPOCH_COUNT = 1 26 | cls.TASK_MODEL_CLASS = BiLSTM_Model 27 | cls.w2v_embedding = WordEmbedding(TestMacros.w2v_path) 28 | 29 | @classmethod 30 | def tearDownClass(cls) -> None: 31 | cls.w2v_embedding = None 32 | 33 | def test_basic_use(self): 34 | model = self.TASK_MODEL_CLASS(sequence_length=20) 35 | train_x, train_y = SMP2018ECDTCorpus.load_data() 36 | valid_x, valid_y = train_x, train_y 37 | 38 | model.fit(train_x, 39 | train_y, 40 | x_validate=valid_x, 41 | y_validate=valid_y, 42 | epochs=self.EPOCH_COUNT) 43 | 44 | model_path = os.path.join(tempfile.gettempdir(), str(time.time())) 45 | original_y = model.predict(train_x[:20]) 46 | model.save(model_path) 47 | 48 | # Make sure use sigmoid as activation function 49 | assert model.tf_model.layers[-1].activation.__name__ == 'softmax' 50 | 51 | del model 52 | new_model = self.TASK_MODEL_CLASS.load_model(model_path) 53 | new_model.tf_model.summary() 54 | new_y = new_model.predict(train_x[:20]) 55 | assert new_y == original_y 56 | 57 | report = new_model.evaluate(valid_x, valid_y) 58 | for key in ['precision', 'recall', 'f1-score', 'support', 'detail']: 59 | assert key in report 60 | 61 | # Make sure use sigmoid as activation function 62 | assert new_model.tf_model.layers[-1].activation.__name__ == 'softmax' 63 | 64 | # TF Serving Test 65 | from kashgari.utils import convert_to_saved_model 66 | convert_to_saved_model(new_model, 67 | os.path.join(model_path, 'serving'), 68 | version=1) 69 | 70 | from kashgari.processors import load_processors_from_model 71 | _ = load_processors_from_model(os.path.join(model_path, 'serving', '1')) 72 | 73 | def test_multi_label(self): 74 | corpus = TestMacros.jigsaw_mini_corpus 75 | model = self.TASK_MODEL_CLASS(sequence_length=20, multi_label=True) 76 | x, y = corpus.load_data() 77 | model.fit(x, y, epochs=self.EPOCH_COUNT) 78 | 79 | model_path = os.path.join(tempfile.gettempdir(), str(time.time())) 80 | original_y = model.predict(x[:20]) 81 | model.save(model_path) 82 | 83 | # Make sure use sigmoid as activation function 84 | assert model.tf_model.layers[-1].activation.__name__ == 'sigmoid' 85 | del model 86 | 87 | new_model = self.TASK_MODEL_CLASS.load_model(model_path) 88 | new_model.tf_model.summary() 89 | new_y = new_model.predict(x[:20]) 90 | 91 | assert new_y == original_y 92 | 93 | report = new_model.evaluate(x, y) 94 | for key in ['precision', 'recall', 'f1-score', 'support', 'detail']: 95 | assert key in report 96 | 97 | # Make sure use sigmoid as activation function 98 | assert new_model.tf_model.layers[-1].activation.__name__ == 'sigmoid' 99 | 100 | def test_with_word_embedding(self): 101 | model = self.TASK_MODEL_CLASS(embedding=self.w2v_embedding) 102 | train_x, train_y = SMP2018ECDTCorpus.load_data() 103 | valid_x, valid_y = train_x, train_y 104 | 105 | model.fit(train_x, 106 | train_y, 107 | x_validate=valid_x, 108 | y_validate=valid_y, 109 | epochs=self.EPOCH_COUNT) 110 | 111 | model_path = os.path.join(tempfile.gettempdir(), str(time.time())) 112 | _ = model.predict(valid_x[:20]) 113 | model.save(model_path) 114 | 115 | del model 116 | 117 | new_model = self.TASK_MODEL_CLASS.load_model(model_path) 118 | new_model.tf_model.summary() 119 | _ = new_model.predict(valid_x[:20]) 120 | 121 | 122 | if __name__ == '__main__': 123 | unittest.main() 124 | -------------------------------------------------------------------------------- /tests/test_classification/test_cnn_attention_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: Adline 4 | # contact: gglfxsld@gmail.com 5 | # blog: https://medium.com/@Adline125 6 | 7 | # file: test_cnn_attention_model.py 8 | # time: 3:31 下午 9 | 10 | import pytest 11 | import unittest 12 | import tensorflow as tf 13 | 14 | import tests.test_classification.test_bi_lstm_model as base 15 | from kashgari.embeddings import WordEmbedding 16 | from kashgari.tasks.classification.cnn_attention_model import CNN_Attention_Model 17 | from tests.test_macros import TestMacros 18 | 19 | 20 | @pytest.mark.xfail(tuple(tf.__version__.split('.')) < tuple('2.1.0'.split('.')), 21 | reason='Attention Layer cannot be load and saved in TF 2.0.0') 22 | class TestCnnAttention_Model(base.TestBiLSTM_Model): 23 | @classmethod 24 | def setUpClass(cls): 25 | cls.EPOCH_COUNT = 1 26 | cls.TASK_MODEL_CLASS = CNN_Attention_Model 27 | cls.w2v_embedding = WordEmbedding(TestMacros.w2v_path) 28 | 29 | def test_multi_label(self): 30 | super(TestCnnAttention_Model, self).test_multi_label() 31 | 32 | 33 | if __name__ == "__main__": 34 | unittest.main() 35 | -------------------------------------------------------------------------------- /tests/test_classification/test_cnn_gru_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: test_cnn_gru_model.py 8 | # time: 5:17 下午 9 | 10 | import unittest 11 | 12 | import tests.test_classification.test_bi_lstm_model as base 13 | from kashgari.embeddings import WordEmbedding 14 | from kashgari.tasks.classification import CNN_GRU_Model 15 | from tests.test_macros import TestMacros 16 | 17 | 18 | class TestCNN_GRU_Model(base.TestBiLSTM_Model): 19 | @classmethod 20 | def setUpClass(cls): 21 | cls.EPOCH_COUNT = 1 22 | cls.TASK_MODEL_CLASS = CNN_GRU_Model 23 | cls.w2v_embedding = WordEmbedding(TestMacros.w2v_path) 24 | 25 | 26 | if __name__ == "__main__": 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /tests/test_classification/test_cnn_lstm_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: test_cnn_lstm_model.py 8 | # time: 5:17 下午 9 | 10 | import unittest 11 | 12 | import tests.test_classification.test_bi_lstm_model as base 13 | from kashgari.embeddings import WordEmbedding 14 | from kashgari.tasks.classification import CNN_LSTM_Model 15 | from tests.test_macros import TestMacros 16 | 17 | 18 | class TestCNN_LSTM_Model(base.TestBiLSTM_Model): 19 | @classmethod 20 | def setUpClass(cls): 21 | cls.EPOCH_COUNT = 1 22 | cls.TASK_MODEL_CLASS = CNN_LSTM_Model 23 | cls.w2v_embedding = WordEmbedding(TestMacros.w2v_path) 24 | 25 | 26 | if __name__ == "__main__": 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /tests/test_classification/test_cnn_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: test_cnn_model.py 8 | # time: 3:33 下午 9 | 10 | import unittest 11 | 12 | import tests.test_classification.test_bi_lstm_model as base 13 | from kashgari.embeddings import WordEmbedding 14 | from kashgari.tasks.classification import CNN_Model 15 | from tests.test_macros import TestMacros 16 | 17 | 18 | class TestBiGRU_Model(base.TestBiLSTM_Model): 19 | @classmethod 20 | def setUpClass(cls): 21 | cls.EPOCH_COUNT = 1 22 | cls.TASK_MODEL_CLASS = CNN_Model 23 | cls.w2v_embedding = WordEmbedding(TestMacros.w2v_path) 24 | 25 | 26 | if __name__ == "__main__": 27 | unittest.main() 28 | -------------------------------------------------------------------------------- /tests/test_classification/test_custom_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: test_custom_model.py 8 | # time: 6:07 下午 9 | 10 | import unittest 11 | from typing import Dict, Any 12 | 13 | from tensorflow import keras 14 | 15 | import tests.test_classification.test_bi_lstm_model as base 16 | from kashgari.embeddings import WordEmbedding 17 | from kashgari.layers import L 18 | from kashgari.tasks.classification.abc_model import ABCClassificationModel 19 | from tests.test_macros import TestMacros 20 | 21 | 22 | class Double_BiLSTM_Model(ABCClassificationModel): 23 | @classmethod 24 | def default_hyper_parameters(cls) -> Dict[str, Any]: 25 | return { 26 | 'layer_lstm1': { 27 | 'units': 128, 28 | 'return_sequences': True 29 | }, 30 | 'layer_lstm2': { 31 | 'units': 64, 32 | 'return_sequences': False 33 | }, 34 | 'layer_dropout': { 35 | 'rate': 0.5 36 | }, 37 | 'layer_output': { 38 | 39 | } 40 | } 41 | 42 | def build_model_arc(self) -> None: 43 | config = self.hyper_parameters 44 | output_dim = self.label_processor.vocab_size 45 | embed_model = self.embedding.embed_model 46 | 47 | # 定义模型架构 48 | self.tf_model = keras.Sequential([ 49 | embed_model, 50 | L.Bidirectional(L.LSTM(**config['layer_lstm1'])), 51 | L.Bidirectional(L.LSTM(**config['layer_lstm2'])), 52 | L.Dropout(**config['layer_dropout']), 53 | L.Dense(output_dim, **config['layer_output']), 54 | self._activation_layer() 55 | ]) 56 | 57 | 58 | class TestCustom_Model(base.TestBiLSTM_Model): 59 | @classmethod 60 | def setUpClass(cls): 61 | cls.EPOCH_COUNT = 1 62 | cls.TASK_MODEL_CLASS = Double_BiLSTM_Model 63 | cls.w2v_embedding = WordEmbedding(TestMacros.w2v_path) 64 | 65 | 66 | if __name__ == "__main__": 67 | unittest.main() 68 | -------------------------------------------------------------------------------- /tests/test_corpus.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: test_corpus.py 8 | # time: 10:47 上午 9 | 10 | import unittest 11 | from kashgari.corpus import ChineseDailyNerCorpus 12 | from kashgari.corpus import SMP2018ECDTCorpus 13 | 14 | 15 | class TestChineseDailyNerCorpus(unittest.TestCase): 16 | 17 | def test_load_data(self): 18 | train_x, train_y = ChineseDailyNerCorpus.load_data() 19 | assert len(train_x) == len(train_y) 20 | assert len(train_x) > 0 21 | assert train_x[:5] != train_y[:5] 22 | 23 | test_x, test_y = ChineseDailyNerCorpus.load_data('test') 24 | assert len(test_x) == len(test_y) 25 | assert len(test_x) > 0 26 | 27 | test_x, test_y = ChineseDailyNerCorpus.load_data('valid') 28 | assert len(test_x) == len(test_y) 29 | assert len(test_x) > 0 30 | 31 | 32 | class TestSMP2018ECDTCorpus(unittest.TestCase): 33 | 34 | def test_load_data(self): 35 | train_x, train_y = SMP2018ECDTCorpus.load_data() 36 | assert len(train_x) == len(train_y) 37 | assert len(train_x) > 0 38 | assert train_x[:5] != train_y[:5] 39 | 40 | test_x, test_y = SMP2018ECDTCorpus.load_data('test') 41 | assert len(test_x) == len(test_y) 42 | assert len(test_x) > 0 43 | 44 | test_x, test_y = SMP2018ECDTCorpus.load_data('valid') 45 | assert len(test_x) == len(test_y) 46 | assert len(test_x) > 0 47 | 48 | 49 | if __name__ == "__main__": 50 | pass 51 | -------------------------------------------------------------------------------- /tests/test_embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __init__.py 8 | # time: 10:47 上午 9 | 10 | 11 | if __name__ == "__main__": 12 | pass 13 | -------------------------------------------------------------------------------- /tests/test_embeddings/test_bare_embedding.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: test_bare_embedding.py 8 | # time: 2:29 下午 9 | 10 | import os 11 | import time 12 | import random 13 | import tempfile 14 | import unittest 15 | from kashgari.logger import logger 16 | from kashgari.processors import SequenceProcessor 17 | from kashgari.corpus import SMP2018ECDTCorpus 18 | from kashgari.embeddings import BareEmbedding 19 | from kashgari.tasks.classification import BiGRU_Model 20 | from kashgari.utils import load_data_object 21 | 22 | sample_count = 50 23 | 24 | 25 | class TestBareEmbedding(unittest.TestCase): 26 | 27 | def build_embedding(self): 28 | embedding = BareEmbedding() 29 | return embedding 30 | 31 | def test_base_cases(self): 32 | embedding = self.build_embedding() 33 | x, y = SMP2018ECDTCorpus.load_data() 34 | processor = SequenceProcessor() 35 | processor.build_vocab(x, y) 36 | embedding.setup_text_processor(processor) 37 | 38 | samples = random.sample(x, sample_count) 39 | res = embedding.embed(samples) 40 | max_len = max([len(i) for i in samples]) + 2 41 | 42 | if embedding.max_position is not None: 43 | max_len = embedding.max_position 44 | 45 | assert res.shape == (len(samples), max_len, embedding.embedding_size) 46 | 47 | # Test Save And Load 48 | embed_dict = embedding.to_dict() 49 | embedding2 = load_data_object(embed_dict) 50 | embedding2.setup_text_processor(processor) 51 | assert embedding2.embed(samples).shape == (len(samples), max_len, embedding.embedding_size) 52 | 53 | def test_with_model(self): 54 | x, y = SMP2018ECDTCorpus.load_data('test') 55 | embedding = self.build_embedding() 56 | 57 | model = BiGRU_Model(embedding=embedding) 58 | model.build_model(x, y) 59 | model_summary = [] 60 | embedding.embed_model.summary(print_fn=lambda x: model_summary.append(x)) 61 | logger.debug('\n'.join(model_summary)) 62 | 63 | model.fit(x, y, epochs=1) 64 | 65 | model_path = os.path.join(tempfile.gettempdir(), str(time.time())) 66 | model.save(model_path) 67 | 68 | 69 | if __name__ == "__main__": 70 | unittest.main() 71 | -------------------------------------------------------------------------------- /tests/test_embeddings/test_transformer_embedding.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: test_transformer_embedding.py 8 | # time: 2:47 下午 9 | 10 | from tensorflow.keras.utils import get_file 11 | 12 | from kashgari.embeddings import BertEmbedding 13 | from kashgari.macros import DATA_PATH 14 | from tests.test_embeddings.test_bare_embedding import TestBareEmbedding 15 | 16 | 17 | class TestTransferEmbedding(TestBareEmbedding): 18 | 19 | def build_embedding(self): 20 | bert_path = get_file('bert_sample_model', 21 | "http://s3.bmio.net/kashgari/bert_sample_model.tar.bz2", 22 | cache_dir=DATA_PATH, 23 | untar=True) 24 | embedding = BertEmbedding(model_folder=bert_path) 25 | return embedding 26 | 27 | 28 | if __name__ == "__main__": 29 | pass 30 | -------------------------------------------------------------------------------- /tests/test_embeddings/test_word_embedding.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: test_word_embedding.py 8 | # time: 2:55 下午 9 | 10 | import unittest 11 | 12 | from tensorflow.keras.utils import get_file 13 | 14 | from kashgari.embeddings import WordEmbedding 15 | from kashgari.macros import DATA_PATH 16 | from tests.test_embeddings.test_bare_embedding import TestBareEmbedding 17 | 18 | 19 | class TestWordEmbedding(TestBareEmbedding): 20 | 21 | def build_embedding(self): 22 | sample_w2v_path = get_file('sample_w2v.txt', 23 | "http://s3.bmio.net/kashgari/sample_w2v.txt", 24 | cache_dir=DATA_PATH) 25 | embedding = WordEmbedding(sample_w2v_path) 26 | return embedding 27 | 28 | 29 | if __name__ == '__main__': 30 | unittest.main() 31 | -------------------------------------------------------------------------------- /tests/test_generator.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: test_generator.py 8 | # time: 5:46 下午 9 | 10 | import unittest 11 | 12 | from kashgari.corpus import ChineseDailyNerCorpus 13 | from kashgari.generators import CorpusGenerator, BatchDataSet 14 | from kashgari.processors import SequenceProcessor 15 | from tests.test_macros import TestMacros 16 | 17 | 18 | class TestGenerator(unittest.TestCase): 19 | def test_corpus_generator(self): 20 | x_set, y_set = TestMacros.load_labeling_corpus() 21 | corpus_gen = CorpusGenerator(x_set, y_set) 22 | pass 23 | 24 | def test_batch_generator(self): 25 | x, y = ChineseDailyNerCorpus.load_data('valid') 26 | 27 | text_processor = SequenceProcessor() 28 | label_processor = SequenceProcessor(build_vocab_from_labels=True, min_count=1) 29 | 30 | corpus_gen = CorpusGenerator(x, y) 31 | 32 | text_processor.build_vocab_generator([corpus_gen]) 33 | label_processor.build_vocab_generator([corpus_gen]) 34 | 35 | batch_dataset = BatchDataSet(corpus_gen, 36 | text_processor=text_processor, 37 | label_processor=label_processor, 38 | segment=False, 39 | seq_length=None, 40 | max_position=100, 41 | batch_size=12) 42 | 43 | duplicate_len = len(batch_dataset) 44 | assert len(list(batch_dataset.take(duplicate_len))) == duplicate_len 45 | assert len(list(batch_dataset.take(1))) == 1 46 | 47 | def test_huge_batch_size(self): 48 | x, y = [['this', 'is', 'Jack', 'Ma']], [['O', 'O', 'B', 'I']] 49 | 50 | text_processor = SequenceProcessor() 51 | label_processor = SequenceProcessor(build_vocab_from_labels=True, min_count=1) 52 | 53 | corpus_gen = CorpusGenerator(x, y) 54 | 55 | text_processor.build_vocab_generator([corpus_gen]) 56 | label_processor.build_vocab_generator([corpus_gen]) 57 | 58 | batch_dataset = BatchDataSet(corpus_gen, 59 | text_processor=text_processor, 60 | label_processor=label_processor, 61 | segment=False, 62 | seq_length=None, 63 | max_position=100, 64 | batch_size=512) 65 | 66 | for x_b, y_b in batch_dataset.take(1): 67 | print(y_b.shape) 68 | duplicate_len = len(batch_dataset) 69 | assert len(list(batch_dataset.take(duplicate_len))) == duplicate_len 70 | assert len(list(batch_dataset.take(1))) == 1 71 | 72 | 73 | if __name__ == '__main__': 74 | unittest.main() 75 | -------------------------------------------------------------------------------- /tests/test_labeling/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __init__.py 8 | # time: 4:40 下午 9 | 10 | 11 | if __name__ == "__main__": 12 | pass 13 | -------------------------------------------------------------------------------- /tests/test_labeling/test_bi_gru_crf_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Author : BrikerMan 4 | # Site : https://eliyar.biz 5 | 6 | # Time : 2020/9/2 2:09 下午 7 | # File : test_bi_gru_crf_model.py 8 | # Project : Kashgari 9 | from distutils.version import LooseVersion 10 | 11 | import pytest 12 | import tensorflow as tf 13 | 14 | import tests.test_labeling.test_bi_lstm_model as base 15 | from kashgari.tasks.labeling import BiGRU_CRF_Model 16 | 17 | 18 | @pytest.mark.skipif(LooseVersion(tf.__version__) < '2.2.0', 19 | reason="The KConditionalRandomField requires TensorFlow 2.2.x version or higher.") 20 | class TestBiGRU_CRF_Model(base.TestBiLSTM_Model): 21 | 22 | @classmethod 23 | def setUpClass(cls): 24 | cls.EPOCH_COUNT = 1 25 | cls.TASK_MODEL_CLASS = BiGRU_CRF_Model 26 | -------------------------------------------------------------------------------- /tests/test_labeling/test_bi_gru_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: test_bi_gru_model.py 8 | # time: 12:35 上午 9 | 10 | import unittest 11 | 12 | import tests.test_labeling.test_bi_lstm_model as base 13 | from kashgari.tasks.labeling import BiGRU_Model 14 | 15 | 16 | class TestBiGRU_Model(base.TestBiLSTM_Model): 17 | 18 | @classmethod 19 | def setUpClass(cls): 20 | cls.EPOCH_COUNT = 1 21 | cls.TASK_MODEL_CLASS = BiGRU_Model 22 | 23 | def test_basic_use(self): 24 | super(TestBiGRU_Model, self).test_basic_use() 25 | 26 | def test_predict_and_callback(self): 27 | from kashgari.corpus import ChineseDailyNerCorpus 28 | from kashgari.callbacks import EvalCallBack 29 | 30 | train_x, train_y = ChineseDailyNerCorpus.load_data('train') 31 | valid_x, valid_y = ChineseDailyNerCorpus.load_data('valid') 32 | 33 | model = BiGRU_Model(sequence_length=10) 34 | 35 | eval_callback = EvalCallBack(kash_model=model, 36 | x_data=valid_x[:200], 37 | y_data=valid_y[:200], 38 | truncating=True, 39 | step=1) 40 | 41 | model.fit(train_x[:300], train_y[:300], 42 | valid_x[:200], valid_y[:200], 43 | epochs=1, 44 | callbacks=[eval_callback]) 45 | response = model.predict(train_x[:200], truncating=True) 46 | lengths = [len(i) for i in response] 47 | assert all([(i <= 10) for i in lengths]) 48 | 49 | response = model.predict(train_x[:200]) 50 | lengths = [len(i) for i in response] 51 | assert not all([(i <= 10) for i in lengths]) 52 | 53 | 54 | if __name__ == "__main__": 55 | pass 56 | -------------------------------------------------------------------------------- /tests/test_labeling/test_bi_lstm_crf_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Author : BrikerMan 4 | # Site : https://eliyar.biz 5 | 6 | # Time : 2020/9/2 2:10 下午 7 | # File : test_bi_lstm_crf_model.py 8 | # Project : Kashgari 9 | 10 | from distutils.version import LooseVersion 11 | 12 | import pytest 13 | import tensorflow as tf 14 | 15 | import tests.test_labeling.test_bi_lstm_model as base 16 | from kashgari.tasks.labeling import BiLSTM_CRF_Model 17 | 18 | 19 | @pytest.mark.skipif(LooseVersion(tf.__version__) < '2.2.0', 20 | reason="The KConditionalRandomField requires TensorFlow 2.2.x version or higher.") 21 | class TestBiLSTM_CRF_Model(base.TestBiLSTM_Model): 22 | 23 | @classmethod 24 | def setUpClass(cls): 25 | cls.EPOCH_COUNT = 1 26 | cls.TASK_MODEL_CLASS = BiLSTM_CRF_Model 27 | -------------------------------------------------------------------------------- /tests/test_labeling/test_bi_lstm_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: test_bi_lstm_model.py 8 | # time: 4:41 下午 9 | 10 | import os 11 | import tempfile 12 | import time 13 | import unittest 14 | from typing import Type 15 | 16 | from tensorflow.keras.utils import get_file 17 | 18 | from kashgari.embeddings import BertEmbedding 19 | from kashgari.embeddings import WordEmbedding 20 | from kashgari.macros import DATA_PATH 21 | from kashgari.tasks.labeling import BiLSTM_Model, ABCLabelingModel 22 | from tests.test_macros import TestMacros 23 | 24 | 25 | class TestBiLSTM_Model(unittest.TestCase): 26 | 27 | @classmethod 28 | def setUpClass(cls): 29 | cls.EPOCH_COUNT = 1 30 | cls.TASK_MODEL_CLASS: Type[ABCLabelingModel] = BiLSTM_Model 31 | 32 | def test_basic_use(self): 33 | model = self.TASK_MODEL_CLASS() 34 | train_x, train_y = TestMacros.load_labeling_corpus() 35 | 36 | model.fit(train_x, 37 | train_y, 38 | epochs=self.EPOCH_COUNT) 39 | 40 | model_path = os.path.join(tempfile.gettempdir(), str(time.time())) 41 | original_y = model.predict(train_x[:20]) 42 | model.save(model_path) 43 | del model 44 | 45 | new_model = self.TASK_MODEL_CLASS.load_model(model_path) 46 | new_model.tf_model.summary() 47 | new_y = new_model.predict(train_x[:20]) 48 | assert new_y == original_y 49 | 50 | report = new_model.evaluate(train_x, train_y) 51 | print(report) 52 | 53 | def test_with_word_embedding(self): 54 | w2v_embedding = WordEmbedding(TestMacros.w2v_path) 55 | model = self.TASK_MODEL_CLASS(embedding=w2v_embedding, sequence_length=120) 56 | train_x, train_y = TestMacros.load_labeling_corpus() 57 | valid_x, valid_y = train_x, train_y 58 | 59 | model.fit(train_x, 60 | train_y, 61 | x_validate=valid_x, 62 | y_validate=valid_y, 63 | epochs=self.EPOCH_COUNT) 64 | 65 | def test_with_bert(self): 66 | bert_path = get_file('bert_sample_model', 67 | "http://s3.bmio.net/kashgari/bert_sample_model.tar.bz2", 68 | cache_dir=DATA_PATH, 69 | untar=True) 70 | embedding = BertEmbedding(model_folder=bert_path) 71 | model = self.TASK_MODEL_CLASS(embedding=embedding) 72 | train_x, train_y = TestMacros.load_labeling_corpus() 73 | valid_x, valid_y = train_x, train_y 74 | 75 | model.fit(train_x, 76 | train_y, 77 | x_validate=valid_x, 78 | y_validate=valid_y, 79 | epochs=self.EPOCH_COUNT) 80 | 81 | model.evaluate(valid_x, valid_y) 82 | model.evaluate(valid_x, valid_y, truncating=True) 83 | model.predict(valid_x) 84 | model.predict(valid_x, truncating=True) 85 | 86 | 87 | if __name__ == '__main__': 88 | unittest.main() 89 | -------------------------------------------------------------------------------- /tests/test_labeling/test_cnn_lstm_model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: test_cnn_lstm_model.py 8 | # time: 5:41 下午 9 | 10 | 11 | import unittest 12 | 13 | import tests.test_labeling.test_bi_lstm_model as base 14 | from kashgari.tasks.labeling import CNN_LSTM_Model 15 | 16 | 17 | class TestCNN_LSTM_Model(base.TestBiLSTM_Model): 18 | 19 | @classmethod 20 | def setUpClass(cls): 21 | cls.EPOCH_COUNT = 1 22 | cls.TASK_MODEL_CLASS = CNN_LSTM_Model 23 | 24 | 25 | if __name__ == "__main__": 26 | unittest.main() 27 | -------------------------------------------------------------------------------- /tests/test_processor/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __init__.py 8 | # time: 12:06 下午 9 | 10 | 11 | if __name__ == "__main__": 12 | pass 13 | -------------------------------------------------------------------------------- /tests/test_processor/test_class_processor.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: test_class_processor.py 8 | # time: 12:04 下午 9 | 10 | 11 | import unittest 12 | from tests.test_macros import TestMacros 13 | 14 | from kashgari.utils import load_data_object 15 | from kashgari.generators import CorpusGenerator 16 | from kashgari.processors import ClassificationProcessor 17 | 18 | 19 | class TestClassificationProcessor(unittest.TestCase): 20 | def test_processor(self): 21 | x_set, y_set = TestMacros.load_classification_corpus() 22 | processor = ClassificationProcessor() 23 | processor.build_vocab(x_set, y_set) 24 | transformed_idx = processor.transform(y_set[20:40]) 25 | 26 | info_dict = processor.to_dict() 27 | 28 | p2: ClassificationProcessor = load_data_object(info_dict) 29 | assert (transformed_idx == p2.transform(y_set[20:40])).all() 30 | assert y_set[20:40] == p2.inverse_transform(transformed_idx) 31 | 32 | def test_multi_label_processor(self): 33 | from kashgari.corpus import JigsawToxicCommentCorpus 34 | file_path = TestMacros.jigsaw_mini_corpus_path 35 | corpus = JigsawToxicCommentCorpus(file_path) 36 | x_set, y_set = corpus.load_data() 37 | 38 | corpus_gen = CorpusGenerator(x_set, y_set) 39 | 40 | processor = ClassificationProcessor(multi_label=True) 41 | processor.build_vocab_generator([corpus_gen]) 42 | transformed_idx = processor.transform(y_set[20:40]) 43 | 44 | info_dict = processor.to_dict() 45 | 46 | p2: ClassificationProcessor = load_data_object(info_dict) 47 | assert (transformed_idx == p2.transform(y_set[20:40])).all() 48 | 49 | x1s = y_set[20:40] 50 | x2s = p2.inverse_transform(transformed_idx) 51 | for sample_x1, sample_x2 in zip(x1s, x2s): 52 | assert sorted(sample_x1) == sorted(sample_x2) 53 | 54 | 55 | if __name__ == "__main__": 56 | pass 57 | -------------------------------------------------------------------------------- /tests/test_processor/test_sequence_processor.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: test_sequence_processor.py 8 | # time: 12:09 下午 9 | 10 | import random 11 | import unittest 12 | from tests.test_macros import TestMacros 13 | 14 | from kashgari.utils import load_data_object 15 | from kashgari.processors import SequenceProcessor 16 | 17 | 18 | class TestSequenceProcessor(unittest.TestCase): 19 | def test_text_processor(self): 20 | x_set, y_set = TestMacros.load_labeling_corpus() 21 | x_samples = random.sample(x_set, 5) 22 | text_processor = SequenceProcessor(min_count=1) 23 | text_processor.build_vocab(x_set, y_set) 24 | text_idx = text_processor.transform(x_samples) 25 | 26 | text_info_dict = text_processor.to_dict() 27 | text_processor2: SequenceProcessor = load_data_object(text_info_dict) 28 | 29 | text_idx2 = text_processor2.transform(x_samples) 30 | sample_lengths = [len(i) for i in x_samples] 31 | 32 | assert (text_idx2 == text_idx).all() 33 | assert text_processor.inverse_transform(text_idx, lengths=sample_lengths) == x_samples 34 | assert text_processor2.inverse_transform(text_idx2, lengths=sample_lengths) == x_samples 35 | 36 | def test_label_processor(self): 37 | x_set, y_set = TestMacros.load_labeling_corpus() 38 | text_processor = SequenceProcessor(build_vocab_from_labels=True, min_count=1) 39 | text_processor.build_vocab(x_set, y_set) 40 | 41 | samples = random.sample(y_set, 20) 42 | 43 | text_idx = text_processor.transform(samples) 44 | 45 | text_info_dict = text_processor.to_dict() 46 | 47 | text_processor2: SequenceProcessor = load_data_object(text_info_dict) 48 | 49 | text_idx2 = text_processor2.transform(samples) 50 | lengths = [len(i) for i in samples] 51 | assert (text_idx2 == text_idx).all() 52 | assert text_processor2.inverse_transform(text_idx, lengths=lengths) == samples 53 | assert text_processor2.inverse_transform(text_idx2, lengths=lengths) == samples 54 | 55 | text_idx3 = text_processor.transform(samples, seq_length=20) 56 | assert [len(i) for i in text_idx3] == [20] * len(text_idx3) 57 | 58 | 59 | if __name__ == "__main__": 60 | pass 61 | -------------------------------------------------------------------------------- /tests/test_seq2seq/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: __init__.py 8 | # time: 4:46 下午 9 | 10 | 11 | if __name__ == "__main__": 12 | pass 13 | -------------------------------------------------------------------------------- /tests/test_seq2seq/test_seq2seq.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: test_seq2seq.py 8 | # time: 4:46 下午 9 | 10 | import os 11 | import time 12 | import unittest 13 | import tempfile 14 | from kashgari.tasks.seq2seq import Seq2Seq 15 | from kashgari.corpus import ChineseDailyNerCorpus 16 | 17 | 18 | class TestSeq2Seq(unittest.TestCase): 19 | def test_base_use_case(self): 20 | x, y = ChineseDailyNerCorpus.load_data('test') 21 | x = x[:200] 22 | y = y[:200] 23 | seq2seq = Seq2Seq(hidden_size=64, 24 | encoder_seq_length=64, 25 | decoder_seq_length=64) 26 | seq2seq.fit(x, y, epochs=1) 27 | res, att = seq2seq.predict(x) 28 | 29 | model_path = os.path.join(tempfile.gettempdir(), str(time.time())) 30 | seq2seq.save(model_path) 31 | 32 | s2 = Seq2Seq.load_model(model_path) 33 | res2, att2 = s2.predict(x) 34 | 35 | assert res2 == res 36 | assert (att2 == att).all() 37 | 38 | 39 | if __name__ == '__main__': 40 | unittest.main() 41 | -------------------------------------------------------------------------------- /tests/test_tokenizers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Author : BrikerMan 4 | # Site : https://eliyar.biz 5 | 6 | # Time : 2020/9/4 6:46 上午 7 | # File : test_tokenizers.py 8 | # Project : Kashgari 9 | 10 | import unittest 11 | import numpy as np 12 | import os 13 | from kashgari.tokenizers import Tokenizer, JiebaTokenizer, BertTokenizer 14 | from tests.test_macros import TestMacros 15 | 16 | 17 | class TestUtils(unittest.TestCase): 18 | 19 | def test_jieba_tokenizer(self): 20 | os.system("pip3 uninstall -y jieba") 21 | 22 | with self.assertRaises(ModuleNotFoundError): 23 | _ = JiebaTokenizer() 24 | 25 | os.system("pip3 install jieba") 26 | t = JiebaTokenizer() 27 | assert ['你好', '世界', '!', ' ', 'Hello', ' ', 'World'] == t.tokenize('你好世界! Hello World') 28 | 29 | def test_base_tokenizer(self): 30 | t = Tokenizer() 31 | assert ['Hello', 'World'] == t.tokenize('Hello World') 32 | 33 | def test_bert_tokenizer(self): 34 | bert_path = TestMacros.bert_path 35 | vocab_path = os.path.join(bert_path, 'vocab.txt') 36 | tokenizer = BertTokenizer.load_from_vocab_file(vocab_path) 37 | 38 | assert ['你', '好', '世', '界', '!', 39 | 'h', '##e', '##l', '##l', '##o', 40 | 'w', '##o', '##r', '##l', '##d'] == tokenizer.tokenize("你好世界! Hello World") 41 | assert ['jack', 'makes', 'c', '##a', '##k', '##e'] == tokenizer.tokenize("Jack makes cake") 42 | assert ['你', '好', '呀'] == tokenizer.tokenize("你好呀") 43 | 44 | tokenizer = BertTokenizer() 45 | assert ['你', '好', '世', '界', '!', 'hello', 'world'] == tokenizer.tokenize("你好世界! Hello World") 46 | assert ['jack', 'makes', 'cake'] == tokenizer.tokenize("Jack makes cake") 47 | assert ['你', '好', '呀'] == tokenizer.tokenize("你好呀") 48 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # author: BrikerMan 4 | # contact: eliyar917@gmail.com 5 | # blog: https://eliyar.biz 6 | 7 | # file: test_utils.py 8 | # time: 10:48 上午 9 | 10 | import unittest 11 | import numpy as np 12 | from kashgari.utils import unison_shuffled_copies 13 | from kashgari.utils import get_list_subset 14 | 15 | 16 | class TestUtils(unittest.TestCase): 17 | 18 | def test_unison_shuffled_copies(self): 19 | x: np.ndarray = np.random.randint(0, 10, size=(100, 5)) 20 | y: np.ndarray = np.random.randint(0, 10, size=(100, )) 21 | 22 | new_x, new_y = unison_shuffled_copies(x, y) 23 | assert new_x.shape == x.shape 24 | assert new_y.shape == y.shape 25 | 26 | def test_get_list_subset(self): 27 | x = list(range(0, 100)) 28 | subset = get_list_subset(x, list(range(10, 20))) 29 | assert subset == [10, 11, 12, 13, 14, 15, 16, 17, 18, 19] 30 | 31 | 32 | if __name__ == "__main__": 33 | pass 34 | --------------------------------------------------------------------------------