├── Dockerfile ├── requirements.txt ├── tests ├── __init__.py ├── tasks │ ├── __init__.py │ └── ner │ │ ├── __init__.py │ │ └── test_bert.py ├── utils.py └── dummy.py ├── fastie ├── tasks │ ├── ee │ │ └── __init__.py │ ├── re │ │ └── __init__.py │ ├── ner │ │ ├── bert │ │ │ ├── __init__.py │ │ │ └── bert.py │ │ ├── __init__.py │ │ └── BaseNERTask.py │ ├── __init__.py │ ├── sequential_task.py │ └── build_task.py ├── dataset │ ├── legacy │ │ ├── __init__.py │ │ ├── wikiann.py │ │ └── conll2003.py │ ├── io │ │ ├── __init__.py │ │ ├── sentence.py │ │ ├── jsonlinesNER.py │ │ └── columnNER.py │ ├── __init__.py │ ├── build_dataset.py │ └── base_dataset.py ├── utils │ ├── __init__.py │ ├── hub.py │ ├── path.py │ ├── utils.py │ ├── misc.py │ └── registry.py ├── controller │ ├── __init__.py │ ├── interactor.py │ ├── evaluator.py │ ├── trainer.py │ ├── base_controller.py │ └── inference.py ├── chain.py ├── __init__.py ├── envs.py ├── exhibition.py ├── command.py └── node.py ├── docs ├── source │ ├── docutils.conf │ ├── _static │ │ ├── image │ │ │ ├── fastie-logo.png │ │ │ └── fastie-logo2.png │ │ └── css │ │ │ ├── readthedocs.css │ │ │ └── badge_only.css │ ├── tutorials │ │ └── basic │ │ │ ├── figures │ │ │ └── T3-task-life.jpg │ │ │ ├── index.rst │ │ │ ├── fastie_tutorial_1.ipynb │ │ │ ├── fastie_tutorial_2.ipynb │ │ │ └── fastie_tutorial_0.ipynb │ ├── _templates │ │ └── classtemplate.rst │ ├── api │ │ ├── generated │ │ │ ├── fastie.dataset.Wikiann.rst │ │ │ ├── fastie.node.BaseNode.rst │ │ │ ├── fastie.tasks.BaseTask.rst │ │ │ ├── fastie.tasks.BertNER.rst │ │ │ ├── fastie.dataset.Sentence.rst │ │ │ ├── fastie.controller.Trainer.rst │ │ │ ├── fastie.dataset.ColumnNER.rst │ │ │ ├── fastie.dataset.Conll2003.rst │ │ │ ├── fastie.controller.Evaluator.rst │ │ │ ├── fastie.controller.Inference.rst │ │ │ ├── fastie.controller.Interactor.rst │ │ │ ├── fastie.dataset.BaseDataset.rst │ │ │ ├── fastie.dataset.JsonLinesNER.rst │ │ │ ├── fastie.tasks.BertNERConfig.rst │ │ │ ├── fastie.dataset.WikiannConfig.rst │ │ │ ├── fastie.dataset.build_dataset.rst │ │ │ ├── fastie.node.BaseNodeConfig.rst │ │ │ ├── fastie.tasks.BaseTaskConfig.rst │ │ │ ├── fastie.dataset.SentenceConfig.rst │ │ │ ├── fastie.dataset.ColumnNERConfig.rst │ │ │ ├── fastie.dataset.Conll2003Config.rst │ │ │ ├── fastie.dataset.BaseDatasetConfig.rst │ │ │ └── fastie.dataset.JsonLinesNERConfig.rst │ │ ├── node.rst │ │ ├── controller.rst │ │ ├── tasks.rst │ │ └── dataset.rst │ ├── index.rst │ └── conf.py ├── requirements.txt └── Makefile ├── README.MD ├── configs └── ner │ └── bert │ └── bert-conll2003.py ├── setup.py ├── .pre-commit-config.yaml └── .gitignore /Dockerfile: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fastie/tasks/ee/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /fastie/tasks/re/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/tasks/ner/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/source/docutils.conf: -------------------------------------------------------------------------------- 1 | [html writers] 2 | table_style: colwidths-auto 3 | -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | # FastIE 2 | 3 | A general framework for information extraction. 4 | -------------------------------------------------------------------------------- /fastie/tasks/ner/bert/__init__.py: -------------------------------------------------------------------------------- 1 | from .bert import BertNER, BertNERConfig 2 | 3 | __all__ = ['BertNER', 'BertNERConfig'] 4 | -------------------------------------------------------------------------------- /docs/source/_static/image/fastie-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-nlplab/fastIE/HEAD/docs/source/_static/image/fastie-logo.png -------------------------------------------------------------------------------- /docs/source/_static/image/fastie-logo2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-nlplab/fastIE/HEAD/docs/source/_static/image/fastie-logo2.png -------------------------------------------------------------------------------- /configs/ner/bert/bert-conll2003.py: -------------------------------------------------------------------------------- 1 | _help = '使用 bert 对 conll2003 数据集进行序列标注' 2 | config = dict( 3 | task='ner/bert', 4 | dataset='conll2003', 5 | ) 6 | -------------------------------------------------------------------------------- /docs/source/tutorials/basic/figures/T3-task-life.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-nlplab/fastIE/HEAD/docs/source/tutorials/basic/figures/T3-task-life.jpg -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | nbsphinx 2 | -e git+https://github.com/x54-729/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 3 | sphinx 4 | sphinx-multiversion 5 | sphinx_autodoc_typehints 6 | sphinx_copybutton 7 | -------------------------------------------------------------------------------- /fastie/dataset/legacy/__init__.py: -------------------------------------------------------------------------------- 1 | from .conll2003 import Conll2003, Conll2003Config 2 | from .wikiann import Wikiann, WikiannConfig 3 | 4 | __all__ = ['Conll2003', 'Conll2003Config', 'Wikiann', 'WikiannConfig'] 5 | -------------------------------------------------------------------------------- /fastie/tasks/ner/__init__.py: -------------------------------------------------------------------------------- 1 | from .BaseNERTask import BaseNERTask, BaseNERTaskConfig 2 | from .bert import BertNER, BertNERConfig 3 | 4 | __all__ = ['BertNER', 'BertNERConfig', 'BaseNERTaskConfig', 'BaseNERTask'] 5 | -------------------------------------------------------------------------------- /docs/source/tutorials/basic/index.rst: -------------------------------------------------------------------------------- 1 | FastIE 基础教程 2 | ================ 3 | 4 | 下面的教程系列介绍了 FastIE 基本使用方法,通过阅读它们您可以对 FastIE 有一个基础的了解。 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | :glob: 9 | 10 | ./* 11 | -------------------------------------------------------------------------------- /fastie/dataset/io/__init__.py: -------------------------------------------------------------------------------- 1 | from .columnNER import ColumnNER, ColumnNERConfig 2 | from .jsonlinesNER import JsonLinesNER, JsonLinesNERConfig 3 | from .sentence import Sentence, SentenceConfig 4 | 5 | __all__ = [ 6 | 'ColumnNER', 'ColumnNERConfig', 'Sentence', 'SentenceConfig', 7 | 'JsonLinesNERConfig', 'JsonLinesNER' 8 | ] 9 | -------------------------------------------------------------------------------- /docs/source/_templates/classtemplate.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline}} 7 | 8 | .. autoclass:: {{ name }} 9 | :members: 10 | 11 | 12 | .. 13 | autogenerated from source/_templates/classtemplate.rst 14 | note it does not have :inherited-members: 15 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.dataset.Wikiann.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.dataset 4 | 5 | 6 | Wikiann 7 | ======= 8 | 9 | .. autoclass:: Wikiann 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.node.BaseNode.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.node 4 | 5 | 6 | BaseNode 7 | ======== 8 | 9 | .. autoclass:: BaseNode 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.tasks.BaseTask.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.tasks 4 | 5 | 6 | BaseTask 7 | ======== 8 | 9 | .. autoclass:: BaseTask 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.tasks.BertNER.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.tasks 4 | 5 | 6 | BertNER 7 | ======= 8 | 9 | .. autoclass:: BertNER 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.dataset.Sentence.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.dataset 4 | 5 | 6 | Sentence 7 | ======== 8 | 9 | .. autoclass:: Sentence 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.controller.Trainer.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.controller 4 | 5 | 6 | Trainer 7 | ======= 8 | 9 | .. autoclass:: Trainer 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.dataset.ColumnNER.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.dataset 4 | 5 | 6 | ColumnNER 7 | ========= 8 | 9 | .. autoclass:: ColumnNER 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.dataset.Conll2003.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.dataset 4 | 5 | 6 | Conll2003 7 | ========= 8 | 9 | .. autoclass:: Conll2003 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.controller.Evaluator.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.controller 4 | 5 | 6 | Evaluator 7 | ========= 8 | 9 | .. autoclass:: Evaluator 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.controller.Inference.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.controller 4 | 5 | 6 | Inference 7 | ========= 8 | 9 | .. autoclass:: Inference 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.controller.Interactor.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.controller 4 | 5 | 6 | Interactor 7 | ========== 8 | 9 | .. autoclass:: Interactor 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.dataset.BaseDataset.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.dataset 4 | 5 | 6 | BaseDataset 7 | =========== 8 | 9 | .. autoclass:: BaseDataset 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.dataset.JsonLinesNER.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.dataset 4 | 5 | 6 | JsonLinesNER 7 | ============ 8 | 9 | .. autoclass:: JsonLinesNER 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.tasks.BertNERConfig.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.tasks 4 | 5 | 6 | BertNERConfig 7 | ============= 8 | 9 | .. autoclass:: BertNERConfig 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.dataset.WikiannConfig.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.dataset 4 | 5 | 6 | WikiannConfig 7 | ============= 8 | 9 | .. autoclass:: WikiannConfig 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.dataset.build_dataset.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.dataset 4 | 5 | 6 | build_dataset 7 | ============= 8 | 9 | .. autoclass:: build_dataset 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.node.BaseNodeConfig.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.node 4 | 5 | 6 | BaseNodeConfig 7 | ============== 8 | 9 | .. autoclass:: BaseNodeConfig 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.tasks.BaseTaskConfig.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.tasks 4 | 5 | 6 | BaseTaskConfig 7 | ============== 8 | 9 | .. autoclass:: BaseTaskConfig 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.dataset.SentenceConfig.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.dataset 4 | 5 | 6 | SentenceConfig 7 | ============== 8 | 9 | .. autoclass:: SentenceConfig 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.dataset.ColumnNERConfig.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.dataset 4 | 5 | 6 | ColumnNERConfig 7 | =============== 8 | 9 | .. autoclass:: ColumnNERConfig 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.dataset.Conll2003Config.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.dataset 4 | 5 | 6 | Conll2003Config 7 | =============== 8 | 9 | .. autoclass:: Conll2003Config 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.dataset.BaseDatasetConfig.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.dataset 4 | 5 | 6 | BaseDatasetConfig 7 | ================= 8 | 9 | .. autoclass:: BaseDatasetConfig 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /fastie/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import Config 2 | from .hub import Hub 3 | from .registry import Registry 4 | from .utils import generate_tag_vocab, check_loaded_tag_vocab, parse_config, \ 5 | inspect_function_calling 6 | 7 | __all__ = [ 8 | 'Registry', 'Config', 'Hub', 'generate_tag_vocab', 9 | 'check_loaded_tag_vocab', 'parse_config', 'inspect_function_calling' 10 | ] 11 | -------------------------------------------------------------------------------- /docs/source/api/generated/fastie.dataset.JsonLinesNERConfig.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: fastie.dataset 4 | 5 | 6 | JsonLinesNERConfig 7 | ================== 8 | 9 | .. autoclass:: JsonLinesNERConfig 10 | :members: 11 | 12 | 13 | .. 14 | autogenerated from source/_templates/classtemplate.rst 15 | note it does not have :inherited-members: 16 | -------------------------------------------------------------------------------- /docs/source/_static/css/readthedocs.css: -------------------------------------------------------------------------------- 1 | table.colwidths-auto td { 2 | width: 50% 3 | } 4 | 5 | .header-logo { 6 | background-image: url("../image/fastie-logo2.png"); 7 | background-size: 130px 40px; 8 | height: 40px; 9 | width: 130px; 10 | } 11 | 12 | .two-column-table-wrapper { 13 | width: 50%; 14 | max-width: 300px; 15 | overflow-x: auto; 16 | } 17 | 18 | .two-column-table-wrapper .highlight { 19 | width: 1500px 20 | } 21 | -------------------------------------------------------------------------------- /docs/source/api/node.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | fastie.node 5 | =================================== 6 | 7 | .. contents:: fastie.node 8 | :local: 9 | :depth: 2 10 | :backlinks: top 11 | 12 | .. currentmodule:: fastie.node 13 | 14 | Node 15 | ---------------- 16 | 17 | .. autosummary:: 18 | :toctree: generated 19 | :nosignatures: 20 | :template: classtemplate.rst 21 | 22 | BaseNode 23 | BaseNodeConfig 24 | -------------------------------------------------------------------------------- /fastie/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_task import BaseTask, BaseTaskConfig, NER, RE, EE 2 | from .build_task import build_task 3 | from .ner import BertNER, BertNERConfig, BaseNERTask, BaseNERTaskConfig 4 | from .sequential_task import SequentialTask, SequentialTaskConfig 5 | 6 | __all__ = [ 7 | 'BaseTask', 'BaseTaskConfig', 'NER', 'RE', 'EE', 'build_task', 'BertNER', 8 | 'BertNERConfig', 'BaseNERTask', 'BaseNERTaskConfig', 'SequentialTask', 9 | 'SequentialTaskConfig' 10 | ] 11 | -------------------------------------------------------------------------------- /fastie/controller/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_controller import BaseController, CONTROLLER 2 | from .evaluator import Evaluator, EvaluatorConfig 3 | from .inference import Inference, InferenceConfig 4 | from .interactor import Interactor, InteractorConfig 5 | from .trainer import Trainer, TrainerConfig 6 | 7 | __all__ = [ 8 | 'BaseController', 'CONTROLLER', 'Trainer', 'TrainerConfig', 'Inference', 9 | 'InferenceConfig', 'Evaluator', 'EvaluatorConfig', 'Interactor', 10 | 'InteractorConfig' 11 | ] 12 | -------------------------------------------------------------------------------- /docs/source/api/controller.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | fastie.controller 5 | =================================== 6 | 7 | .. contents:: fastie.controller 8 | :local: 9 | :depth: 2 10 | :backlinks: top 11 | 12 | .. currentmodule:: fastie.controller 13 | 14 | Controller 15 | ---------------- 16 | 17 | .. autosummary:: 18 | :toctree: generated 19 | :nosignatures: 20 | :template: classtemplate.rst 21 | 22 | Trainer 23 | Evaluator 24 | Inference 25 | Interactor 26 | -------------------------------------------------------------------------------- /fastie/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_dataset import BaseDataset, DATASET, BaseDatasetConfig 2 | from .build_dataset import build_dataset 3 | from .io import ColumnNER, ColumnNERConfig, Sentence, SentenceConfig, \ 4 | JsonLinesNER, JsonLinesNERConfig 5 | from .legacy import Conll2003, Conll2003Config, WikiannConfig, Wikiann 6 | 7 | __all__ = [ 8 | 'BaseDataset', 'DATASET', 'BaseDatasetConfig', 'Conll2003', 9 | 'Conll2003Config', 'ColumnNER', 'ColumnNERConfig', 'Sentence', 10 | 'SentenceConfig', 'build_dataset', 'JsonLinesNER', 'JsonLinesNERConfig', 11 | 'Wikiann', 'WikiannConfig' 12 | ] 13 | -------------------------------------------------------------------------------- /tests/tasks/ner/test_bert.py: -------------------------------------------------------------------------------- 1 | from fastie.tasks.ner import BertNER 2 | from tests.dummy import dummy_ner_dataset 3 | from tests.utils import UnifiedTaskTest 4 | 5 | 6 | class TestBertNER(UnifiedTaskTest): 7 | 8 | def setup_class(self): 9 | super().setup_class(self, 10 | task_cls=BertNER, 11 | data_bundle=dummy_ner_dataset(), 12 | extra_parameters={ 13 | 'pretrained_model_name_or_path': 14 | 'prajjwal1/bert-tiny' 15 | }) 16 | -------------------------------------------------------------------------------- /docs/source/api/tasks.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | fastie.tasks 5 | =================================== 6 | 7 | .. contents:: fastie.tasks 8 | :local: 9 | :depth: 2 10 | :backlinks: top 11 | 12 | .. currentmodule:: fastie.tasks 13 | 14 | BaseTask 15 | ---------------- 16 | 17 | .. autosummary:: 18 | :toctree: generated 19 | :nosignatures: 20 | :template: classtemplate.rst 21 | 22 | BaseTask 23 | BaseTaskConfig 24 | 25 | BertNER 26 | ---------------- 27 | 28 | .. autosummary:: 29 | :toctree: generated 30 | :nosignatures: 31 | :template: classtemplate.rst 32 | 33 | BertNER 34 | BertNERConfig 35 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | FastIE 中文文档 2 | ===================== 3 | 4 | `FastIE `_ 是一款信息抽取通用集成框架。您既可以快速利用 FastIE 中的海量模型对您的数据进行训练、检验、推理,又可以快速基于现有模型修改,创建自己的模型。 5 | 6 | FastIE 具有如下的特性: 7 | 8 | - 简化训练、检验、推理过程,SOTA 模型开箱即用; 9 | 10 | 11 | .. toctree:: 12 | :maxdepth: 3 13 | :caption: 快速上手 14 | 15 | tutorials/basic/index 16 | 17 | 18 | .. toctree:: 19 | :maxdepth: 2 20 | :caption: API 文档 21 | 22 | fastie.dataset 23 | fastie.tasks 24 | fastie.controller 25 | fastie.node 26 | 27 | 28 | 29 | 索引与搜索 30 | ================== 31 | 32 | * :ref:`genindex` 33 | * :ref:`modindex` 34 | * :ref:`search` 35 | -------------------------------------------------------------------------------- /fastie/utils/hub.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import pickle 4 | 5 | 6 | class Unpickler(pickle.Unpickler): 7 | 8 | def find_class(self, module, name): 9 | if module == 'torch.storage' and name == '_load_from_bytes': 10 | import torch 11 | return lambda b: torch.load(io.BytesIO(b), map_location='cpu') 12 | else: 13 | return super().find_class(module, name) 14 | 15 | 16 | class Hub: 17 | 18 | @classmethod 19 | def load(cls, path: str): 20 | if os.path.exists(path) and os.path.isfile(path): 21 | with open(path, mode='br') as file: 22 | return Unpickler(file).load() 23 | else: 24 | # 到 s3 上下载 25 | pass 26 | 27 | @classmethod 28 | def save(cls, path: str, state_dict: dict): 29 | with open(path, mode='wb') as file: 30 | pickle.dump(state_dict, file) 31 | -------------------------------------------------------------------------------- /fastie/chain.py: -------------------------------------------------------------------------------- 1 | from defaultlist import defaultlist 2 | 3 | from fastie.controller import BaseController 4 | from fastie.dataset import BaseDataset 5 | from fastie.tasks import BaseTask 6 | 7 | 8 | class Chain(defaultlist): 9 | 10 | def __init__(self, *args): 11 | defaultlist.__init__(self, *args) 12 | 13 | def run(self): 14 | result = None 15 | for node in self: 16 | if node is None: 17 | continue 18 | result = node(result) 19 | return result 20 | 21 | def __add__(self, other): 22 | if isinstance(other, BaseTask): 23 | self[1] = other 24 | elif isinstance(other, BaseDataset): 25 | self[0] = other 26 | elif isinstance(other, BaseController): 27 | self[2] = other 28 | _ = other.parser 29 | return self 30 | 31 | def __call__(self, *args, **kwargs): 32 | self.run() 33 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXAPIDOC = sphinx-apidoc 7 | SPHINXBUILD = sphinx-build 8 | SPHINXPROJ = fastie 9 | # SPHINXEXCLUDE = ../FastIE/transformers/* 10 | SOURCEDIR = source 11 | BUILDDIR = temp/ 12 | PORT = 8000 13 | 14 | # Put it first so that "make" without argument is like "make help". 15 | help: 16 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) 17 | 18 | server: 19 | cd @$(BUILDDIR)/html && python -m http.server $(PORT) 20 | 21 | versions: 22 | sphinx-multiversion "$(SOURCEDIR)" "$(BUILDDIR)" 23 | 24 | server-versions: 25 | cd $(BUILDDIR) && python -m http.server $(PORT) 26 | 27 | .PHONY: help Makefile 28 | 29 | # Catch-all target: route all unknown targets to Sphinx using the new 30 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 31 | %: Makefile 32 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 33 | -------------------------------------------------------------------------------- /fastie/__init__.py: -------------------------------------------------------------------------------- 1 | from .chain import Chain 2 | from .controller import CONTROLLER, Trainer, TrainerConfig, Evaluator, \ 3 | EvaluatorConfig, Inference, InferenceConfig, Interactor, InteractorConfig 4 | from .dataset import DATASET, BaseDataset, BaseDatasetConfig 5 | from .envs import get_flag, set_flag, parser, get_task, logger 6 | from .node import BaseNode, BaseNodeConfig 7 | from .tasks import NER, EE, RE, BaseTask, BaseTaskConfig 8 | from .utils import Registry, Config, Hub, parse_config, generate_tag_vocab, \ 9 | check_loaded_tag_vocab 10 | 11 | __all__ = [ 12 | 'BaseNode', 'Chain', 'get_flag', 'set_flag', 'parser', 'BaseNodeConfig', 13 | 'parse_config', 'NER', 'EE', 'RE', 'DATASET', 'CONTROLLER', 'Trainer', 14 | 'TrainerConfig', 'Evaluator', 'EvaluatorConfig', 'Inference', 15 | 'InferenceConfig', 'Interactor', 'InteractorConfig', 'Registry', 'Config', 16 | 'Hub', 'logger', 'generate_tag_vocab', 'check_loaded_tag_vocab', 17 | 'BaseTask', 'BaseTaskConfig', 'BaseDataset', 'BaseDatasetConfig', 18 | 'get_task' 19 | ] 20 | -------------------------------------------------------------------------------- /docs/source/api/dataset.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | fastie.dataset 5 | =================================== 6 | 7 | .. contents:: fastie.dataset 8 | :local: 9 | :depth: 2 10 | :backlinks: top 11 | 12 | .. currentmodule:: fastie.dataset 13 | 14 | BaseDataset 15 | ---------------- 16 | 17 | .. autosummary:: 18 | :toctree: generated 19 | :nosignatures: 20 | :template: classtemplate.rst 21 | 22 | BaseDataset 23 | BaseDatasetConfig 24 | 25 | build_dataset 26 | ---------------- 27 | 28 | .. autosummary:: 29 | :toctree: generated 30 | :nosignatures: 31 | :template: classtemplate.rst 32 | 33 | build_dataset 34 | 35 | IO 36 | ---------------- 37 | 38 | .. autosummary:: 39 | :toctree: generated 40 | :nosignatures: 41 | :template: classtemplate.rst 42 | 43 | ColumnNER 44 | ColumnNERConfig 45 | JsonLinesNER 46 | JsonLinesNERConfig 47 | Sentence 48 | SentenceConfig 49 | 50 | Legacy 51 | ---------------- 52 | 53 | .. autosummary:: 54 | :toctree: generated 55 | :nosignatures: 56 | :template: classtemplate.rst 57 | 58 | Conll2003 59 | Conll2003Config 60 | Wikiann 61 | WikiannConfig 62 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import setup, find_packages 4 | 5 | with open('requirements.txt', encoding='utf-8') as f: 6 | reqs = f.read() 7 | 8 | 9 | def find_configs(): 10 | data_files = [] 11 | for root, dirs, files in os.walk('configs'): 12 | for file in files: 13 | if file.endswith('.py'): 14 | data_files.append( 15 | (os.path.join( 16 | os.path.expanduser('~'), 17 | os.path.join('.fastie/configs/', 18 | '/'.join(root.split('/')[-2:]))), 19 | ['./' + '/'.join(root.split('/')[-3:]) + '/' + file])) 20 | return data_files 21 | 22 | 23 | setup(name='fastie', 24 | version='0.0.1', 25 | packages=find_packages(), 26 | install_requires=reqs.strip().split('\n'), 27 | entry_points={ 28 | 'console_scripts': [ 29 | 'fastie-train = fastie.command:main', 30 | 'fastie-eval = fastie.command:main', 31 | 'fastie-infer = fastie.command:main', 32 | 'fastie-interact = fastie.command:main', 33 | ] 34 | }, 35 | data_files=find_configs()) 36 | -------------------------------------------------------------------------------- /fastie/envs.py: -------------------------------------------------------------------------------- 1 | """FastIE 全局变量.""" 2 | __all__ = [ 3 | 'get_task', 'set_task', 'get_dataset', 'set_dataset', 'get_flag', 4 | 'set_flag', 'logger' 5 | ] 6 | 7 | import os 8 | from argparse import ArgumentParser 9 | from fastNLP import logger as fastnlp_logger 10 | 11 | logger = fastnlp_logger 12 | 13 | parser: ArgumentParser = ArgumentParser(prog='fastie-train', 14 | conflict_handler='resolve') 15 | 16 | FASTIE_HOME = f"{os.environ['HOME']}/.fastie" 17 | 18 | PARSER_FLAG = 'dataclass' # "comment" 19 | CONFIG_FLAG = 'dict' # class 20 | 21 | task = None 22 | 23 | 24 | def get_task(): 25 | return task 26 | 27 | 28 | def set_task(_task): 29 | global task 30 | task = _task 31 | 32 | 33 | dataset = None 34 | 35 | 36 | def get_dataset(): 37 | return dataset 38 | 39 | 40 | def set_dataset(_dataset): 41 | global dataset 42 | dataset = _dataset 43 | 44 | 45 | FLAG = None 46 | 47 | 48 | def set_flag(_flag: str = 'train'): 49 | global FLAG 50 | if _flag not in ['train', 'eval', 'infer']: 51 | _flag = 'train' 52 | else: 53 | FLAG = _flag 54 | 55 | 56 | def get_flag(): 57 | return FLAG 58 | 59 | 60 | sample_type = [int, bool, float, str, list, dict, set, tuple, type(None)] 61 | -------------------------------------------------------------------------------- /fastie/tasks/sequential_task.py: -------------------------------------------------------------------------------- 1 | """多阶段任务的基类.""" 2 | __all__ = ['SequentialTaskConfig', 'SequentialTask'] 3 | from fastie.tasks.base_task import BaseTask 4 | from fastie.node import BaseNodeConfig, BaseNode 5 | from fastie.envs import get_flag 6 | 7 | from fastNLP.io import DataBundle 8 | 9 | from dataclasses import dataclass 10 | from typing import List 11 | from abc import ABCMeta, abstractmethod 12 | 13 | 14 | @dataclass 15 | class SequentialTaskConfig(BaseNodeConfig): 16 | """多阶段任务的配置类.""" 17 | pass 18 | 19 | 20 | class SequentialTask(BaseNode, metaclass=ABCMeta): 21 | 22 | def __init__(self): 23 | self._tasks: List[BaseTask] = [] 24 | 25 | @abstractmethod 26 | def on_train(self, data_bundle: DataBundle): 27 | """ 28 | 训练多阶段任务的逻辑 29 | :return: 30 | """ 31 | 32 | @abstractmethod 33 | def on_eval(self, data_bundle: DataBundle): 34 | """ 35 | 验证多阶段任务的逻辑 36 | :return: 37 | """ 38 | 39 | @abstractmethod 40 | def on_infer(self, data_bundle: DataBundle): 41 | """ 42 | 推理多阶段任务的逻辑 43 | :return: 44 | """ 45 | 46 | def run(self, data_bundle: DataBundle): 47 | if get_flag() == 'train': 48 | return self.on_train(data_bundle) 49 | elif get_flag() == 'eval': 50 | return self.on_eval(data_bundle) 51 | elif get_flag() == 'infer': 52 | return self.on_infer(data_bundle) 53 | 54 | def __call__(self, *args, **kwargs): 55 | return self.run(*args, **kwargs) 56 | -------------------------------------------------------------------------------- /fastie/tasks/build_task.py: -------------------------------------------------------------------------------- 1 | """This module is used to build the task from config.""" 2 | __all__ = ['build_task'] 3 | 4 | from typing import Union, Optional 5 | 6 | from fastie.envs import get_task 7 | from fastie.tasks import NER, RE, EE 8 | from fastie.utils.utils import parse_config 9 | 10 | 11 | def build_task(_config: Optional[Union[dict, str]] = None): 12 | """Build the task you want to use from the config you give. 13 | 14 | :param _config: The config you want to use. It can be a dict or a path to a `*.py` config file. 15 | :return: The task in config. 16 | """ 17 | task, solution = '', '' 18 | if _config is not None: 19 | _config = parse_config(_config) 20 | if not get_task(): 21 | if _config is None: 22 | raise ValueError('The task you want to use is not specified.') 23 | else: 24 | if isinstance(_config, dict) and 'task' not in _config.keys(): 25 | raise ValueError('The task you want to use is not specified.') 26 | else: 27 | task, solution = _config['task'].split('/') 28 | else: 29 | task, solution = get_task().split('/') 30 | if task.lower() == 'ner': 31 | task_cls = NER.get(solution) 32 | elif task.lower() == 're': 33 | task_cls = RE.get(solution) 34 | elif task.lower() == 'ee': 35 | task_cls = EE.get(solution) 36 | if task_cls is None: 37 | raise ValueError( 38 | f'The task {task} with solution {solution} is not supported.') 39 | task_obj = task_cls.from_config(_config) 40 | return task_obj 41 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # 参数说明 https://flake8.pycqa.org/en/latest/user/options.html 3 | - repo: https://github.com/PyCQA/flake8 4 | rev: 4.0.1 5 | hooks: 6 | - id: flake8 7 | args: [ "--extend-ignore", "E402,E501,E741,F403,F405" ] 8 | # - repo: https://github.com/PyCQA/isort 9 | # rev: 5.10.1 10 | # hooks: 11 | # - id: isort 12 | # args: ["-m=2"] 13 | - repo: https://github.com/pre-commit/mirrors-yapf 14 | rev: v0.32.0 15 | hooks: 16 | - id: yapf 17 | - repo: https://github.com/codespell-project/codespell 18 | rev: v2.2.1 19 | hooks: 20 | - id: codespell 21 | args: 22 | - --skip=docs/*,tutorials 23 | - repo: https://github.com/pre-commit/pre-commit-hooks 24 | rev: v4.2.0 25 | hooks: 26 | - id: trailing-whitespace 27 | args: [ --markdown-linebreak-ext=md ] 28 | - id: check-added-large-files 29 | args: [ "--maxkb=100" ] 30 | - id: debug-statements 31 | - id: end-of-file-fixer 32 | - id: requirements-txt-fixer 33 | - id: double-quote-string-fixer 34 | - id: check-merge-conflict 35 | - id: fix-encoding-pragma 36 | args: [ "--remove" ] 37 | - id: mixed-line-ending 38 | args: [ "--fix=lf" ] 39 | # parameter introduction: https://docformatter.readthedocs.io/en/latest/usage.html#use-from-the-command-line 40 | - repo: https://github.com/myint/docformatter 41 | rev: v1.4 42 | hooks: 43 | - id: docformatter 44 | args: [ "--in-place", "--wrap-descriptions", "79" ] 45 | # check type-hint 46 | - repo: https://github.com/pre-commit/mirrors-mypy 47 | rev: v0.991 48 | hooks: 49 | - id: mypy 50 | exclude: |- 51 | (?x)( 52 | ^tests 53 | | ^docs 54 | ) 55 | -------------------------------------------------------------------------------- /fastie/controller/interactor.py: -------------------------------------------------------------------------------- 1 | """Interactor for FastIE.""" 2 | __all__ = [ 3 | 'InteractorConfig', 4 | 'Interactor', 5 | ] 6 | 7 | from dataclasses import dataclass, field 8 | from typing import Union, Sequence, Optional 9 | 10 | from fastNLP import DataSet 11 | from fastNLP.io import DataBundle 12 | 13 | from fastie.controller.base_controller import BaseController, CONTROLLER 14 | from fastie.controller.inference import Inference 15 | from fastie.envs import set_flag, logger 16 | from fastie.node import BaseNodeConfig 17 | 18 | 19 | @dataclass 20 | class InteractorConfig(BaseNodeConfig): 21 | """交互器的配置.""" 22 | log: str = field( 23 | default='', 24 | metadata={ 25 | 'help': 26 | 'What file to write the interactive log to. If this is not set, ' 27 | 'the log will not be written. ', 28 | 'existence': 29 | 'interact' 30 | }) 31 | 32 | 33 | @CONTROLLER.register_module('interactor') 34 | class Interactor(BaseController): 35 | """交互器 用于在命令行模式中进行交互式的预测, 例如: 36 | 37 | .. code-block:: console 38 | $ fastie-interact --task ner/bert --load_model model.pkl --cuda --log interactive.log 39 | 40 | :param log: 交互日志的保存路径, 如果不设置, 则不会保存日志 41 | """ 42 | 43 | def __init__(self, log: Optional[str] = None): 44 | super(Interactor, self).__init__() 45 | self.log = log 46 | if self.log is not None: 47 | self.inference = Inference(save_path=self.log, verbose=True) 48 | else: 49 | self.inference = Inference(verbose=True) 50 | set_flag('infer') 51 | 52 | def run(self, 53 | parameters_or_data: Optional[Union[dict, DataBundle, DataSet, str, 54 | Sequence[str]]] = None): 55 | parameters_or_data = BaseController.run(self, parameters_or_data) 56 | if self._sequential: 57 | return parameters_or_data 58 | if parameters_or_data is None: 59 | logger.error( 60 | 'Interacting tool do not allow task and dataset to be left ' 61 | 'empty. ') 62 | exit(1) 63 | return self.inference(parameters_or_data=parameters_or_data) 64 | -------------------------------------------------------------------------------- /fastie/dataset/io/sentence.py: -------------------------------------------------------------------------------- 1 | """Sentence dataset for inference.""" 2 | __all__ = ['SentenceConfig', 'Sentence'] 3 | 4 | from dataclasses import dataclass, field 5 | from typing import Union, Sequence, Optional 6 | 7 | from fastNLP import DataSet, Instance, Vocabulary 8 | from fastNLP.io import DataBundle 9 | 10 | from fastie.dataset.base_dataset import DATASET, BaseDataset, BaseDatasetConfig 11 | 12 | 13 | @dataclass 14 | class SentenceConfig(BaseDatasetConfig): 15 | sentence: str = field(default='', 16 | metadata=dict(help='Input a sequence as a dataset.', 17 | existence=True, 18 | nargs='+', 19 | multi_method='space-join')) 20 | 21 | 22 | @DATASET.register_module('sentence') 23 | class Sentence(BaseDataset): 24 | """Sentence dataset for inference (Only for inference). 25 | 26 | :param sentence: Input a sequence or sentences as a dataset (Use Spaces to separate tokens). 27 | For examples: 28 | 29 | .. code-block:: python 30 | data_bundle = Sentence(sentence='I love FastIE .').run() 31 | data_bundle = Sentence(sentence=['I love FastIE .', 'I love fastNLP .']).run() 32 | """ 33 | _config = SentenceConfig() 34 | _help = 'Input a sequence or sentences as a dataset. (Only for inference). ' 35 | 36 | def __init__(self, 37 | sentence: Optional[Union[Sequence[str], str]] = None, 38 | cache: bool = False, 39 | refresh_cache: bool = False, 40 | tag_vocab: Optional[Union[Vocabulary, dict]] = None, 41 | **kwargs): 42 | super(Sentence, self).__init__(cache=cache, 43 | refresh_cache=refresh_cache, 44 | tag_vocab=tag_vocab, 45 | **kwargs) 46 | self.sentence = sentence 47 | 48 | def run(self): 49 | dataset = DataSet() 50 | sentences = [self.sentence] if isinstance(self.sentence, 51 | str) else self.sentence 52 | for sentence in sentences: 53 | dataset.append(Instance(tokens=sentence.split(' '))) 54 | data_bundle = DataBundle(datasets={'infer': dataset}) 55 | return data_bundle 56 | -------------------------------------------------------------------------------- /fastie/tasks/ner/BaseNERTask.py: -------------------------------------------------------------------------------- 1 | """Base class for NER tasks.""" 2 | __all__ = ['BaseNERTask', 'BaseNERTaskConfig'] 3 | 4 | import abc 5 | from typing import Optional, Dict 6 | 7 | from fastNLP import Vocabulary 8 | from fastNLP.io import DataBundle 9 | 10 | from fastie.envs import logger 11 | from fastie.tasks.base_task import BaseTask, BaseTaskConfig 12 | from fastie.utils.utils import generate_tag_vocab, check_loaded_tag_vocab 13 | 14 | 15 | class BaseNERTaskConfig(BaseTaskConfig): 16 | """NER 任务所需参数.""" 17 | pass 18 | 19 | 20 | class BaseNERTask(BaseTask, metaclass=abc.ABCMeta): 21 | """FastIE NER 任务基类.""" 22 | 23 | _config = BaseNERTaskConfig() 24 | _help = 'Base class for NER tasks. ' 25 | 26 | def __init__(self, **kwargs): 27 | super().__init__(**kwargs) 28 | 29 | def on_generate_and_check_tag_vocab(self, 30 | data_bundle: DataBundle, 31 | state_dict: Optional[dict]) \ 32 | -> Dict[str, Vocabulary]: 33 | """根据数据集中每个样本 `sample['entity_motions'][i][1]` 生成标签词典。 如果加载模型得到的 34 | ``state_dict`` 中存在 ``tag_vocab``,则检查是否与根据 ``data_bundle`` 生成的 tag_vocab 35 | 一致 (优先使用加载得到的 tag_vocab)。 36 | 37 | :param data_bundle: 原始数据集, 38 | 可能包含 ``train``、``dev``、``test``、``infer`` 四种,需要分类处理。 39 | :param state_dict: 加载模型得到的 ``state_dict``,可能为 ``None`` 40 | :return: 标签词典,可能为 ``None`` 41 | """ 42 | tag_vocab = {} 43 | if state_dict is not None and 'tag_vocab' in state_dict: 44 | tag_vocab = state_dict['tag_vocab'] 45 | generated_tag_vocab = generate_tag_vocab(data_bundle) 46 | for key, value in tag_vocab.items(): 47 | if key not in generated_tag_vocab.keys(): 48 | generated_tag_vocab[key] = check_loaded_tag_vocab(value, 49 | None)[1] 50 | else: 51 | signal, generated_tag_vocab[key] = check_loaded_tag_vocab( 52 | value, generated_tag_vocab[key]) 53 | if signal == -1: 54 | logger.warning( 55 | f'It is detected that the loaded ``{key}`` vocabulary ' 56 | f'conflicts with the generated ``{key}`` vocabulary, ' 57 | f'so the model loading may fail. ') 58 | return generated_tag_vocab 59 | -------------------------------------------------------------------------------- /fastie/dataset/build_dataset.py: -------------------------------------------------------------------------------- 1 | """Build dataset from different sources.""" 2 | __all__ = ['build_dataset'] 3 | 4 | from typing import Union, Optional, Sequence 5 | 6 | from fastNLP import DataSet, Instance 7 | from fastNLP.io import DataBundle 8 | 9 | from fastie.dataset.base_dataset import DATASET 10 | from fastie.dataset.io.sentence import Sentence 11 | from fastie.envs import get_flag, get_dataset 12 | from fastie.utils.utils import parse_config 13 | 14 | 15 | def build_dataset(dataset: Optional[Union[str, Sequence[str], dict, 16 | Sequence[dict], DataSet, 17 | DataBundle]], 18 | dataset_config: Optional[dict] = None) -> DataBundle: 19 | """从不同的来源构造数据集. 20 | 21 | :param dataset: 可以是 ``str`` 或 ``Sequence[str]`` 或 ``dict`` 22 | 或 ``Sequence[dict]`` 或 ``DataSet`` 或 ``DataBundle``: 23 | 24 | * 为 ``str`` 时, 将自动构建 ``Sentence`` 数据集, 该数据集只有一个 ``tokens`` 字段, 请用空格分割不同的 ``token``; 25 | * 为 ``Sequence[str]`` 时, 将自动构建 ``Sentence`` 数据集, 包含多个样本; 26 | * 为 ``dict`` 时, 将自动构建 ``DataSet`` 数据集, 键名将被映射到 ``DataSet`` 的 ``field_name``; 27 | * 为 ``Sequence[dict]`` 时, 将自动构建 ``DataSet`` 数据集, 包含多个样本; 28 | * 为 ``DataSet`` 时, 将自动构建 ``DataBundle`` 数据集, 并根据当前的 ``flag`` 自动决定 ``split`` 的名称, 例如 ``train`` ``dev`` ``test`` ``infer`` ; 29 | * 为 ``DataBundle`` 时, 直接返回该数据集; 30 | * 为 ``None`` 时, 根据配置文件中的 ``dataset`` 构建数据集. 31 | 32 | :param dataset_config: ``dataset`` 对象的参数 33 | :return: ``DataBundle`` 数据集 34 | """ 35 | data_bundle = DataBundle() 36 | if dataset is None: 37 | if not get_dataset(): 38 | raise ValueError('The dataset you want to use is not specified.') 39 | else: 40 | if dataset_config is None: 41 | data_bundle = DATASET.get(get_dataset())().run() 42 | else: 43 | data_bundle = DATASET.get( 44 | get_dataset())(**parse_config(dataset_config)).run() 45 | else: 46 | if isinstance(dataset, str) or isinstance(dataset, Sequence) \ 47 | and isinstance(dataset[0], str): 48 | data_bundle = Sentence(dataset)() # type: ignore [arg-type] 49 | if isinstance(dataset, dict): 50 | dataset = [dataset] 51 | if isinstance(dataset, Sequence) and isinstance(dataset[0], dict): 52 | dataset = DataSet([Instance(**sample) for sample in dataset]) 53 | if isinstance(dataset, DataSet): 54 | if get_flag() == 'train': 55 | data_bundle = DataBundle(datasets={'train': dataset}) 56 | elif get_flag() == 'eval': 57 | data_bundle = DataBundle(datasets={'test': dataset}) 58 | elif get_flag() == 'infer': 59 | data_bundle = DataBundle(datasets={'infer': dataset}) 60 | else: 61 | data_bundle = DataBundle(datasets={'train': dataset}) 62 | return data_bundle 63 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import pytest 4 | 5 | from typing import Type 6 | 7 | from fastNLP.io import DataBundle 8 | 9 | from fastie import Trainer, Evaluator, Inference, BaseTask 10 | 11 | 12 | class UnifiedTaskTest: 13 | 14 | def setup_class(self, 15 | task_cls: Type[BaseTask], 16 | data_bundle: DataBundle, 17 | extra_parameters: dict = {}): 18 | self.task_cls = task_cls 19 | self.extra_parameters = extra_parameters 20 | self.data_bundle = data_bundle 21 | 22 | @pytest.mark.parametrize('device', ['cpu', 'cuda:0', [0, 1]]) 23 | def test_train(self, device): 24 | task = self.task_cls(**{ 25 | 'device': device, 26 | 'batch_size': 2, 27 | 'epoch': 2, 28 | **self.extra_parameters 29 | }).run(self.data_bundle) 30 | assert Trainer().run(task) 31 | 32 | @pytest.mark.parametrize('device', ['cpu', 'cuda:0', [0, 1]]) 33 | def test_eval(self, device): 34 | task = self.task_cls(**{ 35 | 'device': device, 36 | 'batch_size': 2, 37 | 'epoch': 2, 38 | **self.extra_parameters 39 | }).run(self.data_bundle) 40 | assert Evaluator().run(task) 41 | 42 | @pytest.mark.parametrize('device', ['cpu', 'cuda:0', [0, 1]]) 43 | def test_inference(self, device): 44 | task = self.task_cls(**{ 45 | 'device': device, 46 | 'batch_size': 2, 47 | **self.extra_parameters 48 | }).run(self.data_bundle) 49 | assert Inference().run(task) 50 | 51 | def test_topk(self): 52 | with tempfile.TemporaryDirectory() as tmpdir: 53 | task = self.task_cls( 54 | **{ 55 | 'device': 'cpu', 56 | 'batch_size': 2, 57 | 'epoch': 2, 58 | 'topk': 2, 59 | 'save_model_folder': tmpdir, 60 | **self.extra_parameters 61 | }).run(self.data_bundle) 62 | assert Trainer().run(task) 63 | assert len(os.listdir(tmpdir)) > 0 64 | 65 | def test_load_best_model(self): 66 | with tempfile.TemporaryDirectory() as tmpdir: 67 | task = self.task_cls( 68 | **{ 69 | 'device': 'cpu', 70 | 'batch_size': 2, 71 | 'epoch': 2, 72 | 'load_best_model': True, 73 | 'save_model_folder': tmpdir, 74 | **self.extra_parameters 75 | }).run(self.data_bundle) 76 | assert Trainer().run(task) 77 | assert len(os.listdir(tmpdir)) > 0 78 | 79 | # def test_fp16(self): 80 | # task = self.task_cls(**{ 81 | # "device": "cpu", 82 | # "batch_size": 2, 83 | # "epoch": 2, 84 | # "fp16": True, 85 | # **self.extra_parameters 86 | # }).run(self.data_bundle) 87 | # assert Trainer().run(task) 88 | -------------------------------------------------------------------------------- /tests/dummy.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from fastNLP.io import DataBundle 4 | 5 | from fastie.dataset import build_dataset 6 | 7 | 8 | def dummy_ner_dataset() -> DataBundle: 9 | data_bundle = build_dataset([{ 10 | 'tokens': [ 11 | 'EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 12 | 'lamb', '.' 13 | ], 14 | 'entity_mentions': [[[0], 'ORG'], [[2], 'MISC'], [[6], 'MISC']] 15 | }, { 16 | 'tokens': ['Peter', 'Blackburn'], 17 | 'entity_mentions': [[[0, 1], 'PER']] 18 | }, { 19 | 'tokens': ['BRUSSELS', '1996-08-22'], 20 | 'entity_mentions': [[[0], 'LOC']] 21 | }, { 22 | 'tokens': [ 23 | 'The', 'European', 'Commission', 'said', 'on', 'Thursday', 'it', 24 | 'disagreed', 'with', 'German', 'advice', 'to', 'consumers', 'to', 25 | 'shun', 'British', 'lamb', 'until', 'scientists', 'determine', 26 | 'whether', 'mad', 'cow', 'disease', 'can', 'be', 'transmitted', 27 | 'to', 'sheep', '.' 28 | ], 29 | 'entity_mentions': [[[1, 2], 'ORG'], [[9], 'MISC'], [[15], 'MISC']] 30 | }, { 31 | 'tokens': [ 32 | 'Germany', "\'s", 'representative', 'to', 'the', 'European', 33 | 'Union', "'s", 'veterinary', 'committee', 'Werner', 'Zwingmann', 34 | 'said', 'on', 'Wednesday', 'consumers', 'should', 'buy', 35 | 'sheepmeat', 'from', 'countries', 'other', 'than', 'Britain', 36 | 'until', 'the', 'scientific', 'advice', 'was', 'clearer', '.' 37 | ], 38 | 'entity_mentions': [[[0], 'LOC'], [[5, 6], 'ORG'], [[10, 11], 'PER'], 39 | [[23], 'LOC']] 40 | }, { 41 | 'tokens': [ 42 | 'Rabinovich', 'is', 'winding', 'up', 'his', 'term', 'as', 43 | 'ambassador', '.' 44 | ], 45 | 'entity_mentions': [[[0], 'PER']] 46 | }, { 47 | 'tokens': [ 48 | 'He', 'will', 'be', 'replaced', 'by', 'Eliahu', 'Ben-Elissar', ',', 49 | 'a', 'former', 'Israeli', 'envoy', 'to', 'Egypt', 'and', 50 | 'right-wing', 'Likud', 'party', 'politician', '.' 51 | ], 52 | 'entity_mentions': [[[5, 6], 'PER'], [[10], 'MISC'], [[13], 'LOC'], 53 | [[16], 'ORG']] 54 | }, { 55 | 'tokens': [ 56 | 'Israel', 'on', 'Wednesday', 'sent', 'Syria', 'a', 'message', ',', 57 | 'via', 'Washington', ',', 'saying', 'it', 'was', 'committed', 'to', 58 | 'peace', 'and', 'wanted', 'to', 'open', 'negotiations', 'without', 59 | 'preconditions', '.' 60 | ], 61 | 'entity_mentions': [[[0], 'LOC'], [[4], 'LOC'], [[9], 'LOC']] 62 | }]) 63 | data_bundle.set_dataset(deepcopy(data_bundle.get_dataset('train')), 'dev') 64 | data_bundle.set_dataset(deepcopy(data_bundle.get_dataset('train')), 'test') 65 | data_bundle.set_dataset(deepcopy(data_bundle.get_dataset('train')), 66 | 'infer') 67 | data_bundle.get_dataset('infer').delete_field('entity_mentions') 68 | return data_bundle 69 | -------------------------------------------------------------------------------- /docs/source/_static/css/badge_only.css: -------------------------------------------------------------------------------- 1 | .fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:after,.clearfix:before{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-style:normal;font-weight:400;src:url(../fonts/fontawesome-webfont.eot?674f50d287a8c48dc19ba404d20fe713?#iefix) format("embedded-opentype"),url(../fonts/fontawesome-webfont.woff2?af7ae505a9eed503f8b8e6982036873e) format("woff2"),url(../fonts/fontawesome-webfont.woff?fee66e712a8a08eef5805a46892932ad) format("woff"),url(../fonts/fontawesome-webfont.ttf?b06871f281fee6b241d60582ae9369b9) format("truetype"),url(../fonts/fontawesome-webfont.svg?912ec66d7572ff821749319396470bde#FontAwesome) format("svg")}.fa:before{font-family:FontAwesome;font-style:normal;font-weight:400;line-height:1}.fa:before,a .fa{text-decoration:inherit}.fa:before,a .fa,li .fa{display:inline-block}li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before,.icon-book:before{content:"\f02d"}.fa-caret-down:before,.icon-caret-down:before{content:"\f0d7"}.fa-caret-up:before,.icon-caret-up:before{content:"\f0d8"}.fa-caret-left:before,.icon-caret-left:before{content:"\f0d9"}.fa-caret-right:before,.icon-caret-right:before{content:"\f0da"}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;z-index:400}.rst-versions a{color:#2980b9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27ae60}.rst-versions .rst-current-version:after{clear:both;content:"";display:block}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book,.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#e74c3c;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#f1c40f;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:grey;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:1px solid #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .fa-book,.rst-versions.rst-badge .icon-book{float:none;line-height:30px}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book,.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge>.rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width:768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}} 2 | -------------------------------------------------------------------------------- /fastie/controller/evaluator.py: -------------------------------------------------------------------------------- 1 | """Evaluator for FastIE.""" 2 | __all__ = ['Evaluator', 'EvaluatorConfig'] 3 | 4 | from dataclasses import dataclass 5 | from typing import Union, Optional 6 | 7 | from fastNLP import DataSet, auto_param_call 8 | from fastNLP import Evaluator as FastNLP_Evaluator 9 | from fastNLP.io import DataBundle 10 | 11 | from fastie.controller.base_controller import BaseController, CONTROLLER 12 | from fastie.envs import set_flag 13 | from fastie.node import BaseNodeConfig 14 | 15 | 16 | @dataclass 17 | class EvaluatorConfig(BaseNodeConfig): 18 | """验证器的配置类.""" 19 | pass 20 | 21 | 22 | @CONTROLLER.register_module('evaluator') 23 | class Evaluator(BaseController): 24 | """验证器 用于对任务在 ``test`` 数据集上进行检验,并输出 ``test`` 数据集上的 ``metric``""" 25 | 26 | def __init__(self): 27 | super(Evaluator, self).__init__() 28 | set_flag('eval') 29 | 30 | def run( 31 | self, 32 | parameters_or_data: Optional[Union[dict, DataBundle, DataSet]] = None 33 | ) -> dict: 34 | """验证器的 ``run`` 方法,用于实际地对传入的 ``task`` 或是数据集进行验证. 35 | 36 | 也可以使用命令行模式, 例如: 37 | 38 | .. code-block:: console 39 | $ fastie-inference --task ner/bert --dataset conll2003 --save_path result.jsonl 40 | 41 | :param parameters_or_data: 既可以是 task,也可以是数据集: 42 | * 为 ``task`` 时, 应为 :class:`~fastie.BaseTask` 对象 ``run`` 43 | 方法的返回值, 例如: 44 | >>> from fastie.tasks import BertNER 45 | >>> task = BertNER().run() 46 | >>> Evaluator().run(task) 47 | * 为数据集,可以是 ``[dict, DataSet, DataBundle, None]`` 类型的数据集: 48 | * ``dict`` 类型的数据集,例如: 49 | >>> dataset = {'tokens': [ "It", "is", "located", "in", "Seoul", "." ], 50 | >>> 'entity_motions': [([4], "LOC")]} 51 | * ``Sequence[dict]`` 类型的数据集,例如: 52 | >>> dataset = [{'tokens': [ "It", "is", "located", "in", "Seoul", "." ], 53 | >>> 'entity_motions': [([4], "LOC")]}] 54 | * ``DataSet`` 类型的数据集,例如: 55 | >>> from fastNLP import DataSet, Instance 56 | >>> dataset = DataSet([Instance(tokens=[ "It", "is", "located", "in", "Seoul", "." ], 57 | >>> entity_motions=([4], "LOC"))]) 58 | * ``DataBundle`` 类型的数据集,必须包含 ``test`` 子集, 例如: 59 | >>> from fastNLP import DataSet, Instance 60 | >>> from fastNLP.io import DataBundle 61 | >>> dataset = DataBundle(datasets={'test': DataSet([Instance(tokens=[ "It", "is", "located", "in", "Seoul", "." ], 62 | >>> entity_motions=([4], "LOC"))])}) 63 | * ``None`` 会自动寻找 ``config`` 中的 ``dataset``, 例如: 64 | >>> config = {'dataset': 'conll2003'} 65 | >>> Evaluator.from_config(config).run() 66 | 67 | :return: ``dict`` 类型的 ``metric`` 结果, 例如: 68 | >>> {'acc': 0.0, 'f1': 0.0, 'precision': 0.0, 'recall': 0.0} 69 | """ 70 | parameters_or_data = BaseController.run(self, parameters_or_data) 71 | if self._sequential: 72 | return parameters_or_data 73 | if parameters_or_data is None: 74 | raise Exception( 75 | 'Evaluating tool do not allow task or dataset to be left ' 76 | 'empty. ') 77 | evaluator = FastNLP_Evaluator(**parameters_or_data) 78 | return auto_param_call(evaluator.run, parameters_or_data) 79 | -------------------------------------------------------------------------------- /fastie/dataset/base_dataset.py: -------------------------------------------------------------------------------- 1 | """Base class for all FastIE datasets.""" 2 | __all__ = ['BaseDataset', 'BaseDatasetConfig', 'load_dataset', 'DATASET'] 3 | 4 | import abc 5 | import os 6 | from dataclasses import dataclass, field 7 | 8 | from fastNLP import cache_results 9 | 10 | from fastie.envs import FASTIE_HOME, logger 11 | from fastie.node import BaseNode, BaseNodeConfig 12 | from fastie.utils import Registry 13 | 14 | DATASET: Registry = Registry('DATASET') 15 | 16 | 17 | def load_dataset(name, *args, **kwargs): 18 | """根据 dataset 的注册名字加载 dataset 对象. 19 | 20 | :param name: dataset 的注册名字. 21 | :param args: dataset 的参数. 22 | :param kwargs: dataset 的参数. 23 | :return: 24 | """ 25 | return DATASET.get(name)(*args, **kwargs) 26 | 27 | 28 | @dataclass 29 | class BaseDatasetConfig(BaseNodeConfig): 30 | """FastIE 数据集基类的配置类.""" 31 | cache: bool = field( 32 | default=False, 33 | metadata=dict( 34 | help='The result of data loading is cached for accelerated reading ' 35 | 'the next time it is used.', 36 | existence=True)) 37 | refresh_cache: bool = field( 38 | default=False, 39 | metadata=dict(help='Clear cache (Use this when your data changes). ', 40 | existence=True)) 41 | 42 | 43 | class BaseDataset(BaseNode, metaclass=abc.ABCMeta): 44 | """FastIE 数据集基类. 45 | 46 | :param cache: 是否缓存数据集. 47 | :param refresh_cache: 是否刷新缓存. 48 | """ 49 | 50 | _config = BaseDatasetConfig() 51 | _help = '数据集基类' 52 | 53 | def __init__(self, 54 | cache: bool = False, 55 | refresh_cache: bool = False, 56 | **kwargs): 57 | BaseNode.__init__(self, **kwargs) 58 | self.refresh_cache: bool = refresh_cache 59 | self.cache: bool = cache 60 | 61 | @property 62 | def cache(self): 63 | return self._cache 64 | 65 | @cache.setter 66 | def cache(self, value: bool): 67 | if value: 68 | # 保存 cache 的位置默认为 `~/.fastie/cache/BaseDataset/cache.pkl` 69 | original_run = self.run 70 | 71 | def run_wrapper(): 72 | cache_name = 'cache' 73 | if 'io' in self.__class__.__module__: 74 | if hasattr(self, 'folder'): 75 | if not self.folder.endswith('/'): 76 | self.folder += '/' 77 | cache_name = os.path.basename( 78 | os.path.dirname(self.folder)) 79 | else: 80 | logger.warn(""" 81 | Please make sure that your IO Dataset class has a ``folder`` attribute. 82 | Otherwise, your dataset will be cached into the same cache file, whether or not you use the same folder the next time. 83 | """) 84 | path = os.path.join( 85 | FASTIE_HOME, 86 | f'cache/{self.__class__.__name__}/{cache_name}.pkl') 87 | return cache_results( 88 | _cache_fp=f'{path}', 89 | _refresh=self.refresh_cache)(original_run)() 90 | 91 | object.__setattr__(self, 'run', run_wrapper) 92 | self._cache = value 93 | 94 | @abc.abstractmethod 95 | def run(self): 96 | """加载数据集, 返回一个 DataBundle 对象. 97 | 98 | :return: 99 | """ 100 | raise NotImplementedError('The `run` method must be implemented. ') 101 | 102 | def __call__(self, *args, **kwargs): 103 | return self.run() 104 | -------------------------------------------------------------------------------- /fastie/dataset/legacy/wikiann.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wikiann dataset for FastIE. . 3 | """ 4 | __all__ = ['Wikiann', 'WikiannConfig'] 5 | 6 | from dataclasses import dataclass, field 7 | 8 | import numpy as np 9 | from datasets import load_dataset 10 | from fastNLP import DataSet, Instance 11 | from fastNLP.io import DataBundle 12 | 13 | from fastie.dataset.base_dataset import DATASET, BaseDataset, BaseDatasetConfig 14 | 15 | 16 | @dataclass 17 | class WikiannConfig(BaseDatasetConfig): 18 | """Wikiann 数据集配置类.""" 19 | language: str = field( 20 | default='en', 21 | metadata=dict(help='Select which language subset in wikiann. ' 22 | 'Refer to https://huggingface.co/datasets/wikiann .', 23 | existence=True)) 24 | 25 | 26 | @DATASET.register_module('wikiann') 27 | class Wikiann(BaseDataset): 28 | """Wikiann 为 NER 数据集,由 172 语言的子集组成,标签包括 LOC, ORG, PER。 29 | 30 | :param language: 选择哪个语言的子集 31 | 参考 https://huggingface.co/datasets/wikiann 32 | """ 33 | _config = WikiannConfig() 34 | _help = 'Wikiann for NER task. Refer to ' \ 35 | 'https://huggingface.co/datasets/wikiann for more information.' 36 | 37 | def __init__(self, language: str = 'en', **kwargs): 38 | super().__init__(**kwargs) 39 | self.language = language 40 | 41 | def run(self): 42 | raw_dataset = load_dataset('wikiann', self.language) 43 | tag2idx = { 44 | 'O': 0, 45 | 'B-PER': 1, 46 | 'I-PER': 2, 47 | 'B-ORG': 3, 48 | 'I-ORG': 4, 49 | 'B-LOC': 5, 50 | 'I-LOC': 6 51 | } 52 | idx2tag = {value: key for key, value in tag2idx.items()} 53 | datasets = {} 54 | 55 | for split, dataset in raw_dataset.items(): 56 | split = split.replace('validation', 'dev') 57 | datasets[split] = DataSet() 58 | for sample in dataset: 59 | instance = Instance() 60 | instance.add_field('tokens', sample['tokens']) 61 | entity_mentions = [] 62 | span = [] 63 | current_tag = 0 64 | for i in np.arange(len(sample['ner_tags'])): 65 | if sample['ner_tags'][i] != 0: 66 | if len(span) == 0: 67 | current_tag = sample['ner_tags'][i] 68 | span.append(i) 69 | continue 70 | else: 71 | if current_tag == sample['ner_tags'][i] or \ 72 | current_tag + 1 == sample['ner_tags'][i]: 73 | span.append(i) 74 | continue 75 | else: 76 | entity_mentions.append( 77 | (span, idx2tag[current_tag][2:])) 78 | span = [i] 79 | current_tag = sample['ner_tags'][i] 80 | continue 81 | else: 82 | if len(span) > 0: 83 | entity_mentions.append( 84 | (span, 85 | idx2tag[sample['ner_tags'][span[0]]][2:])) 86 | span = [] 87 | if len(span) > 0: 88 | entity_mentions.append( 89 | (span, idx2tag[sample['ner_tags'][span[0]]][2:])) 90 | instance.add_field('entity_mentions', entity_mentions) 91 | datasets[split].append(instance) 92 | 93 | data_bundle = DataBundle(datasets=datasets) 94 | return data_bundle 95 | -------------------------------------------------------------------------------- /fastie/utils/path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import os.path as osp 4 | from pathlib import Path 5 | 6 | from .misc import is_str 7 | 8 | 9 | def is_filepath(x): 10 | return is_str(x) or isinstance(x, Path) 11 | 12 | 13 | def fopen(filepath, *args, **kwargs): 14 | if is_str(filepath): 15 | return open(filepath, *args, **kwargs) 16 | elif isinstance(filepath, Path): 17 | return filepath.open(*args, **kwargs) 18 | raise ValueError('`filepath` should be a string or a Path') 19 | 20 | 21 | def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): 22 | if not osp.isfile(filename): 23 | raise FileNotFoundError(msg_tmpl.format(filename)) 24 | 25 | 26 | def mkdir_or_exist(dir_name, mode=0o777): 27 | if dir_name == '': 28 | return 29 | dir_name = osp.expanduser(dir_name) 30 | os.makedirs(dir_name, mode=mode, exist_ok=True) 31 | 32 | 33 | def symlink(src, dst, overwrite=True, **kwargs): 34 | if os.path.lexists(dst) and overwrite: 35 | os.remove(dst) 36 | os.symlink(src, dst, **kwargs) 37 | 38 | 39 | def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True): 40 | """Scan a directory to find the interested files. 41 | 42 | Args: 43 | dir_path (str | :obj:`Path`): Path of the directory. 44 | suffix (str | tuple(str), optional): File suffix that we are 45 | interested in. Default: None. 46 | recursive (bool, optional): If set to True, recursively scan the 47 | directory. Default: False. 48 | case_sensitive (bool, optional) : If set to False, ignore the case of 49 | suffix. Default: True. 50 | Returns: 51 | A generator for all the interested files with relative paths. 52 | """ 53 | if isinstance(dir_path, (str, Path)): 54 | dir_path = str(dir_path) 55 | else: 56 | raise TypeError('"dir_path" must be a string or Path object') 57 | 58 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 59 | raise TypeError('"suffix" must be a string or tuple of strings') 60 | 61 | if suffix is not None and not case_sensitive: 62 | suffix = suffix.lower() if isinstance(suffix, str) else tuple( 63 | item.lower() for item in suffix) 64 | 65 | root = dir_path 66 | 67 | def _scandir(dir_path, suffix, recursive, case_sensitive): 68 | for entry in os.scandir(dir_path): 69 | if not entry.name.startswith('.') and entry.is_file(): 70 | rel_path = osp.relpath(entry.path, root) 71 | _rel_path = rel_path if case_sensitive else rel_path.lower() 72 | if suffix is None or _rel_path.endswith(suffix): 73 | yield rel_path 74 | elif recursive and os.path.isdir(entry.path): 75 | # scan recursively if entry.path is a directory 76 | yield from _scandir(entry.path, suffix, recursive, 77 | case_sensitive) 78 | 79 | return _scandir(dir_path, suffix, recursive, case_sensitive) 80 | 81 | 82 | def find_vcs_root(path, markers=('.git', )): 83 | """Finds the root directory (including itself) of specified markers. 84 | 85 | Args: 86 | path (str): Path of directory or file. 87 | markers (list[str], optional): List of file or directory names. 88 | Returns: 89 | The directory contained one of the markers or None if not found. 90 | """ 91 | if osp.isfile(path): 92 | path = osp.dirname(path) 93 | 94 | prev, cur = None, osp.abspath(osp.expanduser(path)) 95 | while cur != prev: 96 | if any(osp.exists(osp.join(cur, marker)) for marker in markers): 97 | return cur 98 | prev, cur = cur, osp.split(cur)[0] 99 | return None 100 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python template 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | docs/src/ 75 | docs/temp/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | .idea/ 164 | 165 | # 自己的测试文件夹 166 | playground/ 167 | -------------------------------------------------------------------------------- /fastie/controller/trainer.py: -------------------------------------------------------------------------------- 1 | """Trainer for FastIE.""" 2 | __all__ = ['Trainer', 'TrainerConfig'] 3 | 4 | from dataclasses import dataclass 5 | from typing import Union, Sequence, Optional 6 | 7 | from fastNLP import DataSet, auto_param_call 8 | from fastNLP import Trainer as FastNLP_Trainer 9 | from fastNLP.io import DataBundle 10 | 11 | from fastie.controller.base_controller import BaseController, CONTROLLER 12 | from fastie.envs import set_flag, logger 13 | from fastie.node import BaseNodeConfig 14 | from fastie.tasks.base_task import BaseTask 15 | 16 | 17 | @dataclass 18 | class TrainerConfig(BaseNodeConfig): 19 | """训练器的配置类.""" 20 | pass 21 | 22 | 23 | @CONTROLLER.register_module('trainer') 24 | class Trainer(BaseController): 25 | """训练器 用于对任务在 ``train`` 数据集上进行训练,并输出 ``dev`` 数据集上的 ``metric`` 26 | 27 | 也可以使用命令行模式, 例如: 28 | 29 | .. code-block:: console 30 | $ fastie-train --task ner/bert --dataset conll2003 --topk 3 --save_model model.pkl 31 | """ 32 | _config = TrainerConfig() 33 | _help = 'Trainer for FastIE ' 34 | 35 | def __init__(self): 36 | super(Trainer, self).__init__() 37 | set_flag('train') 38 | 39 | def run(self, 40 | parameters_or_data: Optional[Union[dict, DataBundle, DataSet, str, 41 | Sequence[str]]] = None): 42 | """验证器的 ``run`` 方法,用于实际地对传入的 ``task`` 或是数据集进行验证. 43 | 44 | 也可以使用命令行模式, 例如: 45 | 46 | .. code-block:: console 47 | $ fastie-inference --task ner/bert --dataset conll2003 --save_path result.jsonl 48 | 49 | :param parameters_or_data: 既可以是 task,也可以是数据集: 50 | * 为 ``task`` 时, 应为 :class:`~fastie.BaseTask` 对象 ``run`` 51 | 方法的返回值, 例如: 52 | >>> from fastie.tasks import BertNER 53 | >>> task = BertNER().run() 54 | >>> Evaluator().run(task) 55 | * 为数据集,可以是 ``[dict, Sequence[dict], DataSet, DataBundle, None]`` 类型的数据集: 56 | * ``dict`` 类型的数据集,例如: 57 | >>> dataset = {'tokens': [ "It", "is", "located", "in", "Seoul", "." ], 58 | >>> 'entity_motions': [([4], "LOC")]} 59 | * ``Sequence[dict]`` 类型的数据集,例如: 60 | >>> dataset = [{'tokens': [ "It", "is", "located", "in", "Seoul", "." ], 61 | >>> 'entity_motions': [([4], "LOC")]}] 62 | * ``DataSet`` 类型的数据集,例如: 63 | >>> from fastNLP import DataSet, Instance 64 | >>> dataset = DataSet([Instance(tokens=[ "It", "is", "located", "in", "Seoul", "." ], 65 | >>> entity_motions=([4], "LOC"))]) 66 | * ``DataBundle`` 类型的数据集,必须包含 ``train`` 子集, 67 | 如果有 ``dev`` 子集的话,则在每个 epoch 结束后进行检验,例如: 68 | >>> from fastNLP import DataSet, Instance 69 | >>> from fastNLP.io import DataBundle 70 | >>> dataset = DataBundle(datasets={'train': DataSet([Instance(tokens=[ "It", "is", "located", "in", "Seoul", "." ], 71 | >>> entity_motions=([4], "LOC"))])}) 72 | * ``None`` 会自动寻找 ``config`` 中的 ``dataset``, 例如: 73 | >>> config = {'dataset': 'conll2003'} 74 | >>> Evaluator.from_config(config).run() 75 | 76 | :return: ``None`` 77 | """ 78 | parameters_or_data = BaseController.run(self, parameters_or_data) 79 | if self._sequential: 80 | return parameters_or_data 81 | if parameters_or_data is None: 82 | logger.error( 83 | 'Training tool do not allow task and dataset to be left ' 84 | 'empty. ') 85 | exit(1) 86 | 87 | task: BaseTask = parameters_or_data.pop('fastie_task') 88 | trainer = FastNLP_Trainer(**parameters_or_data) 89 | auto_param_call(trainer.run, parameters_or_data) 90 | if task is not None: 91 | task._on_get_state_dict_cache = task.on_get_state_dict( 92 | model=task._on_setup_model_cache, 93 | data_bundle=task._on_dataset_preprocess_cache, 94 | tag_vocab=task._on_generate_and_check_tag_vocab_cache) 95 | return task._on_get_state_dict_cache 96 | else: 97 | return None 98 | -------------------------------------------------------------------------------- /fastie/controller/base_controller.py: -------------------------------------------------------------------------------- 1 | """控制器基类.""" 2 | 3 | __all__ = ['BaseController', 'CONTROLLER'] 4 | 5 | from typing import Union, Sequence, Generator, Optional 6 | 7 | from fastNLP import DataSet 8 | from fastNLP.io import DataBundle 9 | 10 | from fastie.dataset.build_dataset import build_dataset 11 | from fastie.node import BaseNode 12 | from fastie.tasks import build_task, BaseTask, SequentialTask 13 | from fastie.utils import Registry 14 | 15 | CONTROLLER: Registry = Registry('CONTROLLER') 16 | 17 | 18 | class BaseController(BaseNode): 19 | """Base class for all controllers.""" 20 | 21 | def __init__(self, **kwargs): 22 | BaseNode.__init__(self, **kwargs) 23 | self._sequential = False 24 | 25 | def run(self, 26 | parameters_or_data: Optional[Union[dict, DataBundle, DataSet, str, 27 | Sequence[str], Sequence[dict], 28 | BaseTask, Generator, 29 | SequentialTask]] = None): 30 | """控制器基类的 ``run`` 方法,用于实际地对传入的 ``task`` 或是数据集进行训练, 验证或推理. 31 | 32 | :param parameters_or_data: 既可以是 task,也可以是数据集: 33 | * 为 ``task`` 时, 应为 :class:`~fastie.BaseTask` 对象 ``run`` 34 | 方法的返回值, 例如: 35 | >>> from fastie.tasks import BertNER 36 | >>> task = BertNER().run() 37 | >>> Trainer().run(task) 38 | * 为数据集,可以是 ``[dict, DataSet, DataBundle, str, 39 | Sequence[str], Sequence[dict], None]`` 类型的数据集: 40 | * ``dict`` 类型的数据集,例如: 41 | >>> dataset = {'tokens': [ "It", "is", "located", "in", "Seoul", "." ], 42 | >>> 'entity_motions': [([4], "LOC")]} 43 | * ``Sequence[dict]`` 类型的数据集,例如: 44 | >>> dataset = [{'tokens': [ "It", "is", "located", "in", "Seoul", "." ], 45 | >>> 'entity_motions': [([4], "LOC")]}] 46 | * ``DataSet`` 类型的数据集,例如: 47 | >>> from fastNLP import DataSet, Instance 48 | >>> dataset = DataSet([Instance(tokens=[ "It", "is", "located", "in", "Seoul", "." ], 49 | >>> entity_motions=([4], "LOC"))]) 50 | * ``DataBundle`` 类型的数据集,例如: 51 | >>> from fastNLP import DataSet, Instance 52 | >>> from fastNLP.io import DataBundle 53 | >>> dataset = DataBundle(datasets={'train': DataSet([Instance(tokens=[ "It", "is", "located", "in", "Seoul", "." ], 54 | >>> entity_motions=([4], "LOC"))])}) 55 | * ``str`` 类型的数据集,会自动根据空格分割转换为 ``token``, 仅适用于推理, 56 | 详见 :class:`fastie.dataset.io.sentence.Sentence` 例如: 57 | >>> dataset = "It is located in Seoul ." 58 | * ``Sequence[str]`` 类型的数据集,会自动根据空格分割转换为 ``token``, 仅适用于推理, 59 | 详见 :class:`fastie.dataset.io.sentence.Sentence`, 例如: 60 | >>> dataset = ["It is located in Seoul .", "It is located in Beijing ."] 61 | * ``None`` 会自动寻找 ``config`` 中的 ``dataset``, 例如: 62 | >>> config = {'dataset': 'conll2003'} 63 | >>> Trainer.from_config(config).run() 64 | 65 | :return: 训练, 验证, 或推理的结果: 66 | * 训练时, 返回任务的 ``state_dict`` 67 | * 验证时, 返回验证集的 ``metric`` 68 | * 推理时, 返回推理结果 69 | """ 70 | if callable(parameters_or_data): 71 | parameters_or_data = parameters_or_data() 72 | if isinstance(parameters_or_data, Generator) and hasattr( 73 | parameters_or_data, '__qualname__'): 74 | if parameters_or_data.__qualname__ == 'SequentialTask.run': # type: ignore [attr-defined] 75 | self._sequential = True 76 | parameters_or_data = next(parameters_or_data) 77 | if isinstance(parameters_or_data, dict) \ 78 | and 'model' in parameters_or_data.keys(): 79 | return parameters_or_data 80 | else: 81 | # 下面的是直接传入数据集的情况,需要根据 global_config 构建 task 82 | data_bundle = build_dataset(parameters_or_data, 83 | dataset_config=self._overload_config) 84 | parameters_or_data = build_task(self._overload_config)(data_bundle) 85 | if isinstance(parameters_or_data, Generator): 86 | parameters_or_data = next(parameters_or_data) 87 | return parameters_or_data 88 | 89 | def __call__(self, *args, **kwargs): 90 | """重载 ``__call__`` 方法,使得控制器可以直接调用 ``run`` 方法. 91 | 92 | :param args: 与 ``run`` 方法的参数一致 93 | :param kwargs: 与 ``run`` 方法的参数一致 94 | :return: 与 ``run`` 方法的返回结果一致 95 | """ 96 | return self.run(*args, **kwargs) 97 | -------------------------------------------------------------------------------- /fastie/exhibition.py: -------------------------------------------------------------------------------- 1 | """FastIE 展示模块.""" 2 | 3 | __all__ = ['Exhibition'] 4 | 5 | import os 6 | import sys 7 | 8 | from texttable import Texttable 9 | 10 | from fastie.dataset import DATASET 11 | from fastie.envs import FASTIE_HOME 12 | from fastie.tasks import NER, EE, RE 13 | from fastie.utils import Config 14 | 15 | 16 | class Exhibition: 17 | """FastIE 展示模块, 主要用来在命令行模式中展示 FastIE 现有的 ``dataset``, ``task`` 和 18 | ``config``""" 19 | 20 | @property 21 | def NER(self): 22 | table = list() 23 | for key, value in NER.module_dict.items(): 24 | table.append( 25 | dict(task='NER', solution=key, 26 | description=value().description)) 27 | return table 28 | 29 | @property 30 | def RE(self): 31 | table = list() 32 | for key, value in RE.module_dict.items(): 33 | table.append( 34 | dict(task='RE', solution=key, description=value().description)) 35 | return table 36 | 37 | @property 38 | def EE(self): 39 | table = list() 40 | for key, value in EE.module_dict.items(): 41 | table.append( 42 | dict(task='EE', solution=key, description=value().description)) 43 | return table 44 | 45 | @property 46 | def TASK(self): 47 | return self.NER + self.RE + self.EE 48 | 49 | @property 50 | def DATASET(self): 51 | table = list() 52 | for key, value in DATASET.module_dict.items(): 53 | table.append(dict(dataset=key, description=value().description)) 54 | return table 55 | 56 | @classmethod 57 | def intercept(cls): 58 | if '-l' in sys.argv or '--list' in sys.argv: 59 | sys.argv.append('--task') 60 | table = Texttable() 61 | table.set_deco(Texttable.HEADER) 62 | exhibition = cls() 63 | for i in range(len(sys.argv)): 64 | if sys.argv[i].startswith('--task') or sys.argv[i].startswith( 65 | '-t'): 66 | if i < len(sys.argv) - 1 and sys.argv[i + 1].upper( 67 | ).replace('/', '').replace('\\', '') == 'NER': 68 | table.add_rows([[ 69 | 'Task', 'Solution', 'Description' 70 | ], *[[ 71 | item['task'], item['solution'], item['description'] 72 | ] for item in exhibition.NER]]) 73 | elif i < len(sys.argv) - 1 and sys.argv[i + 1].upper( 74 | ).replace('/', '').replace('\\', '') == 'EE': 75 | table.add_rows([[ 76 | 'Task', 'Solution', 'Description' 77 | ], *[[ 78 | item['task'], item['solution'], item['description'] 79 | ] for item in exhibition.EE]]) 80 | elif i < len(sys.argv) - 1 and sys.argv[i + 1].upper( 81 | ).replace('/', '').replace('\\', '') == 'RE': 82 | table.add_rows([[ 83 | 'Task', 'Solution', 'Description' 84 | ], *[[ 85 | item['task'], item['solution'], item['description'] 86 | ] for item in exhibition.RE]]) 87 | else: 88 | table.add_rows([[ 89 | 'Task', 'Solution', 'Description' 90 | ], *[[ 91 | item['task'], item['solution'], item['description'] 92 | ] for item in exhibition.TASK]]) 93 | table.add_row(['', '', '']) 94 | print(table.draw()) 95 | exit(0) 96 | elif sys.argv[i].startswith( 97 | '--dataset') or sys.argv[i].startswith('-d'): 98 | table.add_rows([['Dataset', 'Description'], 99 | *[[item['dataset'], item['description']] 100 | for item in exhibition.DATASET]]) 101 | table.add_row(['', '']) 102 | print(table.draw()) 103 | exit(0) 104 | elif sys.argv[i].startswith( 105 | '--config') or sys.argv[i].startswith('-c'): 106 | table.add_row(['Config', 'Description']) 107 | for root, dirs, files in os.walk( 108 | os.path.join(FASTIE_HOME, 'configs')): 109 | for file in files: 110 | if file.endswith('.py'): 111 | config = Config.fromfile( 112 | os.path.join(root, file)) 113 | description = f'{file}' 114 | if '_help' in config.keys(): 115 | description = f": {config['_help']}" 116 | table.add_row([ 117 | f"{file.replace('.py', '')}", description 118 | ]) 119 | print(table.draw()) 120 | exit(0) 121 | -------------------------------------------------------------------------------- /docs/source/tutorials/basic/fastie_tutorial_1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# T1. 使用 SDK 工具\n", 7 | "`fastie` 除了支持命令行工具,还支持 `Python` `SDK`,可以方便的在 `Python` 脚本中使用 `fastie`.\n", 8 | "\n", 9 | "`fastie` 的 `SDK` 工具主要包括三个模块:\n", 10 | "\n", 11 | "- `dataset` 模块\n", 12 | "- `task` 模块\n", 13 | "- `controller` 模块\n", 14 | "\n", 15 | "## 1. dataset 模块\n", 16 | "\n", 17 | "`dataset` 即为命令行参数中的 `--dataset`,用于指定数据集. 例如在上一节基础教程 `T0. 使用命令行工具`\n", 18 | "中, 我们使用了 `jsonlines-ner` 作为参数来加载数据集, 那么在 `python` 中, 我们可以使用如下代码来加载数据集:\n", 19 | "\n", 20 | "```python\n", 21 | "from fastie.dataset.io import JsonLinesNER\n", 22 | "\n", 23 | "data_bundle = JsonLinesNER(folder=\"/path/to/dataset\").run()\n", 24 | "```\n", 25 | "\n", 26 | "注意 `fastie` 中的模块都有传入参数和 `run` 两步, 传入参数用于指定模块的参数, `run` 用于执行模块.\n", 27 | "\n", 28 | "`fastie` 中 `dataset` 获得的数据集是 `fastNLP.io.DataBundle` 类型, 即包含多个子集, 例如\n", 29 | "`train`, `dev`, `test`, `infer`, 详细的使用方法可以参考\n", 30 | "[fastNLP 文档](https://fastnlp.readthedocs.io/zh/latest/)\n", 31 | "\n", 32 | "不同的 `controller` 会使用不同的子集, 例如 `train` 控制器会使用 `train` 子集和 `dev` 子集\n", 33 | "(如果要在训练阶段同时进行验证), `eval` 控制器会使用 `test` 子集, `infer` 控制器会使用 `infer` 子集.\n", 34 | "\n", 35 | "## 2. task 模块\n", 36 | "\n", 37 | "`task` 即为命令行参数中的 `--task`,用于指定任务. 例如在上一节基础教程 `T0. 使用命令行工具`\n", 38 | "中, 我们使用了 `ner/bert` 作为参数来加载任务, 那么在 `python` 中, 我们可以使用如下代码来加载任务:\n", 39 | "\n", 40 | "```python\n", 41 | "from fastie.tasks.ner import BertNER\n", 42 | "\n", 43 | "task = BertNER(batch_size=32, load_model=\"path/to/model\").run(data_bundle)\n", 44 | "```\n", 45 | "\n", 46 | "与 `dataset` 类似, `task` 也有传入参数和 `run` 两步, 传入参数用于指定模块的参数, `run` 用于执行模块,\n", 47 | "不同的是, `task` 的 `run` 方法需要传入 `data_bundle` 参数, 也就是 `dataset`, `task`,\n", 48 | "`controll` 具有顺序关系, 体现在每一步的 `run` 要传入上一步的结构.\n", 49 | "\n", 50 | "## 3. controller 模块\n", 51 | "\n", 52 | "`controller` 即为命令行模式中的可执行部分, 例如 `fastie-train`,用于指定控制器. 例如在上一节基础教程\n", 53 | "`T0. 使用命令行工具`, 我们使用了 `fastie-train` 作为启动来加载控制器, 那么在 `python` 中,\n", 54 | "我们可以使用如下代码来加载控制器:\n", 55 | "\n", 56 | "```python\n", 57 | "from fastie.controller import Trainer\n", 58 | "\n", 59 | "state_dict = Train().run(task)\n", 60 | "```\n", 61 | "\n", 62 | "与 `dataset` 和 `task` 类似, `controller` 也有传入参数和 `run` 两步, 传入参数用于指定模块的参数,\n", 63 | "`run` 用于执行模块, 不同的是, `controller` 的 `run` 方法需要传入 `task` 参数, 也就是 `dataset`,\n", 64 | "`task`, `controll` 具有顺序关系, 体现在每一步的 `run` 要传入上一步的结构.\n", 65 | "\n", 66 | "`controller` 的 `run` 返回结果与控制器的类型有关, 当使用 `train` 控制器时, `run` 返回的\n", 67 | "结果是 `task` 的 `state_dict`, 包含模型参数; 当使用 `eval` 控制器时, `run` 返回的结果是\n", 68 | "`json` 的验证结果; 当使用 `infer` 控制器时, `run` 返回的结果是的推理结果.\n", 69 | "\n", 70 | "## 4. 使用 controller 模块 加载配置\n", 71 | "\n", 72 | "前文 `T0. 使用命令行工具` 提到了使用 `config` 文件可以方便地对 `task`, `dataset` 以及其他\n", 73 | "参数进行储存, 在 `SDK` 方式中也可以使用 `config` 文件, 例如:\n", 74 | "\n", 75 | "```python\n", 76 | "from fastie.controller import Trainer\n", 77 | "\n", 78 | "state_dict = Trainer.from_config(\"config.py\").run()\n", 79 | "```\n", 80 | "\n", 81 | "如上所示, 使用 `controller` 的类方法 `from_config`, 即可不指定 `controller` 对象 `run`\n", 82 | "方法的参数, 具体的 `task` 和 `dataset` 将会到配置文件中搜索并自动运行.\n", 83 | "\n", 84 | "`controller` 的 `from_config` 除了支持文件类型的 `config` 外, 还可以输入 `dict` 类型的\n", 85 | "`config`, 例如:\n", 86 | "\n", 87 | "```python\n", 88 | "from fastie.controller import Trainer\n", 89 | "\n", 90 | "config = {\n", 91 | " \"task\": \"ner/bert\",\n", 92 | " \"dataset\": \"conll2003\",\n", 93 | " \"batch_size\": 32,\n", 94 | " \"device\": [0, 1]\n", 95 | "}\n", 96 | "state_dict = Trainer.from_config(config).run()\n", 97 | "```\n", 98 | "\n", 99 | "上述的使用与配置文件基本一致, 只是将配置文件的内容写道了代码里.\n", 100 | "\n", 101 | "## 5. 技巧: 重复使用 task 和 dataset\n", 102 | "\n", 103 | "当您使用 `controller` 训练完成后, 可直接使用 `task` 的结果进行验证和推理, 例如:\n", 104 | "\n", 105 | "```python\n", 106 | "from fastie.controller import Trainer, Evaluator, Inference\n", 107 | "from fastie.tasks.ner import BertNER\n", 108 | "from fastie.dataset.io import Conll2003\n", 109 | "\n", 110 | "data_bundle = Conll2003().run()\n", 111 | "task = BertNER(batch_size=32, load_best_model=True).run(data_bundle)\n", 112 | "\n", 113 | "Trainer().run(task)\n", 114 | "Evaluator().run(task)\n", 115 | "Inference().run(task)\n", 116 | "```\n", 117 | "\n", 118 | "如上所示, 当 `Trainer().run(task)` 执行结束后, `task` 中的 `model` 就已经训练完成,\n", 119 | "因此可以直接使用同一个 `task` 结果进行验证和推理, 当然, 如果要进行验证, 请确保您的 `data_bundle`\n", 120 | "中包含 `test` 集; 如果要进行推理, 请确保您的 `data_bundle` 中包含 `infer` 集." 121 | ], 122 | "metadata": { 123 | "collapsed": false 124 | } 125 | } 126 | ], 127 | "metadata": { 128 | "kernelspec": { 129 | "display_name": "Python 3", 130 | "language": "python", 131 | "name": "python3" 132 | }, 133 | "language_info": { 134 | "codemirror_mode": { 135 | "name": "ipython", 136 | "version": 2 137 | }, 138 | "file_extension": ".py", 139 | "mimetype": "text/x-python", 140 | "name": "python", 141 | "nbconvert_exporter": "python", 142 | "pygments_lexer": "ipython2", 143 | "version": "2.7.6" 144 | } 145 | }, 146 | "nbformat": 4, 147 | "nbformat_minor": 0 148 | } 149 | -------------------------------------------------------------------------------- /fastie/dataset/io/jsonlinesNER.py: -------------------------------------------------------------------------------- 1 | """JsonLinesNER dataset for FastIE.""" 2 | __all__ = ['JsonLinesNER', 'JsonLinesNERConfig'] 3 | 4 | import json 5 | import os 6 | from dataclasses import dataclass, field 7 | 8 | from fastNLP import DataSet, Instance 9 | from fastNLP.io import Loader, DataBundle 10 | 11 | from fastie.dataset.base_dataset import BaseDataset, DATASET, BaseDatasetConfig 12 | 13 | 14 | @dataclass 15 | class JsonLinesNERConfig(BaseDatasetConfig): 16 | """JsonLinesNER 数据集配置类.""" 17 | folder: str = field( 18 | default='', 19 | metadata=dict(help='The folder where the data set resides. ' 20 | 'We will automatically read the possible train.jsonl, ' 21 | 'dev.jsonl, test.jsonl and infer.jsonl in it. ', 22 | existence=True)) 23 | right_inclusive: bool = field( 24 | default=True, 25 | metadata=dict( 26 | help='When data is in the format of start and end, ' 27 | 'whether each span contains the token corresponding to end. ', 28 | existence=True)) 29 | 30 | 31 | @DATASET.register_module('jsonlines-ner') 32 | class JsonLinesNER(BaseDataset): 33 | """JsonLinesNER dataset for FastIE. Each row has a NER sample in json 34 | format: 35 | 36 | .. code-block:: json 37 | { 38 | "tokens": ["I", "love", "FastIE", "."], 39 | "entity_mentions": [ 40 | { 41 | "entity_index": [2], 42 | "entity_type": "MISC" 43 | }, 44 | } 45 | 46 | or: 47 | 48 | .. code-block:: json 49 | { 50 | "tokens": ["I", "love", "FastIE", "."], 51 | "entity_mentions": [ 52 | { 53 | "start": 2, 54 | "end": 3, 55 | "entity_type": "MISC" 56 | }, 57 | } 58 | 59 | :param folder: The folder where the data set resides. 60 | :param right_inclusive: When data is in the format of start and end, 61 | whether each span contains the token corresponding to end. 62 | :param cache: Whether to cache the dataset. 63 | :param refresh_cache: Whether to refresh the cache. 64 | """ 65 | _config = JsonLinesNERConfig() 66 | _help = 'JsonLinesNER dataset for FastIE. Each row has a NER sample in json format. ' 67 | 68 | def __init__(self, 69 | folder: str = '', 70 | right_inclusive: bool = False, 71 | cache: bool = False, 72 | refresh_cache: bool = False, 73 | **kwargs): 74 | BaseDataset.__init__(self, 75 | cache=cache, 76 | refresh_cache=refresh_cache, 77 | **kwargs) 78 | self.folder = folder 79 | self.right_inclusive = right_inclusive 80 | 81 | def run(self) -> DataBundle: 82 | node = self 83 | 84 | class JsonNERLoader(Loader): 85 | 86 | def _load(self, path: str) -> DataSet: 87 | dataset = DataSet() 88 | with open(path, 'r', encoding='utf-8') as file: 89 | for line in file.readlines(): 90 | line = line.strip() 91 | if line: 92 | sample: dict = json.loads(line) 93 | instance = Instance() 94 | instance.add_field('tokens', sample['tokens']) 95 | if 'entity_mentions' in sample.keys(): 96 | entity_mentions = [] 97 | for entity_mention in sample[ 98 | 'entity_mentions']: 99 | if 'entity_index' in entity_mention.keys(): 100 | entity_mentions.append( 101 | (entity_mention['entity_index'], 102 | entity_mention['entity_type'])) 103 | elif 'start' in entity_mention.keys( 104 | ) and 'end' in entity_mention.keys(): 105 | if node.right_inclusive: 106 | entity_mentions.append((list( 107 | range( 108 | entity_mention['start'], 109 | entity_mention['end'] + 1) 110 | ), entity_mention['entity_type'])) 111 | else: 112 | entity_mentions.append((list( 113 | range(entity_mention['start'], 114 | entity_mention['end']) 115 | ), entity_mention['entity_type'])) 116 | instance.add_field('entity_mentions', 117 | entity_mentions) 118 | dataset.append(instance) 119 | return dataset 120 | 121 | data_bundle = JsonNERLoader().load({ 122 | file: os.path.join(self.folder, f'{file}.jsonl') 123 | for file in ('train', 'dev', 'test', 'infer') 124 | if os.path.exists(os.path.join(self.folder, f'{file}.jsonl')) 125 | }) 126 | return data_bundle 127 | -------------------------------------------------------------------------------- /fastie/dataset/io/columnNER.py: -------------------------------------------------------------------------------- 1 | """Conll2003 like dataset for FastIE.""" 2 | 3 | __all__ = ['ColumnNER', 'ColumnNERConfig'] 4 | 5 | import os 6 | from dataclasses import dataclass, field 7 | from functools import reduce 8 | from typing import Union, Sequence, List 9 | 10 | from fastNLP import DataSet, Instance 11 | from fastNLP.io import Loader, DataBundle 12 | 13 | from fastie.dataset.base_dataset import DATASET, BaseDatasetConfig, BaseDataset 14 | 15 | 16 | @dataclass 17 | class ColumnNERConfig(BaseDatasetConfig): 18 | """ColumnNER 数据集配置类.""" 19 | folder: str = field( 20 | default='', 21 | metadata=dict(help='The folder where the data set resides. \n' 22 | 'We will automatically read the possible train.txt, \n' 23 | 'dev.txt, test.txt and infer.txt in it. ', 24 | existence=True)) 25 | token_index: int = field(default=0, 26 | metadata=dict(help='The index of tokens.', 27 | existence=True)) 28 | tag_index: int = field(default=-1, 29 | metadata=dict(help='The index of tags to predict.', 30 | existence=['train', 'eval'])) 31 | split_char: str = field( 32 | default=' ', 33 | metadata=dict(help='The split char. If this parameter is not set, ' 34 | 'it is separated by space. ', 35 | existence=True)) 36 | skip_content: str = field( 37 | default=' ', 38 | metadata=dict(help='The content to skip. If this item is not set, ' 39 | 'it is divided by newline character. ', 40 | existence=True)) 41 | 42 | 43 | @DATASET.register_module('column-ner') 44 | class ColumnNER(BaseDataset): 45 | """Conll2003 like dataset for FastIE. Each row has a token and its 46 | corresponding NER tag. 47 | 48 | :param folder: The folder where the data set resides. ``train.txt``, 49 | ``dev.txt``, ``test.txt`` and ``infer.txt`` in the folder will be loaded. 50 | :param token_index: The index of tokens in a row. 51 | :param tag_index: The index of tags to predict in a row. 52 | :param split_char: The split character. If this parameter is not set, it is 53 | separated by space. 54 | :param skip_content: The content to skip. If this item is not set, it is 55 | divided by newline character. 56 | :param cache: Whether to cache the dataset. 57 | :param refresh_cache: Whether to refresh the cache. 58 | """ 59 | _config = ColumnNERConfig() 60 | _help = 'Conll2003 like dataset for FastIE. Each row has a token and its corresponding NER tag.' 61 | 62 | def __init__(self, 63 | folder: str = './', 64 | token_index: int = 0, 65 | tag_index: int = -1, 66 | split_char: str = ' ', 67 | skip_content: Union[str, Sequence[str]] = '\n', 68 | cache: bool = False, 69 | refresh_cache: bool = False, 70 | **kwargs): 71 | super(ColumnNER, self).__init__(cache=cache, 72 | refresh_cache=refresh_cache, 73 | **kwargs) 74 | self.folder = folder 75 | self.token_index = token_index 76 | self.tag_index = tag_index 77 | self.split_char = split_char 78 | self.skip_content: Sequence = [skip_content] \ 79 | if isinstance(skip_content, str) else skip_content 80 | 81 | def run(self) -> DataBundle: 82 | node = self 83 | 84 | class ColumnNERLoader(Loader): 85 | 86 | def _load(self, path: str) -> DataSet: 87 | ds = DataSet() 88 | data: List[dict] = [] 89 | with open(path, 'r', encoding='utf-8') as f: 90 | for line in f.readlines(): 91 | if reduce(lambda x, y: x or y, [ 92 | line.startswith(content) 93 | for content in node.skip_content 94 | ]) or line.strip() == '': 95 | if len(data) != 0: 96 | if 'infer' in path: 97 | ds.append( 98 | Instance( 99 | tokens=[d['token'] for d in data])) 100 | else: 101 | ... 102 | # ds.append( 103 | # Instance(tokens=[d['token'] for d in data], 104 | # tags=[d['tag'] for d in data])) 105 | # data = [] 106 | continue 107 | lines = line.strip().split(node.split_char) 108 | if 'infer' in path: 109 | data.append({'token': lines[node.token_index]}) 110 | else: 111 | data.append({ 112 | 'token': lines[node.token_index], 113 | 'tag': lines[node.tag_index] 114 | }) 115 | if len(data) != 0: 116 | if 'infer' in path: 117 | ds.append(Instance(tokens=[d['token'] for d in data])) 118 | else: 119 | ... 120 | return ds 121 | 122 | data_bundle = ColumnNERLoader().load({ 123 | file: os.path.exists(os.path.join(self.folder, f'{file}.txt')) 124 | for file in ('train', 'dev', 'test', 'infer') 125 | if os.path.exists(os.path.join(self.folder, f'{file}.txt')) 126 | }) 127 | return data_bundle 128 | -------------------------------------------------------------------------------- /fastie/dataset/legacy/conll2003.py: -------------------------------------------------------------------------------- 1 | """The shared dataset of CoNLL-2003 concerns language-independent named entity 2 | recognition.""" 3 | __all__ = ['Conll2003', 'Conll2003Config'] 4 | 5 | from dataclasses import dataclass 6 | 7 | import numpy as np 8 | from datasets import load_dataset 9 | from fastNLP import DataSet, Instance 10 | from fastNLP.io import DataBundle 11 | 12 | from fastie.dataset.base_dataset import BaseDataset, DATASET, BaseDatasetConfig 13 | 14 | 15 | @dataclass 16 | class Conll2003Config(BaseDatasetConfig): 17 | """Conll2003 数据集的配置类.""" 18 | pass 19 | 20 | 21 | @DATASET.register_module('conll2003') 22 | class Conll2003(BaseDataset): 23 | """The shared task of CoNLL-2003 concerns language-independent named entity 24 | recognition.""" 25 | _config = Conll2003Config() 26 | _help = 'Conll2003 for NER task. Refer to ' \ 27 | 'https://huggingface.co/datasets/conll2003 for more information.' 28 | 29 | def __init__(self, 30 | cache: bool = False, 31 | refresh_cache: bool = False, 32 | **kwargs): 33 | super(Conll2003, self).__init__(cache=cache, 34 | refresh_cache=refresh_cache, 35 | **kwargs) 36 | 37 | def run(self): 38 | raw_dataset = load_dataset('conll2003') 39 | datasets = {} 40 | tag2idx = { 41 | 'ner': { 42 | 'O': 0, 43 | 'B-PER': 1, 44 | 'I-PER': 2, 45 | 'B-ORG': 3, 46 | 'I-ORG': 4, 47 | 'B-LOC': 5, 48 | 'I-LOC': 6, 49 | 'B-MISC': 7, 50 | 'I-MISC': 8 51 | }, 52 | 'pos': { 53 | '"': 0, 54 | "''": 1, 55 | '#': 2, 56 | '$': 3, 57 | '(': 4, 58 | ')': 5, 59 | ',': 6, 60 | '.': 7, 61 | ':': 8, 62 | '``': 9, 63 | 'CC': 10, 64 | 'CD': 11, 65 | 'DT': 12, 66 | 'EX': 13, 67 | 'FW': 14, 68 | 'IN': 15, 69 | 'JJ': 16, 70 | 'JJR': 17, 71 | 'JJS': 18, 72 | 'LS': 19, 73 | 'MD': 20, 74 | 'NN': 21, 75 | 'NNP': 22, 76 | 'NNPS': 23, 77 | 'NNS': 24, 78 | 'NN|SYM': 25, 79 | 'PDT': 26, 80 | 'POS': 27, 81 | 'PRP': 28, 82 | 'PRP$': 29, 83 | 'RB': 30, 84 | 'RBR': 31, 85 | 'RBS': 32, 86 | 'RP': 33, 87 | 'SYM': 34, 88 | 'TO': 35, 89 | 'UH': 36, 90 | 'VB': 37, 91 | 'VBD': 38, 92 | 'VBG': 39, 93 | 'VBN': 40, 94 | 'VBP': 41, 95 | 'VBZ': 42, 96 | 'WDT': 43, 97 | 'WP': 44, 98 | 'WP$': 45, 99 | 'WRB': 46 100 | }, 101 | 'chunk': { 102 | 'O': 0, 103 | 'B-ADJP': 1, 104 | 'I-ADJP': 2, 105 | 'B-ADVP': 3, 106 | 'I-ADVP': 4, 107 | 'B-CONJP': 5, 108 | 'I-CONJP': 6, 109 | 'B-INTJ': 7, 110 | 'I-INTJ': 8, 111 | 'B-LST': 9, 112 | 'I-LST': 10, 113 | 'B-NP': 11, 114 | 'I-NP': 12, 115 | 'B-PP': 13, 116 | 'I-PP': 14, 117 | 'B-PRT': 15, 118 | 'I-PRT': 16, 119 | 'B-SBAR': 17, 120 | 'I-SBAR': 18, 121 | 'B-UCP': 19, 122 | 'I-UCP': 20, 123 | 'B-VP': 21, 124 | 'I-VP': 22 125 | } 126 | } 127 | idx2tag = {'ner': {}, 'pos': {}, 'chunk': {}} 128 | idx2tag['ner'] = {v: k for k, v in tag2idx['ner'].items()} 129 | idx2tag['pos'] = {v: k for k, v in tag2idx['pos'].items()} 130 | idx2tag['chunk'] = {v: k for k, v in tag2idx['chunk'].items()} 131 | for split, dataset in raw_dataset.items(): 132 | split = split.replace('validation', 'dev') 133 | datasets[split] = DataSet() 134 | for sample in dataset: 135 | instance = Instance() 136 | instance.add_field('tokens', sample['tokens']) 137 | entity_mentions = [] 138 | span = [] 139 | current_tag = 0 140 | for i in np.arange(len(sample['ner_tags'])): 141 | if sample['ner_tags'][i] != 0: 142 | if len(span) == 0: 143 | current_tag = sample['ner_tags'][i] 144 | span.append(i) 145 | continue 146 | else: 147 | if current_tag == sample['ner_tags'][i] \ 148 | or current_tag + 1 == sample['ner_tags'][i]: 149 | span.append(i) 150 | continue 151 | else: 152 | entity_mentions.append( 153 | (span, idx2tag['ner'][current_tag][2:])) 154 | span = [i] 155 | current_tag = sample['ner_tags'][i] 156 | continue 157 | else: 158 | if len(span) > 0: 159 | entity_mentions.append((span, idx2tag['ner'][ 160 | sample['ner_tags'][span[0]]][2:])) 161 | span = [] 162 | if len(span) > 0: 163 | entity_mentions.append( 164 | (span, 165 | idx2tag['ner'][sample['ner_tags'][span[0]]][2:])) 166 | instance.add_field('entity_mentions', entity_mentions) 167 | instance.add_field('pos_tags', sample['pos_tags']) 168 | instance.add_field('chunk_tags', sample['chunk_tags']) 169 | instance.add_field('ner_tags', sample['ner_tags']) 170 | datasets[split].append(instance) 171 | data_bundle = DataBundle(datasets=datasets) 172 | return data_bundle 173 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # 2 | # Configuration file for the Sphinx documentation builder. 3 | # 4 | # This file does only contain a selection of the most common options. For a 5 | # full list see the documentation: 6 | # http://www.sphinx-doc.org/en/master/config 7 | 8 | # -- Path setup -------------------------------------------------------------- 9 | 10 | # If extensions (or modules to document with autodoc) are in another directory, 11 | # add these directories to sys.path here. If the directory is relative to the 12 | # documentation root, use os.path.abspath to make it absolute, like shown here. 13 | # 14 | import os 15 | import sys 16 | 17 | sys.path.insert(0, os.path.abspath('../../')) 18 | 19 | import pytorch_sphinx_theme 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = 'FastIE' 24 | copyright = '2023, FastIE' 25 | author = 'FastIE' 26 | 27 | # The short X.Y version 28 | version = '0.0.1' 29 | # The full version, including alpha/beta/rc tags 30 | release = '0.0.1-beta' 31 | 32 | # -- General configuration --------------------------------------------------- 33 | 34 | # If your documentation needs a minimal Sphinx version, state it here. 35 | # 36 | # needs_sphinx = '1.0' 37 | 38 | # Add any Sphinx extension module names here, as strings. They can be 39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 40 | # ones. 41 | extensions = [ 42 | 'sphinx.ext.autodoc', 43 | 'sphinx.ext.viewcode', 44 | 'sphinx.ext.autosummary', 45 | 'sphinx.ext.mathjax', 46 | 'sphinx.ext.todo', 47 | 'sphinx_autodoc_typehints', 48 | 'sphinx_multiversion', 49 | 'nbsphinx', 50 | 'sphinx_copybutton', 51 | ] 52 | 53 | autodoc_default_options = { 54 | 'member-order': 'bysource', 55 | 'special-members': '__init__', 56 | 'undoc-members': False, 57 | } 58 | 59 | add_module_names = False 60 | autosummary_ignore_module_all = False 61 | # autodoc_typehints = "description" 62 | autoclass_content = 'class' 63 | typehints_fully_qualified = False 64 | typehints_defaults = 'comma' 65 | 66 | nbsphinx_allow_errors = True 67 | 68 | # Add any paths that contain templates here, relative to this directory. 69 | templates_path = ['_templates'] 70 | # template_bridge 71 | # The suffix(es) of source filenames. 72 | # You can specify multiple suffix as a list of string: 73 | # 74 | # source_suffix = ['.rst', '.md'] 75 | source_suffix = '.rst' 76 | 77 | # The master toctree document. 78 | master_doc = 'index' 79 | 80 | # The language for content autogenerated by Sphinx. Refer to documentation 81 | # for a list of supported languages. 82 | # 83 | # This is also used if you do content translation via gettext catalogs. 84 | # Usually you set "language" from the command line for these cases. 85 | language = 'zh_CN' 86 | 87 | # List of patterns, relative to source directory, that match files and 88 | # directories to ignore when looking for source files. 89 | # This pattern also affects html_static_path and html_extra_path . 90 | exclude_patterns = ['modules.rst'] 91 | 92 | # The name of the Pygments (syntax highlighting) style to use. 93 | # pygments_style = 'sphinx' 94 | 95 | # -- Options for HTML output ------------------------------------------------- 96 | 97 | # The theme to use for HTML and HTML Help pages. See the documentation for 98 | # a list of builtin themes. 99 | # 100 | # html_theme = 'sphinx_rtd_theme' 101 | html_theme = 'pytorch_sphinx_theme' 102 | html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] 103 | 104 | # Theme options are theme-specific and customize the look and feel of a theme 105 | # further. For a list of options available for each theme, see the 106 | # documentation. 107 | # 108 | html_theme_options = { 109 | 'logo_url': 110 | 'https://github.com/open-nlplab/fastIE', 111 | 'menu': [ 112 | # A link 113 | { 114 | 'name': 'GitHub', 115 | 'url': 'https://github.com/open-nlplab/fastIE' 116 | }, 117 | # A dropdown menu 118 | # { 119 | # 'name': 'Projects', 120 | # 'children': [ 121 | # # A vanilla dropdown item 122 | # { 123 | # 'name': 'MMCV', 124 | # 'url': 'https://github.com/open-mmlab/mmcv', 125 | # }, 126 | # # A dropdown item with a description 127 | # { 128 | # 'name': 'MMDetection', 129 | # 'url': 'https://github.com/open-mmlab/mmdetection', 130 | # 'description': 'Object detection toolbox and benchmark' 131 | # }, 132 | # ], 133 | # # Optional, determining whether this dropdown menu will always be 134 | # # highlighted. 135 | # 'active': True, 136 | # }, 137 | ], 138 | # Specify the language of shared menu 139 | 'menu_lang': 140 | 'cn', 141 | } 142 | 143 | # Add any paths that contain custom static files (such as style sheets) here, 144 | # relative to this directory. They are copied after the builtin static files, 145 | # so a file named "default.css" will overwrite the builtin "default.css". 146 | html_static_path = ['_static'] 147 | html_css_files = ['css/readthedocs.css', 'css/badge_only.css'] 148 | 149 | # Custom sidebar templates, must be a dictionary that maps document names 150 | # to template names. 151 | # 152 | # The default sidebars (for documents that don't match any pattern) are 153 | # defined by theme itself. Builtin themes are using these templates by 154 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 155 | # 'searchbox.html']``. 156 | # 157 | # html_sidebars = {} 158 | # html_sidebars = { 159 | # '**': [ 160 | # 'versions.html', 161 | # ], 162 | # } 163 | 164 | # -- Options for HTMLHelp output --------------------------------------------- 165 | 166 | # Output file base name for HTML help builder. 167 | htmlhelp_basename = 'FastIE' 168 | 169 | # -- Options for LaTeX output ------------------------------------------------ 170 | 171 | latex_elements = { 172 | # The paper size ('letterpaper' or 'a4paper'). 173 | # 174 | # 'papersize': 'letterpaper', 175 | 176 | # The font size ('10pt', '11pt' or '12pt'). 177 | # 178 | # 'pointsize': '10pt', 179 | 180 | # Additional stuff for the LaTeX preamble. 181 | # 182 | # 'preamble': '', 183 | 184 | # Latex figure (float) alignment 185 | # 186 | # 'figure_align': 'htbp', 187 | } 188 | 189 | # Grouping the document tree into LaTeX files. List of tuples 190 | # (source start file, target name, title, 191 | # author, documentclass [howto, manual, or own class]). 192 | latex_documents = [] 193 | 194 | # -- Options for manual page output ------------------------------------------ 195 | 196 | # One entry per manual page. List of tuples 197 | # (source start file, name, description, authors, manual section). 198 | man_pages = [(master_doc, 'FastIE', 'FastIE Documentation', [author], 1)] 199 | 200 | # -- Options for Texinfo output ---------------------------------------------- 201 | 202 | # Grouping the document tree into Texinfo files. List of tuples 203 | # (source start file, target name, title, author, 204 | # dir menu entry, description, category) 205 | texinfo_documents = [ 206 | (master_doc, 'FastIE', 'FastIE Documentation', author, 'FastIE', 207 | 'A general integration framework for information extraction.', 208 | 'Miscellaneous'), 209 | ] 210 | 211 | # Ignore >>> when copying code 212 | copybutton_prompt_text = r'>>> |\.\.\. ' 213 | copybutton_prompt_is_regexp = True 214 | 215 | 216 | # -- Extension configuration ------------------------------------------------- 217 | def maybe_skip_member(app, what, name, obj, skip, options): 218 | if obj.__doc__ is None: 219 | return True 220 | if name == '__init__': 221 | return False 222 | if name.startswith('_'): 223 | return True 224 | return skip 225 | 226 | 227 | def setup(app): 228 | app.connect('autodoc-skip-member', maybe_skip_member) 229 | -------------------------------------------------------------------------------- /docs/source/tutorials/basic/fastie_tutorial_2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# T2. 开发自己的 task\n", 7 | "\n", 8 | "`fastie` 中集成了一些常用的 `task`,但是并不是所有的任务都能满足你的需求,因此你可以自己开发自己的 `task`.\n", 9 | "开发自己的 `task` 可以选择继承已有任务进行修改, 或者继承 `BaseTask` 类全新开发.\n", 10 | "\n", 11 | "## 1. task 的生命周期\n", 12 | "\n", 13 | "`fastie` 中, 每个 `task` 的 `run` 方法都遵守一套固定的流程(我们将其称为 `fastie` 的生命周期),\n", 14 | "其中流程的每个阶段都有固定的任务目标. 换而言之, `task` 的每个生命周期方法都有固定的参数输入和期望的\n", 15 | "返回值. 例如, `task` 的 `on_setup_model` 方法的任务目标是模型的搭建, 因此该方方法期待返回一个拥有\n", 16 | "`train_step`, `evaluation_step` 和 `infer_step` 方法的 `model` 对象. 因此, 在实现自己\n", 17 | "的 `task` 的过程中, 你需要遵守 `fastie` 的生命周期, 并且在每个生命周期方法中返回期望的对象.\n", 18 | "\n", 19 | "`fastie` 的生命周期如下:\n", 20 | "\n", 21 | "" 22 | ], 23 | "metadata": { 24 | "collapsed": false 25 | } 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "source": [ 30 | "### 1.1 on_generate_and_check_tag_vocab\n", 31 | "\n", 32 | "`on_generate_and_check_tag_vocab` 方法的功能为, 根据输入原始数据集 `data_bundle` 生成\n", 33 | "标签词典 `tag_vocab`, 并将生成的词典和可能存在的从模型加载到的词典对比检查.\n", 34 | "\n", 35 | "该方法的输入为原始数据集 `data_bundle` 和 `checkpoint` 信息 `state_dict`. 该方法应返回一个\n", 36 | "`fastNLP.Vocabulary` 的变量或数组 (存在多个标签值的情况).\n", 37 | "\n", 38 | "如上流程图所示, 该方法为非必要重写的方法, `task` 的基类 `Basetask` 已经实现了该方法的基本逻辑.\n", 39 | "在有特殊需求的情况下可以重新定义该方法.\n", 40 | "\n", 41 | "### 1.2 on_dataset_preprocess\n", 42 | "\n", 43 | "`on_dataset_preprocess` 方法为将原始数据集 `data_bundle` 进行预处理的方法, 包括把原始的\n", 44 | "`tag` 通过上一步生成的 `tag_vocab` 进行转换为可以计算的 `id`, 以及将 `token` 转换为 `id` 等.\n", 45 | "\n", 46 | "该方法的输入为原始数据集 `data_bundle`, 上一步产生的一个或多个 `tag_vocab`, 以及 `checkpoint`\n", 47 | "信息 `state_dict`. 该方法应返回处理过后的 `data_bundle`, 数据类型为 `fastNLP.io.DataBundle`.\n", 48 | "\n", 49 | "如上流程图所示, 该方法为必须重写的方法, `task` 的基类 `Basetask` 未实现该方法. 直接调用会导致异常.\n", 50 | "\n", 51 | "### 1.3 on_setup_model\n", 52 | "\n", 53 | "`on_setup_model` 方法为模型的实现方法, 包括模型的创建和初始化.\n", 54 | "\n", 55 | "该方法的输入为上一步处理过后的 `data_bundle`, `on_generate_and_check_tag_vocab` 的输出\n", 56 | "`tag_vvocab`, 以及 `checkpoint` 信息 `state_dict`. 该方法的输出为一个拥有 `train_step`,\n", 57 | "`evaluation_step` 和 `infer_step` 方法的 `model` 对象.\n", 58 | "\n", 59 | "如上流程图所示, 该方法为必须重写的方法, `task` 的基类 `Basetask` 未实现该方法. 直接调用会导致异常.\n", 60 | "\n", 61 | "### 1.4 on_setup_optimizers\n", 62 | "\n", 63 | "`on_setup_optimizers` 方法为优化器的实现方法, 包括优化器的创建和初始化.\n", 64 | "\n", 65 | "该方法的输入为 `on_dataset_preprocess` 的输出 `data_bundle`, `on_setup_model` 的输出 `model`,\n", 66 | "`on_generate_and_check_tag_vocab` 的输出 `tag_vocab`, 以及 `checkpoint` 信息 `state_dict`.\n", 67 | "该方法的输出为训练所需的优化器, 可以是单独的一个优化器实例,也可以是多个优化器组成的 `List`.\n", 68 | "\n", 69 | "如上流程图所示, 该方法为必须重写的方法, `task` 的基类 `Basetask` 未实现该方法. 直接调用会导致异常.\n", 70 | "\n", 71 | "### 1.5 on_setup_dataloader\n", 72 | "\n", 73 | "`on_setup_dataloader` 方法为数据加载器的实现方法, 包括数据加载器的创建和初始化.\n", 74 | "\n", 75 | "该方法的输入为 `on_dataset_preprocess` 的输出 `data_bundle`, `on_setup_model` 的输出 `model`,\n", 76 | "`on_generate_and_check_tag_vocab` 的输出 `tag_vocab`, 以及 `checkpoint` 信息 `state_dict`.\n", 77 | "\n", 78 | "该方法的输出需要根据当前的控制器判断, 当前的控制器可以通过 `fastie.envs.get_flag` 方法获得,\n", 79 | "可能的取值包括: `train`, `eval`, `infer`. 需要根据当前的控制器来判断取用 `data_bundle` 的哪个 `split.\n", 80 | "该方法的输出可以是一个可迭代的 `dataloader` 组成的 `dict`, 其中当 `flag` 为 `train` 时, `key` 为 `train`\n", 81 | "将被用于训练集, 其他的会被用于验证集; 当 `flag` 为 `eval`, `infer` 时, 所有的 `dataloader` 会被用作测试集\n", 82 | "和推理集.\n", 83 | "\n", 84 | "如上流程图所示, 该方法为非必要重写的方法, `task` 的基类 `Basetask` 已经实现了该方法的基本逻辑.\n", 85 | "在有特殊需求的情况下可以重新定义该方法.\n", 86 | "\n", 87 | "### 1.6 on_setup_callbacks\n", 88 | "\n", 89 | "`on_setup_callbacks` 创建训练过程中的回调项.\n", 90 | "\n", 91 | "该方法的输入为 `on_dataset_preprocess` 的输出 `data_bundle`, `on_setup_model` 的输出 `model`,\n", 92 | "`on_generate_and_check_tag_vocab` 的输出 `tag_vocab`, 以及 `checkpoint` 信息 `state_dict`.\n", 93 | "\n", 94 | "该方法的输出为一个 `fastNLP.Callback` 对象或者 `fastNLP.Callback` 的列表. 具体可参照\n", 95 | "[fastNLP.callbacks](http://www.fastnlp.top/docs/fastNLP/master/api/core.html#callbacks).\n", 96 | "\n", 97 | "如上流程图所示, 该方法为非必要重写的方法, `task` 的基类 `Basetask` 默认不实现任何回调.\n", 98 | "\n", 99 | "### 1.7 on_setup_metrics\n", 100 | "\n", 101 | "`on_setup_metrics` 创建验证过程中的评估指标.\n", 102 | "\n", 103 | "该方法的输入为 `on_dataset_preprocess` 的输出 `data_bundle`, `on_setup_model` 的输出 `model`,\n", 104 | "`on_generate_and_check_tag_vocab` 的输出 `tag_vocab`, 以及 `checkpoint` 信息 `state_dict`.\n", 105 | "\n", 106 | "该方法的返回结果应该为一个字典, 例如: ``{\"acc1\": Accuracy(), \"acc2\": Accuracy()}``.\n", 107 | "\n", 108 | "目前我们支持的 ``metric`` 的种类有以下几种:\n", 109 | "\n", 110 | "1. fastNLP 的 ``metric``: 详见 [fastNLP.metrics](http://www.fastnlp.top/docs/fastNLP/master/api/core.html#metrics);\n", 111 | "2. torchmetrics;\n", 112 | "3. allennlp.training.metrics;\n", 113 | "4. paddle.metric;\n", 114 | "\n", 115 | "如上流程图所示, 该方法为非必要重写的方法, `task` 的基类 `Basetask` 默认不实现任何 `metric`.\n", 116 | "注意: 如果要使用 `fastie` 的 `load_best_model` 或 `topk` 等必需 `metric` 的特性, 则必须重写\n", 117 | "该方法.\n", 118 | "\n", 119 | "### 1.8 on_setup_extra_fastnlp_parameters\n", 120 | "\n", 121 | "`on_setup_extra_fastnlp_parameters` 方法为 `fastNLP` 的一些额外参数的设置.\n", 122 | "\n", 123 | "该方法的输入为 `on_dataset_preprocess` 的输出 `data_bundle`, `on_setup_model` 的输出 `model`,\n", 124 | "`on_generate_and_check_tag_vocab` 的输出 `tag_vocab`, 以及 `checkpoint` 信息 `state_dict`.\n", 125 | "\n", 126 | "该方法的返回结果应该为一个字典, 对应 `fastNLP.Trainer` 或 `fastNLP.Evaluator` 的参数,\n", 127 | "参见:\n", 128 | "[fastNLP.Trainer](http://www.fastnlp.top/docs/fastNLP/master/api/generated/fastNLP.core.Trainer.html#fastNLP.core.Trainer),\n", 129 | "[fastNLP.Evaluator](http://www.fastnlp.top/docs/fastNLP/master/api/generated/fastNLP.core.Evaluator.html#fastNLP.core.Evaluator)\n", 130 | "\n", 131 | "如上流程图所示, 该方法为非必要重写的方法, `task` 的基类 `Basetask` 默认不添加任何额外参数.\n", 132 | "\n", 133 | "### 1.9 on_get_state_dict\n", 134 | "\n", 135 | "`on_get_state_dict` 方法为获取 `checkpoint` 变量 `state_dict` 的方法.\n", 136 | "\n", 137 | "该方法的输入为 `on_dataset_preprocess` 的输出 `data_bundle`, `on_setup_model` 的输出 `model`,\n", 138 | "以及 `on_generate_and_check_tag_vocab` 的输出 `tag_vocab`.\n", 139 | "\n", 140 | "该方法的输出应与前面所有的生命周期方法的参数 `state_dict` 格式一直.\n", 141 | "\n", 142 | "如上流程图所示, 该方法为非必要重写的方法, `task` 的基类 `Basetask` 默认保存 `model` 的\n", 143 | "`state_dict` 以及 `tag_vocab`。" 144 | ], 145 | "metadata": { 146 | "collapsed": false 147 | } 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "source": [], 152 | "metadata": { 153 | "collapsed": false 154 | } 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "source": [ 159 | "## 2. 实战\n", 160 | "\n", 161 | "`fastie` 默认提供了使用预训练的 `BERT` 模型进行 `NER` 任务的 `task` 类 `BertNER`, 该任务中\n", 162 | "默认使用的优化器为 `torch.optim.Adam`. 因此, 我们可以通过重写 `on_setup_optimizer` 方法来\n", 163 | "实现使用 `AdamW` 优化器的 `NER` 任务.\n", 164 | "\n", 165 | "```python\n", 166 | "from fastie.tasks import BertNER\n", 167 | "from torch.optim import AdamW\n", 168 | "\n", 169 | "class BertNERAdamW(BertNER):\n", 170 | " def on_setup_optimizer(self, model, tag_vocab, data_bundle, state_dict=None):\n", 171 | " return AdamW(model.parameters(), lr=1e-3)\n", 172 | "```" 173 | ], 174 | "metadata": { 175 | "collapsed": false 176 | } 177 | } 178 | ], 179 | "metadata": { 180 | "kernelspec": { 181 | "display_name": "Python 3", 182 | "language": "python", 183 | "name": "python3" 184 | }, 185 | "language_info": { 186 | "codemirror_mode": { 187 | "name": "ipython", 188 | "version": 2 189 | }, 190 | "file_extension": ".py", 191 | "mimetype": "text/x-python", 192 | "name": "python", 193 | "nbconvert_exporter": "python", 194 | "pygments_lexer": "ipython2", 195 | "version": "2.7.6" 196 | } 197 | }, 198 | "nbformat": 4, 199 | "nbformat_minor": 0 200 | } 201 | -------------------------------------------------------------------------------- /fastie/command.py: -------------------------------------------------------------------------------- 1 | """""" 2 | import sys 3 | from argparse import ArgumentParser, Namespace, Action 4 | from dataclasses import dataclass, field 5 | from typing import Sequence, Optional 6 | 7 | from fastie.chain import Chain 8 | from fastie.controller import CONTROLLER 9 | from fastie.dataset import DATASET 10 | from fastie.envs import set_flag, parser, logger 11 | from fastie.exhibition import Exhibition 12 | from fastie.node import BaseNodeConfig, BaseNode 13 | from fastie.tasks import NER, EE, RE 14 | from fastie.utils import parse_config 15 | 16 | chain = Chain([]) 17 | 18 | global_config: dict = dict() 19 | 20 | 21 | @dataclass 22 | class CommandNodeConfig(BaseNodeConfig): 23 | config: str = field(default='', 24 | metadata=dict(help='The config file you want to use.', 25 | existence=True, 26 | alias='-c')) 27 | task: str = field( 28 | default='', 29 | metadata=dict( 30 | help='The task you want to use. Please use / to split the task and ' 31 | 'the specific solution.', 32 | existence=True, 33 | alias='-t')) 34 | dataset: str = field(default='', 35 | metadata=dict( 36 | help='The dataset you want to work with.', 37 | existence=['train', 'eval', 'infer'], 38 | alias='-d')) 39 | 40 | 41 | class CommandNode(BaseNode): 42 | """ fastIE command line basic arguments 43 | Args: 44 | :task (str)[train,evaluation,inference]=None: 任务名. 45 | :dataset (str)[train,evaluation,inference]=None: 数据集名. 46 | """ 47 | 48 | _config = CommandNodeConfig() 49 | _help = 'fastIE command line basic arguments' 50 | 51 | def __init__(self, 52 | solution: Optional[str] = None, 53 | dataset: Optional[str] = None): 54 | BaseNode.__init__(self) 55 | self.solution: Optional[str] = solution 56 | self.dataset: Optional[str] = dataset 57 | 58 | @property 59 | def action(self): 60 | node = self 61 | 62 | class ParseAction(Action): 63 | 64 | def __call__(self, 65 | parser: ArgumentParser, 66 | namespace: Namespace, 67 | values, 68 | option_string: Optional[str] = None): 69 | if option_string is None: 70 | return 71 | field_dict = node._config.__class__.__dataclass_fields__ 72 | variable_name = '' 73 | if isinstance(values, list) and len(values) == 1: 74 | values = values[0] 75 | if option_string.replace('--', '') in field_dict.keys(): 76 | variable_name = option_string.replace('--', '') 77 | if variable_name != 'config': 78 | setattr(node, variable_name, values) 79 | setattr(namespace, variable_name, values) 80 | else: 81 | for key, value in field_dict.items(): 82 | if isinstance(value.metadata['alias'], Sequence): 83 | if option_string in value.metadata['alias']: 84 | variable_name = key 85 | if variable_name != 'config': 86 | setattr(node, variable_name, values) 87 | setattr(namespace, variable_name, values) 88 | 89 | elif isinstance(value.metadata['alias'], str): 90 | if option_string == value.metadata['alias']: 91 | variable_name = key 92 | if variable_name != 'config': 93 | setattr(node, variable_name, values) 94 | setattr(namespace, variable_name, values) 95 | if variable_name == 'task': 96 | if '/' in values: 97 | obj_cls = None 98 | task, solution = values.split('/') 99 | if task.lower() == 'ner': 100 | obj_cls = NER.get(solution) 101 | elif task.lower() == 'ee': 102 | obj_cls = EE.get(solution) 103 | elif task.lower() == 're': 104 | obj_cls = RE.get(solution) 105 | else: 106 | logger.error( 107 | f'The task type `{task}` you selected does not ' 108 | f'exist. \n', 109 | 'You can only choose from `ner`, `ee`, or `re`. ' 110 | ) 111 | exit(0) 112 | if obj_cls is not None: 113 | obj = obj_cls() 114 | _ = chain + obj 115 | else: 116 | logger.warn( 117 | f'The solution `{solution}` you selected does ' 118 | f'not exist. ') 119 | logger.info('Here are the optional options: \n') 120 | sys.argv.append('-t') 121 | sys.argv.append('-l') 122 | Exhibition.intercept() 123 | exit(0) 124 | else: 125 | logger.warn( 126 | f'You must specify both the task category and the ' 127 | f'specific solution, such as `ner/bert` instead of ' 128 | f'`{values}`. ') 129 | logger.info('Here are the optional options: \n') 130 | sys.argv.append('-t') 131 | sys.argv.append('-l') 132 | Exhibition.intercept() 133 | exit(0) 134 | elif variable_name == 'dataset': 135 | obj_cls = DATASET.get(values) 136 | if obj_cls is None: 137 | logger.warn( 138 | f'The dataset `{values}` you selected does not ' 139 | f'exist. ') 140 | logger.info('Here are the optional options: \n') 141 | sys.argv.append('-d') 142 | sys.argv.append('-l') 143 | Exhibition.intercept() 144 | exit(0) 145 | else: 146 | obj = obj_cls() 147 | _ = chain + obj 148 | elif variable_name == 'config': 149 | global_config = parse_config(values) 150 | if global_config is None: 151 | logger.warn( 152 | f'The config file `{values}` you selected does not ' 153 | f'exist. ') 154 | logger.info('Here are the optional options: \n') 155 | sys.argv.append('-c') 156 | sys.argv.append('-l') 157 | Exhibition.intercept() 158 | exit(0) 159 | 160 | return ParseAction 161 | 162 | 163 | def interact_handler(): 164 | if not sys.argv[0].endswith('interact'): 165 | return 166 | from fastie.dataset.io.sentence import Sentence 167 | sentence = '' 168 | while sentence != '!exit': 169 | sentence = input( 170 | 'Type a sequence, type `!exit` to end interacting:\n>> ') 171 | if len(sentence) == 0: 172 | continue 173 | if sentence == '!exit': 174 | break 175 | sentence = Sentence(sentence=sentence) 176 | _ = chain + sentence 177 | chain[1]._on_dataset_preprocess_cache = None 178 | chain.run() 179 | 180 | 181 | def main(): 182 | Exhibition.intercept() 183 | if sys.argv[0].endswith('train'): 184 | _ = chain + CONTROLLER.get('trainer')() 185 | set_flag('train') 186 | elif sys.argv[0].endswith('eval'): 187 | _ = chain + CONTROLLER.get('evaluator')() 188 | set_flag('eval') 189 | elif sys.argv[0].endswith('infer'): 190 | _ = chain + CONTROLLER.get('inference')() 191 | set_flag('infer') 192 | elif sys.argv[0].endswith('interact'): 193 | _ = chain + CONTROLLER.get('interactor')() 194 | set_flag('infer') 195 | elif sys.argv[0].endswith('server'): 196 | set_flag('web') 197 | node = CommandNode() 198 | _ = node.parser 199 | args = parser.parse_known_args() 200 | args = parser.parse_known_args(args[1]) 201 | interact_handler() 202 | chain.run() 203 | 204 | 205 | if __name__ == '__main__': 206 | main() 207 | -------------------------------------------------------------------------------- /fastie/controller/inference.py: -------------------------------------------------------------------------------- 1 | """Inference tool for FastIE.""" 2 | __all__ = ['Inference', 'InferenceConfig', 'InferenceMetric'] 3 | 4 | import json 5 | from dataclasses import dataclass 6 | from dataclasses import field 7 | from functools import reduce 8 | from typing import Union, Sequence, Optional 9 | 10 | from fastNLP import Evaluator, DataSet, Metric, auto_param_call 11 | from fastNLP.io import DataBundle 12 | 13 | from fastie.controller.base_controller import BaseController, CONTROLLER 14 | from fastie.envs import set_flag, logger 15 | from fastie.node import BaseNodeConfig 16 | 17 | 18 | class InferenceMetric(Metric): 19 | """用于保存推理结果的 Metric. 20 | 21 | :param save_path: 保存路径, 应为一个文件名, 例如 ``result.jsonl`` 22 | :param verbose: 是否打印推理结果 23 | """ 24 | 25 | def __init__(self, save_path: Optional[str] = None, verbose: bool = True): 26 | super().__init__(aggregate_when_get_metric=True) 27 | self.result: list = [] 28 | self.save_path: Optional[str] = save_path 29 | self.verbose: bool = verbose 30 | 31 | def update(self, pred: Sequence[dict]): 32 | if self.save_path is not None: 33 | if self.backend.is_distributed() and self.backend.is_global_zero(): 34 | with open(self.save_path, 'a+') as file: 35 | file.write('\n'.join( 36 | map( 37 | lambda x: json.dumps(x), 38 | reduce(lambda x, y: [*x, *y], 39 | self.all_gather_object(pred)))) + '\n') 40 | elif not self.backend.is_distributed(): 41 | with open(self.save_path, 'a+') as file: 42 | file.write('\n'.join(map(lambda x: json.dumps(x), pred)) + 43 | '\n') 44 | if self.verbose and not self.backend.is_distributed(): 45 | for sample in pred: 46 | # 判断一下不同的格式 47 | # 首先是 NER 小组约定的格式 48 | if 'entity_mentions' in sample.keys(): 49 | print('tokens: ', ' '.join(sample['tokens'])) 50 | print( 51 | 'pred: ', ' '.join([ 52 | sample['tokens'][i] if i 53 | in sample['entity_mentions'][0][0] else ''.join( 54 | [' ' for j in range(len(sample['tokens'][i]))]) 55 | for i in range(len(sample['tokens'])) 56 | ]), f" {sample['entity_mentions'][0][1]} -> " 57 | f"{sample['entity_mentions'][0][2]}" 58 | if len(sample['entity_mentions'][0]) >= 3 else 59 | f" {sample['entity_mentions'][0][1]}") 60 | if len(sample['entity_mentions']) > 1: 61 | for entity_mention in sample['entity_mentions'][1:]: 62 | print( 63 | ' ', ' '.join([ 64 | sample['tokens'][i] 65 | if i in entity_mention[0] else ''.join([ 66 | ' ' for j in range( 67 | len(sample['tokens'][i])) 68 | ]) for i in range(len(sample['tokens'])) 69 | ]), f' {entity_mention[1]} -> ' 70 | f'{entity_mention[2]}' if len(entity_mention) 71 | == 3 else f' {entity_mention[1]}') 72 | else: 73 | # TODO: 其他类型的格式,例如为关系抽取小组制定的格式 74 | pass 75 | self.result.extend(pred) 76 | 77 | def get_metric(self): 78 | return reduce(lambda x, y: x + y, self.all_gather_object(self.result)) 79 | 80 | 81 | def generate_step_fn(evaluator, batch): 82 | outputs = evaluator.evaluate_step(batch) 83 | content = '\n'.join(outputs['pred']) 84 | if getattr(evaluator, 'generate_save_path', None) is not None: 85 | with open(evaluator.generate_save_path, 'a+') as f: 86 | f.write(f'{content}\n') 87 | else: 88 | evaluator.result.extend(outputs['pred']) 89 | 90 | 91 | @dataclass 92 | class InferenceConfig(BaseNodeConfig): 93 | """推理器的配置.""" 94 | save_path: Optional[str] = field( 95 | default=None, 96 | metadata=dict( 97 | help='The path to save the generated results. If not set, output to ' 98 | 'the returned variable. ', 99 | existence=['infer'])) 100 | verbose: bool = field( 101 | default=True, 102 | metadata=dict(help='Whether to output the contents of each inference. ' 103 | 'Multiple cards are not supported. ', 104 | existence=['infer'])) 105 | 106 | 107 | @CONTROLLER.register_module('inference') 108 | class Inference(BaseController): 109 | """推理器 用于对任务在 ``infer`` 数据集上进行检验,并输出 ``infer`` 数据集上的推理结果. 110 | 111 | 也可以使用命令行模式, 例如: 112 | 113 | .. code-block:: console 114 | $ fastie-infer --task ner/bert --dataset sentence --sentence It is located in Beijing --verbose 115 | 116 | :param save_path: 推理结果的保存路径, 应为一个文件名, 例如 ``result.jsonl`` 117 | 推理结果将保存到 ``save_path`` 中, 保存的格式为 ``jsonlines`` 格式, 每行为一个样本的推理结果的 ``json`` 字符串 118 | 119 | :param verbose: 是否在推理的过程中实时打印推理结果 120 | """ 121 | _config = InferenceConfig() 122 | _help = 'Inference tool for FastIE. ' 123 | 124 | def __init__(self, 125 | save_path: Optional[str] = None, 126 | verbose: bool = True, 127 | **kwargs): 128 | super(Inference, self).__init__(**kwargs) 129 | self.save_path: Optional[str] = save_path 130 | self.verbose: bool = verbose 131 | set_flag('infer') 132 | 133 | def run( 134 | self, 135 | parameters_or_data: Optional[Union[dict, DataBundle, DataSet, str, 136 | Sequence[str]]] = None 137 | ) -> Sequence[dict]: 138 | """验证器的 ``run`` 方法,用于实际地对传入的 ``task`` 或是数据集进行推理. 139 | 140 | :param parameters_or_data: 既可以是 task,也可以是数据集: 141 | * 为 ``task`` 时, 应为 :class:`~fastie.BaseTask` 对象 ``run`` 142 | 方法的返回值, 例如: 143 | >>> from fastie.tasks import BertNER 144 | >>> task = BertNER().run() 145 | >>> Inference().run(task) 146 | * 为数据集,可以是 ``[dict, DataSet, DataBundle, None]`` 类型的数据集: 147 | * ``dict`` 类型的数据集,例如: 148 | >>> dataset = {'tokens': [ "It", "is", "located", "in", "Seoul", "." ]} 149 | * ``Sequence[dict]`` 类型的数据集,例如: 150 | >>> dataset = [{'tokens': [ "It", "is", "located", "in", "Seoul", "." ]}] 151 | * ``DataSet`` 类型的数据集,例如: 152 | >>> from fastNLP import DataSet, Instance 153 | >>> dataset = DataSet([Instance(tokens=[ "It", "is", "located", "in", "Seoul", "." ])]) 154 | * ``DataBundle`` 类型的数据集,必须包含 ``infer`` 子集, 例如: 155 | >>> from fastNLP import DataSet, Instance 156 | >>> from fastNLP.io import DataBundle 157 | >>> dataset = DataBundle(datasets={'infer': DataSet([Instance(tokens=[ "It", "is", "located", "in", "Seoul", "." ])])}) 158 | * ``str`` 类型的数据集,会自动根据空格分割转换为 ``token``, 159 | 详见 :class:`fastie.dataset.io.sentence.Sentence` 例如: 160 | >>> dataset = "It is located in Seoul ." 161 | * ``Sequence[str]`` 类型的数据集,会自动根据空格分割转换为 ``token``, 162 | 详见 :class:`fastie.dataset.io.sentence.Sentence`, 例如: 163 | >>> dataset = ["It is located in Seoul .", "It is located in Beijing ."] 164 | * ``None`` 会自动寻找 ``config`` 中的 ``dataset``, 例如: 165 | >>> config = {'dataset': 'conll2003'} 166 | >>> Inference.from_config(config).run() 167 | 168 | :return: ``List[dict]`` 类型的推理结果, 例如: 169 | >>> [{'tokens': [ "It", "is", "located", "in", "Seoul", "." ], 170 | >>> 'entity_motions': [([4], "LOC")]}] 171 | """ 172 | parameters_or_data = BaseController.run(self, parameters_or_data) 173 | if self._sequential: 174 | return parameters_or_data 175 | if parameters_or_data is None: 176 | logger.error( 177 | 'Inference tool do not allow task and dataset to be left ' 178 | 'empty. ') 179 | exit(1) 180 | parameters_or_data['evaluate_fn'] = 'infer_step' 181 | parameters_or_data['verbose'] = False 182 | inference_metric = InferenceMetric(save_path=self.save_path, 183 | verbose=self.verbose) 184 | parameters_or_data['metrics'] = {'infer': inference_metric} 185 | evaluator = Evaluator(**parameters_or_data) 186 | return auto_param_call(evaluator.run, parameters_or_data)['infer'] 187 | -------------------------------------------------------------------------------- /fastie/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from functools import reduce 4 | from typing import Union, Optional, Set, Tuple, List, Dict 5 | 6 | from fastNLP import Vocabulary, Instance 7 | from fastNLP.io import DataBundle 8 | 9 | from fastie.envs import get_flag, CONFIG_FLAG, FASTIE_HOME, logger, set_task, \ 10 | set_dataset 11 | from fastie.utils.config import Config 12 | 13 | 14 | def generate_tag_vocab( 15 | data_bundle: DataBundle, 16 | unknown: Optional[str] = 'O', 17 | base_mapping: Optional[dict] = None) -> Dict[str, Vocabulary]: 18 | """根据数据集中的已标注样本构建 tag_vocab. 19 | 20 | :param data_bundle: :class:`~fastNLP.io.DataBundle` 对象 21 | :param unknown: 未知标签的标记 22 | :param base_mapping: 基础映射,例如 ``{"entity": {"label": 0}}`` 或者 ``{"entity": {0: "label"}}`` 23 | 函数将在确保 ``base_mapping`` 中的标签不会被覆盖改变的前提下构造 vocab 24 | :return: 如果存在已标注样本,则返回构造成功的 :class:`~fastNLP.Vocabulary` 的 `dict` 对象, 25 | 其中 `value` 为 `vocab` 类别. 否则返回空的 ``dict`` 26 | """ 27 | if 'train' in data_bundle.datasets.keys() \ 28 | or 'dev' in data_bundle.datasets.keys() \ 29 | or 'test' in data_bundle.datasets.keys(): 30 | vocab: Dict = {} 31 | 32 | def construct_vocab(instance: Instance): 33 | # 当然,用来 infer 的数据集是无法构建的,这里判断一下 34 | if 'entity_mentions' in instance.keys(): 35 | if 'entity' not in vocab.keys(): 36 | vocab['entity'] = Vocabulary(padding=None, unknown=unknown) 37 | for entity_mention in instance['entity_mentions']: 38 | vocab['entity'].add(entity_mention[1]) 39 | else: 40 | # TODO: 增加新的标签种类 41 | ... 42 | return instance 43 | 44 | data_bundle.apply_more(construct_vocab) 45 | for key, value in vocab.items(): 46 | if base_mapping and key in base_mapping.keys(): 47 | base_word2idx = {} 48 | base_idx2word = {} 49 | if isinstance(base_mapping[key], dict) and isinstance( 50 | list(base_mapping[key].keys())[0], str): 51 | base_word2idx = base_mapping[key] 52 | base_idx2word = \ 53 | {word: idx for idx, word in base_mapping[key].items()} 54 | elif isinstance(base_mapping[key], dict) and isinstance( 55 | list(base_mapping[key].keys())[0], int): 56 | base_idx2word = base_mapping[key] 57 | base_word2idx = \ 58 | {word: idx for idx, word in base_mapping[key].items()} 59 | for k, v in value.word2idx.items(): 60 | if key not in base_word2idx.keys(): 61 | # 线性探测法 62 | while v in base_idx2word.keys(): 63 | v += 1 64 | base_word2idx[key] = value 65 | base_idx2word[value] = key 66 | v._word2idx = base_word2idx 67 | v._idx2word = base_idx2word 68 | return vocab 69 | else: 70 | # 无以标注样本 71 | return {} 72 | 73 | 74 | def check_loaded_tag_vocab( 75 | loaded_tag_vocab: Optional[Union[dict, Vocabulary]], 76 | tag_vocab: Optional[Vocabulary]) -> Tuple[int, Optional[Vocabulary]]: 77 | """检查加载的 tag_vocab 是否与新生成的 tag_vocab 一致. 78 | 79 | :param loaded_tag_vocab: 从 ``checkpoint`` 中加载得到的 ``tag_vocab``; 80 | 可以为 ``dict`` 类型,也可以是 :class:`~fastNLP.Vocabulary` 类型。 81 | :param tag_vocab: 从数据集中构建的 ``tag_vocab``; 82 | :return: 检查的结果信号和可使用的 ``tag_vocab``,信号取值: 83 | 84 | * 为 ``1`` 时 85 | 表示一致或可矫正的错误,可以直接使用返回的 ``tag_vocab`` 86 | * 为 ``-1`` 时 87 | 表示出现冲突且无法矫正,请抛弃加载得到的 ``loaded_tag_vocab`` 88 | 使用返回的 ``tag_vocab`` 89 | * 为 ``0`` 时 90 | 无可用的 ``tag_vocab``,将直接输出 idx 91 | """ 92 | idx2word = None 93 | word2idx = None 94 | if loaded_tag_vocab is not None: 95 | if isinstance(loaded_tag_vocab, Vocabulary): 96 | idx2word = loaded_tag_vocab.idx2word 97 | word2idx = loaded_tag_vocab.word2idx 98 | elif isinstance(list(loaded_tag_vocab.keys())[0], int): 99 | idx2word = loaded_tag_vocab 100 | word2idx = {word: idx for idx, word in idx2word.items()} 101 | elif isinstance(list(loaded_tag_vocab.keys())[0], str): 102 | word2idx = loaded_tag_vocab 103 | idx2word = {idx: word for word, idx in word2idx.items()} 104 | if loaded_tag_vocab is None and tag_vocab is None: 105 | logger.warn('Error: No tag dictionary is available. ') 106 | return 0, None 107 | if loaded_tag_vocab is None and tag_vocab is not None: 108 | return 1, tag_vocab 109 | if loaded_tag_vocab is not None and tag_vocab is None: 110 | tag_vocab = Vocabulary() 111 | tag_vocab._word2idx = word2idx 112 | tag_vocab._idx2word = idx2word 113 | return 1, tag_vocab 114 | if loaded_tag_vocab is not None and tag_vocab is not None: 115 | if get_flag() != 'infer': 116 | if word2idx != tag_vocab.word2idx: 117 | if set(word2idx.keys()) == set( # type: ignore [union-attr] 118 | tag_vocab.word2idx.keys() # type: ignore [union-attr] 119 | ): # type: ignore [union-attr] 120 | tag_vocab._word2idx.update(word2idx) 121 | tag_vocab._idx2word.update(idx2word) 122 | return 1, tag_vocab 123 | elif set(tag_vocab.word2idx.keys() # type: ignore [union-attr] 124 | ).issubset(set( 125 | word2idx.keys())): # type: ignore [union-attr] 126 | tag_vocab._word2idx.update(word2idx) 127 | tag_vocab._idx2word.update(idx2word) 128 | return 1, tag_vocab 129 | else: 130 | logger.warn( 131 | 'The tag dictionary ' # type: ignore [union-attr] 132 | f"`\n[{','.join(list(tag_vocab._word2idx.keys()))}]`\n" # type: ignore [union-attr] 133 | ' loaded from the model is not the same as the ' 134 | 'tag dictionary ' 135 | f"\n`[{','.join(list(word2idx.keys()))}]`\n" # type: ignore [union-attr] 136 | ' built from the dataset, so the loaded model may be ' 137 | 'discarded') 138 | return -1, tag_vocab 139 | else: 140 | return 1, tag_vocab 141 | else: 142 | tag_vocab._word2idx = word2idx 143 | tag_vocab._idx2word = idx2word 144 | return 1, tag_vocab 145 | return 0, None 146 | 147 | 148 | def parse_config(_config: object) -> Optional[dict]: 149 | config = dict() 150 | if isinstance(_config, dict): 151 | for key, value in _config.items(): 152 | if key == 'task': 153 | set_task(value) 154 | if key == 'dataset': 155 | set_dataset(value) 156 | if not key.startswith('_'): 157 | config[key] = value 158 | return config 159 | elif isinstance(_config, str): 160 | if os.path.exists(_config) and os.path.isfile( 161 | _config) and _config.endswith('.py'): 162 | if CONFIG_FLAG == 'dict': 163 | config_dict = reduce(lambda x, y: { 164 | **x, 165 | **y 166 | }, [ 167 | value 168 | for value in Config.fromfile(_config)._cfg_dict.values() 169 | if isinstance(value, dict) 170 | ]) 171 | return parse_config(config_dict) 172 | elif CONFIG_FLAG == 'class': 173 | config_obj = Config.fromfile(_config)._cfg_dict.Config() 174 | config_dict = { 175 | key: getattr(config_obj, key) 176 | for key in dir(config_obj) if not key.startswith('_') 177 | } 178 | return parse_config(config_dict) 179 | else: 180 | for root, dirs, files in os.walk( 181 | os.path.join(FASTIE_HOME, 'configs')): 182 | for file in files: 183 | if _config == file.replace('.py', ''): 184 | return parse_config(os.path.join(root, file)) 185 | return None 186 | else: 187 | for key in _config.__dir__(): 188 | if key == 'task': 189 | set_task(getattr(_config, key)) 190 | if key == 'dataset': 191 | set_dataset(getattr(_config, key)) 192 | if not key.startswith('_'): 193 | config[key] = getattr(_config, key) 194 | return config 195 | 196 | 197 | def inspect_function_calling(func_name: str) -> Optional[Set[str]]: 198 | import inspect 199 | frame_info_list = inspect.stack() 200 | argument_user_provided = [] 201 | for i in range(len(frame_info_list)): 202 | if frame_info_list[i + 1].function == func_name: 203 | for k in range(i, len(frame_info_list)): 204 | if frame_info_list[k + 1].function != func_name: 205 | co_const = frame_info_list[k + 1].frame.f_code.co_consts 206 | if len(co_const) > 1: 207 | for j in range(len(co_const) - 1, -1, -1): 208 | if isinstance(co_const[j], tuple) and \ 209 | isinstance(co_const[j][0], str): 210 | argument_user_provided.extend(co_const[j]) 211 | break 212 | if 'args' in frame_info_list[k].frame.f_locals.keys(): 213 | argument_list = \ 214 | frame_info_list[k + 1].frame.f_locals[func_name].__code__.co_varnames 215 | argument_user_provided. \ 216 | extend(argument_list[:len(frame_info_list[k].frame.f_locals['args'])]) 217 | # 转换为 set 去除重复项 218 | return set(argument_user_provided) 219 | return None 220 | 221 | 222 | def inspect_metrics(parameters: dict = {}) -> List[str]: 223 | """根据参数中的 metrics 字段,返回真正检验结果中可能存在的 metric 名称. 224 | 225 | 例如, 输入参数为 {'metrics': {'accuracy': Accuracy(), 'f1': SpanFPreRecMetric()}}, 226 | 则返回 ['accuracy#acc', 'f1#f', 'f1#pre', 'f1#rec']. 227 | 228 | :param parameters: 可以用于 :class:`fastNLP.Trainer` 的参数字典. 229 | :return: 返回可能存在的 metric 名称列表. 230 | """ 231 | from fastNLP import Trainer 232 | 233 | if 'metrics' not in parameters: 234 | return [] 235 | try: 236 | stdout = sys.stdout 237 | sys.stdout = open(os.devnull, 'w') 238 | trainer = Trainer(**parameters) 239 | result = trainer.evaluator.run(num_eval_batch_per_dl=1) 240 | sys.stdout = stdout 241 | return list(result.keys()) 242 | except Exception as e: 243 | logger.error(e) 244 | return [] 245 | -------------------------------------------------------------------------------- /fastie/tasks/ner/bert/bert.py: -------------------------------------------------------------------------------- 1 | """BertNER.""" 2 | __all__ = ['BertNER', 'BertNERConfig'] 3 | 4 | from dataclasses import dataclass, field 5 | from functools import reduce 6 | from typing import Optional, Dict 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | from fastNLP import Instance, Vocabulary 12 | from fastNLP.core.metrics import Accuracy 13 | from fastNLP.io import DataBundle 14 | from fastNLP.transformers.torch.models.bert import BertModel, BertConfig, \ 15 | BertTokenizer 16 | from torch import nn 17 | 18 | from fastie.tasks.base_task import NER 19 | from fastie.tasks.ner.BaseNERTask import BaseNERTask, BaseNERTaskConfig 20 | 21 | 22 | class Model(nn.Module): 23 | 24 | def __init__(self, 25 | pretrained_model_name_or_path: Optional[str] = None, 26 | num_labels: int = 9, 27 | tag_vocab: Optional[Vocabulary] = None, 28 | **kwargs): 29 | super(Model, self).__init__() 30 | if pretrained_model_name_or_path is not None: 31 | self.bert = BertModel.from_pretrained( 32 | pretrained_model_name_or_path) 33 | else: 34 | self.bert = BertModel(BertConfig(**kwargs)) 35 | self.bert.requires_grad_(False) 36 | self.num_labels = num_labels 37 | self.classificationHead = nn.Linear(self._get_bert_embedding_dim(), 38 | num_labels) 39 | # 为了推理过程中能输出人类可读的结果,把 tag2label 也传进来 40 | self.tag_vocab = tag_vocab 41 | 42 | def _get_bert_embedding_dim(self): 43 | with torch.no_grad(): 44 | temp = torch.zeros(1, 1).int().to(self.bert.device) 45 | return self.bert(temp).last_hidden_state.shape[-1] 46 | 47 | def forward(self, input_ids, attention_mask): 48 | features = self.bert(input_ids=input_ids, 49 | attention_mask=attention_mask).last_hidden_state 50 | features = self.classificationHead(features) 51 | return dict(features=features) 52 | 53 | def train_step(self, input_ids, attention_mask, offset_mask, 54 | entity_mentions): 55 | features = self.forward(input_ids, attention_mask)['features'] 56 | loss = 0 57 | for b in range(features.shape[0]): 58 | logits = F.softmax( 59 | features[b][offset_mask[b].nonzero(), :].squeeze(1), dim=1) 60 | target = torch.zeros(logits.shape[0]).to(features.device) 61 | for entity_mention in entity_mentions[b]: 62 | for i in entity_mention[0]: 63 | target[i] = entity_mention[1] 64 | for i in range(logits.shape[0]): 65 | one_hot_target = torch.zeros(self.num_labels).to( 66 | features.device) 67 | one_hot_target[int(target[i])] = 1 68 | loss += F.binary_cross_entropy(logits[i], one_hot_target) 69 | return dict(loss=loss) 70 | 71 | def evaluate_step(self, input_ids, attention_mask, offset_mask, 72 | entity_mentions): 73 | features = self.forward(input_ids, attention_mask)['features'] 74 | pred_list = [] 75 | target_list = [] 76 | max_len = 0 77 | for b in range(features.shape[0]): 78 | logits = F.softmax( 79 | features[b][offset_mask[b].nonzero(), :].squeeze(1), dim=1) 80 | pred = logits.argmax(dim=1).to(features.device) 81 | target = torch.zeros(pred.shape[0]).to(features.device) 82 | if pred.shape[0] > max_len: 83 | max_len = pred.shape[0] 84 | for entity_mention in entity_mentions[b]: 85 | for i in entity_mention[0]: 86 | target[i] = entity_mention[1] 87 | pred_list.append(pred) 88 | target_list.append(target) 89 | pred = torch.stack( 90 | [F.pad(pred, (0, max_len - pred.shape[0])) for pred in pred_list]) 91 | target = torch.stack([ 92 | F.pad(target, (0, max_len - target.shape[0])) 93 | for target in target_list 94 | ]) 95 | return dict(pred=pred, target=target) 96 | 97 | def infer_step(self, tokens, input_ids, attention_mask, offset_mask): 98 | features = self.forward(input_ids, attention_mask)['features'] 99 | pred_list = [] 100 | for b in range(features.shape[0]): 101 | logits = F.softmax( 102 | features[b][offset_mask[b].nonzero(), :].squeeze(1), dim=1) 103 | pred = logits.argmax(dim=1).to(features.device) 104 | pred_dict = {} 105 | pred_dict['tokens'] = tokens[b] 106 | pred_dict['entity_mentions'] = [] 107 | for i in range(pred.shape[0]): 108 | # 考虑一下,如果用户没有传入 tag_vocab,那么这里的输出就是 idx 109 | if self.tag_vocab is not None: 110 | pred_dict['entity_mentions'].append( 111 | ([i], self.tag_vocab.idx2word[int(pred[i])], 112 | round(float(logits[i].max()), 3))) 113 | else: 114 | pred_dict['entity_mentions'].append( 115 | ([i], int(pred[i]), round(float(logits[i].max()), 3))) 116 | pred_list.append(pred_dict) 117 | # 推理的结果一定是可 json 化的,建议 List[Dict],和输入的数据集的格式一致 118 | # 这里的结果是用户可读的,所以建议把 idx2label 存起来 119 | # 怎么存可以看一下下面 233 行 120 | return dict(pred=pred_list) 121 | 122 | 123 | @dataclass 124 | class BertNERConfig(BaseNERTaskConfig): 125 | """BertNER 所需参数.""" 126 | pretrained_model_name_or_path: str = field( 127 | default='bert-base-uncased', 128 | metadata=dict( 129 | help='name of transformer model (see ' 130 | 'https://huggingface.co/transformers/pretrained_models.html for ' 131 | 'options).', 132 | existence='train')) 133 | lr: float = field(default=2e-5, 134 | metadata=dict(help='learning rate', existence='train')) 135 | 136 | 137 | @NER.register_module('bert') 138 | class BertNER(BaseNERTask): 139 | """BertNER 使用预训练的 BERT 模型和分类头来做 NER 任务. 140 | 141 | :param pretrained_model_name_or_path: transformers 预训练 BERT 模型名字或路径. 142 | (see https://huggingface.co/models for options). 143 | :param lr: 学习率 144 | """ 145 | # 必须在这里定义自己 config 146 | _config = BertNERConfig() 147 | # 帮助信息,会显示在命令行分组的帮助信息中 148 | _help = 'Use pre-trained BERT and a classification head to classify tokens' 149 | 150 | def __init__(self, 151 | pretrained_model_name_or_path: str = 'bert-base-uncased', 152 | lr: float = 2e-5, 153 | **kwargs): 154 | # 必须要把父类 (BaseTask)的参数也复制过来,否则用户没有父类的代码提示; 155 | # 在这里进行父类的初始化; 156 | # 父类的参数我们不需要进行任何操作,比如这里的 cuda 和 load_model,我们无视就可以了。 157 | super().__init__(**kwargs) 158 | self.pretrained_model_name_or_path = pretrained_model_name_or_path 159 | self.lr = lr 160 | 161 | def on_dataset_preprocess(self, data_bundle: DataBundle, 162 | tag_vocab: Dict[str, Vocabulary], 163 | state_dict: Optional[dict]) -> DataBundle: 164 | """数据预处理, 包括将 `label` 通过生成或加载的 `tag_vocab` 转化为 id, 并将 `tokens` 通过 165 | `BertTokenizer` 转化为 id. 166 | 167 | :param data_bundle: 原始数据 168 | :param tag_vocab: 生成或加载的 `tag_vocab` 169 | :param state_dict: 加载的 `checkpoint` 170 | :return: 处理后的数据集 171 | """ 172 | # 数据预处理 173 | tokenizer = BertTokenizer.from_pretrained( 174 | self.pretrained_model_name_or_path) 175 | 176 | def tokenize(instance: Instance): 177 | result_dict = {} 178 | input_ids_list, attention_mask_list, offset_mask_list = [], [], [] 179 | for token in instance['tokens']: 180 | tokenized_token = tokenizer([token], 181 | is_split_into_words=True, 182 | return_tensors='np', 183 | return_attention_mask=True, 184 | return_token_type_ids=False, 185 | add_special_tokens=False) 186 | token_offset_mask = np.zeros( 187 | tokenized_token['input_ids'].shape, dtype=int) 188 | token_offset_mask[0, 0] = 1 189 | input_ids_list.append(tokenized_token['input_ids']) 190 | attention_mask_list.append(tokenized_token['attention_mask']) 191 | offset_mask_list.append(token_offset_mask) 192 | input_ids = reduce(lambda x, y: np.concatenate((x, y), axis=1), 193 | input_ids_list)[0, :] 194 | attention_mask = reduce( 195 | lambda x, y: np.concatenate((x, y), axis=1), 196 | attention_mask_list)[0, :] 197 | offset_mask = reduce(lambda x, y: np.concatenate((x, y), axis=1), 198 | offset_mask_list)[0, :] 199 | result_dict['input_ids'] = input_ids 200 | result_dict['attention_mask'] = attention_mask 201 | result_dict['offset_mask'] = offset_mask 202 | # 顺便把 tag 转化为 id 203 | if 'entity_mentions' in instance.keys(): 204 | for i in range(len(instance['entity_mentions'])): 205 | instance['entity_mentions'][i] = ( 206 | instance['entity_mentions'][i][0], 207 | tag_vocab['entity'].to_index( 208 | instance['entity_mentions'][i][1])) 209 | result_dict['entity_mentions'] = instance['entity_mentions'] 210 | return result_dict 211 | 212 | data_bundle.apply_more(tokenize) 213 | return data_bundle 214 | 215 | def on_setup_model(self, data_bundle: DataBundle, 216 | tag_vocab: Dict[str, Vocabulary], 217 | state_dict: Optional[dict]): 218 | """加载 BERT 模型和分类头. 219 | 220 | :param data_bundle: 预处理后的数据集 221 | :param tag_vocab: 生成或加载的 `tag_vocab` 222 | :param state_dict: 加载的 `checkpoint` 223 | :return: 拥有 ``train_step``、``evaluate_step``、 ``infer_step`` 224 | 方法的对象 225 | """ 226 | if 'entity' not in tag_vocab.keys(): 227 | raise Exception('Failed to get `tag_vocab`') 228 | # 模型加载阶段 229 | model = Model(self.pretrained_model_name_or_path, 230 | num_labels=len(list( 231 | tag_vocab['entity'].word2idx.keys())), 232 | tag_vocab=tag_vocab['entity']) 233 | if state_dict: 234 | model.load_state_dict(state_dict['model']) 235 | return model 236 | 237 | def on_setup_optimizers(self, model, data_bundle: DataBundle, 238 | tag_vocab: Dict[str, Vocabulary], 239 | state_dict: Optional[dict]): 240 | """加载 `Adam` 优化器. 241 | 242 | :param model: 模型 243 | :param data_bundle: 预处理后的数据集 244 | :param tag_vocab: 生成或加载的 `tag_vocab` 245 | :param state_dict: 加载的 `checkpoint` 246 | :return: 247 | """ 248 | # 优化器加载阶段 249 | return torch.optim.Adam(model.parameters(), lr=self.lr) 250 | 251 | def on_setup_metrics(self, model, data_bundle: DataBundle, 252 | tag_vocab: Dict[str, Vocabulary], 253 | state_dict: Optional[dict]) -> dict: 254 | """加载 `Accuracy` 评价指标. 255 | 256 | :param model: 模型 257 | :param data_bundle: 预处理后的数据集 258 | :param tag_vocab: 生成或加载的 `tag_vocab` 259 | :param state_dict: 加载的 `checkpoint` 260 | :return: 261 | """ 262 | # 评价指标加载阶段 263 | return {'accuracy': Accuracy()} 264 | 265 | def on_get_state_dict(self, model, data_bundle: DataBundle, 266 | tag_vocab: Dict[str, Vocabulary]) -> dict: 267 | state_dict = super().on_get_state_dict(model, data_bundle, tag_vocab) 268 | state_dict[ 269 | 'pretrained_model_name_or_path'] = self.pretrained_model_name_or_path 270 | return state_dict 271 | -------------------------------------------------------------------------------- /fastie/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import collections.abc 3 | import functools 4 | import itertools 5 | import subprocess 6 | import warnings 7 | from collections import abc 8 | from importlib import import_module 9 | from inspect import getfullargspec 10 | from itertools import repeat 11 | 12 | 13 | # From PyTorch internals 14 | def _ntuple(n): 15 | 16 | def parse(x): 17 | if isinstance(x, collections.abc.Iterable): 18 | return x 19 | return tuple(repeat(x, n)) 20 | 21 | return parse 22 | 23 | 24 | to_1tuple = _ntuple(1) 25 | to_2tuple = _ntuple(2) 26 | to_3tuple = _ntuple(3) 27 | to_4tuple = _ntuple(4) 28 | to_ntuple = _ntuple 29 | 30 | 31 | def is_str(x): 32 | """Whether the input is an string instance. 33 | 34 | Note: This method is deprecated since python 2 is no longer supported. 35 | """ 36 | return isinstance(x, str) 37 | 38 | 39 | def import_modules_from_strings(imports, allow_failed_imports=False): 40 | """Import del-modules from the given list of strings. 41 | Args: 42 | imports (list | str | None): The given module names to be imported. 43 | allow_failed_imports (bool): If True, the failed imports will return 44 | None. Otherwise, an ImportError is raise. Default: False. 45 | Returns: 46 | list[module] | module | None: The imported del-modules. 47 | Examples: 48 | >>> osp, sys = import_modules_from_strings( 49 | ... ['os.path', 'sys']) 50 | >>> import os.path as osp_ 51 | >>> import sys as sys_ 52 | >>> assert osp == osp_ 53 | >>> assert sys == sys_ 54 | """ 55 | if not imports: 56 | return 57 | single_import = False 58 | if isinstance(imports, str): 59 | single_import = True 60 | imports = [imports] 61 | if not isinstance(imports, list): 62 | raise TypeError( 63 | f'custom_imports must be a list but got type {type(imports)}') 64 | imported = [] 65 | for imp in imports: 66 | if not isinstance(imp, str): 67 | raise TypeError( 68 | f'{imp} is of type {type(imp)} and cannot be imported.') 69 | try: 70 | imported_tmp = import_module(imp) 71 | except ImportError: 72 | if allow_failed_imports: 73 | warnings.warn(f'{imp} failed to import and is ignored.', 74 | UserWarning) 75 | imported_tmp = None 76 | else: 77 | raise ImportError 78 | imported.append(imported_tmp) 79 | if single_import: 80 | imported = imported[0] 81 | return imported 82 | 83 | 84 | def iter_cast(inputs, dst_type, return_type=None): 85 | """Cast elements of an iterable object into some type. 86 | 87 | Args: 88 | inputs (Iterable): The input object. 89 | dst_type (type): Destination type. 90 | return_type (type, optional): If specified, the output object will be 91 | converted to this type, otherwise an iterator. 92 | Returns: 93 | iterator or specified type: The converted object. 94 | """ 95 | if not isinstance(inputs, abc.Iterable): 96 | raise TypeError('inputs must be an iterable object') 97 | if not isinstance(dst_type, type): 98 | raise TypeError('"dst_type" must be a valid type') 99 | 100 | out_iterable = map(dst_type, inputs) 101 | 102 | if return_type is None: 103 | return out_iterable 104 | else: 105 | return return_type(out_iterable) 106 | 107 | 108 | def list_cast(inputs, dst_type): 109 | """Cast elements of an iterable object into a list of some type. 110 | 111 | A partial method of :func:`iter_cast`. 112 | """ 113 | return iter_cast(inputs, dst_type, return_type=list) 114 | 115 | 116 | def tuple_cast(inputs, dst_type): 117 | """Cast elements of an iterable object into a tuple of some type. 118 | 119 | A partial method of :func:`iter_cast`. 120 | """ 121 | return iter_cast(inputs, dst_type, return_type=tuple) 122 | 123 | 124 | def is_seq_of(seq, expected_type, seq_type=None): 125 | """Check whether it is a sequence of some type. 126 | 127 | Args: 128 | seq (Sequence): The sequence to be checked. 129 | expected_type (type): Expected type of sequence items. 130 | seq_type (type, optional): Expected sequence type. 131 | Returns: 132 | bool: Whether the sequence is valid. 133 | """ 134 | if seq_type is None: 135 | exp_seq_type = abc.Sequence 136 | else: 137 | assert isinstance(seq_type, type) 138 | exp_seq_type = seq_type 139 | if not isinstance(seq, exp_seq_type): 140 | return False 141 | for item in seq: 142 | if not isinstance(item, expected_type): 143 | return False 144 | return True 145 | 146 | 147 | def is_list_of(seq, expected_type): 148 | """Check whether it is a list of some type. 149 | 150 | A partial method of :func:`is_seq_of`. 151 | """ 152 | return is_seq_of(seq, expected_type, seq_type=list) 153 | 154 | 155 | def is_tuple_of(seq, expected_type): 156 | """Check whether it is a tuple of some type. 157 | 158 | A partial method of :func:`is_seq_of`. 159 | """ 160 | return is_seq_of(seq, expected_type, seq_type=tuple) 161 | 162 | 163 | def slice_list(in_list, lens): 164 | """Slice a list into several sub lists by a list of given length. 165 | 166 | Args: 167 | in_list (list): The list to be sliced. 168 | lens(int or list): The expected length of each out list. 169 | Returns: 170 | list: A list of sliced list. 171 | """ 172 | if isinstance(lens, int): 173 | assert len(in_list) % lens == 0 174 | lens = [lens] * int(len(in_list) / lens) 175 | if not isinstance(lens, list): 176 | raise TypeError('"indices" must be an integer or a list of integers') 177 | elif sum(lens) != len(in_list): 178 | raise ValueError('sum of lens and list length does not ' 179 | f'match: {sum(lens)} != {len(in_list)}') 180 | out_list = [] 181 | idx = 0 182 | for i in range(len(lens)): 183 | out_list.append(in_list[idx:idx + lens[i]]) 184 | idx += lens[i] 185 | return out_list 186 | 187 | 188 | def concat_list(in_list): 189 | """Concatenate a list of list into a single list. 190 | 191 | Args: 192 | in_list (list): The list of list to be merged. 193 | Returns: 194 | list: The concatenated flat list. 195 | """ 196 | return list(itertools.chain(*in_list)) 197 | 198 | 199 | def check_prerequisites( 200 | prerequisites, 201 | checker, 202 | msg_tmpl='Prerequisites "{}" are required in method "{}" but not ' 203 | 'found, please install them first.'): # yapf: disable 204 | """A decorator factory to check if prerequisites are satisfied. 205 | 206 | Args: 207 | prerequisites (str of list[str]): Prerequisites to be checked. 208 | checker (callable): The checker method that returns True if a 209 | prerequisite is meet, False otherwise. 210 | msg_tmpl (str): The message template with two variables. 211 | Returns: 212 | decorator: A specific decorator. 213 | """ 214 | 215 | def wrap(func): 216 | 217 | @functools.wraps(func) 218 | def wrapped_func(*args, **kwargs): 219 | requirements = [prerequisites] if isinstance( 220 | prerequisites, str) else prerequisites 221 | missing = [] 222 | for item in requirements: 223 | if not checker(item): 224 | missing.append(item) 225 | if missing: 226 | print(msg_tmpl.format(', '.join(missing), func.__name__)) 227 | raise RuntimeError('Prerequisites not meet.') 228 | else: 229 | return func(*args, **kwargs) 230 | 231 | return wrapped_func 232 | 233 | return wrap 234 | 235 | 236 | def _check_py_package(package): 237 | try: 238 | import_module(package) 239 | except ImportError: 240 | return False 241 | else: 242 | return True 243 | 244 | 245 | def _check_executable(cmd): 246 | if subprocess.call(f'which {cmd}', shell=True) != 0: 247 | return False 248 | else: 249 | return True 250 | 251 | 252 | def requires_package(prerequisites): 253 | """A decorator to check if some python packages are installed. 254 | Example: 255 | >>> @requires_package('numpy') 256 | >>> func(arg1, args): 257 | >>> return numpy.zeros(1) 258 | array([0.]) 259 | >>> @requires_package(['numpy', 'non_package']) 260 | >>> func(arg1, args): 261 | >>> return numpy.zeros(1) 262 | ImportError 263 | """ 264 | return check_prerequisites(prerequisites, checker=_check_py_package) 265 | 266 | 267 | def requires_executable(prerequisites): 268 | """A decorator to check if some executable files are installed. 269 | Example: 270 | >>> @requires_executable('ffmpeg') 271 | >>> func(arg1, args): 272 | >>> print(1) 273 | 1 274 | """ 275 | return check_prerequisites(prerequisites, checker=_check_executable) 276 | 277 | 278 | def deprecated_api_warning(name_dict, cls_name=None): 279 | """A decorator to check if some arguments are deprecate and try to replace 280 | deprecate src_arg_name to dst_arg_name. 281 | 282 | Args: 283 | name_dict(dict): 284 | key (str): Deprecate argument names. 285 | val (str): Expected argument names. 286 | Returns: 287 | func: New function. 288 | """ 289 | 290 | def api_warning_wrapper(old_func): 291 | 292 | @functools.wraps(old_func) 293 | def new_func(*args, **kwargs): 294 | # get the arg spec of the decorated method 295 | args_info = getfullargspec(old_func) 296 | # get name of the function 297 | func_name = old_func.__name__ 298 | if cls_name is not None: 299 | func_name = f'{cls_name}.{func_name}' 300 | if args: 301 | arg_names = args_info.args[:len(args)] 302 | for src_arg_name, dst_arg_name in name_dict.items(): 303 | if src_arg_name in arg_names: 304 | warnings.warn( 305 | f'"{src_arg_name}" is deprecated in ' 306 | f'`{func_name}`, please use "{dst_arg_name}" ' 307 | 'instead', DeprecationWarning) 308 | arg_names[arg_names.index(src_arg_name)] = dst_arg_name 309 | if kwargs: 310 | for src_arg_name, dst_arg_name in name_dict.items(): 311 | if src_arg_name in kwargs: 312 | assert dst_arg_name not in kwargs, ( 313 | f'The expected behavior is to replace ' 314 | f'the deprecated key `{src_arg_name}` to ' 315 | f'new key `{dst_arg_name}`, but got them ' 316 | f'in the arguments at the same time, which ' 317 | f'is confusing. `{src_arg_name} will be ' 318 | f'deprecated in the future, please ' 319 | f'use `{dst_arg_name}` instead.') 320 | 321 | warnings.warn( 322 | f'"{src_arg_name}" is deprecated in ' 323 | f'`{func_name}`, please use "{dst_arg_name}" ' 324 | 'instead', DeprecationWarning) 325 | kwargs[dst_arg_name] = kwargs.pop(src_arg_name) 326 | 327 | # apply converted arguments to the decorated method 328 | output = old_func(*args, **kwargs) 329 | return output 330 | 331 | return new_func 332 | 333 | return api_warning_wrapper 334 | 335 | 336 | def is_method_overridden(method, base_class, derived_class): 337 | """Check if a method of base class is overridden in derived class. 338 | 339 | Args: 340 | method (str): the method name to check. 341 | base_class (type): the class of the base class. 342 | derived_class (type | Any): the class or instance of the derived class. 343 | """ 344 | assert isinstance(base_class, type), \ 345 | "base_class doesn't accept instance, Please pass class instead." 346 | 347 | if not isinstance(derived_class, type): 348 | derived_class = derived_class.__class__ 349 | 350 | base_method = getattr(base_class, method) 351 | derived_method = getattr(derived_class, method) 352 | return derived_method != base_method 353 | 354 | 355 | def has_method(obj: object, method: str) -> bool: 356 | """Check whether the object has a method. 357 | 358 | Args: 359 | method (str): The method name to check. 360 | obj (object): The object to check. 361 | Returns: 362 | bool: True if the object has the method else False. 363 | """ 364 | return hasattr(obj, method) and callable(getattr(obj, method)) 365 | -------------------------------------------------------------------------------- /fastie/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import inspect 3 | import warnings 4 | from functools import partial 5 | from typing import Any, Dict, Optional 6 | 7 | from .misc import deprecated_api_warning, is_seq_of 8 | 9 | 10 | def build_from_cfg(cfg: Dict, 11 | registry: 'Registry', 12 | default_args: Optional[Dict] = None) -> Any: 13 | """Build a module from config dict when it is a class configuration, or 14 | call a function from config dict when it is a function configuration. 15 | Example: 16 | >>> MODELS = Registry('models') 17 | >>> @MODELS.register_module() 18 | >>> class ResNet: 19 | >>> pass 20 | >>> resnet = build_from_cfg(dict(type='Resnet'), MODELS) 21 | >>> # Returns an instantiated object 22 | >>> @MODELS.register_module() 23 | >>> def resnet50(): 24 | >>> pass 25 | >>> resnet = build_from_cfg(dict(type='resnet50'), MODELS) 26 | >>> # Return a result of the calling function 27 | Args: 28 | cfg (dict): Config dict. It should at least contain the key "type". 29 | registry (:obj:`Registry`): The registry to search the type from. 30 | default_args (dict, optional): Default initialization arguments. 31 | Returns: 32 | object: The constructed object. 33 | """ 34 | if not isinstance(cfg, dict): 35 | raise TypeError(f'cfg must be a dict, but got {type(cfg)}') 36 | if 'type' not in cfg: 37 | if default_args is None or 'type' not in default_args: 38 | raise KeyError( 39 | '`cfg` or `default_args` must contain the key "type", ' 40 | f'but got {cfg}\n{default_args}') 41 | if not isinstance(registry, Registry): 42 | raise TypeError('registry must be an mmcv.Registry object, ' 43 | f'but got {type(registry)}') 44 | if not (isinstance(default_args, dict) or default_args is None): 45 | raise TypeError('default_args must be a dict or None, ' 46 | f'but got {type(default_args)}') 47 | 48 | args = cfg.copy() 49 | 50 | if default_args is not None: 51 | for name, value in default_args.items(): 52 | args.setdefault(name, value) 53 | 54 | obj_type = args.pop('type') 55 | if isinstance(obj_type, str): 56 | obj_cls = registry.get(obj_type) 57 | if obj_cls is None: 58 | raise KeyError( 59 | f'{obj_type} is not in the {registry.name} registry') 60 | elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): 61 | obj_cls = obj_type 62 | else: 63 | raise TypeError( 64 | f'type must be a str or valid type, but got {type(obj_type)}') 65 | try: 66 | return obj_cls(**args) 67 | except Exception as e: 68 | # Normal TypeError does not print class name. 69 | raise type(e)(f'{obj_cls.__name__}: {e}') 70 | 71 | 72 | class Registry: 73 | """A registry to map strings to classes or functions. 74 | Registered object could be built from registry. Meanwhile, registered 75 | functions could be called from registry. 76 | Example: 77 | >>> MODELS = Registry('models') 78 | >>> @MODELS.register_module() 79 | >>> class ResNet: 80 | >>> pass 81 | >>> resnet = MODELS.build(dict(type='ResNet')) 82 | >>> @MODELS.register_module() 83 | >>> def resnet50(): 84 | >>> pass 85 | >>> resnet = MODELS.build(dict(type='resnet50')) 86 | Please refer to 87 | https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for 88 | advanced usage. 89 | Args: 90 | name (str): Registry name. 91 | build_func(func, optional): Build function to construct instance from 92 | Registry, func:`build_from_cfg` is used if neither ``parent`` or 93 | ``build_func`` is specified. If ``parent`` is specified and 94 | ``build_func`` is not given, ``build_func`` will be inherited 95 | from ``parent``. Default: None. 96 | parent (Registry, optional): Parent registry. The class registered in 97 | children registry could be built from parent. Default: None. 98 | scope (str, optional): The scope of registry. It is the key to search 99 | for children registry. If not specified, scope will be the name of 100 | the package where class is defined, e.g. mmdet, mmcls, mmseg. 101 | Default: None. 102 | """ 103 | 104 | def __init__(self, name, build_func=None, parent=None, scope=None): 105 | self._name = name 106 | self._module_dict = dict() 107 | self._children = dict() 108 | self._scope = self.infer_scope() if scope is None else scope 109 | 110 | # self.build_func will be set with the following priority: 111 | # 1. build_func 112 | # 2. parent.build_func 113 | # 3. build_from_cfg 114 | if build_func is None: 115 | if parent is not None: 116 | self.build_func = parent.build_func 117 | else: 118 | self.build_func = build_from_cfg 119 | else: 120 | self.build_func = build_func 121 | if parent is not None: 122 | assert isinstance(parent, Registry) 123 | parent._add_children(self) 124 | self.parent = parent 125 | else: 126 | self.parent = None 127 | 128 | def __len__(self): 129 | return len(self._module_dict) 130 | 131 | def __contains__(self, key): 132 | return self.get(key) is not None 133 | 134 | def __repr__(self): 135 | format_str = self.__class__.__name__ + \ 136 | f'(name={self._name}, ' \ 137 | f'items={self._module_dict})' 138 | return format_str 139 | 140 | @staticmethod 141 | def infer_scope(): 142 | """Infer the scope of registry. 143 | The name of the package where registry is defined will be returned. 144 | Example: 145 | >>> # in mmdet/models/backbone/resnet.py 146 | >>> MODELS = Registry('models') 147 | >>> @MODELS.register_module() 148 | >>> class ResNet: 149 | >>> pass 150 | The scope of ``ResNet`` will be ``mmdet``. 151 | Returns: 152 | str: The inferred scope name. 153 | """ 154 | # We access the caller using inspect.currentframe() instead of 155 | # inspect.stack() for performance reasons. See details in PR #1844 156 | frame = inspect.currentframe() 157 | # get the frame where `infer_scope()` is called 158 | infer_scope_caller = frame.f_back.f_back 159 | filename = inspect.getmodule(infer_scope_caller).__name__ 160 | split_filename = filename.split('.') 161 | return split_filename[0] 162 | 163 | @staticmethod 164 | def split_scope_key(key): 165 | """Split scope and key. 166 | The first scope will be split from key. 167 | Examples: 168 | >>> Registry.split_scope_key('mmdet.ResNet') 169 | 'mmdet', 'ResNet' 170 | >>> Registry.split_scope_key('ResNet') 171 | None, 'ResNet' 172 | Return: 173 | tuple[str | None, str]: The former element is the first scope of 174 | the key, which can be ``None``. The latter is the remaining key. 175 | """ 176 | split_index = key.find('.') 177 | if split_index != -1: 178 | return key[:split_index], key[split_index + 1:] 179 | else: 180 | return None, key 181 | 182 | @property 183 | def name(self): 184 | return self._name 185 | 186 | @property 187 | def scope(self): 188 | return self._scope 189 | 190 | @property 191 | def module_dict(self): 192 | return self._module_dict 193 | 194 | @property 195 | def children(self): 196 | return self._children 197 | 198 | def get(self, key): 199 | """Get the registry record. 200 | 201 | Args: 202 | key (str): The class name in string format. 203 | Returns: 204 | class: The corresponding class. 205 | """ 206 | scope, real_key = self.split_scope_key(key) 207 | if scope is None or scope == self._scope: 208 | # get from self 209 | if real_key in self._module_dict: 210 | return self._module_dict[real_key] 211 | else: 212 | # get from self._children 213 | if scope in self._children: 214 | return self._children[scope].get(real_key) 215 | else: 216 | # goto root 217 | parent = self.parent 218 | while parent.parent is not None: 219 | parent = parent.parent 220 | return parent.get(key) 221 | 222 | def build(self, *args, **kwargs): 223 | return self.build_func(*args, **kwargs, registry=self) 224 | 225 | def _add_children(self, registry): 226 | """Add children for a registry. 227 | The ``registry`` will be added as children based on its scope. 228 | The parent registry could build objects from children registry. 229 | Example: 230 | >>> models = Registry('models') 231 | >>> mmdet_models = Registry('models', parent=models) 232 | >>> @mmdet_models.register_module() 233 | >>> class ResNet: 234 | >>> pass 235 | >>> resnet = models.build(dict(type='mmdet.ResNet')) 236 | """ 237 | 238 | assert isinstance(registry, Registry) 239 | assert registry.scope is not None 240 | assert registry.scope not in self.children, \ 241 | f'scope {registry.scope} exists in {self.name} registry' 242 | self.children[registry.scope] = registry 243 | 244 | @deprecated_api_warning(name_dict=dict(module_class='module')) 245 | def _register_module(self, module, module_name=None, force=False): 246 | if not inspect.isclass(module) and not inspect.isfunction(module): 247 | raise TypeError('module must be a class or a function, ' 248 | f'but got {type(module)}') 249 | 250 | if module_name is None: 251 | module_name = module.__name__ 252 | if isinstance(module_name, str): 253 | module_name = [module_name] 254 | for name in module_name: 255 | if not force and name in self._module_dict: 256 | raise KeyError(f'{name} is already registered ' 257 | f'in {self.name}') 258 | self._module_dict[name] = module 259 | 260 | def deprecated_register_module(self, cls=None, force=False): 261 | warnings.warn( 262 | 'The old API of register_module(module, force=False) ' 263 | 'is deprecated and will be removed, please use the new API ' 264 | 'register_module(name=None, force=False, module=None) instead.', 265 | DeprecationWarning) 266 | if cls is None: 267 | return partial(self.deprecated_register_module, force=force) 268 | self._register_module(cls, force=force) 269 | return cls 270 | 271 | def register_module(self, name=None, force=False, module=None): 272 | """Register a module. 273 | A record will be added to `self._module_dict`, whose key is the class 274 | name or the specified name, and value is the class itself. 275 | It can be used as a decorator or a normal function. 276 | Example: 277 | >>> backbones = Registry('backbone') 278 | >>> @backbones.register_module() 279 | >>> class ResNet: 280 | >>> pass 281 | >>> backbones = Registry('backbone') 282 | >>> @backbones.register_module(name='mnet') 283 | >>> class MobileNet: 284 | >>> pass 285 | >>> backbones = Registry('backbone') 286 | >>> class ResNet: 287 | >>> pass 288 | >>> backbones.register_module(ResNet) 289 | Args: 290 | name (str | None): The module name to be registered. If not 291 | specified, the class name will be used. 292 | force (bool, optional): Whether to override an existing class with 293 | the same name. Default: False. 294 | module (type): Module class or function to be registered. 295 | """ 296 | if not isinstance(force, bool): 297 | raise TypeError(f'force must be a boolean, but got {type(force)}') 298 | # NOTE: This is a walkaround to be compatible with the old api, 299 | # while it may introduce unexpected bugs. 300 | if isinstance(name, type): 301 | return self.deprecated_register_module(name, force=force) 302 | 303 | # raise the error ahead of time 304 | if not (name is None or isinstance(name, str) or is_seq_of(name, str)): 305 | raise TypeError( 306 | 'name must be either of None, an instance of str or a sequence' 307 | f' of str, but got {type(name)}') 308 | 309 | # use it as a normal method: x.register_module(module=SomeClass) 310 | if module is not None: 311 | self._register_module(module=module, module_name=name, force=force) 312 | return module 313 | 314 | # use it as a decorator: @x.register_module() 315 | def _register(module): 316 | self._register_module(module=module, module_name=name, force=force) 317 | return module 318 | 319 | return _register 320 | -------------------------------------------------------------------------------- /fastie/node.py: -------------------------------------------------------------------------------- 1 | """FastIE 节点基类,继承该类的子类将具有以下功能: 2 | 3 | * 自动将配置类中的配置项注册为 ``argparse`` 解析器的参数,并在解析时自动赋值 4 | * 从 ``dict`` 类型的配置对象或者配置文件实例化 5 | """ 6 | __all__ = ['BaseNodeConfig', 'BaseNode'] 7 | 8 | import inspect 9 | import re 10 | import zipfile 11 | from argparse import ArgumentParser, Namespace, Action 12 | from dataclasses import dataclass, MISSING 13 | from typing import Union, Sequence, Optional, Dict, Type 14 | 15 | from fastie.envs import parser as global_parser, get_flag, PARSER_FLAG 16 | from fastie.utils.utils import parse_config 17 | 18 | 19 | @dataclass 20 | class BaseNodeConfig: 21 | """FastIE 节点配置基类.""" 22 | 23 | def parse(self, obj: object): 24 | """将当前对象的属性值赋值给obj. 25 | 26 | :param obj: 任意对象 27 | :return: 28 | """ 29 | for key, value in self.__dict__.items(): 30 | setattr(obj, key, value) 31 | 32 | def to_dict(self) -> dict: 33 | """将当前对象转换为字典. 34 | 35 | :return: 36 | """ 37 | fields = dict() 38 | for field_name, field_value in self.__class__.__dict__[ 39 | '__dataclass_fields__'].items(): 40 | fields[field_name] = dict(type=getattr(field_value, 'type'), 41 | default=getattr(field_value, 'default'), 42 | default_factory=getattr( 43 | field_value, 'default_factory'), 44 | metadata=getattr(field_value, 45 | 'metadata')) 46 | return fields 47 | 48 | @classmethod 49 | def from_dict(cls, _config: dict): 50 | """从字典中创建配置. 51 | 52 | :param _config: ``dict`` 类型的配置 53 | :return: :class:`BaseNodeConfig` 类型的配置 54 | """ 55 | config = cls() 56 | for key in config.__dir__(): 57 | if not key.startswith('_') and key in _config.keys(): 58 | setattr(config, key, _config[key]) 59 | return config 60 | 61 | def keys(self): 62 | """获取当前配置的所有属性名. 63 | 64 | :return: ``list`` 类型的属性名列表 65 | """ 66 | return [key for key in self.__dir__() if not key.startswith('_')] 67 | 68 | def __getitem__(self, item): 69 | """通过 ``[]`` 获取属性值. 70 | 71 | :param item: 属性名 72 | :return: 属性值 73 | """ 74 | return getattr(self, item) 75 | 76 | 77 | class BaseNode(object): 78 | """FastIE 节点基类. 79 | 80 | 继承该类的子类将具有以下功能: 81 | * 自动将配置类中的配置项注册为 ``argparse`` 解析器的参数 82 | * 从 ``dict`` 类型的配置对象或者配置文件实例化 83 | """ 84 | _config = BaseNodeConfig() 85 | _help = 'The base class of all node objects' 86 | 87 | def __init__(self, **kwargs): 88 | self._parser = global_parser.add_argument_group( 89 | title=f'Optional argument for {self.__class__.__name__}') 90 | self._overload_config: dict = {} 91 | 92 | @classmethod 93 | def from_config(cls, config: Union[BaseNodeConfig, str, dict]): 94 | """从配置文件或配置对象中创建节点. 95 | 96 | :param config: 可以为 ``*.py`` 文件路径或者 :class:`BaseNodeConfig` 类型的对象 97 | :return: :class:`BaseNode` 类型的节点 98 | """ 99 | node = cls() 100 | if isinstance(config, BaseNodeConfig): 101 | node._config = config 102 | else: 103 | _config = parse_config(config) 104 | if _config is not None: 105 | node._overload_config = _config 106 | node._config = node._config.__class__.from_dict(_config) 107 | for key, value in _config.items(): # type: ignore [union-attr] 108 | if hasattr(node, key): 109 | setattr(node, key, value) 110 | # node._config.parse(node) 111 | return node 112 | 113 | @property 114 | def parser(self): 115 | """根据当前节点的配置类构造当前节点的 ``argparse`` 解析器. 116 | 117 | :return: :class:`argparse.ArgumentParser` 类型的解析器 118 | """ 119 | 120 | def inspect_all_bases(cls: type): 121 | if cls == object: 122 | return 123 | if PARSER_FLAG == 'dataclass': 124 | for key, value in self._config.to_dict().items(): 125 | if isinstance(value['metadata']['existence'], bool) \ 126 | and value['metadata']['existence'] \ 127 | or isinstance(value['metadata']['existence'], list) \ 128 | and get_flag() in value['metadata']['existence'] \ 129 | or isinstance(value['metadata']['existence'], str) \ 130 | and get_flag() == value['metadata']['existence']: 131 | default_value = None 132 | if value['default'] != MISSING: 133 | default_value = value['default'] 134 | if value['default_factory'] != MISSING: 135 | default_value = value['default_factory']() 136 | arg_flag = [f'--{key}'] 137 | if 'alias' in value['metadata']: 138 | if isinstance(value['metadata']['alias'], str): 139 | arg_flag = [ 140 | *arg_flag, value['metadata']['alias'] 141 | ] 142 | elif isinstance(value['metadata']['alias'], 143 | Sequence): 144 | arg_flag.extend([ 145 | item for item in value['metadata']['alias'] 146 | ]) 147 | nargs = 1 148 | if 'nargs' in value['metadata']: 149 | nargs = value['metadata']['nargs'] 150 | if type(default_value) == bool: 151 | self._parser.add_argument( 152 | *arg_flag, 153 | default=default_value, 154 | help=f"{value['metadata']['help']} " 155 | f'default: {default_value}', 156 | action=self.action, 157 | metavar='', 158 | nargs='?', 159 | const=True, 160 | required=False) 161 | else: 162 | self._parser.add_argument( 163 | *arg_flag, 164 | default=default_value, 165 | type=type(default_value), 166 | help=f"{value['metadata']['help']} " 167 | f'default: {default_value}', 168 | action=self.action, 169 | metavar='', 170 | nargs=nargs, 171 | required=False) 172 | elif PARSER_FLAG == 'comment': 173 | for key, value in cls().comments.items(): 174 | if get_flag() in value['flags']: 175 | self._parser.add_argument( 176 | f'--{key}', 177 | default=value['value'], 178 | type=type(value['value']), 179 | help=f"{value['description']} " 180 | f"默认值为: {value['value']}", 181 | action=self.action, 182 | metavar='', 183 | required=False) 184 | for father in cls.__bases__: 185 | inspect_all_bases(father) 186 | 187 | inspect_all_bases(self.__class__) 188 | return self._parser 189 | 190 | @property 191 | def action(self) -> Type[Action]: 192 | """根据当前节点的配置类构造当前节点的 ``argparse`` 解析器的 ``action`` 参数. 193 | 194 | :return: :class:`argparse.Action` 类型的 ``action`` 参数 195 | """ 196 | node = self 197 | 198 | class ParseAction(Action): 199 | 200 | def __call__(self, 201 | parser: ArgumentParser, 202 | namespace: Namespace, 203 | values, 204 | option_string: Optional[str] = None): 205 | if option_string is None: 206 | return 207 | field_dict = node._config.__class__.__dataclass_fields__ 208 | if option_string.replace('--', '') in field_dict.keys(): 209 | variable_name = option_string.replace('--', '') 210 | if 'multi_method' in field_dict[ 211 | variable_name].metadata.keys(): 212 | if field_dict[variable_name].metadata[ 213 | 'multi_method'] == 'space-join': 214 | values = ' '.join(values) 215 | else: 216 | # TODO: 当此位置接受多个参数时,需要对多个参数采取的操作 217 | pass 218 | if isinstance(values, Sequence) and len(values) == 1: 219 | values = values[0] 220 | setattr(node, variable_name, values) 221 | setattr(namespace, variable_name, values) 222 | else: 223 | for key, value in field_dict.items(): 224 | if isinstance(value.metadata['alias'], Sequence): 225 | if option_string in value.metadata['alias']: 226 | if 'multi_method' in value.metadata.keys(): 227 | if value.metadata[ 228 | 'multi_method'] == 'space-join': 229 | values = ' '.join(values) 230 | else: 231 | # TODO: 当此位置接受多个参数时,需要对多个参数采取的操作 232 | pass 233 | setattr(node, key, values) 234 | setattr(namespace, key, values) 235 | 236 | elif isinstance(value.metadata['alias'], str): 237 | if option_string == value.metadata['alias']: 238 | if 'multi_method' in value.metadata.keys(): 239 | if value.metadata[ 240 | 'multi_method'] == 'space-join': 241 | values = ' '.join(values) 242 | else: 243 | # TODO: 当此位置接受多个参数时,需要对多个参数采取的操作 244 | pass 245 | setattr(node, key, values) 246 | setattr(namespace, key, values) 247 | 248 | return ParseAction 249 | 250 | @property 251 | def comments(self) -> dict: 252 | """获取当前节点的注释信息. 253 | 254 | .. warning:: 255 | 256 | 该方法已废弃 257 | 258 | :return: 259 | """ 260 | comments = {} 261 | 262 | def inspect_all_bases(cls: type): 263 | if cls == object: 264 | return 265 | code_path = inspect.getfile(cls) 266 | try: 267 | with open(file=inspect.getfile(cls), mode='r') as file: 268 | lines = file.readlines() 269 | except NotADirectoryError: 270 | egg_file = code_path.split('.egg')[0] + '.egg' 271 | sub_path = code_path.split('.egg')[1][1:] 272 | with zipfile.ZipFile(egg_file, 'r') as zip_file: 273 | lines = list( 274 | map(lambda x: x.decode(), 275 | zip_file.open(sub_path).readlines())) 276 | for index in range(len(lines)): 277 | if 'Args:' in lines[index]: 278 | break 279 | for index in range(index + 1, len(lines)): 280 | if ':' in lines[index]: 281 | match_result = re.search( 282 | r':(.*?)\((.*?)\)\[(.*?)\]=(.*?):(.*?)$', lines[index]) 283 | if match_result is None: 284 | continue 285 | else: 286 | key, t, flags, value, description = match_result.groups( 287 | ) 288 | comments[key.strip()] = dict( 289 | type=t.strip(), 290 | flags=flags.strip().split(','), 291 | value=value, 292 | description=description) 293 | if '"""' in lines[index]: 294 | break 295 | for cls in cls.__bases__: 296 | inspect_all_bases(cls) 297 | 298 | inspect_all_bases(self.__class__) 299 | return comments 300 | 301 | @property 302 | def description(self): 303 | """获取当前节点注释中的描述信息. 304 | 305 | .. warning:: 306 | 该方法已废弃 307 | 308 | :return: 309 | """ 310 | code_path = inspect.getfile(self.__class__) 311 | try: 312 | with open(file=inspect.getfile(self.__class__), mode='r') as file: 313 | lines = file.readlines() 314 | except NotADirectoryError: 315 | egg_file = code_path.split('.egg')[0] + '.egg' 316 | sub_path = code_path.split('.egg')[1][1:] 317 | with zipfile.ZipFile(egg_file, 'r') as zip_file: 318 | lines = zip_file.open(sub_path).readlines() 319 | lines = list(map(lambda x: x.decode(), lines)) 320 | for index in range(len(lines)): 321 | if '"""' in lines[index]: 322 | return lines[index].replace('"""', '').strip() 323 | 324 | @property 325 | def fields(self) -> dict: 326 | """获取当前节点配置类的所有字段. 327 | 328 | :return: ``dict`` 类型的字段信息 329 | """ 330 | fields: Dict[str, dict] = dict() 331 | for key in self.__dir__(): 332 | if isinstance(object.__getattribute__(self, key), BaseNodeConfig): 333 | config_cls = object.__getattribute__(self, key).__class__ 334 | for field_name, field_value in config_cls.__dict__[ 335 | '__dataclass_fields__'].items(): 336 | fields[field_name] = dict( 337 | type=getattr(field_value, 'type'), 338 | default=getattr(field_value, 'default'), 339 | default_factory=getattr(field_value, 340 | 'default_factory'), 341 | metadata=getattr(field_value, 'metadata')) 342 | return fields 343 | -------------------------------------------------------------------------------- /docs/source/tutorials/basic/fastie_tutorial_0.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "# T0. 使用命令行工具\n", 7 | "## 1. 命令行工具简介\n", 8 | "\n", 9 | "使用命令行工具允许您以最快的速度使用 `fastie` 集成的多种模型对您的数据进行训练、验证和推理,\n", 10 | "当您成功安装 `fastie` 后,您可以通过在命令行中输入 `fastie` 来查看所有可以使用的命令行工具.\n", 11 | "具体而言, `fastie` 提供了以下命令行工具:\n", 12 | "\n", 13 | "### 1.1 fastie-train\n", 14 | "\n", 15 | "`fastie-train` 为 `fastie` 的训练工具,它可以帮助您使用 `fastie` 集成的模型架构或预设方案\n", 16 | "对您的数据集进行训练并保存模型参数,您可以通过 `fastie-train --help` 来查看所有可用的参数.\n", 17 | "具体而言,所有的 `fastie` 命令行工具都有拥有以下基础参数:\n", 18 | "\n", 19 | "- `--task`, `-t` 指定任务名称.\n", 20 | "- `--dataset`, `-d` 指定数据集名称.\n", 21 | "- `--config`, `-c` 指定预设配置名称.\n", 22 | "- `--help`, `-h` 获得当前可用参数的帮助信息.\n", 23 | "\n", 24 | "下面对每个参数进行详细说明:\n", 25 | "\n", 26 | "#### 1.1.1 task\n", 27 | "\n", 28 | "`fastie` 中的 `task` 代表具体的 `NLP` 任务及其使用的解决方案(或模型)架构, 例如, `fastie`\n", 29 | "提供了使用预训练的 `BERT` 模型进行 `NER` 任务的解决方案, 该解决方案的 `task` 名称为:\n", 30 | "`ner/bert`, `/` 符号分割了任务的名称和架构的名称.\n", 31 | "\n", 32 | "`fastie` 提供了 `--list`, `-l` 工具帮助您快速获悉 `fastie` 中所有可用的 `task` 名称.\n", 33 | "例如, 在控制台中输入(在您的控制台中请勿输入开头的 `!` 字符):" 34 | ], 35 | "metadata": { 36 | "collapsed": false 37 | } 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 6, 42 | "outputs": [ 43 | { 44 | "name": "stdout", 45 | "output_type": "stream", 46 | "text": [ 47 | "Task Solution Description \r\n", 48 | "================================================================================\r\n", 49 | "NER bert 用预训练模型 Bert 对 token 进行向量表征,然后通过 classification head 对每个 token\r\n", 50 | " 进行分类。 \r\n", 51 | " \r\n" 52 | ] 53 | } 54 | ], 55 | "source": [ 56 | "!fastie-train --task --list" 57 | ], 58 | "metadata": { 59 | "collapsed": false 60 | } 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "source": [ 65 | "`fastie` 的检索功能同时支持筛选, 若您只想查看 `NER` 的相关任务, 您可以输入:" 66 | ], 67 | "metadata": { 68 | "collapsed": false 69 | } 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 7, 74 | "outputs": [ 75 | { 76 | "name": "stdout", 77 | "output_type": "stream", 78 | "text": [ 79 | "Task Solution Description \r\n", 80 | "================================================================================\r\n", 81 | "NER bert 用预训练模型 Bert 对 token 进行向量表征,然后通过 classification head 对每个 token\r\n", 82 | " 进行分类。 \r\n", 83 | " \r\n" 84 | ] 85 | } 86 | ], 87 | "source": [ 88 | "!fastie-train --task ner --list" 89 | ], 90 | "metadata": { 91 | "collapsed": false 92 | } 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "source": [ 97 | "#### 1.1.2 dataset\n", 98 | "\n", 99 | "`fastie` 中的 `dataset` 代表具体的 `NLP` 数据集或者数据集结构, 例如, `fastie` 提供了\n", 100 | "`CoNLL-2003` 数据集, 那么对应的 `dataset` 名称为 `conll2003`. `fastie` 除了提供常用\n", 101 | "`NLP` 数据集外(在 `fastie` 中被称为 `legacy` 数据集),还提供了多种数据集格式(在 `fastie`\n", 102 | "中被称为 `io` 数据集), 便于您读取自己的数据集, 例如, `fastie` 提供了 `jsonlines-ner`\n", 103 | "格式的数据集:\n", 104 | "\n", 105 | "```jsonl\n", 106 | "{\"tokens\": [\"I\", \"love\", \"fastie\", \".\"], \"entity_motions\": [\n", 107 | " {\"entity_index\": [0], \"entity_type\": \"PER\"}\n", 108 | " {\"entity_index\": [2], \"entity_type\": \"MISC\"}\n", 109 | "]}\n", 110 | "{\"tokens\": [\"I\", \"love\", \"fastNLP\", \".\"], \"entity_motions\": [\n", 111 | " {\"entity_index\": [0], \"entity_type\": \"PER\"}\n", 112 | " {\"entity_index\": [2], \"entity_type\": \"MISC\"}\n", 113 | "]}\n", 114 | "```\n", 115 | "\n", 116 | "如上所示, `jsonlines` 格式规定您的文件中每行是一个 `json` 格式的样本, 当您的数据集是\n", 117 | "`jsonlines-ner`格式的时候,您就可以将 `dataset` 设置为 `jsonlines-ner`,`fastie` 会自动\n", 118 | "读取您的数据集.\n", 119 | "\n", 120 | "与 `task` 类似, `fastie` 提供了 `--list`, `-l` 工具帮助您快速获悉 `fastie` 中所有可用的\n", 121 | "`dataset` 名称. 例如, 在控制台中输入(在您的控制台中请勿输入开头的 `!` 字符):" 122 | ], 123 | "metadata": { 124 | "collapsed": false 125 | } 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 8, 130 | "outputs": [ 131 | { 132 | "name": "stdout", 133 | "output_type": "stream", 134 | "text": [ 135 | " Dataset Description \r\n", 136 | "================================================================================\r\n", 137 | "conll2003 The shared task of CoNLL-2003 concerns language-independent \r\n", 138 | " named entity \r\n", 139 | "wikiann 这个类还没写好,请勿参考. \r\n", 140 | "column-ner 这个类还没写好,请勿参考. \r\n", 141 | "sentence None \r\n", 142 | "jsonlines-ner None \r\n", 143 | " \r\n" 144 | ] 145 | } 146 | ], 147 | "source": [ 148 | "!fastie-train --dataset --list" 149 | ], 150 | "metadata": { 151 | "collapsed": false 152 | } 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "source": [ 157 | "#### 1.1.3 config\n", 158 | "\n", 159 | "`fastie` 中的 `config` 代表预设配置文件, 您可以使用自己的配置文件, 也可以使用 `fastie` 提\n", 160 | "供的预设配置. 简单来说, `config` 同时具有 `fastie` 中的 `task` 和 `dataset` 的功能, 例如,\n", 161 | "您可以在当前的工作目录建立 `config.py`:\n", 162 | "\n", 163 | "```python\n", 164 | "config = {\n", 165 | " \"task\": \"ner/bert\",\n", 166 | " \"dataset\": \"conll2003\"\n", 167 | "}\n", 168 | "```\n", 169 | "\n", 170 | "然后在控制台中输入:\n", 171 | "\n", 172 | "```bash\n", 173 | "fastie-train --config config.py\n", 174 | "```\n", 175 | "\n", 176 | "上述的操作等价于直接在命令行输入:\n", 177 | "\n", 178 | "```bash\n", 179 | "fastie-train --task ner/bert --dataset conll2003\n", 180 | "```\n", 181 | "\n", 182 | "如上所示可见 `fastie` 中的 `config` 是使用 `python` 语言编写的, 因此您可以在 `config` 中\n", 183 | "使用任何 `python` 语言的特性, 例如, 您可以在 `config` 中使用 `if` 语句来判断当前的 `task`\n", 184 | "和 `dataset` 并设置不同的参数:\n", 185 | "\n", 186 | "```python\n", 187 | "dataset = \"conll2003\"\n", 188 | "config = {\n", 189 | " \"task\": \"ner/bert\" if dataset == \"conll2003\" else \"ner/bilstm\",\n", 190 | " \"dataset\": dataset\n", 191 | "}\n", 192 | "```\n", 193 | "\n", 194 | "但是请注意, 您提供的配置文件必须包含 `config` 变量, 且 `config` 变量必须是一个 `dict` 类型.\n", 195 | "\n", 196 | "`fastie` 中的配置文件除了可以储存 `task`, `dataset` 这种基础参数外, 还可以储存 `task` 或\n", 197 | "`dataset` 规定的内部参数, 详见 `2. 获得命令行的帮助`.\n", 198 | "\n", 199 | "同样的, `fastie` 提供了 `--list`, `-l` 工具帮助您快速获悉 `fastie` 中所有可用的 `config`\n", 200 | "名称. 例如, 在控制台中输入(在您的控制台中请勿输入开头的 `!` 字符):" 201 | ], 202 | "metadata": { 203 | "collapsed": false 204 | } 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 9, 209 | "outputs": [ 210 | { 211 | "name": "stdout", 212 | "output_type": "stream", 213 | "text": [ 214 | "Config Description \r\n", 215 | "bert-conll2003 : 使用 bert 对 conll2003 数据集进行序列标注\r\n" 216 | ] 217 | } 218 | ], 219 | "source": [ 220 | "!fastie-train --config --list" 221 | ], 222 | "metadata": { 223 | "collapsed": false 224 | } 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "source": [ 229 | "### 1.2 fastie-eval\n", 230 | "\n", 231 | "`fastie-eval` 是 `fastie` 中的验证工具, 用于评估 `fastie-train` 训练出的模型. `fastie-eval`\n", 232 | "的使用方法与 `fastie-train` 类似, 同样拥有 `--task`, `--dataset`, `--config` 这三个参数.\n", 233 | "\n", 234 | "在进行模型评测的时候, 我们一般需要设置一个模型文件路径的参数, 用于初始化 `task` 中的模型, 因此\n", 235 | "在这里简答介绍 `fastie` 的渐进式命令行帮助机制:\n", 236 | "\n", 237 | "当您仅在控制台中输入 `fastie-eval --help` 时, `fastie` 只会提供给你 `--task`, `--dataset`,\n", 238 | "`--config` 这三个基本参数的帮助. 而当您进一步明确了 `task` 或者 `dataset` 或者 `config` 后,\n", 239 | "`--help` 提供的参数帮助将会增加, 例如, 当您输入 `fastie-eval --task ner/bert --help` 时:" 240 | ], 241 | "metadata": { 242 | "collapsed": false 243 | } 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": 10, 248 | "outputs": [ 249 | { 250 | "name": "stdout", 251 | "output_type": "stream", 252 | "text": [ 253 | "usage: fastie-train [-h] [--config] [--task] [--dataset] [--cuda ]\r\n", 254 | " [--load_model] [--batch_size] [--shuffle ]\r\n", 255 | " [--pretrained_model_name_or_path]\r\n", 256 | "\r\n", 257 | "options:\r\n", 258 | " -h, --help show this help message and exit\r\n", 259 | "\r\n", 260 | "fastIE command line basic arguments:\r\n", 261 | " --config , -c The config file you want to use. default:\r\n", 262 | " --task , -t The task you want to use. Please use / to split the\r\n", 263 | " task and the specific solution. default:\r\n", 264 | " --dataset , -d The dataset you want to work with. default:\r\n", 265 | "\r\n", 266 | "Use pre-trained BERT and a classification head to classify tokens:\r\n", 267 | " --cuda [] Whether to use your NVIDIA graphics card to accelerate\r\n", 268 | " the process. default: False\r\n", 269 | " --load_model Load the model from the path or model name. default:\r\n", 270 | " --batch_size Batch size. default: 32\r\n", 271 | " --shuffle [] Whether to shuffle the dataset. default: True\r\n", 272 | " --pretrained_model_name_or_path \r\n", 273 | " name of transformer model (see https://huggingface.co/\r\n", 274 | " transformers/pretrained_models.html for options).\r\n", 275 | " default: bert-base-uncased\r\n" 276 | ] 277 | } 278 | ], 279 | "source": [ 280 | "!fastie-eval --task ner/bert --help" 281 | ], 282 | "metadata": { 283 | "collapsed": false 284 | } 285 | }, 286 | { 287 | "cell_type": "markdown", 288 | "source": [ 289 | "造成上述现象的原因是 `fastie` 会根据您的输入自动加载 `task` 或者 `dataset`, 并将其可用的参数\n", 290 | "帮助信息打印出来. 因此, 当您使用 `fastie` 的时候, 请尽量使用 `--help`.\n", 291 | "\n", 292 | "因此, 如果我们想要用现有的 `bertNER` 模型对自己的 `jsonlines-ner` 数据集进行验证, 我们可以在控制台中\n", 293 | "首先查询自己需要输入的参数信息:" 294 | ], 295 | "metadata": { 296 | "collapsed": false 297 | } 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 13, 302 | "outputs": [ 303 | { 304 | "name": "stdout", 305 | "output_type": "stream", 306 | "text": [ 307 | "usage: fastie-train [-h] [--config] [--task] [--dataset] [--cuda ]\r\n", 308 | " [--load_model] [--batch_size] [--shuffle ]\r\n", 309 | " [--pretrained_model_name_or_path] [--use_cache ]\r\n", 310 | " [--refresh_cache ] [--folder] [--right_inclusive ]\r\n", 311 | "\r\n", 312 | "options:\r\n", 313 | " -h, --help show this help message and exit\r\n", 314 | "\r\n", 315 | "fastIE command line basic arguments:\r\n", 316 | " --config , -c The config file you want to use. default:\r\n", 317 | " --task , -t The task you want to use. Please use / to split the\r\n", 318 | " task and the specific solution. default:\r\n", 319 | " --dataset , -d The dataset you want to work with. default:\r\n", 320 | "\r\n", 321 | "Use pre-trained BERT and a classification head to classify tokens:\r\n", 322 | " --cuda [] Whether to use your NVIDIA graphics card to accelerate\r\n", 323 | " the process. default: False\r\n", 324 | " --load_model Load the model from the path or model name. default:\r\n", 325 | " --batch_size Batch size. default: 32\r\n", 326 | " --shuffle [] Whether to shuffle the dataset. default: True\r\n", 327 | " --pretrained_model_name_or_path \r\n", 328 | " name of transformer model (see https://huggingface.co/\r\n", 329 | " transformers/pretrained_models.html for options).\r\n", 330 | " default: bert-base-uncased\r\n", 331 | "\r\n", 332 | "数据集基类:\r\n", 333 | " --use_cache [] The result of data loading is cached for accelerated\r\n", 334 | " reading the next time it is used. default: False\r\n", 335 | " --refresh_cache [] Clear cache (Use this when your data changes).\r\n", 336 | " default: False\r\n", 337 | " --folder The folder where the data set resides. We will\r\n", 338 | " automatically read the possible train.jsonl,\r\n", 339 | " dev.jsonl, test.jsonl and infer.jsonl in it. default:\r\n", 340 | " --right_inclusive [] When data is in the format of start and end, whether\r\n", 341 | " each span contains the token corresponding to end.\r\n", 342 | " default: True\r\n" 343 | ] 344 | } 345 | ], 346 | "source": [ 347 | "!fastie-eval --task ner/bert --dataset jsonlines-ner --help" 348 | ], 349 | "metadata": { 350 | "collapsed": false 351 | } 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "source": [ 356 | "从获得帮助信息中可知, 我们需要 `load_model` 参数来加载 `fastie-train` 训练好的模型, 而\n", 357 | "`jsonlines-ner` 数据集需要 `folder` 参数来指定数据集路径. 因此, 完整的验证命令为:\n", 358 | "\n", 359 | "```bash\n", 360 | "fastie-eval --task ner/bert --dataset jsonlines-ner --load_model /path/to/model --folder /path/to/dataset\n", 361 | "```\n", 362 | "\n", 363 | "### 1.3 fastie-infer\n", 364 | "\n", 365 | "`fastie-infer` 是 `fastie` 中的推理工具, 用于使用训练好的模型对新的数据进行预测. `fastie-infer`\n", 366 | "的使用方法与 `fastie-eval` 类似, 但 `fastie` 中的数据集一般只有 `tokens` 字段.\n", 367 | "\n", 368 | "在 `fastie` 中为您提供了推理专用的数据结构 `sentence`, 方便您快速对单个序列进行测试, 例如:\n", 369 | "\n", 370 | "```bash\n", 371 | "fastie-infer --task ner/bert --load_model /path/to/model --dataset sentence --sentence \"I love fastie .\"\n", 372 | "```\n", 373 | "\n", 374 | "### 1.4 fastie-interact\n", 375 | "\n", 376 | "`fastie-interact` 是 `fastie` 中的交互工具, 用于与训练好的模型进行交互. `fastie-interact`\n", 377 | "的使用方法与 `fastie-eval` 类似, 但由于 `fastie-interact` 的数据集为您实时输入的序列, 因此\n", 378 | "没有 `dataset` 参数.\n", 379 | "\n", 380 | "## 2. 获得命令行的帮助\n", 381 | "\n", 382 | "`fastie` 中的所有工具都提供了 `--help` 参数, 用于获得命令行的帮助信息. 帮助信息会随着您进一步\n", 383 | "明确 `task` 或者 `dataset` 或者 `config` 后增加, 例如, 当您指定 `task` 为 `ner/bert`\n", 384 | "后, 帮助信息将 `ner/bert` 需要的参数也展示出来.\n", 385 | "\n", 386 | "## 3. 使用配置文件\n", 387 | "\n", 388 | " 如前文 `1.1.3 config` 所述, `config` 文件不但可以储存基础参数 `task`, `dataset`, 还可以\n", 389 | " 储存 `task` 或 `dataset` 规定的内部参数. 例如, 可以在 `config` 中设置 `task` 的 `batch_size`\n", 390 | " 等参数, 也可以在 `config` 中设置 `dataset` 的 `folder` 等参数:\n", 391 | "\n", 392 | "```python\n", 393 | "config = {\n", 394 | " \"task\": \"ner/bert\",\n", 395 | " \"dataset\": \"conll2003\",\n", 396 | " \"batch_size\": 32,\n", 397 | " \"folder\": \"/path/to/dataset\"\n", 398 | "}\n", 399 | "```\n", 400 | "\n", 401 | "值得注意的是, `config` 文件中的参数, `help` 中提示的参数, 与 `SDK` 方式中每个 `task` 类或\n", 402 | "`dataset` 类的参数都是保持一致的, 因此, 无论您使用哪种方式, 您都可以使用 `--help` 获得\n", 403 | "完整的参数帮助信息." 404 | ], 405 | "metadata": { 406 | "collapsed": false 407 | } 408 | } 409 | ], 410 | "metadata": { 411 | "kernelspec": { 412 | "display_name": "Python 3", 413 | "language": "python", 414 | "name": "python3" 415 | }, 416 | "language_info": { 417 | "codemirror_mode": { 418 | "name": "ipython", 419 | "version": 2 420 | }, 421 | "file_extension": ".py", 422 | "mimetype": "text/x-python", 423 | "name": "python", 424 | "nbconvert_exporter": "python", 425 | "pygments_lexer": "ipython2", 426 | "version": "2.7.6" 427 | } 428 | }, 429 | "nbformat": 4, 430 | "nbformat_minor": 0 431 | } 432 | --------------------------------------------------------------------------------