├── .gitignore ├── EXTRACTORS.md ├── LICENSE.md ├── README.md ├── doc ├── extraction_overview.png ├── overview_horizontal.png ├── wannadb_inspection.png ├── wannadb_main.png └── wannadb_matching.png ├── experiments ├── automatic_feedback.py ├── experiment_runner.py └── util.py ├── main.py ├── pyproject.toml ├── requirements.txt ├── scripts └── preprocess.py ├── tests └── test_data.py ├── wannadb ├── __init__.py ├── configuration.py ├── data │ ├── __init__.py │ ├── data.py │ └── signals.py ├── interaction.py ├── matching │ ├── __init__.py │ ├── custom_match_extraction.py │ ├── distance.py │ └── matching.py ├── preprocessing │ ├── __init__.py │ ├── embedding.py │ ├── extraction.py │ ├── label_paraphrasing.py │ ├── normalization.py │ └── other_processing.py ├── querying │ ├── __init__.py │ └── grouping.py ├── resources.py ├── statistics.py └── status.py ├── wannadb_parsql ├── __init__.py ├── cache_db.py ├── parsql.py ├── rewrite.py └── sql_tokens.py └── wannadb_ui ├── __init__.py ├── common.py ├── document_base.py ├── interactive_matching.py ├── main_window.py ├── resources ├── confidence_high.svg ├── confidence_low.svg ├── correct.svg ├── folder.svg ├── idea.svg ├── incorrect.svg ├── info.svg ├── leave.svg ├── locate.svg ├── magnifier.svg ├── pencil.svg ├── plus.svg ├── redo.svg ├── run.svg ├── run_run.svg ├── save.svg ├── statistics.svg ├── statistics_folder.svg ├── statistics_incorrect.svg ├── statistics_save.svg ├── table.svg ├── text_cursor.svg ├── tools.svg ├── trash.svg └── two_documents.svg ├── start_menu.py └── wannadb_api.py /.gitignore: -------------------------------------------------------------------------------- 1 | /venv*/ 2 | /.idea/ 3 | **/__pycache__/ 4 | /.pytest_cache/ 5 | 6 | /models/ 7 | /cache/ 8 | .json 9 | .pdf 10 | .bson 11 | -------------------------------------------------------------------------------- /EXTRACTORS.md: -------------------------------------------------------------------------------- 1 | ## WannaDB - Interactive Extraction 2 | 3 | In previous research, we proposed an interactive system for extracting structured representations but mainly concentrated on the alignment of information nuggets and target cells while assuming that there are off-the-shelf extractors suitable for finding the required nuggets in the text. 4 | With the latest additions to this repository, we are now extending on that, presenting an approach for going beyond these fixed sets of nuggets while leveraging the same feedback for extraction and alignment. 5 | 6 | ![Overview of the interactive extraction and matching as part of the overall system flow](doc/extraction_overview.png) 7 | 8 | User feedback can be used to provide ad-hoc domain adaptation and find additional relevant extractions that were missed by the generic extractors before. 9 | Once a user points the system to a missing extraction (and provides the relevant value by selecting it from the source document), the system should not only consider that custom span as a valid nugget but additionally try to find similar text spans in different documents and add these missing extractions to the vector space. 10 | Those additional extractions are then new candidates for creating the final structured representation---which supports the exploration of the vector space. 11 | At the same time, the vector space is leveraged to steer the additional extraction process. 12 | 13 | More details can be found in our Paper *Benjamin Hättasch and Carsten Binnig. 2024. More of that, please: Domain Adaptation of Information Extraction through Examples & Feedback. In Workshop on Human-In-the-Loop Data Analytics (HILDA 24), June 14, 2024, Santiago, AA, Chile.* 14 | 15 | We implemented a range of custom extractors that can be used to find similar text spans in different documents. The extractors are based on different approaches, such as exact matching, question answering, semantic similarity, and syntactic similarity. 16 | The implementations can be found in `matching/custom_match_extraction.py`, where the base abstract class `BaseCustomMatchExtractor` is implemented, which provides the structure for all extractors to adapt upon. See below for a full list of all implemented extractors. In `wannadb_api.py`, the extractor that is to be used can be changed by changing the `find_additional_nuggets` attribute of the matching pipeline. For all extractors, the exception being the `FAISS` extractor, a `ParallelWrapper` is provided in `matching/custom_match_extraction.py`. This class can be wrapped around the extractor initialization, which causes the extractor invocation to be designed data-parallel by distributing the remaining documents over a team of threads. 17 | 18 | ## List of extractors 19 | 20 | 1. `ExactCustomMatchExtractor`: Based on extracting exact matches to the annotated span from the other documents. Corresponds to the status quo of WannaDB. 21 | 2. `VarianceExtractor`: Finds syntactic variances of the given text span. 22 | 2. `QuestionAnsweringCustomMatchExtractor`: Prompts the pretrained question answering LLM `deepset/roberta-base-squad2` by asking to extract a similar phrase to that of the selected span. With this, one match for each remaining document is retrieved and is classified as match if the extraction score exceeds a threshold. 23 | 3. `WordNetSimilarityCustomMatchExtractor`: Leverages `WordNet`, a semantic and lexical network which captures relationships between concepts, in order to extract semantically similar words to the selected span. To this end, the Wu-Palmer-Similarity between the match and each token of remaining documents is computed, which quantifies the depth of the first common preprocessor w.r.t to the two concepts. If a high similarity is found, a span corresponding to the ngram structure of the input span is extracted around the match. 24 | 4. `FaissSemanticSimilarityExtractor`: Extracts semantically and syntactically similar spans to the match using the [FAISS](https://github.com/facebookresearch/faiss) library, allowing for high temporal efficiency, even with a large number of documents and tokens. To this end, the embeddings of every token is computed once and indexed using `FAISS. If an embedding of a single token is found to be similar to the whole query, it is further examined by matching it to the ngram structure of the query. A threshold is used to determine whether a candidate ngram is sufficient to classify it as a match. 25 | 5. `SpacySimilarityExtractor`: Similar to the FAISS extractor, this extractor computes the cosine similarity between the custom match to all tokens of remaining documents, and extracting a similar span corresponding to the ngram structure of the query. The main distinction is that a spaCy corpus is used to embed all tokens. **Important**: If this extractor is to be used, the spaCy corpus `en_core_web_md` needs to be loaded beforehand, since the kernel requires a restart. 26 | 6. `NgramCustomMatchExtractor`: An old approach to custom extraction that works similar to the SpacySimilarityExtractor. The main difference is that SBERT is used as an embedding model. However, the inference times are too high to be considered practical. 27 | 7. `ForestCustomMatchExtractor`: This extractor is based on the task of regex synthesis, where positive and negative examples of an attribute are used to produce a regex string which can be used to extract syntactically similar spans to the custom span. To this end, [FOREST](https://github.com/marghrid/FOREST) is used and integrated into WannaDB. However, since multiple examples are required for each attribute, and the fact that many attributes lack close syntactic similarity, such synthesizer might not terminate. For this reason, the extractor has been removed from the main branch, but is still preserved on the `sb/forest-extractor` branch. 28 | 8. `FaissSentenceSimilarityExtractor`: Semantic similarity extraction using a two stage approach (first sentence, then phrase level) and speed up by FAISS. 29 | 9. `VarianceSemanticExtractor`: Combination of `VarianceExtractor`and `FaissSentenceSimilarityExtractor`, to find both syntactic and semantic variants. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WannaDB: Ad-hoc SQL Queries over Text Collections 2 | 3 | ![System overview of WannaDB, consisting of offline preprocessing and only table filling phase](doc/overview_horizontal.png) 4 | 5 | WannaDB allows users to explore unstructured text collections by automatically organizing the relevant information nuggets in a table. It supports ad-hoc SQL queries over text collections using a novel two-phased approach: First, a superset of information nuggets is extracted from the texts using existing extractors such as named entity recognizers. The extractions are then interactively matched to a structured table definition as requested by the user. 6 | 7 | Watch our [demo video](https://link.tuda.systems/aset-video) or [read our paper](https://doi.org/10.18420/BTW2023-08) to learn more about the usage and underlying concepts. 8 | 9 | ## GUI 10 | 11 | Our tool is provided ready-to-use with a graphical user interface that allows to load own text collections and execute the table extraction on them. 12 | 13 | ![Screenshot of the WannaDB GUI: Main Window](doc/wannadb_main.png) 14 | *Main Window, specify the query or attributes list, check current table filling state and inspect the result table.* 15 | 16 | ![Screenshot of the WannaDB GUI: Matching process](doc/wannadb_matching.png) 17 | *Interactive matching, inspect the current list of feedback requests and confirm matches or request detailed inspection.* 18 | 19 | ![Screenshot of the WannaDB GUI: Inspecting a single document to confirm or fix a match](doc/wannadb_inspection.png) 20 | *Inspect a single document, see the extractions in context and chose the correct one (including custom text spans) or state that the document does not contain a relevant value.* 21 | 22 | ## Usage 23 | 24 | Run `main.py` to start the WannaDB GUI. 25 | 26 | To run batch experiments instead, use `experiments/experiment_runner.py`. 27 | 28 | To run the preprocessing stand-alone (e.g., on a server with GPUs) you can use `scripts/preprocess.py`. 29 | 30 | See [EXTRACTORS.md](EXTRACTORS.md) for more information on the interactive extraction component. 31 | 32 | ## Installation 33 | 34 | This project requires Python 3.10 or newer. 35 | 36 | ##### 1. Create a virtual environment. 37 | 38 | ``` 39 | python -m venv venv 40 | source venv/bin/activate 41 | export PYTHONPATH="." 42 | ``` 43 | 44 | ##### 2. Install the dependencies. 45 | 46 | ``` 47 | pip install --upgrade pip 48 | pip install --use-pep517 -r requirements.txt 49 | pip install --use-pep517 pytest 50 | ``` 51 | 52 | You may have to install `torch` by hand if you want to use CUDA: 53 | 54 | https://pytorch.org/get-started/locally/ 55 | 56 | ##### 3. Run the tests. 57 | 58 | ``` 59 | pytest 60 | ``` 61 | 62 | ## Citing WannaDB 63 | 64 | The code in this repository is the result of several scientific publications. If you build upon WannaDB, please cite: 65 | 66 | ``` 67 | @inproceedings{wannadb@BTW23, 68 | author = {Hättasch, Benjamin AND Bodensohn, Jan-Micha AND Vogel, Liane AND Urban, Matthias AND Binnig, Carsten}, 69 | title = {WannaDB: Ad-hoc SQL Queries over Text Collections}, 70 | booktitle = {BTW 2023}, 71 | year = {2023}, 72 | editor = {König-Ries, Birgitta AND Scherzinger, Stefanie AND Lehner, Wolfgang AND Vossen, Gottfried} , 73 | doi = { 10.18420/BTW2023-08 }, 74 | publisher = {Gesellschaft für Informatik e.V.}, 75 | address = {} 76 | } 77 | ``` 78 | 79 | If you want to reference specific features/parts, our further publications might be relevant: 80 | 81 | ``` 82 | @inproceedings{aset@SIGMOD22, 83 | author = {H\"{a}ttasch, Benjamin and Bodensohn, Jan-Micha and Binnig, Carsten}, 84 | title = {Demonstrating ASET: Ad-Hoc Structured Exploration of Text Collections}, 85 | year = {2022}, 86 | isbn = {9781450392495}, 87 | publisher = {Association for Computing Machinery}, 88 | address = {New York, NY, USA}, 89 | url = {https://doi.org/10.1145/3514221.3520174}, 90 | doi = {10.1145/3514221.3520174}, 91 | abstract = {In this demo, we present ASET, a novel tool to explore the contents of unstructured data (text) by automatically transforming relevant parts into tabular form. ASET works in an ad-hoc manner without the need to curate extraction pipelines for the (unseen) text collection or to annotate large amounts of training data. The main idea is to use a new two-phased approach that first extracts a superset of information nuggets from the texts using existing extractors such as named entity recognizers. In a second step, it leverages embeddings and a novel matching strategy to match the extractions to a structured table definition as requested by the user. This demo features the ASET system with a graphical user interface that allows people without machine learning or programming expertise to explore text collections efficiently. This can be done in a self-directed and flexible manner, and ASET provides an intuitive impression of the result quality.}, 92 | booktitle = {Proceedings of the 2022 International Conference on Management of Data}, 93 | pages = {2393–2396}, 94 | numpages = {4}, 95 | keywords = {matching embeddings, text to table, interactive text exploration}, 96 | location = {Philadelphia, PA, USA}, 97 | series = {SIGMOD '22} 98 | } 99 | ``` 100 | 101 | ``` 102 | @inproceedings{aset@AIDB21, 103 | author = {H{\"a}ttasch, Benjamin and Bodensohn, Jan-Micha and Binnig, Carsten}, 104 | year = "2021", 105 | title = "ASET: Ad-hoc Structured Exploration of Text Collections", 106 | eventdate = "16.-20.08.2021", 107 | language = "en", 108 | booktitle = "3rd International Workshop on Applied AI for Database Systems and Applications (AIDB21). In conjunction with the 47th International Conference on Very Large Data Bases, Copenhagen, Denmark, August 16 - 20, 2021.", 109 | location = "Copenhagen, Denmark" 110 | } 111 | ``` 112 | 113 | ``` 114 | @inproceedings{wannadb@DESIRES21, 115 | author = {H{\"{a}}ttasch, Benjamin}, 116 | title = "WannaDB: Ad-hoc Structured Exploration of Text Collections Using Queries", 117 | booktitle = "Proceedings of the Second International Conference on Design of Experimental Search Information REtrieval Systems, Padova, Italy, September 15-18, 2021", 118 | series = "{CEUR} Workshop Proceedings", 119 | volume = "2950", 120 | pages = "179--180", 121 | publisher = "CEUR-WS.org", 122 | year = "2021", 123 | url = "http://ceur-ws.org/Vol-2950/paper-23.pdf", 124 | timestamp = "Mon, 25 Oct 2021 15:03:55 +0200", 125 | biburl = "https://dblp.org/rec/conf/desires/Hattasch21.bib", 126 | bibsource = "dblp computer science bibliography, https://dblp.org" 127 | } 128 | ``` 129 | 130 | ## License 131 | 132 | WannaDB is dually licensed under both AGPLv3 for the free usage by end users or the embedding in Open Source projects, and a commercial license for the integration in industrial projects and closed-source tool chains. More details can be found in [our licence agreement](LICENSE.md). 133 | 134 | 135 | ## Availability of Code & Datasets 136 | 137 | We publish the source code four our system as discussed in the papers here. Additionally, we publish code to reproduce our experiments in a separate repository (coming soon). 138 | 139 | Unfortunately, we cannot publish the datasets online due to copyright issues. We will send them via email on request to everyone interested and hope they can be of benefit for other research, too. 140 | 141 | 142 | ## Implementation details 143 | 144 | The core of WannaDB (extraction and matching) was previously developed by us under the name [ASET (Ad-hoc Structured Exploration of Text Collections)](https://link.tuda.systems/aset). To better reflect the whole application cycle vision we present with this paper, we switchted the name to WannaDB. 145 | 146 | ### Repository structure 147 | 148 | This repository is structured as follows: 149 | 150 | * `wannadb`, `wannadb_parsql`, and `wannadb_ui` contain the implementation of ASET and the GUI. 151 | * `scripts` contains helpers, like a stand-alone preprocessing script. 152 | * `tests` contains pytest tests. 153 | 154 | ### Architecture: Core 155 | 156 | The core implementation of WannaDB is in the `wannadb` package and implemented as a library. The implementation allows you to construct pipelines of different data processors that work with the data model and may involve user feedback. 157 | 158 | **Data model** 159 | 160 | `data` contains WannaDB's data model. The entities are `InformationNugget`s, `Attribute`s, `Document`s, and the `DocumentBase`. 161 | 162 | A nugget is an information piece obtained from a document. An attribute is a table column that gets 163 | populated with information from the documents. A document is a textual document, and the document base is a collection of documents and provides facilities for `BSON` serialization, consistency checks, and data access. 164 | 165 | `InformationNugget`s, `Attribute`s, and `Document`s can have `BaseSignal`s, which provide a way to easily store additional information with them. Each signal is identified with a unique identifier and implements the serialization and deserialization. Furthermore, some signals may not be serialized. There are base implementations for different data types like floats or numpy arrays. 166 | 167 | **Configurations** 168 | 169 | `configuration.py` contains the abstract pipeline code. An `Pipeline` allows you to execute multiple 170 | `BasePipelineElement`s one after the other. These pipeline elements work on an `DocumentBase` and receive a 171 | `BaseInteractionCallback` and `BaseStatusCallback` to facilitate user interactions and convey status updates. 172 | Furthermore, they receive a `Statistics` object that allows them to record information during runtime. 173 | 174 | Both `BasePipelineElement`s and the `Pipeline` are `BaseConfigurableElement`s. This means that they come with a unique identifier and provide methods to instantiate them from a given configuration dictionary and to serialize their configuration as a dictionary. 175 | 176 | Each `BasePipelineElement` specifies which `BaseSignal`s it requires and generates for the nuggets, attributes, and documents. This ensures the consistency of the pipeline. In other words, when a pipeline element is executed, all signals it requires must be set. 177 | 178 | **Callbacks** 179 | 180 | `interaction.py` and `status.py` contain `BaseInteractionCallback` and `BaseStatusCallback`, which allow the pipeline elements to request user interactions and convey status updates. They come with default implementations `InteractionCallback` and `StatusCallback` that receive a callback function when initialized, and `EmptyInteractionCallback` and `EmptyStatusCallback` that simply do nothing. 181 | 182 | **Resources** 183 | 184 | `resources.py` contains a resource manager that allows different parts of WannaDB to share resources like embeddings or transformer models. The module implements the singleton pattern, so there is always only one `ResourceManager` accessed via `resources.MANAGER`, which handles the loading, access, and unloading of `BaseResource`s. You should use a context manager (`with ResourceManager() as resource_manager:`) to ensure that all resources are properly closed when the program stops/crashes. 185 | 186 | Each `BaseResource` comes with a unique identifier and implements methods for loading, unloading, and access. 187 | 188 | **Statistics** 189 | 190 | The `Statistics` object allows you to easily record information during runtime. It is handed from the `Pipeline` to the `BasePipelineElement`s, and from the `BasePipelineElement`s to other components like distance functions. 191 | 192 | ### Architecture: GUI 193 | 194 | The GUI implementation can be found in the `wannadb_ui` package. `wannadb_api.py` provides an asynchronous API for the `wannadb` library using PyQt's slots and signals mechanism. `main_window.py`, `document_base.py`, and `interactive_window.py` contain different parts of the user interface, and `common.py` contains base classes for some recurring user interface elements. 195 | -------------------------------------------------------------------------------- /doc/extraction_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataManagementLab/wannadb/05b36dbf87a66d0e022b4638d764cbdbc8e2c69d/doc/extraction_overview.png -------------------------------------------------------------------------------- /doc/overview_horizontal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataManagementLab/wannadb/05b36dbf87a66d0e022b4638d764cbdbc8e2c69d/doc/overview_horizontal.png -------------------------------------------------------------------------------- /doc/wannadb_inspection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataManagementLab/wannadb/05b36dbf87a66d0e022b4638d764cbdbc8e2c69d/doc/wannadb_inspection.png -------------------------------------------------------------------------------- /doc/wannadb_main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataManagementLab/wannadb/05b36dbf87a66d0e022b4638d764cbdbc8e2c69d/doc/wannadb_main.png -------------------------------------------------------------------------------- /doc/wannadb_matching.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataManagementLab/wannadb/05b36dbf87a66d0e022b4638d764cbdbc8e2c69d/doc/wannadb_matching.png -------------------------------------------------------------------------------- /experiments/automatic_feedback.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Dict, Any, List 3 | import logging 4 | 5 | from wannadb.interaction import BaseInteractionCallback 6 | from util import consider_overlap_as_match, get_document_by_name 7 | 8 | logger: logging.Logger = logging.getLogger(__name__) 9 | logger.setLevel(logging.INFO) 10 | 11 | 12 | class AutomaticRandomRankingBasedMatchingFeedback(BaseInteractionCallback): 13 | """Interaction callback that gives feedback on a random nugget of the ranked list.""" 14 | 15 | def __init__(self, documents: List[Dict[str, Any]], user_attribute_name2dataset_attribute_name: Dict[str, str]): 16 | self._documents: List[Dict[str, Any]] = documents 17 | self._user_attribute_name2dataset_attribute_name: Dict[str, str] = user_attribute_name2dataset_attribute_name 18 | 19 | def _call(self, pipeline_element_identifier: str, data: Dict[str, Any]) -> Dict[str, Any]: 20 | if "do-attribute-request" in data.keys(): 21 | return { 22 | "do-attribute": True 23 | } 24 | nuggets = data["nuggets"] 25 | attribute = data["attribute"] 26 | 27 | attribute_name = self._user_attribute_name2dataset_attribute_name[attribute.name] 28 | 29 | # randomly select a nugget to give feedback on 30 | nugget = random.choice(nuggets) 31 | document = None 32 | for doc in self._documents: 33 | print(doc["id"]) 34 | print(nugget.document.name) 35 | if doc["id"] == nugget.document.name: 36 | document = doc 37 | 38 | # check whether nugget matches attribute 39 | for mention in document["mentions"][attribute_name]: 40 | if consider_overlap_as_match(mention["start_char"], mention["end_char"], 41 | nugget.start_char, nugget.end_char): 42 | # the nugget matches the attribute 43 | print(f"{data['max-distance']:.2f} {attribute_name}: '{nugget.text}' ==> IS MATCH") 44 | return { 45 | "message": "is-match", 46 | "nugget": nugget, 47 | "not-a-match": None 48 | } 49 | 50 | # check if any other nugget matches attribute 51 | for nug in nugget.document.nuggets: 52 | for mention in document["mentions"][attribute_name]: 53 | if consider_overlap_as_match(mention["start_char"], mention["end_char"], 54 | nug.start_char, nug.end_char): 55 | # there is a matching nugget in nugget's document 56 | logger.debug( 57 | f"{data['max-distance']:.2f} {attribute_name}: '{nugget.text}' ==> RETURN OTHER MATCHING NUGGET '{nug.text}'") 58 | return { 59 | "message": "is-match", 60 | "nugget": nug, 61 | "not-a-match": nugget 62 | } 63 | 64 | # there is no matching nugget in nugget's document 65 | logger.debug(f"{data['max-distance']:.2f} {attribute_name}: '{nugget.text}' ==> NO MATCH IN DOCUMENT") 66 | return { 67 | "message": "no-match-in-document", 68 | "nugget": nugget, 69 | "not-a-match": nugget 70 | } 71 | 72 | 73 | class AutomaticFixFirstRankingBasedMatchingFeedback(BaseInteractionCallback): 74 | """Interaction callback that gives feedback on the first incorrect nugget of the ranked list.""" 75 | 76 | def __init__(self, documents: List[Dict[str, Any]], user_attribute_name2dataset_attribute_name: Dict[str, str]): 77 | self._documents: List[Dict[str, Any]] = documents 78 | self._user_attribute_name2dataset_attribute_name: Dict[str, str] = user_attribute_name2dataset_attribute_name 79 | 80 | def _call(self, pipeline_element_identifier: str, data: Dict[str, Any]) -> Dict[str, Any]: 81 | if "do-attribute-request" in data.keys(): 82 | return { 83 | "do-attribute": True 84 | } 85 | nuggets = data["nuggets"] 86 | attribute = data["attribute"] 87 | 88 | attribute_name = self._user_attribute_name2dataset_attribute_name[attribute.name] 89 | 90 | # iterate through nuggets of ranked list and give feedback on first incorrect one 91 | for nugget in nuggets: 92 | document = None 93 | for doc in self._documents: 94 | if doc["id"] == nugget.document.name: 95 | document = doc 96 | 97 | for mention in document["mentions"][attribute_name]: 98 | if consider_overlap_as_match(mention["start_char"], mention["end_char"], 99 | nugget.start_char, nugget.end_char): 100 | break 101 | else: 102 | # nugget is an incorrect guess 103 | for nug in nugget.document.nuggets: 104 | for men in document["mentions"][attribute_name]: 105 | if consider_overlap_as_match(men["start_char"], men["end_char"], 106 | nug.start_char, nug.end_char): 107 | # there is a matching nugget in nugget's document 108 | logger.debug( 109 | f"{data['max-distance']:.2f} {attribute_name}: '{nugget.text}' ==> RETURN OTHER MATCHING NUGGET '{nug.text}'") 110 | return { 111 | "message": "is-match", 112 | "nugget": nug, 113 | "not-a-match": nugget 114 | } 115 | 116 | # there is no matching nugget in nugget's document 117 | logger.debug(f"{data['max-distance']:.2f} {attribute_name}: '{nugget.text}' ==> NO MATCH IN DOCUMENT") 118 | return { 119 | "message": "no-match-in-document", 120 | "nugget": nugget, 121 | "not-a-match": nugget 122 | } 123 | 124 | # all nuggets are matches 125 | logger.debug(f"{data['max-distance']:.2f} {attribute_name}: '{nuggets[0].text}' ==> IS MATCH") 126 | return { 127 | "message": "is-match", 128 | "nugget": nuggets[0], 129 | "not-a-match": None 130 | } 131 | 132 | 133 | class AutomaticCustomMatchesRandomRankingBasedMatchingFeedback(BaseInteractionCallback): 134 | """Interaction callback that gives feedback on a random nugget of the ranked list and creates custom matches.""" 135 | 136 | def __init__(self, documents: List[Dict[str, Any]], user_attribute_name2dataset_attribute_name: Dict[str, str]): 137 | self._documents: List[Dict[str, Any]] = documents 138 | self._user_attribute_name2dataset_attribute_name: Dict[str, str] = user_attribute_name2dataset_attribute_name 139 | 140 | def _call(self, pipeline_element_identifier: str, data: Dict[str, Any]) -> Dict[str, Any]: 141 | if "do-attribute-request" in data.keys(): 142 | return { 143 | "do-attribute": True 144 | } 145 | nuggets = data["nuggets"] 146 | logger.debug(f"Nuggets: {', '.join(f'{n.document.name}: {n.text}' for n in data['nuggets'])}") 147 | attribute = data["attribute"] 148 | 149 | attribute_name = self._user_attribute_name2dataset_attribute_name[attribute.name] 150 | 151 | # randomly select a nugget to give feedback on 152 | nugget = random.choice(nuggets) 153 | document = get_document_by_name(self._documents, nugget.document.name) 154 | if document is None: 155 | logger.warning(f"Document {nugget.document.name} not found in documents.") 156 | 157 | # check whether nugget matches attribute 158 | for mention in document["mentions"][attribute_name]: 159 | if consider_overlap_as_match(mention["start_char"], mention["end_char"], 160 | nugget.start_char, nugget.end_char): 161 | # the nugget matches the attribute 162 | logger.debug(f"{data['max-distance']:.2f} {attribute_name}: '{nugget.text}' ==> IS MATCH") 163 | return { 164 | "message": "is-match", 165 | "nugget": nugget, 166 | "not-a-match": None 167 | } 168 | 169 | # check if any other nugget matches attribute 170 | for nug in nugget.document.nuggets: 171 | for mention in document["mentions"][attribute_name]: 172 | if consider_overlap_as_match(mention["start_char"], mention["end_char"], 173 | nug.start_char, nug.end_char): 174 | # there is a matching nugget in nugget's document 175 | logger.debug(f"{data['max-distance']:.2f} {attribute_name}: '{nugget.text}' ==> RETURN OTHER MATCHING NUGGET '{nug.text}'") 176 | return { 177 | "message": "is-match", 178 | "nugget": nug, 179 | "not-a-match": nugget 180 | } 181 | 182 | # there is no matching nugget in nugget's document 183 | 184 | # check if the value is mentioned in the document 185 | if document["mentions"][attribute_name] != []: 186 | # the value is mentioned in the document 187 | start_char = document["mentions"][attribute_name][0]["start_char"] 188 | end_char = document["mentions"][attribute_name][0]["end_char"] 189 | text = nugget.document.text[start_char:end_char] 190 | logger.debug(f"{data['max-distance']:.2f} {attribute_name}: '{nugget.text}' ==> RETURN CUSTOM MATCH '{text}'") 191 | return { 192 | "message": "custom-match", 193 | "document": nugget.document, 194 | "start": start_char, 195 | "end": end_char 196 | } 197 | 198 | # the value is not mentioned in the document 199 | logger.debug(f"{data['max-distance']:.2f} {attribute_name}: '{nugget.text}' ==> NO MATCH IN DOCUMENT") 200 | return { 201 | "message": "no-match-in-document", 202 | "nugget": nugget, 203 | "not-a-match": nugget 204 | } 205 | -------------------------------------------------------------------------------- /experiments/util.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from wannadb.data.data import DocumentBase 4 | from wannadb.data.signals import NaturalLanguageLabelSignal, LabelSignal, CachedContextSentenceSignal 5 | from wannadb.statistics import Statistics 6 | 7 | 8 | def consider_overlap_as_match(true_start, true_end, pred_start, pred_end): 9 | """Determines whether the predicted span is considered a match of the true span.""" 10 | # considered as overlap if at least half of the larger span 11 | pred_length = pred_end - pred_start 12 | true_length = true_end - true_start 13 | 14 | valid_overlap = max(pred_length // 2, true_length // 2, 1) 15 | 16 | if pred_start <= true_start: 17 | actual_overlap = min(pred_end - true_start, true_length) 18 | else: 19 | actual_overlap = min(true_end - pred_start, pred_length) 20 | 21 | return actual_overlap >= valid_overlap 22 | 23 | 24 | def create_dataframes_attributes_nuggets(document_base: DocumentBase): 25 | for document in document_base.documents: 26 | attributes_and_matches_df = pd.DataFrame({ 27 | "attribute": document_base.attributes, # object ==> cannot be written to csv 28 | "raw_attribute_name": [attribute.name for attribute in document_base.attributes], 29 | "nl_attribute_name": [attribute[NaturalLanguageLabelSignal] for attribute in document_base.attributes], 30 | "matching_nuggets": [document.attribute_mappings[attribute.name] for attribute in 31 | document_base.attributes], # objects ==> cannot be written to csv 32 | "matching_nugget_texts": [[n.text for n in document.attribute_mappings[attribute.name]] for attribute in 33 | document_base.attributes] 34 | }) 35 | 36 | pd.set_option('display.max_columns', 500) 37 | pd.set_option('display.width', 1000) 38 | #print(attributes_and_matches_df) 39 | 40 | nuggets_df = pd.DataFrame({ 41 | "nugget": document.nuggets, # object ==> cannot be written to csv 42 | "raw_nugget_label": [nugget[LabelSignal] for nugget in document.nuggets], 43 | "nl_nugget_label": [nugget[NaturalLanguageLabelSignal] for nugget in document.nuggets], 44 | "nugget_text": [nugget.text for nugget in document.nuggets], 45 | "context_sentence": [nugget[CachedContextSentenceSignal]["text"] for nugget in document.nuggets], 46 | "start_char_in_context": [nugget[CachedContextSentenceSignal]["start_char"] for nugget in 47 | document.nuggets], 48 | "end_char_in_context": [nugget[CachedContextSentenceSignal]["end_char"] for nugget in document.nuggets] 49 | }) 50 | 51 | pd.set_option('display.max_columns', 500) 52 | pd.set_option('display.width', 1000) 53 | #print(nuggets_df) 54 | return attributes_and_matches_df, nuggets_df 55 | 56 | 57 | def calculate_f1_scores(results: Statistics): 58 | # compute the evaluation metrics per attribute 59 | 60 | # recall 61 | if (results["num_should_be_filled_is_correct"] + results["num_should_be_filled_is_incorrect"] + results["num_should_be_filled_is_empty"]) == 0: 62 | results["recall"] = 1 63 | else: 64 | results["recall"] = results["num_should_be_filled_is_correct"] / ( 65 | results["num_should_be_filled_is_correct"] + results["num_should_be_filled_is_incorrect"] + 66 | results["num_should_be_filled_is_empty"]) 67 | 68 | # precision 69 | if (results["num_should_be_filled_is_correct"] + results["num_should_be_filled_is_incorrect"] + results["num_should_be_empty_is_full"]) == 0: 70 | results["precision"] = 1 71 | else: 72 | results["precision"] = results["num_should_be_filled_is_correct"] / ( 73 | results["num_should_be_filled_is_correct"] + results["num_should_be_filled_is_incorrect"] + results["num_should_be_empty_is_full"]) 74 | 75 | # f1 score 76 | if results["precision"] + results["recall"] == 0: 77 | results["f1_score"] = 0 78 | else: 79 | results["f1_score"] = ( 80 | 2 * results["precision"] * results["recall"] / (results["precision"] + results["recall"])) 81 | 82 | # true negative rate 83 | if results["num_should_be_empty_is_empty"] + results["num_should_be_empty_is_full"] == 0: 84 | results["true_negative_rate"] = 1 85 | else: 86 | results["true_negative_rate"] = results["num_should_be_empty_is_empty"] / (results["num_should_be_empty_is_empty"] + results["num_should_be_empty_is_full"]) 87 | 88 | # true positive rate 89 | if results["num_should_be_filled_is_correct"] + results["num_should_be_filled_is_incorrect"] + results["num_should_be_filled_is_empty"] == 0: 90 | results["true_positive_rate"] = 1 91 | else: 92 | results["true_positive_rate"] = results["num_should_be_filled_is_correct"] / (results["num_should_be_filled_is_correct"] + results["num_should_be_filled_is_incorrect"] + results["num_should_be_filled_is_empty"]) 93 | 94 | 95 | def get_document_by_name(documents, doc_name): 96 | doc_name = doc_name.split("\\")[-1] 97 | # if the doc name ends with json, remove it 98 | if "." in doc_name: 99 | doc_name = doc_name.split(".")[0] 100 | for doc in documents: 101 | if doc["id"] == doc_name: 102 | return doc 103 | return None 104 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | from PyQt6.QtWidgets import QApplication 5 | 6 | from wannadb.resources import ResourceManager 7 | from wannadb_ui.main_window import MainWindow 8 | 9 | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") 10 | logger = logging.getLogger() 11 | 12 | if __name__ == "__main__": 13 | logger.info("Starting wannadb_ui.") 14 | 15 | with ResourceManager() as resource_manager: 16 | # set up PyQt application 17 | app = QApplication(sys.argv) 18 | 19 | window = MainWindow() 20 | 21 | sys.exit(app.exec()) 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "wannadb" 7 | version = "0.0.1" 8 | authors = [ 9 | { name = "Benjamin Hättasch" }, 10 | { name = "Jan-Micha Bodensohn" }, 11 | { name = "Liane Vogel" }, 12 | ] 13 | description = "WannaDB: Ad-hoc SQL Queries over Text Collections" 14 | readme = "README.md" 15 | license = { file = "LICENSE" } 16 | requires-python = ">=3.10" 17 | classifiers = [ 18 | "Programming Language :: Python :: 3", 19 | ] 20 | dependencies = [ 21 | "pymongo==4.8.0", 22 | "torch==2.4.1", 23 | "numpy==1.26.4", 24 | "pandas==2.2.2", 25 | "scipy==1.14.0", 26 | "stanza==1.8.2", 27 | "spacy==3.7.5", 28 | "sentence-transformers[train]==3.0.1", 29 | "matplotlib==3.9.2", 30 | "seaborn==0.13.2", 31 | "scikit-learn==1.5.1", 32 | "transformers==4.43.3", 33 | "PyQt6==6.7.1", 34 | "sqlparse==0.5.1", 35 | "faiss-cpu==1.8.0.post1", 36 | "nltk==3.8.1", 37 | ] 38 | 39 | [project.urls] 40 | "Homepage" = "https://github.com/DataManagementLab/wannadb" 41 | 42 | [tool.setuptools] 43 | packages = [ 44 | "wannadb", 45 | "wannadb_parsql", 46 | "wannadb_ui", 47 | ] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.12 3 | # by the following command: 4 | # 5 | # pip-compile --output-file=requirements.txt 6 | # 7 | accelerate==0.34.2 8 | # via sentence-transformers 9 | aiohappyeyeballs==2.4.0 10 | # via aiohttp 11 | aiohttp==3.10.5 12 | # via 13 | # datasets 14 | # fsspec 15 | aiosignal==1.3.1 16 | # via aiohttp 17 | annotated-types==0.7.0 18 | # via pydantic 19 | attrs==24.2.0 20 | # via aiohttp 21 | blis==0.7.11 22 | # via thinc 23 | catalogue==2.0.10 24 | # via 25 | # spacy 26 | # srsly 27 | # thinc 28 | certifi==2024.7.4 29 | # via requests 30 | charset-normalizer==3.3.2 31 | # via requests 32 | click==8.1.7 33 | # via 34 | # nltk 35 | # typer 36 | cloudpathlib==0.18.1 37 | # via weasel 38 | colorama==0.4.6 39 | # via 40 | # click 41 | # tqdm 42 | # wasabi 43 | confection==0.1.5 44 | # via 45 | # thinc 46 | # weasel 47 | contourpy==1.2.1 48 | # via matplotlib 49 | cycler==0.12.1 50 | # via matplotlib 51 | cymem==2.0.8 52 | # via 53 | # preshed 54 | # spacy 55 | # thinc 56 | datasets==3.0.0 57 | # via sentence-transformers 58 | dill==0.3.8 59 | # via 60 | # datasets 61 | # multiprocess 62 | dnspython==2.6.1 63 | # via pymongo 64 | emoji==2.12.1 65 | # via stanza 66 | faiss-cpu==1.8.0.post1 67 | # via wannadb (pyproject.toml) 68 | filelock==3.15.4 69 | # via 70 | # datasets 71 | # huggingface-hub 72 | # torch 73 | # transformers 74 | fonttools==4.53.1 75 | # via matplotlib 76 | frozenlist==1.4.1 77 | # via 78 | # aiohttp 79 | # aiosignal 80 | fsspec[http]==2024.6.1 81 | # via 82 | # datasets 83 | # huggingface-hub 84 | # torch 85 | huggingface-hub==0.24.5 86 | # via 87 | # accelerate 88 | # datasets 89 | # sentence-transformers 90 | # tokenizers 91 | # transformers 92 | idna==3.7 93 | # via 94 | # requests 95 | # yarl 96 | jinja2==3.1.4 97 | # via 98 | # spacy 99 | # torch 100 | joblib==1.4.2 101 | # via 102 | # nltk 103 | # scikit-learn 104 | kiwisolver==1.4.5 105 | # via matplotlib 106 | langcodes==3.4.0 107 | # via spacy 108 | language-data==1.2.0 109 | # via langcodes 110 | marisa-trie==1.2.0 111 | # via language-data 112 | markdown-it-py==3.0.0 113 | # via rich 114 | markupsafe==2.1.5 115 | # via jinja2 116 | matplotlib==3.9.2 117 | # via 118 | # seaborn 119 | # wannadb (pyproject.toml) 120 | mdurl==0.1.2 121 | # via markdown-it-py 122 | mpmath==1.3.0 123 | # via sympy 124 | multidict==6.1.0 125 | # via 126 | # aiohttp 127 | # yarl 128 | multiprocess==0.70.16 129 | # via datasets 130 | murmurhash==1.0.10 131 | # via 132 | # preshed 133 | # spacy 134 | # thinc 135 | networkx==3.3 136 | # via 137 | # stanza 138 | # torch 139 | nltk==3.8.1 140 | # via wannadb (pyproject.toml) 141 | numpy==1.26.4 142 | # via 143 | # accelerate 144 | # blis 145 | # contourpy 146 | # datasets 147 | # faiss-cpu (requires numpy<2.0,>=1.0) 148 | # matplotlib 149 | # pandas 150 | # pyarrow 151 | # scikit-learn 152 | # scipy 153 | # seaborn 154 | # sentence-transformers 155 | # spacy 156 | # stanza 157 | # thinc (requires numpy<2.0,>=1.19.0) 158 | # transformers 159 | # wannadb (pyproject.toml) 160 | packaging==24.1 161 | # via 162 | # accelerate 163 | # datasets 164 | # faiss-cpu 165 | # huggingface-hub 166 | # matplotlib 167 | # spacy 168 | # thinc 169 | # transformers 170 | # weasel 171 | pandas==2.2.2 172 | # via 173 | # datasets 174 | # seaborn 175 | # wannadb (pyproject.toml) 176 | pillow==10.4.0 177 | # via 178 | # matplotlib 179 | # sentence-transformers 180 | preshed==3.0.9 181 | # via 182 | # spacy 183 | # thinc 184 | protobuf==5.27.3 185 | # via stanza 186 | psutil==6.0.0 187 | # via accelerate 188 | pyarrow==17.0.0 189 | # via datasets 190 | pydantic==2.8.2 191 | # via 192 | # confection 193 | # spacy 194 | # thinc 195 | # weasel 196 | pydantic-core==2.20.1 197 | # via pydantic 198 | pygments==2.18.0 199 | # via rich 200 | pymongo==4.8.0 201 | # via wannadb (pyproject.toml) 202 | pyparsing==3.1.2 203 | # via matplotlib 204 | pyqt6==6.7.1 205 | # via wannadb (pyproject.toml) 206 | pyqt6-qt6==6.7.2 207 | # via pyqt6 208 | pyqt6-sip==13.8.0 209 | # via pyqt6 210 | python-dateutil==2.9.0.post0 211 | # via 212 | # matplotlib 213 | # pandas 214 | pytz==2024.1 215 | # via pandas 216 | pyyaml==6.0.1 217 | # via 218 | # accelerate 219 | # datasets 220 | # huggingface-hub 221 | # transformers 222 | regex==2024.7.24 223 | # via 224 | # nltk 225 | # transformers 226 | requests==2.32.3 227 | # via 228 | # datasets 229 | # huggingface-hub 230 | # spacy 231 | # stanza 232 | # transformers 233 | # weasel 234 | rich==13.7.1 235 | # via typer 236 | safetensors==0.4.3 237 | # via 238 | # accelerate 239 | # transformers 240 | scikit-learn==1.5.1 241 | # via 242 | # sentence-transformers 243 | # wannadb (pyproject.toml) 244 | scipy==1.14.0 245 | # via 246 | # scikit-learn 247 | # sentence-transformers 248 | # wannadb (pyproject.toml) 249 | seaborn==0.13.2 250 | # via wannadb (pyproject.toml) 251 | sentence-transformers[train]==3.0.1 252 | # via wannadb (pyproject.toml) 253 | shellingham==1.5.4 254 | # via typer 255 | six==1.16.0 256 | # via python-dateutil 257 | smart-open==6.4.0 258 | # via weasel 259 | spacy==3.7.5 260 | # via wannadb (pyproject.toml) 261 | spacy-legacy==3.0.12 262 | # via spacy 263 | spacy-loggers==1.0.5 264 | # via spacy 265 | sqlparse==0.5.1 266 | # via wannadb (pyproject.toml) 267 | srsly==2.4.8 268 | # via 269 | # confection 270 | # spacy 271 | # thinc 272 | # weasel 273 | stanza==1.8.2 274 | # via wannadb (pyproject.toml) 275 | sympy==1.13.1 276 | # via torch 277 | thinc==8.2.5 278 | # via spacy 279 | threadpoolctl==3.5.0 280 | # via scikit-learn 281 | tokenizers==0.19.1 282 | # via transformers 283 | toml==0.10.2 284 | # via stanza 285 | torch==2.4.1 286 | # via 287 | # accelerate 288 | # sentence-transformers 289 | # stanza 290 | # wannadb (pyproject.toml) 291 | tqdm==4.66.4 292 | # via 293 | # datasets 294 | # huggingface-hub 295 | # nltk 296 | # sentence-transformers 297 | # spacy 298 | # stanza 299 | # transformers 300 | transformers==4.43.3 301 | # via 302 | # sentence-transformers 303 | # wannadb (pyproject.toml) 304 | typer==0.12.3 305 | # via 306 | # spacy 307 | # weasel 308 | typing-extensions==4.12.2 309 | # via 310 | # emoji 311 | # huggingface-hub 312 | # pydantic 313 | # pydantic-core 314 | # torch 315 | # typer 316 | tzdata==2024.1 317 | # via pandas 318 | urllib3==2.2.2 319 | # via requests 320 | wasabi==1.1.3 321 | # via 322 | # spacy 323 | # thinc 324 | # weasel 325 | weasel==0.4.1 326 | # via spacy 327 | xxhash==3.5.0 328 | # via datasets 329 | yarl==1.12.1 330 | # via aiohttp 331 | 332 | # The following packages are considered to be unsafe in a requirements file: 333 | # setuptools 334 | -------------------------------------------------------------------------------- /scripts/preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging.config 3 | import os 4 | from pathlib import Path 5 | 6 | from wannadb.configuration import Pipeline 7 | from wannadb.data.data import Document, DocumentBase 8 | from wannadb.interaction import EmptyInteractionCallback 9 | from wannadb.preprocessing.embedding import BERTContextSentenceEmbedder, RelativePositionEmbedder, SBERTTextEmbedder, SBERTLabelEmbedder 10 | from wannadb.preprocessing.extraction import StanzaNERExtractor, SpacyNERExtractor 11 | from wannadb.preprocessing.label_paraphrasing import OntoNotesLabelParaphraser, SplitAttributeNameLabelParaphraser 12 | from wannadb.preprocessing.normalization import CopyNormalizer 13 | from wannadb.preprocessing.other_processing import ContextSentenceCacher 14 | from wannadb.resources import ResourceManager 15 | from wannadb.statistics import Statistics 16 | from wannadb.status import EmptyStatusCallback 17 | 18 | 19 | def init_argparse() -> argparse.ArgumentParser: 20 | parser = argparse.ArgumentParser( 21 | usage="preprocess.py input_path output_path [OPTIONS]", 22 | description="Preprocess a collection of textual documents into a document base.", 23 | prog="WannaDB Preprocessing CLI", 24 | ) 25 | parser.add_argument( 26 | "-v", "--version", action="version", 27 | version=f"{parser.prog} version 0.9.0" 28 | ) 29 | parser.add_argument("input_path", help="Path containing the input files") 30 | parser.add_argument("output_path", 31 | help="Path where the resulting bson file should be placed. " 32 | "Will be created if it does not exist yet.") 33 | parser.add_argument('-n', '--name', required=False, 34 | help="Name of the serialized document base. " 35 | "Optional, if not specified 'document_base' will be used.") 36 | return parser 37 | 38 | 39 | def main() -> None: 40 | parser = init_argparse() 41 | args = parser.parse_args() 42 | print(args.input_path) 43 | print(args.output_path) 44 | 45 | logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") 46 | logger = logging.getLogger() 47 | 48 | dataset_name = args.name if args.name else "document_base" 49 | input_path = args.input_path 50 | output_path = args.output_path 51 | 52 | with ResourceManager(): 53 | documents = [] 54 | for filename in os.listdir(input_path): 55 | with open(os.path.join(input_path, filename), "r", encoding='utf-8') as infile: 56 | text = infile.read() 57 | documents.append(Document(filename.split(".")[0], text)) 58 | 59 | logger.info(f"Loaded {len(documents)} documents") 60 | 61 | wannadb_pipeline = Pipeline([ 62 | StanzaNERExtractor(), 63 | SpacyNERExtractor("SpacyEnCoreWebLg"), 64 | ContextSentenceCacher(), 65 | CopyNormalizer(), 66 | OntoNotesLabelParaphraser(), 67 | SplitAttributeNameLabelParaphraser(do_lowercase=True, splitters=[" ", "_"]), 68 | SBERTLabelEmbedder("SBERTBertLargeNliMeanTokensResource"), 69 | SBERTTextEmbedder("SBERTBertLargeNliMeanTokensResource"), 70 | BERTContextSentenceEmbedder("BertLargeCasedResource"), 71 | RelativePositionEmbedder() 72 | ]) 73 | 74 | document_base = DocumentBase(documents, []) 75 | 76 | statistics = Statistics(do_collect=True) 77 | statistics["preprocessing"]["config"] = wannadb_pipeline.to_config() 78 | 79 | wannadb_pipeline( 80 | document_base=document_base, 81 | interaction_callback=EmptyInteractionCallback(), 82 | status_callback=EmptyStatusCallback(), 83 | statistics=statistics["preprocessing"] 84 | ) 85 | 86 | Path(output_path).mkdir(parents=True, exist_ok=True) 87 | with open(os.path.join(output_path, f"{dataset_name}.bson"), "wb") as file: 88 | file.write(document_base.to_bson()) 89 | 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import pytest 4 | 5 | from wannadb.data.data import Attribute, Document, DocumentBase, InformationNugget 6 | from wannadb.data.signals import CachedDistanceSignal, LabelSignal, SentenceStartCharsSignal, CurrentMatchIndexSignal 7 | 8 | 9 | @pytest.fixture 10 | def documents() -> List[Document]: 11 | return [ 12 | Document( 13 | "document-0", 14 | "Wilhelm Conrad Röntgen (/ˈrɛntɡən, -dʒən, ˈrʌnt-/; [ˈvɪlhɛlm ˈʁœntɡən]; 27 March 1845 – 10 " 15 | "February 1923) was a German physicist, who, on 8 November 1895, produced and detected " 16 | "electromagnetic radiation in a wavelength range known as X-rays or Röntgen rays, an achievement " 17 | "that earned him the first Nobel Prize in Physics in 1901. In honour of his accomplishments, in " 18 | "2004 the International Union of Pure and Applied Chemistry (IUPAC) named element 111, " 19 | "roentgenium, a radioactive element with multiple unstable isotopes, after him." 20 | ), 21 | Document( 22 | "document-1", 23 | "Wilhelm Carl Werner Otto Fritz Franz Wien ([ˈviːn]; 13 January 1864 – 30 August 1928) was a " 24 | "German physicist who, in 1893, used theories about heat and electromagnetism to deduce Wien's " 25 | "displacement law, which calculates the emission of a blackbody at any temperature from the " 26 | "emission at any one reference temperature. He also formulated an expression for the black-body " 27 | "radiation which is correct in the photon-gas limit. His arguments were based on the notion of " 28 | "adiabatic invariance, and were instrumental for the formulation of quantum mechanics. Wien " 29 | "received the 1911 Nobel Prize for his work on heat radiation. He was a cousin of Max Wien, " 30 | "inventor of the Wien bridge." 31 | ), 32 | Document( 33 | "document-2", 34 | "Heike Kamerlingh Onnes ([ˈɔnəs]; 21 September 1853 – 21 February 1926) was a Dutch physicist and " 35 | "Nobel laureate. He exploited the Hampson-Linde cycle to investigate how materials behave when " 36 | "cooled to nearly absolute zero and later to liquefy helium for the first time. His production " 37 | "of extreme cryogenic temperatures led to his discovery of superconductivity in 1911: for " 38 | "certain materials, electrical resistance abruptly vanishes at very low temperatures." 39 | ), 40 | Document( 41 | "document-2", 42 | "Heike Kamerlingh Onnes ([ˈɔnəs]; 21 September 1853 – 21 February 1926) was a Dutch physicist and " 43 | "Nobel laureate. He exploited the Hampson-Linde cycle to investigate how materials behave when " 44 | "cooled to nearly absolute zero and later to liquefy helium for the first time. His production " 45 | "of extreme cryogenic temperatures led to his discovery of superconductivity in 1911: for " 46 | "certain materials, electrical resistance abruptly vanishes at very low temperatures." 47 | ) 48 | ] 49 | 50 | 51 | @pytest.fixture 52 | def information_nuggets(documents) -> List[InformationNugget]: 53 | return [ 54 | InformationNugget(documents[0], 0, 22), 55 | InformationNugget(documents[0], 56, 123), 56 | InformationNugget(documents[1], 165, 176), 57 | InformationNugget(documents[1], 234, 246), 58 | InformationNugget(documents[2], 434, 456), 59 | InformationNugget(documents[2], 123, 234), 60 | InformationNugget(documents[2], 123, 234) 61 | ] 62 | 63 | 64 | @pytest.fixture 65 | def attributes() -> List[Attribute]: 66 | return [ 67 | Attribute("name"), 68 | Attribute("field"), 69 | Attribute("field") 70 | ] 71 | 72 | 73 | @pytest.fixture 74 | def document_base(documents, information_nuggets, attributes) -> DocumentBase: 75 | # link nuggets to documents 76 | for nugget in information_nuggets: 77 | nugget.document.nuggets.append(nugget) 78 | 79 | # set up some dummy attribute mappings for documents 0 and 1 80 | documents[0].attribute_mappings[attributes[0].name] = [information_nuggets[0]] 81 | documents[0].attribute_mappings[attributes[1].name] = [] 82 | 83 | return DocumentBase( 84 | documents=documents[:-1], 85 | attributes=attributes[:-1] 86 | ) 87 | 88 | 89 | def test_information_nugget(documents, information_nuggets, attributes, document_base) -> None: 90 | # test __eq__ 91 | assert information_nuggets[0] == information_nuggets[0] 92 | assert information_nuggets[0] != information_nuggets[1] 93 | assert information_nuggets[1] != information_nuggets[0] 94 | assert information_nuggets[0] != object() 95 | assert object() != information_nuggets[0] 96 | 97 | # test __str__ and __repr__ and __hash__ 98 | for nugget in information_nuggets: 99 | assert str(nugget) == f"'{nugget.text}'" 100 | assert repr(nugget) == f"InformationNugget({repr(nugget.document)}, {nugget.start_char}, {nugget.end_char})" 101 | assert hash(nugget) == hash((nugget.document, nugget.start_char, nugget.end_char)) 102 | 103 | # test document 104 | assert information_nuggets[0].document is documents[0] 105 | 106 | # test start_char and end_char 107 | assert information_nuggets[0].start_char == 0 108 | assert information_nuggets[0].end_char == 22 109 | 110 | # test text 111 | assert information_nuggets[0].text == "Wilhelm Conrad Röntgen" 112 | 113 | # test signals 114 | information_nuggets[5][LabelSignal] = "my-label-signal" 115 | assert information_nuggets[5].signals[LabelSignal.identifier].value == "my-label-signal" 116 | assert information_nuggets[5][LabelSignal.identifier] == "my-label-signal" 117 | assert information_nuggets[5][LabelSignal] == "my-label-signal" 118 | assert information_nuggets[5] != information_nuggets[6] 119 | assert information_nuggets[6] != information_nuggets[5] 120 | 121 | information_nuggets[5][LabelSignal] = "new-value" 122 | assert information_nuggets[5][LabelSignal] == "new-value" 123 | 124 | information_nuggets[5][LabelSignal.identifier] = "new-new-value" 125 | assert information_nuggets[5][LabelSignal] == "new-new-value" 126 | 127 | information_nuggets[5][LabelSignal] = LabelSignal("another-value") 128 | assert information_nuggets[5][LabelSignal] == "another-value" 129 | 130 | information_nuggets[5][CachedDistanceSignal] = CachedDistanceSignal(0.23) 131 | assert information_nuggets[5][CachedDistanceSignal] == 0.23 132 | 133 | 134 | def test_attribute(documents, information_nuggets, attributes, document_base) -> None: 135 | # test __eq__ 136 | assert attributes[0] == attributes[0] 137 | assert attributes[0] != attributes[1] 138 | assert attributes[1] != attributes[0] 139 | assert attributes[0] != object() 140 | assert object() != attributes[0] 141 | 142 | # test __str__ and __repr__ and __hash__ 143 | for attribute in attributes: 144 | assert str(attribute) == f"'{attribute.name}'" 145 | assert repr(attribute) == f"Attribute('{attribute.name}')" 146 | assert hash(attribute) == hash(attribute.name) 147 | 148 | # test name 149 | assert attributes[0].name == "name" 150 | 151 | # test signals 152 | attributes[1][LabelSignal] = "my-label-signal" 153 | assert attributes[1].signals[LabelSignal.identifier].value == "my-label-signal" 154 | assert attributes[1][LabelSignal.identifier] == "my-label-signal" 155 | assert attributes[1][LabelSignal] == "my-label-signal" 156 | assert attributes[1] != attributes[2] 157 | assert attributes[2] != attributes[1] 158 | 159 | attributes[1][LabelSignal] = "new-value" 160 | assert attributes[1][LabelSignal] == "new-value" 161 | 162 | attributes[1][LabelSignal.identifier] = "new-new-value" 163 | assert attributes[1][LabelSignal] == "new-new-value" 164 | 165 | attributes[1][LabelSignal] = LabelSignal("another-value") 166 | assert attributes[1][LabelSignal] == "another-value" 167 | 168 | attributes[1][CachedDistanceSignal] = CachedDistanceSignal(0.23) 169 | assert attributes[1][CachedDistanceSignal] == 0.23 170 | 171 | 172 | def test_document(documents, information_nuggets, attributes, document_base) -> None: 173 | # test __eq__ 174 | assert documents[0] == documents[0] 175 | assert documents[0] != documents[1] 176 | assert documents[1] != documents[0] 177 | assert documents[0] != object() 178 | assert object() != documents[0] 179 | 180 | # test __str__ and __repr__ and __hash__ 181 | for document in documents: 182 | assert str(document) == f"'{document.text}'" 183 | assert repr(document) == f"Document('{document.name}', '{document.text}')" 184 | assert hash(document) == hash(document.name) 185 | 186 | # test name 187 | assert documents[0].name == "document-0" 188 | 189 | # test text 190 | assert documents[0].text[:40] == "Wilhelm Conrad Röntgen (/ˈrɛntɡən, -dʒən" 191 | 192 | # test nuggets 193 | assert documents[0].nuggets == [information_nuggets[0], information_nuggets[1]] 194 | 195 | # test attribute mappings 196 | assert documents[0].attribute_mappings[attributes[0].name] == [information_nuggets[0]] 197 | assert documents[0].attribute_mappings[attributes[1].name] == [] 198 | 199 | # test signals 200 | documents[2][SentenceStartCharsSignal] = [0, 10, 20] 201 | assert documents[2].signals[SentenceStartCharsSignal.identifier].value == [0, 10, 20] 202 | assert documents[2][SentenceStartCharsSignal.identifier] == [0, 10, 20] 203 | assert documents[2][SentenceStartCharsSignal] == [0, 10, 20] 204 | assert documents[2] != documents[3] 205 | assert documents[3] != documents[2] 206 | 207 | documents[2][SentenceStartCharsSignal] = [1, 2, 3] 208 | assert documents[2][SentenceStartCharsSignal] == [1, 2, 3] 209 | 210 | documents[2][SentenceStartCharsSignal.identifier] = [3, 4, 5] 211 | assert documents[2][SentenceStartCharsSignal] == [3, 4, 5] 212 | 213 | documents[2][SentenceStartCharsSignal.identifier] = SentenceStartCharsSignal([6, 7]) 214 | assert documents[2][SentenceStartCharsSignal] == [6, 7] 215 | 216 | documents[2][CurrentMatchIndexSignal] = CurrentMatchIndexSignal(2) 217 | assert documents[2][CurrentMatchIndexSignal] == 2 218 | 219 | 220 | def test_document_base(documents, information_nuggets, attributes, document_base) -> None: 221 | # test __eq__ 222 | assert document_base == document_base 223 | assert document_base != DocumentBase(documents, attributes[:1]) 224 | assert DocumentBase(documents, attributes[:1]) != document_base 225 | assert document_base != object() 226 | assert object() != document_base 227 | 228 | # test __str__ 229 | assert str(document_base) == "(3 documents, 7 nuggets, 2 attributes)" 230 | 231 | # test __repr__ 232 | assert repr(document_base) == "DocumentBase([{}], [{}])".format( 233 | ", ".join(repr(document) for document in document_base.documents), 234 | ", ".join(repr(attribute) for attribute in document_base.attributes) 235 | ) 236 | 237 | # test documents 238 | assert document_base.documents == documents[:-1] 239 | 240 | # test attributes 241 | assert document_base.attributes == attributes[:-1] 242 | 243 | # test nuggets 244 | assert document_base.nuggets == information_nuggets 245 | 246 | # test to_table_dict 247 | assert document_base.to_table_dict() == { 248 | "document-name": ["document-0", "document-1", "document-2"], 249 | "name": [[information_nuggets[0]], None, None], 250 | "field": [[], None, None] 251 | } 252 | 253 | assert document_base.to_table_dict("text") == { 254 | "document-name": ["document-0", "document-1", "document-2"], 255 | "name": [["Wilhelm Conrad Röntgen"], None, None], 256 | "field": [[], None, None] 257 | } 258 | 259 | assert document_base.to_table_dict("value") == { 260 | "document-name": ["document-0", "document-1", "document-2"], 261 | "name": [[None], None, None], 262 | "field": [[], None, None] 263 | } 264 | 265 | # test get_nuggets_for_attribute 266 | assert document_base.get_nuggets_for_attribute(attributes[0]) == [information_nuggets[0]] 267 | 268 | # test get_column_for_attribute 269 | assert document_base.get_column_for_attribute(attributes[0]) == [[information_nuggets[0]], None, None] 270 | 271 | # test validate_consistency 272 | assert document_base.validate_consistency() 273 | 274 | # test to_bson and from_bson 275 | bson_bytes: bytes = document_base.to_bson() 276 | copied_document_base: DocumentBase = DocumentBase.from_bson(bson_bytes) 277 | assert document_base == copied_document_base 278 | assert copied_document_base == document_base 279 | -------------------------------------------------------------------------------- /wannadb/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataManagementLab/wannadb/05b36dbf87a66d0e022b4638d764cbdbc8e2c69d/wannadb/__init__.py -------------------------------------------------------------------------------- /wannadb/configuration.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | import time 4 | from typing import Any, Dict, List, Type 5 | 6 | from wannadb.data.data import DocumentBase 7 | from wannadb.interaction import BaseInteractionCallback 8 | from wannadb.statistics import Statistics 9 | from wannadb.status import BaseStatusCallback 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | CONFIGURABLE_ELEMENTS: Dict[str, Type["BaseConfigurableElement"]] = {} 14 | 15 | 16 | def register_configurable_element( 17 | configurable_element: Type["BaseConfigurableElement"] 18 | ) -> Type["BaseConfigurableElement"]: 19 | """Register the given configurable element class.""" 20 | CONFIGURABLE_ELEMENTS[configurable_element.identifier] = configurable_element 21 | return configurable_element 22 | 23 | 24 | class BaseConfigurableElement(abc.ABC): 25 | """ 26 | Base class for all configurable elements. 27 | 28 | A configurable element is a class (e.g. pipeline element or pipeline) that can be configured. The element's 29 | configuration must be serializable ('to_config'), and the exact same element must be reproducible from its 30 | serialized configuration ('from_config'). 31 | 32 | Furthermore, each kind of configurable element must be identifiable by a unique identifier. 33 | """ 34 | identifier: str = "BaseConfigurableElement" 35 | 36 | def __repr__(self): 37 | return f"{self.__class__.__name__}()" 38 | 39 | def __str__(self): 40 | return self.identifier 41 | 42 | def __eq__(self, other): 43 | return isinstance(other, self.__class__) and self.identifier == other.identifier 44 | 45 | def __hash__(self) -> int: 46 | return hash(self.identifier) 47 | 48 | @abc.abstractmethod 49 | def to_config(self) -> Dict[str, Any]: 50 | """ 51 | Obtain a JSON-serializable representation of the element. 52 | 53 | :return: JSON-serializable representation of the element 54 | """ 55 | raise NotImplementedError 56 | 57 | @classmethod 58 | def from_config(cls, config: Dict[str, Any]) -> "BaseConfigurableElement": 59 | """ 60 | Create the element from its JSON-serializable representation. 61 | 62 | :param config: JSON-serializable representation of the element 63 | :return: element created from the JSON-serializable representation 64 | """ 65 | return CONFIGURABLE_ELEMENTS[config["identifier"]].from_config(config) 66 | 67 | 68 | class BasePipelineElement(BaseConfigurableElement, abc.ABC): 69 | """ 70 | Base class for all pipeline elements. 71 | 72 | A pipeline element is a class (e.g. an extractor, embedder, or matcher) that can be applied ('__call__') to a 73 | document base as part of a pipeline. As such, it works with the data elements' signals. Each pipeline element 74 | specifies the signals it consumes and the signals it produces for each kind of data element. 75 | 76 | A pipeline element is a configurable element. 77 | """ 78 | 79 | # identifiers of the signals that the pipeline element requires for nuggets, attributes, and documents 80 | # signals the pipeline element may use if they exist but does not necessarily require are not part of this list 81 | required_signal_identifiers: Dict[str, List[str]] = { 82 | "nuggets": [], 83 | "attributes": [], 84 | "documents": [] 85 | } 86 | 87 | # identifiers of the signals that the pipeline element generates for all nuggets, attributes, and documents 88 | generated_signal_identifiers: Dict[str, List[str]] = { 89 | "nuggets": [], 90 | "attributes": [], 91 | "documents": [] 92 | } 93 | 94 | def _add_required_signal_identifiers(self, required_signal_identifiers: Dict[str, List[str]]) -> None: 95 | """ 96 | Helper method that adds the dictionary of required signal identifiers to this pipeline element's dictionary of 97 | required signal identifiers. 98 | 99 | :param required_signal_identifiers: dictionary of required signal identifiers 100 | """ 101 | for data_element in ["nuggets", "attributes", "documents"]: 102 | ids: List[str] = self.required_signal_identifiers[data_element] + required_signal_identifiers[data_element] 103 | self.required_signal_identifiers[data_element] = list(set(ids)) 104 | 105 | def __call__( 106 | self, 107 | document_base: DocumentBase, 108 | interaction_callback: BaseInteractionCallback, 109 | status_callback: BaseStatusCallback, 110 | statistics: Statistics 111 | ) -> None: 112 | """ 113 | Apply the pipeline element to the document base. 114 | 115 | This method is called by the pipeline and calls the _call method that contains the actual implementation of the 116 | pipeline element. Furthermore, it ensures that the proper status is communicated before and after the pipeline 117 | element's execution and tracks the execution time. 118 | 119 | :param document_base: document base to work on 120 | :param interaction_callback: callback to allow for user interaction 121 | :param status_callback: callback to communicate current status (message and progress) 122 | :param statistics: statistics object to collect statistics 123 | """ 124 | logger.info(f"Execute {self.identifier}.") 125 | tick: float = time.time() 126 | status_callback(f"Running {self.identifier}...", -1) 127 | 128 | statistics["identifier"] = self.identifier 129 | 130 | self._call(document_base, interaction_callback, status_callback, statistics) 131 | 132 | status_callback(f"Running {self.identifier}...", 1) 133 | tack: float = time.time() 134 | logger.info(f"Executed {self.identifier} in {tack - tick} seconds.") 135 | statistics["runtime"] = tack - tick 136 | 137 | @abc.abstractmethod 138 | def _call( 139 | self, 140 | document_base: DocumentBase, 141 | interaction_callback: BaseInteractionCallback, 142 | status_callback: BaseStatusCallback, 143 | statistics: Statistics 144 | ) -> None: 145 | """ 146 | Apply the pipeline element to the document base. 147 | 148 | This method is overwritten by the actual pipeline elements and contains their implementation. 149 | 150 | :param document_base: document base to work on 151 | :param interaction_callback: callback to allow for user interaction 152 | :param status_callback: callback to communicate current status (message and progress) 153 | :param statistics: statistics object to collect statistics 154 | """ 155 | raise NotImplementedError 156 | 157 | def _use_status_callback(self, status_callback: BaseStatusCallback, ix: int, total: int) -> None: 158 | """ 159 | Helper method that calls the status callback at regular intervals. 160 | 161 | :param status_callback: callback to communicate current status (message and progress) 162 | :param ix: index of the current element 163 | :param total: total number of elements 164 | """ 165 | if total == 0: 166 | status_callback(f"Running {self.identifier}...", -1) 167 | elif ix == 0: 168 | status_callback(f"Running {self.identifier}...", 0) 169 | else: 170 | interval: int = total // 20 171 | if interval != 0 and ix % interval == 0: 172 | status_callback(f"Running {self.identifier}...", ix / total) 173 | 174 | 175 | class Pipeline(BaseConfigurableElement): 176 | """ 177 | Pipeline that applies pipeline elements to a document base. 178 | 179 | The pipeline can be applied ('__call__') to a document base. 180 | 181 | A pipeline is a configurable element. 182 | """ 183 | identifier: str = "Pipeline" 184 | 185 | def __init__(self, pipeline_elements: List[BasePipelineElement]) -> None: 186 | """ 187 | Initialize the Pipeline. 188 | 189 | :param pipeline_elements: list of pipeline elements that make up the pipeline 190 | """ 191 | super(Pipeline, self).__init__() 192 | self._pipeline_elements: List[BasePipelineElement] = pipeline_elements 193 | 194 | logger.debug("Initialized the pipeline.") 195 | 196 | def validate_consistency(self, initial_signals: Dict[str, List[str]]) -> bool: 197 | """ 198 | Validate the consistency of the pipeline regarding required and generated signals. 199 | 200 | This method checks for each pipeline element whether the signals it requires are actually present in the 201 | document base. 202 | 203 | :param initial_signals: signals that exist in the document base before the pipeline is executed 204 | :return: True if the pipeline is consistent, else False 205 | """ 206 | current_signals: Dict[str, List[str]] = initial_signals 207 | 208 | for pipeline_element in self._pipeline_elements: 209 | for data_element in ["nuggets", "documents", "attributes"]: 210 | # check that all required signals exist before the pipeline element is executed 211 | for signal_identifier in pipeline_element.required_signal_identifiers[data_element]: 212 | if signal_identifier not in current_signals[data_element]: 213 | return False 214 | 215 | # add the newly generated signals to the current signals 216 | for signal_identifier in pipeline_element.generated_signal_identifiers[data_element]: 217 | if signal_identifier not in current_signals[data_element]: 218 | current_signals[data_element].append(signal_identifier) 219 | 220 | return True 221 | 222 | @property 223 | def pipeline_elements(self) -> List[BasePipelineElement]: 224 | return self._pipeline_elements 225 | 226 | def __str__(self) -> str: 227 | return f"({', '.join(str(pipeline_element) for pipeline_element in self._pipeline_elements)})" 228 | 229 | def __eq__(self, other) -> bool: 230 | return isinstance(other, Pipeline) and self._pipeline_elements == other._pipeline_elements 231 | 232 | def __call__( 233 | self, 234 | document_base: DocumentBase, 235 | interaction_callback: BaseInteractionCallback, 236 | status_callback: BaseStatusCallback, 237 | statistics: Statistics 238 | ) -> None: 239 | """ 240 | Apply the pipeline to the document base. 241 | 242 | :param document_base: document base to work on 243 | :param interaction_callback: callback to allow for user interaction 244 | :param status_callback: callback to communicate current status (message and progress) 245 | :param statistics: statistics object to collect statistics 246 | """ 247 | logger.info("Execute the pipeline.") 248 | tick: float = time.time() 249 | status_callback("Running the pipeline...", -1) 250 | 251 | for ix, pipeline_element in enumerate(self._pipeline_elements): 252 | pipeline_element(document_base, interaction_callback, status_callback, statistics[f"pipeline-element-{ix}"]) 253 | 254 | status_callback("Running the pipeline...", 1) 255 | tack: float = time.time() 256 | logger.info(f"Executed the pipeline in {tack - tick} seconds.") 257 | statistics["runtime"] = tack - tick 258 | 259 | def to_config(self) -> Dict[str, Any]: 260 | """ 261 | Obtain a JSON-serializable representation of the pipeline. 262 | 263 | :return: JSON-serializable representation of the pipeline 264 | """ 265 | return { 266 | "identifier": self.identifier, 267 | "pipeline_elements": [pipeline_element.to_config() for pipeline_element in self._pipeline_elements] 268 | } 269 | 270 | @classmethod 271 | def from_config(cls, config: Dict[str, Any]) -> "Pipeline": 272 | """ 273 | Create the pipeline from its JSON-serializable representation. 274 | 275 | :param config: JSON-serializable representation of the pipeline 276 | :return: pipeline created from the JSON-serializable representation 277 | """ 278 | return cls( 279 | [BasePipelineElement.from_config(element_config) for element_config in config["pipeline_elements"]] 280 | ) 281 | -------------------------------------------------------------------------------- /wannadb/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataManagementLab/wannadb/05b36dbf87a66d0e022b4638d764cbdbc8e2c69d/wannadb/data/__init__.py -------------------------------------------------------------------------------- /wannadb/data/signals.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import io 3 | import logging 4 | from typing import Any, Dict, List, Type 5 | 6 | import numpy as np 7 | 8 | logger: logging.Logger = logging.getLogger(__name__) 9 | 10 | SIGNALS: Dict[str, Type["BaseSignal"]] = {} 11 | 12 | 13 | def register_signal(signal: Type["BaseSignal"]) -> Type["BaseSignal"]: 14 | """Register the given signal class.""" 15 | SIGNALS[signal.identifier] = signal 16 | return signal 17 | 18 | 19 | class BaseSignal(abc.ABC): 20 | """ 21 | Signals for InformationNuggets, Attributes, and Documents. 22 | 23 | Signals are values that can be set for a given data object (e.g. InformationNugget), but do not necessarily need to be set 24 | for every data object. Each pipeline element specifies which signals it requires for nuggets, attributes, and 25 | documents and which signals it generates. 26 | 27 | Signals also specify whether they should be serialized ('do_serialize') and provide their own serialization 28 | ('to_serializable') and deserialization ('from_serializable') implementation. 29 | 30 | Furthermore, each kind of signal must be identifiable by a unique identifier. 31 | """ 32 | identifier: str = "BaseSignal" 33 | do_serialize: bool = False 34 | 35 | def __init__(self, value: Any) -> None: 36 | """ 37 | Initialize the signal. 38 | 39 | :param value: value of the signal 40 | """ 41 | super(BaseSignal, self).__init__() 42 | self._value: Any = value 43 | 44 | def __str__(self) -> str: 45 | return str(self._value) 46 | 47 | def __repr__(self) -> str: 48 | return f"{self.__class__.__name__}({repr(self._value)})" 49 | 50 | def __eq__(self, other) -> bool: 51 | return isinstance(other, self.__class__) and self._value == other._value 52 | 53 | def __hash__(self) -> int: 54 | return hash(self.identifier) 55 | 56 | @property 57 | def value(self) -> Any: 58 | """Value of the signal.""" 59 | return self._value 60 | 61 | @value.setter 62 | def value(self, value: Any) -> None: 63 | """Set the value of the signal.""" 64 | self._value = value 65 | 66 | @abc.abstractmethod 67 | def to_serializable(self) -> Any: 68 | """ 69 | Convert the signal to a BSON-serializable representation. 70 | 71 | :return: BSON-serializable representation of the signal 72 | """ 73 | raise NotImplementedError 74 | 75 | @classmethod 76 | def from_serializable(cls, serialized_signal: Any, identifier: str) -> "BaseSignal": 77 | """ 78 | Create a signal from the BSON-serializable representation. 79 | 80 | :param serialized_signal: BSON-serializable representation of the signal 81 | :param identifier: identifier of the signal kind 82 | :return: deserialized signal 83 | """ 84 | return SIGNALS[identifier].from_serializable(serialized_signal, identifier) 85 | 86 | 87 | class BaseUntypedSignal(BaseSignal, abc.ABC): 88 | """Base class for all untyped signals""" 89 | identifier: str = "BaseUntypedSignal" 90 | do_serialize: bool = False 91 | 92 | @property 93 | def value(self) -> Any: 94 | return self._value 95 | 96 | @value.setter 97 | def value(self, value: Any) -> None: 98 | self._value = value 99 | 100 | def to_serializable(self) -> Any: 101 | return self.value 102 | 103 | @classmethod 104 | def from_serializable(cls, serialized_signal: Any, identifier: str) -> "BaseUntypedSignal": 105 | return cls(serialized_signal) 106 | 107 | 108 | class BaseIntSignal(BaseSignal, abc.ABC): 109 | """Base class for all integer signals.""" 110 | identifier: str = "BaseIntSignal" 111 | do_serialize: bool = False 112 | 113 | @property 114 | def value(self) -> int: 115 | return self._value 116 | 117 | @value.setter 118 | def value(self, value: int) -> None: 119 | self._value = value 120 | 121 | def to_serializable(self) -> int: 122 | return self.value 123 | 124 | @classmethod 125 | def from_serializable(cls, serialized_signal: int, identifier: str) -> "BaseIntSignal": 126 | return cls(serialized_signal) 127 | 128 | 129 | class BaseFloatSignal(BaseSignal, abc.ABC): 130 | """Base class for all float signals.""" 131 | identifier: str = "BaseFloatSignal" 132 | do_serialize: bool = False 133 | 134 | @property 135 | def value(self) -> float: 136 | return self._value 137 | 138 | @value.setter 139 | def value(self, value: float) -> None: 140 | self._value = value 141 | 142 | def to_serializable(self) -> float: 143 | return self.value 144 | 145 | @classmethod 146 | def from_serializable(cls, serialized_signal: float, identifier: str) -> "BaseFloatSignal": 147 | return cls(serialized_signal) 148 | 149 | 150 | class BaseStringSignal(BaseSignal, abc.ABC): 151 | """Base class for all string signals.""" 152 | identifier: str = "BaseStringSignal" 153 | do_serialize: bool = False 154 | 155 | @property 156 | def value(self) -> str: 157 | return self._value 158 | 159 | @value.setter 160 | def value(self, value: str) -> None: 161 | self._value = value 162 | 163 | def to_serializable(self) -> str: 164 | return self.value 165 | 166 | @classmethod 167 | def from_serializable(cls, serialized_signal: str, identifier: str) -> "BaseStringSignal": 168 | return cls(serialized_signal) 169 | 170 | 171 | class BaseIntListSignal(BaseSignal, abc.ABC): 172 | """Base class for all integer list signals.""" 173 | identifier: str = "BaseIntListSignal" 174 | do_serialize: bool = False 175 | 176 | @property 177 | def value(self) -> List[int]: 178 | return self._value 179 | 180 | @value.setter 181 | def value(self, value: List[int]) -> None: 182 | self._value = value 183 | 184 | def to_serializable(self) -> List[int]: 185 | return self.value 186 | 187 | @classmethod 188 | def from_serializable(cls, serialized_signal: List[int], identifier: str) -> "BaseIntListSignal": 189 | return cls(serialized_signal) 190 | 191 | 192 | class BaseFloatListSignal(BaseSignal, abc.ABC): 193 | """Base class for all float list signals.""" 194 | identifier: str = "BaseIntListSignal" 195 | do_serialize: bool = False 196 | 197 | @property 198 | def value(self) -> List[float]: 199 | return self._value 200 | 201 | @value.setter 202 | def value(self, value: List[float]) -> None: 203 | self._value = value 204 | 205 | def to_serializable(self) -> List[float]: 206 | return self.value 207 | 208 | @classmethod 209 | def from_serializable(cls, serialized_signal: List[float], identifier: str) -> "BaseFloatListSignal": 210 | return cls(serialized_signal) 211 | 212 | 213 | class BaseStringListSignal(BaseSignal, abc.ABC): 214 | """Base class for all string list signals.""" 215 | identifier: str = "BaseStringListSignal" 216 | do_serialize: bool = False 217 | 218 | @property 219 | def value(self) -> List[str]: 220 | return self._value 221 | 222 | @value.setter 223 | def value(self, value: List[str]) -> None: 224 | self._value = value 225 | 226 | def to_serializable(self) -> List[str]: 227 | return self.value 228 | 229 | @classmethod 230 | def from_serializable(cls, serialized_signal: List[str], identifier: str) -> "BaseStringListSignal": 231 | return cls(serialized_signal) 232 | 233 | 234 | class BaseNumpyArraySignal(BaseSignal, abc.ABC): 235 | """Base class forall numpy array signals.""" 236 | identifier: str = "BaseNumpyArraySignal" 237 | do_serialize: bool = False 238 | 239 | def __eq__(self, other) -> bool: 240 | return isinstance(other, self.__class__) and np.array_equal(self._value, other._value) 241 | 242 | @property 243 | def value(self) -> np.ndarray: 244 | return self._value 245 | 246 | @value.setter 247 | def value(self, value: np.ndarray) -> None: 248 | self._value = value 249 | 250 | def to_serializable(self) -> bytes: 251 | save_bytes: io.BytesIO = io.BytesIO() 252 | # noinspection PyTypeChecker 253 | np.save(save_bytes, self._value, allow_pickle=True) 254 | return save_bytes.getvalue() 255 | 256 | @classmethod 257 | def from_serializable(cls, serialized_signal: bytes, identifier: str) -> "BaseNumpyArraySignal": 258 | load_bytes: io.BytesIO = io.BytesIO(serialized_signal) 259 | # noinspection PyTypeChecker 260 | return cls(np.load(load_bytes, allow_pickle=True)) 261 | 262 | 263 | ######################################################################################################################## 264 | # actual signals 265 | ######################################################################################################################## 266 | 267 | 268 | @register_signal 269 | class ValueSignal(BaseStringSignal): 270 | """Value of the nugget.""" 271 | identifier: str = "ValueSignal" 272 | do_serialize: bool = True 273 | 274 | 275 | @register_signal 276 | class TypeSignal(BaseStringSignal): 277 | """Type identifier of the nugget's value type.""" 278 | identifier: str = "TypeSignal" 279 | do_serialize: bool = True 280 | 281 | 282 | @register_signal 283 | class ExtractorNameSignal(BaseStringSignal): 284 | """Type identifier of the nugget's value type.""" 285 | identifier: str = "ExtractorNameSignal" 286 | do_serialize: bool = True 287 | 288 | 289 | @register_signal 290 | class LabelSignal(BaseFloatSignal): 291 | """Label of the nugget as determined by the extractors.""" 292 | identifier: str = "LabelSignal" 293 | do_serialize: bool = True 294 | 295 | 296 | @register_signal 297 | class NaturalLanguageLabelSignal(BaseStringSignal): 298 | """Natural language version of the nugget's label that works well with natural language embeddings.""" 299 | identifier: str = "NaturalLanguageLabelSignal" 300 | do_serialize: bool = True 301 | 302 | 303 | @register_signal 304 | class RelativePositionSignal(BaseFloatSignal): 305 | """Relative position of the nugget based on the total length of the document.""" 306 | identifier: str = "RelativePositionSignal" 307 | do_serialize: bool = True 308 | 309 | 310 | @register_signal 311 | class CachedContextSentenceSignal(BaseStringSignal): 312 | """Context sentence and position in context for caching.""" 313 | identifier: str = "CachedContextSentenceSignal" 314 | do_serialize: bool = False 315 | 316 | 317 | @register_signal 318 | class CachedDistanceSignal(BaseFloatSignal): 319 | """Cached distance of the nugget or attribute.""" 320 | identifier: str = "CachedDistanceSignal" 321 | do_serialize: bool = False 322 | 323 | 324 | @register_signal 325 | class CurrentMatchIndexSignal(BaseIntSignal): 326 | """Index of the nugget that is currently considered as the match.""" 327 | identifier: str = "CurrentMatchIndexSignal" 328 | do_serialize: bool = False 329 | 330 | 331 | @register_signal 332 | class POSTagsSignal(BaseStringListSignal): 333 | """POS tags of the nugget's words as determined by extractors.""" 334 | identifier: str = "POSTagsSignal" 335 | do_serialize: bool = True 336 | 337 | 338 | @register_signal 339 | class UserProvidedExamplesSignal(BaseStringListSignal): 340 | """User-provided example values/texts for an attribute.""" 341 | identifier: str = "UserProvidedExamplesSignal" 342 | do_serialize: bool = True 343 | 344 | 345 | @register_signal 346 | class SentenceStartCharsSignal(BaseIntListSignal): 347 | """Sentence boundaries as a list of indices of the first characters in each sentence.""" 348 | identifier: str = "SentenceStartCharsSignal" 349 | do_serialize: bool = True 350 | 351 | 352 | @register_signal 353 | class LabelEmbeddingSignal(BaseNumpyArraySignal): 354 | """Embedding of the nugget's label or attribute's name.""" 355 | identifier: str = "LabelEmbeddingSignal" 356 | do_serialize: bool = True 357 | 358 | 359 | @register_signal 360 | class TextEmbeddingSignal(BaseNumpyArraySignal): 361 | """Embedding of the nugget's text.""" 362 | identifier: str = "TextEmbeddingSignal" 363 | do_serialize: bool = True 364 | 365 | 366 | @register_signal 367 | class ContextSentenceEmbeddingSignal(BaseNumpyArraySignal): 368 | """Embedding of the nugget's textual context sentence.""" 369 | identifier: str = "ContextSentenceEmbeddingSignal" 370 | do_serialize: bool = True 371 | 372 | 373 | @register_signal 374 | class DocumentSentenceEmbeddingSignal(BaseNumpyArraySignal): 375 | """Embedding of the sentences of a document.""" 376 | identifier: str = "DocumentSentenceEmbeddingSignal" 377 | do_serialize: bool = True 378 | -------------------------------------------------------------------------------- /wannadb/interaction.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | from typing import Dict, Any, Callable 4 | 5 | logger: logging.Logger = logging.getLogger(__name__) 6 | 7 | 8 | class BaseInteractionCallback(abc.ABC): 9 | """ 10 | Base class for all interaction callbacks. 11 | 12 | An interaction callback allows pipeline elements to interact with the user. The interaction callback is provided to 13 | the pipeline element when applying it to the document base. The pipeline element calls ('__call__') the interaction 14 | callback to interact with the user. The return value of the interaction callback provides the user's feedback to the 15 | pipeline element. 16 | 17 | Both the parameters and the return values of the interaction callbacks are generic dictionaries, the content of 18 | which may differ between different implementations. 19 | """ 20 | 21 | def __call__(self, pipeline_element_identifier: str, data: Dict[str, Any]) -> Dict[str, Any]: 22 | """ 23 | Interaction between the pipeline element and the user. 24 | 25 | This method is called by the pipeline element and calls the _call method that contains the actual implementation 26 | of the interaction callback. 27 | 28 | :param pipeline_element_identifier: identifier of the calling pipeline element 29 | :param data: parameters of the feedback request provided to the user interface 30 | :return: result of the feedback request provided to the pipeline element 31 | """ 32 | logger.info(f"{pipeline_element_identifier} called the interaction callback.") 33 | return self._call(pipeline_element_identifier, data) 34 | 35 | @abc.abstractmethod 36 | def _call(self, pipeline_element_identifier: str, data: Dict[str, Any]) -> Dict[str, Any]: 37 | """ 38 | Interaction between the pipeline element and the user. 39 | 40 | This method is overwritten by the actual interaction callbacks and contains their implementation. 41 | 42 | :param pipeline_element_identifier: identifier of the calling pipeline element 43 | :param data: parameters of the feedback request provided to the user interface 44 | :return: result of the feedback request provided to the pipeline element 45 | """ 46 | raise NotImplementedError 47 | 48 | 49 | class InteractionCallback(BaseInteractionCallback): 50 | """Interaction callback that is initialized with a callback function.""" 51 | 52 | def __init__(self, callback_fn: Callable[[str, Dict[str, Any]], Dict[str, Any]]): 53 | """ 54 | Initialize the interaction callback. 55 | 56 | :param callback_fn: callback function that is called whenever the interaction callback is called 57 | """ 58 | self._callback_fn: Callable[[str, Dict[str, Any]], Dict[str, Any]] = callback_fn 59 | 60 | def _call(self, pipeline_element_identifier: str, data: Dict[str, Any]) -> Dict[str, Any]: 61 | return self._callback_fn(pipeline_element_identifier, data) 62 | 63 | 64 | class EmptyInteractionCallback(BaseInteractionCallback): 65 | """Interaction callback that does nothing whenever it is called.""" 66 | 67 | def _call(self, pipeline_element_identifier: str, data: Dict[str, Any]) -> Dict[str, Any]: 68 | pass 69 | -------------------------------------------------------------------------------- /wannadb/matching/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataManagementLab/wannadb/05b36dbf87a66d0e022b4638d764cbdbc8e2c69d/wannadb/matching/__init__.py -------------------------------------------------------------------------------- /wannadb/matching/distance.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | from collections.abc import Collection 4 | from typing import Any, Union 5 | 6 | import numpy as np 7 | from scipy.spatial.distance import cosine 8 | from sklearn.metrics.pairwise import cosine_distances 9 | 10 | from wannadb.configuration import BaseConfigurableElement, register_configurable_element 11 | from wannadb.data.data import Attribute, InformationNugget 12 | from wannadb.data.signals import ContextSentenceEmbeddingSignal, LabelEmbeddingSignal, \ 13 | POSTagsSignal, RelativePositionSignal, TextEmbeddingSignal 14 | from wannadb.statistics import Statistics 15 | 16 | logger: logging.Logger = logging.getLogger(__name__) 17 | 18 | 19 | class BaseDistance(BaseConfigurableElement, abc.ABC): 20 | """ 21 | Base class for all distance functions. 22 | 23 | Distance functions compute distances between InformationNuggets and Attributes. They must be able to compute distances 24 | between pairs of InformationNuggets, pairs of Attributes, or mixed pairs. 25 | """ 26 | identifier: str = "BaseDistance" 27 | 28 | # identifiers of the signals that the distance function requires for nuggets, attributes, and documents 29 | # signals the distance function may use if they exist but does not necessarily require are not part of this list 30 | required_signal_identifiers: dict[str, list[str]] = { 31 | "nuggets": [], 32 | "attributes": [], 33 | "documents": [] 34 | } 35 | 36 | @abc.abstractmethod 37 | def compute_distance( 38 | self, 39 | x: Union[InformationNugget, Attribute], 40 | y: Union[InformationNugget, Attribute], 41 | statistics: Statistics 42 | ) -> float: 43 | """ 44 | Compute distance between the two given InformationNuggets/Attributes. 45 | 46 | :param x: first InformationNugget/Attribute 47 | :param y: second InformationNugget/Attribute 48 | :param statistics: statistics object to collect statistics 49 | :return: computed distance 50 | """ 51 | raise NotImplementedError 52 | 53 | def compute_distances( 54 | self, 55 | xs: Collection[Union[InformationNugget, Attribute]], 56 | ys: Collection[Union[InformationNugget, Attribute]], 57 | statistics: Statistics 58 | ) -> np.ndarray: 59 | """ 60 | Compute distances between all pairs from two collections of InformationNuggets/Attributes. 61 | 62 | This method exists to speed up the calculation using batching. The default implementation works by calling the 63 | 'compute_distance' method. 64 | 65 | :param xs: first list of InformationNuggets/Attributes 66 | :param ys: second list of InformationNuggets/Attributes 67 | :param statistics: statistics object to collect statistics 68 | :return: matrix of computed distances (row corresponds to xs, column corresponds to ys) 69 | """ 70 | statistics["num_multi_call"] += 1 71 | 72 | assert len(xs) > 0 and len(ys) > 0, "Cannot compute distances for an empty collection!" 73 | 74 | res: np.ndarray = np.zeros((len(xs), len(ys))) 75 | for x_ix, x in enumerate(xs): 76 | for y_ix, y in enumerate(ys): 77 | res[x_ix, y_ix] = self.compute_distance(x, y, statistics) 78 | return res 79 | 80 | 81 | ######################################################################################################################## 82 | # actual distance functions 83 | ######################################################################################################################## 84 | 85 | 86 | @register_configurable_element 87 | class SignalsMeanDistance(BaseDistance): 88 | """Compute the distance as the mean of the distances between the available signals.""" 89 | 90 | identifier: str = "SignalsMeanDistance" 91 | 92 | required_signal_identifiers: dict[str, list[str]] = { 93 | "nuggets": [ 94 | LabelEmbeddingSignal.identifier 95 | # can use (but not required): TextEmbeddingSignal.identifier 96 | # can use (but not required): ContextSentenceEmbeddingSignal.identifier 97 | # can use (but not required): RelativePositionSignal.identifier 98 | # can use (but not required): POSTagsSignal.identifier 99 | ], 100 | "attributes": [LabelEmbeddingSignal.identifier], 101 | "documents": [] 102 | } 103 | 104 | def __init__(self, signal_identifiers: list[str]) -> None: 105 | """ 106 | Initialize the SignalsMeanDistance. 107 | 108 | :param signal_identifiers: identifiers of the signals to include 109 | """ 110 | super(SignalsMeanDistance, self).__init__() 111 | self._signal_identifiers: list[str] = list(set(signal_identifiers + [LabelEmbeddingSignal.identifier])) 112 | logger.debug(f"Initialized '{self.identifier}'.") 113 | 114 | def compute_distance( 115 | self, 116 | x: Union[InformationNugget, Attribute], 117 | y: Union[InformationNugget, Attribute], 118 | statistics: Statistics 119 | ) -> float: 120 | statistics["num_calls"] += 1 121 | 122 | distances: np.ndarray = np.zeros(5) 123 | is_present: np.ndarray = np.zeros(5) 124 | 125 | label_embedding_signal_identifier: str = LabelEmbeddingSignal.identifier 126 | if ( 127 | label_embedding_signal_identifier in self._signal_identifiers 128 | and label_embedding_signal_identifier in x.signals.keys() 129 | and label_embedding_signal_identifier in y.signals.keys() 130 | ): 131 | cosine_distance: float = float( 132 | cosine(x[label_embedding_signal_identifier], y[label_embedding_signal_identifier])) 133 | distances[0] = min(abs(cosine_distance), 1) 134 | is_present[0] = 1 135 | 136 | text_embedding_signal_identifier: str = TextEmbeddingSignal.identifier 137 | if ( 138 | text_embedding_signal_identifier in self._signal_identifiers 139 | and text_embedding_signal_identifier in x.signals.keys() 140 | and text_embedding_signal_identifier in y.signals.keys() 141 | ): 142 | cosine_distance: float = float( 143 | cosine(x[text_embedding_signal_identifier], y[text_embedding_signal_identifier])) 144 | distances[1] = min(abs(cosine_distance), 1) 145 | is_present[1] = 1 146 | 147 | context_sentence_embedding_signal_identifier: str = ContextSentenceEmbeddingSignal.identifier 148 | if ( 149 | context_sentence_embedding_signal_identifier in self._signal_identifiers 150 | and context_sentence_embedding_signal_identifier in x.signals.keys() 151 | and context_sentence_embedding_signal_identifier in y.signals.keys() 152 | ): 153 | cosine_distance: float = float( 154 | cosine(x[context_sentence_embedding_signal_identifier], y[context_sentence_embedding_signal_identifier]) 155 | ) 156 | distances[2] = min(abs(cosine_distance), 1) 157 | is_present[2] = 1 158 | 159 | relative_position_signal_identifier: str = RelativePositionSignal.identifier 160 | if ( 161 | relative_position_signal_identifier in self._signal_identifiers 162 | and relative_position_signal_identifier in x.signals.keys() 163 | and relative_position_signal_identifier in y.signals.keys() 164 | ): 165 | relative_distance: float = (x[relative_position_signal_identifier] - y[relative_position_signal_identifier]) 166 | distances[3] = min(abs(relative_distance), 1) 167 | is_present[3] = 1 168 | 169 | pos_tags_signal_identifier: str = POSTagsSignal.identifier 170 | if ( 171 | pos_tags_signal_identifier in self._signal_identifiers 172 | and pos_tags_signal_identifier in x.signals.keys() 173 | and pos_tags_signal_identifier in y.signals.keys() 174 | ): 175 | if x[pos_tags_signal_identifier] == y[context_sentence_embedding_signal_identifier]: 176 | distances[4] = 0 177 | else: 178 | distances[4] = 1 # TODO: magic float, measure "distance" 179 | is_present[4] = 1 180 | 181 | return 1 if np.sum(is_present) == 0 else np.sum(distances) / np.sum(is_present) 182 | 183 | def compute_distances( 184 | self, 185 | xs: Collection[Union[InformationNugget, Attribute]], 186 | ys: Collection[Union[InformationNugget, Attribute]], 187 | statistics: Statistics 188 | ) -> np.ndarray: 189 | statistics["num_multi_calls"] += 1 190 | 191 | assert len(xs) > 0 and len(ys) > 0, "Cannot compute distances for an empty collection!" 192 | if not isinstance(xs, list): 193 | xs = list(xs) 194 | if not isinstance(ys, list): 195 | ys = list(xs) 196 | 197 | signal_identifiers: list[str] = [ 198 | LabelEmbeddingSignal.identifier, 199 | TextEmbeddingSignal.identifier, 200 | ContextSentenceEmbeddingSignal.identifier, 201 | RelativePositionSignal.identifier, 202 | POSTagsSignal.identifier 203 | ] 204 | 205 | # check that all xs and all ys contain the same signals 206 | xs_is_present: np.ndarray = np.zeros(5) 207 | for idx in range(5): 208 | if signal_identifiers[idx] in self._signal_identifiers and signal_identifiers[idx] in xs[0].signals.keys(): 209 | xs_is_present[idx] = 1 210 | for x in xs: 211 | for idx in range(5): 212 | if signal_identifiers[idx] in self._signal_identifiers: 213 | if ( 214 | xs_is_present[idx] == 1 215 | and signal_identifiers[idx] not in x.signals.keys() 216 | or xs_is_present[idx] == 0 217 | and signal_identifiers[idx] in x.signals.keys() 218 | ): 219 | assert False, "All xs must have the same signals!" 220 | 221 | ys_is_present: np.ndarray = np.zeros(5) 222 | for idx in range(5): 223 | if signal_identifiers[idx] in self._signal_identifiers and signal_identifiers[idx] in ys[0].signals.keys(): 224 | ys_is_present[idx] = 1 225 | for y in ys: 226 | for idx in range(5): 227 | if signal_identifiers[idx] in self._signal_identifiers: 228 | if ( 229 | ys_is_present[idx] == 1 230 | and signal_identifiers[idx] not in y.signals.keys() 231 | or ys_is_present[idx] == 0 232 | and signal_identifiers[idx] in y.signals.keys() 233 | ): 234 | assert False, "All ys must have the same signals!" 235 | 236 | # compute distances signal by signal 237 | distances: np.ndarray = np.zeros((len(xs), len(ys))) 238 | for idx in range(3): 239 | if xs_is_present[idx] == 1 and ys_is_present[idx] == 1: 240 | x_embeddings: np.ndarray = np.array([x[signal_identifiers[idx]] for x in xs]) 241 | y_embeddings: np.ndarray = np.array([y[signal_identifiers[idx]] for y in ys]) 242 | tmp: np.ndarray = cosine_distances(x_embeddings, y_embeddings) 243 | distances = np.add(distances, tmp) 244 | 245 | if xs_is_present[3] == 1 and ys_is_present[3] == 1: 246 | x_positions: np.ndarray = np.array([x[signal_identifiers[3]] for x in xs]) 247 | y_positions: np.ndarray = np.array([y[signal_identifiers[3]] for y in ys]) 248 | tmp: np.ndarray = np.zeros((len(x_positions), len(y_positions))) 249 | for x_ix, x_value in enumerate(x_positions): 250 | for y_ix, y_value in enumerate(y_positions): 251 | tmp[x_ix, y_ix] = np.abs(x_value - y_value) 252 | distances = np.add(distances, tmp) 253 | 254 | if xs_is_present[4] == 1 and ys_is_present[4] == 1: 255 | x_values: list[list[str]] = [x[signal_identifiers[4]] for x in xs] 256 | y_values: list[list[str]] = [y[signal_identifiers[4]] for y in ys] 257 | tmp: np.ndarray = np.ones((len(x_values), len(y_values))) 258 | for x_ix, x_value in enumerate(x_values): 259 | for y_ix, y_value in enumerate(y_values): 260 | if x_value == y_value: 261 | tmp[x_ix, y_ix] = 0 262 | distances = np.add(distances, tmp) 263 | 264 | actually_present: np.ndarray = xs_is_present * ys_is_present 265 | if np.sum(actually_present) == 0: 266 | return np.ones_like(distances) 267 | else: 268 | return np.divide(distances, np.sum(actually_present)) 269 | 270 | @classmethod 271 | def from_config(cls, config: dict[str, Any]) -> "SignalsMeanDistance": 272 | return cls(config["signal_identifiers"]) 273 | 274 | def to_config(self) -> dict[str, Any]: 275 | return { 276 | "identifier": self.identifier, 277 | "signal_identifiers": self._signal_identifiers 278 | } 279 | -------------------------------------------------------------------------------- /wannadb/preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataManagementLab/wannadb/05b36dbf87a66d0e022b4638d764cbdbc8e2c69d/wannadb/preprocessing/__init__.py -------------------------------------------------------------------------------- /wannadb/preprocessing/extraction.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import json 3 | import logging 4 | from typing import Any, Dict, List 5 | 6 | import requests 7 | from spacy.tokens import Doc 8 | 9 | from wannadb import resources 10 | from wannadb.configuration import register_configurable_element, BasePipelineElement 11 | from wannadb.data.data import DocumentBase, InformationNugget 12 | from wannadb.data.signals import LabelSignal, POSTagsSignal, SentenceStartCharsSignal, ExtractorNameSignal 13 | from wannadb.interaction import BaseInteractionCallback 14 | from wannadb.resources import StanzaNERPipeline, FigerNERPipeline 15 | from wannadb.statistics import Statistics 16 | from wannadb.status import BaseStatusCallback 17 | 18 | logger: logging.Logger = logging.getLogger(__name__) 19 | 20 | 21 | class BaseExtractor(BasePipelineElement, abc.ABC): 22 | """ 23 | Base class for all extractors. 24 | 25 | Extractors derive the information nuggets from the documents. 26 | """ 27 | identifier: str = "BaseExtractor" 28 | 29 | 30 | ######################################################################################################################## 31 | # actual extractors 32 | ######################################################################################################################## 33 | 34 | 35 | @register_configurable_element 36 | class SpacyNERExtractor(BaseExtractor): 37 | """Extractor based on spacy's NER models.""" 38 | 39 | identifier: str = "SpacyNERExtractor" 40 | 41 | required_signal_identifiers: Dict[str, List[str]] = { 42 | "nuggets": [], 43 | "attributes": [], 44 | "documents": [] 45 | } 46 | 47 | generated_signal_identifiers: Dict[str, List[str]] = { 48 | "nuggets": [LabelSignal.identifier, POSTagsSignal.identifier, ExtractorNameSignal.identifier], 49 | "attributes": [], 50 | "documents": [SentenceStartCharsSignal.identifier] 51 | } 52 | 53 | def __init__(self, spacy_resource_identifier: str) -> None: 54 | """ 55 | Initialize the SpacyNERExtractor. 56 | 57 | :param spacy_resource_identifier: identifier of the spacy model resource 58 | """ 59 | super(SpacyNERExtractor, self).__init__() 60 | self._spacy_resource_identifier: str = spacy_resource_identifier 61 | 62 | # preload required resources 63 | resources.MANAGER.load(self._spacy_resource_identifier) 64 | logger.debug(f"Initialized '{self.identifier}'.") 65 | 66 | def _call( 67 | self, 68 | document_base: DocumentBase, 69 | interaction_callback: BaseInteractionCallback, 70 | status_callback: BaseStatusCallback, 71 | statistics: Statistics 72 | ) -> None: 73 | statistics["num_documents"] = len(document_base.documents) 74 | 75 | for ix, document in enumerate(document_base.documents): 76 | self._use_status_callback(status_callback, ix, len(document_base.documents)) 77 | 78 | spacy_output: Doc = resources.MANAGER[self._spacy_resource_identifier](document.text) 79 | sentence_start_chars: List[int] = [] 80 | 81 | # transform the spacy output into the document and nuggets 82 | for sentence in spacy_output.sents: 83 | sentence_start_chars.append(sentence.start_char) 84 | 85 | document[SentenceStartCharsSignal] = SentenceStartCharsSignal(sentence_start_chars) 86 | 87 | for entity in spacy_output.ents: 88 | nugget: InformationNugget = InformationNugget( 89 | document=document, 90 | start_char=entity.start_char, 91 | end_char=entity.end_char 92 | ) 93 | 94 | nugget[POSTagsSignal] = POSTagsSignal([]) # TODO: gather pos tags 95 | nugget[LabelSignal] = LabelSignal(entity.label_) 96 | nugget[ExtractorNameSignal] = ExtractorNameSignal(self.identifier) 97 | 98 | document.nuggets.append(nugget) 99 | 100 | statistics["num_nuggets"] += 1 101 | statistics["spacy_entity_type_dist"][entity.label_] += 1 102 | 103 | def to_config(self) -> Dict[str, Any]: 104 | return { 105 | "identifier": self.identifier, 106 | "spacy_resource_identifier": self._spacy_resource_identifier 107 | } 108 | 109 | @classmethod 110 | def from_config(cls, config: Dict[str, Any]) -> "SpacyNERExtractor": 111 | return cls(config["spacy_resource_identifier"]) 112 | 113 | 114 | @register_configurable_element 115 | class StanzaNERExtractor(BaseExtractor): 116 | """Extractor based on Stanza's NER model.""" 117 | 118 | identifier: str = "StanzaNERExtractor" 119 | 120 | required_signal_identifiers: Dict[str, List[str]] = { 121 | "nuggets": [], 122 | "attributes": [], 123 | "documents": [] 124 | } 125 | 126 | generated_signal_identifiers: Dict[str, List[str]] = { 127 | "nuggets": [LabelSignal.identifier, POSTagsSignal.identifier, ExtractorNameSignal.identifier], 128 | "attributes": [], 129 | "documents": [SentenceStartCharsSignal.identifier] 130 | } 131 | 132 | def __init__(self) -> None: 133 | """Initialize the StanzaNERExtractor.""" 134 | super(StanzaNERExtractor, self).__init__() 135 | 136 | # preload required resources 137 | resources.MANAGER.load(StanzaNERPipeline) 138 | logger.debug(f"Initialized '{self.identifier}'.") 139 | 140 | def _call( 141 | self, 142 | document_base: DocumentBase, 143 | interaction_callback: BaseInteractionCallback, 144 | status_callback: BaseStatusCallback, 145 | statistics: Statistics 146 | ) -> None: 147 | statistics["num_documents"] = len(document_base.documents) 148 | 149 | for ix, document in enumerate(document_base.documents): 150 | self._use_status_callback(status_callback, ix, len(document_base.documents)) 151 | 152 | stanza_output = resources.MANAGER[StanzaNERPipeline](document.text) 153 | 154 | sentence_start_chars: List[int] = [] 155 | 156 | # transform the stanza output into the document and nuggets 157 | for sentence in stanza_output.sentences: 158 | sentence_start_chars.append(sentence.tokens[0].start_char) 159 | 160 | for entity in sentence.entities: 161 | nugget: InformationNugget = InformationNugget( 162 | document=document, 163 | start_char=entity.start_char, 164 | end_char=entity.start_char + len(entity.text) 165 | ) 166 | 167 | nugget[POSTagsSignal] = POSTagsSignal([word.xpos for word in entity.words]) 168 | nugget[LabelSignal] = LabelSignal(entity.type) 169 | nugget[ExtractorNameSignal] = ExtractorNameSignal(self.identifier) 170 | 171 | document.nuggets.append(nugget) 172 | 173 | statistics["num_nuggets"] += 1 174 | statistics["stanza_entity_type_dist"][entity.type] += 1 175 | 176 | document[SentenceStartCharsSignal] = SentenceStartCharsSignal(sentence_start_chars) 177 | 178 | def to_config(self) -> Dict[str, Any]: 179 | return { 180 | "identifier": self.identifier 181 | } 182 | 183 | @classmethod 184 | def from_config(cls, config: Dict[str, Any]) -> "StanzaNERExtractor": 185 | return cls() 186 | 187 | 188 | @register_configurable_element 189 | class FigerNERExtractor(BaseExtractor): 190 | """ 191 | Extractor based on Figer's NER model 192 | (using CoreNLP for basic extraction and fine-graned labeling on top). 193 | """ 194 | 195 | identifier: str = "FigerNERExtractor" 196 | 197 | required_signal_identifiers: Dict[str, List[str]] = { 198 | "nuggets": [], 199 | "attributes": [], 200 | "documents": [] 201 | } 202 | 203 | generated_signal_identifiers: Dict[str, List[str]] = { 204 | "nuggets": [LabelSignal.identifier], 205 | "attributes": [], 206 | "documents": [SentenceStartCharsSignal.identifier] 207 | } 208 | 209 | def __init__(self) -> None: 210 | """Initialize the FigerNERExtractor.""" 211 | super().__init__() 212 | 213 | # preload required resources 214 | resources.MANAGER.load(FigerNERPipeline) 215 | logger.debug(f"Initialized '{self.identifier}'.") 216 | 217 | def _call( 218 | self, 219 | document_base: DocumentBase, 220 | interaction_callback: BaseInteractionCallback, 221 | status_callback: BaseStatusCallback, 222 | statistics: Statistics 223 | ) -> None: 224 | statistics["num_documents"] = len(document_base.documents) 225 | 226 | base_url = resources.MANAGER[FigerNERPipeline] 227 | 228 | for ix, document in enumerate(document_base.documents): 229 | self._use_status_callback(status_callback, ix, len(document_base.documents)) 230 | 231 | # Run FIGER on document (truncate to first 8000 chars if necessary due to server limitations) 232 | r = requests.get(base_url, params={'text': document.text[:8000]}) 233 | if r.status_code == 200: 234 | answer = json.loads(r.text) 235 | if answer["status"] == 200: 236 | sentence_start_chars: List[int] = answer["sentence_offsets"] 237 | 238 | for raw_nugget in answer["data"]: 239 | nugget: InformationNugget = InformationNugget( 240 | document=document, 241 | start_char=raw_nugget["start_char"], 242 | end_char=raw_nugget["end_char"] 243 | ) 244 | 245 | # nugget[POSTagsSignal] = POSTagsSignal([word.xpos for word in entity.words]) 246 | 247 | # Label format from FIGER is e.g. 248 | # "/location@1.4898770776826524,/organization/company@0.17639383484191654,/location/country@0.25034040521054085" 249 | # Extract first label (without numeric value) 250 | label_string = raw_nugget['label'].split(',')[0].split('@')[0][1:].replace("/", " ") 251 | nugget[LabelSignal] = LabelSignal(label_string) 252 | document.nuggets.append(nugget) 253 | 254 | statistics["num_nuggets"] += 1 255 | statistics["figer_label_dist"][label_string] += 1 256 | 257 | document[SentenceStartCharsSignal] = SentenceStartCharsSignal(sentence_start_chars) 258 | else: 259 | print(logger.warning(f"Failed to run FIGER on document '{document.name}' with error '{answer['error']}'")) 260 | else: 261 | logger.warning(f"Failed to run FIGER on document '{document.name}'") 262 | 263 | def to_config(self) -> Dict[str, Any]: 264 | return { 265 | "identifier": self.identifier 266 | } 267 | 268 | @classmethod 269 | def from_config(cls, config: Dict[str, Any]) -> "FigerNERExtractor": 270 | return cls() 271 | -------------------------------------------------------------------------------- /wannadb/preprocessing/label_paraphrasing.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | import time 4 | from typing import Dict, List, Any 5 | 6 | from wannadb.configuration import BasePipelineElement, register_configurable_element 7 | from wannadb.data.data import DocumentBase, InformationNugget, Attribute 8 | from wannadb.data.signals import LabelSignal, NaturalLanguageLabelSignal 9 | from wannadb.interaction import BaseInteractionCallback 10 | from wannadb.statistics import Statistics 11 | from wannadb.status import BaseStatusCallback 12 | 13 | logger: logging.Logger = logging.getLogger(__name__) 14 | 15 | 16 | class BaseLabelParaphraser(BasePipelineElement, abc.ABC): 17 | """ 18 | Base class for all label paraphrasers. 19 | 20 | Label paraphrasers translate NER-tags (in the case of information nuggets) or column titles (in the case of 21 | attributes) into a natural language string that works well with natural language embeddings. 22 | """ 23 | identifier: str = "BaseLabelParaphraser" 24 | 25 | def _use_status_callback_for_label_paraphrasers( 26 | self, 27 | status_callback: BaseStatusCallback, 28 | element: str, 29 | ix: int, 30 | total: int 31 | ) -> None: 32 | """ 33 | Helper method that calls the status callback at regular intervals. 34 | 35 | :param status_callback: status callback to call 36 | :param element: 'nugget labels' or 'attribute names' 37 | :param ix: index of the current element 38 | :param total: total number of elements 39 | """ 40 | if total == 0: 41 | status_callback(f"Paraphrasing {element} with {self.identifier}...", -1) 42 | elif ix == 0: 43 | status_callback(f"Paraphrasing {element} with {self.identifier}...", 0) 44 | else: 45 | interval: int = total // 10 46 | if interval != 0 and ix % interval == 0: 47 | status_callback(f"Paraphrasing {element} with {self.identifier}...", ix / total) 48 | 49 | def _call( 50 | self, 51 | document_base: DocumentBase, 52 | interaction_callback: BaseInteractionCallback, 53 | status_callback: BaseStatusCallback, 54 | statistics: Statistics 55 | ) -> None: 56 | # paraphrasing nugget labels 57 | nuggets: List[InformationNugget] = document_base.nuggets 58 | logger.info(f"Paraphrase {len(nuggets)} nugget labels with {self.identifier}.") 59 | tick: float = time.time() 60 | status_callback(f"Paraphrasing nugget labels with {self.identifier}...", -1) 61 | statistics["nuggets"]["num_nuggets"] = len(nuggets) 62 | self._paraphrase_nugget_labels(nuggets, interaction_callback, status_callback, statistics["nuggets"]) 63 | status_callback(f"Paraphrasing nugget labels with {self.identifier}...", 1) 64 | tack: float = time.time() 65 | logger.info(f"Paraphrased {len(nuggets)} nugget labels with {self.identifier} in {tack - tick} seconds.") 66 | statistics["nuggets"]["runtime"] = tack - tick 67 | 68 | # paraphrasing attribute names 69 | attributes: List[Attribute] = document_base.attributes 70 | logger.info(f"Paraphrase {len(attributes)} attribute names with {self.identifier}.") 71 | tick: float = time.time() 72 | status_callback(f"Paraphrasing attribute names with {self.identifier}...", -1) 73 | statistics["attributes"]["num_attributes"] = len(nuggets) 74 | self._paraphrase_attribute_names(attributes, interaction_callback, status_callback, statistics["attributes"]) 75 | status_callback(f"Paraphrasing attribute names with {self.identifier}...", 1) 76 | tack: float = time.time() 77 | logger.info(f"Paraphrased {len(attributes)} attribute names with {self.identifier} in {tack - tick} seconds.") 78 | statistics["attributes"]["runtime"] = tack - tick 79 | 80 | def _paraphrase_nugget_labels( 81 | self, 82 | nuggets: List[InformationNugget], 83 | interaction_callback: BaseInteractionCallback, 84 | status_callback: BaseStatusCallback, 85 | statistics: Statistics 86 | ) -> None: 87 | """ 88 | Paraphrase nugget labels for the given list of nuggets. 89 | 90 | :param nuggets: list of nuggets to work on 91 | :param interaction_callback: callback to allow for user interaction 92 | :param status_callback: callback to communicate current status (message and progress) 93 | :param statistics: statistics object to collect statistics 94 | """ 95 | pass # default behavior: do nothing 96 | 97 | def _paraphrase_attribute_names( 98 | self, 99 | attributes: List[Attribute], 100 | interaction_callback: BaseInteractionCallback, 101 | status_callback: BaseStatusCallback, 102 | statistics: Statistics 103 | ) -> None: 104 | """ 105 | Paraphrase attribute names for the given list of InformationNuggets. 106 | 107 | :param attributes: list of Attributes to work on 108 | :param interaction_callback: callback to allow for user interaction 109 | :param status_callback: callback to communicate current status (message and progress) 110 | :param statistics: statistics object to collect statistics 111 | """ 112 | pass # default behavior: do nothing 113 | 114 | 115 | ######################################################################################################################## 116 | # actual label paraphrasers 117 | ######################################################################################################################## 118 | 119 | 120 | @register_configurable_element 121 | class OntoNotesLabelParaphraser(BaseLabelParaphraser): 122 | """Label paraphraser for OntoNotes NER tags based on their definition.""" 123 | identifier: str = "OntoNotesLabelParaphraser" 124 | 125 | required_signal_identifiers: Dict[str, List[str]] = { 126 | "nuggets": [LabelSignal.identifier], 127 | "attributes": [], 128 | "documents": [] 129 | } 130 | 131 | generated_signal_identifiers: Dict[str, List[str]] = { 132 | "nuggets": [NaturalLanguageLabelSignal.identifier], 133 | "attributes": [], 134 | "documents": [] 135 | } 136 | 137 | def __init__(self): 138 | """Initialize the OntoNotesLabelParaphraser.""" 139 | super(OntoNotesLabelParaphraser, self).__init__() 140 | logger.debug(f"Initialized '{self.identifier}'.") 141 | 142 | def _paraphrase_nugget_labels( 143 | self, 144 | nuggets: List[InformationNugget], 145 | interaction_callback: BaseInteractionCallback, 146 | status_callback: BaseStatusCallback, 147 | statistics: Statistics 148 | ) -> None: 149 | statistics["num_nuggets"] = len(nuggets) 150 | statistics["copied_labels"] = set() 151 | 152 | for ix, nugget in enumerate(nuggets): 153 | self._use_status_callback_for_label_paraphrasers(status_callback, "nugget labels", ix, len(nuggets)) 154 | label_mappings: Dict[str, str] = { 155 | "QUANTITY": "quantity measurement weight distance", 156 | "CARDINAL": "cardinal numeral", 157 | "NORP": "nationality religion political group", 158 | "FAC": "building airport highway bridge", 159 | "ORG": "organization", 160 | "GPE": "country city state", 161 | "LOC": "location mountain range body of water", 162 | "PRODUCT": "product vehicle weapon food", 163 | "EVENT": "event hurricane battle war sports", 164 | "WORK_OF_ART": "work of art title of book song", 165 | "LAW": "law document", 166 | "LANGUAGE": "language", 167 | "ORDINAL": "ordinal", 168 | "MONEY": "money", 169 | "PERCENT": "percentage", 170 | "DATE": "date period", 171 | "TIME": "time", 172 | "PERSON": "person", 173 | } 174 | 175 | if nugget[LabelSignal] in label_mappings.keys(): 176 | natural_language_label: str = label_mappings[nugget[LabelSignal]] 177 | nugget[NaturalLanguageLabelSignal] = NaturalLanguageLabelSignal(natural_language_label) 178 | statistics["num_label_changed"] += 1 179 | else: 180 | nugget[NaturalLanguageLabelSignal] = NaturalLanguageLabelSignal(nugget[LabelSignal]) 181 | statistics["num_label_copied"] += 1 182 | statistics["copied_labels"].add(nugget[LabelSignal]) 183 | 184 | def to_config(self) -> Dict[str, Any]: 185 | return { 186 | "identifier": self.identifier 187 | } 188 | 189 | @classmethod 190 | def from_config(cls, config: Dict[str, Any]) -> "OntoNotesLabelParaphraser": 191 | return cls() 192 | 193 | 194 | @register_configurable_element 195 | class CopyAttributeNameLabelParaphraser(BaseLabelParaphraser): 196 | """Label paraphraser that simply copies the attribute name as the natural language label signal.""" 197 | identifier: str = "CopyAttributeNameLabelParaphraser" 198 | 199 | def __init__(self): 200 | """Initialize the CopyAttributeNameLabelParaphraser.""" 201 | super(CopyAttributeNameLabelParaphraser, self).__init__() 202 | logger.debug(f"Initialized '{self.identifier}'.") 203 | 204 | required_signal_identifiers: Dict[str, List[str]] = { 205 | "nuggets": [], 206 | "attributes": [], 207 | "documents": [] 208 | } 209 | 210 | generated_signal_identifiers: Dict[str, List[str]] = { 211 | "nuggets": [], 212 | "attributes": [NaturalLanguageLabelSignal.identifier], 213 | "documents": [] 214 | } 215 | 216 | def _paraphrase_attribute_names( 217 | self, 218 | attributes: List[Attribute], 219 | interaction_callback: BaseInteractionCallback, 220 | status_callback: BaseStatusCallback, 221 | statistics: Statistics 222 | ) -> None: 223 | statistics["num_attributes"] = len(attributes) 224 | 225 | for ix, attribute in enumerate(attributes): 226 | self._use_status_callback_for_label_paraphrasers(status_callback, "attribute names", ix, len(attributes)) 227 | attribute[NaturalLanguageLabelSignal] = NaturalLanguageLabelSignal(attribute.name) 228 | statistics["num_label_copied"] += 1 229 | 230 | def to_config(self) -> Dict[str, Any]: 231 | return { 232 | "identifier": self.identifier 233 | } 234 | 235 | @classmethod 236 | def from_config(cls, config: Dict[str, Any]) -> "CopyAttributeNameLabelParaphraser": 237 | return cls() 238 | 239 | 240 | @register_configurable_element 241 | class SplitAttributeNameLabelParaphraser(BaseLabelParaphraser): 242 | """Label paraphraser that splits the attribute name to generate the natural language label signal.""" 243 | identifier: str = "SplitAttributeNameLabelParaphraser" 244 | 245 | required_signal_identifiers: Dict[str, List[str]] = { 246 | "nuggets": [], 247 | "attributes": [], 248 | "documents": [] 249 | } 250 | 251 | generated_signal_identifiers: Dict[str, List[str]] = { 252 | "nuggets": [], 253 | "attributes": [NaturalLanguageLabelSignal.identifier], 254 | "documents": [] 255 | } 256 | 257 | def __init__(self, do_lowercase: bool, splitters: List[str]) -> None: 258 | """ 259 | Initialize the SplitAttributeNameLabelParaphraser. 260 | 261 | :param do_lowercase: whether to lowercase the attribute names 262 | :param splitters: characters at which the attribute name should be split 263 | """ 264 | super(SplitAttributeNameLabelParaphraser, self).__init__() 265 | self._do_lowercase: bool = do_lowercase 266 | self._splitters: List[str] = splitters 267 | logger.debug(f"Initialized '{self.identifier}'.") 268 | 269 | def _paraphrase_attribute_names( 270 | self, 271 | attributes: List[Attribute], 272 | interaction_callback: BaseInteractionCallback, 273 | status_callback: BaseStatusCallback, 274 | statistics: Statistics 275 | ) -> None: 276 | statistics["num_attributes"] = len(attributes) 277 | 278 | for ix, attribute in enumerate(attributes): 279 | self._use_status_callback_for_label_paraphrasers(status_callback, "attribute names", ix, len(attributes)) 280 | 281 | # tokenize the label 282 | tokens: List[str] = [attribute.name] 283 | for splitter in self._splitters: 284 | new_tokens: List[str] = [] 285 | for token in tokens: 286 | new_tokens += token.split(splitter) 287 | tokens: List[str] = new_tokens 288 | 289 | # lowercase the tokens 290 | if self._do_lowercase: 291 | tokens: List[str] = [token.lower() for token in tokens] 292 | 293 | attribute[NaturalLanguageLabelSignal] = NaturalLanguageLabelSignal(" ".join(tokens)) 294 | if " ".join(tokens) == attribute.name: 295 | statistics["num_label_unchanged"] += 1 296 | else: 297 | statistics["num_label_changed"] += 1 298 | 299 | def to_config(self) -> Dict[str, Any]: 300 | return { 301 | "identifier": self.identifier, 302 | "do_lowercase": self._do_lowercase, 303 | "splitters": self._splitters 304 | } 305 | 306 | @classmethod 307 | def from_config(cls, config: Dict[str, Any]) -> "SplitAttributeNameLabelParaphraser": 308 | return cls(config["do_lowercase"], config["splitters"]) 309 | -------------------------------------------------------------------------------- /wannadb/preprocessing/normalization.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | from typing import Any, Dict, List 4 | 5 | from wannadb.configuration import register_configurable_element, BasePipelineElement 6 | from wannadb.data.data import DocumentBase, InformationNugget 7 | from wannadb.data.signals import LabelSignal, ValueSignal 8 | from wannadb.interaction import BaseInteractionCallback 9 | from wannadb.statistics import Statistics 10 | from wannadb.status import BaseStatusCallback 11 | 12 | logger: logging.Logger = logging.getLogger(__name__) 13 | 14 | 15 | class BaseNormalizer(BasePipelineElement, abc.ABC): 16 | """ 17 | Base class for all normalizers. 18 | 19 | Normalizers derive the value of an information nuggets from its mention text. 20 | """ 21 | identifier: str = "BaseNormalizer" 22 | 23 | 24 | ######################################################################################################################## 25 | # actual normalizers 26 | ######################################################################################################################## 27 | 28 | 29 | @register_configurable_element 30 | class CopyNormalizer(BaseNormalizer): 31 | """ 32 | Normalizer that simply uses the nuggets' mention texts as their values if the value signal does not already exist. 33 | """ 34 | identifier: str = "CopyNormalizer" 35 | 36 | required_signal_identifiers: Dict[str, List[str]] = { 37 | "nuggets": [], 38 | "attributes": [], 39 | "documents": [] 40 | } 41 | 42 | generated_signal_identifiers: Dict[str, List[str]] = { 43 | "nuggets": [ValueSignal.identifier], 44 | "attributes": [], 45 | "documents": [] 46 | } 47 | 48 | def __init__(self) -> None: 49 | super(CopyNormalizer, self).__init__() 50 | 51 | logger.debug(f"Initialized '{self.identifier}'.") 52 | 53 | def _call( 54 | self, 55 | document_base: DocumentBase, 56 | interaction_callback: BaseInteractionCallback, 57 | status_callback: BaseStatusCallback, 58 | statistics: Statistics 59 | ) -> None: 60 | nuggets: List[InformationNugget] = document_base.nuggets # document_base.nuggets has overhead 61 | statistics["num_nuggets"] = len(nuggets) 62 | 63 | for ix, nugget in enumerate(nuggets): 64 | self._use_status_callback(status_callback, ix, len(nuggets)) 65 | 66 | if ValueSignal.identifier not in nugget.signals.keys(): 67 | statistics["num_value_set"] += 1 68 | nugget[ValueSignal] = ValueSignal(nugget.text) 69 | else: 70 | statistics["num_value_already_exists"] += 1 71 | 72 | def to_config(self) -> Dict[str, Any]: 73 | return { 74 | "identifier": self.identifier 75 | } 76 | 77 | @classmethod 78 | def from_config(cls, config: Dict[str, Any]) -> "CopyNormalizer": 79 | return cls() 80 | 81 | 82 | @register_configurable_element 83 | class VerySimpleDateNormalizer(BaseNormalizer): 84 | identifier: str = "VerySimpleDateNormalizer" 85 | 86 | required_signal_identifiers: Dict[str, List[str]] = { 87 | "nuggets": [LabelSignal.identifier], 88 | "attributes": [], 89 | "documents": [] 90 | } 91 | 92 | generated_signal_identifiers: Dict[str, List[str]] = { 93 | "nuggets": [ValueSignal.identifier], 94 | "attributes": [], 95 | "documents": [] 96 | } 97 | 98 | def __init__(self) -> None: 99 | super(VerySimpleDateNormalizer, self).__init__() 100 | 101 | logger.debug(f"Initialized '{self.identifier}'.") 102 | 103 | def _call( 104 | self, 105 | document_base: DocumentBase, 106 | interaction_callback: BaseInteractionCallback, 107 | status_callback: BaseStatusCallback, 108 | statistics: Statistics 109 | ) -> None: 110 | nuggets: List[InformationNugget] = document_base.nuggets # document_base.nuggets has overhead 111 | statistics["num_nuggets"] = len(nuggets) 112 | statistics["date_value_failed"] = set() 113 | 114 | for ix, nugget in enumerate(nuggets): 115 | self._use_status_callback(status_callback, ix, len(nuggets)) 116 | 117 | if nugget[LabelSignal] == "DATE": 118 | year = nugget.text[-4:] 119 | 120 | month_mapping = { 121 | "January": "01", 122 | "February": "02", 123 | "March": "03", 124 | "April": "04", 125 | "May": "05", 126 | "June": "06", 127 | "July": "07", 128 | "August": "08", 129 | "September": "09", 130 | "October": "10", 131 | "November": "11", 132 | "December": "12" 133 | } 134 | if " " in nugget.text: 135 | month = nugget.text[:nugget.text.index(" ")] 136 | if month in month_mapping.keys() and " " in nugget.text and "," in nugget.text: 137 | month = month_mapping[month] 138 | day = nugget.text[nugget.text.index(" ") + 1:nugget.text.index(",")] 139 | day = day.rjust(2, "0") 140 | nugget[ValueSignal] = ValueSignal(f"{year}-{month}-{day}") 141 | statistics["num_date_value_set"] += 1 142 | continue 143 | 144 | nugget[ValueSignal] = ValueSignal(nugget.text) 145 | statistics["num_date_value_failed"] += 1 146 | statistics["date_value_failed"].add(nugget.text) 147 | else: 148 | nugget[ValueSignal] = ValueSignal(nugget.text) 149 | statistics["num_other_value_copied"] += 1 150 | 151 | def to_config(self) -> Dict[str, Any]: 152 | return { 153 | "identifier": self.identifier 154 | } 155 | 156 | @classmethod 157 | def from_config(cls, config: Dict[str, Any]) -> "VerySimpleDateNormalizer": 158 | return cls() 159 | -------------------------------------------------------------------------------- /wannadb/preprocessing/other_processing.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, List, Any 3 | 4 | from wannadb.configuration import BasePipelineElement, register_configurable_element 5 | from wannadb.data.data import DocumentBase, InformationNugget 6 | from wannadb.data.signals import CachedContextSentenceSignal, \ 7 | SentenceStartCharsSignal 8 | from wannadb.interaction import BaseInteractionCallback 9 | from wannadb.statistics import Statistics 10 | from wannadb.status import BaseStatusCallback 11 | 12 | logger: logging.Logger = logging.getLogger(__name__) 13 | 14 | 15 | @register_configurable_element 16 | class ContextSentenceCacher(BasePipelineElement): 17 | """Caches a nugget's context sentence.""" 18 | 19 | identifier: str = "ContextSentenceCacher" 20 | 21 | required_signal_identifiers: Dict[str, List[str]] = { 22 | "nuggets": [], 23 | "attributes": [], 24 | "documents": [SentenceStartCharsSignal.identifier] 25 | } 26 | 27 | generated_signal_identifiers: Dict[str, List[str]] = { 28 | "nuggets": [CachedContextSentenceSignal.identifier], 29 | "attributes": [], 30 | "documents": [] 31 | } 32 | 33 | def __init__(self): 34 | """Initialize the ContextSentenceCacher.""" 35 | super(ContextSentenceCacher, self).__init__() 36 | logger.debug(f"Initialized '{self.identifier}'.") 37 | 38 | def _call(self, document_base: DocumentBase, interaction_callback: BaseInteractionCallback, 39 | status_callback: BaseStatusCallback, statistics: Statistics) -> None: 40 | nuggets: List[InformationNugget] = document_base.nuggets 41 | statistics["num_nuggets"] = len(nuggets) 42 | 43 | for nugget in nuggets: 44 | sent_start_chars: List[int] = nugget.document[SentenceStartCharsSignal] 45 | context_start_char: int = 0 46 | context_end_char: int = 0 47 | for ix, sent_start_char in enumerate(sent_start_chars): 48 | if sent_start_char > nugget.start_char: 49 | if ix == 0: 50 | context_start_char: int = 0 51 | context_end_char: int = sent_start_char 52 | statistics["num_context_sentence_before_first_sentence"] += 1 53 | break 54 | else: 55 | context_start_char: int = sent_start_chars[ix - 1] 56 | context_end_char: int = sent_start_char 57 | statistics["num_context_sentence_is_first_or_inner_sentence"] += 1 58 | break 59 | else: 60 | if sent_start_chars != []: 61 | context_start_char: int = sent_start_chars[-1] 62 | context_end_char: int = len(nugget.document.text) 63 | statistics["num_context_sentence_is_final_sentence"] += 1 64 | 65 | context_sentence: str = nugget.document.text[context_start_char:context_end_char] 66 | start_in_context: int = nugget.start_char - context_start_char 67 | end_in_context: int = nugget.end_char - context_start_char 68 | 69 | nugget[CachedContextSentenceSignal] = CachedContextSentenceSignal({ 70 | "text": context_sentence, 71 | "start_char": start_in_context, 72 | "end_char": end_in_context 73 | }) 74 | 75 | def to_config(self) -> Dict[str, Any]: 76 | return { 77 | "identifier": self.identifier 78 | } 79 | 80 | @classmethod 81 | def from_config(cls, config: Dict[str, Any]) -> "ContextSentenceCacher": 82 | return cls() 83 | -------------------------------------------------------------------------------- /wannadb/querying/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataManagementLab/wannadb/05b36dbf87a66d0e022b4638d764cbdbc8e2c69d/wannadb/querying/__init__.py -------------------------------------------------------------------------------- /wannadb/querying/grouping.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import itertools 3 | import logging 4 | from typing import Dict, Any, List, Tuple, Set 5 | 6 | import numpy as np 7 | 8 | from wannadb.configuration import BasePipelineElement, register_configurable_element 9 | from wannadb.data.data import DocumentBase, Attribute, InformationNugget 10 | from wannadb.interaction import BaseInteractionCallback 11 | from wannadb.matching.distance import BaseDistance 12 | from wannadb.statistics import Statistics 13 | from wannadb.status import BaseStatusCallback 14 | 15 | logger: logging.Logger = logging.getLogger(__name__) 16 | 17 | 18 | class BaseGrouper(BasePipelineElement, abc.ABC): 19 | """ 20 | Base class for all groupers. 21 | 22 | A grouper can group rows based on their matches for an attribute. 23 | """ 24 | identifier: str = "BaseGrouper" 25 | 26 | 27 | @register_configurable_element 28 | class MergeGrouper(BaseGrouper): 29 | """Grouper that works by interactively merging groups based on the distances between the nuggets.""" 30 | identifier: str = "MergeGrouper" 31 | 32 | required_signal_identifiers: Dict[str, List[str]] = { # TODO 33 | "nuggets": [], 34 | "attributes": [], 35 | "documents": [] 36 | } 37 | 38 | generated_signal_identifiers: Dict[str, List[str]] = { 39 | "nuggets": [], 40 | "attributes": [], 41 | "documents": [] 42 | } 43 | 44 | def __init__( 45 | self, 46 | distance: BaseDistance, 47 | max_tries_no_merge: int, 48 | skip: int, 49 | automatically_merge_same_surface_form: bool 50 | ) -> None: 51 | """ 52 | Initialize the MergeGrouper. 53 | 54 | :param distance: distance function 55 | :param max_tries_no_merge: number of tries that are confirmed to not be merges before stopping 56 | :param skip: number of pairs to skip feedback on 57 | :param automatically_merge_same_surface_form: whether to automatically merge nuggets with the same surface form 58 | """ 59 | super(MergeGrouper, self).__init__() 60 | self._distance: BaseDistance = distance 61 | self._max_tries_no_merge: int = max_tries_no_merge 62 | self._skip: int = skip 63 | self._automatically_merge_same_surface_form: bool = automatically_merge_same_surface_form 64 | 65 | # add signals required by the distance function to the signals required by the matcher 66 | self._add_required_signal_identifiers(self._distance.required_signal_identifiers) 67 | 68 | logger.debug(f"Initialized {self.identifier}.") 69 | 70 | def _call( 71 | self, 72 | document_base: DocumentBase, 73 | interaction_callback: BaseInteractionCallback, 74 | status_callback: BaseStatusCallback, 75 | statistics: Statistics 76 | ) -> None: 77 | 78 | # decide on attribute to group 79 | feedback_result: Dict[str, Any] = interaction_callback( 80 | self.identifier, 81 | { 82 | "request-name": "get-attribute", 83 | "attributes": document_base.attributes 84 | } 85 | ) 86 | 87 | attribute: Attribute = feedback_result["attribute"] 88 | statistics["attribute_name"] = attribute.name 89 | 90 | # start clustering: every nugget is in its own cluster 91 | nuggets: List[InformationNugget] = [] 92 | for matching_nuggets in document_base.get_column_for_attribute(attribute): 93 | if matching_nuggets is None: 94 | logger.error(f"Document does not know of attribute '{attribute.name}'!") 95 | assert False, f"Document does not know of attribute '{attribute.name}'!" 96 | else: 97 | if matching_nuggets != []: 98 | nuggets.append(matching_nuggets[0]) # only consider the first match 99 | 100 | clusters: Dict[int, List[InformationNugget]] = { 101 | idx: [nugget] for idx, nugget in enumerate(nuggets) 102 | } 103 | confirmed_as_distinct: Set[Tuple[int, int]] = set() 104 | 105 | inter_cluster_distances: np.ndarray = self._distance.compute_distances(nuggets, nuggets, statistics["distance"]) 106 | 107 | def merge_clusters( 108 | index_a: int, 109 | index_b: int, 110 | clusters: Dict[int, List[InformationNugget]], 111 | inter_cluster_distances: np.ndarray, 112 | confirmed_as_distinct: Set[Tuple[int, int]] 113 | ): 114 | """Merges the two clusters into the cluster with index_a.""" 115 | clusters[index_a] += clusters[index_b] 116 | del clusters[index_b] 117 | 118 | # replace index_b with index_a in confirmed_as_distinct 119 | new_confirmed_as_distinct: Set[Tuple[int, int]] = set() 120 | for ix_a, ix_b in confirmed_as_distinct: 121 | if ix_a == index_b: 122 | new_confirmed_as_distinct.add((index_a, ix_b)) 123 | if ix_b == index_b: 124 | new_confirmed_as_distinct.add((ix_a, index_a)) 125 | else: 126 | new_confirmed_as_distinct.add((ix_a, ix_b)) 127 | confirmed_as_distinct = new_confirmed_as_distinct 128 | 129 | # choose lower distances as distances for index_a 130 | for ix_b in clusters.keys(): 131 | min_val = min(inter_cluster_distances[index_a, ix_b], inter_cluster_distances[index_b, ix_b]) 132 | inter_cluster_distances[index_a, ix_b] = min_val 133 | inter_cluster_distances[ix_b, index_a] = min_val 134 | 135 | return clusters, inter_cluster_distances, confirmed_as_distinct 136 | 137 | # merge by surface form 138 | if self._automatically_merge_same_surface_form: 139 | for ix_a, ix_b in itertools.product(clusters.keys(), clusters.keys()): 140 | if ix_a != ix_b and ix_a in clusters.keys() and ix_b in clusters.keys(): 141 | if clusters[ix_a][0].text == clusters[ix_b][0].text: 142 | clusters, inter_cluster_distances, confirmed_as_distinct = merge_clusters( 143 | ix_a, ix_b, clusters, inter_cluster_distances, confirmed_as_distinct 144 | ) 145 | statistics["num_merges_same_surface_form"] += 1 146 | 147 | # merge interactively 148 | num_not_same_cluster: int = 0 149 | current_skip: int = self._skip 150 | while len(clusters.keys()) > 1 and num_not_same_cluster < self._max_tries_no_merge: 151 | 152 | # determine the pair to present to the user for feedback 153 | pairs_and_distances: List[Tuple[Tuple[int, int], float]] = [] 154 | for ix_a, ix_b in itertools.product(clusters.keys(), clusters.keys()): 155 | if ix_a < ix_b and (ix_a, ix_b) not in confirmed_as_distinct and \ 156 | (ix_b, ix_a) not in confirmed_as_distinct: 157 | pairs_and_distances.append(((ix_a, ix_b), inter_cluster_distances[ix_a, ix_b])) 158 | 159 | if pairs_and_distances == []: 160 | logger.info("No more clusters can be merged!") 161 | break 162 | 163 | pairs_and_distances = list(sorted(pairs_and_distances, key=lambda x: x[1])) 164 | right: int = min(current_skip + 1, len(pairs_and_distances)) 165 | pairs: List[Tuple[int, int]] = [pair_and_distance[0] for pair_and_distance in pairs_and_distances[:right]] 166 | idx_a, idx_b = pairs[-1] 167 | 168 | # ask the user for feedback 169 | statistics["num_feedback"] += 1 170 | statistics[f"num_feedback_at_skip_{current_skip}"] += 1 171 | feedback_result: Dict[str, Any] = interaction_callback( 172 | self.identifier, 173 | { 174 | "request-name": "same-cluster-feedback", 175 | "cluster-1": clusters[idx_a], 176 | "cluster-2": clusters[idx_b], 177 | "inter-cluster-distance": inter_cluster_distances[idx_a, idx_b], 178 | "clusters": list(clusters.values()) 179 | } 180 | ) 181 | 182 | num_merged: int = 0 183 | if feedback_result["feedback"]: # feedback ==> the two clusters are the same 184 | statistics["num_feedback_same_cluster"] += 1 185 | statistics[f"num_feedback_same_cluster_at_skip_{current_skip}"] += 1 186 | statistics["num_confirmed_merges"] += 1 187 | num_not_same_cluster = 0 188 | current_skip = self._skip 189 | 190 | confirmed = True 191 | while pairs != []: 192 | idx_a, idx_b = pairs[-1] 193 | 194 | if (idx_a, idx_b) not in confirmed_as_distinct and (idx_b, idx_a) not in confirmed_as_distinct: 195 | # first merge is confirmed, rest is guessed 196 | if not confirmed: 197 | statistics["num_guessed_merges"] += 1 198 | confirmed = False 199 | 200 | # merge the two clusters into idx_a 201 | num_merged += 1 202 | clusters, inter_cluster_distances, confirmed_as_distinct = merge_clusters( 203 | idx_a, idx_b, clusters, inter_cluster_distances, confirmed_as_distinct 204 | ) 205 | 206 | # replace idx_b with idx_a in pairs and remove current pair 207 | new_pairs: List[Tuple[int, int]] = [] 208 | for ix_a, ix_b in pairs[:-1]: 209 | if ix_a == idx_b: 210 | if ix_b != idx_a: 211 | new_pairs.append((idx_a, ix_b)) 212 | elif ix_b == idx_b: 213 | if ix_a != idx_a: 214 | new_pairs.append((ix_a, idx_a)) 215 | else: 216 | new_pairs.append((ix_a, ix_b)) 217 | pairs = new_pairs 218 | else: # guessed match blocked by confirmed_not_same_indexes 219 | statistics["num_blocked_guessed_merges"] += 1 220 | 221 | else: 222 | statistics["num_feedback_not_same_cluster"] += 1 223 | statistics[f"num_feedback_not_same_cluster_at_skip_{current_skip}"] += 1 224 | num_not_same_cluster += 1 225 | current_skip = current_skip // 2 226 | confirmed_as_distinct.add((idx_a, idx_b)) 227 | 228 | logger.info(f"Number of clusters merged in this step: {num_merged}") 229 | logger.info(f"Number of remaining clusters: {len(clusters.keys())}") 230 | 231 | feedback_result: Dict[str, Any] = interaction_callback( 232 | self.identifier, 233 | { 234 | "request-name": "output-clusters", 235 | "clusters": clusters 236 | } 237 | ) 238 | 239 | def to_config(self) -> Dict[str, Any]: 240 | return { 241 | "identifier": self.identifier, 242 | "distance": self._distance.to_config(), 243 | "max_tries_no_merge": self._max_tries_no_merge, 244 | "skip": self._skip, 245 | "automatically_merge_same_surface_form": self._automatically_merge_same_surface_form 246 | } 247 | 248 | @classmethod 249 | def from_config(cls, config: Dict[str, Any]) -> "MergeGrouper": 250 | distance: BaseDistance = BaseDistance.from_config(config["distance"]) 251 | return cls(distance, config["max_tries_no_merge"], config["skip"], 252 | config["automatically_merge_same_surface_form"]) 253 | -------------------------------------------------------------------------------- /wannadb/statistics.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from typing import Any, Dict, List, Optional, Union 4 | 5 | logger: logging.Logger = logging.getLogger(__name__) 6 | 7 | 8 | class Statistics: 9 | """ 10 | Statistics to collect information during execution. 11 | 12 | A statistics object allows the pipeline and its pipeline elements to record information during their execution. The 13 | statistics object is provided to the pipeline or pipeline element when applying it to the document base. 14 | 15 | In contrast to a basic Python dictionary, the statistics object can easily be configured as to whether it actually 16 | should record any information. This unclutters the code required in the pipeline / pipeline element implementation. 17 | Furthermore, it can be used as a counter for integer and float values without requiring initialization. 18 | """ 19 | 20 | def __init__(self, do_collect: bool = True) -> None: 21 | """ 22 | Initialize the Statistics. 23 | 24 | :param do_collect: whether to actually collect statistics 25 | """ 26 | super(Statistics, self).__init__() 27 | self._do_collect: bool = do_collect 28 | if self._do_collect: 29 | self._entries: Optional[Dict[str, Union[Statistics, Any]]] = {} 30 | else: 31 | self._entries: Optional[Dict[str, Union[Statistics, Any]]] = None 32 | 33 | def __str__(self) -> str: 34 | if self._do_collect: 35 | return json.dumps(self.to_serializable(), indent=4) 36 | else: 37 | return "not-collecting-statistics" 38 | 39 | def __repr__(self) -> str: 40 | return f"Statistics({self._do_collect})" 41 | 42 | def __eq__(self, other: Any) -> bool: 43 | return isinstance(other, Statistics) and \ 44 | self._do_collect == other._do_collect and self._entries == other._entries 45 | 46 | def __getitem__(self, item: str) -> Union["Statistics", Any]: 47 | if self._do_collect: 48 | if item not in self._entries.keys(): 49 | self._entries[item] = Statistics(True) 50 | return self._entries[item] 51 | else: 52 | return Statistics(False) 53 | 54 | def __setitem__(self, key: str, value: Union["Statistics", Any]) -> None: 55 | if self._do_collect: 56 | self._entries[key] = value 57 | 58 | def __iadd__(self, other: Union[int, float]) -> Union[int, float]: 59 | return other 60 | 61 | def __isub__(self, other: Union[int, float]) -> Union[int, float]: 62 | return -other 63 | 64 | def add(self, other): 65 | # dummy method in case this is a no-collect statistics object that replaces a set 66 | pass 67 | 68 | def append(self, other): 69 | # dummy method in case this is a no-collect statistics object that replaces a list 70 | pass 71 | 72 | def all_keys(self) -> List[str]: 73 | return list(self._entries.keys()) 74 | 75 | def all_values(self) -> List[Union["Statistics", Any]]: 76 | return list(self._entries.values()) 77 | 78 | def to_serializable(self) -> Dict[str, Any]: 79 | if self._do_collect: 80 | d: Dict[str, Union[Dict, Any]] = {} 81 | for key, entry in self._entries.items(): 82 | if isinstance(entry, Statistics): 83 | d[key] = entry.to_serializable() 84 | elif isinstance(entry, set): 85 | d[key] = list(entry) 86 | else: 87 | d[key] = entry 88 | return d 89 | else: 90 | return {"message": "not-collecting-statistics"} 91 | -------------------------------------------------------------------------------- /wannadb/status.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | from typing import Callable 4 | 5 | logger: logging.Logger = logging.getLogger(__name__) 6 | 7 | 8 | class BaseStatusCallback(abc.ABC): 9 | """ 10 | Base class for all status callbacks. 11 | 12 | A status callback allows the pipeline and its pipeline elements to convey status updates to the user interface. The 13 | status callback is provided to the pipeline or pipeline element when applying it to the document base. The pipeline 14 | or pipeline element calls ('__call__') the status callback to convey a status update. 15 | 16 | The status information comprises a message string and a float progress indicator. 17 | """ 18 | 19 | def __call__(self, message: str, progress: float) -> None: 20 | """ 21 | Convey a status update from the pipeline or pipeline element to the user interface 22 | 23 | This method is called by the pipeline element and calls the _call method that contains the actual implementation 24 | of the status callback. 25 | 26 | :param message: status message 27 | :param progress: progress indicator (either between 0.0 and 1.0 or -1 if the progress is unclear) 28 | """ 29 | if progress == -1: 30 | logger.info(f"{message} ~%") 31 | else: 32 | logger.info(f"{message} {round(progress * 100)}%") 33 | 34 | self._call(message, progress) 35 | 36 | @abc.abstractmethod 37 | def _call(self, message: str, progress: float) -> None: 38 | """ 39 | Convey a status update from the pipeline or pipeline element to the user interface 40 | 41 | This method is overwritten by the actual status callbacks and contains their implementation. 42 | 43 | :param message: status message 44 | :param progress: progress indicator (either between 0.0 and 1.0 or -1 if the progress is unclear) 45 | """ 46 | raise NotImplementedError 47 | 48 | 49 | class StatusCallback(BaseStatusCallback): 50 | """Status callback that is initialized with a callback function.""" 51 | 52 | def __init__(self, callback_fn: Callable[[str, float], None]): 53 | """ 54 | Initialize the status callback. 55 | 56 | :param callback_fn: callback function that is called whenever the interaction callback is called 57 | """ 58 | self._callback_fn: Callable[[str, float], None] = callback_fn 59 | 60 | def _call(self, message: str, progress: float) -> None: 61 | return self._callback_fn(message, progress) 62 | 63 | 64 | class EmptyStatusCallback(BaseStatusCallback): 65 | """Status callback that does nothing whenever it is called.""" 66 | 67 | def _call(self, message: str, progress: float) -> None: 68 | pass 69 | -------------------------------------------------------------------------------- /wannadb_parsql/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataManagementLab/wannadb/05b36dbf87a66d0e022b4638d764cbdbc8e2c69d/wannadb_parsql/__init__.py -------------------------------------------------------------------------------- /wannadb_parsql/cache_db.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sqlite3 3 | from pathlib import Path 4 | from sqlite3 import Error 5 | from typing import List, Any, Generator 6 | 7 | import pandas as pd 8 | 9 | from .parsql import ColumnToken 10 | from .rewrite import DOCUMENT_ID 11 | from .sql_tokens import UNKNOWN_TYPE 12 | 13 | 14 | def create_connection(db_file): 15 | """ create a database connection to the SQLite database 16 | specified by db_file 17 | :param db_file: database file 18 | :return: Connection object or None 19 | """ 20 | conn = None 21 | try: 22 | conn = sqlite3.connect(db_file, check_same_thread=False) 23 | conn.row_factory = sqlite3.Row 24 | return conn 25 | except Error as e: 26 | print(e) 27 | 28 | return conn 29 | 30 | 31 | def alter_table(conn, entry): 32 | if entry["type"] is None: 33 | entry["type"] = 'text' 34 | sql = ''' ALTER TABLE {} ADD COLUMN {} {}'''.format(entry["table"], entry["attribute"], entry["type"]) 35 | cur = conn.cursor() 36 | cur.execute(sql) 37 | conn.commit() 38 | return cur.lastrowid 39 | 40 | 41 | class SQLiteCacheDB: 42 | 43 | def __init__(self, db_file="wannadb_cache.db"): 44 | self.db_file = db_file 45 | self.conn = create_connection(self.db_file) 46 | 47 | def existing_tables(self): 48 | c = self.conn.cursor() 49 | c.execute("select name from sqlite_master where type = 'table';") 50 | # name of all existing tables 51 | return [str(row[0]).lower() for row in c.fetchall()] 52 | 53 | def table_empty(self, attribute_name): 54 | c = self.conn.cursor() 55 | c.execute(f'SELECT * FROM {attribute_name}') 56 | return c.fetchone() is None 57 | 58 | def create_tables(self, attributes: List[ColumnToken]): 59 | for attribute in attributes: 60 | self.create_table(attribute) 61 | 62 | def create_table_by_name(self, attribute_name): 63 | self.create_table(ColumnToken(attribute_name, UNKNOWN_TYPE)) 64 | 65 | def create_table(self, attribute: ColumnToken): 66 | if self.conn is None: 67 | raise EnvironmentError("No database connection found.") 68 | data = [] 69 | 70 | data.append({"table": attribute.name, "attribute": "value", "type": attribute.datatype}) 71 | data.append({"table": attribute.name, "attribute": DOCUMENT_ID, "type": "integer"}) 72 | 73 | tableName = attribute.name 74 | 75 | try: 76 | c = self.conn.cursor() 77 | c.execute( 78 | ''' SELECT count(name) FROM sqlite_master WHERE type='table' AND name='{}' '''.format(tableName)) 79 | except Error as e: 80 | print(e) 81 | 82 | # if the count is 1, then table exists and does not need to be created again 83 | if c.fetchone()[0] != 1: 84 | 85 | sql_create_table = """ CREATE TABLE IF NOT EXISTS {} ( 86 | id integer PRIMARY KEY 87 | )""".format(tableName) 88 | 89 | try: 90 | c = self.conn.cursor() 91 | c.execute(sql_create_table) 92 | except Error as e: 93 | print(e) 94 | 95 | with self.conn: 96 | for entry in data: 97 | # Insert the attributes 98 | alter_table(self.conn, entry) 99 | 100 | def create_input_docs_table(self, table_name, documents): 101 | self.create_table_by_name(table_name) 102 | self.store_many(table_name, ((i, Path(doc.name).name) for i, doc in enumerate(documents))) 103 | 104 | def delete_tables(self, attributes: List[ColumnToken]): 105 | c = self.conn.cursor() 106 | for attribute in attributes: 107 | c.execute(''' DELETE FROM {} '''.format(attribute.name)) 108 | 109 | def delete_table(self, attribute): 110 | self.conn.execute(f"DROP TABLE IF EXISTS {attribute}") 111 | 112 | def execute_queries(self, *queries) -> List[pd.DataFrame]: 113 | res = [] 114 | 115 | for query in queries: 116 | cur = self.conn.cursor() 117 | cur.execute(query) 118 | 119 | data = [] 120 | for row in cur.fetchall(): 121 | entry = {} 122 | for key in row.keys(): 123 | entry[key] = row[key] 124 | data.append(entry) 125 | res.append(pd.DataFrame(data)) 126 | 127 | return res 128 | 129 | def store_many(self, attr, iter: Generator[tuple[int, Any], None, None]): 130 | self.conn.executemany(f"INSERT INTO {attr}({DOCUMENT_ID}, value) VALUES (?, ?)", iter) 131 | 132 | def store_and_split_entry(self, data): 133 | for doc_idx, item in enumerate(data): 134 | 135 | question_marks = ",".join(["?"] * 2) 136 | attrs = [] 137 | val = [] 138 | for attribute, value in item.items(): 139 | attrs = [DOCUMENT_ID, "value"] 140 | val = [doc_idx, value] 141 | sql = ' INSERT INTO ' + attribute + "(" + ",".join(attrs) + ")" + ' VALUES(' + question_marks + ')' 142 | 143 | cur = self.conn.cursor() 144 | cur.execute(sql, val) 145 | self.conn.commit() 146 | 147 | def drop_all_and_reconnect(self): 148 | self.conn.close() 149 | os.remove(self.db_file) 150 | # creating a new connection will recreate the DB file 151 | self.conn = create_connection(self.db_file) 152 | -------------------------------------------------------------------------------- /wannadb_parsql/rewrite.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union, Tuple 2 | 3 | from wannadb_parsql.parsql import ColumnToken, SQLGroupType, SQLStatement, SQLToken, SQLTokenGroup, Parser 4 | 5 | DOCUMENT_ID = "doc_id" 6 | 7 | 8 | def rewrite_query(columns: List[ColumnToken], parsed: SQLStatement) -> Tuple[List[ColumnToken], str]: 9 | """Rewrites a given user-specified SQL query to a valid SQL query that can be executed on the WannaDB cache DB. 10 | Note that this method rewrites the given query in-place. 11 | 12 | :param columns: List of tokens that represents the attributes in the query. 13 | :type columns: List[ColumnToken] 14 | :param parsed: SQL query as returned by Parser.parse(). 15 | :type parsed: SQLStatement 16 | :return: Tuple containing the original (not rewritten) columns and the rewritten SQL query. 17 | :rtype: Union[List[ColumnToken], str] 18 | """ 19 | # Preserve original column names 20 | attributes = list(map(lambda column: column.name, columns)) 21 | # Rewrite FROM clause 22 | _rewrite_from_clause(parsed, attributes) 23 | # Rewrite column names: "SELECT name ..." -> "SELECT name.value as name ..." 24 | _rewrite_columns(parsed) 25 | # Add a ";" at the very end if not present 26 | if ";" not in parsed.groups[-1].tokens[-1].name: 27 | parsed.groups[-1].tokens.append(SQLToken(";")) 28 | 29 | return columns, str(parsed) 30 | 31 | 32 | def _rewrite_from_clause(parsed, attributes): 33 | if not isinstance(parsed, SQLStatement): 34 | raise ValueError("Given argument is not a SQLStatement.") 35 | 36 | found_from = False 37 | for group in parsed.groups: 38 | if group.group_type == SQLGroupType.FROM: 39 | found_from = True 40 | subquery = _get_subquery(group) 41 | if subquery is not None: 42 | _rewrite_from_clause(subquery, attributes) 43 | else: 44 | group.tokens = _build_new_from_clause(attributes) 45 | 46 | if not found_from: 47 | # In case there was only a select clause, it might end with a ";". Make sure to remove it. 48 | if len(parsed.groups) == 1 and ";" in parsed.groups[0].tokens[-1].name: 49 | parsed.groups[0].tokens = parsed.groups[0].tokens[:-1] 50 | tokens = _build_new_from_clause(attributes) 51 | parsed.groups.insert(1, SQLTokenGroup(tokens, SQLGroupType.FROM)) 52 | 53 | 54 | def _rewrite_columns(parsed): 55 | if not isinstance(parsed, SQLStatement): 56 | raise ValueError("Given argument is not a SQLStatement.") 57 | 58 | # First look if we have a FROM clause that might hold a subquery. 59 | from_clause = next(filter(lambda group: group.group_type == SQLGroupType.FROM, parsed.groups), None) 60 | if from_clause is not None: 61 | subquery = _get_subquery(from_clause) 62 | else: 63 | subquery = None 64 | 65 | # We only want to rewrite the subquery furthest down in the hierarchy. 66 | if subquery is not None: 67 | _rewrite_columns(subquery) 68 | else: 69 | for group in parsed.groups: 70 | new_tokens = [] 71 | for token in group.tokens: 72 | if isinstance(token, ColumnToken): 73 | # If column stands within a SELECT clause, we need to preserve the original name via "as " 74 | if group.group_type == SQLGroupType.SELECT: 75 | new_name = f"{token.name}.value as {token.name}" 76 | else: 77 | new_name = f"{token.name}.value" 78 | new_tokens.append(ColumnToken(new_name, token.datatype)) 79 | else: 80 | new_tokens.append(token) 81 | group.tokens = new_tokens 82 | 83 | 84 | def _get_subquery(group: SQLTokenGroup) -> Optional[SQLStatement]: 85 | for token in group.tokens: 86 | if isinstance(token, SQLStatement): 87 | return token 88 | return None 89 | 90 | 91 | def _build_new_from_clause(attributes) -> List[SQLToken]: 92 | tokens = [SQLToken("FROM")] 93 | for i, attribute in enumerate(attributes): 94 | if i != 0: 95 | # " INNER JOIN USING ()" 96 | tokens += [ 97 | SQLToken("LEFT JOIN"), 98 | SQLToken(attribute), 99 | SQLToken("USING"), 100 | SQLToken("("), 101 | SQLToken(DOCUMENT_ID), 102 | SQLToken(")"), 103 | ] 104 | else: 105 | tokens.append(SQLToken(attribute)) 106 | return tokens 107 | 108 | 109 | def update_query_attribute_list(parsed_query, new_attributes_list: List[str]) -> str: 110 | _, parsed_attrs_only = Parser().parse(f"SELECT {', '.join(new_attributes_list)}") 111 | 112 | for sql_token_group in parsed_query.groups: 113 | if sql_token_group.group_type == SQLGroupType.SELECT: 114 | sql_token_group.tokens = parsed_attrs_only.groups[0].tokens 115 | 116 | return str(parsed_query).replace(" ,", ",") 117 | -------------------------------------------------------------------------------- /wannadb_parsql/sql_tokens.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import List 3 | 4 | import sqlparse.tokens as T 5 | 6 | STRING_TYPE = "string" 7 | NUMERIC_TYPE = "numeric" 8 | UNKNOWN_TYPE = "unknown" 9 | FUNC_TYPE_MAP = { 10 | "AVG": NUMERIC_TYPE, 11 | "COUNT": NUMERIC_TYPE, 12 | "MIN": NUMERIC_TYPE, 13 | "MAX": NUMERIC_TYPE, 14 | "SUM": NUMERIC_TYPE, 15 | } 16 | TTYPE_MAP = { 17 | T.String.Single: STRING_TYPE, 18 | T.Number.Integer: NUMERIC_TYPE, 19 | } 20 | 21 | 22 | def translate_datatype(ttype): 23 | if not ttype: 24 | return UNKNOWN_TYPE 25 | if ttype in TTYPE_MAP: 26 | return TTYPE_MAP[ttype] 27 | return UNKNOWN_TYPE 28 | 29 | 30 | def get_func_type(function_name: str): 31 | if not function_name: 32 | return UNKNOWN_TYPE 33 | function_name = function_name.upper() 34 | if function_name in FUNC_TYPE_MAP: 35 | return FUNC_TYPE_MAP[function_name] 36 | return UNKNOWN_TYPE 37 | 38 | 39 | class SQLToken: 40 | def __init__(self, name: str): 41 | self._name = name 42 | 43 | def __str__(self): 44 | return self._name 45 | 46 | def __repr__(self): 47 | return f'' 48 | 49 | @staticmethod 50 | def is_column(): 51 | return False 52 | 53 | @property 54 | def name(self): 55 | return self._name 56 | 57 | @name.setter 58 | def name(self, new_name): 59 | self._name = new_name 60 | 61 | 62 | class ColumnToken(SQLToken): 63 | def __init__(self, value: str, datatype: str): 64 | super(ColumnToken, self).__init__(value) 65 | self.datatype = datatype 66 | 67 | def __str__(self): 68 | return self._name 69 | 70 | def __repr__(self): 71 | return f"" 72 | 73 | @staticmethod 74 | def is_column(): 75 | return True 76 | 77 | 78 | class SQLGroupType(Enum): 79 | SELECT = 'SELECT-Group' 80 | FROM = 'FROM-Group' 81 | WHERE = 'WHERE-Group' 82 | GROUP_BY = 'GROUP_BY-Group' 83 | HAVING = 'HAVING-Group' 84 | ORDER_BY = 'ORDER_BY-Group' 85 | LIMIT = 'LIMIT-Group' 86 | SEMICOLON = 'SEMICOLON-Group' 87 | 88 | 89 | class SQLTokenGroup: 90 | def __init__(self, tokens, group_type: SQLGroupType): 91 | self.tokens = tokens 92 | self.group_type = group_type 93 | 94 | def __str__(self): 95 | return ' '.join([str(token) for token in self.tokens]) 96 | 97 | def __repr__(self): 98 | return f'<{self.group_type.value}, {self.tokens}>' 99 | 100 | 101 | class SQLStatement: 102 | def __init__(self, name='Query'): 103 | self.groups: List[SQLTokenGroup] = [] 104 | self.name = name 105 | 106 | def __str__(self): 107 | return ' '.join([str(group) for group in self.groups]) 108 | 109 | def __repr__(self): 110 | return f'<{self.name}, {self.groups}>' 111 | 112 | def __iter__(self): 113 | return iter(self.groups) 114 | 115 | def append(self, group: SQLTokenGroup): 116 | self.groups.append(group) 117 | 118 | def empty(self): 119 | return len(self.groups) == 0 120 | -------------------------------------------------------------------------------- /wannadb_ui/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataManagementLab/wannadb/05b36dbf87a66d0e022b4638d764cbdbc8e2c69d/wannadb_ui/__init__.py -------------------------------------------------------------------------------- /wannadb_ui/common.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from PyQt6.QtCore import Qt 4 | from PyQt6.QtGui import QFont 5 | from PyQt6.QtWidgets import QWidget, QVBoxLayout, QLabel, QScrollArea, QFrame, QHBoxLayout, QDialog, QPushButton 6 | 7 | # fonts 8 | HEADER_FONT = QFont("Segoe UI", pointSize=20, weight=QFont.Weight.Bold) 9 | SUBHEADER_FONT = QFont("Segoe UI", pointSize=14, weight=QFont.Weight.DemiBold) 10 | LABEL_FONT = QFont("Segoe UI", pointSize=11) 11 | LABEL_FONT_BOLD = QFont("Segoe UI", pointSize=11, weight=QFont.Weight.Bold) 12 | LABEL_FONT_ITALIC = QFont("Segoe UI", pointSize=11, italic=True) 13 | CODE_FONT = QFont("Consolas", pointSize=12) 14 | CODE_FONT_SMALLER = QFont("Consolas", pointSize=10) 15 | CODE_FONT_BOLD = QFont("Consolas", pointSize=12, weight=QFont.Weight.Bold) 16 | MENU_FONT = QFont("Segoe UI", pointSize=11) 17 | STATUS_BAR_FONT = QFont("Segoe UI", pointSize=11) 18 | STATUS_BAR_FONT_BOLD = QFont("Segoe UI", pointSize=11, weight=QFont.Weight.Bold) 19 | BUTTON_FONT = QFont("Segoe UI", pointSize=11) 20 | 21 | # colors 22 | WHITE = "#FFFFFF" 23 | BLACK = "#000000" 24 | 25 | YELLOW = "#FEC306" 26 | LIGHT_YELLOW = "#ffefca" 27 | ORANGE = "#F69200" 28 | LIGHT_ORANGE = "#ffe3c6" 29 | RED = "#DF5327" 30 | LIGHT_RED = "#ffd5c6" 31 | BLUE = "#418AB3" 32 | LIGHT_BLUE = "#d3e1ec" 33 | GREEN = "#A6B727" 34 | LIGHT_GREEN = "#ececcb" 35 | 36 | INPUT_DOCS_COLUMN_NAME = "input_document" 37 | 38 | 39 | class MainWindowContent(QWidget): 40 | 41 | def __init__(self, main_window, header_text): 42 | super(MainWindowContent, self).__init__() 43 | self.main_window = main_window 44 | 45 | self.layout = QVBoxLayout(self) 46 | self.layout.setContentsMargins(10, 5, 10, 5) 47 | self.layout.setAlignment(Qt.AlignmentFlag.AlignTop) 48 | self.layout.setSpacing(20) 49 | 50 | self.top_widget = QWidget() 51 | self.top_widget_layout = QHBoxLayout(self.top_widget) 52 | self.top_widget_layout.setContentsMargins(0, 0, 0, 0) 53 | self.layout.addWidget(self.top_widget) 54 | 55 | self.header = QLabel(header_text) 56 | self.header.setFont(HEADER_FONT) 57 | self.top_widget_layout.addWidget(self.header, alignment=Qt.AlignmentFlag.AlignLeft) 58 | 59 | self.controls_widget = QWidget() 60 | self.controls_widget_layout = QHBoxLayout(self.controls_widget) 61 | self.controls_widget_layout.setContentsMargins(0, 0, 0, 0) 62 | self.top_widget_layout.addWidget(self.controls_widget, alignment=Qt.AlignmentFlag.AlignRight) 63 | 64 | @abc.abstractmethod 65 | def enable_input(self): 66 | raise NotImplementedError 67 | 68 | @abc.abstractmethod 69 | def disable_input(self): 70 | raise NotImplementedError 71 | 72 | 73 | class MainWindowContentSection(QWidget): 74 | 75 | def __init__(self, main_window_content, sub_header_text): 76 | super(MainWindowContentSection, self).__init__() 77 | self.main_window_content = main_window_content 78 | 79 | self.layout = QVBoxLayout(self) 80 | self.layout.setContentsMargins(0, 0, 0, 0) 81 | self.layout.setAlignment(Qt.AlignmentFlag.AlignTop) 82 | self.layout.setSpacing(10) 83 | 84 | self.sub_header = QLabel(sub_header_text) 85 | self.sub_header.setFont(SUBHEADER_FONT) 86 | self.layout.addWidget(self.sub_header) 87 | 88 | 89 | class CustomScrollableList(QWidget): 90 | 91 | def __init__(self, parent, item_type, floating_widget=None, orientation="vertical", above_widget=None): 92 | super(CustomScrollableList, self).__init__() 93 | self.parent = parent 94 | self.item_type = item_type 95 | self.floating_widget = floating_widget 96 | self.above_widget = above_widget 97 | 98 | self.layout = QVBoxLayout(self) 99 | self.layout.setContentsMargins(0, 0, 0, 0) 100 | 101 | self.list_widget = QWidget() 102 | if orientation == "vertical": 103 | self.list_layout = QVBoxLayout(self.list_widget) 104 | self.list_layout.setAlignment(Qt.AlignmentFlag.AlignTop) 105 | elif orientation == "horizontal": 106 | self.list_layout = QHBoxLayout(self.list_widget) 107 | self.list_layout.setAlignment(Qt.AlignmentFlag.AlignLeft) 108 | else: 109 | assert False, f"Unknown mode '{orientation}'!" 110 | self.list_layout.setContentsMargins(0, 0, 0, 0) 111 | self.list_layout.setSpacing(10) 112 | 113 | self.scroll_area = QScrollArea() 114 | self.scroll_area.setWidgetResizable(True) 115 | self.scroll_area.setFrameStyle(0) 116 | self.scroll_area.setWidget(self.list_widget) 117 | self.layout.addWidget(self.scroll_area) 118 | 119 | if self.above_widget is not None: 120 | self.list_layout.addWidget(self.above_widget) 121 | 122 | if self.floating_widget is not None: 123 | self.list_layout.addWidget(self.floating_widget) 124 | 125 | self.item_widgets = [] 126 | self.num_visible_item_widgets = 0 127 | 128 | def last_item_widget(self): 129 | return self.item_widgets[self.num_visible_item_widgets - 1] 130 | 131 | def update_item_list(self, item_list, params=None): 132 | 133 | if self.floating_widget is not None: 134 | self.list_layout.removeWidget(self.floating_widget) 135 | 136 | # make sure that there are enough item widgets 137 | while len(item_list) > len(self.item_widgets): 138 | self.item_widgets.append(self.item_type(self.parent)) 139 | 140 | # make sure that the correct number of item widgets is shown 141 | while len(item_list) > self.num_visible_item_widgets: 142 | widget = self.item_widgets[self.num_visible_item_widgets] 143 | self.list_layout.addWidget(widget) 144 | widget.show() 145 | self.num_visible_item_widgets += 1 146 | 147 | while len(item_list) < self.num_visible_item_widgets: 148 | widget = self.item_widgets[self.num_visible_item_widgets - 1] 149 | widget.hide() 150 | self.list_layout.removeWidget(widget) 151 | self.num_visible_item_widgets -= 1 152 | 153 | if self.floating_widget is not None: 154 | self.list_layout.addWidget(self.floating_widget) 155 | 156 | # update item widgets 157 | for item, item_widget in zip(item_list, self.item_widgets[:len(item_list)]): 158 | item_widget.update_item(item, params) 159 | 160 | def enable_input(self): 161 | for item_widget in self.item_widgets: 162 | item_widget.enable_input() 163 | 164 | def disable_input(self): 165 | for item_widget in self.item_widgets: 166 | item_widget.disable_input() 167 | 168 | 169 | class CustomScrollableListItem(QFrame): 170 | 171 | def __init__(self, parent): 172 | super(CustomScrollableListItem, self).__init__() 173 | self.parent = parent 174 | 175 | @abc.abstractmethod 176 | def update_item(self, item, params=None): 177 | raise NotImplementedError 178 | 179 | @abc.abstractmethod 180 | def enable_input(self): 181 | raise NotImplementedError 182 | 183 | @abc.abstractmethod 184 | def disable_input(self): 185 | raise NotImplementedError 186 | 187 | 188 | def show_confirmation_dialog(parent, title_text, explanation_text, accept_text, reject_text): 189 | dialog = QDialog(parent) 190 | dialog.setWindowTitle(title_text) 191 | dialog_layout = QVBoxLayout(dialog) 192 | 193 | explanation = QLabel(explanation_text) 194 | explanation.setFont(LABEL_FONT) 195 | dialog_layout.addWidget(explanation) 196 | 197 | buttons_widget = QWidget(dialog) 198 | buttons_layout = QHBoxLayout(buttons_widget) 199 | buttons_layout.setContentsMargins(0, 10, 0, 0) 200 | dialog_layout.addWidget(buttons_widget) 201 | yes_button = QPushButton(accept_text) 202 | yes_button.setFont(BUTTON_FONT) 203 | yes_button.clicked.connect(dialog.accept) 204 | buttons_layout.addWidget(yes_button) 205 | no_button = QPushButton(reject_text) 206 | no_button.setFont(BUTTON_FONT) 207 | no_button.clicked.connect(dialog.reject) 208 | buttons_layout.addWidget(no_button) 209 | 210 | no_button.setFocus() 211 | 212 | return dialog.exec() 213 | -------------------------------------------------------------------------------- /wannadb_ui/resources/confidence_high.svg: -------------------------------------------------------------------------------- 1 | 2 | 32 | 40 | 45 | 50 | -------------------------------------------------------------------------------- /wannadb_ui/resources/confidence_low.svg: -------------------------------------------------------------------------------- 1 | 2 | 32 | 40 | 45 | 50 | -------------------------------------------------------------------------------- /wannadb_ui/resources/correct.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /wannadb_ui/resources/folder.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /wannadb_ui/resources/idea.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /wannadb_ui/resources/incorrect.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /wannadb_ui/resources/info.svg: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /wannadb_ui/resources/leave.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /wannadb_ui/resources/locate.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /wannadb_ui/resources/magnifier.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /wannadb_ui/resources/pencil.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /wannadb_ui/resources/plus.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /wannadb_ui/resources/redo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /wannadb_ui/resources/run.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /wannadb_ui/resources/run_run.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /wannadb_ui/resources/save.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /wannadb_ui/resources/statistics.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /wannadb_ui/resources/statistics_folder.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 27 | 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /wannadb_ui/resources/statistics_incorrect.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /wannadb_ui/resources/statistics_save.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 22 | 24 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /wannadb_ui/resources/table.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /wannadb_ui/resources/text_cursor.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /wannadb_ui/resources/tools.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /wannadb_ui/resources/trash.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /wannadb_ui/resources/two_documents.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /wannadb_ui/start_menu.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from PyQt6.QtCore import Qt, QSize 4 | from PyQt6.QtGui import QIcon 5 | from PyQt6.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget 6 | 7 | from wannadb_ui.common import LABEL_FONT, MainWindowContent, \ 8 | SUBHEADER_FONT 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class StartMenuWidget(MainWindowContent): 14 | def __init__(self, main_window): 15 | super(StartMenuWidget, self).__init__(main_window, "Welcome to WannaDB!") 16 | 17 | self.setMaximumWidth(400) 18 | self.layout.setContentsMargins(0, 0, 0, 0) 19 | self.layout.setSpacing(30) 20 | 21 | self.create_new_document_base_widget = QWidget() 22 | self.create_new_document_base_layout = QVBoxLayout(self.create_new_document_base_widget) 23 | self.create_new_document_base_layout.setContentsMargins(0, 0, 0, 0) 24 | self.create_new_document_base_layout.setSpacing(10) 25 | self.layout.addWidget(self.create_new_document_base_widget) 26 | 27 | self.create_new_document_base_subheader = QLabel("Create a new document base.") 28 | self.create_new_document_base_subheader.setFont(SUBHEADER_FONT) 29 | self.create_new_document_base_layout.addWidget(self.create_new_document_base_subheader) 30 | 31 | self.create_new_document_base_wrapper_widget = QWidget() 32 | self.create_new_document_base_wrapper_layout = QHBoxLayout(self.create_new_document_base_wrapper_widget) 33 | self.create_new_document_base_wrapper_layout.setContentsMargins(0, 0, 0, 0) 34 | self.create_new_document_base_wrapper_layout.setSpacing(20) 35 | self.create_new_document_base_wrapper_layout.setAlignment(Qt.AlignmentFlag.AlignLeft) 36 | self.create_new_document_base_layout.addWidget(self.create_new_document_base_wrapper_widget) 37 | 38 | self.create_document_base_button = QPushButton() 39 | self.create_document_base_button.setFixedHeight(45) 40 | self.create_document_base_button.setFixedWidth(45) 41 | self.create_document_base_button.setIcon(QIcon("wannadb_ui/resources/two_documents.svg")) 42 | self.create_document_base_button.setIconSize(QSize(25, 25)) 43 | self.create_document_base_button.clicked.connect(self.main_window.show_document_base_creator_widget_task) 44 | self.create_new_document_base_wrapper_layout.addWidget(self.create_document_base_button) 45 | 46 | self.create_document_base_label = QLabel( 47 | "Create a new document base from a directory\nof .txt files and a list of attribute names.") 48 | self.create_document_base_label.setFont(LABEL_FONT) 49 | self.create_new_document_base_wrapper_layout.addWidget(self.create_document_base_label) 50 | 51 | self.load_document_base_widget = QWidget() 52 | self.load_document_base_layout = QVBoxLayout(self.load_document_base_widget) 53 | self.load_document_base_layout.setContentsMargins(0, 0, 0, 0) 54 | self.load_document_base_layout.setSpacing(10) 55 | self.layout.addWidget(self.load_document_base_widget) 56 | 57 | self.load_document_base_subheader = QLabel("Load an existing document base.") 58 | self.load_document_base_subheader.setFont(SUBHEADER_FONT) 59 | self.load_document_base_layout.addWidget(self.load_document_base_subheader) 60 | 61 | self.load_document_base_wrapper_widget = QWidget() 62 | self.load_document_base_wrapper_layout = QHBoxLayout(self.load_document_base_wrapper_widget) 63 | self.load_document_base_wrapper_layout.setContentsMargins(0, 0, 0, 0) 64 | self.load_document_base_wrapper_layout.setSpacing(20) 65 | self.load_document_base_wrapper_layout.setAlignment(Qt.AlignmentFlag.AlignLeft) 66 | self.load_document_base_layout.addWidget(self.load_document_base_wrapper_widget) 67 | 68 | self.load_document_base_button = QPushButton() 69 | self.load_document_base_button.setFixedHeight(45) 70 | self.load_document_base_button.setFixedWidth(45) 71 | self.load_document_base_button.setIcon(QIcon("wannadb_ui/resources/folder.svg")) 72 | self.load_document_base_button.setIconSize(QSize(25, 25)) 73 | self.load_document_base_button.clicked.connect(self.main_window.load_document_base_from_bson_task) 74 | self.load_document_base_wrapper_layout.addWidget(self.load_document_base_button) 75 | 76 | self.load_document_base_label = QLabel("Load an existing document base\nfrom a .bson file.") 77 | self.load_document_base_label.setFont(LABEL_FONT) 78 | self.load_document_base_wrapper_layout.addWidget(self.load_document_base_label) 79 | 80 | def enable_input(self): 81 | self.create_document_base_button.setEnabled(True) 82 | self.load_document_base_button.setEnabled(True) 83 | 84 | def disable_input(self): 85 | self.create_document_base_button.setDisabled(True) 86 | self.load_document_base_button.setDisabled(True) 87 | --------------------------------------------------------------------------------