├── .circleci └── config.yml ├── .coveragerc ├── .github ├── ISSUE_TEMPLATE.md └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── activation_venv ├── activation_venv.bat ├── demo ├── atis_joint_model │ ├── atis_joint_config.json │ └── data_processor.py ├── configs │ ├── distributed_docnn.json │ ├── docnn.json │ ├── docnn_dense_feat.json │ ├── docnn_wo_export.json │ ├── export_list.json │ ├── export_options.json │ ├── lm.json │ ├── multitask_sst_lm.json │ ├── new_joint.json │ ├── rnng.json │ ├── seqnn.json │ └── word_tagging.json ├── datasource │ └── source.py ├── examples │ └── tensorizer.py ├── flask_server │ ├── atis.py │ ├── server.py │ └── setup.sh ├── my_tagging │ ├── metric.py │ ├── model.py │ ├── my_config.json │ ├── my_tagging_task.py │ ├── output.py │ └── source.py ├── notebooks │ ├── seq2seq_tutorial.ipynb │ └── xlm_r_tutorial.ipynb └── predictor_service │ ├── Dockerfile │ ├── Makefile │ ├── predictor.thrift │ └── server.cpp ├── docs_requirements.txt ├── install_deps ├── install_deps.bat ├── pytest.ini ├── pytext ├── __init__.py ├── builtin_task.py ├── common │ ├── __init__.py │ ├── constants.py │ └── utils.py ├── config │ ├── __init__.py │ ├── component.py │ ├── config_adapter.py │ ├── contextual_intent_slot.py │ ├── doc_classification.py │ ├── field_config.py │ ├── module_config.py │ ├── pair_classification.py │ ├── pytext_config.py │ ├── query_document_pairwise_ranking.py │ ├── serialize.py │ ├── test │ │ ├── component_test.py │ │ ├── config_adapter_test.py │ │ ├── json_config │ │ │ ├── v15_test_upgrade.json │ │ │ ├── v16_test_upgrade.json │ │ │ ├── v1_test_upgrade.json │ │ │ ├── v22_test_upgrade.json │ │ │ ├── v23_test_downgrade.json │ │ │ ├── v23_test_upgrade.json │ │ │ ├── v24_test_downgrade.json │ │ │ ├── v24_test_upgrade.json │ │ │ ├── v25_test_downgrade.json │ │ │ ├── v25_test_upgrade.json │ │ │ ├── v26_test_downgrade.json │ │ │ ├── v26_test_upgrade.json │ │ │ ├── v27_test_downgrade.json │ │ │ ├── v27_test_upgrade.json │ │ │ ├── v28_test_downgrade.json │ │ │ ├── v28_test_upgrade.json │ │ │ ├── v29_test_downgrade.json │ │ │ ├── v29_test_upgrade.json │ │ │ ├── v2_test_upgrade.json │ │ │ ├── v30_test_downgrade.json │ │ │ ├── v30_test_upgrade.json │ │ │ ├── v31_test_downgrade.json │ │ │ ├── v31_test_upgrade.json │ │ │ ├── v32_test_downgrade.json │ │ │ ├── v32_test_upgrade.json │ │ │ ├── v33_test_downgrade.json │ │ │ ├── v33_test_upgrade.json │ │ │ ├── v34_test_downgrade.json │ │ │ ├── v34_test_upgrade.json │ │ │ ├── v35_test_downgrade.json │ │ │ ├── v35_test_upgrade.json │ │ │ ├── v36_test_downgrade.json │ │ │ ├── v36_test_upgrade.json │ │ │ ├── v37_test_downgrade.json │ │ │ ├── v37_test_upgrade.json │ │ │ ├── v38_test_downgrade.json │ │ │ ├── v38_test_upgrade.json │ │ │ ├── v39_test_downgrade.json │ │ │ ├── v39_test_upgrade.json │ │ │ ├── v3_test_upgrade.json │ │ │ ├── v40_test_downgrade.json │ │ │ ├── v40_test_upgrade.json │ │ │ ├── v41_test_downgrade.json │ │ │ ├── v41_test_upgrade.json │ │ │ ├── v42_test_downgrade.json │ │ │ ├── v42_test_upgrade.json │ │ │ ├── v43_test_downgrade.json │ │ │ ├── v43_test_upgrade.json │ │ │ ├── v44_test_downgrade.json │ │ │ ├── v44_test_upgrade.json │ │ │ ├── v45_test_downgrade.json │ │ │ ├── v45_test_upgrade.json │ │ │ ├── v46_test_downgrade.json │ │ │ ├── v46_test_upgrade.json │ │ │ ├── v47_test_downgrade.json │ │ │ ├── v47_test_upgrade.json │ │ │ ├── v48_test_downgrade.json │ │ │ ├── v48_test_upgrade.json │ │ │ ├── v4_test_upgrade.json │ │ │ ├── v6_test_upgrade.json │ │ │ └── v8_test_upgrade.json │ │ ├── pytext_all_config_test.py │ │ ├── pytext_config_test.py │ │ └── serialize_test.py │ └── utils.py ├── data │ ├── __init__.py │ ├── batch_sampler.py │ ├── bert_tensorizer.py │ ├── data.py │ ├── data_handler.py │ ├── data_structures │ │ ├── __init__.py │ │ ├── annotation.py │ │ ├── node.py │ │ └── tests │ │ │ └── annotation_test.py │ ├── decoupled_data.py │ ├── dense_retrieval_tensorizer.py │ ├── disjoint_multitask_data.py │ ├── disjoint_multitask_data_handler.py │ ├── dynamic_pooling_batcher.py │ ├── featurizer │ │ ├── __init__.py │ │ ├── featurizer.py │ │ └── simple_featurizer.py │ ├── masked_tensorizer.py │ ├── masked_util.py │ ├── packed_lm_data.py │ ├── pickleable_gpt2bpe_encoder │ │ ├── __init__.py │ │ └── pickleable_gpt2bpe_encoder.py │ ├── roberta_tensorizer.py │ ├── sources │ │ ├── __init__.py │ │ ├── conllu.py │ │ ├── data_source.py │ │ ├── dense_retrieval.py │ │ ├── pandas.py │ │ ├── session.py │ │ ├── squad.py │ │ └── tsv.py │ ├── squad_for_bert_tensorizer.py │ ├── squad_tensorizer.py │ ├── tensorizers.py │ ├── test │ │ ├── __init__.py │ │ ├── batch_sampler_test.py │ │ ├── data │ │ │ ├── gpt2_dict.txt │ │ │ ├── gpt2_encoder.json │ │ │ ├── gpt2_vocab.bpe │ │ │ ├── sentencepiece.model │ │ │ ├── sentencepiece_dict_1k.txt │ │ │ ├── spm_ontology.txt │ │ │ └── wordpiece_1k.txt │ │ ├── data_test.py │ │ ├── dynamic_pooling_batcher_test.py │ │ ├── mask_tensorizers_test.py │ │ ├── pandas_data_source_test.py │ │ ├── round_robin_batchiterator_test.py │ │ ├── simple_featurizer_test.py │ │ ├── tensorizers_test.py │ │ ├── tokenizers_test.py │ │ ├── tsv_data_source_test.py │ │ └── utils_test.py │ ├── token_tensorizer.py │ ├── tokenizers │ │ ├── __init__.py │ │ └── tokenizer.py │ ├── utils.py │ ├── xlm_constants.py │ ├── xlm_dictionary.py │ └── xlm_tensorizer.py ├── docs │ ├── Makefile │ ├── make_config_docs.py │ ├── origin │ │ ├── README │ │ └── pytext.odg │ ├── requirements.txt │ └── source │ │ ├── _static │ │ └── img │ │ │ ├── flask_www.png │ │ │ ├── ios_demo.png │ │ │ ├── pytext.png │ │ │ ├── pytext_design.png │ │ │ ├── tb_graph.png │ │ │ ├── tb_test_metrics.png │ │ │ └── tb_train_metrics.png │ │ ├── atis_tutorial.rst │ │ ├── conf.py │ │ ├── config_commands.rst │ │ ├── config_files.rst │ │ ├── create_new_model.rst │ │ ├── datasource_tutorial.rst │ │ ├── dense.rst │ │ ├── disjoint_multitask_tutorial.rst │ │ ├── distributed_training_tutorial.rst │ │ ├── execute_your_first_model.rst │ │ ├── hacking_pytext.rst │ │ ├── hierarchical_intent_slot_tutorial.rst │ │ ├── index.rst │ │ ├── installation.rst │ │ ├── overview.rst │ │ ├── pytext_models_in_your_app.rst │ │ ├── seq2seq_tutorial.rst │ │ ├── serving_models_in_production.rst │ │ ├── tensorizer.rst │ │ ├── train_your_first_model.rst │ │ ├── visualize_your_model.rst │ │ └── xlm_r.rst ├── exporters │ ├── __init__.py │ ├── custom_exporters.py │ ├── exporter.py │ └── test │ │ ├── new_text_model_exporter_test.py │ │ └── text_model_exporter_test.py ├── fields │ ├── __init__.py │ ├── char_field.py │ ├── contextual_token_embedding_field.py │ ├── dict_field.py │ ├── field.py │ ├── test │ │ ├── char_field_test.py │ │ ├── contextual_token_embedding_field_test.py │ │ ├── dict_field_test.py │ │ └── field_test.py │ └── text_field_with_special_unk.py ├── legacy │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── batch.py │ │ ├── dataset.py │ │ ├── example.py │ │ ├── field.py │ │ ├── iterator.py │ │ └── pipeline.py │ ├── datasets │ │ ├── __init__.py │ │ ├── babi.py │ │ ├── imdb.py │ │ ├── language_modeling.py │ │ ├── nli.py │ │ ├── sequence_tagging.py │ │ ├── sst.py │ │ ├── text_classification.py │ │ ├── translation.py │ │ ├── trec.py │ │ └── unsupervised_learning.py │ └── vocab.py ├── loss │ ├── __init__.py │ ├── loss.py │ ├── regularized_loss.py │ ├── regularizer.py │ ├── structured_loss.py │ └── tests │ │ ├── ctc_loss_test.py │ │ ├── focal_loss_test.py │ │ ├── label_smoothing_loss_test.py │ │ └── samplewise_label_smoothing_loss_test.py ├── main.py ├── metric_reporters │ ├── __init__.py │ ├── calibration_metric_reporter.py │ ├── channel.py │ ├── classification_metric_reporter.py │ ├── compositional_metric_reporter.py │ ├── compositional_utils.py │ ├── dense_retrieval_metric_reporter.py │ ├── disjoint_multitask_metric_reporter.py │ ├── intent_slot_detection_metric_reporter.py │ ├── language_model_metric_reporter.py │ ├── mask_compositional.py │ ├── mask_seq2seq_topk.py │ ├── metric_reporter.py │ ├── multi_span_qa_metric_reporter.py │ ├── pairwise_ranking_metric_reporter.py │ ├── regression_metric_reporter.py │ ├── seq2seq_compositional.py │ ├── seq2seq_metric_reporter.py │ ├── seq2seq_utils.py │ ├── squad_metric_reporter.py │ ├── tests │ │ ├── classification_metric_reporter_test.py │ │ ├── compositional_metric_reporter_test.py │ │ ├── intent_slot_metric_reporter_test.py │ │ ├── language_model_metric_reporter_test.py │ │ ├── multi_label_seq_tagging_metric_reporter_test.py │ │ └── tensorboard_test.py │ └── word_tagging_metric_reporter.py ├── metrics │ ├── __init__.py │ ├── calibration_metrics.py │ ├── dense_retrieval_metrics.py │ ├── intent_slot_metrics.py │ ├── language_model_metrics.py │ ├── mask_metrics.py │ ├── seq2seq_metrics.py │ ├── squad_metrics.py │ └── tests │ │ ├── basic_metrics_test.py │ │ ├── calibration_metrics_test.py │ │ ├── intent_slot_metrics_test.py │ │ ├── metrics_test_base.py │ │ └── multilabel_metrics_test.py ├── models │ ├── __init__.py │ ├── bert_classification_models.py │ ├── bert_regression_model.py │ ├── crf.py │ ├── decoders │ │ ├── __init__.py │ │ ├── decoder_base.py │ │ ├── intent_slot_model_decoder.py │ │ ├── mlp_decoder.py │ │ ├── mlp_decoder_n_tower.py │ │ ├── mlp_decoder_query_response.py │ │ ├── mlp_decoder_tri_tower.py │ │ ├── mlp_decoder_two_tower.py │ │ └── multilabel_decoder.py │ ├── disjoint_multitask_model.py │ ├── distributed_model.py │ ├── doc_model.py │ ├── embeddings │ │ ├── __init__.py │ │ ├── char_embedding.py │ │ ├── contextual_token_embedding.py │ │ ├── dict_embedding.py │ │ ├── embedding_base.py │ │ ├── embedding_list.py │ │ ├── int_single_category_embedding.py │ │ ├── int_weighted_multi_category_embedding.py │ │ ├── mlp_embedding.py │ │ ├── scriptable_embedding_list.py │ │ ├── word_embedding.py │ │ └── word_seq_embedding.py │ ├── ensembles │ │ ├── __init__.py │ │ ├── bagging_doc_ensemble.py │ │ ├── bagging_intent_slot_ensemble.py │ │ └── ensemble.py │ ├── joint_model.py │ ├── language_models │ │ ├── __init__.py │ │ └── lmlstm.py │ ├── masked_lm.py │ ├── masking_utils.py │ ├── model.py │ ├── module.py │ ├── output_layers │ │ ├── __init__.py │ │ ├── distance_output_layer.py │ │ ├── doc_classification_output_layer.py │ │ ├── doc_regression_output_layer.py │ │ ├── intent_slot_output_layer.py │ │ ├── lm_output_layer.py │ │ ├── multi_label_classification_layer.py │ │ ├── output_layer_base.py │ │ ├── pairwise_ranking_output_layer.py │ │ ├── squad_output_layer.py │ │ ├── utils.py │ │ └── word_tagging_output_layer.py │ ├── pair_classification_model.py │ ├── qna │ │ ├── __init__.py │ │ ├── bert_squad_qa.py │ │ └── dr_qa.py │ ├── query_document_pairwise_ranking_model.py │ ├── r3f_models.py │ ├── representations │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── augmented_lstm.py │ │ ├── bilstm.py │ │ ├── bilstm_doc_attention.py │ │ ├── bilstm_doc_slot_attention.py │ │ ├── bilstm_slot_attn.py │ │ ├── biseqcnn.py │ │ ├── contextual_intent_slot_rep.py │ │ ├── deepcnn.py │ │ ├── docnn.py │ │ ├── huggingface_bert_sentence_encoder.py │ │ ├── huggingface_electra_sentence_encoder.py │ │ ├── jointcnn_rep.py │ │ ├── lightconv.py │ │ ├── ordered_neuron_lstm.py │ │ ├── pair_rep.py │ │ ├── pass_through.py │ │ ├── pooling.py │ │ ├── pure_doc_attention.py │ │ ├── representation_base.py │ │ ├── seq_rep.py │ │ ├── slot_attention.py │ │ ├── sparse_transformer_sentence_encoder.py │ │ ├── stacked_bidirectional_rnn.py │ │ ├── test │ │ │ ├── augmented_lstm_test.py │ │ │ ├── ordered_neuron_lstm_test.py │ │ │ └── transformer_test.py │ │ ├── traced_transformer_encoder.py │ │ ├── transformer │ │ │ ├── __init__.py │ │ │ ├── luna_attention.py │ │ │ ├── luna_sentence_encoder.py │ │ │ ├── multihead_linear_attention.py │ │ │ ├── representation.py │ │ │ └── sentence_encoder.py │ │ ├── transformer_sentence_encoder.py │ │ └── transformer_sentence_encoder_base.py │ ├── roberta.py │ ├── semantic_parsers │ │ ├── __init__.py │ │ └── rnng │ │ │ ├── __init__.py │ │ │ ├── rnng_constant.py │ │ │ ├── rnng_data_structures.py │ │ │ └── rnng_parser.py │ ├── seq_models │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── base.py │ │ ├── benchmarks │ │ │ └── nar_top.json │ │ ├── contextual_intent_slot.py │ │ ├── conv_decoder.py │ │ ├── conv_encoder.py │ │ ├── conv_model.py │ │ ├── light_conv.py │ │ ├── mask_generator.py │ │ ├── nar_length.py │ │ ├── nar_modules.py │ │ ├── nar_output_layer.py │ │ ├── nar_seq2seq_model.py │ │ ├── positional.py │ │ ├── projection_layers.py │ │ ├── rnn_decoder.py │ │ ├── rnn_encoder.py │ │ ├── rnn_encoder_decoder.py │ │ ├── seq2seq_model.py │ │ ├── seq2seq_output_layer.py │ │ ├── seqnn.py │ │ └── utils.py │ ├── test │ │ ├── bilstm_test.py │ │ ├── crf_test.py │ │ ├── dict_embedding_test.py │ │ ├── int_single_category_embedding_test.py │ │ ├── int_weighted_multi_category_embedding_test.py │ │ ├── mlp_decoder_test.py │ │ ├── mlp_embedding_test.py │ │ ├── module_test.py │ │ ├── output_layer_test.py │ │ ├── personalized_doc_model_test.py │ │ ├── rnng_test.py │ │ ├── scripted_seq2seq_generator_test.py │ │ ├── transformer_sentence_encoder_test.py │ │ ├── word_embedding_test.py │ │ └── word_seq_embedding_test.py │ ├── tri_tower_classification_model.py │ ├── two_tower_classification_model.py │ ├── two_tower_regression_model.py │ ├── utils.py │ └── word_model.py ├── optimizer │ ├── __init__.py │ ├── activations.py │ ├── adabelief.py │ ├── fairseq_fp16_utils.py │ ├── fp16_optimizer.py │ ├── lamb.py │ ├── madgrad.py │ ├── optimizers.py │ ├── radam.py │ ├── scheduler.py │ ├── sparsifiers │ │ ├── __init__.py │ │ ├── blockwise_sparsifier.py │ │ ├── sparsifier.py │ │ └── tests │ │ │ └── sparsifier_test.py │ ├── swa.py │ └── tests │ │ ├── fp16optimizer_test.py │ │ └── test_swa.py ├── resources │ ├── __init__.py │ └── roberta.py ├── task │ ├── __init__.py │ ├── disjoint_multitask.py │ ├── new_task.py │ ├── nop_decorator.py │ ├── pytext_checkpoint_management.py │ ├── serialize.py │ ├── task.py │ └── tasks.py ├── torchscript │ ├── __init__.py │ ├── batchutils.py │ ├── module.py │ ├── seq2seq │ │ ├── __init__.py │ │ ├── beam_decode.py │ │ ├── beam_search.py │ │ ├── decoder.py │ │ ├── encoder.py │ │ ├── export_model.py │ │ ├── scripted_seq2seq_generator.py │ │ └── seq2seq_rnn_decoder_utils.py │ ├── tensorizer │ │ ├── __init__.py │ │ ├── bert.py │ │ ├── normalizer.py │ │ ├── roberta.py │ │ ├── tensorizer.py │ │ └── xlm.py │ ├── tests │ │ ├── test_batchutils.py │ │ ├── test_module.py │ │ ├── test_tensorizer.py │ │ ├── test_tokenizer.py │ │ └── test_vocab.py │ ├── tokenizer │ │ ├── __init__.py │ │ ├── bpe.py │ │ └── tokenizer.py │ ├── utils.py │ └── vocab.py ├── trainers │ ├── __init__.py │ ├── ensemble_trainer.py │ ├── hogwild_trainer.py │ ├── trainer.py │ └── training_state.py ├── utils │ ├── __init__.py │ ├── ascii_table.py │ ├── config_utils.py │ ├── cuda.py │ ├── data.py │ ├── distributed.py │ ├── documentation.py │ ├── file_io.py │ ├── label.py │ ├── lazy.py │ ├── loss.py │ ├── meter.py │ ├── mobile_onnx.py │ ├── model.py │ ├── onnx.py │ ├── path.py │ ├── precision.py │ ├── tensor.py │ ├── test.py │ ├── tests │ │ ├── ascii_table_test.py │ │ ├── embeddings_utils_test.py │ │ ├── label_test.py │ │ ├── lazy_test.py │ │ ├── path_test.py │ │ ├── timing_test.py │ │ └── utils_test.py │ ├── timing.py │ ├── torch.py │ ├── typing.py │ └── usage.py └── workflow.py ├── readthedocs.yml ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── data ├── alarm_lm_tiny.tsv ├── assistant_intent_slot_model_train.tsv ├── compositional_seq2seq_unit.tsv ├── contextual_intent_slot_test_tiny.tsv ├── contextual_intent_slot_train_tiny.tsv ├── contextual_intent_slot_train_tiny_dense.tsv ├── dummy_pretrained_embedding_dim4 ├── eval_data_tiny.tsv ├── export_seq2seq_unit.tsv ├── fl_test.tsv ├── knowledge_distillation_test_tiny.tsv ├── msg_topic_train.tsv ├── pairwise_classification.tsv ├── pretrained_embed_raw ├── query_document_pairwise_ranking_different_users.tsv ├── query_document_pairwise_ranking_one_user.tsv ├── query_document_pairwise_ranking_tiny.tsv ├── roberta_sp_vocab_small ├── seq2seq_model_unit.tsv ├── seq_tagging_example.tsv ├── squad_tiny.json ├── squad_tiny.tsv ├── sts_tiny.tsv ├── test_data_split_tiny.tsv ├── test_data_tiny.tsv ├── test_data_tiny_csv.tsv ├── test_data_tiny_fl.tsv ├── test_data_tiny_weights.tsv ├── test_dense_features_tiny.tsv ├── test_embed.cached ├── test_embed.raw ├── test_embed_xlu.cached ├── test_lm_tiny.tsv ├── test_lm_tiny_broadcast_data.tsv ├── test_lm_tiny_fl.tsv ├── test_music_samples.json ├── test_personalization_opposite_inputs.tsv ├── test_personalization_same_inputs.tsv ├── test_personalization_single_user.tsv ├── test_rnng.tsv ├── test_tiny.en ├── test_tsv_quoting.tsv ├── test_utf8_errors.tsv ├── train_data_tiny.tsv ├── train_data_tiny_weights.tsv ├── train_dense_features_and_text_tiny.tsv ├── train_dense_features_tiny.tsv ├── train_dense_features_tiny_fl.tsv ├── train_dict_features.tsv ├── train_dict_features_bad_json.tsv ├── train_seq_features.tsv ├── train_tiny_with_lang.tsv └── xlm_vocab_small ├── data_utils.py ├── main_test.py ├── model_utils_test.py ├── module_load_save_test.py ├── predictor_test.py ├── seq2seq_model_tests.py ├── task_load_save_test.py └── utils.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | concurrency = multiprocessing 4 | include = pytext/* 5 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Steps to reproduce 2 | 3 | 1. _____ 4 | 2. _____ 5 | 3. _____ 6 | 7 | ## Observed Results 8 | 9 | * What happened? This could be a description, log output, etc. 10 | ## Expected Results 11 | 12 | * What did you expect to happen? 13 | 14 | ## Relevant Code 15 | 16 | ``` 17 | // TODO(you): code here to reproduce the problem 18 | ``` 19 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Motivation and Context 2 | 3 | 4 | 5 | 6 | 7 | ## How Has This Been Tested 8 | 9 | 10 | 11 | ## Types of changes 12 | 13 | 14 | - [ ] Docs change / refactoring / dependency upgrade 15 | - [ ] Bug fix (non-breaking change which fixes an issue) 16 | - [ ] New feature (non-breaking change which adds functionality) 17 | - [ ] Breaking change (fix or feature that would cause existing functionality to change) 18 | 19 | ## Checklist 20 | 21 | 22 | 23 | - [ ] My code follows the code style of this project. 24 | - [ ] My change requires a change to the documentation. 25 | - [ ] I have updated the documentation accordingly. 26 | - [ ] I have read the [**CONTRIBUTING**](https://github.com/facebookresearch/pytext/blob/master/CONTRIBUTING.md) document. 27 | - [ ] I have completed my CLA (see [**CLA**](https://github.com/facebookresearch/pytext/blob/master/CONTRIBUTING.md#contributor-license-agreement-cla)) 28 | - [ ] I have added tests to cover my changes. 29 | - [ ] All new and existing tests passed. 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[codi] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | eggs/ 15 | .eggs/ 16 | *.egg-info/ 17 | .installed.cfg 18 | *.egg 19 | 20 | # Environments 21 | .env 22 | .venv 23 | pytext_venv/ 24 | env/ 25 | venv/ 26 | ENV/ 27 | env.bak/ 28 | venv.bak/ 29 | 30 | # Backups 31 | *~ 32 | 33 | # Coverage Files 34 | .coverage 35 | 36 | # Generated Document Sources 37 | pytext/docs/source/configs/* 38 | pytext/docs/source/modules/* 39 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to PyText 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## Coding Style 30 | We use isort and black to format our code, you can use the following commands to format your code prior to submission: 31 | 32 | ``` 33 | (pytext_venv) $ pip install isort black 34 | (pytext_venv) $ black pytext 35 | (pytext_venv) $ isort pytext --recursive --multi-line 3 --trailing-comma --force-grid-wrap 0 --line-width 88 --lines-after-imports 2 --combine-as --section-default THIRDPARTY 36 | ``` 37 | 38 | ## Updates to Docs 39 | The documentation build process work with Python 3.7 and above. 40 | 41 | ## License 42 | By contributing to PyText, you agree that your contributions will be licensed 43 | under the LICENSE file in the root directory of this source tree. 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For PyText software 4 | 5 | Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /activation_venv: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | VENV_NAME=${1:-pytext_venv} 3 | 4 | if [ ! -d "$VENV_NAME" ] 5 | then 6 | python3 -m venv "$VENV_NAME" 7 | fi 8 | # shellcheck source=/dev/null 9 | source "$VENV_NAME/bin/activate" 10 | -------------------------------------------------------------------------------- /activation_venv.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | ::Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | ::Use venv name if passed, otherwise default 4 | IF "%1"=="" ( 5 | SET "_PYTEXT_ENV_NAME_=pytext_venv" 6 | ) ELSE ( 7 | SET "_PYTEXT_ENV_NAME_=%1" 8 | ) 9 | 10 | IF NOT EXIST %_PYTEXT_ENV_NAME_% ( 11 | python -m venv %_PYTEXT_ENV_NAME_% 12 | ) 13 | 14 | call %_PYTEXT_ENV_NAME_%\Scripts\activate.bat 15 | -------------------------------------------------------------------------------- /demo/configs/distributed_docnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "config": { 3 | "task": { 4 | "DocClassificationTask": { 5 | "data_handler": { 6 | "columns_to_read": [ 7 | "text", 8 | "doc_label" 9 | ], 10 | "train_path": "base_dir/train_tiny.tsv", 11 | "eval_path": "base_dir/test_tiny.tsv", 12 | "test_path": "base_dir/test_tiny.tsv" 13 | } 14 | } 15 | }, 16 | "use_cuda_if_available": true, 17 | "distributed_world_size": 8 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /demo/configs/docnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": 8, 3 | "task": { 4 | "DocumentClassificationTask": { 5 | "data": { 6 | "source": { 7 | "TSVDataSource": { 8 | "field_names": ["label", "slots", "text"], 9 | "train_filename": "tests/data/train_data_tiny.tsv", 10 | "test_filename": "tests/data/test_data_tiny.tsv", 11 | "eval_filename": "tests/data/test_data_tiny.tsv" 12 | } 13 | } 14 | }, 15 | "model": { 16 | "DocModel": { 17 | "representation": { 18 | "DocNNRepresentation": {} 19 | } 20 | } 21 | } 22 | } 23 | }, 24 | "save_snapshot_path": "/tmp/model.pt", 25 | "export_torchscript_path": "/tmp/new_docnn.pt1", 26 | "export_caffe2_path": "/tmp/model.caffe2.predictor" 27 | } 28 | -------------------------------------------------------------------------------- /demo/configs/docnn_dense_feat.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": { 3 | "DocClassificationTask": { 4 | "features" : { 5 | "dense_feat" : { 6 | "column": "dense_feat", 7 | "dim": 10 8 | } 9 | }, 10 | "data_handler": { 11 | "columns_to_read": ["doc_label", "text", "dict_feat", "dense_feat"], 12 | "train_path": "tests/data/train_dense_features_tiny.tsv", 13 | "eval_path": "tests/data/test_dense_features_tiny.tsv", 14 | "test_path": "tests/data/test_dense_features_tiny.tsv" 15 | } 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /demo/configs/docnn_wo_export.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": 8, 3 | "task": { 4 | "DocumentClassificationTask": { 5 | "data": { 6 | "source": { 7 | "TSVDataSource": { 8 | "field_names": ["label", "slots", "text"], 9 | "train_filename": "tests/data/train_data_tiny.tsv", 10 | "test_filename": "tests/data/test_data_tiny.tsv", 11 | "eval_filename": "tests/data/test_data_tiny.tsv" 12 | } 13 | } 14 | }, 15 | "model": { 16 | "DocModel": { 17 | "representation": { 18 | "DocNNRepresentation": {} 19 | } 20 | } 21 | } 22 | } 23 | }, 24 | "save_snapshot_path": "/tmp/model.pt" 25 | } 26 | -------------------------------------------------------------------------------- /demo/configs/export_list.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": 26, 3 | "export_list": [ 4 | { 5 | "accelerate": [ 6 | "nnpi" 7 | ], 8 | "seq_padding_control": [ 9 | 0, 10 | 16, 11 | 40, 12 | 512 13 | ], 14 | "batch_padding_control": [ 15 | 0, 16 | 1, 17 | 2 18 | ], 19 | "export_torchscript_path": "/tmp/intentslot-nnpi.pt1", 20 | "torchscript_quantize": false, 21 | "target": "nnpi" 22 | }, 23 | { 24 | "export_torchscript_path": "/tmp/intentslot-cpu.pt1", 25 | "torchscript_quantize": false, 26 | "target": "cpu" 27 | }, 28 | { 29 | "torchscript_quantize": false, 30 | "export_torchscript_path": "/tmp/intentslot-cuda-fp32.pt1", 31 | "target": "gpu-fp32" 32 | }, 33 | { 34 | "export_caffe2_path": null, 35 | "export_torchscript_path": "/tmp/intentslot-cuda-fp16.pt1", 36 | "export_lite_path": null, 37 | "torchscript_quantize": false, 38 | "accelerate": [ 39 | "cuda:half" 40 | ], 41 | "inference_interface": null, 42 | "seq_padding_control": null, 43 | "batch_padding_control": null, 44 | "target": "gpu-fp16" 45 | } 46 | ] 47 | } 48 | -------------------------------------------------------------------------------- /demo/configs/export_options.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": 23, 3 | "export": { 4 | "export_torchscript_path": "/tmp/new_docnn.pt1", 5 | "export_caffe2_path": "/tmp/model.caffe2.predictor" 6 | } 7 | } 8 | -------------------------------------------------------------------------------- /demo/configs/lm.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": 8, 3 | "task": { 4 | "LMTask": { 5 | "data": { 6 | "source": { 7 | "TSVDataSource": { 8 | "field_names": ["text"], 9 | "train_filename": "tests/data/alarm_lm_tiny.tsv", 10 | "test_filename": "tests/data/alarm_lm_tiny.tsv", 11 | "eval_filename": "tests/data/alarm_lm_tiny.tsv" 12 | } 13 | } 14 | }, 15 | "model": { 16 | "tied_weights": true, 17 | "stateful": true, 18 | "embedding": { 19 | "embed_dim": 10 20 | }, 21 | "decoder": { 22 | "hidden_dims": [ 23 | 10 24 | ] 25 | } 26 | } 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /demo/configs/multitask_sst_lm.json: -------------------------------------------------------------------------------- 1 | { 2 | "config": { 3 | "version": 19, 4 | "task": { 5 | "NewDisjointMultitask": { 6 | "trainer": { 7 | "epochs": 5 8 | }, 9 | "task_weights": { 10 | "SST2": 1, 11 | "LM": 1 12 | }, 13 | "tasks": { 14 | "SST2": { 15 | "DocumentClassificationTask": { 16 | "data": { 17 | "source": { 18 | "TSVDataSource": { 19 | "field_names": ["label", "slots", "text"], 20 | "train_filename": "tests/data/train_data_tiny.tsv", 21 | "test_filename": "tests/data/test_data_tiny.tsv", 22 | "eval_filename": "tests/data/test_data_tiny.tsv" 23 | } 24 | } 25 | }, 26 | "model": { 27 | "DocModel": { 28 | "representation": { 29 | "BiLSTMDocAttention": { 30 | "lstm": { 31 | "shared_module_key": "LSTM" 32 | } 33 | } 34 | } 35 | } 36 | } 37 | } 38 | }, 39 | "LM": { 40 | "LMTask": { 41 | "data": { 42 | "source": { 43 | "TSVDataSource": { 44 | "field_names": ["text"], 45 | "train_filename": "tests/data/alarm_lm_tiny.tsv", 46 | "test_filename": "tests/data/alarm_lm_tiny.tsv", 47 | "eval_filename": "tests/data/alarm_lm_tiny.tsv" 48 | } 49 | } 50 | }, 51 | "model": { 52 | "representation": { 53 | "shared_module_key": "LSTM" 54 | } 55 | } 56 | } 57 | } 58 | } 59 | } 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /demo/configs/new_joint.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": { 3 | "IntentSlotTask": { 4 | "data": { 5 | "source": { 6 | "TSVDataSource": { 7 | "field_names": ["label", "slots", "text", "doc_weight", "word_weight"], 8 | "train_filename": "tests/data/train_data_tiny_weights.tsv", 9 | "test_filename": "tests/data/test_data_tiny_weights.tsv", 10 | "eval_filename": "tests/data/test_data_tiny_weights.tsv" 11 | } 12 | }, 13 | "sort_key": "tokens" 14 | } 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /demo/configs/rnng.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": { 3 | "SemanticParsingTask": { 4 | "data": { 5 | "batcher": { 6 | "PoolingBatcher": { 7 | "eval_batch_size": 1, 8 | "test_batch_size": 1, 9 | "train_batch_size": 1 10 | } 11 | }, 12 | "source": { 13 | "TSVDataSource": { 14 | "field_names": ["text", "tokenized_text", "seqlogical"], 15 | "train_filename": "tests/data/test_rnng.tsv", 16 | "test_filename": "tests/data/test_rnng.tsv", 17 | "eval_filename": "tests/data/test_rnng.tsv" 18 | } 19 | } 20 | }, 21 | "model": { 22 | "lstm": { 23 | "dropout": 0.34, 24 | "lstm_dim": 16, 25 | "num_layers": 2, 26 | "bidirectional": true 27 | }, 28 | "ablation": { 29 | "use_buffer": true, 30 | "use_stack": true, 31 | "use_action": true, 32 | "use_last_open_NT_feature": false 33 | }, 34 | "constraints": { 35 | "intent_slot_nesting": true, 36 | "ignore_loss_for_unsupported": false, 37 | "no_slots_inside_unsupported": true 38 | }, 39 | "max_open_NT": 10, 40 | "dropout": 0.34, 41 | "compositional_type": "sum" 42 | }, 43 | "metric_reporter": { 44 | "text_column_name": "tokenized_text" 45 | }, 46 | "trainer": { 47 | "real_trainer": { 48 | "report_train_metrics": false, 49 | "epochs": 1 50 | } 51 | } 52 | } 53 | }, 54 | "version": 12 55 | } 56 | -------------------------------------------------------------------------------- /demo/configs/seqnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "version":7, 3 | "task": { 4 | "SeqNNTask": { 5 | "data": { 6 | "source": { 7 | "TSVDataSource": { 8 | "field_names": ["label", "text_seq"], 9 | "train_filename": "tests/data/msg_topic_train.tsv", 10 | "test_filename": "tests/data/msg_topic_train.tsv", 11 | "eval_filename": "tests/data/msg_topic_train.tsv" 12 | } 13 | } 14 | }, 15 | "trainer": { 16 | "epochs": 2 17 | }, 18 | "model": { 19 | "representation": { 20 | "doc_representation": { 21 | "dropout": 0.5 22 | }, 23 | "seq_representation": { 24 | "BiLSTMDocAttention": { 25 | "dropout": 0.5 26 | } 27 | } 28 | }, 29 | "output_layer": { 30 | "loss": { 31 | "BinaryCrossEntropyLoss": {} 32 | } 33 | }, 34 | "decoder": { 35 | "hidden_dims": [ 36 | 100 37 | ] 38 | } 39 | } 40 | } 41 | }, 42 | "use_cuda_if_available": false 43 | } 44 | -------------------------------------------------------------------------------- /demo/configs/word_tagging.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": 7, 3 | "task": { 4 | "WordTaggingTask": { 5 | "data": { 6 | "source": { 7 | "TSVDataSource": { 8 | "field_names": ["label", "slots", "text"], 9 | "train_filename": "tests/data/train_data_tiny.tsv", 10 | "test_filename": "tests/data/test_data_tiny.tsv", 11 | "eval_filename": "tests/data/test_data_tiny.tsv" 12 | } 13 | } 14 | }, 15 | "trainer": { 16 | "epochs": 2 17 | } 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /demo/examples/tensorizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from pytext.data.tensorizers import Tensorizer 5 | from pytext.data.utils import pad_and_tensorize, VocabBuilder 6 | 7 | 8 | class MyWordTensorizer(Tensorizer): 9 | """ 10 | Simple Tensorizer that splits a sentence on spaces and create tensors 11 | from the vocabulary index built on the training data. 12 | """ 13 | 14 | class Config(Tensorizer.Config): 15 | #: The name of the text column to parse from the data source. 16 | column: str = "text" 17 | 18 | @classmethod 19 | def from_config(cls, config: Config): 20 | return cls(column=config.column) 21 | 22 | def __init__(self, column): 23 | self.column = column 24 | self.vocab = None 25 | 26 | @property 27 | def column_schema(self): 28 | return [(self.column, str)] 29 | 30 | def _tokenize(self, row): 31 | raw_text = row[self.column] 32 | return raw_text.split() 33 | 34 | def initialize(self): 35 | """Build vocabulary based on training corpus.""" 36 | vocab_builder = VocabBuilder() 37 | 38 | try: 39 | while True: 40 | row = yield 41 | words = self._tokenize(row) 42 | vocab_builder.add_all(words) 43 | except GeneratorExit: 44 | self.vocab = vocab_builder.make_vocab() 45 | 46 | def numberize(self, row): 47 | """Look up tokens in vocabulary to get their corresponding index""" 48 | words = self._tokenize(row) 49 | idx = self.vocab.lookup_all(words) 50 | # LSTM representations need the length of the sequence 51 | return idx, len(idx) 52 | 53 | def tensorize(self, batch): 54 | tokens, seq_lens = zip(*batch) 55 | return ( 56 | pad_and_tensorize(tokens, self.vocab.get_pad_index()), 57 | pad_and_tensorize(seq_lens), 58 | ) 59 | 60 | def sort_key(self, row): 61 | # LSTM representations need the batches to be sorted by descending seq_len 62 | return row[1] 63 | -------------------------------------------------------------------------------- /demo/flask_server/server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import json 5 | 6 | import atis 7 | from flask import Flask, request 8 | 9 | 10 | app = Flask(__name__) 11 | 12 | 13 | @app.route("/") 14 | def predict(): 15 | return json.dumps(atis.predict(request.args.get("text", ""))) 16 | 17 | 18 | if __name__ == "__main__": 19 | app.run(host="0.0.0.0", port=3000) 20 | -------------------------------------------------------------------------------- /demo/flask_server/setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh 5 | chmod +x miniconda.sh 6 | ./miniconda.sh -b -p ~/miniconda 7 | rm -f miniconda.sh 8 | source miniconda/bin/activate 9 | 10 | conda install -y protobuf 11 | conda install -y boto3 flask future numpy pip 12 | conda install -y pytorch -c pytorch 13 | 14 | sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to 10000 15 | -------------------------------------------------------------------------------- /demo/my_tagging/metric.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import itertools 5 | 6 | from pytext.metric_reporters.channel import ConsoleChannel, TensorBoardChannel 7 | from pytext.metric_reporters.metric_reporter import MetricReporter 8 | from pytext.metrics import compute_classification_metrics, LabelPrediction 9 | 10 | 11 | class MyTaggingMetricReporter(MetricReporter): 12 | @classmethod 13 | def from_config0(cls, config, vocab): 14 | return MyTaggingMetricReporter( 15 | channels=[ConsoleChannel(), TensorBoardChannel()], label_names=vocab 16 | ) 17 | 18 | @classmethod 19 | def from_config(cls, config, tensorizers): 20 | return MyTaggingMetricReporter( 21 | channels=[ConsoleChannel(), TensorBoardChannel()], 22 | label_names=tensorizers["slots"].vocab, 23 | ) 24 | 25 | def __init__(self, label_names, channels): 26 | super().__init__(channels) 27 | self.label_names = label_names 28 | 29 | def calculate_metric(self): 30 | return compute_classification_metrics( 31 | list( 32 | itertools.chain.from_iterable( 33 | (LabelPrediction(s, p, e) for s, p, e in zip(scores, pred, expect)) 34 | for scores, pred, expect in zip( 35 | self.all_scores, self.all_preds, self.all_targets 36 | ) 37 | ) 38 | ), 39 | self.label_names, 40 | self.calculate_loss(), 41 | ) 42 | 43 | # def batch_context(self, batch): 44 | # return {} 45 | 46 | @staticmethod 47 | def get_model_select_metric(metrics): 48 | return metrics.accuracy 49 | -------------------------------------------------------------------------------- /demo/my_tagging/my_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "debug_path": "my_model.debug", 3 | "export_caffe2_path": "my_model.caffe2.predictor", 4 | "export_onnx_path": "my_model.onnx", 5 | "save_snapshot_path": "my_model.pt", 6 | "task": { 7 | "MyTaggingTask": { 8 | "data": { 9 | "Data": { 10 | "source": { 11 | "AtisSlotsDataSource": { 12 | "field_names": ["text", "slots"], 13 | "path": "/home/egaudet/atis" 14 | } 15 | } 16 | } 17 | }, 18 | "metric_reporter": { 19 | "output_path": "my_test.out" 20 | }, 21 | "model": { 22 | "embedding": { 23 | "embed_dim": 100 24 | } 25 | }, 26 | "trainer": { 27 | "epochs": 3 28 | } 29 | } 30 | }, 31 | "test_out_path": "my_test_out.txt", 32 | "use_cuda_if_available": true, 33 | "use_fp16": false, 34 | "use_tensorboard": true, 35 | "version": 12 36 | } 37 | -------------------------------------------------------------------------------- /demo/my_tagging/my_tagging_task.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from pytext.metric_reporters.channel import ConsoleChannel, TensorBoardChannel 5 | from pytext.task.new_task import NewTask 6 | 7 | from .metric import MyTaggingMetricReporter 8 | from .model import MyTaggingModel 9 | 10 | 11 | class MyTaggingTask(NewTask): 12 | class Config(NewTask.Config): 13 | model: MyTaggingModel.Config = MyTaggingModel.Config() 14 | metric_reporter: MyTaggingMetricReporter.Config = ( 15 | MyTaggingMetricReporter.Config() 16 | ) 17 | 18 | @classmethod 19 | def create_metric_reporter(cls, config, tensorizers): 20 | return MyTaggingMetricReporter( 21 | channels=[ConsoleChannel(), TensorBoardChannel()], 22 | label_names=list(tensorizers["slots"].vocab), 23 | ) 24 | -------------------------------------------------------------------------------- /demo/my_tagging/output.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from pytext.config.component import create_loss 7 | from pytext.loss import CrossEntropyLoss 8 | from pytext.models.output_layers.output_layer_base import OutputLayerBase 9 | 10 | 11 | class MyTaggingOutputLayer(OutputLayerBase): 12 | class Config(OutputLayerBase.Config): 13 | loss: CrossEntropyLoss.Config = CrossEntropyLoss.Config() 14 | 15 | @classmethod 16 | def from_config(cls, config, vocab, pad_token): 17 | return cls(vocab, create_loss(config.loss, ignore_index=pad_token)) 18 | 19 | def get_loss(self, logit, target, context, reduce=True): 20 | # flatten the logit from [batch_size, seq_lens, dim] to 21 | # [batch_size * seq_lens, dim] 22 | return self.loss_fn(logit.view(-1, logit.size()[-1]), target.view(-1), reduce) 23 | 24 | def get_pred(self, logit, *args, **kwargs): 25 | preds = torch.max(logit, 2)[1] 26 | scores = F.log_softmax(logit, 2) 27 | return preds, scores 28 | -------------------------------------------------------------------------------- /demo/predictor_service/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | FROM ubuntu:18.04 4 | 5 | # Install dependencies 6 | RUN apt-get update && apt-get install -y --no-install-recommends \ 7 | build-essential \ 8 | ca-certificates \ 9 | cmake \ 10 | curl \ 11 | git \ 12 | libcurl4-openssl-dev \ 13 | libgflags-dev \ 14 | unzip 15 | 16 | # Install Thrift + dependencies 17 | WORKDIR / 18 | RUN apt-get update && apt-get install -y \ 19 | libboost-dev \ 20 | libboost-test-dev \ 21 | libboost-program-options-dev \ 22 | libboost-filesystem-dev \ 23 | libboost-thread-dev \ 24 | libevent-dev \ 25 | automake \ 26 | libtool \ 27 | flex \ 28 | bison \ 29 | pkg-config \ 30 | libssl-dev \ 31 | && rm -rf /var/lib/apt/lists/* 32 | RUN curl https://downloads.apache.org/thrift/0.13.0/thrift-0.13.0.tar.gz --output thrift-0.13.0.tar.gz \ 33 | && tar -xvf thrift-0.13.0.tar.gz \ 34 | && rm thrift-0.13.0.tar.gz 35 | WORKDIR /thrift-0.13.0 36 | RUN ./bootstrap.sh \ 37 | && ./configure \ 38 | && make \ 39 | && make install 40 | 41 | # Install Pistache (C++ REST framework) 42 | WORKDIR / 43 | RUN git clone https://github.com/oktal/pistache.git 44 | WORKDIR /pistache 45 | RUN git submodule update --init \ 46 | && mkdir build 47 | WORKDIR /pistache/build 48 | RUN cmake -G "Unix Makefiles" -DCMAKE_BUILD_TYPE=Release .. \ 49 | && make \ 50 | && make install 51 | 52 | # Install libtorch 53 | WORKDIR / 54 | RUN curl https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.4.0%2Bcpu.zip --output libtorch.zip \ 55 | && unzip libtorch.zip \ 56 | && rm libtorch.zip 57 | 58 | # Copy local files to /app 59 | COPY . /app 60 | WORKDIR /app 61 | 62 | # Compile app 63 | RUN thrift -r --gen cpp predictor.thrift 64 | RUN make 65 | 66 | # Add library search paths 67 | ENV LD_LIBRARY_PATH /libtorch/lib:/usr/local/lib 68 | 69 | # Expose ports for Thrift and REST 70 | EXPOSE 9090 71 | EXPOSE 8080 72 | -------------------------------------------------------------------------------- /demo/predictor_service/Makefile: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | CPPFLAGS += -g -std=c++11 -std=c++14 \ 4 | -I./gen-cpp \ 5 | -I/libtorch/include \ 6 | -Wno-deprecated-declarations 7 | CLIENT_LDFLAGS += -lthrift 8 | SERVER_LDFLAGS += -L/libtorch/lib \ 9 | -lthrift -lpistache -lpthread -ltorch -lc10 -lcurl -lgflags 10 | 11 | server: server.o gen-cpp/Predictor.o 12 | g++ $^ $(SERVER_LDFLAGS) -o $@ 13 | 14 | clean: 15 | rm -f *.o server 16 | -------------------------------------------------------------------------------- /demo/predictor_service/predictor.thrift: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | namespace cpp predictor_service 4 | 5 | service Predictor { 6 | // Returns scores for each class 7 | map predict(1: string doc); 8 | } 9 | -------------------------------------------------------------------------------- /docs_requirements.txt: -------------------------------------------------------------------------------- 1 | click 2 | fairseq 3 | future 4 | hypothesis<4.0 5 | iopath 6 | mock 7 | numpy 8 | onnx 9 | pytorch-pretrained-bert 10 | requests 11 | sentencepiece 12 | torchtext 13 | tensorboard 14 | torch 15 | transformers==3.4.0 16 | pandas 17 | -------------------------------------------------------------------------------- /install_deps: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | pip install --upgrade pip 3 | if [[ "$OSTYPE" == "darwin"* ]]; then 4 | # Mac OSX, need to specify the CFLAGS for fairseq https://github.com/pytorch/fairseq 5 | echo "Mac OS" 6 | CFLAGS="-stdlib=libc++" pip install -e . --upgrade --no-cache-dir --progress-bar off --upgrade-strategy eager 7 | else 8 | # Any other OS 9 | pip install -e . --upgrade --no-cache-dir --progress-bar off --upgrade-strategy eager 10 | fi 11 | -------------------------------------------------------------------------------- /install_deps.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | ::Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | python -m pip install --upgrade pip 4 | pip install -e . --process-dependency-links --no-cache-dir --progress-bar off 5 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | filterwarnings = 3 | ignore::DeprecationWarning:.*caffe2.* -------------------------------------------------------------------------------- /pytext/common/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from .constants import ( 5 | BatchContext, 6 | DatasetFieldName, 7 | DFColumn, 8 | matcha_entity_high_level_domains, 9 | matcha_entity_raw_domains, 10 | PackageFileName, 11 | Padding, 12 | Stage, 13 | VocabMeta, 14 | ) 15 | -------------------------------------------------------------------------------- /pytext/common/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from sys import stderr 4 | 5 | 6 | def eprint(*args, **kwargs): 7 | print(file=stderr, *args, **kwargs) 8 | -------------------------------------------------------------------------------- /pytext/config/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from .config_adapter import upgrade_to_latest # noqa 5 | from .pytext_config import ( # noqa 6 | ConfigBase, 7 | ExportConfig, 8 | LATEST_VERSION, 9 | PyTextConfig, 10 | TestConfig, 11 | ) 12 | from .serialize import config_from_json, config_to_json, pytext_config_from_json # noqa 13 | -------------------------------------------------------------------------------- /pytext/config/contextual_intent_slot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from typing import List, Optional 4 | 5 | from .field_config import ( 6 | CharFeatConfig, 7 | ContextualTokenEmbeddingConfig, 8 | DictFeatConfig, 9 | FloatVectorConfig, 10 | TargetConfigBase, 11 | WordFeatConfig, 12 | ) 13 | from .module_config import ModuleConfig 14 | 15 | 16 | class ModelInputConfig(ModuleConfig): 17 | word_feat: Optional[WordFeatConfig] = WordFeatConfig() 18 | dict_feat: Optional[DictFeatConfig] = None 19 | char_feat: Optional[CharFeatConfig] = None 20 | contextual_token_embedding: Optional[ContextualTokenEmbeddingConfig] = None 21 | seq_word_feat: Optional[WordFeatConfig] = WordFeatConfig() 22 | dense_feat: Optional[FloatVectorConfig] = None 23 | 24 | 25 | TargetConfig = List[TargetConfigBase] 26 | 27 | 28 | class ModelInput: 29 | TEXT = "word_feat" 30 | DICT = "dict_feat" 31 | CHAR = "char_feat" 32 | CONTEXTUAL_TOKEN_EMBEDDING = "contextual_token_embedding" 33 | SEQ = "seq_word_feat" 34 | DENSE = "dense_feat" 35 | 36 | 37 | class ExtraField: 38 | DOC_WEIGHT = "doc_weight" 39 | WORD_WEIGHT = "word_weight" 40 | RAW_WORD_LABEL = "raw_word_label" 41 | TOKEN_RANGE = "token_range" 42 | UTTERANCE = "utterance" 43 | -------------------------------------------------------------------------------- /pytext/config/doc_classification.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from typing import Optional 5 | 6 | from .field_config import ( 7 | CharFeatConfig, 8 | ContextualTokenEmbeddingConfig, 9 | DictFeatConfig, 10 | DocLabelConfig, 11 | FloatVectorConfig, 12 | WordFeatConfig, 13 | ) 14 | from .module_config import ModuleConfig 15 | 16 | 17 | class ModelInputConfig(ModuleConfig): 18 | word_feat: WordFeatConfig = WordFeatConfig() 19 | dict_feat: Optional[DictFeatConfig] = None 20 | char_feat: Optional[CharFeatConfig] = None 21 | contextual_token_embedding: Optional[ContextualTokenEmbeddingConfig] = None 22 | dense_feat: Optional[FloatVectorConfig] = None 23 | 24 | 25 | TargetConfig = DocLabelConfig 26 | 27 | 28 | class ModelInput: 29 | WORD_FEAT = "word_feat" 30 | DICT_FEAT = "dict_feat" 31 | CHAR_FEAT = "char_feat" 32 | CONTEXTUAL_TOKEN_EMBEDDING = "contextual_token_embedding" 33 | SEQ_LENS = "seq_lens" 34 | DENSE_FEAT = "dense_feat" 35 | 36 | 37 | class ExtraField: 38 | RAW_TEXT = "text" 39 | -------------------------------------------------------------------------------- /pytext/config/module_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from enum import Enum 4 | from typing import List, Optional 5 | 6 | from .pytext_config import ConfigBase 7 | 8 | 9 | class ModuleConfig(ConfigBase): 10 | # Checkpoint load path 11 | load_path: Optional[str] = None 12 | # Checkpoint save path, relative to PyTextConfig.modules_save_dir (if set) 13 | save_path: Optional[str] = None 14 | # Freezing a module means its parameters won't be updated during training. 15 | freeze: bool = False 16 | # modules which have the same shared_module_key and type share parameters 17 | shared_module_key: Optional[str] = None 18 | 19 | 20 | class CNNParams(ConfigBase): 21 | # Number of feature maps for each kernel 22 | kernel_num: int = 100 23 | # Kernel sizes to use in convolution 24 | kernel_sizes: List[int] = [3, 4] 25 | # Use weight norm in convolution 26 | weight_norm: bool = False 27 | # Enables dilated convolutions 28 | dilated: bool = False 29 | # Enables causal convolutions 30 | causal: bool = False 31 | 32 | 33 | class PoolingType(Enum): 34 | MEAN = "mean" 35 | MAX = "max" 36 | LOGSUMEXP = "logsumexp" 37 | NONE = "none" 38 | 39 | 40 | class SlotAttentionType(Enum): 41 | NO_ATTENTION = "no_attention" 42 | CONCAT = "concat" 43 | MULTIPLY = "multiply" 44 | DOT = "dot" 45 | 46 | 47 | class PerplexityType(Enum): 48 | MIN = "min" 49 | MAX = "max" 50 | MEAN = "mean" 51 | MEDIAN = "median" 52 | EOS = "eos" 53 | 54 | 55 | class Activation(Enum): 56 | RELU = "relu" 57 | LEAKYRELU = "leakyrelu" 58 | TANH = "tanh" 59 | GELU = "gelu" 60 | GLU = "glu" 61 | 62 | 63 | class ExporterType(Enum): 64 | PREDICTOR = "predictor" 65 | INIT_PREDICT = "init_predict" 66 | -------------------------------------------------------------------------------- /pytext/config/pair_classification.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from .field_config import DocLabelConfig, WordFeatConfig 4 | from .module_config import ModuleConfig 5 | 6 | 7 | class ModelInputConfig(ModuleConfig): 8 | text1: WordFeatConfig = WordFeatConfig() 9 | text2: WordFeatConfig = WordFeatConfig() 10 | 11 | 12 | TargetConfig = DocLabelConfig 13 | 14 | 15 | class ModelInput: 16 | TEXT1 = "text1" 17 | TEXT2 = "text2" 18 | 19 | 20 | class ExtraField: 21 | UTTERANCE_PAIR = "utterance" 22 | -------------------------------------------------------------------------------- /pytext/config/query_document_pairwise_ranking.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from .field_config import WordFeatConfig 4 | from .module_config import ModuleConfig 5 | 6 | 7 | class ModelInputConfig(ModuleConfig): 8 | pos_response: WordFeatConfig = WordFeatConfig() 9 | neg_response: WordFeatConfig = WordFeatConfig() 10 | query: WordFeatConfig = WordFeatConfig() 11 | 12 | 13 | class ModelInput: 14 | QUERY = "query" 15 | POS_RESPONSE = "pos_response" 16 | NEG_RESPONSE = "neg_response" 17 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v1_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "original": { 4 | "task": { 5 | "DocClassificationTask": { 6 | "optimizer": { 7 | "type": "adam", 8 | "lr": 0.05, 9 | "weight_decay": 0.00001, 10 | "momentum": 0 11 | } 12 | } 13 | } 14 | }, 15 | "adapted": { 16 | "task": { 17 | "DocClassificationTask": { 18 | "optimizer": { 19 | "Adam": { 20 | "lr": 0.05, 21 | "weight_decay": 0.00001 22 | } 23 | } 24 | } 25 | }, 26 | "version": 1 27 | } 28 | }, 29 | { 30 | "original": { 31 | "task": { 32 | "DocClassificationTask": { 33 | "optimizer": { 34 | "lr": 0.05, 35 | "weight_decay": 0.00001, 36 | "momentum": 0 37 | } 38 | } 39 | } 40 | }, 41 | "adapted": { 42 | "task": { 43 | "DocClassificationTask": { 44 | "optimizer": { 45 | "Adam": { 46 | "lr": 0.05, 47 | "weight_decay": 0.00001 48 | } 49 | } 50 | } 51 | }, 52 | "version": 1 53 | } 54 | }, 55 | { 56 | "original": { 57 | "task": { 58 | "DocClassificationTask": { 59 | "optimizer": { 60 | "type": "sgd", 61 | "lr": 0.05, 62 | "weight_decay": 0.00001, 63 | "momentum": 0.00001 64 | } 65 | } 66 | } 67 | }, 68 | "adapted": { 69 | "task": { 70 | "DocClassificationTask": { 71 | "optimizer": { 72 | "SGD": { 73 | "lr": 0.05, 74 | "momentum": 0.00001 75 | } 76 | } 77 | } 78 | }, 79 | "version": 1 80 | } 81 | } 82 | ] 83 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v24_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v27_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v28_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v29_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v30_test_downgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v30_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v31_test_downgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v31_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v32_test_downgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v32_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v33_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v34_test_downgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v34_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v35_test_downgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v35_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v36_test_downgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v36_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v37_test_downgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v37_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v38_test_downgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v38_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v39_test_downgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v39_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v40_test_downgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v40_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v41_test_downgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v41_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v42_test_downgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v42_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v43_test_downgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v43_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v44_test_downgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v44_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v45_test_downgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v45_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v46_test_downgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v46_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v47_test_downgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v47_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v48_test_downgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v48_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v4_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "original": { 4 | "task": { 5 | "DocClassificationTask": { 6 | "features": { 7 | "pretrained_model_embedding": { 8 | "model_paths": { 9 | "en_XX": "/mnt/vol/pytext/nlu/tests/nets/dummy_elmo_net.predictor" 10 | }, 11 | "embed_dim": 128 12 | } 13 | } 14 | } 15 | }, 16 | "version": 3 17 | }, 18 | "adapted": { 19 | "task": { 20 | "DocClassificationTask": { 21 | "features": { 22 | "contextual_token_embedding": { 23 | "model_paths": { 24 | "en_XX": "/mnt/vol/pytext/nlu/tests/nets/dummy_elmo_net.predictor" 25 | }, 26 | "embed_dim": 128 27 | } 28 | } 29 | } 30 | }, 31 | "version": 4 32 | } 33 | } 34 | ] 35 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v6_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "original": { 4 | "task": { 5 | "DocClassificationTask": { 6 | "data_handler": { 7 | "train_path": "tests/data/train_data_tiny.tsv", 8 | "eval_path": "tests/data/test_data_tiny.tsv", 9 | "test_path": "tests/data/test_data_tiny.tsv" 10 | } 11 | } 12 | }, 13 | "version": 5 14 | }, 15 | "adapted": { 16 | "task": { 17 | "DocClassificationTask_Deprecated": { 18 | "data_handler": { 19 | "train_path": "tests/data/train_data_tiny.tsv", 20 | "eval_path": "tests/data/test_data_tiny.tsv", 21 | "test_path": "tests/data/test_data_tiny.tsv" 22 | } 23 | } 24 | }, 25 | "version": 6 26 | } 27 | } 28 | ] 29 | -------------------------------------------------------------------------------- /pytext/config/test/json_config/v8_test_upgrade.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "original": { 4 | "task": { 5 | "LMTask_Deprecated": { 6 | "data_handler": { 7 | "LanguageModelDataHandler": { 8 | "train_path": "tests/data/train_data_tiny.tsv", 9 | "eval_path": "tests/data/test_data_tiny.tsv", 10 | "test_path": "tests/data/test_data_tiny.tsv" 11 | } 12 | } 13 | } 14 | }, 15 | "version": 7 16 | }, 17 | "adapted": { 18 | "task": { 19 | "LMTask_Deprecated": { 20 | "data_handler": { 21 | "LanguageModelDataHandler": { 22 | "train_path": "tests/data/train_data_tiny.tsv", 23 | "eval_path": "tests/data/test_data_tiny.tsv", 24 | "test_path": "tests/data/test_data_tiny.tsv" 25 | } 26 | } 27 | } 28 | }, 29 | "version": 8 30 | } 31 | } 32 | ] 33 | -------------------------------------------------------------------------------- /pytext/config/test/pytext_all_config_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import glob 5 | import json 6 | import os 7 | import unittest 8 | 9 | from pytext.builtin_task import register_builtin_tasks 10 | from pytext.config.serialize import parse_config 11 | from pytext.utils.file_io import PathManager 12 | from pytext.utils.path import get_absolute_path, PYTEXT_HOME 13 | 14 | 15 | register_builtin_tasks() 16 | 17 | 18 | # These JSON files are not parseable configs 19 | EXCLUDE_JSON = { 20 | # used by test_merge_token_labels_to_slot 21 | "utils/tests/test_samples.json", 22 | # "pytext/data/test/data/gpt2_encoder.json", 23 | } 24 | # TODO: @stevenliu T52746850 include all config files from demo, include 25 | # as many as possible from fb 26 | EXCLUDE_DIRS = { 27 | "pytext/contrib", 28 | "pytext/config/test/json_config", 29 | "pytext/demo", 30 | "pytext/data/test/data", 31 | "pytext/fb", 32 | "pytext/tests/data", 33 | } 34 | 35 | 36 | class LoadAllConfigTest(unittest.TestCase): 37 | def test_load_all_configs(self): 38 | """ 39 | Try an load all the json files in pytext to make sure we didn't 40 | break the config API. 41 | """ 42 | for filename in glob.iglob("pytext/**/*.json", recursive=True): 43 | if any(f in filename for f in EXCLUDE_JSON): 44 | continue 45 | if any(d in filename for d in EXCLUDE_DIRS): 46 | continue 47 | print("--- loading:", filename) 48 | with PathManager.open(filename) as file: 49 | config_json = json.load(file) 50 | config = parse_config(config_json) 51 | self.assertIsNotNone(config) 52 | -------------------------------------------------------------------------------- /pytext/data/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from .batch_sampler import ( 5 | AlternatingRandomizedBatchSampler, 6 | BaseBatchSampler, 7 | EvalBatchSampler, 8 | NaturalBatchSampler, 9 | RandomizedBatchSampler, 10 | RoundRobinBatchSampler, 11 | ) 12 | from .data import Batcher, Data, generator_iterator, PoolingBatcher 13 | from .data_handler import BatchIterator, CommonMetadata, DataHandler 14 | from .disjoint_multitask_data import DisjointMultitaskData 15 | from .disjoint_multitask_data_handler import DisjointMultitaskDataHandler 16 | from .dynamic_pooling_batcher import DynamicPoolingBatcher 17 | from .tensorizers import Tensorizer 18 | 19 | 20 | __all__ = [ 21 | "AlternatingRandomizedBatchSampler", 22 | "Batcher", 23 | "BaseBatchSampler", 24 | "BatchIterator", 25 | "CommonMetadata", 26 | "Data", 27 | "DataHandler", 28 | "DisjointMultitaskData", 29 | "DisjointMultitaskDataHandler", 30 | "DynamicPoolingBatcher", 31 | "EvalBatchSampler", 32 | "generator_iterator", 33 | "PoolingBatcher", 34 | "RandomizedBatchSampler", 35 | "RoundRobinBatchSampler", 36 | "NaturalBatchSampler", 37 | "Tensorizer", 38 | ] 39 | -------------------------------------------------------------------------------- /pytext/data/data_structures/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | -------------------------------------------------------------------------------- /pytext/data/data_structures/node.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from typing import AbstractSet, Any, NamedTuple, Optional 5 | 6 | 7 | class Span(NamedTuple): 8 | """ 9 | Span of a node in an intent-slot tree. 10 | 11 | Attributes: 12 | start: Start position of the node. 13 | end: End position of the node (exclusive). 14 | """ 15 | 16 | start: int 17 | end: int 18 | 19 | 20 | class Node: 21 | """ 22 | Node in an intent-slot tree, representing either an intent or a slot. 23 | 24 | Attributes: 25 | label (str): Label of the node. 26 | span (Span): Span of the node. 27 | children (:obj:`set` of :obj:`Node`): Children of the node. 28 | """ 29 | 30 | __slots__ = "label", "span", "children", "text" 31 | 32 | def __init__( 33 | self, 34 | label: str, 35 | span: Span, 36 | children: Optional[AbstractSet["Node"]] = None, 37 | text: str = None, 38 | ) -> None: 39 | object.__setattr__(self, "label", label) 40 | object.__setattr__(self, "span", span) 41 | object.__setattr__( 42 | self, "children", children if children is not None else set() 43 | ) 44 | object.__setattr__(self, "text", text) 45 | 46 | def __eq__(self, other: Any) -> bool: 47 | if not isinstance(other, Node): 48 | return NotImplemented 49 | return ( 50 | self.label == other.label # noqa 51 | and self.span == other.span # noqa 52 | and self.children == other.children # noqa 53 | and self.text == other.text # noqa 54 | ) 55 | 56 | def get_depth(self) -> int: 57 | return 1 + max( 58 | (child.get_depth() for child in self.children), default=0 # noqa 59 | ) 60 | -------------------------------------------------------------------------------- /pytext/data/featurizer/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from .featurizer import Featurizer, InputRecord, OutputRecord 5 | from .simple_featurizer import SimpleFeaturizer 6 | 7 | 8 | __all__ = ["Featurizer", "InputRecord", "OutputRecord", "SimpleFeaturizer"] 9 | -------------------------------------------------------------------------------- /pytext/data/pickleable_gpt2bpe_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from .pickleable_gpt2bpe_encoder import PickleableGPT2BPEEncoder 5 | 6 | 7 | __all__ = [ 8 | "PickleableGPT2BPEEncoder", 9 | ] 10 | -------------------------------------------------------------------------------- /pytext/data/pickleable_gpt2bpe_encoder/pickleable_gpt2bpe_encoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | import copy 4 | 5 | from fairseq.data.encoders.gpt2_bpe_utils import Encoder as GPT2BPEEncoder 6 | 7 | 8 | class PickleableGPT2BPEEncoder(GPT2BPEEncoder): 9 | """Fairseq's encoder stores the regex module as a local reference on its encoders, 10 | which means they can't be saved via pickle.dumps or torch.save. This modified 11 | their save/load logic doesn't store the module, and restores the reference 12 | after re-inflating.""" 13 | 14 | def __getstate__(self): 15 | # make a shallow copy of state to avoid side effect on the original object 16 | state = copy.copy(vars(self)) 17 | state.pop("re") 18 | return state 19 | 20 | def __setstate__(self, state): 21 | vars(self).update(state) 22 | import regex 23 | 24 | self.re = regex 25 | -------------------------------------------------------------------------------- /pytext/data/sources/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from .conllu import CoNLLUNERDataSource 5 | from .data_source import DataSource, RawExample 6 | from .dense_retrieval import DenseRetrievalDataSource 7 | from .pandas import PandasDataSource 8 | from .squad import SquadDataSource 9 | from .tsv import TSVDataSource 10 | 11 | 12 | __all__ = [ 13 | "DataSource", 14 | "RawExample", 15 | "SquadDataSource", 16 | "TSVDataSource", 17 | "PandasDataSource", 18 | "CoNLLUNERDataSource", 19 | "DenseRetrievalDataSource", 20 | ] 21 | -------------------------------------------------------------------------------- /pytext/data/sources/session.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from itertools import chain 5 | from typing import List 6 | 7 | from pytext.config.serialize import _get_class_type 8 | from pytext.data.sources.data_source import RootDataSource 9 | 10 | 11 | class SessionDataSource(RootDataSource): 12 | """ 13 | Data source for session based data, the input data is organized in sessions, 14 | each session may have multiple rows. The first column is always the session id. 15 | Raw input rows are consolidated by session id and returned as one session 16 | per example 17 | """ 18 | 19 | def __init__(self, id_col, **kwargs): 20 | self.id_col = id_col 21 | self.current_id = None 22 | self.current_session = [] 23 | super().__init__(**kwargs) 24 | 25 | def _validate_schema(self): 26 | """Make sure the input schema are all list type, which is the return value 27 | type, and convert it to the actual type (e.g List[T] -> T) when reading the 28 | raw data from file. 29 | """ 30 | for k, v in self.schema.items(): 31 | if k != self.id_col: 32 | assert _get_class_type(v) is list, f"{k} is not a list type!" 33 | self.schema[k] = v.__args__[0] 34 | 35 | def merge_session(self, session): 36 | res = {self.id_col: session[0][self.id_col]} 37 | for k, v in chain.from_iterable([s.items() for s in session]): 38 | if k != self.id_col: 39 | res[k] = res.get(k, []) 40 | res[k].append(v) 41 | return res 42 | 43 | def _convert_raw_source(self, source): 44 | for row in source: 45 | example = self._read_example(row) 46 | if example is None: 47 | continue 48 | if example[self.id_col] == self.current_id: 49 | self.current_session.append(example) 50 | else: 51 | self.current_id = example[self.id_col] 52 | session = self.current_session 53 | self.current_session = [example] 54 | if session: 55 | yield self.merge_session(session) 56 | self.current_id = None 57 | session = self.current_session 58 | self.current_session = [] 59 | if session: 60 | yield self.merge_session(session) 61 | -------------------------------------------------------------------------------- /pytext/data/test/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | -------------------------------------------------------------------------------- /pytext/data/test/data/gpt2_dict.txt: -------------------------------------------------------------------------------- 1 | 19703 850314647 2 | 8690 800385005 3 | -------------------------------------------------------------------------------- /pytext/data/test/data/gpt2_encoder.json: -------------------------------------------------------------------------------- 1 | {"otype": 8690, "Prot": 19703, "Ġ": 220} 2 | -------------------------------------------------------------------------------- /pytext/data/test/data/gpt2_vocab.bpe: -------------------------------------------------------------------------------- 1 | #version: 0.2 2 | ĠProt otype 3 | r o 4 | o t 5 | p e 6 | P ro 7 | y pe 8 | ot ype 9 | Pro t 10 | -------------------------------------------------------------------------------- /pytext/data/test/data/sentencepiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/pytext/08754b483421884d2e363f00517ea42e449aec2c/pytext/data/test/data/sentencepiece.model -------------------------------------------------------------------------------- /pytext/data/test/data/spm_ontology.txt: -------------------------------------------------------------------------------- 1 | testing 2 | -------------------------------------------------------------------------------- /pytext/data/test/pandas_data_source_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import unittest 5 | 6 | import pandas as pd 7 | from pytext.config.component import ComponentType, create_component 8 | from pytext.data.sources import PandasDataSource 9 | 10 | 11 | class PandasDataSourceTest(unittest.TestCase): 12 | def test_create_from_config(self): 13 | source_config = PandasDataSource.Config( 14 | train_df=pd.DataFrame({"c1": [10, 20, 30], "c2": [40, 50, 60]}), 15 | eval_df=pd.DataFrame({"c1": [11, 21, 31], "c2": [41, 51, 61]}), 16 | test_df=pd.DataFrame({"c1": [12, 22, 32], "c2": [42, 52, 62]}), 17 | column_mapping={"c1": "feature1", "c2": "feature2"}, 18 | ) 19 | ds = create_component( 20 | ComponentType.DATA_SOURCE, 21 | source_config, 22 | schema={"feature1": float, "feature2": float}, 23 | ) 24 | self.assertEqual({"feature1": 10, "feature2": 40}, next(iter(ds.train))) 25 | self.assertEqual({"feature1": 11, "feature2": 41}, next(iter(ds.eval))) 26 | self.assertEqual({"feature1": 12, "feature2": 42}, next(iter(ds.test))) 27 | self.assertEqual(3, len(list(ds.train))) 28 | 29 | def test_create_data_source(self): 30 | ds = PandasDataSource( 31 | train_df=pd.DataFrame({"c1": [10, 20, 30], "c2": [40, 50, 60]}), 32 | eval_df=pd.DataFrame({"c1": [11, 21, 31], "c2": [41, 51, 61]}), 33 | test_df=pd.DataFrame({"c1": [12, 22, 32], "c2": [42, 52, 62]}), 34 | schema={"feature1": float, "feature2": float}, 35 | column_mapping={"c1": "feature1", "c2": "feature2"}, 36 | ) 37 | self.assertEqual({"feature1": 10, "feature2": 40}, next(iter(ds.train))) 38 | self.assertEqual({"feature1": 11, "feature2": 41}, next(iter(ds.eval))) 39 | self.assertEqual({"feature1": 12, "feature2": 42}, next(iter(ds.test))) 40 | self.assertEqual(3, len(list(ds.train))) 41 | 42 | def test_empty_data(self): 43 | ds = PandasDataSource(schema={"feature1": float, "feature2": float}) 44 | self.assertEqual(0, len(list(ds.train))) 45 | -------------------------------------------------------------------------------- /pytext/data/test/round_robin_batchiterator_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import unittest 5 | 6 | from pytext.data.disjoint_multitask_data_handler import RoundRobinBatchIterator 7 | 8 | 9 | class RoundRobinBatchIteratorTest(unittest.TestCase): 10 | def test_batch_iterator(self): 11 | iteratorA = [(input, None, {}) for input in ["1", "2", "3", "4", "5"]] 12 | iteratorB = [(input, None, {}) for input in ["a", "b", "c"]] 13 | 14 | # upsample = True, no iter_to_set_epoch 15 | round_robin_iterator = RoundRobinBatchIterator( 16 | {"A": iteratorA, "B": iteratorB}, upsample=True 17 | ) 18 | expected_items = ["1", "a", "2", "b", "3", "c"] 19 | self._check_iterator(round_robin_iterator, expected_items) 20 | 21 | # upsample = True, iter_to_set_epoch = "A" 22 | round_robin_iterator = RoundRobinBatchIterator( 23 | {"A": iteratorA, "B": iteratorB}, upsample=True, iter_to_set_epoch="A" 24 | ) 25 | expected_items = ["1", "a", "2", "b", "3", "c", "4", "a", "5", "b"] 26 | self._check_iterator(round_robin_iterator, expected_items) 27 | 28 | # upsample = False 29 | round_robin_iterator = RoundRobinBatchIterator( 30 | {"A": iteratorA, "B": iteratorB}, upsample=False 31 | ) 32 | expected_items = ["1", "2", "3", "4", "5", "a", "b", "c"] 33 | self._check_iterator(round_robin_iterator, expected_items, fixed_order=False) 34 | 35 | def _check_iterator(self, iterator, expected_items, fixed_order=True): 36 | actual_items = [item for item, _, _ in iterator] 37 | if not fixed_order: 38 | # Order is random, just check that the sorted arrays are equal 39 | actual_items = sorted(actual_items) 40 | expected_items = sorted(expected_items) 41 | self.assertListEqual(actual_items, expected_items) 42 | -------------------------------------------------------------------------------- /pytext/data/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from .tokenizer import ( 5 | CppProcessorMixin, 6 | DoNothingTokenizer, 7 | GPT2BPETokenizer, 8 | SentencePieceTokenizer, 9 | SPEandWordTokenizer, 10 | Token, 11 | Tokenizer, 12 | WordPieceTokenizer, 13 | ) 14 | 15 | 16 | __all__ = [ 17 | "GPT2BPETokenizer", 18 | "Token", 19 | "Tokenizer", 20 | "DoNothingTokenizer", 21 | "WordPieceTokenizer", 22 | "CppProcessorMixin", 23 | "SPEandWordTokenizer", 24 | "SentencePieceTokenizer", 25 | ] 26 | -------------------------------------------------------------------------------- /pytext/data/xlm_constants.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | LANG2ID_15 = { 5 | "ar": 0, 6 | "bg": 1, 7 | "de": 2, 8 | "el": 3, 9 | "en": 4, 10 | "es": 5, 11 | "fr": 6, 12 | "hi": 7, 13 | "ru": 8, 14 | "sw": 9, 15 | "th": 10, 16 | "tr": 11, 17 | "ur": 12, 18 | "vi": 13, 19 | "zh": 14, 20 | } 21 | 22 | LANG2ID_20 = { 23 | "ar": 0, 24 | "bn": 1, 25 | "de": 2, 26 | "en": 3, 27 | "es": 4, 28 | "fr": 5, 29 | "hi": 6, 30 | "id": 7, 31 | "it": 8, 32 | "ko": 9, 33 | "my": 10, 34 | "pl": 11, 35 | "pt": 12, 36 | "ru": 13, 37 | "sw": 14, 38 | "th": 15, 39 | "tl": 16, 40 | "tr": 17, 41 | "vi": 18, 42 | "zh": 19, 43 | } 44 | 45 | 46 | LANG2ID_43 = { 47 | "ar_AR": 0, 48 | "bg_BG": 1, 49 | "bn_IN": 2, 50 | "da_DK": 3, 51 | "de_DE": 4, 52 | "el_GR": 5, 53 | "en_XX": 6, 54 | "es_XX": 7, 55 | "fa_IR": 8, 56 | "fr_XX": 9, 57 | "he_IL": 10, 58 | "hi_IN": 11, 59 | "hu_HU": 12, 60 | "id_ID": 13, 61 | "it_IT": 14, 62 | "ja_XX": 15, 63 | "km_KH": 16, 64 | "kn_IN": 17, 65 | "ko_KR": 18, 66 | "lt_LT": 19, 67 | "ml_IN": 20, 68 | "mr_IN": 21, 69 | "ms_MY": 22, 70 | "my_MM": 23, 71 | "nl_XX": 24, 72 | "pa_IN": 25, 73 | "pl_PL": 26, 74 | "ps_AF": 27, 75 | "pt_XX": 28, 76 | "ro_RO": 29, 77 | "ru_RU": 30, 78 | "si_LK": 31, 79 | "sv_SE": 32, 80 | "sw_KE": 33, 81 | "ta_IN": 34, 82 | "te_IN": 35, 83 | "th_TH": 36, 84 | "tl_XX": 37, 85 | "tr_TR": 38, 86 | "ur_PK": 39, 87 | "vi_VN": 40, 88 | "zh_TW": 41, 89 | "zh_CN": 42, 90 | } 91 | -------------------------------------------------------------------------------- /pytext/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = -j=auto -W 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = PyText 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | doc-html: 16 | make html 17 | 18 | clean-all: 19 | make clean 20 | rm -r source/modules 21 | rm -r source/configs 22 | 23 | .PHONY: help Makefile 24 | 25 | # Catch-all target: route all unknown targets to Sphinx using the new 26 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 27 | %: Makefile 28 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 29 | -------------------------------------------------------------------------------- /pytext/docs/origin/README: -------------------------------------------------------------------------------- 1 | This directory contains the source files for the images, diagrams, etc. 2 | 3 | *.odg -> LibreOffice/OpenOffice Draw 4 | -------------------------------------------------------------------------------- /pytext/docs/origin/pytext.odg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/pytext/08754b483421884d2e363f00517ea42e449aec2c/pytext/docs/origin/pytext.odg -------------------------------------------------------------------------------- /pytext/docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx<2.0 2 | sphinx_rtd_theme 3 | -------------------------------------------------------------------------------- /pytext/docs/source/_static/img/flask_www.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/pytext/08754b483421884d2e363f00517ea42e449aec2c/pytext/docs/source/_static/img/flask_www.png -------------------------------------------------------------------------------- /pytext/docs/source/_static/img/ios_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/pytext/08754b483421884d2e363f00517ea42e449aec2c/pytext/docs/source/_static/img/ios_demo.png -------------------------------------------------------------------------------- /pytext/docs/source/_static/img/pytext.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/pytext/08754b483421884d2e363f00517ea42e449aec2c/pytext/docs/source/_static/img/pytext.png -------------------------------------------------------------------------------- /pytext/docs/source/_static/img/pytext_design.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/pytext/08754b483421884d2e363f00517ea42e449aec2c/pytext/docs/source/_static/img/pytext_design.png -------------------------------------------------------------------------------- /pytext/docs/source/_static/img/tb_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/pytext/08754b483421884d2e363f00517ea42e449aec2c/pytext/docs/source/_static/img/tb_graph.png -------------------------------------------------------------------------------- /pytext/docs/source/_static/img/tb_test_metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/pytext/08754b483421884d2e363f00517ea42e449aec2c/pytext/docs/source/_static/img/tb_test_metrics.png -------------------------------------------------------------------------------- /pytext/docs/source/_static/img/tb_train_metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/pytext/08754b483421884d2e363f00517ea42e449aec2c/pytext/docs/source/_static/img/tb_train_metrics.png -------------------------------------------------------------------------------- /pytext/docs/source/seq2seq_tutorial.rst: -------------------------------------------------------------------------------- 1 | Semantic parsing with sequence-to-sequence models 2 | ================================================= 3 | 4 | Introduction 5 | ------------ 6 | 7 | PyText provides an encoder-decoder framework that is suitable for any task 8 | that requires mapping a sequence of input tokens to a sequence of output 9 | tokens. The default implementation is based on recurrent neural networks 10 | (RNNs), which have been shown to be `unreasonably effective 11 | `_ at sequence 12 | processing tasks. The default implementation includes three major components 13 | 14 | #. A bidirectional LSTM sequence encoder 15 | #. An LSTM sequence decoder 16 | #. A sequence generator that supports incremental decoding and beam search 17 | 18 | All of these components are Torchscript-friendly, so that the trained model 19 | can be exported directly as-is. Following the general design of PyText, each 20 | of these components may be customized via their respective config objects or 21 | replaced entirely by custom components. 22 | 23 | Tutorial 24 | -------- 25 | 26 | `Tutorial in notebook `_ 27 | `Run the tutorial in Google Colab `_ 28 | -------------------------------------------------------------------------------- /pytext/exporters/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from pytext.exporters.custom_exporters import ( 5 | DenseFeatureExporter, 6 | InitPredictNetExporter, 7 | ) 8 | from pytext.exporters.exporter import ModelExporter 9 | 10 | 11 | __all__ = ["ModelExporter", "DenseFeatureExporter", "InitPredictNetExporter"] 12 | -------------------------------------------------------------------------------- /pytext/fields/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from .char_field import CharFeatureField 5 | from .contextual_token_embedding_field import ContextualTokenEmbeddingField 6 | from .dict_field import DictFeatureField 7 | from .field import ( 8 | ActionField, 9 | create_fields, 10 | create_label_fields, 11 | DocLabelField, 12 | Field, 13 | FieldMeta, 14 | FloatField, 15 | FloatVectorField, 16 | NestedField, 17 | RawField, 18 | SeqFeatureField, 19 | TextFeatureField, 20 | VocabUsingField, 21 | VocabUsingNestedField, 22 | WordLabelField, 23 | ) 24 | from .text_field_with_special_unk import TextFeatureFieldWithSpecialUnk 25 | 26 | 27 | __all__ = [ 28 | "create_fields", 29 | "create_label_fields", 30 | "ActionField", 31 | "CharFeatureField", 32 | "ContextualTokenEmbeddingField", 33 | "DictFeatureField", 34 | "DocLabelField", 35 | "Field", 36 | "FieldMeta", 37 | "FloatField", 38 | "FloatVectorField", 39 | "RawField", 40 | "TextFeatureField", 41 | "VocabUsingField", 42 | "WordLabelField", 43 | "NestedField", 44 | "VocabUsingNestedField", 45 | "SeqFeatureField", 46 | "TextFeatureFieldWithSpecialUnk", 47 | ] 48 | -------------------------------------------------------------------------------- /pytext/legacy/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from torchtext import nn, utils 5 | 6 | from . import data, datasets, vocab 7 | 8 | __all__ = ["data", "nn", "datasets", "utils", "vocab"] 9 | -------------------------------------------------------------------------------- /pytext/legacy/data/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | # Those are not in the legacy folder. 5 | from torchtext.data import functional, metrics, utils 6 | from torchtext.data.functional import ( 7 | custom_replace, 8 | generate_sp_model, 9 | load_sp_model, 10 | numericalize_tokens_from_iterator, 11 | sentencepiece_numericalizer, 12 | sentencepiece_tokenizer, 13 | simple_space_split, 14 | ) 15 | from torchtext.data.metrics import bleu_score 16 | from torchtext.data.utils import get_tokenizer, interleave_keys 17 | 18 | from .batch import Batch 19 | from .dataset import Dataset, TabularDataset 20 | from .example import Example 21 | from .field import Field, LabelField, NestedField, RawField 22 | from .iterator import batch, BPTTIterator, BucketIterator, Iterator, pool 23 | from .pipeline import Pipeline 24 | 25 | __all__ = [ 26 | "Batch", 27 | "Example", 28 | "RawField", 29 | "Field", 30 | "NestedField", 31 | "LabelField", 32 | "batch", 33 | "BucketIterator", 34 | "Iterator", 35 | "BPTTIterator", 36 | "pool", 37 | "Pipeline", 38 | "Dataset", 39 | "TabularDataset", 40 | "metrics", 41 | "bleu_score", 42 | "utils", 43 | "get_tokenizer", 44 | "interleave_keys", 45 | "functional", 46 | "generate_sp_model", 47 | "load_sp_model", 48 | "sentencepiece_numericalizer", 49 | "sentencepiece_tokenizer", 50 | "custom_replace", 51 | "simple_space_split", 52 | "numericalize_tokens_from_iterator", 53 | ] 54 | -------------------------------------------------------------------------------- /pytext/legacy/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from .babi import BABI20 5 | from .imdb import IMDB 6 | from .language_modeling import ( 7 | LanguageModelingDataset, 8 | PennTreebank, 9 | WikiText103, 10 | WikiText2, 11 | ) # NOQA 12 | from .nli import MultiNLI, SNLI, XNLI 13 | from .sequence_tagging import CoNLL2000Chunking, SequenceTaggingDataset, UDPOS # NOQA 14 | from .sst import SST 15 | from .text_classification import ( 16 | AG_NEWS, 17 | AmazonReviewFull, 18 | AmazonReviewPolarity, 19 | DBpedia, 20 | SogouNews, 21 | TextClassificationDataset, 22 | YahooAnswers, 23 | YelpReviewFull, 24 | YelpReviewPolarity, 25 | ) 26 | from .translation import IWSLT, Multi30k, TranslationDataset, WMT14 # NOQA 27 | from .trec import TREC 28 | from .unsupervised_learning import EnWik9 29 | 30 | __all__ = [ 31 | "LanguageModelingDataset", 32 | "SNLI", 33 | "MultiNLI", 34 | "XNLI", 35 | "SST", 36 | "TranslationDataset", 37 | "Multi30k", 38 | "IWSLT", 39 | "WMT14", 40 | "WikiText2", 41 | "WikiText103", 42 | "PennTreebank", 43 | "TREC", 44 | "IMDB", 45 | "SequenceTaggingDataset", 46 | "UDPOS", 47 | "CoNLL2000Chunking", 48 | "BABI20", 49 | "TextClassificationDataset", 50 | "AG_NEWS", 51 | "SogouNews", 52 | "DBpedia", 53 | "YelpReviewPolarity", 54 | "YelpReviewFull", 55 | "YahooAnswers", 56 | "AmazonReviewPolarity", 57 | "AmazonReviewFull", 58 | "EnWik9", 59 | ] 60 | -------------------------------------------------------------------------------- /pytext/loss/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from .loss import ( 5 | AUCPRHingeLoss, 6 | BinaryCrossEntropyLoss, 7 | BinaryCrossEntropyWithLogitsLoss, 8 | BinaryFocalLoss, 9 | CosineEmbeddingLoss, 10 | CrossEntropyLoss, 11 | CTCLoss, 12 | FocalLoss, 13 | HingeLoss, 14 | KLDivergenceBCELoss, 15 | KLDivergenceCELoss, 16 | LabelSmoothedCrossEntropyLoss, 17 | Loss, 18 | MAELoss, 19 | MSELoss, 20 | MultiLabelSoftMarginLoss, 21 | NLLLoss, 22 | PairwiseRankingLoss, 23 | SourceType, 24 | ) 25 | from .regularized_loss import ( 26 | LabelSmoothingLoss, 27 | NARSamplewiseSequenceLoss, 28 | NARSequenceLoss, 29 | SamplewiseLabelSmoothingLoss, 30 | ) 31 | from .regularizer import AdaptiveRegularizer, EntropyRegularizer, UniformRegularizer 32 | from .structured_loss import CostFunctionType, StructuredLoss, StructuredMarginLoss 33 | 34 | 35 | __all__ = [ 36 | "AUCPRHingeLoss", 37 | "Loss", 38 | "CrossEntropyLoss", 39 | "CosineEmbeddingLoss", 40 | "BinaryCrossEntropyLoss", 41 | "BinaryCrossEntropyWithLogitsLoss", 42 | "HingeLoss", 43 | "MultiLabelSoftMarginLoss", 44 | "KLDivergenceBCELoss", 45 | "KLDivergenceCELoss", 46 | "MAELoss", 47 | "MSELoss", 48 | "NLLLoss", 49 | "PairwiseRankingLoss", 50 | "LabelSmoothedCrossEntropyLoss", 51 | "SourceType", 52 | "CostFunctionType", 53 | "StructuredLoss", 54 | "StructuredMarginLoss", 55 | "LabelSmoothingLoss", 56 | "SamplewiseLabelSmoothingLoss", 57 | "NARSequenceLoss", 58 | "NARSamplewiseSequenceLoss", 59 | "UniformRegularizer", 60 | "EntropyRegularizer", 61 | "AdaptiveRegularizer", 62 | "BinaryFocalLoss", 63 | "FocalLoss", 64 | "CTCLoss", 65 | ] 66 | -------------------------------------------------------------------------------- /pytext/loss/tests/ctc_loss_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import unittest 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from pytext.loss.loss import CTCLoss 9 | 10 | 11 | class CTCLossTest(unittest.TestCase): 12 | def test_ctc_loss(self): 13 | torch.manual_seed(0) 14 | 15 | N = 16 # Batch size 16 | T = 50 # Input sequence length 17 | C = 20 # Number of classes (including blank) 18 | S = 30 # Target sequence length of longest target in batch (padding length) 19 | S_min = 10 # Minimum target length (only for testing) 20 | 21 | logits = torch.randn(N, T, C) 22 | targets = torch.randint(1, C, (N, S), dtype=torch.long) 23 | input_lengths = torch.full((N,), T, dtype=torch.long) 24 | target_lengths = torch.randint(S_min, S, (N,), dtype=torch.long) 25 | 26 | config = CTCLoss.Config() 27 | config.blank = 0 # Needs to be set to 0 for CuDNN support. 28 | ctc_loss_fn = CTCLoss(config=config) 29 | 30 | ctc_loss_val = ctc_loss_fn( 31 | logits, 32 | targets, 33 | input_lengths, 34 | target_lengths, 35 | ) 36 | 37 | # PyTorch CTC loss 38 | log_probs = logits.permute(1, 0, 2).log_softmax( 39 | 2 40 | ) # permute to conform to CTC loss input tensor (T,N,C) in PyTorch. 41 | lib_ctc_loss_val = F.ctc_loss(log_probs, targets, input_lengths, target_lengths) 42 | 43 | self.assertAlmostEqual(ctc_loss_val.item(), lib_ctc_loss_val.item()) 44 | -------------------------------------------------------------------------------- /pytext/loss/tests/focal_loss_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import unittest 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from pytext.loss.loss import BinaryFocalLoss, FocalLoss 9 | 10 | 11 | class FocalLossTest(unittest.TestCase): 12 | def test_focal_loss_base(self): 13 | 14 | target = torch.randint(size=(5, 1), low=0, high=10) 15 | score = torch.randn((5, 10)) 16 | 17 | config = FocalLoss.Config() 18 | config.gamma = 0 19 | config.alpha = 1 20 | loss_fn = FocalLoss(config=config) 21 | 22 | val = loss_fn( 23 | score, 24 | target.reshape( 25 | 5, 26 | ), 27 | ) 28 | val2 = F.nll_loss( 29 | F.log_softmax(score, 1, dtype=torch.float32), 30 | target.reshape( 31 | 5, 32 | ), 33 | ) 34 | 35 | self.assertAlmostEqual(val.item(), val2.item()) 36 | 37 | def test_binary_focal_loss_base(self): 38 | 39 | target = torch.randint(size=(5, 1), low=0, high=10) 40 | score = torch.randn((5, 10)) 41 | 42 | # onehot encoded 43 | target_encode = torch.zeros_like(score) 44 | target_encode.scatter_(1, target, 1) 45 | 46 | config = BinaryFocalLoss.Config() 47 | config.gamma = 0 48 | config.alpha = 1 49 | loss_fn = BinaryFocalLoss(config=config) 50 | 51 | val = loss_fn(score, target_encode) 52 | val2 = F.binary_cross_entropy_with_logits(score, target_encode) 53 | 54 | self.assertAlmostEqual(val.item(), val2.item()) 55 | -------------------------------------------------------------------------------- /pytext/loss/tests/samplewise_label_smoothing_loss_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import unittest 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from pytext.loss import LabelSmoothingLoss, SamplewiseLabelSmoothingLoss 9 | 10 | 11 | class SamplewiseLabelSmoothingLossTest(unittest.TestCase): 12 | def test_samplewise_label_smoothing_loss(self): 13 | batch_size = 5 14 | num_labels = 5 15 | 16 | label_smoothing_loss = LabelSmoothingLoss( 17 | LabelSmoothingLoss.Config(), ignore_index=-1 18 | ) 19 | samplewise_label_smoothing_loss = SamplewiseLabelSmoothingLoss( 20 | SamplewiseLabelSmoothingLoss.Config(), ignore_index=-1 21 | ) 22 | 23 | logits = F.log_softmax(torch.rand(batch_size, num_labels), 1) 24 | targets = torch.randint(batch_size, (num_labels,)) 25 | 26 | self.assertTrue( 27 | torch.isclose( 28 | label_smoothing_loss(logits, targets, reduce=True), 29 | samplewise_label_smoothing_loss(logits, targets, reduce=True), 30 | ) 31 | ) 32 | self.assertTrue( 33 | torch.isclose( 34 | label_smoothing_loss(logits, targets, reduce=False), 35 | samplewise_label_smoothing_loss(logits, targets, reduce=False), 36 | ).all() 37 | ) 38 | -------------------------------------------------------------------------------- /pytext/metric_reporters/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from .calibration_metric_reporter import CalibrationMetricReporter 5 | from .channel import Channel 6 | from .classification_metric_reporter import ( 7 | ClassificationMetricReporter, 8 | MultiLabelClassificationMetricReporter, 9 | TopKClassificationMetricReporter, 10 | ) 11 | from .compositional_metric_reporter import CompositionalMetricReporter 12 | from .dense_retrieval_metric_reporter import DenseRetrievalMetricReporter 13 | from .intent_slot_detection_metric_reporter import IntentSlotMetricReporter 14 | from .language_model_metric_reporter import LanguageModelMetricReporter 15 | from .metric_reporter import MetricReporter, PureLossMetricReporter 16 | from .multi_span_qa_metric_reporter import MultiSpanQAMetricReporter 17 | from .pairwise_ranking_metric_reporter import PairwiseRankingMetricReporter 18 | from .regression_metric_reporter import RegressionMetricReporter 19 | from .squad_metric_reporter import SquadMetricReporter 20 | from .word_tagging_metric_reporter import ( 21 | MultiLabelSequenceTaggingMetricReporter, 22 | NERMetricReporter, 23 | SequenceTaggingMetricReporter, 24 | WordTaggingMetricReporter, 25 | ) 26 | 27 | 28 | __all__ = [ 29 | "Channel", 30 | "MetricReporter", 31 | "CalibrationMetricReporter", 32 | "ClassificationMetricReporter", 33 | "TopKClassificationMetricReporter", 34 | "MultiLabelClassificationMetricReporter", 35 | "MultiLabelSequenceTaggingMetricReporter", 36 | "RegressionMetricReporter", 37 | "IntentSlotMetricReporter", 38 | "LanguageModelMetricReporter", 39 | "SquadMetricReporter", 40 | "MultiSpanQAMetricReporter", 41 | "WordTaggingMetricReporter", 42 | "CompositionalMetricReporter", 43 | "PairwiseRankingMetricReporter", 44 | "SequenceTaggingMetricReporter", 45 | "PureLossMetricReporter", 46 | "NERMetricReporter", 47 | "DenseRetrievalMetricReporter", 48 | ] 49 | -------------------------------------------------------------------------------- /pytext/metric_reporters/pairwise_ranking_metric_reporter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from pytext.data import CommonMetadata 5 | from pytext.metrics import compute_pairwise_ranking_metrics 6 | 7 | from .channel import ConsoleChannel 8 | from .metric_reporter import MetricReporter 9 | 10 | 11 | class PairwiseRankingMetricReporter(MetricReporter): 12 | @classmethod 13 | def from_config(cls, config, meta: CommonMetadata = None, tensorizers=None): 14 | # TODO: add file channel 15 | return cls([ConsoleChannel()]) 16 | 17 | def calculate_metric(self): 18 | return compute_pairwise_ranking_metrics(self.all_preds, self.all_scores) 19 | 20 | def add_batch_stats( 21 | self, n_batches, preds, targets, scores, loss, m_input, **context 22 | ): 23 | # target = 1 means the first response was ranked higher than the second response 24 | # however, our training data is tuples of {pos_response, neg_response} pairs 25 | # i.e, pos_response is always the first response, neg_response is always the 26 | # second response. so target = 1 for all cases 27 | targets = [1] * preds.shape[0] 28 | super().add_batch_stats( 29 | n_batches, preds, targets, scores, loss, m_input, **context 30 | ) 31 | 32 | @staticmethod 33 | def get_model_select_metric(metrics): 34 | return metrics.accuracy 35 | -------------------------------------------------------------------------------- /pytext/metric_reporters/regression_metric_reporter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from pytext.metrics import compute_regression_metrics 5 | 6 | from .channel import ConsoleChannel 7 | from .metric_reporter import MetricReporter 8 | 9 | 10 | class RegressionMetricReporter(MetricReporter): 11 | 12 | lower_is_better = False 13 | 14 | class Config(MetricReporter.Config): 15 | pass 16 | 17 | @classmethod 18 | def from_config(cls, config, tensorizers=None): 19 | return cls([ConsoleChannel()]) 20 | 21 | def calculate_metric(self): 22 | assert len(self.all_preds) == len(self.all_targets) 23 | return compute_regression_metrics(self.all_preds, self.all_targets) 24 | 25 | def get_model_select_metric(self, metrics): 26 | return metrics.pearson_correlation 27 | -------------------------------------------------------------------------------- /pytext/metric_reporters/seq2seq_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | 5 | def stringify(token_indices, vocab): 6 | return " ".join([vocab[index] for index in token_indices]) 7 | -------------------------------------------------------------------------------- /pytext/metric_reporters/tests/classification_metric_reporter_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from unittest import TestCase 5 | 6 | from pytext.metric_reporters.classification_metric_reporter import ( 7 | compute_topk_classification_metrics, 8 | ) 9 | from pytext.metrics import ClassificationMetrics, LabelTopKPrediction 10 | 11 | 12 | class TestClassificationMetricReporter(TestCase): 13 | test_labels = ["hello", "hi", "hey", "how are you"] 14 | 15 | def test_compute_topk_classification_metrics_zero_correct(self): 16 | metrics: ClassificationMetrics = compute_topk_classification_metrics( 17 | predictions=[ 18 | LabelTopKPrediction([0.5, 0.3, 0.2], [0, 1, 2], 3), 19 | LabelTopKPrediction([0.5, 0.3, 0.2], [0, 2, 3], 1), 20 | ], 21 | label_names=self.test_labels, 22 | loss=0, 23 | ) 24 | 25 | self.assertEqual(0, metrics.accuracy) 26 | 27 | def test_compute_topk_classification_metrics_two_thirds_correct(self): 28 | metrics: ClassificationMetrics = compute_topk_classification_metrics( 29 | predictions=[ 30 | LabelTopKPrediction([0.5, 0.3, 0.2], [0, 1, 2], 0), 31 | LabelTopKPrediction([0.5, 0.3, 0.2], [0, 2, 3], 2), 32 | LabelTopKPrediction([0.5, 0.3, 0.2], [1, 2, 3], 0), 33 | ], 34 | label_names=self.test_labels, 35 | loss=0, 36 | ) 37 | 38 | self.assertEqual(2 / 3, metrics.accuracy) 39 | -------------------------------------------------------------------------------- /pytext/metrics/dense_retrieval_metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from typing import NamedTuple 5 | 6 | 7 | class DenseRetrievalMetrics(NamedTuple): 8 | """ 9 | Metric class for dense passage retrieval. 10 | 11 | Attributes: 12 | num_examples (int): number of samples 13 | accuracy (float): how many times did we get the +ve doc from list of docs 14 | average_rank (float): average rank of positive passage 15 | mean_reciprocal_rank (float): average 1/rank of positive passage 16 | """ 17 | 18 | num_examples: int 19 | accuracy: float 20 | average_rank: float 21 | mean_reciprocal_rank: float 22 | 23 | def print_metrics(self) -> None: 24 | print(f"Number of samples = {self.num_examples}") 25 | print(f"Accuracy = {self.accuracy * 100:.2f}") 26 | print(f"Average Rank = {self.average_rank}") 27 | print(f"Mean Reciprocal Rank = {self.mean_reciprocal_rank}") 28 | -------------------------------------------------------------------------------- /pytext/metrics/language_model_metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import math 5 | from typing import NamedTuple 6 | 7 | 8 | """ 9 | Language model metric utilities. 10 | """ 11 | 12 | 13 | class LanguageModelMetric(NamedTuple): 14 | """ 15 | Class for language model metrics. 16 | 17 | Attributes: 18 | perplexity_per_word: Average perplexity per word of the dataset. 19 | """ 20 | 21 | perplexity_per_word: float 22 | 23 | def print_metrics(self): 24 | print(f"Perplexity per word : {self.perplexity_per_word: 0.2f}") 25 | 26 | 27 | def compute_language_model_metric(loss_per_word: float) -> LanguageModelMetric: 28 | try: 29 | ppl = math.exp(loss_per_word) 30 | except OverflowError: 31 | ppl = float("inf") 32 | return LanguageModelMetric(perplexity_per_word=ppl) 33 | -------------------------------------------------------------------------------- /pytext/metrics/squad_metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from typing import NamedTuple 5 | 6 | from pytext.metrics import ClassificationMetrics 7 | 8 | 9 | class SquadMetrics(NamedTuple): 10 | classification_metrics: ClassificationMetrics 11 | num_examples: int = 1 12 | exact_matches: float = -1.0 13 | f1_score: float = -1.0 14 | f1_score_pos_only: float = -1.0 15 | 16 | def print_metrics(self) -> None: 17 | print(f"Number of Examples = {self.num_examples}") 18 | print(f"Exact Matches = {self.exact_matches:.2f} %") 19 | print(f"Token Level F1 Score = {self.f1_score:.2f} %") 20 | print( 21 | f"Token Level F1 Score for positive examples = {self.f1_score_pos_only:.2f} %" 22 | ) 23 | if self.classification_metrics: 24 | # this is NoneType if we ignore_impossible. 25 | print("======= Has Answer Classification Metrics =======") 26 | self.classification_metrics.print_metrics() 27 | -------------------------------------------------------------------------------- /pytext/metrics/tests/calibration_metrics_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import unittest 5 | 6 | from pytext.metrics.calibration_metrics import ( 7 | calculate_error, 8 | get_bucket_accuracy, 9 | get_bucket_confidence, 10 | get_bucket_scores, 11 | ) 12 | 13 | 14 | class CalibrationUtilsTest(unittest.TestCase): 15 | def test_calibration(self): 16 | buckets = 10 17 | conf_list = [0.84, 0.98, 0.97, 0.76, 0.59, 0.62, 0.40, 0.33, 0.54, 0.37] 18 | true_list = [1, 5, 3, 2, 4, 2, 7, 4, 5, 2] 19 | pred_list = [1, 5, 3, 5, 2, 1, 7, 4, 5, 3] 20 | 21 | bucket_values, bucket_indices = get_bucket_scores(conf_list, buckets) 22 | bucket_confidence = get_bucket_confidence(bucket_values) 23 | bucket_accuracy = get_bucket_accuracy(bucket_indices, true_list, pred_list) 24 | expected_error, max_error, total_error = calculate_error( 25 | len(conf_list), bucket_values, bucket_confidence, bucket_accuracy 26 | ) 27 | 28 | self.assertEqual( 29 | bucket_values, 30 | [ 31 | [], 32 | [], 33 | [], 34 | [0.33, 0.37], 35 | [0.4], 36 | [0.59, 0.54], 37 | [0.62], 38 | [0.76], 39 | [0.84], 40 | [0.98, 0.97], 41 | ], 42 | ) 43 | self.assertEqual( 44 | bucket_indices, [[], [], [], [7, 9], [6], [4, 8], [5], [3], [0], [1, 2]] 45 | ) 46 | self.assertEqual( 47 | bucket_confidence, 48 | [-1.0, -1.0, -1.0, 0.35, 0.4, 0.565, 0.62, 0.76, 0.84, 0.975], 49 | ) 50 | self.assertEqual( 51 | bucket_accuracy, [-1.0, -1.0, -1.0, 0.5, 1.0, 0.5, 0.0, 0.0, 1.0, 1.0] 52 | ) 53 | self.assertAlmostEqual(expected_error, 26.2) 54 | self.assertAlmostEqual(max_error, 76.0) 55 | self.assertAlmostEqual(total_error, 238.0) 56 | -------------------------------------------------------------------------------- /pytext/metrics/tests/metrics_test_base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from typing import Any 5 | from unittest import TestCase 6 | 7 | 8 | class MetricsTestBase(TestCase): 9 | def assertMetricsAlmostEqual(self, first: Any, second: Any) -> None: 10 | self.assertEqual(type(first), type(second)) 11 | if first is None: 12 | return 13 | elif isinstance(first, int): 14 | self.assertEqual(first, second) 15 | elif isinstance(first, float): 16 | self.assertAlmostEqual(first, second) 17 | elif isinstance(first, dict): 18 | self.assertEqual(first.keys(), second.keys()) 19 | for key in first.keys(): 20 | self.assertMetricsAlmostEqual(first[key], second[key]) 21 | # Then "first" and "second" should be of type NamedTuple. 22 | else: 23 | self.assertEqual(first._fields, second._fields) 24 | for attr in first._fields: 25 | self.assertMetricsAlmostEqual( 26 | getattr(first, attr), getattr(second, attr) 27 | ) 28 | -------------------------------------------------------------------------------- /pytext/metrics/tests/multilabel_metrics_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from pytext.metrics import ( 5 | compute_multi_label_classification_metrics, 6 | LabelListPrediction, 7 | ) 8 | from pytext.metrics.tests.metrics_test_base import MetricsTestBase 9 | 10 | 11 | LABEL_NAMES = ["label1", "label2", "label3"] 12 | PREDICTIONS = [ 13 | LabelListPrediction(scores, predicted, expected) 14 | for scores, predicted, expected in [ 15 | ([-0.5, -0.7, -0.8], [1, 0, 0], [0]), 16 | ([-0.9, -0.2, -0.9], [0, 1, 0], [2]), 17 | ([-0.7, -0.4, -0.7], [0, 1, 0], [1]), 18 | ([-0.8, -0.9, -0.3], [0, 0, 1], [1]), 19 | ] 20 | ] 21 | 22 | 23 | class BasicMetricsTest(MetricsTestBase): 24 | def test_compute_multi_label_classification_metrics(self) -> None: 25 | 26 | roc_auc_dict = {"label1": 1.0, "label2": 0.25, "label3": 0.0} 27 | 28 | metrics = compute_multi_label_classification_metrics( 29 | PREDICTIONS, LABEL_NAMES, loss=5.0 30 | ) 31 | self.assertAlmostEqual(metrics.roc_auc, 1.25 / 3) 32 | for k, v in metrics.per_label_soft_scores.items(): 33 | metric_value = getattr(v, "roc_auc", None) 34 | self.assertAlmostEqual(metric_value, roc_auc_dict[k]) 35 | -------------------------------------------------------------------------------- /pytext/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from .model import BaseModel, Model 4 | from .tri_tower_classification_model import TriTowerClassificationModel 5 | from .two_tower_classification_model import TwoTowerClassificationModel 6 | from .two_tower_regression_model import TwoTowerRegressionModel 7 | 8 | 9 | __all__ = [ 10 | "Model", 11 | "BaseModel", 12 | "TwoTowerClassificationModel", 13 | "TriTowerClassificationModel", 14 | "TwoTowerRegressionModel", 15 | ] 16 | -------------------------------------------------------------------------------- /pytext/models/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from .decoder_base import DecoderBase 5 | from .intent_slot_model_decoder import IntentSlotModelDecoder 6 | from .mlp_decoder import MLPDecoder 7 | 8 | 9 | __all__ = ["DecoderBase", "MLPDecoder", "IntentSlotModelDecoder"] 10 | -------------------------------------------------------------------------------- /pytext/models/decoders/decoder_base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from pytext.config import ConfigBase 5 | from pytext.models.module import Module 6 | from pytext.utils.usage import log_class_usage 7 | 8 | 9 | class DecoderBase(Module): 10 | """Base class for all decoder modules. 11 | 12 | Args: 13 | config (ConfigBase): Configuration object. 14 | 15 | Attributes: 16 | in_dim (int): Dimension of input Tensor passed to the decoder. 17 | out_dim (int): Dimension of output Tensor produced by the decoder. 18 | 19 | """ 20 | 21 | def __init__(self, config: ConfigBase): 22 | super().__init__(config) 23 | self.input_dim = 0 24 | self.target_dim = 0 25 | self.num_decoder_modules = 0 26 | log_class_usage(__class__) 27 | 28 | def forward(self, *input): 29 | raise NotImplementedError() 30 | 31 | def get_decoder(self): 32 | """Returns the decoder module.""" 33 | raise NotImplementedError() 34 | 35 | def get_in_dim(self) -> int: 36 | """Returns the dimension of the input Tensor that the decoder accepts.""" 37 | return self.in_dim 38 | 39 | def get_out_dim(self) -> int: 40 | """Returns the dimension of the input Tensor that the decoder emits.""" 41 | return self.out_dim 42 | -------------------------------------------------------------------------------- /pytext/models/decoders/mlp_decoder_query_response.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from typing import List 5 | 6 | import torch 7 | import torch.nn as nn 8 | from pytext.utils.usage import log_class_usage 9 | 10 | from .decoder_base import DecoderBase 11 | 12 | 13 | class MLPDecoderQueryResponse(DecoderBase): 14 | """ 15 | Implements a 'two-tower' MLP: one for query and one for response 16 | Used in search pairwise ranking: both pos_response and neg_response 17 | use the response-MLP 18 | """ 19 | 20 | class Config(DecoderBase.Config): 21 | # Intermediate hidden dimensions 22 | hidden_dims: List[int] = [] 23 | 24 | def __init__(self, config: Config, from_dim: int, to_dim: int) -> None: 25 | super().__init__(config) 26 | self.mlp_for_response = MLPDecoderQueryResponse.get_mlp( 27 | from_dim, to_dim, config.hidden_dims 28 | ) 29 | self.mlp_for_query = MLPDecoderQueryResponse.get_mlp( 30 | from_dim, to_dim, config.hidden_dims 31 | ) 32 | self.out_dim = (3, to_dim) 33 | log_class_usage 34 | 35 | @staticmethod 36 | def get_mlp(from_dim: int, to_dim: int, hidden_dims: List[int]): 37 | layers = [] 38 | current_dim = from_dim 39 | for dim in hidden_dims or []: 40 | layers.append(nn.Linear(current_dim, dim)) 41 | layers.append(nn.ReLU()) 42 | current_dim = dim 43 | layers.append(nn.Linear(current_dim, to_dim)) 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, *x: List[torch.Tensor]) -> List[torch.Tensor]: 47 | output = [] 48 | assert len(x) == 3 49 | output.append(self.mlp_for_response(x[0])) 50 | output.append(self.mlp_for_response(x[1])) 51 | output.append(self.mlp_for_query(x[2])) 52 | return output 53 | 54 | def get_decoder(self) -> List[nn.Module]: 55 | return [self.mlp_for_response, self.mlp_for_query] 56 | -------------------------------------------------------------------------------- /pytext/models/decoders/multilabel_decoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from typing import Dict, List 5 | 6 | import torch 7 | import torch.nn as nn 8 | from pytext.utils.usage import log_class_usage 9 | 10 | from .decoder_base import DecoderBase 11 | 12 | 13 | class MultiLabelDecoder(DecoderBase): 14 | """ 15 | Implements a 'n-tower' MLP: one for each of the multi labels 16 | Used in USM/EA: the user satisfaction modeling, pTSR prediction and 17 | Error Attribution are all 3 label sets that need predicting. 18 | 19 | """ 20 | 21 | class Config(DecoderBase.Config): 22 | # Intermediate hidden dimensions 23 | hidden_dims: List[int] = [] 24 | 25 | def __init__( 26 | self, 27 | config: Config, 28 | in_dim: int, 29 | output_dim: Dict[str, int], 30 | label_names: List[str], 31 | ) -> None: 32 | super().__init__(config) 33 | self.label_mlps = nn.ModuleDict({}) 34 | # Store the ordered list to preserve the ordering of the labels 35 | # when generating the output layer 36 | self.label_names = label_names 37 | aggregate_out_dim = 0 38 | for label_, _ in output_dim.items(): 39 | self.label_mlps[label_] = MultiLabelDecoder.get_mlp( 40 | in_dim, output_dim[label_], config.hidden_dims 41 | ) 42 | aggregate_out_dim += output_dim[label_] 43 | self.out_dim = (1, aggregate_out_dim) 44 | log_class_usage(__class__) 45 | 46 | @staticmethod 47 | def get_mlp(in_dim: int, out_dim: int, hidden_dims: List[int]): 48 | layers = [] 49 | current_dim = in_dim 50 | for dim in hidden_dims or []: 51 | layers.append(nn.Linear(current_dim, dim)) 52 | layers.append(nn.ReLU()) 53 | current_dim = dim 54 | layers.append(nn.Linear(current_dim, out_dim)) 55 | return nn.Sequential(*layers) 56 | 57 | def forward(self, *input: torch.Tensor): 58 | logits = tuple( 59 | self.label_mlps[x](torch.cat(input, 1)) for x in self.label_names 60 | ) 61 | return logits 62 | 63 | def get_decoder(self) -> List[nn.Module]: 64 | return self.label_mlps 65 | -------------------------------------------------------------------------------- /pytext/models/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from .char_embedding import CharacterEmbedding 4 | from .contextual_token_embedding import ContextualTokenEmbedding 5 | from .dict_embedding import DictEmbedding 6 | from .embedding_base import EmbeddingBase 7 | from .embedding_list import EmbeddingList 8 | from .int_single_category_embedding import IntSingleCategoryEmbedding 9 | from .int_weighted_multi_category_embedding import IntWeightedMultiCategoryEmbedding 10 | from .mlp_embedding import MLPEmbedding 11 | from .word_embedding import WordEmbedding 12 | from .word_seq_embedding import WordSeqEmbedding 13 | 14 | __all__ = [ 15 | "EmbeddingBase", 16 | "EmbeddingList", 17 | "WordEmbedding", 18 | "DictEmbedding", 19 | "CharacterEmbedding", 20 | "ContextualTokenEmbedding", 21 | "WordSeqEmbedding", 22 | "MLPEmbedding", 23 | "IntSingleCategoryEmbedding", 24 | "IntWeightedMultiCategoryEmbedding", 25 | ] 26 | -------------------------------------------------------------------------------- /pytext/models/embeddings/contextual_token_embedding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from typing import Optional 4 | 5 | import torch 6 | from pytext.config.field_config import ContextualTokenEmbeddingConfig 7 | from pytext.models.seq_models.base import PlaceholderIdentity 8 | from torch.nn import Linear 9 | 10 | from .embedding_base import EmbeddingBase 11 | 12 | 13 | class ContextualTokenEmbedding(EmbeddingBase): 14 | """Module for providing token embeddings from a pretrained model.""" 15 | 16 | Config = ContextualTokenEmbeddingConfig 17 | 18 | @classmethod 19 | def from_config(cls, config: ContextualTokenEmbeddingConfig, *args, **kwargs): 20 | return cls(config.embed_dim, downsample_dim=config.downsample_dim) 21 | 22 | def __init__(self, embed_dim: int, downsample_dim: Optional[int] = None) -> None: 23 | super().__init__(embed_dim) 24 | self.input_embed_dim = embed_dim 25 | if downsample_dim: 26 | self.proj = Linear(embed_dim, downsample_dim) 27 | self.embedding_dim = downsample_dim 28 | else: 29 | self.proj = PlaceholderIdentity() 30 | 31 | def forward(self, embedding: torch.Tensor) -> torch.Tensor: 32 | embedding_shape = torch.onnx.operators.shape_as_tensor(embedding) 33 | 34 | # Since embeddings vector is flattened, verify its shape correctness. 35 | if embedding_shape[1].item() % self.input_embed_dim != 0: 36 | raise ValueError( 37 | f"Input embedding_dim {embedding_shape[1]} is not a" 38 | + f" multiple of specified embedding_dim {self.input_embed_dim}" 39 | ) 40 | 41 | # Unflatten embedding Tensor from (batch_size, seq_len * embedding_size) 42 | # to (batch_size, seq_len, embedding_size). 43 | num_tokens = embedding_shape[1] // self.input_embed_dim 44 | new_embedding_shape = torch.cat( 45 | ( 46 | torch.tensor([-1], dtype=torch.long), 47 | num_tokens.view(1), 48 | torch.tensor([self.input_embed_dim], dtype=torch.long), 49 | ) 50 | ) 51 | reshaped_embed = torch.onnx.operators.reshape_from_tensor_shape( 52 | embedding, new_embedding_shape 53 | ) 54 | return self.proj(reshaped_embed) 55 | -------------------------------------------------------------------------------- /pytext/models/embeddings/embedding_base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from typing import Dict, List 5 | 6 | import torch.nn as nn 7 | from pytext.models.module import Module 8 | from pytext.utils.usage import log_class_usage 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | 12 | class EmbeddingBase(Module): 13 | """Base class for token level embedding modules. 14 | 15 | Args: 16 | embedding_dim (int): Size of embedding vector. 17 | 18 | Attributes: 19 | num_emb_modules (int): Number of ways to embed a token. 20 | embedding_dim (int): Size of embedding vector. 21 | 22 | """ 23 | 24 | __EXPANSIBLE__ = True 25 | 26 | def __init__(self, embedding_dim: int): 27 | super().__init__() 28 | # By default has 1 embedding which is itself, for EmbeddingList, this num 29 | # can be greater than 1 30 | self.num_emb_modules = 1 31 | self.embedding_dim = embedding_dim 32 | log_class_usage(__class__) 33 | 34 | def visualize(self, summary_writer: SummaryWriter): 35 | """ 36 | Overridden in sub classes to implement Tensorboard visualization of 37 | embedding space 38 | """ 39 | pass 40 | -------------------------------------------------------------------------------- /pytext/models/ensembles/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from .bagging_doc_ensemble import BaggingDocEnsembleModel 4 | from .bagging_intent_slot_ensemble import BaggingIntentSlotEnsembleModel 5 | from .ensemble import EnsembleModel 6 | 7 | 8 | __all__ = ["BaggingDocEnsembleModel", "BaggingIntentSlotEnsembleModel", "EnsembleModel"] 9 | -------------------------------------------------------------------------------- /pytext/models/ensembles/bagging_doc_ensemble.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from typing import List 4 | 5 | import torch 6 | from pytext.models.doc_model import DocModel 7 | 8 | from .ensemble import EnsembleModel 9 | 10 | 11 | class BaggingDocEnsembleModel(EnsembleModel): 12 | """Ensemble class that uses bagging for ensembling document classification 13 | models. 14 | """ 15 | 16 | class Config(EnsembleModel.Config): 17 | """Configuration class for `NewBaggingDocEnsemble`. These attributes are 18 | used by `Ensemble.from_config()` to construct instance of 19 | `NewBaggingDocEnsemble`. 20 | 21 | Attributes: 22 | models (List[NewDocModel.Config]): List of document classification 23 | model configurations. 24 | 25 | """ 26 | 27 | models: List[DocModel.Config] 28 | 29 | def forward(self, *args, **kwargs) -> torch.Tensor: 30 | """Call `forward()` method of each document classification sub-model by 31 | passing all arguments and named arguments to the sub-models, collect the 32 | logits from them and average their values. 33 | 34 | Returns: 35 | torch.Tensor: Logits from the ensemble. 36 | 37 | """ 38 | logit_d_list = torch.cat( 39 | tuple(model.forward(*args, **kwargs).unsqueeze(2) for model in self.models), 40 | dim=2, 41 | ) 42 | 43 | return torch.mean(logit_d_list, dim=2) 44 | -------------------------------------------------------------------------------- /pytext/models/language_models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | -------------------------------------------------------------------------------- /pytext/models/output_layers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from .distance_output_layer import ( 5 | DenseRetrievalOutputLayer, 6 | PairwiseCosineDistanceOutputLayer, 7 | ) 8 | from .doc_classification_output_layer import ClassificationOutputLayer 9 | from .doc_regression_output_layer import ( 10 | PairwiseCosineRegressionOutputLayer, 11 | RegressionOutputLayer, 12 | ) 13 | from .output_layer_base import OutputLayerBase 14 | from .pairwise_ranking_output_layer import PairwiseRankingOutputLayer 15 | from .utils import OutputLayerUtils 16 | from .word_tagging_output_layer import CRFOutputLayer, WordTaggingOutputLayer 17 | 18 | 19 | __all__ = [ 20 | "OutputLayerBase", 21 | "CRFOutputLayer", 22 | "ClassificationOutputLayer", 23 | "RegressionOutputLayer", 24 | "WordTaggingOutputLayer", 25 | "PairwiseRankingOutputLayer", 26 | "PairwiseCosineDistanceOutputLayer", 27 | "PairwiseCosineRegressionOutputLayer", 28 | "DenseRetrievalOutputLayer", 29 | "OutputLayerUtils", 30 | ] 31 | -------------------------------------------------------------------------------- /pytext/models/output_layers/pairwise_ranking_output_layer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | 5 | from pytext.config.component import create_loss 6 | from pytext.loss import PairwiseRankingLoss 7 | 8 | from .output_layer_base import OutputLayerBase 9 | 10 | 11 | class PairwiseRankingOutputLayer(OutputLayerBase): 12 | @classmethod 13 | def from_config(cls, config): 14 | return cls(None, create_loss(config.loss), config) 15 | 16 | class Config(OutputLayerBase.Config): # noqa: T484 17 | loss: PairwiseRankingLoss.Config = PairwiseRankingLoss.Config() 18 | 19 | def get_pred(self, logit, targets, context): 20 | pos_similarity, neg_similarity, _sz = PairwiseRankingLoss.get_similarities( 21 | logit 22 | ) 23 | preds = pos_similarity > neg_similarity 24 | scores = pos_similarity - neg_similarity 25 | return preds, scores 26 | -------------------------------------------------------------------------------- /pytext/models/output_layers/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from typing import List 5 | 6 | import torch 7 | from caffe2.python import core 8 | from torch import Tensor 9 | 10 | 11 | class OutputLayerUtils: 12 | @staticmethod 13 | def gen_additional_blobs( 14 | predict_net: core.Net, 15 | probability_out, 16 | model_out: torch.Tensor, 17 | output_name: str, 18 | label_names: List[str], 19 | ) -> List[core.BlobReference]: 20 | """ 21 | Utility method to generate additional blobs for human readable result for 22 | models that use explicit labels. 23 | """ 24 | res = [] 25 | tmp_out_score = predict_net.Log(probability_out) 26 | label_scores = predict_net.Split( 27 | tmp_out_score, label_names, axis=model_out.dim() - 1 28 | ) 29 | 30 | # Make sure label_scores is iterable 31 | if not isinstance(label_scores, tuple): 32 | label_scores = (label_scores,) 33 | for name, label_score in zip(label_names, label_scores): 34 | res.append(predict_net.Copy(label_score, "{}:{}".format(output_name, name))) 35 | return res 36 | 37 | 38 | def query_word_reprs(encoder_repr: Tensor, token_indices: Tensor) -> Tensor: 39 | """ 40 | Given an encoder_repr (B x T_1 x H) and token_indices (B x T_2) where T_2 <= T_1, 41 | collect embeddings from encoder_repr pertaining to indices in token_indices. In the 42 | context of fine-tuning pre-trained encoders on sequence labeling, our goal is to 43 | build token-level representations as opposed to subword-level represenatations 44 | for alignment with other token-level cues, such as dictionary features. Currently, 45 | a token representation is built by taking its first subword representation. 46 | """ 47 | 48 | return torch.gather( 49 | encoder_repr, 50 | 1, 51 | token_indices.unsqueeze(2).expand(-1, -1, encoder_repr.size(-1)), 52 | ) 53 | -------------------------------------------------------------------------------- /pytext/models/qna/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/pytext/08754b483421884d2e363f00517ea42e449aec2c/pytext/models/qna/__init__.py -------------------------------------------------------------------------------- /pytext/models/representations/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | -------------------------------------------------------------------------------- /pytext/models/representations/docnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from pytext.config import ConfigBase 8 | from pytext.config.module_config import CNNParams, PoolingType 9 | from pytext.utils.usage import log_class_usage 10 | 11 | from .representation_base import RepresentationBase 12 | 13 | 14 | class DocNNRepresentation(RepresentationBase): 15 | """CNN based representation of a document.""" 16 | 17 | class Config(RepresentationBase.Config): 18 | dropout: float = 0.4 19 | cnn: CNNParams = CNNParams() 20 | pooling: PoolingType = PoolingType.MAX 21 | 22 | def __init__(self, config: Config, embed_dim: int) -> None: 23 | super().__init__(config) 24 | self.max_kernel = max(config.cnn.kernel_sizes) 25 | self.convs = nn.ModuleList( 26 | [ 27 | nn.Conv1d(embed_dim, config.cnn.kernel_num, K, padding=K) 28 | for K in config.cnn.kernel_sizes 29 | ] 30 | ) 31 | self.dropout = nn.Dropout(config.dropout) 32 | self.representation_dim = len(config.cnn.kernel_sizes) * config.cnn.kernel_num 33 | self.pooling_type = config.pooling 34 | log_class_usage(__class__) 35 | 36 | def forward(self, embedded_tokens: torch.Tensor, *args) -> torch.Tensor: 37 | # embedded_tokens of size (N,W,D) 38 | rep = embedded_tokens 39 | # nn.Conv1d expects a tensor of dim (batch_size x embed_dim x seq_len) 40 | rep = rep.transpose(1, 2) 41 | rep = [self.conv_and_pool(rep, conv) for conv in self.convs] 42 | rep = self.dropout(torch.cat(rep, 1)) # (N,len(Ks)*Co) 43 | return rep 44 | 45 | def conv_and_pool(self, x, conv): 46 | x = F.relu(conv(x)) 47 | if self.pooling_type == PoolingType.MAX: 48 | x, _ = torch.max(x, dim=2) 49 | elif self.pooling_type == PoolingType.MEAN: 50 | x = torch.mean(x, dim=2) 51 | elif self.pooling_type == PoolingType.LOGSUMEXP: 52 | x = torch.logsumexp(x, dim=2) 53 | else: 54 | raise NotImplementedError 55 | return x 56 | -------------------------------------------------------------------------------- /pytext/models/representations/pass_through.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import torch 5 | from pytext.utils.usage import log_class_usage 6 | 7 | from .representation_base import RepresentationBase 8 | 9 | 10 | class PassThroughRepresentation(RepresentationBase): 11 | def __init__(self, config: RepresentationBase.Config, embed_dim: int) -> None: 12 | super().__init__(config) 13 | self.representation_dim = embed_dim 14 | log_class_usage(__class__) 15 | 16 | def forward(self, embedded_tokens: torch.Tensor, *args) -> torch.Tensor: 17 | return embedded_tokens 18 | -------------------------------------------------------------------------------- /pytext/models/representations/representation_base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from pytext.models.module import Module 5 | 6 | 7 | class RepresentationBase(Module): 8 | def __init__(self, config): 9 | super().__init__(config) 10 | self.representation_dim = None 11 | 12 | def forward(self, *inputs): 13 | raise NotImplementedError() 14 | 15 | def get_representation_dim(self): 16 | return self.representation_dim 17 | 18 | def _preprocess_inputs(self, inputs): 19 | raise NotImplementedError() 20 | -------------------------------------------------------------------------------- /pytext/models/representations/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | 5 | """ 6 | This directory contains modules for implementing a productionized RoBERTa model. 7 | These modules implement the same Transformer components that are implemented in 8 | the fairseq library, however they're distilled down to just the elements which 9 | are used in the final RoBERTa model, and within that are restructured and 10 | rewritten to be able to be compiled by TorchScript for production use cases. 11 | 12 | The SentenceEncoder specifically can be used to load model weights directly from 13 | the publicly release RoBERTa weights, and it will translate these weights to 14 | the corresponding values in this implementation. 15 | """ 16 | 17 | from pytorch.text.fb.nn.modules.multihead_attention import MultiheadSelfAttention 18 | from pytorch.text.fb.nn.modules.positional_embedding import PositionalEmbedding 19 | from pytorch.text.fb.nn.modules.residual_mlp import GeLU, ResidualMLP 20 | from pytorch.text.fb.nn.modules.transformer import ( 21 | PassthroughTransformer, 22 | SELFIETransformer, 23 | Transformer, 24 | TransformerLayer, 25 | TransformerPrefixLayer, 26 | ) 27 | 28 | from .luna_attention import LunarCausalAttention, LunarMultiheadAttention 29 | from .luna_sentence_encoder import LunaSentenceEncoder 30 | from .multihead_linear_attention import ( 31 | MultiheadLinearAttention, 32 | QuantizedMultiheadLinearAttention, 33 | ) 34 | from .representation import TransformerRepresentation 35 | from .sentence_encoder import PassthroughEncoder, PostEncoder, SentenceEncoder # noqa 36 | 37 | 38 | __all__ = [ 39 | "MultiheadLinearAttention", 40 | "LunarMultiheadAttention", 41 | "LunarCausalAttention", 42 | "LunaSentenceEncoder", 43 | "QuantizedMultiheadLinearAttention", 44 | "MultiheadSelfAttention", 45 | "PassthroughTransformer", 46 | "PositionalEmbedding", 47 | "ResidualMLP", 48 | "SentenceEncoder", 49 | "PostEncoder", 50 | "SELFIETransformer", 51 | "Transformer", 52 | "TransformerLayer", 53 | "TransformerRepresentation", 54 | "GeLU", 55 | "TransformerPrefixLayer", 56 | ] 57 | -------------------------------------------------------------------------------- /pytext/models/semantic_parsers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | -------------------------------------------------------------------------------- /pytext/models/semantic_parsers/rnng/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | -------------------------------------------------------------------------------- /pytext/models/semantic_parsers/rnng/rnng_constant.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | ROOT_ELEMENT = -1 5 | TERMINAL_ELEMENT = -1 6 | TREE_ELEMENT = -100 7 | BATCH_SIZE = 1 8 | -------------------------------------------------------------------------------- /pytext/models/seq_models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | -------------------------------------------------------------------------------- /pytext/models/seq_models/nar_output_layer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from typing import Any, Dict, Tuple, Union 5 | 6 | import torch 7 | from pytext.config import ConfigBase 8 | from pytext.config.component import create_loss 9 | from pytext.data.utils import Vocabulary 10 | from pytext.loss import NARSamplewiseSequenceLoss, NARSequenceLoss, StructuredLoss 11 | from pytext.models.output_layers import OutputLayerBase 12 | 13 | 14 | class NARSeq2SeqOutputLayer(OutputLayerBase): 15 | """Non-autoregressive seq2seq output layer.""" 16 | 17 | class Config(ConfigBase): 18 | loss: Union[ 19 | NARSequenceLoss.Config, NARSamplewiseSequenceLoss.Config 20 | ] = NARSequenceLoss.Config() 21 | 22 | @classmethod 23 | def from_config(cls, config: Config, vocab: Vocabulary): 24 | return cls( 25 | vocab._vocab, create_loss(config.loss, ignore_index=vocab.get_pad_index()) 26 | ) 27 | 28 | def get_loss( 29 | self, 30 | model_outputs: Tuple[torch.Tensor, Dict[str, torch.Tensor]], 31 | targets: Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], 32 | context: Dict[str, Any] = None, 33 | reduce=True, 34 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 35 | """ 36 | label_logits: B x T x V_1 37 | label_targets: B x T 38 | length_logits: B x V_2 39 | length_targets: B 40 | """ 41 | 42 | label_logits, output_dict = model_outputs 43 | length_logits = output_dict["predicted_tgt_lengths"] 44 | (_, label_targets), length_targets = targets 45 | 46 | # Structured losses require access to sequences in each batch, so don't 47 | # flatten logits and targets for these. 48 | if not isinstance(self.loss_fn.label_loss_fn.label_loss_fn, StructuredLoss): 49 | label_logits = label_logits.view(-1, label_logits.size(-1)) # (B x T) x V 50 | label_targets = label_targets.view(-1) # (B x T) 51 | 52 | loss, two_losses = self.loss_fn( 53 | label_logits, 54 | label_targets, 55 | length_logits, 56 | length_targets, 57 | reduce, 58 | ) 59 | return loss, two_losses 60 | -------------------------------------------------------------------------------- /pytext/models/seq_models/rnn_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from typing import Dict, List, Optional 4 | 5 | import torch.jit 6 | from pytext.config import ConfigBase 7 | from pytext.models.module import create_module 8 | from pytext.utils.usage import log_class_usage 9 | 10 | from .base import PyTextSeq2SeqModule 11 | from .rnn_decoder import RNNDecoder 12 | from .rnn_encoder import LSTMSequenceEncoder 13 | 14 | 15 | class RNNModel(PyTextSeq2SeqModule): 16 | class Config(ConfigBase): 17 | encoder: LSTMSequenceEncoder.Config = LSTMSequenceEncoder.Config() 18 | decoder: RNNDecoder.Config = RNNDecoder.Config() 19 | 20 | def __init__(self, encoder, decoder, source_embeddings): 21 | super().__init__() 22 | self.source_embeddings = source_embeddings 23 | self.encoder = encoder 24 | self.decoder = decoder 25 | log_class_usage(__class__) 26 | 27 | def forward( 28 | self, 29 | src_tokens: torch.Tensor, 30 | additional_features: List[List[torch.Tensor]], 31 | src_lengths, 32 | prev_output_tokens, 33 | incremental_state: Optional[Dict[str, torch.Tensor]] = None, 34 | ): 35 | # embed tokens 36 | embeddings = self.source_embeddings([[src_tokens]] + additional_features) 37 | 38 | # n.b. tensorized_features[0][0] must be src_tokens 39 | encoder_out = self.encoder(src_tokens, embeddings, src_lengths=src_lengths) 40 | decoder_out = self.decoder(prev_output_tokens, encoder_out, incremental_state) 41 | return decoder_out 42 | 43 | @classmethod 44 | def from_config( 45 | cls, 46 | config: Config, 47 | source_vocab, 48 | source_embedding, 49 | target_vocab, 50 | target_embedding, 51 | ): 52 | out_vocab_size = len(target_vocab) 53 | encoder = create_module(config.encoder) 54 | decoder = create_module(config.decoder, out_vocab_size, target_embedding) 55 | return cls(encoder, decoder, source_embedding) 56 | 57 | def get_normalized_probs(self, net_output, log_probs, sample=None): 58 | return self.decoder.get_normalized_probs(net_output, log_probs, sample) 59 | 60 | def max_decoder_positions(self): 61 | return max(self.encoder.max_positions(), self.decoder.max_positions()) 62 | -------------------------------------------------------------------------------- /pytext/models/seq_models/seq2seq_output_layer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from typing import Any, Dict, Tuple, Union 5 | 6 | import torch 7 | from pytext.config import ConfigBase 8 | from pytext.config.component import create_loss 9 | from pytext.data.utils import Vocabulary 10 | from pytext.loss import CrossEntropyLoss, LabelSmoothedCrossEntropyLoss, NLLLoss 11 | from pytext.models.output_layers import OutputLayerBase 12 | 13 | 14 | class Seq2SeqOutputLayer(OutputLayerBase): 15 | class Config(ConfigBase): 16 | loss: Union[ 17 | CrossEntropyLoss.Config, 18 | LabelSmoothedCrossEntropyLoss.Config, 19 | NLLLoss.Config, 20 | ] = CrossEntropyLoss.Config() 21 | 22 | @classmethod 23 | def from_config(cls, config: Config, vocab: Vocabulary): 24 | return cls(vocab._vocab, create_loss(config.loss, vocab.get_pad_index())) 25 | 26 | def get_loss( 27 | self, 28 | model_outputs: Tuple[torch.Tensor, Dict[str, torch.Tensor]], 29 | targets: Tuple[torch.Tensor, torch.Tensor], 30 | context: Dict[str, Any] = None, 31 | reduce=True, 32 | ) -> torch.Tensor: 33 | # flatten the logit from [batch_size, seq_lens, dim] to 34 | # [batch_size * seq_lens, dim] 35 | logits = model_outputs[0] 36 | loss = self.loss_fn( 37 | logits.view(-1, logits.size()[-1]), targets[0].view(-1), reduce 38 | ) 39 | return loss 40 | -------------------------------------------------------------------------------- /pytext/models/seq_models/seqnn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from pytext.config import ConfigBase 5 | from pytext.data.tensorizers import SeqTokenTensorizer 6 | from pytext.models.decoders.mlp_decoder import MLPDecoder 7 | from pytext.models.doc_model import DocModel 8 | from pytext.models.model import Model 9 | from pytext.models.output_layers import ClassificationOutputLayer 10 | from pytext.models.representations.seq_rep import SeqRepresentation 11 | 12 | 13 | class SeqNNModel_Deprecated(Model): 14 | """ 15 | Classification model with sequence of utterances as input. 16 | It uses a docnn model (CNN or LSTM) to generate vector representation 17 | for each sequence, and then use an LSTM or BLSTM to capture the dynamics 18 | and produce labels for each sequence. 19 | 20 | DEPRECATED: Use SeqNNModel 21 | """ 22 | 23 | class Config(ConfigBase): 24 | representation: SeqRepresentation.Config = SeqRepresentation.Config() 25 | output_layer: ClassificationOutputLayer.Config = ( 26 | ClassificationOutputLayer.Config() 27 | ) 28 | decoder: MLPDecoder.Config = MLPDecoder.Config() 29 | 30 | 31 | class SeqNNModel(DocModel): 32 | """ 33 | Classification model with sequence of utterances as input. 34 | It uses a docnn model (CNN or LSTM) to generate vector representation 35 | for each sequence, and then use an LSTM or BLSTM to capture the dynamics 36 | and produce labels for each sequence. 37 | """ 38 | 39 | class Config(DocModel.Config): 40 | class ModelInput(DocModel.Config.ModelInput): 41 | tokens: SeqTokenTensorizer.Config = SeqTokenTensorizer.Config( 42 | column="text_seq" 43 | ) 44 | 45 | inputs: ModelInput = ModelInput() 46 | representation: SeqRepresentation.Config = SeqRepresentation.Config() 47 | 48 | def arrange_model_inputs(self, tensor_dict): 49 | tokens, _, seq_lens = tensor_dict["tokens"] 50 | model_inputs = (tokens, seq_lens) 51 | if "dense" in tensor_dict: 52 | model_inputs += (tensor_dict["dense"],) 53 | return model_inputs 54 | -------------------------------------------------------------------------------- /pytext/models/seq_models/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | import logging 4 | from typing import Dict, List, Optional 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn, Tensor 9 | 10 | 11 | def log_and_overwrite(param_name: str, x, y) -> int: 12 | if x != y: 13 | logging.warning(f"Mismatch of {param_name} expected {y} got {x}") 14 | return y 15 | 16 | 17 | def prepare_full_key(instance_id: str, key: str, secondary_key: Optional[str] = None): 18 | if secondary_key is not None: 19 | return instance_id + "." + key + "." + secondary_key 20 | else: 21 | return instance_id + "." + key 22 | 23 | 24 | def make_positions(input, padding_idx: int): 25 | """Replace non-padding symbols with their position numbers. 26 | 27 | Position numbers begin at padding_idx+1. Padding symbols are ignored. 28 | """ 29 | mask = input.ne(padding_idx) 30 | return torch.cumsum(mask, dim=1) * mask + padding_idx 31 | 32 | 33 | def unfold1d(x, kernel_size: int, padding_l: int, pad_value: float = 0): 34 | """unfold T x B x C to T x B x C x K""" 35 | if kernel_size > 1: 36 | T, B, C = x.size() 37 | x = F.pad( 38 | x, (0, 0, 0, 0, padding_l, kernel_size - 1 - padding_l), value=pad_value 39 | ) 40 | x = x.as_strided((T, B, C, kernel_size), (B * C, C, 1, B * C)) 41 | else: 42 | x = x.unsqueeze(3) 43 | return x 44 | 45 | 46 | def Linear(in_features, out_features, bias=True): 47 | m = nn.Linear(in_features, out_features, bias) 48 | nn.init.xavier_uniform_(m.weight) 49 | if bias: 50 | nn.init.constant_(m.bias, 0.0) 51 | return m 52 | 53 | 54 | def verify_encoder_out(encoder_out: Dict[str, Tensor], keys: List[str]): 55 | for key in keys: 56 | assert key in encoder_out, f"Needed {key} to be in {encoder_out.keys()}" 57 | 58 | 59 | def extract_ontology_vocab(target_dictionary): 60 | fixed_generation_vocab = [] 61 | for i, symbol in enumerate(target_dictionary._vocab): 62 | lower_symbol = symbol.lower() 63 | if lower_symbol[0] == "[" or lower_symbol == "]": 64 | fixed_generation_vocab.append(i) 65 | return fixed_generation_vocab 66 | -------------------------------------------------------------------------------- /pytext/models/test/bilstm_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import unittest 5 | 6 | import torch 7 | from pytext.models.representations import bilstm 8 | from torch import jit, nn 9 | 10 | 11 | VOCAB_SIZE = 10 12 | EMBEDDING_SIZE = 3 13 | 14 | 15 | class BiLSTMTest(unittest.TestCase): 16 | def test_trace_bilstm_differ_batch_size(self): 17 | # BiLSTM torch tracing was using torch.new_zeros for default input hidden 18 | # states, which doesn't trace properly. torch.jit traces torch.new_zeros as 19 | # constant and therefore locks the traced model into a static batch size. 20 | # torch.LSTM now uses zeros, adding test case here to verify behavior. 21 | # see https://github.com/pytorch/pytorch/issues/16664 22 | 23 | class Model(nn.Module): 24 | def __init__(self): 25 | super().__init__() 26 | self.embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_SIZE) 27 | self.bilstm = bilstm.BiLSTM(bilstm.BiLSTM.Config(), EMBEDDING_SIZE) 28 | 29 | def forward(self, tokens, seq_lengths): 30 | embeddings = self.embedding(tokens) 31 | return self.bilstm(embeddings, seq_lengths) 32 | 33 | model = Model() 34 | trace_inputs = ( 35 | torch.LongTensor([[2, 3, 4], [2, 2, 1]]), 36 | torch.LongTensor([3, 2]), 37 | ) 38 | 39 | trace = jit.trace(model, trace_inputs) 40 | 41 | test_inputs = (torch.LongTensor([[4, 5, 6]]), torch.LongTensor([3])) 42 | 43 | # we are just testing that this doesn't throw an exception 44 | trace(*test_inputs) 45 | -------------------------------------------------------------------------------- /pytext/models/test/dict_embedding_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | import unittest 4 | 5 | import torch 6 | from pytext.data.utils import PAD_INDEX, UNK_INDEX 7 | from pytext.models.embeddings.dict_embedding import DictEmbedding, PoolingType 8 | 9 | 10 | class DictEmbeddingTest(unittest.TestCase): 11 | def test_basic(self): 12 | # Setup embedding 13 | num_embeddings = 5 14 | output_dim = 6 15 | embedding_module = DictEmbedding( 16 | num_embeddings=num_embeddings, 17 | embed_dim=output_dim, 18 | pooling_type=PoolingType.MEAN, 19 | ) 20 | self.assertEqual(embedding_module.embedding.weight.size(0), num_embeddings) 21 | self.assertEqual(embedding_module.embedding.weight.size(1), output_dim) 22 | 23 | # The first and last tokens should be mapped to the zero vector. 24 | # This is due to the invariant that both unk and pad are considered 25 | # as padding indices. 26 | idx = torch.tensor( 27 | [UNK_INDEX, UNK_INDEX, 2, 3, 1, 1, 4, 1, PAD_INDEX, PAD_INDEX] 28 | ).unsqueeze(0) 29 | weights = torch.tensor( 30 | [0.3, 0.0, 0.8, 0.2, 0.0, 0.0, 1.0, 0.0, 0.3, 0.0], dtype=torch.float32 31 | ).unsqueeze(0) 32 | lens = torch.tensor([1, 2, 1, 1, 1]).unsqueeze(0) 33 | 34 | output = embedding_module(idx, weights, lens) 35 | 36 | self.assertAlmostEqual(output[0][0].sum().item(), 0) 37 | self.assertAlmostEqual(output[-1][-1].sum().item(), 0) 38 | -------------------------------------------------------------------------------- /pytext/models/test/mlp_embedding_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | import unittest 4 | 5 | import torch 6 | from pytext.models.embeddings.mlp_embedding import MLPEmbedding 7 | 8 | 9 | class MLPEmbeddingTest(unittest.TestCase): 10 | def test_basic(self): 11 | # Setup embedding 12 | output_dim = 16 13 | embedding_module = MLPEmbedding( 14 | embedding_dim=4, 15 | embeddings_weight=None, 16 | init_range=[-1, 1], 17 | mlp_layer_dims=[output_dim], 18 | ) 19 | self.assertEqual(embedding_module.embedding_dim, output_dim) 20 | 21 | # Check output shape 22 | input_batch_size, input_dim = 4, 4 23 | dense_features = torch.rand(size=[input_batch_size, input_dim]) 24 | output_embedding = embedding_module(dense_features) 25 | expected_output_dims = [input_batch_size, output_dim] 26 | self.assertEqual(list(output_embedding.size()), expected_output_dims) 27 | 28 | def test_multi_mlp_layer_dims(self): 29 | output_dim = 16 30 | embedding_module = MLPEmbedding( 31 | embedding_dim=4, 32 | embeddings_weight=None, 33 | init_range=[-1, 1], 34 | mlp_layer_dims=[64, output_dim], 35 | ) 36 | self.assertEqual(embedding_module.embedding_dim, output_dim) 37 | -------------------------------------------------------------------------------- /pytext/models/test/module_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import unittest 5 | 6 | from pytext.config.component import create_model 7 | from pytext.config.field_config import DictFeatConfig, FeatureConfig, WordFeatConfig 8 | from pytext.data import CommonMetadata 9 | from pytext.fields import FieldMeta 10 | from pytext.models.doc_model import DocModel 11 | 12 | 13 | class VocabStub: 14 | def __init__(self): 15 | self.itos = [] 16 | self.stoi = {} 17 | 18 | 19 | def mock_metadata(): 20 | meta = CommonMetadata 21 | field_meta = FieldMeta() 22 | field_meta.vocab = VocabStub() 23 | field_meta.vocab_size = 10 24 | field_meta.pretrained_embeds_weight = None 25 | field_meta.unk_token_idx = 0 26 | meta.features = {"word_feat": field_meta, "dict_feat": field_meta} 27 | meta.target = field_meta 28 | return meta 29 | 30 | 31 | class ModuleTest(unittest.TestCase): 32 | # TODO () Port this test to DocModel 33 | def DISABLED_test_freeze_word_embedding(self): 34 | model = create_model( 35 | DocModel.Config(), 36 | FeatureConfig( 37 | word_feat=WordFeatConfig(freeze=True, mlp_layer_dims=[4]), 38 | dict_feat=DictFeatConfig(), 39 | ), 40 | metadata=mock_metadata(), 41 | ) 42 | # word embedding 43 | for param in model.embedding[0].word_embedding.parameters(): 44 | self.assertFalse(param.requires_grad) 45 | for param in model.embedding[0].mlp.parameters(): 46 | self.assertTrue(param.requires_grad) 47 | 48 | # dict feat embedding 49 | for param in model.embedding[1].parameters(): 50 | self.assertTrue(param.requires_grad) 51 | 52 | # TODO () Port this test to DocModel 53 | def DISABLED_test_freeze_all_embedding(self): 54 | model = create_model( 55 | DocModel.Config(), FeatureConfig(freeze=True), metadata=mock_metadata() 56 | ) 57 | for param in model.embedding.parameters(): 58 | self.assertFalse(param.requires_grad) 59 | -------------------------------------------------------------------------------- /pytext/models/test/word_embedding_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | import unittest 4 | 5 | import torch 6 | from pytext.models.embeddings.word_embedding import WordEmbedding 7 | 8 | 9 | class WordEmbeddingTest(unittest.TestCase): 10 | def test_basic(self): 11 | # Setup embedding 12 | num_embeddings = 5 13 | output_dim = 6 14 | embedding_module = WordEmbedding( 15 | num_embeddings=num_embeddings, 16 | embedding_dim=4, 17 | embeddings_weight=None, 18 | init_range=[-1, 1], 19 | unk_token_idx=4, 20 | mlp_layer_dims=[3, output_dim], 21 | ) 22 | self.assertEqual(embedding_module.embedding_dim, output_dim) 23 | 24 | # Check output shape 25 | input_batch_size, input_len = 4, 6 26 | token_ids = torch.randint( 27 | low=0, high=num_embeddings, size=[input_batch_size, input_len] 28 | ) 29 | output_embedding = embedding_module(token_ids) 30 | expected_output_dims = [input_batch_size, input_len, output_dim] 31 | self.assertEqual(list(output_embedding.size()), expected_output_dims) 32 | 33 | def test_none_mlp_layer_dims(self): 34 | num_embeddings = 5 35 | embedding_dim = 4 36 | embedding_module = WordEmbedding( 37 | num_embeddings=num_embeddings, 38 | embedding_dim=embedding_dim, 39 | embeddings_weight=None, 40 | init_range=[-1, 1], 41 | unk_token_idx=4, 42 | mlp_layer_dims=None, 43 | ) 44 | self.assertEqual(embedding_module.embedding_dim, embedding_dim) 45 | 46 | def test_empty_mlp_layer_dims(self): 47 | num_embeddings = 5 48 | embedding_dim = 4 49 | embedding_module = WordEmbedding( 50 | num_embeddings=num_embeddings, 51 | embedding_dim=embedding_dim, 52 | embeddings_weight=None, 53 | init_range=[-1, 1], 54 | unk_token_idx=4, 55 | mlp_layer_dims=[], 56 | ) 57 | self.assertEqual(embedding_module.embedding_dim, embedding_dim) 58 | -------------------------------------------------------------------------------- /pytext/models/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | import torch 4 | 5 | 6 | def normalize_embeddings(embeddings: torch.Tensor): 7 | # assume [batch, embed_dim] dimensions 8 | # eps to make sure everything works in fp16 9 | return torch.nn.functional.normalize(embeddings, eps=1e-6) 10 | -------------------------------------------------------------------------------- /pytext/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from pytext.optimizer.activations import get_activation # noqa 4 | from pytext.optimizer.adabelief import AdaBelief # noqa 5 | from pytext.optimizer.fp16_optimizer import ( # noqa 6 | FP16Optimizer, 7 | FP16OptimizerApex, 8 | FP16OptimizerFairseq, 9 | ) 10 | from pytext.optimizer.lamb import Lamb # noqa 11 | from pytext.optimizer.madgrad import MADGRAD # noqa 12 | from pytext.optimizer.optimizers import ( # noqa 13 | Adagrad, 14 | Adam, 15 | AdamW, 16 | learning_rates, 17 | Optimizer, 18 | SGD, 19 | ) 20 | from pytext.optimizer.radam import RAdam # noqa 21 | from pytext.optimizer.swa import StochasticWeightAveraging # noqa 22 | -------------------------------------------------------------------------------- /pytext/optimizer/activations.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import torch.nn as nn 5 | from pytext.config.module_config import Activation 6 | 7 | 8 | def get_activation(name, dim=1): 9 | if name == Activation.RELU: 10 | return nn.ReLU() 11 | elif name == Activation.LEAKYRELU: 12 | return nn.LeakyReLU() 13 | elif name == Activation.TANH: 14 | return nn.Tanh() 15 | elif name == Activation.GELU: 16 | return nn.GELU() 17 | elif name == Activation.GLU: 18 | return nn.GLU(dim=dim) 19 | else: 20 | raise RuntimeError(f"{name} is not supported") 21 | -------------------------------------------------------------------------------- /pytext/optimizer/sparsifiers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from pytext.optimizer.sparsifiers.blockwise_sparsifier import ( # noqa 4 | BlockwiseMagnitudeSparsifier, 5 | ) 6 | -------------------------------------------------------------------------------- /pytext/resources/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from . import roberta 4 | 5 | 6 | __all__ = ["roberta"] 7 | -------------------------------------------------------------------------------- /pytext/resources/roberta.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | PUBLIC = "public" 5 | GPT2_BPE_DICT = "gpt2_bpe_dict" 6 | GPT2_BPE_VOCAB = "gpt2_bpe_vocab" 7 | GPT2_BPE_ENCODER = "gpt2_bpe_encoder" 8 | 9 | # Put public URLs here for OSS 10 | RESOURCE_MAP = { 11 | PUBLIC: "https//dl.fbaipublicfiles.com/pytext/models/roberta/roberta_public.pt1", 12 | GPT2_BPE_DICT: "https://dl.fbaipublicfiles.com/pytext/vocabs/gpt2_bpe/dict.txt", 13 | GPT2_BPE_VOCAB: "https://dl.fbaipublicfiles.com/pytext/vocabs/gpt2_bpe/vocab.bpe", 14 | GPT2_BPE_ENCODER: "https://dl.fbaipublicfiles.com/pytext/vocabs/gpt2_bpe/encoder.json", 15 | } 16 | -------------------------------------------------------------------------------- /pytext/task/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from accelerators.pytorch.lib.quantize import quantize_statically 5 | 6 | from .new_task import _NewTask, NewTask 7 | from .serialize import get_latest_checkpoint_path, load, save 8 | from .task import create_task, Task_Deprecated, TaskBase 9 | 10 | 11 | __all__ = [ 12 | "_NewTask", 13 | "NewTask", 14 | "Task_Deprecated", 15 | "TaskBase", 16 | "save", 17 | "load", 18 | "create_task", 19 | "get_latest_checkpoint_path", 20 | "quantize_statically", 21 | ] 22 | -------------------------------------------------------------------------------- /pytext/task/nop_decorator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reservedimport functools 3 | 4 | import functools 5 | 6 | # module decorator for specifying acceleration 7 | # The purpose is to avoid ImportError when glow_decorator is not available 8 | class accelerator: 9 | def __init__(self, specs, inputs_function=None): 10 | pass 11 | 12 | def __call__(self, module): 13 | @functools.wraps(module) 14 | def wrapper(*args, **kwargs): 15 | return module(*args, **kwargs) 16 | 17 | return wrapper 18 | 19 | @classmethod 20 | def _dfs_modules(cls, node, backend, results, submod_path=""): 21 | pass 22 | 23 | @classmethod 24 | def get_modules(cls, model, backend): 25 | pass 26 | 27 | @classmethod 28 | def get_module_from_path(cls, model, prefixes): 29 | pass 30 | 31 | @classmethod 32 | def get_embedding_module_from_path(cls, model, submod_path): 33 | pass 34 | -------------------------------------------------------------------------------- /pytext/torchscript/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/pytext/08754b483421884d2e363f00517ea42e449aec2c/pytext/torchscript/__init__.py -------------------------------------------------------------------------------- /pytext/torchscript/seq2seq/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/pytext/08754b483421884d2e363f00517ea42e449aec2c/pytext/torchscript/seq2seq/__init__.py -------------------------------------------------------------------------------- /pytext/torchscript/seq2seq/seq2seq_rnn_decoder_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | 5 | def get_src_length(models, decoder_ip): 6 | for i, model in enumerate(models): 7 | return model.get_src_length(decoder_ip[i]) 8 | 9 | 10 | def prepare_decoder_ips(models, decoder_ip, model_state_outputs, prev_hypos): 11 | 12 | decoder_ips = [] 13 | 14 | for i, (model, states) in enumerate(zip(models, model_state_outputs)): 15 | src_tokens, src_lengths = model.get_src_tokens_lengths(decoder_ip) 16 | encoder_rep = model.get_encoder_rep(decoder_ip[i]) 17 | 18 | prev_hiddens = states[0] 19 | prev_cells = states[1] 20 | attention = states[2] 21 | 22 | prev_hiddens_for_next = [] 23 | for hidden in prev_hiddens: 24 | prev_hiddens_for_next.append(hidden.index_select(dim=0, index=prev_hypos)) 25 | 26 | prev_cells_for_next = [] 27 | for cell in prev_cells: 28 | prev_cells_for_next.append(cell.index_select(dim=0, index=prev_hypos)) 29 | 30 | attention_for_next = attention.index_select(dim=0, index=prev_hypos) 31 | 32 | decoder_ips.append( 33 | ( 34 | encoder_rep, 35 | tuple(prev_hiddens_for_next), 36 | tuple(prev_cells_for_next), 37 | attention_for_next, 38 | src_tokens, 39 | src_lengths, 40 | ) 41 | ) 42 | 43 | return tuple(decoder_ips) 44 | -------------------------------------------------------------------------------- /pytext/torchscript/tensorizer/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from .bert import ScriptBERTTensorizer 5 | from .normalizer import VectorNormalizer 6 | from .roberta import ScriptRoBERTaTensorizer, ScriptRoBERTaTensorizerWithIndices 7 | from .tensorizer import ( 8 | ScriptFloat1DListTensorizer, 9 | ScriptFloatListSeqTensorizer, 10 | ScriptInteger1DListTensorizer, 11 | ScriptTensorizer, 12 | ) 13 | from .xlm import ScriptXLMTensorizer, VocabLookup 14 | 15 | 16 | __all__ = [ 17 | "ScriptBERTTensorizer", 18 | "ScriptFloat1DListTensorizer", 19 | "ScriptFloatListSeqTensorizer", 20 | "ScriptInteger1DListTensorizer", 21 | "ScriptRoBERTaTensorizer", 22 | "ScriptRoBERTaTensorizerWithIndices", 23 | "ScriptXLMTensorizer", 24 | "VectorNormalizer", 25 | "ScriptTensorizer", 26 | "VocabLookup", 27 | ] 28 | -------------------------------------------------------------------------------- /pytext/torchscript/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from .bpe import ScriptBPE 5 | from .tokenizer import ( 6 | ScriptBPETokenizer, 7 | ScriptDoNothingTokenizer, 8 | ScriptTokenizerBase, 9 | ScriptWordTokenizer, 10 | ) 11 | 12 | 13 | __all__ = [ 14 | "ScriptBPE", 15 | "ScriptBPETokenizer", 16 | "ScriptDoNothingTokenizer", 17 | "ScriptTokenizerBase", 18 | "ScriptWordTokenizer", 19 | ] 20 | -------------------------------------------------------------------------------- /pytext/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from .ensemble_trainer import EnsembleTrainer 4 | from .hogwild_trainer import HogwildTrainer, HogwildTrainer_Deprecated 5 | from .trainer import FP16GradsTrainer, TaskTrainer, Trainer 6 | from .training_state import TrainingState 7 | 8 | 9 | __all__ = [ 10 | "Trainer", 11 | "TrainingState", 12 | "EnsembleTrainer", 13 | "HogwildTrainer", 14 | "HogwildTrainer_Deprecated", 15 | "TaskTrainer", 16 | "FP16GradsTrainer", 17 | ] 18 | -------------------------------------------------------------------------------- /pytext/trainers/training_state.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | 5 | from typing import Any, Dict 6 | 7 | from pytext.common.constants import Stage 8 | from pytext.data.tensorizers import Tensorizer 9 | from pytext.models.model import Model 10 | from pytext.optimizer import Optimizer 11 | from pytext.optimizer.scheduler import Scheduler 12 | from pytext.optimizer.sparsifiers.sparsifier import Sparsifier 13 | 14 | 15 | class TrainingState: 16 | model: Model 17 | optimizer: Optimizer 18 | scheduler: Scheduler 19 | sparsifier: Sparsifier 20 | start_time: float 21 | # epoch counter 22 | epoch: int = 0 23 | # step counter: each optimizer.step() increments step_counter 24 | step_counter: int = 0 25 | rank: int = 0 26 | stage: Stage = Stage.TRAIN 27 | epochs_since_last_improvement: int = 0 28 | best_model_state: Any = None 29 | best_model_metric: Any = None 30 | tensorizers: Dict[str, Tensorizer] = None 31 | 32 | def __init__(self, **kwargs): 33 | unknown_keys = kwargs.keys() - TrainingState.__annotations__.keys() 34 | if unknown_keys: 35 | raise TypeError(f"TrainingState unexpected attributes {unknown_keys}") 36 | vars(self).update(kwargs) 37 | -------------------------------------------------------------------------------- /pytext/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from collections.abc import Sequence 5 | 6 | 7 | def cls_vars(cls): 8 | return [v for n, v in vars(cls).items() if not n.startswith("_")] 9 | 10 | 11 | def set_random_seeds(seed, use_deterministic_cudnn): 12 | import random 13 | 14 | import numpy as np 15 | import torch 16 | from pytext.utils import cuda 17 | 18 | # See https://pytorch.org/docs/stable/notes/randomness.html 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | if cuda.CUDA_ENABLED and use_deterministic_cudnn: 23 | print( 24 | """WARNING: Your training might be slower because you have set 25 | use_deterministic_cudnn flag to True. Read 26 | https://pytorch.org/docs/stable/notes/randomness.html and 27 | https://discuss.pytorch.org/t/what-is-the-differenc-between-cudnn-deterministic-and-cudnn-benchmark/38054 28 | """ 29 | ) 30 | torch.backends.cudnn.deterministic = True 31 | torch.backends.cudnn.benchmark = False 32 | 33 | 34 | def recursive_map(seq, func): 35 | """This is similar to the build-in map function but works for nested lists. 36 | Useful for transforming tensors serialized with .tolist() 37 | """ 38 | for item in seq: 39 | if isinstance(item, Sequence): 40 | yield type(item)(recursive_map(item, func)) 41 | else: 42 | yield func(item) 43 | 44 | 45 | def round_seq(seq, ndigits): 46 | """Rounds a nested sequence of floats to ndigits precision. 47 | Useful for rounding tensors serialized with .tolist() 48 | """ 49 | return type(seq)(recursive_map(seq, lambda item: round(item, ndigits))) 50 | -------------------------------------------------------------------------------- /pytext/utils/ascii_table.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import itertools 5 | 6 | 7 | def ordered_unique(sequence): 8 | seen = set() 9 | return [x for x in sequence if not (x in seen or seen.add(x))] 10 | 11 | 12 | def ascii_table( 13 | data, human_column_names=None, footer=None, indentation="", alignments=() 14 | ): 15 | data = list(data) 16 | column_alignments = dict(alignments) 17 | columns = human_column_names or ordered_unique(itertools.chain.from_iterable(data)) 18 | widths = { 19 | column: max(len(str(row.get(column))) for row in data) for column in columns 20 | } 21 | 22 | if human_column_names: 23 | for column, human in human_column_names.items(): 24 | widths[column] = max(widths[column], len(human)) 25 | 26 | if footer: 27 | for column, footer_value in footer.items(): 28 | widths[column] = max(widths[column], len(footer_value)) 29 | 30 | separator = "+" + "+".join("-" * (width + 2) for width in widths.values()) + "+" 31 | 32 | def format_row(row, alignment=None): 33 | alignments = { 34 | column: alignment or column_alignments.get(column, ">") 35 | for column in columns 36 | } 37 | return ( 38 | "| " 39 | + " | ".join( 40 | format(row.get(column, ""), f"{alignments[column]}{width}") 41 | for column, width in widths.items() 42 | ) 43 | + " |" 44 | ) 45 | 46 | header = ( 47 | (format_row(human_column_names, alignment="<"), separator) 48 | if human_column_names 49 | else () 50 | ) 51 | 52 | footer = (format_row(footer, alignment="<"), separator) if footer else () 53 | 54 | return indentation + f"\n{indentation}".join( 55 | (separator, *header, *(format_row(row) for row in data), separator, *footer) 56 | ) 57 | 58 | 59 | def ascii_table_from_dict(dict, key_name, value_name, indentation=""): 60 | return ascii_table( 61 | [{"key": key, "value": value} for key, value in dict.items()], 62 | {"key": key_name, "value": value_name}, 63 | indentation=indentation, 64 | alignments={"key": "<"}, 65 | ) 66 | -------------------------------------------------------------------------------- /pytext/utils/cuda.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import torch 5 | 6 | 7 | CUDA_ENABLED = False 8 | DISTRIBUTED_WORLD_SIZE = 1 9 | 10 | 11 | def Variable(data, *args, **kwargs): 12 | if CUDA_ENABLED: 13 | return torch.autograd.Variable(data.cuda(), *args, **kwargs) 14 | else: 15 | return torch.autograd.Variable(data, *args, **kwargs) 16 | 17 | 18 | def var_to_numpy(v): 19 | return (v.cpu() if CUDA_ENABLED else v).data.numpy() 20 | 21 | 22 | def zerovar(*size): 23 | return Variable(torch.zeros(*size)) 24 | 25 | 26 | def FloatTensor(*args): 27 | if CUDA_ENABLED: 28 | return torch.cuda.FloatTensor(*args) 29 | else: 30 | return torch.FloatTensor(*args) 31 | 32 | 33 | def LongTensor(*args): 34 | if CUDA_ENABLED: 35 | return torch.cuda.LongTensor(*args) 36 | else: 37 | return torch.LongTensor(*args) 38 | 39 | 40 | def GetTensor(tensor): 41 | if CUDA_ENABLED: 42 | return tensor.cuda() 43 | else: 44 | return tensor 45 | 46 | 47 | def tensor(data, dtype): 48 | return torch.tensor(data, dtype=dtype, device=device()) 49 | 50 | 51 | def device(): 52 | return "cuda:{}".format(torch.cuda.current_device()) if CUDA_ENABLED else "cpu" 53 | -------------------------------------------------------------------------------- /pytext/utils/file_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | import contextlib 4 | import math 5 | import os 6 | 7 | # keep PathManager here for more flexibility until PathManager becomes more mature 8 | # in case we want some hacks in PathManager, we can do it here without updating 9 | # the import everywhere in PyText 10 | # TODO: @stevenliu use PathManagerFactory after it's released to PyPI 11 | from iopath.common.file_io import HTTPURLHandler 12 | from pytorch.text.fb.utils import PATH_MANAGER as PathManager # noqa 13 | 14 | 15 | def register_http_url_handler(): 16 | """ 17 | support reading file from url starting with "http://", "https://", "ftp://" 18 | """ 19 | PathManager.register_handler(HTTPURLHandler(), allow_override=True) 20 | 21 | 22 | def chunk_file(file_path, chunks, work_dir): 23 | """Splits a large file by line into number of chunks and writes them into work_dir""" 24 | with PathManager.open(file_path) as fin: 25 | num_lines = sum(1 for line in fin) 26 | 27 | chunk_size = math.ceil(num_lines / chunks) 28 | output_file_paths = [] 29 | with contextlib.ExitStack() as stack: 30 | fin = stack.enter_context(PathManager.open(file_path)) 31 | for i, line in enumerate(fin): 32 | if not i % chunk_size: 33 | file_split = "{}.chunk_{}".format( 34 | os.path.join(work_dir, os.path.basename(file_path)), i // chunk_size 35 | ) 36 | output_file_paths.append(file_split) 37 | fout = stack.enter_context(open(file_split, "w")) 38 | fout.write("{}\n".format(line.strip())) 39 | 40 | return output_file_paths 41 | -------------------------------------------------------------------------------- /pytext/utils/meter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import time 5 | 6 | 7 | class Meter: 8 | def __init__(self): 9 | self.reset() 10 | 11 | def reset(self): 12 | raise NotImplementedError 13 | 14 | def update(self, val=1): 15 | raise NotImplementedError 16 | 17 | @property 18 | def avg(self): 19 | return 0 20 | 21 | 22 | class TimeMeter(Meter): 23 | """Computes the average occurrence of some event per second""" 24 | 25 | def reset(self): 26 | self.start = time.time() 27 | self.n = 0 28 | 29 | def update(self, val=1): 30 | self.n += val 31 | 32 | @property 33 | def avg(self): 34 | return self.n / self.elapsed_time 35 | 36 | @property 37 | def elapsed_time(self): 38 | return time.time() - self.start 39 | -------------------------------------------------------------------------------- /pytext/utils/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from typing import Iterable, Optional 5 | 6 | import torch 7 | 8 | from .cuda import Variable, zerovar 9 | 10 | 11 | def to_onehot(feat: Variable, size: int) -> Variable: 12 | """ 13 | Transform features into one-hot vectors 14 | """ 15 | dim = [d for d in feat.size()] 16 | vec_ = torch.unsqueeze(feat, len(dim)) 17 | dim.append(size) 18 | one_hot = zerovar(dim) 19 | one_hot.data.scatter_(len(dim) - 1, vec_.data, 1) 20 | return one_hot 21 | 22 | 23 | def get_mismatched_param( 24 | models: Iterable[torch.nn.Module], 25 | rel_epsilon: Optional[float] = None, 26 | abs_epsilon: Optional[float] = None, 27 | ) -> str: 28 | """ 29 | Return the name of the first mismatched parameter. 30 | Return an empty string if all the parameters of the modules are identical. 31 | """ 32 | if rel_epsilon is None and abs_epsilon is not None: 33 | print("WARNING: rel_epsilon is not specified, abs_epsilon is ignored.") 34 | 35 | if len(models) <= 1: 36 | return True 37 | 38 | # Verify all models have the same params. 39 | for model in models[1:]: 40 | for name, param in models[0].state_dict().items(): 41 | param_here = model.state_dict()[name] 42 | 43 | # If epsilon is specified, do approx comparison. 44 | if rel_epsilon is not None: 45 | if abs_epsilon is not None: 46 | if not torch.allclose( 47 | param, param_here, rtol=rel_epsilon, atol=abs_epsilon 48 | ): 49 | return name 50 | else: 51 | if not torch.allclose(param, param_here, rtol=rel_epsilon): 52 | return name 53 | else: 54 | if not torch.equal(param, param_here): 55 | return name 56 | return "" 57 | -------------------------------------------------------------------------------- /pytext/utils/path.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | import os 4 | 5 | from pytext.utils.file_io import PathManager 6 | 7 | 8 | def get_pytext_home(): 9 | internal_home = os.path.realpath(os.path.join(__file__, "../../")) 10 | oss_home = os.path.realpath(os.path.join(__file__, "../../../")) 11 | default_home = "" 12 | # use tests as anchor which will always in PYTEXT_HOME/tests 13 | if PathManager.exists(os.path.join(internal_home, "tests")): 14 | default_home = internal_home 15 | elif PathManager.exists(os.path.join(oss_home, "tests")): 16 | default_home = oss_home 17 | else: 18 | # when PyText is used as a module and packed as part of a single file X 19 | # __file__ will be path of X instead of path.py 20 | # in these case, PYTEXT_HOME will be the parent folder of X 21 | default_home = os.path.dirname(__file__) 22 | pytext_home = os.environ.get("PYTEXT_HOME", default_home) 23 | return pytext_home 24 | 25 | 26 | # relateive path in PyText is either based on PYTEXT_HOME or current work directory 27 | PYTEXT_HOME = get_pytext_home() 28 | 29 | 30 | def get_absolute_path(file_path: str) -> str: 31 | if os.path.isabs(file_path): 32 | return file_path 33 | absolute_path = os.path.realpath(os.path.join(PYTEXT_HOME, file_path)) 34 | if PathManager.exists(absolute_path): 35 | return absolute_path 36 | return file_path 37 | 38 | 39 | def is_absolute_path(file_path: str) -> bool: 40 | if file_path: 41 | return file_path.startswith("/") or file_path.startswith("manifold://") 42 | else: 43 | return False 44 | -------------------------------------------------------------------------------- /pytext/utils/precision.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from contextlib import contextmanager 4 | 5 | from . import cuda 6 | 7 | 8 | FP16_ENABLED = False 9 | DELAY_UNSCALE = False 10 | 11 | 12 | @contextmanager 13 | def delay_unscale(): 14 | global DELAY_UNSCALE 15 | 16 | # delay_unscale is required for gradients accumulation, model accumulate 17 | # gradient on FP16 parameters when set to True and using the same loss_scale 18 | old_delay_unscale = DELAY_UNSCALE 19 | DELAY_UNSCALE = True 20 | try: 21 | yield 22 | finally: 23 | DELAY_UNSCALE = old_delay_unscale 24 | 25 | 26 | def set_fp16(fp16_enabled: bool): 27 | global FP16_ENABLED 28 | 29 | if fp16_enabled: 30 | if not cuda.CUDA_ENABLED: 31 | raise RuntimeError("Cuda is not available, should not running fp16...") 32 | 33 | FP16_ENABLED = fp16_enabled 34 | 35 | 36 | def maybe_float(tensor): 37 | if FP16_ENABLED and tensor.type().split(".")[-1] == "HalfTensor": 38 | return tensor.float() 39 | else: 40 | return tensor 41 | 42 | 43 | def maybe_half(tensor): 44 | if FP16_ENABLED and tensor.type().split(".")[-1] == "FloatTensor": 45 | return tensor.half() 46 | else: 47 | return tensor 48 | 49 | 50 | def pad_length(n): 51 | if FP16_ENABLED: 52 | # To take advantage of tensor core, length should be multiple of 8 53 | remainder = n % 8 54 | if remainder > 0: 55 | n = n + 8 - remainder 56 | 57 | return n 58 | -------------------------------------------------------------------------------- /pytext/utils/tensor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from typing import List 5 | 6 | import torch 7 | 8 | 9 | @torch.jit.script 10 | def xaviervar(size: List[int], device: str): 11 | t = torch.empty(size, device=device) 12 | t = torch.nn.init.xavier_normal_(t) 13 | return t 14 | -------------------------------------------------------------------------------- /pytext/utils/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from importlib import import_module 4 | 5 | 6 | def import_tests_module(packages_to_scan=None): 7 | if not packages_to_scan: 8 | packages_to_scan = ["pytext.tests", "tests"] 9 | 10 | for package in packages_to_scan: 11 | try: 12 | return import_module(".data_utils", package=package) 13 | except (ModuleNotFoundError, SystemError): 14 | pass 15 | else: 16 | raise ModuleNotFoundError(f"Scanned packages: {packages_to_scan}") 17 | -------------------------------------------------------------------------------- /pytext/utils/tests/label_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import unittest 5 | 6 | import numpy as np 7 | from pytext.utils import label 8 | 9 | 10 | class LabelUtilTest(unittest.TestCase): 11 | def test_get_label_weights(self): 12 | vocab = {"foo": 0, "bar": 1} 13 | weights = {"foo": 3.2, "foobar": 2.1} 14 | weights_tensor = label.get_label_weights(vocab, weights) 15 | np.testing.assert_array_almost_equal( 16 | np.array([3.2, 1]), weights_tensor.detach().numpy() 17 | ) 18 | 19 | def test_get_auto_label_weights(self): 20 | vocab_dict = {"foo": 0, "bar": 1} 21 | label_counts = {"foo": 4, "bar": 1} 22 | weights_tensor = label.get_auto_label_weights(vocab_dict, label_counts) 23 | np.testing.assert_array_almost_equal( 24 | np.array([0.25, 4]), weights_tensor[0].detach().numpy() 25 | ) 26 | 27 | def test_get_normalized_sqrt_label_weights(self): 28 | vocab_dict = {"foo": 0, "bar": 1} 29 | label_counts = {"foo": 4, "bar": 1} 30 | weights_tensor = label.get_normalized_sqrt_label_weights( 31 | vocab_dict, label_counts 32 | ) 33 | np.testing.assert_array_almost_equal( 34 | np.array([0.5, 2]), weights_tensor[0].detach().numpy() 35 | ) 36 | 37 | def test_get_normalized_cap_label_weights(self): 38 | vocab_dict = {"foo": 0, "bar": 1} 39 | label_counts = {"foo": 4, "bar": 1} 40 | weights_tensor = label.get_normalized_cap_label_weights( 41 | vocab_dict, label_counts 42 | ) 43 | np.testing.assert_array_almost_equal( 44 | np.array([0.625, 1]), weights_tensor[0].detach().numpy() 45 | ) 46 | -------------------------------------------------------------------------------- /pytext/utils/tests/lazy_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import unittest 5 | 6 | import torch 7 | from pytext.utils import lazy 8 | from torch import nn 9 | 10 | 11 | class LazyTest(unittest.TestCase): 12 | def test_parameters_throws_exception_before_init(self): 13 | linear = lazy.Linear(4) 14 | with self.assertRaises(lazy.UninitializedLazyModuleError): 15 | list(linear.parameters()) 16 | 17 | seq = nn.Sequential(linear) 18 | with self.assertRaises(lazy.UninitializedLazyModuleError): 19 | list(seq.parameters()) 20 | 21 | def test_parameters_after_init(self): 22 | linear = lazy.Linear(4) 23 | linear = lazy.init_lazy_modules(linear, torch.rand(1, 2)) 24 | self.assertEqual(2, len(list(linear.parameters()))) 25 | 26 | seq = nn.Sequential(lazy.Linear(4)) 27 | seq = lazy.init_lazy_modules(seq, torch.rand(1, 2)) 28 | self.assertEqual(2, len(list(seq.parameters()))) 29 | self.assertIsInstance(seq[0], nn.Linear) 30 | 31 | def test_lazy_linear(self): 32 | linear = lazy.Linear(4) 33 | input = torch.rand(1, 2) 34 | out = linear(input) 35 | self.assertEqual((1, 4), out.size()) 36 | resolved = lazy.init_lazy_modules(linear, input) 37 | self.assertIsInstance(resolved, nn.Linear) 38 | self.assertEqual(2, resolved.in_features) 39 | self.assertEqual(4, resolved.out_features) 40 | self.assertIsNotNone(resolved.bias) 41 | self.assertTrue(torch.equal(out, resolved(input))) 42 | 43 | def test_lazy_linear_without_bais(self): 44 | linear = lazy.Linear(4, bias=False) 45 | input = torch.rand(1, 2) 46 | out = linear(input) 47 | self.assertEqual((1, 4), out.size()) 48 | resolved = lazy.init_lazy_modules(linear, input) 49 | self.assertIsInstance(resolved, nn.Linear) 50 | self.assertEqual(2, resolved.in_features) 51 | self.assertEqual(4, resolved.out_features) 52 | self.assertFalse(resolved.bias) 53 | self.assertTrue(torch.equal(out, resolved(input))) 54 | -------------------------------------------------------------------------------- /pytext/utils/tests/path_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import unittest 5 | 6 | from pytext.utils.path import is_absolute_path 7 | 8 | 9 | class PathTest(unittest.TestCase): 10 | def test_is_absolute_path(self): 11 | self.assertEqual(is_absolute_path("/mnt/vol/pytext/encoder.pt"), True) 12 | self.assertEqual(is_absolute_path("manifold://pytext/tree/encoder.pt"), True) 13 | self.assertEqual(is_absolute_path("encoder.pt"), False) 14 | -------------------------------------------------------------------------------- /pytext/utils/tests/timing_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import unittest 5 | 6 | from pytext.utils import timing 7 | 8 | 9 | class TimingTest(unittest.TestCase): 10 | def test_format_time(self): 11 | tests = ( 12 | (1.2473e-8, "0.0ns"), 13 | (1.2473e-7, "0.1ns"), 14 | (1.2473e-6, "1.2ns"), 15 | (0.000_012_473, "12.5ns"), 16 | (0.000_124_73, "124.7ns"), 17 | (0.001_247_3, "1.2ms"), 18 | (0.012_473, "12.5ms"), 19 | (0.12473, "124.7ms"), 20 | (1.2473, "1.2s"), 21 | (12.473, "12.5s"), 22 | (124.73, "2m5s"), 23 | (1247.3, "20m47s"), 24 | (12473.0, "3h28m"), 25 | (124_730.0, "1d11h"), 26 | (1_247_300.0, "14d10h"), 27 | ) 28 | for seconds, expected in tests: 29 | self.assertEqual( 30 | expected, timing.format_time(seconds), f"Failed to format {seconds}" 31 | ) 32 | -------------------------------------------------------------------------------- /pytext/utils/torch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import torch 5 | from pytext.torchscript.tensorizer import VectorNormalizer # noqa 6 | from pytext.torchscript.tokenizer import ScriptBPE as BPE # noqa 7 | from pytext.torchscript.utils import ( # noqa 8 | add_bos_eos_2d, 9 | add_special_token_2d, 10 | long_tensor_2d, 11 | make_byte_inputs, 12 | make_sequence_lengths, 13 | pad_2d_mask, 14 | utf8_chars, 15 | ) 16 | from pytext.torchscript.vocab import ScriptVocabulary as Vocabulary # noqa 17 | from pytext.utils import cuda 18 | 19 | 20 | # Note: this file is used to load the training checkpoint (backward compatibility) 21 | # For any new usecase, please import directly from pytext.torchscript. 22 | 23 | 24 | class CPUOnlyParameter(torch.nn.Parameter): 25 | def __init__(self, *args, **kwargs): 26 | assert ( 27 | cuda.DISTRIBUTED_WORLD_SIZE <= 1 28 | ), "Multiple GPUs not supported for cpu_only embeddings" 29 | super().__init__(*args, **kwargs) 30 | 31 | def cuda(self, device=None): 32 | # We do nothing because this Parameter should only be on the CPU 33 | return self 34 | -------------------------------------------------------------------------------- /pytext/utils/typing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from enum import Enum 5 | 6 | 7 | class WeightingMethod(Enum): 8 | CLASS_RATIO = "CLASS_RATIO" # weight = #neg/#pos for each class. 9 | SQRT_RATIO = "SQRT" # normalized by square root of CLASS_RATIO 10 | CAPPED_RATIO = "CAP" # weight = # avg positive / # positive if # positive is greater than average, otherwise 1.0 11 | -------------------------------------------------------------------------------- /pytext/utils/usage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import torch 5 | 6 | subsystem_name = "PyText" 7 | 8 | 9 | def log_class_usage(klass): 10 | identifier = subsystem_name 11 | if klass and hasattr(klass, "__name__"): 12 | identifier += f".{klass.__name__}" 13 | torch._C._log_api_usage_once(identifier) 14 | 15 | 16 | def log_feature_usage(feature): 17 | identifier = subsystem_name + "." + feature 18 | torch._C._log_api_usage_once(identifier) 19 | 20 | 21 | def log_accelerator_feature_usage(feature): 22 | feature = "Accelerator." + feature 23 | log_feature_usage(feature) 24 | 25 | 26 | def log_flow_usage(flow_name): 27 | identifier = subsystem_name + ".flow." + flow_name 28 | torch._C._log_api_usage_once(identifier) 29 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | build: 2 | image: latest 3 | 4 | python: 5 | version: 3.7 6 | setup_py_install: true 7 | use_system_site_packages: true 8 | requirements_file: docs_requirements.txt 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | click 2 | fairseq 3 | future 4 | hypothesis<4.0 5 | iopath 6 | joblib 7 | numpy 8 | onnx>=1.6.0 9 | pandas 10 | transformers==3.4.0 11 | regex==2019.11.1 12 | requests 13 | scipy 14 | sentencepiece 15 | tensorboard 16 | torch 17 | torchtext 18 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import os 5 | 6 | import setuptools 7 | 8 | 9 | DIR = os.path.dirname(__file__) 10 | REQUIREMENTS = os.path.join(DIR, "requirements.txt") 11 | 12 | 13 | with open(REQUIREMENTS) as f: 14 | reqs = f.read() 15 | 16 | setuptools.setup( 17 | name="pytext-nlp", 18 | version="0.3.3", 19 | description="pytorch modeling framework and model zoo for text models", 20 | url="https://github.com/facebookresearch/PyText", 21 | author="Facebook", 22 | license="BSD", 23 | packages=setuptools.find_packages(), 24 | install_requires=reqs.strip().split("\n"), 25 | entry_points={"console_scripts": ["pytext = pytext.main:main"]}, 26 | ) 27 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | -------------------------------------------------------------------------------- /tests/data/alarm_lm_tiny.tsv: -------------------------------------------------------------------------------- 1 | how much do I have left before my alarm goes off 2 | Set alarm every minute for next hour 3 | Snooze my alarm 4 | set my alarm for every hour 5 | show all my alarms 6 | -------------------------------------------------------------------------------- /tests/data/compositional_seq2seq_unit.tsv: -------------------------------------------------------------------------------- 1 | delays in tempe [IN:GET_INFO_TRAFFIC delays in [SL:LOCATION tempe ] ] 2 | directions the detroit zoo from my house , leaving at 8:30am . [IN:GET_DIRECTIONS directions [SL:DESTINATION the detroit zoo ] from [SL:SOURCE [IN:GET_LOCATION_HOME [SL:CONTACT my ] house ] ] , leaving [SL:DATE_TIME_DEPARTURE at 8 : 30 am ] . ] 3 | why are the police directing traffic in this area [IN:UNSUPPORTED_NAVIGATION why are the police directing traffic in this area ] 4 | what is causing the traffic [IN:GET_INFO_TRAFFIC what is causing the traffic ] 5 | how long will it take me to get from home to work if i leave in 15 minutes ? [IN:GET_ESTIMATED_DURATION how long will it take me to get from [SL:SOURCE [IN:GET_LOCATION_HOME home ] ] to [SL:DESTINATION [IN:GET_LOCATION_WORK work ] ] if i leave [SL:DATE_TIME_DEPARTURE in 15 minutes ] ? ] 6 | get me directions to greenville without using i-85 [IN:GET_DIRECTIONS get me directions to [SL:DESTINATION greenville ] without using [SL:PATH_AVOID i - 85 ] ] 7 | find me the quickest route home [IN:GET_DIRECTIONS find me the quickest route [SL:DESTINATION [IN:GET_LOCATION_HOME home ] ] ] 8 | shortest distance between faridabad and noida [IN:GET_DISTANCE shortest distance between [SL:SOURCE faridabad ] and [SL:DESTINATION noida ] ] 9 | what is traffic like in paris [IN:GET_INFO_TRAFFIC what is traffic like in [SL:LOCATION paris ] ] 10 | what route to the jazz fest saves the most gas [IN:UNSUPPORTED_NAVIGATION what route to the jazz fest saves the most gas ] 11 | -------------------------------------------------------------------------------- /tests/data/contextual_intent_slot_test_tiny.tsv: -------------------------------------------------------------------------------- 1 | cu:other ["Hey", "Youd love this"] {"tokenFeatList": [{"tokenIdx": 1, "features": {"b_ozlo_ner_category:content_reaction": 0.12}}]} 0.2 0.5 2 | cu:address_Person 0:4:person ["this is crazy still to me"] {"tokenFeatList": [{"tokenIdx": 5, "features": {"b_uce_contacts": 1.0, "uce_contacts": 1.0}}]} 0.2 0.5 3 | cu:other ["Whats up. How are you doing"] 0.2 0.5 4 | cu:other ["Hey", "Hey"] 0.2 0.5 5 | cu:other ["Dinner tonight?", "yup"] 0.2 0.5 6 | -------------------------------------------------------------------------------- /tests/data/contextual_intent_slot_train_tiny.tsv: -------------------------------------------------------------------------------- 1 | cu:other ["Hey", "Youd love this"] {"tokenFeatList": [{"tokenIdx": 1, "features": {"b_ozlo_ner_category:content_reaction": 0.12}}]} 0.2 0.5 2 | cu:address_Person 0:4:person ["this is crazy still to me"] {"tokenFeatList": [{"tokenIdx": 5, "features": {"b_uce_contacts": 1.0, "uce_contacts": 1.0}}]} 0.2 0.5 3 | cu:other ["Whats up. How are you doing"] 0.2 0.5 4 | cu:other ["Hey", "Hey"] 0.2 0.5 5 | cu:other ["Dinner tonight?", "yup"] 0.2 0.5 6 | cu:other ["wanna hangout?", "maybe"] 0.2 0.5 7 | cu:other ["wya?", "home"] 0.2 0.5 8 | cu:other ["tommi sushi again?", "why not"] {"tokenFeatList": [{"tokenIdx": 0, "features": {"b_ozlo_ner_category:poi": 0.9}}, {"tokenIdx": 1, "features": {"b_ozlo_ner_category:poi": 1.0}}]} 0.2 0.5 9 | cu:other ["going out?"] 0.2 0.5 10 | cu:other ["I like this!!"] 0.2 0.5 11 | -------------------------------------------------------------------------------- /tests/data/contextual_intent_slot_train_tiny_dense.tsv: -------------------------------------------------------------------------------- 1 | cu:other ["Hey", "Youd love this"] {"tokenFeatList": [{"tokenIdx": 1, "features": {"b_ozlo_ner_category:content_reaction": 0.12}}]} 0.2 0.5 [0,1,2,3,4] 2 | cu:address_Person 0:4:person ["this is crazy still to me"] {"tokenFeatList": [{"tokenIdx": 5, "features": {"b_uce_contacts": 1.0, "uce_contacts": 1.0}}]} 0.2 0.5 [0,1,2,3,4] 3 | cu:other ["Whats up. How are you doing"] 0.2 0.5 [0,1,2,3,4] 4 | cu:other ["Hey", "Hey"] 0.2 0.5 [0,1,2,3,4] 5 | cu:other ["Dinner tonight?", "yup"] 0.2 0.5 [0,1,2,3,4] 6 | cu:other ["wanna hangout?", "maybe"] 0.2 0.5 [0,1,2,3,4] 7 | cu:other ["wya?", "home"] 0.2 0.5 [0,1,2,3,4] 8 | cu:other ["tommi sushi again?", "why not"] {"tokenFeatList": [{"tokenIdx": 0, "features": {"b_ozlo_ner_category:poi": 0.9}}, {"tokenIdx": 1, "features": {"b_ozlo_ner_category:poi": 1.0}}]} 0.2 0.5 [0,1,2,3,4] 9 | cu:other ["going out?"] 0.2 0.5 [0,1,2,3,4] 10 | cu:other ["I like this!!"] 0.2 0.5 [0,1,2,3,4] 11 | -------------------------------------------------------------------------------- /tests/data/dummy_pretrained_embedding_dim4: -------------------------------------------------------------------------------- 1 | a 0.1 0.2 0.3 0.4 2 | b 0.1 0.2 0.3 0.4 3 | c 0.1 0.2 0.3 0.4 4 | d 0.1 0.2 0.3 0.4 5 | e 0.1 0.2 0.3 0.4 6 | -------------------------------------------------------------------------------- /tests/data/eval_data_tiny.tsv: -------------------------------------------------------------------------------- 1 | alarm/time_left_on_alarm How much time is left? 2 | alarm/set_alarm 23:38:datetime Set my alarm on all days 3 | reminder/set_reminder 10:18:datetime,22:38:reminder/todo remind me tomorrow to call my mom 4 | -------------------------------------------------------------------------------- /tests/data/export_seq2seq_unit.tsv: -------------------------------------------------------------------------------- 1 | delays in tempe [IN:GET_INFO_TRAFFIC delays in [SL:LOCATION tempe ] ] {} 2 | directions the detroit zoo from my house , leaving at 8:30am . [IN:GET_DIRECTIONS directions [SL:DESTINATION the detroit zoo ] from [SL:SOURCE [IN:GET_LOCATION_HOME [SL:CONTACT my ] house ] ] , leaving [SL:DATE_TIME_DEPARTURE at 8 : 30 am ] . ] {} 3 | why are the police directing traffic in this area [IN:UNSUPPORTED_NAVIGATION why are the police directing traffic in this area ] {} 4 | what is causing the traffic [IN:GET_INFO_TRAFFIC what is causing the traffic ] {} 5 | how long will it take me to get from home to work if i leave in 15 minutes ? [IN:GET_ESTIMATED_DURATION how long will it take me to get from [SL:SOURCE [IN:GET_LOCATION_HOME home ] ] to [SL:DESTINATION [IN:GET_LOCATION_WORK work ] ] if i leave [SL:DATE_TIME_DEPARTURE in 15 minutes ] ? ] {} 6 | get me directions to greenville without using i-85 [IN:GET_DIRECTIONS get me directions to [SL:DESTINATION greenville ] without using [SL:PATH_AVOID i - 85 ] ] {} 7 | find me the quickest route home [IN:GET_DIRECTIONS find me the quickest route [SL:DESTINATION [IN:GET_LOCATION_HOME home ] ] ] {} 8 | shortest distance between faridabad and noida [IN:GET_DISTANCE shortest distance between [SL:SOURCE faridabad ] and [SL:DESTINATION noida ] ] {} 9 | what is traffic like in paris [IN:GET_INFO_TRAFFIC what is traffic like in [SL:LOCATION paris ] ] {} 10 | what route to the jazz fest saves the most gas [IN:UNSUPPORTED_NAVIGATION what route to the jazz fest saves the most gas ] {} 11 | -------------------------------------------------------------------------------- /tests/data/fl_test.tsv: -------------------------------------------------------------------------------- 1 | 0 how are you doing? user1 2 | 1 i am good user1 3 | 1 well user2 4 | 0 what's the time like? user2 5 | 1 this is not good user3 6 | 0 what is your name? user3 7 | -------------------------------------------------------------------------------- /tests/data/knowledge_distillation_test_tiny.tsv: -------------------------------------------------------------------------------- 1 | Who R U ? [-0.005602254066616297, -5.430975914001465] [5.181787490844727, -5.4265875816345215] ["cu:other", "cu:ask_Location"] cu:other 2 | You in the gym [-1.8341524600982666, -0.1862216591835022] [-1.6600980758666992, 1.5862623453140259] ["cu:other", "cu:ask_Location"] cu:ask_Location 3 | look at me [-0.0028048718813806772, -5.696012496948242] [5.874996662139893, -5.692647457122803] ["cu:other", "cu:ask_Location"] cu:other 4 | Move fast and ship love [-0.0013048476539552212, -6.862490653991699] [6.641023635864258, -6.861443996429443] ["cu:other", "cu:ask_Location"] cu:other 5 | At your house? [-3.073305130004883, -0.038981884717941284] [-3.025932550430298, 3.225104808807373] ["cu:other", "cu:ask_Location"] cu:ask_Location 6 | Lmao [-2.312633478140924e-05, -10.719563484191895] [10.675718307495117, -10.719541549682617] ["cu:other", "cu:ask_Location"] cu:other 7 | Owwww Sweet [-0.038656335324048996, -3.3673951625823975] [3.2336537837982178, -3.3323073387145996] ["cu:other", "cu:ask_Location"] cu:other 8 | You home yet lol [-2.0499677658081055, -0.12795871496200562] [-1.912153959274292, 1.9913861751556396] ["cu:other", "cu:ask_Location"] cu:ask_Location 9 | Are you home? [-5.4780426025390625, -0.003714567981660366] [-5.473856449127197, 5.593633651733398] ["cu:other", "cu:ask_Location"] cu:ask_Location 10 | Where you at? [-5.395582675933838, -0.004174210596829653] [-5.391035556793213, 5.4767537117004395] ["cu:other", "cu:ask_Location"] cu:ask_Location 11 | -------------------------------------------------------------------------------- /tests/data/msg_topic_train.tsv: -------------------------------------------------------------------------------- 1 | cu:other ["hello","hello, how are you?","not good"] 2 | cu:discuss ["do you want to discuss something","yes","ok","what do you want to discuss","nothing"] 3 | cu:other ["wanna do something?","get drunk"] 4 | cu:discuss ["we should discuss"] 5 | -------------------------------------------------------------------------------- /tests/data/pairwise_classification.tsv: -------------------------------------------------------------------------------- 1 | true A plane is taking off. An air plane is taking off. 2 | true A man is playing a large flute. A man is playing a flute. 3 | true A man is spreading shreded cheese on a pizza. A man is spreading shredded cheese on an uncooked pizza. 4 | true Three men are playing chess. Two men are playing chess. 5 | true A man is playing the cello. A man seated is playing the cello. 6 | true Some men are fighting. Two men are fighting. 7 | false A man is smoking. A man is skating. 8 | false The man is playing the piano. The man is playing the guitar. 9 | false A man is playing on a guitar and singing. A woman is playing an acoustic guitar and singing. 10 | true A person is throwing a cat on to the ceiling. A person throws a cat on the ceiling. 11 | -------------------------------------------------------------------------------- /tests/data/pretrained_embed_raw: -------------------------------------------------------------------------------- 1 | 10 5 2 | 0.16322 -0.36765 0.20026 -0.12489 0.08388 3 | the -0.39153 -0.19803 0.2573 -0.18617 0.25551 4 | to -0.19776 0.047666 0.28138 -0.18371 0.28204 5 | and -0.26861 0.0755 0.25665 -0.20097 0.22401 6 | a -0.099258 -0.046465 0.23546 -0.10014 0.22648 7 | I -0.039575 -0.11011 0.14779 -0.22566 0.061617 8 | you -0.017067 -0.20538 0.11632 -0.24342 0.35351 9 | is 0.17873 -0.10689 0.069695 -0.20678 0.037157 10 | aloha -0.43124 0.014934 -0.50635 0.60506 0.56051 11 | for -0.079701 0.056928 0.29815 -0.23357 0.29599 12 | -------------------------------------------------------------------------------- /tests/data/query_document_pairwise_ranking_different_users.tsv: -------------------------------------------------------------------------------- 1 | query response response user1 2 | query response1 response2 user2 3 | query2 response2 response2 user3 4 | query3 responseA responseB user4 5 | -------------------------------------------------------------------------------- /tests/data/query_document_pairwise_ranking_one_user.tsv: -------------------------------------------------------------------------------- 1 | query response response user1 2 | query response1 response2 user1 3 | query2 response2 response2 user1 4 | query3 responseA responseB user1 5 | -------------------------------------------------------------------------------- /tests/data/query_document_pairwise_ranking_tiny.tsv: -------------------------------------------------------------------------------- 1 | query response response user1 2 | query response1 response2 user2 3 | query2 response2 response2 user1 4 | query3 responseA responseB user3 5 | -------------------------------------------------------------------------------- /tests/data/roberta_sp_vocab_small: -------------------------------------------------------------------------------- 1 | ▁let 100 2 | ' 90 3 | s 85 4 | ▁see 75 5 | ▁if 65 6 | ▁this 60 7 | ▁works 55 8 | ▁it 50 9 | s 45 #fairseq:overwrite 10 | ▁they 40 11 | re 30 12 | ▁! 32 13 | ▁? 25 14 | ▁. 48 15 | ▁test 65 16 | ing 90 17 | ▁out 10 18 | ▁sent 20 19 | ence 20 20 | piece 5 21 | _friend 5 22 | -------------------------------------------------------------------------------- /tests/data/seq2seq_model_unit.tsv: -------------------------------------------------------------------------------- 1 | what time is it [{"tokenIdx": 0, "features": {"pos:ADJ": 1}}, {"tokenIdx": 1, "features": {"pos:NOUN": 1, "expected:TIME": 1}}, {"tokenIdx": 2, "features": {"pos:VB": 1}}, {"tokenIdx": 3, "features": {"pos:PREP": 1}}] [IN:TIME what [SLOT:TIME time] is it] 2 | directions to starbucks [{"tokenIdx": 0, "features": {"pos:NOUN": 1, "expected:LOC": 1}}, {"tokenIdx": 1, "features": {"pos:PREP": 1}}, {"tokenIdx": 2, "features": {"pos:PNOUN": 1, "SL:LOC": 1}}] [IN:DIRECTIONS direction to [SLOT:LOC starbucks]] 3 | is it raining [{"tokenIdx": 0, "features": {"pos:VB": 1}}, {"tokenIdx": 1, "features": {"pos:PREP": 1}}, {"tokenIdx": 2, "features": {"pos:ADV": 1, "expected:WEATHER": 1}}] [IN:WEATHER is it raining] [outOfDomain information ] 4 | -------------------------------------------------------------------------------- /tests/data/seq_tagging_example.tsv: -------------------------------------------------------------------------------- 1 | id1 int11 g11 0 2 | id1 int12 g12 0 3 | id2 int21 g21 0 4 | id2 int22 g22 0 5 | id3 int31 g31 0 6 | id3 int32 g32 1 7 | id3 int33 g33 1 8 | -------------------------------------------------------------------------------- /tests/data/sts_tiny.tsv: -------------------------------------------------------------------------------- 1 | 0 main-captions MSRvid 2012test 0001 none none A plane is taking off . An air plane is taking off . 5.000 2 | 1 main-captions MSRvid 2012test 0004 none none A man is playing a large flute . A man is playing a flute . 3.800 3 | 2 main-captions MSRvid 2012test 0005 none none A man is spreading shreded cheese on a pizza . A man is spreading shredded cheese on an uncooked pizza . 3.800 4 | 3 main-captions MSRvid 2012test 0006 none none Three men are playing chess . Two men are playing chess . 2.600 5 | 4 main-captions MSRvid 2012test 0009 none none A man is playing the cello . A man seated is playing the cello . 4.250 6 | 5 main-captions MSRvid 2012test 0011 none none Some men are fighting . Two men are fighting . 4.250 7 | 6 main-captions MSRvid 2012test 0012 none none A man is smoking . A man is skating . 0.500 8 | 7 main-captions MSRvid 2012test 0013 none none The man is playing the piano . The man is playing the guitar . 1.600 9 | 8 main-captions MSRvid 2012test 0014 none none A man is playing on a guitar and singing . A woman is playing an acoustic guitar and singing . 2.200 10 | 9 main-captions MSRvid 2012test 0016 none none A person is throwing a cat on to the ceiling . A person throws a cat on the ceiling . 5.000 11 | 10 main-captions MSRvid 2012test 0017 none none The man hit the other man with a stick . The man spanked the other man with a stick . 4.200 12 | 11 main-captions MSRvid 2012test 0018 none none A woman picks up and holds a baby kangaroo . A woman picks up and holds a baby kangaroo in her arms . 4.600 13 | 12 main-captions MSRvid 2012test 0019 none none A man is playing a flute . A man is playing a bamboo flute . 3.867 14 | 13 main-captions MSRvid 2012test 0020 none none A person is folding a piece of paper . Someone is folding a piece of paper . 4.667 15 | 14 main-captions MSRvid 2012test 0021 none none A man is running on the road . A panda dog is running on the road . 1.667 16 | 15 main-captions MSRvid 2012test 0022 none none A dog is trying to get bacon off his back . A dog is trying to eat the bacon on its back . 3.750 17 | -------------------------------------------------------------------------------- /tests/data/test_data_split_tiny.tsv: -------------------------------------------------------------------------------- 1 | I am an Arsenal fan. user1 2 | I am an Arsenal fan. user1 3 | I am an Arsenal fan. user1 4 | I am an Arsenal fan. user1 5 | What is the weather today? user2 6 | What is the weather today? user2 7 | What is the weather today? user2 8 | What is the weather today? user2 9 | I want to sleep. user3 10 | I want to sleep. user3 11 | I want to sleep. user3 12 | I want to sleep. user3 13 | Where can I get some golf lessons? user4 14 | Where can I get some golf lessons? user4 15 | Where can I get some golf lessons? user4 16 | Where can I get some golf lessons? user4 17 | What the the shortest route to get to the airport? user5 18 | What the the shortest route to get to the airport? user5 19 | What the the shortest route to get to the airport? user5 20 | What the the shortest route to get to the airport? user5 21 | When is the time for the Arsenal vs Liverpool match this week? user6 22 | When is the time for the Arsenal vs Liverpool match this week? user6 23 | When is the time for the Arsenal vs Liverpool match this week? user6 24 | When is the time for the Arsenal vs Liverpool match this week? user6 25 | Nice to meet you! user7 26 | Nice to meet you! user7 27 | Nice to meet you! user7 28 | Nice to meet you! user7 29 | Thank you very much. user8 30 | Thank you very much. user8 31 | Thank you very much. user8 32 | Thank you very much. user8 33 | -------------------------------------------------------------------------------- /tests/data/test_data_tiny.tsv: -------------------------------------------------------------------------------- 1 | alarm/set_alarm 11:17:datetime reactivate weekly alarm [1.0] 2 | alarm/set_alarm 23:38:datetime Set alarm to ring only on the weekdays [1.0] 3 | alarm/time_left_on_alarm When will alarm go off [1.0] 4 | reminder/set_reminder 10:18:datetime,22:38:reminder/todo remind me tomorrow to call the groomer [1.0] 5 | -------------------------------------------------------------------------------- /tests/data/test_data_tiny_csv.tsv: -------------------------------------------------------------------------------- 1 | alarm/set_alarm,"16:24:datetime,39:57:datetime",this is the text, 2 | alarm/set_alarm,,this is the text, 3 | alarm/set_alarm,12:27:datetime,this is the text, 4 | alarm/set_alarm,,"this is the text, why not" 5 | alarm/set_alarm,12:37:datetime,"this is the text, it's good", 6 | -------------------------------------------------------------------------------- /tests/data/test_data_tiny_fl.tsv: -------------------------------------------------------------------------------- 1 | alarm/set_alarm 11:17:datetime reactivate weekly alarm [1.0] 2 | alarm/set_alarm 23:38:datetime Set alarm to ring only on the weekdays [1.0] 3 | alarm/time_left_on_alarm When will alarm go off [1.0] 4 | reminder/set_reminder 10:18:datetime,22:38:reminder/todo remind me tomorrow to call the groomer [1.0] 5 | reminder/set_reminder 10:18:datetime,22:38:reminder/todo remind me tomorrow to call the groomer [1.0] 6 | -------------------------------------------------------------------------------- /tests/data/test_data_tiny_weights.tsv: -------------------------------------------------------------------------------- 1 | alarm/set_alarm 11:17:datetime reactivate weekly alarm 0.2 0.5 2 | alarm/set_alarm 23:38:datetime Set alarm to ring only on the weekdays 0.2 0.5 3 | alarm/time_left_on_alarm When will alarm go off 0.2 0.5 4 | reminder/set_reminder 10:18:datetime,22:38:reminder/todo remind me tomorrow to call the groomer 0.2 0.5 5 | -------------------------------------------------------------------------------- /tests/data/test_dense_features_tiny.tsv: -------------------------------------------------------------------------------- 1 | alarm/set_alarm 11:17:datetime reactivate weekly alarm [0.862891,0.8683,0,0,0.00012208146,0.9694,0,0,0,0.8488] 2 | alarm/set_alarm 23:38:datetime Set alarm to ring only on the weekdays [0.6755983,0.9942,0,0,0.99237925,0,0,0,0,0.1688] 3 | alarm/time_left_on_alarm When will alarm go off [0.7782097,0.9114,0,0,0.31280947,0.8484,0,0,0,0.9058] 4 | reminder/set_reminder 10:18:datetime,22:38:reminder/todo remind be tomorrow to call the groomer [0.74115264,0.8891,0,0,0.46979037,0.9753,0,0,0,0] 5 | -------------------------------------------------------------------------------- /tests/data/test_embed.cached: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/pytext/08754b483421884d2e363f00517ea42e449aec2c/tests/data/test_embed.cached -------------------------------------------------------------------------------- /tests/data/test_embed.raw: -------------------------------------------------------------------------------- 1 | 10 5 2 | 0.16322 -0.36765 0.20026 -0.12489 0.08388 3 | the -0.39153 -0.19803 0.2573 -0.18617 0.25551 4 | to -0.19776 0.047666 0.28138 -0.18371 0.28204 5 | and -0.26861 0.0755 0.25665 -0.20097 0.22401 6 | a -0.099258 -0.046465 0.23546 -0.10014 0.22648 7 | I -0.039575 -0.11011 0.14779 -0.22566 0.061617 8 | you -0.017067 -0.20538 0.11632 -0.24342 0.35351 9 | is 0.17873 -0.10689 0.069695 -0.20678 0.037157 10 | aloha -0.43124 0.014934 -0.50635 0.60506 0.56051 11 | for -0.079701 0.056928 0.29815 -0.23357 0.29599 12 | -------------------------------------------------------------------------------- /tests/data/test_embed_xlu.cached: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/pytext/08754b483421884d2e363f00517ea42e449aec2c/tests/data/test_embed_xlu.cached -------------------------------------------------------------------------------- /tests/data/test_lm_tiny.tsv: -------------------------------------------------------------------------------- 1 | I am an Arsenal fan. user1 2 | What is the weather today? user2 3 | I want to sleep. user3 4 | Where can I get some golf lessons? user4 5 | What the the shortest route to get to the airport? user5 6 | When is the time for the Arsenal vs Liverpool match this week? user6 7 | Nice to meet you! user7 8 | Thank you very much. user8 9 | -------------------------------------------------------------------------------- /tests/data/test_lm_tiny_broadcast_data.tsv: -------------------------------------------------------------------------------- 1 | I am an Arsenal fan. user1 2 | What is the weather today? user1 3 | I want to sleep. user1 4 | Where can I get some golf lessons? user1 5 | What the the shortest route to get to the airport? user1 6 | When is the time for the Arsenal vs Liverpool match this week? user1 7 | Nice to meet you! user1 8 | Thank you very much. user1 9 | -------------------------------------------------------------------------------- /tests/data/test_lm_tiny_fl.tsv: -------------------------------------------------------------------------------- 1 | I am an Arsenal fan. user1 2 | What is the weather today? user2 3 | I want to sleep. user3 4 | Where can I get some golf lessons? user4 5 | What the the shortest route to get to the airport? user5 6 | When is the time for the Arsenal vs Liverpool match this week? user6 7 | Nice to meet you! user7 8 | Thank you very much. user8 9 | Thank you very much. user9 10 | -------------------------------------------------------------------------------- /tests/data/test_personalization_opposite_inputs.tsv: -------------------------------------------------------------------------------- 1 | 0 how are you doing? [0.64840776,0.7575,0.5531,0.2403,0,0.9481,0,0.1538,0.2403,0.3564] user1 2 | 0 I am good [0.64840776,0.7575,0.5531,0.2403,0,0.9481,0,0.1538,0.2403,0.3564] user1 3 | 0 well [0.64840776,0.7575,0.5531,0.2403,0,0.9481,0,0.1538,0.2403,0.3564] user1 4 | 0 today is Sunday [0.84840776,0.9575,0.15531,0.2403,0,0.9481,1.0,0.1538,0.2403,0.3564] user1 5 | 0 Tomorrow is Monday [0.60776,0.575,0.531,0.403,0,0.9481,2.0,0.538,0.243,0.64] user1 6 | 1 how are you doing? [0.64840776,0.7575,0.5531,0.2403,0,0.9481,0,0.1538,0.2403,0.3564] user2 7 | 1 I am good [0.64840776,0.7575,0.5531,0.2403,0,0.9481,0,0.1538,0.2403,0.3564] user2 8 | 1 well [0.64840776,0.7575,0.5531,0.2403,0,0.9481,0,0.1538,0.2403,0.3564] user2 9 | 1 today is Sunday [0.84840776,0.9575,0.15531,0.2403,0,0.9481,1.0,0.1538,0.2403,0.3564] user2 10 | 1 Tomorrow is Monday [0.60776,0.575,0.531,0.403,0,0.9481,2.0,0.538,0.243,0.64] user2 11 | -------------------------------------------------------------------------------- /tests/data/test_personalization_same_inputs.tsv: -------------------------------------------------------------------------------- 1 | 0 how are you doing? [0.64840776,0.7575,0.5531,0.2403,0,0.9481,0,0.1538,0.2403,0.3564] user1 2 | 1 I am good [0.64840776,0.7575,0.5531,0.2403,0,0.9481,0,0.1538,0.2403,0.3564] user1 3 | 0 well [0.64840776,0.7575,0.5531,0.2403,0,0.9481,0,0.1538,0.2403,0.3564] user1 4 | 1 today is Sunday [0.84840776,0.9575,0.15531,0.2403,0,0.9481,1.0,0.1538,0.2403,0.3564] user1 5 | 0 Tomorrow is Monday [0.60776,0.575,0.531,0.403,0,0.9481,2.0,0.538,0.243,0.64] user1 6 | 0 how are you doing? [0.64840776,0.7575,0.5531,0.2403,0,0.9481,0,0.1538,0.2403,0.3564] user2 7 | 1 I am good [0.64840776,0.7575,0.5531,0.2403,0,0.9481,0,0.1538,0.2403,0.3564] user2 8 | 0 well [0.64840776,0.7575,0.5531,0.2403,0,0.9481,0,0.1538,0.2403,0.3564] user2 9 | 1 today is Sunday [0.84840776,0.9575,0.15531,0.2403,0,0.9481,1.0,0.1538,0.2403,0.3564] user2 10 | 0 Tomorrow is Monday [0.60776,0.575,0.531,0.403,0,0.9481,2.0,0.538,0.243,0.64] user2 11 | -------------------------------------------------------------------------------- /tests/data/test_personalization_single_user.tsv: -------------------------------------------------------------------------------- 1 | 0 how are you doing? [0.64840776,0.7575,0.5531,0.2403,0,0.9481,0,0.1538,0.2403,0.3564] user1 2 | 1 I am good [0.64840776,0.7575,0.5531,0.2403,0,0.9481,0,0.1538,0.2403,0.3564] user1 3 | 0 well [0.64840776,0.7575,0.5531,0.2403,0,0.9481,0,0.1538,0.2403,0.3564] user1 4 | 1 today is Sunday [0.84840776,0.9575,0.15531,0.2403,0,0.9481,1.0,0.1538,0.2403,0.3564] user1 5 | 0 Tomorrow is Monday [0.60776,0.575,0.531,0.403,0,0.9481,2.0,0.538,0.243,0.64] user1 6 | -------------------------------------------------------------------------------- /tests/data/test_tiny.en: -------------------------------------------------------------------------------- 1 | false day al@@ kal@@ ine diet plan to fight inflam@@ mation and disease en 2 | false one@@ football en 3 | false zur@@ iel angel rose @ 2@@ mos@@ .@@ old en 4 | false hannah & har@@ lynn en 5 | false canada 's cut@@ est kid con@@ test en 6 | false summer@@ time swe@@ ethe@@ art en 7 | false happy baby ✌ 🙊 en 8 | false jack anton@@ off wo@@ os lena dun@@ ham with v@@ ma sn@@ acking | 10@@ 0.3 kiss fm en 9 | -------------------------------------------------------------------------------- /tests/data/test_tsv_quoting.tsv: -------------------------------------------------------------------------------- 1 | false "the vampire di@@ aries ' : nina do@@ bre@@ v to return for series finale 2 | true the white house is alle@@ ge@@ dly ' ic@@ ing out ' cnn 3 | true bel@@ o@@ ved ' m * a * s * h ' cast member dies 4 | false this is the last sentence 5 | -------------------------------------------------------------------------------- /tests/data/test_utf8_errors.tsv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/pytext/08754b483421884d2e363f00517ea42e449aec2c/tests/data/test_utf8_errors.tsv -------------------------------------------------------------------------------- /tests/data/train_data_tiny.tsv: -------------------------------------------------------------------------------- 1 | alarm/modify_alarm 16:24:datetime,39:57:datetime change my alarm tomorrow to wake me up 30 minutes earlier [1.0] 2 | alarm/set_alarm Turn on all my alarms [1.0] 3 | alarm/set_alarm 12:27:datetime sound alarm every 8 minutes [1.0] 4 | alarm/set_alarm 7:17:datetime repeat yesterdays alarm [1.0] 5 | alarm/snooze_alarm continue my alarm [1.0] 6 | alarm/time_left_on_alarm Do I have anymore time on the alarm [1.0] 7 | reminder/set_reminder 10:22:datetime,28:56:reminder/todo remind me Monday night that get doe with work 12 Tuesday [1.0] 8 | reminder/show_reminders 8:15:datetime,18:27:reminder/noun display Tuesday's reminders [1.0] 9 | weather/find 12:19:weather/noun,20:28:datetime what is the weather tomorrow [1.0] 10 | weather/find 13:17:weather/attribute When will it snow [1.0] 11 | -------------------------------------------------------------------------------- /tests/data/train_data_tiny_weights.tsv: -------------------------------------------------------------------------------- 1 | alarm/modify_alarm 16:24:datetime,39:57:datetime change my alarm tomorrow to wake me up 30 minutes earlier 0.2 0.5 2 | alarm/set_alarm Turn on all my alarms 0.2 0.5 3 | alarm/set_alarm 12:27:datetime sound alarm every 8 minutes 0.2 0.5 4 | alarm/set_alarm 7:17:datetime repeat yesterdays alarm 0.2 0.5 5 | alarm/snooze_alarm continue my alarm 0.2 0.5 6 | alarm/time_left_on_alarm Do I have anymore time on the alarm 0.2 0.5 7 | reminder/set_reminder 10:22:datetime,28:56:reminder/todo remind me Monday night that get doe with work 12 Tuesday 0.2 0.5 8 | reminder/show_reminders 8:15:datetime,18:27:reminder/noun display Tuesday's reminders 0.2 0.5 9 | weather/find 12:19:weather/noun,20:28:datetime what is the weather tomorrow 0.2 0.5 10 | weather/find 13:17:weather/attribute When will it snow 0.2 0.5 11 | -------------------------------------------------------------------------------- /tests/data/train_dense_features_and_text_tiny.tsv: -------------------------------------------------------------------------------- 1 | alarm/modify_alarm change my alarm tomorrow to wake me up 30 minutes earlier [0.64840776,0.7575,0.5531,0.2403,0,0.9481,0,0.1538,0.2403,0.3564] 2 | alarm/set_alarm Turn on all my alarms [0.69509345,0.9371,0,0,0,0.1693,0,0,0,0.995] 3 | alarm/set_alarm sound alarm every 8 minutes [0.73880494,0.8836,0,0,0.0003586522,0.9022,0.9159,0,0,0.6729] 4 | alarm/set_alarm repeat yesterdays alarm [0.74912685,0.9579,0,0,0.64698946,0.6173,0.7484,0,0,0.9966] 5 | alarm/snooze_alarm continue my alarm [0.8172329,0.9206,0,0,0.03937195,0.8109,0.7319,0,0,0.5002] 6 | alarm/time_left_on_alarm Do I have anymore time on the alarm [0.92558855,0.922,0,0,0.004977477,0.9947,0,0,0,0] 7 | reminder/set_reminder remind me Monday night that get doe with work 12 Tuesday [0.7453431,0,0,0,0.0054596593,0,0,0,0,0] 8 | reminder/show_reminders display Tuesday's reminders [0.79554415,0.9694,0,0,0.32772592,0.1919,0.7633,0,0,0.9799] 9 | weather/find what is the weather tomorrow [0.6053743,0.9244,0,0,0.0030888896,0.6641,0,0,0,0.8453] 10 | weather/find When will it snow [0.84357846,0.9663,0,0,0.0033321558,0.9903,0,0,0,0] 11 | -------------------------------------------------------------------------------- /tests/data/train_dense_features_tiny.tsv: -------------------------------------------------------------------------------- 1 | alarm/modify_alarm 16:24:datetime,39:57:datetime change my alarm tomorrow to wake me up 30 minutes earlier [0.64840776,0.7575,0.5531,0.2403,0,0.9481,0,0.1538,0.2403,0.3564] 2 | alarm/set_alarm Turn on all my alarms [0.69509345,0.9371,0,0,0,0.1693,0,0,0,0.995] 3 | alarm/set_alarm 12:27:datetime sound alarm every 8 minutes [0.73880494,0.8836,0,0,0.0003586522,0.9022,0.9159,0,0,0.6729] 4 | alarm/set_alarm 7:17:datetime repeat yesterdays alarm [0.74912685,0.9579,0,0,0.64698946,0.6173,0.7484,0,0,0.9966] 5 | alarm/snooze_alarm continue my alarm [0.8172329,0.9206,0,0,0.03937195,0.8109,0.7319,0,0,0.5002] 6 | alarm/time_left_on_alarm Do I have anymore time on the alarm [0.92558855,0.922,0,0,0.004977477,0.9947,0,0,0,0] 7 | reminder/set_reminder 10:22:datetime,28:56:reminder/todo remind me Monday night that get doe with work 12 Tuesday [0.7453431,0,0,0,0.0054596593,0,0,0,0,0] 8 | reminder/show_reminders 8:15:datetime,18:27:reminder/noun display Tuesday's reminders [0.79554415,0.9694,0,0,0.32772592,0.1919,0.7633,0,0,0.9799] 9 | weather/find 12:19:weather/noun,20:28:datetime what is the weather tomorrow [0.6053743,0.9244,0,0,0.0030888896,0.6641,0,0,0,0.8453] 10 | weather/find 13:17:weather/attribute When will it snow [0.84357846,0.9663,0,0,0.0033321558,0.9903,0,0,0,0] 11 | -------------------------------------------------------------------------------- /tests/data/train_dense_features_tiny_fl.tsv: -------------------------------------------------------------------------------- 1 | alarm/modify_alarm 16:24:datetime,39:57:datetime change my alarm tomorrow to wake me up 30 minutes earlier [0.64840776,0.7575,0.5531,0.2403,0,0.9481,0,0.1538,0.2403,0.3564] 2 | alarm/set_alarm Turn on all my alarms [0.69509345,0.9371,0,0,0,0.1693,0,0,0,0.995] 3 | alarm/set_alarm 12:27:datetime sound alarm every 8 minutes [0.73880494,0.8836,0,0,0.0003586522,0.9022,0.9159,0,0,0.6729] 4 | alarm/set_alarm 7:17:datetime repeat yesterdays alarm [0.74912685,0.9579,0,0,0.64698946,0.6173,0.7484,0,0,0.9966] 5 | alarm/snooze_alarm continue my alarm [0.8172329,0.9206,0,0,0.03937195,0.8109,0.7319,0,0,0.5002] 6 | alarm/time_left_on_alarm Do I have anymore time on the alarm [0.92558855,0.922,0,0,0.004977477,0.9947,0,0,0,0] 7 | reminder/set_reminder 10:22:datetime,28:56:reminder/todo remind me Monday night that get doe with work 12 Tuesday [0.7453431,0,0,0,0.0054596593,0,0,0,0,0] 8 | reminder/show_reminders 8:15:datetime,18:27:reminder/noun display Tuesday's reminders [0.79554415,0.9694,0,0,0.32772592,0.1919,0.7633,0,0,0.9799] 9 | weather/find 12:19:weather/noun,20:28:datetime what is the weather tomorrow [0.6053743,0.9244,0,0,0.0030888896,0.6641,0,0,0,0.8453] 10 | weather/find 13:17:weather/attribute When will it snow [0.84357846,0.9663,0,0,0.0033321558,0.9903,0,0,0,0] 11 | weather/find 13:17:weather/attribute When will it snow [0.84357846,0.9663,0,0,0.0033321558,0.9903,0,0,0,0] 12 | weather/find 13:17:weather/attribute When will it snow [0.84357846,0.9663,0,0,0.0033321558,0.9903,0,0,0,0] 13 | -------------------------------------------------------------------------------- /tests/data/train_dict_features.tsv: -------------------------------------------------------------------------------- 1 | Order coffee from Starbucks please [{"tokenIdx": 1, "features": {"drink/beverage": 0.8, "music/song": 0.2}},{"tokenIdx": 3, "features": {"store/coffee_shop": 1.0}}] 2 | Order some fries from McDonalds please [{"tokenIdx": 2, "features": {"food": 1.0}},{"tokenIdx": 4, "features": {"store/fast_food": 1.0}}] 3 | -------------------------------------------------------------------------------- /tests/data/train_dict_features_bad_json.tsv: -------------------------------------------------------------------------------- 1 | Order coffee from Starbucks [{tokenIdx: 1, features: {"drink/beverage": 0.8, "music/song": 0.2 2 | -------------------------------------------------------------------------------- /tests/data/train_seq_features.tsv: -------------------------------------------------------------------------------- 1 | ["where do you wanna meet?", "MPK"] 2 | 3 | -------------------------------------------------------------------------------- /tests/data/train_tiny_with_lang.tsv: -------------------------------------------------------------------------------- 1 | alarm/modify_alarm 16:24:datetime,39:57:datetime change my alarm tomorrow to wake me up 30 minutes earlier 1-1:1|2-2:1 en 2 | alarm/set_alarm Turn on all my alarms 1-1:1|2-2:1 en 3 | alarm/set_alarm 12:27:datetime sound alarm every 8 minutes 1-1:1|2-2:1 en 4 | alarm/set_alarm 7:17:datetime repeat yesterdays alarm 1-1:1|2-2:1 en 5 | alarm/snooze_alarm continue my alarm 1-1:1|2-2:1 en 6 | alarm/time_left_on_alarm Do I have anymore time on the alarm 1-1:1|2-2:1 en 7 | reminder/show_reminders 8:15:datetime,18:27:reminder/noun display Tuesday's reminders 1-1:1|2-2:1 en 8 | weather/find 12:19:weather/noun,20:28:datetime what is the weather tomorrow 1-1:1|2-2:1 en 9 | weather/find 13:17:weather/attribute When will it snow 1-1:1|2-2:1 en 10 | alarm/set_alarm 21:35:datetime Configurar mi alarma por 60 minutos 1-1:1|2-2:1 es 11 | -------------------------------------------------------------------------------- /tests/data/xlm_vocab_small: -------------------------------------------------------------------------------- 1 | , 146805425 2 | ._EOW 128661007 3 | de 81673059 4 | " 68075167 5 | la 46720979 6 | en 35121205 7 | der 29483420 8 | ) 29473053 9 | ( 29419894 10 | a_EOW 27957264 11 | die 22438757 12 | und 20327710 13 | el 19003684 14 | в 18754006 15 | in 18499530 16 | le 18363616 17 | des 17520601 18 | y 15797922 19 | et 15523745 20 | , 15190970 21 | -------------------------------------------------------------------------------- /tests/model_utils_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from __future__ import absolute_import, division, print_function, unicode_literals 4 | 5 | import unittest 6 | 7 | import numpy as np 8 | import torch 9 | from pytext.utils.cuda import Variable 10 | from pytext.utils.model import to_onehot 11 | 12 | 13 | class ModelUtilsTest(unittest.TestCase): 14 | def test_to_onehot(self): 15 | feat_vec = Variable(torch.LongTensor([[0, 1, 2], [3, 4, 0]])) 16 | onehot = to_onehot(feat_vec, 5) 17 | self.assertEqual(onehot.size()[0], 2) 18 | self.assertEqual(onehot.size()[1], 3) 19 | self.assertEqual(onehot.size()[2], 5) 20 | 21 | expected = np.zeros((2, 3, 5)) 22 | expected[0][0][0] = 1 23 | expected[0][1][1] = 1 24 | expected[0][2][2] = 1 25 | expected[1][0][3] = 1 26 | expected[1][1][4] = 1 27 | expected[1][2][0] = 1 28 | 29 | for (i, row) in enumerate(onehot): 30 | for (j, feat) in enumerate(row): 31 | for (k, val) in enumerate(feat): 32 | self.assertEqual(expected[i][j][k], val.item()) 33 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | import tempfile 4 | 5 | from pytext.utils import test 6 | from pytext.utils.config_utils import MockConfigLoader 7 | 8 | tests_module = test.import_tests_module() 9 | 10 | 11 | def find_and_patch_config( 12 | config_filename, config_base_path, output_path_prefix="pytext_demo_" 13 | ): 14 | output_base_path = tempfile.mkdtemp(prefix=output_path_prefix) 15 | mock_config_loader = MockConfigLoader( 16 | config_base_path=config_base_path, 17 | replace_paths={"/tmp": output_base_path}, 18 | ) 19 | config_dict = mock_config_loader.make_config(config_filename) 20 | return config_dict 21 | --------------------------------------------------------------------------------