├── .gitignore ├── LICENSE.md ├── README.md ├── environment.yml ├── inference.py ├── phases.md ├── quobert ├── __init__.py ├── dataprocessing │ ├── mturk │ │ ├── mturkify.py │ │ └── sample-mturk-context.py │ ├── postprocessing │ │ ├── process_res.py │ │ └── speakers_offset.py │ ├── preprocessing │ │ ├── bootstrap_EM.py │ │ ├── extract_entities.py │ │ ├── features.py │ │ ├── merge.py │ │ ├── sampling.py │ │ └── sampling_uncased.py │ ├── pygtrie.py │ └── run.sh ├── model │ ├── __init__.py │ ├── eval.py │ ├── fit.py │ └── model.py └── utils │ ├── __init__.py │ ├── _utils.py │ └── data │ ├── __init__.py │ ├── dataloader.py │ └── dataset.py ├── quootstrap ├── README.md ├── extract_quotations.sh ├── lib │ └── spinn3r-client-3.4.05-edit.jar ├── pom.xml ├── quootstrap.tar.gz ├── resources │ ├── config.properties │ └── seedPatterns.txt └── src │ └── main │ └── java │ └── ch │ └── epfl │ └── dlab │ ├── quootstrap │ ├── ConfigManager.java │ ├── ContextExtractor.java │ ├── DatasetLoader.java │ ├── Dawg.java │ ├── Exporter.java │ ├── ExporterArticle.java │ ├── ExporterContext.java │ ├── ExporterSpeakers.java │ ├── GroundTruthEvaluator.java │ ├── HashTrie.java │ ├── HashTriePatternMatcher.java │ ├── Hashed.java │ ├── LineageInfo.java │ ├── MultiCounter.java │ ├── NameDatabaseWikiData.java │ ├── ParquetDatasetLoader.java │ ├── Pattern.java │ ├── PatternExtractor.java │ ├── PatternMatcher.java │ ├── QuotationExtraction.java │ ├── Sentence.java │ ├── SimplePatternMatcher.java │ ├── SpeakerAlias.java │ ├── Spinn3rDatasetLoader.java │ ├── Spinn3rDocument.java │ ├── Spinn3rTextDatasetLoader.java │ ├── StaticRules.java │ ├── Token.java │ ├── Trie.java │ ├── TriePatternMatcher.java │ ├── TupleExtractor.java │ └── Utils.java │ └── spinn3r │ ├── EntryWrapper.java │ ├── Tokenizer.java │ ├── TokenizerImpl.java │ └── converter │ ├── AbstractDecoder.java │ ├── CombinedDecoder.java │ ├── Decompressor.java │ ├── EntryWrapperBuilder.java │ ├── ProtoToJson.java │ ├── README.md │ ├── SpinnerDecoder.java │ └── Stopwatch.java ├── setup.cfg ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | runs/ 132 | quobert/evaluation/ 133 | report/ 134 | *.ipynb 135 | 136 | quootstrap/target/ 137 | quootstrap/.* -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 DLAB 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: base 2 | channels: 3 | - defaults 4 | - pytorch 5 | - conda-forge 6 | dependencies: 7 | - python=3.7 8 | - cudatoolkit=10.1 9 | - pip 10 | - pip: 11 | - tensorboard 12 | - torch===1.4.0 13 | - pandas 14 | - pyarrow 15 | - transformers==2.6 -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import logging 4 | import os 5 | 6 | import torch 7 | 8 | from quobert.model import BertForQuotationAttribution, evaluate 9 | from quobert.utils.data import ParquetDataset 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "--model_dir", 17 | default=None, 18 | type=str, 19 | required=True, 20 | help="Path to pre-trained model", 21 | ) 22 | parser.add_argument( 23 | "--output_dir", 24 | default=None, 25 | type=str, 26 | required=True, 27 | help="The output directory where the model results will be written.", 28 | ) 29 | parser.add_argument( 30 | "--inference_dir", 31 | default=None, 32 | type=str, 33 | required=True, 34 | help="The input inference directory. Should contain (.gz.parquet) files", 35 | ) 36 | parser.add_argument( 37 | "--per_gpu_eval_batch_size", 38 | default=256, 39 | type=int, 40 | help="Batch size per GPU/CPU for Inference.", 41 | ) 42 | 43 | args = parser.parse_args() 44 | args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 45 | args.n_gpu = torch.cuda.device_count() 46 | 47 | logging.basicConfig( 48 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 49 | datefmt="%m/%d/%Y %H:%M:%S", 50 | level=logging.INFO, 51 | ) 52 | 53 | model = BertForQuotationAttribution.from_pretrained(args.model_dir) 54 | model.to(args.device) 55 | 56 | logger.info(f"Started loading the dataset from {args.inference_dir}") 57 | files = sorted(glob.glob(os.path.join(args.inference_dir, "**.gz.parquet"))) 58 | 59 | for i, f in enumerate(files): 60 | dataset = ParquetDataset(f) 61 | args.output_file = os.path.join(args.output_dir, f"results_{i:04d}.csv") 62 | evaluate(args, model, dataset, has_target=False) 63 | -------------------------------------------------------------------------------- /phases.md: -------------------------------------------------------------------------------- 1 | # Spinn3r character encoding: description of Phases A through E 2 | 3 | Quotebank was extracted from news articles that had been collected from the media aggregation service Spinn3r (now called [Datastreamer](https://www.datastreamer.io)). 4 | The Spinn3r data was collected over the course of over a decade. 5 | During this time, the client-side code used for collecting the data changed several times, and various character-encoding-related issues led to different representations of the original text at different times. Most issues relate to capital letters and non-ASCII characters. 6 | 7 | This document, although not entirely conclusive, is the result of an "archæological" endeavor to reconstruct the history of the various character encodings used at different times during Spinn3r data collection. 8 | Based on our insights, we divide the 12 years spanned by the Spinn3r corpus into five phases (Phases A through E), detailed below. 9 | Non-ASCII characters are most relevant for non-English text; capitalization, however, matters for English as well. 10 | For most users, the key takeaways of this document are these: 11 | 12 | 1. Text was lowercased in Phases A, B, and C, whereas the original capitalization was maintained in Phases D and E. 13 | 2. Non-ASCII characters are properly represented only in Phase E. 14 | 15 | (This document is based on an initial write-up made on 6 June 2014.) 16 | 17 | 18 | ## Phase A (until 2010-07-13) 19 | 20 | Spinn3r's probably UTF-8-encoded data was read as Latin-1 (a.k.a. ISO-8859-1). UTF-8 has potentially several bytes per character, while Latin-1 has always one byte per character. That is, a single character from the original data now looks like two characters. For instance, Unicode code point U+00E4 ("Latin small letter a with diaeresis", a.k.a. "ä") is represented by the two-byte code C3A4 in UTF-8. Reading the bytes C3A4 as Latin-1 results in the two-character sequence "ä", since C3 encodes "Ã" in Latin-1, and A4, "¤". 21 | Then, **lowercasing** was performed on the garbled text, making it even more garbled. For instance, "ä" became "ã¤". 22 | Finally, the data was written to disk as UTF-8. 23 | 24 | **Approximate solution:** 25 | Take the debugging table from [http://www.i18nqa.com/debug/utf8-debug.html](https://web.archive.org/web/20210228174408/http://www.i18nqa.com/debug/utf8-debug.html), look for the garbled and lower-cased sequences and replace them by their original character. 26 | Note that the garbling is not bijective, but since most of the garbled sequences are highly unlikely (e.g., "ã¤"), this should be mostly fine. 27 | 28 | 29 | ## Phase B (2010-07-14 to 2010-07-26) 30 | 31 | For just about two weeks, the data seems to have been read as UTF-8 and written as Latin-1 (i.e., the other way round than in phase A). 32 | Non-Latin-1 characters are printed as "?". However, there also seem to be a very few cases as in Phase A. 33 | All text was **lowercased** in this phase. 34 | 35 | **Approximate solution:** 36 | Simply read the data as Latin-1. 37 | 38 | 39 | ## Phase C (2010-07-27 to 2013-04-28) 40 | 41 | The data was written to disk as ASCII, such that all non-ASCII characters (including Latin-1 characters) appear as "?". 42 | All text was **lowercased** in this phase. 43 | 44 | **Approximate solution:** 45 | None. We simply need to byte (haha...) the bullet and deal with the question marks. 46 | 47 | 48 | ## Phase D (2013-04-29 to 2014-05-21) 49 | 50 | Attempt 1 at fixing the above legacy issues: 51 | capitalization is kept as in the original text obtained from Spinn3r. 52 | However, due to a bad BASH environment variable, text was written to disk as ASCII, such that non-ASCII characters still appear as "?". 53 | 54 | 55 | ## Phase E (since 2014-05-22) 56 | 57 | Attempt 2 at fixing the above legacy issues: 58 | capitalization is kept as in the original text obtained from Spinn3r, and output is now finally written as proper UTF-8 Unicode. 59 | -------------------------------------------------------------------------------- /quobert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-dlab/Quotebank/a203acf5a05e34841671578d14d6a0be6a66a3ef/quobert/__init__.py -------------------------------------------------------------------------------- /quobert/dataprocessing/mturk/sample-mturk-context.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from pyspark.sql import SparkSession 5 | import pyspark.sql.functions as F 6 | 7 | 8 | def extract( 9 | spark: SparkSession, 10 | *, 11 | mturk_path: str, 12 | output_path: str, 13 | nb_partition: int, 14 | compression: str = "gzip", 15 | ): 16 | df = spark.read.parquet(mturk_path) 17 | df_key = df.withColumn( 18 | "key", F.substring(df.articleUID, 1, 6).cast("int") 19 | ) 20 | key_proportions = df_key.groupBy("key").count().orderBy("key").collect() 21 | proportions = {r.key: 100 / r["count"] for r in key_proportions} 22 | selected_keys = df_key.sampleBy("key", proportions, seed=1).withColumn("source", F.lit("normal")).cache() 23 | print("number of normal quotation", selected_keys.count()) 24 | 25 | speaker_proportions = df_key.groupBy("nb_entities").count().orderBy("nb_entities").collect() 26 | proportions = {r.nb_entities: min(1., 100 / r["count"]) if r.nb_entities <= 20 else 0 for r in speaker_proportions} 27 | selected_speakers = df_key.sampleBy("nb_entities", proportions, seed=2).withColumn("source", F.lit("nb_entities")).cache() 28 | print("number of quotation by more speakers", selected_speakers.count()) 29 | 30 | df_filtered = df_key.filter(~df_key.in_quootstrap) 31 | nb_hard = df_filtered.count() 32 | selected_hard = df_filtered.sample(fraction=5000 / nb_hard, seed=3).withColumn("source", F.lit("not_quootstrap")).cache() 33 | print("number of quotation not in quootstrap", selected_hard.count()) 34 | to_evaluate = selected_keys.union(selected_speakers).union(selected_hard).dropDuplicates(['articleUID', 'articleOffset']).dropDuplicates(['context']).cache() 35 | print("total after dropDuplicates", to_evaluate.count()) 36 | to_evaluate.coalesce(nb_partition).write.parquet( 37 | os.path.join(output_path, "mturk"), 38 | mode="overwrite", 39 | compression=compression, 40 | ) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument( 46 | "-m", 47 | "--mturk", 48 | type=str, 49 | help="Path to folder with all raw context for mturk (.gz.parquet)", 50 | required=True, 51 | ) 52 | parser.add_argument( 53 | "-o", "--output", type=str, help="Path to output folder", required=True 54 | ) 55 | parser.add_argument( 56 | "-n", 57 | "--nb_partition", 58 | type=int, 59 | help="Number of partition for the output (useful if used with unsplittable compression algorithm). Default=10", 60 | default=10, 61 | ) 62 | parser.add_argument( 63 | "--compression", 64 | type=str, 65 | help="Compression algorithm. Can be any compatible alogrithm with Spark Parquet. Default=gzip", 66 | default="gzip", 67 | ) 68 | parser.add_argument( 69 | "-l", 70 | "--local", 71 | help="Add if you want to execute locally. The code is expected to be run on a cluster if you run on big files", 72 | action="store_true", 73 | ) 74 | args = parser.parse_args() 75 | 76 | print("Starting the Spark Session") 77 | if args.local: 78 | import findspark 79 | 80 | findspark.init() 81 | 82 | spark = ( 83 | SparkSession.builder.master("local[24]") 84 | .appName("SampleMTurkLocal") 85 | .config("spark.driver.memory", "16g") 86 | .config("spark.executor.memory", "32g") 87 | .getOrCreate() 88 | ) 89 | else: 90 | spark = SparkSession.builder.appName("SampleMTurk").getOrCreate() 91 | 92 | print("Starting the merging process") 93 | extract( 94 | spark, 95 | mturk_path=args.mturk, 96 | output_path=args.output, 97 | nb_partition=args.nb_partition, 98 | compression=args.compression, 99 | ) 100 | -------------------------------------------------------------------------------- /quobert/dataprocessing/postprocessing/process_res.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | import string 4 | from collections import Counter 5 | from os.path import join 6 | 7 | import pyspark.sql.functions as F 8 | from pyspark.sql import SparkSession, Window 9 | from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StringType 10 | 11 | 12 | @F.udf 13 | def lower(text): 14 | no_punct = "".join(t for t in text if t not in string.punctuation) 15 | one_space = re.sub(r"\W+", " ", no_punct) 16 | return one_space.lower() 17 | 18 | 19 | @F.udf 20 | def longest(quotes): 21 | return sorted(quotes, key=len, reverse=True)[0] 22 | 23 | 24 | @F.udf 25 | def pad_int(idx): 26 | return f"{idx:06d}" 27 | 28 | 29 | def first(col): 30 | return F.first(f"quotes.{col}").alias(col) 31 | 32 | 33 | def get_col(col): 34 | return F.col(f"quotes.{col}").alias(col) 35 | 36 | 37 | @F.udf(returnType=DoubleType()) 38 | def normalize(x): 39 | return x / 100.0 if x is not None else None 40 | 41 | 42 | @F.udf 43 | def get_top1(probas): 44 | return probas[0][0] 45 | 46 | 47 | @F.udf(returnType=IntegerType()) 48 | def get_article_len(article): 49 | return len(article.split()) 50 | 51 | 52 | @F.udf 53 | def get_most_common_name(names): 54 | return sorted(Counter(names).items(), key=lambda x: (-x[1], x[0]))[0][0] 55 | 56 | 57 | @F.udf(returnType=ArrayType(StringType())) 58 | def get_website(array): 59 | return list(dict.fromkeys(x["website"] for x in array)) 60 | 61 | 62 | @F.udf(returnType=ArrayType(StringType())) 63 | def get_first_dedup(ids): 64 | return list({x[0] for x in ids}) 65 | 66 | 67 | NB_PARTITIONS = 2 ** 11 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument( 70 | "-a", 71 | "--articles", 72 | type=str, 73 | help="Path to folder with all articles (.json.gz)", 74 | required=True, 75 | ) 76 | parser.add_argument( 77 | "-c", 78 | "--context", 79 | type=str, 80 | help="Path to folder with all quotes with context (.json.gz)", 81 | required=True, 82 | ) 83 | parser.add_argument( 84 | "-s", 85 | "--speakers", 86 | type=str, 87 | help="Path to the augmented speakers folder (.parquet.gz)", 88 | required=True, 89 | ) 90 | parser.add_argument( 91 | "-r", 92 | "--results", 93 | type=str, 94 | help="Path to the results from the inference step (.csv)", 95 | required=True, 96 | ) 97 | parser.add_argument( 98 | "-o", "--output", type=str, help="Path to output folder", required=True 99 | ) 100 | parser.add_argument( 101 | "--partition", 102 | nargs="+", 103 | default=["year", "month"], 104 | help="Column to partition the out with. A choice of year and month. Order matters. Default 'year month'", 105 | ) 106 | 107 | args = parser.parse_args() 108 | 109 | spark = SparkSession.builder.appName("processRes").getOrCreate() 110 | qc = spark.read.json(args.context) 111 | articles = spark.read.json(args.articles) 112 | speakers = spark.read.parquet(args.speakers) 113 | res = spark.read.csv( 114 | args.results, 115 | header=True, 116 | schema="articleUID STRING, articleOffset LONG, rank INT, speaker STRING, proba DOUBLE", 117 | ) 118 | 119 | quote_article_link = ( 120 | qc.join(articles.select("articleUID", "date", "website", "phase"), on="articleUID") 121 | .groupBy(lower(F.col("quotation")).alias("canonicalQuotation")) 122 | .agg( 123 | F.collect_set("quotation").alias("quotations"), 124 | F.min("date").alias("earliest_date"), 125 | F.min("phase").alias("phase"), 126 | F.count("*").alias("numOccurrences"), 127 | F.collect_list( 128 | F.struct( 129 | F.col("articleUID"), 130 | F.col("articleOffset"), 131 | F.col("quotation"), 132 | F.col("leftContext"), 133 | F.col("rightContext"), 134 | F.col("quotationOffset"), 135 | F.col("leftOffset").alias("contextStart"), 136 | F.col("rightOffset").alias("contextEnd"), 137 | ) 138 | ).alias("quotes_link"), 139 | get_website(F.sort_array(F.collect_list(F.struct("date", "website")))).alias( 140 | "urls" 141 | ), 142 | ) 143 | .withColumn("quotation", longest(F.col("quotations"))) 144 | .withColumn( 145 | "row_nb", 146 | F.row_number().over( 147 | Window.partitionBy(F.to_date("earliest_date")).orderBy("canonicalQuotation") 148 | ), 149 | ) 150 | .withColumn( 151 | "quoteID", F.concat_ws("-", F.to_date("earliest_date"), pad_int("row_nb")), 152 | ) 153 | .withColumn("month", F.month("earliest_date")) 154 | .withColumn("year", F.year("earliest_date")) 155 | .drop("quotations", "row_nb") 156 | ) 157 | 158 | joined_df = qc.join(res, on=["articleUID", "articleOffset"]) 159 | 160 | w = Window.partitionBy("canonicalQuotation") 161 | rank_w = Window.partitionBy("canonicalQuotation").orderBy(F.desc("sum(proba)")) 162 | agg_proba = ( 163 | joined_df.groupBy(lower(F.col("quotation")).alias("canonicalQuotation"), "qids") 164 | .agg(F.sum("proba"), F.collect_list("speaker").alias("speakers")) 165 | .select( 166 | "*", 167 | F.sum("sum(proba)").over(w).alias("weight"), 168 | F.row_number().over(rank_w).alias("rank"), 169 | get_most_common_name("speakers").alias("speaker"), 170 | ) 171 | .withColumn("proba", F.round(F.col("sum(proba)") / F.col("weight"), 4)) 172 | .filter("proba >= 1e-4") 173 | .drop("sum(porba)", "weight") 174 | ) 175 | 176 | agg_proba.write.parquet( 177 | join(args.output, "quotebank-cache-proba"), 178 | mode="overwrite", 179 | compression="gzip", 180 | ) 181 | agg_proba = spark.read.parquet(join(args.output, "quotebank-cache-proba")) 182 | 183 | top_speaker = agg_proba.filter("rank = 1").select( 184 | "canonicalQuotation", 185 | F.col("speaker").alias("top_speaker"), 186 | F.col("qids").alias("top_speaker_qid"), 187 | F.col("speakers").alias("top_surface_forms"), 188 | ) 189 | 190 | probas = ( 191 | agg_proba.orderBy("canonicalQuotation", "rank") 192 | .groupBy("canonicalQuotation") 193 | .agg( 194 | F.collect_list(F.struct(F.col("speaker"), F.col("proba"), F.col("qids"))).alias( 195 | "probas" 196 | ) 197 | ) 198 | ) 199 | 200 | final = quote_article_link.join(top_speaker, on="canonicalQuotation").join( 201 | probas, on="canonicalQuotation" 202 | ) 203 | final.write.parquet( 204 | join(args.output, "quotebank-cache1"), mode="overwrite", compression="gzip" 205 | ) 206 | final = spark.read.parquet(join(args.output, "quotebank-cache1")) 207 | 208 | 209 | SMALL_COLS = [ 210 | "quoteID", 211 | "quotation", 212 | F.col("top_speaker").alias("speaker"), 213 | # F.col("top_speaker_qid").alias("qids"), 214 | F.col("earliest_date").alias("date"), 215 | "numOccurrences", 216 | "probas", 217 | "year", 218 | "month", 219 | "urls", 220 | "phase", 221 | ] 222 | 223 | final.select(*SMALL_COLS).repartition(*args.partition).write.partitionBy( 224 | *args.partition 225 | ).json( 226 | join(args.output, "quotes-df"), mode="overwrite", compression="bzip2", 227 | ) 228 | 229 | BIG_COLS = [ 230 | "quoteID", 231 | "numOccurrences", 232 | F.col("top_speaker").alias("globalTopSpeaker"), 233 | F.col("probas").alias("globalProbas"), 234 | F.explode("quotes_link").alias("quotes"), 235 | ] 236 | 237 | individual_probas = ( 238 | res.filter("proba > 0") 239 | .orderBy("articleUID", "articleOffset", "rank") 240 | .groupBy("articleUID", "articleOffset") 241 | .agg( 242 | F.collect_list( 243 | F.struct("speaker", F.round(normalize("proba"), 4).alias("proba"), "qids") 244 | ).alias("localProbas") 245 | ) 246 | .withColumn("speaker", get_top1("localProbas")) 247 | ) 248 | 249 | df = final.select(*BIG_COLS) 250 | df = df.join( 251 | individual_probas, 252 | on=[ 253 | df.quotes.articleUID == individual_probas.articleUID, 254 | df.quotes.articleOffset == individual_probas.articleOffset, 255 | ], 256 | ).drop("articleUID", "articleOffset") 257 | 258 | article_df = df.groupBy(F.col("quotes.articleUID").alias("articleUID")).agg( 259 | F.collect_list( 260 | F.struct( 261 | "quoteID", 262 | "numOccurrences", 263 | get_col("quotation"), 264 | get_col("quotationOffset"), 265 | get_col("contextStart"), 266 | get_col("contextEnd"), 267 | "globalTopSpeaker", 268 | "globalProbas", 269 | F.col("speaker").alias("localTopSpeaker"), 270 | "localProbas", 271 | ) 272 | ).alias("quotations"), 273 | ) 274 | 275 | article_df = ( 276 | article_df.join(articles, on="articleUID") 277 | .join(speakers, on="articleUID") 278 | .withColumn("articleLength", get_article_len("content")) 279 | .withColumnRenamed("articleUID", "articleID") 280 | .withColumnRenamed("website", "url") 281 | .withColumn("year", F.year("date")) 282 | .withColumn("month", F.month("date")) 283 | ) 284 | 285 | article_df.repartition(*args.partition).write.partitionBy(*args.partition).parquet( 286 | join(args.output, "article-df"), mode="overwrite", compression="gzip", 287 | ) 288 | -------------------------------------------------------------------------------- /quobert/dataprocessing/postprocessing/speakers_offset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import string 3 | 4 | from pyspark.sql import Row, SparkSession 5 | 6 | from pygtrie import StringTrie # type: ignore # Sideloaded in the spark-submit 7 | 8 | PUNCT = "".join(x for x in string.punctuation if x not in "[]") 9 | BLACK_LIST = {"president", "manager"} 10 | 11 | 12 | def create_trie(names): 13 | trie = StringTrie(delimiter="/") 14 | for name in names: 15 | processed_name = [ 16 | x for x in name["name"].split() if x.lower() not in BLACK_LIST 17 | ] 18 | for i in range(len(processed_name)): 19 | trie["/".join(processed_name[i:]).lower()] = (name["name"], name["ids"]) 20 | return trie 21 | 22 | 23 | def update_entity(name, qinfo, start, end, out): 24 | if name in out: 25 | out[name]["offsets"].append([start, end]) 26 | else: 27 | out[name] = { 28 | "name": name, 29 | "ids": sorted({x[0] for x in qinfo}), 30 | "offsets": [[start, end]], 31 | } 32 | 33 | 34 | def reduce_entities(entities): 35 | out = dict() 36 | for i, (value, start, end) in entities.items(): 37 | if isinstance(value, list): 38 | for name, qinfo in value: 39 | update_entity(name, qinfo, start, end, out) 40 | else: 41 | try: 42 | update_entity(value[0], value[1], start, end, out) 43 | except IndexError: # this happens when value = '/', which is rare 44 | print(value, "We had an error here") 45 | return list(out.values()) 46 | 47 | 48 | def get_partial_match(trie, key): 49 | match = list(trie[key:]) 50 | return match if len(match) > 1 else match[0] 51 | 52 | 53 | def get_entity(trie, key): 54 | entites = get_partial_match(trie, key) 55 | if not type(entites) == tuple: 56 | raise Exception(f"{entites} is not a tuple") 57 | return entites 58 | 59 | 60 | def find_entites(text: str, trie: StringTrie): 61 | tokens = text.split() 62 | start = 0 63 | count = 1 # start at 1, 0 is for the "NO_MATCH" 64 | entities = dict() 65 | for i in range(len(tokens)): 66 | key = "/".join(tokens[start : i + 1]).lower() 67 | if trie.has_subtrie(key): # Not done yet 68 | if i == len(tokens) - 1: # Reached the end of the string 69 | entities[count] = (get_entity(trie, key), start, i + 1) 70 | elif trie.has_key(key): # noqa: W601 # Find a perfect match 71 | entities[count] = (trie[key], start, i + 1) 72 | count += 1 73 | start = i + 1 74 | elif start < i: # Found partial prefix match before this token 75 | old_key = "/".join(tokens[start:i]).lower() 76 | entities[count] = (get_entity(trie, old_key), start, i) 77 | count += 1 78 | if trie.has_node( 79 | tokens[i].lower() 80 | ): # Need to verify that the current token isn't in the Trie 81 | start = i 82 | else: 83 | start = i + 1 84 | else: # No match 85 | start = i + 1 86 | return reduce_entities(entities) 87 | 88 | 89 | def transform(x: Row): 90 | trie = create_trie(x.names) 91 | try: 92 | entities = find_entites(x.content, trie) 93 | except Exception: 94 | return None 95 | 96 | return Row(articleUID=x.articleUID, names=entities,) 97 | 98 | 99 | def speakers_offset( 100 | spark: SparkSession, 101 | *, 102 | articles_path: str, 103 | speakers_path: str, 104 | output_path: str, 105 | nb_partition: int, 106 | compression: str = "gzip", 107 | ): 108 | df = ( 109 | spark.read.json(articles_path) 110 | .select("articleUID", "content") 111 | .repartition(nb_partition) 112 | ) 113 | speakers = spark.read.json(speakers_path) 114 | joined = df.join(speakers, on="articleUID") 115 | 116 | transformed = joined.rdd.map(transform).filter(lambda x: x is not None).toDF() 117 | transformed.write.parquet(output_path, "overwrite", compression=compression) 118 | 119 | 120 | if __name__ == "__main__": 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument( 123 | "-a", 124 | "--articles", 125 | type=str, 126 | help="Path to the articles (.json)", 127 | required=True, 128 | ) 129 | parser.add_argument( 130 | "-s", 131 | "--speakers", 132 | type=str, 133 | help="Path to the speakers folder (.json)", 134 | required=True, 135 | ) 136 | parser.add_argument( 137 | "-o", 138 | "--output", 139 | type=str, 140 | help="Path to output folder for the transformed speaker offsets", 141 | required=True, 142 | ) 143 | parser.add_argument( 144 | "-l", 145 | "--local", 146 | help="Add if you want to execute locally. The code is expected to be run on a cluster if you run on big files", 147 | action="store_true", 148 | ) 149 | parser.add_argument( 150 | "-n", 151 | "--nb_partition", 152 | type=int, 153 | help="Number of partition for the output (useful if used with unsplittable compression algorithm). Default=50", 154 | default=200, 155 | ) 156 | parser.add_argument( 157 | "--compression", 158 | type=str, 159 | help="Compression algorithm. Can be any compatible alogrithm with Spark Parquet. Default=gzip", 160 | default="gzip", 161 | ) 162 | args = parser.parse_args() 163 | 164 | if args.local: 165 | # import findspark 166 | 167 | # findspark.init() 168 | 169 | spark = ( 170 | SparkSession.builder.master("local[24]") 171 | .appName("SpeakerOffsetsLocal") 172 | .config("spark.driver.memory", "16g") 173 | .config("spark.executor.memory", "32g") 174 | .getOrCreate() 175 | ) 176 | else: 177 | spark = SparkSession.builder.appName("SpeakerOffsets").getOrCreate() 178 | 179 | speakers_offset( 180 | spark, 181 | articles_path=args.articles, 182 | speakers_path=args.speakers, 183 | output_path=args.output, 184 | nb_partition=args.nb_partition, 185 | compression=args.compression, 186 | ) 187 | -------------------------------------------------------------------------------- /quobert/dataprocessing/preprocessing/bootstrap_EM.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from string import punctuation 3 | 4 | import pyspark.sql.functions as F 5 | from pyspark.sql import SparkSession 6 | 7 | 8 | @F.udf() 9 | def process_quote(quote): 10 | return "".join(x for x in quote.lower() if x not in punctuation) 11 | 12 | 13 | @F.udf() 14 | def most_frequent(x): 15 | return max(set(x), key=x.count) 16 | 17 | 18 | def find_exact_match( 19 | spark: SparkSession, 20 | quotes_context_path: str, 21 | quootstrap_path: str, 22 | output_path: str, 23 | nb_partition: int = 200, 24 | compression: str = "gzip", 25 | ): 26 | quootstrap_df = spark.read.json(quootstrap_path) 27 | quotes_context_df = spark.read.json(quotes_context_path) 28 | 29 | q2 = quootstrap_df.select(F.explode("occurrences").alias("occurrence")) 30 | fields_to_keep = [ 31 | q2.occurrence.articleUID.alias("articleUID"), 32 | q2.occurrence.articleOffset.alias("articleOffset"), 33 | ] 34 | 35 | attributed_quotes_df = q2.select(*fields_to_keep) 36 | 37 | new_quotes_context_df = quotes_context_df.join( 38 | attributed_quotes_df, on=["articleUID", "articleOffset"], how="left_anti", 39 | ) 40 | 41 | quootstrap_df.select( 42 | most_frequent("speaker").alias("speaker"), 43 | process_quote("quotation").alias("uncased_quote"), 44 | ).join( 45 | new_quotes_context_df.withColumn("uncased_quote", process_quote("quotation")), 46 | on="uncased_quote", 47 | ).drop( 48 | "uncased_quote" 49 | ).repartition( 50 | nb_partition 51 | ).write.parquet( 52 | output_path, "overwrite", compression=compression 53 | ) 54 | 55 | 56 | if __name__ == "__main__": 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument( 59 | "-q", 60 | "--quootstrap", 61 | type=str, 62 | help="Path to Quoostrap output (.json)", 63 | required=True, 64 | ) 65 | parser.add_argument( 66 | "-c", 67 | "--context", 68 | type=str, 69 | help="Path to folder with all quotes with context (.json.gz)", 70 | required=True, 71 | ) 72 | parser.add_argument( 73 | "-o", "--output", type=str, help="Path to output folder", required=True 74 | ) 75 | parser.add_argument( 76 | "-l", 77 | "--local", 78 | help="Add if you want to execute locally.", 79 | action="store_true", 80 | ) 81 | parser.add_argument( 82 | "-n", 83 | "--nb_partition", 84 | type=int, 85 | help="Number of partition for the output (useful if used with unsplittable compression algorithm). Default=200", 86 | default=200, 87 | ) 88 | parser.add_argument( 89 | "--compression", 90 | type=str, 91 | help="Compression algorithm. Can be any compatible alogrithm with Spark Parquet. Default=gzip", 92 | default="gzip", 93 | ) 94 | args = parser.parse_args() 95 | 96 | print("Starting the Spark Session") 97 | if args.local: 98 | # import findspark 99 | 100 | # findspark.init() 101 | 102 | spark = ( 103 | SparkSession.builder.master("local[24]") 104 | .appName("BootstrapLocal") 105 | .config("spark.driver.memory", "32g") 106 | .config("spark.executor.memory", "32g") 107 | .getOrCreate() 108 | ) 109 | else: 110 | spark = SparkSession.builder.appName("Bootstrap_EM").getOrCreate() 111 | 112 | find_exact_match( 113 | spark, 114 | quootstrap_path=args.quootstrap, 115 | quotes_context_path=args.context, 116 | output_path=args.output, 117 | nb_partition=args.nb_partition, 118 | compression=args.compression, 119 | ) 120 | -------------------------------------------------------------------------------- /quobert/dataprocessing/preprocessing/features.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import re 4 | import signal 5 | import string 6 | from typing import List, Optional, Union 7 | 8 | from pyspark.broadcast import Broadcast 9 | from pyspark.sql import Row, SparkSession 10 | from transformers import AutoTokenizer 11 | 12 | QUOTE_TOKEN = "[QUOTE]" 13 | QUOTE_TARGET_TOKEN = "[TARGET_QUOTE]" 14 | MASK_IDX = 103 # Index in BERT wordpiece 15 | 16 | PUNCT = re.escape("".join(x for x in string.punctuation if x not in "[]")) 17 | 18 | 19 | class TimeoutError(Exception): 20 | pass 21 | 22 | 23 | def handler(signum, frame): 24 | raise TimeoutError() 25 | 26 | 27 | def timeout(func, args=(), kwargs={}, timeout_duration=5, default=None): 28 | # set the timeout handler 29 | signal.signal(signal.SIGALRM, handler) 30 | signal.alarm(timeout_duration) 31 | try: 32 | result = func(*args, **kwargs) 33 | except TimeoutError: 34 | print(f"Timeout Error running {func} with {args} and {kwargs}") 35 | result = default 36 | finally: 37 | signal.alarm(0) 38 | 39 | return result 40 | 41 | 42 | def clean_text(masked_text): 43 | masked_text = re.sub( 44 | r"""(?i)\b((?:https?://|www\d{0,3}[.]|[a-z0-9.\-]+[.][a-z]{2,4}/)(?:[^\s()<>]+|\(([^\s()<>]+|(\([^\s()<>]+\)))*\))+(?:\(([^\s()<>]+|(\([^\s()<>]+\)))*\)|[^\s`!()\[\]{};:'".,<>?«»“”‘’]))""", 45 | " ", 46 | masked_text, 47 | ) 48 | masked_text = re.sub("[" + PUNCT + "]{4,}", " ", masked_text) 49 | return masked_text 50 | 51 | 52 | def example_to_features( 53 | articleUID: str, 54 | articleOffset: int, 55 | masked_text: str, 56 | speaker: str, 57 | targets: List[int], 58 | entities, 59 | tokenizer: Union[AutoTokenizer, Broadcast], 60 | max_seq_len: int = 320, 61 | pad_to_max_length: bool = True, 62 | ) -> Optional[Row]: 63 | """Transform examples to QuoteFeatures row. Given the context and the speaker, 64 | extract the start/end offset that match the best the speaker. 65 | Those offsets will be used as targets for the models. 66 | 67 | Args: 68 | articleUID (str): The unique identifier of the associated article 69 | articleOffset (int): The offset (in number of quotes) in the associated article 70 | left_context (str): The text left of the quote 71 | quotation (str): The quote 72 | right_context (str): The text right of the quote 73 | speaker (str): The identified speaker 74 | tokenizer (BertTokenizer): The tokenizer to use to compute the features 75 | max_len_quotation (int, optional): Maximum length for the quotation in tokens. Defaults to 100. 76 | max_seq_len (int, optional): Maximum sequence length in tokens, extra tokens will be dropped. Defaults to 320. 77 | pad_to_max_length (bool, optional): Wheter to pad the tokens to `max_seq_len`. Pad on the right. Defaults to True. 78 | 79 | Returns: 80 | Optional[Row]: The features 81 | """ 82 | tokenizer = tokenizer.value 83 | 84 | masked_text = timeout(clean_text, args=(masked_text,), default="") 85 | tokenized = timeout(tokenizer.tokenize, args=(masked_text,)) 86 | if not tokenized or len(tokenized) > max_seq_len: 87 | # print(len(tokenized), masked_text) 88 | return None 89 | encoded = timeout(tokenizer.encode, args=(tokenized,)) 90 | 91 | mask_idx = [0] + [ 92 | i for i, idx in enumerate(encoded) if idx == MASK_IDX 93 | ] # indexes of [CLS] and [MASK] token 94 | 95 | # if len(mask_idx) < 2: # This should *NOT* happen 96 | # # print("No mask token in", masked_text) 97 | # return None 98 | 99 | return Row( 100 | uid=articleUID + " " + str(articleOffset), 101 | input_ids=encoded, 102 | mask_idx=mask_idx, 103 | target=targets[0], 104 | entities=json.dumps(entities), 105 | speaker=speaker, 106 | ) 107 | 108 | 109 | def transform_to_features( 110 | spark: SparkSession, 111 | *, 112 | transformed_path: str, 113 | tokenizer_model: str, 114 | output_path: str, 115 | nb_partition: int, 116 | compression: str, 117 | kind: str, 118 | ): 119 | """Entire transformation pipeline. Entry point to the process. 120 | Create a tokenizer from the model. Read the merged data and do the transformations. 121 | Finally write the resulting Dataset/Dataframe to the disk 122 | 123 | Args: 124 | merged_path (str): Path to the folder containing the merged data 125 | tokenizer_model (str): Model of the tokenizer. Must be supported by `transformers` 126 | output_path (str): Path to the output folder 127 | nb_partition (int): Number of partition for the output 128 | compression (str, optional): A parquet compatible compression algorithm. Defaults to 'gzip'. 129 | """ 130 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_model) 131 | tokenizer.add_tokens([QUOTE_TOKEN, QUOTE_TARGET_TOKEN]) 132 | tokenizer_bc = spark.sparkContext.broadcast(tokenizer) 133 | 134 | def __example_to_features(row: Row) -> Optional[Row]: 135 | return example_to_features( 136 | row.articleUID, 137 | row.articleOffset, 138 | row.masked_text, 139 | row.speaker if kind == "train" else "", 140 | row.targets if kind == "train" else [-1], 141 | row.entities, 142 | tokenizer_bc, 143 | pad_to_max_length=False, 144 | ) 145 | 146 | transformed_df = spark.read.parquet(transformed_path) 147 | 148 | output_df = ( 149 | transformed_df.rdd.repartition(2 ** 7) 150 | .map(__example_to_features) 151 | .filter(lambda x: x is not None) 152 | .toDF() 153 | ) 154 | 155 | if kind == "test": 156 | output_df.drop("speaker", "target") # .coalesce(nb_partition) 157 | output_df.write.parquet(output_path, "overwrite", compression=compression) 158 | 159 | 160 | if __name__ == "__main__": 161 | parser = argparse.ArgumentParser() 162 | parser.add_argument( 163 | "-t", 164 | "--transformed", 165 | type=str, 166 | help="Path to the transformed (sampled) output folder (.parquet)", 167 | required=True, 168 | ) 169 | parser.add_argument( 170 | "--tokenizer", 171 | type=str, 172 | help="Name of the pretrained model, default: bert-base-cased", 173 | default="bert-base-cased", 174 | ) 175 | parser.add_argument( 176 | "-o", "--output", type=str, help="Path to output folder", required=True 177 | ) 178 | parser.add_argument( 179 | "-l", 180 | "--local", 181 | help="Add if you want to execute locally. The code is expected to be run on a cluster if you run on big files", 182 | action="store_true", 183 | ) 184 | parser.add_argument( 185 | "-n", 186 | "--nb_partition", 187 | type=int, 188 | help="Number of partition for the output (useful if used with unsplittable compression algorithm). Default=50", 189 | default=50, 190 | ) 191 | parser.add_argument( 192 | "--compression", 193 | type=str, 194 | help="Compression algorithm. Can be any compatible alogrithm with Spark Parquet. Default=gzip", 195 | default="gzip", 196 | ) 197 | parser.add_argument( 198 | "--kind", 199 | type=str, 200 | help="Which kind of data it is to transform (train = with labels, test = without labels)", 201 | required=True, 202 | choices=["train", "test"], 203 | ) 204 | args = parser.parse_args() 205 | 206 | print("Starting the Spark Session") 207 | if args.local: 208 | # import findspark 209 | 210 | # findspark.init() 211 | 212 | spark = ( 213 | SparkSession.builder.master("local[24]") 214 | .appName("FeaturesExtractorLocal") 215 | .config("spark.driver.memory", "16g") 216 | .config("spark.executor.memory", "32g") 217 | .getOrCreate() 218 | ) 219 | else: 220 | spark = SparkSession.builder.appName("FeaturesExtractor").getOrCreate() 221 | 222 | print("Starting the transformation to features") 223 | transform_to_features( 224 | spark, 225 | transformed_path=args.transformed, 226 | tokenizer_model=args.tokenizer, 227 | output_path=args.output, 228 | nb_partition=args.nb_partition, 229 | compression=args.compression, 230 | kind=args.kind, 231 | ) 232 | -------------------------------------------------------------------------------- /quobert/dataprocessing/preprocessing/merge.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from urllib.parse import urlparse 3 | 4 | from pyspark.sql import SparkSession 5 | from pyspark.sql.functions import explode, to_date, udf 6 | 7 | 8 | @udf 9 | def get_domain(x): 10 | return urlparse(x).netloc 11 | 12 | 13 | def merge( 14 | spark: SparkSession, 15 | *, 16 | quootstrap_path: str, 17 | quotes_context_path: str, 18 | output_path: str, 19 | nb_partition: int, 20 | compression: str = "gzip", 21 | ): 22 | """ Merge The output from Quootstrap together in order to create an input training set 23 | 24 | Args: 25 | spark (SparkSession): Current spark session 26 | quootstrap_path (str): HDFS path to the Quootstrap output (Q, S) pairs 27 | quotes_context_path (str): HDFS path to the quotes+context 28 | output_path (str): HDFS path to store the output of the merge 29 | nb_partition (int): Number of partition for the output 30 | compression (str, optional): A parquet compatible compression algorithm. Defaults to 'gzip'. 31 | """ 32 | quootstrap_df = spark.read.json(quootstrap_path) 33 | quotes_context_df = spark.read.json(quotes_context_path) 34 | 35 | # Extract all quotes and speakers from Quootstrap data 36 | q2 = quootstrap_df.select("quotation", explode("occurrences").alias("occurrence")) 37 | 38 | fields_to_keep = [ 39 | q2.occurrence.articleUID.alias("articleUID"), 40 | q2.occurrence.articleOffset.alias("articleOffset"), 41 | q2.occurrence.matchedSpeakerTokens.alias("speaker"), 42 | q2.occurrence.extractedBy.alias("pattern"), 43 | get_domain(q2.occurrence.website).alias("domain"), 44 | to_date(q2.occurrence.date).alias("date"), 45 | ] 46 | 47 | attributed_quotes_df = q2.select(*fields_to_keep) 48 | 49 | # Merge df and write parquet files 50 | attributed_quotes_context_df = attributed_quotes_df.join( 51 | quotes_context_df, on=["articleUID", "articleOffset"], how="inner" 52 | ) 53 | 54 | attributed_quotes_context_df.repartition(nb_partition).write.parquet( 55 | output_path, "overwrite", compression=compression 56 | ) 57 | 58 | 59 | if __name__ == "__main__": 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument( 62 | "-q", 63 | "--quootstrap", 64 | type=str, 65 | help="Path to Quoostrap output (.json)", 66 | required=True, 67 | ) 68 | parser.add_argument( 69 | "-c", 70 | "--context", 71 | type=str, 72 | help="Path to folder with all quotes with context (.json.gz)", 73 | required=True, 74 | ) 75 | parser.add_argument( 76 | "-o", "--output", type=str, help="Path to output folder", required=True 77 | ) 78 | parser.add_argument( 79 | "-n", 80 | "--nb_partition", 81 | type=int, 82 | help="Number of partition for the output (useful if used with unsplittable compression algorithm). Default=50", 83 | default=200, 84 | ) 85 | parser.add_argument( 86 | "--compression", 87 | type=str, 88 | help="Compression algorithm. Can be any compatible alogrithm with Spark Parquet. Default=gzip", 89 | default="gzip", 90 | ) 91 | parser.add_argument( 92 | "-l", 93 | "--local", 94 | help="Add if you want to execute locally. The code is expected to be run on a cluster if you run on big files", 95 | action="store_true", 96 | ) 97 | args = parser.parse_args() 98 | 99 | print("Starting the Spark Session") 100 | if args.local: 101 | import findspark 102 | 103 | findspark.init() 104 | 105 | spark = ( 106 | SparkSession.builder.master("local[24]") 107 | .appName("QuoteMergerLocal") 108 | .config("spark.driver.memory", "16g") 109 | .config("spark.executor.memory", "32g") 110 | .getOrCreate() 111 | ) 112 | else: 113 | spark = SparkSession.builder.appName("QuoteMerger").getOrCreate() 114 | 115 | print("Starting the merging process") 116 | merge( 117 | spark, 118 | quootstrap_path=args.quootstrap, 119 | quotes_context_path=args.context, 120 | output_path=args.output, 121 | nb_partition=args.nb_partition, 122 | compression=args.compression, 123 | ) 124 | -------------------------------------------------------------------------------- /quobert/dataprocessing/run.sh: -------------------------------------------------------------------------------- 1 | spark-submit \ 2 | --master yarn \ 3 | --num-executors 16 \ 4 | --executor-cores 32 \ 5 | --driver-memory 16g \ 6 | --executor-memory 64g \ 7 | --py-files pygtrie.py \ 8 | --conf spark.pyspark.python=python3 \ 9 | --conf spark.driver.maxResultSize=0 \ 10 | --conf spark.sql.shuffle.partitions=2048 \ 11 | --conf spark.executor.memoryOverhead=16g \ 12 | --conf spark.blacklist.enabled=true \ 13 | --conf spark.reducer.maxReqsInFlight=10 \ 14 | --conf spark.shuffle.io.retryWait=10s \ 15 | --conf spark.shuffle.io.maxRetries=10 \ 16 | --conf spark.shuffle.io.backLog=2048 \ 17 | "$@" 18 | -------------------------------------------------------------------------------- /quobert/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .fit import fit 2 | from .eval import evaluate 3 | from .model import BertForQuotationAttribution 4 | -------------------------------------------------------------------------------- /quobert/model/eval.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import logging 3 | import os 4 | from operator import itemgetter 5 | 6 | import torch 7 | from torch.utils.data import DataLoader, RandomSampler 8 | from tqdm.auto import tqdm 9 | 10 | from quobert.utils.data import collate_batch_eval 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def get_most_probable_entity(proba, entities): 16 | most_probable_entity_idx = proba.argmax().item() 17 | for entity, val in entities.items(): 18 | if most_probable_entity_idx in val[0]: 19 | return entity 20 | return "None" 21 | 22 | 23 | def evaluate(args, model, dataset, no_save=False, has_target=True, output_proba=True): 24 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 25 | sampler = RandomSampler(dataset) 26 | loader = DataLoader( 27 | dataset, 28 | sampler=sampler, 29 | collate_fn=collate_batch_eval, 30 | batch_size=args.eval_batch_size, 31 | ) 32 | 33 | # multi-gpu evaluate 34 | if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel): 35 | model = torch.nn.DataParallel(model) 36 | 37 | logger.info("*** Start the evaluation porcess ***") 38 | logger.info(f" Number of samples: {len(dataset)}") 39 | logger.info(f" Batch size: {args.eval_batch_size} using {args.n_gpu} GPU(s)") 40 | 41 | correct_pos, correct_neg = 0, 0 42 | total_pos, total_neg = 0, 0 43 | model.eval() 44 | out = [] 45 | 46 | for batch in tqdm(loader, desc="Evaluating"): 47 | batch = { 48 | k: v.to(args.device) if hasattr(v, "to") else v for k, v in batch.items() 49 | } 50 | with torch.no_grad(): 51 | scores = model( 52 | input_ids=batch["input_ids"], 53 | mask_ids=batch["mask_ids"], 54 | attention_mask=batch["attention_mask"], 55 | )[0].cpu() 56 | 57 | for proba, entities, speaker, uid in zip( 58 | scores, batch["entities"], batch["speakers"], batch["uid"] 59 | ): 60 | speakers_proba = { 61 | entity: proba[val[0]].sum().item() for entity, val in entities.items() 62 | } 63 | speakers_proba["None"] = proba[0].item() 64 | speakers_proba_sorted = sorted( 65 | speakers_proba.items(), key=itemgetter(1), reverse=True 66 | ) 67 | 68 | most_probable_speaker = speakers_proba_sorted[0][0] 69 | most_probable_entity = get_most_probable_entity(proba, entities) 70 | 71 | if has_target: 72 | target_speaker = speaker if speaker in entities else "None" 73 | is_correct = most_probable_speaker == target_speaker 74 | if target_speaker == "None": 75 | total_neg += 1 76 | if is_correct: 77 | correct_neg += 1 78 | else: 79 | total_pos += 1 80 | if is_correct: 81 | correct_pos += 1 82 | 83 | out.append( 84 | ( 85 | uid, 86 | speakers_proba_sorted, 87 | most_probable_speaker, 88 | most_probable_entity, 89 | ) 90 | ) 91 | 92 | if has_target: 93 | EM_neg = correct_neg / total_neg 94 | EM_pos = correct_pos / total_pos 95 | total = total_neg + total_pos 96 | EM = (correct_neg + correct_pos) / (total_neg + total_pos) 97 | logger.info(f"EM value: {EM:.2%}%, total: {total}") 98 | logger.info( 99 | f"EM pos: {EM_pos:.2%}%, total: {total_pos} ({total_pos / total:.2%})" 100 | ) 101 | logger.info( 102 | f"EM neg: {EM_neg:.2%}%, total: {total_neg} ({total_neg / total:.2%})" 103 | ) 104 | 105 | if not no_save: 106 | with open( 107 | os.path.join(args.output_file), "w", encoding="utf-8", newline="", 108 | ) as csvfile: 109 | csvwriter = csv.writer(csvfile) 110 | if output_proba: 111 | csvwriter.writerow( 112 | ["articleUID", "articleOffset", "rank", "speaker", "proba"] 113 | ) 114 | for uid, speakers_proba, _, _ in out: 115 | articleUID, articleOffset = uid.split() 116 | for i, (speaker, proba) in enumerate(speakers_proba): 117 | csvwriter.writerow( 118 | [ 119 | articleUID, 120 | articleOffset, 121 | i, 122 | speaker, 123 | round(proba * 100, 2), 124 | ] 125 | ) 126 | else: 127 | csvwriter.writerow( 128 | ["articleUID", "articleOffset", "sum_speaker", "max_speaker"] 129 | ) 130 | for uid, _, sum_speaker, max_speaker in out: 131 | articleUID, articleOffset = uid.split() 132 | csvwriter.writerow( 133 | [articleUID, articleOffset, sum_speaker, max_speaker] 134 | ) 135 | -------------------------------------------------------------------------------- /quobert/model/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from torch.nn import CrossEntropyLoss, Linear 5 | from torch.nn.functional import softmax 6 | from torch.nn.utils.rnn import pack_sequence, pad_packed_sequence 7 | from transformers import BertModel, BertPreTrainedModel 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class BertForQuotationAttribution(BertPreTrainedModel): 13 | def __init__(self, config): 14 | super().__init__(config) 15 | self.num_labels = 1 16 | 17 | self.bert = BertModel(config) 18 | self.qa_outputs = Linear(config.hidden_size, self.num_labels) 19 | 20 | self.loss_fn = CrossEntropyLoss() 21 | self.init_weights() 22 | 23 | def forward( 24 | self, 25 | input_ids, 26 | mask_ids, 27 | *, 28 | attention_mask=None, 29 | token_type_ids=None, 30 | position_ids=None, 31 | head_mask=None, 32 | inputs_embeds=None, 33 | targets=None, 34 | ): 35 | outputs = self.bert( 36 | input_ids, 37 | attention_mask=attention_mask, 38 | token_type_ids=token_type_ids, 39 | position_ids=position_ids, 40 | head_mask=head_mask, 41 | inputs_embeds=inputs_embeds, 42 | ) 43 | 44 | sequence_output = outputs[0] 45 | logits = [ 46 | self.qa_outputs(output[mask[mask >= 0]]).squeeze(-1) 47 | for output, mask in zip(sequence_output, mask_ids) 48 | ] 49 | 50 | proba, _ = pad_packed_sequence( 51 | pack_sequence( 52 | [softmax(logit, dim=0) for logit in logits], enforce_sorted=False 53 | ), 54 | batch_first=True, 55 | total_length=100, 56 | ) 57 | 58 | # logger.info(f"logits: {logits},\ntargets: {targets}\nmask_ids: {mask_ids}") 59 | outputs = (proba,) + outputs[2:] 60 | if targets is not None: 61 | loss = torch.stack( 62 | [ 63 | self.loss_fn(logit[None, :], target[None]) 64 | for logit, target in zip(logits, targets) 65 | ] 66 | ).mean() 67 | correct = torch.tensor( 68 | [ 69 | 1 if logit.argmax().item() == target.item() else 0 70 | for logit, target in zip(logits, targets) 71 | ], 72 | device=loss.get_device(), 73 | ).sum() 74 | outputs = (loss, correct,) + outputs 75 | return outputs # (loss, correct), logits, (hidden_states), (attentions) 76 | -------------------------------------------------------------------------------- /quobert/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from ._utils import set_seed, get_device 2 | -------------------------------------------------------------------------------- /quobert/utils/_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, TypeVar 3 | 4 | import numpy as np 5 | import torch 6 | 7 | T = TypeVar("T") 8 | 9 | 10 | def set_seed(seed: int): 11 | """ 12 | Fix all possible seeds for reproducibility 13 | 14 | Args: 15 | seed (int): number used to set the seeds 16 | """ 17 | random.seed(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) # type: ignore 21 | 22 | 23 | def get_device() -> torch.device: 24 | """ Check if CUDA is available """ 25 | return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 26 | 27 | 28 | def to_list(tensor: torch.Tensor) -> List[T]: 29 | return tensor.detach().cpu().tolist() 30 | -------------------------------------------------------------------------------- /quobert/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataloader import collate_batch_eval, collate_batch_train 2 | from .dataset import ConcatParquetDataset, ParquetDataset 3 | -------------------------------------------------------------------------------- /quobert/utils/data/dataloader.py: -------------------------------------------------------------------------------- 1 | import json 2 | from functools import partial 3 | from typing import Dict, List 4 | 5 | import pandas as pd 6 | import torch 7 | from torch.nn.utils.rnn import pad_sequence 8 | 9 | 10 | def __collate_batch( 11 | data: List[pd.Series], *, train: bool, test: bool 12 | ) -> Dict[str, torch.Tensor]: 13 | """ 14 | Transform a list of `ConcatParquetDataset` entries into a Dict of tensors that can be fed to a BERT model. 15 | 16 | This private method is intented to be used with a partial to set `train` and then be fed to `torch.utils.data.DataLoader` as `collate_fn` 17 | 18 | Args: 19 | data (List[pd.Series]): a list of `ConcatParquetDataset` or `ParquetDataset` entries 20 | train (bool): set to `True` if `start_offset` and `end_offset` will be returned 21 | 22 | Returns: 23 | Dict[str, torch.Tensor]: a dict containing the dataset / sample index, input_ids, mask and if `train` the start and end offset 24 | """ 25 | input_ids = pad_sequence( 26 | [torch.tensor(d.input_ids, dtype=torch.long) for d in data], batch_first=True, 27 | ) # (b_size, max_sentence_len) 28 | attention_mask = input_ids.where( 29 | input_ids == 0, torch.tensor(1) 30 | ) # (b_size, max_sentence_len) 31 | 32 | mask_ids = pad_sequence( 33 | [torch.tensor(d.mask_idx, dtype=torch.long) for d in data], 34 | batch_first=True, 35 | padding_value=-1, 36 | ) # mask for indexes of CLS/MASK tokens 37 | 38 | out = { 39 | "input_ids": input_ids, 40 | "attention_mask": attention_mask, 41 | "mask_ids": mask_ids, 42 | } 43 | 44 | if train: 45 | targets = torch.tensor([d.target for d in data], dtype=torch.long) # (b_size, ) 46 | out["targets"] = targets 47 | 48 | if test: 49 | entities = [json.loads(d.entities) for d in data] 50 | speakers = [d.speaker if "speaker" in d else "" for d in data] 51 | uid = [d.uid for d in data] 52 | out.update({"entities": entities, "speakers": speakers, "uid": uid}) 53 | 54 | return out 55 | 56 | 57 | collate_batch_train = partial(__collate_batch, train=True, test=False) 58 | collate_batch_eval = partial(__collate_batch, train=False, test=True) 59 | -------------------------------------------------------------------------------- /quobert/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | from typing import Tuple, List 3 | 4 | import pandas as pd 5 | from torch.utils.data import ConcatDataset, Dataset 6 | 7 | 8 | class ParquetDataset(Dataset): 9 | """ 10 | Parquet Dataset is a wrapper around `torch.utils.data.Dataset`. It loads and serve a parquet DataFrame. 11 | 12 | Args: 13 | parquet_path (str): The path to a single parquet file, can be compressed. If using multiple file, check `ConcatParquetDataset` 14 | sample_n (int, optional): Set to the number of items from the data set to sample. If 0, use all items. Defaults to 0. 15 | """ 16 | 17 | def __init__( 18 | self, parquet_path: str, sample_n: int = 0 19 | ): 20 | super(ParquetDataset, self).__init__() 21 | self.df = pd.read_parquet(parquet_path) 22 | if sample_n > 0: 23 | self.df = self.df.sample(n=sample_n) 24 | 25 | def __len__(self) -> int: 26 | return len(self.df) 27 | 28 | def __getitem__(self, idx: int) -> pd.Series: 29 | return self.df.iloc[idx] 30 | 31 | 32 | class ConcatParquetDataset(ConcatDataset): 33 | """ 34 | Concat Parquet Dataset is a wrapper around `torch.utils.data.ConcatDataset`. It serves multiple `ParquetDataset` 35 | 36 | Args: 37 | datasets (List[ParquetDataset]): a list of `ParquetDataset` 38 | """ 39 | 40 | def __init__(self, datasets: List[ParquetDataset]): 41 | super(ConcatParquetDataset, self).__init__(datasets) 42 | 43 | def __getitem__(self, idx: int) -> Tuple[int, int, pd.Series]: 44 | if idx < 0: 45 | if -idx > len(self): 46 | raise ValueError( 47 | "absolute value of index should not exceed dataset length" 48 | ) 49 | idx = len(self) + idx 50 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 51 | if dataset_idx == 0: 52 | sample_idx = idx 53 | else: 54 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 55 | return self.datasets[dataset_idx][sample_idx] 56 | -------------------------------------------------------------------------------- /quootstrap/README.md: -------------------------------------------------------------------------------- 1 | # Quootstrap 2 | This folder contains an adapted version of the reference implementation of Quootstrap, as described in the paper [[PDF]](https://dlab.epfl.ch/people/west/pub/Pavllo-Piccardi-West_ICWSM-18.pdf): 3 | > Dario Pavllo, Tiziano Piccardi, Robert West. *Quootstrap: Scalable Unsupervised Extraction of Quotation-Speaker Pairs from Large News Corpora via Bootstrapping*. In *Proceedings of the 12th International Conference on Web and Social Media (ICWSM),* 2018. 4 | 5 | #### Disclaimer 6 | 7 | *This folder contains an improved version of the code from the paper using Wikidata instead of Freebase for entity linking and with the option of saving all the quotes with their context and speakers, instead of only those extracted by the rules. For the code of the paper, please check [this branch]( https://github.com/epfl-dlab/quootstrap/tree/master )* 8 | 9 | ## How to run 10 | 11 | Go to the **Release** section and download the .zip archive, which contains the executable `quootstrap.jar` as well as all necessary dependencies and configuration files. You can also find a convenient script `extraction_quotations.sh` that can be used to run the application on a Yarn cluster. The script runs this command: 12 | ```bash 13 | spark-submit --jars opennlp-tools-1.9.2.jar,spinn3r-client-3.4.05-edit.jar,stanford-corenlp-3.8.0.jar,jsoup-1.10.3.jar,guava-14.0.1.jar \ 14 | --num-executors 25 \ 15 | --executor-cores 16 \ 16 | --driver-memory 128g \ 17 | --executor-memory 128g \ 18 | --conf "spark.executor.memoryOverhead=32768" \ 19 | --class ch.epfl.dlab.quootstrap.QuotationExtraction \ 20 | --master yarn \ 21 | quootstrap.jar 22 | ``` 23 | After tuning the settings to suit your particular configuration, you can run the command as: 24 | ```bash 25 | ./extraction_quotations.sh 26 | ``` 27 | 28 | ### Setup 29 | 30 | To run our code, you need: 31 | - Java 8 32 | - Spark 2.3 33 | - The entire Spinn3r dataset (available on the hadoop cluster) or your own dataset 34 | - Our dataset of people extracted from Wikidata (available on `/data`) 35 | 36 | ### How to build 37 | 38 | Clone the repository and import it as an Eclipse project. All dependencies are downloaded through Maven. To build the application, generate a .jar file with all source files and run it as explained in the previous section. Alternatively, you can use Spark in local mode for experimenting. Additional instructions on how to extend the project with new functionalities (e.g. support for new datasets) are reported later. 39 | 40 | ## Configuration 41 | The first configuration file is `config.properties`. The most important fields in order to get the application running are: 42 | - `NEWS_DATASET_PATH` specifies the HDFS path of the Spinn3r news dataset 43 | - `PEOPLE_DATASET_PATH` specifies the HDFS path of the Wikidata people list. 44 | - `NEWS_DATASET_LOADER` specifies which loader to use for the dataset. 45 | - `EXPORT_PATH` specifies the HDFS path for the (quotation, speaker) pairs output. 46 | - `CONTEXT_PATH` specifies the HDFS path for the (quotation, context, lang) output. 47 | - `NUM_ITERATIONS=1` specifies the number of iterations of the extraction algorithm. Set it to 1 if you want to run the algorithm only on the seed patterns (iteration 0). A number of iterations between 3 and 5 is more than enough. **Note:** Currently this feature is not supported anymore and a fix to support wikidata would be appraciated. 48 | 49 | Additionally you can change the flow of the application with the following parameters: 50 | 51 | - `DO_QUOTATION_ATTRIBUTION` & `EXPORT_RESULTS`: Whether to perform Quote Attribution and export the results. This is the original output from Quootstrap. 52 | - `EXPORT_CONTEXT`: Whether to export all the quotes and their context. Not present in the original version of Quootstrap 53 | - `EXPORT_SPEAKERS`: Whether to export all full matches of candidate in a given article. Partial matches are dealt with in Quobert 54 | - `EXPORT_ARTICLE`: Whether to export all articles in a easier to read format (not recommanded if your input data is already easily readable) 55 | 56 | The second configuration file is `seedPatterns.txt`, which, as the name suggests, contains the seed patterns that are used in the first iteration, one by line. 57 | 58 | 59 | ## Exporting results 60 | ### Quotation-speaker pairs 61 | 62 | You can save the results as a HDFS text file formatted in JSON, with one record per line. For each record, the full quotation is exported, as well as the full name of the speaker (as reported in the article), his/her Wikidata ID, the confidence value of the tuple, and the occurrences in which the quotation was found. As for the latter, we report the article ID, an incremental offset within the article (which is useful for linking together split quotations), the pattern that extracted the tuple along with its confidence, the website, and the date the article appeared. 63 | 64 | ```json 65 | { 66 | "canonicalQuotation":"action is easy it is very easy comedy is difficult", 67 | "confidence":1.0, 68 | "numOccurrences":6, 69 | "numSpeakers":1, 70 | "occurrences":[ 71 | {"articleOffset":0, 72 | "articleUID":"2012031008_00073468_W", 73 | "date":"2012-03-10 08:49:40", 74 | "extractedBy":"$Q , $S said", 75 | "matchedSpeakerTokens":"Akshay Kumar", 76 | "patternConfidence":1.0, 77 | "quotation":"action is easy, it is very easy. comedy is difficult,", 78 | "website":"http://deccanchronicle.com/channels/showbiz/bollywood/action-easy-comedy-toughest-akshay-kumar-887"}, 79 | ... 80 | {"articleOffset":0, 81 | "articleUID":"2012031012_00038989_W", 82 | "date":"2012-03-10 12:23:40", 83 | "extractedBy":"$Q , $S told", 84 | "matchedSpeakerTokens":"Akshay Kumar", 85 | "patternConfidence":0.7360332115513023, 86 | "quotation":"action is easy, it is very easy. comedy is difficult,", 87 | "website":"http://hindustantimes.com/Entertainment/Bollywood/Comedy-is-difficult-Akshay-Kumar/Article1-823315.aspx"} 88 | ], 89 | "quotation":"action is easy, it is very easy. comedy is difficult,", 90 | "speaker":["Akshay Kumar"], 91 | "speakerID":["Q233748"] 92 | } 93 | ``` 94 | Remarks: 95 | - *articleOffset* can have gaps (but quotations are guaranteed to be ordered correctly). 96 | - *canonicalQuotation* is the internal representation of a particular quotation, used for pattern matching purposes. The string is converted to lower case and punctuation marks are removed. 97 | - As in the example above, the full quotation might differ from the one(s) found in *occurrences* due to the quotation merging mechanism. We always report the longest (and most likely useful) quotation when multiple choices are possible. 98 | 99 | ## Adding support for new datasets/formats 100 | 101 | If you want to add support for other datasets/formats, you can provide a concrete implementation for the Java interface `DatasetLoader` and specify its full class name in the `NEWS_DATASET_LOADER` field of the configuration. For each article, you must supply a unique ID (int64/long), the website in which it can be found, and its content in tokenized format, i.e. as a list of strings. We provide an implementation for our JSON Spinn3r dataset in `ch.epfl.dlab.quootstrap.Spinn3rDatasetLoader`. For parquet dataframes in `ch.epfl.dlab.quootstrap.ParquetDatasetLoader`, and for Standford Spinn3r data format `ch.epfl.dlab.quootstrap.Spinn3rTextDatasetLoader`. 102 | 103 | ## Replacing the tokenizer 104 | If, for any reason (e.g. license, language other than English), you do not want to depend on Stanford PTBTokenizer, you can provide your own implementation of the `ch.epfl.dlab.spinn3r.Tokenizer` interface. You only have to implement two methods: `tokenize` and `untokenize`. Tokenization is one of the least critical steps in our pipeline, and does not impact the final result significantly. 105 | 106 | ## License 107 | We release our work under the MIT license. Third-party components, such as Stanford CoreNLP, are subject to their respective licenses. 108 | 109 | If you use our code and/or data in your research, please cite our paper [[PDF]](https://dlab.epfl.ch/people/west/pub/Pavllo-Piccardi-West_ICWSM-18.pdf): 110 | ``` 111 | @inproceedings{quootstrap2018, 112 | title={Quootstrap: Scalable Unsupervised Extraction of Quotation-Speaker Pairs from Large News Corpora via Bootstrapping}, 113 | author={Pavllo, Dario and Piccardi, Tiziano and West, Robert}, 114 | booktitle={Proceedings of the 12th International Conference on Web and Social Media (ICWSM)}, 115 | year={2018} 116 | } 117 | ``` 118 | -------------------------------------------------------------------------------- /quootstrap/extract_quotations.sh: -------------------------------------------------------------------------------- 1 | spark-submit --jars spinn3r-client-3.4.05-edit.jar,stanford-corenlp-3.8.0.jar,jsoup-1.10.3.jar,guava-14.0.1.jar \ 2 | --num-executors 25 \ 3 | --executor-cores 16 \ 4 | --driver-memory 128g \ 5 | --executor-memory 128g \ 6 | --conf "spark.executor.memoryOverhead=32768" \ 7 | --class ch.epfl.dlab.quootstrap.QuotationExtraction \ 8 | --master yarn \ 9 | quootstrap.jar -------------------------------------------------------------------------------- /quootstrap/lib/spinn3r-client-3.4.05-edit.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-dlab/Quotebank/a203acf5a05e34841671578d14d6a0be6a66a3ef/quootstrap/lib/spinn3r-client-3.4.05-edit.jar -------------------------------------------------------------------------------- /quootstrap/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 4.0.0 3 | ch.epfl 4 | ProtoToJson 5 | 0.0.1-SNAPSHOT 6 | 7 | 8 | 1.8 9 | 1.8 10 | 11 | 12 | 13 | 14 | 15 | org.apache.spark 16 | spark-sql_2.11 17 | 2.3.1 18 | 19 | 20 | com.google.code.gson 21 | gson 22 | 2.3 23 | 24 | 25 | com.sun.jersey 26 | jersey-server 27 | 1.2 28 | 29 | 30 | 31 | org.jsoup 32 | jsoup 33 | 1.10.3 34 | 35 | 36 | edu.stanford.nlp 37 | stanford-corenlp 38 | 3.8.0 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /quootstrap/quootstrap.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-dlab/Quotebank/a203acf5a05e34841671578d14d6a0be6a66a3ef/quootstrap/quootstrap.tar.gz -------------------------------------------------------------------------------- /quootstrap/resources/config.properties: -------------------------------------------------------------------------------- 1 | # Basic settings (news dataset path, people list dataset, etc.) 2 | NEWS_DATASET_PATH=/path/to/articles 3 | PEOPLE_DATASET_PATH=/path/to/wikidata_people_ALIVE_FILTERED-NAMES-CLEAN.tsv 4 | 5 | # Number of tokens to be considered as a quotation 6 | MIN_QUOTATION_SIZE=6 7 | MAX_QUOTATION_SIZE=500 8 | 9 | # Case sensitive true == bad idea (no match on broken cases articles) 10 | CASE_SENSITIVE=false 11 | 12 | # Provide the concrete implementation class name of DatasetLoader 13 | NEWS_DATASET_LOADER=ch.epfl.dlab.quootstrap.Spinn3rTextDatasetLoader 14 | 15 | # Settings for exporting results 16 | EXPORT_RESULTS=true 17 | EXPORT_PATH=/path/to/quootstrap 18 | DO_QUOTE_ATTRIBUTION=true 19 | 20 | # Settings for exporting Article / Speakers 21 | EXPORT_SPEAKERS=true 22 | SPEAKERS_PATH=/path/to/speakers 23 | 24 | # Settings for exporting Articles 25 | EXPORT_ARTICLE=false 26 | ARTICLE_PATH=/path/to/articles 27 | 28 | # Settings for exporting the quotes and context of the quotes 29 | EXPORT_CONTEXT=true 30 | CONTEXT_PATH=/path/to/quotes_context 31 | NUM_PARTITIONS=100 32 | 33 | 34 | ###### UNUSED PARAMS ###### 35 | 36 | # Note: Currently, only 1 iteration is supported 37 | NUM_ITERATIONS=1 38 | 39 | # Set to true if you want to use Spark in local mode 40 | LOCAL_MODE=false 41 | 42 | # Hyperparameters 43 | PATTERN_CONFIDENCE_THRESHOLD=0.7 44 | PATTERN_CLUSTERING_THRESHOLDS=0|0.0002|0.001|0.005 45 | 46 | # Quotation merging 47 | ENABLE_QUOTATION_MERGING=false 48 | ENABLE_DEDUPLICATION=true 49 | MERGING_SHINGLE_SIZE=10 50 | 51 | # Cache settings: some frequently used (and immutable) RDDs can be cached on disk 52 | # in order to speed up the execution of the algorithm after the first time. 53 | # Note that the cache must be invalidated manually (by deleting the files) 54 | # if the code or the internal parameters are changed. 55 | ENABLE_CACHE=false 56 | CACHE_PATH=/path/to/cache 57 | 58 | # Note: Currently Evaluation of the results is not supported anymore 59 | # Evaluation settings 60 | GROUND_TRUTH_PATH=ground_truth.json 61 | # Enable the evaluation on the last iteration 62 | ENABLE_FINAL_EVALUATION=false 63 | # Enable the evaluation on intermediate iterations (slower) 64 | ENABLE_INTERMEDIATE_EVALUATION=false 65 | 66 | # Debug settings 67 | # Set to true if you want to dump all new discovered patterns at each iteration 68 | DEBUG_DUMP_PATTERNS=false 69 | 70 | # Set to true if you want to convert the entire input data to lower case (not recommended) 71 | DEBUG_CASE_FOLDING=false 72 | 73 | # Deprecated 74 | LANGUAGE_FILTER=en|uk 75 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/ConfigManager.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.io.FileInputStream; 4 | import java.io.IOException; 5 | import java.util.Arrays; 6 | import java.util.List; 7 | import java.util.Properties; 8 | import java.util.stream.Collectors; 9 | 10 | public final class ConfigManager { 11 | 12 | private static final String CONFIG_FILE = "resources/config.properties"; 13 | 14 | private static final ConfigManager INSTANCE = new ConfigManager(); 15 | 16 | private final String datasetPath; 17 | private final String namesPath; 18 | private final int numIterations; 19 | private final int minQuotationSize; 20 | private final int maxQuotationSize; 21 | private final boolean caseSensitive; 22 | private final List langFilter; 23 | 24 | private final String newsDatasetLoader; 25 | 26 | private final boolean exportEnabled; 27 | private final String exportPath; 28 | private final boolean doQuoteAttribution; 29 | 30 | private final boolean contextExportEnabled; 31 | private final String contextOutputPath; 32 | 33 | private final int numPartition; 34 | 35 | private final boolean localModeEnabled; 36 | 37 | private final double confidenceThreshold; 38 | private final List clusteringThresholds; 39 | 40 | private final boolean mergingEnabled; 41 | private final boolean doDedup; 42 | private final int mergingShingleSize; 43 | 44 | private final boolean cacheEnabled; 45 | private final String cachePath; 46 | 47 | private final String groundTruthPath; 48 | private final boolean finalEvaluationEnabled; 49 | private final boolean intermediateEvaluationEnabled; 50 | 51 | private final boolean dumpPatternsEnabled; 52 | private final boolean caseFoldingEnabled; 53 | 54 | private String outputPath; 55 | 56 | private final boolean speakersEnabled; 57 | private final String speakersPath; 58 | private final boolean articleEnabled; 59 | private final String articlePath; 60 | 61 | private ConfigManager() { 62 | final Properties prop = new Properties(); 63 | try { 64 | prop.load(new FileInputStream(CONFIG_FILE)); 65 | } catch (IOException e) { 66 | throw new IllegalArgumentException("Unable to read config file", e); 67 | } 68 | 69 | datasetPath = prop.getProperty("NEWS_DATASET_PATH"); 70 | namesPath = prop.getProperty("PEOPLE_DATASET_PATH"); 71 | numIterations = Integer.parseInt( 72 | prop.getProperty("NUM_ITERATIONS")); 73 | minQuotationSize = Integer.parseInt( 74 | prop.getProperty("MIN_QUOTATION_SIZE")); 75 | maxQuotationSize = Integer.parseInt( 76 | prop.getProperty("MAX_QUOTATION_SIZE")); 77 | caseSensitive = prop.getProperty("CASE_SENSITIVE").equals("true"); 78 | langFilter = Arrays.asList(prop.getProperty("LANGUAGE_FILTER").split("\\|")); 79 | 80 | newsDatasetLoader = prop.getProperty("NEWS_DATASET_LOADER"); 81 | 82 | doQuoteAttribution = prop.getProperty("DO_QUOTE_ATTRIBUTION").equals("true"); 83 | 84 | exportEnabled = prop.getProperty("EXPORT_RESULTS").equals("true"); 85 | exportPath = prop.getProperty("EXPORT_PATH"); 86 | 87 | contextExportEnabled = prop.getProperty("EXPORT_CONTEXT").equals("true"); 88 | contextOutputPath = prop.getProperty("CONTEXT_PATH"); 89 | 90 | speakersEnabled = prop.getProperty("EXPORT_SPEAKERS").equals("true"); 91 | speakersPath = prop.getProperty("SPEAKERS_PATH"); 92 | 93 | articleEnabled = prop.getProperty("EXPORT_ARTICLE").equals("true"); 94 | articlePath = prop.getProperty("ARTICLE_PATH"); 95 | 96 | numPartition = Integer.parseInt( 97 | prop.getProperty("NUM_PARTITIONS")); 98 | 99 | localModeEnabled = prop.getProperty("LOCAL_MODE").equals("true"); 100 | 101 | confidenceThreshold = Double.parseDouble( 102 | prop.getProperty("PATTERN_CONFIDENCE_THRESHOLD")); 103 | clusteringThresholds = Arrays.asList(prop.getProperty("PATTERN_CLUSTERING_THRESHOLDS").split("\\|")) 104 | .stream() 105 | .map(Double::parseDouble) 106 | .collect(Collectors.toList()); 107 | 108 | mergingEnabled = prop.getProperty("ENABLE_QUOTATION_MERGING").equals("true"); 109 | doDedup = prop.getProperty("ENABLE_DEDUPLICATION").equals("true"); 110 | mergingShingleSize = Integer.parseInt( 111 | prop.getProperty("MERGING_SHINGLE_SIZE")); 112 | 113 | cacheEnabled = prop.getProperty("ENABLE_CACHE").equals("true"); 114 | cachePath = prop.getProperty("CACHE_PATH"); 115 | 116 | groundTruthPath = prop.getProperty("GROUND_TRUTH_PATH"); 117 | finalEvaluationEnabled = prop.getProperty("ENABLE_FINAL_EVALUATION").equals("true"); 118 | intermediateEvaluationEnabled = prop.getProperty("ENABLE_INTERMEDIATE_EVALUATION").equals("true"); 119 | 120 | dumpPatternsEnabled = prop.getProperty("DEBUG_DUMP_PATTERNS").equals("true"); 121 | caseFoldingEnabled = prop.getProperty("DEBUG_CASE_FOLDING").equals("true"); 122 | 123 | outputPath = ""; 124 | } 125 | 126 | public static ConfigManager getInstance() { 127 | return INSTANCE; 128 | } 129 | 130 | public String getDatasetPath() { 131 | return datasetPath; 132 | } 133 | 134 | public String getNamesPath() { 135 | return namesPath; 136 | } 137 | 138 | public int getNumIterations() { 139 | return numIterations; 140 | } 141 | 142 | public boolean isCaseSensitive() { 143 | return caseSensitive; 144 | } 145 | 146 | public List getLangFilter() { 147 | return langFilter; 148 | } 149 | 150 | public String getLangSuffix() { 151 | return langFilter.stream().sorted().collect(Collectors.joining("-")); 152 | } 153 | 154 | public boolean isLocalModeEnabled() { 155 | return localModeEnabled; 156 | } 157 | 158 | public double getConfidenceThreshold() { 159 | return confidenceThreshold; 160 | } 161 | 162 | public List getClusteringThresholds() { 163 | return clusteringThresholds; 164 | } 165 | 166 | public boolean isCacheEnabled() { 167 | return cacheEnabled; 168 | } 169 | 170 | public String getCachePath() { 171 | return cachePath; 172 | } 173 | 174 | public String getGroundTruthPath() { 175 | return groundTruthPath; 176 | } 177 | 178 | public boolean isFinalEvaluationEnabled() { 179 | return finalEvaluationEnabled; 180 | } 181 | 182 | public boolean isIntermediateEvaluationEnabled() { 183 | return intermediateEvaluationEnabled; 184 | } 185 | 186 | public boolean isDumpPatternsEnabled() { 187 | return dumpPatternsEnabled; 188 | } 189 | 190 | public boolean isCaseFoldingEnabled() { 191 | return caseFoldingEnabled; 192 | } 193 | 194 | public boolean isMergingEnabled() { 195 | return mergingEnabled; 196 | } 197 | 198 | public int getMergingShingleSize() { 199 | return mergingShingleSize; 200 | } 201 | 202 | public String getOutputPath() { 203 | return outputPath; 204 | } 205 | 206 | public void setOutputPath(String outputPath) { 207 | this.outputPath = outputPath; 208 | } 209 | 210 | public boolean isExportEnabled() { 211 | return exportEnabled; 212 | } 213 | 214 | public String getExportPath() { 215 | return exportPath; 216 | } 217 | 218 | public boolean isSpeakersEnabled() { 219 | return speakersEnabled; 220 | } 221 | 222 | public String getSpeakersPath() { 223 | return speakersPath; 224 | } 225 | 226 | public boolean isArticleEnabled() { 227 | return articleEnabled; 228 | } 229 | 230 | public String getArticlePath() { 231 | return articlePath; 232 | } 233 | 234 | public String getNewsDatasetLoader() { 235 | return newsDatasetLoader; 236 | } 237 | 238 | public boolean isContextExportEnabled() { 239 | return contextExportEnabled; 240 | } 241 | 242 | public String getContextOutputPath() { 243 | return contextOutputPath; 244 | } 245 | 246 | public int getNumPartition() { 247 | return numPartition; 248 | } 249 | 250 | public boolean isDoQuoteAttribution() { 251 | return doQuoteAttribution; 252 | } 253 | 254 | public boolean isDoDedupEnabled() { 255 | return doDedup; 256 | } 257 | 258 | public int getMinQuotationSize() { 259 | return minQuotationSize; 260 | } 261 | 262 | public int getMaxQuotationSize() { 263 | return maxQuotationSize; 264 | } 265 | } 266 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/ContextExtractor.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Collections; 5 | import java.util.List; 6 | import java.util.stream.Collectors; 7 | 8 | import ch.epfl.dlab.spinn3r.Tokenizer; 9 | import ch.epfl.dlab.spinn3r.TokenizerImpl; 10 | import scala.Tuple2; 11 | 12 | public class ContextExtractor { 13 | 14 | public static List extractQuotations(List tokenList, String articleUid) { 15 | String[] openTokens = {"``", "", ""}; 17 | 18 | final int minQuotationSize = 6; // ConfigManager.getInstance().getMinQuotationSize(); 19 | final int maxQuotationSize = 500; // ConfigManager.getInstance().getMaxQuotationSize(); 20 | final int minContextSize = 50; // At least N tokens in surrounding context 21 | final int maxContextSize = 100; // At most N tokens in surrounding context 22 | 23 | // Extract candidate quotations (without surrounding context) 24 | List> candidateQuotations = new ArrayList<>(); // First token, last token 25 | int expectedToken = -1; 26 | int startIndex = -1; 27 | for (int t = 0; t < tokenList.size(); t++) { 28 | for (int i = 0; i < openTokens.length; i++) { 29 | if (tokenList.get(t).startsWith(openTokens[i])) { 30 | expectedToken = i; 31 | startIndex = t; 32 | break; 33 | } 34 | } 35 | if (expectedToken != -1) { 36 | for (int i = 0; i < closeTokens.length; i++) { 37 | if (tokenList.get(t).equals(closeTokens[i])) { 38 | if (i == expectedToken && t - startIndex >= minQuotationSize && t - startIndex <= maxQuotationSize) { 39 | // Consider only well-formed quotations (e.g. must be followed by , not ) 40 | // In addition, when nested quotations are present, only the innermost one is extracted. 41 | candidateQuotations.add(new Tuple2<>(startIndex, t)); 42 | } 43 | expectedToken = -1; 44 | startIndex = -1; 45 | break; 46 | } 47 | } 48 | } 49 | } 50 | 51 | Tokenizer tokenizer = new TokenizerImpl(); 52 | 53 | // Extract surrounding context from quotation 54 | List quotations = new ArrayList<>(); 55 | int articleIdx = 0; 56 | for (Tuple2 t : candidateQuotations) { 57 | 58 | List currentSentence = new ArrayList<>(); 59 | 60 | List prevContext = new ArrayList<>(); 61 | int leftOffset = scanBackward(tokenList, prevContext, t._1 - 1, minContextSize, maxContextSize); 62 | Collections.reverse(prevContext); 63 | prevContext.forEach(token -> { 64 | currentSentence.add(new Token(token, Token.Type.GENERIC)); 65 | }); 66 | 67 | List quotation = new ArrayList<>(); 68 | for (int i = t._1 + 1; i < t._2; i++) { 69 | quotation.add(tokenList.get(i)); 70 | } 71 | 72 | String quotationStr = tokenizer.untokenize(quotation.stream() 73 | .filter(x -> !StaticRules.isHtmlTag(x)) 74 | .collect(Collectors.toList())); 75 | 76 | currentSentence.add(new Token(quotationStr, Token.Type.QUOTATION)); 77 | 78 | List nextContext = new ArrayList<>(); 79 | int rightOffset = scanForward(tokenList, nextContext, t._2 + 1, minContextSize, maxContextSize); 80 | nextContext.forEach(token -> { 81 | currentSentence.add(new Token(token, Token.Type.GENERIC)); 82 | }); 83 | 84 | // Remove HTML tags 85 | // currentSentence.removeIf(x -> StaticRules.isHtmlTag(x.toString())); 86 | 87 | quotations.add(new Sentence(currentSentence, articleUid, articleIdx, t._1 + 1, leftOffset, rightOffset + 1)); 88 | articleIdx++; 89 | } 90 | 91 | return quotations; 92 | } 93 | 94 | private static int scanForward(List tokens, List context, int initialIndex, int minSize, int maxSize) { 95 | int finalIndexLower = Math.min(initialIndex + minSize, tokens.size() - 1); 96 | 97 | boolean anotherQuotationFound = false; 98 | for (int i = initialIndex; i < finalIndexLower; i++) { 99 | String token = tokens.get(i); 100 | if (token.equals("``")) { 101 | anotherQuotationFound = true; 102 | } else if (token.equals("''")) { 103 | anotherQuotationFound = false; 104 | } 105 | context.add(token); 106 | } 107 | 108 | int finalIndexUpper = Math.min(initialIndex + maxSize, tokens.size() - 1); 109 | for (int i = finalIndexLower; i < finalIndexUpper; i++) { 110 | String token = tokens.get(i); 111 | context.add(token); 112 | 113 | if (anotherQuotationFound) { 114 | if (token.equals("''")) { 115 | return i; 116 | } 117 | } else { 118 | if (token.equals(".") || token.equals("
") || token.equals("

