├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── __pycache__ ├── __init__.cpython-36.pyc ├── ner_output.cpython-36.pyc └── style_utils.cpython-36.pyc ├── assets ├── assertion_viz.png ├── dp_viz.png ├── er_viz.png ├── ner_viz.png └── re_viz.png ├── build └── lib │ └── sparknlp_display │ ├── VERSION │ ├── __init__.py │ ├── assertion.py │ ├── dep_updates.py │ ├── dependency_parser.py │ ├── entity_resolution.py │ ├── fonts │ └── Lucida_Console.ttf │ ├── label_colors │ ├── ner.json │ └── relations.json │ ├── ner.py │ ├── re_updates.py │ ├── relation_extraction.py │ ├── retemp.py │ ├── style.css │ └── style_utils.py ├── dist ├── spark-nlp-display-5.0.tar.gz └── spark_nlp_display-5.0-py3-none-any.whl ├── setup.cfg ├── setup.py ├── spark_nlp_display.egg-info ├── PKG-INFO ├── SOURCES.txt ├── dependency_links.txt ├── requires.txt └── top_level.txt ├── sparknlp_display ├── VERSION ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── ner_output.cpython-36.pyc │ └── style_utils.cpython-36.pyc ├── assertion.py ├── dep_updates.py ├── dependency_parser.py ├── entity_resolution.py ├── fonts │ └── Lucida_Console.ttf ├── label_colors │ ├── ner.json │ └── relations.json ├── ner.py ├── re_updates.py ├── relation_extraction.py ├── retemp.py ├── style.css └── style_utils.py └── tutorials ├── .ipynb_checkpoints └── Spark_NLP_Display-checkpoint.ipynb └── Spark_NLP_Display.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright (c) 2020 John Snow Labs 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include sparknlp_display/* 2 | include sparknlp_display/fonts/* 3 | include sparknlp_display/style.css 4 | include sparknlp_display/style_utils.py 5 | include sparknlp_display/label_colors/* 6 | include sparknlp_display/VERSION -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # spark-nlp-display 2 | A library for the simple visualization of different types of Spark NLP annotations. 3 | 4 | ## Supported Visualizations: 5 | - Dependency Parser 6 | - Named Entity Recognition 7 | - Entity Resolution 8 | - Relation Extraction 9 | - Assertion Status 10 | 11 | ## Complete Tutorial 12 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-display/blob/main/tutorials/Spark_NLP_Display.ipynb) 13 | 14 | https://github.com/JohnSnowLabs/spark-nlp-display/blob/main/tutorials/Spark_NLP_Display.ipynb 15 | 16 | ### Requirements 17 | - spark-nlp 18 | - ipython 19 | - svgwrite 20 | - pandas 21 | - numpy 22 | 23 | ### Installation 24 | ```bash 25 | pip install spark-nlp-display 26 | ``` 27 | 28 | ### How to use 29 | 30 | ### Databricks 31 | #### For all modules, pass in the additional parameter "return_html=True" in the display function and use Databrick's function displayHTML() to render visualization as explained below: 32 | ```python 33 | from sparknlp_display import NerVisualizer 34 | 35 | ner_vis = NerVisualizer() 36 | 37 | ## To set custom label colors: 38 | ner_vis.set_label_colors({'LOC':'#800080', 'PER':'#77b5fe'}) #set label colors by specifying hex codes 39 | 40 | pipeline_result = ner_light_pipeline.fullAnnotate(text) ##light pipeline 41 | #pipeline_result = ner_full_pipeline.transform(df).collect()##full pipeline 42 | 43 | vis_html = ner_vis.display(pipeline_result[0], #should be the results of a single example, not the complete dataframe 44 | label_col='entities', #specify the entity column 45 | document_col='document', #specify the document column (default: 'document') 46 | labels=['PER'], #only allow these labels to be displayed. (default: [] - all labels will be displayed) 47 | return_html=True) 48 | 49 | displayHTML(vis_html) 50 | ``` 51 | ![title](https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/main/assets/ner_viz.png) 52 | 53 | ### Jupyter 54 | 55 | To save the visualization as html, provide the export file path: `save_path='./export.html'` for each visualizer. 56 | 57 | 58 | #### Dependency Parser 59 | ```python 60 | from sparknlp_display import DependencyParserVisualizer 61 | 62 | dependency_vis = DependencyParserVisualizer() 63 | 64 | pipeline_result = dp_pipeline.fullAnnotate(text) 65 | #pipeline_result = dp_full_pipeline.transform(df).collect()##full pipeline 66 | 67 | dependency_vis.display(pipeline_result[0], #should be the results of a single example, not the complete dataframe. 68 | pos_col = 'pos', #specify the pos column 69 | dependency_col = 'dependency', #specify the dependency column 70 | dependency_type_col = 'dependency_type', #specify the dependency type column 71 | save_path='./export.html' # optional - to save viz as html. (default: None) 72 | ) 73 | ``` 74 | 75 | ![title](https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/main/assets/dp_viz.png) 76 | 77 | #### Named Entity Recognition 78 | 79 | ```python 80 | from sparknlp_display import NerVisualizer 81 | 82 | ner_vis = NerVisualizer() 83 | 84 | pipeline_result = ner_light_pipeline.fullAnnotate(text) 85 | #pipeline_result = ner_full_pipeline.transform(df).collect()##full pipeline 86 | 87 | ner_vis.display(pipeline_result[0], #should be the results of a single example, not the complete dataframe 88 | label_col='entities', #specify the entity column 89 | document_col='document', #specify the document column (default: 'document') 90 | labels=['PER'], #only allow these labels to be displayed. (default: [] - all labels will be displayed) 91 | save_path='./export.html' # optional - to save viz as html. (default: None) 92 | ) 93 | 94 | ## To set custom label colors: 95 | ner_vis.set_label_colors({'LOC':'#800080', 'PER':'#77b5fe'}) #set label colors by specifying hex codes 96 | 97 | ``` 98 | 99 | ![title](https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/main/assets/ner_viz.png) 100 | 101 | #### Entity Resolution 102 | 103 | ```python 104 | from sparknlp_display import EntityResolverVisualizer 105 | 106 | er_vis = EntityResolverVisualizer() 107 | 108 | pipeline_result = er_light_pipeline.fullAnnotate(text) 109 | 110 | er_vis.display(pipeline_result[0], #should be the results of a single example, not the complete dataframe 111 | label_col='entities', #specify the ner result column 112 | resolution_col = 'resolution', 113 | document_col='document', #specify the document column (default: 'document') 114 | save_path='./export.html' # optional - to save viz as html. (default: None) 115 | ) 116 | 117 | ## To set custom label colors: 118 | er_vis.set_label_colors({'TREATMENT':'#800080', 'PROBLEM':'#77b5fe'}) #set label colors by specifying hex codes 119 | 120 | ``` 121 | 122 | ![title](https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/main/assets/er_viz.png) 123 | 124 | #### Relation Extraction 125 | ```python 126 | from sparknlp_display import RelationExtractionVisualizer 127 | 128 | re_vis = RelationExtractionVisualizer() 129 | 130 | pipeline_result = re_light_pipeline.fullAnnotate(text) 131 | 132 | re_vis.display(pipeline_result[0], #should be the results of a single example, not the complete dataframe 133 | relation_col = 'relations', #specify relations column 134 | document_col = 'document', #specify document column 135 | show_relations=True, #display relation names on arrows (default: True) 136 | save_path='./export.html' # optional - to save viz as html. (default: None) 137 | ) 138 | 139 | ``` 140 | 141 | ![title](https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/main/assets/re_viz.png) 142 | 143 | #### Assertion Status 144 | ```python 145 | from sparknlp_display import AssertionVisualizer 146 | 147 | assertion_vis = AssertionVisualizer() 148 | 149 | pipeline_result = ner_assertion_light_pipeline.fullAnnotate(text) 150 | 151 | assertion_vis.display(pipeline_result[0], 152 | label_col = 'entities', #specify the ner result column 153 | assertion_col = 'assertion', #specify assertion column 154 | document_col = 'document', #specify the document column (default: 'document') 155 | save_path='./export.html' # optional - to save viz as html. (default: None) 156 | ) 157 | 158 | ## To set custom label colors: 159 | assertion_vis.set_label_colors({'TREATMENT':'#008080', 'problem':'#800080'}) #set label colors by specifying hex codes 160 | 161 | ``` 162 | 163 | ![title](https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/main/assets/assertion_viz.png) 164 | -------------------------------------------------------------------------------- /__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/a6c5a6ac166c5f75ec4ccf64ffafb2fa6e1164a1/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/ner_output.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/a6c5a6ac166c5f75ec4ccf64ffafb2fa6e1164a1/__pycache__/ner_output.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/style_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/a6c5a6ac166c5f75ec4ccf64ffafb2fa6e1164a1/__pycache__/style_utils.cpython-36.pyc -------------------------------------------------------------------------------- /assets/assertion_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/a6c5a6ac166c5f75ec4ccf64ffafb2fa6e1164a1/assets/assertion_viz.png -------------------------------------------------------------------------------- /assets/dp_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/a6c5a6ac166c5f75ec4ccf64ffafb2fa6e1164a1/assets/dp_viz.png -------------------------------------------------------------------------------- /assets/er_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/a6c5a6ac166c5f75ec4ccf64ffafb2fa6e1164a1/assets/er_viz.png -------------------------------------------------------------------------------- /assets/ner_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/a6c5a6ac166c5f75ec4ccf64ffafb2fa6e1164a1/assets/ner_viz.png -------------------------------------------------------------------------------- /assets/re_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/a6c5a6ac166c5f75ec4ccf64ffafb2fa6e1164a1/assets/re_viz.png -------------------------------------------------------------------------------- /build/lib/sparknlp_display/VERSION: -------------------------------------------------------------------------------- 1 | 5.0 -------------------------------------------------------------------------------- /build/lib/sparknlp_display/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from sparknlp_display.ner import NerVisualizer 3 | from sparknlp_display.dependency_parser import DependencyParserVisualizer 4 | from sparknlp_display.relation_extraction import RelationExtractionVisualizer 5 | from sparknlp_display.entity_resolution import EntityResolverVisualizer 6 | from sparknlp_display.assertion import AssertionVisualizer 7 | 8 | here = os.path.abspath(os.path.dirname(__file__)) 9 | 10 | def get_version(): 11 | version_path = os.path.abspath(os.path.dirname(__file__)) 12 | with open(os.path.join(here, "VERSION"), "r") as fh: 13 | app_version = fh.read().strip() 14 | return app_version 15 | 16 | __version__ = get_version() 17 | 18 | def version(): 19 | return get_version() -------------------------------------------------------------------------------- /build/lib/sparknlp_display/assertion.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import json 4 | import numpy as np 5 | from . import style_utils as style_config 6 | from IPython.display import display, HTML 7 | 8 | here = os.path.abspath(os.path.dirname(__file__)) 9 | 10 | class AssertionVisualizer: 11 | def __init__(self): 12 | with open(os.path.join(here, 'label_colors/ner.json'), 'r', encoding='utf-8') as f_: 13 | self.label_colors = json.load(f_) 14 | 15 | #public function to get color for a label 16 | def get_label_color(self, label): 17 | """Returns color of a particular label 18 | 19 | Input: entity label 20 | Output: Color or if not found 21 | """ 22 | 23 | if str(label).lower() in self.label_colors: 24 | return self.label_colors[label.lower()] 25 | else: 26 | return None 27 | 28 | # private function for colors for display 29 | def __get_label(self, label): 30 | """Set label colors. 31 | 32 | Input: dictionary of entity labels and corresponding colors 33 | Output: self object - to allow chaining 34 | Note: Previous values of colors will be overwritten 35 | """ 36 | if str(label).lower() in self.label_colors: 37 | return self.label_colors[label.lower()] 38 | else: 39 | #update it to fetch from git new labels 40 | r = lambda: random.randint(0,200) 41 | return '#%02X%02X%02X' % (r(), r(), r()) 42 | 43 | def set_label_colors(self, color_dict): 44 | """Sets label colors. 45 | 46 | input: dictionary of entity labels and corresponding colors 47 | output: self object - to allow chaining 48 | note: Previous values of colors will be overwritten 49 | """ 50 | 51 | for key, value in color_dict.items(): 52 | self.label_colors[key.lower()] = value 53 | return self 54 | 55 | def __verify_structure(self, result, label_col, document_col, original_text): 56 | 57 | if original_text is None: 58 | basic_msg_1 = """Incorrect annotation structure of '{}'.""".format(document_col) 59 | if not hasattr(result[document_col][0], 'result'): 60 | raise AttributeError(basic_msg_1+""" 'result' attribute not found in the annotation. 61 | Make sure '"""+document_col+"""' is a list of objects having the following structure: 62 | Annotation(type='annotation', begin=0, end=10, result='This is a text') 63 | Or You can pass the text manually using 'raw_text' argument.""") 64 | 65 | basic_msg_1 = """Incorrect annotation structure of '{}'.""".format(label_col) 66 | basic_msg = """ 67 | In sparknlp please use 'LightPipeline.fullAnnotate' for LightPipeline or 'Pipeline.transform' for PipelineModel. 68 | Or 69 | Make sure '"""+label_col+"""' is a list of objects having the following structure: 70 | Annotation(type='annotation', begin=0, end=0, result='Adam', metadata={'entity': 'PERSON'})""" 71 | 72 | for entity in result[label_col]: 73 | if not hasattr(entity, 'begin'): 74 | raise AttributeError( basic_msg_1 + """ 'begin' attribute not found in the annotation."""+basic_msg) 75 | if not hasattr(entity, 'end'): 76 | raise AttributeError(basic_msg_1 + """ 'end' attribute not found in the annotation."""+basic_msg) 77 | if not hasattr(entity, 'result'): 78 | raise AttributeError(basic_msg_1 + """ 'result' attribute not found in the annotation."""+basic_msg) 79 | if not hasattr(entity, 'metadata'): 80 | raise AttributeError(basic_msg_1 + """ 'metadata' attribute not found in the annotation."""+basic_msg) 81 | if 'entity' not in entity.metadata: 82 | raise AttributeError(basic_msg_1+""" KeyError: 'entity' not found in metadata."""+basic_msg) 83 | 84 | def __verify_input(self, result, label_col, document_col, original_text): 85 | # check if label colum in result 86 | if label_col not in result: 87 | raise AttributeError("""column/key '{}' not found in the provided dataframe/dictionary. 88 | Please specify the correct key/column using 'label_col' argument.""".format(label_col)) 89 | 90 | if original_text is not None: 91 | # check if provided text is correct data type 92 | if not isinstance(original_text, str): 93 | raise ValueError("Invalid value for argument 'raw_text' input. Text should be of type 'str'.") 94 | 95 | else: 96 | # check if document column in result 97 | if document_col not in result: 98 | raise AttributeError("""column/key '{}' not found in the provided dataframe/dictionary. 99 | Please specify the correct key/column using 'document_col' argument. 100 | Or You can pass the text manually using 'raw_text' argument""".format(document_col)) 101 | 102 | self.__verify_structure( result, label_col, document_col, original_text) 103 | 104 | # main display function 105 | def __display_ner(self, result, label_col, resolution_col, document_col, original_text, labels_list = None): 106 | 107 | if original_text is None: 108 | original_text = result[document_col][0].result 109 | 110 | if labels_list is not None: 111 | labels_list = [v.lower() for v in labels_list] 112 | 113 | assertion_temp_dict = {} 114 | for resol in result[resolution_col]: 115 | assertion_temp_dict[int(resol.begin)] = resol.result 116 | 117 | label_color = {} 118 | html_output = "" 119 | pos = 0 120 | 121 | sorted_labs_idx = np.argsort([ int(i.begin) for i in result[label_col]]) 122 | sorted_labs = [ result[label_col][i] for i in sorted_labs_idx] 123 | 124 | for entity in sorted_labs: 125 | entity_type = entity.metadata['entity'].lower() 126 | if (entity_type not in label_color) and ((labels_list is None) or (entity_type in labels_list)) : 127 | label_color[entity_type] = self.__get_label(entity_type) 128 | 129 | begin = int(entity.begin) 130 | end = int(entity.end) 131 | if pos < begin and pos < len(original_text): 132 | white_text = original_text[pos:begin] 133 | html_output += '{}'.format(white_text) 134 | pos = end+1 135 | 136 | if entity_type in label_color: 137 | 138 | if begin in assertion_temp_dict: 139 | 140 | html_output += '{} {}{} '.format( 141 | label_color[entity_type] + 'B3', #color 142 | original_text[begin:end+1],#entity.result, 143 | entity.metadata['entity'], #entity - label 144 | label_color[entity_type] + 'FF', #color '#D2C8C6' 145 | assertion_temp_dict[begin] # res_assertion 146 | ) 147 | else: 148 | html_output += '{} {}'.format( 149 | label_color[entity_type] + 'B3', #color 150 | original_text[begin:end+1],#entity.result, 151 | entity.metadata['entity'] #entity - label 152 | ) 153 | 154 | else: 155 | html_output += '{}'.format(original_text[begin:end+1]) 156 | 157 | if pos < len(original_text): 158 | html_output += '{}'.format(original_text[pos:]) 159 | 160 | html_output += """""" 161 | 162 | html_output = html_output.replace("\n", "
") 163 | 164 | return html_output 165 | 166 | def display(self, result, label_col, assertion_col, document_col='document', raw_text=None, return_html=False, save_path=None): 167 | """Displays Assertion visualization. 168 | 169 | Inputs: 170 | result -- A Dataframe or dictionary. 171 | label_col -- Name of the column/key containing NER annotations. 172 | document_col -- Name of the column/key containing text document. 173 | original_text -- Original text of type 'str'. If specified, it will take precedence over 'document_col' and will be used as the reference text for display. 174 | labels_list -- A list of labels that should be highlighted in the output. Default: Display all labels. 175 | 176 | Output: Visualization 177 | """ 178 | 179 | #self.__verifyInput(result, label_col, document_col, raw_text) 180 | 181 | html_content = self.__display_ner(result, label_col, assertion_col, document_col, raw_text) 182 | html_content_save = style_config.STYLE_CONFIG_ENTITIES+ " "+html_content 183 | 184 | if save_path != None: 185 | with open(save_path, 'w') as f_: 186 | f_.write(html_content_save) 187 | 188 | if return_html: 189 | return html_content_save 190 | else: 191 | return display(HTML(html_content_save)) 192 | -------------------------------------------------------------------------------- /build/lib/sparknlp_display/dep_updates.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import json 4 | import pandas as pd 5 | import numpy as np 6 | import svgwrite 7 | from IPython.display import display, HTML 8 | 9 | here = os.path.abspath(os.path.dirname(__file__)) 10 | 11 | class DependencyParserVisualizer: 12 | 13 | def __get_color(self, l): 14 | r = lambda: random.randint(100,255) 15 | return '#%02X%02X%02X' % (r(), r(), r()) 16 | 17 | def __size(self, text): 18 | return ((len(text)+1)*15)-5 19 | 20 | def __draw_line(self, dwg, s_x , s_y, e_x, e_y, d_type, color): 21 | rx=ry=1 22 | dwg.add(dwg.path(d=f"M{min(s_x,e_x)},{s_y} A {rx}, {ry}, 0 1 1 {max(s_x,e_x)}, {s_y}", 23 | stroke=color, stroke_width = "3", fill='none')) 24 | dwg.add(dwg.polyline( 25 | [(e_x, s_y), (e_x+3, s_y), 26 | (e_x, s_y+5), 27 | (e_x-3, s_y), 28 | (e_x, s_y) 29 | ], stroke=color, stroke_width = "4", fill='none',)) 30 | ''' 31 | line = dwg.add(dwg.polyline( 32 | [(s_x, s_y+4), 33 | (s_x, e_y), 34 | (e_x, e_y), 35 | (e_x, s_y), 36 | (e_x+2, s_y), 37 | (e_x, s_y+4), 38 | (e_x-2, s_y), 39 | (e_x, s_y) 40 | ], 41 | stroke=color, stroke_width = "2", fill='none',)) 42 | ''' 43 | dwg.add(dwg.text(d_type, insert=(((s_x+e_x)/2)-(self.__size(d_type.strip())/2.75), e_y-4), 44 | fill=color, font_size='20', font_family='courier')) 45 | 46 | def __generate_graph(self, result_df): 47 | # returns an svg graph 48 | 49 | colors_dict = {} 50 | max_x = 50 51 | max_y = 100 52 | 53 | for i in result_df['dependency_type'].unique(): 54 | colors_dict[i] = self.__get_color(i) 55 | 56 | for i in result_df['pos'].unique(): 57 | colors_dict[i] = self.__get_color(i) 58 | 59 | for i, row in result_df.iterrows(): 60 | txt = row['chunk'].strip() 61 | max_x += (self.__size(txt) + 50) 62 | max_y += 30 63 | 64 | max_x += 50 65 | start_x = 50 66 | starty_y = max_y 67 | dp_dict={} 68 | tk_dict = {} 69 | dist_dict = {} 70 | main_text = [] 71 | main_pos = [] 72 | 73 | for i, row in result_df.iterrows(): 74 | txt = row['chunk'].strip() 75 | dt = row['dependency_type'].lower().strip() 76 | is_root = False 77 | if dt == 'root': 78 | is_root = True 79 | main_text.append((txt, start_x, starty_y, is_root)) 80 | main_pos.append((row['pos'].strip(), (start_x + int((self.__size(txt)/2) - int(self.__size(row['pos'])/2))), starty_y+30)) 81 | 82 | tk_dict[str(row['begin'])+str(row['end'])] = (start_x+int(self.__size(txt)/2), starty_y) 83 | start_x += (self.__size(txt) + 50) 84 | 85 | y_offset = starty_y-100 86 | dist_dict = {} 87 | e_dist_dict = {} 88 | direct_dict = {} 89 | left_side_dict = {} 90 | right_side_dict = {} 91 | y_hist = {} 92 | root_list = [] 93 | main_lines = [] 94 | lines_dist = [] 95 | 96 | dist = [] 97 | for i, row in result_df.iterrows(): 98 | if row['dependency_type'].lower().strip() != 'root': 99 | lines_dist.append(abs(int(row['begin']) - int(row['dependency_start']['head.begin']))) 100 | else: 101 | lines_dist.append(0) 102 | 103 | result_df = result_df.iloc[np.argsort(lines_dist)] 104 | 105 | count_left = {} 106 | count_right = {} 107 | t_x_offset = {} 108 | for i, row in result_df.iterrows(): 109 | if row['dependency_type'].lower().strip() != 'root': 110 | sp = str(row['dependency_start']['head.begin'])+str(row['dependency_start']['head.end']) 111 | x_e, y_e = tk_dict[str(row['begin'])+str(row['end'])] 112 | x, y = tk_dict[sp] 113 | if int(row['begin']) < int(row['dependency_start']['head.begin']): 114 | if x in count_left: 115 | count_left[x] += 1 116 | t_x_offset[x] += 7 117 | else: 118 | count_left[x] = 1 119 | t_x_offset[x] = 7 120 | if x_e in count_right: 121 | count_right[x_e] += 1 122 | t_x_offset[x_e] -= 7 123 | else: 124 | count_right[x_e] = 0 125 | t_x_offset[x_e] = 0 126 | else: 127 | if x in count_right: 128 | count_right[x] += 1 129 | t_x_offset[x] -= 7 130 | else: 131 | count_right[x] = 0 132 | t_x_offset[x] = 0 133 | if x_e in count_left: 134 | count_left[x_e] += 1 135 | t_x_offset[x_e] += 7 136 | else: 137 | count_left[x_e] = 1 138 | t_x_offset[x_e] = 7 139 | 140 | for i, row in result_df.iterrows(): 141 | 142 | sp = str(row['dependency_start']['head.begin'])+str(row['dependency_start']['head.end']) 143 | ep = tk_dict[str(row['begin'])+str(row['end'])] 144 | 145 | if sp != '-1-1': 146 | x, y = tk_dict[sp] 147 | 148 | if int(row['begin']) > int(row['dependency_start']['head.begin']): 149 | dist_dict[x] = count_right[x] * 7 150 | count_right[x] -= 1 151 | e_dist_dict[ep[0]] = count_left[ep[0]] * -7 152 | count_left[ep[0]] -= 1 153 | else: 154 | dist_dict[x] = count_left[x] * -7 155 | count_left[x] -= 1 156 | e_dist_dict[ep[0]] = count_right[ep[0]] * 7 157 | count_right[ep[0]] -= 1 158 | #row['dependency'], x, t_x_offset[x], x+dist_dict[x], x+dist_dict[x]+t_x_offset[x] 159 | final_x_s = int(x+dist_dict[x]+(t_x_offset[x]/2)) 160 | final_x_e = int(ep[0]+ e_dist_dict[ep[0]]+(t_x_offset[ep[0]]/2)) 161 | 162 | x_inds = range(min(final_x_s, final_x_e), max(final_x_s, final_x_e)+1) 163 | common = set(y_hist.keys()).intersection(set(x_inds)) 164 | 165 | if common: 166 | y_fset = min([y_hist[c] for c in common]) 167 | y_fset -= 50 168 | y_hist.update(dict(zip(x_inds, [y_fset]*len(x_inds)))) 169 | 170 | else: 171 | y_hist.update(dict(zip(x_inds, [y_offset]*len(x_inds)))) 172 | 173 | main_lines.append((None, final_x_s, y-30, final_x_e, y_hist[final_x_s], row['dependency_type'])) 174 | 175 | else: 176 | x_x , y_y = tk_dict[str(row['begin'])+str(row['end'])] 177 | 178 | root_list.append((row['dependency_type'].upper(), x_x, y_y)) 179 | 180 | 181 | current_y = min(y_hist.values()) 182 | 183 | y_ff = (max_y - current_y) + 50 184 | y_f = (current_y - 50) 185 | current_y = 50 186 | 187 | dwg = svgwrite.Drawing("temp.svg", 188 | profile='tiny', size = (max_x, y_ff+100)) 189 | 190 | for mt, mp in zip(main_text, main_pos): 191 | dwg.add(dwg.text(mt[0], insert=(mt[1], mt[2]-y_f), fill='gray', 192 | font_size='25', font_family='courier')) 193 | 194 | if mt[3]: 195 | dwg.add(dwg.rect(insert=(mt[1]-5, mt[2]-y_f-25), size=(self.__size(mt[0]),35), stroke='orange', 196 | stroke_width='2', fill='none')) 197 | 198 | dwg.add(dwg.text(mp[0], insert=(mp[1], mp[2]-y_f), fill=colors_dict[mp[0]])) 199 | 200 | for ml in main_lines: 201 | self.__draw_line(dwg, ml[1], ml[2]-y_f, ml[3], ml[4]-y_f, ml[5], colors_dict[ml[5]]) 202 | 203 | return dwg.tostring() 204 | 205 | 206 | def display(self, res, pos_col, dependency_col, dependency_type_col): 207 | """Displays NER visualization. 208 | 209 | Inputs: 210 | result -- A Dataframe or dictionary. 211 | label_col -- Name of the column/key containing NER annotations. 212 | document_col -- Name of the column/key containing text document. 213 | original_text -- Original text of type 'str'. If specified, it will take precedence over 'document_col' and will be used as the reference text for display. 214 | labels_list -- A list of labels that should be highlighted in the output. Default: Display all labels. 215 | 216 | Output: Visualization 217 | """ 218 | 219 | pos_res = [] 220 | for i in res[pos_col]: 221 | t_ = {'chunk': i.metadata['word'], 222 | 'begin': i.begin, 223 | 'end' : i.end, 224 | 'pos' : i.result} 225 | pos_res.append(t_) 226 | dep_res = [] 227 | dep_res_meta = [] 228 | for i in res[dependency_col]: 229 | dep_res.append(i.result) 230 | dep_res_meta.append(i.metadata) 231 | df = pd.DataFrame(pos_res) 232 | df['dependency'] = dep_res 233 | df['dependency_start'] = dep_res_meta 234 | 235 | dept_res = [] 236 | for i in res[dependency_type_col]: 237 | dept_res.append(i.result) 238 | df['dependency_type'] = dept_res 239 | 240 | return display(HTML(self.__generate_graph(df))) 241 | 242 | -------------------------------------------------------------------------------- /build/lib/sparknlp_display/dependency_parser.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import json 4 | import pandas as pd 5 | import numpy as np 6 | import svgwrite 7 | from . import style_utils as style_config 8 | from IPython.display import display, HTML 9 | 10 | here = os.path.abspath(os.path.dirname(__file__)) 11 | 12 | class DependencyParserVisualizer: 13 | 14 | def __init__(self): 15 | self.font_path = os.path.join(here, 'fonts/Lucida_Console.ttf') 16 | self.main_font = 'Lucida' 17 | 18 | def __get_color(self, l): 19 | r = lambda: random.randint(0,200) 20 | return '#%02X%02X%02X' % (r(), r(), r()) 21 | 22 | def __size(self, text): 23 | return ((len(text)+1)*12) 24 | 25 | def __draw_line(self, dwg, s_x , s_y, e_x, e_y, d_type, color): 26 | line = dwg.add(dwg.polyline( 27 | [ 28 | (e_x, s_y), 29 | (e_x+2, s_y), 30 | (e_x, s_y+4), 31 | (e_x-2, s_y), 32 | (e_x, s_y) 33 | ], 34 | stroke='black', stroke_width = "2", fill='none',)) 35 | 36 | #if e_x > s_x: 37 | rad=10 38 | height=abs(e_y-s_y-4)-rad 39 | sx = s_x 40 | sy=300 41 | distance=abs(e_x-s_x)-rad*2 42 | if e_x > s_x: 43 | dwg.add(dwg.path(d=f"M{s_x},{s_y+4} v-{height} a{rad},{rad} 0 0 1 {rad},-{rad} h{distance} a{rad},{rad} 0 0 1 {rad},{rad} v{height-4}", 44 | fill="none", 45 | stroke="black", stroke_width=1 46 | )) 47 | else: 48 | dwg.add(dwg.path(d=f"M{s_x},{s_y+4} v-{height} a{rad},{rad} 0 0 0 -{rad},-{rad} h-{distance} a{rad},{rad} 0 0 0 -{rad} {rad} v{height-4}", 49 | fill="none", 50 | stroke="black", stroke_width=1 51 | )) 52 | 53 | 54 | 55 | dwg.add(dwg.text(d_type, insert=(((s_x+e_x)/2)-(self.__size(d_type.strip())/3.0), e_y-4), 56 | fill=color, font_size='14', font_family=self.main_font)) 57 | 58 | def __generate_graph(self, result_df): 59 | # returns an svg graph 60 | 61 | colors_dict = {} 62 | max_x = 50 63 | max_y = 100 64 | 65 | for i in result_df['dependency_type'].unique(): 66 | colors_dict[i] = self.__get_color(i) 67 | 68 | for i in result_df['pos'].unique(): 69 | colors_dict[i] = self.__get_color(i) 70 | 71 | for i, row in result_df.iterrows(): 72 | txt = row['chunk'].strip() 73 | max_x += (self.__size(txt) + 50) 74 | max_y += 30 75 | 76 | max_x += 50 77 | start_x = 50 78 | starty_y = max_y 79 | dp_dict={} 80 | tk_dict = {} 81 | dist_dict = {} 82 | main_text = [] 83 | main_pos = [] 84 | 85 | for i, row in result_df.iterrows(): 86 | txt = row['chunk'].strip() 87 | dt = row['dependency'].lower().strip() 88 | is_root = False 89 | if dt == 'root': 90 | is_root = True 91 | main_text.append((txt, start_x, starty_y, is_root)) 92 | main_pos.append( 93 | (row['pos'].strip(), 94 | (start_x + int((self.__size(txt)/2) - int(self.__size(row['pos'])/2.25))), 95 | starty_y+30)) 96 | 97 | tk_dict[str(row['begin'])+str(row['end'])] = (start_x+int(self.__size(txt)/2), starty_y) 98 | start_x += (self.__size(txt) + 50) 99 | 100 | y_offset = starty_y-100 101 | dist_dict = {} 102 | e_dist_dict = {} 103 | direct_dict = {} 104 | left_side_dict = {} 105 | right_side_dict = {} 106 | y_hist = {} 107 | root_list = [] 108 | main_lines = [] 109 | lines_dist = [] 110 | 111 | dist = [] 112 | for i, row in result_df.iterrows(): 113 | if row['dependency'].lower().strip() != 'root': 114 | lines_dist.append(abs(int(row['begin']) - int(row['dependency_start']['head.begin']))) 115 | else: 116 | lines_dist.append(0) 117 | 118 | result_df = result_df.iloc[np.argsort(lines_dist)] 119 | 120 | count_left = {} 121 | count_right = {} 122 | t_x_offset = {} 123 | for i, row in result_df.iterrows(): 124 | if row['dependency'].lower().strip() != 'root': 125 | sp = str(row['dependency_start']['head.begin'])+str(row['dependency_start']['head.end']) 126 | x_e, y_e = tk_dict[str(row['begin'])+str(row['end'])] 127 | x, y = tk_dict[sp] 128 | if int(row['begin']) < int(row['dependency_start']['head.begin']): 129 | if x in count_left: 130 | count_left[x] += 1 131 | t_x_offset[x] += 7 132 | else: 133 | count_left[x] = 1 134 | t_x_offset[x] = 7 135 | if x_e in count_right: 136 | count_right[x_e] += 1 137 | t_x_offset[x_e] -= 7 138 | else: 139 | count_right[x_e] = 0 140 | t_x_offset[x_e] = 0 141 | else: 142 | if x in count_right: 143 | count_right[x] += 1 144 | t_x_offset[x] -= 7 145 | else: 146 | count_right[x] = 0 147 | t_x_offset[x] = 0 148 | if x_e in count_left: 149 | count_left[x_e] += 1 150 | t_x_offset[x_e] += 7 151 | else: 152 | count_left[x_e] = 1 153 | t_x_offset[x_e] = 7 154 | 155 | for i, row in result_df.iterrows(): 156 | 157 | sp = str(row['dependency_start']['head.begin'])+str(row['dependency_start']['head.end']) 158 | ep = tk_dict[str(row['begin'])+str(row['end'])] 159 | 160 | if sp != '-1-1': 161 | x, y = tk_dict[sp] 162 | 163 | if int(row['begin']) > int(row['dependency_start']['head.begin']): 164 | dist_dict[x] = count_right[x] * 7 165 | count_right[x] -= 1 166 | e_dist_dict[ep[0]] = count_left[ep[0]] * -7 167 | count_left[ep[0]] -= 1 168 | else: 169 | dist_dict[x] = count_left[x] * -7 170 | count_left[x] -= 1 171 | e_dist_dict[ep[0]] = count_right[ep[0]] * 7 172 | count_right[ep[0]] -= 1 173 | #row['dependency'], x, t_x_offset[x], x+dist_dict[x], x+dist_dict[x]+t_x_offset[x] 174 | final_x_s = int(x+dist_dict[x]+(t_x_offset[x]/2)) 175 | final_x_e = int(ep[0]+ e_dist_dict[ep[0]]+(t_x_offset[ep[0]]/2)) 176 | 177 | x_inds = range(min(final_x_s, final_x_e), max(final_x_s, final_x_e)+1) 178 | common = set(y_hist.keys()).intersection(set(x_inds)) 179 | 180 | if common: 181 | y_fset = min([y_hist[c] for c in common]) 182 | y_fset -= 50 183 | y_hist.update(dict(zip(x_inds, [y_fset]*len(x_inds)))) 184 | 185 | else: 186 | y_hist.update(dict(zip(x_inds, [y_offset]*len(x_inds)))) 187 | 188 | main_lines.append((None, final_x_s, y-30, final_x_e, y_hist[final_x_s], row['dependency_type'])) 189 | 190 | else: 191 | x_x , y_y = tk_dict[str(row['begin'])+str(row['end'])] 192 | 193 | root_list.append((row['dependency_type'].upper(), x_x, y_y)) 194 | 195 | 196 | current_y = min(y_hist.values()) 197 | 198 | y_ff = (max_y - current_y) + 50 199 | y_f = (current_y - 50) 200 | current_y = 50 201 | 202 | dwg = svgwrite.Drawing("temp.svg", 203 | profile='full', size = (max_x, y_ff+100)) 204 | dwg.embed_font(self.main_font, self.font_path) 205 | 206 | for mt, mp in zip(main_text, main_pos): 207 | dwg.add(dwg.text(mt[0], insert=(mt[1], mt[2]-y_f), fill='gray', 208 | font_size='20', font_family=self.main_font)) 209 | 210 | if mt[3]: 211 | dwg.add(dwg.rect(insert=(mt[1]-5, mt[2]-y_f-25), rx=5,ry=5, size=(self.__size(mt[0]),35), stroke='#800080', 212 | stroke_width='1', fill='none')) 213 | 214 | dwg.add(dwg.text(mp[0], insert=(mp[1], mp[2]-y_f), font_size='14', fill=colors_dict[mp[0]])) 215 | 216 | for ml in main_lines: 217 | self.__draw_line(dwg, ml[1], ml[2]-y_f, ml[3], ml[4]-y_f, ml[5], colors_dict[ml[5]]) 218 | 219 | return dwg.tostring() 220 | 221 | 222 | def display(self, res, pos_col, dependency_col, dependency_type_col=None, return_html=False, save_path=None): 223 | """Displays NER visualization. 224 | 225 | Inputs: 226 | result -- A Dataframe or dictionary. 227 | label_col -- Name of the column/key containing NER annotations. 228 | document_col -- Name of the column/key containing text document. 229 | original_text -- Original text of type 'str'. If specified, it will take precedence over 'document_col' and will be used as the reference text for display. 230 | labels_list -- A list of labels that should be highlighted in the output. Default: Display all labels. 231 | 232 | Output: Visualization 233 | """ 234 | 235 | pos_res = [] 236 | for i in res[pos_col]: 237 | t_ = {'chunk': i.metadata['word'], 238 | 'begin': i.begin, 239 | 'end' : i.end, 240 | 'pos' : i.result} 241 | pos_res.append(t_) 242 | dep_res = [] 243 | dep_res_meta = [] 244 | for i in res[dependency_col]: 245 | dep_res.append(i.result) 246 | dep_res_meta.append(i.metadata) 247 | df = pd.DataFrame(pos_res) 248 | df['dependency'] = dep_res 249 | df['dependency_start'] = dep_res_meta 250 | 251 | if dependency_type_col != None: 252 | df['dependency_type'] = [ i.result for i in res[dependency_type_col] ] 253 | else: 254 | df['dependency_type'] = '' 255 | 256 | html_content = self.__generate_graph(df) 257 | 258 | if save_path != None: 259 | with open(save_path, 'w') as f_: 260 | f_.write(html_content) 261 | 262 | if return_html: 263 | return html_content 264 | else: 265 | return display(HTML(html_content)) 266 | -------------------------------------------------------------------------------- /build/lib/sparknlp_display/entity_resolution.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import json 4 | import numpy as np 5 | from . import style_utils as style_config 6 | from IPython.display import display, HTML 7 | 8 | here = os.path.abspath(os.path.dirname(__file__)) 9 | 10 | class EntityResolverVisualizer: 11 | def __init__(self): 12 | with open(os.path.join(here, 'label_colors/ner.json'), 'r', encoding='utf-8') as f_: 13 | self.label_colors = json.load(f_) 14 | 15 | #public function to get color for a label 16 | def get_label_color(self, label): 17 | """Returns color of a particular label 18 | 19 | Input: entity label 20 | Output: Color or if not found 21 | """ 22 | 23 | if str(label).lower() in self.label_colors: 24 | return self.label_colors[label.lower()] 25 | else: 26 | return None 27 | 28 | # private function for colors for display 29 | def __get_label(self, label): 30 | """Set label colors. 31 | 32 | Input: dictionary of entity labels and corresponding colors 33 | Output: self object - to allow chaining 34 | Note: Previous values of colors will be overwritten 35 | """ 36 | if str(label).lower() in self.label_colors: 37 | return self.label_colors[label.lower()] 38 | else: 39 | #update it to fetch from git new labels 40 | r = lambda: random.randint(0,200) 41 | return '#%02X%02X%02X' % (r(), r(), r()) 42 | 43 | def set_label_colors(self, color_dict): 44 | """Sets label colors. 45 | 46 | input: dictionary of entity labels and corresponding colors 47 | output: self object - to allow chaining 48 | note: Previous values of colors will be overwritten 49 | """ 50 | 51 | for key, value in color_dict.items(): 52 | self.label_colors[key.lower()] = value 53 | return self 54 | 55 | def __verify_structure(self, result, label_col, document_col, original_text): 56 | 57 | if original_text is None: 58 | basic_msg_1 = """Incorrect annotation structure of '{}'.""".format(document_col) 59 | if not hasattr(result[document_col][0], 'result'): 60 | raise AttributeError(basic_msg_1+""" 'result' attribute not found in the annotation. 61 | Make sure '"""+document_col+"""' is a list of objects having the following structure: 62 | Annotation(type='annotation', begin=0, end=10, result='This is a text') 63 | Or You can pass the text manually using 'raw_text' argument.""") 64 | 65 | basic_msg_1 = """Incorrect annotation structure of '{}'.""".format(label_col) 66 | basic_msg = """ 67 | In sparknlp please use 'LightPipeline.fullAnnotate' for LightPipeline or 'Pipeline.transform' for PipelineModel. 68 | Or 69 | Make sure '"""+label_col+"""' is a list of objects having the following structure: 70 | Annotation(type='annotation', begin=0, end=0, result='Adam', metadata={'entity': 'PERSON'})""" 71 | 72 | for entity in result[label_col]: 73 | if not hasattr(entity, 'begin'): 74 | raise AttributeError( basic_msg_1 + """ 'begin' attribute not found in the annotation."""+basic_msg) 75 | if not hasattr(entity, 'end'): 76 | raise AttributeError(basic_msg_1 + """ 'end' attribute not found in the annotation."""+basic_msg) 77 | if not hasattr(entity, 'result'): 78 | raise AttributeError(basic_msg_1 + """ 'result' attribute not found in the annotation."""+basic_msg) 79 | if not hasattr(entity, 'metadata'): 80 | raise AttributeError(basic_msg_1 + """ 'metadata' attribute not found in the annotation."""+basic_msg) 81 | if 'entity' not in entity.metadata: 82 | raise AttributeError(basic_msg_1+""" KeyError: 'entity' not found in metadata."""+basic_msg) 83 | 84 | def __verify_input(self, result, label_col, document_col, original_text): 85 | # check if label colum in result 86 | if label_col not in result: 87 | raise AttributeError("""column/key '{}' not found in the provided dataframe/dictionary. 88 | Please specify the correct key/column using 'label_col' argument.""".format(label_col)) 89 | 90 | if original_text is not None: 91 | # check if provided text is correct data type 92 | if not isinstance(original_text, str): 93 | raise ValueError("Invalid value for argument 'raw_text' input. Text should be of type 'str'.") 94 | 95 | else: 96 | # check if document column in result 97 | if document_col not in result: 98 | raise AttributeError("""column/key '{}' not found in the provided dataframe/dictionary. 99 | Please specify the correct key/column using 'document_col' argument. 100 | Or You can pass the text manually using 'raw_text' argument""".format(document_col)) 101 | 102 | self.__verify_structure( result, label_col, document_col, original_text) 103 | 104 | # main display function 105 | def __display_ner(self, result, label_col, resolution_col, document_col, original_text, labels_list = None): 106 | 107 | if original_text is None: 108 | original_text = result[document_col][0].result 109 | 110 | if labels_list is not None: 111 | labels_list = [v.lower() for v in labels_list] 112 | 113 | resolution_temp_dict = {} 114 | for resol in result[resolution_col]: 115 | resolution_temp_dict[int(resol.begin)] = [resol.result, resol.metadata['resolved_text']] 116 | 117 | 118 | label_color = {} 119 | html_output = "" 120 | pos = 0 121 | 122 | sorted_labs_idx = np.argsort([ int(i.begin) for i in result[label_col]]) 123 | sorted_labs = [ result[label_col][i] for i in sorted_labs_idx] 124 | 125 | for entity in sorted_labs: 126 | entity_type = entity.metadata['entity'].lower() 127 | if (entity_type not in label_color) and ((labels_list is None) or (entity_type in labels_list)) : 128 | label_color[entity_type] = self.__get_label(entity_type) 129 | 130 | begin = int(entity.begin) 131 | end = int(entity.end) 132 | if pos < begin and pos < len(original_text): 133 | white_text = original_text[pos:begin] 134 | html_output += '{}'.format(white_text) 135 | pos = end+1 136 | 137 | if entity_type in label_color: 138 | if begin in resolution_temp_dict: 139 | html_output += '{} {}{} {}'.format( 140 | label_color[entity_type] + 'B3', #color 141 | entity.result, #entity - chunk 142 | entity.metadata['entity'], #entity - label 143 | label_color[entity_type] + 'FF', #color '#D2C8C6' 144 | resolution_temp_dict[begin][0], # res_code 145 | label_color[entity_type] + 'CC', # res_color '#DDD2D0' 146 | resolution_temp_dict[begin][1] # res_text 147 | ) 148 | else: 149 | html_output += '{} {}'.format( 150 | label_color[entity_type] + 'B3', #color 151 | entity.result, #entity - chunk 152 | entity.metadata['entity'] #entity - label 153 | ) 154 | else: 155 | html_output += '{}'.format(entity.result) 156 | 157 | if pos < len(original_text): 158 | html_output += '{}'.format(original_text[pos:]) 159 | 160 | html_output += """""" 161 | 162 | html_output = html_output.replace("\n", "
") 163 | 164 | return html_output 165 | 166 | def display(self, result, label_col, resolution_col, document_col='document', raw_text=None, return_html=False, save_path=None): 167 | """Displays NER visualization. 168 | 169 | Inputs: 170 | result -- A Dataframe or dictionary. 171 | label_col -- Name of the column/key containing NER annotations. 172 | document_col -- Name of the column/key containing text document. 173 | original_text -- Original text of type 'str'. If specified, it will take precedence over 'document_col' and will be used as the reference text for display. 174 | labels_list -- A list of labels that should be highlighted in the output. Default: Display all labels. 175 | 176 | Output: Visualization 177 | """ 178 | 179 | #self.__verifyInput(result, label_col, document_col, raw_text) 180 | 181 | html_content = self.__display_ner(result, label_col, resolution_col, document_col, raw_text) 182 | 183 | html_content_save = style_config.STYLE_CONFIG_ENTITIES+ " "+html_content 184 | 185 | if save_path != None: 186 | with open(save_path, 'w') as f_: 187 | f_.write(html_content_save) 188 | 189 | if return_html: 190 | return html_content_save 191 | else: 192 | return display(HTML(html_content_save)) 193 | -------------------------------------------------------------------------------- /build/lib/sparknlp_display/fonts/Lucida_Console.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/a6c5a6ac166c5f75ec4ccf64ffafb2fa6e1164a1/build/lib/sparknlp_display/fonts/Lucida_Console.ttf -------------------------------------------------------------------------------- /build/lib/sparknlp_display/label_colors/ner.json: -------------------------------------------------------------------------------- 1 | { 2 | "problem": "#800080", 3 | "test": "#77b5fe", 4 | "treatment": "#8b6673", 5 | "multi": "#494ca3", 6 | "multi-tissue_structure": "#8dd8b4", 7 | "cell": "#ffe6cc", 8 | "organism": "#ffddcc", 9 | "gene_or_gene_product": "#fff0b3", 10 | "organ": "#e6e600", 11 | "simple_chemical": "#ffd699", 12 | "drug": "#8B668B", 13 | "diagnosis": "#b5a1c9", 14 | "maybe": "#FFB5C5", 15 | "lab_result": "#3abd80", 16 | "negated": "#CD3700", 17 | "name": "#C0FF3E", 18 | "lab_name": "#698B22", 19 | "modifier": "#8B475D", 20 | "symptom_name": "#CDB7B5", 21 | "section_name": "#8B7D7B", 22 | "drug_name": "#a3496c", 23 | "procedure_name": "#48D1CC", 24 | "grading": "#8c61e8", 25 | "size": "#746b87", 26 | "organism_substance": "#ffaa80", 27 | "gender": "#ffacb7", 28 | "age": "#ffe0ac", 29 | "date": "#a6b1e1" 30 | } -------------------------------------------------------------------------------- /build/lib/sparknlp_display/label_colors/relations.json: -------------------------------------------------------------------------------- 1 | { 2 | "overlap" : "#ffb345", 3 | "before" : "#0398da", 4 | "after" : "#39bf7f", 5 | 6 | "trip": "#e4815e", 7 | "trwp": "#0398da", 8 | "trcp": "#39bf7f", 9 | "trap": "#ffb345", 10 | "trnap": "#0059b3", 11 | "terp": "#8c35cd", 12 | "tecp": "#fa3e74", 13 | "pip" : "#6e5772", 14 | 15 | "drug-strength" : "purple", 16 | "drug-frequency": "slategray", 17 | "drug-form" : "deepskyblue", 18 | "dosage-drug" : "springgreen", 19 | "strength-drug": "maroon", 20 | "drug-dosage" : "gold", 21 | 22 | "0" : "#e4815e", 23 | "1" : "#6e5772" 24 | } -------------------------------------------------------------------------------- /build/lib/sparknlp_display/ner.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import json 4 | import numpy as np 5 | from . import style_utils as style_config 6 | from IPython.display import display, HTML 7 | 8 | here = os.path.abspath(os.path.dirname(__file__)) 9 | 10 | class NerVisualizer: 11 | def __init__(self): 12 | with open(os.path.join(here, 'label_colors/ner.json'), 'r', encoding='utf-8') as f_: 13 | self.label_colors = json.load(f_) 14 | 15 | #public function to get color for a label 16 | def get_label_color(self, label): 17 | """Returns color of a particular label 18 | 19 | Input: entity label 20 | Output: Color or if not found 21 | """ 22 | 23 | if str(label).lower() in self.label_colors: 24 | return self.label_colors[label.lower()] 25 | else: 26 | return None 27 | 28 | # private function for colors for display 29 | def __get_label(self, label): 30 | """Internal function to generate random color codes for missing colors 31 | 32 | Input: dictionary of entity labels and corresponding colors 33 | Output: color code (Hex) 34 | """ 35 | if str(label).lower() in self.label_colors: 36 | return self.label_colors[label.lower()] 37 | else: 38 | #update it to fetch from git new labels 39 | r = lambda: random.randint(0,200) 40 | return '#%02X%02X%02X' % (r(), r(), r()) 41 | 42 | def set_label_colors(self, color_dict): 43 | """Sets label colors. 44 | input: dictionary of entity labels and corresponding colors 45 | output: self object - to allow chaining 46 | note: Previous values of colors will be overwritten 47 | """ 48 | 49 | for key, value in color_dict.items(): 50 | self.label_colors[key.lower()] = value 51 | return self 52 | 53 | def __verify_structure(self, result, label_col, document_col, original_text): 54 | 55 | if original_text is None: 56 | basic_msg_1 = """Incorrect annotation structure of '{}'.""".format(document_col) 57 | if not hasattr(result[document_col][0], 'result'): 58 | raise AttributeError(basic_msg_1+""" 'result' attribute not found in the annotation. 59 | Make sure '"""+document_col+"""' is a list of objects having the following structure: 60 | Annotation(type='annotation', begin=0, end=10, result='This is a text') 61 | Or You can pass the text manually using 'raw_text' argument.""") 62 | 63 | basic_msg_1 = """Incorrect annotation structure of '{}'.""".format(label_col) 64 | basic_msg = """ 65 | In sparknlp please use 'LightPipeline.fullAnnotate' for LightPipeline or 'Pipeline.transform' for PipelineModel. 66 | Or 67 | Make sure '"""+label_col+"""' is a list of objects having the following structure: 68 | Annotation(type='annotation', begin=0, end=0, result='Adam', metadata={'entity': 'PERSON'})""" 69 | 70 | for entity in result[label_col]: 71 | if not hasattr(entity, 'begin'): 72 | raise AttributeError( basic_msg_1 + """ 'begin' attribute not found in the annotation."""+basic_msg) 73 | if not hasattr(entity, 'end'): 74 | raise AttributeError(basic_msg_1 + """ 'end' attribute not found in the annotation."""+basic_msg) 75 | if not hasattr(entity, 'result'): 76 | raise AttributeError(basic_msg_1 + """ 'result' attribute not found in the annotation."""+basic_msg) 77 | if not hasattr(entity, 'metadata'): 78 | raise AttributeError(basic_msg_1 + """ 'metadata' attribute not found in the annotation."""+basic_msg) 79 | if 'entity' not in entity.metadata: 80 | raise AttributeError(basic_msg_1+""" KeyError: 'entity' not found in metadata."""+basic_msg) 81 | 82 | def __verify_input(self, result, label_col, document_col, original_text): 83 | # check if label colum in result 84 | if label_col not in result: 85 | raise AttributeError("""column/key '{}' not found in the provided dataframe/dictionary. 86 | Please specify the correct key/column using 'label_col' argument.""".format(label_col)) 87 | 88 | if original_text is not None: 89 | # check if provided text is correct data type 90 | if not isinstance(original_text, str): 91 | raise ValueError("Invalid value for argument 'raw_text' input. Text should be of type 'str'.") 92 | 93 | else: 94 | # check if document column in result 95 | if document_col not in result: 96 | raise AttributeError("""column/key '{}' not found in the provided dataframe/dictionary. 97 | Please specify the correct key/column using 'document_col' argument. 98 | Or You can pass the text manually using 'raw_text' argument""".format(document_col)) 99 | 100 | self.__verify_structure( result, label_col, document_col, original_text) 101 | 102 | # main display function 103 | def __display_ner(self, result, label_col, document_col, original_text, labels_list = None): 104 | 105 | if original_text is None: 106 | original_text = result[document_col][0].result 107 | 108 | if labels_list is not None: 109 | labels_list = [v.lower() for v in labels_list] 110 | label_color = {} 111 | html_output = "" 112 | pos = 0 113 | 114 | sorted_labs_idx = np.argsort([ int(i.begin) for i in result[label_col]]) 115 | sorted_labs = [ result[label_col][i] for i in sorted_labs_idx] 116 | 117 | for entity in sorted_labs: 118 | entity_type = entity.metadata['entity'].lower() 119 | if (entity_type not in label_color) and ((labels_list is None) or (entity_type in labels_list)) : 120 | label_color[entity_type] = self.__get_label(entity_type) 121 | 122 | begin = int(entity.begin) 123 | end = int(entity.end) 124 | if pos < begin and pos < len(original_text): 125 | white_text = original_text[pos:begin] 126 | html_output += '{}'.format(white_text) 127 | pos = end+1 128 | 129 | if entity_type in label_color: 130 | html_output += '{} {}'.format( 131 | label_color[entity_type], 132 | original_text[begin:end+1],#entity.result, 133 | entity.metadata['entity']) 134 | else: 135 | html_output += '{}'.format(original_text[begin:end+1]) 136 | 137 | if pos < len(original_text): 138 | html_output += '{}'.format(original_text[pos:]) 139 | 140 | html_output += """""" 141 | 142 | html_output = html_output.replace("\n", "
") 143 | 144 | return html_output 145 | 146 | def display(self, result, label_col, document_col='document', raw_text=None, labels=None, return_html=False, save_path=None): 147 | """Displays NER visualization. 148 | Inputs: 149 | result -- A Dataframe or dictionary. 150 | label_col -- Name of the column/key containing NER annotations. 151 | document_col -- Name of the column/key containing text document. 152 | original_text -- Original text of type 'str'. If specified, it will take precedence over 'document_col' and will be used as the reference text for display. 153 | labels_list -- A list of labels that should be highlighted in the output. Default: Display all labels. 154 | Output: Visualization 155 | """ 156 | 157 | self.__verify_input(result, label_col, document_col, raw_text) 158 | 159 | html_content = self.__display_ner(result, label_col, document_col, raw_text, labels) 160 | 161 | html_content_save = style_config.STYLE_CONFIG_ENTITIES+ " "+html_content 162 | 163 | if save_path != None: 164 | with open(save_path, 'w') as f_: 165 | f_.write(html_content_save) 166 | 167 | if return_html: 168 | return html_content_save 169 | else: 170 | return display(HTML(html_content_save)) 171 | 172 | -------------------------------------------------------------------------------- /build/lib/sparknlp_display/re_updates.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import json 4 | import pandas as pd 5 | import numpy as np 6 | import svgwrite 7 | from IPython.display import display, HTML 8 | 9 | here = os.path.abspath(os.path.dirname(__file__)) 10 | 11 | class RelationExtractionVisualizer: 12 | 13 | def __init__(self): 14 | self.color_dict = { 15 | "overlap" : "lightsalmon", 16 | "before" : "deepskyblue", 17 | "after" : "springgreen", 18 | 19 | "trip": "lightsalmon", 20 | "trwp": "deepskyblue", 21 | "trcp": "springgreen", 22 | "trap": "gold", 23 | "trnap": "maroon", 24 | "terp": "purple", 25 | "tecp": "tomato", 26 | "pip" : "slategray", 27 | 28 | "drug-strength" : "purple", 29 | "drug-frequency": "slategray", 30 | "drug-form" : "deepskyblue", 31 | "dosage-drug" : "springgreen", 32 | "strength-drug": "maroon", 33 | "drug-dosage" : "gold" 34 | } 35 | 36 | def __get_color(self, l): 37 | r = lambda: random.randint(100,255) 38 | return '#%02X%02X%02X' % (r(), r(), r()) 39 | 40 | def __size(self, text): 41 | return ((len(text)+1)*9.7)-5 42 | 43 | def __draw_line(self, dwg, s_x , s_y, e_x, e_y, d_type, color, show_relations): 44 | # find the a & b points 45 | def get_bezier_coef(points): 46 | # since the formulas work given that we have n+1 points 47 | # then n must be this: 48 | n = len(points) - 1 49 | 50 | # build coefficents matrix 51 | C = 4 * np.identity(n) 52 | np.fill_diagonal(C[1:], 1) 53 | np.fill_diagonal(C[:, 1:], 1) 54 | C[0, 0] = 2 55 | C[n - 1, n - 1] = 7 56 | C[n - 1, n - 2] = 2 57 | 58 | # build points vector 59 | P = [2 * (2 * points[i] + points[i + 1]) for i in range(n)] 60 | P[0] = points[0] + 2 * points[1] 61 | P[n - 1] = 8 * points[n - 1] + points[n] 62 | 63 | # solve system, find a & b 64 | A = np.linalg.solve(C, P) 65 | B = [0] * n 66 | for i in range(n - 1): 67 | B[i] = 2 * points[i + 1] - A[i + 1] 68 | B[n - 1] = (A[n - 1] + points[n]) / 2 69 | 70 | return A, B 71 | 72 | # returns the general Bezier cubic formula given 4 control points 73 | def get_cubic(a, b, c, d): 74 | return lambda t: np.power(1 - t, 3) * a + 3 * np.power(1 - t, 2) * t * b + 3 * (1 - t) * np.power(t, 2) * c + np.power(t, 3) * d 75 | 76 | # return one cubic curve for each consecutive points 77 | def get_bezier_cubic(points): 78 | A, B = get_bezier_coef(points) 79 | return [ 80 | get_cubic(points[i], A[i], B[i], points[i + 1]) 81 | for i in range(len(points) - 1) 82 | ] 83 | 84 | # evalute each cubic curve on the range [0, 1] sliced in n points 85 | def evaluate_bezier(points, n): 86 | curves = get_bezier_cubic(points) 87 | return np.array([fun(t) for fun in curves for t in np.linspace(0, 1, n)]) 88 | 89 | 90 | def draw_pointer(dwg, s_x, s_y, e_x, e_y): 91 | size = 8 92 | ratio = 2 93 | fullness1 = 2 94 | fullness2 = 3 95 | bx = e_x 96 | ax = s_x 97 | by = e_y 98 | ay = s_y 99 | abx = bx - ax 100 | aby = by - ay 101 | ab = np.sqrt(abx * abx + aby * aby) 102 | 103 | cx = bx - size * abx / ab 104 | cy = by - size * aby / ab 105 | dx = cx + (by - cy) / ratio 106 | dy = cy + (cx - bx) / ratio 107 | ex = cx - (by - cy) / ratio 108 | ey = cy - (cx - bx) / ratio 109 | fx = (fullness1 * cx + bx) / fullness2 110 | fy = (fullness1 * cy + by) / fullness2 111 | 112 | text_place_y = s_y-(abs(s_y-e_y)/2) 113 | line = dwg.add(dwg.polyline( 114 | [ 115 | (bx, by), 116 | (dx, dy), 117 | (fx, fy), 118 | (ex, ey), 119 | (bx, by) 120 | ], 121 | stroke=color, stroke_width = "2", fill='none',)) 122 | return text_place_y 123 | 124 | if s_x > e_x: 125 | #s_x -= 5 126 | e_x += 10 127 | else: 128 | #s_x += 5 129 | e_x -= 2 130 | if s_y == e_y: 131 | s_y -= 20 132 | e_y = s_y-4#55 133 | text_place_y = s_y-45 134 | 135 | pth = evaluate_bezier(np.array([[s_x, s_y], 136 | [(s_x+e_x)/2.0, s_y-40], 137 | [e_x,e_y]]), 50) 138 | dwg.add(dwg.polyline(pth, 139 | stroke=color, stroke_width = "2", fill='none',)) 140 | 141 | draw_pointer(dwg, (s_x+e_x)/2.0, s_y-50, e_x, e_y) 142 | 143 | elif s_y >= e_y: 144 | e_y +=15 145 | s_y-=20 146 | dwg.add(dwg.polyline([(s_x,s_y), (e_x, e_y)], 147 | stroke=color, stroke_width = "2", fill='none',)) 148 | text_place_y = draw_pointer(dwg, s_x, s_y, e_x, e_y) 149 | else: 150 | s_y-=5 151 | e_y -= 40 152 | dwg.add(dwg.polyline([(s_x,s_y), (e_x, e_y)], 153 | stroke=color, stroke_width = "2", fill='none',)) 154 | text_place_y = draw_pointer(dwg, s_x, s_y, e_x, e_y) 155 | if show_relations: 156 | dwg.add(dwg.text(d_type, insert=(((s_x+e_x)/2)-(self.__size(d_type)/2.75), text_place_y), 157 | fill=color, font_size='12', font_family='courier')) 158 | 159 | def __gen_graph(self, rdf, selected_text, show_relations): 160 | 161 | done_ent1 = {} 162 | done_ent2 = {} 163 | all_done = {} 164 | 165 | start_y = 75 166 | x_limit = 920 167 | y_offset = 100 168 | dwg = svgwrite.Drawing("temp.svg",profile='tiny', size = (x_limit, len(selected_text) * 1.1 + len(rdf)*20)) 169 | 170 | begin_index = 0 171 | start_x = 10 172 | this_line = 0 173 | 174 | all_entities_index = set() 175 | all_entities_1_index = [] 176 | basic_dict = {} 177 | relation_dict = {} 178 | for t in rdf: 179 | if t.result.lower().strip() != 'o': 180 | all_entities_index.add(int(t.metadata['entity1_begin'])) 181 | all_entities_index.add(int(t.metadata['entity2_begin'])) 182 | basic_dict[int(t.metadata['entity1_begin'])] = [t.metadata['entity1_begin'], 183 | t.metadata['entity1_end'], 184 | t.metadata['chunk1'], 185 | t.metadata['entity1']] 186 | 187 | basic_dict[int(t.metadata['entity2_begin'])] = [t.metadata['entity2_begin'], 188 | t.metadata['entity2_end'], 189 | t.metadata['chunk2'], 190 | t.metadata['entity2']] 191 | 192 | #all_entities_1_index.append(t[4]['entity1_begin']) 193 | all_entities_index = np.asarray(list(all_entities_index)) 194 | all_entities_index = all_entities_index[np.argsort(all_entities_index)] 195 | for ent_start_ind in all_entities_index: 196 | e_start_now, e_end_now, e_chunk_now, e_entity_now = basic_dict[ent_start_ind] 197 | prev_text = selected_text[begin_index:int(e_start_now)] 198 | begin_index = int(e_end_now)+1 199 | for word_ in prev_text.split(' '): 200 | this_size = self.__size(word_) 201 | if (start_x + this_size + 10) >= x_limit: 202 | start_y += y_offset 203 | start_x = 10 204 | this_line = 0 205 | dwg.add(dwg.text(word_, insert=(start_x, start_y ), fill='gray', font_size='16', font_family='courier')) 206 | start_x += this_size + 5 207 | 208 | this_size = self.__size(e_chunk_now) 209 | if (start_x + this_size + 10)>= x_limit:# or this_line >= 2: 210 | start_y += y_offset 211 | start_x = 10 212 | this_line = 0 213 | #chunk1 214 | dwg.add(dwg.text(e_chunk_now, insert=(start_x, start_y ), fill='gray', font_size='16', font_family='courier')) 215 | #rectange chunk 1 216 | dwg.add(dwg.rect(insert=(start_x-3, start_y-18), size=(this_size,25), 217 | rx=2, ry=2, stroke='orange', 218 | stroke_width='2', fill='none')) 219 | #entity 1 220 | central_point_x = start_x+(this_size/2) 221 | 222 | dwg.add(dwg.text(e_entity_now, 223 | insert=(central_point_x-(self.__size(e_entity_now)/2.75), start_y+20), 224 | fill='slateblue', font_size='12', font_family='courier')) 225 | 226 | all_done[int(e_start_now)] = [central_point_x-(self.__size(e_entity_now)/2.75), start_y] 227 | start_x += this_size + 10 228 | this_line += 1 229 | 230 | #all_done[ent_start_ind] = 231 | 232 | prev_text = selected_text[begin_index:] 233 | for word_ in prev_text.split(' '): 234 | this_size = self.__size(word_) 235 | if (start_x + this_size)>= x_limit: 236 | start_y += y_offset 237 | start_x = 10 238 | dwg.add(dwg.text(word_, insert=(start_x, start_y ), fill='gray', font_size='16', font_family='courier')) 239 | start_x += this_size 240 | 241 | for row in rdf: 242 | if row.result.lower().strip() != 'o': 243 | if row.result.lower().strip() not in self.color_dict: 244 | self.color_dict[row.result.lower().strip()] = self.__get_color(row.result.lower().strip()) 245 | d_key2 = all_done[int(row.metadata['entity2_begin'])] 246 | d_key1 = all_done[int(row.metadata['entity1_begin'])] 247 | self.__draw_line(dwg, d_key2[0] , d_key2[1], d_key1[0], d_key1[1], 248 | row.result,self.color_dict[row.result.lower().strip()], show_relations) 249 | 250 | return dwg.tostring() 251 | 252 | def display(self, result, relation_col, document_col='document', show_relations=True): 253 | 254 | original_text = result[document_col][0].result 255 | res = result[relation_col] 256 | return display(HTML(self.__gen_graph(res, original_text, show_relations))) 257 | -------------------------------------------------------------------------------- /build/lib/sparknlp_display/relation_extraction.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | import os 4 | import json 5 | import pandas as pd 6 | import numpy as np 7 | import svgwrite 8 | import math 9 | import re 10 | from IPython.display import display, HTML 11 | 12 | here = os.path.abspath(os.path.dirname(__file__)) 13 | #overlap_hist = [] 14 | #y_hist_dict = {} 15 | x_i_diff_dict = {} 16 | x_o_diff_dict = {} 17 | class RelationExtractionVisualizer: 18 | 19 | def __init__(self): 20 | with open(os.path.join(here, 'label_colors/relations.json'), 'r', encoding='utf-8') as f_: 21 | self.color_dict = json.load(f_) 22 | with open(os.path.join(here, 'label_colors/ner.json'), 'r', encoding='utf-8') as f_: 23 | self.entity_color_dict = json.load(f_) 24 | self.entity_color_dict = dict((k.lower(), v) for k, v in self.entity_color_dict.items()) 25 | self.font_path = os.path.join(here, 'fonts/Lucida_Console.ttf') 26 | self.main_font = 'Lucida' 27 | def __get_color(self, l): 28 | r = lambda: random.randint(0,200) 29 | return '#%02X%02X%02X' % (r(), r(), r()) 30 | 31 | def __size(self, text): 32 | return ((len(text)+1)*9.7)-5 33 | 34 | def __draw_line(self, dwg, s_x , s_y, e_x, e_y, d_type, color, show_relations, size_of_entity_label): 35 | eps = 0.0000001 36 | def get_bezier_coef(points): 37 | # since the formulas work given that we have n+1 points 38 | # then n must be this: 39 | n = len(points) - 1 40 | 41 | # build coefficents matrix 42 | C = 4 * np.identity(n) 43 | np.fill_diagonal(C[1:], 1) 44 | np.fill_diagonal(C[:, 1:], 1) 45 | C[0, 0] = 2 46 | C[n - 1, n - 1] = 7 47 | C[n - 1, n - 2] = 2 48 | 49 | # build points vector 50 | P = [2 * (2 * points[i] + points[i + 1]) for i in range(n)] 51 | P[0] = points[0] + 2 * points[1] 52 | P[n - 1] = 8 * points[n - 1] + points[n] 53 | 54 | # solve system, find a & b 55 | A = np.linalg.solve(C, P) 56 | B = [0] * n 57 | for i in range(n - 1): 58 | B[i] = 2 * points[i + 1] - A[i + 1] 59 | B[n - 1] = (A[n - 1] + points[n]) / 2 60 | 61 | return A, B 62 | 63 | # returns the general Bezier cubic formula given 4 control points 64 | def get_cubic(a, b, c, d): 65 | return lambda t: np.power(1 - t, 3) * a + 3 * np.power(1 - t, 2) * t * b + 3 * (1 - t) * np.power(t, 2) * c + np.power(t, 3) * d 66 | 67 | # return one cubic curve for each consecutive points 68 | def get_bezier_cubic(points): 69 | A, B = get_bezier_coef(points) 70 | return [ 71 | get_cubic(points[i], A[i], B[i], points[i + 1]) 72 | for i in range(len(points) - 1) 73 | ] 74 | 75 | # evalute each cubic curve on the range [0, 1] sliced in n points 76 | def evaluate_bezier(points, n): 77 | curves = get_bezier_cubic(points) 78 | return np.array([fun(t) for fun in curves for t in np.linspace(0, 1, n)]) 79 | 80 | 81 | def draw_pointer(dwg, s_x, s_y, e_x, e_y): 82 | size = 5 83 | ratio = 1 84 | fullness1 = 2 85 | fullness2 = 3 86 | bx = e_x 87 | ax = s_x 88 | by = e_y 89 | ay = s_y 90 | abx = bx - ax 91 | aby = by - ay 92 | ab = np.sqrt(abx * abx + aby * aby) + eps 93 | 94 | cx = bx - size * abx / ab 95 | cy = by - size * aby / ab 96 | dx = cx + (by - cy) / ratio 97 | dy = cy + (cx - bx) / ratio 98 | ex = cx - (by - cy) / ratio 99 | ey = cy - (cx - bx) / ratio 100 | fx = (fullness1 * cx + bx) / fullness2 101 | fy = (fullness1 * cy + by) / fullness2 102 | 103 | text_place_y = s_y-(abs(s_y-e_y)/2) 104 | ''' 105 | line = dwg.add(dwg.polyline( 106 | [ 107 | (bx, by), 108 | (dx, dy), 109 | (fx, fy), 110 | (ex, ey), 111 | (bx, by) 112 | ], 113 | stroke=color, stroke_width = "1", fill='none',)) 114 | ''' 115 | line = dwg.add(dwg.polyline( 116 | [ 117 | (dx, dy), 118 | (bx, by), 119 | (ex, ey), 120 | (bx, by) 121 | ], 122 | stroke=color, stroke_width = "1", fill='none',)) 123 | return text_place_y 124 | unique_o_index = str(s_x)+str(s_y) 125 | unique_i_index = str(e_x)+str(e_y) 126 | if s_x > e_x: 127 | if unique_o_index in x_o_diff_dict: 128 | s_x -= 5 129 | else: 130 | s_x -= 10 131 | x_o_diff_dict[unique_o_index] = 5 132 | if s_y > e_y: 133 | e_x += size_of_entity_label 134 | elif s_y < e_y: 135 | s_x -= size_of_entity_label 136 | 137 | if unique_i_index in x_i_diff_dict: 138 | e_x += 5 139 | else: 140 | e_x += 10 141 | x_i_diff_dict[unique_i_index] = 5 142 | else: 143 | if unique_o_index in x_o_diff_dict: 144 | s_x += 5 145 | else: 146 | s_x += 10 147 | x_o_diff_dict[unique_o_index] = 5 148 | if s_y > e_y: 149 | e_x -= size_of_entity_label 150 | elif s_y < e_y: 151 | s_x += size_of_entity_label 152 | 153 | if unique_i_index in x_i_diff_dict: 154 | e_x -= 5 155 | else: 156 | e_x -= 10 157 | x_i_diff_dict[unique_i_index] = 5 158 | #this_y_vals = list(range(min(s_x,e_x), max(s_x,e_x)+1)) 159 | #this_y_vals = [ str(s_y)+'|'+str(i) for i in this_y_vals] 160 | #common = set(this_y_vals) & set(overlap_hist) 161 | #overlap_hist.extend(this_y_vals) 162 | #if s_y not in y_hist_dict: 163 | # y_hist_dict[s_y] = 20 164 | #if common: 165 | # y_hist_dict[s_y] += 20 166 | #y_increase = y_hist_dict[s_y] 167 | angle = -1 168 | if s_y == e_y: 169 | angle = 0 170 | s_y -= 20 171 | e_y = s_y-4#55 172 | 173 | text_place_y = s_y-35 174 | 175 | pth = evaluate_bezier(np.array([[s_x, s_y], 176 | [(s_x+e_x)/2.0, s_y-40], 177 | [e_x,e_y]]), 50) 178 | dwg.add(dwg.polyline(pth, 179 | stroke=color, stroke_width = "1", fill='none',)) 180 | draw_pointer(dwg, (s_x+e_x)/2.0, s_y-50, e_x, e_y) 181 | elif s_y >= e_y: 182 | 183 | e_y +=15 184 | s_y-=20 185 | text_place_y = s_y-(abs(s_y-e_y)/2) 186 | 187 | pth = evaluate_bezier(np.array([[s_x, s_y], 188 | #[((3*s_x)+e_x)/4.0, (s_y+e_y)/2.0], 189 | [(s_x+e_x)/2.0, (s_y+e_y)/2.0], 190 | #[(s_x+(3*e_x))/4.0,(s_y+e_y)/2.0], 191 | [e_x,e_y]]), 50) 192 | dwg.add(dwg.polyline(pth, 193 | stroke=color, stroke_width = "1", fill='none',)) 194 | draw_pointer(dwg, s_x, s_y, e_x, e_y) 195 | 196 | ''' 197 | line = dwg.add(dwg.polyline( 198 | [(s_x, s_y),(s_x, s_y-y_increase), (e_x, s_y-y_increase), 199 | (e_x, e_y), 200 | (e_x+2, e_y), 201 | (e_x, e_y-4), 202 | (e_x-2, e_y), 203 | (e_x, e_y) 204 | ], 205 | stroke=color, stroke_width = "2", fill='none',)) 206 | ''' 207 | else: 208 | 209 | s_y+=15 210 | e_y -= 20 211 | text_place_y = s_y+(abs(s_y-e_y)/2) 212 | 213 | pth = evaluate_bezier(np.array([[s_x, s_y], 214 | #[((3*s_x)+e_x)/4.0, (s_y+e_y)/2.0], 215 | [(s_x+e_x)/2.0, (s_y+e_y)/2.0], 216 | #[(s_x+(3*e_x))/4.0,(s_y+e_y)/2.0], 217 | [e_x,e_y]]), 50) 218 | dwg.add(dwg.polyline(pth, 219 | stroke=color, stroke_width = "1", fill='none',)) 220 | draw_pointer(dwg, s_x, s_y, e_x, e_y) 221 | 222 | if show_relations: 223 | if angle == -1: angle = math.degrees(math.atan((s_y-e_y)/((s_x-e_x)+eps))) 224 | rel_temp_size = self.__size(d_type)/1.35 225 | rect_x, rect_y = (((s_x+e_x)/2.0)-(rel_temp_size/2.0)-3, text_place_y-10) 226 | rect_w, rect_h = (rel_temp_size+3,13) 227 | dwg.add(dwg.rect(insert=(rect_x, rect_y), rx=2,ry=2, 228 | size=(rect_w, rect_h), 229 | fill='white', stroke=color , stroke_width='1', 230 | transform = f"rotate({angle} {rect_x+rect_w/2} {rect_y+rect_h/2})")) 231 | 232 | dwg.add(dwg.text(d_type, insert=(((s_x+e_x)/2)-(rel_temp_size/2.0), text_place_y), 233 | fill=color, font_size='12', font_family='courier', 234 | transform = f"rotate({angle} {rect_x+rect_w/2} {rect_y+rect_h/2})")) 235 | 236 | def __gen_graph(self, rdf, selected_text, exclude_relations, show_relations): 237 | exclude_relations = [ i.lower().strip() for i in exclude_relations] 238 | rdf = [ i for i in rdf if i.result.lower().strip() not in exclude_relations] 239 | 240 | done_ent1 = {} 241 | done_ent2 = {} 242 | all_done = {} 243 | 244 | start_y = 75 245 | x_limit = 1000 246 | y_offset = 100 247 | #dwg = svgwrite.Drawing("temp.svg",profile='full', size = (x_limit, len(selected_text) * 1.1 + len(rdf)*20)) 248 | 249 | begin_index = 0 250 | start_x = 10 251 | this_line = 0 252 | 253 | all_entities_index = set() 254 | all_entities_1_index = [] 255 | basic_dict = {} 256 | relation_dict = {} 257 | for t in rdf: 258 | 259 | all_entities_index.add(int(t.metadata['entity1_begin'])) 260 | all_entities_index.add(int(t.metadata['entity2_begin'])) 261 | basic_dict[int(t.metadata['entity1_begin'])] = [t.metadata['entity1_begin'], 262 | t.metadata['entity1_end'], 263 | t.metadata['chunk1'], 264 | t.metadata['entity1']] 265 | 266 | basic_dict[int(t.metadata['entity2_begin'])] = [t.metadata['entity2_begin'], 267 | t.metadata['entity2_end'], 268 | t.metadata['chunk2'], 269 | t.metadata['entity2']] 270 | if t.metadata['entity1'].lower().strip() not in self.entity_color_dict: 271 | self.entity_color_dict[t.metadata['entity1'].lower().strip()] = self.__get_color(t.metadata['entity1'].lower().strip()) 272 | if t.metadata['entity2'].lower().strip() not in self.entity_color_dict: 273 | self.entity_color_dict[t.metadata['entity2'].lower().strip()] = self.__get_color(t.metadata['entity2'].lower().strip()) 274 | 275 | 276 | #all_entities_1_index.append(t[4]['entity1_begin']) 277 | all_entities_index = np.asarray(list(all_entities_index)) 278 | all_entities_index = all_entities_index[np.argsort(all_entities_index)] 279 | dwg_rects, dwg_texts = [], [] 280 | for ent_start_ind in all_entities_index: 281 | e_start_now, e_end_now, e_chunk_now, e_entity_now = basic_dict[ent_start_ind] 282 | prev_text = selected_text[begin_index:int(e_start_now)] 283 | prev_text = re.sub(r'\s*(\n)+', r'\1', prev_text.strip(), re.MULTILINE) 284 | begin_index = int(e_end_now)+1 285 | for line_num, line in enumerate(prev_text.split('\n')): 286 | if line_num != 0: 287 | start_y += y_offset 288 | start_x = 10 289 | this_line = 0 290 | for word_ in line.split(' '): 291 | this_size = self.__size(word_) 292 | if (start_x + this_size + 10) >= x_limit: 293 | start_y += y_offset 294 | start_x = 10 295 | this_line = 0 296 | dwg_texts.append([word_, (start_x, start_y ), '#546c74', '16', self.main_font, 'font-weight:100']) 297 | #dwg.add(dwg.text(word_, insert=(start_x, start_y ), fill='#546c77', font_size='16', 298 | # font_family='Monaco', style='font-weight:lighter')) 299 | start_x += this_size + 10 300 | 301 | this_size = self.__size(e_chunk_now) 302 | if (start_x + this_size + 10)>= x_limit:# or this_line >= 2: 303 | start_y += y_offset 304 | start_x = 10 305 | this_line = 0 306 | 307 | #rectange chunk 1 308 | dwg_rects.append([(start_x-3, start_y-18), (this_size,25), self.entity_color_dict[e_entity_now.lower().strip()]]) 309 | #dwg.add(dwg.rect(insert=(start_x-3, start_y-18),rx=2,ry=2, size=(this_size,25), stroke=self.entity_color_dict[e_entity_now.lower()], 310 | #stroke_width='1', fill=self.entity_color_dict[e_entity_now.lower()], fill_opacity='0.2')) 311 | #chunk1 312 | dwg_texts.append([e_chunk_now, (start_x, start_y ), '#546c74', '16', self.main_font, 'font-weight:100']) 313 | #dwg.add(dwg.text(e_chunk_now, insert=(start_x, start_y ), fill='#546c77', font_size='16', 314 | # font_family='Monaco', style='font-weight:lighter')) 315 | #entity 1 316 | central_point_x = start_x+(this_size/2) 317 | temp_size = self.__size(e_entity_now)/2.75 318 | dwg_texts.append([e_entity_now.upper(), (central_point_x-temp_size, start_y+20), '#1f77b7', '12', self.main_font, 'font-weight:lighter']) 319 | #dwg.add(dwg.text(e_entity_now.upper(), 320 | # insert=(central_point_x-temp_size, start_y+20), 321 | # fill='#1f77b7', font_size='12', font_family='Monaco', 322 | # style='font-weight:lighter')) 323 | 324 | all_done[int(e_start_now)] = [central_point_x, start_y, temp_size] 325 | start_x += this_size + 20 326 | this_line += 1 327 | 328 | 329 | prev_text = selected_text[begin_index:] 330 | prev_text = re.sub(r'\s*(\n)+', r'\1', prev_text.strip(), re.MULTILINE) 331 | for line_num, line in enumerate(prev_text.split('\n')): 332 | if line_num != 0: 333 | start_y += y_offset 334 | start_x = 10 335 | for word_ in line.split(' '): 336 | this_size = self.__size(word_) 337 | if (start_x + this_size)>= x_limit: 338 | start_y += y_offset 339 | start_x = 10 340 | dwg_texts.append([word_, (start_x, start_y ), '#546c77', '16', self.main_font, 'font-weight:100']) 341 | #dwg.add(dwg.text(word_, insert=(start_x, start_y ), fill='#546c77', font_size='16', 342 | # font_family='Monaco', style='font-weight:lighter')) 343 | start_x += this_size + 10 344 | 345 | 346 | dwg = svgwrite.Drawing("temp.svg",profile='full', size = (x_limit, start_y+y_offset)) 347 | dwg.embed_font(self.main_font, self.font_path) 348 | 349 | for crect_ in dwg_rects: 350 | dwg.add(dwg.rect(insert=crect_[0],rx=2,ry=2, size=crect_[1], stroke=crect_[2], 351 | stroke_width='1', fill=crect_[2], fill_opacity='0.2')) 352 | 353 | for ctext_ in dwg_texts: 354 | dwg.add(dwg.text(ctext_[0], insert=ctext_[1], fill=ctext_[2], font_size=ctext_[3], 355 | font_family=ctext_[4], style=ctext_[5])) 356 | 357 | 358 | relation_distances = [] 359 | relation_coordinates = [] 360 | for row in rdf: 361 | if row.result.lower().strip() not in self.color_dict: 362 | self.color_dict[row.result.lower().strip()] = self.__get_color(row.result.lower().strip()) 363 | 364 | d_key2 = all_done[int(row.metadata['entity1_begin'])] 365 | d_key1 = all_done[int(row.metadata['entity2_begin'])] 366 | this_dist = abs(d_key2[0] - d_key1[0]) + abs (d_key2[1]-d_key1[1]) 367 | relation_distances.append(this_dist) 368 | relation_coordinates.append((d_key2, d_key1, row.result)) 369 | 370 | relation_distances = np.array(relation_distances) 371 | relation_coordinates = np.array(relation_coordinates, dtype=object) 372 | temp_ind = np.argsort(relation_distances) 373 | relation_distances = relation_distances[temp_ind] 374 | relation_coordinates = relation_coordinates[temp_ind] 375 | for row in relation_coordinates: 376 | #if int(row[0][1]) == int(row[1][1]): 377 | size_of_entity_label = int(row[1][2]) 378 | self.__draw_line(dwg, int(row[0][0]) , int(row[0][1]), int(row[1][0]), int(row[1][1]), 379 | row[2],self.color_dict[row[2].lower().strip()], show_relations, size_of_entity_label) 380 | 381 | return dwg.tostring() 382 | 383 | def display(self, result, relation_col, document_col='document', exclude_relations=['O'], show_relations=True, return_html=False, save_path=None): 384 | """Displays Relation Extraction visualization. 385 | Inputs: 386 | result -- A Dataframe or dictionary. 387 | relation_col -- Name of the column/key containing relationships. 388 | document_col -- Name of the column/key containing text document. 389 | exclude_relations -- list of relations that don't need to be displayed. Default: ["O"] 390 | show_relations -- Display relation types on arrows. Default: True 391 | return_html -- If true, returns raw html code instead of displaying. Default: False 392 | Output: Visualization 393 | """ 394 | 395 | original_text = result[document_col][0].result 396 | res = result[relation_col] 397 | 398 | html_content = self.__gen_graph(res, original_text, exclude_relations, show_relations) 399 | 400 | if save_path != None: 401 | with open(save_path, 'w') as f_: 402 | f_.write(html_content) 403 | 404 | if return_html: 405 | return html_content 406 | else: 407 | return display(HTML(html_content)) 408 | -------------------------------------------------------------------------------- /build/lib/sparknlp_display/retemp.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import json 4 | import pandas as pd 5 | import numpy as np 6 | import svgwrite 7 | from IPython.display import display, HTML 8 | 9 | here = os.path.abspath(os.path.dirname(__file__)) 10 | overlap_hist = [] 11 | y_hist_dict = {} 12 | x_i_diff_dict = {} 13 | x_o_diff_dict = {} 14 | class RelationExtractionVisualizer: 15 | 16 | def __init__(self): 17 | self.color_dict = { 18 | "overlap" : "lightsalmon", 19 | "before" : "deepskyblue", 20 | "after" : "springgreen", 21 | 22 | "trip": "lightsalmon", 23 | "trwp": "deepskyblue", 24 | "trcp": "springgreen", 25 | "trap": "gold", 26 | "trnap": "maroon", 27 | "terp": "purple", 28 | "tecp": "tomato", 29 | "pip" : "slategray", 30 | 31 | "drug-strength" : "purple", 32 | "drug-frequency": "slategray", 33 | "drug-form" : "deepskyblue", 34 | "dosage-drug" : "springgreen", 35 | "strength-drug": "maroon", 36 | "drug-dosage" : "gold" 37 | } 38 | 39 | def __get_color(self, l): 40 | r = lambda: random.randint(100,255) 41 | return '#%02X%02X%02X' % (r(), r(), r()) 42 | 43 | def __size(self, text): 44 | return ((len(text)+1)*9.7)-5 45 | 46 | def __draw_line(self, dwg, s_x , s_y, e_x, e_y, d_type, color, show_relations): 47 | if s_x > e_x: 48 | if s_x in x_o_diff_dict: 49 | x_o_diff_dict[s_x] -= 10 50 | else: 51 | x_o_diff_dict[s_x] = 10 52 | if e_x in x_i_diff_dict: 53 | x_i_diff_dict[e_x] += 10 54 | else: 55 | x_i_diff_dict[e_x] = 10 56 | s_x -= x_o_diff_dict[s_x] 57 | e_x += x_i_diff_dict[e_x] 58 | else: 59 | if s_x in x_o_diff_dict: 60 | x_o_diff_dict[s_x] += 10 61 | else: 62 | x_o_diff_dict[s_x] = 10 63 | if e_x in x_i_diff_dict: 64 | x_i_diff_dict[e_x] -= 10 65 | else: 66 | x_i_diff_dict[e_x] = 10 67 | s_x += x_o_diff_dict[s_x] 68 | e_x -= x_i_diff_dict[e_x] 69 | this_y_vals = list(range(min(s_x,e_x), max(s_x,e_x)+1)) 70 | this_y_vals = [ str(s_y)+'|'+str(i) for i in this_y_vals] 71 | common = set(this_y_vals) & set(overlap_hist) 72 | overlap_hist.extend(this_y_vals) 73 | if s_y not in y_hist_dict: 74 | y_hist_dict[s_y] = 20 75 | if common: 76 | y_hist_dict[s_y] += 20 77 | y_increase = y_hist_dict[s_y] 78 | if s_y == e_y: 79 | s_y -= 20 80 | e_y = s_y-4#55 81 | 82 | text_place_y = s_y-45 83 | 84 | line = dwg.add(dwg.polyline( 85 | [(s_x, s_y), (s_x, s_y-y_increase), (e_x, s_y-y_increase), 86 | (e_x, e_y), 87 | (e_x+2, e_y), 88 | (e_x, e_y+4), 89 | (e_x-2, e_y), 90 | (e_x, e_y) 91 | ], 92 | stroke=color, stroke_width = "2", fill='none',)) 93 | elif s_y >= e_y: 94 | e_y +=30 95 | s_y-=20 96 | text_place_y = s_y-(abs(s_y-e_y)/2) 97 | line = dwg.add(dwg.polyline( 98 | [(s_x, s_y),(s_x, s_y-y_increase), (e_x, s_y-y_increase), 99 | (e_x, e_y), 100 | (e_x+2, e_y), 101 | (e_x, e_y-4), 102 | (e_x-2, e_y), 103 | (e_x, e_y) 104 | ], 105 | stroke=color, stroke_width = "2", fill='none',)) 106 | else: 107 | s_y-=5 108 | e_y -= 40 109 | text_place_y = s_y+(abs(s_y-e_y)/2) 110 | line = dwg.add(dwg.polyline( 111 | [(s_x, s_y), 112 | (e_x, e_y-40), 113 | (e_x+2, e_y), 114 | (e_x, e_y+4), 115 | (e_x-2, e_y), 116 | (e_x, e_y) 117 | ], 118 | stroke=color, stroke_width = "2", fill='none',)) 119 | if show_relations: 120 | dwg.add(dwg.text(d_type, insert=(((s_x+e_x)/2)-(self.__size(d_type)/2.75), text_place_y), 121 | fill=color, font_size='12', font_family='courier')) 122 | 123 | def __gen_graph(self, rdf, selected_text, show_relations): 124 | rdf = [ i for i in rdf if i.result.lower().strip()!='o'] 125 | 126 | done_ent1 = {} 127 | done_ent2 = {} 128 | all_done = {} 129 | 130 | start_y = 75 131 | x_limit = 920 132 | y_offset = 100 133 | dwg = svgwrite.Drawing("temp.svg",profile='tiny', size = (x_limit, len(selected_text) * 1.1 + len(rdf)*20)) 134 | 135 | begin_index = 0 136 | start_x = 10 137 | this_line = 0 138 | 139 | all_entities_index = set() 140 | all_entities_1_index = [] 141 | basic_dict = {} 142 | relation_dict = {} 143 | for t in rdf: 144 | 145 | all_entities_index.add(int(t.metadata['entity1_begin'])) 146 | all_entities_index.add(int(t.metadata['entity2_begin'])) 147 | basic_dict[int(t.metadata['entity1_begin'])] = [t.metadata['entity1_begin'], 148 | t.metadata['entity1_end'], 149 | t.metadata['chunk1'], 150 | t.metadata['entity1']] 151 | 152 | basic_dict[int(t.metadata['entity2_begin'])] = [t.metadata['entity2_begin'], 153 | t.metadata['entity2_end'], 154 | t.metadata['chunk2'], 155 | t.metadata['entity2']] 156 | 157 | #all_entities_1_index.append(t[4]['entity1_begin']) 158 | all_entities_index = np.asarray(list(all_entities_index)) 159 | all_entities_index = all_entities_index[np.argsort(all_entities_index)] 160 | for ent_start_ind in all_entities_index: 161 | e_start_now, e_end_now, e_chunk_now, e_entity_now = basic_dict[ent_start_ind] 162 | prev_text = selected_text[begin_index:int(e_start_now)] 163 | begin_index = int(e_end_now)+1 164 | for word_ in prev_text.split(' '): 165 | this_size = self.__size(word_) 166 | if (start_x + this_size + 10) >= x_limit: 167 | start_y += y_offset 168 | start_x = 10 169 | this_line = 0 170 | dwg.add(dwg.text(word_, insert=(start_x, start_y ), fill='gray', font_size='16', font_family='courier')) 171 | start_x += this_size + 5 172 | 173 | this_size = self.__size(e_chunk_now) 174 | if (start_x + this_size + 10)>= x_limit:# or this_line >= 2: 175 | start_y += y_offset 176 | start_x = 10 177 | this_line = 0 178 | #chunk1 179 | dwg.add(dwg.text(e_chunk_now, insert=(start_x, start_y ), fill='gray', font_size='16', font_family='courier')) 180 | #rectange chunk 1 181 | dwg.add(dwg.rect(insert=(start_x-3, start_y-18),rx=3,ry=3, size=(this_size,25), stroke='orange', 182 | stroke_width='2', fill='none')) 183 | #entity 1 184 | central_point_x = start_x+(this_size/2) 185 | 186 | dwg.add(dwg.text(e_entity_now, 187 | insert=(central_point_x-(self.__size(e_entity_now)/2.75), start_y+20), 188 | fill='mediumseagreen', font_size='12', font_family='courier')) 189 | 190 | all_done[int(e_start_now)] = [central_point_x, start_y] 191 | start_x += this_size + 10 192 | this_line += 1 193 | 194 | #all_done[ent_start_ind] = 195 | 196 | prev_text = selected_text[begin_index:] 197 | for word_ in prev_text.split(' '): 198 | this_size = self.__size(word_) 199 | if (start_x + this_size)>= x_limit: 200 | start_y += y_offset 201 | start_x = 10 202 | dwg.add(dwg.text(word_, insert=(start_x, start_y ), fill='gray', font_size='16', font_family='courier')) 203 | start_x += this_size 204 | 205 | relation_distances = [] 206 | relation_coordinates = [] 207 | for row in rdf: 208 | if row.result.lower().strip() not in self.color_dict: 209 | self.color_dict[row.result.lower().strip()] = self.__get_color(row.result.lower().strip()) 210 | d_key2 = all_done[int(row.metadata['entity2_begin'])] 211 | d_key1 = all_done[int(row.metadata['entity1_begin'])] 212 | this_dist = abs(d_key2[0] - d_key1[0]) + abs (d_key2[1]-d_key1[1]) 213 | relation_distances.append(this_dist) 214 | relation_coordinates.append((d_key2, d_key1, row.result)) 215 | 216 | relation_distances = np.array(relation_distances) 217 | relation_coordinates = np.array(relation_coordinates) 218 | temp_ind = np.argsort(relation_distances) 219 | relation_distances = relation_distances[temp_ind] 220 | relation_coordinates = relation_coordinates[temp_ind] 221 | for row in relation_coordinates: 222 | self.__draw_line(dwg, int(row[0][0]) , int(row[0][1]), int(row[1][0]), int(row[1][1]), 223 | row[2],self.color_dict[row[2].lower().strip()], show_relations) 224 | 225 | return dwg.tostring() 226 | 227 | def display(self, result, relation_col, document_col='document', show_relations=True): 228 | 229 | original_text = result[document_col][0].result 230 | res = result[relation_col] 231 | return display(HTML(self.__gen_graph(res, original_text, show_relations))) 232 | -------------------------------------------------------------------------------- /build/lib/sparknlp_display/style.css: -------------------------------------------------------------------------------- 1 | @import url('https://fonts.googleapis.com/css2?family=Montserrat:wght@300;400;500;600;700&display=swap'); 2 | 3 | .scroll.entities { 4 | border: 1px solid #E7EDF0; 5 | border-radius: 3px; 6 | text-align: justify; 7 | } 8 | 9 | .scroll.entities span { 10 | font-size: 14px; 11 | line-height: 24px; 12 | color: #536B76; 13 | font-family: Montserrat, sans-serif !important; 14 | } 15 | .entity-wrapper { 16 | border-radius: 3px; 17 | padding: 1px; 18 | margin: 0 2px 5px 2px; 19 | } 20 | 21 | .scroll.entities span .entity-type { 22 | font-weight: 500; 23 | color: #ffffff; 24 | display: block; 25 | padding: 3px 5px; 26 | } 27 | 28 | .scroll.entities span .entity-name { 29 | border-radius: 3px; 30 | padding: 2px 5px; 31 | display: block; 32 | margin: 3px 2px; 33 | } 34 | 35 | 36 | div.scroll { 37 | line-height: 24px; 38 | } 39 | 40 | -------------------------------------------------------------------------------- /build/lib/sparknlp_display/style_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | HTML_WRAPPER = """
{}
""" 4 | HTML_INDEX_WRAPPER = """
{}
""" 5 | 6 | STYLE_CONFIG_ENTITIES = f""" 7 | 86 | """ 87 | -------------------------------------------------------------------------------- /dist/spark-nlp-display-5.0.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/a6c5a6ac166c5f75ec4ccf64ffafb2fa6e1164a1/dist/spark-nlp-display-5.0.tar.gz -------------------------------------------------------------------------------- /dist/spark_nlp_display-5.0-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/a6c5a6ac166c5f75ec4ccf64ffafb2fa6e1164a1/dist/spark_nlp_display-5.0-py3-none-any.whl -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Inside of setup.cfg 2 | [metadata] 3 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | import os 3 | 4 | here = os.path.abspath(os.path.dirname(__file__)) 5 | with open(os.path.join(here, "README.md"), "r") as fh: 6 | long_description = fh.read() 7 | 8 | with open(os.path.join(here, "sparknlp_display/VERSION"), "r") as fh: 9 | app_version = fh.read().strip() 10 | 11 | setuptools.setup( 12 | name="spark-nlp-display", 13 | version=app_version, 14 | author="John Snow Labs", 15 | author_email="john@johnsnowlabs.com", 16 | description="Visualization package for Spark NLP", 17 | long_description=long_description, 18 | long_description_content_type="text/markdown", 19 | url="http://nlp.johnsnowlabs.com", 20 | packages=setuptools.find_packages(), 21 | classifiers=[ 22 | "Programming Language :: Python :: 3", 23 | "Programming Language :: Python :: 2", 24 | "License :: OSI Approved :: Apache Software License", 25 | "Operating System :: OS Independent", 26 | ], 27 | python_requires='>=2.7', 28 | include_package_data=True, 29 | install_requires=[ 30 | 'spark-nlp', 31 | 'ipython', 32 | 'svgwrite==1.4', 33 | 'pandas', 34 | 'numpy' 35 | ] 36 | ) -------------------------------------------------------------------------------- /spark_nlp_display.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: spark-nlp-display 3 | Version: 5.0 4 | Summary: Visualization package for Spark NLP 5 | Home-page: http://nlp.johnsnowlabs.com 6 | Author: John Snow Labs 7 | Author-email: john@johnsnowlabs.com 8 | Classifier: Programming Language :: Python :: 3 9 | Classifier: Programming Language :: Python :: 2 10 | Classifier: License :: OSI Approved :: Apache Software License 11 | Classifier: Operating System :: OS Independent 12 | Requires-Python: >=2.7 13 | Description-Content-Type: text/markdown 14 | License-File: LICENSE 15 | 16 | # spark-nlp-display 17 | A library for the simple visualization of different types of Spark NLP annotations. 18 | 19 | ## Supported Visualizations: 20 | - Dependency Parser 21 | - Named Entity Recognition 22 | - Entity Resolution 23 | - Relation Extraction 24 | - Assertion Status 25 | 26 | ## Complete Tutorial 27 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-display/blob/main/tutorials/Spark_NLP_Display.ipynb) 28 | 29 | https://github.com/JohnSnowLabs/spark-nlp-display/blob/main/tutorials/Spark_NLP_Display.ipynb 30 | 31 | ### Requirements 32 | - spark-nlp 33 | - ipython 34 | - svgwrite 35 | - pandas 36 | - numpy 37 | 38 | ### Installation 39 | ```bash 40 | pip install spark-nlp-display 41 | ``` 42 | 43 | ### How to use 44 | 45 | ### Databricks 46 | #### For all modules, pass in the additional parameter "return_html=True" in the display function and use Databrick's function displayHTML() to render visualization as explained below: 47 | ```python 48 | from sparknlp_display import NerVisualizer 49 | 50 | ner_vis = NerVisualizer() 51 | 52 | ## To set custom label colors: 53 | ner_vis.set_label_colors({'LOC':'#800080', 'PER':'#77b5fe'}) #set label colors by specifying hex codes 54 | 55 | pipeline_result = ner_light_pipeline.fullAnnotate(text) ##light pipeline 56 | #pipeline_result = ner_full_pipeline.transform(df).collect()##full pipeline 57 | 58 | vis_html = ner_vis.display(pipeline_result[0], #should be the results of a single example, not the complete dataframe 59 | label_col='entities', #specify the entity column 60 | document_col='document', #specify the document column (default: 'document') 61 | labels=['PER'], #only allow these labels to be displayed. (default: [] - all labels will be displayed) 62 | return_html=True) 63 | 64 | displayHTML(vis_html) 65 | ``` 66 | ![title](https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/main/assets/ner_viz.png) 67 | 68 | ### Jupyter 69 | 70 | To save the visualization as html, provide the export file path: `save_path='./export.html'` for each visualizer. 71 | 72 | 73 | #### Dependency Parser 74 | ```python 75 | from sparknlp_display import DependencyParserVisualizer 76 | 77 | dependency_vis = DependencyParserVisualizer() 78 | 79 | pipeline_result = dp_pipeline.fullAnnotate(text) 80 | #pipeline_result = dp_full_pipeline.transform(df).collect()##full pipeline 81 | 82 | dependency_vis.display(pipeline_result[0], #should be the results of a single example, not the complete dataframe. 83 | pos_col = 'pos', #specify the pos column 84 | dependency_col = 'dependency', #specify the dependency column 85 | dependency_type_col = 'dependency_type', #specify the dependency type column 86 | save_path='./export.html' # optional - to save viz as html. (default: None) 87 | ) 88 | ``` 89 | 90 | ![title](https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/main/assets/dp_viz.png) 91 | 92 | #### Named Entity Recognition 93 | 94 | ```python 95 | from sparknlp_display import NerVisualizer 96 | 97 | ner_vis = NerVisualizer() 98 | 99 | pipeline_result = ner_light_pipeline.fullAnnotate(text) 100 | #pipeline_result = ner_full_pipeline.transform(df).collect()##full pipeline 101 | 102 | ner_vis.display(pipeline_result[0], #should be the results of a single example, not the complete dataframe 103 | label_col='entities', #specify the entity column 104 | document_col='document', #specify the document column (default: 'document') 105 | labels=['PER'], #only allow these labels to be displayed. (default: [] - all labels will be displayed) 106 | save_path='./export.html' # optional - to save viz as html. (default: None) 107 | ) 108 | 109 | ## To set custom label colors: 110 | ner_vis.set_label_colors({'LOC':'#800080', 'PER':'#77b5fe'}) #set label colors by specifying hex codes 111 | 112 | ``` 113 | 114 | ![title](https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/main/assets/ner_viz.png) 115 | 116 | #### Entity Resolution 117 | 118 | ```python 119 | from sparknlp_display import EntityResolverVisualizer 120 | 121 | er_vis = EntityResolverVisualizer() 122 | 123 | pipeline_result = er_light_pipeline.fullAnnotate(text) 124 | 125 | er_vis.display(pipeline_result[0], #should be the results of a single example, not the complete dataframe 126 | label_col='entities', #specify the ner result column 127 | resolution_col = 'resolution', 128 | document_col='document', #specify the document column (default: 'document') 129 | save_path='./export.html' # optional - to save viz as html. (default: None) 130 | ) 131 | 132 | ## To set custom label colors: 133 | er_vis.set_label_colors({'TREATMENT':'#800080', 'PROBLEM':'#77b5fe'}) #set label colors by specifying hex codes 134 | 135 | ``` 136 | 137 | ![title](https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/main/assets/er_viz.png) 138 | 139 | #### Relation Extraction 140 | ```python 141 | from sparknlp_display import RelationExtractionVisualizer 142 | 143 | re_vis = RelationExtractionVisualizer() 144 | 145 | pipeline_result = re_light_pipeline.fullAnnotate(text) 146 | 147 | re_vis.display(pipeline_result[0], #should be the results of a single example, not the complete dataframe 148 | relation_col = 'relations', #specify relations column 149 | document_col = 'document', #specify document column 150 | show_relations=True, #display relation names on arrows (default: True) 151 | save_path='./export.html' # optional - to save viz as html. (default: None) 152 | ) 153 | 154 | ``` 155 | 156 | ![title](https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/main/assets/re_viz.png) 157 | 158 | #### Assertion Status 159 | ```python 160 | from sparknlp_display import AssertionVisualizer 161 | 162 | assertion_vis = AssertionVisualizer() 163 | 164 | pipeline_result = ner_assertion_light_pipeline.fullAnnotate(text) 165 | 166 | assertion_vis.display(pipeline_result[0], 167 | label_col = 'entities', #specify the ner result column 168 | assertion_col = 'assertion', #specify assertion column 169 | document_col = 'document', #specify the document column (default: 'document') 170 | save_path='./export.html' # optional - to save viz as html. (default: None) 171 | ) 172 | 173 | ## To set custom label colors: 174 | assertion_vis.set_label_colors({'TREATMENT':'#008080', 'problem':'#800080'}) #set label colors by specifying hex codes 175 | 176 | ``` 177 | 178 | ![title](https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/main/assets/assertion_viz.png) 179 | -------------------------------------------------------------------------------- /spark_nlp_display.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | MANIFEST.in 3 | README.md 4 | setup.cfg 5 | setup.py 6 | spark_nlp_display.egg-info/PKG-INFO 7 | spark_nlp_display.egg-info/SOURCES.txt 8 | spark_nlp_display.egg-info/dependency_links.txt 9 | spark_nlp_display.egg-info/requires.txt 10 | spark_nlp_display.egg-info/top_level.txt 11 | sparknlp_display/VERSION 12 | sparknlp_display/__init__.py 13 | sparknlp_display/assertion.py 14 | sparknlp_display/dep_updates.py 15 | sparknlp_display/dependency_parser.py 16 | sparknlp_display/entity_resolution.py 17 | sparknlp_display/ner.py 18 | sparknlp_display/re_updates.py 19 | sparknlp_display/relation_extraction.py 20 | sparknlp_display/retemp.py 21 | sparknlp_display/style.css 22 | sparknlp_display/style_utils.py 23 | sparknlp_display/fonts/Lucida_Console.ttf 24 | sparknlp_display/label_colors/ner.json 25 | sparknlp_display/label_colors/relations.json -------------------------------------------------------------------------------- /spark_nlp_display.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /spark_nlp_display.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | spark-nlp 2 | ipython 3 | svgwrite==1.4 4 | pandas 5 | numpy 6 | -------------------------------------------------------------------------------- /spark_nlp_display.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | sparknlp_display 2 | -------------------------------------------------------------------------------- /sparknlp_display/VERSION: -------------------------------------------------------------------------------- 1 | 5.0 -------------------------------------------------------------------------------- /sparknlp_display/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from sparknlp_display.ner import NerVisualizer 3 | from sparknlp_display.dependency_parser import DependencyParserVisualizer 4 | from sparknlp_display.relation_extraction import RelationExtractionVisualizer 5 | from sparknlp_display.entity_resolution import EntityResolverVisualizer 6 | from sparknlp_display.assertion import AssertionVisualizer 7 | 8 | here = os.path.abspath(os.path.dirname(__file__)) 9 | 10 | def get_version(): 11 | version_path = os.path.abspath(os.path.dirname(__file__)) 12 | with open(os.path.join(here, "VERSION"), "r") as fh: 13 | app_version = fh.read().strip() 14 | return app_version 15 | 16 | __version__ = get_version() 17 | 18 | def version(): 19 | return get_version() -------------------------------------------------------------------------------- /sparknlp_display/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/a6c5a6ac166c5f75ec4ccf64ffafb2fa6e1164a1/sparknlp_display/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /sparknlp_display/__pycache__/ner_output.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/a6c5a6ac166c5f75ec4ccf64ffafb2fa6e1164a1/sparknlp_display/__pycache__/ner_output.cpython-36.pyc -------------------------------------------------------------------------------- /sparknlp_display/__pycache__/style_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/a6c5a6ac166c5f75ec4ccf64ffafb2fa6e1164a1/sparknlp_display/__pycache__/style_utils.cpython-36.pyc -------------------------------------------------------------------------------- /sparknlp_display/assertion.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import json 4 | import numpy as np 5 | from . import style_utils as style_config 6 | from IPython.display import display, HTML 7 | 8 | here = os.path.abspath(os.path.dirname(__file__)) 9 | 10 | class AssertionVisualizer: 11 | def __init__(self): 12 | with open(os.path.join(here, 'label_colors/ner.json'), 'r', encoding='utf-8') as f_: 13 | self.label_colors = json.load(f_) 14 | 15 | #public function to get color for a label 16 | def get_label_color(self, label): 17 | """Returns color of a particular label 18 | 19 | Input: entity label 20 | Output: Color or if not found 21 | """ 22 | 23 | if str(label).lower() in self.label_colors: 24 | return self.label_colors[label.lower()] 25 | else: 26 | return None 27 | 28 | # private function for colors for display 29 | def __get_label(self, label): 30 | """Set label colors. 31 | 32 | Input: dictionary of entity labels and corresponding colors 33 | Output: self object - to allow chaining 34 | Note: Previous values of colors will be overwritten 35 | """ 36 | if str(label).lower() in self.label_colors: 37 | return self.label_colors[label.lower()] 38 | else: 39 | #update it to fetch from git new labels 40 | r = lambda: random.randint(0,200) 41 | return '#%02X%02X%02X' % (r(), r(), r()) 42 | 43 | def set_label_colors(self, color_dict): 44 | """Sets label colors. 45 | 46 | input: dictionary of entity labels and corresponding colors 47 | output: self object - to allow chaining 48 | note: Previous values of colors will be overwritten 49 | """ 50 | 51 | for key, value in color_dict.items(): 52 | self.label_colors[key.lower()] = value 53 | return self 54 | 55 | def __verify_structure(self, result, label_col, document_col, original_text): 56 | 57 | if original_text is None: 58 | basic_msg_1 = """Incorrect annotation structure of '{}'.""".format(document_col) 59 | if not hasattr(result[document_col][0], 'result'): 60 | raise AttributeError(basic_msg_1+""" 'result' attribute not found in the annotation. 61 | Make sure '"""+document_col+"""' is a list of objects having the following structure: 62 | Annotation(type='annotation', begin=0, end=10, result='This is a text') 63 | Or You can pass the text manually using 'raw_text' argument.""") 64 | 65 | basic_msg_1 = """Incorrect annotation structure of '{}'.""".format(label_col) 66 | basic_msg = """ 67 | In sparknlp please use 'LightPipeline.fullAnnotate' for LightPipeline or 'Pipeline.transform' for PipelineModel. 68 | Or 69 | Make sure '"""+label_col+"""' is a list of objects having the following structure: 70 | Annotation(type='annotation', begin=0, end=0, result='Adam', metadata={'entity': 'PERSON'})""" 71 | 72 | for entity in result[label_col]: 73 | if not hasattr(entity, 'begin'): 74 | raise AttributeError( basic_msg_1 + """ 'begin' attribute not found in the annotation."""+basic_msg) 75 | if not hasattr(entity, 'end'): 76 | raise AttributeError(basic_msg_1 + """ 'end' attribute not found in the annotation."""+basic_msg) 77 | if not hasattr(entity, 'result'): 78 | raise AttributeError(basic_msg_1 + """ 'result' attribute not found in the annotation."""+basic_msg) 79 | if not hasattr(entity, 'metadata'): 80 | raise AttributeError(basic_msg_1 + """ 'metadata' attribute not found in the annotation."""+basic_msg) 81 | if 'entity' not in entity.metadata: 82 | raise AttributeError(basic_msg_1+""" KeyError: 'entity' not found in metadata."""+basic_msg) 83 | 84 | def __verify_input(self, result, label_col, document_col, original_text): 85 | # check if label colum in result 86 | if label_col not in result: 87 | raise AttributeError("""column/key '{}' not found in the provided dataframe/dictionary. 88 | Please specify the correct key/column using 'label_col' argument.""".format(label_col)) 89 | 90 | if original_text is not None: 91 | # check if provided text is correct data type 92 | if not isinstance(original_text, str): 93 | raise ValueError("Invalid value for argument 'raw_text' input. Text should be of type 'str'.") 94 | 95 | else: 96 | # check if document column in result 97 | if document_col not in result: 98 | raise AttributeError("""column/key '{}' not found in the provided dataframe/dictionary. 99 | Please specify the correct key/column using 'document_col' argument. 100 | Or You can pass the text manually using 'raw_text' argument""".format(document_col)) 101 | 102 | self.__verify_structure( result, label_col, document_col, original_text) 103 | 104 | # main display function 105 | def __display_ner(self, result, label_col, resolution_col, document_col, original_text, labels_list = None): 106 | 107 | if original_text is None: 108 | original_text = result[document_col][0].result 109 | 110 | if labels_list is not None: 111 | labels_list = [v.lower() for v in labels_list] 112 | 113 | assertion_temp_dict = {} 114 | for resol in result[resolution_col]: 115 | assertion_temp_dict[int(resol.begin)] = resol.result 116 | 117 | label_color = {} 118 | html_output = "" 119 | pos = 0 120 | 121 | sorted_labs_idx = np.argsort([ int(i.begin) for i in result[label_col]]) 122 | sorted_labs = [ result[label_col][i] for i in sorted_labs_idx] 123 | 124 | for entity in sorted_labs: 125 | entity_type = entity.metadata['entity'].lower() 126 | if (entity_type not in label_color) and ((labels_list is None) or (entity_type in labels_list)) : 127 | label_color[entity_type] = self.__get_label(entity_type) 128 | 129 | begin = int(entity.begin) 130 | end = int(entity.end) 131 | if pos < begin and pos < len(original_text): 132 | white_text = original_text[pos:begin] 133 | html_output += '{}'.format(white_text) 134 | pos = end+1 135 | 136 | if entity_type in label_color: 137 | 138 | if begin in assertion_temp_dict: 139 | 140 | html_output += '{} {}{} '.format( 141 | label_color[entity_type] + 'B3', #color 142 | original_text[begin:end+1],#entity.result, 143 | entity.metadata['entity'], #entity - label 144 | label_color[entity_type] + 'FF', #color '#D2C8C6' 145 | assertion_temp_dict[begin] # res_assertion 146 | ) 147 | else: 148 | html_output += '{} {}'.format( 149 | label_color[entity_type] + 'B3', #color 150 | original_text[begin:end+1],#entity.result, 151 | entity.metadata['entity'] #entity - label 152 | ) 153 | 154 | else: 155 | html_output += '{}'.format(original_text[begin:end+1]) 156 | 157 | if pos < len(original_text): 158 | html_output += '{}'.format(original_text[pos:]) 159 | 160 | html_output += """""" 161 | 162 | html_output = html_output.replace("\n", "
") 163 | 164 | return html_output 165 | 166 | def display(self, result, label_col, assertion_col, document_col='document', raw_text=None, return_html=False, save_path=None): 167 | """Displays Assertion visualization. 168 | 169 | Inputs: 170 | result -- A Dataframe or dictionary. 171 | label_col -- Name of the column/key containing NER annotations. 172 | document_col -- Name of the column/key containing text document. 173 | original_text -- Original text of type 'str'. If specified, it will take precedence over 'document_col' and will be used as the reference text for display. 174 | labels_list -- A list of labels that should be highlighted in the output. Default: Display all labels. 175 | 176 | Output: Visualization 177 | """ 178 | 179 | #self.__verifyInput(result, label_col, document_col, raw_text) 180 | 181 | html_content = self.__display_ner(result, label_col, assertion_col, document_col, raw_text) 182 | html_content_save = style_config.STYLE_CONFIG_ENTITIES+ " "+html_content 183 | 184 | if save_path != None: 185 | with open(save_path, 'w') as f_: 186 | f_.write(html_content_save) 187 | 188 | if return_html: 189 | return html_content_save 190 | else: 191 | return display(HTML(html_content_save)) 192 | -------------------------------------------------------------------------------- /sparknlp_display/dep_updates.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import json 4 | import pandas as pd 5 | import numpy as np 6 | import svgwrite 7 | from IPython.display import display, HTML 8 | 9 | here = os.path.abspath(os.path.dirname(__file__)) 10 | 11 | class DependencyParserVisualizer: 12 | 13 | def __get_color(self, l): 14 | r = lambda: random.randint(100,255) 15 | return '#%02X%02X%02X' % (r(), r(), r()) 16 | 17 | def __size(self, text): 18 | return ((len(text)+1)*15)-5 19 | 20 | def __draw_line(self, dwg, s_x , s_y, e_x, e_y, d_type, color): 21 | rx=ry=1 22 | dwg.add(dwg.path(d=f"M{min(s_x,e_x)},{s_y} A {rx}, {ry}, 0 1 1 {max(s_x,e_x)}, {s_y}", 23 | stroke=color, stroke_width = "3", fill='none')) 24 | dwg.add(dwg.polyline( 25 | [(e_x, s_y), (e_x+3, s_y), 26 | (e_x, s_y+5), 27 | (e_x-3, s_y), 28 | (e_x, s_y) 29 | ], stroke=color, stroke_width = "4", fill='none',)) 30 | ''' 31 | line = dwg.add(dwg.polyline( 32 | [(s_x, s_y+4), 33 | (s_x, e_y), 34 | (e_x, e_y), 35 | (e_x, s_y), 36 | (e_x+2, s_y), 37 | (e_x, s_y+4), 38 | (e_x-2, s_y), 39 | (e_x, s_y) 40 | ], 41 | stroke=color, stroke_width = "2", fill='none',)) 42 | ''' 43 | dwg.add(dwg.text(d_type, insert=(((s_x+e_x)/2)-(self.__size(d_type.strip())/2.75), e_y-4), 44 | fill=color, font_size='20', font_family='courier')) 45 | 46 | def __generate_graph(self, result_df): 47 | # returns an svg graph 48 | 49 | colors_dict = {} 50 | max_x = 50 51 | max_y = 100 52 | 53 | for i in result_df['dependency_type'].unique(): 54 | colors_dict[i] = self.__get_color(i) 55 | 56 | for i in result_df['pos'].unique(): 57 | colors_dict[i] = self.__get_color(i) 58 | 59 | for i, row in result_df.iterrows(): 60 | txt = row['chunk'].strip() 61 | max_x += (self.__size(txt) + 50) 62 | max_y += 30 63 | 64 | max_x += 50 65 | start_x = 50 66 | starty_y = max_y 67 | dp_dict={} 68 | tk_dict = {} 69 | dist_dict = {} 70 | main_text = [] 71 | main_pos = [] 72 | 73 | for i, row in result_df.iterrows(): 74 | txt = row['chunk'].strip() 75 | dt = row['dependency_type'].lower().strip() 76 | is_root = False 77 | if dt == 'root': 78 | is_root = True 79 | main_text.append((txt, start_x, starty_y, is_root)) 80 | main_pos.append((row['pos'].strip(), (start_x + int((self.__size(txt)/2) - int(self.__size(row['pos'])/2))), starty_y+30)) 81 | 82 | tk_dict[str(row['begin'])+str(row['end'])] = (start_x+int(self.__size(txt)/2), starty_y) 83 | start_x += (self.__size(txt) + 50) 84 | 85 | y_offset = starty_y-100 86 | dist_dict = {} 87 | e_dist_dict = {} 88 | direct_dict = {} 89 | left_side_dict = {} 90 | right_side_dict = {} 91 | y_hist = {} 92 | root_list = [] 93 | main_lines = [] 94 | lines_dist = [] 95 | 96 | dist = [] 97 | for i, row in result_df.iterrows(): 98 | if row['dependency_type'].lower().strip() != 'root': 99 | lines_dist.append(abs(int(row['begin']) - int(row['dependency_start']['head.begin']))) 100 | else: 101 | lines_dist.append(0) 102 | 103 | result_df = result_df.iloc[np.argsort(lines_dist)] 104 | 105 | count_left = {} 106 | count_right = {} 107 | t_x_offset = {} 108 | for i, row in result_df.iterrows(): 109 | if row['dependency_type'].lower().strip() != 'root': 110 | sp = str(row['dependency_start']['head.begin'])+str(row['dependency_start']['head.end']) 111 | x_e, y_e = tk_dict[str(row['begin'])+str(row['end'])] 112 | x, y = tk_dict[sp] 113 | if int(row['begin']) < int(row['dependency_start']['head.begin']): 114 | if x in count_left: 115 | count_left[x] += 1 116 | t_x_offset[x] += 7 117 | else: 118 | count_left[x] = 1 119 | t_x_offset[x] = 7 120 | if x_e in count_right: 121 | count_right[x_e] += 1 122 | t_x_offset[x_e] -= 7 123 | else: 124 | count_right[x_e] = 0 125 | t_x_offset[x_e] = 0 126 | else: 127 | if x in count_right: 128 | count_right[x] += 1 129 | t_x_offset[x] -= 7 130 | else: 131 | count_right[x] = 0 132 | t_x_offset[x] = 0 133 | if x_e in count_left: 134 | count_left[x_e] += 1 135 | t_x_offset[x_e] += 7 136 | else: 137 | count_left[x_e] = 1 138 | t_x_offset[x_e] = 7 139 | 140 | for i, row in result_df.iterrows(): 141 | 142 | sp = str(row['dependency_start']['head.begin'])+str(row['dependency_start']['head.end']) 143 | ep = tk_dict[str(row['begin'])+str(row['end'])] 144 | 145 | if sp != '-1-1': 146 | x, y = tk_dict[sp] 147 | 148 | if int(row['begin']) > int(row['dependency_start']['head.begin']): 149 | dist_dict[x] = count_right[x] * 7 150 | count_right[x] -= 1 151 | e_dist_dict[ep[0]] = count_left[ep[0]] * -7 152 | count_left[ep[0]] -= 1 153 | else: 154 | dist_dict[x] = count_left[x] * -7 155 | count_left[x] -= 1 156 | e_dist_dict[ep[0]] = count_right[ep[0]] * 7 157 | count_right[ep[0]] -= 1 158 | #row['dependency'], x, t_x_offset[x], x+dist_dict[x], x+dist_dict[x]+t_x_offset[x] 159 | final_x_s = int(x+dist_dict[x]+(t_x_offset[x]/2)) 160 | final_x_e = int(ep[0]+ e_dist_dict[ep[0]]+(t_x_offset[ep[0]]/2)) 161 | 162 | x_inds = range(min(final_x_s, final_x_e), max(final_x_s, final_x_e)+1) 163 | common = set(y_hist.keys()).intersection(set(x_inds)) 164 | 165 | if common: 166 | y_fset = min([y_hist[c] for c in common]) 167 | y_fset -= 50 168 | y_hist.update(dict(zip(x_inds, [y_fset]*len(x_inds)))) 169 | 170 | else: 171 | y_hist.update(dict(zip(x_inds, [y_offset]*len(x_inds)))) 172 | 173 | main_lines.append((None, final_x_s, y-30, final_x_e, y_hist[final_x_s], row['dependency_type'])) 174 | 175 | else: 176 | x_x , y_y = tk_dict[str(row['begin'])+str(row['end'])] 177 | 178 | root_list.append((row['dependency_type'].upper(), x_x, y_y)) 179 | 180 | 181 | current_y = min(y_hist.values()) 182 | 183 | y_ff = (max_y - current_y) + 50 184 | y_f = (current_y - 50) 185 | current_y = 50 186 | 187 | dwg = svgwrite.Drawing("temp.svg", 188 | profile='tiny', size = (max_x, y_ff+100)) 189 | 190 | for mt, mp in zip(main_text, main_pos): 191 | dwg.add(dwg.text(mt[0], insert=(mt[1], mt[2]-y_f), fill='gray', 192 | font_size='25', font_family='courier')) 193 | 194 | if mt[3]: 195 | dwg.add(dwg.rect(insert=(mt[1]-5, mt[2]-y_f-25), size=(self.__size(mt[0]),35), stroke='orange', 196 | stroke_width='2', fill='none')) 197 | 198 | dwg.add(dwg.text(mp[0], insert=(mp[1], mp[2]-y_f), fill=colors_dict[mp[0]])) 199 | 200 | for ml in main_lines: 201 | self.__draw_line(dwg, ml[1], ml[2]-y_f, ml[3], ml[4]-y_f, ml[5], colors_dict[ml[5]]) 202 | 203 | return dwg.tostring() 204 | 205 | 206 | def display(self, res, pos_col, dependency_col, dependency_type_col): 207 | """Displays NER visualization. 208 | 209 | Inputs: 210 | result -- A Dataframe or dictionary. 211 | label_col -- Name of the column/key containing NER annotations. 212 | document_col -- Name of the column/key containing text document. 213 | original_text -- Original text of type 'str'. If specified, it will take precedence over 'document_col' and will be used as the reference text for display. 214 | labels_list -- A list of labels that should be highlighted in the output. Default: Display all labels. 215 | 216 | Output: Visualization 217 | """ 218 | 219 | pos_res = [] 220 | for i in res[pos_col]: 221 | t_ = {'chunk': i.metadata['word'], 222 | 'begin': i.begin, 223 | 'end' : i.end, 224 | 'pos' : i.result} 225 | pos_res.append(t_) 226 | dep_res = [] 227 | dep_res_meta = [] 228 | for i in res[dependency_col]: 229 | dep_res.append(i.result) 230 | dep_res_meta.append(i.metadata) 231 | df = pd.DataFrame(pos_res) 232 | df['dependency'] = dep_res 233 | df['dependency_start'] = dep_res_meta 234 | 235 | dept_res = [] 236 | for i in res[dependency_type_col]: 237 | dept_res.append(i.result) 238 | df['dependency_type'] = dept_res 239 | 240 | return display(HTML(self.__generate_graph(df))) 241 | 242 | -------------------------------------------------------------------------------- /sparknlp_display/dependency_parser.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import json 4 | import pandas as pd 5 | import numpy as np 6 | import svgwrite 7 | from . import style_utils as style_config 8 | from IPython.display import display, HTML 9 | 10 | here = os.path.abspath(os.path.dirname(__file__)) 11 | 12 | class DependencyParserVisualizer: 13 | 14 | def __init__(self): 15 | self.font_path = os.path.join(here, 'fonts/Lucida_Console.ttf') 16 | self.main_font = 'Lucida' 17 | 18 | def __get_color(self, l): 19 | r = lambda: random.randint(0,200) 20 | return '#%02X%02X%02X' % (r(), r(), r()) 21 | 22 | def __size(self, text): 23 | return ((len(text)+1)*12) 24 | 25 | def __draw_line(self, dwg, s_x , s_y, e_x, e_y, d_type, color): 26 | line = dwg.add(dwg.polyline( 27 | [ 28 | (e_x, s_y), 29 | (e_x+2, s_y), 30 | (e_x, s_y+4), 31 | (e_x-2, s_y), 32 | (e_x, s_y) 33 | ], 34 | stroke='black', stroke_width = "2", fill='none',)) 35 | 36 | #if e_x > s_x: 37 | rad=10 38 | height=abs(e_y-s_y-4)-rad 39 | sx = s_x 40 | sy=300 41 | distance=abs(e_x-s_x)-rad*2 42 | if e_x > s_x: 43 | dwg.add(dwg.path(d=f"M{s_x},{s_y+4} v-{height} a{rad},{rad} 0 0 1 {rad},-{rad} h{distance} a{rad},{rad} 0 0 1 {rad},{rad} v{height-4}", 44 | fill="none", 45 | stroke="black", stroke_width=1 46 | )) 47 | else: 48 | dwg.add(dwg.path(d=f"M{s_x},{s_y+4} v-{height} a{rad},{rad} 0 0 0 -{rad},-{rad} h-{distance} a{rad},{rad} 0 0 0 -{rad} {rad} v{height-4}", 49 | fill="none", 50 | stroke="black", stroke_width=1 51 | )) 52 | 53 | 54 | 55 | dwg.add(dwg.text(d_type, insert=(((s_x+e_x)/2)-(self.__size(d_type.strip())/3.0), e_y-4), 56 | fill=color, font_size='14', font_family=self.main_font)) 57 | 58 | def __generate_graph(self, result_df): 59 | # returns an svg graph 60 | 61 | colors_dict = {} 62 | max_x = 50 63 | max_y = 100 64 | 65 | for i in result_df['dependency_type'].unique(): 66 | colors_dict[i] = self.__get_color(i) 67 | 68 | for i in result_df['pos'].unique(): 69 | colors_dict[i] = self.__get_color(i) 70 | 71 | for i, row in result_df.iterrows(): 72 | txt = row['chunk'].strip() 73 | max_x += (self.__size(txt) + 50) 74 | max_y += 30 75 | 76 | max_x += 50 77 | start_x = 50 78 | starty_y = max_y 79 | dp_dict={} 80 | tk_dict = {} 81 | dist_dict = {} 82 | main_text = [] 83 | main_pos = [] 84 | 85 | for i, row in result_df.iterrows(): 86 | txt = row['chunk'].strip() 87 | dt = row['dependency'].lower().strip() 88 | is_root = False 89 | if dt == 'root': 90 | is_root = True 91 | main_text.append((txt, start_x, starty_y, is_root)) 92 | main_pos.append( 93 | (row['pos'].strip(), 94 | (start_x + int((self.__size(txt)/2) - int(self.__size(row['pos'])/2.25))), 95 | starty_y+30)) 96 | 97 | tk_dict[str(row['begin'])+str(row['end'])] = (start_x+int(self.__size(txt)/2), starty_y) 98 | start_x += (self.__size(txt) + 50) 99 | 100 | y_offset = starty_y-100 101 | dist_dict = {} 102 | e_dist_dict = {} 103 | direct_dict = {} 104 | left_side_dict = {} 105 | right_side_dict = {} 106 | y_hist = {} 107 | root_list = [] 108 | main_lines = [] 109 | lines_dist = [] 110 | 111 | dist = [] 112 | for i, row in result_df.iterrows(): 113 | if row['dependency'].lower().strip() != 'root': 114 | lines_dist.append(abs(int(row['begin']) - int(row['dependency_start']['head.begin']))) 115 | else: 116 | lines_dist.append(0) 117 | 118 | result_df = result_df.iloc[np.argsort(lines_dist)] 119 | 120 | count_left = {} 121 | count_right = {} 122 | t_x_offset = {} 123 | for i, row in result_df.iterrows(): 124 | if row['dependency'].lower().strip() != 'root': 125 | sp = str(row['dependency_start']['head.begin'])+str(row['dependency_start']['head.end']) 126 | x_e, y_e = tk_dict[str(row['begin'])+str(row['end'])] 127 | x, y = tk_dict[sp] 128 | if int(row['begin']) < int(row['dependency_start']['head.begin']): 129 | if x in count_left: 130 | count_left[x] += 1 131 | t_x_offset[x] += 7 132 | else: 133 | count_left[x] = 1 134 | t_x_offset[x] = 7 135 | if x_e in count_right: 136 | count_right[x_e] += 1 137 | t_x_offset[x_e] -= 7 138 | else: 139 | count_right[x_e] = 0 140 | t_x_offset[x_e] = 0 141 | else: 142 | if x in count_right: 143 | count_right[x] += 1 144 | t_x_offset[x] -= 7 145 | else: 146 | count_right[x] = 0 147 | t_x_offset[x] = 0 148 | if x_e in count_left: 149 | count_left[x_e] += 1 150 | t_x_offset[x_e] += 7 151 | else: 152 | count_left[x_e] = 1 153 | t_x_offset[x_e] = 7 154 | 155 | for i, row in result_df.iterrows(): 156 | 157 | sp = str(row['dependency_start']['head.begin'])+str(row['dependency_start']['head.end']) 158 | ep = tk_dict[str(row['begin'])+str(row['end'])] 159 | 160 | if sp != '-1-1': 161 | x, y = tk_dict[sp] 162 | 163 | if int(row['begin']) > int(row['dependency_start']['head.begin']): 164 | dist_dict[x] = count_right[x] * 7 165 | count_right[x] -= 1 166 | e_dist_dict[ep[0]] = count_left[ep[0]] * -7 167 | count_left[ep[0]] -= 1 168 | else: 169 | dist_dict[x] = count_left[x] * -7 170 | count_left[x] -= 1 171 | e_dist_dict[ep[0]] = count_right[ep[0]] * 7 172 | count_right[ep[0]] -= 1 173 | #row['dependency'], x, t_x_offset[x], x+dist_dict[x], x+dist_dict[x]+t_x_offset[x] 174 | final_x_s = int(x+dist_dict[x]+(t_x_offset[x]/2)) 175 | final_x_e = int(ep[0]+ e_dist_dict[ep[0]]+(t_x_offset[ep[0]]/2)) 176 | 177 | x_inds = range(min(final_x_s, final_x_e), max(final_x_s, final_x_e)+1) 178 | common = set(y_hist.keys()).intersection(set(x_inds)) 179 | 180 | if common: 181 | y_fset = min([y_hist[c] for c in common]) 182 | y_fset -= 50 183 | y_hist.update(dict(zip(x_inds, [y_fset]*len(x_inds)))) 184 | 185 | else: 186 | y_hist.update(dict(zip(x_inds, [y_offset]*len(x_inds)))) 187 | 188 | main_lines.append((None, final_x_s, y-30, final_x_e, y_hist[final_x_s], row['dependency_type'])) 189 | 190 | else: 191 | x_x , y_y = tk_dict[str(row['begin'])+str(row['end'])] 192 | 193 | root_list.append((row['dependency_type'].upper(), x_x, y_y)) 194 | 195 | 196 | current_y = min(y_hist.values()) 197 | 198 | y_ff = (max_y - current_y) + 50 199 | y_f = (current_y - 50) 200 | current_y = 50 201 | 202 | dwg = svgwrite.Drawing("temp.svg", 203 | profile='full', size = (max_x, y_ff+100)) 204 | dwg.embed_font(self.main_font, self.font_path) 205 | 206 | for mt, mp in zip(main_text, main_pos): 207 | dwg.add(dwg.text(mt[0], insert=(mt[1], mt[2]-y_f), fill='gray', 208 | font_size='20', font_family=self.main_font)) 209 | 210 | if mt[3]: 211 | dwg.add(dwg.rect(insert=(mt[1]-5, mt[2]-y_f-25), rx=5,ry=5, size=(self.__size(mt[0]),35), stroke='#800080', 212 | stroke_width='1', fill='none')) 213 | 214 | dwg.add(dwg.text(mp[0], insert=(mp[1], mp[2]-y_f), font_size='14', fill=colors_dict[mp[0]])) 215 | 216 | for ml in main_lines: 217 | self.__draw_line(dwg, ml[1], ml[2]-y_f, ml[3], ml[4]-y_f, ml[5], colors_dict[ml[5]]) 218 | 219 | return dwg.tostring() 220 | 221 | 222 | def display(self, res, pos_col, dependency_col, dependency_type_col=None, return_html=False, save_path=None): 223 | """Displays NER visualization. 224 | 225 | Inputs: 226 | result -- A Dataframe or dictionary. 227 | label_col -- Name of the column/key containing NER annotations. 228 | document_col -- Name of the column/key containing text document. 229 | original_text -- Original text of type 'str'. If specified, it will take precedence over 'document_col' and will be used as the reference text for display. 230 | labels_list -- A list of labels that should be highlighted in the output. Default: Display all labels. 231 | 232 | Output: Visualization 233 | """ 234 | 235 | pos_res = [] 236 | for i in res[pos_col]: 237 | t_ = {'chunk': i.metadata['word'], 238 | 'begin': i.begin, 239 | 'end' : i.end, 240 | 'pos' : i.result} 241 | pos_res.append(t_) 242 | dep_res = [] 243 | dep_res_meta = [] 244 | for i in res[dependency_col]: 245 | dep_res.append(i.result) 246 | dep_res_meta.append(i.metadata) 247 | df = pd.DataFrame(pos_res) 248 | df['dependency'] = dep_res 249 | df['dependency_start'] = dep_res_meta 250 | 251 | if dependency_type_col != None: 252 | df['dependency_type'] = [ i.result for i in res[dependency_type_col] ] 253 | else: 254 | df['dependency_type'] = '' 255 | 256 | html_content = self.__generate_graph(df) 257 | 258 | if save_path != None: 259 | with open(save_path, 'w') as f_: 260 | f_.write(html_content) 261 | 262 | if return_html: 263 | return html_content 264 | else: 265 | return display(HTML(html_content)) 266 | -------------------------------------------------------------------------------- /sparknlp_display/entity_resolution.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import json 4 | import numpy as np 5 | from . import style_utils as style_config 6 | from IPython.display import display, HTML 7 | 8 | here = os.path.abspath(os.path.dirname(__file__)) 9 | 10 | class EntityResolverVisualizer: 11 | def __init__(self): 12 | with open(os.path.join(here, 'label_colors/ner.json'), 'r', encoding='utf-8') as f_: 13 | self.label_colors = json.load(f_) 14 | 15 | #public function to get color for a label 16 | def get_label_color(self, label): 17 | """Returns color of a particular label 18 | 19 | Input: entity label 20 | Output: Color or if not found 21 | """ 22 | 23 | if str(label).lower() in self.label_colors: 24 | return self.label_colors[label.lower()] 25 | else: 26 | return None 27 | 28 | # private function for colors for display 29 | def __get_label(self, label): 30 | """Set label colors. 31 | 32 | Input: dictionary of entity labels and corresponding colors 33 | Output: self object - to allow chaining 34 | Note: Previous values of colors will be overwritten 35 | """ 36 | if str(label).lower() in self.label_colors: 37 | return self.label_colors[label.lower()] 38 | else: 39 | #update it to fetch from git new labels 40 | r = lambda: random.randint(0,200) 41 | return '#%02X%02X%02X' % (r(), r(), r()) 42 | 43 | def set_label_colors(self, color_dict): 44 | """Sets label colors. 45 | 46 | input: dictionary of entity labels and corresponding colors 47 | output: self object - to allow chaining 48 | note: Previous values of colors will be overwritten 49 | """ 50 | 51 | for key, value in color_dict.items(): 52 | self.label_colors[key.lower()] = value 53 | return self 54 | 55 | def __verify_structure(self, result, label_col, document_col, original_text): 56 | 57 | if original_text is None: 58 | basic_msg_1 = """Incorrect annotation structure of '{}'.""".format(document_col) 59 | if not hasattr(result[document_col][0], 'result'): 60 | raise AttributeError(basic_msg_1+""" 'result' attribute not found in the annotation. 61 | Make sure '"""+document_col+"""' is a list of objects having the following structure: 62 | Annotation(type='annotation', begin=0, end=10, result='This is a text') 63 | Or You can pass the text manually using 'raw_text' argument.""") 64 | 65 | basic_msg_1 = """Incorrect annotation structure of '{}'.""".format(label_col) 66 | basic_msg = """ 67 | In sparknlp please use 'LightPipeline.fullAnnotate' for LightPipeline or 'Pipeline.transform' for PipelineModel. 68 | Or 69 | Make sure '"""+label_col+"""' is a list of objects having the following structure: 70 | Annotation(type='annotation', begin=0, end=0, result='Adam', metadata={'entity': 'PERSON'})""" 71 | 72 | for entity in result[label_col]: 73 | if not hasattr(entity, 'begin'): 74 | raise AttributeError( basic_msg_1 + """ 'begin' attribute not found in the annotation."""+basic_msg) 75 | if not hasattr(entity, 'end'): 76 | raise AttributeError(basic_msg_1 + """ 'end' attribute not found in the annotation."""+basic_msg) 77 | if not hasattr(entity, 'result'): 78 | raise AttributeError(basic_msg_1 + """ 'result' attribute not found in the annotation."""+basic_msg) 79 | if not hasattr(entity, 'metadata'): 80 | raise AttributeError(basic_msg_1 + """ 'metadata' attribute not found in the annotation."""+basic_msg) 81 | if 'entity' not in entity.metadata: 82 | raise AttributeError(basic_msg_1+""" KeyError: 'entity' not found in metadata."""+basic_msg) 83 | 84 | def __verify_input(self, result, label_col, document_col, original_text): 85 | # check if label colum in result 86 | if label_col not in result: 87 | raise AttributeError("""column/key '{}' not found in the provided dataframe/dictionary. 88 | Please specify the correct key/column using 'label_col' argument.""".format(label_col)) 89 | 90 | if original_text is not None: 91 | # check if provided text is correct data type 92 | if not isinstance(original_text, str): 93 | raise ValueError("Invalid value for argument 'raw_text' input. Text should be of type 'str'.") 94 | 95 | else: 96 | # check if document column in result 97 | if document_col not in result: 98 | raise AttributeError("""column/key '{}' not found in the provided dataframe/dictionary. 99 | Please specify the correct key/column using 'document_col' argument. 100 | Or You can pass the text manually using 'raw_text' argument""".format(document_col)) 101 | 102 | self.__verify_structure( result, label_col, document_col, original_text) 103 | 104 | # main display function 105 | def __display_ner(self, result, label_col, resolution_col, document_col, original_text, labels_list = None): 106 | 107 | if original_text is None: 108 | original_text = result[document_col][0].result 109 | 110 | if labels_list is not None: 111 | labels_list = [v.lower() for v in labels_list] 112 | 113 | resolution_temp_dict = {} 114 | for resol in result[resolution_col]: 115 | resolution_temp_dict[int(resol.begin)] = [resol.result, resol.metadata['resolved_text']] 116 | 117 | 118 | label_color = {} 119 | html_output = "" 120 | pos = 0 121 | 122 | sorted_labs_idx = np.argsort([ int(i.begin) for i in result[label_col]]) 123 | sorted_labs = [ result[label_col][i] for i in sorted_labs_idx] 124 | 125 | for entity in sorted_labs: 126 | entity_type = entity.metadata['entity'].lower() 127 | if (entity_type not in label_color) and ((labels_list is None) or (entity_type in labels_list)) : 128 | label_color[entity_type] = self.__get_label(entity_type) 129 | 130 | begin = int(entity.begin) 131 | end = int(entity.end) 132 | if pos < begin and pos < len(original_text): 133 | white_text = original_text[pos:begin] 134 | html_output += '{}'.format(white_text) 135 | pos = end+1 136 | 137 | if entity_type in label_color: 138 | if begin in resolution_temp_dict: 139 | html_output += '{} {}{} {}'.format( 140 | label_color[entity_type] + 'B3', #color 141 | entity.result, #entity - chunk 142 | entity.metadata['entity'], #entity - label 143 | label_color[entity_type] + 'FF', #color '#D2C8C6' 144 | resolution_temp_dict[begin][0], # res_code 145 | label_color[entity_type] + 'CC', # res_color '#DDD2D0' 146 | resolution_temp_dict[begin][1] # res_text 147 | ) 148 | else: 149 | html_output += '{} {}'.format( 150 | label_color[entity_type] + 'B3', #color 151 | entity.result, #entity - chunk 152 | entity.metadata['entity'] #entity - label 153 | ) 154 | else: 155 | html_output += '{}'.format(entity.result) 156 | 157 | if pos < len(original_text): 158 | html_output += '{}'.format(original_text[pos:]) 159 | 160 | html_output += """""" 161 | 162 | html_output = html_output.replace("\n", "
") 163 | 164 | return html_output 165 | 166 | def display(self, result, label_col, resolution_col, document_col='document', raw_text=None, return_html=False, save_path=None): 167 | """Displays NER visualization. 168 | 169 | Inputs: 170 | result -- A Dataframe or dictionary. 171 | label_col -- Name of the column/key containing NER annotations. 172 | document_col -- Name of the column/key containing text document. 173 | original_text -- Original text of type 'str'. If specified, it will take precedence over 'document_col' and will be used as the reference text for display. 174 | labels_list -- A list of labels that should be highlighted in the output. Default: Display all labels. 175 | 176 | Output: Visualization 177 | """ 178 | 179 | #self.__verifyInput(result, label_col, document_col, raw_text) 180 | 181 | html_content = self.__display_ner(result, label_col, resolution_col, document_col, raw_text) 182 | 183 | html_content_save = style_config.STYLE_CONFIG_ENTITIES+ " "+html_content 184 | 185 | if save_path != None: 186 | with open(save_path, 'w') as f_: 187 | f_.write(html_content_save) 188 | 189 | if return_html: 190 | return html_content_save 191 | else: 192 | return display(HTML(html_content_save)) 193 | -------------------------------------------------------------------------------- /sparknlp_display/fonts/Lucida_Console.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-display/a6c5a6ac166c5f75ec4ccf64ffafb2fa6e1164a1/sparknlp_display/fonts/Lucida_Console.ttf -------------------------------------------------------------------------------- /sparknlp_display/label_colors/ner.json: -------------------------------------------------------------------------------- 1 | { 2 | "problem": "#800080", 3 | "test": "#77b5fe", 4 | "treatment": "#8b6673", 5 | "multi": "#494ca3", 6 | "multi-tissue_structure": "#8dd8b4", 7 | "cell": "#ffe6cc", 8 | "organism": "#ffddcc", 9 | "gene_or_gene_product": "#fff0b3", 10 | "organ": "#e6e600", 11 | "simple_chemical": "#ffd699", 12 | "drug": "#8B668B", 13 | "diagnosis": "#b5a1c9", 14 | "maybe": "#FFB5C5", 15 | "lab_result": "#3abd80", 16 | "negated": "#CD3700", 17 | "name": "#C0FF3E", 18 | "lab_name": "#698B22", 19 | "modifier": "#8B475D", 20 | "symptom_name": "#CDB7B5", 21 | "section_name": "#8B7D7B", 22 | "drug_name": "#a3496c", 23 | "procedure_name": "#48D1CC", 24 | "grading": "#8c61e8", 25 | "size": "#746b87", 26 | "organism_substance": "#ffaa80", 27 | "gender": "#ffacb7", 28 | "age": "#ffe0ac", 29 | "date": "#a6b1e1" 30 | } -------------------------------------------------------------------------------- /sparknlp_display/label_colors/relations.json: -------------------------------------------------------------------------------- 1 | { 2 | "overlap" : "#ffb345", 3 | "before" : "#0398da", 4 | "after" : "#39bf7f", 5 | 6 | "trip": "#e4815e", 7 | "trwp": "#0398da", 8 | "trcp": "#39bf7f", 9 | "trap": "#ffb345", 10 | "trnap": "#0059b3", 11 | "terp": "#8c35cd", 12 | "tecp": "#fa3e74", 13 | "pip" : "#6e5772", 14 | 15 | "drug-strength" : "purple", 16 | "drug-frequency": "slategray", 17 | "drug-form" : "deepskyblue", 18 | "dosage-drug" : "springgreen", 19 | "strength-drug": "maroon", 20 | "drug-dosage" : "gold", 21 | 22 | "0" : "#e4815e", 23 | "1" : "#6e5772" 24 | } -------------------------------------------------------------------------------- /sparknlp_display/ner.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import json 4 | import numpy as np 5 | from . import style_utils as style_config 6 | from IPython.display import display, HTML 7 | 8 | here = os.path.abspath(os.path.dirname(__file__)) 9 | 10 | class NerVisualizer: 11 | def __init__(self): 12 | with open(os.path.join(here, 'label_colors/ner.json'), 'r', encoding='utf-8') as f_: 13 | self.label_colors = json.load(f_) 14 | 15 | #public function to get color for a label 16 | def get_label_color(self, label): 17 | """Returns color of a particular label 18 | 19 | Input: entity label 20 | Output: Color or if not found 21 | """ 22 | 23 | if str(label).lower() in self.label_colors: 24 | return self.label_colors[label.lower()] 25 | else: 26 | return None 27 | 28 | # private function for colors for display 29 | def __get_label(self, label): 30 | """Internal function to generate random color codes for missing colors 31 | 32 | Input: dictionary of entity labels and corresponding colors 33 | Output: color code (Hex) 34 | """ 35 | if str(label).lower() in self.label_colors: 36 | return self.label_colors[label.lower()] 37 | else: 38 | #update it to fetch from git new labels 39 | r = lambda: random.randint(0,200) 40 | return '#%02X%02X%02X' % (r(), r(), r()) 41 | 42 | def set_label_colors(self, color_dict): 43 | """Sets label colors. 44 | input: dictionary of entity labels and corresponding colors 45 | output: self object - to allow chaining 46 | note: Previous values of colors will be overwritten 47 | """ 48 | 49 | for key, value in color_dict.items(): 50 | self.label_colors[key.lower()] = value 51 | return self 52 | 53 | def __verify_structure(self, result, label_col, document_col, original_text): 54 | 55 | if original_text is None: 56 | basic_msg_1 = """Incorrect annotation structure of '{}'.""".format(document_col) 57 | if not hasattr(result[document_col][0], 'result'): 58 | raise AttributeError(basic_msg_1+""" 'result' attribute not found in the annotation. 59 | Make sure '"""+document_col+"""' is a list of objects having the following structure: 60 | Annotation(type='annotation', begin=0, end=10, result='This is a text') 61 | Or You can pass the text manually using 'raw_text' argument.""") 62 | 63 | basic_msg_1 = """Incorrect annotation structure of '{}'.""".format(label_col) 64 | basic_msg = """ 65 | In sparknlp please use 'LightPipeline.fullAnnotate' for LightPipeline or 'Pipeline.transform' for PipelineModel. 66 | Or 67 | Make sure '"""+label_col+"""' is a list of objects having the following structure: 68 | Annotation(type='annotation', begin=0, end=0, result='Adam', metadata={'entity': 'PERSON'})""" 69 | 70 | for entity in result[label_col]: 71 | if not hasattr(entity, 'begin'): 72 | raise AttributeError( basic_msg_1 + """ 'begin' attribute not found in the annotation."""+basic_msg) 73 | if not hasattr(entity, 'end'): 74 | raise AttributeError(basic_msg_1 + """ 'end' attribute not found in the annotation."""+basic_msg) 75 | if not hasattr(entity, 'result'): 76 | raise AttributeError(basic_msg_1 + """ 'result' attribute not found in the annotation."""+basic_msg) 77 | if not hasattr(entity, 'metadata'): 78 | raise AttributeError(basic_msg_1 + """ 'metadata' attribute not found in the annotation."""+basic_msg) 79 | if 'entity' not in entity.metadata: 80 | raise AttributeError(basic_msg_1+""" KeyError: 'entity' not found in metadata."""+basic_msg) 81 | 82 | def __verify_input(self, result, label_col, document_col, original_text): 83 | # check if label colum in result 84 | if label_col not in result: 85 | raise AttributeError("""column/key '{}' not found in the provided dataframe/dictionary. 86 | Please specify the correct key/column using 'label_col' argument.""".format(label_col)) 87 | 88 | if original_text is not None: 89 | # check if provided text is correct data type 90 | if not isinstance(original_text, str): 91 | raise ValueError("Invalid value for argument 'raw_text' input. Text should be of type 'str'.") 92 | 93 | else: 94 | # check if document column in result 95 | if document_col not in result: 96 | raise AttributeError("""column/key '{}' not found in the provided dataframe/dictionary. 97 | Please specify the correct key/column using 'document_col' argument. 98 | Or You can pass the text manually using 'raw_text' argument""".format(document_col)) 99 | 100 | self.__verify_structure( result, label_col, document_col, original_text) 101 | 102 | # main display function 103 | def __display_ner(self, result, label_col, document_col, original_text, labels_list = None): 104 | 105 | if original_text is None: 106 | original_text = result[document_col][0].result 107 | 108 | if labels_list is not None: 109 | labels_list = [v.lower() for v in labels_list] 110 | label_color = {} 111 | html_output = "" 112 | pos = 0 113 | 114 | sorted_labs_idx = np.argsort([ int(i.begin) for i in result[label_col]]) 115 | sorted_labs = [ result[label_col][i] for i in sorted_labs_idx] 116 | 117 | for entity in sorted_labs: 118 | entity_type = entity.metadata['entity'].lower() 119 | if (entity_type not in label_color) and ((labels_list is None) or (entity_type in labels_list)) : 120 | label_color[entity_type] = self.__get_label(entity_type) 121 | 122 | begin = int(entity.begin) 123 | end = int(entity.end) 124 | if pos < begin and pos < len(original_text): 125 | white_text = original_text[pos:begin] 126 | html_output += '{}'.format(white_text) 127 | pos = end+1 128 | 129 | if entity_type in label_color: 130 | html_output += '{} {}'.format( 131 | label_color[entity_type], 132 | original_text[begin:end+1],#entity.result, 133 | entity.metadata['entity']) 134 | else: 135 | html_output += '{}'.format(original_text[begin:end+1]) 136 | 137 | if pos < len(original_text): 138 | html_output += '{}'.format(original_text[pos:]) 139 | 140 | html_output += """""" 141 | 142 | html_output = html_output.replace("\n", "
") 143 | 144 | return html_output 145 | 146 | def display(self, result, label_col, document_col='document', raw_text=None, labels=None, return_html=False, save_path=None): 147 | """Displays NER visualization. 148 | Inputs: 149 | result -- A Dataframe or dictionary. 150 | label_col -- Name of the column/key containing NER annotations. 151 | document_col -- Name of the column/key containing text document. 152 | original_text -- Original text of type 'str'. If specified, it will take precedence over 'document_col' and will be used as the reference text for display. 153 | labels_list -- A list of labels that should be highlighted in the output. Default: Display all labels. 154 | Output: Visualization 155 | """ 156 | 157 | self.__verify_input(result, label_col, document_col, raw_text) 158 | 159 | html_content = self.__display_ner(result, label_col, document_col, raw_text, labels) 160 | 161 | html_content_save = style_config.STYLE_CONFIG_ENTITIES+ " "+html_content 162 | 163 | if save_path != None: 164 | with open(save_path, 'w') as f_: 165 | f_.write(html_content_save) 166 | 167 | if return_html: 168 | return html_content_save 169 | else: 170 | return display(HTML(html_content_save)) 171 | 172 | -------------------------------------------------------------------------------- /sparknlp_display/re_updates.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import json 4 | import pandas as pd 5 | import numpy as np 6 | import svgwrite 7 | from IPython.display import display, HTML 8 | 9 | here = os.path.abspath(os.path.dirname(__file__)) 10 | 11 | class RelationExtractionVisualizer: 12 | 13 | def __init__(self): 14 | self.color_dict = { 15 | "overlap" : "lightsalmon", 16 | "before" : "deepskyblue", 17 | "after" : "springgreen", 18 | 19 | "trip": "lightsalmon", 20 | "trwp": "deepskyblue", 21 | "trcp": "springgreen", 22 | "trap": "gold", 23 | "trnap": "maroon", 24 | "terp": "purple", 25 | "tecp": "tomato", 26 | "pip" : "slategray", 27 | 28 | "drug-strength" : "purple", 29 | "drug-frequency": "slategray", 30 | "drug-form" : "deepskyblue", 31 | "dosage-drug" : "springgreen", 32 | "strength-drug": "maroon", 33 | "drug-dosage" : "gold" 34 | } 35 | 36 | def __get_color(self, l): 37 | r = lambda: random.randint(100,255) 38 | return '#%02X%02X%02X' % (r(), r(), r()) 39 | 40 | def __size(self, text): 41 | return ((len(text)+1)*9.7)-5 42 | 43 | def __draw_line(self, dwg, s_x , s_y, e_x, e_y, d_type, color, show_relations): 44 | # find the a & b points 45 | def get_bezier_coef(points): 46 | # since the formulas work given that we have n+1 points 47 | # then n must be this: 48 | n = len(points) - 1 49 | 50 | # build coefficents matrix 51 | C = 4 * np.identity(n) 52 | np.fill_diagonal(C[1:], 1) 53 | np.fill_diagonal(C[:, 1:], 1) 54 | C[0, 0] = 2 55 | C[n - 1, n - 1] = 7 56 | C[n - 1, n - 2] = 2 57 | 58 | # build points vector 59 | P = [2 * (2 * points[i] + points[i + 1]) for i in range(n)] 60 | P[0] = points[0] + 2 * points[1] 61 | P[n - 1] = 8 * points[n - 1] + points[n] 62 | 63 | # solve system, find a & b 64 | A = np.linalg.solve(C, P) 65 | B = [0] * n 66 | for i in range(n - 1): 67 | B[i] = 2 * points[i + 1] - A[i + 1] 68 | B[n - 1] = (A[n - 1] + points[n]) / 2 69 | 70 | return A, B 71 | 72 | # returns the general Bezier cubic formula given 4 control points 73 | def get_cubic(a, b, c, d): 74 | return lambda t: np.power(1 - t, 3) * a + 3 * np.power(1 - t, 2) * t * b + 3 * (1 - t) * np.power(t, 2) * c + np.power(t, 3) * d 75 | 76 | # return one cubic curve for each consecutive points 77 | def get_bezier_cubic(points): 78 | A, B = get_bezier_coef(points) 79 | return [ 80 | get_cubic(points[i], A[i], B[i], points[i + 1]) 81 | for i in range(len(points) - 1) 82 | ] 83 | 84 | # evalute each cubic curve on the range [0, 1] sliced in n points 85 | def evaluate_bezier(points, n): 86 | curves = get_bezier_cubic(points) 87 | return np.array([fun(t) for fun in curves for t in np.linspace(0, 1, n)]) 88 | 89 | 90 | def draw_pointer(dwg, s_x, s_y, e_x, e_y): 91 | size = 8 92 | ratio = 2 93 | fullness1 = 2 94 | fullness2 = 3 95 | bx = e_x 96 | ax = s_x 97 | by = e_y 98 | ay = s_y 99 | abx = bx - ax 100 | aby = by - ay 101 | ab = np.sqrt(abx * abx + aby * aby) 102 | 103 | cx = bx - size * abx / ab 104 | cy = by - size * aby / ab 105 | dx = cx + (by - cy) / ratio 106 | dy = cy + (cx - bx) / ratio 107 | ex = cx - (by - cy) / ratio 108 | ey = cy - (cx - bx) / ratio 109 | fx = (fullness1 * cx + bx) / fullness2 110 | fy = (fullness1 * cy + by) / fullness2 111 | 112 | text_place_y = s_y-(abs(s_y-e_y)/2) 113 | line = dwg.add(dwg.polyline( 114 | [ 115 | (bx, by), 116 | (dx, dy), 117 | (fx, fy), 118 | (ex, ey), 119 | (bx, by) 120 | ], 121 | stroke=color, stroke_width = "2", fill='none',)) 122 | return text_place_y 123 | 124 | if s_x > e_x: 125 | #s_x -= 5 126 | e_x += 10 127 | else: 128 | #s_x += 5 129 | e_x -= 2 130 | if s_y == e_y: 131 | s_y -= 20 132 | e_y = s_y-4#55 133 | text_place_y = s_y-45 134 | 135 | pth = evaluate_bezier(np.array([[s_x, s_y], 136 | [(s_x+e_x)/2.0, s_y-40], 137 | [e_x,e_y]]), 50) 138 | dwg.add(dwg.polyline(pth, 139 | stroke=color, stroke_width = "2", fill='none',)) 140 | 141 | draw_pointer(dwg, (s_x+e_x)/2.0, s_y-50, e_x, e_y) 142 | 143 | elif s_y >= e_y: 144 | e_y +=15 145 | s_y-=20 146 | dwg.add(dwg.polyline([(s_x,s_y), (e_x, e_y)], 147 | stroke=color, stroke_width = "2", fill='none',)) 148 | text_place_y = draw_pointer(dwg, s_x, s_y, e_x, e_y) 149 | else: 150 | s_y-=5 151 | e_y -= 40 152 | dwg.add(dwg.polyline([(s_x,s_y), (e_x, e_y)], 153 | stroke=color, stroke_width = "2", fill='none',)) 154 | text_place_y = draw_pointer(dwg, s_x, s_y, e_x, e_y) 155 | if show_relations: 156 | dwg.add(dwg.text(d_type, insert=(((s_x+e_x)/2)-(self.__size(d_type)/2.75), text_place_y), 157 | fill=color, font_size='12', font_family='courier')) 158 | 159 | def __gen_graph(self, rdf, selected_text, show_relations): 160 | 161 | done_ent1 = {} 162 | done_ent2 = {} 163 | all_done = {} 164 | 165 | start_y = 75 166 | x_limit = 920 167 | y_offset = 100 168 | dwg = svgwrite.Drawing("temp.svg",profile='tiny', size = (x_limit, len(selected_text) * 1.1 + len(rdf)*20)) 169 | 170 | begin_index = 0 171 | start_x = 10 172 | this_line = 0 173 | 174 | all_entities_index = set() 175 | all_entities_1_index = [] 176 | basic_dict = {} 177 | relation_dict = {} 178 | for t in rdf: 179 | if t.result.lower().strip() != 'o': 180 | all_entities_index.add(int(t.metadata['entity1_begin'])) 181 | all_entities_index.add(int(t.metadata['entity2_begin'])) 182 | basic_dict[int(t.metadata['entity1_begin'])] = [t.metadata['entity1_begin'], 183 | t.metadata['entity1_end'], 184 | t.metadata['chunk1'], 185 | t.metadata['entity1']] 186 | 187 | basic_dict[int(t.metadata['entity2_begin'])] = [t.metadata['entity2_begin'], 188 | t.metadata['entity2_end'], 189 | t.metadata['chunk2'], 190 | t.metadata['entity2']] 191 | 192 | #all_entities_1_index.append(t[4]['entity1_begin']) 193 | all_entities_index = np.asarray(list(all_entities_index)) 194 | all_entities_index = all_entities_index[np.argsort(all_entities_index)] 195 | for ent_start_ind in all_entities_index: 196 | e_start_now, e_end_now, e_chunk_now, e_entity_now = basic_dict[ent_start_ind] 197 | prev_text = selected_text[begin_index:int(e_start_now)] 198 | begin_index = int(e_end_now)+1 199 | for word_ in prev_text.split(' '): 200 | this_size = self.__size(word_) 201 | if (start_x + this_size + 10) >= x_limit: 202 | start_y += y_offset 203 | start_x = 10 204 | this_line = 0 205 | dwg.add(dwg.text(word_, insert=(start_x, start_y ), fill='gray', font_size='16', font_family='courier')) 206 | start_x += this_size + 5 207 | 208 | this_size = self.__size(e_chunk_now) 209 | if (start_x + this_size + 10)>= x_limit:# or this_line >= 2: 210 | start_y += y_offset 211 | start_x = 10 212 | this_line = 0 213 | #chunk1 214 | dwg.add(dwg.text(e_chunk_now, insert=(start_x, start_y ), fill='gray', font_size='16', font_family='courier')) 215 | #rectange chunk 1 216 | dwg.add(dwg.rect(insert=(start_x-3, start_y-18), size=(this_size,25), 217 | rx=2, ry=2, stroke='orange', 218 | stroke_width='2', fill='none')) 219 | #entity 1 220 | central_point_x = start_x+(this_size/2) 221 | 222 | dwg.add(dwg.text(e_entity_now, 223 | insert=(central_point_x-(self.__size(e_entity_now)/2.75), start_y+20), 224 | fill='slateblue', font_size='12', font_family='courier')) 225 | 226 | all_done[int(e_start_now)] = [central_point_x-(self.__size(e_entity_now)/2.75), start_y] 227 | start_x += this_size + 10 228 | this_line += 1 229 | 230 | #all_done[ent_start_ind] = 231 | 232 | prev_text = selected_text[begin_index:] 233 | for word_ in prev_text.split(' '): 234 | this_size = self.__size(word_) 235 | if (start_x + this_size)>= x_limit: 236 | start_y += y_offset 237 | start_x = 10 238 | dwg.add(dwg.text(word_, insert=(start_x, start_y ), fill='gray', font_size='16', font_family='courier')) 239 | start_x += this_size 240 | 241 | for row in rdf: 242 | if row.result.lower().strip() != 'o': 243 | if row.result.lower().strip() not in self.color_dict: 244 | self.color_dict[row.result.lower().strip()] = self.__get_color(row.result.lower().strip()) 245 | d_key2 = all_done[int(row.metadata['entity2_begin'])] 246 | d_key1 = all_done[int(row.metadata['entity1_begin'])] 247 | self.__draw_line(dwg, d_key2[0] , d_key2[1], d_key1[0], d_key1[1], 248 | row.result,self.color_dict[row.result.lower().strip()], show_relations) 249 | 250 | return dwg.tostring() 251 | 252 | def display(self, result, relation_col, document_col='document', show_relations=True): 253 | 254 | original_text = result[document_col][0].result 255 | res = result[relation_col] 256 | return display(HTML(self.__gen_graph(res, original_text, show_relations))) 257 | -------------------------------------------------------------------------------- /sparknlp_display/relation_extraction.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | import os 4 | import json 5 | import pandas as pd 6 | import numpy as np 7 | import svgwrite 8 | import math 9 | import re 10 | from IPython.display import display, HTML 11 | 12 | here = os.path.abspath(os.path.dirname(__file__)) 13 | #overlap_hist = [] 14 | #y_hist_dict = {} 15 | x_i_diff_dict = {} 16 | x_o_diff_dict = {} 17 | class RelationExtractionVisualizer: 18 | 19 | def __init__(self): 20 | with open(os.path.join(here, 'label_colors/relations.json'), 'r', encoding='utf-8') as f_: 21 | self.color_dict = json.load(f_) 22 | with open(os.path.join(here, 'label_colors/ner.json'), 'r', encoding='utf-8') as f_: 23 | self.entity_color_dict = json.load(f_) 24 | self.entity_color_dict = dict((k.lower(), v) for k, v in self.entity_color_dict.items()) 25 | self.font_path = os.path.join(here, 'fonts/Lucida_Console.ttf') 26 | self.main_font = 'Lucida' 27 | def __get_color(self, l): 28 | r = lambda: random.randint(0,200) 29 | return '#%02X%02X%02X' % (r(), r(), r()) 30 | 31 | def __size(self, text): 32 | return ((len(text)+1)*9.7)-5 33 | 34 | def __draw_line(self, dwg, s_x , s_y, e_x, e_y, d_type, color, show_relations, size_of_entity_label): 35 | eps = 0.0000001 36 | def get_bezier_coef(points): 37 | # since the formulas work given that we have n+1 points 38 | # then n must be this: 39 | n = len(points) - 1 40 | 41 | # build coefficents matrix 42 | C = 4 * np.identity(n) 43 | np.fill_diagonal(C[1:], 1) 44 | np.fill_diagonal(C[:, 1:], 1) 45 | C[0, 0] = 2 46 | C[n - 1, n - 1] = 7 47 | C[n - 1, n - 2] = 2 48 | 49 | # build points vector 50 | P = [2 * (2 * points[i] + points[i + 1]) for i in range(n)] 51 | P[0] = points[0] + 2 * points[1] 52 | P[n - 1] = 8 * points[n - 1] + points[n] 53 | 54 | # solve system, find a & b 55 | A = np.linalg.solve(C, P) 56 | B = [0] * n 57 | for i in range(n - 1): 58 | B[i] = 2 * points[i + 1] - A[i + 1] 59 | B[n - 1] = (A[n - 1] + points[n]) / 2 60 | 61 | return A, B 62 | 63 | # returns the general Bezier cubic formula given 4 control points 64 | def get_cubic(a, b, c, d): 65 | return lambda t: np.power(1 - t, 3) * a + 3 * np.power(1 - t, 2) * t * b + 3 * (1 - t) * np.power(t, 2) * c + np.power(t, 3) * d 66 | 67 | # return one cubic curve for each consecutive points 68 | def get_bezier_cubic(points): 69 | A, B = get_bezier_coef(points) 70 | return [ 71 | get_cubic(points[i], A[i], B[i], points[i + 1]) 72 | for i in range(len(points) - 1) 73 | ] 74 | 75 | # evalute each cubic curve on the range [0, 1] sliced in n points 76 | def evaluate_bezier(points, n): 77 | curves = get_bezier_cubic(points) 78 | return np.array([fun(t) for fun in curves for t in np.linspace(0, 1, n)]) 79 | 80 | 81 | def draw_pointer(dwg, s_x, s_y, e_x, e_y): 82 | size = 5 83 | ratio = 1 84 | fullness1 = 2 85 | fullness2 = 3 86 | bx = e_x 87 | ax = s_x 88 | by = e_y 89 | ay = s_y 90 | abx = bx - ax 91 | aby = by - ay 92 | ab = np.sqrt(abx * abx + aby * aby) + eps 93 | 94 | cx = bx - size * abx / ab 95 | cy = by - size * aby / ab 96 | dx = cx + (by - cy) / ratio 97 | dy = cy + (cx - bx) / ratio 98 | ex = cx - (by - cy) / ratio 99 | ey = cy - (cx - bx) / ratio 100 | fx = (fullness1 * cx + bx) / fullness2 101 | fy = (fullness1 * cy + by) / fullness2 102 | 103 | text_place_y = s_y-(abs(s_y-e_y)/2) 104 | ''' 105 | line = dwg.add(dwg.polyline( 106 | [ 107 | (bx, by), 108 | (dx, dy), 109 | (fx, fy), 110 | (ex, ey), 111 | (bx, by) 112 | ], 113 | stroke=color, stroke_width = "1", fill='none',)) 114 | ''' 115 | line = dwg.add(dwg.polyline( 116 | [ 117 | (dx, dy), 118 | (bx, by), 119 | (ex, ey), 120 | (bx, by) 121 | ], 122 | stroke=color, stroke_width = "1", fill='none',)) 123 | return text_place_y 124 | unique_o_index = str(s_x)+str(s_y) 125 | unique_i_index = str(e_x)+str(e_y) 126 | if s_x > e_x: 127 | if unique_o_index in x_o_diff_dict: 128 | s_x -= 5 129 | else: 130 | s_x -= 10 131 | x_o_diff_dict[unique_o_index] = 5 132 | if s_y > e_y: 133 | e_x += size_of_entity_label 134 | elif s_y < e_y: 135 | s_x -= size_of_entity_label 136 | 137 | if unique_i_index in x_i_diff_dict: 138 | e_x += 5 139 | else: 140 | e_x += 10 141 | x_i_diff_dict[unique_i_index] = 5 142 | else: 143 | if unique_o_index in x_o_diff_dict: 144 | s_x += 5 145 | else: 146 | s_x += 10 147 | x_o_diff_dict[unique_o_index] = 5 148 | if s_y > e_y: 149 | e_x -= size_of_entity_label 150 | elif s_y < e_y: 151 | s_x += size_of_entity_label 152 | 153 | if unique_i_index in x_i_diff_dict: 154 | e_x -= 5 155 | else: 156 | e_x -= 10 157 | x_i_diff_dict[unique_i_index] = 5 158 | #this_y_vals = list(range(min(s_x,e_x), max(s_x,e_x)+1)) 159 | #this_y_vals = [ str(s_y)+'|'+str(i) for i in this_y_vals] 160 | #common = set(this_y_vals) & set(overlap_hist) 161 | #overlap_hist.extend(this_y_vals) 162 | #if s_y not in y_hist_dict: 163 | # y_hist_dict[s_y] = 20 164 | #if common: 165 | # y_hist_dict[s_y] += 20 166 | #y_increase = y_hist_dict[s_y] 167 | angle = -1 168 | if s_y == e_y: 169 | angle = 0 170 | s_y -= 20 171 | e_y = s_y-4#55 172 | 173 | text_place_y = s_y-35 174 | 175 | pth = evaluate_bezier(np.array([[s_x, s_y], 176 | [(s_x+e_x)/2.0, s_y-40], 177 | [e_x,e_y]]), 50) 178 | dwg.add(dwg.polyline(pth, 179 | stroke=color, stroke_width = "1", fill='none',)) 180 | draw_pointer(dwg, (s_x+e_x)/2.0, s_y-50, e_x, e_y) 181 | elif s_y >= e_y: 182 | 183 | e_y +=15 184 | s_y-=20 185 | text_place_y = s_y-(abs(s_y-e_y)/2) 186 | 187 | pth = evaluate_bezier(np.array([[s_x, s_y], 188 | #[((3*s_x)+e_x)/4.0, (s_y+e_y)/2.0], 189 | [(s_x+e_x)/2.0, (s_y+e_y)/2.0], 190 | #[(s_x+(3*e_x))/4.0,(s_y+e_y)/2.0], 191 | [e_x,e_y]]), 50) 192 | dwg.add(dwg.polyline(pth, 193 | stroke=color, stroke_width = "1", fill='none',)) 194 | draw_pointer(dwg, s_x, s_y, e_x, e_y) 195 | 196 | ''' 197 | line = dwg.add(dwg.polyline( 198 | [(s_x, s_y),(s_x, s_y-y_increase), (e_x, s_y-y_increase), 199 | (e_x, e_y), 200 | (e_x+2, e_y), 201 | (e_x, e_y-4), 202 | (e_x-2, e_y), 203 | (e_x, e_y) 204 | ], 205 | stroke=color, stroke_width = "2", fill='none',)) 206 | ''' 207 | else: 208 | 209 | s_y+=15 210 | e_y -= 20 211 | text_place_y = s_y+(abs(s_y-e_y)/2) 212 | 213 | pth = evaluate_bezier(np.array([[s_x, s_y], 214 | #[((3*s_x)+e_x)/4.0, (s_y+e_y)/2.0], 215 | [(s_x+e_x)/2.0, (s_y+e_y)/2.0], 216 | #[(s_x+(3*e_x))/4.0,(s_y+e_y)/2.0], 217 | [e_x,e_y]]), 50) 218 | dwg.add(dwg.polyline(pth, 219 | stroke=color, stroke_width = "1", fill='none',)) 220 | draw_pointer(dwg, s_x, s_y, e_x, e_y) 221 | 222 | if show_relations: 223 | if angle == -1: angle = math.degrees(math.atan((s_y-e_y)/((s_x-e_x)+eps))) 224 | rel_temp_size = self.__size(d_type)/1.35 225 | rect_x, rect_y = (((s_x+e_x)/2.0)-(rel_temp_size/2.0)-3, text_place_y-10) 226 | rect_w, rect_h = (rel_temp_size+3,13) 227 | dwg.add(dwg.rect(insert=(rect_x, rect_y), rx=2,ry=2, 228 | size=(rect_w, rect_h), 229 | fill='white', stroke=color , stroke_width='1', 230 | transform = f"rotate({angle} {rect_x+rect_w/2} {rect_y+rect_h/2})")) 231 | 232 | dwg.add(dwg.text(d_type, insert=(((s_x+e_x)/2)-(rel_temp_size/2.0), text_place_y), 233 | fill=color, font_size='12', font_family='courier', 234 | transform = f"rotate({angle} {rect_x+rect_w/2} {rect_y+rect_h/2})")) 235 | 236 | def __gen_graph(self, rdf, selected_text, exclude_relations, show_relations): 237 | exclude_relations = [ i.lower().strip() for i in exclude_relations] 238 | rdf = [ i for i in rdf if i.result.lower().strip() not in exclude_relations] 239 | 240 | done_ent1 = {} 241 | done_ent2 = {} 242 | all_done = {} 243 | 244 | start_y = 75 245 | x_limit = 1000 246 | y_offset = 100 247 | #dwg = svgwrite.Drawing("temp.svg",profile='full', size = (x_limit, len(selected_text) * 1.1 + len(rdf)*20)) 248 | 249 | begin_index = 0 250 | start_x = 10 251 | this_line = 0 252 | 253 | all_entities_index = set() 254 | all_entities_1_index = [] 255 | basic_dict = {} 256 | relation_dict = {} 257 | for t in rdf: 258 | 259 | all_entities_index.add(int(t.metadata['entity1_begin'])) 260 | all_entities_index.add(int(t.metadata['entity2_begin'])) 261 | basic_dict[int(t.metadata['entity1_begin'])] = [t.metadata['entity1_begin'], 262 | t.metadata['entity1_end'], 263 | t.metadata['chunk1'], 264 | t.metadata['entity1']] 265 | 266 | basic_dict[int(t.metadata['entity2_begin'])] = [t.metadata['entity2_begin'], 267 | t.metadata['entity2_end'], 268 | t.metadata['chunk2'], 269 | t.metadata['entity2']] 270 | if t.metadata['entity1'].lower().strip() not in self.entity_color_dict: 271 | self.entity_color_dict[t.metadata['entity1'].lower().strip()] = self.__get_color(t.metadata['entity1'].lower().strip()) 272 | if t.metadata['entity2'].lower().strip() not in self.entity_color_dict: 273 | self.entity_color_dict[t.metadata['entity2'].lower().strip()] = self.__get_color(t.metadata['entity2'].lower().strip()) 274 | 275 | 276 | #all_entities_1_index.append(t[4]['entity1_begin']) 277 | all_entities_index = np.asarray(list(all_entities_index)) 278 | all_entities_index = all_entities_index[np.argsort(all_entities_index)] 279 | dwg_rects, dwg_texts = [], [] 280 | for ent_start_ind in all_entities_index: 281 | e_start_now, e_end_now, e_chunk_now, e_entity_now = basic_dict[ent_start_ind] 282 | prev_text = selected_text[begin_index:int(e_start_now)] 283 | prev_text = re.sub(r'\s*(\n)+', r'\1', prev_text.strip(), re.MULTILINE) 284 | begin_index = int(e_end_now)+1 285 | for line_num, line in enumerate(prev_text.split('\n')): 286 | if line_num != 0: 287 | start_y += y_offset 288 | start_x = 10 289 | this_line = 0 290 | for word_ in line.split(' '): 291 | this_size = self.__size(word_) 292 | if (start_x + this_size + 10) >= x_limit: 293 | start_y += y_offset 294 | start_x = 10 295 | this_line = 0 296 | dwg_texts.append([word_, (start_x, start_y ), '#546c74', '16', self.main_font, 'font-weight:100']) 297 | #dwg.add(dwg.text(word_, insert=(start_x, start_y ), fill='#546c77', font_size='16', 298 | # font_family='Monaco', style='font-weight:lighter')) 299 | start_x += this_size + 10 300 | 301 | this_size = self.__size(e_chunk_now) 302 | if (start_x + this_size + 10)>= x_limit:# or this_line >= 2: 303 | start_y += y_offset 304 | start_x = 10 305 | this_line = 0 306 | 307 | #rectange chunk 1 308 | dwg_rects.append([(start_x-3, start_y-18), (this_size,25), self.entity_color_dict[e_entity_now.lower().strip()]]) 309 | #dwg.add(dwg.rect(insert=(start_x-3, start_y-18),rx=2,ry=2, size=(this_size,25), stroke=self.entity_color_dict[e_entity_now.lower()], 310 | #stroke_width='1', fill=self.entity_color_dict[e_entity_now.lower()], fill_opacity='0.2')) 311 | #chunk1 312 | dwg_texts.append([e_chunk_now, (start_x, start_y ), '#546c74', '16', self.main_font, 'font-weight:100']) 313 | #dwg.add(dwg.text(e_chunk_now, insert=(start_x, start_y ), fill='#546c77', font_size='16', 314 | # font_family='Monaco', style='font-weight:lighter')) 315 | #entity 1 316 | central_point_x = start_x+(this_size/2) 317 | temp_size = self.__size(e_entity_now)/2.75 318 | dwg_texts.append([e_entity_now.upper(), (central_point_x-temp_size, start_y+20), '#1f77b7', '12', self.main_font, 'font-weight:lighter']) 319 | #dwg.add(dwg.text(e_entity_now.upper(), 320 | # insert=(central_point_x-temp_size, start_y+20), 321 | # fill='#1f77b7', font_size='12', font_family='Monaco', 322 | # style='font-weight:lighter')) 323 | 324 | all_done[int(e_start_now)] = [central_point_x, start_y, temp_size] 325 | start_x += this_size + 20 326 | this_line += 1 327 | 328 | 329 | prev_text = selected_text[begin_index:] 330 | prev_text = re.sub(r'\s*(\n)+', r'\1', prev_text.strip(), re.MULTILINE) 331 | for line_num, line in enumerate(prev_text.split('\n')): 332 | if line_num != 0: 333 | start_y += y_offset 334 | start_x = 10 335 | for word_ in line.split(' '): 336 | this_size = self.__size(word_) 337 | if (start_x + this_size)>= x_limit: 338 | start_y += y_offset 339 | start_x = 10 340 | dwg_texts.append([word_, (start_x, start_y ), '#546c77', '16', self.main_font, 'font-weight:100']) 341 | #dwg.add(dwg.text(word_, insert=(start_x, start_y ), fill='#546c77', font_size='16', 342 | # font_family='Monaco', style='font-weight:lighter')) 343 | start_x += this_size + 10 344 | 345 | 346 | dwg = svgwrite.Drawing("temp.svg",profile='full', size = (x_limit, start_y+y_offset)) 347 | dwg.embed_font(self.main_font, self.font_path) 348 | 349 | for crect_ in dwg_rects: 350 | dwg.add(dwg.rect(insert=crect_[0],rx=2,ry=2, size=crect_[1], stroke=crect_[2], 351 | stroke_width='1', fill=crect_[2], fill_opacity='0.2')) 352 | 353 | for ctext_ in dwg_texts: 354 | dwg.add(dwg.text(ctext_[0], insert=ctext_[1], fill=ctext_[2], font_size=ctext_[3], 355 | font_family=ctext_[4], style=ctext_[5])) 356 | 357 | 358 | relation_distances = [] 359 | relation_coordinates = [] 360 | for row in rdf: 361 | if row.result.lower().strip() not in self.color_dict: 362 | self.color_dict[row.result.lower().strip()] = self.__get_color(row.result.lower().strip()) 363 | 364 | d_key2 = all_done[int(row.metadata['entity1_begin'])] 365 | d_key1 = all_done[int(row.metadata['entity2_begin'])] 366 | this_dist = abs(d_key2[0] - d_key1[0]) + abs (d_key2[1]-d_key1[1]) 367 | relation_distances.append(this_dist) 368 | relation_coordinates.append((d_key2, d_key1, row.result)) 369 | 370 | relation_distances = np.array(relation_distances) 371 | relation_coordinates = np.array(relation_coordinates, dtype=object) 372 | temp_ind = np.argsort(relation_distances) 373 | relation_distances = relation_distances[temp_ind] 374 | relation_coordinates = relation_coordinates[temp_ind] 375 | for row in relation_coordinates: 376 | #if int(row[0][1]) == int(row[1][1]): 377 | size_of_entity_label = int(row[1][2]) 378 | self.__draw_line(dwg, int(row[0][0]) , int(row[0][1]), int(row[1][0]), int(row[1][1]), 379 | row[2],self.color_dict[row[2].lower().strip()], show_relations, size_of_entity_label) 380 | 381 | return dwg.tostring() 382 | 383 | def display(self, result, relation_col, document_col='document', exclude_relations=['O'], show_relations=True, return_html=False, save_path=None): 384 | """Displays Relation Extraction visualization. 385 | Inputs: 386 | result -- A Dataframe or dictionary. 387 | relation_col -- Name of the column/key containing relationships. 388 | document_col -- Name of the column/key containing text document. 389 | exclude_relations -- list of relations that don't need to be displayed. Default: ["O"] 390 | show_relations -- Display relation types on arrows. Default: True 391 | return_html -- If true, returns raw html code instead of displaying. Default: False 392 | Output: Visualization 393 | """ 394 | 395 | original_text = result[document_col][0].result 396 | res = result[relation_col] 397 | 398 | html_content = self.__gen_graph(res, original_text, exclude_relations, show_relations) 399 | 400 | if save_path != None: 401 | with open(save_path, 'w') as f_: 402 | f_.write(html_content) 403 | 404 | if return_html: 405 | return html_content 406 | else: 407 | return display(HTML(html_content)) 408 | -------------------------------------------------------------------------------- /sparknlp_display/retemp.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import json 4 | import pandas as pd 5 | import numpy as np 6 | import svgwrite 7 | from IPython.display import display, HTML 8 | 9 | here = os.path.abspath(os.path.dirname(__file__)) 10 | overlap_hist = [] 11 | y_hist_dict = {} 12 | x_i_diff_dict = {} 13 | x_o_diff_dict = {} 14 | class RelationExtractionVisualizer: 15 | 16 | def __init__(self): 17 | self.color_dict = { 18 | "overlap" : "lightsalmon", 19 | "before" : "deepskyblue", 20 | "after" : "springgreen", 21 | 22 | "trip": "lightsalmon", 23 | "trwp": "deepskyblue", 24 | "trcp": "springgreen", 25 | "trap": "gold", 26 | "trnap": "maroon", 27 | "terp": "purple", 28 | "tecp": "tomato", 29 | "pip" : "slategray", 30 | 31 | "drug-strength" : "purple", 32 | "drug-frequency": "slategray", 33 | "drug-form" : "deepskyblue", 34 | "dosage-drug" : "springgreen", 35 | "strength-drug": "maroon", 36 | "drug-dosage" : "gold" 37 | } 38 | 39 | def __get_color(self, l): 40 | r = lambda: random.randint(100,255) 41 | return '#%02X%02X%02X' % (r(), r(), r()) 42 | 43 | def __size(self, text): 44 | return ((len(text)+1)*9.7)-5 45 | 46 | def __draw_line(self, dwg, s_x , s_y, e_x, e_y, d_type, color, show_relations): 47 | if s_x > e_x: 48 | if s_x in x_o_diff_dict: 49 | x_o_diff_dict[s_x] -= 10 50 | else: 51 | x_o_diff_dict[s_x] = 10 52 | if e_x in x_i_diff_dict: 53 | x_i_diff_dict[e_x] += 10 54 | else: 55 | x_i_diff_dict[e_x] = 10 56 | s_x -= x_o_diff_dict[s_x] 57 | e_x += x_i_diff_dict[e_x] 58 | else: 59 | if s_x in x_o_diff_dict: 60 | x_o_diff_dict[s_x] += 10 61 | else: 62 | x_o_diff_dict[s_x] = 10 63 | if e_x in x_i_diff_dict: 64 | x_i_diff_dict[e_x] -= 10 65 | else: 66 | x_i_diff_dict[e_x] = 10 67 | s_x += x_o_diff_dict[s_x] 68 | e_x -= x_i_diff_dict[e_x] 69 | this_y_vals = list(range(min(s_x,e_x), max(s_x,e_x)+1)) 70 | this_y_vals = [ str(s_y)+'|'+str(i) for i in this_y_vals] 71 | common = set(this_y_vals) & set(overlap_hist) 72 | overlap_hist.extend(this_y_vals) 73 | if s_y not in y_hist_dict: 74 | y_hist_dict[s_y] = 20 75 | if common: 76 | y_hist_dict[s_y] += 20 77 | y_increase = y_hist_dict[s_y] 78 | if s_y == e_y: 79 | s_y -= 20 80 | e_y = s_y-4#55 81 | 82 | text_place_y = s_y-45 83 | 84 | line = dwg.add(dwg.polyline( 85 | [(s_x, s_y), (s_x, s_y-y_increase), (e_x, s_y-y_increase), 86 | (e_x, e_y), 87 | (e_x+2, e_y), 88 | (e_x, e_y+4), 89 | (e_x-2, e_y), 90 | (e_x, e_y) 91 | ], 92 | stroke=color, stroke_width = "2", fill='none',)) 93 | elif s_y >= e_y: 94 | e_y +=30 95 | s_y-=20 96 | text_place_y = s_y-(abs(s_y-e_y)/2) 97 | line = dwg.add(dwg.polyline( 98 | [(s_x, s_y),(s_x, s_y-y_increase), (e_x, s_y-y_increase), 99 | (e_x, e_y), 100 | (e_x+2, e_y), 101 | (e_x, e_y-4), 102 | (e_x-2, e_y), 103 | (e_x, e_y) 104 | ], 105 | stroke=color, stroke_width = "2", fill='none',)) 106 | else: 107 | s_y-=5 108 | e_y -= 40 109 | text_place_y = s_y+(abs(s_y-e_y)/2) 110 | line = dwg.add(dwg.polyline( 111 | [(s_x, s_y), 112 | (e_x, e_y-40), 113 | (e_x+2, e_y), 114 | (e_x, e_y+4), 115 | (e_x-2, e_y), 116 | (e_x, e_y) 117 | ], 118 | stroke=color, stroke_width = "2", fill='none',)) 119 | if show_relations: 120 | dwg.add(dwg.text(d_type, insert=(((s_x+e_x)/2)-(self.__size(d_type)/2.75), text_place_y), 121 | fill=color, font_size='12', font_family='courier')) 122 | 123 | def __gen_graph(self, rdf, selected_text, show_relations): 124 | rdf = [ i for i in rdf if i.result.lower().strip()!='o'] 125 | 126 | done_ent1 = {} 127 | done_ent2 = {} 128 | all_done = {} 129 | 130 | start_y = 75 131 | x_limit = 920 132 | y_offset = 100 133 | dwg = svgwrite.Drawing("temp.svg",profile='tiny', size = (x_limit, len(selected_text) * 1.1 + len(rdf)*20)) 134 | 135 | begin_index = 0 136 | start_x = 10 137 | this_line = 0 138 | 139 | all_entities_index = set() 140 | all_entities_1_index = [] 141 | basic_dict = {} 142 | relation_dict = {} 143 | for t in rdf: 144 | 145 | all_entities_index.add(int(t.metadata['entity1_begin'])) 146 | all_entities_index.add(int(t.metadata['entity2_begin'])) 147 | basic_dict[int(t.metadata['entity1_begin'])] = [t.metadata['entity1_begin'], 148 | t.metadata['entity1_end'], 149 | t.metadata['chunk1'], 150 | t.metadata['entity1']] 151 | 152 | basic_dict[int(t.metadata['entity2_begin'])] = [t.metadata['entity2_begin'], 153 | t.metadata['entity2_end'], 154 | t.metadata['chunk2'], 155 | t.metadata['entity2']] 156 | 157 | #all_entities_1_index.append(t[4]['entity1_begin']) 158 | all_entities_index = np.asarray(list(all_entities_index)) 159 | all_entities_index = all_entities_index[np.argsort(all_entities_index)] 160 | for ent_start_ind in all_entities_index: 161 | e_start_now, e_end_now, e_chunk_now, e_entity_now = basic_dict[ent_start_ind] 162 | prev_text = selected_text[begin_index:int(e_start_now)] 163 | begin_index = int(e_end_now)+1 164 | for word_ in prev_text.split(' '): 165 | this_size = self.__size(word_) 166 | if (start_x + this_size + 10) >= x_limit: 167 | start_y += y_offset 168 | start_x = 10 169 | this_line = 0 170 | dwg.add(dwg.text(word_, insert=(start_x, start_y ), fill='gray', font_size='16', font_family='courier')) 171 | start_x += this_size + 5 172 | 173 | this_size = self.__size(e_chunk_now) 174 | if (start_x + this_size + 10)>= x_limit:# or this_line >= 2: 175 | start_y += y_offset 176 | start_x = 10 177 | this_line = 0 178 | #chunk1 179 | dwg.add(dwg.text(e_chunk_now, insert=(start_x, start_y ), fill='gray', font_size='16', font_family='courier')) 180 | #rectange chunk 1 181 | dwg.add(dwg.rect(insert=(start_x-3, start_y-18),rx=3,ry=3, size=(this_size,25), stroke='orange', 182 | stroke_width='2', fill='none')) 183 | #entity 1 184 | central_point_x = start_x+(this_size/2) 185 | 186 | dwg.add(dwg.text(e_entity_now, 187 | insert=(central_point_x-(self.__size(e_entity_now)/2.75), start_y+20), 188 | fill='mediumseagreen', font_size='12', font_family='courier')) 189 | 190 | all_done[int(e_start_now)] = [central_point_x, start_y] 191 | start_x += this_size + 10 192 | this_line += 1 193 | 194 | #all_done[ent_start_ind] = 195 | 196 | prev_text = selected_text[begin_index:] 197 | for word_ in prev_text.split(' '): 198 | this_size = self.__size(word_) 199 | if (start_x + this_size)>= x_limit: 200 | start_y += y_offset 201 | start_x = 10 202 | dwg.add(dwg.text(word_, insert=(start_x, start_y ), fill='gray', font_size='16', font_family='courier')) 203 | start_x += this_size 204 | 205 | relation_distances = [] 206 | relation_coordinates = [] 207 | for row in rdf: 208 | if row.result.lower().strip() not in self.color_dict: 209 | self.color_dict[row.result.lower().strip()] = self.__get_color(row.result.lower().strip()) 210 | d_key2 = all_done[int(row.metadata['entity2_begin'])] 211 | d_key1 = all_done[int(row.metadata['entity1_begin'])] 212 | this_dist = abs(d_key2[0] - d_key1[0]) + abs (d_key2[1]-d_key1[1]) 213 | relation_distances.append(this_dist) 214 | relation_coordinates.append((d_key2, d_key1, row.result)) 215 | 216 | relation_distances = np.array(relation_distances) 217 | relation_coordinates = np.array(relation_coordinates) 218 | temp_ind = np.argsort(relation_distances) 219 | relation_distances = relation_distances[temp_ind] 220 | relation_coordinates = relation_coordinates[temp_ind] 221 | for row in relation_coordinates: 222 | self.__draw_line(dwg, int(row[0][0]) , int(row[0][1]), int(row[1][0]), int(row[1][1]), 223 | row[2],self.color_dict[row[2].lower().strip()], show_relations) 224 | 225 | return dwg.tostring() 226 | 227 | def display(self, result, relation_col, document_col='document', show_relations=True): 228 | 229 | original_text = result[document_col][0].result 230 | res = result[relation_col] 231 | return display(HTML(self.__gen_graph(res, original_text, show_relations))) 232 | -------------------------------------------------------------------------------- /sparknlp_display/style.css: -------------------------------------------------------------------------------- 1 | @import url('https://fonts.googleapis.com/css2?family=Montserrat:wght@300;400;500;600;700&display=swap'); 2 | 3 | .scroll.entities { 4 | border: 1px solid #E7EDF0; 5 | border-radius: 3px; 6 | text-align: justify; 7 | } 8 | 9 | .scroll.entities span { 10 | font-size: 14px; 11 | line-height: 24px; 12 | color: #536B76; 13 | font-family: Montserrat, sans-serif !important; 14 | } 15 | .entity-wrapper { 16 | border-radius: 3px; 17 | padding: 1px; 18 | margin: 0 2px 5px 2px; 19 | } 20 | 21 | .scroll.entities span .entity-type { 22 | font-weight: 500; 23 | color: #ffffff; 24 | display: block; 25 | padding: 3px 5px; 26 | } 27 | 28 | .scroll.entities span .entity-name { 29 | border-radius: 3px; 30 | padding: 2px 5px; 31 | display: block; 32 | margin: 3px 2px; 33 | } 34 | 35 | 36 | div.scroll { 37 | line-height: 24px; 38 | } 39 | 40 | -------------------------------------------------------------------------------- /sparknlp_display/style_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | HTML_WRAPPER = """
{}
""" 4 | HTML_INDEX_WRAPPER = """
{}
""" 5 | 6 | STYLE_CONFIG_ENTITIES = f""" 7 | 86 | """ 87 | --------------------------------------------------------------------------------