")) { 119 | return i; 120 | } 121 | } 122 | } 123 | // No stopper found -> revert to short context 124 | while (context.size() > minSize) { 125 | context.remove(context.size() - 1); 126 | finalIndexUpper--; 127 | } 128 | 129 | return finalIndexUpper; 130 | } 131 | 132 | private static int scanBackward(List tokens, List context, int initialIndex, int minSize, int maxSize) { 133 | int finalIndexLower = Math.max(initialIndex - minSize, 0); 134 | 135 | boolean anotherQuotationFound = false; 136 | for (int i = initialIndex; i > finalIndexLower; i--) { 137 | String token = tokens.get(i); 138 | if (token.equals("''")) { 139 | anotherQuotationFound = true; 140 | } else if (token.equals("``")) { 141 | anotherQuotationFound = false; 142 | } 143 | context.add(token); 144 | } 145 | 146 | int finalIndexUpper = Math.max(initialIndex - maxSize, 0); 147 | for (int i = finalIndexLower; i >= finalIndexUpper; i--) { 148 | String token = tokens.get(i); 149 | context.add(token); 150 | 151 | if (anotherQuotationFound) { 152 | if (token.equals("``")) { 153 | return i; 154 | } 155 | } else { 156 | if (token.equals(".") || token.equals("
") || token.startsWith(" revert to short context 162 | while (context.size() > minSize) { 163 | context.remove(context.size() - 1); 164 | finalIndexUpper++; 165 | } 166 | 167 | return finalIndexUpper; 168 | } 169 | 170 | public static Sentence postProcess(Sentence s) { 171 | 172 | List tokens = new ArrayList<>(s.getTokens()); 173 | 174 | // Remove HTML tags 175 | tokens.removeIf(x -> x.getType() != Token.Type.QUOTATION && StaticRules.isHtmlTag(x.toString())); 176 | 177 | // Find other quotations other than the main quotation 178 | int startQuot = 0; 179 | int i = 0; 180 | while (i < tokens.size()) { 181 | String t = tokens.get(i).toString(); 182 | if (tokens.get(i).getType() == Token.Type.QUOTATION) { 183 | startQuot = -1; 184 | } 185 | if (t.equals("''") && startQuot != -1) { 186 | List sub = tokens.subList(startQuot, i + 1); 187 | sub.clear(); 188 | sub.add(new Token("[QUOTE]", Token.Type.GENERIC)); 189 | startQuot = -1; 190 | i = startQuot; 191 | } else if ((t.equals("``"))) { 192 | startQuot = i; 193 | } 194 | i++; 195 | } 196 | if (startQuot != -1) { 197 | List sub = tokens.subList(startQuot, tokens.size()); 198 | sub.clear(); 199 | sub.add(new Token("[QUOTE]", Token.Type.GENERIC)); 200 | } 201 | 202 | return new Sentence(tokens, s.getArticleUid(), s.getIndex(), s.getQuotationOffset(), s.getLeftOffset(), s.getRightOffset()); 203 | } 204 | 205 | public static Sentence canonicalizeQuotation(Sentence s) { 206 | List tokens = new ArrayList<>(s.getTokens()); 207 | 208 | // Canonicalize quotation, and push out final punctuation marks, if present inside quotation 209 | for (int i = 0; i < tokens.size(); i++) { 210 | if (tokens.get(i).getType() == Token.Type.QUOTATION) { 211 | String tokenStr = tokens.get(i).toString().trim(); 212 | final String[] patterns = {",", "."}; 213 | for (String p : patterns) { 214 | if (tokenStr.endsWith(p)) { 215 | if (i == tokens.size() - 1 || !StaticRules.isPunctuation(tokens.get(i + 1).toString())) { 216 | // Push out 217 | tokens.add(i + 1, new Token(p, Token.Type.GENERIC)); 218 | } 219 | } 220 | } 221 | tokens.set(i, new Token(StaticRules.canonicalizeQuotation(tokenStr), Token.Type.QUOTATION)); 222 | } 223 | } 224 | 225 | return new Sentence(tokens, s.getArticleUid(), s.getIndex(), s.getQuotationOffset(), s.getLeftOffset(), s.getRightOffset()); 226 | } 227 | 228 | } 229 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/DatasetLoader.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.util.List; 4 | import java.util.Set; 5 | 6 | import org.apache.spark.api.java.JavaRDD; 7 | import org.apache.spark.api.java.JavaSparkContext; 8 | 9 | /** 10 | * This interface is used for loading a news dataset. 11 | * The handling of the format is up to the concrete implementation. 12 | */ 13 | public interface DatasetLoader { 14 | 15 | /** 16 | * Load all articles as a Java RDD. 17 | * @param sc the JavaSparkContext used for loading the RDD 18 | * @param datasetPath the path of the dataset. 19 | * @param languageFilter a set that contains the requested language codes. 20 | * @return a JavaRDD of Articles 21 | */ 22 | JavaRDD
loadArticles(JavaSparkContext sc, String datasetPath, Set languageFilter); 23 | 24 | /** 25 | * A news article, identified by a unique long identifier. 26 | */ 27 | public interface Article { 28 | 29 | /** Get the unique identifier of this article. */ 30 | String getArticleUID(); 31 | 32 | /** Get the content of this article in tokenized format. */ 33 | List getArticleContent(); 34 | 35 | /** Get the domain name in which this article was found. */ 36 | String getWebsite(); 37 | 38 | /** Get the date of this article. */ 39 | String getDate(); 40 | 41 | /** Get the version of the article */ 42 | String getVersion(); 43 | 44 | /** Get the title of this article */ 45 | String getTitle(); 46 | 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/Dawg.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Collection; 5 | import java.util.Collections; 6 | import java.util.HashMap; 7 | import java.util.List; 8 | import java.util.Map; 9 | import java.util.Optional; 10 | 11 | import com.google.common.collect.HashMultimap; 12 | import com.google.common.collect.Multimap; 13 | 14 | import scala.Tuple2; 15 | 16 | /** 17 | * This class implements a Directed Acyclic Word Graph (DAWG), 18 | * also known as deterministic acyclic finite state automaton (DAFSA). 19 | */ 20 | public class Dawg { 21 | 22 | final Node root; 23 | final Multimap allNodes; 24 | 25 | public Dawg() { 26 | root = new Node(""); 27 | allNodes = HashMultimap.create(); 28 | } 29 | 30 | public Node getRoot() { 31 | return root; 32 | } 33 | 34 | public static Optional convert(List newPattern) { 35 | // Remove leading/trailing asterisks 36 | int startPos = 0; 37 | while (newPattern.size() > startPos && newPattern.get(startPos).getType() == Token.Type.ANY) { 38 | startPos++; 39 | } 40 | 41 | int endPos = newPattern.size(); 42 | while (endPos > 0 && newPattern.get(endPos - 1).getType() == Token.Type.ANY) { 43 | endPos--; 44 | } 45 | 46 | if (startPos >= endPos) { 47 | return Optional.empty(); 48 | } 49 | 50 | newPattern = newPattern.subList(startPos, endPos); 51 | 52 | newPattern.replaceAll(x -> { 53 | if (x.toString().equals(Pattern.QUOTATION_PLACEHOLDER)) { 54 | return new Token(null, Token.Type.QUOTATION); 55 | } else if (x.toString().equals(Pattern.SPEAKER_PLACEHOLDER)) { 56 | return new Token(null, Token.Type.SPEAKER); 57 | } 58 | return x; 59 | }); 60 | 61 | try { 62 | Pattern p = new Pattern(newPattern); 63 | 64 | // Add a constraint on the number of consecutive $* tokens 65 | int maxCount = 0; 66 | final int threshold = 5; 67 | for (Token t : p.getTokens()) { 68 | if (t.getType() == Token.Type.ANY) { 69 | maxCount++; 70 | } else { 71 | maxCount = 0; 72 | } 73 | if (maxCount > threshold) { 74 | return Optional.empty(); 75 | } 76 | } 77 | 78 | return Optional.of(p); 79 | } catch (Exception e) { 80 | // Discard invalid patterns 81 | return Optional.empty(); 82 | } 83 | } 84 | 85 | public void addAll(List> patterns) { 86 | patterns.sort((x, y) -> { 87 | for (int i = 0; i < Math.min(x.size(), y.size()); i++) { 88 | int comp = x.get(i).compareTo(y.get(i)); 89 | if (comp != 0) { 90 | return comp; 91 | } 92 | } 93 | return Integer.compare(x.size(), y.size()); 94 | }); 95 | patterns.forEach(this::add); 96 | if (root.children.size() > 0) { 97 | addOrReplace(root); 98 | } 99 | } 100 | 101 | private void add(List tokens) { 102 | Tuple2> tail = getInsertionPoint(tokens); 103 | if (tail._1 == null) { 104 | return; 105 | } 106 | if (tail._1.children.size() > 0) { 107 | addOrReplace(tail._1); 108 | } 109 | addSuffix(tail._1, tail._2); 110 | } 111 | 112 | private Tuple2> getInsertionPoint(List tokens) { 113 | Node node = root; 114 | for (int i = 0; i < tokens.size(); i++) { 115 | node.count++; 116 | 117 | String token = tokens.get(i); 118 | Node next = node.children.get(token); 119 | if (next == null) { 120 | return new Tuple2<>(node, tokens.subList(i, tokens.size())); 121 | } 122 | node = next; 123 | } 124 | return new Tuple2<>(null, Collections.emptyList()); 125 | } 126 | 127 | /** 128 | * Export this graph into a GraphViz description (for rendering). 129 | * @return a string in GraphViz format 130 | */ 131 | public String toDot() { 132 | StringBuilder builder = new StringBuilder(); 133 | builder.append("digraph g{\n\trankdir=LR\n"); 134 | 135 | allNodes.put("", root); 136 | allNodes.entries().forEach(x -> { 137 | builder.append("\t" + System.identityHashCode(x.getValue()) 138 | + " [label=\"" + x.getKey() + " ["+ x.getValue().count + "]\"]\n"); 139 | }); 140 | allNodes.values().forEach(x -> { 141 | x.children.values().forEach(y -> { 142 | builder.append("\t" + System.identityHashCode(x) + " -> " + System.identityHashCode(y) + "\n"); 143 | }); 144 | }); 145 | allNodes.remove("", root); 146 | builder.append("}\n"); 147 | return builder.toString(); 148 | } 149 | 150 | private void addSuffix(Node node, List tokens) { 151 | for (String token : tokens) { 152 | Node next = new Node(token); 153 | next.count++; 154 | node.children.put(token, next); 155 | node.lastAdded = next; 156 | node = next; 157 | } 158 | } 159 | 160 | private void addOrReplace(Node node) { 161 | Node child = node.lastAdded; 162 | if (child.children.size() > 0) { 163 | addOrReplace(child); 164 | } 165 | 166 | Collection candidateChildren = allNodes.get(child.word); 167 | for (Node existingChild : candidateChildren) { 168 | if (existingChild.equals(child)) { 169 | existingChild.count += child.count; 170 | node.lastAdded = existingChild; 171 | node.children.put(child.word, existingChild); 172 | return; 173 | } 174 | } 175 | 176 | allNodes.put(child.word, child); 177 | } 178 | 179 | public static class Node { 180 | private final String word; 181 | private final Map children; 182 | private int count; 183 | private Node lastAdded; 184 | 185 | public Node(String word) { 186 | this.word = word; 187 | this.children = new HashMap<>(); 188 | this.count = 0; 189 | } 190 | 191 | public String getWord() { 192 | return word; 193 | } 194 | 195 | public int getCount() { 196 | return count; 197 | } 198 | 199 | public Collection getNodes() { 200 | return children.values(); 201 | } 202 | 203 | @Override 204 | public int hashCode() { 205 | int hash = word.hashCode(); 206 | // Commutative hash function 207 | for (Map.Entry entry : children.entrySet()) { 208 | hash ^= entry.getKey().hashCode() + 31 * System.identityHashCode(entry.getValue()); 209 | } 210 | return hash; 211 | } 212 | 213 | @Override 214 | public boolean equals(final Object obj) { 215 | if (this == obj) 216 | return true; 217 | 218 | if (obj instanceof Node) { 219 | final Node node = (Node) obj; 220 | return hashCode() == node.hashCode() && word.equals(node.word) && equalsChildren(node); 221 | } 222 | 223 | return false; 224 | } 225 | 226 | @Override 227 | public String toString() { 228 | return Integer.toHexString(System.identityHashCode(this)) + "(" + word + ", " + count + ")"; 229 | } 230 | 231 | private boolean equalsChildren(final Node other) { 232 | if (children.size() != other.children.size()) { 233 | return false; 234 | } 235 | 236 | for (Map.Entry entry : children.entrySet()) { 237 | Node match = other.children.get(entry.getKey()); 238 | // The pointer comparison is intentional! 239 | if (match == null || match != entry.getValue()) { 240 | return false; 241 | } 242 | } 243 | 244 | return true; 245 | } 246 | 247 | /** 248 | * Extract all patterns stored in this graph. 249 | * @param all The output list. 250 | * @param current Temporary storage (must be empty). 251 | */ 252 | public void dump(List> all, List current) { 253 | current.add(this); 254 | if (children.isEmpty()) { 255 | all.add(new ArrayList<>(current)); 256 | } else { 257 | children.values().forEach(x -> x.dump(all, current)); 258 | } 259 | current.remove(current.size() - 1); 260 | } 261 | 262 | } 263 | } 264 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/Exporter.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.util.ArrayList; 4 | import java.util.HashMap; 5 | import java.util.HashSet; 6 | import java.util.Iterator; 7 | import java.util.List; 8 | import java.util.Map; 9 | import java.util.Set; 10 | 11 | import org.apache.hadoop.io.compress.GzipCodec; 12 | import org.apache.spark.api.java.JavaPairRDD; 13 | import org.apache.spark.api.java.JavaRDD; 14 | import org.apache.spark.api.java.JavaSparkContext; 15 | 16 | import com.google.gson.GsonBuilder; 17 | import com.google.gson.JsonArray; 18 | import com.google.gson.JsonObject; 19 | import com.google.gson.JsonPrimitive; 20 | 21 | import scala.Tuple2; 22 | import scala.Tuple3; 23 | import scala.Tuple4; 24 | 25 | public class Exporter { 26 | 27 | private final JavaPairRDD quotationMap; 28 | 29 | /** Map article UID to a tuple (full quotation, website, date) */ 30 | private final JavaPairRDD, Tuple3> articles; 31 | 32 | private final int numPartition; 33 | 34 | public Exporter(JavaSparkContext sc, JavaRDD sentences, NameDatabaseWikiData people) { 35 | 36 | final Set langSet = new HashSet<>(ConfigManager.getInstance().getLangFilter()); 37 | this.articles = QuotationExtraction.getConcreteDatasetLoader().loadArticles(sc, 38 | ConfigManager.getInstance().getDatasetPath(), langSet) 39 | .mapToPair(x -> new Tuple2<>(new Tuple2<>(x.getWebsite(), x.getDate()), new Tuple2<>(x.getArticleContent(), x.getArticleUID()))) 40 | .flatMapValues(x -> ContextExtractor.extractQuotations(x._1(), x._2())) 41 | .mapToPair(x -> new Tuple2<>(x._2.getKey(), new Tuple3<>(x._2.getQuotation(), x._1._1, x._1._2))); 42 | 43 | this.quotationMap = computeQuotationMap(sc); 44 | this.numPartition = ConfigManager.getInstance().getNumPartition(); 45 | } 46 | 47 | private JavaPairRDD computeQuotationMap(JavaSparkContext sc) { 48 | Set langSet = new HashSet<>(ConfigManager.getInstance().getLangFilter()); 49 | 50 | // Reconstruct quotations (from the lower-case canonical form to the full form) 51 | return QuotationExtraction.getConcreteDatasetLoader().loadArticles(sc, 52 | ConfigManager.getInstance().getDatasetPath(), langSet) 53 | .flatMap(x -> ContextExtractor.extractQuotations(x.getArticleContent(), x.getArticleUID()).iterator()) 54 | .mapToPair(x -> new Tuple2<>(StaticRules.canonicalizeQuotation(x.getQuotation()), x.getQuotation())) 55 | .reduceByKey((x, y) -> { 56 | // Out of multiple possibilities, get the longest quotation 57 | if (x.length() > y.length()) { 58 | return x; 59 | } else if (x.length() < y.length()) { 60 | return y; 61 | } else { 62 | // Lexicographical comparison to ensure determinism 63 | return x.compareTo(y) == -1 ? x : y; 64 | } 65 | }); 66 | } 67 | 68 | public void exportResults(JavaPairRDD>, LineageInfo>> pairs) { 69 | String exportPath = ConfigManager.getInstance().getExportPath(); 70 | 71 | JavaPairRDD, String, String, String>> articleMap = pairs.mapToPair(x -> new Tuple2<>(x._1, x._2._2)) // (canonical quotation, lineage info) 72 | .flatMapValues(x -> { 73 | // (key) 74 | List> values = new ArrayList<>(); 75 | for (int i = 0; i < x.getPatterns().size(); i++) { 76 | values.add(x.getSentences().get(i).getKey()); 77 | } 78 | return values; 79 | }) // (canonical quotation, key) 80 | .mapToPair(Tuple2::swap) // (key, canonical quotation) 81 | .join(this.articles) // (key, (canonical quotation, (website, date))) 82 | .mapToPair(x -> new Tuple2<>(x._2._1, new Tuple4<>(x._1, x._2._2._1(), x._2._2._2(), x._2._2._3()))); // (canonical quotation, (key, full quotation, website, date)) 83 | 84 | pairs // (canonical quotation, (speaker, lineage info)) 85 | .join(quotationMap) 86 | .mapValues(x -> new Tuple3<>(x._1._1, x._1._2, x._2)) // (canonical quotation, (speakers, lineage info, full quotation)) 87 | .cogroup(articleMap) 88 | .map(t -> { 89 | 90 | String canonicalQuotation = t._1; 91 | Map, Tuple3> articles = new HashMap<>(); 92 | t._2._2.forEach(x -> { 93 | articles.put(x._1(), new Tuple3<>(x._2(), x._3(), x._4())); // (key, (full quotation, website, date)) 94 | }); 95 | 96 | Iterator>, LineageInfo, String>> it = t._2._1.iterator(); 97 | if (!it.hasNext()) { 98 | return null; 99 | } 100 | Tuple3>, LineageInfo, String> data = it.next(); 101 | 102 | if (data._2().getPatterns().size() != articles.size()) { 103 | return null; 104 | } 105 | 106 | JsonArray ids = new JsonArray(); 107 | data._1().stream().map(y -> new JsonPrimitive(y._1)).forEach(y -> ids.add(y)); 108 | 109 | JsonArray names = new JsonArray(); 110 | data._1().stream().map(y -> new JsonPrimitive(y._2)).forEach(y -> names.add(y)); 111 | 112 | 113 | JsonObject o = new JsonObject(); 114 | o.addProperty("quotation", data._3()); 115 | o.addProperty("canonicalQuotation", canonicalQuotation); 116 | o.addProperty("numSpeakers", data._1().size()); 117 | o.add("speaker", names); 118 | o.add("speakerID", ids); 119 | o.addProperty("confidence", data._2().getConfidence()); // Tuple confidence 120 | o.addProperty("numOccurrences", data._2().getPatterns().size()); 121 | 122 | JsonArray occurrences = new JsonArray(); 123 | for (int i = 0; i < data._2().getPatterns().size(); i++) { 124 | JsonObject occ = new JsonObject(); 125 | Tuple2 key = data._2().getSentences().get(i).getKey(); 126 | occ.addProperty("articleUID", key._1); 127 | occ.addProperty("articleOffset", data._2().getSentences().get(i).getIndex()); 128 | occ.addProperty("extractedBy", data._2().getPatterns().get(i).toString(false)); 129 | occ.addProperty("patternConfidence", data._2().getPatterns().get(i).getConfidenceMetric()); 130 | occ.addProperty("quotation", articles.get(key)._1()); 131 | 132 | 133 | String matchedTokens = String.join(" ", data._2().getAliases().get(i)); 134 | occ.addProperty("matchedSpeakerTokens", matchedTokens); 135 | occ.addProperty("website", articles.get(key)._2()); 136 | String date = articles.get(key)._3(); 137 | if (!date.isEmpty()) { 138 | occ.addProperty("date", date); 139 | } 140 | occurrences.add(occ); 141 | } 142 | o.add("occurrences", occurrences); 143 | 144 | return new GsonBuilder().disableHtmlEscaping().create().toJson(o); 145 | }) 146 | .filter(x -> x != null) 147 | .repartition(numPartition) 148 | .saveAsTextFile(exportPath, GzipCodec.class); 149 | } 150 | 151 | } 152 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/ExporterArticle.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.io.IOException; 4 | 5 | import org.apache.hadoop.io.compress.GzipCodec; 6 | import org.apache.spark.api.java.JavaRDD; 7 | import org.apache.spark.api.java.JavaSparkContext; 8 | 9 | import com.google.gson.GsonBuilder; 10 | import com.google.gson.JsonObject; 11 | 12 | import ch.epfl.dlab.quootstrap.DatasetLoader.Article; 13 | 14 | public class ExporterArticle { 15 | 16 | private final JavaRDD
allArticles; 17 | // private final int numPartition; 18 | 19 | public ExporterArticle(JavaRDD
allArticles) { 20 | this.allArticles = allArticles; 21 | // this.numPartition = ConfigManager.getInstance().getNumPartition(); 22 | } 23 | 24 | public void exportResults(String exportPath, JavaSparkContext sc) throws IOException { 25 | allArticles.map(x -> { //repartition(numPartition). 26 | JsonObject o = new JsonObject(); 27 | o.addProperty("articleUID", x.getArticleUID()); 28 | o.addProperty("content", String.join(" ", x.getArticleContent())); 29 | o.addProperty("date", x.getDate()); 30 | o.addProperty("website", x.getWebsite()); 31 | o.addProperty("phase", x.getVersion()); 32 | o.addProperty("title", x.getTitle()); 33 | 34 | return new GsonBuilder().disableHtmlEscaping().create().toJson(o); 35 | }).filter(x -> x != null).saveAsTextFile(exportPath, GzipCodec.class); 36 | } 37 | 38 | } 39 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/ExporterContext.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.io.IOException; 4 | import java.util.List; 5 | 6 | import org.apache.hadoop.io.compress.GzipCodec; 7 | import org.apache.spark.api.java.JavaRDD; 8 | import org.apache.spark.api.java.JavaSparkContext; 9 | 10 | import com.google.gson.GsonBuilder; 11 | import com.google.gson.JsonObject; 12 | 13 | import ch.epfl.dlab.spinn3r.Tokenizer; 14 | import ch.epfl.dlab.spinn3r.TokenizerImpl; 15 | 16 | public class ExporterContext { 17 | 18 | private final JavaRDD sentences; 19 | private final int numPartition; 20 | 21 | public ExporterContext(JavaRDD sentences) { 22 | this.sentences = sentences; 23 | this.numPartition = ConfigManager.getInstance().getNumPartition(); 24 | } 25 | 26 | public void exportResults(String exportPath, JavaSparkContext sc) throws IOException { 27 | sentences.repartition(numPartition).map(x -> { 28 | JsonObject o = new JsonObject(); 29 | o.addProperty("articleUID", x.getArticleUid()); 30 | o.addProperty("articleOffset", x.getIndex()); 31 | 32 | String leftContext = ""; 33 | String rightContext = ""; 34 | String quotation = ""; 35 | List tokens = x.getTokens(); 36 | Tokenizer tokenizer = new TokenizerImpl(); 37 | 38 | for (int i = 0; i < x.getTokenCount(); i++) { 39 | Token t = tokens.get(i); 40 | if (t.getType() == Token.Type.QUOTATION) { 41 | leftContext = tokenizer.untokenize(Token.getStrings(tokens.subList(0, i))); 42 | quotation = t.toString(); 43 | rightContext = tokenizer.untokenize(Token.getStrings(tokens.subList(i + 1, x.getTokenCount()))); 44 | break; 45 | } 46 | } 47 | 48 | o.addProperty("leftContext", leftContext); 49 | o.addProperty("rightContext", rightContext); 50 | o.addProperty("quotation", quotation); 51 | o.addProperty("quotationOffset", x.getQuotationOffset()); 52 | o.addProperty("leftOffset", x.getLeftOffset()); 53 | o.addProperty("rightOffset", x.getRightOffset()); 54 | 55 | return new GsonBuilder().disableHtmlEscaping().create().toJson(o); 56 | }).filter(x -> x != null).saveAsTextFile(exportPath, GzipCodec.class); 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/ExporterSpeakers.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.io.IOException; 4 | import java.util.stream.Collectors; 5 | 6 | import org.apache.hadoop.io.compress.GzipCodec; 7 | import org.apache.spark.api.java.JavaPairRDD; 8 | import org.apache.spark.api.java.JavaSparkContext; 9 | 10 | import com.google.gson.GsonBuilder; 11 | import com.google.gson.JsonArray; 12 | import com.google.gson.JsonObject; 13 | import com.google.gson.JsonPrimitive; 14 | 15 | import scala.Tuple2; 16 | 17 | public class ExporterSpeakers { 18 | 19 | private final JavaPairRDD> allSpeakers; 20 | private final int numPartition; 21 | 22 | public ExporterSpeakers(JavaPairRDD> allSpeakers) { 23 | this.allSpeakers = allSpeakers; 24 | this.numPartition = ConfigManager.getInstance().getNumPartition(); 25 | } 26 | 27 | public void exportResults(String exportPath, JavaSparkContext sc) throws IOException { 28 | allSpeakers.repartition(numPartition).map(x -> { 29 | JsonObject o = new JsonObject(); 30 | o.addProperty("articleUID", x._1); 31 | JsonArray names = new JsonArray(); 32 | for (SpeakerAlias alias: x._2) { 33 | JsonObject current = new JsonObject(); 34 | JsonArray ids = new JsonArray(); 35 | JsonArray offsets = new JsonArray(); 36 | String current_alias = alias.getAlias().stream().collect(Collectors.joining(" ")); 37 | alias.getIds().stream().map(y -> new Tuple2<>(new JsonPrimitive(y._1), new JsonPrimitive(y._2))).forEach(y -> { 38 | JsonArray current_ids = new JsonArray(); 39 | current_ids.add(y._1); 40 | current_ids.add(y._2); 41 | ids.add(current_ids); 42 | }); 43 | alias.getOffsets().stream().map(y -> new Tuple2<>(new JsonPrimitive(y._1), new JsonPrimitive(y._2))).forEach(y -> { 44 | JsonArray current_offset = new JsonArray(); 45 | current_offset.add(y._1); 46 | current_offset.add(y._2); 47 | offsets.add(current_offset); 48 | }); 49 | current.addProperty("name", current_alias); 50 | current.add("ids", ids); 51 | current.add("offsets", offsets); 52 | names.add(current); 53 | } 54 | o.add("names", names); 55 | 56 | return new GsonBuilder().disableHtmlEscaping().create().toJson(o); 57 | }).filter(x -> x != null).saveAsTextFile(exportPath, GzipCodec.class); 58 | } 59 | 60 | } 61 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/HashTrie.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.io.Serializable; 4 | import java.util.ArrayList; 5 | import java.util.Collection; 6 | import java.util.Collections; 7 | import java.util.HashMap; 8 | import java.util.Iterator; 9 | import java.util.List; 10 | import java.util.Locale; 11 | import java.util.Map; 12 | import java.util.Map.Entry; 13 | 14 | import scala.Tuple2; 15 | 16 | public class HashTrie implements Serializable { 17 | 18 | private static final long serialVersionUID = -9164247688163451454L; 19 | 20 | private final Node rootNode; 21 | private final boolean caseSensitive; 22 | 23 | public HashTrie(Iterable, List>>> substrings, boolean caseSensitive) { 24 | this.rootNode = new Node(null, caseSensitive, Collections.emptyList()); 25 | this.caseSensitive = caseSensitive; 26 | substrings.forEach(this::insertSubstring); 27 | } 28 | 29 | private void insertSubstring(Tuple2, List>> substring) { 30 | Node current = rootNode; 31 | 32 | for (int i = 0; i < substring._1.size(); i++) { 33 | String token = substring._1.get(i); 34 | String key = caseSensitive ? token : token.toLowerCase(Locale.ROOT); 35 | 36 | Node next = current.findChild(key); 37 | if (next == null) { 38 | next = new Node(token, caseSensitive, Collections.emptyList()); 39 | current.children.put(key, next); 40 | } 41 | current = next; 42 | } 43 | 44 | current.terminal = true; 45 | current.addIds(substring._2); 46 | } 47 | 48 | public List> getAllSubstrings() { 49 | List> allSubstrings = new ArrayList<>(); 50 | 51 | List currentSubstring = new ArrayList<>(); 52 | DFS(rootNode, currentSubstring, allSubstrings); 53 | return allSubstrings; 54 | } 55 | 56 | public Node getRootNode() { 57 | return rootNode; 58 | } 59 | 60 | private void DFS(Node current, List currentSubstring, List> allSubstrings) { 61 | if (current.isTerminal()) { 62 | allSubstrings.add(new ArrayList<>(currentSubstring)); 63 | } 64 | 65 | for (Map.Entry next : current) { 66 | currentSubstring.add(next.getKey()); 67 | DFS(next.getValue(), currentSubstring, allSubstrings); 68 | currentSubstring.remove(currentSubstring.size() - 1); 69 | } 70 | } 71 | 72 | public static class Node implements Iterable>, Serializable { 73 | 74 | private static final long serialVersionUID = -4344489198225825075L; 75 | 76 | private final Map children; 77 | private final boolean caseSensitive; 78 | private final String value; 79 | private boolean terminal; 80 | private List> ids; // (id, standard name) 81 | 82 | public Node(String value, boolean caseSensitive, Collection> ids) { 83 | this.children = new HashMap<>(); 84 | this.caseSensitive = caseSensitive; 85 | this.value = value; 86 | this.terminal = false; 87 | this.ids = new ArrayList<>(ids); 88 | } 89 | 90 | public void addIds(Collection> ids) { 91 | this.ids.addAll(ids); 92 | } 93 | 94 | public boolean hasChildren() { 95 | return !children.isEmpty(); 96 | } 97 | 98 | public Node findChild(String token) { 99 | return children.get(caseSensitive ? token : token.toLowerCase(Locale.ROOT)); 100 | } 101 | 102 | public boolean isTerminal() { 103 | return terminal; 104 | } 105 | 106 | public String getValue() { 107 | return value; 108 | } 109 | 110 | public List> getIds() { 111 | return ids; 112 | } 113 | 114 | @Override 115 | public Iterator> iterator() { 116 | return Collections.unmodifiableMap(children).entrySet().iterator(); 117 | } 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/HashTriePatternMatcher.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Collections; 5 | import java.util.List; 6 | 7 | import scala.Tuple2; 8 | import scala.Tuple3; 9 | 10 | public class HashTriePatternMatcher { 11 | 12 | private final HashTrie trie; 13 | private final List currentMatch; 14 | private List> currentIds; 15 | 16 | public HashTriePatternMatcher(HashTrie trie) { 17 | this.trie = trie; 18 | currentMatch = new ArrayList<>(); 19 | } 20 | 21 | public boolean match(List tokens) { 22 | List, List>>> longestMatches = new ArrayList<>(); 23 | boolean result = false; 24 | for (int i = 0; i < tokens.size(); i++) { 25 | currentMatch.clear(); 26 | if (matchImpl(tokens, trie.getRootNode(), i)) { 27 | result = true; 28 | longestMatches.add(new Tuple3<>(i, new ArrayList<>(currentMatch), new ArrayList<>(currentIds))); 29 | } 30 | } 31 | 32 | if (result) { 33 | // Get longest match out of multiple matches 34 | int maxLen = Collections.max(longestMatches, (x, y) -> Integer.compare(x._2().size(), y._2().size()))._2() 35 | .size(); 36 | longestMatches.removeIf(x -> x._2().size() != maxLen); 37 | 38 | currentMatch.clear(); 39 | currentMatch.addAll(longestMatches.get(0)._2()); 40 | currentIds = longestMatches.get(0)._3(); 41 | } 42 | 43 | return result; 44 | } 45 | 46 | public List multiMatch(List tokens) { 47 | ArrayList matches = new ArrayList<>(); 48 | for (int i = 0; i < tokens.size(); i++) { 49 | currentMatch.clear(); 50 | if (matchImpl(tokens, trie.getRootNode(), i)) { 51 | Tuple2 offset = new Tuple2<>(i, i + currentMatch.size()); 52 | SpeakerAlias speaker = new SpeakerAlias(new ArrayList<>(currentMatch), new ArrayList<>(currentIds), 53 | offset); 54 | int idx = matches.indexOf(speaker); 55 | if (idx >= 0) { 56 | matches.get(idx).addOffset(offset); 57 | } else { 58 | matches.add(speaker); 59 | } 60 | i += currentMatch.size() - 1; // Avoid matching subsequences 61 | } 62 | } 63 | 64 | return new ArrayList<>(matches); 65 | } 66 | 67 | public boolean match(Sentence s) { 68 | List tokens = s.getTokens(); 69 | List, List>>> longestMatches = new ArrayList<>(); 70 | boolean result = false; 71 | for (int i = 0; i < tokens.size(); i++) { 72 | currentMatch.clear(); 73 | if (matchImpl(tokens, trie.getRootNode(), i)) { 74 | result = true; 75 | longestMatches.add(new Tuple3<>(i, new ArrayList<>(currentMatch), new ArrayList<>(currentIds))); 76 | } 77 | } 78 | 79 | if (result) { 80 | int maxLen = Collections.max(longestMatches, (x, y) -> Integer.compare(x._2().size(), y._2().size()))._2() 81 | .size(); 82 | longestMatches.removeIf(x -> x._2().size() != maxLen); 83 | 84 | // If there are multiple speakers with max length, select the one that is 85 | // nearest to the quotation 86 | currentMatch.clear(); 87 | if (longestMatches.size() > 1) { 88 | int quotationIdx = -1; 89 | for (int i = 0; i < tokens.size(); i++) { 90 | if (tokens.get(i).getType() == Token.Type.QUOTATION) { 91 | quotationIdx = i; 92 | break; 93 | } 94 | } 95 | final int qi = quotationIdx; 96 | Tuple3, List>> nearest = Collections.min(longestMatches, 97 | (x, y) -> { 98 | int delta1 = Math.abs(x._1() - qi); 99 | int delta2 = Math.abs(y._1() - qi); 100 | return Integer.compare(delta1, delta2); 101 | }); 102 | currentMatch.addAll(nearest._2()); 103 | currentIds = nearest._3(); 104 | } else { 105 | currentMatch.addAll(longestMatches.get(0)._2()); 106 | currentIds = longestMatches.get(0)._3(); 107 | } 108 | } 109 | 110 | return result; 111 | } 112 | 113 | public SpeakerAlias getLongestMatch() { 114 | return new SpeakerAlias(new ArrayList<>(currentMatch), new ArrayList<>(currentIds), new Tuple2<>(0, 0)); 115 | } 116 | 117 | private boolean matchImpl(List tokens, HashTrie.Node current, int i) { 118 | 119 | if (i == tokens.size()) { 120 | return false; 121 | } 122 | 123 | if (tokens.get(i).getType() != Token.Type.GENERIC) { 124 | return false; 125 | } 126 | 127 | String tokenStr = tokens.get(i).toString(); 128 | HashTrie.Node next = current.findChild(tokenStr); 129 | if (next != null) { 130 | currentMatch.add(tokenStr); // next.getValue()); 131 | boolean result = false; 132 | if (next.isTerminal()) { 133 | // Match found 134 | currentIds = next.getIds(); 135 | result = true; 136 | } 137 | 138 | // Even if a match is found, try to match a longer sequence 139 | if (matchImpl(tokens, next, i + 1) || result) { 140 | return true; 141 | } 142 | 143 | currentMatch.remove(currentMatch.size() - 1); 144 | } 145 | 146 | return false; 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/Hashed.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.io.Serializable; 4 | 5 | /** 6 | * Implements a lightweight hashed object that contains 7 | * a 128-bit hash of a string (using MurmurHash3 128bit x64), 8 | * as well as its original length prior to hashing it. 9 | */ 10 | public final class Hashed implements Serializable, Comparable { 11 | 12 | private static final long serialVersionUID = -5391976195212991788L; 13 | 14 | private final long h1; 15 | private final long h2; 16 | private final int len; 17 | 18 | public Hashed(String s) { 19 | long h1 = 0; 20 | long h2 = 0; 21 | 22 | final byte[] key = s.getBytes(); 23 | final int len = key.length; 24 | final int offset = 0; 25 | 26 | final long c1 = 0x87c37b91114253d5L; 27 | final long c2 = 0x4cf5ad432745937fL; 28 | 29 | int roundedEnd = offset + (len & 0xFFFFFFF0); 30 | for (int i = offset; i < roundedEnd; i += 16) { 31 | long k1 = arrayToLong(key, i); 32 | long k2 = arrayToLong(key, i+8); 33 | k1 *= c1; 34 | k1 = Long.rotateLeft(k1,31); 35 | k1 *= c2; 36 | h1 ^= k1; 37 | h1 = Long.rotateLeft(h1,27); 38 | h1 += h2; 39 | h1 = h1*5 + 0x52dce729; 40 | k2 *= c2; 41 | k2 = Long.rotateLeft(k2,33); 42 | k2 *= c1; 43 | h2 ^= k2; 44 | h2 = Long.rotateLeft(h2,31); 45 | h2 += h1; 46 | h2 = h2*5 + 0x38495ab5; 47 | } 48 | 49 | long k1 = 0; 50 | long k2 = 0; 51 | 52 | switch (len & 15) { 53 | case 15: k2 = (key[roundedEnd+14] & 0xffL) << 48; 54 | case 14: k2 |= (key[roundedEnd+13] & 0xffL) << 40; 55 | case 13: k2 |= (key[roundedEnd+12] & 0xffL) << 32; 56 | case 12: k2 |= (key[roundedEnd+11] & 0xffL) << 24; 57 | case 11: k2 |= (key[roundedEnd+10] & 0xffL) << 16; 58 | case 10: k2 |= (key[roundedEnd+ 9] & 0xffL) << 8; 59 | case 9: k2 |= (key[roundedEnd+ 8] & 0xffL); 60 | k2 *= c2; 61 | k2 = Long.rotateLeft(k2, 33); 62 | k2 *= c1; 63 | h2 ^= k2; 64 | case 8: k1 = ((long)key[roundedEnd+7]) << 56; 65 | case 7: k1 |= (key[roundedEnd+6] & 0xffL) << 48; 66 | case 6: k1 |= (key[roundedEnd+5] & 0xffL) << 40; 67 | case 5: k1 |= (key[roundedEnd+4] & 0xffL) << 32; 68 | case 4: k1 |= (key[roundedEnd+3] & 0xffL) << 24; 69 | case 3: k1 |= (key[roundedEnd+2] & 0xffL) << 16; 70 | case 2: k1 |= (key[roundedEnd+1] & 0xffL) << 8; 71 | case 1: k1 |= (key[roundedEnd ] & 0xffL); 72 | k1 *= c1; k1 = Long.rotateLeft(k1,31); k1 *= c2; h1 ^= k1; 73 | } 74 | 75 | h1 ^= len; h2 ^= len; 76 | 77 | h1 += h2; 78 | h2 += h1; 79 | 80 | h1 = mix(h1); 81 | h2 = mix(h2); 82 | 83 | h1 += h2; 84 | h2 += h1; 85 | 86 | this.h1 = h1; 87 | this.h2 = h2; 88 | this.len = s.length(); 89 | } 90 | 91 | private static long mix(long k) { 92 | k ^= k >>> 33; 93 | k *= 0xff51afd7ed558ccdL; 94 | k ^= k >>> 33; 95 | k *= 0xc4ceb9fe1a85ec53L; 96 | k ^= k >>> 33; 97 | return k; 98 | } 99 | 100 | private static long arrayToLong(byte[] buf, int offset) { 101 | return ((long)buf[offset+7] << 56) 102 | | ((buf[offset+6] & 0xffL) << 48) 103 | | ((buf[offset+5] & 0xffL) << 40) 104 | | ((buf[offset+4] & 0xffL) << 32) 105 | | ((buf[offset+3] & 0xffL) << 24) 106 | | ((buf[offset+2] & 0xffL) << 16) 107 | | ((buf[offset+1] & 0xffL) << 8) 108 | | ((buf[offset ] & 0xffL)); 109 | } 110 | 111 | @Override 112 | public String toString() { 113 | return String.format("%016X", h1) + String.format("%016X", h2) + ":" + len; 114 | } 115 | 116 | @Override 117 | public boolean equals(Object o) { 118 | if (o instanceof Hashed) { 119 | Hashed h = (Hashed) o; 120 | return h1 == h.h1 && h2 == h.h2 && len == h.len; 121 | } 122 | return false; 123 | } 124 | 125 | @Override 126 | public int hashCode() { 127 | // Return the 32 least significant bits of the hash 128 | return (int) (h1 & 0x00000000ffffffff); 129 | } 130 | 131 | public long hashCode64() { 132 | return h1 ^ h2; 133 | } 134 | 135 | public int getLength() { 136 | return len; 137 | } 138 | 139 | @Override 140 | public int compareTo(Hashed o) { 141 | // First, compare by the original string length... 142 | if (len != o.len) { 143 | return Integer.compare(len, o.len); 144 | } 145 | 146 | // In case of a tie, define a lexicographical order based on the hash 147 | if (h1 != o.h1) { 148 | return Long.compare(h1, o.h1); 149 | } 150 | return Long.compare(h2, o.h2); 151 | } 152 | 153 | } 154 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/LineageInfo.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.io.Serializable; 4 | import java.util.List; 5 | import java.util.stream.Collectors; 6 | 7 | /** 8 | * Instances of this class contain the lineage information associated with a quotation-speaker pair. 9 | * 10 | */ 11 | public final class LineageInfo implements Serializable { 12 | 13 | private static final long serialVersionUID = -6380483748689471512L; 14 | 15 | private final List patterns; 16 | private final List sentences; 17 | private final List> aliases; 18 | private final double confidence; 19 | 20 | public LineageInfo(List patterns, List sentences, List> aliases, double confidence) { 21 | this.patterns = patterns; 22 | this.sentences = sentences; 23 | this.confidence = confidence; 24 | this.aliases = aliases; 25 | } 26 | 27 | public List getPatterns() { 28 | return patterns; 29 | } 30 | 31 | public List getSentences() { 32 | return sentences; 33 | } 34 | 35 | public List> getAliases() { 36 | return aliases; 37 | } 38 | 39 | public double getConfidence() { 40 | return confidence; 41 | } 42 | 43 | @Override 44 | public String toString() { 45 | return "{Confidence: " + confidence + ", Patterns: " + patterns + ", Sentences: " 46 | + sentences.stream().map(x -> "<" + x.getKey().toString() + "> " + x.toString()).collect(Collectors.toList()) + "}"; 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/MultiCounter.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.io.Serializable; 4 | import java.util.HashMap; 5 | import java.util.Map; 6 | 7 | import org.apache.spark.util.LongAccumulator; 8 | import org.apache.spark.api.java.JavaSparkContext; 9 | 10 | public class MultiCounter implements Serializable { 11 | 12 | private static final long serialVersionUID = -5606791394577221239L; 13 | 14 | private final Map accumulators; 15 | 16 | public MultiCounter(JavaSparkContext sc, String... counters) { 17 | accumulators = new HashMap<>(); 18 | for (String counter : counters) { 19 | accumulators.put(counter, sc.sc().longAccumulator()); 20 | } 21 | } 22 | 23 | public void increment(String accumulator) { 24 | accumulators.get(accumulator).add(1); 25 | } 26 | 27 | public long getValue(String accumulator) { 28 | return accumulators.get(accumulator).value(); 29 | } 30 | 31 | public void dump() { 32 | accumulators.forEach((k, v) -> System.out.println(k + ": " + v.value())); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/NameDatabaseWikiData.java: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-dlab/Quotebank/a203acf5a05e34841671578d14d6a0be6a66a3ef/quootstrap/src/main/java/ch/epfl/dlab/quootstrap/NameDatabaseWikiData.java -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/ParquetDatasetLoader.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.util.Set; 4 | 5 | import org.apache.spark.api.java.JavaRDD; 6 | import org.apache.spark.api.java.JavaSparkContext; 7 | import org.apache.spark.sql.Dataset; 8 | import org.apache.spark.sql.Encoders; 9 | import org.apache.spark.sql.Row; 10 | import org.apache.spark.sql.SparkSession; 11 | 12 | import scala.Tuple4; 13 | 14 | public class ParquetDatasetLoader implements DatasetLoader { 15 | 16 | private static SparkSession session; 17 | 18 | private static void initialize(JavaSparkContext sc) { 19 | if (session == null) { 20 | session = new SparkSession(sc.sc()); 21 | } 22 | } 23 | 24 | @Override 25 | public JavaRDD
loadArticles(JavaSparkContext sc, String datasetPath, Set languageFilter) { 26 | ParquetDatasetLoader.initialize(sc); 27 | Dataset df = ParquetDatasetLoader.session.read().parquet(datasetPath); 28 | 29 | /* Expected schema: 30 | * articleUID: String 31 | * website: String 32 | * content: String 33 | * date: String 34 | */ 35 | 36 | return df.select(df.col("articleUID"), df.col("website"), df.col("content"), df.col("date")) 37 | .map(x -> { 38 | String uid = x.getString(0); 39 | String url = x.getString(1); 40 | String content = x.getString(2); 41 | String date = x.getString(3); 42 | return new Tuple4<>(uid, content, url, date); 43 | }, Encoders.tuple(Encoders.STRING(), Encoders.STRING(), Encoders.STRING(), Encoders.STRING())) 44 | .javaRDD() 45 | .map(x -> new Spinn3rTextDatasetLoader.Article(x._1(), x._2(), x._3(), x._4())); 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/Pattern.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.io.Serializable; 4 | import java.util.ArrayList; 5 | import java.util.Collections; 6 | import java.util.Iterator; 7 | import java.util.List; 8 | import java.util.regex.Matcher; 9 | 10 | /** 11 | * This class defines a pattern used for representing regular expressions. 12 | * In this context, a valid regular expression is defined by the following set of rules: 13 | * - The matching is done on a per-token basis (and not on a per-character basis). 14 | * - A token usually corresponds to a full word or to a punctuation mark. 15 | * 16 | * The pattern should/could contain: 17 | * - Exactly one quotation token $Q 18 | * - Exactly one speaker token $S (which can match a variable number of tokens) 19 | * - Any number of text tokens (i.e. generic tokens), which are matched exactly 20 | * - Any number of "ANY" tokens $* (equivalent to "." in Perl), which are matched 1 time 21 | * 22 | * Additionally, a pattern is subject to the following constraints: 23 | * - It must not start or end with an ANY token or a speaker token 24 | */ 25 | public final class Pattern implements Serializable, Iterable, Comparable { 26 | 27 | private static final long serialVersionUID = -3877164912177196115L; 28 | 29 | public static final String QUOTATION_PLACEHOLDER = "$Q"; 30 | public static final String SPEAKER_PLACEHOLDER = "$S"; 31 | public static final String ANY_PLACEHOLDER = "$*"; 32 | 33 | private final List tokens; 34 | private final double confidenceMetric; 35 | 36 | public Pattern(String expression, double confidence) { 37 | tokens = new ArrayList<>(); 38 | String[] strTokens = expression.split(" "); 39 | for (String token : strTokens) { 40 | Token next; 41 | switch (token) { 42 | case QUOTATION_PLACEHOLDER: 43 | next = new Token(null, Token.Type.QUOTATION); 44 | break; 45 | case SPEAKER_PLACEHOLDER: 46 | next = new Token(null, Token.Type.SPEAKER); 47 | break; 48 | case ANY_PLACEHOLDER: 49 | next = new Token(null, Token.Type.ANY); 50 | break; 51 | default: 52 | next = new Token(token, Token.Type.GENERIC); 53 | break; 54 | } 55 | this.tokens.add(next); 56 | } 57 | confidenceMetric = confidence; 58 | sanityCheck(); 59 | } 60 | 61 | public Pattern(List tokens, double confidence) { 62 | this.tokens = new ArrayList<>(tokens); // Defensive copy 63 | this.confidenceMetric = confidence; 64 | sanityCheck(); 65 | } 66 | 67 | public Pattern(List tokens) { 68 | this(tokens, Double.NaN); 69 | } 70 | 71 | public Pattern(String expression) { 72 | this(expression, Double.NaN); 73 | } 74 | 75 | public List getTokens() { 76 | return Collections.unmodifiableList(tokens); 77 | } 78 | 79 | public int getTokenCount() { 80 | return tokens.size(); 81 | } 82 | 83 | /** 84 | * @return the number of text tokens 85 | */ 86 | public int getCardinality() { 87 | int count = 0; 88 | for (Token t : tokens) { 89 | if (t.getType() == Token.Type.GENERIC) { 90 | count++; 91 | } 92 | } 93 | return count; 94 | } 95 | 96 | public double getConfidenceMetric() { 97 | return confidenceMetric; 98 | } 99 | 100 | public boolean isSpeakerSurroundedByAny() { 101 | for (int i = 0; i < tokens.size(); i++) { 102 | if (tokens.get(i).getType() == Token.Type.SPEAKER) { 103 | return tokens.get(i - 1).getType() == Token.Type.ANY 104 | || tokens.get(i + 1).getType() == Token.Type.ANY; 105 | } 106 | } 107 | throw new IllegalStateException("Speaker not found in pattern"); 108 | } 109 | 110 | private void sanityCheck() { 111 | if (tokens.isEmpty()) { 112 | throw new IllegalArgumentException("Invalid pattern: the pattern is empty"); 113 | } 114 | boolean quotationFound = false; 115 | boolean speakerFound = false; 116 | for (Token t : tokens) { 117 | if (t.getType() == Token.Type.QUOTATION) { 118 | if (quotationFound) { 119 | throw new IllegalArgumentException("Invalid pattern: more than one quotation placeholder found"); 120 | } 121 | quotationFound = true; 122 | } else if (t.getType() == Token.Type.SPEAKER) { 123 | if (speakerFound) { 124 | throw new IllegalArgumentException("Invalid pattern: more than one speaker placeholder found"); 125 | } 126 | speakerFound = true; 127 | } 128 | } 129 | if (!quotationFound) { 130 | throw new IllegalArgumentException("Invalid pattern: no quotation placeholder found"); 131 | } 132 | if (!speakerFound) { 133 | throw new IllegalArgumentException("Invalid pattern: no speaker placeholder found"); 134 | } 135 | if (tokens.get(0).getType() == Token.Type.SPEAKER) { 136 | throw new IllegalArgumentException("Invalid pattern: the pattern must not start with a speaker placeholder"); 137 | } 138 | if (tokens.get(tokens.size() - 1).getType() == Token.Type.SPEAKER) { 139 | throw new IllegalArgumentException("Invalid pattern: the pattern must not end with a speaker placeholder"); 140 | } 141 | if (tokens.get(0).getType() == Token.Type.ANY) { 142 | throw new IllegalArgumentException("Invalid pattern: the pattern must not start with an 'any' placeholder"); 143 | } 144 | if (tokens.get(tokens.size() - 1).getType() == Token.Type.ANY) { 145 | throw new IllegalArgumentException("Invalid pattern: the pattern must not end with an 'any' placeholder"); 146 | } 147 | } 148 | 149 | @Override 150 | public String toString() { 151 | return toString(true); 152 | } 153 | 154 | public String toString(boolean addConfidence) { 155 | StringBuilder str = new StringBuilder(); 156 | Iterator it = tokens.iterator(); 157 | while (it.hasNext()) { 158 | Token t = it.next(); 159 | switch (t.getType()) { 160 | case QUOTATION: 161 | str.append(QUOTATION_PLACEHOLDER); 162 | break; 163 | case SPEAKER: 164 | str.append(SPEAKER_PLACEHOLDER); 165 | break; 166 | case ANY: 167 | str.append(ANY_PLACEHOLDER); 168 | break; 169 | default: 170 | str.append(t.toString()); 171 | break; 172 | } 173 | if (it.hasNext()) { 174 | str.append(' '); 175 | } 176 | } 177 | 178 | String output = str.toString(); 179 | if (addConfidence && confidenceMetric >= 0) { 180 | output = output.replace("\"", "\\\""); // Escape character 181 | output = "[\"" + output + "\": " + confidenceMetric + "]"; 182 | } 183 | return output; 184 | } 185 | 186 | public static Pattern parse(String input) { 187 | java.util.regex.Pattern pattern = java.util.regex.Pattern.compile("^\\[\\\"(.*)\": ([0-9.eE]+)\\]$"); 188 | Matcher m = pattern.matcher(input); 189 | if (m.matches() && m.groupCount() == 2) { 190 | String p = m.group(1).replace("\\\"", "\""); 191 | double confidence = Double.parseDouble(m.group(2)); 192 | return new Pattern(p, confidence); 193 | } 194 | throw new IllegalArgumentException("Invalid pattern format: " + input); 195 | } 196 | 197 | @Override 198 | public int hashCode() { 199 | return tokens.hashCode(); 200 | } 201 | 202 | @Override 203 | public boolean equals(Object obj) { 204 | if (obj instanceof Pattern) { 205 | Pattern p = (Pattern) obj; 206 | // The comparison is done only on the content 207 | return p.tokens.equals(tokens); 208 | } 209 | return false; 210 | } 211 | 212 | @Override 213 | public Iterator iterator() { 214 | return Collections.unmodifiableList(tokens).iterator(); 215 | } 216 | 217 | @Override 218 | public int compareTo(Pattern other) { 219 | // Define a lexicographical order for patterns 220 | int n = Math.min(tokens.size(), other.tokens.size()); 221 | for (int i = 0; i < n; i++) { 222 | int comp = tokens.get(i).compareTo(other.tokens.get(i)); 223 | if (comp != 0) { 224 | return comp; 225 | } 226 | } 227 | return Integer.compare(tokens.size(), other.tokens.size()); 228 | } 229 | } 230 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/PatternExtractor.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | public class PatternExtractor { 7 | 8 | public static Pattern extractPattern(Sentence s, String quotation, 9 | List speaker, boolean caseSensitive) { 10 | 11 | int quotationPtr = 0; 12 | 13 | int indexOfSpeaker = getIndexOfSpeaker(s, speaker, caseSensitive); 14 | if (indexOfSpeaker == -1 || indexOfSpeaker == 0) { 15 | // If match not found, or, if match found at position 0... 16 | return null; 17 | } 18 | 19 | boolean patternStarted = false; 20 | boolean hasQuotation = false; 21 | boolean hasSpeaker = false; 22 | List tokens = s.getTokens(); 23 | List patternTokens = new ArrayList<>(); 24 | for (int i = 0; i < tokens.size(); i++) { 25 | 26 | if (i == indexOfSpeaker) { 27 | if (patternTokens.isEmpty()) { 28 | // The pattern cannot start with the speaker 29 | patternTokens.add(tokens.get(i - 1)); 30 | patternStarted = true; 31 | } 32 | patternTokens.add(new Token(null, Token.Type.SPEAKER)); 33 | i += speaker.size() - 1; 34 | hasSpeaker = true; 35 | } else { 36 | switch (tokens.get(i).getType()) { 37 | case QUOTATION: 38 | if (quotationPtr < 1 && tokens.get(i).toString().equals(quotation)) { 39 | patternStarted = true; 40 | quotationPtr++; 41 | patternTokens.add(new Token(null, Token.Type.QUOTATION)); 42 | } else { 43 | return null; 44 | } 45 | if (quotationPtr == 1) { 46 | hasQuotation = true; 47 | } 48 | break; 49 | default: 50 | if (hasQuotation && hasSpeaker && patternTokens.get(patternTokens.size() - 1).getType() != Token.Type.SPEAKER) { 51 | patternStarted = false; 52 | } 53 | 54 | if (patternStarted) { 55 | patternTokens.add(tokens.get(i)); 56 | } 57 | if (hasQuotation && hasSpeaker) { 58 | patternStarted = false; 59 | } 60 | break; 61 | } 62 | } 63 | } 64 | 65 | if (quotationPtr < 1) { 66 | // The quotation has not been matched entirely 67 | return null; 68 | } 69 | 70 | // Ensure that the speaker token is not last token in the pattern 71 | if (patternTokens.get(patternTokens.size() - 1).getType() == Token.Type.SPEAKER) { 72 | return null; 73 | } 74 | 75 | return new Pattern(patternTokens); 76 | } 77 | 78 | private static int getIndexOfSpeaker(Sentence s, List speaker, boolean caseSensitive) { 79 | List tokens = s.getTokens(); 80 | 81 | if (!caseSensitive) { 82 | tokens = Token.caseFold(tokens); 83 | speaker = Token.caseFold(speaker); 84 | } 85 | 86 | for (Token t : speaker) { 87 | // Each token must be perfectly matched without duplicates (even if they are partial) 88 | int firstIndex = tokens.indexOf(t); 89 | int lastIndex = tokens.lastIndexOf(t); 90 | if (firstIndex == -1 || lastIndex != firstIndex) { 91 | return -1; 92 | } 93 | } 94 | 95 | for (int i = 0; i < tokens.size() - speaker.size() + 1; i++) { 96 | for (int j = 0; j < speaker.size(); j++) { 97 | if (!tokens.get(i + j).equals(speaker.get(j))) { 98 | break; 99 | } 100 | if (j == speaker.size() - 1) { 101 | // Match found! 102 | return i; 103 | } 104 | } 105 | } 106 | 107 | return -1; 108 | } 109 | 110 | } 111 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/PatternMatcher.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Collections; 5 | import java.util.List; 6 | import java.util.stream.Collectors; 7 | 8 | public abstract class PatternMatcher { 9 | 10 | protected List sentenceTokens; 11 | protected boolean speakerTokenFoundFlag; 12 | protected boolean quotationTokenFoundFlag; 13 | 14 | protected final List matchedQuotation; 15 | protected final List matchedSpeaker; 16 | protected final List matches; 17 | 18 | protected final int minSpeakerLength; 19 | protected final int maxSpeakerLength; 20 | protected final boolean caseSensitive; 21 | 22 | protected PatternMatcher(int speakerLengthMin, int speakerLengthMax, boolean caseSensitive) { 23 | matchedQuotation = new ArrayList<>(); 24 | matchedSpeaker = new ArrayList<>(); 25 | minSpeakerLength = speakerLengthMin; 26 | maxSpeakerLength = speakerLengthMax; 27 | this.caseSensitive = caseSensitive; 28 | matches = new ArrayList<>(); 29 | } 30 | 31 | public abstract boolean match(Sentence s); 32 | 33 | public List getMatches(boolean longest) { 34 | if (longest && !matches.isEmpty()) { 35 | // We want only the pattern with the highest cardinality (i.e. number of text tokens) 36 | final int longestLength = Collections.max(matches, 37 | (x, y) -> Integer.compare(x.getPattern().getCardinality(), y.getPattern().getCardinality())) 38 | .getPattern().getCardinality(); 39 | 40 | return matches.stream() 41 | .filter(x -> x.getPattern().getCardinality() == longestLength) 42 | .collect(Collectors.toList()); 43 | } else { 44 | return Collections.unmodifiableList(matches); 45 | } 46 | } 47 | 48 | protected final boolean matchTokens(Token patternToken, Token token, int speakerTokensLeft) { 49 | switch (patternToken.getType()) { 50 | case GENERIC: 51 | if (token.getType() == Token.Type.GENERIC) { 52 | String tokenStr = token.toString(); 53 | String patternTokenStr = patternToken.toString(); 54 | if (caseSensitive) { 55 | return tokenStr.equals(patternTokenStr); 56 | } else { 57 | return tokenStr.equalsIgnoreCase(patternTokenStr); 58 | } 59 | } 60 | return false; 61 | case SPEAKER: 62 | if (speakerTokensLeft > 0 && token.getType() == Token.Type.GENERIC) { 63 | speakerTokenFoundFlag = true; 64 | matchedSpeaker.add(token); 65 | return true; 66 | } 67 | return false; 68 | case QUOTATION: 69 | if (token.getType() == Token.Type.QUOTATION) { 70 | matchedQuotation.add(token); 71 | quotationTokenFoundFlag = true; 72 | return true; 73 | } 74 | return false; 75 | case ANY: 76 | return true; 77 | default: 78 | throw new IllegalStateException(); 79 | } 80 | } 81 | 82 | public final static class Match { 83 | private final String matchedQuotation; 84 | private final List matchedSpeaker; 85 | private final Pattern matchedPattern; 86 | 87 | public Match(List quotation, List speaker, Pattern pattern) { 88 | if (quotation.size() != 1) { 89 | throw new IllegalArgumentException("Invalid quotation token"); 90 | } 91 | matchedQuotation = quotation.get(0).toString(); 92 | matchedSpeaker = new ArrayList<>(speaker); 93 | matchedPattern = pattern; 94 | } 95 | 96 | public final String getQuotation() { 97 | return matchedQuotation; 98 | } 99 | 100 | public final List getSpeaker() { 101 | return matchedSpeaker; 102 | } 103 | 104 | public final Pattern getPattern() { 105 | return matchedPattern; 106 | } 107 | } 108 | 109 | } 110 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/Sentence.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.io.Serializable; 4 | import java.util.ArrayList; 5 | import java.util.Collections; 6 | import java.util.Iterator; 7 | import java.util.List; 8 | import java.util.function.Predicate; 9 | import java.util.stream.Collectors; 10 | 11 | import ch.epfl.dlab.spinn3r.Tokenizer; 12 | import ch.epfl.dlab.spinn3r.TokenizerImpl; 13 | import scala.Tuple2; 14 | import scala.Tuple5; 15 | 16 | /** 17 | * A "Sentence" is an extracted context from an article. 18 | * It consists of: 19 | * - A *single* quotation between quotation marks (or equivalent HTML tags) 20 | * - The context surrounding the quotation (before and after) 21 | */ 22 | public final class Sentence implements Serializable, Iterable, Comparable { 23 | 24 | private static final long serialVersionUID = -4971549043386733522L; 25 | 26 | private final List tokens; 27 | private final String articleUid; 28 | private final int index; 29 | private final int quotationOffset; 30 | private final int leftOffset; 31 | private final int rightOffset; 32 | private transient int cachedHashCode; 33 | 34 | /** 35 | * Constructs a new Sentence from the given list of tokens. 36 | * @param tokens The list of tokens (must contain one quotation token) 37 | * @param articleUid The UID of the article from which this sentence has been extracted 38 | * @param index The index of this sentence within the article (starting from 0) 39 | * @param skipChecks Skip safety checks (for debug or test purposes) 40 | */ 41 | public Sentence(List tokens, String articleUid, int index, int quotationOffset, int leftOffset, int rightOffset, boolean skipChecks) { 42 | this.tokens = tokens; 43 | this.articleUid = articleUid; 44 | this.index = index; 45 | this.quotationOffset = quotationOffset; 46 | this.leftOffset = leftOffset; 47 | this.rightOffset = rightOffset; 48 | if (!skipChecks) { 49 | sanityCheck(); 50 | } 51 | } 52 | 53 | /** 54 | * Constructs a new Sentence from the given list of tokens. 55 | * @param tokens The list of tokens (must contain one quotation token) 56 | * @param articleUid The UID of the article from which this sentence has been extracted 57 | * @param index The index of this sentence within the article (starting from 0) 58 | */ 59 | public Sentence(List tokens, String articleUid, int index, int quotationOffset, int leftOffset, int rightOffset) { 60 | this(tokens, articleUid, index, quotationOffset, leftOffset, rightOffset, false); 61 | } 62 | 63 | /** 64 | * Constructs a new Sentence from the given list of tokens. 65 | * @param tokens The list of tokens (must contain one quotation token) 66 | */ 67 | public Sentence(List tokens) { 68 | this(tokens, "", -1, 0, 0, 0, false); 69 | } 70 | 71 | /** 72 | * Constructs a new Sentence from the given list of tokens. 73 | * @param tokens The list of tokens (must contain one quotation token) 74 | * @param skipChecks Skip safety checks (for debug or test purposes) 75 | */ 76 | public Sentence(List tokens, boolean skipChecks) { 77 | this(tokens, "", -1, 0, 0, 0, skipChecks); 78 | } 79 | 80 | public boolean matches(Predicate> predicate) { 81 | return predicate.test(tokens.iterator()); 82 | } 83 | 84 | public List getTokens() { 85 | return tokens; 86 | } 87 | 88 | public List getTokensByType(Token.Type type) { 89 | return tokens.stream() 90 | .filter(x -> x.getType() == type) 91 | .collect(Collectors.toCollection(ArrayList::new)); 92 | } 93 | 94 | public String getQuotation() { 95 | for (Token t : tokens) { 96 | if (t.getType() == Token.Type.QUOTATION) { 97 | return t.toString(); 98 | } 99 | } 100 | throw new IllegalStateException("No quotation found in this sentence"); 101 | } 102 | 103 | public String getArticleUid() { 104 | return articleUid; 105 | } 106 | 107 | public int getIndex() { 108 | return index; 109 | } 110 | 111 | public int getQuotationOffset() { 112 | return quotationOffset; 113 | } 114 | 115 | /** 116 | * @return the leftOffset 117 | */ 118 | public int getLeftOffset() { 119 | return leftOffset; 120 | } 121 | 122 | /** 123 | * @return the rightOffset 124 | */ 125 | public int getRightOffset() { 126 | return rightOffset; 127 | } 128 | 129 | /** 130 | * Gets the key that uniquely identifies this Sentence. 131 | * @return An (Article UID, index within article) pair 132 | */ 133 | public Tuple2 getKey() { 134 | return new Tuple2<>(articleUid, index); 135 | } 136 | 137 | /** 138 | * Gets all the attributes of this Sentence. 139 | * @return An (Article UID, index within article, left Offset, right Offset) pair 140 | */ 141 | public Tuple5 getInfo() { 142 | return new Tuple5<>(articleUid, index, quotationOffset, leftOffset, rightOffset); 143 | } 144 | 145 | public int getTokenCount() { 146 | return tokens.size(); 147 | } 148 | 149 | @Override 150 | public String toString() { 151 | Iterator it = tokens.iterator(); 152 | StringBuilder str = new StringBuilder(); 153 | while (it.hasNext()) { 154 | Token next = it.next(); 155 | if (next.getType() == Token.Type.QUOTATION) { 156 | str.append('['); 157 | } 158 | str.append(next.toString()); 159 | if (next.getType() == Token.Type.QUOTATION) { 160 | str.append(']'); 161 | } 162 | if (it.hasNext()) { 163 | str.append(' '); 164 | } 165 | 166 | } 167 | return str.toString(); 168 | } 169 | 170 | public String toHumanReadableString(boolean htmlQuotation) { 171 | if (tokens.isEmpty()) { 172 | return ""; 173 | } 174 | 175 | Tokenizer tokenizer = new TokenizerImpl(); 176 | List buffer = new ArrayList<>(); 177 | StringBuilder result = new StringBuilder(); 178 | for (int i = 0; i < tokens.size(); i++) { 179 | Token t = tokens.get(i); 180 | if (t.getType() == Token.Type.GENERIC && StaticRules.isHtmlTag(t.toString())) { 181 | continue; // Discard HTML tags 182 | } 183 | 184 | if (i == 0 && t.getType() == Token.Type.GENERIC && (t.toString().equals(".") || t.toString().equals(","))) { 185 | continue; // Discard leading punctuation 186 | } 187 | 188 | if (t.getType() != Token.Type.QUOTATION) { 189 | buffer.add(t.toString()); 190 | } else { 191 | if (!buffer.isEmpty()) { 192 | result.append(tokenizer.untokenize(buffer)); 193 | result.append(" "); 194 | } 195 | 196 | if (htmlQuotation) { 197 | result.append("" + t.toString() + ""); 198 | } else { 199 | result.append("\"" + t.toString() + "\""); 200 | } 201 | buffer.clear(); 202 | } 203 | } 204 | if (!buffer.isEmpty()) { 205 | result.append(" "); 206 | result.append(tokenizer.untokenize(buffer)); 207 | } 208 | 209 | 210 | return result.toString(); 211 | } 212 | 213 | @Override 214 | public int hashCode() { 215 | // Racy single-check idiom 216 | int h = cachedHashCode; 217 | if (h == 0) { 218 | h = tokens.hashCode(); 219 | cachedHashCode = h; 220 | } 221 | return h; 222 | } 223 | 224 | @Override 225 | public boolean equals(Object obj) { 226 | if (obj instanceof Sentence) { 227 | Sentence p = (Sentence) obj; 228 | return p.tokens.equals(tokens); 229 | } 230 | return false; 231 | } 232 | 233 | @Override 234 | public Iterator iterator() { 235 | return Collections.unmodifiableList(tokens).iterator(); 236 | } 237 | 238 | @Override 239 | public int compareTo(Sentence other) { 240 | // Define a lexicographical order for sentences 241 | int n = Math.min(tokens.size(), other.tokens.size()); 242 | for (int i = 0; i < n; i++) { 243 | int comp = tokens.get(i).compareTo(other.tokens.get(i)); 244 | if (comp != 0) { 245 | return comp; 246 | } 247 | } 248 | return Integer.compare(tokens.size(), other.tokens.size()); 249 | } 250 | 251 | private void sanityCheck() { 252 | if (tokens.isEmpty()) { 253 | throw new IllegalArgumentException("Invalid sentence: the sentence is empty"); 254 | } 255 | 256 | boolean quotationFound = false; 257 | for (Token t : tokens) { 258 | if (t.getType() == Token.Type.QUOTATION) { 259 | if (quotationFound) { 260 | throw new IllegalArgumentException("Invalid sentence: more than one quotation placeholder found"); 261 | } 262 | quotationFound = true; 263 | } 264 | } 265 | if (!quotationFound) { 266 | throw new IllegalArgumentException("Invalid sentence: no quotation placeholder found"); 267 | } 268 | } 269 | } 270 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/SimplePatternMatcher.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | public class SimplePatternMatcher extends PatternMatcher { 7 | 8 | private final List patterns; 9 | private List patternTokens; 10 | private Pattern currentPattern; 11 | 12 | public SimplePatternMatcher(Pattern pattern, 13 | int speakerLengthMin, int speakerLengthMax, boolean caseSensitive) { 14 | super(speakerLengthMin, speakerLengthMax, caseSensitive); 15 | this.patterns = new ArrayList<>(); 16 | patterns.add(pattern); 17 | } 18 | 19 | public SimplePatternMatcher(List patterns, 20 | int speakerLengthMin, int speakerLengthMax, boolean caseSensitive) { 21 | super(speakerLengthMin, speakerLengthMax, caseSensitive); 22 | this.patterns = new ArrayList<>(patterns); 23 | } 24 | 25 | @Override 26 | public boolean match(Sentence s) { 27 | sentenceTokens = s.getTokens(); 28 | 29 | matches.clear(); 30 | boolean result = false; 31 | for (int i = 0; i < sentenceTokens.size(); i++) { 32 | for (Pattern currentPattern : patterns) { 33 | this.currentPattern = currentPattern; 34 | patternTokens = currentPattern.getTokens(); 35 | 36 | if (matchImpl(0, i, maxSpeakerLength)) { 37 | result = true; 38 | } 39 | 40 | assert matchedQuotation.isEmpty(); 41 | assert matchedSpeaker.isEmpty(); 42 | assert !speakerTokenFoundFlag; 43 | assert !quotationTokenFoundFlag; 44 | } 45 | } 46 | 47 | return result; 48 | } 49 | 50 | private boolean matchImpl(int i, int j, int speakerTokensLeft) { 51 | if (i == patternTokens.size()) { 52 | // End of pattern reached 53 | matches.add(new Match(matchedQuotation, matchedSpeaker, currentPattern)); 54 | return true; 55 | } 56 | 57 | if (j == sentenceTokens.size()) { 58 | // End of sentence reached: no match possible 59 | return false; 60 | } 61 | 62 | if (matchTokens(patternTokens.get(i), sentenceTokens.get(j), speakerTokensLeft)) { 63 | if (speakerTokenFoundFlag) { 64 | speakerTokenFoundFlag = false; 65 | boolean m1 = false; 66 | if (matchedSpeaker.size() >= minSpeakerLength) { 67 | m1 = matchImpl(i + 1, j + 1, 0); 68 | } 69 | boolean m2 = matchImpl(i, j + 1, speakerTokensLeft - 1); 70 | matchedSpeaker.remove(matchedSpeaker.size() - 1); 71 | return m1 || m2; 72 | 73 | } else if (quotationTokenFoundFlag) { 74 | quotationTokenFoundFlag = false; 75 | boolean m = matchImpl(i + 1, j + 1, speakerTokensLeft); 76 | matchedQuotation.remove(matchedQuotation.size() - 1); 77 | return m; 78 | 79 | } else { 80 | return matchImpl(i + 1, j + 1, speakerTokensLeft); 81 | } 82 | } 83 | 84 | return false; 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/SpeakerAlias.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.io.Serializable; 4 | import java.util.ArrayList; 5 | import java.util.List; 6 | 7 | import scala.Tuple2; 8 | 9 | public class SpeakerAlias implements Serializable { 10 | 11 | private static final long serialVersionUID = -6037497222642880562L; 12 | 13 | private final List aliasTokens; 14 | private final List> ids; 15 | private final List> offsets; 16 | 17 | public SpeakerAlias(List aliasTokens, List> ids, Tuple2 offset) { 18 | this.aliasTokens = aliasTokens; 19 | this.ids = ids; 20 | this.offsets = new ArrayList<>(); 21 | this.offsets.add(offset); 22 | } 23 | 24 | public List getAlias() { 25 | return aliasTokens; 26 | } 27 | 28 | public List> getIds() { 29 | return ids; 30 | } 31 | 32 | public List> getOffsets() { 33 | return offsets; 34 | } 35 | 36 | public void addOffset(Tuple2 offset) { 37 | offsets.add(offset); 38 | } 39 | 40 | @Override 41 | public int hashCode() { 42 | return aliasTokens.hashCode() ^ ids.hashCode(); 43 | } 44 | 45 | @Override 46 | public boolean equals(Object o) { 47 | if (o instanceof SpeakerAlias) { 48 | SpeakerAlias oa = (SpeakerAlias) o; 49 | return oa.ids.equals(ids); //oa.aliasTokens.equals(aliasTokens) && 50 | } 51 | return false; 52 | } 53 | 54 | } 55 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/Spinn3rDatasetLoader.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.io.Serializable; 4 | import java.net.MalformedURLException; 5 | import java.net.URL; 6 | import java.util.Arrays; 7 | import java.util.Collections; 8 | import java.util.List; 9 | import java.util.Locale; 10 | import java.util.Set; 11 | 12 | import org.apache.spark.api.java.JavaRDD; 13 | import org.apache.spark.api.java.JavaSparkContext; 14 | 15 | import com.google.gson.Gson; 16 | 17 | import ch.epfl.dlab.spinn3r.EntryWrapper; 18 | 19 | public class Spinn3rDatasetLoader implements DatasetLoader, Serializable { 20 | 21 | private static final long serialVersionUID = 7689367526740335445L; 22 | 23 | private final boolean caseFold; 24 | 25 | public Spinn3rDatasetLoader() { 26 | caseFold = ConfigManager.getInstance().isCaseFoldingEnabled(); 27 | } 28 | 29 | @Override 30 | public JavaRDD loadArticles(JavaSparkContext sc, String datasetPath, Set languageFilter) { 31 | return sc.textFile(datasetPath) 32 | .map(x -> new Gson().fromJson(x, EntryWrapper.class).getPermalinkEntry()) 33 | //.filter(x -> languageFilter.contains(x.getLanguage())) 34 | .map(x -> { 35 | String domain; 36 | try { 37 | final URL url = new URL(x.getUrl()); 38 | domain = url.getHost(); 39 | } catch (MalformedURLException e) { 40 | domain = ""; 41 | } 42 | 43 | String time; 44 | if (!x.getLastPublished().isEmpty() && !x.getDateFound().isEmpty()) { 45 | time = Collections.min(Arrays.asList(x.getLastPublished(), x.getDateFound())); 46 | } else if (!x.getLastPublished().isEmpty()) { 47 | time = x.getLastPublished(); 48 | } else { 49 | time = x.getDateFound(); 50 | } 51 | 52 | List tokenizedContent = x.getContent(); 53 | if (caseFold) { 54 | tokenizedContent.replaceAll(token -> token.toLowerCase(Locale.ROOT)); 55 | } 56 | 57 | return new Article(Long.parseLong(x.getIdentifier()), tokenizedContent, domain, time); 58 | }); 59 | } 60 | 61 | public static class Article implements DatasetLoader.Article { 62 | 63 | private final long articleUID; 64 | private final List articleContent; 65 | private final String website; 66 | private final String date; 67 | 68 | public Article(long articleUID, List articleContent, String website, String date) { 69 | this.articleUID = articleUID; 70 | this.articleContent = articleContent; 71 | this.website = website; 72 | this.date = date; 73 | } 74 | 75 | @Override 76 | public String getArticleUID() { 77 | return Long.toString(articleUID); 78 | } 79 | 80 | @Override 81 | public List getArticleContent() { 82 | return articleContent; 83 | } 84 | 85 | @Override 86 | public String getWebsite() { 87 | return website; 88 | } 89 | 90 | @Override 91 | public String getDate() { 92 | return date; 93 | } 94 | 95 | @Override 96 | public String getVersion() { 97 | return null; 98 | } 99 | 100 | @Override 101 | public String getTitle() { 102 | return null; 103 | } 104 | 105 | 106 | 107 | } 108 | 109 | } 110 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/Spinn3rTextDatasetLoader.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.io.Serializable; 4 | import java.text.SimpleDateFormat; 5 | import java.util.ArrayList; 6 | import java.util.Arrays; 7 | import java.util.HashMap; 8 | import java.util.HashSet; 9 | import java.util.List; 10 | import java.util.Locale; 11 | import java.util.Set; 12 | 13 | import org.apache.hadoop.conf.Configuration; 14 | import org.apache.hadoop.io.LongWritable; 15 | import org.apache.hadoop.io.Text; 16 | import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; 17 | import org.apache.spark.SparkConf; 18 | import org.apache.spark.api.java.JavaRDD; 19 | import org.apache.spark.api.java.JavaSparkContext; 20 | 21 | 22 | public class Spinn3rTextDatasetLoader implements DatasetLoader, Serializable { 23 | 24 | private static final long serialVersionUID = 1432174520856283513L; 25 | 26 | private final boolean caseFold; 27 | private transient Configuration config = null; 28 | 29 | public Spinn3rTextDatasetLoader() { 30 | caseFold = ConfigManager.getInstance().isCaseFoldingEnabled(); 31 | } 32 | 33 | @Override 34 | public JavaRDD loadArticles(JavaSparkContext sc, String datasetPath, Set languageFilter) { 35 | if (config == null) { 36 | // Copy configuration 37 | config = new Configuration(sc.hadoopConfiguration()); 38 | config.set("textinputformat.record.delimiter", "\n\n"); 39 | } 40 | 41 | JavaRDD records = sc.newAPIHadoopFile(datasetPath, TextInputFormat.class, LongWritable.class, Text.class, config) 42 | .map(x -> new Spinn3rDocument(x._2.toString())) 43 | .filter(x -> !x.isGarbled 44 | && x.content != null && !x.content.isEmpty() 45 | && x.date != null 46 | && x.urlString != null 47 | //&& languageFilter.contains(x.getProbableLanguage()) 48 | ) 49 | .map(x -> { 50 | final String id = x.docId; 51 | final String domain = x.urlString; 52 | final String time = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(x.date); 53 | final String version = x.version.name(); 54 | final String title = x.title; 55 | 56 | String content = x.content; 57 | if (caseFold) { 58 | content = content.toLowerCase(Locale.ROOT); 59 | } 60 | 61 | return new Article(id, content, domain, time, version, title); 62 | }); 63 | //.mapToPair(x -> new Tuple2<>(x.getArticleUID(), x)) 64 | //.reduceByKey((x, y) -> x.getArticleUID() < y.getArticleUID() ? x : y) // Deterministic distinct 65 | //.map(x -> x._2); 66 | 67 | /*System.out.println(sc.newAPIHadoopFile(datasetPath, TextInputFormat.class, LongWritable.class, Text.class, config) 68 | .map(x -> new Spinn3rDocument(x._2.toString())) 69 | .take(2)); 70 | 71 | System.out.println(records.take(1)); 72 | System.exit(0);*/ 73 | 74 | //String suffix = ConfigManager.getInstance().getLangSuffix(); 75 | //records = Utils.loadCache(records, "documents-" + suffix); 76 | return records; 77 | } 78 | 79 | public static void main(String[] args) { 80 | final SparkConf conf = new SparkConf() 81 | .setAppName("QuotationExtraction") 82 | .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") 83 | .registerKryoClasses(new Class[] { ArrayList.class, Token.class, Token.Type.class, Sentence.class, Pattern.class, 84 | Trie.class, Trie.Node.class, String[].class, Object[].class, HashMap.class, Hashed.class }); 85 | 86 | if (ConfigManager.getInstance().isLocalModeEnabled()) { 87 | conf.setMaster("local[*]") 88 | .set("spark.executor.memory", "8g") 89 | .set("spark.driver.memory", "8g"); 90 | } 91 | 92 | try (JavaSparkContext sc = new JavaSparkContext(conf)) { 93 | Spinn3rTextDatasetLoader loader = new Spinn3rTextDatasetLoader(); 94 | 95 | Set langFilter = new HashSet<>(ConfigManager.getInstance().getLangFilter()); 96 | loader.loadArticles(sc, "C:\\Users\\Dario\\Documents\\EPFL\\part-r-00000.snappy", langFilter); 97 | } 98 | 99 | 100 | } 101 | 102 | public static class Article implements DatasetLoader.Article, Serializable { 103 | 104 | private static final long serialVersionUID = -5411421564171041258L; 105 | 106 | private final String articleUID; 107 | private final String articleContent; 108 | private final String website; 109 | private final String date; 110 | private final String version; 111 | private final String title; 112 | 113 | public Article(String articleUID, String articleContent, String website, String date, String version, String title) { 114 | this.articleUID = articleUID; 115 | this.articleContent = articleContent; 116 | this.website = website; 117 | this.date = date; 118 | this.version = version; 119 | this.title = title; 120 | } 121 | 122 | public Article(String articleUID, String articleContent, String website, String date) { 123 | this(articleUID, articleContent, website, date, "", ""); 124 | } 125 | 126 | @Override 127 | public String getArticleUID() { 128 | return articleUID; 129 | } 130 | 131 | @Override 132 | public List getArticleContent() { 133 | // Construct on the fly 134 | return new ArrayList<>(Arrays.asList(articleContent.split(" "))); 135 | } 136 | 137 | @Override 138 | public String getWebsite() { 139 | return website; 140 | } 141 | 142 | @Override 143 | public String getDate() { 144 | return date; 145 | } 146 | 147 | @Override 148 | public String getVersion() { 149 | return version; 150 | } 151 | 152 | @Override 153 | public String getTitle() { 154 | return title; 155 | } 156 | 157 | @Override 158 | public String toString() { 159 | return articleUID + ": " + articleContent; 160 | } 161 | 162 | } 163 | 164 | } 165 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/StaticRules.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Collections; 5 | import java.util.List; 6 | import java.util.Optional; 7 | 8 | import org.apache.commons.lang.StringEscapeUtils; 9 | 10 | public class StaticRules { 11 | 12 | private static final java.util.regex.Pattern NONASCII_REGEX = java.util.regex.Pattern.compile("[^\u0000-\u007F]+"); 13 | private static final java.util.regex.Pattern QUEST_REGEX = java.util.regex.Pattern.compile("\\?+"); 14 | 15 | public static boolean isHtmlTag(String token) { 16 | return token.startsWith("<") && token.endsWith(">"); 17 | } 18 | 19 | public static boolean isPunctuation(String token) { 20 | return token.equals(",") || token.equals("."); 21 | } 22 | 23 | public static String normalizePre(String str) { 24 | str = str.toLowerCase(); 25 | str = NONASCII_REGEX.matcher(str).replaceAll("?"); 26 | str = QUEST_REGEX.matcher(str).replaceAll("?"); 27 | str = StringEscapeUtils.unescapeHtml(str); 28 | // Need to do this again because HTML-decoding may reintroduce non-ASCII characters. 29 | str = NONASCII_REGEX.matcher(str).replaceAll("?"); 30 | str = QUEST_REGEX.matcher(str).replaceAll("?"); 31 | 32 | return str; 33 | } 34 | 35 | public static String normalize(String str) { 36 | str = str.toLowerCase(); 37 | str = NONASCII_REGEX.matcher(str).replaceAll("?"); 38 | str = QUEST_REGEX.matcher(str).replaceAll("?"); 39 | str = StringEscapeUtils.unescapeHtml(str); 40 | // Need to do this again because HTML-decoding may reintroduce non-ASCII characters. 41 | str = NONASCII_REGEX.matcher(str).replaceAll("?"); 42 | str = QUEST_REGEX.matcher(str).replaceAll("?"); 43 | 44 | // Post-process 45 | String[] inTokens = str.split(" "); 46 | List outTokens = new ArrayList<>(); 47 | for (int i = 0; i < inTokens.length; i++) { 48 | String token = inTokens[i]; 49 | 50 | boolean added = false; 51 | if (token.length() > 1) { 52 | if (token.startsWith("?")) { 53 | outTokens.add("?"); 54 | token = token.substring(1); 55 | } 56 | if (token.endsWith("?")) { 57 | token = token.substring(0, token.length() - 1); 58 | if (token.contains("?") && token.contains("-")) { 59 | token = String.join(" - ", token.split("-")); 60 | } 61 | outTokens.add(token); 62 | outTokens.add("?"); 63 | added = true; 64 | } 65 | } 66 | 67 | if (!added) { 68 | if (token.contains("?") && token.contains("-")) { 69 | token = String.join(" - ", token.split("-")); 70 | } 71 | outTokens.add(token); 72 | } 73 | } 74 | 75 | return String.join(" ", outTokens); 76 | } 77 | 78 | public static String canonicalizeQuotation(String str) { 79 | 80 | // Normalize 81 | str = normalize(str); 82 | 83 | StringBuilder sb = new StringBuilder(); 84 | str.codePoints() 85 | .filter(c -> Character.isWhitespace(c) || Character.isLetterOrDigit(c) || c == '?') 86 | .map(c -> Character.isWhitespace(c) ? ' ' : c) 87 | .mapToObj(c -> Character.isAlphabetic(c) ? Character.toLowerCase(c) : c) 88 | .forEach(sb::appendCodePoint); 89 | 90 | return sb.toString() 91 | .trim() 92 | .replaceAll(" +", " "); // Remove double (or more) spaces 93 | } 94 | 95 | public static boolean matchSpeakerApprox(List first, List second, 96 | boolean caseSensitive) { 97 | if (second == null) { 98 | return false; 99 | } 100 | if (!caseSensitive) { 101 | first = Token.caseFold(first); 102 | second = Token.caseFold(second); 103 | } 104 | // Return true if they have at least one token in common 105 | return !Collections.disjoint(first, second); 106 | } 107 | 108 | public static Optional> matchSpeakerApprox(List first, 109 | Iterable> choices, boolean caseSensitive) { 110 | // Return the match with the highest number of tokens in common 111 | Optional> bestMatch = Optional.empty(); 112 | int bestMatchLen = 0; 113 | boolean dirty = false; // Used to track conflicts 114 | for (List choice : choices) { 115 | int matches = 0; 116 | // O(n^2) loop, but it is fine since these lists are very small 117 | for (int i = 0; i < choice.size(); i++) { 118 | for (int j = 0; j < first.size(); j++) { 119 | boolean equals; 120 | if (caseSensitive) { 121 | equals = choice.get(i).equals(first.get(j)); 122 | } else { 123 | equals = choice.get(i).equalsIgnoreCase(first.get(j)); 124 | } 125 | if (equals) { 126 | matches++; 127 | break; 128 | } 129 | } 130 | } 131 | if (matches > bestMatchLen) { 132 | bestMatchLen = matches; 133 | bestMatch = Optional.of(choice); 134 | dirty = false; 135 | } else if (matches == bestMatchLen) { 136 | dirty = true; 137 | } 138 | } 139 | 140 | if (dirty && bestMatchLen > 1) { 141 | throw new IllegalStateException("Conflicting speakers during ground truth evaluation: " 142 | + first + " " + bestMatch.get()); 143 | } 144 | 145 | if (bestMatchLen >= 2) { 146 | return bestMatch; 147 | } else { 148 | return Optional.empty(); 149 | } 150 | } 151 | 152 | } -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/Token.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.io.Serializable; 4 | import java.util.ArrayList; 5 | import java.util.Collection; 6 | import java.util.List; 7 | import java.util.Locale; 8 | import java.util.Objects; 9 | import java.util.stream.Collectors; 10 | 11 | import org.apache.commons.lang3.ObjectUtils; 12 | 13 | public final class Token implements Serializable, Comparable { 14 | 15 | private static final long serialVersionUID = 160449932354710762L; 16 | 17 | public enum Type { 18 | GENERIC, QUOTATION, SPEAKER, ANY; 19 | } 20 | 21 | private final String text; 22 | private final Type type; 23 | 24 | public Token(String text, Type type) { 25 | this.text = text; 26 | this.type = type; 27 | } 28 | 29 | public final Type getType() { 30 | return this.type; 31 | } 32 | 33 | public static List getStrings(Collection tokens) { 34 | return tokens.stream() 35 | .map(x -> x.text) 36 | .collect(Collectors.toList()); 37 | } 38 | 39 | public static List getTokens(Collection tokens) { 40 | return tokens.stream() 41 | .map(x -> new Token(x, Type.GENERIC)) 42 | .collect(Collectors.toList()); 43 | } 44 | 45 | @Override 46 | public int hashCode() { 47 | int hashCode = Objects.hashCode(text); 48 | 49 | // Important, since the JVM uses the object hash code for enums 50 | switch (type) { 51 | case GENERIC: 52 | hashCode += 377280272; 53 | break; 54 | case QUOTATION: 55 | hashCode += 116811174; 56 | break; 57 | case SPEAKER: 58 | hashCode += 637353343; 59 | break; 60 | case ANY: 61 | hashCode += 265741433; 62 | break; 63 | } 64 | return hashCode; 65 | } 66 | 67 | @Override 68 | public boolean equals(Object obj) { 69 | if (obj instanceof Token) { 70 | Token t = (Token) obj; 71 | return type == t.type && Objects.equals(text, t.text); 72 | } 73 | return false; 74 | } 75 | 76 | public boolean equalsIgnoreCase(Token t) { 77 | if (t == null) { 78 | return false; 79 | } 80 | 81 | return type == t.type && text.equalsIgnoreCase(t.text); 82 | } 83 | 84 | public static List caseFold(List tokens) { 85 | return tokens.stream() 86 | .map(x -> { 87 | if (x.text != null) { 88 | String s = x.text.toLowerCase(Locale.ROOT); 89 | return new Token(s, x.type); 90 | } 91 | return x; 92 | }) 93 | .collect(Collectors.toCollection(ArrayList::new)); 94 | } 95 | 96 | @Override 97 | public String toString() { 98 | if (text == null && type == Type.QUOTATION) { 99 | return Pattern.QUOTATION_PLACEHOLDER; 100 | } 101 | if (text == null && type == Type.SPEAKER) { 102 | return Pattern.SPEAKER_PLACEHOLDER; 103 | } 104 | if (text == null && type == Type.ANY) { 105 | return Pattern.ANY_PLACEHOLDER; 106 | } 107 | return text; 108 | } 109 | 110 | @Override 111 | public int compareTo(Token other) { 112 | // Define a lexicographical order of tokens 113 | if (type == other.type) 114 | { 115 | return ObjectUtils.compare(text, other.text); // Handles nulls 116 | } 117 | return type.compareTo(other.type); 118 | } 119 | 120 | } 121 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/Trie.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.io.Serializable; 4 | import java.util.ArrayList; 5 | import java.util.Collection; 6 | import java.util.Collections; 7 | import java.util.HashMap; 8 | import java.util.Iterator; 9 | import java.util.List; 10 | import java.util.Locale; 11 | import java.util.Map; 12 | 13 | public class Trie implements Serializable { 14 | 15 | private static final long serialVersionUID = -9164247688163451454L; 16 | 17 | private final Node rootNode; 18 | private final boolean caseSensitive; 19 | 20 | public Trie(Collection patterns, boolean caseSensitive) { 21 | this.rootNode = new RootNode(caseSensitive); 22 | this.caseSensitive = caseSensitive; 23 | List sortedPatterns = new ArrayList<>(patterns); 24 | sortedPatterns.sort(Collections.reverseOrder()); 25 | sortedPatterns.forEach(this::insertPattern); 26 | } 27 | 28 | private void insertPattern(Pattern pattern) { 29 | Node current = rootNode; 30 | 31 | Iterator it = pattern.iterator(); 32 | while (it.hasNext()) { 33 | Token token = it.next(); 34 | Node next = null; 35 | 36 | if (token.getType() == Token.Type.GENERIC) { 37 | // Text token 38 | String key = caseSensitive 39 | ? token.toString() 40 | : token.toString().toLowerCase(Locale.ROOT); 41 | next = current.getTextChild(key); 42 | if (next == null) { 43 | next = it.hasNext() 44 | ? new InnerNode(token, caseSensitive) 45 | : new TerminalNode(token, pattern.getConfidenceMetric()); 46 | ((RootNode) current).textChildren.put(key, next); 47 | } 48 | } else { 49 | // Special token 50 | for (Node candidateNext : current) { 51 | if (candidateNext.getToken().equals(token)) { 52 | next = candidateNext; 53 | break; 54 | } 55 | } 56 | if (next == null) { 57 | next = it.hasNext() 58 | ? new InnerNode(token, caseSensitive) 59 | : new TerminalNode(token, pattern.getConfidenceMetric()); 60 | ((RootNode) current).children.add(next); 61 | } 62 | } 63 | 64 | current = next; 65 | } 66 | } 67 | 68 | public boolean isCaseSensitive() { 69 | return caseSensitive; 70 | } 71 | 72 | public List getAllPatterns() { 73 | List allPatterns = new ArrayList<>(); 74 | 75 | List currentPattern = new ArrayList<>(); 76 | DFS(rootNode, currentPattern, allPatterns); 77 | return allPatterns; 78 | } 79 | 80 | public Node getRootNode() { 81 | return rootNode; 82 | } 83 | 84 | private void DFS(Node current, List currentPattern, List allPatterns) { 85 | if (current.hasChildren()) { 86 | for (Node next : current.getTextChildren()) { 87 | currentPattern.add(next.getToken()); 88 | DFS(next, currentPattern, allPatterns); 89 | currentPattern.remove(currentPattern.size() - 1); 90 | } 91 | 92 | for (Node next : current) { 93 | currentPattern.add(next.getToken()); 94 | DFS(next, currentPattern, allPatterns); 95 | currentPattern.remove(currentPattern.size() - 1); 96 | } 97 | } else { 98 | // Terminal node 99 | allPatterns.add(new Pattern(new ArrayList<>(currentPattern), current.getConfidenceFactor())); 100 | } 101 | } 102 | 103 | public interface Node extends Iterable, Serializable { 104 | Token getToken(); 105 | boolean hasChildren(); 106 | Node getTextChild(String key); 107 | Iterable getTextChildren(); 108 | double getConfidenceFactor(); 109 | } 110 | 111 | private static class RootNode implements Node { 112 | 113 | private static final long serialVersionUID = -3680161357395993244L; 114 | 115 | private final Map textChildren; 116 | private final List children; 117 | private final boolean caseSensitive; 118 | 119 | public RootNode(boolean caseSensitive) { 120 | this.children = new ArrayList<>(); 121 | this.textChildren = new HashMap<>(); 122 | this.caseSensitive = caseSensitive; 123 | } 124 | 125 | @Override 126 | public Iterator iterator() { 127 | return children.iterator(); 128 | } 129 | 130 | @Override 131 | public Token getToken() { 132 | throw new UnsupportedOperationException("No token in the root node"); 133 | } 134 | 135 | @Override 136 | public boolean hasChildren() { 137 | return !children.isEmpty() || !textChildren.isEmpty(); 138 | } 139 | 140 | @Override 141 | public String toString() { 142 | return "*ROOT*"; 143 | } 144 | 145 | @Override 146 | public double getConfidenceFactor() { 147 | throw new UnsupportedOperationException("This is not a terminal node"); 148 | } 149 | 150 | @Override 151 | public Node getTextChild(String key) { 152 | if (!caseSensitive) { 153 | key = key.toLowerCase(Locale.ROOT); 154 | } 155 | return textChildren.get(key); 156 | } 157 | 158 | @Override 159 | public Iterable getTextChildren() { 160 | return textChildren.values(); 161 | } 162 | } 163 | 164 | private static class InnerNode extends RootNode { 165 | 166 | private static final long serialVersionUID = 8115648454995183879L; 167 | 168 | private final Token token; 169 | 170 | private InnerNode(Token t, boolean caseSensitive) { 171 | super(caseSensitive); 172 | token = t; 173 | } 174 | 175 | public Token getToken() { 176 | return token; 177 | } 178 | 179 | @Override 180 | public String toString() { 181 | return token.toString(); 182 | } 183 | } 184 | 185 | private static class TerminalNode implements Node { 186 | 187 | private static final long serialVersionUID = 6535993027251722667L; 188 | 189 | private final Token token; 190 | private final double confidenceFactor; 191 | 192 | public TerminalNode(Token t, double confidence) { 193 | token = t; 194 | confidenceFactor = confidence; 195 | } 196 | 197 | @Override 198 | public Iterator iterator() { 199 | return Collections.emptyIterator(); 200 | } 201 | 202 | @Override 203 | public Token getToken() { 204 | return token; 205 | } 206 | 207 | @Override 208 | public boolean hasChildren() { 209 | return false; 210 | } 211 | 212 | @Override 213 | public String toString() { 214 | return token.toString(); 215 | } 216 | 217 | @Override 218 | public double getConfidenceFactor() { 219 | return confidenceFactor; 220 | } 221 | 222 | @Override 223 | public Node getTextChild(String key) { 224 | return null; 225 | } 226 | 227 | @Override 228 | public Iterable getTextChildren() { 229 | return Collections.emptyList(); 230 | } 231 | } 232 | } 233 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/TriePatternMatcher.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | import java.util.stream.Collectors; 6 | 7 | public class TriePatternMatcher extends PatternMatcher { 8 | 9 | private final Trie trie; 10 | private final List currentPattern; 11 | private double matchedConfidenceFactor; 12 | 13 | public TriePatternMatcher(Trie trie, int speakerLengthMin, int speakerLengthMax) { 14 | super(speakerLengthMin, speakerLengthMax, trie.isCaseSensitive()); 15 | this.trie = trie; 16 | currentPattern = new ArrayList<>(); 17 | matchedConfidenceFactor = Double.NaN; 18 | } 19 | 20 | @Override 21 | public boolean match(Sentence s) { 22 | sentenceTokens = s.getTokens(); 23 | 24 | boolean result = false; 25 | matches.clear(); 26 | for (int i = 0; i < sentenceTokens.size(); i++) { 27 | 28 | matchedConfidenceFactor = Double.NaN; 29 | if (sentenceTokens.get(i).getType() == Token.Type.GENERIC) { 30 | Trie.Node node = trie.getRootNode().getTextChild(sentenceTokens.get(i).toString()); 31 | if (node != null) { 32 | currentPattern.add(node); 33 | if (matchImpl(node, i, maxSpeakerLength)) { 34 | result = true; 35 | } 36 | currentPattern.remove(currentPattern.size() - 1); 37 | } 38 | } 39 | for (Trie.Node node : trie.getRootNode()) { 40 | currentPattern.add(node); 41 | if (matchImpl(node, i, maxSpeakerLength)) { 42 | result = true; 43 | } 44 | currentPattern.remove(currentPattern.size() - 1); 45 | } 46 | } 47 | 48 | return result; 49 | } 50 | 51 | private Pattern getMatchedPattern() { 52 | return new Pattern(new ArrayList<>(currentPattern.stream() 53 | .map(x -> x.getToken()) 54 | .collect(Collectors.toList())), matchedConfidenceFactor); 55 | } 56 | 57 | private boolean matchImplNext(Trie.Node node, int j, int speakerTokensLeft) { 58 | if (node.hasChildren()) { 59 | boolean result = false; 60 | 61 | if (j + 1 < sentenceTokens.size() && sentenceTokens.get(j + 1).getType() == Token.Type.GENERIC) { 62 | Trie.Node next = node.getTextChild(sentenceTokens.get(j + 1).toString()); 63 | if (next != null) { 64 | currentPattern.add(next); 65 | if (matchImpl(next, j + 1, speakerTokensLeft)) { 66 | result = true; 67 | } 68 | currentPattern.remove(currentPattern.size() - 1); 69 | } 70 | } 71 | for (Trie.Node next : node) { 72 | currentPattern.add(next); 73 | 74 | if (matchImpl(next, j + 1, speakerTokensLeft)) { 75 | result = true; 76 | } 77 | currentPattern.remove(currentPattern.size() - 1); 78 | } 79 | return result; 80 | } else { 81 | // End of pattern reached (terminal node) 82 | matchedConfidenceFactor = node.getConfidenceFactor(); 83 | matches.add(new Match(matchedQuotation, matchedSpeaker, getMatchedPattern())); 84 | return true; 85 | } 86 | } 87 | 88 | private boolean matchImpl(Trie.Node node, int j, int speakerTokensLeft) { 89 | if (j == sentenceTokens.size()) { 90 | // End of text reached. No matches are possible. 91 | return false; 92 | } 93 | 94 | if (matchTokens(node.getToken(), sentenceTokens.get(j), speakerTokensLeft)) { 95 | if (speakerTokenFoundFlag) { 96 | speakerTokenFoundFlag = false; 97 | boolean m1 = false; 98 | if (matchedSpeaker.size() >= minSpeakerLength) { 99 | m1 = matchImplNext(node, j, 0); 100 | } 101 | boolean m2 = matchImpl(node, j + 1, speakerTokensLeft - 1); 102 | matchedSpeaker.remove(matchedSpeaker.size() - 1); 103 | return m1 || m2; 104 | 105 | } else if (quotationTokenFoundFlag) { 106 | quotationTokenFoundFlag = false; 107 | boolean m = matchImplNext(node, j, speakerTokensLeft); 108 | matchedQuotation.remove(matchedQuotation.size() - 1); 109 | return m; 110 | 111 | } else { 112 | return matchImplNext(node, j, speakerTokensLeft); 113 | } 114 | } 115 | 116 | return false; 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/TupleExtractor.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.util.Iterator; 4 | import java.util.List; 5 | 6 | import ch.epfl.dlab.spinn3r.converter.AbstractDecoder; 7 | import scala.Tuple4; 8 | 9 | public class TupleExtractor 10 | // Tuple: (quotation, speaker, full context, pattern that extracted the quotation) 11 | extends AbstractDecoder, Sentence, Pattern>> { 12 | 13 | private final Iterator it; 14 | private final PatternMatcher patternMatcher; 15 | 16 | private Iterator currentMatch; 17 | private Sentence currentSentence; 18 | 19 | public TupleExtractor(Iterator input, PatternMatcher pm) { 20 | it = input; 21 | patternMatcher = pm; 22 | currentMatch = null; 23 | currentSentence = null; 24 | } 25 | 26 | @Override 27 | protected Tuple4, Sentence, Pattern> getNextImpl() { 28 | while (true) { 29 | if (currentMatch == null) { 30 | if (it.hasNext()) { 31 | currentSentence = it.next(); 32 | if (patternMatcher.match(currentSentence)) { 33 | currentMatch = patternMatcher.getMatches(false).iterator(); 34 | } 35 | } else { 36 | // No more results 37 | return null; 38 | } 39 | } 40 | 41 | if (currentMatch != null) { 42 | if (currentMatch.hasNext()) { 43 | PatternMatcher.Match m = currentMatch.next(); 44 | return new Tuple4<>(m.getQuotation(), m.getSpeaker(), currentSentence, m.getPattern()); 45 | } else { 46 | currentMatch = null; 47 | } 48 | } 49 | } 50 | } 51 | 52 | } 53 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/quootstrap/Utils.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.quootstrap; 2 | 3 | import java.io.File; 4 | import java.io.IOException; 5 | import java.nio.file.Files; 6 | import java.nio.file.Paths; 7 | import java.util.Collection; 8 | import java.util.HashMap; 9 | import java.util.List; 10 | import java.util.Map; 11 | import java.util.stream.Collectors; 12 | 13 | import org.apache.commons.io.FileUtils; 14 | import org.apache.hadoop.fs.FileSystem; 15 | import org.apache.hadoop.fs.Path; 16 | import org.apache.spark.api.java.JavaPairRDD; 17 | import org.apache.spark.api.java.JavaRDD; 18 | import org.apache.spark.api.java.JavaSparkContext; 19 | 20 | import ch.epfl.dlab.spinn3r.converter.ProtoToJson; 21 | import scala.Tuple2; 22 | 23 | public class Utils { 24 | 25 | /** 26 | * Load or generate a cached RDD. 27 | * If caching is disabled, the RDD is saved and returned as-is. 28 | * If caching is enabled, the RDD is loaded from disk if it is already cached, 29 | * otherwise it is generated, saved, and returned. 30 | * @param rdd the rdd to cache 31 | * @param fileName the file name of the cache file 32 | * @return the output RDD (cached or generated) 33 | */ 34 | public static JavaRDD loadCache(JavaRDD rdd, String fileName) { 35 | if (!ConfigManager.getInstance().isCacheEnabled()) { 36 | return rdd; 37 | } 38 | final String cacheDir = ConfigManager.getInstance().getCachePath(); 39 | final String path = cacheDir + "/" + fileName; 40 | try { 41 | FileSystem hdfs = org.apache.hadoop.fs.FileSystem.get(rdd.context().hadoopConfiguration()); 42 | Path hdfsPath = new Path(path); 43 | if (!hdfs.exists(hdfsPath)) { 44 | rdd.saveAsObjectFile(path); 45 | } 46 | return JavaSparkContext.fromSparkContext(rdd.context()).objectFile(path); 47 | } catch (IOException e) { 48 | throw new IllegalStateException(e); 49 | } 50 | 51 | } 52 | 53 | /** 54 | * Load or generate a cached PairRDD. 55 | * If caching is disabled, the RDD is saved and returned as-is. 56 | * If caching is enabled, the RDD is loaded from disk if it is already cached, 57 | * otherwise it is generated, saved, and returned. 58 | * @param rdd the rdd to cache 59 | * @param fileName the file name of the cache file 60 | * @return the output RDD (cached or generated) 61 | */ 62 | public static JavaPairRDD loadCache(JavaPairRDD rdd, String fileName) { 63 | if (!ConfigManager.getInstance().isCacheEnabled()) { 64 | return rdd; 65 | } 66 | final String cacheDir = ConfigManager.getInstance().getCachePath(); 67 | final String path = cacheDir + "/" + fileName; 68 | try { 69 | FileSystem hdfs = org.apache.hadoop.fs.FileSystem.get(rdd.context().hadoopConfiguration()); 70 | Path hdfsPath = new Path(path); 71 | if (!hdfs.exists(hdfsPath)) { 72 | rdd.saveAsObjectFile(path); 73 | } 74 | return JavaPairRDD.fromJavaRDD(JavaSparkContext.fromSparkContext(rdd.context()).objectFile(path)); 75 | } catch (IOException e) { 76 | throw new IllegalStateException(e); 77 | } 78 | } 79 | 80 | public static void dumpRDD(JavaRDD rdd, String fileName) { 81 | if (!ConfigManager.getInstance().isLocalModeEnabled()) { 82 | throw new IllegalArgumentException("The method dumpRDD can be used only in local mode"); 83 | } 84 | fileName = ConfigManager.getInstance().getOutputPath() + fileName; 85 | FileUtils.deleteQuietly(new File(fileName)); 86 | rdd.saveAsTextFile(fileName); 87 | try { 88 | ProtoToJson.mergeHdfsFile(fileName); 89 | } catch (IOException e) { 90 | throw new IllegalStateException(e); 91 | } 92 | } 93 | 94 | public static void dumpRDDLocal(JavaRDD rdd, String fileName) { 95 | fileName = ConfigManager.getInstance().getOutputPath() + fileName; 96 | FileUtils.deleteQuietly(new File(fileName)); 97 | List lines = rdd.collect() 98 | .stream() 99 | .map(x -> x.toString()) 100 | .collect(Collectors.toList()); 101 | try { 102 | Files.write(Paths.get(fileName), lines); 103 | } catch (IOException e) { 104 | throw new IllegalStateException(e); 105 | } 106 | } 107 | 108 | public static void dumpCollection(Collection data, String fileName) { 109 | fileName = ConfigManager.getInstance().getOutputPath() + fileName; 110 | List lines = data.stream() 111 | .map(x -> x.toString()) 112 | .sorted() 113 | .collect(Collectors.toList()); 114 | 115 | try { 116 | Files.write(Paths.get(fileName), lines); 117 | } catch (IOException e) { 118 | throw new IllegalStateException(e); 119 | } 120 | } 121 | 122 | public static void dumpRDD(JavaPairRDD rdd, String fileName) { 123 | if (!ConfigManager.getInstance().isLocalModeEnabled()) { 124 | throw new IllegalArgumentException("The method dumpRDD can be used only in local mode"); 125 | } 126 | FileUtils.deleteQuietly(new File(fileName)); 127 | rdd.saveAsTextFile(fileName); 128 | try { 129 | ProtoToJson.mergeHdfsFile(fileName); 130 | } catch (IOException e) { 131 | throw new IllegalStateException(e); 132 | } 133 | } 134 | 135 | public static List findLongestSuperstring(List needle, Iterable> haystack) { 136 | List bestMatch = null; 137 | boolean dirty = false; 138 | for (List candidate : haystack) { 139 | if (bestMatch == null || candidate.size() >= bestMatch.size()) { 140 | for (int i = 0; i < candidate.size() - needle.size() + 1; i++) { 141 | List subCandidate = candidate.subList(i, i + needle.size()); 142 | if (subCandidate.equals(needle)) { 143 | if (bestMatch == null || candidate.size() > bestMatch.size()) { 144 | bestMatch = candidate; 145 | dirty = false; 146 | } else { 147 | dirty = true; 148 | } 149 | } 150 | } 151 | } 152 | } 153 | 154 | // If we have multiple superstrings of the same length, return no match to avoid conflicts 155 | // e.g. "John Doe" could be extended to either "John Doe Jr" or "John Doe Sr" 156 | if (dirty) { 157 | return null; 158 | } 159 | return bestMatch; 160 | } 161 | 162 | public static SpeakerAlias findUniqueSuperstring(List needle, 163 | Iterable haystack, boolean caseSensitive) { 164 | SpeakerAlias bestMatch = null; 165 | for (SpeakerAlias candidate : haystack) { 166 | List candTok = Token.getTokens(candidate.getAlias()); 167 | for (int i = 0; i < candTok.size() - needle.size() + 1; i++) { 168 | List subCandidate = candTok.subList(i, i + needle.size()); 169 | boolean equals; 170 | if (caseSensitive) { 171 | equals = subCandidate.equals(needle); 172 | } else { 173 | equals = equalsCaseInsensitive(subCandidate, needle); 174 | } 175 | if (equals) { 176 | if (bestMatch != null) { 177 | return null; // Conflict detected 178 | } 179 | bestMatch = candidate; 180 | break; // Break the outer loop 181 | } 182 | } 183 | } 184 | 185 | return bestMatch; 186 | } 187 | 188 | private static boolean equalsCaseInsensitive(List a, List b) { 189 | if (a.size() != b.size()) { 190 | return false; 191 | } 192 | 193 | for (int i = 0; i < a.size(); i++) { 194 | if (!a.get(i).equalsIgnoreCase(b.get(i))) { 195 | return false; 196 | } 197 | } 198 | 199 | return true; 200 | } 201 | 202 | /** 203 | * Returns the element that appears with the highest frequency, 204 | * along with its count. 205 | * If there are ties, this method returns false. 206 | */ 207 | public static Tuple2 maxFrequencyItem(Iterable it) { 208 | Map frequencies = new HashMap<>(); 209 | for (T elem : it) { 210 | frequencies.put(elem, frequencies.getOrDefault(elem, 0) + 1); 211 | } 212 | Map.Entry best = null; 213 | boolean dirty = true; // We want to avoid ties 214 | for (Map.Entry entry : frequencies.entrySet()) { 215 | if (best == null || entry.getValue() > best.getValue()) { 216 | dirty = false; 217 | best = entry; 218 | } else if (entry.getValue() == best.getValue()) { 219 | dirty = true; 220 | } 221 | } 222 | 223 | if (dirty) { 224 | return null; 225 | } else { 226 | return new Tuple2<>(best.getKey(), best.getValue()); 227 | } 228 | } 229 | 230 | public static boolean doubleEquals(double a, double b) { 231 | return Math.abs(a - b) < 1e-6; 232 | } 233 | } 234 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/spinn3r/EntryWrapper.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.spinn3r; 2 | import java.io.Serializable; 3 | import java.util.List; 4 | 5 | /** 6 | * This apparently purpose-less class wraps Spinn3r entries 7 | * and is used as input for a JSON serializer such as Gson. 8 | */ 9 | public class EntryWrapper implements Serializable { 10 | 11 | private static final long serialVersionUID = 3319808914288405234L; 12 | 13 | private final SourceWrapper source; 14 | private final FeedWrapper feed; 15 | private final FeedEntryWrapper feedEntry; 16 | private final PermalinkEntryWrapper permalinkEntry; 17 | 18 | public EntryWrapper(SourceWrapper source, FeedWrapper feed, FeedEntryWrapper feedEntry, 19 | PermalinkEntryWrapper permalinkEntry) { 20 | this.source = source; 21 | this.feed = feed; 22 | this.feedEntry = feedEntry; 23 | this.permalinkEntry = permalinkEntry; 24 | } 25 | 26 | public SourceWrapper getSource() { 27 | return source; 28 | } 29 | 30 | public FeedWrapper getFeed() { 31 | return feed; 32 | } 33 | 34 | public FeedEntryWrapper getFeedEntry() { 35 | return feedEntry; 36 | } 37 | 38 | public PermalinkEntryWrapper getPermalinkEntry() { 39 | return permalinkEntry; 40 | } 41 | 42 | public static class SourceWrapper implements Serializable { 43 | 44 | private static final long serialVersionUID = 884787890709132785L; 45 | 46 | private final String url; 47 | private final String title; 48 | private final String language; 49 | private final String description; 50 | private final String lastPosted; 51 | private final String lastPublished; 52 | private final String dateFound; 53 | private final String publisherType; 54 | 55 | public SourceWrapper(String url, String title, String language, String description, String lastPosted, 56 | String lastPublished, String dateFound, String publisherType) { 57 | this.url = url; 58 | this.title = title; 59 | this.language = language; 60 | this.description = description; 61 | this.lastPosted = lastPosted; 62 | this.lastPublished = lastPublished; 63 | this.dateFound = dateFound; 64 | this.publisherType = publisherType; 65 | } 66 | 67 | public String getUrl() { 68 | return url; 69 | } 70 | 71 | public String getTitle() { 72 | return title; 73 | } 74 | 75 | public String getLanguage() { 76 | return language; 77 | } 78 | 79 | public String getDescription() { 80 | return description; 81 | } 82 | 83 | public String getLastPosted() { 84 | return lastPosted; 85 | } 86 | 87 | public String getLastPublished() { 88 | return lastPublished; 89 | } 90 | 91 | public String getDateFound() { 92 | return dateFound; 93 | } 94 | 95 | public String getPublisherType() { 96 | return publisherType; 97 | } 98 | 99 | } 100 | 101 | public static class FeedWrapper implements Serializable { 102 | 103 | private static final long serialVersionUID = 928980040260245901L; 104 | 105 | private final String url; 106 | private final String title; 107 | private final String language; 108 | private final String lastPosted; 109 | private final String lastPublished; 110 | private final String dateFound; 111 | private final String channelUrl; 112 | 113 | public FeedWrapper(String url, String title, String language, String lastPosted, String lastPublished, 114 | String dateFound, String channelUrl) { 115 | this.url = url; 116 | this.title = title; 117 | this.language = language; 118 | this.lastPosted = lastPosted; 119 | this.lastPublished = lastPublished; 120 | this.dateFound = dateFound; 121 | this.channelUrl = channelUrl; 122 | } 123 | 124 | public String getUrl() { 125 | return url; 126 | } 127 | 128 | public String getTitle() { 129 | return title; 130 | } 131 | 132 | public String getLanguage() { 133 | return language; 134 | } 135 | 136 | public String getLastPosted() { 137 | return lastPosted; 138 | } 139 | 140 | public String getLastPublished() { 141 | return lastPublished; 142 | } 143 | 144 | public String getDateFound() { 145 | return dateFound; 146 | } 147 | 148 | public String getChannelUrl() { 149 | return channelUrl; 150 | } 151 | } 152 | 153 | public static class FeedEntryWrapper implements Serializable { 154 | 155 | private static final long serialVersionUID = 5137190259805666093L; 156 | 157 | private final String identifier; 158 | private final String url; 159 | private final String title; 160 | private final String language; 161 | private final String authorName; 162 | private final String authorEmail; 163 | private final String lastPublished; 164 | private final String dateFound; 165 | private final List tokenizedContent; 166 | 167 | public FeedEntryWrapper(String identifier, String url, String title, String language, String authorName, 168 | String authorEmail, String lastPublished, String dateFound, List tokenizedContent) { 169 | super(); 170 | this.identifier = identifier; 171 | this.url = url; 172 | this.title = title; 173 | this.language = language; 174 | this.authorName = authorName; 175 | this.authorEmail = authorEmail; 176 | this.lastPublished = lastPublished; 177 | this.dateFound = dateFound; 178 | this.tokenizedContent = tokenizedContent; 179 | } 180 | 181 | public String getIdentifier() { 182 | return identifier; 183 | } 184 | 185 | public String getUrl() { 186 | return url; 187 | } 188 | 189 | public String getTitle() { 190 | return title; 191 | } 192 | 193 | public String getLanguage() { 194 | return language; 195 | } 196 | 197 | public String getAuthorName() { 198 | return authorName; 199 | } 200 | 201 | public String getAuthorEmail() { 202 | return authorEmail; 203 | } 204 | 205 | public String getLastPublished() { 206 | return lastPublished; 207 | } 208 | 209 | public String getDateFound() { 210 | return dateFound; 211 | } 212 | 213 | public List getContent() { 214 | return tokenizedContent; 215 | } 216 | } 217 | 218 | public static class PermalinkEntryWrapper implements Serializable { 219 | 220 | private static final long serialVersionUID = -1886188543863283140L; 221 | 222 | private final String identifier; 223 | private final String url; 224 | private final String title; 225 | private final String language; 226 | private final String authorName; 227 | private final String authorEmail; 228 | private final String lastPublished; 229 | private final String dateFound; 230 | private final List tokenizedContent; 231 | 232 | public PermalinkEntryWrapper(String identifier, String url, String title, String language, String authorName, 233 | String authorEmail, String lastPublished, String dateFound, List tokenizedContent) { 234 | this.identifier = identifier; 235 | this.url = url; 236 | this.title = title; 237 | this.language = language; 238 | this.authorName = authorName; 239 | this.authorEmail = authorEmail; 240 | this.lastPublished = lastPublished; 241 | this.dateFound = dateFound; 242 | this.tokenizedContent = tokenizedContent; 243 | } 244 | 245 | public String getIdentifier() { 246 | return identifier; 247 | } 248 | 249 | public String getUrl() { 250 | return url; 251 | } 252 | 253 | public String getTitle() { 254 | return title; 255 | } 256 | 257 | public String getLanguage() { 258 | return language; 259 | } 260 | 261 | public String getAuthorName() { 262 | return authorName; 263 | } 264 | 265 | public String getAuthorEmail() { 266 | return authorEmail; 267 | } 268 | 269 | public String getLastPublished() { 270 | return lastPublished; 271 | } 272 | 273 | public String getDateFound() { 274 | return dateFound; 275 | } 276 | 277 | public List getContent() { 278 | return tokenizedContent; 279 | } 280 | } 281 | } 282 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/spinn3r/Tokenizer.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.spinn3r; 2 | 3 | import java.util.List; 4 | 5 | public interface Tokenizer { 6 | 7 | List tokenize(String sentence); 8 | String untokenize(List tokens); 9 | 10 | } 11 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/spinn3r/TokenizerImpl.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.spinn3r; 2 | 3 | import java.io.StringReader; 4 | import java.util.ArrayList; 5 | import java.util.List; 6 | 7 | import edu.stanford.nlp.ling.CoreLabel; 8 | import edu.stanford.nlp.process.CoreLabelTokenFactory; 9 | import edu.stanford.nlp.process.PTBTokenizer; 10 | 11 | public class TokenizerImpl implements Tokenizer { 12 | 13 | /** 14 | * Configuration for Stanford PTBTokenizer. 15 | */ 16 | private static final String TOKENIZER_SETTINGS = "tokenizeNLs=false, americanize=false, " + 17 | "normalizeCurrency=false, normalizeParentheses=false," + 18 | "normalizeOtherBrackets=false, unicodeQuotes=false, ptb3Ellipsis=true," + 19 | "escapeForwardSlashAsterisk=false, untokenizable=noneKeep, normalizeSpace=false"; 20 | 21 | 22 | @Override 23 | public List tokenize(String sentence) { 24 | PTBTokenizer ptbt = new PTBTokenizer<>(new StringReader(sentence), 25 | new CoreLabelTokenFactory(), TOKENIZER_SETTINGS); 26 | List tokens = new ArrayList<>(); 27 | while (ptbt.hasNext()) { 28 | CoreLabel label = ptbt.next(); 29 | tokens.add(label.toString()); 30 | } 31 | return tokens; 32 | } 33 | 34 | @Override 35 | public String untokenize(List tokens) { 36 | return PTBTokenizer.ptb2Text(tokens); 37 | } 38 | 39 | } 40 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/spinn3r/converter/AbstractDecoder.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.spinn3r.converter; 2 | import java.util.Iterator; 3 | 4 | /** 5 | * Provides a convenient interface to create custom iterators 6 | * without materializing intermediate results. 7 | */ 8 | public abstract class AbstractDecoder implements Iterator { 9 | 10 | private boolean dirty; 11 | private boolean finished; 12 | private T cachedNext; 13 | 14 | protected AbstractDecoder() { 15 | dirty = true; 16 | finished = false; 17 | } 18 | 19 | /** 20 | * Obtain the next element. 21 | */ 22 | protected abstract T getNextImpl(); 23 | 24 | private void extractNext() 25 | { 26 | if (!finished && dirty) { 27 | cachedNext = getNextImpl(); 28 | dirty = false; 29 | 30 | if (cachedNext == null) { 31 | finished = true; 32 | } 33 | } 34 | } 35 | 36 | @Override 37 | public boolean hasNext() { 38 | extractNext(); 39 | return !finished; 40 | } 41 | 42 | @Override 43 | public T next() { 44 | extractNext(); 45 | T next = cachedNext; 46 | cachedNext = null; 47 | dirty = true; 48 | return next; 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/spinn3r/converter/CombinedDecoder.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.spinn3r.converter; 2 | import java.util.Iterator; 3 | 4 | import org.apache.spark.input.PortableDataStream; 5 | 6 | import ch.epfl.dlab.spinn3r.EntryWrapper; 7 | import scala.Tuple2; 8 | 9 | public class CombinedDecoder extends AbstractDecoder { 10 | 11 | private final Iterator> fileInput; 12 | private Decompressor decompressor; 13 | private SpinnerDecoder spinnerDecoder; 14 | private final boolean clean; 15 | private final boolean tokenize; 16 | 17 | /** 18 | * Initialize this decoder. 19 | * @param in The streams to decode and convert. 20 | * @param clean Whether to clean or not useless HTML tags. 21 | * @param tokenize Whether to tokenize the output. 22 | */ 23 | public CombinedDecoder(Iterator> in, boolean clean, boolean tokenize) { 24 | fileInput = in; 25 | this.clean = clean; 26 | this.tokenize = tokenize; 27 | } 28 | 29 | @Override 30 | protected EntryWrapper getNextImpl() { 31 | while (true) { 32 | 33 | if (decompressor == null) { 34 | if (fileInput.hasNext()) { 35 | // Decompress next archive 36 | decompressor = new Decompressor(fileInput.next()._2.open()); 37 | } else { 38 | return null; 39 | } 40 | } 41 | 42 | if (spinnerDecoder == null) { 43 | if (decompressor.hasNext()) { 44 | // Decode next file from archive 45 | spinnerDecoder = new SpinnerDecoder(decompressor.next()); 46 | } else { 47 | decompressor = null; 48 | } 49 | } 50 | 51 | if (spinnerDecoder != null) { 52 | if (spinnerDecoder.hasNext()) { 53 | // Decode next entry from file 54 | return EntryWrapperBuilder.buildFrom(spinnerDecoder.next(), clean, tokenize); 55 | } else { 56 | spinnerDecoder = null; 57 | } 58 | } 59 | } 60 | } 61 | 62 | } 63 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/spinn3r/converter/Decompressor.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.spinn3r.converter; 2 | import java.io.IOException; 3 | import java.io.InputStream; 4 | 5 | import org.apache.commons.compress.archivers.tar.TarArchiveEntry; 6 | import org.apache.commons.compress.archivers.tar.TarArchiveInputStream; 7 | import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream; 8 | 9 | /** 10 | * This class implements an on-the-fly decoder for .tar.gz archives. 11 | * The stream is decompressed in memory, and the result is returned as an iterator, 12 | * without materializing the entire result. 13 | */ 14 | public class Decompressor extends AbstractDecoder { 15 | 16 | private final TarArchiveInputStream tar; 17 | 18 | /** 19 | * Initializes this iterator using the given stream. 20 | * @param inStream the stream to decode 21 | */ 22 | public Decompressor(InputStream inStream) { 23 | try { 24 | tar = new TarArchiveInputStream(new GzipCompressorInputStream(inStream)); 25 | } catch (IOException e) { 26 | throw new IllegalStateException("Unable to open stream", e); 27 | } 28 | } 29 | 30 | @Override 31 | protected byte[] getNextImpl() { 32 | try { 33 | TarArchiveEntry entry; 34 | while ((entry = tar.getNextTarEntry()) != null) { 35 | if (entry.isDirectory()) { 36 | continue; 37 | } 38 | 39 | byte[] data = new byte[tar.available()]; 40 | tar.read(data); 41 | return data; 42 | } 43 | 44 | // Entry is null 45 | tar.close(); 46 | return null; 47 | } catch (IOException e) { 48 | throw new IllegalStateException("Unable to read stream", e); 49 | } 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/spinn3r/converter/ProtoToJson.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.spinn3r.converter; 2 | import java.io.BufferedOutputStream; 3 | import java.io.File; 4 | import java.io.FileInputStream; 5 | import java.io.FileOutputStream; 6 | import java.io.IOException; 7 | import java.io.InputStream; 8 | import java.nio.file.Files; 9 | import java.nio.file.Paths; 10 | import java.util.ArrayList; 11 | import java.util.Arrays; 12 | import java.util.List; 13 | import java.util.Optional; 14 | 15 | import org.apache.commons.io.FileUtils; 16 | import org.apache.commons.io.IOUtils; 17 | import org.apache.hadoop.io.compress.CompressionCodec; 18 | import org.apache.spark.SparkConf; 19 | import org.apache.spark.api.java.JavaRDD; 20 | import org.apache.spark.api.java.JavaSparkContext; 21 | import org.apache.spark.util.LongAccumulator; 22 | 23 | import com.google.gson.Gson; 24 | import com.google.gson.GsonBuilder; 25 | 26 | import ch.epfl.dlab.spinn3r.EntryWrapper; 27 | import scala.Tuple2; 28 | 29 | public class ProtoToJson { 30 | 31 | private static final Gson JSON_BUILDER = new GsonBuilder().disableHtmlEscaping().create(); 32 | 33 | public static void main(final String[] args) throws IOException { 34 | 35 | if (args.length < 2) { 36 | System.err.println("Usage: ProtoToJson [optional arguments]"); 37 | System.err.println("[--master=] specifies a master if spark-submit is not used."); 38 | System.err.println("[--sample=] specifies the fraction of data to sample."); 39 | System.err.println("[--merge] merges HDFS output into one single file (if the application is run locally)."); 40 | System.err.println("[--partitions=] specifies the number of partitions to save (useful if compression is used)"); 41 | System.err.println("[--compress=] compresses the output using the given codec (GzipCodec, Lz4Codec, BZip2Codec, SnappyCodec)"); 42 | System.err.println("[--source-type=] specifies the source type(s) to extract (e.g. MAINSTREAM_NEWS). Separate with ;"); 43 | System.err.println("[--clean=] clean useless HTML tags (default: true)"); 44 | System.err.println("[--tokenize=] tokenize the output using Stanford PTBTokenizer (default: true)"); 45 | System.err.println("[--remove-duplicates] try to remove articles with duplicate contents (costly operation)"); 46 | System.err.println("Examples:"); 47 | System.err.println("ProtoToJson dir/*.tar.gz out.json --master=local[*] --sample=0.1 --merge"); 48 | System.err.println("ProtoToJson dir/*.tar.gz out.json --master=local[*] --partitions=1000 --compress=GzipCodec --source-type=MAINSTREAM_NEWS;FORUM"); 49 | return; 50 | } 51 | 52 | final SparkConf conf = new SparkConf() 53 | .setAppName(ProtoToJson.class.getName()) 54 | .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") 55 | .registerKryoClasses(new Class[] { EntryWrapper.class, EntryWrapper.PermalinkEntryWrapper.class, 56 | EntryWrapper.FeedEntryWrapper.class, EntryWrapper.FeedWrapper.class, EntryWrapper.SourceWrapper.class, ArrayList.class, 57 | List.class }); 58 | 59 | 60 | double samplingRate = 1.0; 61 | boolean merge = false; 62 | Optional compressionCodec = Optional.empty(); 63 | int numPartitions = 0; 64 | final List allowedSources = new ArrayList<>(); 65 | boolean clean = true; 66 | boolean tokenize = true; 67 | boolean removeDuplicates = false; 68 | for (int i = 2; i < args.length; i++) { 69 | String[] arg = args[i].split("="); 70 | switch (arg[0]) { 71 | case "--master": 72 | conf.setMaster(arg[1]); 73 | break; 74 | case "--sample": 75 | samplingRate = Double.parseDouble(arg[1]); 76 | break; 77 | case "--merge": 78 | merge = true; 79 | break; 80 | case "--partitions": 81 | numPartitions = Integer.parseInt(arg[1]); 82 | break; 83 | case "--compress": 84 | compressionCodec = Optional.of(arg[1]); 85 | break; 86 | case "--source-type": 87 | allowedSources.addAll(Arrays.asList(arg[1].split(";"))); 88 | break; 89 | case "--clean": 90 | clean = Boolean.parseBoolean(arg[1]); 91 | break; 92 | case "--tokenize": 93 | tokenize = Boolean.parseBoolean(arg[1]); 94 | break; 95 | case "--remove-duplicates": 96 | removeDuplicates = true; 97 | break; 98 | } 99 | } 100 | 101 | // Clean-up 102 | if (merge) { 103 | FileUtils.deleteQuietly(new File(args[1])); 104 | } 105 | 106 | Stopwatch sw = new Stopwatch(); 107 | 108 | try (JavaSparkContext sc = new JavaSparkContext(conf)) { 109 | // The accumulator is used to count the number of documents efficiently 110 | final LongAccumulator counter = sc.sc().longAccumulator(); 111 | 112 | final boolean cleanFlag = clean; 113 | final boolean tokenizeFlag = tokenize; 114 | 115 | JavaRDD rdd = sc.binaryFiles(args[0]).mapPartitions(it -> { 116 | return new CombinedDecoder(it, cleanFlag, tokenizeFlag); 117 | }); 118 | 119 | if (!allowedSources.isEmpty()) { 120 | rdd = rdd.filter(x -> allowedSources.contains(x.getSource().getPublisherType())); 121 | } 122 | 123 | if (samplingRate < 1.0) { 124 | rdd = rdd.sample(false, samplingRate); 125 | } 126 | 127 | if (numPartitions > 0) { 128 | rdd = rdd.repartition(numPartitions); 129 | } 130 | 131 | if (removeDuplicates) { 132 | rdd = rdd.mapToPair(x -> new Tuple2<>(x.getPermalinkEntry().getContent(), x)) 133 | .reduceByKey((x, y) -> { 134 | // Deterministic "distinct": return the article with the lowest ID 135 | long id1 = Long.parseLong(x.getPermalinkEntry().getIdentifier()); 136 | long id2 = Long.parseLong(y.getPermalinkEntry().getIdentifier()); 137 | return id1 < id2 ? x : y; 138 | }) 139 | .map(x -> x._2); 140 | } 141 | 142 | JavaRDD out = rdd.map(x -> { 143 | counter.add(1); 144 | // Convert the document to a JSON object 145 | return JSON_BUILDER.toJson(x); 146 | }); 147 | 148 | if (compressionCodec.isPresent()) { 149 | try { 150 | out.saveAsTextFile(args[1], 151 | Class.forName("org.apache.hadoop.io.compress." + compressionCodec.get()) 152 | .asSubclass(CompressionCodec.class)); 153 | } catch (ClassNotFoundException e) { 154 | throw new IllegalArgumentException("Invalid compression codec", e); 155 | } 156 | } else { 157 | out.saveAsTextFile(args[1]); 158 | } 159 | 160 | long count = counter.value(); 161 | System.out.println("Processed " + count + " documents"); 162 | double time = sw.printTime(); 163 | System.out.println((time / count * 1000) + " ms per document"); 164 | System.out.println((count / time) + " documents per second"); 165 | } 166 | 167 | // If the destination is not a HDFS path (i.e. a local one), we want a single file 168 | if (merge) { 169 | mergeHdfsFile(args[1]); 170 | } 171 | } 172 | 173 | /** 174 | * Merges multiple HDFS chunks into one single file. 175 | * @param fileName the file (directory) to merge 176 | * @throws IOException 177 | */ 178 | public static void mergeHdfsFile(String fileName) throws IOException { 179 | System.out.println("Merging files..."); 180 | Stopwatch stopwatch = new Stopwatch(); 181 | 182 | BufferedOutputStream out = new BufferedOutputStream(new FileOutputStream(fileName + "_tmp")); 183 | File folder = new File(fileName); 184 | File[] files = folder.listFiles((dir, name) -> name.startsWith("part")); 185 | 186 | // Sort files by chunk ID 187 | Arrays.sort(files, (a, b) -> a.getPath().compareTo(b.getPath())); 188 | 189 | for (File file : files) { 190 | InputStream in = new FileInputStream(file.getPath()); 191 | IOUtils.copyLarge(in, out); 192 | in.close(); 193 | } 194 | out.close(); 195 | FileUtils.deleteQuietly(new File(fileName)); 196 | Files.move(Paths.get(fileName + "_tmp"), Paths.get(fileName)); 197 | System.out.println("Merged " + files.length + " files."); 198 | stopwatch.printTime(); 199 | } 200 | } 201 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/spinn3r/converter/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-dlab/Quotebank/a203acf5a05e34841671578d14d6a0be6a66a3ef/quootstrap/src/main/java/ch/epfl/dlab/spinn3r/converter/README.md -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/spinn3r/converter/SpinnerDecoder.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.spinn3r.converter; 2 | import java.io.ByteArrayInputStream; 3 | import java.io.IOException; 4 | 5 | import com.spinn3r.api.EntryDecoderFactory; 6 | import com.spinn3r.api.protobuf.ContentApi; 7 | import com.spinn3r.api.protobuf.ContentApi.Entry; 8 | import com.spinn3r.api.util.Decoder; 9 | 10 | public class SpinnerDecoder extends AbstractDecoder { 11 | 12 | private final Decoder decoder; 13 | 14 | public SpinnerDecoder(byte[] data) { 15 | EntryDecoderFactory factory = EntryDecoderFactory.newFactory(); 16 | decoder = factory.get(new ByteArrayInputStream(data)); 17 | } 18 | 19 | @Override 20 | protected Entry getNextImpl() { 21 | try { 22 | Entry decoded = decoder.read(); 23 | if (decoded == null) { 24 | decoder.close(); 25 | } 26 | return decoded; 27 | } catch (IOException e) { 28 | throw new IllegalStateException("Unable to read stream", e); 29 | } 30 | } 31 | 32 | } 33 | -------------------------------------------------------------------------------- /quootstrap/src/main/java/ch/epfl/dlab/spinn3r/converter/Stopwatch.java: -------------------------------------------------------------------------------- 1 | package ch.epfl.dlab.spinn3r.converter; 2 | 3 | /** 4 | * This class is used for measuring time. 5 | */ 6 | public class Stopwatch { 7 | 8 | private final long startTime; 9 | 10 | public Stopwatch() { 11 | startTime = System.currentTimeMillis(); 12 | } 13 | 14 | public double printTime() { 15 | long endTime = System.currentTimeMillis(); 16 | double elapsed = (endTime - startTime) / 1000.0; 17 | System.out.println("Elapsed time: " + elapsed + " seconds."); 18 | return elapsed; 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | per-file-ignores = __init__.py: F401 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import logging 4 | import os 5 | 6 | import torch 7 | 8 | from quobert.model import BertForQuotationAttribution, evaluate 9 | from quobert.utils.data import ConcatParquetDataset, ParquetDataset 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "--model_dir", 17 | default=None, 18 | type=str, 19 | required=True, 20 | help="Path to pre-trained model", 21 | ) 22 | parser.add_argument( 23 | "--output_dir", 24 | default=None, 25 | type=str, 26 | required=True, 27 | help="The output directory where the model results will be written.", 28 | ) 29 | parser.add_argument( 30 | "--test_dir", 31 | default=None, 32 | type=str, 33 | required=True, 34 | help="The input test directory. Should contain (.gz.parquet) files", 35 | ) 36 | parser.add_argument( 37 | "--output_speakers_only", 38 | action="store_true", 39 | help="If set, only output the top1 speakers instead of the probabilities associated", 40 | ) 41 | parser.add_argument( 42 | "--per_gpu_eval_batch_size", 43 | default=128, 44 | type=int, 45 | help="Batch size per GPU/CPU for evaluation.", 46 | ) 47 | 48 | args = parser.parse_args() 49 | args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 50 | args.n_gpu = torch.cuda.device_count() 51 | 52 | logging.basicConfig( 53 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 54 | datefmt="%m/%d/%Y %H:%M:%S", 55 | level=logging.INFO, 56 | ) 57 | 58 | logger.info(f"Started loading the dataset from {args.test_dir}") 59 | files = glob.glob(os.path.join(args.test_dir, "**.gz.parquet")) 60 | datasets = [ParquetDataset(f) for f in files] 61 | concat_dataset = ConcatParquetDataset(datasets) 62 | 63 | model = BertForQuotationAttribution.from_pretrained(args.model_dir) 64 | model.to(args.device) 65 | args.output_file = os.path.join(args.output_dir, f"results.csv") 66 | evaluate(args, model, concat_dataset, output_proba=not args.output_speakers_only) 67 | # logger.info(f"EM: {result * 100:.2f}%") 68 | --------------------------------------------------------------------------------